1diff --git a/cmake/package_lite.cmake b/cmake/package_lite.cmake 2index 2254c2a7..f15724f1 100644 3--- a/cmake/package_lite.cmake 4+++ b/cmake/package_lite.cmake 5@@ -474,7 +474,7 @@ if(PLATFORM_ARM64) 6 COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ops*" EXCLUDE) 7 install(DIRECTORY ${TOP_DIR}/include/c_api/ DESTINATION ${RUNTIME_INC_DIR}/c_api 8 COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") 9- if(ANDROID_NDK_TOOLCHAIN_INCLUDED OR MSLITE_ENABLE_CONVERTER OR TARGET_HIMIX) 10+ if(ANDROID_NDK_TOOLCHAIN_INCLUDED OR MSLITE_ENABLE_CONVERTER OR TARGET_HIMIX OR TARGET_OHOS) 11 __install_micro_wrapper() 12 endif() 13 if(MSLITE_ENABLE_RUNTIME_GLOG) 14diff --git a/mindspore/ccsrc/backend/common/optimizer/pass.h b/mindspore/ccsrc/backend/common/optimizer/pass.h 15new file mode 100644 16index 00000000..8d396164 17--- /dev/null 18+++ b/mindspore/ccsrc/backend/common/optimizer/pass.h 19@@ -0,0 +1,48 @@ 20+/** 21+ * Copyright 2023 Huawei Technologies Co., Ltd 22+ * 23+ * Licensed under the Apache License, Version 2.0 (the "License"); 24+ * you may not use this file except in compliance with the License. 25+ * You may obtain a copy of the License at 26+ * 27+ * http://www.apache.org/licenses/LICENSE-2.0 28+ * 29+ * Unless required by applicable law or agreed to in writing, software 30+ * distributed under the License is distributed on an "AS IS" BASIS, 31+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 32+ * See the License for the specific language governing permissions and 33+ * limitations under the License. 34+ */ 35+#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PASS_H_ 36+#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PASS_H_ 37+#include <memory> 38+#include <string> 39+#include "ir/anf.h" 40+#include "mindspore/core/ops/array_ops.h" 41+#include "mindspore/core/ops/lite_ops.h" 42+#include "utils/trace_base.h" 43+ 44+namespace mindspore { 45+namespace opt { 46+class CacheManager; 47+using CacheManagerPtr = std::shared_ptr<CacheManager>; 48+ 49+// @brief ANF Graph level optimization base pass 50+class Pass { 51+public: 52+ explicit Pass(const std::string &name = "pass") : name_(name) {} 53+ virtual ~Pass() = default; 54+ virtual bool Run(const FuncGraphPtr &fun_graph) = 0; 55+ const std::string &name() const { return name_;} 56+ void SetCacheManager(const CacheManagerPtr &cm) { cache_manager_ = cm;} 57+ const CacheManagerPtr &GetCacheManager() const {return cache_manager_;} 58+ 59+private: 60+ const std::string name_; 61+ CacheManagerPtr cache_manager_; 62+}; 63+using PassPtr = std::shared_ptr<Pass>; 64+} // namespace opt 65+} // namespace mindspore 66+ 67+#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PASS_H_ 68diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/lstm_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/lstm_cpu_kernel.cc 69index 55bbddac..378ef00c 100644 70--- a/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/lstm_cpu_kernel.cc 71+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/lstm_cpu_kernel.cc 72@@ -60,6 +60,8 @@ bool LstmCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vec 73 hidden_size_ = kernel_ptr->get_hidden_size(); 74 num_layers_ = kernel_ptr->get_num_layers(); 75 has_bias_ = kernel_ptr->get_has_bias(); 76+ proj_size_ = kernel_ptr->get_proj_size(); 77+ real_hidden_size_ = proj_size_ > 0 ? proj_size_ : hidden_size_; 78 constexpr int kBidirectional = 2; 79 num_directions_ = 1; 80 if (bidirectional_) { 81@@ -73,14 +75,20 @@ bool LstmCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vec 82 MS_LOG(EXCEPTION) << "Layers must be lower than 100!"; 83 } 84 85+ weight_size_ = 0; 86+ weight_h_size_ = 0; 87+ weight_r_size_ = 0; 88 for (int i = 0; i < num_layers_; ++i) { 89 weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_); 90- weight_h_size_ += gate_size * hidden_size_; 91+ weight_h_size_ += gate_size * real_hidden_size_; 92+ weight_r_size_ += hidden_size_ * proj_size_; 93 } 94 weight_size_ = weight_size_ * num_directions_; 95 weight_h_size_ = weight_h_size_ * num_directions_; 96+ weight_r_size_ = weight_r_size_ * num_directions_; 97 weights_dims_ = {num_layers_, num_directions_, input_size_, kGateNum, hidden_size_}; 98- weights_h_dims_ = {num_layers_, num_directions_, hidden_size_, kGateNum, hidden_size_}; 99+ weights_h_dims_ = {num_layers_, num_directions_, real_hidden_size_, kGateNum, hidden_size_}; 100+ weights_r_dims_ = {num_layers_, num_directions_, hidden_size_, proj_size_}; 101 bias_dims_ = {num_layers_, num_directions_, kGateNum, hidden_size_}; 102 is_training_ = 103 base_operator->HasAttr(kAttrIsTraining) ? GetValue<bool>(base_operator->GetAttr(kAttrIsTraining)) : true; 104@@ -110,10 +118,10 @@ int LstmCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::ve 105 direction = dnnl::rnn_direction::bidirectional_concat; 106 } 107 dim src_dims = {seq_len_, batch_size_, input_size_}; 108- dim src_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; 109+ dim src_h_dims = {num_layers_, num_directions_, batch_size_, real_hidden_size_}; 110 dim src_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; 111- dim dst_dims = {seq_len_, batch_size_, static_cast<int64_t>(hidden_size_) * num_directions_}; 112- dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; 113+ dim dst_dims = {seq_len_, batch_size_, real_hidden_size_ * num_directions_}; 114+ dim dst_h_dims = {num_layers_, num_directions_, batch_size_, real_hidden_size_}; 115 dim dst_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; 116 dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc); 117 dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc); 118@@ -126,13 +134,16 @@ int LstmCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::ve 119 auto prop_kind = is_training_ ? dnnl::prop_kind::forward_training : dnnl::prop_kind::forward_inference; 120 auto weights_desc = formatted_md(weights_dims_, tag::any); 121 auto weights_h_desc = formatted_md(weights_h_dims_, tag::any); 122- auto desc = 123- CreatePrimitive<dnnl::lstm_forward::desc>(prop_kind, direction, src_desc, src_h_desc, src_c_desc, weights_desc, 124- weights_h_desc, bias_desc, dst_desc, dst_h_desc, dst_c_desc); 125+ auto weights_r_desc = proj_size_ > 0 ? formatted_md(weights_r_dims_, tag::any) : dnnl::memory::desc(); 126+ auto peephole_desc = dnnl::memory::desc(); 127+ auto desc = CreatePrimitive<dnnl::lstm_forward::desc>(prop_kind, direction, src_desc, src_h_desc, src_c_desc, 128+ weights_desc, weights_h_desc, peephole_desc, weights_r_desc, 129+ bias_desc, dst_desc, dst_h_desc, dst_c_desc); 130 prim_desc_ = CreateDesc<dnnl::lstm_forward::primitive_desc>(*desc, engine_); 131 primitive_ = CreatePrimitive<dnnl::lstm_forward>(prim_desc_); 132 auto weights_layer = GetWeightsLayerDesc(prim_desc_); 133 auto weights_iter = GetWeightsIterDesc(prim_desc_); 134+ auto weights_proj = GetWeightsProjectionDesc(prim_desc_); 135 bias_desc_ = GetBiasDesc(prim_desc_); 136 if (is_training_) { 137 auto wksp_desc = GetWorkspaceDesc(prim_desc_); 138@@ -144,6 +155,7 @@ int LstmCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::ve 139 AddArgument(DNNL_ARG_SRC_ITER_C, src_c_desc); 140 AddArgument(DNNL_ARG_WEIGHTS_LAYER, weights_layer); 141 AddArgument(DNNL_ARG_WEIGHTS_ITER, weights_iter); 142+ AddArgument(DNNL_ARG_WEIGHTS_PROJECTION, weights_proj); 143 AddArgument(DNNL_ARG_BIAS, bias_desc); 144 AddArgument(DNNL_ARG_DST_LAYER, dst_desc); 145 AddArgument(DNNL_ARG_DST_ITER, dst_h_desc); 146@@ -151,10 +163,13 @@ int LstmCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::ve 147 148 auto weights_dims_desc = CreateDesc<dnnl::memory::desc>(weights_dims_, dt::f32, tag::ldgoi); 149 auto weights_h_dims_desc = CreateDesc<dnnl::memory::desc>(weights_h_dims_, dt::f32, tag::ldgoi); 150+ auto weights_r_dims_desc = CreateDesc<dnnl::memory::desc>(weights_r_dims_, dt::f32, tag::ldoi); 151 user_weights_memory_ = CreateDesc<dnnl::memory>(weights_dims_desc, engine_); 152 user_weights_h_memory_ = CreateDesc<dnnl::memory>(weights_h_dims_desc, engine_); 153+ user_weights_r_memory_ = CreateDesc<dnnl::memory>(weights_r_dims_desc, engine_); 154 weights_memory_ = CreateDesc<dnnl::memory>(weights_layer, engine_); 155 weights_h_memory_ = CreateDesc<dnnl::memory>(weights_iter, engine_); 156+ weights_r_memory_ = CreateDesc<dnnl::memory>(weights_proj, engine_); 157 bias_memory_ = CreateDesc<dnnl::memory>(bias_desc_, engine_); 158 159 InitOutputSize(outputs); 160@@ -163,13 +178,20 @@ int LstmCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::ve 161 162 bool LstmCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &, 163 const std::vector<kernel::AddressPtr> &outputs) { 164+ size_t offset = 0; 165 SetDataHandle(user_weights_memory_, inputs[kInputWeightIndex]->addr); 166- SetDataHandle(user_weights_h_memory_, reinterpret_cast<float *>(inputs[kInputWeightIndex]->addr) + weight_size_); 167+ offset += weight_size_; 168+ SetDataHandle(user_weights_h_memory_, reinterpret_cast<float *>(inputs[kInputWeightIndex]->addr) + offset); 169+ offset += weight_h_size_; 170 Reorder(&user_weights_memory_, &weights_memory_); 171 Reorder(&user_weights_h_memory_, &weights_h_memory_); 172+ if (proj_size_ > 0) { 173+ SetDataHandle(user_weights_r_memory_, reinterpret_cast<float *>(inputs[kInputWeightIndex]->addr) + offset); 174+ Reorder(&user_weights_r_memory_, &weights_r_memory_); 175+ offset += weight_r_size_; 176+ } 177 if (has_bias_) { 178- SetDataHandle(bias_memory_, 179- reinterpret_cast<float *>(inputs[kInputWeightIndex]->addr) + weight_size_ + weight_h_size_); 180+ SetDataHandle(bias_memory_, reinterpret_cast<float *>(inputs[kInputWeightIndex]->addr) + offset); 181 } else { 182 auto size = GetSize(bias_desc_); 183 if (memset_s(GetDataHandle(bias_memory_), size, 0, size) != EOK) { 184@@ -182,6 +204,7 @@ bool LstmCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, con 185 SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[kInputCIndex]->addr); 186 SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, GetDataHandle(weights_memory_)); 187 SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, GetDataHandle(weights_h_memory_)); 188+ SetArgumentHandle(DNNL_ARG_WEIGHTS_PROJECTION, GetDataHandle(weights_r_memory_)); 189 SetArgumentHandle(DNNL_ARG_BIAS, GetDataHandle(bias_memory_)); 190 SetArgumentHandle(DNNL_ARG_DST_LAYER, outputs[0]->addr); 191 SetArgumentHandle(DNNL_ARG_DST_ITER, outputs[1]->addr); 192diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/lstm_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/lstm_cpu_kernel.h 193index 42609eed..a0241c16 100644 194--- a/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/lstm_cpu_kernel.h 195+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/lstm_cpu_kernel.h 196@@ -58,14 +58,17 @@ class LstmCpuKernelMod : public MKLCpuKernelMod { 197 private: 198 void InitOutputSize(const std::vector<KernelTensorPtr> &outputs); 199 200- int weight_size_{0}; 201- int weight_h_size_{0}; 202- int input_size_{0}; 203- int hidden_size_{0}; 204- int num_layers_{0}; 205- int batch_size_{0}; 206- int seq_len_{0}; 207- int num_directions_{0}; 208+ int64_t weight_size_{0}; 209+ int64_t weight_h_size_{0}; 210+ int64_t weight_r_size_{0}; 211+ int64_t input_size_{0}; 212+ int64_t hidden_size_{0}; 213+ int64_t num_layers_{0}; 214+ int64_t batch_size_{0}; 215+ int64_t seq_len_{0}; 216+ int64_t num_directions_{0}; 217+ int64_t proj_size_{0}; 218+ int64_t real_hidden_size_{0}; 219 bool bidirectional_{false}; 220 bool has_bias_{false}; 221 bool is_training_{false}; 222@@ -73,13 +76,16 @@ class LstmCpuKernelMod : public MKLCpuKernelMod { 223 224 dnnl::memory::dims weights_dims_; 225 dnnl::memory::dims weights_h_dims_; 226+ dnnl::memory::dims weights_r_dims_; 227 dnnl::memory::dims bias_dims_; 228 dnnl::lstm_forward::primitive_desc prim_desc_; 229 dnnl::memory::desc bias_desc_; 230 dnnl::memory user_weights_memory_; 231 dnnl::memory user_weights_h_memory_; 232+ dnnl::memory user_weights_r_memory_; 233 dnnl::memory weights_memory_; 234 dnnl::memory weights_h_memory_; 235+ dnnl::memory weights_r_memory_; 236 dnnl::memory bias_memory_; 237 }; 238 } // namespace kernel 239diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/lstm_grad_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/lstm_grad_cpu_kernel.cc 240index aa1f8b44..0b5d09c1 100644 241--- a/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/lstm_grad_cpu_kernel.cc 242+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/lstm_grad_cpu_kernel.cc 243@@ -62,6 +62,8 @@ bool LSTMGradCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std: 244 hidden_size_ = op_prim->get_hidden_size(); 245 num_layers_ = op_prim->get_num_layers(); 246 has_bias_ = op_prim->get_has_bias(); 247+ proj_size_ = op_prim->get_proj_size(); 248+ real_hidden_size_ = proj_size_ > 0 ? proj_size_ : hidden_size_; 249 auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); 250 auto match = MatchKernelAttr(kernel_attr, GetOpSupport()); 251 if (!match.first) { 252@@ -103,12 +105,15 @@ int LSTMGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std 253 } 254 weight_size_ = 0; 255 weight_h_size_ = 0; 256+ weight_r_size_ = 0; 257 for (int64_t i = 0; i < num_layers_; ++i) { 258 weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_); 259- weight_h_size_ += gate_size * hidden_size_; 260+ weight_h_size_ += gate_size * real_hidden_size_; 261+ weight_r_size_ += proj_size_ * hidden_size_; 262 } 263 weight_size_ = weight_size_ * num_directions_; 264 weight_h_size_ = weight_h_size_ * num_directions_; 265+ weight_r_size_ = weight_r_size_ * num_directions_; 266 if (num_directions_ * num_layers_ != src_h_shape[0]) { 267 MS_LOG(ERROR) << "Error iteration shape!"; 268 return KRET_RESIZE_FAILED; 269@@ -124,13 +129,14 @@ void LSTMGradCpuKernelMod::InitDnnl() { 270 direction = dnnl::rnn_direction::bidirectional_concat; 271 } 272 dim src_dims = {seq_len_, batch_size_, input_size_}; 273- dim src_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; 274+ dim src_h_dims = {num_layers_, num_directions_, batch_size_, real_hidden_size_}; 275 dim src_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; 276 weights_dims_ = {num_layers_, num_directions_, input_size_, kNumberFour, hidden_size_}; 277- weights_h_dims_ = {num_layers_, num_directions_, hidden_size_, kNumberFour, hidden_size_}; 278+ weights_h_dims_ = {num_layers_, num_directions_, real_hidden_size_, kNumberFour, hidden_size_}; 279+ weights_r_dims_ = {num_layers_, num_directions_, hidden_size_, proj_size_}; 280 bias_dims_ = {num_layers_, num_directions_, kNumberFour, hidden_size_}; 281- dim dst_dims = {seq_len_, batch_size_, static_cast<int64_t>(hidden_size_) * num_directions_}; 282- dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; 283+ dim dst_dims = {seq_len_, batch_size_, real_hidden_size_ * num_directions_}; 284+ dim dst_h_dims = {num_layers_, num_directions_, batch_size_, real_hidden_size_}; 285 dim dst_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; 286 dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc); 287 dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc); 288@@ -141,15 +147,17 @@ void LSTMGradCpuKernelMod::InitDnnl() { 289 dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc); 290 auto weights_desc = formatted_md(weights_dims_, tag::any); 291 auto weights_h_desc = formatted_md(weights_h_dims_, tag::any); 292+ auto weights_r_desc = proj_size_ > 0 ? formatted_md(weights_r_dims_, tag::any) : dnnl::memory::desc(); 293+ auto peepole_desc = dnnl::memory::desc(); 294 295- auto forward_desc = CreatePrimitive<dnnl::lstm_forward::desc>(dnnl::prop_kind::forward_training, direction, src_desc, 296- src_h_desc, src_c_desc, weights_desc, weights_h_desc, 297- bias_desc, dst_desc, dst_h_desc, dst_c_desc); 298+ auto forward_desc = CreatePrimitive<dnnl::lstm_forward::desc>( 299+ dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc, weights_desc, weights_h_desc, 300+ peepole_desc, weights_r_desc, bias_desc, dst_desc, dst_h_desc, dst_c_desc); 301 auto prim_forward_desc = CreateDesc<dnnl::lstm_forward::primitive_desc>(*forward_desc, eng); 302 auto backward_desc = CreatePrimitive<dnnl::lstm_backward::desc>( 303- dnnl::prop_kind::backward, direction, src_desc, src_h_desc, src_c_desc, weights_desc, weights_h_desc, bias_desc, 304- dst_desc, dst_h_desc, dst_c_desc, src_desc, src_h_desc, src_c_desc, weights_desc, weights_h_desc, bias_desc, 305- dst_desc, dst_h_desc, dst_c_desc); 306+ dnnl::prop_kind::backward, direction, src_desc, src_h_desc, src_c_desc, weights_desc, weights_h_desc, peepole_desc, 307+ weights_r_desc, bias_desc, dst_desc, dst_h_desc, dst_c_desc, src_desc, src_h_desc, src_c_desc, weights_desc, 308+ weights_h_desc, peepole_desc, weights_r_desc, bias_desc, dst_desc, dst_h_desc, dst_c_desc); 309 prim_backward_desc_ = CreateDesc<dnnl::lstm_backward::primitive_desc>(*backward_desc, eng, prim_forward_desc); 310 primitive_ = CreatePrimitive<dnnl::lstm_backward>(prim_backward_desc_); 311 auto wksp_desc = GetWorkspaceDesc(prim_forward_desc); 312@@ -159,24 +167,31 @@ void LSTMGradCpuKernelMod::InitDnnl() { 313 // construct fw memory 314 weights_layer_desc_ = GetWeightsLayerDesc(prim_backward_desc_); 315 weights_iter_desc_ = GetWeightsIterDesc(prim_backward_desc_); 316+ weights_proj_desc_ = GetWeightsProjectionDesc(prim_backward_desc_); 317 bias_desc_ = GetBiasDesc(prim_backward_desc_); 318 auto weights_mem_desc = CreateDesc<dnnl::memory::desc>(weights_dims_, dt::f32, tag::ldgoi); 319 auto weights_h_mem_desc = CreateDesc<dnnl::memory::desc>(weights_h_dims_, dt::f32, tag::ldgoi); 320+ auto weights_r_mem_desc = CreateDesc<dnnl::memory::desc>(weights_r_dims_, dt::f32, tag::ldoi); 321 user_weights_memory_ = CreateDesc<dnnl::memory>(weights_mem_desc, eng); 322 user_weights_h_memory_ = CreateDesc<dnnl::memory>(weights_h_mem_desc, eng); 323+ user_weights_r_memory_ = CreateDesc<dnnl::memory>(weights_r_mem_desc, eng); 324 weights_memory_ = CreateDesc<dnnl::memory>(weights_layer_desc_, eng); 325 weights_h_memory_ = CreateDesc<dnnl::memory>(weights_iter_desc_, eng); 326+ weights_r_memory_ = CreateDesc<dnnl::memory>(weights_proj_desc_, eng); 327 bias_memory_ = CreateDesc<dnnl::memory>(bias_desc_, eng); 328 329 // construct bw memory 330 diff_weights_layer_desc_ = GetDiffWeightsLayerDesc(prim_backward_desc_); 331 diff_weights_iter_desc_ = GetDiffWeightsIterDesc(prim_backward_desc_); 332+ diff_weights_proj_desc_ = GetDiffWeightsProjectionDesc(prim_backward_desc_); 333 diff_bias_desc_ = GetDiffBiasDesc(prim_backward_desc_); 334 diff_weights_memory_ = CreateDesc<dnnl::memory>(diff_weights_layer_desc_, eng); 335 diff_weights_h_memory_ = CreateDesc<dnnl::memory>(diff_weights_iter_desc_, eng); 336+ diff_weights_r_memory_ = CreateDesc<dnnl::memory>(diff_weights_proj_desc_, eng); 337 diff_bias_memory_ = CreateDesc<dnnl::memory>(diff_bias_desc_, eng); 338 user_diff_weights_memory_ = CreateDesc<dnnl::memory>(weights_mem_desc, eng); 339 user_diff_weights_h_memory_ = CreateDesc<dnnl::memory>(weights_h_mem_desc, eng); 340+ user_diff_weights_r_memory_ = CreateDesc<dnnl::memory>(weights_r_mem_desc, eng); 341 } 342 343 void LSTMGradCpuKernelMod::AddArgumentOp(const dnnl::memory::desc &src_desc, const dnnl::memory::desc &src_h_desc, 344@@ -188,6 +203,7 @@ void LSTMGradCpuKernelMod::AddArgumentOp(const dnnl::memory::desc &src_desc, con 345 AddArgument(DNNL_ARG_SRC_ITER_C, src_c_desc); 346 AddArgument(DNNL_ARG_WEIGHTS_LAYER, weights_layer_desc_); 347 AddArgument(DNNL_ARG_WEIGHTS_ITER, weights_iter_desc_); 348+ AddArgument(DNNL_ARG_WEIGHTS_PROJECTION, weights_proj_desc_); 349 AddArgument(DNNL_ARG_BIAS, bias_desc); 350 AddArgument(DNNL_ARG_DST_LAYER, dst_desc); 351 AddArgument(DNNL_ARG_DST_ITER, dst_h_desc); 352@@ -197,6 +213,7 @@ void LSTMGradCpuKernelMod::AddArgumentOp(const dnnl::memory::desc &src_desc, con 353 AddArgument(DNNL_ARG_DIFF_SRC_ITER_C, src_c_desc); 354 AddArgument(DNNL_ARG_DIFF_WEIGHTS_LAYER, diff_weights_layer_desc_); 355 AddArgument(DNNL_ARG_DIFF_WEIGHTS_ITER, diff_weights_iter_desc_); 356+ AddArgument(DNNL_ARG_DIFF_WEIGHTS_PROJECTION, diff_weights_proj_desc_); 357 AddArgument(DNNL_ARG_DIFF_BIAS, bias_desc); 358 AddArgument(DNNL_ARG_DIFF_DST_LAYER, dst_desc); 359 AddArgument(DNNL_ARG_DIFF_DST_ITER, dst_h_desc); 360@@ -211,6 +228,7 @@ void LSTMGradCpuKernelMod::SetArgumentHandleOp(const std::vector<kernel::Address 361 SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[kSrcIterCIdx]->addr); 362 SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, GetDataHandle(weights_memory_)); 363 SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, GetDataHandle(weights_h_memory_)); 364+ SetArgumentHandle(DNNL_ARG_WEIGHTS_PROJECTION, GetDataHandle(weights_r_memory_)); 365 SetArgumentHandle(DNNL_ARG_BIAS, GetDataHandle(bias_memory_)); 366 SetArgumentHandle(DNNL_ARG_DST_LAYER, inputs[kDstLayerIdx]->addr); 367 SetArgumentHandle(DNNL_ARG_DST_ITER, inputs[kDstIterIdx]->addr); 368@@ -221,6 +239,7 @@ void LSTMGradCpuKernelMod::SetArgumentHandleOp(const std::vector<kernel::Address 369 SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER_C, outputs[kSrcIterCIdx]->addr); 370 SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_LAYER, GetDataHandle(diff_weights_memory_)); 371 SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_ITER, GetDataHandle(diff_weights_h_memory_)); 372+ SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_PROJECTION, GetDataHandle(diff_weights_r_memory_)); 373 SetArgumentHandle(DNNL_ARG_DIFF_BIAS, GetDataHandle(diff_bias_memory_)); 374 SetArgumentHandle(DNNL_ARG_DIFF_DST_LAYER, inputs[kDiffDstLayerIdx]->addr); 375 SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER, inputs[kDiffDstIterIdx]->addr); 376@@ -241,13 +260,20 @@ bool LSTMGradCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, 377 const std::vector<kernel::AddressPtr> &outputs) { 378 CHECK_KERNEL_INPUTS_NUM(inputs.size(), kLstmGradInputsNum, kernel_name_); 379 CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kLstmGradOutputsNum, kernel_name_); 380+ size_t offset = 0; 381 SetDataHandle(user_weights_memory_, inputs[kInputWeightIndex]->addr); 382- SetDataHandle(user_weights_h_memory_, reinterpret_cast<float *>(inputs[kInputWeightIndex]->addr) + weight_size_); 383+ offset += weight_size_; 384+ SetDataHandle(user_weights_h_memory_, reinterpret_cast<float *>(inputs[kInputWeightIndex]->addr) + offset); 385+ offset += weight_h_size_; 386 Reorder(&user_weights_memory_, &weights_memory_); 387 Reorder(&user_weights_h_memory_, &weights_h_memory_); 388+ if (proj_size_ > 0) { 389+ SetDataHandle(user_weights_r_memory_, reinterpret_cast<float *>(inputs[kInputWeightIndex]->addr) + offset); 390+ Reorder(&user_weights_r_memory_, &weights_r_memory_); 391+ offset += weight_r_size_; 392+ } 393 if (has_bias_) { 394- SetDataHandle(bias_memory_, 395- reinterpret_cast<float *>(inputs[kInputWeightIndex]->addr) + weight_size_ + weight_h_size_); 396+ SetDataHandle(bias_memory_, reinterpret_cast<float *>(inputs[kInputWeightIndex]->addr) + offset); 397 } else { 398 auto dst_ptr = GetDataHandle(bias_memory_); 399 auto size = GetSize(bias_desc_); 400@@ -256,16 +282,23 @@ bool LSTMGradCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, 401 } 402 } 403 404+ offset = 0; 405 SetDataHandle(user_diff_weights_memory_, outputs[kOutputWeightIndex]->addr); 406- SetDataHandle(user_diff_weights_h_memory_, 407- reinterpret_cast<float *>(outputs[kOutputWeightIndex]->addr) + weight_size_); 408+ offset += weight_size_; 409+ SetDataHandle(user_diff_weights_h_memory_, reinterpret_cast<float *>(outputs[kOutputWeightIndex]->addr) + offset); 410+ offset += weight_h_size_; 411 ResetMemory(user_diff_weights_memory_, "user weights grad"); 412 ResetMemory(user_diff_weights_h_memory_, "user weights iter grad"); 413 ResetMemory(diff_weights_memory_, "weights grad"); 414 ResetMemory(diff_weights_h_memory_, "weights iter grad"); 415+ if (proj_size_ > 0) { 416+ SetDataHandle(user_diff_weights_r_memory_, reinterpret_cast<float *>(outputs[kOutputWeightIndex]->addr) + offset); 417+ ResetMemory(user_diff_weights_r_memory_, "user weights projection grad"); 418+ ResetMemory(diff_weights_r_memory_, "weights projection grad"); 419+ offset += weight_r_size_; 420+ } 421 if (has_bias_) { 422- SetDataHandle(diff_bias_memory_, 423- reinterpret_cast<float *>(outputs[kOutputWeightIndex]->addr) + weight_size_ + weight_h_size_); 424+ SetDataHandle(diff_bias_memory_, reinterpret_cast<float *>(outputs[kOutputWeightIndex]->addr) + offset); 425 } 426 auto dst_ptr = GetDataHandle(diff_bias_memory_); 427 auto size = GetSize(diff_bias_desc_); 428@@ -276,6 +309,9 @@ bool LSTMGradCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, 429 ExecutePrimitive(); 430 Reorder(&diff_weights_memory_, &user_diff_weights_memory_); 431 Reorder(&diff_weights_h_memory_, &user_diff_weights_h_memory_); 432+ if (proj_size_ > 0) { 433+ Reorder(&diff_weights_r_memory_, &user_diff_weights_r_memory_); 434+ } 435 return true; 436 } 437 438diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/lstm_grad_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/lstm_grad_cpu_kernel.h 439index f47bafc0..9768464d 100644 440--- a/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/lstm_grad_cpu_kernel.h 441+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/lstm_grad_cpu_kernel.h 442@@ -75,34 +75,44 @@ class LSTMGradCpuKernelMod : public MKLCpuKernelMod { 443 bool has_bias_{false}; 444 int64_t weight_size_{0}; 445 int64_t weight_h_size_{0}; 446+ int64_t weight_r_size_{0}; 447 int64_t input_size_{0}; 448 int64_t hidden_size_{0}; 449 int64_t num_layers_{0}; 450 int64_t batch_size_{0}; 451 int64_t seq_len_{0}; 452+ int64_t proj_size_{0}; 453+ int64_t real_hidden_size_{0}; 454 size_t reserve_size_{0}; 455 456 dnnl::memory::dims weights_dims_; 457 dnnl::memory::dims weights_h_dims_; 458+ dnnl::memory::dims weights_r_dims_; 459 dnnl::memory::dims bias_dims_; 460 dnnl::lstm_backward::primitive_desc prim_backward_desc_; 461 462 dnnl::memory::desc weights_layer_desc_; 463 dnnl::memory::desc weights_iter_desc_; 464+ dnnl::memory::desc weights_proj_desc_; 465 dnnl::memory::desc bias_desc_; 466 dnnl::memory::desc diff_weights_layer_desc_; 467 dnnl::memory::desc diff_weights_iter_desc_; 468+ dnnl::memory::desc diff_weights_proj_desc_; 469 dnnl::memory::desc diff_bias_desc_; 470 dnnl::memory user_weights_memory_; 471 dnnl::memory user_weights_h_memory_; 472+ dnnl::memory user_weights_r_memory_; 473 dnnl::memory weights_memory_; 474 dnnl::memory weights_h_memory_; 475+ dnnl::memory weights_r_memory_; 476 dnnl::memory bias_memory_; 477 dnnl::memory diff_weights_memory_; 478 dnnl::memory diff_weights_h_memory_; 479+ dnnl::memory diff_weights_r_memory_; 480 dnnl::memory diff_bias_memory_; 481 dnnl::memory user_diff_weights_memory_; 482 dnnl::memory user_diff_weights_h_memory_; 483+ dnnl::memory user_diff_weights_r_memory_; 484 }; 485 } // namespace kernel 486 } // namespace mindspore 487diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/mkl_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/mkl_cpu_kernel.h 488index 7c8292df..0c98f8f6 100644 489--- a/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/mkl_cpu_kernel.h 490+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/mkl_cpu_kernel.h 491@@ -89,6 +89,14 @@ auto GetWeightsIterDesc(const T &prim_desc) { 492 return desc; 493 } 494 495+template <class T> 496+auto GetWeightsProjectionDesc(const T &prim_desc) { 497+ MS_LOG(DEBUG) << "begin to invoke " << demangle(typeid(T).name()) << "::weights_projection_desc()"; 498+ auto desc = prim_desc.weights_projection_desc(); 499+ MS_LOG(DEBUG) << "end to invoke " << demangle(typeid(T).name()) << "::weights_projection_desc()"; 500+ return desc; 501+} 502+ 503 template <class T> 504 auto GetBiasDesc(const T &prim_desc) { 505 MS_LOG(DEBUG) << "begin to invoke " << demangle(typeid(T).name()) << "::bias_desc()"; 506@@ -113,6 +121,14 @@ auto GetDiffWeightsIterDesc(const T &prim_desc) { 507 return desc; 508 } 509 510+template <class T> 511+auto GetDiffWeightsProjectionDesc(const T &prim_desc) { 512+ MS_LOG(DEBUG) << "begin to invoke " << demangle(typeid(T).name()) << "::diff_weights_projection_desc()"; 513+ auto desc = prim_desc.diff_weights_projection_desc(); 514+ MS_LOG(DEBUG) << "end to invoke " << demangle(typeid(T).name()) << "::diff_weights_projection_desc()"; 515+ return desc; 516+} 517+ 518 template <class T> 519 auto GetDiffBiasDesc(const T &prim_desc) { 520 MS_LOG(DEBUG) << "begin to invoke " << demangle(typeid(T).name()) << "::diff_bias_desc()"; 521diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn 522index 103e53b7..d27817be 100644 523--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn 524+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn 525@@ -501,6 +501,7 @@ infer_shape_sources = [ 526 "infer/custom_masked_fill_infer.c", 527 "infer/custom_is_inf_infer.c", 528 "infer/custom_tensor_scatter_max_infer.c", 529+ "infer/custom_gather_d_grad_v2_infer.c", 530 "infer/decoder_layer_infer.c", 531 "infer/deconv2d_infer.c", 532 "infer/depth_to_space_infer.c", 533@@ -740,6 +741,7 @@ arm64_fp16_assembly_sources = [ 534 "assembly/fp16/Matmul12X16Fp16.S", 535 "assembly/fp16/MatmulBaseFp16Neon.S", 536 "assembly/fp16/MatmulFp16Opt.S", 537+ "assembly/fp16/MatmulFp16OptV2.S", 538 "assembly/fp16/MatmulFp16.S", 539 "assembly/fp16/MatmulWinogradFp16.S", 540 "assembly/fp16/MatVecMulFp16.S", 541diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/assembly/fp16/MatmulFp16OptV2.S b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/assembly/fp16/MatmulFp16OptV2.S 542new file mode 100644 543index 00000000..2d901a3d 544--- /dev/null 545+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/assembly/fp16/MatmulFp16OptV2.S 546@@ -0,0 +1,2966 @@ 547+/** 548+ * Copyright 2023 Huawei Technologies Co., Ltd 549+ * 550+ * Licensed under the Apache License, Version 2.0 (the "License"); 551+ * you may not use this file except in compliance with the License. 552+ * You may obtain a copy of the License at 553+ * 554+ * http://www.apache.org/licenses/LICENSE-2.0 555+ * 556+ * Unless required by applicable law or agreed to in writing, software 557+ * distributed under the License is distributed on an "AS IS" BASIS, 558+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 559+ * See the License for the specific language governing permissions and 560+ * limitations under the License. 561+ */ 562+#ifdef ENABLE_ARM64 563+#include "nnacl/assembly_global.h" 564+ 565+.text 566+.align 5 567+ 568+// void MatmulFp16OptV2(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, 569+// size_t depth, size_t row, size_t col, size_t stride, size_t writeMode) 570+// x0: a 571+// x1: b 572+// x2: c 573+// x3: bias 574+// x4: act_type 575+// x5: depth 576+// x6: row 577+// x7: col 578+// x8: stride 579+// x9: writeMode 580+ 581+asm_function MatmulFp16OptV2 582+ sub sp, sp, #192 583+ st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64 584+ st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [sp], #64 585+ stp x19, x20, [sp], #16 586+ stp x21, x22, [sp], #16 587+ stp x23, x24, [sp], #16 588+ stp x29, x30, [sp], #16 589+ 590+ ldr x8, [sp] 591+ ldr x9, [sp, #8] // writeMode 592+ lsl x8, x8, #1 // stride * sizeof(float16_t) 593+ 594+ lsl x15, x7, #1 // col * sizeof(float16_t) 595+ lsl x16, x5, #1 // depth * sizeof(float16_t) 596+ mov x11, x2 597+ movi v7.8h, #0x46, lsl #8 598+ subs x6, x6, #12 599+ blt LoopRow8 600+LoopRow12: 601+ mov x11, x1 // reload matrixB 602+ mov x12, x3 // reload bias 603+ mov x13, x7 // reload col 604+ mov x21, x2 // relocate output 605+ subs x13, x13, #16 606+ blt LoopCol12x8 607+ LoopCol12x16: 608+ mov x10, x0 // update matrixA 609+ ld1 {v0.8h}, [x10], #16 610+ mov x14, x5 // reload depth 611+ prfm pldl1strm, [x11, #632] 612+ ld1 {v3.8h}, [x11], #16 613+ cbnz x12, InitFromBias12x16 614+ dup v8.2d, xzr 615+ dup v9.2d, xzr 616+ dup v10.2d, xzr 617+ dup v11.2d, xzr 618+ dup v12.2d, xzr 619+ dup v13.2d, xzr 620+ dup v14.2d, xzr 621+ dup v15.2d, xzr 622+ dup v16.2d, xzr 623+ dup v17.2d, xzr 624+ dup v18.2d, xzr 625+ dup v19.2d, xzr 626+ dup v20.2d, xzr 627+ dup v21.2d, xzr 628+ dup v22.2d, xzr 629+ dup v23.2d, xzr 630+ dup v24.2d, xzr 631+ dup v25.2d, xzr 632+ dup v26.2d, xzr 633+ dup v27.2d, xzr 634+ dup v28.2d, xzr 635+ dup v29.2d, xzr 636+ dup v30.2d, xzr 637+ dup v31.2d, xzr 638+ b Compute12x16Enter 639+ InitFromBias12x16: 640+ ld1 {v8.8h, v9.8h}, [x12] 641+ ld1 {v10.8h, v11.8h}, [x12] 642+ ld1 {v12.8h, v13.8h}, [x12] 643+ ld1 {v14.8h, v15.8h}, [x12] 644+ ld1 {v16.8h, v17.8h}, [x12] 645+ ld1 {v18.8h, v19.8h}, [x12] 646+ ld1 {v20.8h, v21.8h}, [x12] 647+ ld1 {v22.8h, v23.8h}, [x12] 648+ ld1 {v24.8h, v25.8h}, [x12] 649+ ld1 {v26.8h, v27.8h}, [x12] 650+ ld1 {v28.8h, v29.8h}, [x12] 651+ ld1 {v30.8h, v31.8h}, [x12] 652+ add x12, x12, #32 653+ Compute12x16Enter: 654+ bl Compute12x16Unit 655+ Activation12x16: 656+ cmp x4, #3 657+ beq Relu612x16 658+ cmp x4, #1 659+ beq Relu12x16 660+ b Write12x16 661+ 662+ Relu612x16: 663+ fmin v8.8h, v8.8h, v7.8h 664+ fmin v9.8h, v9.8h, v7.8h 665+ fmin v10.8h, v10.8h, v7.8h 666+ fmin v11.8h, v11.8h, v7.8h 667+ fmin v12.8h, v12.8h, v7.8h 668+ fmin v13.8h, v13.8h, v7.8h 669+ fmin v14.8h, v14.8h, v7.8h 670+ fmin v15.8h, v15.8h, v7.8h 671+ fmin v16.8h, v16.8h, v7.8h 672+ fmin v17.8h, v17.8h, v7.8h 673+ fmin v18.8h, v18.8h, v7.8h 674+ fmin v19.8h, v19.8h, v7.8h 675+ fmin v20.8h, v20.8h, v7.8h 676+ fmin v21.8h, v21.8h, v7.8h 677+ fmin v22.8h, v22.8h, v7.8h 678+ fmin v23.8h, v23.8h, v7.8h 679+ fmin v24.8h, v24.8h, v7.8h 680+ fmin v25.8h, v25.8h, v7.8h 681+ fmin v26.8h, v26.8h, v7.8h 682+ fmin v27.8h, v27.8h, v7.8h 683+ fmin v28.8h, v28.8h, v7.8h 684+ fmin v29.8h, v29.8h, v7.8h 685+ fmin v30.8h, v30.8h, v7.8h 686+ fmin v31.8h, v31.8h, v7.8h 687+ 688+ Relu12x16: 689+ dup v6.8h, wzr 690+ fmax v8.8h, v8.8h, v6.8h 691+ fmax v9.8h, v9.8h, v6.8h 692+ fmax v10.8h, v10.8h, v6.8h 693+ fmax v11.8h, v11.8h, v6.8h 694+ fmax v12.8h, v12.8h, v6.8h 695+ fmax v13.8h, v13.8h, v6.8h 696+ fmax v14.8h, v14.8h, v6.8h 697+ fmax v15.8h, v15.8h, v6.8h 698+ fmax v16.8h, v16.8h, v6.8h 699+ fmax v17.8h, v17.8h, v6.8h 700+ fmax v18.8h, v18.8h, v6.8h 701+ fmax v19.8h, v19.8h, v6.8h 702+ fmax v20.8h, v20.8h, v6.8h 703+ fmax v21.8h, v21.8h, v6.8h 704+ fmax v22.8h, v22.8h, v6.8h 705+ fmax v23.8h, v23.8h, v6.8h 706+ fmax v24.8h, v24.8h, v6.8h 707+ fmax v25.8h, v25.8h, v6.8h 708+ fmax v26.8h, v26.8h, v6.8h 709+ fmax v27.8h, v27.8h, v6.8h 710+ fmax v28.8h, v28.8h, v6.8h 711+ fmax v29.8h, v29.8h, v6.8h 712+ fmax v30.8h, v30.8h, v6.8h 713+ fmax v31.8h, v31.8h, v6.8h 714+ Write12x16: 715+ mov x22, x21 716+ add x23, x21, x8, lsl #2 717+ add x24, x21, x8, lsl #3 718+ st1 {v8.8h, v9.8h}, [x22], x8 719+ st1 {v10.8h, v11.8h}, [x22], x8 720+ st1 {v12.8h, v13.8h}, [x22], x8 721+ st1 {v14.8h, v15.8h}, [x22] 722+ st1 {v16.8h, v17.8h}, [x23], x8 723+ st1 {v18.8h, v19.8h}, [x23], x8 724+ st1 {v20.8h, v21.8h}, [x23], x8 725+ st1 {v22.8h, v23.8h}, [x23] 726+ st1 {v24.8h, v25.8h}, [x24], x8 727+ st1 {v26.8h, v27.8h}, [x24], x8 728+ st1 {v28.8h, v29.8h}, [x24], x8 729+ st1 {v30.8h, v31.8h}, [x24] 730+ add x21, x21, #32 731+ subs x13, x13, #16 732+ bge LoopCol12x16 733+ 734+ LoopCol12x8: 735+ adds x13, x13, #16 736+ cbz x13, LoopRow12End 737+ subs x13, x13, #8 738+ blt LoopCol12x4 739+ mov x10, x0 // update matrixA 740+ ld1 {v0.8h}, [x10], #16 741+ mov x14, x5 // reload depth 742+ prfm pldl1strm, [x11, #632] 743+ ld1 {v3.8h}, [x11], #16 744+ cbnz x12, InitFromBias12x8 745+ dup v8.2d, xzr 746+ dup v10.2d, xzr 747+ dup v12.2d, xzr 748+ dup v14.2d, xzr 749+ dup v16.2d, xzr 750+ dup v18.2d, xzr 751+ dup v20.2d, xzr 752+ dup v22.2d, xzr 753+ dup v24.2d, xzr 754+ dup v26.2d, xzr 755+ dup v28.2d, xzr 756+ dup v30.2d, xzr 757+ b Compute12x8Enter 758+ InitFromBias12x8: 759+ ld1 {v8.8h}, [x12] 760+ ld1 {v10.8h}, [x12] 761+ ld1 {v12.8h}, [x12] 762+ ld1 {v14.8h}, [x12] 763+ ld1 {v16.8h}, [x12] 764+ ld1 {v18.8h}, [x12] 765+ ld1 {v20.8h}, [x12] 766+ ld1 {v22.8h}, [x12] 767+ ld1 {v24.8h}, [x12] 768+ ld1 {v26.8h}, [x12] 769+ ld1 {v28.8h}, [x12] 770+ ld1 {v30.8h}, [x12] 771+ add x12, x12, #16 772+ Compute12x8Enter: 773+ bl Compute12x8Unit 774+ Activation12x8: 775+ cmp x4, #3 776+ beq Relu612x8 777+ cmp x4, #1 778+ beq Relu12x8 779+ b Write12x8 780+ 781+ Relu612x8: 782+ fmin v8.8h, v8.8h, v7.8h 783+ fmin v10.8h, v10.8h, v7.8h 784+ fmin v12.8h, v12.8h, v7.8h 785+ fmin v14.8h, v14.8h, v7.8h 786+ fmin v16.8h, v16.8h, v7.8h 787+ fmin v18.8h, v18.8h, v7.8h 788+ fmin v20.8h, v20.8h, v7.8h 789+ fmin v22.8h, v22.8h, v7.8h 790+ fmin v24.8h, v24.8h, v7.8h 791+ fmin v26.8h, v26.8h, v7.8h 792+ fmin v28.8h, v28.8h, v7.8h 793+ fmin v30.8h, v30.8h, v7.8h 794+ 795+ Relu12x8: 796+ dup v6.8h, wzr 797+ fmax v8.8h, v8.8h, v6.8h 798+ fmax v10.8h, v10.8h, v6.8h 799+ fmax v12.8h, v12.8h, v6.8h 800+ fmax v14.8h, v14.8h, v6.8h 801+ fmax v16.8h, v16.8h, v6.8h 802+ fmax v18.8h, v18.8h, v6.8h 803+ fmax v20.8h, v20.8h, v6.8h 804+ fmax v22.8h, v22.8h, v6.8h 805+ fmax v24.8h, v24.8h, v6.8h 806+ fmax v26.8h, v26.8h, v6.8h 807+ fmax v28.8h, v28.8h, v6.8h 808+ fmax v30.8h, v30.8h, v6.8h 809+ Write12x8: 810+ mov x22, x21 811+ add x23, x21, x8, lsl #2 812+ add x24, x21, x8, lsl #3 813+ st1 {v8.8h}, [x22], x8 814+ st1 {v10.8h}, [x22], x8 815+ st1 {v12.8h}, [x22], x8 816+ st1 {v14.8h}, [x22] 817+ st1 {v16.8h}, [x23], x8 818+ st1 {v18.8h}, [x23], x8 819+ st1 {v20.8h}, [x23], x8 820+ st1 {v22.8h}, [x23] 821+ st1 {v24.8h}, [x24], x8 822+ st1 {v26.8h}, [x24], x8 823+ st1 {v28.8h}, [x24], x8 824+ st1 {v30.8h}, [x24] 825+ add x21, x21, #16 826+ subs x13, x13, #8 827+ 828+ LoopCol12x4: 829+ adds x13, x13, #8 830+ cbz x13, LoopRow12End 831+ LoopCol12x4Core: 832+ mov x10, x0 // update matrixA 833+ ld1 {v0.8h}, [x10], #16 834+ mov x14, x5 // reload depth 835+ prfm pldl1strm, [x11, #632] 836+ ld1 {v3.4h}, [x11], #8 837+ cbnz x12, InitFromBias12x4 838+ dup v8.2s, wzr 839+ dup v10.2s, wzr 840+ dup v12.2s, wzr 841+ dup v14.2s, wzr 842+ dup v16.2s, wzr 843+ dup v18.2s, wzr 844+ dup v20.2s, wzr 845+ dup v22.2s, wzr 846+ dup v24.2s, wzr 847+ dup v26.2s, wzr 848+ dup v28.2s, wzr 849+ dup v30.2s, wzr 850+ b Compute12x4Enter 851+ InitFromBias12x4: 852+ ld1 {v8.4h}, [x12] 853+ ld1 {v10.4h}, [x12] 854+ ld1 {v12.4h}, [x12] 855+ ld1 {v14.4h}, [x12] 856+ ld1 {v16.4h}, [x12] 857+ ld1 {v18.4h}, [x12] 858+ ld1 {v20.4h}, [x12] 859+ ld1 {v22.4h}, [x12] 860+ ld1 {v24.4h}, [x12] 861+ ld1 {v26.4h}, [x12] 862+ ld1 {v28.4h}, [x12] 863+ ld1 {v30.4h}, [x12] 864+ add x12, x12, #8 865+ Compute12x4Enter: 866+ bl Compute12x4Unit 867+ Activation12x4: 868+ cmp x4, #3 869+ beq Relu612x4 870+ cmp x4, #1 871+ beq Relu12x4 872+ b Write12x4 873+ 874+ Relu612x4: 875+ fmin v8.4h, v8.4h, v7.4h 876+ fmin v10.4h, v10.4h, v7.4h 877+ fmin v12.4h, v12.4h, v7.4h 878+ fmin v14.4h, v14.4h, v7.4h 879+ fmin v16.4h, v16.4h, v7.4h 880+ fmin v18.4h, v18.4h, v7.4h 881+ fmin v20.4h, v20.4h, v7.4h 882+ fmin v22.4h, v22.4h, v7.4h 883+ fmin v24.4h, v24.4h, v7.4h 884+ fmin v26.4h, v26.4h, v7.4h 885+ fmin v28.4h, v28.4h, v7.4h 886+ fmin v30.4h, v30.4h, v7.4h 887+ 888+ Relu12x4: 889+ dup v6.4h, wzr 890+ fmax v8.4h, v8.4h, v6.4h 891+ fmax v10.4h, v10.4h, v6.4h 892+ fmax v12.4h, v12.4h, v6.4h 893+ fmax v14.4h, v14.4h, v6.4h 894+ fmax v16.4h, v16.4h, v6.4h 895+ fmax v18.4h, v18.4h, v6.4h 896+ fmax v20.4h, v20.4h, v6.4h 897+ fmax v22.4h, v22.4h, v6.4h 898+ fmax v24.4h, v24.4h, v6.4h 899+ fmax v26.4h, v26.4h, v6.4h 900+ fmax v28.4h, v28.4h, v6.4h 901+ fmax v30.4h, v30.4h, v6.4h 902+ Write12x4: 903+ mov x22, x21 904+ add x23, x21, x8, lsl #2 905+ add x24, x21, x8, lsl #3 906+ cmp x13, #1 907+ beq Write12x1 908+ cmp x13, #2 909+ beq Write12x2 910+ cmp x13, #3 911+ beq Write12x3 912+ st1 {v8.4h}, [x22], x8 913+ st1 {v10.4h}, [x22], x8 914+ st1 {v12.4h}, [x22], x8 915+ st1 {v14.4h}, [x22] 916+ st1 {v16.4h}, [x23], x8 917+ st1 {v18.4h}, [x23], x8 918+ st1 {v20.4h}, [x23], x8 919+ st1 {v22.4h}, [x23] 920+ st1 {v24.4h}, [x24], x8 921+ st1 {v26.4h}, [x24], x8 922+ st1 {v28.4h}, [x24], x8 923+ st1 {v30.4h}, [x24] 924+ add x21, x21, #8 925+ subs x13, x13, #4 926+ bgt LoopCol12x4Core 927+ b LoopRow12End 928+ Write12x1: 929+ st1 {v8.h}[0], [x22], x8 930+ st1 {v10.h}[0], [x22], x8 931+ st1 {v12.h}[0], [x22], x8 932+ st1 {v14.h}[0], [x22] 933+ st1 {v16.h}[0], [x23], x8 934+ st1 {v18.h}[0], [x23], x8 935+ st1 {v20.h}[0], [x23], x8 936+ st1 {v22.h}[0], [x23] 937+ st1 {v24.h}[0], [x24], x8 938+ st1 {v26.h}[0], [x24], x8 939+ st1 {v28.h}[0], [x24], x8 940+ st1 {v30.h}[0], [x24] 941+ b LoopRow12End 942+ Write12x2: 943+ st1 {v8.s}[0], [x22], x8 944+ st1 {v10.s}[0], [x22], x8 945+ st1 {v12.s}[0], [x22], x8 946+ st1 {v14.s}[0], [x22] 947+ st1 {v16.s}[0], [x23], x8 948+ st1 {v18.s}[0], [x23], x8 949+ st1 {v20.s}[0], [x23], x8 950+ st1 {v22.s}[0], [x23] 951+ st1 {v24.s}[0], [x24], x8 952+ st1 {v26.s}[0], [x24], x8 953+ st1 {v28.s}[0], [x24], x8 954+ st1 {v30.s}[0], [x24] 955+ b LoopRow12End 956+ Write12x3: 957+ add x23, x22, #4 958+ st1 {v8.s}[0], [x22], x8 959+ st1 {v8.h}[2], [x23], x8 960+ st1 {v10.s}[0], [x22], x8 961+ st1 {v10.h}[2], [x23], x8 962+ st1 {v12.s}[0], [x22], x8 963+ st1 {v12.h}[2], [x23], x8 964+ st1 {v14.s}[0], [x22], x8 965+ st1 {v14.h}[2], [x23], x8 966+ st1 {v16.s}[0], [x22], x8 967+ st1 {v16.h}[2], [x23], x8 968+ st1 {v18.s}[0], [x22], x8 969+ st1 {v18.h}[2], [x23], x8 970+ st1 {v20.s}[0], [x22], x8 971+ st1 {v20.h}[2], [x23], x8 972+ st1 {v22.s}[0], [x22], x8 973+ st1 {v22.h}[2], [x23], x8 974+ st1 {v24.s}[0], [x22], x8 975+ st1 {v24.h}[2], [x23], x8 976+ st1 {v26.s}[0], [x22], x8 977+ st1 {v26.h}[2], [x23], x8 978+ st1 {v28.s}[0], [x22], x8 979+ st1 {v28.h}[2], [x23], x8 980+ st1 {v30.s}[0], [x22] 981+ st1 {v30.h}[2], [x23] 982+ LoopRow12End: 983+ add x0, x0, x16, lsl #3 984+ add x0, x0, x16, lsl #2 985+ add x2, x2, x8, lsl #3 986+ add x2, x2, x8, lsl #2 987+ subs x6, x6, #12 988+ bge LoopRow12 989+ 990+LoopRow8: 991+ adds x6, x6,#12 992+ cbz x6, End 993+ subs x6, x6, #8 994+ blt LoopRow4 995+ mov x11, x1 // reload matrixB 996+ mov x12, x3 // reload bias 997+ mov x13, x7 // reload col 998+ mov x21, x2 // relocate output 999+ subs x13, x13, #16 1000+ blt LoopCol8x8 1001+ LoopCol8x16: 1002+ mov x10, x0 // update matrixA 1003+ ld1 {v0.8h}, [x10], #16 1004+ mov x14, x5 // reload depth 1005+ prfm pldl1strm, [x11, #632] 1006+ ld1 {v3.8h}, [x11], #16 1007+ cbnz x12, InitFromBias8x16 1008+ dup v8.2d, xzr 1009+ dup v9.2d, xzr 1010+ dup v10.2d, xzr 1011+ dup v11.2d, xzr 1012+ dup v12.2d, xzr 1013+ dup v13.2d, xzr 1014+ dup v14.2d, xzr 1015+ dup v15.2d, xzr 1016+ dup v16.2d, xzr 1017+ dup v17.2d, xzr 1018+ dup v18.2d, xzr 1019+ dup v19.2d, xzr 1020+ dup v20.2d, xzr 1021+ dup v21.2d, xzr 1022+ dup v22.2d, xzr 1023+ dup v23.2d, xzr 1024+ b Compute8x16Enter 1025+ InitFromBias8x16: 1026+ ld1 {v8.8h, v9.8h}, [x12] 1027+ ld1 {v10.8h, v11.8h}, [x12] 1028+ ld1 {v12.8h, v13.8h}, [x12] 1029+ ld1 {v14.8h, v15.8h}, [x12] 1030+ ld1 {v16.8h, v17.8h}, [x12] 1031+ ld1 {v18.8h, v19.8h}, [x12] 1032+ ld1 {v20.8h, v21.8h}, [x12] 1033+ ld1 {v22.8h, v23.8h}, [x12] 1034+ add x12, x12, #32 1035+ Compute8x16Enter: 1036+ bl Compute8x16Unit 1037+ Activation8x16: 1038+ cmp x4, #3 1039+ beq Relu68x16 1040+ cmp x4, #1 1041+ beq Relu8x16 1042+ b Write8x16 1043+ 1044+ Relu68x16: 1045+ fmin v8.8h, v8.8h, v7.8h 1046+ fmin v9.8h, v9.8h, v7.8h 1047+ fmin v10.8h, v10.8h, v7.8h 1048+ fmin v11.8h, v11.8h, v7.8h 1049+ fmin v12.8h, v12.8h, v7.8h 1050+ fmin v13.8h, v13.8h, v7.8h 1051+ fmin v14.8h, v14.8h, v7.8h 1052+ fmin v15.8h, v15.8h, v7.8h 1053+ fmin v16.8h, v16.8h, v7.8h 1054+ fmin v17.8h, v17.8h, v7.8h 1055+ fmin v18.8h, v18.8h, v7.8h 1056+ fmin v19.8h, v19.8h, v7.8h 1057+ fmin v20.8h, v20.8h, v7.8h 1058+ fmin v21.8h, v21.8h, v7.8h 1059+ fmin v22.8h, v22.8h, v7.8h 1060+ fmin v23.8h, v23.8h, v7.8h 1061+ 1062+ Relu8x16: 1063+ dup v6.8h, wzr 1064+ fmax v8.8h, v8.8h, v6.8h 1065+ fmax v9.8h, v9.8h, v6.8h 1066+ fmax v10.8h, v10.8h, v6.8h 1067+ fmax v11.8h, v11.8h, v6.8h 1068+ fmax v12.8h, v12.8h, v6.8h 1069+ fmax v13.8h, v13.8h, v6.8h 1070+ fmax v14.8h, v14.8h, v6.8h 1071+ fmax v15.8h, v15.8h, v6.8h 1072+ fmax v16.8h, v16.8h, v6.8h 1073+ fmax v17.8h, v17.8h, v6.8h 1074+ fmax v18.8h, v18.8h, v6.8h 1075+ fmax v19.8h, v19.8h, v6.8h 1076+ fmax v20.8h, v20.8h, v6.8h 1077+ fmax v21.8h, v21.8h, v6.8h 1078+ fmax v22.8h, v22.8h, v6.8h 1079+ fmax v23.8h, v23.8h, v6.8h 1080+ Write8x16: 1081+ mov x22, x21 1082+ add x23, x21, x8, lsl #2 1083+ st1 {v8.8h, v9.8h}, [x22], x8 1084+ st1 {v10.8h, v11.8h}, [x22], x8 1085+ st1 {v12.8h, v13.8h}, [x22], x8 1086+ st1 {v14.8h, v15.8h}, [x22] 1087+ st1 {v16.8h, v17.8h}, [x23], x8 1088+ st1 {v18.8h, v19.8h}, [x23], x8 1089+ st1 {v20.8h, v21.8h}, [x23], x8 1090+ st1 {v22.8h, v23.8h}, [x23] 1091+ add x21, x21, #32 1092+ subs x13, x13, #16 1093+ bge LoopCol8x16 1094+ 1095+ LoopCol8x8: 1096+ adds x13, x13, #16 1097+ cbz x13, LoopRow8End 1098+ subs x13, x13, #8 1099+ blt LoopCol8x4 1100+ mov x10, x0 // update matrixA 1101+ ld1 {v0.8h}, [x10], #16 1102+ mov x14, x5 // reload depth 1103+ prfm pldl1strm, [x11, #632] 1104+ ld1 {v3.8h}, [x11], #16 1105+ cbnz x12, InitFromBias8x8 1106+ dup v8.2d, xzr 1107+ dup v10.2d, xzr 1108+ dup v12.2d, xzr 1109+ dup v14.2d, xzr 1110+ dup v16.2d, xzr 1111+ dup v18.2d, xzr 1112+ dup v20.2d, xzr 1113+ dup v22.2d, xzr 1114+ b Compute8x8Enter 1115+ InitFromBias8x8: 1116+ ld1 {v8.8h}, [x12] 1117+ ld1 {v10.8h}, [x12] 1118+ ld1 {v12.8h}, [x12] 1119+ ld1 {v14.8h}, [x12] 1120+ ld1 {v16.8h}, [x12] 1121+ ld1 {v18.8h}, [x12] 1122+ ld1 {v20.8h}, [x12] 1123+ ld1 {v22.8h}, [x12] 1124+ add x12, x12, #16 1125+ Compute8x8Enter: 1126+ bl Compute8x8Unit 1127+ Activation8x8: 1128+ cmp x4, #3 1129+ beq Relu68x8 1130+ cmp x4, #1 1131+ beq Relu8x8 1132+ b Write8x8 1133+ 1134+ Relu68x8: 1135+ fmin v8.8h, v8.8h, v7.8h 1136+ fmin v10.8h, v10.8h, v7.8h 1137+ fmin v12.8h, v12.8h, v7.8h 1138+ fmin v14.8h, v14.8h, v7.8h 1139+ fmin v16.8h, v16.8h, v7.8h 1140+ fmin v18.8h, v18.8h, v7.8h 1141+ fmin v20.8h, v20.8h, v7.8h 1142+ fmin v22.8h, v22.8h, v7.8h 1143+ 1144+ Relu8x8: 1145+ dup v6.8h, wzr 1146+ fmax v8.8h, v8.8h, v6.8h 1147+ fmax v10.8h, v10.8h, v6.8h 1148+ fmax v12.8h, v12.8h, v6.8h 1149+ fmax v14.8h, v14.8h, v6.8h 1150+ fmax v16.8h, v16.8h, v6.8h 1151+ fmax v18.8h, v18.8h, v6.8h 1152+ fmax v20.8h, v20.8h, v6.8h 1153+ fmax v22.8h, v22.8h, v6.8h 1154+ Write8x8: 1155+ mov x22, x21 1156+ add x23, x21, x8, lsl #2 1157+ st1 {v8.8h}, [x22], x8 1158+ st1 {v10.8h}, [x22], x8 1159+ st1 {v12.8h}, [x22], x8 1160+ st1 {v14.8h}, [x22] 1161+ st1 {v16.8h}, [x23], x8 1162+ st1 {v18.8h}, [x23], x8 1163+ st1 {v20.8h}, [x23], x8 1164+ st1 {v22.8h}, [x23] 1165+ add x21, x21, #16 1166+ subs x13, x13, #8 1167+ 1168+ LoopCol8x4: 1169+ adds x13, x13, #8 1170+ cbz x13, LoopRow8End 1171+ LoopCol8x4Core: 1172+ mov x10, x0 // update matrixA 1173+ ld1 {v0.8h}, [x10], #16 1174+ mov x14, x5 // reload depth 1175+ prfm pldl1strm, [x11, #632] 1176+ ld1 {v3.4h}, [x11], #8 1177+ cbnz x12, InitFromBias8x4 1178+ dup v8.2s, wzr 1179+ dup v10.2s, wzr 1180+ dup v12.2s, wzr 1181+ dup v14.2s, wzr 1182+ dup v16.2s, wzr 1183+ dup v18.2s, wzr 1184+ dup v20.2s, wzr 1185+ dup v22.2s, wzr 1186+ b Compute8x4Enter 1187+ InitFromBias8x4: 1188+ ld1 {v8.4h}, [x12] 1189+ ld1 {v10.4h}, [x12] 1190+ ld1 {v12.4h}, [x12] 1191+ ld1 {v14.4h}, [x12] 1192+ ld1 {v16.4h}, [x12] 1193+ ld1 {v18.4h}, [x12] 1194+ ld1 {v20.4h}, [x12] 1195+ ld1 {v22.4h}, [x12] 1196+ add x12, x12, #8 1197+ Compute8x4Enter: 1198+ bl Compute8x4Unit 1199+ Activation8x4: 1200+ cmp x4, #3 1201+ beq Relu68x4 1202+ cmp x4, #1 1203+ beq Relu8x4 1204+ b Write8x4 1205+ 1206+ Relu68x4: 1207+ fmin v8.4h, v8.4h, v7.4h 1208+ fmin v10.4h, v10.4h, v7.4h 1209+ fmin v12.4h, v12.4h, v7.4h 1210+ fmin v14.4h, v14.4h, v7.4h 1211+ fmin v16.4h, v16.4h, v7.4h 1212+ fmin v18.4h, v18.4h, v7.4h 1213+ fmin v20.4h, v20.4h, v7.4h 1214+ fmin v22.4h, v22.4h, v7.4h 1215+ 1216+ Relu8x4: 1217+ dup v6.4h, wzr 1218+ fmax v8.4h, v8.4h, v6.4h 1219+ fmax v10.4h, v10.4h, v6.4h 1220+ fmax v12.4h, v12.4h, v6.4h 1221+ fmax v14.4h, v14.4h, v6.4h 1222+ fmax v16.4h, v16.4h, v6.4h 1223+ fmax v18.4h, v18.4h, v6.4h 1224+ fmax v20.4h, v20.4h, v6.4h 1225+ fmax v22.4h, v22.4h, v6.4h 1226+ Write8x4: 1227+ mov x22, x21 1228+ add x23, x21, x8, lsl #2 1229+ cmp x13, #1 1230+ beq Write8x1 1231+ cmp x13, #2 1232+ beq Write8x2 1233+ cmp x13, #3 1234+ beq Write8x3 1235+ st1 {v8.4h}, [x22], x8 1236+ st1 {v10.4h}, [x22], x8 1237+ st1 {v12.4h}, [x22], x8 1238+ st1 {v14.4h}, [x22] 1239+ st1 {v16.4h}, [x23], x8 1240+ st1 {v18.4h}, [x23], x8 1241+ st1 {v20.4h}, [x23], x8 1242+ st1 {v22.4h}, [x23] 1243+ add x21, x21, #8 1244+ subs x13, x13, #4 1245+ bgt LoopCol8x4Core 1246+ b LoopRow8End 1247+ Write8x1: 1248+ st1 {v8.h}[0], [x22], x8 1249+ st1 {v10.h}[0], [x22], x8 1250+ st1 {v12.h}[0], [x22], x8 1251+ st1 {v14.h}[0], [x22] 1252+ st1 {v16.h}[0], [x23], x8 1253+ st1 {v18.h}[0], [x23], x8 1254+ st1 {v20.h}[0], [x23], x8 1255+ st1 {v22.h}[0], [x23] 1256+ b LoopRow8End 1257+ Write8x2: 1258+ st1 {v8.s}[0], [x22], x8 1259+ st1 {v10.s}[0], [x22], x8 1260+ st1 {v12.s}[0], [x22], x8 1261+ st1 {v14.s}[0], [x22] 1262+ st1 {v16.s}[0], [x23], x8 1263+ st1 {v18.s}[0], [x23], x8 1264+ st1 {v20.s}[0], [x23], x8 1265+ st1 {v22.s}[0], [x23] 1266+ b LoopRow8End 1267+ Write8x3: 1268+ add x23, x22, #4 1269+ st1 {v8.s}[0], [x22], x8 1270+ st1 {v8.h}[2], [x23], x8 1271+ st1 {v10.s}[0], [x22], x8 1272+ st1 {v10.h}[2], [x23], x8 1273+ st1 {v12.s}[0], [x22], x8 1274+ st1 {v12.h}[2], [x23], x8 1275+ st1 {v14.s}[0], [x22], x8 1276+ st1 {v14.h}[2], [x23], x8 1277+ st1 {v16.s}[0], [x22], x8 1278+ st1 {v16.h}[2], [x23], x8 1279+ st1 {v18.s}[0], [x22], x8 1280+ st1 {v18.h}[2], [x23], x8 1281+ st1 {v20.s}[0], [x22], x8 1282+ st1 {v20.h}[2], [x23], x8 1283+ st1 {v22.s}[0], [x22], x8 1284+ st1 {v22.h}[2], [x23], x8 1285+ LoopRow8End: 1286+ add x0, x0, x16, lsl #3 1287+ add x2, x2, x8, lsl #3 1288+ subs x6, x6, #8 1289+ 1290+LoopRow4: 1291+ adds x6, x6, #8 1292+ cbz x6, End 1293+ subs x6, x6, #4 1294+ blt LoopRowTail 1295+ mov x11, x1 // reload matrixB 1296+ mov x12, x3 // reload bias 1297+ mov x13, x7 // reload col 1298+ mov x21, x2 // relocate output 1299+ subs x13, x13, #16 1300+ blt LoopCol4x8 1301+ LoopCol4x16: 1302+ mov x10, x0 // update matrixA 1303+ ld1 {v0.4h}, [x10], #8 1304+ mov x14, x5 // reload depth 1305+ prfm pldl1strm, [x11, #632] 1306+ ld1 {v3.8h}, [x11], #16 1307+ cbnz x12, InitFromBias4x16 1308+ dup v8.2d, xzr 1309+ dup v9.2d, xzr 1310+ dup v10.2d, xzr 1311+ dup v11.2d, xzr 1312+ dup v12.2d, xzr 1313+ dup v13.2d, xzr 1314+ dup v14.2d, xzr 1315+ dup v15.2d, xzr 1316+ b Compute4x16Enter 1317+ InitFromBias4x16: 1318+ ld1 {v8.8h, v9.8h}, [x12] 1319+ ld1 {v10.8h, v11.8h}, [x12] 1320+ ld1 {v12.8h, v13.8h}, [x12] 1321+ ld1 {v14.8h, v15.8h}, [x12] 1322+ add x12, x12, #32 1323+ Compute4x16Enter: 1324+ bl Compute4x16Unit 1325+ Activation4x16: 1326+ cmp x4, #3 1327+ beq Relu64x16 1328+ cmp x4, #1 1329+ beq Relu4x16 1330+ b Write4x16 1331+ 1332+ Relu64x16: 1333+ fmin v8.8h, v8.8h, v7.8h 1334+ fmin v9.8h, v9.8h, v7.8h 1335+ fmin v10.8h, v10.8h, v7.8h 1336+ fmin v11.8h, v11.8h, v7.8h 1337+ fmin v12.8h, v12.8h, v7.8h 1338+ fmin v13.8h, v13.8h, v7.8h 1339+ fmin v14.8h, v14.8h, v7.8h 1340+ fmin v15.8h, v15.8h, v7.8h 1341+ 1342+ Relu4x16: 1343+ dup v6.8h, wzr 1344+ fmax v8.8h, v8.8h, v6.8h 1345+ fmax v9.8h, v9.8h, v6.8h 1346+ fmax v10.8h, v10.8h, v6.8h 1347+ fmax v11.8h, v11.8h, v6.8h 1348+ fmax v12.8h, v12.8h, v6.8h 1349+ fmax v13.8h, v13.8h, v6.8h 1350+ fmax v14.8h, v14.8h, v6.8h 1351+ fmax v15.8h, v15.8h, v6.8h 1352+ Write4x16: 1353+ mov x22, x21 1354+ st1 {v8.8h, v9.8h}, [x22], x8 1355+ st1 {v10.8h, v11.8h}, [x22], x8 1356+ st1 {v12.8h, v13.8h}, [x22], x8 1357+ st1 {v14.8h, v15.8h}, [x22] 1358+ add x21, x21, #32 1359+ subs x13, x13, #16 1360+ bge LoopCol4x16 1361+ 1362+ LoopCol4x8: 1363+ adds x13, x13, #16 1364+ cbz x13, LoopRow4End 1365+ subs x13, x13, #8 1366+ blt LoopCol4x4 1367+ mov x10, x0 // update matrixA 1368+ ld1 {v0.4h}, [x10], #8 1369+ mov x14, x5 // reload depth 1370+ prfm pldl1strm, [x11, #632] 1371+ ld1 {v3.8h}, [x11], #16 1372+ cbnz x12, InitFromBias4x8 1373+ dup v8.2d, xzr 1374+ dup v10.2d, xzr 1375+ dup v12.2d, xzr 1376+ dup v14.2d, xzr 1377+ b Compute4x8Enter 1378+ InitFromBias4x8: 1379+ ld1 {v8.8h}, [x12] 1380+ ld1 {v10.8h}, [x12] 1381+ ld1 {v12.8h}, [x12] 1382+ ld1 {v14.8h}, [x12] 1383+ add x12, x12, #16 1384+ Compute4x8Enter: 1385+ bl Compute4x8Unit 1386+ Activation4x8: 1387+ cmp x4, #3 1388+ beq Relu64x8 1389+ cmp x4, #1 1390+ beq Relu4x8 1391+ b Write4x8 1392+ 1393+ Relu64x8: 1394+ fmin v8.8h, v8.8h, v7.8h 1395+ fmin v10.8h, v10.8h, v7.8h 1396+ fmin v12.8h, v12.8h, v7.8h 1397+ fmin v14.8h, v14.8h, v7.8h 1398+ 1399+ Relu4x8: 1400+ dup v6.8h, wzr 1401+ fmax v8.8h, v8.8h, v6.8h 1402+ fmax v10.8h, v10.8h, v6.8h 1403+ fmax v12.8h, v12.8h, v6.8h 1404+ fmax v14.8h, v14.8h, v6.8h 1405+ Write4x8: 1406+ mov x22, x21 1407+ st1 {v8.8h}, [x22], x8 1408+ st1 {v10.8h}, [x22], x8 1409+ st1 {v12.8h}, [x22], x8 1410+ st1 {v14.8h}, [x22] 1411+ add x21, x21, #16 1412+ subs x13, x13, #8 1413+ 1414+ LoopCol4x4: 1415+ adds x13, x13, #8 1416+ cbz x13, LoopRow4End 1417+ LoopCol4x4Core: 1418+ mov x10, x0 // update matrixA 1419+ ld1 {v0.4h}, [x10], #8 1420+ mov x14, x5 // reload depth 1421+ prfm pldl1strm, [x11, #632] 1422+ ld1 {v3.4h}, [x11], #8 1423+ cbnz x12, InitFromBias4x4 1424+ dup v8.2s, wzr 1425+ dup v10.2s, wzr 1426+ dup v12.2s, wzr 1427+ dup v14.2s, wzr 1428+ b Compute4x4Enter 1429+ InitFromBias4x4: 1430+ ld1 {v8.4h}, [x12] 1431+ ld1 {v10.4h}, [x12] 1432+ ld1 {v12.4h}, [x12] 1433+ ld1 {v14.4h}, [x12] 1434+ add x12, x12, #8 1435+ Compute4x4Enter: 1436+ bl Compute4x4Unit 1437+ Activation4x4: 1438+ cmp x4, #3 1439+ beq Relu64x4 1440+ cmp x4, #1 1441+ beq Relu4x4 1442+ b Write4x4 1443+ 1444+ Relu64x4: 1445+ fmin v8.4h, v8.4h, v7.4h 1446+ fmin v10.4h, v10.4h, v7.4h 1447+ fmin v12.4h, v12.4h, v7.4h 1448+ fmin v14.4h, v14.4h, v7.4h 1449+ 1450+ Relu4x4: 1451+ dup v6.4h, wzr 1452+ fmax v8.4h, v8.4h, v6.4h 1453+ fmax v10.4h, v10.4h, v6.4h 1454+ fmax v12.4h, v12.4h, v6.4h 1455+ fmax v14.4h, v14.4h, v6.4h 1456+ Write4x4: 1457+ mov x22, x21 1458+ cmp x13, #1 1459+ beq Write4x1 1460+ cmp x13, #2 1461+ beq Write4x2 1462+ cmp x13, #3 1463+ beq Write4x3 1464+ st1 {v8.4h}, [x22], x8 1465+ st1 {v10.4h}, [x22], x8 1466+ st1 {v12.4h}, [x22], x8 1467+ st1 {v14.4h}, [x22] 1468+ add x21, x21, #8 1469+ subs x13, x13, #4 1470+ bgt LoopCol4x4Core 1471+ b LoopRow4End 1472+ Write4x1: 1473+ st1 {v8.h}[0], [x22], x8 1474+ st1 {v10.h}[0], [x22], x8 1475+ st1 {v12.h}[0], [x22], x8 1476+ st1 {v14.h}[0], [x22] 1477+ b LoopRow4End 1478+ Write4x2: 1479+ st1 {v8.s}[0], [x22], x8 1480+ st1 {v10.s}[0], [x22], x8 1481+ st1 {v12.s}[0], [x22], x8 1482+ st1 {v14.s}[0], [x22] 1483+ b LoopRow4End 1484+ Write4x3: 1485+ add x23, x22, #4 1486+ st1 {v8.s}[0], [x22], x8 1487+ st1 {v8.h}[2], [x23], x8 1488+ st1 {v10.s}[0], [x22], x8 1489+ st1 {v10.h}[2], [x23], x8 1490+ st1 {v12.s}[0], [x22], x8 1491+ st1 {v12.h}[2], [x23], x8 1492+ st1 {v14.s}[0], [x22], x8 1493+ st1 {v14.h}[2], [x23], x8 1494+ LoopRow4End: 1495+ add x0, x0, x16, lsl #2 1496+ add x2, x2, x8, lsl #2 1497+ subs x6, x6, #4 1498+ 1499+LoopRowTail: 1500+ adds x6, x6, #4 1501+ cbz x6, End 1502+ cmp x6, #1 1503+ beq LoopRow1 1504+ cmp x6, #2 1505+ beq LoopRow2 1506+ // LoopRow3 1507+ mov x11, x1 // reload matrixB 1508+ mov x12, x3 // reload bias 1509+ mov x13, x7 // reload col 1510+ mov x21, x2 // relocate output 1511+ subs x13, x13, #16 1512+ blt LoopCol3x8 1513+ LoopCol3x16: 1514+ mov x10, x0 // update matrixA 1515+ mov x14, x5 // reload depth 1516+ cbnz x12, InitFromBias3x16 1517+ dup v8.2d, xzr 1518+ dup v9.2d, xzr 1519+ dup v10.2d, xzr 1520+ dup v11.2d, xzr 1521+ dup v12.2d, xzr 1522+ dup v13.2d, xzr 1523+ b Compute3x16Enter 1524+ InitFromBias3x16: 1525+ ld1 {v8.8h, v9.8h}, [x12] 1526+ ld1 {v10.8h, v11.8h}, [x12] 1527+ ld1 {v12.8h, v13.8h}, [x12] 1528+ add x12, x12, #32 1529+ Compute3x16Enter: 1530+ bl Compute3x16Unit 1531+ Activation3x16: 1532+ cmp x4, #3 1533+ beq Relu63x16 1534+ cmp x4, #1 1535+ beq Relu3x16 1536+ b Write3x16 1537+ 1538+ Relu63x16: 1539+ fmin v8.8h, v8.8h, v7.8h 1540+ fmin v9.8h, v9.8h, v7.8h 1541+ fmin v10.8h, v10.8h, v7.8h 1542+ fmin v11.8h, v11.8h, v7.8h 1543+ fmin v12.8h, v12.8h, v7.8h 1544+ fmin v13.8h, v13.8h, v7.8h 1545+ 1546+ Relu3x16: 1547+ dup v6.8h, wzr 1548+ fmax v8.8h, v8.8h, v6.8h 1549+ fmax v9.8h, v9.8h, v6.8h 1550+ fmax v10.8h, v10.8h, v6.8h 1551+ fmax v11.8h, v11.8h, v6.8h 1552+ fmax v12.8h, v12.8h, v6.8h 1553+ fmax v13.8h, v13.8h, v6.8h 1554+ Write3x16: 1555+ mov x22, x21 1556+ st1 {v8.8h, v9.8h}, [x22], x8 1557+ st1 {v10.8h, v11.8h}, [x22], x8 1558+ st1 {v12.8h, v13.8h}, [x22] 1559+ add x21, x21, #32 1560+ subs x13, x13, #16 1561+ bge LoopCol3x16 1562+ 1563+ LoopCol3x8: 1564+ adds x13, x13, #16 1565+ cbz x13, End 1566+ subs x13, x13, #8 1567+ blt LoopCol3x4 1568+ mov x10, x0 // update matrixA 1569+ mov x14, x5 // reload depth 1570+ cbnz x12, InitFromBias3x8 1571+ dup v8.2d, xzr 1572+ dup v10.2d, xzr 1573+ dup v12.2d, xzr 1574+ b Compute3x8Enter 1575+ InitFromBias3x8: 1576+ ld1 {v8.8h}, [x12] 1577+ ld1 {v10.8h}, [x12] 1578+ ld1 {v12.8h}, [x12] 1579+ add x12, x12, #16 1580+ Compute3x8Enter: 1581+ bl Compute3x8Unit 1582+ Activation3x8: 1583+ cmp x4, #3 1584+ beq Relu63x8 1585+ cmp x4, #1 1586+ beq Relu3x8 1587+ b Write3x8 1588+ 1589+ Relu63x8: 1590+ fmin v8.8h, v8.8h, v7.8h 1591+ fmin v10.8h, v10.8h, v7.8h 1592+ fmin v12.8h, v12.8h, v7.8h 1593+ 1594+ Relu3x8: 1595+ dup v6.8h, wzr 1596+ fmax v8.8h, v8.8h, v6.8h 1597+ fmax v10.8h, v10.8h, v6.8h 1598+ fmax v12.8h, v12.8h, v6.8h 1599+ Write3x8: 1600+ mov x22, x21 1601+ st1 {v8.8h}, [x22], x8 1602+ st1 {v10.8h}, [x22], x8 1603+ st1 {v12.8h}, [x22] 1604+ add x21, x21, #16 1605+ subs x13, x13, #8 1606+ 1607+ LoopCol3x4: 1608+ adds x13, x13, #8 1609+ cbz x13, End 1610+ LoopCol3x4Core: 1611+ mov x10, x0 // update matrixA 1612+ mov x14, x5 // reload depth 1613+ cbnz x12, InitFromBias3x4 1614+ dup v8.2s, wzr 1615+ dup v10.2s, wzr 1616+ dup v12.2s, wzr 1617+ b Compute3x4Enter 1618+ InitFromBias3x4: 1619+ ld1 {v8.4h}, [x12] 1620+ ld1 {v10.4h}, [x12] 1621+ ld1 {v12.4h}, [x12] 1622+ add x12, x12, #8 1623+ Compute3x4Enter: 1624+ bl Compute3x4Unit 1625+ Activation3x4: 1626+ cmp x4, #3 1627+ beq Relu63x4 1628+ cmp x4, #1 1629+ beq Relu3x4 1630+ b Write3x4 1631+ 1632+ Relu63x4: 1633+ fmin v8.4h, v8.4h, v7.4h 1634+ fmin v10.4h, v10.4h, v7.4h 1635+ fmin v12.4h, v12.4h, v7.4h 1636+ 1637+ Relu3x4: 1638+ dup v6.4h, wzr 1639+ fmax v8.4h, v8.4h, v6.4h 1640+ fmax v10.4h, v10.4h, v6.4h 1641+ fmax v12.4h, v12.4h, v6.4h 1642+ Write3x4: 1643+ mov x22, x21 1644+ cmp x13, #1 1645+ beq Write3x1 1646+ cmp x13, #2 1647+ beq Write3x2 1648+ cmp x13, #3 1649+ beq Write3x3 1650+ st1 {v8.4h}, [x22], x8 1651+ st1 {v10.4h}, [x22], x8 1652+ st1 {v12.4h}, [x22] 1653+ add x21, x21, #8 1654+ subs x13, x13, #4 1655+ bgt LoopCol3x4Core 1656+ b End 1657+ Write3x1: 1658+ st1 {v8.h}[0], [x22], x8 1659+ st1 {v10.h}[0], [x22], x8 1660+ st1 {v12.h}[0], [x22] 1661+ b End 1662+ Write3x2: 1663+ st1 {v8.s}[0], [x22], x8 1664+ st1 {v10.s}[0], [x22], x8 1665+ st1 {v12.s}[0], [x22] 1666+ b End 1667+ Write3x3: 1668+ add x23, x22, #4 1669+ st1 {v8.s}[0], [x22], x8 1670+ st1 {v8.h}[2], [x23], x8 1671+ st1 {v10.s}[0], [x22], x8 1672+ st1 {v10.h}[2], [x23], x8 1673+ st1 {v12.s}[0], [x22], x8 1674+ st1 {v12.h}[2], [x23], x8 1675+ b End 1676+ 1677+LoopRow2: 1678+ mov x11, x1 // reload matrixB 1679+ mov x12, x3 // reload bias 1680+ mov x13, x7 // reload col 1681+ mov x21, x2 // relocate output 1682+ subs x13, x13, #16 1683+ blt LoopCol2x8 1684+ LoopCol2x16: 1685+ mov x10, x0 // update matrixA 1686+ mov x14, x5 // reload depth 1687+ cbnz x12, InitFromBias2x16 1688+ dup v8.2d, xzr 1689+ dup v9.2d, xzr 1690+ dup v10.2d, xzr 1691+ dup v11.2d, xzr 1692+ b Compute2x16Enter 1693+ InitFromBias2x16: 1694+ ld1 {v8.8h, v9.8h}, [x12] 1695+ ld1 {v10.8h, v11.8h}, [x12] 1696+ add x12, x12, #32 1697+ Compute2x16Enter: 1698+ bl Compute2x16Unit 1699+ Activation2x16: 1700+ cmp x4, #3 1701+ beq Relu62x16 1702+ cmp x4, #1 1703+ beq Relu2x16 1704+ b Write2x16 1705+ 1706+ Relu62x16: 1707+ fmin v8.8h, v8.8h, v7.8h 1708+ fmin v9.8h, v9.8h, v7.8h 1709+ fmin v10.8h, v10.8h, v7.8h 1710+ fmin v11.8h, v11.8h, v7.8h 1711+ 1712+ Relu2x16: 1713+ dup v6.8h, wzr 1714+ fmax v8.8h, v8.8h, v6.8h 1715+ fmax v9.8h, v9.8h, v6.8h 1716+ fmax v10.8h, v10.8h, v6.8h 1717+ fmax v11.8h, v11.8h, v6.8h 1718+ Write2x16: 1719+ mov x22, x21 1720+ st1 {v8.8h, v9.8h}, [x22], x8 1721+ st1 {v10.8h, v11.8h}, [x22] 1722+ add x21, x21, #32 1723+ subs x13, x13, #16 1724+ bge LoopCol2x16 1725+ 1726+ LoopCol2x8: 1727+ adds x13, x13, #16 1728+ cbz x13, End 1729+ subs x13, x13, #8 1730+ blt LoopCol2x4 1731+ mov x10, x0 // update matrixA 1732+ mov x14, x5 // reload depth 1733+ cbnz x12, InitFromBias2x8 1734+ dup v8.2d, xzr 1735+ dup v10.2d, xzr 1736+ b Compute2x8Enter 1737+ InitFromBias2x8: 1738+ ld1 {v8.8h}, [x12] 1739+ ld1 {v10.8h}, [x12] 1740+ add x12, x12, #16 1741+ Compute2x8Enter: 1742+ bl Compute2x8Unit 1743+ Activation2x8: 1744+ cmp x4, #3 1745+ beq Relu62x8 1746+ cmp x4, #1 1747+ beq Relu2x8 1748+ b Write2x8 1749+ 1750+ Relu62x8: 1751+ fmin v8.8h, v8.8h, v7.8h 1752+ fmin v10.8h, v10.8h, v7.8h 1753+ 1754+ Relu2x8: 1755+ dup v6.8h, wzr 1756+ fmax v8.8h, v8.8h, v6.8h 1757+ fmax v10.8h, v10.8h, v6.8h 1758+ Write2x8: 1759+ mov x22, x21 1760+ st1 {v8.8h}, [x22], x8 1761+ st1 {v10.8h}, [x22] 1762+ add x21, x21, #16 1763+ subs x13, x13, #8 1764+ 1765+ LoopCol2x4: 1766+ adds x13, x13, #8 1767+ cbz x13, End 1768+ LoopCol2x4Core: 1769+ mov x10, x0 // update matrixA 1770+ mov x14, x5 // reload depth 1771+ cbnz x12, InitFromBias2x4 1772+ dup v8.2s, wzr 1773+ dup v10.2s, wzr 1774+ b Compute2x4Enter 1775+ InitFromBias2x4: 1776+ ld1 {v8.4h}, [x12] 1777+ ld1 {v10.4h}, [x12] 1778+ add x12, x12, #8 1779+ Compute2x4Enter: 1780+ bl Compute2x4Unit 1781+ Activation2x4: 1782+ cmp x4, #3 1783+ beq Relu62x4 1784+ cmp x4, #1 1785+ beq Relu2x4 1786+ b Write2x4 1787+ 1788+ Relu62x4: 1789+ fmin v8.4h, v8.4h, v7.4h 1790+ fmin v10.4h, v10.4h, v7.4h 1791+ Relu2x4: 1792+ dup v6.4h, wzr 1793+ fmax v8.4h, v8.4h, v6.4h 1794+ fmax v10.4h, v10.4h, v6.4h 1795+ Write2x4: 1796+ mov x22, x21 1797+ cmp x13, #1 1798+ beq Write2x1 1799+ cmp x13, #2 1800+ beq Write2x2 1801+ cmp x13, #3 1802+ beq Write2x3 1803+ st1 {v8.4h}, [x22], x8 1804+ st1 {v10.4h}, [x22] 1805+ add x21, x21, #8 1806+ subs x13, x13, #4 1807+ bgt LoopCol2x4Core 1808+ b End 1809+ Write2x1: 1810+ st1 {v8.h}[0], [x22], x8 1811+ st1 {v10.h}[0], [x22] 1812+ b End 1813+ Write2x2: 1814+ st1 {v8.s}[0], [x22], x8 1815+ st1 {v10.s}[0], [x22] 1816+ b End 1817+ Write2x3: 1818+ add x23, x22, #4 1819+ st1 {v8.s}[0], [x22], x8 1820+ st1 {v8.h}[2], [x23], x8 1821+ st1 {v10.s}[0], [x22], x8 1822+ st1 {v10.h}[2], [x23], x8 1823+ b End 1824+ 1825+LoopRow1: 1826+ mov x11, x1 // reload matrixB 1827+ mov x12, x3 // reload bias 1828+ mov x13, x7 // reload col 1829+ mov x21, x2 // relocate output 1830+ subs x13, x13, #16 1831+ blt LoopCol1x8 1832+ LoopCol1x16: 1833+ mov x10, x0 // update matrixA 1834+ mov x14, x5 // reload depth 1835+ cbnz x12, InitFromBias1x16 1836+ dup v8.2d, xzr 1837+ dup v9.2d, xzr 1838+ b Compute1x16Enter 1839+ InitFromBias1x16: 1840+ ld1 {v8.8h, v9.8h}, [x12], #32 1841+ Compute1x16Enter: 1842+ bl Compute1x16Unit 1843+ Activation1x16: 1844+ cmp x4, #3 1845+ beq Relu61x16 1846+ cmp x4, #1 1847+ beq Relu1x16 1848+ b Write1x16 1849+ 1850+ Relu61x16: 1851+ fmin v8.8h, v8.8h, v7.8h 1852+ fmin v9.8h, v9.8h, v7.8h 1853+ 1854+ Relu1x16: 1855+ dup v6.8h, wzr 1856+ fmax v8.8h, v8.8h, v6.8h 1857+ fmax v9.8h, v9.8h, v6.8h 1858+ Write1x16: 1859+ st1 {v8.8h, v9.8h}, [x21], #32 1860+ subs x13, x13, #16 1861+ bge LoopCol1x16 1862+ 1863+ LoopCol1x8: 1864+ adds x13, x13, #16 1865+ cbz x13, End 1866+ subs x13, x13, #8 1867+ blt LoopCol1x4 1868+ mov x10, x0 // update matrixA 1869+ mov x14, x5 // reload depth 1870+ cbnz x12, InitFromBias1x8 1871+ dup v8.2d, xzr 1872+ b Compute1x8Enter 1873+ InitFromBias1x8: 1874+ ld1 {v8.8h}, [x12], #16 1875+ Compute1x8Enter: 1876+ bl Compute1x8Unit 1877+ Activation1x8: 1878+ cmp x4, #3 1879+ beq Relu61x8 1880+ cmp x4, #1 1881+ beq Relu1x8 1882+ b Write1x8 1883+ 1884+ Relu61x8: 1885+ fmin v8.8h, v8.8h, v7.8h 1886+ 1887+ Relu1x8: 1888+ dup v6.8h, wzr 1889+ fmax v8.8h, v8.8h, v6.8h 1890+ Write1x8: 1891+ st1 {v8.8h}, [x21], #16 1892+ subs x13, x13, #8 1893+ 1894+ LoopCol1x4: 1895+ adds x13, x13, #8 1896+ cbz x13, End 1897+ LoopCol1x4Core: 1898+ mov x10, x0 // update matrixA 1899+ mov x14, x5 // reload depth 1900+ cbnz x12, InitFromBias1x4 1901+ dup v8.2s, wzr 1902+ b Compute1x4Enter 1903+ InitFromBias1x4: 1904+ ld1 {v8.4h}, [x12], #8 1905+ Compute1x4Enter: 1906+ bl Compute1x4Unit 1907+ Activation1x4: 1908+ cmp x4, #3 1909+ beq Relu61x4 1910+ cmp x4, #1 1911+ beq Relu1x4 1912+ b Write1x4 1913+ 1914+ Relu61x4: 1915+ fmin v8.4h, v8.4h, v7.4h 1916+ Relu1x4: 1917+ dup v6.4h, wzr 1918+ fmax v8.4h, v8.4h, v6.4h 1919+ Write1x4: 1920+ cmp x13, #1 1921+ beq Write1x1 1922+ cmp x13, #2 1923+ beq Write1x2 1924+ cmp x13, #3 1925+ beq Write1x3 1926+ st1 {v8.4h}, [x21], #8 1927+ subs x13, x13, #4 1928+ bgt LoopCol1x4Core 1929+ b End 1930+ Write1x1: 1931+ st1 {v8.h}[0], [x21] 1932+ b End 1933+ Write1x2: 1934+ st1 {v8.s}[0], [x21] 1935+ b End 1936+ Write1x3: 1937+ add x22, x21, #4 1938+ st1 {v8.s}[0], [x21] 1939+ st1 {v8.h}[2], [x22] 1940+ b End 1941+ 1942+Compute12x16Unit: 1943+ subs x14, x14, #2 1944+ ble Compute12x16End 1945+ Compute12x16: 1946+ prfm pldl1keep, [x10, #632] 1947+ ld1 {v1.8h, v2.8h}, [x10], #32 1948+ ld1 {v4.8h, v5.8h}, [x11], #32 1949+ fmla v8.8h, v3.8h, v0.h[0] 1950+ fmla v10.8h, v3.8h, v0.h[1] 1951+ fmla v12.8h, v3.8h, v0.h[2] 1952+ fmla v14.8h, v3.8h, v0.h[3] 1953+ fmla v16.8h, v3.8h, v0.h[4] 1954+ fmla v18.8h, v3.8h, v0.h[5] 1955+ fmla v20.8h, v3.8h, v0.h[6] 1956+ fmla v22.8h, v3.8h, v0.h[7] 1957+ fmla v24.8h, v3.8h, v1.h[0] 1958+ fmla v26.8h, v3.8h, v1.h[1] 1959+ fmla v28.8h, v3.8h, v1.h[2] 1960+ fmla v30.8h, v3.8h, v1.h[3] 1961+ prfm pldl1strm, [x11, #632] 1962+ ld1 {v6.8h}, [x11], #16 1963+ fmla v9.8h, v4.8h, v0.h[0] 1964+ fmla v11.8h, v4.8h, v0.h[1] 1965+ fmla v13.8h, v4.8h, v0.h[2] 1966+ fmla v15.8h, v4.8h, v0.h[3] 1967+ fmla v17.8h, v4.8h, v0.h[4] 1968+ fmla v19.8h, v4.8h, v0.h[5] 1969+ fmla v21.8h, v4.8h, v0.h[6] 1970+ fmla v23.8h, v4.8h, v0.h[7] 1971+ fmla v25.8h, v4.8h, v1.h[0] 1972+ fmla v27.8h, v4.8h, v1.h[1] 1973+ fmla v29.8h, v4.8h, v1.h[2] 1974+ fmla v31.8h, v4.8h, v1.h[3] 1975+ 1976+ fmla v8.8h, v5.8h, v1.h[4] 1977+ fmla v10.8h, v5.8h, v1.h[5] 1978+ fmla v12.8h, v5.8h, v1.h[6] 1979+ fmla v14.8h, v5.8h, v1.h[7] 1980+ fmla v16.8h, v5.8h, v2.h[0] 1981+ fmla v18.8h, v5.8h, v2.h[1] 1982+ fmla v20.8h, v5.8h, v2.h[2] 1983+ fmla v22.8h, v5.8h, v2.h[3] 1984+ fmla v24.8h, v5.8h, v2.h[4] 1985+ fmla v26.8h, v5.8h, v2.h[5] 1986+ fmla v28.8h, v5.8h, v2.h[6] 1987+ fmla v30.8h, v5.8h, v2.h[7] 1988+ prfm pldl1strm, [x11, #632] 1989+ ld1 {v3.8h}, [x11], #16 1990+ fmla v9.8h, v6.8h, v1.h[4] 1991+ fmla v11.8h, v6.8h, v1.h[5] 1992+ fmla v13.8h, v6.8h, v1.h[6] 1993+ fmla v15.8h, v6.8h, v1.h[7] 1994+ prfm pldl1keep, [x10, #632] 1995+ ld1 {v0.8h}, [x10], #16 1996+ fmla v17.8h, v6.8h, v2.h[0] 1997+ fmla v19.8h, v6.8h, v2.h[1] 1998+ fmla v21.8h, v6.8h, v2.h[2] 1999+ fmla v23.8h, v6.8h, v2.h[3] 2000+ fmla v25.8h, v6.8h, v2.h[4] 2001+ fmla v27.8h, v6.8h, v2.h[5] 2002+ fmla v29.8h, v6.8h, v2.h[6] 2003+ fmla v31.8h, v6.8h, v2.h[7] 2004+ 2005+ subs x14, x14, #2 2006+ bgt Compute12x16 2007+ Compute12x16End: 2008+ cbnz x14, Compute12x16End1 2009+ prfm pldl1keep, [x10, #632] 2010+ ld1 {v1.4h}, [x10], #8 2011+ ld1 {v4.8h}, [x11], #16 2012+ fmla v8.8h, v3.8h, v0.h[0] 2013+ fmla v10.8h, v3.8h, v0.h[1] 2014+ fmla v12.8h, v3.8h, v0.h[2] 2015+ fmla v14.8h, v3.8h, v0.h[3] 2016+ fmla v16.8h, v3.8h, v0.h[4] 2017+ fmla v18.8h, v3.8h, v0.h[5] 2018+ fmla v20.8h, v3.8h, v0.h[6] 2019+ fmla v22.8h, v3.8h, v0.h[7] 2020+ fmla v24.8h, v3.8h, v1.h[0] 2021+ fmla v26.8h, v3.8h, v1.h[1] 2022+ fmla v28.8h, v3.8h, v1.h[2] 2023+ fmla v30.8h, v3.8h, v1.h[3] 2024+ prfm pldl1strm, [x11, #632] 2025+ ld1 {v3.8h}, [x11], #16 2026+ fmla v9.8h, v4.8h, v0.h[0] 2027+ fmla v11.8h, v4.8h, v0.h[1] 2028+ fmla v13.8h, v4.8h, v0.h[2] 2029+ fmla v15.8h, v4.8h, v0.h[3] 2030+ ld1 {v2.8h}, [x10], #16 2031+ fmla v17.8h, v4.8h, v0.h[4] 2032+ fmla v19.8h, v4.8h, v0.h[5] 2033+ fmla v21.8h, v4.8h, v0.h[6] 2034+ fmla v23.8h, v4.8h, v0.h[7] 2035+ fmla v25.8h, v4.8h, v1.h[0] 2036+ fmla v27.8h, v4.8h, v1.h[1] 2037+ fmla v29.8h, v4.8h, v1.h[2] 2038+ fmla v31.8h, v4.8h, v1.h[3] 2039+ mov v0.16b, v2.16b 2040+ Compute12x16End1: 2041+ ld1 {v1.4h}, [x10] 2042+ ld1 {v4.8h}, [x11], #16 2043+ fmla v8.8h, v3.8h, v0.h[0] 2044+ fmla v10.8h, v3.8h, v0.h[1] 2045+ fmla v12.8h, v3.8h, v0.h[2] 2046+ fmla v14.8h, v3.8h, v0.h[3] 2047+ fmla v16.8h, v3.8h, v0.h[4] 2048+ fmla v18.8h, v3.8h, v0.h[5] 2049+ fmla v20.8h, v3.8h, v0.h[6] 2050+ fmla v22.8h, v3.8h, v0.h[7] 2051+ fmla v24.8h, v3.8h, v1.h[0] 2052+ fmla v26.8h, v3.8h, v1.h[1] 2053+ fmla v28.8h, v3.8h, v1.h[2] 2054+ fmla v30.8h, v3.8h, v1.h[3] 2055+ fmla v9.8h, v4.8h, v0.h[0] 2056+ fmla v11.8h, v4.8h, v0.h[1] 2057+ fmla v13.8h, v4.8h, v0.h[2] 2058+ fmla v15.8h, v4.8h, v0.h[3] 2059+ fmla v17.8h, v4.8h, v0.h[4] 2060+ fmla v19.8h, v4.8h, v0.h[5] 2061+ fmla v21.8h, v4.8h, v0.h[6] 2062+ fmla v23.8h, v4.8h, v0.h[7] 2063+ fmla v25.8h, v4.8h, v1.h[0] 2064+ fmla v27.8h, v4.8h, v1.h[1] 2065+ fmla v29.8h, v4.8h, v1.h[2] 2066+ fmla v31.8h, v4.8h, v1.h[3] 2067+ ret 2068+ 2069+Compute12x8Unit: 2070+ subs x14, x14, #2 2071+ ble Compute12x8End 2072+ Compute12x8: 2073+ prfm pldl1keep, [x10, #632] 2074+ ld1 {v1.8h, v2.8h}, [x10], #32 2075+ ld1 {v4.8h}, [x11], #16 2076+ fmla v8.8h, v3.8h, v0.h[0] 2077+ fmla v10.8h, v3.8h, v0.h[1] 2078+ fmla v12.8h, v3.8h, v0.h[2] 2079+ fmla v14.8h, v3.8h, v0.h[3] 2080+ fmla v16.8h, v3.8h, v0.h[4] 2081+ fmla v18.8h, v3.8h, v0.h[5] 2082+ fmla v20.8h, v3.8h, v0.h[6] 2083+ fmla v22.8h, v3.8h, v0.h[7] 2084+ fmla v24.8h, v3.8h, v1.h[0] 2085+ fmla v26.8h, v3.8h, v1.h[1] 2086+ fmla v28.8h, v3.8h, v1.h[2] 2087+ fmla v30.8h, v3.8h, v1.h[3] 2088+ prfm pldl1strm, [x11, #632] 2089+ ld1 {v3.8h}, [x11], #16 2090+ fmla v8.8h, v4.8h, v1.h[4] 2091+ fmla v10.8h, v4.8h, v1.h[5] 2092+ fmla v12.8h, v4.8h, v1.h[6] 2093+ fmla v14.8h, v4.8h, v1.h[7] 2094+ ld1 {v0.8h}, [x10], #16 2095+ fmla v16.8h, v4.8h, v2.h[0] 2096+ fmla v18.8h, v4.8h, v2.h[1] 2097+ fmla v20.8h, v4.8h, v2.h[2] 2098+ fmla v22.8h, v4.8h, v2.h[3] 2099+ fmla v24.8h, v4.8h, v2.h[4] 2100+ fmla v26.8h, v4.8h, v2.h[5] 2101+ fmla v28.8h, v4.8h, v2.h[6] 2102+ fmla v30.8h, v4.8h, v2.h[7] 2103+ 2104+ subs x14, x14, #2 2105+ bgt Compute12x8 2106+ Compute12x8End: 2107+ cbnz x14, Compute12x8End1 2108+ prfm pldl1keep, [x10, #632] 2109+ ld1 {v1.4h}, [x10], #8 2110+ ld1 {v4.8h}, [x11], #16 2111+ fmla v8.8h, v3.8h, v0.h[0] 2112+ fmla v10.8h, v3.8h, v0.h[1] 2113+ fmla v12.8h, v3.8h, v0.h[2] 2114+ fmla v14.8h, v3.8h, v0.h[3] 2115+ fmla v16.8h, v3.8h, v0.h[4] 2116+ fmla v18.8h, v3.8h, v0.h[5] 2117+ fmla v20.8h, v3.8h, v0.h[6] 2118+ fmla v22.8h, v3.8h, v0.h[7] 2119+ fmla v24.8h, v3.8h, v1.h[0] 2120+ fmla v26.8h, v3.8h, v1.h[1] 2121+ fmla v28.8h, v3.8h, v1.h[2] 2122+ fmla v30.8h, v3.8h, v1.h[3] 2123+ ld1 {v0.8h}, [x10], #16 2124+ mov v3.16b, v4.16b 2125+ Compute12x8End1: 2126+ ld1 {v1.4h}, [x10] 2127+ fmla v8.8h, v3.8h, v0.h[0] 2128+ fmla v10.8h, v3.8h, v0.h[1] 2129+ fmla v12.8h, v3.8h, v0.h[2] 2130+ fmla v14.8h, v3.8h, v0.h[3] 2131+ fmla v16.8h, v3.8h, v0.h[4] 2132+ fmla v18.8h, v3.8h, v0.h[5] 2133+ fmla v20.8h, v3.8h, v0.h[6] 2134+ fmla v22.8h, v3.8h, v0.h[7] 2135+ fmla v24.8h, v3.8h, v1.h[0] 2136+ fmla v26.8h, v3.8h, v1.h[1] 2137+ fmla v28.8h, v3.8h, v1.h[2] 2138+ fmla v30.8h, v3.8h, v1.h[3] 2139+ ret 2140+ 2141+Compute12x4Unit: 2142+ subs x14, x14, #2 2143+ ble Compute12x4End 2144+ Compute12x4: 2145+ prfm pldl1keep, [x10, #632] 2146+ ld1 {v1.8h, v2.8h}, [x10], #32 2147+ ld1 {v4.4h}, [x11], #8 2148+ fmla v8.4h, v3.4h, v0.h[0] 2149+ fmla v10.4h, v3.4h, v0.h[1] 2150+ fmla v12.4h, v3.4h, v0.h[2] 2151+ fmla v14.4h, v3.4h, v0.h[3] 2152+ fmla v16.4h, v3.4h, v0.h[4] 2153+ fmla v18.4h, v3.4h, v0.h[5] 2154+ fmla v20.4h, v3.4h, v0.h[6] 2155+ fmla v22.4h, v3.4h, v0.h[7] 2156+ fmla v24.4h, v3.4h, v1.h[0] 2157+ fmla v26.4h, v3.4h, v1.h[1] 2158+ fmla v28.4h, v3.4h, v1.h[2] 2159+ fmla v30.4h, v3.4h, v1.h[3] 2160+ prfm pldl1strm, [x11, #632] 2161+ ld1 {v3.4h}, [x11], #8 2162+ fmla v8.4h, v4.4h, v1.h[4] 2163+ fmla v10.4h, v4.4h, v1.h[5] 2164+ fmla v12.4h, v4.4h, v1.h[6] 2165+ fmla v14.4h, v4.4h, v1.h[7] 2166+ ld1 {v0.8h}, [x10], #16 2167+ fmla v16.4h, v4.4h, v2.h[0] 2168+ fmla v18.4h, v4.4h, v2.h[1] 2169+ fmla v20.4h, v4.4h, v2.h[2] 2170+ fmla v22.4h, v4.4h, v2.h[3] 2171+ fmla v24.4h, v4.4h, v2.h[4] 2172+ fmla v26.4h, v4.4h, v2.h[5] 2173+ fmla v28.4h, v4.4h, v2.h[6] 2174+ fmla v30.4h, v4.4h, v2.h[7] 2175+ 2176+ subs x14, x14, #2 2177+ bgt Compute12x4 2178+ Compute12x4End: 2179+ cbnz x14, Compute12x4End1 2180+ prfm pldl1keep, [x10, #632] 2181+ ld1 {v1.4h}, [x10], #8 2182+ ld1 {v4.4h}, [x11], #8 2183+ fmla v8.4h, v3.4h, v0.h[0] 2184+ fmla v10.4h, v3.4h, v0.h[1] 2185+ fmla v12.4h, v3.4h, v0.h[2] 2186+ fmla v14.4h, v3.4h, v0.h[3] 2187+ fmla v16.4h, v3.4h, v0.h[4] 2188+ fmla v18.4h, v3.4h, v0.h[5] 2189+ fmla v20.4h, v3.4h, v0.h[6] 2190+ fmla v22.4h, v3.4h, v0.h[7] 2191+ fmla v24.4h, v3.4h, v1.h[0] 2192+ fmla v26.4h, v3.4h, v1.h[1] 2193+ fmla v28.4h, v3.4h, v1.h[2] 2194+ fmla v30.4h, v3.4h, v1.h[3] 2195+ ld1 {v0.8h}, [x10], #16 2196+ mov v3.8b, v4.8b 2197+ Compute12x4End1: 2198+ ld1 {v1.4h}, [x10] 2199+ fmla v8.4h, v3.4h, v0.h[0] 2200+ fmla v10.4h, v3.4h, v0.h[1] 2201+ fmla v12.4h, v3.4h, v0.h[2] 2202+ fmla v14.4h, v3.4h, v0.h[3] 2203+ fmla v16.4h, v3.4h, v0.h[4] 2204+ fmla v18.4h, v3.4h, v0.h[5] 2205+ fmla v20.4h, v3.4h, v0.h[6] 2206+ fmla v22.4h, v3.4h, v0.h[7] 2207+ fmla v24.4h, v3.4h, v1.h[0] 2208+ fmla v26.4h, v3.4h, v1.h[1] 2209+ fmla v28.4h, v3.4h, v1.h[2] 2210+ fmla v30.4h, v3.4h, v1.h[3] 2211+ ret 2212+ 2213+Compute8x16Unit: 2214+ subs x14, x14, #2 2215+ ble Compute8x16End 2216+ Compute8x16: 2217+ prfm pldl1keep, [x10, #632] 2218+ ld1 {v1.8h}, [x10], #16 2219+ ld1 {v4.8h, v5.8h}, [x11], #32 2220+ fmla v8.8h, v3.8h, v0.h[0] 2221+ fmla v10.8h, v3.8h, v0.h[1] 2222+ fmla v12.8h, v3.8h, v0.h[2] 2223+ fmla v14.8h, v3.8h, v0.h[3] 2224+ fmla v16.8h, v3.8h, v0.h[4] 2225+ fmla v18.8h, v3.8h, v0.h[5] 2226+ fmla v20.8h, v3.8h, v0.h[6] 2227+ fmla v22.8h, v3.8h, v0.h[7] 2228+ prfm pldl1strm, [x11, #632] 2229+ ld1 {v6.8h}, [x11], #16 2230+ fmla v9.8h, v4.8h, v0.h[0] 2231+ fmla v11.8h, v4.8h, v0.h[1] 2232+ fmla v13.8h, v4.8h, v0.h[2] 2233+ fmla v15.8h, v4.8h, v0.h[3] 2234+ fmla v17.8h, v4.8h, v0.h[4] 2235+ fmla v19.8h, v4.8h, v0.h[5] 2236+ fmla v21.8h, v4.8h, v0.h[6] 2237+ fmla v23.8h, v4.8h, v0.h[7] 2238+ 2239+ fmla v8.8h, v5.8h, v1.h[0] 2240+ fmla v10.8h, v5.8h, v1.h[1] 2241+ fmla v12.8h, v5.8h, v1.h[2] 2242+ fmla v14.8h, v5.8h, v1.h[3] 2243+ fmla v16.8h, v5.8h, v1.h[4] 2244+ fmla v18.8h, v5.8h, v1.h[5] 2245+ fmla v20.8h, v5.8h, v1.h[6] 2246+ fmla v22.8h, v5.8h, v1.h[7] 2247+ prfm pldl1strm, [x11, #632] 2248+ ld1 {v3.8h}, [x11], #16 2249+ fmla v9.8h, v6.8h, v1.h[0] 2250+ fmla v11.8h, v6.8h, v1.h[1] 2251+ fmla v13.8h, v6.8h, v1.h[2] 2252+ fmla v15.8h, v6.8h, v1.h[3] 2253+ prfm pldl1keep, [x10, #632] 2254+ ld1 {v0.8h}, [x10], #16 2255+ fmla v17.8h, v6.8h, v1.h[4] 2256+ fmla v19.8h, v6.8h, v1.h[5] 2257+ fmla v21.8h, v6.8h, v1.h[6] 2258+ fmla v23.8h, v6.8h, v1.h[7] 2259+ 2260+ subs x14, x14, #2 2261+ bgt Compute8x16 2262+ Compute8x16End: 2263+ cbnz x14, Compute8x16End1 2264+ prfm pldl1keep, [x10, #632] 2265+ ld1 {v1.8h}, [x10] 2266+ ld1 {v4.8h}, [x11], #16 2267+ fmla v8.8h, v3.8h, v0.h[0] 2268+ fmla v10.8h, v3.8h, v0.h[1] 2269+ fmla v12.8h, v3.8h, v0.h[2] 2270+ fmla v14.8h, v3.8h, v0.h[3] 2271+ fmla v16.8h, v3.8h, v0.h[4] 2272+ fmla v18.8h, v3.8h, v0.h[5] 2273+ fmla v20.8h, v3.8h, v0.h[6] 2274+ fmla v22.8h, v3.8h, v0.h[7] 2275+ prfm pldl1strm, [x11, #632] 2276+ ld1 {v3.8h}, [x11], #16 2277+ fmla v9.8h, v4.8h, v0.h[0] 2278+ fmla v11.8h, v4.8h, v0.h[1] 2279+ fmla v13.8h, v4.8h, v0.h[2] 2280+ fmla v15.8h, v4.8h, v0.h[3] 2281+ fmla v17.8h, v4.8h, v0.h[4] 2282+ fmla v19.8h, v4.8h, v0.h[5] 2283+ fmla v21.8h, v4.8h, v0.h[6] 2284+ fmla v23.8h, v4.8h, v0.h[7] 2285+ mov v0.16b, v1.16b 2286+ Compute8x16End1: 2287+ ld1 {v4.8h}, [x11], #16 2288+ fmla v8.8h, v3.8h, v0.h[0] 2289+ fmla v10.8h, v3.8h, v0.h[1] 2290+ fmla v12.8h, v3.8h, v0.h[2] 2291+ fmla v14.8h, v3.8h, v0.h[3] 2292+ fmla v16.8h, v3.8h, v0.h[4] 2293+ fmla v18.8h, v3.8h, v0.h[5] 2294+ fmla v20.8h, v3.8h, v0.h[6] 2295+ fmla v22.8h, v3.8h, v0.h[7] 2296+ fmla v9.8h, v4.8h, v0.h[0] 2297+ fmla v11.8h, v4.8h, v0.h[1] 2298+ fmla v13.8h, v4.8h, v0.h[2] 2299+ fmla v15.8h, v4.8h, v0.h[3] 2300+ fmla v17.8h, v4.8h, v0.h[4] 2301+ fmla v19.8h, v4.8h, v0.h[5] 2302+ fmla v21.8h, v4.8h, v0.h[6] 2303+ fmla v23.8h, v4.8h, v0.h[7] 2304+ ret 2305+ 2306+Compute8x8Unit: 2307+ subs x14, x14, #2 2308+ ble Compute8x8End 2309+ Compute8x8: 2310+ prfm pldl1keep, [x10, #632] 2311+ ld1 {v1.8h}, [x10], #16 2312+ ld1 {v4.8h}, [x11], #16 2313+ fmla v8.8h, v3.8h, v0.h[0] 2314+ fmla v10.8h, v3.8h, v0.h[1] 2315+ fmla v12.8h, v3.8h, v0.h[2] 2316+ fmla v14.8h, v3.8h, v0.h[3] 2317+ fmla v16.8h, v3.8h, v0.h[4] 2318+ fmla v18.8h, v3.8h, v0.h[5] 2319+ fmla v20.8h, v3.8h, v0.h[6] 2320+ fmla v22.8h, v3.8h, v0.h[7] 2321+ prfm pldl1strm, [x11, #632] 2322+ ld1 {v3.8h}, [x11], #16 2323+ fmla v8.8h, v4.8h, v1.h[0] 2324+ fmla v10.8h, v4.8h, v1.h[1] 2325+ fmla v12.8h, v4.8h, v1.h[2] 2326+ fmla v14.8h, v4.8h, v1.h[3] 2327+ ld1 {v0.8h}, [x10], #16 2328+ fmla v16.8h, v4.8h, v1.h[4] 2329+ fmla v18.8h, v4.8h, v1.h[5] 2330+ fmla v20.8h, v4.8h, v1.h[6] 2331+ fmla v22.8h, v4.8h, v1.h[7] 2332+ 2333+ subs x14, x14, #2 2334+ bgt Compute8x8 2335+ Compute8x8End: 2336+ cbnz x14, Compute8x8End1 2337+ prfm pldl1keep, [x10, #632] 2338+ ld1 {v1.8h}, [x10] 2339+ ld1 {v4.8h}, [x11], #16 2340+ fmla v8.8h, v3.8h, v0.h[0] 2341+ fmla v10.8h, v3.8h, v0.h[1] 2342+ fmla v12.8h, v3.8h, v0.h[2] 2343+ fmla v14.8h, v3.8h, v0.h[3] 2344+ fmla v16.8h, v3.8h, v0.h[4] 2345+ fmla v18.8h, v3.8h, v0.h[5] 2346+ fmla v20.8h, v3.8h, v0.h[6] 2347+ fmla v22.8h, v3.8h, v0.h[7] 2348+ mov v0.16b, v1.16b 2349+ mov v3.16b, v4.16b 2350+ Compute8x8End1: 2351+ fmla v8.8h, v3.8h, v0.h[0] 2352+ fmla v10.8h, v3.8h, v0.h[1] 2353+ fmla v12.8h, v3.8h, v0.h[2] 2354+ fmla v14.8h, v3.8h, v0.h[3] 2355+ fmla v16.8h, v3.8h, v0.h[4] 2356+ fmla v18.8h, v3.8h, v0.h[5] 2357+ fmla v20.8h, v3.8h, v0.h[6] 2358+ fmla v22.8h, v3.8h, v0.h[7] 2359+ ret 2360+ 2361+Compute8x4Unit: 2362+ subs x14, x14, #2 2363+ ble Compute8x4End 2364+ Compute8x4: 2365+ prfm pldl1keep, [x10, #632] 2366+ ld1 {v1.8h}, [x10], #16 2367+ ld1 {v4.4h}, [x11], #8 2368+ fmla v8.4h, v3.4h, v0.h[0] 2369+ fmla v10.4h, v3.4h, v0.h[1] 2370+ fmla v12.4h, v3.4h, v0.h[2] 2371+ fmla v14.4h, v3.4h, v0.h[3] 2372+ fmla v16.4h, v3.4h, v0.h[4] 2373+ fmla v18.4h, v3.4h, v0.h[5] 2374+ fmla v20.4h, v3.4h, v0.h[6] 2375+ fmla v22.4h, v3.4h, v0.h[7] 2376+ prfm pldl1strm, [x11, #632] 2377+ ld1 {v3.4h}, [x11], #8 2378+ fmla v8.4h, v4.4h, v1.h[0] 2379+ fmla v10.4h, v4.4h, v1.h[1] 2380+ fmla v12.4h, v4.4h, v1.h[2] 2381+ fmla v14.4h, v4.4h, v1.h[3] 2382+ ld1 {v0.8h}, [x10], #16 2383+ fmla v16.4h, v4.4h, v1.h[4] 2384+ fmla v18.4h, v4.4h, v1.h[5] 2385+ fmla v20.4h, v4.4h, v1.h[6] 2386+ fmla v22.4h, v4.4h, v1.h[7] 2387+ 2388+ subs x14, x14, #2 2389+ bgt Compute8x4 2390+ Compute8x4End: 2391+ cbnz x14, Compute8x4End1 2392+ prfm pldl1keep, [x10, #632] 2393+ ld1 {v1.8h}, [x10] 2394+ ld1 {v4.4h}, [x11], #8 2395+ fmla v8.4h, v3.4h, v0.h[0] 2396+ fmla v10.4h, v3.4h, v0.h[1] 2397+ fmla v12.4h, v3.4h, v0.h[2] 2398+ fmla v14.4h, v3.4h, v0.h[3] 2399+ fmla v16.4h, v3.4h, v0.h[4] 2400+ fmla v18.4h, v3.4h, v0.h[5] 2401+ fmla v20.4h, v3.4h, v0.h[6] 2402+ fmla v22.4h, v3.4h, v0.h[7] 2403+ mov v0.16b, v1.16b 2404+ mov v3.8b, v4.8b 2405+ Compute8x4End1: 2406+ fmla v8.4h, v3.4h, v0.h[0] 2407+ fmla v10.4h, v3.4h, v0.h[1] 2408+ fmla v12.4h, v3.4h, v0.h[2] 2409+ fmla v14.4h, v3.4h, v0.h[3] 2410+ fmla v16.4h, v3.4h, v0.h[4] 2411+ fmla v18.4h, v3.4h, v0.h[5] 2412+ fmla v20.4h, v3.4h, v0.h[6] 2413+ fmla v22.4h, v3.4h, v0.h[7] 2414+ ret 2415+ 2416+Compute4x16Unit: 2417+ subs x14, x14, #2 2418+ ble Compute4x16End 2419+ Compute4x16: 2420+ prfm pldl1keep, [x10, #632] 2421+ ld1 {v1.4h}, [x10], #8 2422+ ld1 {v4.8h, v5.8h}, [x11], #32 2423+ fmla v8.8h, v3.8h, v0.h[0] 2424+ fmla v10.8h, v3.8h, v0.h[1] 2425+ fmla v12.8h, v3.8h, v0.h[2] 2426+ fmla v14.8h, v3.8h, v0.h[3] 2427+ prfm pldl1strm, [x11, #632] 2428+ ld1 {v6.8h}, [x11], #16 2429+ fmla v9.8h, v4.8h, v0.h[0] 2430+ fmla v11.8h, v4.8h, v0.h[1] 2431+ fmla v13.8h, v4.8h, v0.h[2] 2432+ fmla v15.8h, v4.8h, v0.h[3] 2433+ 2434+ fmla v8.8h, v5.8h, v1.h[0] 2435+ fmla v10.8h, v5.8h, v1.h[1] 2436+ fmla v12.8h, v5.8h, v1.h[2] 2437+ fmla v14.8h, v5.8h, v1.h[3] 2438+ prfm pldl1strm, [x11, #632] 2439+ ld1 {v3.8h}, [x11], #16 2440+ fmla v9.8h, v6.8h, v1.h[0] 2441+ fmla v11.8h, v6.8h, v1.h[1] 2442+ fmla v13.8h, v6.8h, v1.h[2] 2443+ fmla v15.8h, v6.8h, v1.h[3] 2444+ ld1 {v0.4h}, [x10], #8 2445+ 2446+ subs x14, x14, #2 2447+ bgt Compute4x16 2448+ Compute4x16End: 2449+ cbnz x14, Compute4x16End1 2450+ prfm pldl1keep, [x10, #632] 2451+ ld1 {v1.4h}, [x10] 2452+ ld1 {v4.8h}, [x11], #16 2453+ fmla v8.8h, v3.8h, v0.h[0] 2454+ fmla v10.8h, v3.8h, v0.h[1] 2455+ fmla v12.8h, v3.8h, v0.h[2] 2456+ fmla v14.8h, v3.8h, v0.h[3] 2457+ prfm pldl1strm, [x11, #632] 2458+ ld1 {v3.8h}, [x11], #16 2459+ fmla v9.8h, v4.8h, v0.h[0] 2460+ fmla v11.8h, v4.8h, v0.h[1] 2461+ fmla v13.8h, v4.8h, v0.h[2] 2462+ fmla v15.8h, v4.8h, v0.h[3] 2463+ mov v0.8b, v1.8b 2464+ Compute4x16End1: 2465+ ld1 {v4.8h}, [x11], #16 2466+ fmla v8.8h, v3.8h, v0.h[0] 2467+ fmla v10.8h, v3.8h, v0.h[1] 2468+ fmla v12.8h, v3.8h, v0.h[2] 2469+ fmla v14.8h, v3.8h, v0.h[3] 2470+ fmla v9.8h, v4.8h, v0.h[0] 2471+ fmla v11.8h, v4.8h, v0.h[1] 2472+ fmla v13.8h, v4.8h, v0.h[2] 2473+ fmla v15.8h, v4.8h, v0.h[3] 2474+ ret 2475+ 2476+Compute4x8Unit: 2477+ subs x14, x14, #2 2478+ ble Compute4x8End 2479+ Compute4x8: 2480+ prfm pldl1keep, [x10, #632] 2481+ ld1 {v1.4h}, [x10], #8 2482+ ld1 {v4.8h}, [x11], #16 2483+ fmla v8.8h, v3.8h, v0.h[0] 2484+ fmla v10.8h, v3.8h, v0.h[1] 2485+ fmla v12.8h, v3.8h, v0.h[2] 2486+ fmla v14.8h, v3.8h, v0.h[3] 2487+ prfm pldl1strm, [x11, #632] 2488+ ld1 {v3.8h}, [x11], #16 2489+ fmla v8.8h, v4.8h, v1.h[0] 2490+ fmla v10.8h, v4.8h, v1.h[1] 2491+ fmla v12.8h, v4.8h, v1.h[2] 2492+ fmla v14.8h, v4.8h, v1.h[3] 2493+ ld1 {v0.4h}, [x10], #8 2494+ 2495+ subs x14, x14, #2 2496+ bgt Compute4x8 2497+ Compute4x8End: 2498+ cbnz x14, Compute4x8End1 2499+ prfm pldl1keep, [x10, #632] 2500+ ld1 {v1.4h}, [x10] 2501+ ld1 {v4.8h}, [x11], #16 2502+ fmla v8.8h, v3.8h, v0.h[0] 2503+ fmla v10.8h, v3.8h, v0.h[1] 2504+ fmla v12.8h, v3.8h, v0.h[2] 2505+ fmla v14.8h, v3.8h, v0.h[3] 2506+ mov v0.8b, v1.8b 2507+ mov v3.16b, v4.16b 2508+ Compute4x8End1: 2509+ fmla v8.8h, v3.8h, v0.h[0] 2510+ fmla v10.8h, v3.8h, v0.h[1] 2511+ fmla v12.8h, v3.8h, v0.h[2] 2512+ fmla v14.8h, v3.8h, v0.h[3] 2513+ ret 2514+ 2515+Compute4x4Unit: 2516+ subs x14, x14, #2 2517+ ble Compute4x4End 2518+ Compute4x4: 2519+ prfm pldl1keep, [x10, #632] 2520+ ld1 {v1.4h}, [x10], #8 2521+ ld1 {v4.4h}, [x11], #8 2522+ fmla v8.4h, v3.4h, v0.h[0] 2523+ fmla v10.4h, v3.4h, v0.h[1] 2524+ fmla v12.4h, v3.4h, v0.h[2] 2525+ fmla v14.4h, v3.4h, v0.h[3] 2526+ prfm pldl1strm, [x11, #632] 2527+ ld1 {v3.4h}, [x11], #8 2528+ fmla v8.4h, v4.4h, v1.h[0] 2529+ fmla v10.4h, v4.4h, v1.h[1] 2530+ fmla v12.4h, v4.4h, v1.h[2] 2531+ fmla v14.4h, v4.4h, v1.h[3] 2532+ ld1 {v0.4h}, [x10], #8 2533+ 2534+ subs x14, x14, #2 2535+ bgt Compute4x4 2536+ Compute4x4End: 2537+ cbnz x14, Compute4x4End1 2538+ prfm pldl1keep, [x10, #632] 2539+ ld1 {v1.4h}, [x10] 2540+ ld1 {v4.4h}, [x11], #8 2541+ fmla v8.4h, v3.4h, v0.h[0] 2542+ fmla v10.4h, v3.4h, v0.h[1] 2543+ fmla v12.4h, v3.4h, v0.h[2] 2544+ fmla v14.4h, v3.4h, v0.h[3] 2545+ mov v0.8b, v1.8b 2546+ mov v3.8b, v4.8b 2547+ Compute4x4End1: 2548+ fmla v8.4h, v3.4h, v0.h[0] 2549+ fmla v10.4h, v3.4h, v0.h[1] 2550+ fmla v12.4h, v3.4h, v0.h[2] 2551+ fmla v14.4h, v3.4h, v0.h[3] 2552+ ret 2553+ 2554+Compute3x16Unit: 2555+ add x19, x10, x16 2556+ add x20, x10, x16, lsl #1 2557+ subs x14, x14, #8 2558+ blt Compute3x16End4 2559+ Compute3x16: 2560+ ld1 {v0.8h}, [x10], #16 2561+ ld1 {v1.8h}, [x19], #16 2562+ ld1 {v2.8h}, [x20], #16 2563+ prfm pldl1strm, [x11, #632] 2564+ ld1 {v3.8h, v4.8h}, [x11], #32 2565+ fmla v8.8h, v3.8h, v0.h[0] 2566+ fmla v10.8h, v3.8h, v1.h[0] 2567+ fmla v12.8h, v3.8h, v2.h[0] 2568+ ld1 {v5.8h, v6.8h}, [x11], #32 2569+ fmla v9.8h, v4.8h, v0.h[0] 2570+ fmla v11.8h, v4.8h, v1.h[0] 2571+ fmla v13.8h, v4.8h, v2.h[0] 2572+ fmla v8.8h, v5.8h, v0.h[1] 2573+ fmla v10.8h, v5.8h, v1.h[1] 2574+ fmla v12.8h, v5.8h, v2.h[1] 2575+ ld1 {v3.8h, v4.8h}, [x11], #32 2576+ fmla v9.8h, v6.8h, v0.h[1] 2577+ fmla v11.8h, v6.8h, v1.h[1] 2578+ fmla v13.8h, v6.8h, v2.h[1] 2579+ fmla v8.8h, v3.8h, v0.h[2] 2580+ fmla v10.8h, v3.8h, v1.h[2] 2581+ fmla v12.8h, v3.8h, v2.h[2] 2582+ ld1 {v5.8h, v6.8h}, [x11], #32 2583+ fmla v9.8h, v4.8h, v0.h[2] 2584+ fmla v11.8h, v4.8h, v1.h[2] 2585+ fmla v13.8h, v4.8h, v2.h[2] 2586+ fmla v8.8h, v5.8h, v0.h[3] 2587+ fmla v10.8h, v5.8h, v1.h[3] 2588+ fmla v12.8h, v5.8h, v2.h[3] 2589+ prfm pldl1strm, [x11, #632] 2590+ ld1 {v3.8h, v4.8h}, [x11], #32 2591+ fmla v9.8h, v6.8h, v0.h[3] 2592+ fmla v11.8h, v6.8h, v1.h[3] 2593+ fmla v13.8h, v6.8h, v2.h[3] 2594+ 2595+ fmla v8.8h, v3.8h, v0.h[4] 2596+ fmla v10.8h, v3.8h, v1.h[4] 2597+ fmla v12.8h, v3.8h, v2.h[4] 2598+ ld1 {v5.8h, v6.8h}, [x11], #32 2599+ fmla v9.8h, v4.8h, v0.h[4] 2600+ fmla v11.8h, v4.8h, v1.h[4] 2601+ fmla v13.8h, v4.8h, v2.h[4] 2602+ fmla v8.8h, v5.8h, v0.h[5] 2603+ fmla v10.8h, v5.8h, v1.h[5] 2604+ fmla v12.8h, v5.8h, v2.h[5] 2605+ ld1 {v3.8h, v4.8h}, [x11], #32 2606+ fmla v9.8h, v6.8h, v0.h[5] 2607+ fmla v11.8h, v6.8h, v1.h[5] 2608+ fmla v13.8h, v6.8h, v2.h[5] 2609+ fmla v8.8h, v3.8h, v0.h[6] 2610+ fmla v10.8h, v3.8h, v1.h[6] 2611+ fmla v12.8h, v3.8h, v2.h[6] 2612+ ld1 {v5.8h, v6.8h}, [x11], #32 2613+ fmla v9.8h, v4.8h, v0.h[6] 2614+ fmla v11.8h, v4.8h, v1.h[6] 2615+ fmla v13.8h, v4.8h, v2.h[6] 2616+ fmla v8.8h, v5.8h, v0.h[7] 2617+ fmla v10.8h, v5.8h, v1.h[7] 2618+ fmla v12.8h, v5.8h, v2.h[7] 2619+ fmla v9.8h, v6.8h, v0.h[7] 2620+ fmla v11.8h, v6.8h, v1.h[7] 2621+ fmla v13.8h, v6.8h, v2.h[7] 2622+ 2623+ subs x14, x14, #8 2624+ bge Compute3x16 2625+ Compute3x16End4: 2626+ adds x14, x14, #8 2627+ cbz x14, Compute3x16Return 2628+ subs x14, x14, #4 2629+ blt Compute3x16EndTail 2630+ ld1 {v0.4h}, [x10], #8 2631+ ld1 {v1.4h}, [x19], #8 2632+ ld1 {v2.4h}, [x20], #8 2633+ prfm pldl1strm, [x11, #632] 2634+ ld1 {v3.8h, v4.8h}, [x11], #32 2635+ fmla v8.8h, v3.8h, v0.h[0] 2636+ fmla v10.8h, v3.8h, v1.h[0] 2637+ fmla v12.8h, v3.8h, v2.h[0] 2638+ ld1 {v5.8h, v6.8h}, [x11], #32 2639+ fmla v9.8h, v4.8h, v0.h[0] 2640+ fmla v11.8h, v4.8h, v1.h[0] 2641+ fmla v13.8h, v4.8h, v2.h[0] 2642+ fmla v8.8h, v5.8h, v0.h[1] 2643+ fmla v10.8h, v5.8h, v1.h[1] 2644+ fmla v12.8h, v5.8h, v2.h[1] 2645+ ld1 {v3.8h, v4.8h}, [x11], #32 2646+ fmla v9.8h, v6.8h, v0.h[1] 2647+ fmla v11.8h, v6.8h, v1.h[1] 2648+ fmla v13.8h, v6.8h, v2.h[1] 2649+ fmla v8.8h, v3.8h, v0.h[2] 2650+ fmla v10.8h, v3.8h, v1.h[2] 2651+ fmla v12.8h, v3.8h, v2.h[2] 2652+ ld1 {v5.8h, v6.8h}, [x11], #32 2653+ fmla v9.8h, v4.8h, v0.h[2] 2654+ fmla v11.8h, v4.8h, v1.h[2] 2655+ fmla v13.8h, v4.8h, v2.h[2] 2656+ fmla v8.8h, v5.8h, v0.h[3] 2657+ fmla v10.8h, v5.8h, v1.h[3] 2658+ fmla v12.8h, v5.8h, v2.h[3] 2659+ fmla v9.8h, v6.8h, v0.h[3] 2660+ fmla v11.8h, v6.8h, v1.h[3] 2661+ fmla v13.8h, v6.8h, v2.h[3] 2662+ subs x14, x14, #4 2663+ Compute3x16EndTail: 2664+ adds x14, x14, #4 2665+ cbz x14, Compute3x16Return 2666+ cmp x14, #1 2667+ beq Compute3x16EndTail1 2668+ cmp x14, #2 2669+ beq Compute3x16EndTail2 2670+ ld1 {v0.4h}, [x10] 2671+ ld1 {v1.4h}, [x19] 2672+ ld1 {v2.s}[0], [x20], #4 2673+ ld1 {v2.h}[2], [x20] 2674+ prfm pldl1strm, [x11, #632] 2675+ ld1 {v3.8h, v4.8h}, [x11], #32 2676+ fmla v8.8h, v3.8h, v0.h[0] 2677+ fmla v10.8h, v3.8h, v1.h[0] 2678+ fmla v12.8h, v3.8h, v2.h[0] 2679+ ld1 {v5.8h, v6.8h}, [x11], #32 2680+ fmla v9.8h, v4.8h, v0.h[0] 2681+ fmla v11.8h, v4.8h, v1.h[0] 2682+ fmla v13.8h, v4.8h, v2.h[0] 2683+ fmla v8.8h, v5.8h, v0.h[1] 2684+ fmla v10.8h, v5.8h, v1.h[1] 2685+ fmla v12.8h, v5.8h, v2.h[1] 2686+ ld1 {v3.8h, v4.8h}, [x11], #32 2687+ fmla v9.8h, v6.8h, v0.h[1] 2688+ fmla v11.8h, v6.8h, v1.h[1] 2689+ fmla v13.8h, v6.8h, v2.h[1] 2690+ fmla v8.8h, v3.8h, v0.h[2] 2691+ fmla v10.8h, v3.8h, v1.h[2] 2692+ fmla v12.8h, v3.8h, v2.h[2] 2693+ fmla v9.8h, v4.8h, v0.h[2] 2694+ fmla v11.8h, v4.8h, v1.h[2] 2695+ fmla v13.8h, v4.8h, v2.h[2] 2696+ b Compute3x16Return 2697+ Compute3x16EndTail2: 2698+ ld1 {v0.4h}, [x10] 2699+ ld1 {v1.4h}, [x19] 2700+ ld1 {v2.s}[0], [x20] 2701+ prfm pldl1strm, [x11, #632] 2702+ ld1 {v3.8h, v4.8h}, [x11], #32 2703+ fmla v8.8h, v3.8h, v0.h[0] 2704+ fmla v10.8h, v3.8h, v1.h[0] 2705+ fmla v12.8h, v3.8h, v2.h[0] 2706+ ld1 {v5.8h, v6.8h}, [x11], #32 2707+ fmla v9.8h, v4.8h, v0.h[0] 2708+ fmla v11.8h, v4.8h, v1.h[0] 2709+ fmla v13.8h, v4.8h, v2.h[0] 2710+ fmla v8.8h, v5.8h, v0.h[1] 2711+ fmla v10.8h, v5.8h, v1.h[1] 2712+ fmla v12.8h, v5.8h, v2.h[1] 2713+ fmla v9.8h, v6.8h, v0.h[1] 2714+ fmla v11.8h, v6.8h, v1.h[1] 2715+ fmla v13.8h, v6.8h, v2.h[1] 2716+ b Compute3x16Return 2717+ Compute3x16EndTail1: 2718+ ld1 {v0.h}[0], [x10] 2719+ ld1 {v1.h}[0], [x19] 2720+ ld1 {v2.h}[0], [x20] 2721+ prfm pldl1strm, [x11, #632] 2722+ ld1 {v3.8h, v4.8h}, [x11], #32 2723+ fmla v8.8h, v3.8h, v0.h[0] 2724+ fmla v10.8h, v3.8h, v1.h[0] 2725+ fmla v12.8h, v3.8h, v2.h[0] 2726+ fmla v9.8h, v4.8h, v0.h[0] 2727+ fmla v11.8h, v4.8h, v1.h[0] 2728+ fmla v13.8h, v4.8h, v2.h[0] 2729+ Compute3x16Return: 2730+ ret 2731+ 2732+Compute3x8Unit: 2733+ add x19, x10, x16 2734+ add x20, x10, x16, lsl #1 2735+ subs x14, x14, #8 2736+ blt Compute3x8End4 2737+ Compute3x8: 2738+ ld1 {v0.8h}, [x10], #16 2739+ ld1 {v1.8h}, [x19], #16 2740+ ld1 {v2.8h}, [x20], #16 2741+ prfm pldl1strm, [x11, #632] 2742+ ld1 {v3.8h, v4.8h}, [x11], #32 2743+ fmla v8.8h, v3.8h, v0.h[0] 2744+ fmla v10.8h, v3.8h, v1.h[0] 2745+ fmla v12.8h, v3.8h, v2.h[0] 2746+ ld1 {v5.8h, v6.8h}, [x11], #32 2747+ fmla v8.8h, v4.8h, v0.h[1] 2748+ fmla v10.8h, v4.8h, v1.h[1] 2749+ fmla v12.8h, v4.8h, v2.h[1] 2750+ fmla v8.8h, v5.8h, v0.h[2] 2751+ fmla v10.8h, v5.8h, v1.h[2] 2752+ fmla v12.8h, v5.8h, v2.h[2] 2753+ prfm pldl1strm, [x11, #632] 2754+ ld1 {v3.8h, v4.8h}, [x11], #32 2755+ fmla v8.8h, v6.8h, v0.h[3] 2756+ fmla v10.8h, v6.8h, v1.h[3] 2757+ fmla v12.8h, v6.8h, v2.h[3] 2758+ fmla v8.8h, v3.8h, v0.h[4] 2759+ fmla v10.8h, v3.8h, v1.h[4] 2760+ fmla v12.8h, v3.8h, v2.h[4] 2761+ ld1 {v5.8h, v6.8h}, [x11], #32 2762+ fmla v8.8h, v4.8h, v0.h[5] 2763+ fmla v10.8h, v4.8h, v1.h[5] 2764+ fmla v12.8h, v4.8h, v2.h[5] 2765+ fmla v8.8h, v5.8h, v0.h[6] 2766+ fmla v10.8h, v5.8h, v1.h[6] 2767+ fmla v12.8h, v5.8h, v2.h[6] 2768+ fmla v8.8h, v6.8h, v0.h[7] 2769+ fmla v10.8h, v6.8h, v1.h[7] 2770+ fmla v12.8h, v6.8h, v2.h[7] 2771+ 2772+ subs x14, x14, #8 2773+ bge Compute3x8 2774+ Compute3x8End4: 2775+ adds x14, x14, #8 2776+ cbz x14, Compute3x8Return 2777+ subs x14, x14, #4 2778+ blt Compute3x8EndTail 2779+ ld1 {v0.4h}, [x10], #8 2780+ ld1 {v1.4h}, [x19], #8 2781+ ld1 {v2.4h}, [x20], #8 2782+ prfm pldl1strm, [x11, #632] 2783+ ld1 {v3.8h, v4.8h}, [x11], #32 2784+ fmla v8.8h, v3.8h, v0.h[0] 2785+ fmla v10.8h, v3.8h, v1.h[0] 2786+ fmla v12.8h, v3.8h, v2.h[0] 2787+ ld1 {v5.8h, v6.8h}, [x11], #32 2788+ fmla v8.8h, v4.8h, v0.h[1] 2789+ fmla v10.8h, v4.8h, v1.h[1] 2790+ fmla v12.8h, v4.8h, v2.h[1] 2791+ fmla v8.8h, v5.8h, v0.h[2] 2792+ fmla v10.8h, v5.8h, v1.h[2] 2793+ fmla v12.8h, v5.8h, v2.h[2] 2794+ fmla v8.8h, v6.8h, v0.h[3] 2795+ fmla v10.8h, v6.8h, v1.h[3] 2796+ fmla v12.8h, v6.8h, v2.h[3] 2797+ subs x14, x14, #4 2798+ Compute3x8EndTail: 2799+ adds x14, x14, #4 2800+ cbz x14, Compute3x8Return 2801+ cmp x14, #1 2802+ beq Compute3x8EndTail1 2803+ cmp x14, #2 2804+ beq Compute3x8EndTail2 2805+ ld1 {v0.4h}, [x10] 2806+ ld1 {v1.4h}, [x19] 2807+ ld1 {v2.s}[0], [x20], #4 2808+ ld1 {v2.h}[2], [x20] 2809+ prfm pldl1strm, [x11, #632] 2810+ ld1 {v3.8h, v4.8h}, [x11], #32 2811+ fmla v8.8h, v3.8h, v0.h[0] 2812+ fmla v10.8h, v3.8h, v1.h[0] 2813+ fmla v12.8h, v3.8h, v2.h[0] 2814+ ld1 {v5.8h}, [x11], #16 2815+ fmla v8.8h, v4.8h, v0.h[1] 2816+ fmla v10.8h, v4.8h, v1.h[1] 2817+ fmla v12.8h, v4.8h, v2.h[1] 2818+ fmla v8.8h, v5.8h, v0.h[2] 2819+ fmla v10.8h, v5.8h, v1.h[2] 2820+ fmla v12.8h, v5.8h, v2.h[2] 2821+ b Compute3x8Return 2822+ Compute3x8EndTail2: 2823+ ld1 {v0.4h}, [x10] 2824+ ld1 {v1.4h}, [x19] 2825+ ld2 {v2.h, v3.h}[0], [x20] 2826+ prfm pldl1strm, [x11, #632] 2827+ ld1 {v5.8h, v6.8h}, [x11], #32 2828+ fmla v8.8h, v5.8h, v0.h[0] 2829+ fmla v10.8h, v5.8h, v1.h[0] 2830+ fmla v12.8h, v5.8h, v2.h[0] 2831+ fmla v8.8h, v6.8h, v0.h[1] 2832+ fmla v10.8h, v6.8h, v1.h[1] 2833+ fmla v12.8h, v6.8h, v3.h[0] 2834+ b Compute3x8Return 2835+ Compute3x8EndTail1: 2836+ ld1 {v0.h}[0], [x10] 2837+ ld1 {v1.h}[0], [x19] 2838+ ld1 {v2.h}[0], [x20] 2839+ prfm pldl1strm, [x11, #632] 2840+ ld1 {v3.8h}, [x11], #16 2841+ fmla v8.8h, v3.8h, v0.h[0] 2842+ fmla v10.8h, v3.8h, v1.h[0] 2843+ fmla v12.8h, v3.8h, v2.h[0] 2844+ Compute3x8Return: 2845+ ret 2846+ 2847+Compute3x4Unit: 2848+ add x19, x10, x16 2849+ add x20, x10, x16, lsl #1 2850+ subs x14, x14, #8 2851+ blt Compute3x4End4 2852+ Compute3x4: 2853+ ld1 {v0.8h}, [x10], #16 2854+ ld1 {v1.8h}, [x19], #16 2855+ ld1 {v2.8h}, [x20], #16 2856+ prfm pldl1strm, [x11, #632] 2857+ ld1 {v3.4h, v4.4h}, [x11], #16 2858+ fmla v8.4h, v3.4h, v0.h[0] 2859+ fmla v10.4h, v3.4h, v1.h[0] 2860+ fmla v12.4h, v3.4h, v2.h[0] 2861+ ld1 {v5.4h, v6.4h}, [x11], #16 2862+ fmla v8.4h, v4.4h, v0.h[1] 2863+ fmla v10.4h, v4.4h, v1.h[1] 2864+ fmla v12.4h, v4.4h, v2.h[1] 2865+ fmla v8.4h, v5.4h, v0.h[2] 2866+ fmla v10.4h, v5.4h, v1.h[2] 2867+ fmla v12.4h, v5.4h, v2.h[2] 2868+ prfm pldl1strm, [x11, #632] 2869+ ld1 {v3.4h, v4.4h}, [x11], #16 2870+ fmla v8.4h, v6.4h, v0.h[3] 2871+ fmla v10.4h, v6.4h, v1.h[3] 2872+ fmla v12.4h, v6.4h, v2.h[3] 2873+ fmla v8.4h, v3.4h, v0.h[4] 2874+ fmla v10.4h, v3.4h, v1.h[4] 2875+ fmla v12.4h, v3.4h, v2.h[4] 2876+ ld1 {v5.4h, v6.4h}, [x11], #16 2877+ fmla v8.4h, v4.4h, v0.h[5] 2878+ fmla v10.4h, v4.4h, v1.h[5] 2879+ fmla v12.4h, v4.4h, v2.h[5] 2880+ fmla v8.4h, v5.4h, v0.h[6] 2881+ fmla v10.4h, v5.4h, v1.h[6] 2882+ fmla v12.4h, v5.4h, v2.h[6] 2883+ fmla v8.4h, v6.4h, v0.h[7] 2884+ fmla v10.4h, v6.4h, v1.h[7] 2885+ fmla v12.4h, v6.4h, v2.h[7] 2886+ 2887+ subs x14, x14, #8 2888+ bge Compute3x4 2889+ Compute3x4End4: 2890+ adds x14, x14, #8 2891+ cbz x14, Compute3x4Return 2892+ subs x14, x14, #4 2893+ blt Compute3x4EndTail 2894+ ld1 {v0.4h}, [x10], #8 2895+ ld1 {v1.4h}, [x19], #8 2896+ ld1 {v2.4h}, [x20], #8 2897+ prfm pldl1strm, [x11, #632] 2898+ ld1 {v3.4h, v4.4h}, [x11], #16 2899+ fmla v8.4h, v3.4h, v0.h[0] 2900+ fmla v10.4h, v3.4h, v1.h[0] 2901+ fmla v12.4h, v3.4h, v2.h[0] 2902+ ld1 {v5.4h, v6.4h}, [x11], #16 2903+ fmla v8.4h, v4.4h, v0.h[1] 2904+ fmla v10.4h, v4.4h, v1.h[1] 2905+ fmla v12.4h, v4.4h, v2.h[1] 2906+ fmla v8.4h, v5.4h, v0.h[2] 2907+ fmla v10.4h, v5.4h, v1.h[2] 2908+ fmla v12.4h, v5.4h, v2.h[2] 2909+ fmla v8.4h, v6.4h, v0.h[3] 2910+ fmla v10.4h, v6.4h, v1.h[3] 2911+ fmla v12.4h, v6.4h, v2.h[3] 2912+ subs x14, x14, #4 2913+ Compute3x4EndTail: 2914+ adds x14, x14, #4 2915+ cbz x14, Compute3x4Return 2916+ cmp x14, #1 2917+ beq Compute3x4EndTail1 2918+ cmp x14, #2 2919+ beq Compute3x4EndTail2 2920+ ld1 {v0.4h}, [x10] 2921+ ld1 {v1.4h}, [x19] 2922+ ld1 {v2.s}[0], [x20], #4 2923+ ld1 {v2.h}[2], [x20] 2924+ prfm pldl1strm, [x11, #632] 2925+ ld1 {v3.4h, v4.4h}, [x11], #16 2926+ fmla v8.4h, v3.4h, v0.h[0] 2927+ fmla v10.4h, v3.4h, v1.h[0] 2928+ fmla v12.4h, v3.4h, v2.h[0] 2929+ ld1 {v5.4h}, [x11], #8 2930+ fmla v8.4h, v4.4h, v0.h[1] 2931+ fmla v10.4h, v4.4h, v1.h[1] 2932+ fmla v12.4h, v4.4h, v2.h[1] 2933+ fmla v8.4h, v5.4h, v0.h[2] 2934+ fmla v10.4h, v5.4h, v1.h[2] 2935+ fmla v12.4h, v5.4h, v2.h[2] 2936+ b Compute3x4Return 2937+ Compute3x4EndTail2: 2938+ ld1 {v0.4h}, [x10] 2939+ ld1 {v1.4h}, [x19] 2940+ ld2 {v2.h, v3.h}[0], [x20] 2941+ prfm pldl1strm, [x11, #632] 2942+ ld1 {v5.4h, v6.4h}, [x11], #16 2943+ fmla v8.4h, v5.4h, v0.h[0] 2944+ fmla v10.4h, v5.4h, v1.h[0] 2945+ fmla v12.4h, v5.4h, v2.h[0] 2946+ fmla v8.4h, v6.4h, v0.h[1] 2947+ fmla v10.4h, v6.4h, v1.h[1] 2948+ fmla v12.4h, v6.4h, v3.h[0] 2949+ b Compute3x4Return 2950+ Compute3x4EndTail1: 2951+ ld1 {v0.h}[0], [x10] 2952+ ld1 {v1.h}[0], [x19] 2953+ ld1 {v2.h}[0], [x20] 2954+ prfm pldl1strm, [x11, #632] 2955+ ld1 {v3.4h}, [x11], #8 2956+ fmla v8.4h, v3.4h, v0.h[0] 2957+ fmla v10.4h, v3.4h, v1.h[0] 2958+ fmla v12.4h, v3.4h, v2.h[0] 2959+ Compute3x4Return: 2960+ ret 2961+ 2962+Compute2x16Unit: 2963+ add x19, x10, x16 2964+ subs x14, x14, #8 2965+ blt Compute2x16End4 2966+ Compute2x16: 2967+ ld1 {v0.8h}, [x10], #16 2968+ ld1 {v1.8h}, [x19], #16 2969+ prfm pldl1strm, [x11, #632] 2970+ ld1 {v3.8h, v4.8h}, [x11], #32 2971+ fmla v8.8h, v3.8h, v0.h[0] 2972+ fmla v10.8h, v3.8h, v1.h[0] 2973+ ld1 {v5.8h, v6.8h}, [x11], #32 2974+ fmla v9.8h, v4.8h, v0.h[0] 2975+ fmla v11.8h, v4.8h, v1.h[0] 2976+ fmla v8.8h, v5.8h, v0.h[1] 2977+ fmla v10.8h, v5.8h, v1.h[1] 2978+ ld1 {v3.8h, v4.8h}, [x11], #32 2979+ fmla v9.8h, v6.8h, v0.h[1] 2980+ fmla v11.8h, v6.8h, v1.h[1] 2981+ fmla v8.8h, v3.8h, v0.h[2] 2982+ fmla v10.8h, v3.8h, v1.h[2] 2983+ ld1 {v5.8h, v6.8h}, [x11], #32 2984+ fmla v9.8h, v4.8h, v0.h[2] 2985+ fmla v11.8h, v4.8h, v1.h[2] 2986+ fmla v8.8h, v5.8h, v0.h[3] 2987+ fmla v10.8h, v5.8h, v1.h[3] 2988+ prfm pldl1strm, [x11, #632] 2989+ ld1 {v3.8h, v4.8h}, [x11], #32 2990+ fmla v9.8h, v6.8h, v0.h[3] 2991+ fmla v11.8h, v6.8h, v1.h[3] 2992+ 2993+ fmla v8.8h, v3.8h, v0.h[4] 2994+ fmla v10.8h, v3.8h, v1.h[4] 2995+ ld1 {v5.8h, v6.8h}, [x11], #32 2996+ fmla v9.8h, v4.8h, v0.h[4] 2997+ fmla v11.8h, v4.8h, v1.h[4] 2998+ fmla v8.8h, v5.8h, v0.h[5] 2999+ fmla v10.8h, v5.8h, v1.h[5] 3000+ ld1 {v3.8h, v4.8h}, [x11], #32 3001+ fmla v9.8h, v6.8h, v0.h[5] 3002+ fmla v11.8h, v6.8h, v1.h[5] 3003+ fmla v8.8h, v3.8h, v0.h[6] 3004+ fmla v10.8h, v3.8h, v1.h[6] 3005+ ld1 {v5.8h, v6.8h}, [x11], #32 3006+ fmla v9.8h, v4.8h, v0.h[6] 3007+ fmla v11.8h, v4.8h, v1.h[6] 3008+ fmla v8.8h, v5.8h, v0.h[7] 3009+ fmla v10.8h, v5.8h, v1.h[7] 3010+ fmla v9.8h, v6.8h, v0.h[7] 3011+ fmla v11.8h, v6.8h, v1.h[7] 3012+ 3013+ subs x14, x14, #8 3014+ bge Compute2x16 3015+ Compute2x16End4: 3016+ adds x14, x14, #8 3017+ cbz x14, Compute2x16Return 3018+ subs x14, x14, #4 3019+ blt Compute2x16EndTail 3020+ ld1 {v0.4h}, [x10], #8 3021+ ld1 {v1.4h}, [x19], #8 3022+ prfm pldl1strm, [x11, #632] 3023+ ld1 {v3.8h, v4.8h}, [x11], #32 3024+ fmla v8.8h, v3.8h, v0.h[0] 3025+ fmla v10.8h, v3.8h, v1.h[0] 3026+ ld1 {v5.8h, v6.8h}, [x11], #32 3027+ fmla v9.8h, v4.8h, v0.h[0] 3028+ fmla v11.8h, v4.8h, v1.h[0] 3029+ fmla v8.8h, v5.8h, v0.h[1] 3030+ fmla v10.8h, v5.8h, v1.h[1] 3031+ ld1 {v3.8h, v4.8h}, [x11], #32 3032+ fmla v9.8h, v6.8h, v0.h[1] 3033+ fmla v11.8h, v6.8h, v1.h[1] 3034+ fmla v8.8h, v3.8h, v0.h[2] 3035+ fmla v10.8h, v3.8h, v1.h[2] 3036+ ld1 {v5.8h, v6.8h}, [x11], #32 3037+ fmla v9.8h, v4.8h, v0.h[2] 3038+ fmla v11.8h, v4.8h, v1.h[2] 3039+ fmla v8.8h, v5.8h, v0.h[3] 3040+ fmla v10.8h, v5.8h, v1.h[3] 3041+ fmla v9.8h, v6.8h, v0.h[3] 3042+ fmla v11.8h, v6.8h, v1.h[3] 3043+ subs x14, x14, #4 3044+ Compute2x16EndTail: 3045+ adds x14, x14, #4 3046+ cbz x14, Compute2x16Return 3047+ cmp x14, #1 3048+ beq Compute2x16EndTail1 3049+ cmp x14, #2 3050+ beq Compute2x16EndTail2 3051+ ld1 {v0.4h}, [x10] 3052+ ld1 {v1.s}[0], [x19], #4 3053+ ld1 {v1.h}[2], [x19] 3054+ prfm pldl1strm, [x11, #632] 3055+ ld1 {v3.8h, v4.8h}, [x11], #32 3056+ fmla v8.8h, v3.8h, v0.h[0] 3057+ fmla v10.8h, v3.8h, v1.h[0] 3058+ ld1 {v5.8h, v6.8h}, [x11], #32 3059+ fmla v9.8h, v4.8h, v0.h[0] 3060+ fmla v11.8h, v4.8h, v1.h[0] 3061+ fmla v8.8h, v5.8h, v0.h[1] 3062+ fmla v10.8h, v5.8h, v1.h[1] 3063+ ld1 {v3.8h, v4.8h}, [x11], #32 3064+ fmla v9.8h, v6.8h, v0.h[1] 3065+ fmla v11.8h, v6.8h, v1.h[1] 3066+ fmla v8.8h, v3.8h, v0.h[2] 3067+ fmla v10.8h, v3.8h, v1.h[2] 3068+ fmla v9.8h, v4.8h, v0.h[2] 3069+ fmla v11.8h, v4.8h, v1.h[2] 3070+ b Compute2x16Return 3071+ Compute2x16EndTail2: 3072+ ld1 {v0.4h}, [x10] 3073+ ld2 {v1.h, v2.h}[0], [x19] 3074+ prfm pldl1strm, [x11, #632] 3075+ ld1 {v3.8h, v4.8h}, [x11], #32 3076+ fmla v8.8h, v3.8h, v0.h[0] 3077+ fmla v10.8h, v3.8h, v1.h[0] 3078+ ld1 {v5.8h, v6.8h}, [x11], #32 3079+ fmla v9.8h, v4.8h, v0.h[0] 3080+ fmla v11.8h, v4.8h, v1.h[0] 3081+ fmla v8.8h, v5.8h, v0.h[1] 3082+ fmla v10.8h, v5.8h, v2.h[0] 3083+ fmla v9.8h, v6.8h, v0.h[1] 3084+ fmla v11.8h, v6.8h, v2.h[0] 3085+ b Compute2x16Return 3086+ Compute2x16EndTail1: 3087+ ld1 {v0.h}[0], [x10] 3088+ ld1 {v1.h}[0], [x19] 3089+ prfm pldl1strm, [x11, #632] 3090+ ld1 {v3.8h, v4.8h}, [x11], #32 3091+ fmla v8.8h, v3.8h, v0.h[0] 3092+ fmla v10.8h, v3.8h, v1.h[0] 3093+ fmla v9.8h, v4.8h, v0.h[0] 3094+ fmla v11.8h, v4.8h, v1.h[0] 3095+ Compute2x16Return: 3096+ ret 3097+ 3098+Compute2x8Unit: 3099+ add x19, x10, x16 3100+ subs x14, x14, #8 3101+ blt Compute2x8End4 3102+ Compute2x8: 3103+ ld1 {v0.8h}, [x10], #16 3104+ ld1 {v1.8h}, [x19], #16 3105+ prfm pldl1strm, [x11, #632] 3106+ ld1 {v3.8h, v4.8h}, [x11], #32 3107+ fmla v8.8h, v3.8h, v0.h[0] 3108+ fmla v10.8h, v3.8h, v1.h[0] 3109+ ld1 {v5.8h, v6.8h}, [x11], #32 3110+ fmla v8.8h, v4.8h, v0.h[1] 3111+ fmla v10.8h, v4.8h, v1.h[1] 3112+ fmla v8.8h, v5.8h, v0.h[2] 3113+ fmla v10.8h, v5.8h, v1.h[2] 3114+ prfm pldl1strm, [x11, #632] 3115+ ld1 {v3.8h, v4.8h}, [x11], #32 3116+ fmla v8.8h, v6.8h, v0.h[3] 3117+ fmla v10.8h, v6.8h, v1.h[3] 3118+ fmla v8.8h, v3.8h, v0.h[4] 3119+ fmla v10.8h, v3.8h, v1.h[4] 3120+ ld1 {v5.8h, v6.8h}, [x11], #32 3121+ fmla v8.8h, v4.8h, v0.h[5] 3122+ fmla v10.8h, v4.8h, v1.h[5] 3123+ fmla v8.8h, v5.8h, v0.h[6] 3124+ fmla v10.8h, v5.8h, v1.h[6] 3125+ fmla v8.8h, v6.8h, v0.h[7] 3126+ fmla v10.8h, v6.8h, v1.h[7] 3127+ 3128+ subs x14, x14, #8 3129+ bge Compute2x8 3130+ Compute2x8End4: 3131+ adds x14, x14, #8 3132+ cbz x14, Compute2x8Return 3133+ subs x14, x14, #4 3134+ blt Compute2x8EndTail 3135+ ld1 {v0.4h}, [x10], #8 3136+ ld1 {v1.4h}, [x19], #8 3137+ prfm pldl1strm, [x11, #632] 3138+ ld1 {v3.8h, v4.8h}, [x11], #32 3139+ fmla v8.8h, v3.8h, v0.h[0] 3140+ fmla v10.8h, v3.8h, v1.h[0] 3141+ ld1 {v5.8h, v6.8h}, [x11], #32 3142+ fmla v8.8h, v4.8h, v0.h[1] 3143+ fmla v10.8h, v4.8h, v1.h[1] 3144+ fmla v8.8h, v5.8h, v0.h[2] 3145+ fmla v10.8h, v5.8h, v1.h[2] 3146+ fmla v8.8h, v6.8h, v0.h[3] 3147+ fmla v10.8h, v6.8h, v1.h[3] 3148+ subs x14, x14, #4 3149+ Compute2x8EndTail: 3150+ adds x14, x14, #4 3151+ cbz x14, Compute2x8Return 3152+ cmp x14, #1 3153+ beq Compute2x8EndTail1 3154+ cmp x14, #2 3155+ beq Compute2x8EndTail2 3156+ ld1 {v0.4h}, [x10] 3157+ ld3 {v1.h, v2.h, v3.h}[0], [x19] 3158+ prfm pldl1strm, [x11, #632] 3159+ ld1 {v4.8h, v5.8h}, [x11], #32 3160+ fmla v8.8h, v4.8h, v0.h[0] 3161+ fmla v10.8h, v4.8h, v1.h[0] 3162+ ld1 {v6.8h}, [x11], #16 3163+ fmla v8.8h, v5.8h, v0.h[1] 3164+ fmla v10.8h, v5.8h, v2.h[0] 3165+ fmla v8.8h, v6.8h, v0.h[2] 3166+ fmla v10.8h, v6.8h, v3.h[0] 3167+ b Compute2x8Return 3168+ Compute2x8EndTail2: 3169+ ld1 {v0.4h}, [x10] 3170+ ld2 {v1.h, v2.h}[0], [x19] 3171+ prfm pldl1strm, [x11, #632] 3172+ ld1 {v3.8h, v4.8h}, [x11], #32 3173+ fmla v8.8h, v3.8h, v0.h[0] 3174+ fmla v10.8h, v3.8h, v1.h[0] 3175+ fmla v8.8h, v4.8h, v0.h[1] 3176+ fmla v10.8h, v4.8h, v2.h[0] 3177+ b Compute2x8Return 3178+ Compute2x8EndTail1: 3179+ ld1 {v0.h}[0], [x10] 3180+ ld1 {v1.h}[0], [x19] 3181+ prfm pldl1strm, [x11, #632] 3182+ ld1 {v3.8h}, [x11], #16 3183+ fmla v8.8h, v3.8h, v0.h[0] 3184+ fmla v10.8h, v3.8h, v1.h[0] 3185+ Compute2x8Return: 3186+ ret 3187+ 3188+Compute2x4Unit: 3189+ add x19, x10, x16 3190+ subs x14, x14, #8 3191+ blt Compute2x4End4 3192+ Compute2x4: 3193+ ld1 {v0.8h}, [x10], #16 3194+ ld1 {v1.8h}, [x19], #16 3195+ prfm pldl1strm, [x11, #632] 3196+ ld1 {v3.4h, v4.4h}, [x11], #16 3197+ fmla v8.4h, v3.4h, v0.h[0] 3198+ fmla v10.4h, v3.4h, v1.h[0] 3199+ ld1 {v5.4h, v6.4h}, [x11], #16 3200+ fmla v8.4h, v4.4h, v0.h[1] 3201+ fmla v10.4h, v4.4h, v1.h[1] 3202+ fmla v8.4h, v5.4h, v0.h[2] 3203+ fmla v10.4h, v5.4h, v1.h[2] 3204+ prfm pldl1strm, [x11, #632] 3205+ ld1 {v3.4h, v4.4h}, [x11], #16 3206+ fmla v8.4h, v6.4h, v0.h[3] 3207+ fmla v10.4h, v6.4h, v1.h[3] 3208+ fmla v8.4h, v3.4h, v0.h[4] 3209+ fmla v10.4h, v3.4h, v1.h[4] 3210+ ld1 {v5.4h, v6.4h}, [x11], #16 3211+ fmla v8.4h, v4.4h, v0.h[5] 3212+ fmla v10.4h, v4.4h, v1.h[5] 3213+ fmla v8.4h, v5.4h, v0.h[6] 3214+ fmla v10.4h, v5.4h, v1.h[6] 3215+ fmla v8.4h, v6.4h, v0.h[7] 3216+ fmla v10.4h, v6.4h, v1.h[7] 3217+ 3218+ subs x14, x14, #8 3219+ bge Compute2x4 3220+ Compute2x4End4: 3221+ adds x14, x14, #8 3222+ cbz x14, Compute2x4Return 3223+ subs x14, x14, #4 3224+ blt Compute2x4EndTail 3225+ ld1 {v0.4h}, [x10], #8 3226+ ld1 {v1.4h}, [x19], #8 3227+ prfm pldl1strm, [x11, #632] 3228+ ld1 {v3.4h, v4.4h}, [x11], #16 3229+ fmla v8.4h, v3.4h, v0.h[0] 3230+ fmla v10.4h, v3.4h, v1.h[0] 3231+ ld1 {v5.4h, v6.4h}, [x11], #16 3232+ fmla v8.4h, v4.4h, v0.h[1] 3233+ fmla v10.4h, v4.4h, v1.h[1] 3234+ fmla v8.4h, v5.4h, v0.h[2] 3235+ fmla v10.4h, v5.4h, v1.h[2] 3236+ fmla v8.4h, v6.4h, v0.h[3] 3237+ fmla v10.4h, v6.4h, v1.h[3] 3238+ subs x14, x14, #4 3239+ Compute2x4EndTail: 3240+ adds x14, x14, #4 3241+ cbz x14, Compute2x4Return 3242+ cmp x14, #1 3243+ beq Compute2x4EndTail1 3244+ cmp x14, #2 3245+ beq Compute2x4EndTail2 3246+ ld1 {v0.4h}, [x10] 3247+ ld3 {v1.h, v2.h, v3.h}[0], [x19] 3248+ prfm pldl1strm, [x11, #632] 3249+ ld1 {v4.4h, v5.4h}, [x11], #16 3250+ fmla v8.4h, v4.4h, v0.h[0] 3251+ fmla v10.4h, v4.4h, v1.h[0] 3252+ ld1 {v6.4h}, [x11], #8 3253+ fmla v8.4h, v5.4h, v0.h[1] 3254+ fmla v10.4h, v5.4h, v2.h[0] 3255+ fmla v8.4h, v6.4h, v0.h[2] 3256+ fmla v10.4h, v6.4h, v3.h[0] 3257+ b Compute2x4Return 3258+ Compute2x4EndTail2: 3259+ ld1 {v0.4h}, [x10] 3260+ ld2 {v1.h, v2.h}[0], [x19] 3261+ prfm pldl1strm, [x11, #632] 3262+ ld1 {v3.4h, v4.4h}, [x11], #16 3263+ fmla v8.4h, v3.4h, v0.h[0] 3264+ fmla v10.4h, v3.4h, v1.h[0] 3265+ fmla v8.4h, v4.4h, v0.h[1] 3266+ fmla v10.4h, v4.4h, v2.h[0] 3267+ b Compute2x4Return 3268+ Compute2x4EndTail1: 3269+ ld1 {v0.h}[0], [x10] 3270+ ld1 {v1.h}[0], [x19] 3271+ prfm pldl1strm, [x11, #632] 3272+ ld1 {v3.4h}, [x11], #8 3273+ fmla v8.4h, v3.4h, v0.h[0] 3274+ fmla v10.4h, v3.4h, v1.h[0] 3275+ Compute2x4Return: 3276+ ret 3277+ 3278+Compute1x16Unit: 3279+ subs x14, x14, #8 3280+ blt Compute1x16End4 3281+ Compute1x16: 3282+ ld1 {v0.8h}, [x10], #16 3283+ prfm pldl1strm, [x11, #632] 3284+ ld1 {v3.8h, v4.8h}, [x11], #32 3285+ fmla v8.8h, v3.8h, v0.h[0] 3286+ ld1 {v5.8h, v6.8h}, [x11], #32 3287+ fmla v9.8h, v4.8h, v0.h[0] 3288+ fmla v8.8h, v5.8h, v0.h[1] 3289+ ld1 {v3.8h, v4.8h}, [x11], #32 3290+ fmla v9.8h, v6.8h, v0.h[1] 3291+ fmla v8.8h, v3.8h, v0.h[2] 3292+ ld1 {v5.8h, v6.8h}, [x11], #32 3293+ fmla v9.8h, v4.8h, v0.h[2] 3294+ fmla v8.8h, v5.8h, v0.h[3] 3295+ prfm pldl1strm, [x11, #632] 3296+ ld1 {v3.8h, v4.8h}, [x11], #32 3297+ fmla v9.8h, v6.8h, v0.h[3] 3298+ 3299+ fmla v8.8h, v3.8h, v0.h[4] 3300+ ld1 {v5.8h, v6.8h}, [x11], #32 3301+ fmla v9.8h, v4.8h, v0.h[4] 3302+ fmla v8.8h, v5.8h, v0.h[5] 3303+ ld1 {v3.8h, v4.8h}, [x11], #32 3304+ fmla v9.8h, v6.8h, v0.h[5] 3305+ fmla v8.8h, v3.8h, v0.h[6] 3306+ ld1 {v5.8h, v6.8h}, [x11], #32 3307+ fmla v9.8h, v4.8h, v0.h[6] 3308+ fmla v8.8h, v5.8h, v0.h[7] 3309+ fmla v9.8h, v6.8h, v0.h[7] 3310+ 3311+ subs x14, x14, #8 3312+ bge Compute1x16 3313+ Compute1x16End4: 3314+ adds x14, x14, #8 3315+ cbz x14, Compute1x16Return 3316+ subs x14, x14, #4 3317+ blt Compute1x16EndTail 3318+ ld1 {v0.4h}, [x10], #8 3319+ prfm pldl1strm, [x11, #632] 3320+ ld1 {v3.8h, v4.8h}, [x11], #32 3321+ fmla v8.8h, v3.8h, v0.h[0] 3322+ ld1 {v5.8h, v6.8h}, [x11], #32 3323+ fmla v9.8h, v4.8h, v0.h[0] 3324+ fmla v8.8h, v5.8h, v0.h[1] 3325+ ld1 {v3.8h, v4.8h}, [x11], #32 3326+ fmla v9.8h, v6.8h, v0.h[1] 3327+ fmla v8.8h, v3.8h, v0.h[2] 3328+ ld1 {v5.8h, v6.8h}, [x11], #32 3329+ fmla v9.8h, v4.8h, v0.h[2] 3330+ fmla v8.8h, v5.8h, v0.h[3] 3331+ fmla v9.8h, v6.8h, v0.h[3] 3332+ subs x14, x14, #4 3333+ Compute1x16EndTail: 3334+ adds x14, x14, #4 3335+ cbz x14, Compute1x16Return 3336+ cmp x14, #1 3337+ beq Compute1x16EndTail1 3338+ cmp x14, #2 3339+ beq Compute1x16EndTail2 3340+ ld3 {v0.h, v1.h, v2.h}[0], [x10] 3341+ prfm pldl1strm, [x11, #632] 3342+ ld1 {v3.8h, v4.8h}, [x11], #32 3343+ fmla v8.8h, v3.8h, v0.h[0] 3344+ ld1 {v5.8h, v6.8h}, [x11], #32 3345+ fmla v9.8h, v4.8h, v0.h[0] 3346+ fmla v8.8h, v5.8h, v1.h[0] 3347+ ld1 {v3.8h, v4.8h}, [x11], #32 3348+ fmla v9.8h, v6.8h, v1.h[0] 3349+ fmla v8.8h, v3.8h, v2.h[0] 3350+ fmla v9.8h, v4.8h, v2.h[0] 3351+ b Compute1x16Return 3352+ Compute1x16EndTail2: 3353+ ld2 {v0.h, v1.h}[0], [x10] 3354+ prfm pldl1strm, [x11, #632] 3355+ ld1 {v3.8h, v4.8h}, [x11], #32 3356+ fmla v8.8h, v3.8h, v0.h[0] 3357+ ld1 {v5.8h, v6.8h}, [x11], #32 3358+ fmla v9.8h, v4.8h, v0.h[0] 3359+ fmla v8.8h, v5.8h, v1.h[0] 3360+ fmla v9.8h, v6.8h, v1.h[0] 3361+ b Compute1x16Return 3362+ Compute1x16EndTail1: 3363+ ld1 {v0.h}[0], [x10] 3364+ prfm pldl1strm, [x11, #632] 3365+ ld1 {v3.8h, v4.8h}, [x11], #32 3366+ fmla v8.8h, v3.8h, v0.h[0] 3367+ fmla v9.8h, v4.8h, v0.h[0] 3368+ Compute1x16Return: 3369+ ret 3370+ 3371+Compute1x8Unit: 3372+ subs x14, x14, #8 3373+ blt Compute1x8End4 3374+ Compute1x8: 3375+ ld1 {v0.8h}, [x10], #16 3376+ prfm pldl1strm, [x11, #632] 3377+ ld1 {v3.8h, v4.8h}, [x11], #32 3378+ fmla v8.8h, v3.8h, v0.h[0] 3379+ ld1 {v5.8h, v6.8h}, [x11], #32 3380+ fmla v8.8h, v4.8h, v0.h[1] 3381+ fmla v8.8h, v5.8h, v0.h[2] 3382+ prfm pldl1strm, [x11, #632] 3383+ ld1 {v3.8h, v4.8h}, [x11], #32 3384+ fmla v8.8h, v6.8h, v0.h[3] 3385+ fmla v8.8h, v3.8h, v0.h[4] 3386+ ld1 {v5.8h, v6.8h}, [x11], #32 3387+ fmla v8.8h, v4.8h, v0.h[5] 3388+ fmla v8.8h, v5.8h, v0.h[6] 3389+ fmla v8.8h, v6.8h, v0.h[7] 3390+ 3391+ subs x14, x14, #8 3392+ bge Compute1x8 3393+ Compute1x8End4: 3394+ adds x14, x14, #8 3395+ cbz x14, Compute1x8Return 3396+ subs x14, x14, #4 3397+ blt Compute1x8EndTail 3398+ ld1 {v0.4h}, [x10], #8 3399+ prfm pldl1strm, [x11, #632] 3400+ ld1 {v3.8h, v4.8h}, [x11], #32 3401+ fmla v8.8h, v3.8h, v0.h[0] 3402+ ld1 {v5.8h, v6.8h}, [x11], #32 3403+ fmla v8.8h, v4.8h, v0.h[1] 3404+ fmla v8.8h, v5.8h, v0.h[2] 3405+ fmla v8.8h, v6.8h, v0.h[3] 3406+ subs x14, x14, #4 3407+ Compute1x8EndTail: 3408+ adds x14, x14, #4 3409+ cbz x14, Compute1x8Return 3410+ cmp x14, #1 3411+ beq Compute1x8EndTail1 3412+ cmp x14, #2 3413+ beq Compute1x8EndTail2 3414+ ld3 {v0.h, v1.h, v2.h}[0], [x10] 3415+ prfm pldl1strm, [x11, #632] 3416+ ld1 {v3.8h, v4.8h}, [x11], #32 3417+ fmla v8.8h, v3.8h, v0.h[0] 3418+ ld1 {v5.8h}, [x11], #16 3419+ fmla v8.8h, v4.8h, v1.h[0] 3420+ fmla v8.8h, v5.8h, v2.h[0] 3421+ b Compute1x8Return 3422+ Compute1x8EndTail2: 3423+ ld2 {v0.h, v1.h}[0], [x10] 3424+ prfm pldl1strm, [x11, #632] 3425+ ld1 {v3.8h, v4.8h}, [x11], #32 3426+ fmla v8.8h, v3.8h, v0.h[0] 3427+ fmla v8.8h, v4.8h, v1.h[0] 3428+ b Compute1x8Return 3429+ Compute1x8EndTail1: 3430+ ld1 {v0.h}[0], [x10] 3431+ prfm pldl1strm, [x11, #632] 3432+ ld1 {v3.8h}, [x11], #16 3433+ fmla v8.8h, v3.8h, v0.h[0] 3434+ Compute1x8Return: 3435+ ret 3436+ 3437+Compute1x4Unit: 3438+ subs x14, x14, #8 3439+ blt Compute1x4End4 3440+ Compute1x4: 3441+ ld1 {v0.8h}, [x10], #16 3442+ prfm pldl1strm, [x11, #632] 3443+ ld1 {v3.4h, v4.4h}, [x11], #16 3444+ fmla v8.4h, v3.4h, v0.h[0] 3445+ ld1 {v5.4h, v6.4h}, [x11], #16 3446+ fmla v8.4h, v4.4h, v0.h[1] 3447+ fmla v8.4h, v5.4h, v0.h[2] 3448+ prfm pldl1strm, [x11, #632] 3449+ ld1 {v3.4h, v4.4h}, [x11], #16 3450+ fmla v8.4h, v6.4h, v0.h[3] 3451+ fmla v8.4h, v3.4h, v0.h[4] 3452+ ld1 {v5.4h, v6.4h}, [x11], #16 3453+ fmla v8.4h, v4.4h, v0.h[5] 3454+ fmla v8.4h, v5.4h, v0.h[6] 3455+ fmla v8.4h, v6.4h, v0.h[7] 3456+ 3457+ subs x14, x14, #8 3458+ bge Compute1x4 3459+ Compute1x4End4: 3460+ adds x14, x14, #8 3461+ cbz x14, Compute1x4Return 3462+ subs x14, x14, #4 3463+ blt Compute1x4EndTail 3464+ ld1 {v0.4h}, [x10], #8 3465+ prfm pldl1strm, [x11, #632] 3466+ ld1 {v3.4h, v4.4h}, [x11], #16 3467+ fmla v8.4h, v3.4h, v0.h[0] 3468+ ld1 {v5.4h, v6.4h}, [x11], #16 3469+ fmla v8.4h, v4.4h, v0.h[1] 3470+ fmla v8.4h, v5.4h, v0.h[2] 3471+ fmla v8.4h, v6.4h, v0.h[3] 3472+ subs x14, x14, #4 3473+ Compute1x4EndTail: 3474+ adds x14, x14, #4 3475+ cbz x14, Compute1x4Return 3476+ cmp x14, #1 3477+ beq Compute1x4EndTail1 3478+ cmp x14, #2 3479+ beq Compute1x4EndTail2 3480+ ld3 {v0.h, v1.h, v2.h}[0], [x10] 3481+ prfm pldl1strm, [x11, #632] 3482+ ld1 {v3.4h, v4.4h}, [x11], #16 3483+ fmla v8.4h, v3.4h, v0.h[0] 3484+ ld1 {v5.4h}, [x11], #8 3485+ fmla v8.4h, v4.4h, v1.h[0] 3486+ fmla v8.4h, v5.4h, v2.h[0] 3487+ b Compute1x4Return 3488+ Compute1x4EndTail2: 3489+ ld2 {v0.h, v1.h}[0], [x10] 3490+ prfm pldl1strm, [x11, #632] 3491+ ld1 {v3.4h, v4.4h}, [x11], #16 3492+ fmla v8.4h, v3.4h, v0.h[0] 3493+ fmla v8.4h, v4.4h, v1.h[0] 3494+ b Compute1x4Return 3495+ Compute1x4EndTail1: 3496+ ld1 {v0.h}[0], [x10] 3497+ prfm pldl1strm, [x11, #632] 3498+ ld1 {v3.4h}, [x11], #8 3499+ fmla v8.4h, v3.4h, v0.h[0] 3500+ Compute1x4Return: 3501+ ret 3502+ 3503+End: 3504+ sub sp, sp, #192 3505+ ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 3506+ ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 3507+ ldp x19, x20, [sp], #16 3508+ ldp x21, x22, [sp], #16 3509+ ldp x23, x24, [sp], #16 3510+ ldp x29, x30, [sp], #16 3511+ ret 3512+#endif 3513diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_gather_d_grad_v2_parameter.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_gather_d_grad_v2_parameter.h 3514new file mode 100644 3515index 00000000..541c7ff1 3516--- /dev/null 3517+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_gather_d_grad_v2_parameter.h 3518@@ -0,0 +1,28 @@ 3519+/** 3520+ * Copyright 2023 Huawei Technologies Co., Ltd 3521+ * 3522+ * Licensed under the Apache License, Version 2.0 (the "License"); 3523+ * you may not use this file except in compliance with the License. 3524+ * You may obtain a copy of the License at 3525+ * 3526+ * http://www.apache.org/licenses/LICENSE-2.0 3527+ * 3528+ * Unless required by applicable law or agreed to in writing, software 3529+ * distributed under the License is distributed on an "AS IS" BASIS, 3530+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 3531+ * See the License for the specific language governing permissions and 3532+ * limitations under the License. 3533+ */ 3534+#ifndef MINDSPORE_NNACL_CUSTOM_GATHER_D_GRAD_V2_PARAMETER_H_ 3535+#define MINDSPORE_NNACL_CUSTOM_GATHER_D_GRAD_V2_PARAMETER_H_ 3536+ 3537+#include "nnacl/op_base.h" 3538+ 3539+typedef struct CustomGatherGradV2Parameter { 3540+ // Primitive parameter 3541+ OpParameter op_parameter_; 3542+ // shape correlative 3543+ int dim; 3544+} CustomGatherGradV2Parameter; 3545+ 3546+#endif // MINDSPORE_NNACL_CUSTOM_GATHER_D_GRAD_V2_PARAMETER_H_ 3547diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.c 3548index 6e754569..72391811 100644 3549--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.c 3550+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.c 3551@@ -35,13 +35,13 @@ void CustomGruFp16(float16_t *output, const float16_t *input, const float16_t *w 3552 float16_t *hidden_gate = buffer[C3NUM]; 3553 for (int i = 0; i < num_step; ++i) { 3554 if (batch_size != 1) { 3555- RowMajor2ColNMajorFp16(input + i * batch_size * input_size, buffer[0], batch_size, input_size); 3556+ RowMajor2ColNMajorFp16(input + i * batch_size * input_size, buffer[0], batch_size, input_size, false); 3557 for (int j = 0; j < C3NUM; ++j) { 3558 MatmulBaseFp16Neon(buffer[0], weight_input + j * weight_in_offset, input_gate + j * output_size, 3559 bias_input + j * col_align, ActType_No, input_size, batch_size, hidden_size, hidden_size, 3560 OutType_Nhwc); 3561 } 3562- RowMajor2ColNMajorFp16(init_h, buffer[C2NUM], batch_size, hidden_size); 3563+ RowMajor2ColNMajorFp16(init_h, buffer[C2NUM], batch_size, hidden_size, false); 3564 for (int j = 0; j < C3NUM; ++j) { 3565 MatmulBaseFp16Neon(buffer[C2NUM], weight_hidden + j * weight_hidden_offset, hidden_gate + j * output_size, 3566 bias_hidden + j * col_align, ActType_No, hidden_size, batch_size, hidden_size, hidden_size, 3567diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/exp_fp16.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/exp_fp16.c 3568index d1555953..93f005c8 100644 3569--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/exp_fp16.c 3570+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/exp_fp16.c 3571@@ -20,8 +20,10 @@ 3572 3573 #if defined(ENABLE_NEON) 3574 static inline void simd_exp_fp16(float16x8_t input, float16_t *dst) { 3575- static float16x8_t maxv = {88.0f, 88.0f, 88.0f, 88.0f, 88.0f, 88.0f, 88.0f, 88.0f}; 3576- static float16x8_t minv = {-88.0f, -88.0f, -88.0f, -88.0f, -88.0f, -88.0f, -88.0f, -88.0f}; 3577+ static float16x8_t maxv = {88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 3578+ 88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f}; 3579+ static float16x8_t minv = {-87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, 3580+ -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f}; 3581 input = vmaxq_f16(minv, vminq_f16(input, maxv)); 3582 vst1q_f16(dst, VexpFp16(input)); 3583 } 3584diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/lstm_fp16.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/lstm_fp16.c 3585index 813237fa..614842a1 100644 3586--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/lstm_fp16.c 3587+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/lstm_fp16.c 3588@@ -23,28 +23,38 @@ 3589 #include "nnacl/fp16/cast_fp16.h" 3590 #include "nnacl/intrinsics/ms_simd_instructions_fp16.h" 3591 3592-void PackLstmWeightFp32ToFp16(float16_t *dst, const float *src, int batch, int deep, int col, int col_align) { 3593+void PackLstmWeightFp32ToFp16(float16_t *dst, const float *src, int batch, int deep, int col, int col_align, 3594+ const int32_t *order) { 3595 for (int i = 0; i < batch; i++) { 3596 const float *src_batch = src + i * col * deep; 3597- float16_t *dst_batch = dst + i * col_align * deep; 3598+ float16_t *dst_batch = dst + (order == NULL ? i : order[i]) * col_align * deep; 3599+#ifdef ENABLE_ARM64 3600+ RowMajor2ColNMajorFp16(src_batch, dst_batch, col, deep, true); 3601+#else 3602 RowMajor2Col8MajorFp16(src_batch, dst_batch, col, deep, true); 3603+#endif 3604 } 3605 } 3606 3607-void PackLstmWeightFp16(float16_t *dst, const float16_t *src, int batch, int deep, int col, int col_align) { 3608+void PackLstmWeightFp16(float16_t *dst, const float16_t *src, int batch, int deep, int col, int col_align, 3609+ const int32_t *order) { 3610 for (int i = 0; i < batch; i++) { 3611 const float16_t *src_batch = src + i * col * deep; 3612- float16_t *dst_batch = dst + i * col_align * deep; 3613+ float16_t *dst_batch = dst + (order == NULL ? i : order[i]) * col_align * deep; 3614+#ifdef ENABLE_ARM64 3615+ RowMajor2ColNMajorFp16(src_batch, dst_batch, col, deep, false); 3616+#else 3617 RowMajor2Col8MajorFp16(src_batch, dst_batch, col, deep, false); 3618+#endif 3619 } 3620 } 3621 3622-void PackLstmBiasFp32ToFp16(float16_t *dst, const float *src, int batch, int col, int col_align, 3623- bool is_bidirectional) { 3624+void PackLstmBiasFp32ToFp16(float16_t *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional, 3625+ const int32_t *order) { 3626 int unidirectional_batch = is_bidirectional ? batch / 2 : batch; 3627 for (int i = 0; i < unidirectional_batch; i++) { 3628 const float *src_batch = src + i * col; 3629- float16_t *dst_batch = dst + i * col_align; 3630+ float16_t *dst_batch = dst + (order == NULL ? i : order[i]) * col_align; 3631 Float32ToFloat16(src_batch, dst_batch, col); 3632 } 3633 if (is_bidirectional) { 3634@@ -52,17 +62,18 @@ void PackLstmBiasFp32ToFp16(float16_t *dst, const float *src, int batch, int col 3635 float16_t *backward_dst = dst + unidirectional_batch * col_align; 3636 for (int i = 0; i < unidirectional_batch; i++) { 3637 const float *backward_src_batch = backward_src + i * col; 3638- float16_t *backward_dst_batch = backward_dst + i * col_align; 3639+ float16_t *backward_dst_batch = backward_dst + (order == NULL ? i : order[i]) * col_align; 3640 Float32ToFloat16(backward_src_batch, backward_dst_batch, col); 3641 } 3642 } 3643 } 3644 3645-void PackLstmBiasFp16(float16_t *dst, const float16_t *src, int batch, int col, int col_align, bool is_bidirectional) { 3646+void PackLstmBiasFp16(float16_t *dst, const float16_t *src, int batch, int col, int col_align, bool is_bidirectional, 3647+ const int32_t *order) { 3648 int unidirectional_batch = is_bidirectional ? batch / 2 : batch; 3649 for (int i = 0; i < unidirectional_batch; i++) { 3650 const float16_t *src_batch = src + i * col; 3651- float16_t *dst_batch = dst + i * col_align; 3652+ float16_t *dst_batch = dst + (order == NULL ? i : order[i]) * col_align; 3653 (void)memcpy(dst_batch, src_batch, col * sizeof(float16_t)); 3654 } 3655 if (is_bidirectional) { 3656@@ -70,7 +81,7 @@ void PackLstmBiasFp16(float16_t *dst, const float16_t *src, int batch, int col, 3657 float16_t *backward_dst = dst + unidirectional_batch * col_align; 3658 for (int i = 0; i < unidirectional_batch; i++) { 3659 const float16_t *backward_src_batch = backward_src + i * col; 3660- float16_t *backward_dst_batch = backward_dst + i * col_align; 3661+ float16_t *backward_dst_batch = backward_dst + (order == NULL ? i : order[i]) * col_align; 3662 (void)memcpy(backward_dst_batch, backward_src_batch, col * sizeof(float16_t)); 3663 } 3664 } 3665@@ -152,13 +163,13 @@ void UpdateOutputFp16(float16_t *hidden_state, float16_t *output, const float16_ 3666 const LstmParameter *lstm_param) { 3667 int batch = lstm_param->batch_; 3668 int hidden_size = lstm_param->hidden_size_; 3669- int project_size = lstm_param->project_size_; 3670+ int output_size = lstm_param->output_size_; 3671 float16_t *state_buffer = buffer[C5NUM]; 3672 float16_t *hidden_buffer = weight_project ? buffer[C3NUM] : hidden_state; 3673 float16_t zoneout = lstm_param->zoneout_hidden_; 3674 if (!(zoneout >= -FLT_EPSILON && zoneout <= FLT_EPSILON)) { 3675- (void)memcpy(state_buffer, hidden_state, batch * project_size * sizeof(float16_t)); 3676- ElementOptMulFp16(state_buffer, &zoneout, state_buffer, batch * project_size, false); 3677+ (void)memcpy(state_buffer, hidden_state, batch * output_size * sizeof(float16_t)); 3678+ ElementOptMulFp16(state_buffer, &zoneout, state_buffer, batch * output_size, false); 3679 } 3680 3681 TanhFp16(cell_state, hidden_buffer, batch * hidden_size); 3682@@ -166,19 +177,32 @@ void UpdateOutputFp16(float16_t *hidden_state, float16_t *output, const float16_ 3683 3684 if (weight_project) { 3685 float16_t *left_matrix = hidden_buffer; 3686+#ifdef ENABLE_ARM64 3687+ if (batch >= C4NUM) { 3688+ left_matrix = buffer[C6NUM]; 3689+ RowMajor2ColLadder12MajorFp16(hidden_buffer, left_matrix, batch, hidden_size); 3690+ } 3691+#else 3692 if (batch != 1) { 3693 left_matrix = buffer[C6NUM]; 3694 RowMajor2Col16MajorFp16(hidden_buffer, left_matrix, batch, hidden_size, false); 3695 } 3696- LstmMatMulFp16(hidden_state, left_matrix, weight_project, project_bias, batch, hidden_size, project_size, 3697+#endif 3698+ LstmMatMulFp16(hidden_state, left_matrix, weight_project, project_bias, batch, hidden_size, output_size, 3699 batch == 1); 3700 } 3701 if (!(zoneout >= -FLT_EPSILON && zoneout <= FLT_EPSILON)) { 3702- ElementOptMulAccFp16(hidden_state, 1 - zoneout, state_buffer, batch * project_size); 3703+ ElementOptMulAccFp16(hidden_state, 1 - zoneout, state_buffer, batch * output_size); 3704 } 3705- (void)memcpy(output, hidden_state, batch * project_size * sizeof(float16_t)); 3706+ (void)memcpy(output, hidden_state, batch * output_size * sizeof(float16_t)); 3707 } 3708 3709+#ifdef ENABLE_ARM64 3710+void LstmMatMulFp16(float16_t *c, const float16_t *a, const float16_t *b, const float16_t *bias, int row, int deep, 3711+ int col, bool is_vec) { 3712+ MatmulFp16OptV2(a, b, c, bias, ActType_No, deep, row, col, col, OutType_Nhwc); 3713+} 3714+#else 3715 void LstmMatMulFp16(float16_t *c, const float16_t *a, const float16_t *b, const float16_t *bias, int row, int deep, 3716 int col, bool is_vec) { 3717 if (is_vec) { 3718@@ -188,11 +212,12 @@ void LstmMatMulFp16(float16_t *c, const float16_t *a, const float16_t *b, const 3719 MatMulFp16(a, b, c, bias, ActType_No, deep, row, col, col, OutType_Nhwc); 3720 } 3721 } 3722+#endif 3723 3724 void UpdateLstmGateFp16(float16_t *gate_buffer, const float16_t *input, const float16_t *weight, const float16_t *bias, 3725 int row, int deep, int col, int col_align, bool is_vec) { 3726 for (int i = 0; i < 4; i++) { 3727- const float16_t *weight_i = weight + deep * col * i; 3728+ const float16_t *weight_i = weight + deep * col_align * i; 3729 const float16_t *bias_i = bias + col_align * i; 3730 float16_t *gate = gate_buffer + row * col * i; 3731 LstmMatMulFp16(gate, input, weight_i, bias_i, row, deep, col, is_vec); 3732@@ -207,16 +232,26 @@ void LstmStepUnitFp16(float16_t *output, float16_t *input_gate, float16_t *forge 3733 float16_t *state_gate = buffer[C3NUM]; 3734 float16_t *cell_buffer = buffer[C4NUM]; 3735 float16_t *hidden_buffer = buffer[C5NUM]; 3736+#ifdef ENABLE_ARM64 3737+ if (lstm_param->batch_ <= C3NUM) { 3738+ UpdateLstmGateFp16(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->output_size_, 3739+ lstm_param->hidden_size_, lstm_param->state_col_align_, false); 3740+ } else { 3741+ RowMajor2ColLadder12MajorFp16(hidden_state, packed_state, lstm_param->batch_, lstm_param->output_size_); 3742+ UpdateLstmGateFp16(state_gate, packed_state, state_weight, state_bias, lstm_param->batch_, lstm_param->output_size_, 3743+ lstm_param->hidden_size_, lstm_param->state_col_align_, false); 3744+ } 3745+#else 3746 bool is_vec = lstm_param->batch_ == 1; 3747 if (is_vec) { 3748- UpdateLstmGateFp16(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, 3749- lstm_param->project_size_, lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec); 3750+ UpdateLstmGateFp16(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->output_size_, 3751+ lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec); 3752 } else { 3753- // pack state for matmul 3754- RowMajor2Col16MajorFp16(hidden_state, packed_state, lstm_param->batch_, lstm_param->project_size_, false); 3755- UpdateLstmGateFp16(state_gate, packed_state, state_weight, state_bias, lstm_param->batch_, 3756- lstm_param->project_size_, lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec); 3757+ RowMajor2Col16MajorFp16(hidden_state, packed_state, lstm_param->batch_, lstm_param->output_size_, false); 3758+ UpdateLstmGateFp16(state_gate, packed_state, state_weight, state_bias, lstm_param->batch_, lstm_param->output_size_, 3759+ lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec); 3760 } 3761+#endif 3762 ElementAddFp16(input_gate, state_gate, input_gate, lstm_param->batch_ * lstm_param->hidden_size_); 3763 ElementAddFp16(forget_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_ * 2, forget_gate, 3764 lstm_param->batch_ * lstm_param->hidden_size_); 3765@@ -247,24 +282,43 @@ void LstmStepUnitFp16(float16_t *output, float16_t *input_gate, float16_t *forge 3766 } 3767 3768 if (!(lstm_param->zoneout_hidden_ >= -FLT_EPSILON && lstm_param->zoneout_hidden_ <= FLT_EPSILON)) { 3769- (void)memcpy(hidden_state, hidden_buffer, lstm_param->batch_ * lstm_param->project_size_ * sizeof(float16_t)); 3770+ (void)memcpy(hidden_state, hidden_buffer, lstm_param->batch_ * lstm_param->output_size_ * sizeof(float16_t)); 3771 } 3772 } 3773 3774-void LstmUnidirectionalFp16(float16_t *output, const float16_t *packed_input, const float16_t *weight_i, 3775- const float16_t *weight_h, const float16_t *input_bias, const float16_t *state_bias, 3776- const float16_t *weight_project, const float16_t *project_bias, float16_t *hidden_state, 3777- float16_t *cell_state, float16_t *buffer[C7NUM], const LstmParameter *lstm_param, 3778- bool is_backward) { 3779- float16_t *gate = buffer[1]; 3780+#ifdef ENABLE_ARM64 3781+void LstmGateCompute(float16_t *gate, const float16_t *input, const float16_t *weight_i, const float16_t *input_bias, 3782+ const LstmParameter *lstm_param) { 3783+ int row_input = lstm_param->seq_len_ * lstm_param->batch_; 3784+ for (int i = 0; i < C4NUM; i++) { 3785+ const float16_t *weight_loop = weight_i + lstm_param->input_size_ * lstm_param->input_col_align_ * i; 3786+ const float16_t *bias_loop = input_bias + lstm_param->input_col_align_ * i; 3787+ float16_t *gate_loop = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * i; 3788+ MatmulFp16OptV2(input, weight_loop, gate_loop, bias_loop, ActType_No, lstm_param->input_size_, row_input, 3789+ lstm_param->hidden_size_, lstm_param->hidden_size_, OutType_Nhwc); 3790+ } 3791+} 3792+#else 3793+void LstmGateCompute(float16_t *gate, const float16_t *input, const float16_t *weight_i, const float16_t *input_bias, 3794+ const LstmParameter *lstm_param) { 3795 for (int i = 0; i < C4NUM; i++) { 3796 const float16_t *weight_loop = weight_i + lstm_param->input_size_ * lstm_param->input_col_align_ * i; 3797 const float16_t *bias_loop = input_bias + lstm_param->input_col_align_ * i; 3798 float16_t *gate_loop = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * i; 3799- MatMulFp16(packed_input, weight_loop, gate_loop, bias_loop, ActType_No, lstm_param->input_size_, 3800+ MatMulFp16(input, weight_loop, gate_loop, bias_loop, ActType_No, lstm_param->input_size_, 3801 lstm_param->seq_len_ * lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_, 3802 OutType_Nhwc); 3803 } 3804+} 3805+#endif 3806+ 3807+void LstmUnidirectionalFp16(float16_t *output, const float16_t *packed_input, const float16_t *weight_i, 3808+ const float16_t *weight_h, const float16_t *input_bias, const float16_t *state_bias, 3809+ const float16_t *weight_project, const float16_t *project_bias, float16_t *hidden_state, 3810+ float16_t *cell_state, float16_t *buffer[C7NUM], const LstmParameter *lstm_param, 3811+ bool is_backward) { 3812+ float16_t *gate = buffer[1]; 3813+ LstmGateCompute(gate, packed_input, weight_i, input_bias, lstm_param); 3814 3815 float16_t *input_gate = gate; 3816 float16_t *forget_gate = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * 2; 3817@@ -287,26 +341,33 @@ void LstmFp16(float16_t *output, const float16_t *input, const float16_t *weight 3818 const float16_t *project_bias, float16_t *hidden_state, float16_t *cell_state, float16_t *buffer[C7NUM], 3819 const LstmParameter *lstm_param) { 3820 // forward 3821+#ifdef ENABLE_ARM64 3822+ const float16_t *packed_input = input; 3823+ if (lstm_param->batch_ * lstm_param->seq_len_ >= C4NUM) { 3824+ float16_t *temp_input = buffer[0]; 3825+ RowMajor2ColLadder12MajorFp16(input, temp_input, lstm_param->seq_len_ * lstm_param->batch_, 3826+ lstm_param->input_size_); 3827+ packed_input = temp_input; 3828+ } 3829+#else 3830 float16_t *packed_input = buffer[0]; 3831 RowMajor2Col16MajorFp16(input, packed_input, lstm_param->seq_len_ * lstm_param->batch_, lstm_param->input_size_, 3832 false); 3833+#endif 3834 LstmUnidirectionalFp16(output, packed_input, weight_i, weight_h, input_bias, state_bias, weight_project, project_bias, 3835 hidden_state, cell_state, buffer, lstm_param, false); 3836 3837 // backward 3838 if (lstm_param->bidirectional_) { 3839 const float16_t *backward_weight_i = weight_i + 4 * lstm_param->input_col_align_ * lstm_param->input_size_; 3840- const float16_t *backward_weight_h = weight_h + 4 * lstm_param->state_col_align_ * lstm_param->hidden_size_; 3841+ const float16_t *backward_weight_h = weight_h + 4 * lstm_param->state_col_align_ * lstm_param->output_size_; 3842 const float16_t *backward_input_bias = input_bias + 4 * lstm_param->input_col_align_; 3843 const float16_t *backward_state_bias = state_bias + 4 * lstm_param->state_col_align_; 3844 const float16_t *backward_weight_project = 3845- weight_project ? weight_project + lstm_param->hidden_size_ * (lstm_param->batch_ == 1 3846- ? lstm_param->project_size_ 3847- : UP_ROUND(lstm_param->project_size_, C8NUM)) 3848- : NULL; 3849- float16_t *backward_output = output + lstm_param->batch_ * lstm_param->hidden_size_; 3850+ weight_project ? weight_project + lstm_param->hidden_size_ * lstm_param->proj_col_align_ : NULL; 3851+ float16_t *backward_output = output + lstm_param->batch_ * lstm_param->output_size_; 3852 float16_t *backward_cell_state = cell_state + lstm_param->batch_ * lstm_param->hidden_size_; 3853- float16_t *backward_hidden_state = hidden_state + lstm_param->batch_ * lstm_param->hidden_size_; 3854+ float16_t *backward_hidden_state = hidden_state + lstm_param->batch_ * lstm_param->output_size_; 3855 3856 LstmUnidirectionalFp16(backward_output, packed_input, backward_weight_i, backward_weight_h, backward_input_bias, 3857 backward_state_bias, backward_weight_project, project_bias, backward_hidden_state, 3858diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/lstm_fp16.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/lstm_fp16.h 3859index f6f853b4..d6af9c78 100644 3860--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/lstm_fp16.h 3861+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/lstm_fp16.h 3862@@ -21,13 +21,17 @@ 3863 #ifdef __cplusplus 3864 extern "C" { 3865 #endif 3866-void PackLstmWeightFp32ToFp16(float16_t *dst, const float *src, int batch, int deep, int col, int col_align); 3867+void PackLstmWeightFp32ToFp16(float16_t *dst, const float *src, int batch, int deep, int col, int col_align, 3868+ const int32_t *order); 3869 3870-void PackLstmWeightFp16(float16_t *dst, const float16_t *src, int batch, int deep, int col, int col_align); 3871+void PackLstmWeightFp16(float16_t *dst, const float16_t *src, int batch, int deep, int col, int col_align, 3872+ const int32_t *order); 3873 3874-void PackLstmBiasFp32ToFp16(float16_t *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional); 3875+void PackLstmBiasFp32ToFp16(float16_t *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional, 3876+ const int32_t *order); 3877 3878-void PackLstmBiasFp16(float16_t *dst, const float16_t *src, int batch, int col, int col_align, bool is_bidirectional); 3879+void PackLstmBiasFp16(float16_t *dst, const float16_t *src, int batch, int col, int col_align, bool is_bidirectional, 3880+ const int32_t *order); 3881 3882 void LstmMatMulFp16(float16_t *c, const float16_t *a, const float16_t *b, const float16_t *bias, int row, int deep, 3883 int col, bool is_vec); 3884diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/matmul_fp16.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/matmul_fp16.c 3885index 1aefbaf5..39dcb9ee 100644 3886--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/matmul_fp16.c 3887+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/matmul_fp16.c 3888@@ -16,7 +16,7 @@ 3889 3890 #include "nnacl/fp16/matmul_fp16.h" 3891 3892-static void Col2Row8SrcFromFp16(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { 3893+static void Col2Row8SrcFromFp16(const void *src_ptr, float16_t *dst_ptr, int row, int col) { 3894 int row_c8 = row / C8NUM * C8NUM; 3895 int col_c8 = col / C8NUM * C8NUM; 3896 const float16_t *src = (const float16_t *)src_ptr; 3897@@ -108,7 +108,7 @@ static void Col2Row8SrcFromFp16(const void *src_ptr, float16_t *dst_ptr, size_t 3898 } 3899 } 3900 3901-static void Col2Row8SrcFromFp32(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { 3902+static void Col2Row8SrcFromFp32(const void *src_ptr, float16_t *dst_ptr, int row, int col) { 3903 int row_c8 = row / C8NUM * C8NUM; 3904 int col_c8 = col / C8NUM * C8NUM; 3905 int ci = 0; 3906@@ -410,17 +410,14 @@ void VecMatmulFp16(const float16_t *a, const float16_t *b, float16_t *c, const f 3907 int di = 0; 3908 for (; di < depth - C8NUM + 1; di += C8NUM) { 3909 float16x8_t av = vld1q_f16(a + di); 3910- float16x8_t bv_0; 3911- float16x8_t bv_1; 3912- for (int i = 0; i < C8NUM; i += C2NUM) { 3913- bv_0 = vld1q_f16(bv_base); // bv_i为一行,8列数据 3914- acc_0 = vfmaq_n_f16(acc_0, bv_0, av[i]); // av[i]为向量中的一个值 3915- bv_base += C8NUM; 3916- 3917- bv_1 = vld1q_f16(bv_base); // bv_i为一行,8列数据 3918- acc_0 = vfmaq_n_f16(acc_0, bv_1, av[i + 1]); // av[i]为向量中的一个值 3919+ float16x8_t bv_0[C8NUM]; 3920+ for (int i = 0; i < C8NUM; ++i) { 3921+ bv_0[i] = vld1q_f16(bv_base); 3922 bv_base += C8NUM; 3923 } 3924+ for (int i = 0; i < C8NUM; ++i) { 3925+ acc_0 = vfmaq_n_f16(acc_0, bv_0[i], av[i]); 3926+ } 3927 } 3928 if (di < depth) { 3929 for (; di < depth; ++di) { 3930@@ -636,8 +633,94 @@ void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, si 3931 } 3932 3933 #ifdef ENABLE_ARM64 3934-void RowMajor2ColNMajorFp16(const float16_t *src_ptr, float16_t *dst_ptr, int row, int col) { 3935- // Col16Major ==> Col8Major ==> Col4Major 3936+void RowMajor2ColLadder12MajorFp16(const float16_t *src, float16_t *dst_ptr, int row, int col) { 3937+ // Col12Major ==> Col8Major ==> Col4Major 3938+ const float16_t *src_r = src; 3939+ float16_t *dst_r = dst_ptr; 3940+ int ri = 0; 3941+ size_t col8 = col / C8NUM * C8NUM; 3942+ // find 16 block unit 3943+ for (; ri <= row - C12NUM; ri += C12NUM) { 3944+ size_t ci = 0; 3945+ for (; ci < col8; ci += C8NUM) { 3946+ const float16_t *src_c = src_r + ci; 3947+ float16_t *dst_c = dst_r + ci * C12NUM; 3948+ Transpose12x8ARM64Fp16(src_c, dst_c, col * C2NUM, C24NUM); 3949+ } 3950+ for (; ci < col; ci++) { 3951+ const float16_t *src_c = src_r + ci; 3952+ float16_t *dst_c = dst_r + ci * C12NUM; 3953+ for (size_t i = 0; i < C12NUM; i++) { 3954+ dst_c[i] = src_c[i * col]; 3955+ } 3956+ } 3957+ src_r += C12NUM * col; 3958+ dst_r += C12NUM * col; 3959+ } 3960+ for (; ri <= row - C8NUM; ri += C8NUM) { 3961+ size_t ci = 0; 3962+ for (; ci < col8; ci += C8NUM) { 3963+ const float16_t *src_c = src_r + ci; 3964+ float16_t *dst_c = dst_r + ci * C8NUM; 3965+ Transpose8x8ARM64Fp16(src_c, dst_c, col * sizeof(float16_t), C8NUM * sizeof(float16_t)); 3966+ } 3967+ for (; ci < col; ci++) { 3968+ const float16_t *src_c = src_r + ci; 3969+ float16_t *dst_c = dst_r + ci * C8NUM; 3970+ for (size_t i = 0; i < C8NUM; i++) { 3971+ dst_c[i] = src_c[i * col]; 3972+ } 3973+ } 3974+ src_r += C8NUM * col; 3975+ dst_r += C8NUM * col; 3976+ } 3977+ for (; ri <= row - C4NUM; ri += C4NUM) { 3978+ size_t ci = 0; 3979+ for (; ci < col8; ci += C8NUM) { 3980+ const float16_t *src_c = src_r + ci; 3981+ float16_t *dst_c = dst_r + ci * C4NUM; 3982+ Transpose4x8ARM64Fp16(src_c, dst_c, col * sizeof(float16_t), C4NUM * sizeof(float16_t)); 3983+ } 3984+ for (; ci < col; ci++) { 3985+ const float16_t *src_c = src_r + ci; 3986+ float16_t *dst_c = dst_r + ci * C4NUM; 3987+ for (size_t i = 0; i < C4NUM; i++) { 3988+ dst_c[i] = src_c[i * col]; 3989+ } 3990+ } 3991+ src_r += C4NUM * col; 3992+ dst_r += C4NUM * col; 3993+ } 3994+ if (ri < row) { 3995+ memcpy(dst_r, src_r, (row - ri) * col * C2NUM); 3996+ } 3997+} 3998+ 3999+void RowMajor2RowLadder12MajorFp16(const float16_t *src, float16_t *dst, int row, int col) { 4000+ // Row12 ==> Row8 ==> Row4 4001+ for (int r = 0; r < row; r++) { 4002+ int c = 0; 4003+ for (; c <= col - C12NUM; c += C12NUM) { 4004+ MS_FLOAT16X8 src_data = MS_LDQ_F16(src + r * col + c); 4005+ MS_FLOAT16X4 src_data1 = MS_LD_F16(src + r * col + c + C8NUM); 4006+ MS_STQ_F16(dst + c / C12NUM * C12NUM * row + r * C12NUM, src_data); 4007+ MS_ST_F16(dst + c / C12NUM * C12NUM * row + r * C12NUM + C8NUM, src_data1); 4008+ } 4009+ for (; c <= col - C8NUM; c += C8NUM) { 4010+ MS_FLOAT16X8 src_data = MS_LDQ_F16(src + r * col + c); 4011+ MS_STQ_F16(dst + c / C12NUM * C12NUM * row + r * C8NUM, src_data); 4012+ } 4013+ for (; c <= col - C4NUM; c += C4NUM) { 4014+ MS_FLOAT16X4 src_data = MS_LD_F16(src + r * col + c); 4015+ MS_ST_F16(dst + c / C4NUM * C4NUM * row + r * C4NUM, src_data); 4016+ } 4017+ for (; c < col; ++c) { 4018+ dst[c / C4NUM * C4NUM * row + r + c % C4NUM * row] = src[r * col + c]; 4019+ } 4020+ } 4021+} 4022+ 4023+void RowMajor2ColNMajorFp16srcFp16(const float16_t *src_ptr, float16_t *dst_ptr, int row, int col) { 4024 const float16_t *src_r = src_ptr; 4025 float16_t *dst_r = dst_ptr; 4026 int ri = 0; 4027@@ -702,6 +785,112 @@ void RowMajor2ColNMajorFp16(const float16_t *src_ptr, float16_t *dst_ptr, int ro 4028 dst_r += 1; 4029 } 4030 } 4031+ 4032+void RowMajor2ColNMajorFp16(const void *src_ptr, float16_t *dst_ptr, int row, int col, bool is_fp32_src) { 4033+ // Col16Major ==> Col8Major ==> Col4Major 4034+ if (!is_fp32_src) { 4035+ RowMajor2ColNMajorFp16srcFp16((const float16_t *)src_ptr, dst_ptr, row, col); 4036+ return; 4037+ } 4038+ const float *src_r = src_ptr; 4039+ float16_t *dst_r = dst_ptr; 4040+ int ri = 0; 4041+ // find 16 block unit 4042+ for (; ri <= row - C16NUM; ri += C16NUM) { 4043+ for (int r = 0; r < C16NUM; ++r) { 4044+ for (int c = 0; c < col; ++c) { 4045+ dst_r[c * C16NUM + r % C16NUM] = src_r[r * col + c]; 4046+ } 4047+ } 4048+ src_r += C16NUM * col; 4049+ dst_r += C16NUM * col; 4050+ } 4051+ for (; ri <= row - C8NUM; ri += C8NUM) { 4052+ for (int r = 0; r < C8NUM; ++r) { 4053+ for (int c = 0; c < col; ++c) { 4054+ dst_r[c * C8NUM + r % C8NUM] = src_r[r * col + c]; 4055+ } 4056+ } 4057+ src_r += C8NUM * col; 4058+ dst_r += C8NUM * col; 4059+ } 4060+ for (; ri <= row - C4NUM; ri += C4NUM) { 4061+ for (int r = 0; r < C4NUM; ++r) { 4062+ for (int c = 0; c < col; ++c) { 4063+ dst_r[c * C4NUM + r % C4NUM] = src_r[r * col + c]; 4064+ } 4065+ } 4066+ src_r += C4NUM * col; 4067+ dst_r += C4NUM * col; 4068+ } 4069+ for (; ri < row; ++ri) { 4070+ for (size_t i = 0; i < col; ++i) { 4071+ dst_r[i * C4NUM] = src_r[i]; 4072+ } 4073+ src_r += col; 4074+ dst_r += 1; 4075+ } 4076+} 4077+ 4078+void RowMajor2RowNMajorFp16(const void *src_ptr, float16_t *dst, int row, int col, bool is_fp32_src) { 4079+ // Row16 ==> Row8 ==> Row4 4080+ if (is_fp32_src) { 4081+ const float *src = (const float *)src_ptr; 4082+ for (int r = 0; r < row; r++) { 4083+ int c = 0; 4084+ for (; c <= col - C16NUM; c += C16NUM) { 4085+ const float *cur_src = src + r * col + c; 4086+ MS_FLOAT32X4X4 src_f32_data = {MS_LDQ_F32(cur_src), MS_LDQ_F32(cur_src + C4NUM), MS_LDQ_F32(cur_src + C8NUM), 4087+ MS_LDQ_F32(cur_src + C12NUM)}; 4088+ MS_FLOAT16X4X4 res = { 4089+ MS_CVT_F16_F32(src_f32_data.val[0]), 4090+ MS_CVT_F16_F32(src_f32_data.val[1]), 4091+ MS_CVT_F16_F32(src_f32_data.val[2]), 4092+ MS_CVT_F16_F32(src_f32_data.val[3]), 4093+ }; 4094+ MS_ST4_F16(dst + c / C16NUM * C16NUM * row + r * C16NUM, res); 4095+ } 4096+ for (; c <= col - C8NUM; c += C8NUM) { 4097+ const float *cur_src = src + r * col + c; 4098+ MS_FLOAT32X4X2 src_f32_data = {MS_LDQ_F32(cur_src), MS_LDQ_F32(cur_src + C4NUM)}; 4099+ MS_FLOAT16X4X2 res = { 4100+ MS_CVT_F16_F32(src_f32_data.val[0]), 4101+ MS_CVT_F16_F32(src_f32_data.val[1]), 4102+ }; 4103+ MS_ST2_F16(dst + c / C8NUM * C8NUM * row + r * C8NUM, res); 4104+ } 4105+ for (; c <= col - C4NUM; c += C4NUM) { 4106+ MS_FLOAT16X4 src_data = MS_CVT_F16_F32(MS_LDQ_F32(src + r * col + c)); 4107+ MS_ST_F16(dst + c / C4NUM * C4NUM * row + r * C4NUM, src_data); 4108+ } 4109+ for (; c < col; ++c) { 4110+ dst[c / C4NUM * C4NUM * row + r * C4NUM + c % C4NUM] = src[r * col + c]; 4111+ } 4112+ } 4113+ return; 4114+ } 4115+ const float16_t *src = (const float16_t *)src_ptr; 4116+ for (int r = 0; r < row; r++) { 4117+ int c = 0; 4118+ for (; c <= col - C16NUM; c += C16NUM) { 4119+ MS_FLOAT16X8 src_data = MS_LDQ_F16(src + r * col + c); 4120+ MS_FLOAT16X8 src_data1 = MS_LDQ_F16(src + r * col + c + C8NUM); 4121+ MS_STQ_F16(dst + c / C16NUM * C16NUM * row + r * C16NUM, src_data); 4122+ MS_STQ_F16(dst + c / C16NUM * C16NUM * row + r * C16NUM + C8NUM, src_data1); 4123+ } 4124+ for (; c <= col - C8NUM; c += C8NUM) { 4125+ MS_FLOAT16X8 src_data = MS_LDQ_F16(src + r * col + c); 4126+ MS_STQ_F16(dst + c / C8NUM * C8NUM * row + r * C8NUM, src_data); 4127+ } 4128+ for (; c <= col - C4NUM; c += C4NUM) { 4129+ MS_FLOAT16X4 src_data = MS_LD_F16(src + r * col + c); 4130+ MS_ST_F16(dst + c / C4NUM * C4NUM * row + r * C4NUM, src_data); 4131+ } 4132+ for (; c < col; ++c) { 4133+ dst[c / C4NUM * C4NUM * row + r * C4NUM + c % C4NUM] = src[r * col + c]; 4134+ } 4135+ } 4136+} 4137 #endif 4138 4139 void RowMajor2Col12MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { 4140@@ -802,32 +991,6 @@ void RowMajor2Row16MajorFp16(const void *src, float16_t *dst, int row, int col, 4141 } 4142 } 4143 4144-#ifdef ENABLE_ARM64 4145-void RowMajor2RowNMajorFp16(const float16_t *src, float16_t *dst, int row, int col) { 4146- // Row16 ==> Row8 ==> Row4 4147- for (int r = 0; r < row; r++) { 4148- int c = 0; 4149- for (; c <= col - C16NUM; c += C16NUM) { 4150- MS_FLOAT16X8 src_data = MS_LDQ_F16(src + r * col + c); 4151- MS_FLOAT16X8 src_data1 = MS_LDQ_F16(src + r * col + c + C8NUM); 4152- MS_STQ_F16(dst + c / C16NUM * C16NUM * row + r * C16NUM, src_data); 4153- MS_STQ_F16(dst + c / C16NUM * C16NUM * row + r * C16NUM + C8NUM, src_data1); 4154- } 4155- for (; c <= col - C8NUM; c += C8NUM) { 4156- MS_FLOAT16X8 src_data = MS_LDQ_F16(src + r * col + c); 4157- MS_STQ_F16(dst + c / C8NUM * C8NUM * row + r * C8NUM, src_data); 4158- } 4159- for (; c <= col - C4NUM; c += C4NUM) { 4160- MS_FLOAT16X4 src_data = MS_LD_F16(src + r * col + c); 4161- MS_ST_F16(dst + c / C4NUM * C4NUM * row + r * C4NUM, src_data); 4162- } 4163- for (; c < col; ++c) { 4164- dst[c / C4NUM * C4NUM * row + r * C4NUM + c % C4NUM] = src[r * col + c]; 4165- } 4166- } 4167-} 4168-#endif 4169- 4170 void RowMajor2Row16MajorFp16Opt(const float16_t *src, float16_t *dst, int row, int col) { 4171 int col_align = UP_ROUND(col, C16NUM); 4172 for (int r = 0; r < row; r++) { 4173diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/matmul_fp16.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/matmul_fp16.h 4174index be7f8443..7acef622 100644 4175--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/matmul_fp16.h 4176+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/matmul_fp16.h 4177@@ -14,8 +14,8 @@ 4178 * limitations under the License. 4179 */ 4180 4181-#ifndef NNACL_FP16_MATMUL_FP16_H_ 4182-#define NNACL_FP16_MATMUL_FP16_H_ 4183+#ifndef MINDSPORE_NNACL_FP16_MATMUL_H_ 4184+#define MINDSPORE_NNACL_FP16_MATMUL_H_ 4185 4186 #include <float.h> 4187 #include <string.h> 4188@@ -45,9 +45,13 @@ void MatMul12x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, cons 4189 int deep, int row, int col, int stride, int write_mode); 4190 4191 #ifdef ENABLE_ARM64 4192-void RowMajor2ColNMajorFp16(const float16_t *src, float16_t *dst_ptr, int row, int col); 4193+void RowMajor2ColLadder12MajorFp16(const float16_t *src, float16_t *dst_ptr, int row, int col); 4194 4195-void RowMajor2RowNMajorFp16(const float16_t *src, float16_t *dst, int row, int col); 4196+void RowMajor2RowLadder12MajorFp16(const float16_t *src, float16_t *dst, int row, int col); 4197+ 4198+void RowMajor2ColNMajorFp16(const void *src, float16_t *dst_ptr, int row, int col, bool is_fp32_src); 4199+ 4200+void RowMajor2RowNMajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src); 4201 4202 void MatMul12x16Fp16Opt(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, 4203 int deep, int row, int col, size_t stride, size_t out_type); 4204@@ -60,6 +64,9 @@ void MatmulFp16Neon64Opt(const float16_t *a, const float16_t *b, float16_t *c, c 4205 void MatmulBaseFp16Neon(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, 4206 size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc); 4207 4208+void MatmulFp16OptV2(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, 4209+ size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc); 4210+ 4211 #ifdef ENABLE_DEBUG 4212 void MatmulBaseFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, 4213 size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc); 4214@@ -118,4 +125,4 @@ void RowMajor2ColMajorFp16(const void *src, float16_t *dst, int row, int col, bo 4215 } 4216 #endif 4217 4218-#endif // NNACL_FP16_MATMUL_FP16_H_ 4219+#endif // MINDSPORE_NNACL_FP16_MATMUL_H_ 4220diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.c 4221index 74e75115..da9f6bef 100644 4222--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.c 4223+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.c 4224@@ -33,7 +33,7 @@ static void PackLstmMatrix(const float *src_batch, float *dst_batch, int col, in 4225 } 4226 4227 static void PackLstmWeightBatch(float *dst, const float *src, int batch, int deep, int col, int col_align, 4228- const int32_t *order) { 4229+ const int *order) { 4230 for (int i = 0; i < batch; i++) { 4231 const float *src_batch = src + i * col * deep; 4232 float *dst_batch = dst + ((order == NULL) ? i : order[i]) * col_align * deep; 4233@@ -41,12 +41,12 @@ static void PackLstmWeightBatch(float *dst, const float *src, int batch, int dee 4234 } 4235 } 4236 4237-void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col, int col_align, const int32_t *order) { 4238+void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col, int col_align, const int *order) { 4239 PackLstmWeightBatch(dst, src, batch, deep, col, col_align, order); 4240 } 4241 4242 void PackLstmWeightWithStride(float *dst, const float *src, int batch, int deep, int col, int col_align, 4243- bool is_bidirectional, int stride, const int32_t *order) { 4244+ bool is_bidirectional, int stride, const int *order) { 4245 int unidirectional_batch = is_bidirectional ? batch / 2 : batch; 4246 PackLstmWeightBatch(dst, src, unidirectional_batch, deep, col, col_align, order); 4247 src += stride; 4248@@ -57,7 +57,7 @@ void PackLstmWeightWithStride(float *dst, const float *src, int batch, int deep, 4249 } 4250 4251 void PackLstmBias(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional, 4252- const int32_t *order) { 4253+ const int *order) { 4254 int unidirectional_batch = is_bidirectional ? batch / 2 : batch; 4255 for (int i = 0; i < unidirectional_batch; i++) { 4256 const float *src_batch = src + i * col; 4257@@ -76,7 +76,7 @@ void PackLstmBias(float *dst, const float *src, int batch, int col, int col_alig 4258 } 4259 4260 void PackLstmBiasWithStride(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional, 4261- int b_stride, const int32_t *order) { 4262+ int b_stride, const int *order) { 4263 int unidirectional_batch = is_bidirectional ? batch / 2 : batch; 4264 for (int i = 0; i < unidirectional_batch; i++) { 4265 const float *src_batch = src + i * col; 4266@@ -175,13 +175,13 @@ void UpdateOutput(float *hidden_state, float *output, const float *cell_state, c 4267 const float *weight_project, float *buffer[C8NUM], const LstmParameter *lstm_param) { 4268 int batch = lstm_param->batch_; 4269 int hidden_size = lstm_param->hidden_size_; 4270- int project_size = lstm_param->project_size_; 4271+ int output_size = lstm_param->output_size_; 4272 float *state_buffer = buffer[C4NUM]; 4273 float *hidden_buffer = weight_project ? buffer[C2NUM] : hidden_state; 4274 float zoneout = lstm_param->zoneout_hidden_; 4275 if (!(zoneout >= -FLT_EPSILON && zoneout <= FLT_EPSILON)) { 4276- (void)memcpy(state_buffer, hidden_state, batch * project_size * sizeof(float)); 4277- ElementOptMul(state_buffer, &zoneout, state_buffer, batch * project_size, false); 4278+ (void)memcpy(state_buffer, hidden_state, batch * hidden_size * sizeof(float)); 4279+ ElementOptMul(state_buffer, &zoneout, state_buffer, batch * hidden_size, false); 4280 } 4281 4282 Tanh(cell_state, batch * hidden_size, hidden_buffer); 4283@@ -193,20 +193,13 @@ void UpdateOutput(float *hidden_state, float *output, const float *cell_state, c 4284 left_matrix = buffer[C6NUM]; 4285 PackLstmInput(hidden_buffer, left_matrix, batch, hidden_size); 4286 } 4287-#ifdef ENABLE_AVX 4288- int col_tile = batch == 1 ? C8NUM : C16NUM; 4289-#elif defined(ENABLE_ARM32) 4290- int col_tile = C4NUM; 4291-#else 4292- int col_tile = C8NUM; 4293-#endif 4294- LstmMatMul(hidden_state, left_matrix, weight_project, NULL, batch, hidden_size, project_size, 4295- UP_ROUND(project_size, col_tile), batch == 1, buffer[C7NUM]); 4296+ LstmMatMul(hidden_state, left_matrix, weight_project, NULL, batch, hidden_size, output_size, 4297+ lstm_param->proj_col_align_, batch == 1, buffer[C7NUM]); 4298 } 4299 if (!(zoneout >= -FLT_EPSILON && zoneout <= FLT_EPSILON)) { 4300- ElementOptMulAcc(hidden_state, 1 - zoneout, state_buffer, batch * project_size); 4301+ ElementOptMulAcc(hidden_state, 1 - zoneout, state_buffer, batch * output_size); 4302 } 4303- (void)memcpy(output, hidden_state, batch * project_size * sizeof(float)); 4304+ (void)memcpy(output, hidden_state, batch * output_size * sizeof(float)); 4305 } 4306 4307 void UpdateLstmGate(float *gate_buffer, const float *input, const float *weight, const float *bias, int row, int deep, 4308@@ -238,12 +231,12 @@ void LstmStepUnit(float *output, float *input_gate, float *forget_gate, float *c 4309 bool is_vec = lstm_param->batch_ == 1; 4310 // state * weight 4311 if (is_vec) { 4312- UpdateLstmGate(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->project_size_, 4313+ UpdateLstmGate(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->output_size_, 4314 lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec, packed_output); 4315 } else { 4316 // pack state for matmul 4317- PackLstmInput(hidden_state, packed_state, lstm_param->batch_, lstm_param->project_size_); 4318- UpdateLstmGate(state_gate, packed_state, state_weight, state_bias, lstm_param->batch_, lstm_param->project_size_, 4319+ PackLstmInput(hidden_state, packed_state, lstm_param->batch_, lstm_param->output_size_); 4320+ UpdateLstmGate(state_gate, packed_state, state_weight, state_bias, lstm_param->batch_, lstm_param->output_size_, 4321 lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec, packed_output); 4322 } 4323 ElementAdd(input_gate, state_gate, input_gate, lstm_param->batch_ * lstm_param->hidden_size_); 4324@@ -276,7 +269,7 @@ void LstmStepUnit(float *output, float *input_gate, float *forget_gate, float *c 4325 } 4326 4327 if (!(lstm_param->zoneout_hidden_ >= -FLT_EPSILON && lstm_param->zoneout_hidden_ <= FLT_EPSILON)) { 4328- (void)memcpy(hidden_state, hidden_buffer, lstm_param->batch_ * lstm_param->project_size_ * sizeof(float)); 4329+ (void)memcpy(hidden_state, hidden_buffer, lstm_param->batch_ * lstm_param->output_size_ * sizeof(float)); 4330 } 4331 } 4332 4333@@ -322,12 +315,12 @@ void Lstm(float *output, const float *input, const float *weight_i, const float 4334 // backward 4335 if (lstm_param->bidirectional_) { 4336 const float *backward_weight_i = weight_i + 4 * lstm_param->input_col_align_ * lstm_param->input_size_; 4337- const float *backward_weight_h = weight_h + 4 * lstm_param->state_col_align_ * lstm_param->hidden_size_; 4338+ const float *backward_weight_h = weight_h + 4 * lstm_param->state_col_align_ * lstm_param->output_size_; 4339 const float *backward_input_bias = input_bias + 4 * lstm_param->input_col_align_; 4340 const float *backward_state_bias = state_bias + 4 * lstm_param->state_col_align_; 4341- float *backward_output = output + lstm_param->batch_ * lstm_param->hidden_size_; 4342+ float *backward_output = output + lstm_param->batch_ * lstm_param->output_size_; 4343 float *backward_cell_state = cell_state + lstm_param->batch_ * lstm_param->hidden_size_; 4344- float *backward_hidden_state = hidden_state + lstm_param->batch_ * lstm_param->hidden_size_; 4345+ float *backward_hidden_state = hidden_state + lstm_param->batch_ * lstm_param->output_size_; 4346 4347 LstmUnidirectional(backward_output, packed_input, backward_weight_i, backward_weight_h, backward_input_bias, 4348 backward_state_bias, backward_hidden_state, backward_cell_state, buffer, lstm_param, true); 4349diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.h 4350index 88dd9d16..f94f0bb7 100644 4351--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.h 4352+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.h 4353@@ -21,16 +21,16 @@ 4354 #ifdef __cplusplus 4355 extern "C" { 4356 #endif 4357-void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col, int col_align, const int32_t *order); 4358+void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col, int col_align, const int *order); 4359 4360 void PackLstmWeightWithStride(float *dst, const float *src, int batch, int deep, int col, int col_align, 4361- bool is_bidirectional, int stride, const int32_t *order); 4362+ bool is_bidirectional, int stride, const int *order); 4363 4364 void PackLstmBias(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional, 4365- const int32_t *order); 4366+ const int *order); 4367 4368 void PackLstmBiasWithStride(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional, 4369- int b_stride, const int32_t *order); 4370+ int b_stride, const int *order); 4371 4372 void PackLstmInput(const float *src, float *dst, int row, int deep); 4373 4374diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_fp32.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_fp32.c 4375index 308419fb..1898ffd4 100644 4376--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_fp32.c 4377+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_fp32.c 4378@@ -440,8 +440,8 @@ void MatVecMulNoPackFp32(const float *a, const float *b, float *c, const float * 4379 } 4380 c[oc_index] = dst; 4381 } 4382- a += k; 4383- b += k * col; 4384+ a += C1500NUM; 4385+ b += C1500NUM * col; 4386 } 4387 if (k == depth) { 4388 return; 4389diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gather_d_grad_v2_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gather_d_grad_v2_infer.c 4390new file mode 100644 4391index 00000000..ad1cac2e 4392--- /dev/null 4393+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gather_d_grad_v2_infer.c 4394@@ -0,0 +1,36 @@ 4395+/** 4396+ * Copyright 2023 Huawei Technologies Co., Ltd 4397+ * 4398+ * Licensed under the Apache License, Version 2.0 (the "License"); 4399+ * you may not use this file except in compliance with the License. 4400+ * You may obtain a copy of the License at 4401+ * 4402+ * http://www.apache.org/licenses/LICENSE-2.0 4403+ * 4404+ * Unless required by applicable law or agreed to in writing, software 4405+ * distributed under the License is distributed on an "AS IS" BASIS, 4406+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 4407+ * See the License for the specific language governing permissions and 4408+ * limitations under the License. 4409+ */ 4410+ 4411+#include "nnacl/infer/custom_gather_d_grad_v2_infer.h" 4412+#include "nnacl/infer/infer_register.h" 4413+ 4414+int CustomGatherDGradV2InferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, 4415+ size_t outputs_size, OpParameter *parameter) { 4416+ int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C3NUM, C1NUM); 4417+ if (check_ret != NNACL_OK) { 4418+ return check_ret; 4419+ } 4420+ const TensorC *input = inputs[0]; 4421+ TensorC *output = outputs[0]; 4422+ SetDataTypeFormat(output, input); 4423+ if (!InferFlag(inputs, inputs_size)) { 4424+ return NNACL_INFER_INVALID; 4425+ } 4426+ SetShapeTensor(output, input); 4427+ return NNACL_OK; 4428+} 4429+ 4430+REG_INFER(CustomGatherDGradV2, PrimType_Inner_CustomGatherDGradV2, CustomGatherDGradV2InferShape) 4431diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gather_d_grad_v2_infer.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gather_d_grad_v2_infer.h 4432new file mode 100644 4433index 00000000..68d85d20 4434--- /dev/null 4435+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gather_d_grad_v2_infer.h 4436@@ -0,0 +1,30 @@ 4437+/** 4438+ * Copyright 2023 Huawei Technologies Co., Ltd 4439+ * 4440+ * Licensed under the Apache License, Version 2.0 (the "License"); 4441+ * you may not use this file except in compliance with the License. 4442+ * You may obtain a copy of the License at 4443+ * 4444+ * http://www.apache.org/licenses/LICENSE-2.0 4445+ * 4446+ * Unless required by applicable law or agreed to in writing, software 4447+ * distributed under the License is distributed on an "AS IS" BASIS, 4448+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 4449+ * See the License for the specific language governing permissions and 4450+ * limitations under the License. 4451+ */ 4452+#ifndef MINDSPORE_NNACL_CUSTOM_GATHER_D_GRAD_V2_INFER_H 4453+#define MINDSPORE_NNACL_CUSTOM_GATHER_D_GRAD_V2_INFER_H 4454+#include "nnacl/infer/common_infer.h" 4455+ 4456+#ifdef __cplusplus 4457+extern "C" { 4458+#endif 4459+ 4460+int CustomGatherDGradV2InferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, 4461+ size_t outputs_size, OpParameter *parameter); 4462+ 4463+#ifdef __cplusplus 4464+} 4465+#endif 4466+#endif // MINDSPORE_NNACL_CUSTOM_GATHER_D_GRAD_V2_INFER_H 4467diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/lstm_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/lstm_infer.c 4468index 9892ef0b..391e2522 100644 4469--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/lstm_infer.c 4470+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/lstm_infer.c 4471@@ -17,41 +17,81 @@ 4472 #include "nnacl/infer/lstm_infer.h" 4473 #include "nnacl/infer/infer_register.h" 4474 4475-static const int num_of_gates = 4; 4476-static const int no_of_recorde_values = 6; 4477+static const int no_of_recorde_values = 5; 4478 4479 int CheckInputShapeValid(const TensorC *const *inputs, size_t inputs_size, const LstmParameter *parameter) { 4480+ if (inputs_size < C6NUM) { 4481+ return NNACL_INPUT_TENSOR_ERROR; 4482+ } 4483 const TensorC *input = inputs[FIRST_INPUT]; 4484 const TensorC *weight_i = inputs[SECOND_INPUT]; 4485 const TensorC *weight_g = inputs[THIRD_INPUT]; 4486 const TensorC *bias = inputs[FOURTH_INPUT]; 4487- const TensorC *cell = inputs[FIFTH_INPUT]; 4488+ const TensorC *hidden_init = inputs[FIFTH_INPUT]; 4489+ const TensorC *cell_init = inputs[SIXTH_INPUT]; 4490+ 4491+ NNACL_CHECK_TRUE_RET(input->shape_size_ == DIMENSION_3D && weight_i->shape_size_ == DIMENSION_3D && 4492+ weight_g->shape_size_ == DIMENSION_3D && bias->shape_size_ == DIMENSION_2D, 4493+ NNACL_ERR); 4494 int batch = input->shape_[kNHWC_H]; 4495 int input_size = input->shape_[kNHWC_W]; 4496 int hidden_size = weight_i->shape_[kNHWC_H] / C4NUM; 4497- int project_size = inputs_size == C7NUM ? inputs[C6NUM]->shape_[kNHWC_H] : hidden_size; 4498- bool bidirectional = parameter->bidirectional_; 4499- if (input->shape_size_ != DIMENSION_3D || weight_i->shape_size_ != DIMENSION_3D) { 4500- return NNACL_ERR; 4501+ int out_size = hidden_size; 4502+ if (inputs_size == C7NUM) { 4503+ NNACL_CHECK_TRUE_RET(inputs[SEVENTH_INPUT]->shape_size_ == DIMENSION_3D, NNACL_INPUT_TENSOR_ERROR); 4504+ out_size = inputs[SEVENTH_INPUT]->shape_[kNHWC_H]; 4505 } 4506+ bool bidirectional = parameter->bidirectional_; 4507 int bidirection = bidirectional ? C2NUM : C1NUM; 4508 NNACL_CHECK_TRUE_RET(weight_i->shape_[kNHWC_N] == bidirection && weight_i->shape_[kNHWC_H] == hidden_size * C4NUM && 4509 weight_i->shape_[kNHWC_W] == input_size, 4510 NNACL_ERR); 4511 NNACL_CHECK_TRUE_RET(weight_g->shape_[kNHWC_N] == bidirection && weight_g->shape_[kNHWC_H] == hidden_size * C4NUM && 4512- weight_g->shape_[kNHWC_W] == project_size, 4513+ weight_g->shape_[kNHWC_W] == out_size, 4514 NNACL_ERR); 4515 NNACL_CHECK_TRUE_RET(bias->shape_[kNHWC_N] == bidirection && bias->shape_[kNHWC_H] == hidden_size * C8NUM, NNACL_ERR); 4516- if (!bidirectional && cell->shape_size_ == DIMENSION_2D) { 4517- NNACL_CHECK_TRUE_RET(cell->shape_[kNHWC_N] == batch && cell->shape_[kNHWC_H] == hidden_size, NNACL_ERR); 4518+ if (!bidirectional && hidden_init->shape_size_ == DIMENSION_2D) { 4519+ NNACL_CHECK_TRUE_RET(hidden_init->shape_[kNHWC_N] == batch && hidden_init->shape_[kNHWC_H] == out_size, NNACL_ERR); 4520 } else { 4521- NNACL_CHECK_TRUE_RET( 4522- cell->shape_[kNHWC_N] == bidirection && cell->shape_[kNHWC_H] == batch && cell->shape_[kNHWC_W] == project_size, 4523- NNACL_ERR); 4524+ NNACL_CHECK_TRUE_RET(hidden_init->shape_size_ == DIMENSION_3D && hidden_init->shape_[kNHWC_N] == bidirection && 4525+ hidden_init->shape_[kNHWC_H] == batch && hidden_init->shape_[kNHWC_W] == out_size, 4526+ NNACL_ERR); 4527+ } 4528+ if (!bidirectional && cell_init->shape_size_ == DIMENSION_2D) { 4529+ NNACL_CHECK_TRUE_RET(cell_init->shape_[kNHWC_N] == batch && cell_init->shape_[kNHWC_H] == hidden_size, NNACL_ERR); 4530+ } else { 4531+ NNACL_CHECK_TRUE_RET(cell_init->shape_size_ == DIMENSION_3D && cell_init->shape_[kNHWC_N] == bidirection && 4532+ cell_init->shape_[kNHWC_H] == batch && cell_init->shape_[kNHWC_W] == hidden_size, 4533+ NNACL_ERR); 4534 } 4535 return NNACL_OK; 4536 } 4537 4538+int InferFirstOutputMindir(const TensorC *const *inputs, size_t inputs_size, TensorC *output, LstmParameter *param) { 4539+ for (size_t i = 0; i < inputs_size; ++i) { 4540+ if (inputs[i]->shape_size_ != C3NUM) { 4541+ return NNACL_INPUT_TENSOR_ERROR; 4542+ } 4543+ } 4544+ ShapeSet(output->shape_, &output->shape_size_, inputs[0]->shape_, inputs[0]->shape_size_); 4545+ int out_size = inputs[SECOND_INPUT]->shape_[THIRD_INPUT]; 4546+ output->shape_[THIRD_INPUT] = (param->bidirectional_ ? C2NUM : 1) * out_size; 4547+ return NNACL_OK; 4548+} 4549+ 4550+int InferFirstOutputNonMindir(const TensorC *const *inputs, size_t inputs_size, TensorC *output, LstmParameter *param) { 4551+ if (CheckInputShapeValid(inputs, inputs_size, param) != NNACL_OK) { 4552+ return NNACL_ERR; 4553+ } 4554+ ShapeSet(output->shape_, &output->shape_size_, inputs[0]->shape_, inputs[0]->shape_size_); 4555+ const TensorC *hidden_init = inputs[FIFTH_INPUT]; 4556+ int out_size = hidden_init->shape_[hidden_init->shape_size_ - 1]; 4557+ output->shape_[THIRD_INPUT] = out_size; 4558+ int direction = param->bidirectional_ ? C2NUM : C1NUM; 4559+ int ret = ShapeInsert(output->shape_, &output->shape_size_, 1, direction); 4560+ return ret; 4561+} 4562+ 4563 int LstmInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, 4564 OpParameter *parameter) { 4565 int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 4, 3); 4566@@ -60,9 +100,8 @@ int LstmInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o 4567 } 4568 4569 const TensorC *input = inputs[0]; 4570- const TensorC *weight_i = inputs[1]; 4571 TensorC *output = outputs[0]; 4572- for (int i = 0; i < 3; i++) { 4573+ for (int i = 0; i < outputs_size; i++) { 4574 SetDataTypeFormat(outputs[i], input); 4575 } 4576 4577@@ -71,42 +110,31 @@ int LstmInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o 4578 if (!InferFlag(inputs, inputs_size)) { 4579 return NNACL_INFER_INVALID; 4580 } 4581- int dir_multiplier = param->bidirectional_ ? 2 : 1; 4582- int out_shape[MAX_SHAPE_SIZE]; 4583- size_t out_shape_size = 0; 4584- int hidden_size = 1; 4585- int project_size = 1; 4586- ShapeSet(out_shape, &out_shape_size, input->shape_, input->shape_size_); 4587- if (inputs_size == DIMENSION_4D) { // if input from MINDIR 4588- hidden_size = weight_i->shape_[THIRD_INPUT]; 4589- project_size = hidden_size; 4590- out_shape[THIRD_INPUT] = hidden_size * dir_multiplier; 4591- } else { 4592- if (CheckInputShapeValid(inputs, inputs_size, param) != NNACL_OK) { 4593- return NNACL_ERR; 4594+ int hidden_size = 0; 4595+ int out_size = 0; 4596+ if (inputs_size == C4NUM) { 4597+ int ret = InferFirstOutputMindir(inputs, inputs_size, output, param); 4598+ if (ret != NNACL_OK) { 4599+ return ret; 4600 } 4601- hidden_size = weight_i->shape_[1] / num_of_gates; 4602- project_size = inputs_size == C7NUM ? inputs[C6NUM]->shape_[kNHWC_H] : hidden_size; 4603- out_shape[THIRD_INPUT] = project_size; 4604- if (param->bidirectional_) { 4605- int ret = ShapeInsert(out_shape, &out_shape_size, 1, 2); 4606- if (ret != NNACL_OK) { 4607- return NNACL_ERR; 4608- } 4609- } else { 4610- int ret = ShapeInsert(out_shape, &out_shape_size, 1, 1); 4611- if (ret != NNACL_OK) { 4612- return NNACL_ERR; 4613- } 4614+ hidden_size = inputs[THIRD_INPUT]->shape_[THIRD_INPUT]; 4615+ out_size = inputs[SECOND_INPUT]->shape_[THIRD_INPUT]; 4616+ } else { 4617+ int ret = InferFirstOutputNonMindir(inputs, inputs_size, output, param); 4618+ if (ret != NNACL_OK) { 4619+ return ret; 4620 } 4621+ hidden_size = inputs[SIXTH_INPUT]->shape_[inputs[SIXTH_INPUT]->shape_size_ - 1]; 4622+ out_size = inputs[FIFTH_INPUT]->shape_[inputs[FIFTH_INPUT]->shape_size_ - 1]; 4623 } 4624- SetShapeArray(output, out_shape, out_shape_size); 4625+ 4626+ int dir_multiplier = param->bidirectional_ ? C2NUM : C1NUM; 4627 int state_shape[MAX_SHAPE_SIZE]; 4628 size_t state_shape_size = 0; 4629 4630 ShapeSet(state_shape, &state_shape_size, input->shape_, input->shape_size_); 4631 state_shape[FIRST_INPUT] = dir_multiplier; 4632- state_shape[THIRD_INPUT] = project_size; 4633+ state_shape[THIRD_INPUT] = out_size; 4634 SetShapeArray(outputs[SECOND_INPUT], state_shape, state_shape_size); 4635 state_shape[THIRD_INPUT] = hidden_size; 4636 SetShapeArray(outputs[THIRD_INPUT], state_shape, state_shape_size); 4637@@ -116,11 +144,9 @@ int LstmInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o 4638 const size_t intermediate_states_shape_size = 1; 4639 int batch_size = input->shape_[SECOND_INPUT]; 4640 int seq_len = input->shape_[FIRST_INPUT]; 4641- intermediate_states_shape[FIRST_INPUT] = no_of_recorde_values * batch_size * hidden_size * seq_len * dir_multiplier; 4642- SetDataTypeFormat(outputs[FOURTH_INPUT], inputs[FIRST_INPUT]); 4643+ intermediate_states_shape[FIRST_INPUT] = 4644+ batch_size * seq_len * dir_multiplier * (out_size + no_of_recorde_values * hidden_size); 4645 SetShapeArray(outputs[FOURTH_INPUT], intermediate_states_shape, intermediate_states_shape_size); 4646- 4647- SetDataTypeFormat(outputs[FIFTH_INPUT], inputs[FIRST_INPUT]); 4648 SetShapeArray(outputs[FIFTH_INPUT], state_shape, state_shape_size); 4649 } 4650 4651diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/reshape_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/reshape_infer.c 4652index 287e9de3..3c192df7 100644 4653--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/reshape_infer.c 4654+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/reshape_infer.c 4655@@ -33,12 +33,14 @@ int CalShape(const int *data, const TensorC *const *inputs, int *out_shape, size 4656 } 4657 ShapePush(out_shape, out_shape_size, data[i]); 4658 } 4659- 4660+ if (size == 0) { 4661+ return NNACL_ERR; 4662+ } 4663 if ((int)(data[index]) == -1) { 4664 if (index >= MAX_SHAPE_SIZE) { 4665 return NNACL_ERR; 4666 } 4667- out_shape[index] = size == 0 ? 0 : input_count / size; 4668+ out_shape[index] = input_count / size; 4669 } 4670 return NNACL_OK; 4671 } 4672diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_instructions.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_instructions.h 4673index 377993cd..6a933785 100644 4674--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_instructions.h 4675+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_instructions.h 4676@@ -308,7 +308,7 @@ static inline float simd_exp32_f32(float data) { 4677 #else 4678 data = MS_MAX32_F32(-88.0f, MS_MIN32_F32(88.0f, data)); // clamp(-88, 88) 4679 #endif 4680- int integer = floor(data * 1.44269504088896341f + 0.5f); 4681+ int integer = data / param[0]; 4682 float decimal = data - integer * param[0]; 4683 fi int_exp; 4684 int_exp.i = (integer + 127) << 23; // Approximate calculation : (integer + 127) << 23 4685@@ -324,14 +324,19 @@ static inline void simd_exp32(float src, float *dst) { 4686 int i; 4687 } fi; 4688 static float param[] = {0.693147f, 1.0f / 120, 1.0f / 24, 1.0f / 6, 1.0f / 2, 1.0f}; // log(2.0f) 4689- src = MS_MAX32_F32(-88.0f, MS_MIN32_F32(88.0f, src)); // clamp(-88.0f, 88.0f) 4690+ src = MS_MAX32_F32(-87.3365478515625f, MS_MIN32_F32(88.72283935546875f, src)); // clamp(logf(FLT_MIN), logf(FLT_MAX)) 4691 int integer = floor(src * 1.44269504088896341f + 0.5f); 4692 float decimal = src - integer * param[0]; 4693 fi int_exp; 4694- int_exp.i = (integer + 127) << 23; // integer num approximate calculation : (x + 127) << 23 4695+ const int shift = 23; 4696+ const int bias = 126; 4697+ const float factor = 2; 4698+ // 2^n * exp(r) should be counted 2 * 2^(n - 1) * exp(r), 4699+ // because n may be 128, and it is not representable by fp32. 4700+ int_exp.i = (integer + bias) << shift; // integer num 2^(n - 1) approximate calculation : ((x - 1) + 127) << 23 4701 const float decimal_exp = 4702 1.0f + decimal * (1.0f + decimal * (0.5f + decimal * (param[3] + decimal * (param[2] + decimal * param[1])))); 4703- *dst = int_exp.f * decimal_exp; 4704+ *dst = factor * int_exp.f * decimal_exp; 4705 } 4706 4707 // define (float/int) data 4708diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_instructions_fp16.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_instructions_fp16.h 4709index a29c4dbb..94ed4b89 100644 4710--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_instructions_fp16.h 4711+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_instructions_fp16.h 4712@@ -94,9 +94,13 @@ static inline float16x4_t ms_vcvt_f16_f32(float32x4_t in) { 4713 4714 #define MS_FLOAT16X8 float16x8_t 4715 #define MS_FLOAT16X4 float16x4_t 4716+#define MS_FLOAT16X4X4 float16x4x4_t 4717+#define MS_FLOAT16X4X2 float16x4x2_t 4718 #define MS_MOVQ_F16 vmovq_n_f16 4719 #define MS_STQ_F16(ptr, val) vst1q_f16(ptr, val) 4720 #define MS_ST_F16 vst1_f16 4721+#define MS_ST2_F16 vst2_f16 4722+#define MS_ST4_F16 vst4_f16 4723 #define MS_MINQ_F16 vminq_f16 4724 #define MS_MAXQ_F16 vmaxq_f16 4725 #define MS_LDQ_F16(ptr) vld1q_f16(ptr) 4726diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_neon_instructions.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_neon_instructions.h 4727index c4bc34d9..fb38b452 100644 4728--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_neon_instructions.h 4729+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_neon_instructions.h 4730@@ -25,6 +25,8 @@ 4731 #define MS128_F32_GETI(src, i) src[i] 4732 #define MS_FLOAT32X4 float32x4_t 4733 #define MS_FLOAT128_F32 float32x4_t 4734+#define MS_FLOAT32X4X2 float32x4x2_t 4735+#define MS_FLOAT32X4X4 float32x4x4_t 4736 #define MS_INT32X4 int32x4_t 4737 #define MS_INT128_EPI32 int32x4_t 4738 #define MS_UINT32X4 uint32x4_t 4739@@ -222,29 +224,30 @@ static inline MS_FLOAT32X4 VexpFp32(MS_FLOAT32X4 input) { 4740 {1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6}, 4741 {0.5f, 0.5f, 0.5f, 0.5f}, 4742 {1.0f, 1.0f, 1.0f, 1.0f}, 4743- {1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f}}; 4744+ {1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f}, 4745+ {2.0f, 2.0f, 2.0f, 2.0f}}; 4746 static MS_FLOAT32X4 negative_flag = {-0.0f, -0.0f, -0.0f, -0.0f}; 4747 4748 MS_INT32X4 integer = 4749 MS_CVTQPS_EPI32(MS_FMADD128_F32(input, param[6], MS_OR128_F32(MS_AND128_F32(input, negative_flag), param[4]))); 4750 MS_FLOAT32X4 decimal = MS_SUBQ_F32(input, MS_MULQ_F32(MS_CVTQEPI32_PS(integer), param[0])); 4751- MS_INT32X4 int_exp = MS_SLLIQ_EPI32(MS_ADDQ_EPI32(integer, MS_MOVQ_EPI32(127)), 23); 4752+ MS_INT32X4 int_exp = MS_SLLIQ_EPI32(MS_ADDQ_EPI32(integer, MS_MOVQ_EPI32(126)), 23); 4753 MS_FLOAT32X4 tmp = MS_MULQ_F32(decimal, (MS_ADDQ_F32(param[2], MS_MULQ_F32(decimal, param[1])))); 4754 tmp = MS_MULQ_F32(decimal, MS_ADDQ_F32(param[4], MS_MULQ_F32(decimal, MS_ADDQ_F32(param[3], tmp)))); 4755 MS_FLOAT32X4 decimal_exp = MS_ADDQ_F32(param[5], MS_MULQ_F32(decimal, MS_ADDQ_F32(param[5], tmp))); 4756- return MS_MULQ_F32(decimal_exp, MS_CAST128_F32_S32(int_exp)); 4757+ return MS_MULQ_F32(param[7], MS_MULQ_F32(decimal_exp, MS_CAST128_F32_S32(int_exp))); 4758 } 4759 4760 static inline void simd_exp128(MS_FLOAT32X4 input, float *dst) { 4761- static MS_FLOAT32X4 maxv = {88.0f, 88.0f, 88.0f, 88.0f}; 4762- static MS_FLOAT32X4 minv = {-88.0f, -88.0f, -88.0f, -88.0f}; 4763+ static MS_FLOAT32X4 maxv = {88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f}; 4764+ static MS_FLOAT32X4 minv = {-87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f}; 4765 input = MS_MAXQ_F32(minv, MS_MINQ_F32(input, maxv)); 4766 MS_STQ_F32(dst, VexpFp32(input)); 4767 } 4768 4769 static inline MS_FLOAT32X4 simd_exp128_f32(MS_FLOAT32X4 input) { 4770- static MS_FLOAT32X4 maxv = {88.0f, 88.0f, 88.0f, 88.0f}; 4771- static MS_FLOAT32X4 minv = {-88.0f, -88.0f, -88.0f, -88.0f}; 4772+ static MS_FLOAT32X4 maxv = {88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f}; 4773+ static MS_FLOAT32X4 minv = {-87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f}; 4774 input = MS_MAXQ_F32(minv, MS_MINQ_F32(input, maxv)); 4775 return VexpFp32(input); 4776 } 4777@@ -286,18 +289,6 @@ static inline MS_FLOAT32X4 MS_TANHX4_F32(MS_FLOAT32X4 src) { 4778 return res; 4779 } 4780 4781-static inline MS_FLOAT128_F32 SIMD_SIGN128_F32(MS_FLOAT128_F32 src) { 4782- MS_FLOAT128_F32 abs_src = MS_ABS128_F32(src); 4783- MS_FLOAT128_F32 src_tmp = MS_OR128_F32(src, MS_MOV128_F32(1.0f)); 4784- MS_FLOAT128_F32 sign = MS_DIV128_F32(abs_src, src_tmp); 4785- return sign; 4786-} 4787- 4788-static inline MS_FLOAT128_F32 SIMD_SIGNABS128_F32(MS_FLOAT128_F32 src, MS_FLOAT128_F32 abs_src) { 4789- MS_FLOAT128_F32 src_tmp = MS_OR128_F32(src, MS_MOV128_F32(1.0f)); 4790- return MS_DIV128_F32(abs_src, src_tmp); 4791-} 4792- 4793 #define MS_TANH128_F32 MS_TANHX4_F32 4794 4795 static inline MS_FLOAT32X4 MS128_ERF_F32(MS_FLOAT32X4 src) { 4796diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/lstm_parameter.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/lstm_parameter.h 4797index 9ecd8409..5baf10fa 100644 4798--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/lstm_parameter.h 4799+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/lstm_parameter.h 4800@@ -25,6 +25,7 @@ typedef struct LstmParameter { 4801 int input_size_; 4802 int hidden_size_; 4803 int project_size_; 4804+ int output_size_; 4805 int seq_len_; 4806 int batch_; 4807 // other parameter 4808@@ -36,6 +37,8 @@ typedef struct LstmParameter { 4809 int input_col_align_; 4810 int state_row_align_; 4811 int state_col_align_; 4812+ int proj_col_align_; 4813+ bool has_bias_; 4814 } LstmParameter; 4815 4816 #endif // NNACL_LSTM_PARAMETER_H_ 4817diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h 4818index 895f7e3d..bd0d152c 100644 4819--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h 4820+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h 4821@@ -562,6 +562,7 @@ enum PrimType { 4822 PrimType_Inner_CustomMaskedFill = 10014, 4823 PrimType_Inner_CustomTensorScatterMax = 10015, 4824 PrimType_Inner_CustomIsInf = 10016, 4825+ PrimType_Inner_CustomGatherDGradV2 = 10017, 4826 PrimType_InnerOpMax, 4827 PrimType_InnerOpMin = PrimType_Inner_ToFormat 4828 }; 4829diff --git a/mindspore/core/mindrt/src/thread/threadpool.cc b/mindspore/core/mindrt/src/thread/threadpool.cc 4830index 2301be8c..342ffb7f 100644 4831--- a/mindspore/core/mindrt/src/thread/threadpool.cc 4832+++ b/mindspore/core/mindrt/src/thread/threadpool.cc 4833@@ -53,7 +53,7 @@ Worker::~Worker() { 4834 void Worker::CreateThread() { thread_ = std::make_unique<std::thread>(&Worker::Run, this); } 4835 4836 void Worker::ReinitAfterFork() { 4837- THREAD_INFO("worker %ld recreate thread after fork in child process", worker_id_); 4838+ THREAD_INFO("worker %zu recreate thread after fork in child process", worker_id_); 4839 if (cond_var_ != nullptr) { 4840 (void)cond_var_.release(); 4841 cond_var_ = std::make_unique<std::condition_variable>(); 4842diff --git a/mindspore/core/ops/base_operator.h b/mindspore/core/ops/base_operator.h 4843index 811a6000..23652e8e 100644 4844--- a/mindspore/core/ops/base_operator.h 4845+++ b/mindspore/core/ops/base_operator.h 4846@@ -75,7 +75,7 @@ class MIND_API OperatorRegisterHelper { 4847 public: 4848 OperatorRegisterHelper(const std::string &kname, const OperatorDefineFunc &fn) { 4849 OperatorRegister::GetInstance().SetOperatorMap(kname, fn); 4850- (void)id_; // make compiler happy on macos 4851+ // (void)id_; // make compiler happy on macos 4852 } 4853 4854 ~OperatorRegisterHelper() = default; 4855diff --git a/mindspore/core/ops/grad/gather_d_grad_v2.cc b/mindspore/core/ops/grad/gather_d_grad_v2.cc 4856index 3ce5f887..c999ca88 100644 4857--- a/mindspore/core/ops/grad/gather_d_grad_v2.cc 4858+++ b/mindspore/core/ops/grad/gather_d_grad_v2.cc 4859@@ -75,6 +75,11 @@ TypePtr GatherDGradV2InferType(const PrimitivePtr &prim, const std::vector<Abstr 4860 } 4861 } // namespace 4862 4863+int64_t GatherDGradV2::get_dim() const { 4864+ auto value_ptr = this->GetAttr(kDim); 4865+ return GetValue<int64_t>(value_ptr); 4866+} 4867+ 4868 AbstractBasePtr GatherDGradV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, 4869 const std::vector<AbstractBasePtr> &input_args) { 4870 auto infer_type = GatherDGradV2InferType(primitive, input_args); 4871diff --git a/mindspore/core/ops/grad/gather_d_grad_v2.h b/mindspore/core/ops/grad/gather_d_grad_v2.h 4872index 94274e3b..40a6e412 100644 4873--- a/mindspore/core/ops/grad/gather_d_grad_v2.h 4874+++ b/mindspore/core/ops/grad/gather_d_grad_v2.h 4875@@ -25,6 +25,7 @@ class MIND_API GatherDGradV2 : public BaseOperator { 4876 public: 4877 MIND_API_BASE_MEMBER(GatherDGradV2); 4878 GatherDGradV2() : BaseOperator(kNameGatherDGradV2) { InitIOName({"x", "dim", "index", "grad"}, {"output"}); } 4879+ int64_t get_dim() const; 4880 }; 4881 MIND_API abstract::AbstractBasePtr GatherDGradV2Infer(const abstract::AnalysisEnginePtr &, 4882 const PrimitivePtr &primitive, 4883diff --git a/mindspore/core/ops/grad/lstm_grad.cc b/mindspore/core/ops/grad/lstm_grad.cc 4884index d51c4882..c25e0379 100644 4885--- a/mindspore/core/ops/grad/lstm_grad.cc 4886+++ b/mindspore/core/ops/grad/lstm_grad.cc 4887@@ -98,15 +98,22 @@ void LSTMGrad::set_zoneout_hidden(float zoneout_hidden) { 4888 4889 float LSTMGrad::get_zoneout_hidden() const { return GetValue<float>(this->GetAttr(kZoneoutHidden)); } 4890 4891+void LSTMGrad::set_proj_size(const int64_t proj_size) { 4892+ (void)CheckAndConvertUtils::CheckInteger(kProjection_size, proj_size, kGreaterThan, 0, this->name()); 4893+ (void)AddAttr(kProjection_size, api::MakeValue(proj_size)); 4894+} 4895+int64_t LSTMGrad::get_proj_size() const { return GetValue<int64_t>(GetAttr(kProjection_size)); } 4896+ 4897 void LSTMGrad::Init(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers, const bool has_bias, 4898- const float dropout, const bool bidirectional, const float zoneout_cell, 4899- const float zoneout_hidden) { 4900+ const float dropout, const bool bidirectional, const float zoneout_cell, const float zoneout_hidden, 4901+ const int64_t proj_size) { 4902 this->set_input_size(input_size); 4903 this->set_hidden_size(hidden_size); 4904 this->set_num_layers(num_layers); 4905 this->set_has_bias(has_bias); 4906 this->set_dropout(dropout); 4907 this->set_bidirectional(bidirectional); 4908+ this->set_proj_size(proj_size); 4909 if (bidirectional) { 4910 constexpr int k2Directions = 2; 4911 this->set_num_directions(k2Directions); 4912diff --git a/mindspore/core/ops/grad/lstm_grad.h b/mindspore/core/ops/grad/lstm_grad.h 4913index 73272d55..f6eba32c 100644 4914--- a/mindspore/core/ops/grad/lstm_grad.h 4915+++ b/mindspore/core/ops/grad/lstm_grad.h 4916@@ -31,7 +31,7 @@ class MIND_API LSTMGrad : public BaseOperator { 4917 LSTMGrad() : BaseOperator(kNameLSTMGrad) {} 4918 void Init(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers, const bool has_bias, 4919 const float dropout, const bool bidirectional = false, const float zoneout_cell = 0.0f, 4920- const float zoneout_hidden = 0.0f); 4921+ const float zoneout_hidden = 0.0f, const int64_t proj_size = 0); 4922 void set_input_size(const int64_t input_size); 4923 int64_t get_input_size() const; 4924 void set_hidden_size(const int64_t hidden_size); 4925@@ -51,6 +51,8 @@ class MIND_API LSTMGrad : public BaseOperator { 4926 void set_zoneout_hidden(float zoneout_hidden); 4927 float get_zoneout_hidden() const; 4928 int64_t get_good_ld(const int64_t dim, const int64_t type_size); 4929+ void set_proj_size(const int64_t proj_size); 4930+ int64_t get_proj_size() const; 4931 }; 4932 } // namespace ops 4933 } // namespace mindspore 4934diff --git a/mindspore/core/ops/grad/lstm_grad_data.cc b/mindspore/core/ops/grad/lstm_grad_data.cc 4935index 573d26f4..2b25282c 100644 4936--- a/mindspore/core/ops/grad/lstm_grad_data.cc 4937+++ b/mindspore/core/ops/grad/lstm_grad_data.cc 4938@@ -91,15 +91,23 @@ void LSTMGradData::set_zoneout_hidden(float zoneout_hidden) { 4939 4940 float LSTMGradData::get_zoneout_hidden() const { return GetValue<float>(this->GetAttr(kZoneoutHidden)); } 4941 4942+void LSTMGradData::set_proj_size(const int64_t proj_size) { 4943+ (void)CheckAndConvertUtils::CheckInteger(kProjection_size, proj_size, kGreaterThan, 0, this->name()); 4944+ (void)AddAttr(kProjection_size, api::MakeValue(proj_size)); 4945+} 4946+ 4947+int64_t LSTMGradData::get_proj_size() const { return GetValue<int64_t>(GetAttr(kProjection_size)); } 4948+ 4949 void LSTMGradData::Init(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers, 4950 const bool has_bias, const float dropout, const bool bidirectional, const float zoneout_cell, 4951- const float zoneout_hidden) { 4952+ const float zoneout_hidden, const int64_t proj_size) { 4953 this->set_input_size(input_size); 4954 this->set_hidden_size(hidden_size); 4955 this->set_num_layers(num_layers); 4956 this->set_has_bias(has_bias); 4957 this->set_dropout(dropout); 4958 this->set_bidirectional(bidirectional); 4959+ this->set_proj_size(proj_size); 4960 if (bidirectional) { 4961 constexpr int k2Directions = 2; 4962 this->set_num_directions(k2Directions); 4963diff --git a/mindspore/core/ops/grad/lstm_grad_data.h b/mindspore/core/ops/grad/lstm_grad_data.h 4964index adcf2ee7..f93e3260 100644 4965--- a/mindspore/core/ops/grad/lstm_grad_data.h 4966+++ b/mindspore/core/ops/grad/lstm_grad_data.h 4967@@ -32,7 +32,7 @@ class MIND_API LSTMGradData : public BaseOperator { 4968 LSTMGradData() : BaseOperator(kNameLSTMGradData) {} 4969 void Init(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers, const bool has_bias, 4970 const float dropout, const bool bidirectional = false, const float zoneout_cell = 0.0f, 4971- const float zoneout_hidden = 0.0f); 4972+ const float zoneout_hidden = 0.0f, const int64_t proj_size = 0); 4973 void set_input_size(const int64_t input_size); 4974 int64_t get_input_size() const; 4975 void set_hidden_size(const int64_t hidden_size); 4976@@ -52,6 +52,8 @@ class MIND_API LSTMGradData : public BaseOperator { 4977 void set_zoneout_hidden(float zoneout_hidden); 4978 float get_zoneout_hidden() const; 4979 int64_t get_good_ld(const int64_t dim, const int64_t type_size); 4980+ void set_proj_size(const int64_t proj_size); 4981+ int64_t get_proj_size() const; 4982 }; 4983 MIND_API abstract::AbstractBasePtr LstmGradDataInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, 4984 const std::vector<abstract::AbstractBasePtr> &input_args); 4985diff --git a/mindspore/core/ops/grad/lstm_grad_weight.cc b/mindspore/core/ops/grad/lstm_grad_weight.cc 4986index 22b519c3..ce0aca94 100644 4987--- a/mindspore/core/ops/grad/lstm_grad_weight.cc 4988+++ b/mindspore/core/ops/grad/lstm_grad_weight.cc 4989@@ -88,15 +88,23 @@ void LSTMGradWeight::set_zoneout_hidden(float zoneout_hidden) { 4990 4991 float LSTMGradWeight::get_zoneout_hidden() const { return GetValue<float>(this->GetAttr(kZoneoutHidden)); } 4992 4993+void LSTMGradWeight::set_proj_size(const int64_t proj_size) { 4994+ (void)CheckAndConvertUtils::CheckInteger(kProjection_size, proj_size, kGreaterThan, 0, this->name()); 4995+ (void)AddAttr(kProjection_size, api::MakeValue(proj_size)); 4996+} 4997+ 4998+int64_t LSTMGradWeight::get_proj_size() const { return GetValue<int64_t>(GetAttr(kProjection_size)); } 4999+ 5000 void LSTMGradWeight::Init(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers, 5001 const bool has_bias, const float dropout, const bool bidirectional, const float zoneout_cell, 5002- const float zoneout_hidden) { 5003+ const float zoneout_hidden, const int64_t proj_size) { 5004 this->set_input_size(input_size); 5005 this->set_hidden_size(hidden_size); 5006 this->set_num_layers(num_layers); 5007 this->set_has_bias(has_bias); 5008 this->set_dropout(dropout); 5009 this->set_bidirectional(bidirectional); 5010+ this->set_proj_size(proj_size); 5011 if (bidirectional) { 5012 constexpr int k2Directions = 2; 5013 this->set_num_directions(k2Directions); 5014diff --git a/mindspore/core/ops/grad/lstm_grad_weight.h b/mindspore/core/ops/grad/lstm_grad_weight.h 5015index c2ca6b5e..add816d3 100644 5016--- a/mindspore/core/ops/grad/lstm_grad_weight.h 5017+++ b/mindspore/core/ops/grad/lstm_grad_weight.h 5018@@ -32,7 +32,7 @@ class MIND_API LSTMGradWeight : public BaseOperator { 5019 LSTMGradWeight() : BaseOperator(kNameLSTMGradWeight) {} 5020 void Init(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers, const bool has_bias, 5021 const float dropout, const bool bidirectional = false, const float zoneout_cell = 0.0f, 5022- const float zoneout_hidden = 0.0f); 5023+ const float zoneout_hidden = 0.0f, const int64_t proj_size = 0); 5024 void set_input_size(const int64_t input_size); 5025 int64_t get_input_size() const; 5026 void set_hidden_size(const int64_t hidden_size); 5027@@ -52,6 +52,8 @@ class MIND_API LSTMGradWeight : public BaseOperator { 5028 void set_zoneout_hidden(float zoneout_hidden); 5029 float get_zoneout_hidden() const; 5030 int64_t get_good_ld(const int64_t dim, const int64_t type_size); 5031+ void set_proj_size(const int64_t proj_size); 5032+ int64_t get_proj_size() const; 5033 }; 5034 MIND_API abstract::AbstractBasePtr LstmGradWeightInfer(const abstract::AnalysisEnginePtr &, 5035 const PrimitivePtr &primitive, 5036diff --git a/mindspore/core/ops/lstm.cc b/mindspore/core/ops/lstm.cc 5037index 43b9241c..937207df 100644 5038--- a/mindspore/core/ops/lstm.cc 5039+++ b/mindspore/core/ops/lstm.cc 5040@@ -68,6 +68,7 @@ abstract::TupleShapePtr LSTMInferShape(const PrimitivePtr &primitive, const std: 5041 int64_t input_x_size = GetValue<int64_t>(primitive->GetAttr(kInput_size)); 5042 int64_t num_layers = GetValue<int64_t>(primitive->GetAttr(kNumLayers)); 5043 bool bidirectional = GetValue<bool>(primitive->GetAttr(kBidirectional)); 5044+ int64_t proj_size = GetValue<int64_t>(primitive->GetAttr(kProjection_size)); 5045 int64_t num_directions = 1; 5046 if (bidirectional) { 5047 num_directions = 2; 5048@@ -90,7 +91,8 @@ abstract::TupleShapePtr LSTMInferShape(const PrimitivePtr &primitive, const std: 5049 (void)CheckAndConvertUtils::CheckInteger("h_shape[1]", h_input_shape[1], kEqual, x_input_shape[1], prim_name); 5050 } 5051 5052- std::vector<int64_t> y_shape = {x_input_shape[0], x_input_shape[1], hidden_size * num_directions}; 5053+ auto real_hidden_size = proj_size > 0 ? proj_size : hidden_size; 5054+ std::vector<int64_t> y_shape = {x_input_shape[0], x_input_shape[1], real_hidden_size * num_directions}; 5055 std::vector<int64_t> h_shape = {h_input_shape}; 5056 std::vector<int64_t> c_shape = {c_input_shape}; 5057 std::vector<int64_t> reverse_shape = {1, 1}; 5058@@ -135,6 +137,11 @@ void LSTM::set_hidden_size(const int64_t hidden_size) { 5059 (void)AddAttr(kHidden_size, api::MakeValue(hidden_size)); 5060 } 5061 int64_t LSTM::get_hidden_size() const { return GetValue<int64_t>(GetAttr(kHidden_size)); } 5062+void LSTM::set_proj_size(const int64_t proj_size) { 5063+ (void)CheckAndConvertUtils::CheckInteger(kProjection_size, proj_size, kGreaterThan, 0, this->name()); 5064+ (void)AddAttr(kProjection_size, api::MakeValue(proj_size)); 5065+} 5066+int64_t LSTM::get_proj_size() const { return GetValue<int64_t>(GetAttr(kProjection_size)); } 5067 void LSTM::set_num_layers(const int64_t num_layers) { 5068 (void)CheckAndConvertUtils::CheckInteger(kNumLayers, num_layers, kGreaterThan, 0, this->name()); 5069 (void)AddAttr(kNumLayers, api::MakeValue(num_layers)); 5070diff --git a/mindspore/core/ops/lstm.h b/mindspore/core/ops/lstm.h 5071index 4d3c8756..e32c5781 100644 5072--- a/mindspore/core/ops/lstm.h 5073+++ b/mindspore/core/ops/lstm.h 5074@@ -51,6 +51,12 @@ class MIND_API LSTM : public BaseOperator { 5075 /// 5076 /// \return hidden_size. 5077 int64_t get_hidden_size() const; 5078+ /// \brief Set proj_size. 5079+ void set_proj_size(const int64_t proj_size); 5080+ /// \brief Get proj_size. 5081+ /// 5082+ /// \return proj_size. 5083+ int64_t get_proj_size() const; 5084 /// \brief Set num_layers. 5085 void set_num_layers(const int64_t num_layers); 5086 /// \brief Get num_layers. 5087diff --git a/mindspore/core/ops/op_name.h b/mindspore/core/ops/op_name.h 5088index ce68079f..ad9066e7 100644 5089--- a/mindspore/core/ops/op_name.h 5090+++ b/mindspore/core/ops/op_name.h 5091@@ -268,6 +268,7 @@ constexpr auto kWindowSize = "window_size"; 5092 constexpr auto kPaddings = "paddings"; 5093 constexpr auto kInput_size = "input_size"; 5094 constexpr auto kHidden_size = "hidden_size"; 5095+constexpr auto kProjection_size = "proj_size"; 5096 constexpr auto kChannelShared = "channel_shared"; 5097 constexpr auto kSlope = "slope"; 5098 constexpr auto kBase = "base"; 5099diff --git a/mindspore/lite/BUILD.gn b/mindspore/lite/BUILD.gn 5100index f7e465e2..9318d54e 100644 5101--- a/mindspore/lite/BUILD.gn 5102+++ b/mindspore/lite/BUILD.gn 5103@@ -602,6 +602,8 @@ all_train_sources = [ 5104 "src/train/optimizer/fusion/matmul_activation_fusion_pass.cc", 5105 "src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.cc", 5106 "src/train/optimizer/fusion/gru_fusion_pass.cc", 5107+ "src/train/optimizer/fusion/matmul_add_fusion_pass.cc", 5108+ "src/train/optimizer/fusion/matmul_matmul_add_fusion_pass.cc", 5109 "src/common/storage.cc", 5110 "tools/converter/optimizer.cc", 5111 "tools/converter/legacy_optimizer/fusion/fusion_pass.cc", 5112@@ -646,6 +648,7 @@ fp32_train_kernel_sources = [ 5113 "src/litert/kernel/cpu/fp32_grad/convolution.cc", 5114 "src/litert/kernel/cpu/fp32_grad/convolution_grad_filter.cc", 5115 "src/litert/kernel/cpu/fp32_grad/convolution_grad_input.cc", 5116+ "src/litert/kernel/cpu/fp32_grad/custom_gather_d_grad_v2_fp32.cc", 5117 "src/litert/kernel/cpu/fp32_grad/deconvolution_grad_filter.cc", 5118 "src/litert/kernel/cpu/fp32_grad/dropout.cc", 5119 "src/litert/kernel/cpu/fp32_grad/dropout_grad.cc", 5120diff --git a/mindspore/lite/CMakeLists.txt b/mindspore/lite/CMakeLists.txt 5121index 1faf2f38..f2b5809f 100644 5122--- a/mindspore/lite/CMakeLists.txt 5123+++ b/mindspore/lite/CMakeLists.txt 5124@@ -977,7 +977,7 @@ if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "lite" OR MSLITE_MINDDATA_IMPLEMENT STREQU 5125 endif() 5126 5127 add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/src/common/ops) 5128-if(ANDROID_NDK_TOOLCHAIN_INCLUDED OR TARGET_OHOS_LITE OR TARGET_HIMIX) 5129+if(ANDROID_NDK_TOOLCHAIN_INCLUDED OR TARGET_OHOS_LITE OR TARGET_HIMIX OR TARGET_OHOS) 5130 add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter/micro/coder) 5131 endif() 5132 5133diff --git a/mindspore/lite/schema/inner/ops_generated.h b/mindspore/lite/schema/inner/ops_generated.h 5134index c4fd8c15..6c861aa5 100644 5135--- a/mindspore/lite/schema/inner/ops_generated.h 5136+++ b/mindspore/lite/schema/inner/ops_generated.h 5137@@ -11338,6 +11338,7 @@ struct LSTMT : public flatbuffers::NativeTable { 5138 float dropout = 0.0f; 5139 float zoneout_cell = 0.0f; 5140 float zoneout_hidden = 0.0f; 5141+ int64_t proj_size = 0; 5142 }; 5143 5144 struct LSTM FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { 5145@@ -11355,7 +11356,8 @@ struct LSTM FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { 5146 VT_NUM_DIRECTIONS = 14, 5147 VT_DROPOUT = 16, 5148 VT_ZONEOUT_CELL = 18, 5149- VT_ZONEOUT_HIDDEN = 20 5150+ VT_ZONEOUT_HIDDEN = 20, 5151+ VT_PROJ_SIZE = 22 5152 }; 5153 bool bidirectional() const { 5154 return GetField<uint8_t>(VT_BIDIRECTIONAL, 0) != 0; 5155@@ -11411,6 +11413,12 @@ struct LSTM FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { 5156 bool mutate_zoneout_hidden(float _zoneout_hidden) { 5157 return SetField<float>(VT_ZONEOUT_HIDDEN, _zoneout_hidden, 0.0f); 5158 } 5159+ int64_t proj_size() const { 5160+ return GetField<int64_t>(VT_PROJ_SIZE, 0); 5161+ } 5162+ bool mutate_proj_size(int64_t _proj_size) { 5163+ return SetField<int64_t>(VT_PROJ_SIZE, _proj_size, 0); 5164+ } 5165 bool Verify(flatbuffers::Verifier &verifier) const { 5166 return VerifyTableStart(verifier) && 5167 VerifyField<uint8_t>(verifier, VT_BIDIRECTIONAL) && 5168@@ -11422,6 +11430,7 @@ struct LSTM FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { 5169 VerifyField<float>(verifier, VT_DROPOUT) && 5170 VerifyField<float>(verifier, VT_ZONEOUT_CELL) && 5171 VerifyField<float>(verifier, VT_ZONEOUT_HIDDEN) && 5172+ VerifyField<int64_t>(verifier, VT_PROJ_SIZE) && 5173 verifier.EndTable(); 5174 } 5175 LSTMT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; 5176@@ -11460,6 +11469,9 @@ struct LSTMBuilder { 5177 void add_zoneout_hidden(float zoneout_hidden) { 5178 fbb_.AddElement<float>(LSTM::VT_ZONEOUT_HIDDEN, zoneout_hidden, 0.0f); 5179 } 5180+ void add_proj_size(int64_t proj_size) { 5181+ fbb_.AddElement<int64_t>(LSTM::VT_PROJ_SIZE, proj_size, 0); 5182+ } 5183 explicit LSTMBuilder(flatbuffers::FlatBufferBuilder &_fbb) 5184 : fbb_(_fbb) { 5185 start_ = fbb_.StartTable(); 5186@@ -11481,8 +11493,10 @@ inline flatbuffers::Offset<LSTM> CreateLSTM( 5187 int64_t num_directions = 0, 5188 float dropout = 0.0f, 5189 float zoneout_cell = 0.0f, 5190- float zoneout_hidden = 0.0f) { 5191+ float zoneout_hidden = 0.0f, 5192+ int64_t proj_size = 0) { 5193 LSTMBuilder builder_(_fbb); 5194+ builder_.add_proj_size(proj_size); 5195 builder_.add_num_directions(num_directions); 5196 builder_.add_num_layers(num_layers); 5197 builder_.add_hidden_size(hidden_size); 5198@@ -23571,6 +23585,7 @@ inline void LSTM::UnPackTo(LSTMT *_o, const flatbuffers::resolver_function_t *_r 5199 { auto _e = dropout(); _o->dropout = _e; } 5200 { auto _e = zoneout_cell(); _o->zoneout_cell = _e; } 5201 { auto _e = zoneout_hidden(); _o->zoneout_hidden = _e; } 5202+ { auto _e = proj_size(); _o->proj_size = _e; } 5203 } 5204 5205 inline flatbuffers::Offset<LSTM> LSTM::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LSTMT* _o, const flatbuffers::rehasher_function_t *_rehasher) { 5206@@ -23590,6 +23605,7 @@ inline flatbuffers::Offset<LSTM> CreateLSTM(flatbuffers::FlatBufferBuilder &_fbb 5207 auto _dropout = _o->dropout; 5208 auto _zoneout_cell = _o->zoneout_cell; 5209 auto _zoneout_hidden = _o->zoneout_hidden; 5210+ auto _proj_size = _o->proj_size; 5211 return mindspore::schema::CreateLSTM( 5212 _fbb, 5213 _bidirectional, 5214@@ -23600,7 +23616,8 @@ inline flatbuffers::Offset<LSTM> CreateLSTM(flatbuffers::FlatBufferBuilder &_fbb 5215 _num_directions, 5216 _dropout, 5217 _zoneout_cell, 5218- _zoneout_hidden); 5219+ _zoneout_hidden, 5220+ _proj_size); 5221 } 5222 5223 inline LSTMGradT *LSTMGrad::UnPack(const flatbuffers::resolver_function_t *_resolver) const { 5224diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs 5225index 76caf810..920c0d31 100644 5226--- a/mindspore/lite/schema/ops.fbs 5227+++ b/mindspore/lite/schema/ops.fbs 5228@@ -688,6 +688,7 @@ table LSTM { 5229 dropout: float; 5230 zoneout_cell: float = 0; 5231 zoneout_hidden: float = 0; 5232+ proj_size: long = 0; 5233 } 5234 5235 table LSTMGrad { 5236diff --git a/mindspore/lite/schema/ops_generated.h b/mindspore/lite/schema/ops_generated.h 5237index 2f792706..8d387e9d 100644 5238--- a/mindspore/lite/schema/ops_generated.h 5239+++ b/mindspore/lite/schema/ops_generated.h 5240@@ -7046,7 +7046,8 @@ struct LSTM FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { 5241 VT_NUM_DIRECTIONS = 14, 5242 VT_DROPOUT = 16, 5243 VT_ZONEOUT_CELL = 18, 5244- VT_ZONEOUT_HIDDEN = 20 5245+ VT_ZONEOUT_HIDDEN = 20, 5246+ VT_PROJ_SIZE = 22 5247 }; 5248 bool bidirectional() const { 5249 return GetField<uint8_t>(VT_BIDIRECTIONAL, 0) != 0; 5250@@ -7075,6 +7076,9 @@ struct LSTM FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { 5251 float zoneout_hidden() const { 5252 return GetField<float>(VT_ZONEOUT_HIDDEN, 0.0f); 5253 } 5254+ int64_t proj_size() const { 5255+ return GetField<int64_t>(VT_PROJ_SIZE, 0); 5256+ } 5257 bool Verify(flatbuffers::Verifier &verifier) const { 5258 return VerifyTableStart(verifier) && 5259 VerifyField<uint8_t>(verifier, VT_BIDIRECTIONAL) && 5260@@ -7086,6 +7090,7 @@ struct LSTM FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { 5261 VerifyField<float>(verifier, VT_DROPOUT) && 5262 VerifyField<float>(verifier, VT_ZONEOUT_CELL) && 5263 VerifyField<float>(verifier, VT_ZONEOUT_HIDDEN) && 5264+ VerifyField<int64_t>(verifier, VT_PROJ_SIZE) && 5265 verifier.EndTable(); 5266 } 5267 }; 5268@@ -7121,6 +7126,9 @@ struct LSTMBuilder { 5269 void add_zoneout_hidden(float zoneout_hidden) { 5270 fbb_.AddElement<float>(LSTM::VT_ZONEOUT_HIDDEN, zoneout_hidden, 0.0f); 5271 } 5272+ void add_proj_size(int64_t proj_size) { 5273+ fbb_.AddElement<int64_t>(LSTM::VT_PROJ_SIZE, proj_size, 0); 5274+ } 5275 explicit LSTMBuilder(flatbuffers::FlatBufferBuilder &_fbb) 5276 : fbb_(_fbb) { 5277 start_ = fbb_.StartTable(); 5278@@ -7142,8 +7150,10 @@ inline flatbuffers::Offset<LSTM> CreateLSTM( 5279 int64_t num_directions = 0, 5280 float dropout = 0.0f, 5281 float zoneout_cell = 0.0f, 5282- float zoneout_hidden = 0.0f) { 5283+ float zoneout_hidden = 0.0f, 5284+ int64_t proj_size = 0) { 5285 LSTMBuilder builder_(_fbb); 5286+ builder_.add_proj_size(proj_size); 5287 builder_.add_num_directions(num_directions); 5288 builder_.add_num_layers(num_layers); 5289 builder_.add_hidden_size(hidden_size); 5290diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt 5291index de1781cd..469bcb6b 100644 5292--- a/mindspore/lite/src/CMakeLists.txt 5293+++ b/mindspore/lite/src/CMakeLists.txt 5294@@ -337,6 +337,8 @@ set(TRAIN_SRC 5295 ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/common/fusion_utils.cc 5296 ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/fusion/gru_fusion_pass.cc 5297 ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/fusion/matmul_activation_fusion_pass.cc 5298+ ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/fusion/matmul_add_fusion_pass.cc 5299+ ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/fusion/matmul_matmul_add_fusion_pass.cc 5300 ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.cc 5301 ${TOOLS_DIR}/converter/optimizer.cc 5302 ${TOOLS_DIR}/converter/legacy_optimizer/fusion/fusion_pass.cc 5303diff --git a/mindspore/lite/src/common/ops/ops_def.cc b/mindspore/lite/src/common/ops/ops_def.cc 5304index e5c7f5ca..baa2497a 100644 5305--- a/mindspore/lite/src/common/ops/ops_def.cc 5306+++ b/mindspore/lite/src/common/ops/ops_def.cc 5307@@ -688,6 +688,7 @@ OP_ATTR(num_directions, long) 5308 OP_ATTR(dropout, float) 5309 OP_ATTR_WITH_VALUE(zoneout_cell, float, 0) 5310 OP_ATTR_WITH_VALUE(zoneout_hidden, float, 0) 5311+OP_ATTR_WITH_VALUE(proj_size, long, 0) 5312 OP_SCHEMA_DEF_END(LSTM) 5313 5314 OP_SCHEMA_DEF(LSTMGrad) 5315diff --git a/mindspore/lite/src/common/ops/populate/custom_populate.cc b/mindspore/lite/src/common/ops/populate/custom_populate.cc 5316index 13957ed7..6c490130 100644 5317--- a/mindspore/lite/src/common/ops/populate/custom_populate.cc 5318+++ b/mindspore/lite/src/common/ops/populate/custom_populate.cc 5319@@ -22,6 +22,7 @@ 5320 #include "nnacl/custom_masked_fill_parameter.h" 5321 #include "nnacl/custom_is_inf_parameter.h" 5322 #include "nnacl/custom_tensor_scatter_max_parameter.h" 5323+#include "nnacl/custom_gather_d_grad_v2_parameter.h" 5324 using mindspore::schema::PrimitiveType_Custom; 5325 5326 namespace mindspore { 5327@@ -128,6 +129,33 @@ OpParameter *CreateCustomMaskedFillParameter() { 5328 return reinterpret_cast<OpParameter *>(param); 5329 } 5330 5331+OpParameter *CreateCustomGatherDGradV2Parameter(const schema::Custom *value) { 5332+ if (value->attr()->size() < 1) { 5333+ return nullptr; 5334+ } 5335+ auto *param = static_cast<CustomGatherGradV2Parameter *>(malloc(sizeof(CustomGatherGradV2Parameter))); 5336+ if (param == nullptr) { 5337+ MS_LOG(ERROR) << "malloc CustomGruParameter failed."; 5338+ return nullptr; 5339+ } 5340+ 5341+ std::string dim_str; 5342+ auto attrs = value->attr(); 5343+ for (size_t i = 0; i < attrs->size(); i++) { 5344+ auto attr = attrs->Get(i); 5345+ if (attr->name()->str() == "dim") { 5346+ auto data = attr->data(); 5347+ dim_str = std::string(reinterpret_cast<const char *>(data->Data()), data->size()); 5348+ break; 5349+ } 5350+ } 5351+ 5352+ memset(param, 0, sizeof(CustomGatherGradV2Parameter)); 5353+ param->dim = std::stoi(dim_str.c_str()); 5354+ param->op_parameter_.type_ = PrimType_Inner_CustomGatherDGradV2; 5355+ return reinterpret_cast<OpParameter *>(param); 5356+} 5357+ 5358 OpParameter *PopulateCustomParameter(const void *prim) { 5359 MS_CHECK_TRUE_RET(prim != nullptr, nullptr); 5360 auto primitive = static_cast<const schema::Primitive *>(prim); 5361@@ -167,6 +195,8 @@ OpParameter *PopulateCustomParameter(const void *prim) { 5362 return CreateCustomGruParameter(); 5363 } else if (type == "CastGatherReduceFusion") { 5364 return CreateParam(PrimType_Inner_CastGatherReduceFusion); 5365+ } else if (type == "GatherDGradV2") { 5366+ return CreateCustomGatherDGradV2Parameter(value); 5367 } else if (type == "ThirdPartyModel") { 5368 auto *param = static_cast<CustomParameter *>(malloc(sizeof(CustomParameter))); 5369 if (param == nullptr) { 5370diff --git a/mindspore/lite/src/common/ops/populate/lstm_populate.cc b/mindspore/lite/src/common/ops/populate/lstm_populate.cc 5371index 522da7ad..b3a85b64 100644 5372--- a/mindspore/lite/src/common/ops/populate/lstm_populate.cc 5373+++ b/mindspore/lite/src/common/ops/populate/lstm_populate.cc 5374@@ -37,8 +37,12 @@ OpParameter *PopulateLstmParameter(const void *prim) { 5375 5376 param->op_parameter_.type_ = primitive->value_type(); 5377 param->bidirectional_ = value->bidirectional(); 5378+ param->has_bias_ = value->has_bias(); 5379+ param->input_size_ = value->input_size(); 5380+ param->hidden_size_ = value->hidden_size(); 5381 param->zoneout_cell_ = value->zoneout_cell(); 5382 param->zoneout_hidden_ = value->zoneout_hidden(); 5383+ param->project_size_ = value->proj_size(); 5384 return reinterpret_cast<OpParameter *>(param); 5385 } 5386 5387diff --git a/mindspore/lite/src/common/prim_util.cc b/mindspore/lite/src/common/prim_util.cc 5388index 5ded05e9..7263775a 100644 5389--- a/mindspore/lite/src/common/prim_util.cc 5390+++ b/mindspore/lite/src/common/prim_util.cc 5391@@ -29,11 +29,25 @@ static std::set<schema::PrimitiveType> kTensorListOps = { 5392 schema::PrimitiveType_TensorListReserve, schema::PrimitiveType_TensorListSetItem, 5393 schema::PrimitiveType_TensorListStack}; 5394 5395-static const char *const kInnerOpNames[C10NUM] = {"Inner_ToFormat", "Inner_GltextureToOpencl", 5396- "Inner_Identity", "Inner_ShapeFusion", 5397- "Inner_GraphKernel", "Inner_SplitReduceConcatFusion", 5398- "Inner_EncoderLayer", "Inner_DecoderLayer", 5399- "Inner_UsePastEmbedding", "Inner_CustomGru"}; 5400+static const char *const kInnerOpNames[C20NUM] = {"Inner_ToFormat", 5401+ "Inner_GltextureToOpencl", 5402+ "Inner_Identity", 5403+ "Inner_ShapeFusion", 5404+ "Inner_GraphKernel", 5405+ "Inner_SplitReduceConcatFusion", 5406+ "Inner_EncoderLayer", 5407+ "PrimType_Inner_FseDecode", 5408+ "Inner_DecoderLayer", 5409+ "Inner_UsePastEmbedding", 5410+ "Inner_CustomGru", 5411+ "PrimType_Inner_CastGatherReduceFusion", 5412+ "PrimType_Inner_ReduceConcatFusion", 5413+ "PrimType_Inner_ThirdPartyModel", 5414+ "PrimType_Inner_CustomMaskedFill", 5415+ "PrimType_Inner_CustomTensorScatterMax", 5416+ "PrimType_Inner_CustomIsInf", 5417+ "PrimType_Inner_CustomGatherDGradV2"}; 5418+ 5419 int GetPrimitiveType(const void *primitive, int schema_version) { 5420 if (primitive == nullptr) { 5421 return -1; 5422diff --git a/mindspore/lite/src/litert/kernel/cpu/BUILD.gn b/mindspore/lite/src/litert/kernel/cpu/BUILD.gn 5423index 65065b5b..7b813314 100644 5424--- a/mindspore/lite/src/litert/kernel/cpu/BUILD.gn 5425+++ b/mindspore/lite/src/litert/kernel/cpu/BUILD.gn 5426@@ -85,6 +85,9 @@ cpu_kernel_sources = [ 5427 "fp32/invert_permutation_fp32.cc", 5428 "fp32/l2_norm_fp32.cc", 5429 "fp32/lstm_fp32.cc", 5430+ "fp32/lstm_fp32_base.cc", 5431+ "fp32/lstm_mindir_fp32.cc", 5432+ "fp32/lstm_non_mindir_fp32.cc", 5433 "fp32/matmul_fp32_arm32.cc", 5434 "fp32/matmul_fp32_arm64.cc", 5435 "fp32/matmul_fp32_avx512.cc", 5436@@ -174,6 +177,9 @@ fp16_kernel_sources = [ 5437 "fp16/instance_norm_fp16.cc", 5438 "fp16/layout_transform_fp16.cc", 5439 "fp16/lstm_fp16.cc", 5440+ "fp16/lstm_fp16_base.cc", 5441+ "fp16/lstm_mindir_fp16.cc", 5442+ "fp16/lstm_non_mindir_fp16.cc", 5443 "fp16/matmul_base_fp16.cc", 5444 "fp16/matmul_fp16.cc", 5445 "fp16/power_fp16.cc", 5446diff --git a/mindspore/lite/src/litert/kernel/cpu/fp16/gru_fp16.cc b/mindspore/lite/src/litert/kernel/cpu/fp16/gru_fp16.cc 5447index 232bbe44..89945e1c 100644 5448--- a/mindspore/lite/src/litert/kernel/cpu/fp16/gru_fp16.cc 5449+++ b/mindspore/lite/src/litert/kernel/cpu/fp16/gru_fp16.cc 5450@@ -100,10 +100,10 @@ int GruFp16CPUKernel::InitInputWeightBias() { 5451 } 5452 if (weight_g->data_type() == kNumberTypeFloat32) { 5453 PackLstmWeightFp32ToFp16(weight_g_ptr_, reinterpret_cast<float *>(weight_g->data()), weight_batch_, 5454- gru_param_->input_size_, gru_param_->hidden_size_, gru_param_->input_col_align_); 5455+ gru_param_->input_size_, gru_param_->hidden_size_, gru_param_->input_col_align_, nullptr); 5456 } else if (weight_g->data_type() == kNumberTypeFloat16) { 5457 PackLstmWeightFp16(weight_g_ptr_, reinterpret_cast<float16_t *>(weight_g->data()), weight_batch_, 5458- gru_param_->input_size_, gru_param_->hidden_size_, gru_param_->input_col_align_); 5459+ gru_param_->input_size_, gru_param_->hidden_size_, gru_param_->input_col_align_, nullptr); 5460 } else { 5461 MS_LOG(ERROR) << "Unsupported data type of weight_g tensor for gru."; 5462 return RET_ERROR; 5463@@ -120,10 +120,10 @@ int GruFp16CPUKernel::InitInputWeightBias() { 5464 memset(input_bias_, 0, weight_batch_ * gru_param_->input_col_align_ * sizeof(float16_t)); 5465 if (bias->data_type() == kNumberTypeFloat32) { 5466 PackLstmBiasFp32ToFp16(input_bias_, reinterpret_cast<float *>(bias->data()), weight_batch_, 5467- gru_param_->hidden_size_, gru_param_->input_col_align_, gru_param_->bidirectional_); 5468+ gru_param_->hidden_size_, gru_param_->input_col_align_, gru_param_->bidirectional_, nullptr); 5469 } else if (bias->data_type() == kNumberTypeFloat16) { 5470 PackLstmBiasFp16(input_bias_, reinterpret_cast<float16_t *>(bias->data()), weight_batch_, gru_param_->hidden_size_, 5471- gru_param_->input_col_align_, gru_param_->bidirectional_); 5472+ gru_param_->input_col_align_, gru_param_->bidirectional_, nullptr); 5473 } else { 5474 MS_LOG(ERROR) << "Unsupported data type of bias tensor for gru."; 5475 return RET_ERROR; 5476@@ -148,10 +148,10 @@ int GruFp16CPUKernel::InitStateWeightBias() { 5477 if (!is_vec_) { 5478 if (weight_r->data_type() == kNumberTypeFloat32) { 5479 PackLstmWeightFp32ToFp16(weight_r_ptr_, reinterpret_cast<float *>(weight_r->data()), weight_batch_, 5480- gru_param_->hidden_size_, gru_param_->hidden_size_, gru_param_->state_col_align_); 5481+ gru_param_->hidden_size_, gru_param_->hidden_size_, gru_param_->state_col_align_, nullptr); 5482 } else if (weight_r->data_type() == kNumberTypeFloat16) { 5483 PackLstmWeightFp16(weight_r_ptr_, reinterpret_cast<float16_t *>(weight_r->data()), weight_batch_, 5484- gru_param_->hidden_size_, gru_param_->hidden_size_, gru_param_->state_col_align_); 5485+ gru_param_->hidden_size_, gru_param_->hidden_size_, gru_param_->state_col_align_, nullptr); 5486 } else { 5487 MS_LOG(ERROR) << "Unsupported data type of weight_r tensor for gru."; 5488 return RET_ERROR; 5489@@ -179,11 +179,11 @@ int GruFp16CPUKernel::InitStateWeightBias() { 5490 if (bias->data_type() == kNumberTypeFloat32) { 5491 auto state_bias_data = reinterpret_cast<float *>(bias->data()) + gate_num * gru_param_->hidden_size_; 5492 PackLstmBiasFp32ToFp16(state_bias_, state_bias_data, weight_batch_, gru_param_->hidden_size_, 5493- gru_param_->state_col_align_, gru_param_->bidirectional_); 5494+ gru_param_->state_col_align_, gru_param_->bidirectional_, nullptr); 5495 } else if (bias->data_type() == kNumberTypeFloat16) { 5496 auto state_bias_data = reinterpret_cast<float16_t *>(bias->data()) + gate_num * gru_param_->hidden_size_; 5497 PackLstmBiasFp16(state_bias_, state_bias_data, weight_batch_, gru_param_->hidden_size_, 5498- gru_param_->state_col_align_, gru_param_->bidirectional_); 5499+ gru_param_->state_col_align_, gru_param_->bidirectional_, nullptr); 5500 } else { 5501 MS_LOG(ERROR) << "Unsupported data type of bias tensor for gru."; 5502 return RET_ERROR; 5503diff --git a/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_fp16.cc b/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_fp16.cc 5504index b583358a..bd99b791 100644 5505--- a/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_fp16.cc 5506+++ b/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_fp16.cc 5507@@ -1,5 +1,5 @@ 5508 /** 5509- * Copyright 2021 Huawei Technologies Co., Ltd 5510+ * Copyright 2021-2023 Huawei Technologies Co., Ltd 5511 * 5512 * Licensed under the Apache License, Version 2.0 (the "License"); 5513 * you may not use this file except in compliance with the License. 5514@@ -16,13 +16,9 @@ 5515 5516 #include "src/litert/kernel/cpu/fp16/lstm_fp16.h" 5517 #include <vector> 5518-#include <cfloat> 5519-#include "schema/model_generated.h" 5520+#include "src/litert/kernel/cpu/fp16/lstm_mindir_fp16.h" 5521+#include "src/litert/kernel/cpu/fp16/lstm_non_mindir_fp16.h" 5522 #include "src/litert/kernel_registry.h" 5523-#include "include/errorcode.h" 5524-#include "nnacl/fp16/lstm_fp16.h" 5525-#include "nnacl/fp16/cast_fp16.h" 5526-#include "nnacl/errorcode.h" 5527 5528 using mindspore::kernel::KERNEL_ARCH; 5529 using mindspore::lite::KernelRegistrar; 5530@@ -31,389 +27,34 @@ using mindspore::lite::RET_OK; 5531 using mindspore::schema::PrimitiveType_LSTM; 5532 5533 namespace mindspore::kernel { 5534-void LstmFp16CPUKernel::FreeTmpBuffer() { 5535- if (weight_i_ptr_ != nullptr) { 5536- free(weight_i_ptr_); 5537- weight_i_ptr_ = nullptr; 5538- } 5539- if (input_bias_ != nullptr) { 5540- free(input_bias_); 5541- input_bias_ = nullptr; 5542- } 5543- if (weight_h_ptr_ != nullptr) { 5544- free(weight_h_ptr_); 5545- weight_h_ptr_ = nullptr; 5546- } 5547- if (state_bias_ != nullptr) { 5548- free(state_bias_); 5549- state_bias_ = nullptr; 5550- } 5551- if (weight_project_ptr_ != nullptr) { 5552- free(weight_project_ptr_); 5553- weight_project_ptr_ = nullptr; 5554- } 5555- if (project_bias_ != nullptr) { 5556- free(project_bias_); 5557- project_bias_ = nullptr; 5558- } 5559-} 5560- 5561-void LstmFp16CPUKernel::FreeRunBuffer() { 5562- ms_context_->allocator->Free(buffer_[packed_input_index]); 5563- ms_context_->allocator->Free(buffer_[input_gate_index]); 5564- if (!is_vec_) { 5565- ms_context_->allocator->Free(buffer_[packed_state_index]); 5566- } 5567- ms_context_->allocator->Free(buffer_[state_gate_index]); 5568- if (!(lstm_param_->zoneout_cell_ >= -FLT_EPSILON && lstm_param_->zoneout_cell_ <= FLT_EPSILON)) { 5569- ms_context_->allocator->Free(buffer_[cell_state_index]); 5570- } 5571- if (!(lstm_param_->zoneout_hidden_ >= -FLT_EPSILON && lstm_param_->zoneout_hidden_ <= FLT_EPSILON)) { 5572- ms_context_->allocator->Free(buffer_[hidden_state_index]); 5573- } 5574-} 5575- 5576-int LstmFp16CPUKernel::InitParam() { 5577- auto input = in_tensors_.front(); 5578- std::vector<int> in_shape = input->shape(); 5579- lstm_param_->seq_len_ = in_shape.at(0); 5580- lstm_param_->batch_ = in_shape.at(1); 5581- lstm_param_->input_size_ = in_shape.at(kNHWC_W); 5582- 5583- auto weight_i = in_tensors_.at(1); 5584- std::vector<int> w_shape = weight_i->shape(); 5585- NNACL_CHECK_ZERO_RETURN_ERR(gate_num); 5586- lstm_param_->hidden_size_ = w_shape.at(1) / gate_num; 5587- 5588- auto weight_h = in_tensors_.at(C2NUM); 5589- auto h_shape = weight_h->shape(); 5590- lstm_param_->project_size_ = h_shape.back(); 5591- 5592- const int twice = 2; 5593- lstm_param_->output_step_ = lstm_param_->bidirectional_ ? twice * lstm_param_->batch_ * lstm_param_->hidden_size_ 5594- : lstm_param_->batch_ * lstm_param_->hidden_size_; 5595- weight_batch_ = lstm_param_->bidirectional_ ? twice * gate_num : gate_num; 5596- lstm_param_->input_row_align_ = UP_ROUND(lstm_param_->seq_len_ * lstm_param_->batch_, C16NUM); 5597- lstm_param_->input_col_align_ = UP_ROUND(lstm_param_->hidden_size_, C8NUM); 5598- 5599- is_vec_ = lstm_param_->batch_ == 1; 5600- lstm_param_->state_row_align_ = is_vec_ ? lstm_param_->batch_ : UP_ROUND(lstm_param_->batch_, C16NUM); 5601- lstm_param_->state_col_align_ = is_vec_ ? lstm_param_->hidden_size_ : UP_ROUND(lstm_param_->hidden_size_, C8NUM); 5602- return RET_OK; 5603-} 5604- 5605-int LstmFp16CPUKernel::InitInputWeightBias() { 5606- // malloc and init input * weight right matrix buffer 5607- // input -- row: seq_len * batch; col: input_size 5608- // weight -- row: hidden_size; col: input_size, need transpose 5609- // result -- row: seq_len * batch; col: hidden_size 5610- auto weight_i = in_tensors_.at(1); 5611- auto weight_i_data = weight_i->data(); 5612- CHECK_NULL_RETURN(weight_i_data); 5613- weight_i_ptr_ = reinterpret_cast<float16_t *>( 5614- malloc(weight_batch_ * lstm_param_->input_col_align_ * lstm_param_->input_size_ * sizeof(float16_t))); 5615- if (weight_i_ptr_ == nullptr) { 5616- MS_LOG(ERROR) << "LstmFp16CPUKernel malloc weight_i_ptr_ error."; 5617- return RET_ERROR; 5618- } 5619- if (weight_i->data_type() == kNumberTypeFloat32) { 5620- PackLstmWeightFp32ToFp16(weight_i_ptr_, reinterpret_cast<float *>(weight_i_data), weight_batch_, 5621- lstm_param_->input_size_, lstm_param_->hidden_size_, lstm_param_->input_col_align_); 5622- } else if (weight_i->data_type() == kNumberTypeFloat16) { 5623- PackLstmWeightFp16(weight_i_ptr_, reinterpret_cast<float16_t *>(weight_i_data), weight_batch_, 5624- lstm_param_->input_size_, lstm_param_->hidden_size_, lstm_param_->input_col_align_); 5625+namespace { 5626+constexpr size_t kMindirInputTensorNum = 4; 5627+} // namespace 5628+ 5629+LiteKernel *LstmFp16KernelCreator(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, 5630+ OpParameter *parameter, const lite::InnerContext *ctx, const kernel::KernelKey &desc) { 5631+ if (parameter == nullptr) { 5632+ MS_LOG(ERROR) << "parameter is nullptr."; 5633+ return nullptr; 5634+ } 5635+ if (desc.data_type == kTypeUnknown) { 5636+ MS_LOG(WARNING) << "desc data_type is unknown."; 5637+ } 5638+ LiteKernel *kernel{nullptr}; 5639+ if (inputs.size() == kMindirInputTensorNum) { 5640+ kernel = new (std::nothrow) 5641+ LstmMindirFp16CPUKernel(parameter, inputs, outputs, static_cast<const lite::InnerContext *>(ctx)); 5642 } else { 5643- MS_LOG(ERROR) << "Unsupported data type of weight_i tensor for lstm."; 5644- return RET_ERROR; 5645- } 5646- 5647- // input bias 5648- auto bias = in_tensors_.at(FOURTH_INPUT); 5649- auto bias_data = bias->data(); 5650- CHECK_NULL_RETURN(bias_data); 5651- input_bias_ = 5652- reinterpret_cast<float16_t *>(malloc(weight_batch_ * lstm_param_->input_col_align_ * sizeof(float16_t))); 5653- if (input_bias_ == nullptr) { 5654- MS_LOG(ERROR) << "LstmFp16CPUKernel malloc input_bias_ error."; 5655- return RET_ERROR; 5656- } 5657- memset(input_bias_, 0, weight_batch_ * lstm_param_->input_col_align_ * sizeof(float16_t)); 5658- if (bias->data_type() == kNumberTypeFloat32) { 5659- PackLstmBiasFp32ToFp16(input_bias_, reinterpret_cast<float *>(bias_data), weight_batch_, lstm_param_->hidden_size_, 5660- lstm_param_->input_col_align_, lstm_param_->bidirectional_); 5661- } else if (bias->data_type() == kNumberTypeFloat16) { 5662- PackLstmBiasFp16(input_bias_, reinterpret_cast<float16_t *>(bias_data), weight_batch_, lstm_param_->hidden_size_, 5663- lstm_param_->input_col_align_, lstm_param_->bidirectional_); 5664- } else { 5665- MS_LOG(ERROR) << "Unsupported data type of bias tensor for lstm."; 5666- return RET_ERROR; 5667- } 5668- return RET_OK; 5669-} 5670- 5671-int LstmFp16CPUKernel::InitStateWeightBias() { 5672- // malloc and init state * weight right matrix buffer, state * weight will be executed seq_len_ times. 5673- // state -- row: batch; col: hidden_size 5674- // weight -- row: hidden_size; col: hidden_size, need transpose 5675- // result -- row: batch; col: hidden_size 5676- auto weight_h = in_tensors_.at(THIRD_INPUT); 5677- auto weight_h_data = weight_h->data(); 5678- CHECK_NULL_RETURN(weight_h_data); 5679- weight_h_ptr_ = reinterpret_cast<float16_t *>( 5680- malloc(weight_batch_ * lstm_param_->state_col_align_ * lstm_param_->project_size_ * sizeof(float16_t))); 5681- if (weight_h_ptr_ == nullptr) { 5682- MS_LOG(ERROR) << "LstmFp16CPUKernel malloc weight_h_ptr_ error."; 5683- return RET_ERROR; 5684- } 5685- 5686- if (!is_vec_) { 5687- if (weight_h->data_type() == kNumberTypeFloat32) { 5688- PackLstmWeightFp32ToFp16(weight_h_ptr_, reinterpret_cast<float *>(weight_h_data), weight_batch_, 5689- lstm_param_->project_size_, lstm_param_->hidden_size_, lstm_param_->state_col_align_); 5690- } else if (weight_h->data_type() == kNumberTypeFloat16) { 5691- PackLstmWeightFp16(weight_h_ptr_, reinterpret_cast<float16_t *>(weight_h_data), weight_batch_, 5692- lstm_param_->project_size_, lstm_param_->hidden_size_, lstm_param_->state_col_align_); 5693- } else { 5694- MS_LOG(ERROR) << "Unsupported data type of weight_h tensor for lstm."; 5695- return RET_ERROR; 5696- } 5697- } else { 5698- if (weight_h->data_type() == kNumberTypeFloat32) { 5699- Float32ToFloat16(reinterpret_cast<float *>(weight_h_data), weight_h_ptr_, weight_h->ElementsNum()); 5700- } else if (weight_h->data_type() == kNumberTypeFloat16) { 5701- memcpy(weight_h_ptr_, reinterpret_cast<float16_t *>(weight_h_data), weight_h->Size()); 5702- } else { 5703- MS_LOG(ERROR) << "Unsupported data type of weight_h tensor for lstm."; 5704- return RET_ERROR; 5705- } 5706- } 5707- 5708- // state bias 5709- auto bias = in_tensors_.at(FOURTH_INPUT); 5710- auto bias_data = bias->data(); 5711- CHECK_NULL_RETURN(bias_data); 5712- state_bias_ = 5713- reinterpret_cast<float16_t *>(malloc(weight_batch_ * lstm_param_->state_col_align_ * sizeof(float16_t))); 5714- if (state_bias_ == nullptr) { 5715- MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state_bias_ error."; 5716- return RET_ERROR; 5717- } 5718- memset(state_bias_, 0, weight_batch_ * lstm_param_->state_col_align_ * sizeof(float16_t)); 5719- if (bias->data_type() == kNumberTypeFloat32) { 5720- auto state_bias_data = reinterpret_cast<float *>(bias_data) + gate_num * lstm_param_->hidden_size_; 5721- PackLstmBiasFp32ToFp16(state_bias_, state_bias_data, weight_batch_, lstm_param_->hidden_size_, 5722- lstm_param_->state_col_align_, lstm_param_->bidirectional_); 5723- } else if (bias->data_type() == kNumberTypeFloat16) { 5724- auto state_bias_data = reinterpret_cast<float16_t *>(bias_data) + gate_num * lstm_param_->hidden_size_; 5725- PackLstmBiasFp16(state_bias_, state_bias_data, weight_batch_, lstm_param_->hidden_size_, 5726- lstm_param_->state_col_align_, lstm_param_->bidirectional_); 5727- } else { 5728- MS_LOG(ERROR) << "Unsupported data type of bias tensor for lstm."; 5729- return RET_ERROR; 5730- } 5731- return RET_OK; 5732-} 5733- 5734-int LstmFp16CPUKernel::InitProjectWeight() { 5735- if (in_tensors_.size() < C7NUM) { 5736- return RET_OK; 5737- } 5738- auto weight_pro = in_tensors_.at(SEVENTH_INPUT); 5739- auto shape = weight_pro->shape(); 5740- if (shape.size() != C3NUM) { 5741- MS_LOG(ERROR) << "Project-weight's shape must be 3D."; 5742- return RET_ERROR; 5743- } 5744- auto weight_pro_data = weight_pro->data(); 5745- CHECK_NULL_RETURN(weight_pro_data); 5746- int batch = lstm_param_->bidirectional_ ? C2NUM : C1NUM; 5747- if (shape[0] != batch) { 5748- MS_LOG(ERROR) << "Project-weight's shape[0] must be 1(bidirectional=false) or 2(bidirectional=true)."; 5749- return RET_ERROR; 5750+ kernel = new (std::nothrow) 5751+ LstmNonMindirFp16CPUKernel(parameter, inputs, outputs, static_cast<const lite::InnerContext *>(ctx)); 5752 } 5753- int pro_col_align = is_vec_ ? lstm_param_->project_size_ : UP_ROUND(lstm_param_->project_size_, C8NUM); 5754- weight_project_ptr_ = 5755- reinterpret_cast<float16_t *>(malloc(batch * lstm_param_->hidden_size_ * pro_col_align * sizeof(float16_t))); 5756- if (weight_project_ptr_ == nullptr) { 5757- MS_LOG(ERROR) << "LstmFp16CPUKernel malloc weight_project_ptr_ error."; 5758- return RET_ERROR; 5759- } 5760- 5761- if (!is_vec_) { 5762- if (weight_pro->data_type() == kNumberTypeFloat32) { 5763- PackLstmWeightFp32ToFp16(weight_project_ptr_, reinterpret_cast<float *>(weight_pro_data), batch, 5764- lstm_param_->hidden_size_, lstm_param_->project_size_, pro_col_align); 5765- } else if (weight_pro->data_type() == kNumberTypeFloat16) { 5766- PackLstmWeightFp16(weight_project_ptr_, reinterpret_cast<float16_t *>(weight_pro_data), batch, 5767- lstm_param_->hidden_size_, lstm_param_->project_size_, pro_col_align); 5768- } else { 5769- MS_LOG(ERROR) << "Unsupported data type of weight_project tensor for lstm."; 5770- return RET_ERROR; 5771- } 5772- } else { 5773- if (weight_pro->data_type() == kNumberTypeFloat32) { 5774- Float32ToFloat16(reinterpret_cast<float *>(weight_pro_data), weight_project_ptr_, weight_pro->ElementsNum()); 5775- } else if (weight_pro->data_type() == kNumberTypeFloat16) { 5776- memcpy(weight_project_ptr_, weight_pro_data, weight_pro->Size()); 5777- } else { 5778- MS_LOG(ERROR) << "Unsupported data type of weight_project tensor for lstm."; 5779- return RET_ERROR; 5780- } 5781- } 5782- size_t bias_size = UP_ROUND(lstm_param_->project_size_, C8NUM) * sizeof(float16_t); 5783- project_bias_ = reinterpret_cast<float16_t *>(malloc(bias_size)); 5784- if (project_bias_ == nullptr) { 5785- MS_LOG(ERROR) << "LstmFp16CPUKernel malloc project_bias_ error."; 5786- return RET_ERROR; 5787- } 5788- (void)memset(project_bias_, 0, bias_size); 5789- return RET_OK; 5790-} 5791- 5792-int LstmFp16CPUKernel::Prepare() { 5793- CHECK_LESS_RETURN(in_tensors_.size(), C6NUM); 5794- for (size_t i = 0; i < in_tensors_.size(); i++) { 5795- CHECK_NULL_RETURN(in_tensors_.at(i)); 5796- } 5797- CHECK_LESS_RETURN(out_tensors_.size(), C3NUM); 5798- for (size_t i = 0; i < out_tensors_.size(); i++) { 5799- CHECK_NULL_RETURN(out_tensors_.at(i)); 5800- } 5801- CHECK_NULL_RETURN(lstm_param_); 5802- if (!InferShapeDone()) { 5803- return RET_OK; 5804- } 5805- return ReSize(); 5806-} 5807- 5808-int LstmFp16CPUKernel::ReSize() { 5809- auto ret = InitParam(); 5810- if (ret != RET_OK) { 5811- MS_LOG(ERROR) << "Lstm fp16 InitParam error."; 5812- return RET_ERROR; 5813- } 5814- 5815- FreeTmpBuffer(); 5816- ret = InitInputWeightBias(); 5817- if (ret != RET_OK) { 5818- MS_LOG(ERROR) << "Lstm fp16 InitInputWeightBias error."; 5819- FreeTmpBuffer(); 5820- return RET_ERROR; 5821- } 5822- 5823- ret = InitStateWeightBias(); 5824- if (ret != RET_OK) { 5825- MS_LOG(ERROR) << "Lstm fp16 InitStateWeightBias error."; 5826- FreeTmpBuffer(); 5827- return RET_ERROR; 5828- } 5829- 5830- ret = InitProjectWeight(); 5831- if (ret != RET_OK) { 5832- MS_LOG(ERROR) << "Lstm fp16 InitProjectWeight error."; 5833- FreeTmpBuffer(); 5834- return RET_ERROR; 5835- } 5836- return RET_OK; 5837-} 5838- 5839-int LstmFp16CPUKernel::MallocRunBuffer() { 5840- for (int i = 0; i < C7NUM; i++) { 5841- buffer_[i] = nullptr; 5842- } 5843- buffer_[packed_input_index] = reinterpret_cast<float16_t *>( 5844- ms_context_->allocator->Malloc(lstm_param_->input_row_align_ * lstm_param_->input_size_ * sizeof(float16_t))); 5845- if (buffer_[packed_input_index] == nullptr) { 5846- MS_LOG(ERROR) << "LstmFp16CPUKernel malloc input * weight left matirx error."; 5847- return RET_ERROR; 5848- } 5849- 5850- buffer_[input_gate_index] = reinterpret_cast<float16_t *>(ms_context_->allocator->Malloc( 5851- gate_num * lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t))); 5852- if (buffer_[input_gate_index] == nullptr) { 5853- MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state * weight left matirx error."; 5854- return RET_ERROR; 5855- } 5856- 5857- if (!is_vec_) { 5858- buffer_[packed_state_index] = reinterpret_cast<float16_t *>( 5859- ms_context_->allocator->Malloc(lstm_param_->state_row_align_ * lstm_param_->project_size_ * sizeof(float16_t))); 5860- if (buffer_[packed_state_index] == nullptr) { 5861- MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state * weight left matirx error."; 5862- return RET_ERROR; 5863- } 5864- } 5865- 5866- buffer_[state_gate_index] = reinterpret_cast<float16_t *>( 5867- ms_context_->allocator->Malloc(gate_num * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t))); 5868- if (buffer_[state_gate_index] == nullptr) { 5869- MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state gate buffer_ error."; 5870- return RET_ERROR; 5871- } 5872- 5873- if (!(lstm_param_->zoneout_cell_ >= -FLT_EPSILON && lstm_param_->zoneout_cell_ <= FLT_EPSILON)) { 5874- int buffer_size = lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t); 5875- buffer_[cell_state_index] = reinterpret_cast<float16_t *>(ms_context_->allocator->Malloc(buffer_size)); 5876- if (buffer_[cell_state_index] == nullptr) { 5877- MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state_buffer for cell error."; 5878- return RET_ERROR; 5879- } 5880- } 5881- if (!(lstm_param_->zoneout_hidden_ >= -FLT_EPSILON && lstm_param_->zoneout_hidden_ <= FLT_EPSILON)) { 5882- int buffer_size = lstm_param_->batch_ * lstm_param_->project_size_ * sizeof(float16_t); 5883- buffer_[hidden_state_index] = reinterpret_cast<float16_t *>(ms_context_->allocator->Malloc(buffer_size)); 5884- if (buffer_[hidden_state_index] == nullptr) { 5885- MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state_buffer for hidden error."; 5886- return RET_ERROR; 5887- } 5888- } 5889- if (!is_vec_ && in_tensors_.size() == C7NUM) { 5890- buffer_[project_input_index] = reinterpret_cast<float16_t *>( 5891- ms_context_->allocator->Malloc(lstm_param_->state_row_align_ * lstm_param_->hidden_size_ * sizeof(float16_t))); 5892- if (buffer_[project_input_index] == nullptr) { 5893- MS_LOG(ERROR) << "LstmFp16CPUKernel malloc project_buffer for hidden error."; 5894- return RET_ERROR; 5895- } 5896- } 5897- return RET_OK; 5898-} 5899- 5900-int LstmFp16CPUKernel::Run() { 5901- auto input = in_tensors_.at(0); 5902- auto input_ptr = reinterpret_cast<float16_t *>(input->data()); 5903- CHECK_NULL_RETURN(input_ptr); 5904- auto output = out_tensors_.at(0); 5905- auto output_ptr = reinterpret_cast<float16_t *>(output->data()); 5906- CHECK_NULL_RETURN(output_ptr); 5907- 5908- auto hidden_state = in_tensors_.at(FIFTH_INPUT); 5909- CHECK_NULL_RETURN(hidden_state->data()); 5910- auto cell_state = in_tensors_.at(SIXTH_INPUT); 5911- CHECK_NULL_RETURN(cell_state->data()); 5912- 5913- auto output_hidden_state = out_tensors_[1]; 5914- CHECK_NULL_RETURN(output_hidden_state->data()); 5915- memcpy(output_hidden_state->data(), hidden_state->data(), hidden_state->ElementsNum() * sizeof(float16_t)); 5916- auto output_cell_state = out_tensors_[THIRD_INPUT]; 5917- CHECK_NULL_RETURN(output_cell_state->data()); 5918- memcpy(output_cell_state->data(), cell_state->data(), cell_state->ElementsNum() * sizeof(float16_t)); 5919- 5920- auto ret = MallocRunBuffer(); 5921- if (ret != RET_OK) { 5922- MS_LOG(ERROR) << "LstmFp16CPUKernel MallocRunBuffer error."; 5923- FreeRunBuffer(); 5924- return RET_ERROR; 5925+ if (kernel == nullptr) { 5926+ MS_LOG(ERROR) << "kernel: " << parameter->name_ << "is nullptr."; 5927+ free(parameter); 5928+ return nullptr; 5929 } 5930- CHECK_NULL_RETURN(weight_i_ptr_); 5931- CHECK_NULL_RETURN(weight_h_ptr_); 5932- CHECK_NULL_RETURN(input_bias_); 5933- CHECK_NULL_RETURN(state_bias_); 5934- LstmFp16(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, input_bias_, state_bias_, weight_project_ptr_, 5935- project_bias_, reinterpret_cast<float16_t *>(output_hidden_state->data()), 5936- reinterpret_cast<float16_t *>(output_cell_state->data()), buffer_, lstm_param_); 5937- FreeRunBuffer(); 5938- return RET_OK; 5939+ return kernel; 5940 } 5941 5942-REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LSTM, LiteKernelCreator<LstmFp16CPUKernel>) 5943+REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LSTM, LstmFp16KernelCreator) 5944 } // namespace mindspore::kernel 5945diff --git a/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_fp16_base.cc b/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_fp16_base.cc 5946new file mode 100644 5947index 00000000..767fdef3 5948--- /dev/null 5949+++ b/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_fp16_base.cc 5950@@ -0,0 +1,270 @@ 5951+/** 5952+ * Copyright 2023 Huawei Technologies Co., Ltd 5953+ * 5954+ * Licensed under the Apache License, Version 2.0 (the "License"); 5955+ * you may not use this file except in compliance with the License. 5956+ * You may obtain a copy of the License at 5957+ * 5958+ * http://www.apache.org/licenses/LICENSE-2.0 5959+ * 5960+ * Unless required by applicable law or agreed to in writing, software 5961+ * distributed under the License is distributed on an "AS IS" BASIS, 5962+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 5963+ * See the License for the specific language governing permissions and 5964+ * limitations under the License. 5965+ */ 5966+ 5967+#include "src/litert/kernel/cpu/fp16/lstm_fp16_base.h" 5968+#include <cfloat> 5969+#include "nnacl/fp16/lstm_fp16.h" 5970+ 5971+using mindspore::lite::RET_ERROR; 5972+using mindspore::lite::RET_OK; 5973+ 5974+namespace mindspore::kernel { 5975+namespace { 5976+constexpr int kGateNum = 4; 5977+constexpr int kTempInputBufferIndex = 0; 5978+constexpr int kTempInputGateBufferIndex = 1; 5979+constexpr int kTempStateBufferIndex = 2; 5980+constexpr int kTempStateGateBufferIndex = 3; 5981+constexpr int kTempCellStateBufferIndex = 4; 5982+constexpr int kTempHiddenStateBufferIndex = 5; 5983+constexpr int kTempProjectInputBufferIndex = 6; 5984+} // namespace 5985+ 5986+LstmFp16BaseCPUKernel::~LstmFp16BaseCPUKernel() { FreePackBuffer(); } 5987+ 5988+int LstmFp16BaseCPUKernel::Prepare() { 5989+ for (size_t i = 0; i < in_tensors_.size(); ++i) { 5990+ CHECK_NULL_RETURN(in_tensors_[i]); 5991+ } 5992+ CHECK_LESS_RETURN(out_tensors_.size(), C3NUM); 5993+ for (size_t i = 0; i < out_tensors_.size(); ++i) { 5994+ CHECK_NULL_RETURN(out_tensors_[i]); 5995+ } 5996+ CHECK_NULL_RETURN(lstm_param_); 5997+ if (!InferShapeDone()) { 5998+ return RET_OK; 5999+ } 6000+ return ReSize(); 6001+} 6002+ 6003+int LstmFp16BaseCPUKernel::ReSize() { 6004+ auto ret = InitParam(); 6005+ if (ret != RET_OK) { 6006+ MS_LOG(ERROR) << "LstmFp16 InitParam failed."; 6007+ return RET_ERROR; 6008+ } 6009+ if (running_pack_) { 6010+ return RET_OK; 6011+ } 6012+ return PackWeightAndBias(); 6013+} 6014+ 6015+int LstmFp16BaseCPUKernel::Run() { 6016+ auto input_ptr = reinterpret_cast<float16_t *>(in_tensors_[FIRST_INPUT]->data()); 6017+ CHECK_NULL_RETURN(input_ptr); 6018+ auto output_ptr = reinterpret_cast<float16_t *>(out_tensors_[FIRST_INPUT]->data()); 6019+ CHECK_NULL_RETURN(output_ptr); 6020+ 6021+ auto hidden_init = in_tensors_[hidden_init_index_]->data(); 6022+ CHECK_NULL_RETURN(hidden_init); 6023+ auto cell_init = in_tensors_[cell_init_index_]->data(); 6024+ CHECK_NULL_RETURN(cell_init); 6025+ 6026+ auto output_hidden = out_tensors_[SECOND_INPUT]->data(); 6027+ CHECK_NULL_RETURN(output_hidden); 6028+ (void)memcpy(output_hidden, hidden_init, in_tensors_[hidden_init_index_]->ElementsNum() * sizeof(float16_t)); 6029+ auto output_cell = out_tensors_[THIRD_INPUT]->data(); 6030+ CHECK_NULL_RETURN(output_cell); 6031+ (void)memcpy(output_cell, cell_init, in_tensors_[cell_init_index_]->ElementsNum() * sizeof(float16_t)); 6032+ 6033+ if (running_pack_) { 6034+ auto ret = PackWeightAndBias(); 6035+ if (ret != lite::RET_OK) { 6036+ MS_LOG(ERROR) << "LstmFp16 PackWeightAndBias failed."; 6037+ return ret; 6038+ } 6039+ } 6040+ auto ret = MallocRunBuffer(); 6041+ if (ret != RET_OK) { 6042+ MS_LOG(ERROR) << "LstmFp16CPUKernel MallocRunBuffer error."; 6043+ FreeRunBuffer(); 6044+ if (running_pack_) { 6045+ FreePackBuffer(); 6046+ } 6047+ return RET_ERROR; 6048+ } 6049+ LstmFp16(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, input_bias_, state_bias_, weight_project_ptr_, 6050+ project_bias_, reinterpret_cast<float16_t *>(output_hidden), reinterpret_cast<float16_t *>(output_cell), 6051+ running_buffer_, lstm_param_); 6052+ FreeRunBuffer(); 6053+ if (running_pack_) { 6054+ FreePackBuffer(); 6055+ } 6056+ return RET_OK; 6057+} 6058+ 6059+int LstmFp16BaseCPUKernel::InitParam() { 6060+ auto in_shape = in_tensors_[FIRST_INPUT]->shape(); 6061+ MS_CHECK_TRUE_MSG(in_shape.size() == C3NUM, lite::RET_INPUT_TENSOR_ERROR, 6062+ "The dims of LSTM's first input must be 3."); 6063+ lstm_param_->seq_len_ = in_shape[0]; 6064+ lstm_param_->batch_ = in_shape[1]; 6065+ lstm_param_->input_size_ = in_shape.back(); 6066+ 6067+ auto h_init_shape = in_tensors_.at(hidden_init_index_)->shape(); 6068+ auto c_init_shape = in_tensors_.at(cell_init_index_)->shape(); 6069+ lstm_param_->hidden_size_ = c_init_shape.back(); 6070+ lstm_param_->output_size_ = h_init_shape.back(); 6071+ 6072+ lstm_param_->output_step_ = lstm_param_->bidirectional_ ? C2NUM * lstm_param_->batch_ * lstm_param_->output_size_ 6073+ : lstm_param_->batch_ * lstm_param_->output_size_; 6074+ weight_segment_num_ = lstm_param_->bidirectional_ ? C2NUM * kGateNum : kGateNum; 6075+#ifdef ENABLE_ARM64 6076+ lstm_param_->input_row_align_ = UP_ROUND(lstm_param_->seq_len_ * lstm_param_->batch_, C1NUM); 6077+ lstm_param_->input_col_align_ = UP_ROUND(lstm_param_->hidden_size_, C4NUM); 6078+ 6079+ lstm_param_->state_row_align_ = UP_ROUND(lstm_param_->batch_, C1NUM); 6080+ lstm_param_->state_col_align_ = UP_ROUND(lstm_param_->hidden_size_, C4NUM); 6081+ lstm_param_->proj_col_align_ = UP_ROUND(lstm_param_->output_size_, C4NUM); 6082+ weight_need_pack_ = true; 6083+#else 6084+ lstm_param_->input_row_align_ = UP_ROUND(lstm_param_->seq_len_ * lstm_param_->batch_, C16NUM); 6085+ lstm_param_->input_col_align_ = UP_ROUND(lstm_param_->hidden_size_, C8NUM); 6086+ 6087+ lstm_param_->state_row_align_ = 6088+ lstm_param_->batch_ == 1 ? lstm_param_->batch_ : UP_ROUND(lstm_param_->batch_, C16NUM); 6089+ lstm_param_->state_col_align_ = 6090+ lstm_param_->batch_ == 1 ? lstm_param_->hidden_size_ : UP_ROUND(lstm_param_->hidden_size_, C8NUM); 6091+ lstm_param_->proj_col_align_ = 6092+ lstm_param_->batch_ == 1 ? lstm_param_->output_size_ : UP_ROUND(lstm_param_->output_size_, C8NUM); 6093+ weight_need_pack_ = lstm_param_->batch_ != 1; 6094+#endif 6095+ return RET_OK; 6096+} 6097+ 6098+int LstmFp16BaseCPUKernel::PackWeightAndBias() { 6099+ FreePackBuffer(); 6100+ auto ret = InitInputWeightBias(); 6101+ if (ret != RET_OK) { 6102+ MS_LOG(ERROR) << "LstmFp16 InitInputWeightBias failed."; 6103+ FreePackBuffer(); 6104+ return RET_ERROR; 6105+ } 6106+ 6107+ ret = InitStateWeightBias(); 6108+ if (ret != RET_OK) { 6109+ MS_LOG(ERROR) << "LstmFp16 InitStateWeightBias failed."; 6110+ FreePackBuffer(); 6111+ return RET_ERROR; 6112+ } 6113+ 6114+ ret = InitProjectWeight(); 6115+ if (ret != RET_OK) { 6116+ MS_LOG(ERROR) << "LstmFp16 InitProjectWeight failed."; 6117+ FreePackBuffer(); 6118+ return RET_ERROR; 6119+ } 6120+ return RET_OK; 6121+} 6122+ 6123+void LstmFp16BaseCPUKernel::FreePackBuffer() { 6124+ for (auto buffer : pack_buffer_) { 6125+ if (buffer) { 6126+ free(buffer); 6127+ } 6128+ } 6129+ pack_buffer_.clear(); 6130+} 6131+ 6132+int LstmFp16BaseCPUKernel::MallocRunBuffer() { 6133+ for (int i = 0; i < C7NUM; i++) { 6134+ running_buffer_[i] = nullptr; 6135+ } 6136+ bool need_pack_input = true; 6137+#ifdef ENABLE_ARM64 6138+ need_pack_input = lstm_param_->seq_len_ * lstm_param_->batch_ >= C4NUM; 6139+#endif 6140+ if (need_pack_input) { 6141+ running_buffer_[kTempInputBufferIndex] = reinterpret_cast<float16_t *>( 6142+ ms_context_->allocator->Malloc(lstm_param_->input_row_align_ * lstm_param_->input_size_ * sizeof(float16_t))); 6143+ if (running_buffer_[kTempInputBufferIndex] == nullptr) { 6144+ MS_LOG(ERROR) << "LstmFp16CPUKernel malloc input * weight left matirx error."; 6145+ return RET_ERROR; 6146+ } 6147+ } 6148+ 6149+ running_buffer_[kTempInputGateBufferIndex] = reinterpret_cast<float16_t *>(ms_context_->allocator->Malloc( 6150+ kGateNum * lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t))); 6151+ if (running_buffer_[kTempInputGateBufferIndex] == nullptr) { 6152+ MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state * weight left matirx error."; 6153+ return RET_ERROR; 6154+ } 6155+ 6156+ need_pack_input = lstm_param_->batch_ != 1; 6157+#ifdef ENABLE_ARM64 6158+ need_pack_input = lstm_param_->batch_ >= C4NUM; 6159+#endif 6160+ if (need_pack_input) { 6161+ running_buffer_[kTempStateBufferIndex] = reinterpret_cast<float16_t *>( 6162+ ms_context_->allocator->Malloc(lstm_param_->state_row_align_ * lstm_param_->output_size_ * sizeof(float16_t))); 6163+ if (running_buffer_[kTempStateBufferIndex] == nullptr) { 6164+ MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state * weight left matirx error."; 6165+ return RET_ERROR; 6166+ } 6167+ } 6168+ 6169+ running_buffer_[kTempStateGateBufferIndex] = reinterpret_cast<float16_t *>( 6170+ ms_context_->allocator->Malloc(kGateNum * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t))); 6171+ if (running_buffer_[kTempStateGateBufferIndex] == nullptr) { 6172+ MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state gate buffer_ error."; 6173+ return RET_ERROR; 6174+ } 6175+ 6176+ if (!(lstm_param_->zoneout_cell_ >= -FLT_EPSILON && lstm_param_->zoneout_cell_ <= FLT_EPSILON)) { 6177+ int buffer_size = lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t); 6178+ running_buffer_[kTempCellStateBufferIndex] = 6179+ reinterpret_cast<float16_t *>(ms_context_->allocator->Malloc(buffer_size)); 6180+ if (running_buffer_[kTempCellStateBufferIndex] == nullptr) { 6181+ MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state_buffer for cell error."; 6182+ return RET_ERROR; 6183+ } 6184+ } 6185+ if (!(lstm_param_->zoneout_hidden_ >= -FLT_EPSILON && lstm_param_->zoneout_hidden_ <= FLT_EPSILON)) { 6186+ int buffer_size = lstm_param_->batch_ * lstm_param_->output_size_ * sizeof(float16_t); 6187+ running_buffer_[kTempHiddenStateBufferIndex] = 6188+ reinterpret_cast<float16_t *>(ms_context_->allocator->Malloc(buffer_size)); 6189+ if (running_buffer_[kTempHiddenStateBufferIndex] == nullptr) { 6190+ MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state_buffer for hidden error."; 6191+ return RET_ERROR; 6192+ } 6193+ } 6194+ 6195+ if (need_pack_input && in_tensors_.size() == C7NUM) { 6196+ running_buffer_[kTempProjectInputBufferIndex] = reinterpret_cast<float16_t *>( 6197+ ms_context_->allocator->Malloc(lstm_param_->state_row_align_ * lstm_param_->hidden_size_ * sizeof(float16_t))); 6198+ if (running_buffer_[kTempProjectInputBufferIndex] == nullptr) { 6199+ MS_LOG(ERROR) << "LstmFp16CPUKernel malloc project_buffer for hidden error."; 6200+ return RET_ERROR; 6201+ } 6202+ } 6203+ return RET_OK; 6204+} 6205+ 6206+void LstmFp16BaseCPUKernel::FreeRunBuffer() { 6207+ ms_context_->allocator->Free(running_buffer_[kTempInputBufferIndex]); 6208+ ms_context_->allocator->Free(running_buffer_[kTempInputGateBufferIndex]); 6209+ if (lstm_param_->batch_ != 1) { 6210+ ms_context_->allocator->Free(running_buffer_[kTempStateBufferIndex]); 6211+ } 6212+ ms_context_->allocator->Free(running_buffer_[kTempStateGateBufferIndex]); 6213+ if (!(lstm_param_->zoneout_cell_ >= -FLT_EPSILON && lstm_param_->zoneout_cell_ <= FLT_EPSILON)) { 6214+ ms_context_->allocator->Free(running_buffer_[kTempCellStateBufferIndex]); 6215+ } 6216+ if (!(lstm_param_->zoneout_hidden_ >= -FLT_EPSILON && lstm_param_->zoneout_hidden_ <= FLT_EPSILON)) { 6217+ ms_context_->allocator->Free(running_buffer_[kTempHiddenStateBufferIndex]); 6218+ } 6219+} 6220+} // namespace mindspore::kernel 6221diff --git a/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_fp16_base.h b/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_fp16_base.h 6222new file mode 100644 6223index 00000000..0bcb9e94 6224--- /dev/null 6225+++ b/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_fp16_base.h 6226@@ -0,0 +1,68 @@ 6227+/** 6228+ * Copyright 2023 Huawei Technologies Co., Ltd 6229+ * 6230+ * Licensed under the Apache License, Version 2.0 (the "License"); 6231+ * you may not use this file except in compliance with the License. 6232+ * You may obtain a copy of the License at 6233+ * 6234+ * http://www.apache.org/licenses/LICENSE-2.0 6235+ * 6236+ * Unless required by applicable law or agreed to in writing, software 6237+ * distributed under the License is distributed on an "AS IS" BASIS, 6238+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 6239+ * See the License for the specific language governing permissions and 6240+ * limitations under the License. 6241+ */ 6242+ 6243+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_LSTM_FP16_BASE_H_ 6244+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_LSTM_FP16_BASE_H_ 6245+ 6246+#include <vector> 6247+#include "src/litert/lite_kernel.h" 6248+#include "nnacl/lstm_parameter.h" 6249+ 6250+namespace mindspore::kernel { 6251+class LstmFp16BaseCPUKernel : public LiteKernel { 6252+ public: 6253+ LstmFp16BaseCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, 6254+ const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) 6255+ : LiteKernel(parameter, inputs, outputs, ctx) { 6256+ lstm_param_ = reinterpret_cast<LstmParameter *>(op_parameter_); 6257+ } 6258+ 6259+ ~LstmFp16BaseCPUKernel() override; 6260+ 6261+ int Prepare() override; 6262+ int ReSize() override; 6263+ int Run() override; 6264+ 6265+ protected: 6266+ virtual int InitInputWeightBias() = 0; 6267+ virtual int InitStateWeightBias() = 0; 6268+ virtual int InitProjectWeight() = 0; 6269+ 6270+ bool running_pack_{false}; 6271+ bool weight_need_pack_{false}; 6272+ int hidden_init_index_{0}; 6273+ int cell_init_index_{0}; 6274+ int weight_segment_num_{0}; 6275+ float16_t *weight_i_ptr_{nullptr}; 6276+ float16_t *weight_h_ptr_{nullptr}; 6277+ float16_t *weight_project_ptr_{nullptr}; 6278+ float16_t *input_bias_{nullptr}; 6279+ float16_t *state_bias_{nullptr}; 6280+ float16_t *project_bias_{nullptr}; 6281+ LstmParameter *lstm_param_{nullptr}; 6282+ float16_t *running_buffer_[C7NUM] = {nullptr}; 6283+ std::vector<void *> pack_buffer_; 6284+ 6285+ private: 6286+ int PackWeightAndBias(); 6287+ int InitParam(); 6288+ void FreePackBuffer(); 6289+ void FreeRunBuffer(); 6290+ int MallocRunBuffer(); 6291+}; 6292+} // namespace mindspore::kernel 6293+ 6294+#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_LSTM_FP16_BASE_H_ 6295diff --git a/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_mindir_fp16.cc b/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_mindir_fp16.cc 6296new file mode 100644 6297index 00000000..cf4071eb 6298--- /dev/null 6299+++ b/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_mindir_fp16.cc 6300@@ -0,0 +1,35 @@ 6301+/** 6302+ * Copyright 2023 Huawei Technologies Co., Ltd 6303+ * 6304+ * Licensed under the Apache License, Version 2.0 (the "License"); 6305+ * you may not use this file except in compliance with the License. 6306+ * You may obtain a copy of the License at 6307+ * 6308+ * http://www.apache.org/licenses/LICENSE-2.0 6309+ * 6310+ * Unless required by applicable law or agreed to in writing, software 6311+ * distributed under the License is distributed on an "AS IS" BASIS, 6312+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 6313+ * See the License for the specific language governing permissions and 6314+ * limitations under the License. 6315+ */ 6316+ 6317+#include "src/litert/kernel/cpu/fp16/lstm_mindir_fp16.h" 6318+ 6319+namespace mindspore::kernel { 6320+namespace { 6321+constexpr size_t kMindirInputTensorNum = 4; 6322+} // namespace 6323+ 6324+int LstmMindirFp16CPUKernel::Prepare() { 6325+ CHECK_NOT_EQUAL_RETURN(in_tensors_.size(), kMindirInputTensorNum); 6326+ running_pack_ = trainable_ || !in_tensors_[FOURTH_INPUT]->IsConst(); 6327+ return LstmFp16BaseCPUKernel::Prepare(); 6328+} 6329+ 6330+int LstmMindirFp16CPUKernel::InitInputWeightBias() { return lite::RET_NOT_SUPPORT; } 6331+ 6332+int LstmMindirFp16CPUKernel::InitStateWeightBias() { return lite::RET_NOT_SUPPORT; } 6333+ 6334+int LstmMindirFp16CPUKernel::InitProjectWeight() { return lite::RET_NOT_SUPPORT; } 6335+} // namespace mindspore::kernel 6336diff --git a/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_mindir_fp16.h b/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_mindir_fp16.h 6337new file mode 100644 6338index 00000000..bd8500d0 6339--- /dev/null 6340+++ b/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_mindir_fp16.h 6341@@ -0,0 +1,56 @@ 6342+/** 6343+ * Copyright 2023 Huawei Technologies Co., Ltd 6344+ * 6345+ * Licensed under the Apache License, Version 2.0 (the "License"); 6346+ * you may not use this file except in compliance with the License. 6347+ * You may obtain a copy of the License at 6348+ * 6349+ * http://www.apache.org/licenses/LICENSE-2.0 6350+ * 6351+ * Unless required by applicable law or agreed to in writing, software 6352+ * distributed under the License is distributed on an "AS IS" BASIS, 6353+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 6354+ * See the License for the specific language governing permissions and 6355+ * limitations under the License. 6356+ */ 6357+ 6358+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_LSTM_MINDIR_FP16_H_ 6359+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_LSTM_MINDIR_FP16_H_ 6360+ 6361+#include <vector> 6362+#include "src/litert/kernel/cpu/fp16/lstm_fp16_base.h" 6363+ 6364+namespace mindspore::kernel { 6365+/* 6366+ * 1. LSTM without project, output_size = hidden_size 6367+ * h_init: second input, shape is [bidirectional, batch_size, hidden_size] 6368+ * c_init: third input, shape is [bidirectional, batch_size, hidden_size] 6369+ * weight_bias: forth input, weight_ih + weight_hh + bias, the gate order is IFGO 6370+ * 6371+ * 2. LSTM with project, output_size = project_size 6372+ * don't support 6373+ * h_init: second input, shape is [bidirectional, batch_size, hidden_size] 6374+ * c_init: third input, shape is [bidirectional, batch_size, hidden_size] 6375+ * weight_bias: forth input, weight_ih + weight_hh + proj + bias, the gate order is IFGO 6376+ */ 6377+class LstmMindirFp16CPUKernel : public LstmFp16BaseCPUKernel { 6378+ public: 6379+ LstmMindirFp16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, 6380+ const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) 6381+ : LstmFp16BaseCPUKernel(parameter, inputs, outputs, ctx) { 6382+ hidden_init_index_ = SECOND_INPUT; 6383+ cell_init_index_ = THIRD_INPUT; 6384+ } 6385+ 6386+ ~LstmMindirFp16CPUKernel() override = default; 6387+ 6388+ int Prepare() override; 6389+ 6390+ protected: 6391+ int InitInputWeightBias() override; 6392+ int InitStateWeightBias() override; 6393+ int InitProjectWeight() override; 6394+}; 6395+} // namespace mindspore::kernel 6396+ 6397+#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_LSTM_MINDIR_FP16_H_ 6398diff --git a/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_non_mindir_fp16.cc b/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_non_mindir_fp16.cc 6399new file mode 100644 6400index 00000000..473fe9b0 6401--- /dev/null 6402+++ b/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_non_mindir_fp16.cc 6403@@ -0,0 +1,194 @@ 6404+/** 6405+ * Copyright 2023 Huawei Technologies Co., Ltd 6406+ * 6407+ * Licensed under the Apache License, Version 2.0 (the "License"); 6408+ * you may not use this file except in compliance with the License. 6409+ * You may obtain a copy of the License at 6410+ * 6411+ * http://www.apache.org/licenses/LICENSE-2.0 6412+ * 6413+ * Unless required by applicable law or agreed to in writing, software 6414+ * distributed under the License is distributed on an "AS IS" BASIS, 6415+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 6416+ * See the License for the specific language governing permissions and 6417+ * limitations under the License. 6418+ */ 6419+ 6420+#include "src/litert/kernel/cpu/fp16/lstm_non_mindir_fp16.h" 6421+#include "nnacl/fp16/lstm_fp16.h" 6422+#include "nnacl/fp16/cast_fp16.h" 6423+ 6424+using mindspore::lite::RET_ERROR; 6425+using mindspore::lite::RET_OK; 6426+ 6427+namespace mindspore::kernel { 6428+namespace { 6429+constexpr int kGateNum = 4; 6430+constexpr size_t kInputTensorNumMin = 6; 6431+} // namespace 6432+ 6433+int LstmNonMindirFp16CPUKernel::Prepare() { 6434+ CHECK_LESS_RETURN(in_tensors_.size(), kInputTensorNumMin); 6435+ running_pack_ = train_mode_; 6436+ for (size_t i = 1; i <= FOURTH_INPUT; ++i) { 6437+ running_pack_ = running_pack_ || !in_tensors_[i]->IsConst(); 6438+ } 6439+ return LstmFp16BaseCPUKernel::Prepare(); 6440+} 6441+ 6442+int LstmNonMindirFp16CPUKernel::InitInputWeightBias() { 6443+ // malloc and init input * weight right matrix buffer 6444+ // input -- row: seq_len * batch; col: input_size 6445+ // weight -- row: hidden_size; col: input_size, need transpose 6446+ // result -- row: seq_len * batch; col: hidden_size 6447+ auto weight_i = in_tensors_.at(1); 6448+ auto weight_i_data = weight_i->data(); 6449+ CHECK_NULL_RETURN(weight_i_data); 6450+ weight_i_ptr_ = reinterpret_cast<float16_t *>( 6451+ malloc(weight_segment_num_ * lstm_param_->input_col_align_ * lstm_param_->input_size_ * sizeof(float16_t))); 6452+ MS_CHECK_TRUE_MSG(weight_i_ptr_ != nullptr, lite::RET_NULL_PTR, 6453+ "LstmNonMindirCPUKernel malloc weight_i_ptr_ failed."); 6454+ pack_buffer_.push_back(weight_i_ptr_); 6455+ if (weight_i->data_type() == kNumberTypeFloat32) { 6456+ PackLstmWeightFp32ToFp16(weight_i_ptr_, reinterpret_cast<float *>(weight_i_data), weight_segment_num_, 6457+ lstm_param_->input_size_, lstm_param_->hidden_size_, lstm_param_->input_col_align_, 6458+ nullptr); 6459+ } else if (weight_i->data_type() == kNumberTypeFloat16) { 6460+ PackLstmWeightFp16(weight_i_ptr_, reinterpret_cast<float16_t *>(weight_i_data), weight_segment_num_, 6461+ lstm_param_->input_size_, lstm_param_->hidden_size_, lstm_param_->input_col_align_, nullptr); 6462+ } else { 6463+ MS_LOG(ERROR) << "Unsupported data type of weight_i tensor for lstm."; 6464+ return RET_ERROR; 6465+ } 6466+ 6467+ // input bias 6468+ auto bias = in_tensors_.at(FOURTH_INPUT); 6469+ auto bias_data = bias->data(); 6470+ CHECK_NULL_RETURN(bias_data); 6471+ input_bias_ = 6472+ reinterpret_cast<float16_t *>(malloc(weight_segment_num_ * lstm_param_->input_col_align_ * sizeof(float16_t))); 6473+ MS_CHECK_TRUE_MSG(input_bias_ != nullptr, lite::RET_NULL_PTR, "LstmNonMindirCPUKernel malloc input_bias_ failed."); 6474+ pack_buffer_.push_back(input_bias_); 6475+ (void)memset(input_bias_, 0, weight_segment_num_ * lstm_param_->input_col_align_ * sizeof(float16_t)); 6476+ if (bias->data_type() == kNumberTypeFloat32) { 6477+ PackLstmBiasFp32ToFp16(input_bias_, reinterpret_cast<float *>(bias_data), weight_segment_num_, 6478+ lstm_param_->hidden_size_, lstm_param_->input_col_align_, lstm_param_->bidirectional_, 6479+ nullptr); 6480+ } else if (bias->data_type() == kNumberTypeFloat16) { 6481+ PackLstmBiasFp16(input_bias_, reinterpret_cast<float16_t *>(bias_data), weight_segment_num_, 6482+ lstm_param_->hidden_size_, lstm_param_->input_col_align_, lstm_param_->bidirectional_, nullptr); 6483+ } else { 6484+ MS_LOG(ERROR) << "Unsupported data type of bias tensor for lstm."; 6485+ return RET_ERROR; 6486+ } 6487+ return RET_OK; 6488+} 6489+ 6490+int LstmNonMindirFp16CPUKernel::InitStateWeightBias() { 6491+ // malloc and init state * weight right matrix buffer, state * weight will be executed seq_len_ times. 6492+ // state -- row: batch; col: hidden_size 6493+ // weight -- row: hidden_size; col: hidden_size, need transpose 6494+ // result -- row: batch; col: hidden_size 6495+ auto weight_h = in_tensors_.at(THIRD_INPUT); 6496+ auto weight_h_data = weight_h->data(); 6497+ CHECK_NULL_RETURN(weight_h_data); 6498+ weight_h_ptr_ = reinterpret_cast<float16_t *>( 6499+ malloc(weight_segment_num_ * lstm_param_->state_col_align_ * lstm_param_->output_size_ * sizeof(float16_t))); 6500+ MS_CHECK_TRUE_MSG(weight_h_ptr_ != nullptr, lite::RET_NULL_PTR, 6501+ "LstmNonMindirCPUKernel malloc weight_h_ptr_ failed."); 6502+ 6503+ if (weight_need_pack_) { 6504+ if (weight_h->data_type() == kNumberTypeFloat32) { 6505+ PackLstmWeightFp32ToFp16(weight_h_ptr_, reinterpret_cast<float *>(weight_h_data), weight_segment_num_, 6506+ lstm_param_->output_size_, lstm_param_->hidden_size_, lstm_param_->state_col_align_, 6507+ nullptr); 6508+ } else if (weight_h->data_type() == kNumberTypeFloat16) { 6509+ PackLstmWeightFp16(weight_h_ptr_, reinterpret_cast<float16_t *>(weight_h_data), weight_segment_num_, 6510+ lstm_param_->output_size_, lstm_param_->hidden_size_, lstm_param_->state_col_align_, nullptr); 6511+ } else { 6512+ MS_LOG(ERROR) << "Unsupported data type of weight_h tensor for lstm."; 6513+ return RET_ERROR; 6514+ } 6515+ } else { 6516+ if (weight_h->data_type() == kNumberTypeFloat32) { 6517+ Float32ToFloat16(reinterpret_cast<float *>(weight_h_data), weight_h_ptr_, weight_h->ElementsNum()); 6518+ } else if (weight_h->data_type() == kNumberTypeFloat16) { 6519+ (void)memcpy(weight_h_ptr_, reinterpret_cast<float16_t *>(weight_h_data), weight_h->Size()); 6520+ } else { 6521+ MS_LOG(ERROR) << "Unsupported data type of weight_h tensor for lstm."; 6522+ return RET_ERROR; 6523+ } 6524+ } 6525+ 6526+ // state bias 6527+ auto bias = in_tensors_[FOURTH_INPUT]; 6528+ auto bias_data = bias->data(); 6529+ CHECK_NULL_RETURN(bias_data); 6530+ state_bias_ = 6531+ reinterpret_cast<float16_t *>(malloc(weight_segment_num_ * lstm_param_->state_col_align_ * sizeof(float16_t))); 6532+ MS_CHECK_TRUE_MSG(state_bias_ != nullptr, lite::RET_NULL_PTR, "LstmNonMindirCPUKernel malloc state_bias_ failed."); 6533+ (void)memset(state_bias_, 0, weight_segment_num_ * lstm_param_->state_col_align_ * sizeof(float16_t)); 6534+ if (bias->data_type() == kNumberTypeFloat32) { 6535+ auto state_bias_data = reinterpret_cast<float *>(bias_data) + kGateNum * lstm_param_->hidden_size_; 6536+ PackLstmBiasFp32ToFp16(state_bias_, state_bias_data, weight_segment_num_, lstm_param_->hidden_size_, 6537+ lstm_param_->state_col_align_, lstm_param_->bidirectional_, nullptr); 6538+ } else if (bias->data_type() == kNumberTypeFloat16) { 6539+ auto state_bias_data = reinterpret_cast<float16_t *>(bias_data) + kGateNum * lstm_param_->hidden_size_; 6540+ PackLstmBiasFp16(state_bias_, state_bias_data, weight_segment_num_, lstm_param_->hidden_size_, 6541+ lstm_param_->state_col_align_, lstm_param_->bidirectional_, nullptr); 6542+ } else { 6543+ MS_LOG(ERROR) << "Unsupported data type of bias tensor for lstm."; 6544+ return RET_ERROR; 6545+ } 6546+ return RET_OK; 6547+} 6548+ 6549+int LstmNonMindirFp16CPUKernel::InitProjectWeight() { 6550+ if (in_tensors_.size() < C7NUM) { 6551+ return RET_OK; 6552+ } 6553+ auto weight_pro = in_tensors_[SEVENTH_INPUT]; 6554+ auto shape = weight_pro->shape(); 6555+ MS_CHECK_TRUE_MSG(shape.size() == C3NUM, lite::RET_ERROR, "Project-weight's shape must be 3D."); 6556+ auto weight_pro_data = weight_pro->data(); 6557+ CHECK_NULL_RETURN(weight_pro_data); 6558+ int batch = lstm_param_->bidirectional_ ? C2NUM : C1NUM; 6559+ if (shape[0] != batch) { 6560+ MS_LOG(ERROR) << "Project-weight's shape[0] must be 1(bidirectional=false) or 2(bidirectional=true)."; 6561+ return RET_ERROR; 6562+ } 6563+ int pro_col_align = lstm_param_->proj_col_align_; 6564+ weight_project_ptr_ = 6565+ reinterpret_cast<float16_t *>(malloc(batch * lstm_param_->hidden_size_ * pro_col_align * sizeof(float16_t))); 6566+ MS_CHECK_TRUE_MSG(weight_project_ptr_ != nullptr, lite::RET_NULL_PTR, 6567+ "LstmNonMindirCPUKernel malloc weight_project_ptr_ failed."); 6568+ 6569+ if (weight_need_pack_) { 6570+ if (weight_pro->data_type() == kNumberTypeFloat32) { 6571+ PackLstmWeightFp32ToFp16(weight_project_ptr_, reinterpret_cast<float *>(weight_pro_data), batch, 6572+ lstm_param_->hidden_size_, lstm_param_->output_size_, pro_col_align, nullptr); 6573+ } else if (weight_pro->data_type() == kNumberTypeFloat16) { 6574+ PackLstmWeightFp16(weight_project_ptr_, reinterpret_cast<float16_t *>(weight_pro_data), batch, 6575+ lstm_param_->hidden_size_, lstm_param_->output_size_, pro_col_align, nullptr); 6576+ } else { 6577+ MS_LOG(ERROR) << "Unsupported data type of weight_project tensor for lstm."; 6578+ return RET_ERROR; 6579+ } 6580+ } else { 6581+ if (weight_pro->data_type() == kNumberTypeFloat32) { 6582+ Float32ToFloat16(reinterpret_cast<float *>(weight_pro_data), weight_project_ptr_, weight_pro->ElementsNum()); 6583+ } else if (weight_pro->data_type() == kNumberTypeFloat16) { 6584+ (void)memcpy(weight_project_ptr_, weight_pro_data, weight_pro->Size()); 6585+ } else { 6586+ MS_LOG(ERROR) << "Unsupported data type of weight_project tensor for lstm."; 6587+ return RET_ERROR; 6588+ } 6589+ } 6590+ size_t bias_size = UP_ROUND(lstm_param_->output_size_, C8NUM) * sizeof(float16_t); 6591+ project_bias_ = reinterpret_cast<float16_t *>(malloc(bias_size)); 6592+ MS_CHECK_TRUE_MSG(project_bias_ != nullptr, lite::RET_NULL_PTR, 6593+ "LstmNonMindirCPUKernel malloc project_bias_ failed."); 6594+ (void)memset(project_bias_, 0, bias_size); 6595+ return RET_OK; 6596+} 6597+} // namespace mindspore::kernel 6598diff --git a/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_non_mindir_fp16.h b/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_non_mindir_fp16.h 6599new file mode 100644 6600index 00000000..132ef1cf 6601--- /dev/null 6602+++ b/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_non_mindir_fp16.h 6603@@ -0,0 +1,59 @@ 6604+/** 6605+ * Copyright 2023 Huawei Technologies Co., Ltd 6606+ * 6607+ * Licensed under the Apache License, Version 2.0 (the "License"); 6608+ * you may not use this file except in compliance with the License. 6609+ * You may obtain a copy of the License at 6610+ * 6611+ * http://www.apache.org/licenses/LICENSE-2.0 6612+ * 6613+ * Unless required by applicable law or agreed to in writing, software 6614+ * distributed under the License is distributed on an "AS IS" BASIS, 6615+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 6616+ * See the License for the specific language governing permissions and 6617+ * limitations under the License. 6618+ */ 6619+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_LSTM_NON_MINDIR_FP16_H_ 6620+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_LSTM_NON_MINDIR_FP16_H_ 6621+ 6622+#include <vector> 6623+#include "src/litert/kernel/cpu/fp16/lstm_fp16_base.h" 6624+ 6625+namespace mindspore::kernel { 6626+/* 6627+ * 1. LSTM without project, output_size = hidden_size 6628+ * weight_ih: second input, shape is [bidirectional, 4 * hidden_size, input_size] 6629+ * weight_hh: third input, shape is [bidirectional, 4 * hidden_size, hidden_size] 6630+ * bias: forth input, shape is [bidirectional, 8 * hidden_size] 6631+ * h_init: fifth input, shape is [bidirectional, batch_size, hidden_size] 6632+ * c_init: sixth input, shape is [bidirectional, batch_size, hidden_size] 6633+ * 6634+ * 2. LSTM with project, output_size = project_size 6635+ * weight_ih: second input, shape is [bidirectional, 4 * hidden_size, input_size] 6636+ * weight_hh: third input, shape is [bidirectional, 4 * hidden_size, project_size] 6637+ * bias: forth input, shape is [bidirectional, 8 * hidden_size] 6638+ * h_init: fifth input, shape is [bidirectional, batch_size, project_size] 6639+ * c_init: sixth input, shape is [bidirectional, batch_size, hidden_size] 6640+ * weight_pro: seventh input, shape is [bidirectional, project_size, hidden_size] 6641+ */ 6642+class LstmNonMindirFp16CPUKernel : public LstmFp16BaseCPUKernel { 6643+ public: 6644+ LstmNonMindirFp16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, 6645+ const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) 6646+ : LstmFp16BaseCPUKernel(parameter, inputs, outputs, ctx) { 6647+ hidden_init_index_ = FIFTH_INPUT; 6648+ cell_init_index_ = SIXTH_INPUT; 6649+ } 6650+ 6651+ ~LstmNonMindirFp16CPUKernel() override = default; 6652+ 6653+ int Prepare() override; 6654+ 6655+ protected: 6656+ int InitInputWeightBias() override; 6657+ int InitStateWeightBias() override; 6658+ int InitProjectWeight() override; 6659+}; 6660+} // namespace mindspore::kernel 6661+ 6662+#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_LSTM_NON_MINDIR_FP16_H_ 6663diff --git a/mindspore/lite/src/litert/kernel/cpu/fp16/matmul_base_fp16.cc b/mindspore/lite/src/litert/kernel/cpu/fp16/matmul_base_fp16.cc 6664index 8adb97b9..d6f94fd9 100644 6665--- a/mindspore/lite/src/litert/kernel/cpu/fp16/matmul_base_fp16.cc 6666+++ b/mindspore/lite/src/litert/kernel/cpu/fp16/matmul_base_fp16.cc 6667@@ -187,13 +187,13 @@ void MatmulBaseFP16CPUKernel::InitMatrixA(const void *src_ptr) { 6668 float16_t *dst = a_pack_ptr_ + i * params_->deep_ * params_->row_align_; 6669 if (params_->a_transpose_) { 6670 #ifdef ENABLE_ARM64 6671- RowMajor2RowNMajorFp16((const float16_t *)src, dst, params_->deep_, params_->row_); 6672+ RowMajor2RowNMajorFp16(src, dst, params_->deep_, params_->row_, src_data_type == kNumberTypeFloat32); 6673 #else 6674 RowMajor2Row12MajorFp16(src, dst, params_->deep_, params_->row_, src_data_type == kNumberTypeFloat32); 6675 #endif 6676 } else { 6677 #ifdef ENABLE_ARM64 6678- RowMajor2ColNMajorFp16((const float16_t *)src, dst, params_->row_, params_->deep_); 6679+ RowMajor2ColNMajorFp16(src, dst, params_->row_, params_->deep_, src_data_type == kNumberTypeFloat32); 6680 #else 6681 RowMajor2Col12MajorFp16(src, dst, params_->row_, params_->deep_, src_data_type == kNumberTypeFloat32); 6682 #endif 6683diff --git a/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_fp32.cc b/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_fp32.cc 6684index 0b67f2c2..67f42265 100644 6685--- a/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_fp32.cc 6686+++ b/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_fp32.cc 6687@@ -1,5 +1,5 @@ 6688 /** 6689- * Copyright 2020 Huawei Technologies Co., Ltd 6690+ * Copyright 2020-2023 Huawei Technologies Co., Ltd 6691 * 6692 * Licensed under the Apache License, Version 2.0 (the "License"); 6693 * you may not use this file except in compliance with the License. 6694@@ -14,14 +14,11 @@ 6695 * limitations under the License. 6696 */ 6697 6698-#include "src/litert/kernel/cpu/fp32/lstm_fp32.h" 6699-#include <cfloat> 6700 #include <vector> 6701-#include "schema/model_generated.h" 6702+#include "src/litert//kernel/cpu/fp32/lstm_mindir_fp32.h" 6703+#include "src/litert/kernel/cpu/fp32/lstm_non_mindir_fp32.h" 6704 #include "src/litert/kernel_registry.h" 6705 #include "include/errorcode.h" 6706-#include "nnacl/fp32/pack_fp32.h" 6707-#include "nnacl/fp32/matmul_fp32.h" 6708 6709 using mindspore::kernel::KERNEL_ARCH; 6710 using mindspore::lite::KernelRegistrar; 6711@@ -32,664 +29,31 @@ using mindspore::schema::PrimitiveType_LSTM; 6712 6713 namespace mindspore::kernel { 6714 namespace { 6715-constexpr int kOutputHiddenStatusIndex = 1; 6716-constexpr int kOutputCellStatusIndex = 2; 6717-} // namespace 6718- 6719-int LstmInputMulWeightRun(void *cdata, int task_id, float, float) { 6720- auto kernel = reinterpret_cast<const LstmCPUKernel *>(cdata); 6721- CHECK_NULL_RETURN(kernel); 6722- kernel->InputWeightMatMul(task_id); 6723- return RET_OK; 6724-} 6725- 6726-int LstmSequenceLoopRun(void *cdata, int task_id, float, float) { 6727- auto kernel = reinterpret_cast<LstmCPUKernel *>(cdata); 6728- CHECK_NULL_RETURN(kernel); 6729- auto ret = kernel->DoSequenceLoop(task_id); 6730- if (ret != RET_OK) { 6731- MS_LOG(ERROR) << "LSTM: Do Sequence-loop failed."; 6732- } 6733- return ret; 6734-} 6735- 6736-void LstmCPUKernel::FreeRunBuffer() { 6737- for (auto data : buffer_running_malloc_) { 6738- ms_context_->allocator->Free(data); 6739- } 6740- buffer_running_malloc_.clear(); 6741-} 6742- 6743-int LstmCPUKernel::InitInputWeightBias() { 6744- // malloc and init input * weight right matrix buffer 6745- // input -- row: seq_len * batch; col: input_size 6746- // weight -- row: hidden_size; col: input_size, need transpose 6747- // result -- row: seq_len * batch; col: hidden_size 6748- weight_i_ptr_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc( 6749- weight_batch_ * lstm_param_->input_col_align_ * lstm_param_->input_size_ * sizeof(float))); 6750- if (weight_i_ptr_ == nullptr) { 6751- MS_LOG(ERROR) << "LstmCPUKernel malloc weight_i_ptr_ error."; 6752- return RET_ERROR; 6753- } 6754- buffer_running_malloc_.push_back(weight_i_ptr_); 6755- int i_index = (in_tensors_.size() == mindir_input_tensors) ? combined_weights_index : onnx_weight_i_index; 6756- const int *weights_order = (in_tensors_.size() == mindir_input_tensors) ? weights_order_IFOG : nullptr; 6757- auto weight_i = in_tensors_.at(i_index); 6758- auto weight_i_data = reinterpret_cast<float *>(weight_i->data()); 6759- 6760- CHECK_NULL_RETURN(weight_i_data); 6761- int cw_size = (lstm_param_->input_size_ * lstm_param_->hidden_size_); 6762- int hh_size = (lstm_param_->hidden_size_ * lstm_param_->hidden_size_); 6763- int b_size = (lstm_param_->hidden_size_); 6764- bool has_bias = (weight_batch_ * (cw_size + hh_size) < weight_i->ElementsNum()) ? true : false; 6765- int stride = (gpu_orig_state_) ? gate_num * (cw_size + hh_size) : gate_num * (cw_size); 6766- PackLstmWeightWithStride(weight_i_ptr_, weight_i_data, weight_batch_, lstm_param_->input_size_, 6767- lstm_param_->hidden_size_, lstm_param_->input_col_align_, lstm_param_->bidirectional_, 6768- stride, weights_order); 6769- // input bias 6770- input_bias_ = reinterpret_cast<float *>( 6771- ms_context_->allocator->Malloc(weight_batch_ * lstm_param_->input_col_align_ * sizeof(float))); 6772- if (input_bias_ == nullptr) { 6773- MS_LOG(ERROR) << "LstmCPUKernel malloc input_bias_ error."; 6774- return RET_ERROR; 6775- } 6776- memset(input_bias_, 0, weight_batch_ * lstm_param_->input_col_align_ * sizeof(float)); 6777- buffer_running_malloc_.push_back(input_bias_); 6778- 6779- int offset = weight_batch_ * (cw_size + hh_size); 6780- float *bias_data = (has_bias) ? weight_i_data + offset : nullptr; 6781- int dir_mul = lstm_param_->bidirectional_ ? C2NUM : C1NUM; 6782- int b_stride = (gpu_orig_state_) ? gate_num * (dir_mul * b_size) : gate_num * (b_size); 6783- if (in_tensors_.size() > mindir_input_tensors) { 6784- bias_data = reinterpret_cast<float *>(in_tensors_.at(onnx_bias_index)->data()); 6785- CHECK_NULL_RETURN(bias_data); 6786- PackLstmBias(input_bias_, bias_data, weight_batch_, lstm_param_->hidden_size_, lstm_param_->input_col_align_, 6787- lstm_param_->bidirectional_, weights_order); 6788- } else { 6789- if (bias_data != nullptr) { 6790- PackLstmBiasWithStride(input_bias_, bias_data, weight_batch_, lstm_param_->hidden_size_, 6791- lstm_param_->input_col_align_, lstm_param_->bidirectional_, b_stride, weights_order); 6792- } 6793- } 6794- return RET_OK; 6795-} 6796- 6797-int LstmCPUKernel::InitStateWeightBias() { 6798- // malloc and init state * weight right matrix buffer, state * weight will be executed seq_len_ times. 6799- // state -- row: batch; col: hidden_size 6800- // weight -- row: hidden_size; col: hidden_size, need transpose 6801- // result -- row: batch; col: hidden_size 6802- int weight_i_size = weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->input_size_; 6803- int h_index = (in_tensors_.size() == mindir_input_tensors) ? combined_weights_index : onnx_weight_h_index; 6804- auto weight_h = in_tensors_.at(h_index); 6805- auto weight_h_data = (reinterpret_cast<float *>(weight_h->data())); 6806- 6807- int cw_size = (lstm_param_->input_size_ * lstm_param_->hidden_size_); 6808- int hh_size = (lstm_param_->hidden_size_ * lstm_param_->project_size_); 6809- int b_size = (lstm_param_->hidden_size_); 6810- int stride = (gpu_orig_state_) ? gate_num * (cw_size + hh_size) : gate_num * (hh_size); 6811- 6812- if (in_tensors_.size() == mindir_input_tensors) { 6813- if (gpu_orig_state_) { 6814- weight_h_data += gate_num * cw_size; 6815- } else { 6816- weight_h_data += weight_i_size; 6817- } 6818- } 6819- CHECK_NULL_RETURN(weight_h_data); 6820- if (!state_is_vec_) { 6821- weight_h_ptr_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc( 6822- weight_batch_ * lstm_param_->state_col_align_ * lstm_param_->project_size_ * sizeof(float))); 6823- if (weight_h_ptr_ == nullptr) { 6824- MS_LOG(ERROR) << "LstmCPUKernel malloc weight_h_ptr_ error."; 6825- return RET_ERROR; 6826- } 6827- buffer_running_malloc_.push_back(weight_h_ptr_); 6828- const int *weights_order = (in_tensors_.size() == mindir_input_tensors) ? weights_order_IFOG : nullptr; 6829- PackLstmWeightWithStride(weight_h_ptr_, weight_h_data, weight_batch_, lstm_param_->project_size_, 6830- lstm_param_->hidden_size_, lstm_param_->state_col_align_, lstm_param_->bidirectional_, 6831- stride, weights_order); 6832- } else { 6833-#ifdef ENABLE_AVX 6834- weight_h_ptr_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc( 6835- weight_batch_ * lstm_param_->state_col_align_ * lstm_param_->project_size_ * sizeof(float))); 6836- if (weight_h_ptr_ == nullptr) { 6837- MS_LOG(ERROR) << "LstmCPUKernel malloc weight_h_ptr_ error."; 6838- return RET_ERROR; 6839- } 6840- buffer_running_malloc_.push_back(weight_h_ptr_); 6841- for (int i = 0; i < weight_batch_; i++) { 6842- const float *src_batch = weight_h_data + i * lstm_param_->hidden_size_ * lstm_param_->project_size_; 6843- float *dst_batch = weight_h_ptr_ + i * lstm_param_->state_col_align_ * lstm_param_->project_size_; 6844- RowMajor2Col32Major(src_batch, dst_batch, lstm_param_->hidden_size_, lstm_param_->project_size_); 6845- } 6846-#else 6847- weight_h_ptr_ = weight_h_data; 6848-#endif 6849- } 6850- 6851- // state bias 6852- int weight_h_size = weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->hidden_size_; 6853- int bias_size = weight_batch_ * lstm_param_->hidden_size_; 6854- state_bias_ = reinterpret_cast<float *>( 6855- ms_context_->allocator->Malloc(weight_batch_ * lstm_param_->state_col_align_ * sizeof(float))); 6856- if (state_bias_ == nullptr) { 6857- MS_LOG(ERROR) << "LstmCPUKernel malloc state_bias_ error."; 6858- return RET_ERROR; 6859- } 6860- memset(state_bias_, 0, weight_batch_ * lstm_param_->state_col_align_ * sizeof(float)); 6861- buffer_running_malloc_.push_back(state_bias_); 6862- // if ONNX, secend bias is also present order IOFG 6863- if (in_tensors_.size() > mindir_input_tensors) { 6864- float *state_bias = 6865- reinterpret_cast<float *>(in_tensors_.at(onnx_bias_index)->data()) + gate_num * lstm_param_->hidden_size_; 6866- CHECK_NULL_RETURN(state_bias); 6867- PackLstmBias(state_bias_, state_bias, weight_batch_, lstm_param_->hidden_size_, lstm_param_->state_col_align_, 6868- lstm_param_->bidirectional_, nullptr); 6869- } else if (weight_h->ElementsNum() - weight_i_size - weight_h_size - C2NUM * bias_size == 0) { 6870- // mindir from device "GPU", secend bias is also present order IFOG 6871- int dir_mul = lstm_param_->bidirectional_ ? C2NUM : C1NUM; 6872- int bias_offset = (gpu_orig_state_) ? gate_num * ((dir_mul - C1NUM) * cw_size + dir_mul * hh_size + b_size) 6873- : weight_h_size + bias_size; 6874- float *state_bias = weight_h_data + bias_offset; 6875- int b_stride = (gpu_orig_state_) ? gate_num * (b_size * C2NUM) : gate_num * b_size; 6876- PackLstmBiasWithStride(state_bias_, state_bias, weight_batch_, lstm_param_->hidden_size_, 6877- lstm_param_->state_col_align_, lstm_param_->bidirectional_, b_stride, weights_order_IFOG); 6878- } 6879- return RET_OK; 6880-} 6881- 6882-int LstmCPUKernel::InitProjectWeight() { 6883- if (in_tensors_.size() < C7NUM) { 6884- return RET_OK; 6885- } 6886- auto weight_pro = in_tensors_.at(SEVENTH_INPUT); 6887- auto shape = weight_pro->shape(); 6888- if (shape.size() != C3NUM) { 6889- MS_LOG(ERROR) << "Project-weight's shape must be 3D."; 6890- return RET_ERROR; 6891- } 6892- auto weight_pro_data = reinterpret_cast<float *>(weight_pro->data()); 6893- CHECK_NULL_RETURN(weight_pro_data); 6894- int batch = lstm_param_->bidirectional_ ? C2NUM : C1NUM; 6895- if (shape[0] != batch) { 6896- MS_LOG(ERROR) << "Project-weight's shape[0] must be 1(bidirectional=false) or 2(bidirectional=true)."; 6897- return RET_ERROR; 6898- } 6899- int col_align = UP_ROUND(lstm_param_->project_size_, col_tile_); 6900- if (!state_is_vec_) { 6901- weight_project_ptr_ = reinterpret_cast<float *>( 6902- ms_context_->allocator->Malloc(batch * lstm_param_->hidden_size_ * col_align * sizeof(float))); 6903- if (weight_project_ptr_ == nullptr) { 6904- MS_LOG(ERROR) << "LstmCPUKernel malloc weight_project_ptr_ error."; 6905- return RET_ERROR; 6906- } 6907- buffer_running_malloc_.push_back(weight_project_ptr_); 6908- PackLstmWeightWithStride(weight_project_ptr_, weight_pro_data, batch, lstm_param_->hidden_size_, 6909- lstm_param_->project_size_, col_align, lstm_param_->bidirectional_, 6910- lstm_param_->hidden_size_ * lstm_param_->project_size_, nullptr); 6911- } else { 6912-#ifdef ENABLE_AVX 6913- weight_project_ptr_ = reinterpret_cast<float *>( 6914- ms_context_->allocator->Malloc(batch * lstm_param_->hidden_size_ * col_align * sizeof(float))); 6915- if (weight_project_ptr_ == nullptr) { 6916- MS_LOG(ERROR) << "LstmCPUKernel malloc weight_project_ptr_ error."; 6917- return RET_ERROR; 6918- } 6919- buffer_running_malloc_.push_back(weight_project_ptr_); 6920- for (int i = 0; i < batch; ++i) { 6921- const float *src_batch = weight_pro_data + i * lstm_param_->hidden_size_ * lstm_param_->project_size_; 6922- float *dst_batch = weight_project_ptr_ + i * lstm_param_->hidden_size_ * col_align; 6923- RowMajor2Col32Major(src_batch, dst_batch, lstm_param_->project_size_, lstm_param_->hidden_size_); 6924- } 6925-#else 6926- weight_project_ptr_ = weight_pro_data; 6927-#endif 6928- } 6929- return RET_OK; 6930-} 6931- 6932-int LstmCPUKernel::InitParam() { 6933- auto input = in_tensors_.front(); 6934- std::vector<int> in_shape = input->shape(); 6935- lstm_param_->seq_len_ = in_shape.at(FIRST_INPUT); 6936- lstm_param_->batch_ = in_shape.at(SECOND_INPUT); 6937- lstm_param_->input_size_ = in_shape.at(THIRD_INPUT); 6938- 6939- auto weight_i = in_tensors_.at(onnx_weight_i_index); 6940- std::vector<int> w_shape = weight_i->shape(); 6941- if (in_tensors_.size() == mindir_input_tensors) { 6942- hidden_state_input_index_ = mindir_hidden_state_input_index; 6943- cell_state_input_index_ = mindir_cell_state_input_index; 6944- lstm_param_->hidden_size_ = w_shape.at(THIRD_INPUT); 6945- lstm_param_->project_size_ = lstm_param_->hidden_size_; 6946- } else { 6947- lstm_param_->hidden_size_ = w_shape.at(SECOND_INPUT) / gate_num; 6948- auto weight_h = in_tensors_[THIRD_INPUT]; 6949- auto h_shape = weight_h->shape(); 6950- lstm_param_->project_size_ = h_shape.back(); 6951- } 6952- 6953- lstm_param_->output_step_ = lstm_param_->bidirectional_ ? C2NUM * lstm_param_->batch_ * lstm_param_->hidden_size_ 6954- : lstm_param_->batch_ * lstm_param_->hidden_size_; 6955- weight_batch_ = lstm_param_->bidirectional_ ? C2NUM * gate_num : gate_num; 6956- state_is_vec_ = lstm_param_->batch_ == 1; 6957- // determine FB origin 6958- gpu_orig_state_ = false; 6959- if (in_tensors_.size() == mindir_input_tensors) { 6960- gpu_orig_state_ = gpu_orig_cfg_; 6961- auto weight_t = in_tensors_.at(combined_weights_index); 6962- int cw_size = (lstm_param_->input_size_ * lstm_param_->hidden_size_); 6963- int hh_size = (lstm_param_->hidden_size_ * lstm_param_->hidden_size_); 6964- int b_size = (lstm_param_->hidden_size_); 6965- bool has_bias = (weight_batch_ * (cw_size + hh_size) < weight_t->ElementsNum()) ? true : false; 6966- // if bias exist we can determine the gpu_orig_state_ 6967- if (has_bias) { 6968- gpu_orig_state_ = 6969- (weight_batch_ * (cw_size + hh_size + C2NUM * b_size) == weight_t->ElementsNum()) ? true : false; 6970- } 6971- } 6972- 6973-#ifdef ENABLE_AVX 6974- row_tile_ = C6NUM; 6975- col_tile_ = C16NUM; 6976-#elif defined(ENABLE_ARM32) 6977- row_tile_ = C12NUM; 6978- col_tile_ = C4NUM; 6979-#elif defined(ENABLE_SSE) 6980- row_tile_ = C4NUM; 6981- col_tile_ = C8NUM; 6982-#else 6983- row_tile_ = C12NUM; 6984- col_tile_ = C8NUM; 6985-#endif 6986- lstm_param_->input_row_align_ = UP_ROUND(lstm_param_->seq_len_ * lstm_param_->batch_, row_tile_); 6987- lstm_param_->input_col_align_ = UP_ROUND(lstm_param_->hidden_size_, col_tile_); 6988- input_thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(lstm_param_->input_col_align_, col_tile_)); 6989- MS_CHECK_FALSE(input_thread_count_ == 0, RET_ERROR); 6990- input_thread_stride_ = UP_DIV(UP_DIV(lstm_param_->input_col_align_, col_tile_), input_thread_count_); 6991- 6992- state_row_tile_ = row_tile_; 6993- state_col_tile_ = col_tile_; 6994-#ifdef ENABLE_AVX 6995- if (state_is_vec_) { 6996- state_row_tile_ = 1; 6997- state_col_tile_ = C8NUM; 6998- } 6999-#endif 7000- 7001- lstm_param_->state_row_align_ = state_is_vec_ ? 1 : UP_ROUND(lstm_param_->batch_, state_row_tile_); 7002-#ifdef ENABLE_AVX 7003- lstm_param_->state_col_align_ = UP_ROUND(lstm_param_->hidden_size_, state_col_tile_); 7004-#else 7005- lstm_param_->state_col_align_ = 7006- state_is_vec_ ? lstm_param_->hidden_size_ : UP_ROUND(lstm_param_->hidden_size_, state_col_tile_); 7007-#endif 7008- return RET_OK; 7009+constexpr size_t kMindirInputTensorNum = 4; 7010 } 7011- 7012-int LstmCPUKernel::Prepare() { 7013- CHECK_LESS_RETURN(in_tensors_.size(), mindir_input_tensors); 7014- for (size_t i = 0; i < in_tensors_.size(); i++) { 7015- CHECK_NULL_RETURN(in_tensors_.at(i)); 7016- } 7017- CHECK_LESS_RETURN(out_tensors_.size(), DIMENSION_3D); 7018- for (size_t i = 0; i < out_tensors_.size(); i++) { 7019- CHECK_NULL_RETURN(out_tensors_.at(i)); 7020- } 7021- CHECK_NULL_RETURN(lstm_param_); 7022- if (!InferShapeDone()) { 7023- return RET_OK; 7024- } 7025- return ReSize(); 7026-} 7027- 7028-int LstmCPUKernel::ReSize() { 7029- auto ret = InitParam(); 7030- if (ret != RET_OK) { 7031- MS_LOG(ERROR) << "LstmCPUKernel InitParam error."; 7032- return RET_ERROR; 7033+LiteKernel *LstmFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, 7034+ OpParameter *parameter, const lite::InnerContext *ctx, const kernel::KernelKey &desc) { 7035+ if (parameter == nullptr) { 7036+ MS_LOG(ERROR) << "parameter is nullptr."; 7037+ return nullptr; 7038 } 7039- 7040- return RET_OK; 7041-} 7042- 7043-int LstmCPUKernel::MallocRunBuffer(bool is_double) { 7044- bool need_zone = lstm_param_->zoneout_cell_ < -FLT_EPSILON || lstm_param_->zoneout_cell_ > FLT_EPSILON; 7045- size_t whole_size = 0; 7046- std::vector<size_t> segments; 7047- int scale = is_double ? C2NUM : 1; 7048- size_t segment = gate_num * lstm_param_->seq_len_ * lstm_param_->batch_ * 7049- lstm_param_->hidden_size_; // 0: input * weight for result matrix 7050- segments.push_back(segment); 7051- whole_size += segment * scale; 7052- 7053- segment = state_is_vec_ 7054- ? 0 7055- : lstm_param_->state_row_align_ * lstm_param_->project_size_; // 1: state * weight for left matirx 7056- segments.push_back(segment); 7057- whole_size += segment * scale; 7058- 7059- segment = gate_num * lstm_param_->batch_ * lstm_param_->hidden_size_; // 2: state gate buffer 7060- segments.push_back(segment); 7061- whole_size += segment * scale; 7062- 7063- segment = need_zone ? lstm_param_->batch_ * lstm_param_->hidden_size_ : 0; // 3: state_buffer for cell 7064- segments.push_back(segment); 7065- whole_size += segment * scale; 7066- 7067- segment = need_zone ? lstm_param_->batch_ * lstm_param_->project_size_ : 0; // 4: state_buffer for hidden 7068- segments.push_back(segment); 7069- whole_size += segment * scale; 7070- 7071- segment = 0; 7072-#ifdef ENABLE_AVX 7073- bool output_need_packed = lstm_param_->hidden_size_ % state_col_tile_; 7074- if (state_is_vec_ && output_need_packed) { // vec matmul need to malloc dst 7075- int out_channel = lstm_param_->hidden_size_; 7076- int oc_block_num = UP_DIV(out_channel, state_col_tile_); 7077- MS_ASSERT(ms_context_->allocator != nullptr); 7078- segment = lstm_param_->batch_ * oc_block_num * state_col_tile_; // 5: tmp output data 7079+ if (desc.data_type == kTypeUnknown) { 7080+ MS_LOG(WARNING) << "desc data_type is unknown."; 7081 } 7082-#endif 7083- segments.push_back(segment); 7084- whole_size += segment * scale; 7085- 7086- if (in_tensors_.size() == C7NUM) { 7087- segment = state_is_vec_ ? 0 : lstm_param_->state_row_align_ * lstm_param_->hidden_size_ * scale; 7088- segments.push_back(segment); // 6: project-layer input 7089- whole_size += segment; 7090- segment = 0; 7091-#ifdef ENABLE_AVX 7092- segment = 7093- output_need_packed ? lstm_param_->batch_ * UP_ROUND(lstm_param_->project_size_, state_col_tile_) * scale : 0; 7094-#endif 7095- segments.push_back(segment); // 7: project-layer output 7096- whole_size += segment; 7097+ LiteKernel *kernel{nullptr}; 7098+ if (inputs.size() == kMindirInputTensorNum) { 7099+ kernel = new (std::nothrow) 7100+ LstmMindirFp32CPUKernel(parameter, inputs, outputs, static_cast<const lite::InnerContext *>(ctx)); 7101 } else { 7102- (void)segments.insert(segments.end(), C2NUM, 0); 7103- } 7104- 7105- segment = 0; 7106- if (!(in_tensors_.size() > mindir_input_tensors)) { 7107- segment = lstm_param_->batch_ * lstm_param_->hidden_size_; 7108- } 7109- segments.push_back(segment); 7110- whole_size += segment * scale; 7111- 7112- segment = 7113- lstm_param_->input_row_align_ * lstm_param_->input_size_; // input * weight for left matrix, which only once 7114- whole_size += segment; 7115- 7116- auto whole_memory = reinterpret_cast<float *>(ms_context_->allocator->Malloc(whole_size * sizeof(float))); 7117- MS_CHECK_TRUE_MSG(whole_memory != nullptr, RET_ERROR, "LSTM: malloc failed."); 7118- buffer_running_malloc_.push_back(whole_memory); 7119- MS_ASSERT(segments.size() == C9NUM); 7120- auto Allocate = [&whole_memory, &segments](float **buffer) mutable { 7121- for (int i = 0; i < C9NUM; ++i) { 7122- buffer[i] = nullptr; 7123- if (segments[i] == 0) { 7124- continue; 7125- } 7126- buffer[i] = whole_memory; 7127- whole_memory += segments[i]; 7128- } 7129- }; 7130- Allocate(buffer_forward_); 7131- if (is_double) { 7132- Allocate(buffer_backward_); 7133- } 7134- packed_input_ = whole_memory; 7135- return RET_OK; 7136-} 7137- 7138-void LstmCPUKernel::InputWeightMatMul(int task_id) const { 7139- int current_start_oc = task_id * input_thread_stride_ * col_tile_; 7140- int current_rest_oc = 0; 7141- current_rest_oc = lstm_param_->hidden_size_ - current_start_oc; 7142- int cur_oc = MSMIN(input_thread_stride_ * col_tile_, current_rest_oc); 7143- if (cur_oc <= 0) { 7144- return; 7145- } 7146- 7147- auto b = weight_loop_ + current_start_oc * lstm_param_->input_size_; 7148- auto c = gate_loop_ + current_start_oc; 7149- auto bias = (bias_loop_ == nullptr) ? nullptr : bias_loop_ + current_start_oc; 7150- MatMulOpt(packed_input_, b, c, bias, ActType_No, lstm_param_->input_size_, 7151- lstm_param_->seq_len_ * lstm_param_->batch_, cur_oc, lstm_param_->hidden_size_, OutType_Nhwc); 7152-} 7153- 7154-int LstmCPUKernel::DoSequenceLoop(int task_id) { 7155- if (task_id == 0) { 7156- LstmForwardLoop(buffer_forward_); 7157- return RET_OK; 7158- } 7159- if (task_id == 1) { 7160- LstmBackwardLoop(buffer_backward_); 7161- return RET_OK; 7162- } 7163- return RET_ERROR; 7164-} 7165- 7166-int LstmCPUKernel::LstmPreProcessWithInput(const float *weight_i, const float *input_bias, float *dst) { 7167- for (int i = 0; i < gate_num; i++) { 7168- weight_loop_ = weight_i + lstm_param_->input_size_ * lstm_param_->input_col_align_ * i; 7169- bias_loop_ = input_bias + lstm_param_->input_col_align_ * i; 7170- gate_loop_ = dst + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_ * i; 7171- auto ret = ParallelLaunch(this->ms_context_, LstmInputMulWeightRun, this, input_thread_count_); 7172- if (ret != RET_OK) { 7173- return RET_ERROR; 7174- } 7175- } 7176- return RET_OK; 7177-} 7178- 7179-void LstmCPUKernel::LstmUnidirectional(float *output, const float *weight_h, const float *state_bias, 7180- float *hidden_state, float *cell_state, const float *weight_project, 7181- float *intermediate_states, float *buffer[], bool is_backward) { 7182- float *gate = buffer[input_gate_index]; 7183- float *input_gate = gate; 7184- float *forget_gate = gate + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_ * C2NUM; 7185- float *cell_gate = gate + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_ * C3NUM; 7186- float *output_gate = gate + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_; 7187- float *tmp = buffer[tmp_hidden_output_index]; 7188- int dir_mult = lstm_param_->bidirectional_ ? C2NUM : C1NUM; 7189- for (int t = 0; t < lstm_param_->seq_len_; t++) { 7190- int real_t = is_backward ? lstm_param_->seq_len_ - t - C1NUM : t; 7191- float *input_gate_t = input_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t; 7192- float *forget_gate_t = forget_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t; 7193- float *cell_gate_t = cell_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t; 7194- float *output_gate_t = output_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t; 7195- // if ONNX 7196- if (in_tensors_.size() > mindir_input_tensors) { 7197- // Sequence, DirMul, Batch, Hidden 7198- float *output_ptr = output + real_t * lstm_param_->output_step_; 7199- 7200- LstmStepUnit(output_ptr, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, weight_h, state_bias, 7201- weight_project, hidden_state, cell_state, buffer, lstm_param_); 7202- } else { 7203- // Sequence, Batch, DirMul, Hidden 7204- LstmStepUnit(tmp, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, weight_h, state_bias, nullptr, 7205- hidden_state, cell_state, buffer, lstm_param_); 7206- int seq_offset = real_t * lstm_param_->batch_ * dir_mult * lstm_param_->hidden_size_; 7207- for (int b = 0; b < lstm_param_->batch_; b++) { 7208- int batch_offset = b * dir_mult * lstm_param_->hidden_size_; 7209- float *output_ptr = output + seq_offset + batch_offset; 7210- memcpy(output_ptr, tmp + b * lstm_param_->hidden_size_, lstm_param_->hidden_size_ * sizeof(float)); 7211- } 7212- } 7213- if (intermediate_states) { 7214- RecordStates(hidden_state, cell_state, input_gate_t, output_gate_t, forget_gate_t, cell_gate_t, 7215- intermediate_states, real_t); 7216- } 7217- } 7218-} 7219- 7220-void LstmCPUKernel::RecordStates(const float *hidden_state, float *cell_state, float *input_gate, 7221- const float *output_gate, float *forget_gate, const float *cell_gate, 7222- float *intermediate_states, int step) { 7223- float *states = intermediate_states; 7224- auto state_size = lstm_param_->batch_ * lstm_param_->hidden_size_; 7225- if (state_size < 0) { 7226- MS_LOG(ERROR) << "state size should be greater than or equal to zero."; 7227- return; 7228- } 7229- auto stride = step * lstm_param_->output_step_; 7230- auto seq_stride = lstm_param_->seq_len_ * lstm_param_->output_step_; 7231- memcpy(states + stride, hidden_state, state_size * sizeof(float)); 7232- stride += seq_stride; 7233- memcpy(states + stride, cell_state, state_size * sizeof(float)); 7234- stride += seq_stride; 7235- memcpy(states + stride, input_gate, state_size * sizeof(float)); 7236- stride += seq_stride; 7237- memcpy(states + stride, output_gate, state_size * sizeof(float)); 7238- stride += seq_stride; 7239- memcpy(states + stride, forget_gate, state_size * sizeof(float)); 7240- stride += seq_stride; 7241- memcpy(states + stride, cell_gate, state_size * sizeof(float)); 7242-} 7243- 7244-void LstmCPUKernel::LstmForwardLoop(float *buffer[]) { 7245- auto *output = reinterpret_cast<float *>(out_tensors_.at(0)->data()); 7246- auto *hidden_state = reinterpret_cast<float *>(out_tensors_.at(1)->data()); 7247- auto *cell_state = reinterpret_cast<float *>(out_tensors_.at(C2NUM)->data()); 7248- LstmUnidirectional(output, weight_h_ptr_, state_bias_, hidden_state, cell_state, weight_project_ptr_, 7249- intermediate_states_, buffer, false); 7250-} 7251- 7252-void LstmCPUKernel::LstmBackwardLoop(float *buffer[]) { 7253- auto *output = reinterpret_cast<float *>(out_tensors_.at(0)->data()); 7254- auto *hidden_state = reinterpret_cast<float *>(out_tensors_.at(1)->data()); 7255- auto *cell_state = reinterpret_cast<float *>(out_tensors_.at(C2NUM)->data()); 7256- const float *backward_weight_h = weight_h_ptr_ + gate_num * lstm_param_->state_col_align_ * lstm_param_->hidden_size_; 7257- const float *backward_state_bias = state_bias_ + gate_num * lstm_param_->state_col_align_; 7258- float *backward_output = output + lstm_param_->batch_ * lstm_param_->hidden_size_; 7259- if (in_tensors_.size() == mindir_input_tensors) { 7260- backward_output = output + lstm_param_->hidden_size_; 7261- } 7262- float *backward_cell_state = cell_state + lstm_param_->batch_ * lstm_param_->hidden_size_; 7263- float *backward_hidden_state = hidden_state + lstm_param_->batch_ * lstm_param_->hidden_size_; 7264- float *intermediate_states = nullptr; 7265- if (intermediate_states_) { 7266- intermediate_states = intermediate_states_ + lstm_param_->batch_ * lstm_param_->hidden_size_; 7267- } 7268- float *backward_weight_project = 7269- weight_project_ptr_ 7270- ? weight_project_ptr_ + lstm_param_->hidden_size_ * UP_ROUND(lstm_param_->project_size_, col_tile_) 7271- : nullptr; 7272- LstmUnidirectional(backward_output, backward_weight_h, backward_state_bias, backward_hidden_state, 7273- backward_cell_state, backward_weight_project, intermediate_states, buffer, true); 7274-} 7275- 7276-int LstmCPUKernel::ExecuteUnidirectionalOrSingleThread() { 7277- auto ret = LstmPreProcessWithInput(weight_i_ptr_, input_bias_, buffer_forward_[input_gate_index]); 7278- if (ret != RET_OK) { 7279- MS_LOG(ERROR) << "LSTM Forward: Input-MatMul running failed."; 7280- return RET_ERROR; 7281- } 7282- LstmForwardLoop(buffer_forward_); 7283- 7284- // backward 7285- if (lstm_param_->bidirectional_) { 7286- const float *backward_weight_i = 7287- weight_i_ptr_ + gate_num * lstm_param_->input_col_align_ * lstm_param_->input_size_; 7288- const float *backward_input_bias = input_bias_ + gate_num * lstm_param_->input_col_align_; 7289- ret = LstmPreProcessWithInput(backward_weight_i, backward_input_bias, buffer_forward_[input_gate_index]); 7290- if (ret != RET_OK) { 7291- MS_LOG(ERROR) << "LSTM Backward: Input-MatMul running failed."; 7292- return RET_ERROR; 7293- } 7294- LstmBackwardLoop(buffer_forward_); 7295- } 7296- return RET_OK; 7297-} 7298- 7299-int LstmCPUKernel::ExecuteBidirectionalWithMultiThread() { 7300- auto ret = LstmPreProcessWithInput(weight_i_ptr_, input_bias_, buffer_forward_[input_gate_index]); 7301- if (ret != RET_OK) { 7302- MS_LOG(ERROR) << "LSTM Forward: Input-MatMul running failed."; 7303- return RET_ERROR; 7304- } 7305- const float *backward_weight_i = weight_i_ptr_ + gate_num * lstm_param_->input_col_align_ * lstm_param_->input_size_; 7306- const float *backward_input_bias = input_bias_ + gate_num * lstm_param_->input_col_align_; 7307- ret = LstmPreProcessWithInput(backward_weight_i, backward_input_bias, buffer_backward_[input_gate_index]); 7308- if (ret != RET_OK) { 7309- MS_LOG(ERROR) << "LSTM Backward: Input-MatMul running failed."; 7310- return RET_ERROR; 7311- } 7312- ret = ParallelLaunch(this->ms_context_, LstmSequenceLoopRun, this, C2NUM); 7313- if (ret != RET_OK) { 7314- MS_LOG(ERROR) << "LSTM: Do sequence-loop failed."; 7315- } 7316- return ret; 7317-} 7318- 7319-int LstmCPUKernel::Run() { 7320- auto input = in_tensors_.at(0); 7321- auto output = out_tensors_.at(0); 7322- CHECK_NULL_RETURN(input); 7323- CHECK_NULL_RETURN(output); 7324- auto input_ptr = reinterpret_cast<float *>(input->data()); 7325- CHECK_NULL_RETURN(input_ptr); 7326- auto output_ptr = reinterpret_cast<float *>(output->data()); 7327- CHECK_NULL_RETURN(output_ptr); 7328- 7329- auto hidden_state = in_tensors_.at(hidden_state_input_index_); 7330- CHECK_NULL_RETURN(hidden_state->data()); 7331- auto cell_state = in_tensors_.at(cell_state_input_index_); 7332- CHECK_NULL_RETURN(cell_state->data()); 7333- 7334- auto output_hidden_state = out_tensors_[kOutputHiddenStatusIndex]; 7335- CHECK_NULL_RETURN(output_hidden_state->data()); 7336- (void)memcpy(output_hidden_state->data(), hidden_state->data(), hidden_state->ElementsNum() * sizeof(float)); 7337- auto output_cell_state = out_tensors_[kOutputCellStatusIndex]; 7338- CHECK_NULL_RETURN(output_cell_state->data()); 7339- (void)memcpy(output_cell_state->data(), cell_state->data(), cell_state->ElementsNum() * sizeof(float)); 7340- 7341- auto ret = InitInputWeightBias(); 7342- if (ret != RET_OK) { 7343- MS_LOG(ERROR) << "LstmCPUKernel InitInputWeightBias error."; 7344- FreeRunBuffer(); 7345- return RET_ERROR; 7346- } 7347- 7348- ret = InitStateWeightBias(); 7349- if (ret != RET_OK) { 7350- MS_LOG(ERROR) << "LstmCPUKernel InitStateWeightBias error."; 7351- FreeRunBuffer(); 7352- return RET_ERROR; 7353+ kernel = new (std::nothrow) 7354+ LstmNonMindirFp32CPUKernel(parameter, inputs, outputs, static_cast<const lite::InnerContext *>(ctx)); 7355 } 7356- 7357- ret = InitProjectWeight(); 7358- if (ret != RET_OK) { 7359- MS_LOG(ERROR) << "LstmCPUKernel InitProjectWeight error."; 7360- FreeRunBuffer(); 7361- return RET_ERROR; 7362- } 7363- bool is_bidirectional_with_multi_thread = thread_num_ != 1 && lstm_param_->bidirectional_; 7364- ret = MallocRunBuffer(is_bidirectional_with_multi_thread); 7365- if (ret != RET_OK) { 7366- MS_LOG(ERROR) << "LstmCPUKernel MallocRunBuffer Error."; 7367- FreeRunBuffer(); 7368- return RET_ERROR; 7369- } 7370- 7371- PackLstmInput(input_ptr, packed_input_, lstm_param_->seq_len_ * lstm_param_->batch_, lstm_param_->input_size_); 7372- if (IsTrain() && IsTrainable()) { 7373- intermediate_states_ = reinterpret_cast<float *>(out_tensors_[out_intermediate_states_index]->data()); 7374+ if (kernel == nullptr) { 7375+ MS_LOG(ERROR) << "kernel: " << parameter->name_ << "is nullptr."; 7376+ free(parameter); 7377+ return nullptr; 7378 } 7379- CHECK_NULL_RETURN(weight_h_ptr_); 7380- CHECK_NULL_RETURN(weight_i_ptr_); 7381- CHECK_NULL_RETURN(input_bias_); 7382- CHECK_NULL_RETURN(state_bias_); 7383- if (is_bidirectional_with_multi_thread) { 7384- ret = ExecuteBidirectionalWithMultiThread(); 7385- } else { 7386- ret = ExecuteUnidirectionalOrSingleThread(); 7387- } 7388- FreeRunBuffer(); 7389- return ret; 7390+ return kernel; 7391 } 7392- 7393-REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LSTM, LiteKernelCreator<LstmCPUKernel>) 7394+REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LSTM, LstmFp32KernelCreator) 7395 } // namespace mindspore::kernel 7396diff --git a/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_fp32_base.cc b/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_fp32_base.cc 7397new file mode 100644 7398index 00000000..bd0f0e7d 7399--- /dev/null 7400+++ b/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_fp32_base.cc 7401@@ -0,0 +1,398 @@ 7402+/** 7403+ * Copyright 2023 Huawei Technologies Co., Ltd 7404+ * 7405+ * Licensed under the Apache License, Version 2.0 (the "License"); 7406+ * you may not use this file except in compliance with the License. 7407+ * You may obtain a copy of the License at 7408+ * 7409+ * http://www.apache.org/licenses/LICENSE-2.0 7410+ * 7411+ * Unless required by applicable law or agreed to in writing, software 7412+ * distributed under the License is distributed on an "AS IS" BASIS, 7413+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 7414+ * See the License for the specific language governing permissions and 7415+ * limitations under the License. 7416+ */ 7417+ 7418+#include "src/litert/kernel/cpu/fp32/lstm_fp32_base.h" 7419+#include <vector> 7420+#include "include/errorcode.h" 7421+#include "nnacl/fp32/pack_fp32.h" 7422+#include "nnacl/fp32/matmul_fp32.h" 7423+ 7424+using mindspore::lite::RET_ERROR; 7425+using mindspore::lite::RET_MEMORY_FAILED; 7426+using mindspore::lite::RET_OK; 7427+ 7428+namespace mindspore::kernel { 7429+namespace { 7430+constexpr size_t kMindirInputTensorNum = 4; 7431+constexpr int kGateNum = 4; 7432+constexpr int kOutIntermediateStatesIndex = 3; 7433+constexpr int kInputGateIndex = 0; 7434+} // namespace 7435+ 7436+int LstmSequenceLoopRun(void *cdata, int task_id, float, float) { 7437+ auto kernel = reinterpret_cast<LstmFp32BaseCPUKernel *>(cdata); 7438+ CHECK_NULL_RETURN(kernel); 7439+ auto ret = kernel->DoSequenceLoop(task_id); 7440+ if (ret != RET_OK) { 7441+ MS_LOG(ERROR) << "LSTM: Do Sequence-loop failed."; 7442+ } 7443+ return ret; 7444+} 7445+ 7446+int LstmFp32BaseCPUKernel::Prepare() { 7447+ MS_CHECK_TRUE_MSG(in_tensors_.size() == kMindirInputTensorNum || in_tensors_.size() >= C6NUM, 7448+ lite::RET_INPUT_TENSOR_ERROR, "Lstm's input-num is invalid."); 7449+ for (size_t i = 0; i < in_tensors_.size(); i++) { 7450+ CHECK_NULL_RETURN(in_tensors_.at(i)); 7451+ } 7452+ CHECK_LESS_RETURN(out_tensors_.size(), DIMENSION_3D); 7453+ for (size_t i = 0; i < out_tensors_.size(); i++) { 7454+ CHECK_NULL_RETURN(out_tensors_.at(i)); 7455+ } 7456+ CHECK_NULL_RETURN(lstm_param_); 7457+ if (!InferShapeDone()) { 7458+ return RET_OK; 7459+ } 7460+ return ReSize(); 7461+} 7462+ 7463+int LstmFp32BaseCPUKernel::ReSize() { 7464+ auto input = in_tensors_.front(); 7465+ std::vector<int> in_shape = input->shape(); 7466+ MS_CHECK_TRUE_MSG(in_shape.size() == C3NUM, lite::RET_INPUT_TENSOR_ERROR, 7467+ "The dims of LSTM's first input must be 3."); 7468+ lstm_param_->seq_len_ = in_shape.at(FIRST_INPUT); 7469+ lstm_param_->batch_ = in_shape.at(SECOND_INPUT); 7470+ lstm_param_->input_size_ = in_shape.at(THIRD_INPUT); 7471+ 7472+ auto h_init_shape = in_tensors_.at(hidden_init_index_)->shape(); 7473+ auto c_init_shape = in_tensors_.at(cell_init_index_)->shape(); 7474+ lstm_param_->hidden_size_ = c_init_shape.back(); 7475+ lstm_param_->output_size_ = h_init_shape.back(); 7476+ 7477+ lstm_param_->output_step_ = lstm_param_->bidirectional_ ? C2NUM * lstm_param_->batch_ * lstm_param_->output_size_ 7478+ : lstm_param_->batch_ * lstm_param_->output_size_; 7479+ weight_segment_num_ = lstm_param_->bidirectional_ ? C2NUM * kGateNum : kGateNum; 7480+ 7481+#ifdef ENABLE_AVX 7482+ row_tile_ = C6NUM; 7483+ col_tile_ = C16NUM; 7484+#elif defined(ENABLE_ARM32) 7485+ row_tile_ = C12NUM; 7486+ col_tile_ = C4NUM; 7487+#elif defined(ENABLE_SSE) 7488+ row_tile_ = C4NUM; 7489+ col_tile_ = C8NUM; 7490+#else 7491+ row_tile_ = C12NUM; 7492+ col_tile_ = C8NUM; 7493+#endif 7494+ lstm_param_->input_row_align_ = UP_ROUND(lstm_param_->seq_len_ * lstm_param_->batch_, row_tile_); 7495+ lstm_param_->input_col_align_ = UP_ROUND(lstm_param_->hidden_size_, col_tile_); 7496+ 7497+ state_row_tile_ = row_tile_; 7498+ state_col_tile_ = col_tile_; 7499+#ifdef ENABLE_AVX 7500+ if (lstm_param_->batch_ == 1) { 7501+ state_row_tile_ = 1; 7502+ state_col_tile_ = C8NUM; 7503+ } 7504+#endif 7505+ 7506+ lstm_param_->state_row_align_ = lstm_param_->batch_ == 1 ? 1 : UP_ROUND(lstm_param_->batch_, state_row_tile_); 7507+#ifdef ENABLE_AVX 7508+ lstm_param_->state_col_align_ = UP_ROUND(lstm_param_->hidden_size_, state_col_tile_); 7509+ lstm_param_->proj_col_align_ = UP_ROUND(lstm_param_->output_size_, state_col_tile_); 7510+#else 7511+ lstm_param_->state_col_align_ = 7512+ lstm_param_->batch_ == 1 ? lstm_param_->hidden_size_ : UP_ROUND(lstm_param_->hidden_size_, state_col_tile_); 7513+ lstm_param_->proj_col_align_ = 7514+ lstm_param_->batch_ == 1 ? lstm_param_->output_size_ : UP_ROUND(lstm_param_->output_size_, state_col_tile_); 7515+#endif 7516+ return RET_OK; 7517+} 7518+ 7519+int LstmFp32BaseCPUKernel::Run() { 7520+ auto input = in_tensors_.at(FIRST_INPUT); 7521+ auto output = out_tensors_.at(FIRST_INPUT); 7522+ auto input_ptr = reinterpret_cast<float *>(input->data()); 7523+ CHECK_NULL_RETURN(input_ptr); 7524+ auto output_ptr = reinterpret_cast<float *>(output->data()); 7525+ CHECK_NULL_RETURN(output_ptr); 7526+ 7527+ auto hidden_state = in_tensors_.at(hidden_init_index_); 7528+ CHECK_NULL_RETURN(hidden_state->data()); 7529+ auto cell_state = in_tensors_.at(cell_init_index_); 7530+ CHECK_NULL_RETURN(cell_state->data()); 7531+ 7532+ auto output_hidden_state = out_tensors_[SECOND_INPUT]; 7533+ CHECK_NULL_RETURN(output_hidden_state->data()); 7534+ (void)memcpy(output_hidden_state->data(), hidden_state->data(), hidden_state->ElementsNum() * sizeof(float)); 7535+ auto output_cell_state = out_tensors_[THIRD_INPUT]; 7536+ CHECK_NULL_RETURN(output_cell_state->data()); 7537+ (void)memcpy(output_cell_state->data(), cell_state->data(), cell_state->ElementsNum() * sizeof(float)); 7538+ 7539+ auto ret = InitInputWeightBias(); 7540+ if (ret != RET_OK) { 7541+ MS_LOG(ERROR) << "LstmCPUKernel InitInputWeightBias error."; 7542+ FreeRunBuffer(); 7543+ return RET_ERROR; 7544+ } 7545+ 7546+ ret = InitStateWeightBias(); 7547+ if (ret != RET_OK) { 7548+ MS_LOG(ERROR) << "LstmCPUKernel InitStateWeightBias error."; 7549+ FreeRunBuffer(); 7550+ return RET_ERROR; 7551+ } 7552+ 7553+ ret = InitProjectWeight(); 7554+ if (ret != RET_OK) { 7555+ MS_LOG(ERROR) << "LstmCPUKernel InitProjectWeight error."; 7556+ FreeRunBuffer(); 7557+ return RET_ERROR; 7558+ } 7559+ bool is_bidirectional_with_multi_thread = thread_num_ != 1 && lstm_param_->bidirectional_; 7560+ ret = MallocRunBuffer(is_bidirectional_with_multi_thread); 7561+ if (ret != RET_OK) { 7562+ MS_LOG(ERROR) << "LstmCPUKernel MallocRunBuffer Error."; 7563+ FreeRunBuffer(); 7564+ return RET_ERROR; 7565+ } 7566+ 7567+ PackLstmInput(input_ptr, packed_input_, lstm_param_->seq_len_ * lstm_param_->batch_, lstm_param_->input_size_); 7568+ if (IsTrain() && IsTrainable()) { 7569+ intermediate_states_ = reinterpret_cast<float *>(out_tensors_[kOutIntermediateStatesIndex]->data()); 7570+ } 7571+ CHECK_NULL_RETURN(weight_h_ptr_); 7572+ CHECK_NULL_RETURN(weight_i_ptr_); 7573+ CHECK_NULL_RETURN(input_bias_); 7574+ CHECK_NULL_RETURN(state_bias_); 7575+ if (is_bidirectional_with_multi_thread) { 7576+ ret = ExecuteBidirectionalWithMultiThread(); 7577+ } else { 7578+ ret = ExecuteUnidirectionalOrSingleThread(); 7579+ } 7580+ FreeRunBuffer(); 7581+ return ret; 7582+} 7583+ 7584+void LstmFp32BaseCPUKernel::FreeRunBuffer() { 7585+ for (auto data : running_buffer_) { 7586+ ms_context_->allocator->Free(data); 7587+ } 7588+ running_buffer_.clear(); 7589+} 7590+ 7591+int LstmFp32BaseCPUKernel::MallocRunBuffer(bool is_double) { 7592+ bool need_zone = lstm_param_->zoneout_cell_ < -FLT_EPSILON || lstm_param_->zoneout_cell_ > FLT_EPSILON; 7593+ size_t whole_size = 0; 7594+ std::vector<size_t> segments; 7595+ int scale = is_double ? C2NUM : 1; 7596+ size_t segment = kGateNum * lstm_param_->seq_len_ * lstm_param_->batch_ * 7597+ lstm_param_->hidden_size_; // 0: input * weight for result matrix 7598+ segments.push_back(segment); 7599+ whole_size += segment * scale; 7600+ 7601+ segment = lstm_param_->batch_ == 1 7602+ ? 0 7603+ : lstm_param_->state_row_align_ * lstm_param_->output_size_; // 1: state * weight for left matirx 7604+ segments.push_back(segment); 7605+ whole_size += segment * scale; 7606+ 7607+ segment = kGateNum * lstm_param_->batch_ * lstm_param_->hidden_size_; // 2: state gate buffer 7608+ segments.push_back(segment); 7609+ whole_size += segment * scale; 7610+ 7611+ segment = need_zone ? lstm_param_->batch_ * lstm_param_->hidden_size_ : 0; // 3: state_buffer for cell 7612+ segments.push_back(segment); 7613+ whole_size += segment * scale; 7614+ 7615+ segment = need_zone ? lstm_param_->batch_ * lstm_param_->output_size_ : 0; // 4: state_buffer for hidden 7616+ segments.push_back(segment); 7617+ whole_size += segment * scale; 7618+ 7619+ segment = 0; 7620+#ifdef ENABLE_AVX 7621+ bool output_need_packed = lstm_param_->hidden_size_ % state_col_tile_; 7622+ if (lstm_param_->batch_ == 1 && output_need_packed) { // vec matmul need to malloc dst 7623+ int out_channel = lstm_param_->hidden_size_; 7624+ int oc_block_num = UP_DIV(out_channel, state_col_tile_); 7625+ MS_ASSERT(ms_context_->allocator != nullptr); 7626+ segment = lstm_param_->batch_ * oc_block_num * state_col_tile_; // 5: tmp output data 7627+ } 7628+#endif 7629+ segments.push_back(segment); 7630+ whole_size += segment * scale; 7631+ 7632+ if (in_tensors_.size() == C7NUM || lstm_param_->project_size_ != 0) { 7633+ segment = lstm_param_->batch_ == 1 ? 0 : lstm_param_->state_row_align_ * lstm_param_->hidden_size_ * scale; 7634+ segments.push_back(segment); // 6: project-layer input 7635+ whole_size += segment; 7636+ segment = 0; 7637+#ifdef ENABLE_AVX 7638+ segment = 7639+ output_need_packed ? lstm_param_->batch_ * UP_ROUND(lstm_param_->output_size_, state_col_tile_) * scale : 0; 7640+#endif 7641+ segments.push_back(segment); // 7: project-layer output 7642+ whole_size += segment; 7643+ } else { 7644+ (void)segments.insert(segments.end(), C2NUM, 0); 7645+ } 7646+ 7647+ segment = 0; 7648+ if (in_tensors_.size() == kMindirInputTensorNum) { 7649+ segment = lstm_param_->batch_ * lstm_param_->output_size_; 7650+ } 7651+ segments.push_back(segment); 7652+ whole_size += segment * scale; 7653+ 7654+ segment = 7655+ lstm_param_->input_row_align_ * lstm_param_->input_size_; // input * weight for left matrix, which only once 7656+ whole_size += segment; 7657+ 7658+ auto whole_memory = reinterpret_cast<float *>(ms_context_->allocator->Malloc(whole_size * sizeof(float))); 7659+ MS_CHECK_TRUE_MSG(whole_memory != nullptr, RET_ERROR, "LSTM: malloc failed."); 7660+ running_buffer_.push_back(whole_memory); 7661+ MS_ASSERT(segments.size() == C9NUM); 7662+ auto Allocate = [&whole_memory, &segments](float **buffer) mutable { 7663+ for (int i = 0; i < C9NUM; ++i) { 7664+ buffer[i] = nullptr; 7665+ if (segments[i] == 0) { 7666+ continue; 7667+ } 7668+ buffer[i] = whole_memory; 7669+ whole_memory += segments[i]; 7670+ } 7671+ }; 7672+ Allocate(buffer_forward_); 7673+ if (is_double) { 7674+ Allocate(buffer_backward_); 7675+ } 7676+ packed_input_ = whole_memory; 7677+ return RET_OK; 7678+} 7679+ 7680+int LstmFp32BaseCPUKernel::ExecuteBidirectionalWithMultiThread() { 7681+ auto ret = LstmPreProcessWithInput(weight_i_ptr_, input_bias_, buffer_forward_[kInputGateIndex]); 7682+ if (ret != RET_OK) { 7683+ MS_LOG(ERROR) << "LSTM Forward: Input-MatMul running failed."; 7684+ return RET_ERROR; 7685+ } 7686+ const float *backward_weight_i = weight_i_ptr_ + kGateNum * lstm_param_->input_col_align_ * lstm_param_->input_size_; 7687+ const float *backward_input_bias = input_bias_ + kGateNum * lstm_param_->input_col_align_; 7688+ ret = LstmPreProcessWithInput(backward_weight_i, backward_input_bias, buffer_backward_[kInputGateIndex]); 7689+ if (ret != RET_OK) { 7690+ MS_LOG(ERROR) << "LSTM Backward: Input-MatMul running failed."; 7691+ return RET_ERROR; 7692+ } 7693+ ret = ParallelLaunch(this->ms_context_, LstmSequenceLoopRun, this, C2NUM); 7694+ if (ret != RET_OK) { 7695+ MS_LOG(ERROR) << "LSTM: Do sequence-loop failed."; 7696+ } 7697+ return ret; 7698+} 7699+ 7700+int LstmFp32BaseCPUKernel::ExecuteUnidirectionalOrSingleThread() { 7701+ auto ret = LstmPreProcessWithInput(weight_i_ptr_, input_bias_, buffer_forward_[kInputGateIndex]); 7702+ if (ret != RET_OK) { 7703+ MS_LOG(ERROR) << "LSTM Forward: Input-MatMul running failed."; 7704+ return RET_ERROR; 7705+ } 7706+ LstmForwardLoop(buffer_forward_); 7707+ 7708+ // backward 7709+ if (lstm_param_->bidirectional_) { 7710+ const float *backward_weight_i = 7711+ weight_i_ptr_ + kGateNum * lstm_param_->input_col_align_ * lstm_param_->input_size_; 7712+ const float *backward_input_bias = input_bias_ + kGateNum * lstm_param_->input_col_align_; 7713+ ret = LstmPreProcessWithInput(backward_weight_i, backward_input_bias, buffer_forward_[kInputGateIndex]); 7714+ if (ret != RET_OK) { 7715+ MS_LOG(ERROR) << "LSTM Backward: Input-MatMul running failed."; 7716+ return RET_ERROR; 7717+ } 7718+ LstmBackwardLoop(buffer_forward_); 7719+ } 7720+ return RET_OK; 7721+} 7722+ 7723+int LstmFp32BaseCPUKernel::LstmPreProcessWithInput(const float *weight_i, const float *input_bias, float *dst) { 7724+ const float *weight{nullptr}; 7725+ const float *bias{nullptr}; 7726+ float *gate{nullptr}; 7727+ int thread_num = MSMIN(op_parameter_->thread_num_, UP_DIV(lstm_param_->input_col_align_, col_tile_)); 7728+ MS_CHECK_FALSE(thread_num == 0, RET_ERROR); 7729+ int stride = UP_DIV(UP_DIV(lstm_param_->input_col_align_, col_tile_), thread_num); 7730+ auto MatMulCoreFunc = [this, &weight, &bias, &gate, &stride](void *, int task_id, float, float) { 7731+ int current_start_oc = task_id * stride * col_tile_; 7732+ int current_rest_oc = 0; 7733+ current_rest_oc = lstm_param_->hidden_size_ - current_start_oc; 7734+ int cur_oc = MSMIN(stride * col_tile_, current_rest_oc); 7735+ if (cur_oc <= 0) { 7736+ return RET_OK; 7737+ } 7738+ 7739+ auto b = weight + current_start_oc * lstm_param_->input_size_; 7740+ auto c = gate + current_start_oc; 7741+ auto bias_ = (bias == nullptr) ? nullptr : bias + current_start_oc; 7742+ MatMulOpt(packed_input_, b, c, bias_, ActType_No, lstm_param_->input_size_, 7743+ lstm_param_->seq_len_ * lstm_param_->batch_, cur_oc, lstm_param_->hidden_size_, OutType_Nhwc); 7744+ return RET_OK; 7745+ }; 7746+ for (int i = 0; i < kGateNum; i++) { 7747+ weight = weight_i + lstm_param_->input_size_ * lstm_param_->input_col_align_ * i; 7748+ bias = input_bias + lstm_param_->input_col_align_ * i; 7749+ gate = dst + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_ * i; 7750+ auto ret = ParallelLaunch(this->ms_context_, MatMulCoreFunc, nullptr, thread_num); 7751+ if (ret != RET_OK) { 7752+ return RET_ERROR; 7753+ } 7754+ } 7755+ return RET_OK; 7756+} 7757+ 7758+int LstmFp32BaseCPUKernel::DoSequenceLoop(int task_id) { 7759+ if (task_id == 0) { 7760+ LstmForwardLoop(buffer_forward_); 7761+ return RET_OK; 7762+ } 7763+ if (task_id == 1) { 7764+ LstmBackwardLoop(buffer_backward_); 7765+ return RET_OK; 7766+ } 7767+ return RET_ERROR; 7768+} 7769+ 7770+void LstmFp32BaseCPUKernel::LstmForwardLoop(float *buffer[]) { 7771+ auto *output = reinterpret_cast<float *>(out_tensors_.at(FIRST_INPUT)->data()); 7772+ auto *hidden_state = reinterpret_cast<float *>(out_tensors_.at(SECOND_INPUT)->data()); 7773+ auto *cell_state = reinterpret_cast<float *>(out_tensors_.at(THIRD_INPUT)->data()); 7774+ LstmUnidirectional(output, weight_h_ptr_, state_bias_, hidden_state, cell_state, weight_project_ptr_, 7775+ intermediate_states_, buffer, false); 7776+} 7777+ 7778+void LstmFp32BaseCPUKernel::LstmBackwardLoop(float *buffer[]) { 7779+ auto *output = reinterpret_cast<float *>(out_tensors_.at(0)->data()); 7780+ auto *hidden_state = reinterpret_cast<float *>(out_tensors_.at(1)->data()); 7781+ auto *cell_state = reinterpret_cast<float *>(out_tensors_.at(C2NUM)->data()); 7782+ const float *backward_weight_h = weight_h_ptr_ + kGateNum * lstm_param_->state_col_align_ * lstm_param_->output_size_; 7783+ const float *backward_state_bias = state_bias_ + kGateNum * lstm_param_->state_col_align_; 7784+ float *backward_output = output + lstm_param_->batch_ * lstm_param_->output_size_; 7785+ if (in_tensors_.size() == kMindirInputTensorNum) { 7786+ backward_output = output + lstm_param_->output_size_; 7787+ } 7788+ float *backward_cell_state = cell_state + lstm_param_->batch_ * lstm_param_->hidden_size_; 7789+ float *backward_hidden_state = hidden_state + lstm_param_->batch_ * lstm_param_->output_size_; 7790+ float *intermediate_states = nullptr; 7791+ if (intermediate_states_) { 7792+ intermediate_states = intermediate_states_ + lstm_param_->batch_ * lstm_param_->output_size_; 7793+ } 7794+ float *backward_weight_project = 7795+ weight_project_ptr_ ? weight_project_ptr_ + lstm_param_->hidden_size_ * lstm_param_->proj_col_align_ : nullptr; 7796+ LstmUnidirectional(backward_output, backward_weight_h, backward_state_bias, backward_hidden_state, 7797+ backward_cell_state, backward_weight_project, intermediate_states, buffer, true); 7798+} 7799+} // namespace mindspore::kernel 7800diff --git a/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_fp32_base.h b/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_fp32_base.h 7801new file mode 100644 7802index 00000000..c3c10cea 7803--- /dev/null 7804+++ b/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_fp32_base.h 7805@@ -0,0 +1,78 @@ 7806+/** 7807+ * Copyright 2023 Huawei Technologies Co., Ltd 7808+ * 7809+ * Licensed under the Apache License, Version 2.0 (the "License"); 7810+ * you may not use this file except in compliance with the License. 7811+ * You may obtain a copy of the License at 7812+ * 7813+ * http://www.apache.org/licenses/LICENSE-2.0 7814+ * 7815+ * Unless required by applicable law or agreed to in writing, software 7816+ * distributed under the License is distributed on an "AS IS" BASIS, 7817+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 7818+ * See the License for the specific language governing permissions and 7819+ * limitations under the License. 7820+ */ 7821+ 7822+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_LSTM_FP32_BASE_H_ 7823+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_LSTM_FP32_BASE_H_ 7824+ 7825+#include <vector> 7826+#include "src/litert/lite_kernel.h" 7827+#include "nnacl/fp32/lstm_fp32.h" 7828+ 7829+namespace mindspore::kernel { 7830+class LstmFp32BaseCPUKernel : public LiteKernel { 7831+ public: 7832+ LstmFp32BaseCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, 7833+ const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) 7834+ : LiteKernel(parameter, inputs, outputs, ctx) { 7835+ lstm_param_ = reinterpret_cast<LstmParameter *>(op_parameter_); 7836+ } 7837+ 7838+ ~LstmFp32BaseCPUKernel() override = default; 7839+ 7840+ int Prepare() override; 7841+ int ReSize() override; 7842+ int Run() override; 7843+ int DoSequenceLoop(int task_id); 7844+ 7845+ protected: 7846+ virtual int InitInputWeightBias() = 0; 7847+ virtual int InitStateWeightBias() = 0; 7848+ virtual int InitProjectWeight() = 0; 7849+ virtual void LstmUnidirectional(float *output, const float *weight_h, const float *state_bias, float *hidden_state, 7850+ float *cell_state, const float *weight_project, float *intermediate_states, 7851+ float *buffer[], bool is_backward) = 0; 7852+ 7853+ int hidden_init_index_{0}; 7854+ int cell_init_index_{0}; 7855+ int row_tile_{0}; 7856+ int col_tile_{0}; 7857+ int state_row_tile_{0}; 7858+ int state_col_tile_{0}; 7859+ int weight_segment_num_{0}; 7860+ float *weight_i_ptr_{nullptr}; 7861+ float *weight_h_ptr_{nullptr}; 7862+ float *weight_project_ptr_{nullptr}; 7863+ float *input_bias_{nullptr}; 7864+ float *state_bias_{nullptr}; 7865+ LstmParameter *lstm_param_{nullptr}; 7866+ std::vector<void *> running_buffer_; 7867+ 7868+ private: 7869+ void FreeRunBuffer(); 7870+ int MallocRunBuffer(bool is_double); 7871+ int ExecuteBidirectionalWithMultiThread(); 7872+ int ExecuteUnidirectionalOrSingleThread(); 7873+ int LstmPreProcessWithInput(const float *weight_i, const float *input_bias, float *dst); 7874+ void LstmForwardLoop(float *buffer[]); 7875+ void LstmBackwardLoop(float *buffer[]); 7876+ float *packed_input_{nullptr}; 7877+ float *intermediate_states_{nullptr}; 7878+ float *buffer_forward_[C9NUM] = {nullptr}; 7879+ float *buffer_backward_[C9NUM] = {nullptr}; 7880+}; 7881+} // namespace mindspore::kernel 7882+ 7883+#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_LSTM_FP32_BASE_H_ 7884diff --git a/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_mindir_fp32.cc b/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_mindir_fp32.cc 7885new file mode 100644 7886index 00000000..476d5940 7887--- /dev/null 7888+++ b/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_mindir_fp32.cc 7889@@ -0,0 +1,266 @@ 7890+/** 7891+ * Copyright 2023 Huawei Technologies Co., Ltd 7892+ * 7893+ * Licensed under the Apache License, Version 2.0 (the "License"); 7894+ * you may not use this file except in compliance with the License. 7895+ * You may obtain a copy of the License at 7896+ * 7897+ * http://www.apache.org/licenses/LICENSE-2.0 7898+ * 7899+ * Unless required by applicable law or agreed to in writing, software 7900+ * distributed under the License is distributed on an "AS IS" BASIS, 7901+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 7902+ * See the License for the specific language governing permissions and 7903+ * limitations under the License. 7904+ */ 7905+ 7906+#include "src/litert/kernel/cpu/fp32/lstm_mindir_fp32.h" 7907+#include "nnacl/fp32/pack_fp32.h" 7908+ 7909+namespace mindspore::kernel { 7910+namespace { 7911+constexpr int kInputGateIndex = 0; 7912+constexpr int kTempHiddenOutputIndex = 8; 7913+constexpr int kGateNum = 4; 7914+constexpr int kWeightsIndex = 3; 7915+const int kWeightsOrderMap[8] = {0, 2, 3, 1, 4, 6, 7, 5}; // IFGO order to IOFG order 7916+} // namespace 7917+ 7918+int LstmMindirFp32CPUKernel::ReSize() { 7919+ auto ret = LstmFp32BaseCPUKernel::ReSize(); 7920+ if (ret != lite::RET_OK) { 7921+ MS_LOG(ERROR) << "LstmMindirFp32CPUKernel resize failed."; 7922+ return ret; 7923+ } 7924+ // determine FB origin 7925+ gpu_orig_state_ = false; 7926+ auto weight_t = in_tensors_.at(kWeightsIndex); 7927+ MS_CHECK_INT_MUL_NOT_OVERFLOW(lstm_param_->hidden_size_, lstm_param_->input_size_, lite::RET_ERROR); 7928+ int hi_unit_size = lstm_param_->hidden_size_ * lstm_param_->input_size_; 7929+ MS_CHECK_INT_MUL_NOT_OVERFLOW(weight_segment_num_, hi_unit_size, lite::RET_ERROR); 7930+ int hi_whole_size = weight_segment_num_ * hi_unit_size; 7931+ MS_CHECK_INT_MUL_NOT_OVERFLOW(lstm_param_->hidden_size_, lstm_param_->output_size_, lite::RET_ERROR); 7932+ int hh_unit_size = lstm_param_->hidden_size_ * lstm_param_->output_size_; 7933+ MS_CHECK_INT_MUL_NOT_OVERFLOW(weight_segment_num_, hh_unit_size, lite::RET_ERROR); 7934+ int hh_whole_size = weight_segment_num_ * hh_unit_size; 7935+ int scale = lstm_param_->bidirectional_ ? C2NUM : C1NUM; 7936+ MS_CHECK_INT_MUL_NOT_OVERFLOW(lstm_param_->hidden_size_, lstm_param_->project_size_, lite::RET_ERROR); 7937+ int hp_unit_size = lstm_param_->hidden_size_ * lstm_param_->project_size_; 7938+ MS_CHECK_INT_MUL_NOT_OVERFLOW(scale, hp_unit_size, lite::RET_ERROR); 7939+ int hp_whole_size = scale * hp_unit_size; 7940+ MS_CHECK_INT_MUL_NOT_OVERFLOW(weight_segment_num_ * C2NUM, lstm_param_->hidden_size_, lite::RET_ERROR); 7941+ int bias_whole_size = weight_segment_num_ * C2NUM * lstm_param_->hidden_size_; 7942+ auto whole_size = weight_t->ElementsNum(); 7943+ bool has_bias = (hi_whole_size + hh_whole_size + hp_whole_size < whole_size) ? true : false; 7944+ // if bias exist we can determine the gpu_orig_state_ 7945+ if (has_bias) { 7946+ gpu_orig_state_ = (hi_whole_size + hh_whole_size + hp_whole_size + bias_whole_size == whole_size) ? true : false; 7947+ } else { 7948+ bias_whole_size = 0; 7949+ } 7950+ if (gpu_orig_state_) { 7951+ return lite::RET_OK; 7952+ } 7953+ bias_whole_size /= C2NUM; 7954+ if (hi_whole_size + hh_whole_size + hp_whole_size + bias_whole_size != whole_size) { 7955+ MS_LOG(ERROR) << "LstmMindir is invalid when original model exports from CPU."; 7956+ return lite::RET_INPUT_TENSOR_ERROR; 7957+ } 7958+ return lite::RET_OK; 7959+} 7960+ 7961+int LstmMindirFp32CPUKernel::InitInputWeightBias() { 7962+ // malloc and init input * weight right matrix buffer 7963+ // input -- row: seq_len * batch; col: input_size 7964+ // weight -- row: hidden_size; col: input_size, need transpose 7965+ // result -- row: seq_len * batch; col: hidden_size 7966+ weight_i_ptr_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc( 7967+ weight_segment_num_ * lstm_param_->input_col_align_ * lstm_param_->input_size_ * sizeof(float))); 7968+ MS_CHECK_TRUE_MSG(weight_i_ptr_ != nullptr, lite::RET_NULL_PTR, "LstmMindirCPUKernel malloc weight_i_ptr_ failed."); 7969+ running_buffer_.push_back(weight_i_ptr_); 7970+ auto weight_data = reinterpret_cast<float *>(in_tensors_.at(kWeightsIndex)->data()); 7971+ CHECK_NULL_RETURN(weight_data); 7972+ 7973+ int hi_unit_size = lstm_param_->input_size_ * lstm_param_->hidden_size_; 7974+ int hh_unit_size = lstm_param_->hidden_size_ * lstm_param_->output_size_; 7975+ int stride = (gpu_orig_state_) ? kGateNum * (hi_unit_size + hh_unit_size) : kGateNum * hi_unit_size; 7976+ PackLstmWeightWithStride(weight_i_ptr_, weight_data, weight_segment_num_, lstm_param_->input_size_, 7977+ lstm_param_->hidden_size_, lstm_param_->input_col_align_, lstm_param_->bidirectional_, 7978+ stride, kWeightsOrderMap); 7979+ // input bias 7980+ auto bias_size = weight_segment_num_ * lstm_param_->input_col_align_ * sizeof(float); 7981+ input_bias_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(bias_size)); 7982+ MS_CHECK_TRUE_MSG(input_bias_ != nullptr, lite::RET_NULL_PTR, "LstmMindirCPUKernel malloc input_bias_ failed."); 7983+ memset(input_bias_, 0, bias_size); 7984+ running_buffer_.push_back(input_bias_); 7985+ if (!lstm_param_->has_bias_) { 7986+ return RET_OK; 7987+ } 7988+ int scale = lstm_param_->bidirectional_ ? C2NUM : C1NUM; 7989+ int offset = weight_segment_num_ * (hi_unit_size + hh_unit_size) + 7990+ scale * lstm_param_->project_size_ * lstm_param_->hidden_size_; 7991+ float *bias_data = weight_data + offset; 7992+ int b_stride = 7993+ (gpu_orig_state_) ? kGateNum * (scale * lstm_param_->hidden_size_) : kGateNum * (lstm_param_->hidden_size_); 7994+ PackLstmBiasWithStride(input_bias_, bias_data, weight_segment_num_, lstm_param_->hidden_size_, 7995+ lstm_param_->input_col_align_, lstm_param_->bidirectional_, b_stride, kWeightsOrderMap); 7996+ return RET_OK; 7997+} 7998+ 7999+int LstmMindirFp32CPUKernel::InitStateWeightBias() { 8000+ // malloc and init state * weight right matrix buffer, state * weight will be executed seq_len_ times. 8001+ // state -- row: batch; col: hidden_size 8002+ // weight -- row: hidden_size; col: hidden_size, need transpose 8003+ // result -- row: batch; col: hidden_size 8004+ auto weight_data = (reinterpret_cast<float *>(in_tensors_.at(kWeightsIndex)->data())); 8005+ CHECK_NULL_RETURN(weight_data); 8006+ 8007+ int hi_unit_size = lstm_param_->input_size_ * lstm_param_->hidden_size_; 8008+ int hh_unit_size = lstm_param_->hidden_size_ * lstm_param_->output_size_; 8009+ int stride = (gpu_orig_state_) ? kGateNum * (hi_unit_size + hh_unit_size) : kGateNum * hh_unit_size; 8010+ 8011+ auto weight_h_data = weight_data + (gpu_orig_state_ ? kGateNum * hi_unit_size : weight_segment_num_ * hi_unit_size); 8012+ 8013+ auto weight_unit_pack_size = sizeof(float) * lstm_param_->state_col_align_ * lstm_param_->output_size_; 8014+ auto weight_pack_size = weight_segment_num_ * weight_unit_pack_size; 8015+ weight_h_ptr_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(weight_pack_size)); 8016+ MS_CHECK_TRUE_MSG(weight_h_ptr_ != nullptr, lite::RET_NULL_PTR, "LstmMindirCPUKernel malloc weight_h_ptr_ failed."); 8017+ running_buffer_.push_back(weight_h_ptr_); 8018+ if (lstm_param_->batch_ != 1) { 8019+ PackLstmWeightWithStride(weight_h_ptr_, weight_h_data, weight_segment_num_, lstm_param_->output_size_, 8020+ lstm_param_->hidden_size_, lstm_param_->state_col_align_, lstm_param_->bidirectional_, 8021+ stride, kWeightsOrderMap); 8022+ } else { 8023+ for (int i = 0; i < weight_segment_num_; i++) { 8024+ const float *src_batch = weight_h_data + i * lstm_param_->hidden_size_ * lstm_param_->output_size_; 8025+ float *dst_batch = 8026+ weight_h_ptr_ + kWeightsOrderMap[i] * lstm_param_->state_col_align_ * lstm_param_->output_size_; 8027+#ifdef ENABLE_AVX 8028+ RowMajor2Col32Major(src_batch, dst_batch, lstm_param_->hidden_size_, lstm_param_->output_size_); 8029+#else 8030+ (void)memcpy(dst_batch, src_batch, weight_unit_pack_size); 8031+#endif 8032+ } 8033+ } 8034+ 8035+ // state bias 8036+ auto bias_pack_size = weight_segment_num_ * lstm_param_->state_col_align_ * sizeof(float); 8037+ state_bias_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(bias_pack_size)); 8038+ MS_CHECK_TRUE_MSG(state_bias_ != nullptr, lite::RET_NULL_PTR, "LstmMindirCPUKernel malloc state_bias_ failed."); 8039+ memset(state_bias_, 0, bias_pack_size); 8040+ running_buffer_.push_back(state_bias_); 8041+ if (!lstm_param_->has_bias_ || !gpu_orig_state_) { 8042+ return RET_OK; 8043+ } 8044+ 8045+ int hi_whole_size = weight_segment_num_ * lstm_param_->hidden_size_ * lstm_param_->input_size_; 8046+ int hh_whole_size = weight_segment_num_ * lstm_param_->hidden_size_ * lstm_param_->output_size_; 8047+ int proj_size = 8048+ (lstm_param_->bidirectional_ ? C2NUM : C1NUM) * lstm_param_->project_size_ * lstm_param_->hidden_size_; 8049+ // mindir from device "GPU", secend bias is also present order IFOG 8050+ int bias_offset = hi_whole_size + hh_whole_size + proj_size + lstm_param_->hidden_size_ * kGateNum; 8051+ float *state_bias = weight_data + bias_offset; 8052+ int b_stride = kGateNum * lstm_param_->hidden_size_ * C2NUM; 8053+ PackLstmBiasWithStride(state_bias_, state_bias, weight_segment_num_, lstm_param_->hidden_size_, 8054+ lstm_param_->state_col_align_, lstm_param_->bidirectional_, b_stride, kWeightsOrderMap); 8055+ return RET_OK; 8056+} 8057+ 8058+int LstmMindirFp32CPUKernel::InitProjectWeight() { 8059+ if (lstm_param_->project_size_ == 0) { 8060+ return RET_OK; 8061+ } 8062+ auto weight_data = (reinterpret_cast<float *>(in_tensors_.at(kWeightsIndex)->data())); 8063+ CHECK_NULL_RETURN(weight_data); 8064+ int hi_whole_size = weight_segment_num_ * lstm_param_->hidden_size_ * lstm_param_->input_size_; 8065+ int hh_whole_size = weight_segment_num_ * lstm_param_->hidden_size_ * lstm_param_->output_size_; 8066+ auto weight_proj_data = weight_data + hi_whole_size + hh_whole_size; 8067+ int batch = lstm_param_->bidirectional_ ? C2NUM : C1NUM; 8068+ auto pack_size = batch * lstm_param_->hidden_size_ * lstm_param_->proj_col_align_ * sizeof(float); 8069+ if (lstm_param_->batch_ != 1) { 8070+ weight_project_ptr_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(pack_size)); 8071+ MS_CHECK_TRUE_MSG(weight_project_ptr_ != nullptr, lite::RET_NULL_PTR, 8072+ "LstmNonMindirCPUKernel malloc weight_project_ptr_ failed."); 8073+ running_buffer_.push_back(weight_project_ptr_); 8074+ PackLstmWeightWithStride(weight_project_ptr_, weight_proj_data, batch, lstm_param_->hidden_size_, 8075+ lstm_param_->output_size_, lstm_param_->proj_col_align_, lstm_param_->bidirectional_, 8076+ lstm_param_->hidden_size_ * lstm_param_->output_size_, nullptr); 8077+ } else { 8078+#ifdef ENABLE_AVX 8079+ weight_project_ptr_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(pack_size)); 8080+ MS_CHECK_TRUE_MSG(weight_project_ptr_ != nullptr, lite::RET_NULL_PTR, 8081+ "LstmNonMindirCPUKernel malloc weight_project_ptr_ failed."); 8082+ running_buffer_.push_back(weight_project_ptr_); 8083+ for (int i = 0; i < batch; ++i) { 8084+ const float *src_batch = weight_proj_data + i * lstm_param_->hidden_size_ * lstm_param_->output_size_; 8085+ float *dst_batch = weight_project_ptr_ + i * lstm_param_->hidden_size_ * lstm_param_->proj_col_align_; 8086+ RowMajor2Col32Major(src_batch, dst_batch, lstm_param_->output_size_, lstm_param_->hidden_size_); 8087+ } 8088+#else 8089+ weight_project_ptr_ = weight_proj_data; 8090+#endif 8091+ } 8092+ return RET_OK; 8093+} 8094+ 8095+void LstmMindirFp32CPUKernel::LstmUnidirectional(float *output, const float *weight_h, const float *state_bias, 8096+ float *hidden_state, float *cell_state, const float *weight_project, 8097+ float *intermediate_states, float **buffer, bool is_backward) { 8098+ float *gate = buffer[kInputGateIndex]; 8099+ float *input_gate = gate; 8100+ float *forget_gate = gate + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_ * C2NUM; 8101+ float *cell_gate = gate + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_ * C3NUM; 8102+ float *output_gate = gate + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_; 8103+ float *tmp = buffer[kTempHiddenOutputIndex]; 8104+ int dir_mult = lstm_param_->bidirectional_ ? C2NUM : C1NUM; 8105+ for (int t = 0; t < lstm_param_->seq_len_; t++) { 8106+ int real_t = is_backward ? lstm_param_->seq_len_ - t - C1NUM : t; 8107+ float *input_gate_t = input_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t; 8108+ float *forget_gate_t = forget_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t; 8109+ float *cell_gate_t = cell_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t; 8110+ float *output_gate_t = output_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t; 8111+ // Sequence, Batch, DirMul, Hidden 8112+ LstmStepUnit(tmp, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, weight_h, state_bias, weight_project, 8113+ hidden_state, cell_state, buffer, lstm_param_); 8114+ int seq_offset = real_t * lstm_param_->batch_ * dir_mult * lstm_param_->output_size_; 8115+ for (int b = 0; b < lstm_param_->batch_; b++) { 8116+ int batch_offset = b * dir_mult * lstm_param_->output_size_; 8117+ float *output_ptr = output + seq_offset + batch_offset; 8118+ memcpy(output_ptr, tmp + b * lstm_param_->output_size_, lstm_param_->output_size_ * sizeof(float)); 8119+ } 8120+ if (intermediate_states) { 8121+ RecordStates(hidden_state, cell_state, input_gate_t, output_gate_t, forget_gate_t, cell_gate_t, 8122+ intermediate_states, real_t); 8123+ } 8124+ } 8125+} 8126+ 8127+void LstmMindirFp32CPUKernel::RecordStates(const float *hidden_state, float *cell_state, float *input_gate, 8128+ const float *output_gate, float *forget_gate, const float *cell_gate, 8129+ float *intermediate_states, int step) { 8130+ float *states = intermediate_states; 8131+ auto hidden_size = lstm_param_->batch_ * lstm_param_->output_size_; 8132+ auto state_size = lstm_param_->batch_ * lstm_param_->hidden_size_; 8133+ if (state_size < 0) { 8134+ MS_LOG(ERROR) << "state size should be greater than or equal to zero."; 8135+ return; 8136+ } 8137+ auto hidden_stride = step * lstm_param_->output_step_; 8138+ auto hidden_seq_stride = lstm_param_->seq_len_ * lstm_param_->output_step_; 8139+ auto other_output_step = lstm_param_->bidirectional_ ? C2NUM * lstm_param_->batch_ * lstm_param_->hidden_size_ 8140+ : lstm_param_->batch_ * lstm_param_->hidden_size_; 8141+ auto stride = step * other_output_step; 8142+ auto seq_stride = lstm_param_->seq_len_ * other_output_step; 8143+ memcpy(states + hidden_stride, hidden_state, hidden_size * sizeof(float)); 8144+ stride += hidden_seq_stride; 8145+ memcpy(states + stride, cell_state, state_size * sizeof(float)); 8146+ stride += seq_stride; 8147+ memcpy(states + stride, input_gate, state_size * sizeof(float)); 8148+ stride += seq_stride; 8149+ memcpy(states + stride, output_gate, state_size * sizeof(float)); 8150+ stride += seq_stride; 8151+ memcpy(states + stride, forget_gate, state_size * sizeof(float)); 8152+ stride += seq_stride; 8153+ memcpy(states + stride, cell_gate, state_size * sizeof(float)); 8154+} 8155+} // namespace mindspore::kernel 8156diff --git a/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_mindir_fp32.h b/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_mindir_fp32.h 8157new file mode 100644 8158index 00000000..84cdd38e 8159--- /dev/null 8160+++ b/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_mindir_fp32.h 8161@@ -0,0 +1,63 @@ 8162+/** 8163+ * Copyright 2023 Huawei Technologies Co., Ltd 8164+ * 8165+ * Licensed under the Apache License, Version 2.0 (the "License"); 8166+ * you may not use this file except in compliance with the License. 8167+ * You may obtain a copy of the License at 8168+ * 8169+ * http://www.apache.org/licenses/LICENSE-2.0 8170+ * 8171+ * Unless required by applicable law or agreed to in writing, software 8172+ * distributed under the License is distributed on an "AS IS" BASIS, 8173+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 8174+ * See the License for the specific language governing permissions and 8175+ * limitations under the License. 8176+ */ 8177+ 8178+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_LSTM_MINDIR_FP32_H_ 8179+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_LSTM_MINDIR_FP32_H_ 8180+ 8181+#include <vector> 8182+#include "src/litert/kernel/cpu/fp32/lstm_fp32_base.h" 8183+ 8184+namespace mindspore::kernel { 8185+/* 8186+ * 1. LSTM without project, output_size = hidden_size 8187+ * h_init: second input, shape is [bidirectional, batch_size, hidden_size] 8188+ * c_init: third input, shape is [bidirectional, batch_size, hidden_size] 8189+ * weight_bias: forth input, weight_ih + weight_hh + bias, the gate order is IFGO 8190+ * 8191+ * 2. LSTM with project, output_size = project_size 8192+ * h_init: second input, shape is [bidirectional, batch_size, project_size] 8193+ * c_init: third input, shape is [bidirectional, batch_size, hidden_size] 8194+ * weight_bias: forth input, weight_ih + weight_hh + proj + bias, the gate order is IFGO 8195+ */ 8196+class LstmMindirFp32CPUKernel : public LstmFp32BaseCPUKernel { 8197+ public: 8198+ LstmMindirFp32CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, 8199+ const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) 8200+ : LstmFp32BaseCPUKernel(parameter, inputs, outputs, ctx) { 8201+ hidden_init_index_ = SECOND_INPUT; 8202+ cell_init_index_ = THIRD_INPUT; 8203+ } 8204+ 8205+ ~LstmMindirFp32CPUKernel() override = default; 8206+ 8207+ int ReSize() override; 8208+ 8209+ protected: 8210+ int InitInputWeightBias() override; 8211+ int InitStateWeightBias() override; 8212+ int InitProjectWeight() override; 8213+ void LstmUnidirectional(float *output, const float *weight_h, const float *state_bias, float *hidden_state, 8214+ float *cell_state, const float *weight_project, float *intermediate_states, float *buffer[], 8215+ bool is_backward) override; 8216+ 8217+ private: 8218+ void RecordStates(const float *hidden_state, float *cell_state, float *input_gate, const float *output_gate, 8219+ float *forget_gate, const float *cell_gate, float *intermediate_states, int step); 8220+ bool gpu_orig_state_{false}; 8221+}; 8222+} // namespace mindspore::kernel 8223+ 8224+#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_LSTM_MINDIR_FP32_H_ 8225diff --git a/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_non_mindir_fp32.cc b/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_non_mindir_fp32.cc 8226new file mode 100644 8227index 00000000..62f9f2b7 8228--- /dev/null 8229+++ b/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_non_mindir_fp32.cc 8230@@ -0,0 +1,173 @@ 8231+/** 8232+ * Copyright 2023 Huawei Technologies Co., Ltd 8233+ * 8234+ * Licensed under the Apache License, Version 2.0 (the "License"); 8235+ * you may not use this file except in compliance with the License. 8236+ * You may obtain a copy of the License at 8237+ * 8238+ * http://www.apache.org/licenses/LICENSE-2.0 8239+ * 8240+ * Unless required by applicable law or agreed to in writing, software 8241+ * distributed under the License is distributed on an "AS IS" BASIS, 8242+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 8243+ * See the License for the specific language governing permissions and 8244+ * limitations under the License. 8245+ */ 8246+ 8247+#include "src/litert/kernel/cpu/fp32/lstm_non_mindir_fp32.h" 8248+#include "nnacl/fp32/pack_fp32.h" 8249+ 8250+namespace mindspore::kernel { 8251+namespace { 8252+constexpr int kInputGateIndex = 0; 8253+constexpr int kGateNum = 4; 8254+constexpr int kWeightInputIndex = 1; 8255+constexpr int kWeightHiddenindex = 2; 8256+constexpr int kCombinedBiasIndex = 3; 8257+} // namespace 8258+ 8259+int LstmNonMindirFp32CPUKernel::InitInputWeightBias() { 8260+ // malloc and init input * weight right matrix buffer 8261+ // input -- row: seq_len * batch; col: input_size 8262+ // weight -- row: hidden_size; col: input_size, need transpose 8263+ // result -- row: seq_len * batch; col: hidden_size 8264+ weight_i_ptr_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc( 8265+ weight_segment_num_ * lstm_param_->input_col_align_ * lstm_param_->input_size_ * sizeof(float))); 8266+ MS_CHECK_TRUE_MSG(weight_i_ptr_ != nullptr, lite::RET_NULL_PTR, 8267+ "LstmNonMindirCPUKernel malloc weight_i_ptr_ failed."); 8268+ running_buffer_.push_back(weight_i_ptr_); 8269+ auto weight_i = in_tensors_.at(kWeightInputIndex); 8270+ auto weight_i_data = reinterpret_cast<float *>(weight_i->data()); 8271+ CHECK_NULL_RETURN(weight_i_data); 8272+ 8273+ int stride = kGateNum * lstm_param_->input_size_ * lstm_param_->hidden_size_; 8274+ PackLstmWeightWithStride(weight_i_ptr_, weight_i_data, weight_segment_num_, lstm_param_->input_size_, 8275+ lstm_param_->hidden_size_, lstm_param_->input_col_align_, lstm_param_->bidirectional_, 8276+ stride, nullptr); 8277+ // input bias 8278+ input_bias_ = reinterpret_cast<float *>( 8279+ ms_context_->allocator->Malloc(weight_segment_num_ * lstm_param_->input_col_align_ * sizeof(float))); 8280+ MS_CHECK_TRUE_MSG(input_bias_ != nullptr, lite::RET_NULL_PTR, "LstmNonMindirCPUKernel malloc input_bias_ failed."); 8281+ memset(input_bias_, 0, weight_segment_num_ * lstm_param_->input_col_align_ * sizeof(float)); 8282+ running_buffer_.push_back(input_bias_); 8283+ auto bias_data = reinterpret_cast<float *>(in_tensors_.at(kCombinedBiasIndex)->data()); 8284+ CHECK_NULL_RETURN(bias_data); 8285+ PackLstmBias(input_bias_, bias_data, weight_segment_num_, lstm_param_->hidden_size_, lstm_param_->input_col_align_, 8286+ lstm_param_->bidirectional_, nullptr); 8287+ return RET_OK; 8288+} 8289+ 8290+int LstmNonMindirFp32CPUKernel::InitStateWeightBias() { 8291+ // malloc and init state * weight right matrix buffer, state * weight will be executed seq_len_ times. 8292+ // state -- row: batch; col: hidden_size 8293+ // weight -- row: hidden_size; col: hidden_size, need transpose 8294+ // result -- row: batch; col: hidden_size 8295+ auto weight_h = in_tensors_.at(kWeightHiddenindex); 8296+ auto weight_h_data = reinterpret_cast<float *>(weight_h->data()); 8297+ CHECK_NULL_RETURN(weight_h_data); 8298+ 8299+ int stride = kGateNum * lstm_param_->hidden_size_ * lstm_param_->output_size_; 8300+ auto weight_pack_size = 8301+ weight_segment_num_ * lstm_param_->state_col_align_ * lstm_param_->output_size_ * sizeof(float); 8302+ if (lstm_param_->batch_ != 1) { 8303+ weight_h_ptr_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(weight_pack_size)); 8304+ MS_CHECK_TRUE_MSG(weight_h_ptr_ != nullptr, lite::RET_NULL_PTR, 8305+ "LstmNonMindirCPUKernel malloc weight_h_ptr_ failed."); 8306+ running_buffer_.push_back(weight_h_ptr_); 8307+ PackLstmWeightWithStride(weight_h_ptr_, weight_h_data, weight_segment_num_, lstm_param_->output_size_, 8308+ lstm_param_->hidden_size_, lstm_param_->state_col_align_, lstm_param_->bidirectional_, 8309+ stride, nullptr); 8310+ } else { 8311+#ifdef ENABLE_AVX 8312+ weight_h_ptr_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(weight_pack_size)); 8313+ MS_CHECK_TRUE_MSG(weight_h_ptr_ != nullptr, lite::RET_NULL_PTR, 8314+ "LstmNonMindirCPUKernel malloc weight_h_ptr_ failed."); 8315+ running_buffer_.push_back(weight_h_ptr_); 8316+ for (int i = 0; i < weight_segment_num_; i++) { 8317+ const float *src_batch = weight_h_data + i * lstm_param_->hidden_size_ * lstm_param_->output_size_; 8318+ float *dst_batch = weight_h_ptr_ + i * lstm_param_->state_col_align_ * lstm_param_->output_size_; 8319+ RowMajor2Col32Major(src_batch, dst_batch, lstm_param_->hidden_size_, lstm_param_->output_size_); 8320+ } 8321+#else 8322+ weight_h_ptr_ = weight_h_data; 8323+#endif 8324+ } 8325+ 8326+ // state bias 8327+ auto bias_pack_size = weight_segment_num_ * lstm_param_->state_col_align_ * sizeof(float); 8328+ state_bias_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(bias_pack_size)); 8329+ MS_CHECK_TRUE_MSG(state_bias_ != nullptr, lite::RET_NULL_PTR, "LstmNonMindirCPUKernel malloc state_bias_ failed."); 8330+ memset(state_bias_, 0, bias_pack_size); 8331+ running_buffer_.push_back(state_bias_); 8332+ // if ONNX, secend bias is also present order IOFG 8333+ auto bias_data = reinterpret_cast<float *>(in_tensors_.at(kCombinedBiasIndex)->data()); 8334+ CHECK_NULL_RETURN(bias_data); 8335+ auto *state_bias = bias_data + kGateNum * lstm_param_->hidden_size_; 8336+ PackLstmBias(state_bias_, state_bias, weight_segment_num_, lstm_param_->hidden_size_, lstm_param_->state_col_align_, 8337+ lstm_param_->bidirectional_, nullptr); 8338+ return RET_OK; 8339+} 8340+ 8341+int LstmNonMindirFp32CPUKernel::InitProjectWeight() { 8342+ if (in_tensors_.size() < C7NUM) { 8343+ return RET_OK; 8344+ } 8345+ auto weight_pro = in_tensors_.at(SEVENTH_INPUT); 8346+ auto shape = weight_pro->shape(); 8347+ MS_CHECK_TRUE_MSG(shape.size() == C3NUM, lite::RET_ERROR, "Project-weight's shape must be 3D."); 8348+ auto weight_pro_data = reinterpret_cast<float *>(weight_pro->data()); 8349+ CHECK_NULL_RETURN(weight_pro_data); 8350+ int batch = lstm_param_->bidirectional_ ? C2NUM : C1NUM; 8351+ if (shape[0] != batch) { 8352+ MS_LOG(ERROR) << "Project-weight's shape[0] must be 1(bidirectional=false) or 2(bidirectional=true)."; 8353+ return lite::RET_ERROR; 8354+ } 8355+ int col_align = UP_ROUND(lstm_param_->output_size_, col_tile_); 8356+ auto pack_size = batch * lstm_param_->hidden_size_ * col_align * sizeof(float); 8357+ if (lstm_param_->batch_ != 1) { 8358+ weight_project_ptr_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(pack_size)); 8359+ MS_CHECK_TRUE_MSG(weight_project_ptr_ != nullptr, lite::RET_NULL_PTR, 8360+ "LstmNonMindirCPUKernel malloc weight_project_ptr_ failed."); 8361+ running_buffer_.push_back(weight_project_ptr_); 8362+ PackLstmWeightWithStride(weight_project_ptr_, weight_pro_data, batch, lstm_param_->hidden_size_, 8363+ lstm_param_->output_size_, col_align, lstm_param_->bidirectional_, 8364+ lstm_param_->hidden_size_ * lstm_param_->output_size_, nullptr); 8365+ } else { 8366+#ifdef ENABLE_AVX 8367+ weight_project_ptr_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(pack_size)); 8368+ MS_CHECK_TRUE_MSG(weight_project_ptr_ != nullptr, lite::RET_NULL_PTR, 8369+ "LstmNonMindirCPUKernel malloc weight_project_ptr_ failed."); 8370+ running_buffer_.push_back(weight_project_ptr_); 8371+ for (int i = 0; i < batch; ++i) { 8372+ const float *src_batch = weight_pro_data + i * lstm_param_->hidden_size_ * lstm_param_->output_size_; 8373+ float *dst_batch = weight_project_ptr_ + i * lstm_param_->hidden_size_ * col_align; 8374+ RowMajor2Col32Major(src_batch, dst_batch, lstm_param_->output_size_, lstm_param_->hidden_size_); 8375+ } 8376+#else 8377+ weight_project_ptr_ = weight_pro_data; 8378+#endif 8379+ } 8380+ return RET_OK; 8381+} 8382+ 8383+void LstmNonMindirFp32CPUKernel::LstmUnidirectional(float *output, const float *weight_h, const float *state_bias, 8384+ float *hidden_state, float *cell_state, const float *weight_project, 8385+ float *intermediate_states, float **buffer, bool is_backward) { 8386+ float *gate = buffer[kInputGateIndex]; 8387+ float *input_gate = gate; 8388+ float *forget_gate = gate + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_ * C2NUM; 8389+ float *cell_gate = gate + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_ * C3NUM; 8390+ float *output_gate = gate + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_; 8391+ for (int t = 0; t < lstm_param_->seq_len_; t++) { 8392+ int real_t = is_backward ? lstm_param_->seq_len_ - t - C1NUM : t; 8393+ float *input_gate_t = input_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t; 8394+ float *forget_gate_t = forget_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t; 8395+ float *cell_gate_t = cell_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t; 8396+ float *output_gate_t = output_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t; 8397+ // Sequence, DirMul, Batch, Hidden 8398+ float *output_ptr = output + real_t * lstm_param_->output_step_; 8399+ LstmStepUnit(output_ptr, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, weight_h, state_bias, 8400+ weight_project, hidden_state, cell_state, buffer, lstm_param_); 8401+ } 8402+} 8403+} // namespace mindspore::kernel 8404diff --git a/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_non_mindir_fp32.h b/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_non_mindir_fp32.h 8405new file mode 100644 8406index 00000000..b16e9175 8407--- /dev/null 8408+++ b/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_non_mindir_fp32.h 8409@@ -0,0 +1,61 @@ 8410+/** 8411+ * Copyright 2023 Huawei Technologies Co., Ltd 8412+ * 8413+ * Licensed under the Apache License, Version 2.0 (the "License"); 8414+ * you may not use this file except in compliance with the License. 8415+ * You may obtain a copy of the License at 8416+ * 8417+ * http://www.apache.org/licenses/LICENSE-2.0 8418+ * 8419+ * Unless required by applicable law or agreed to in writing, software 8420+ * distributed under the License is distributed on an "AS IS" BASIS, 8421+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 8422+ * See the License for the specific language governing permissions and 8423+ * limitations under the License. 8424+ */ 8425+ 8426+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_LSTM_NON_MINDIR_FP32_H_ 8427+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_LSTM_NON_MINDIR_FP32_H_ 8428+ 8429+#include <vector> 8430+#include "src/litert/kernel/cpu/fp32/lstm_fp32_base.h" 8431+ 8432+namespace mindspore::kernel { 8433+/* 8434+ * 1. LSTM without project, output_size = hidden_size 8435+ * weight_ih: second input, shape is [bidirectional, 4 * hidden_size, input_size] 8436+ * weight_hh: third input, shape is [bidirectional, 4 * hidden_size, hidden_size] 8437+ * bias: forth input, shape is [bidirectional, 8 * hidden_size] 8438+ * h_init: fifth input, shape is [bidirectional, batch_size, hidden_size] 8439+ * c_init: sixth input, shape is [bidirectional, batch_size, hidden_size] 8440+ * 8441+ * 2. LSTM with project, output_size = project_size 8442+ * weight_ih: second input, shape is [bidirectional, 4 * hidden_size, input_size] 8443+ * weight_hh: third input, shape is [bidirectional, 4 * hidden_size, project_size] 8444+ * bias: forth input, shape is [bidirectional, 8 * hidden_size] 8445+ * h_init: fifth input, shape is [bidirectional, batch_size, project_size] 8446+ * c_init: sixth input, shape is [bidirectional, batch_size, hidden_size] 8447+ * weight_pro: seventh input, shape is [bidirectional, project_size, hidden_size] 8448+ */ 8449+class LstmNonMindirFp32CPUKernel : public LstmFp32BaseCPUKernel { 8450+ public: 8451+ LstmNonMindirFp32CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, 8452+ const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) 8453+ : LstmFp32BaseCPUKernel(parameter, inputs, outputs, ctx) { 8454+ hidden_init_index_ = FIFTH_INPUT; 8455+ cell_init_index_ = SIXTH_INPUT; 8456+ } 8457+ 8458+ ~LstmNonMindirFp32CPUKernel() override = default; 8459+ 8460+ protected: 8461+ int InitInputWeightBias() override; 8462+ int InitStateWeightBias() override; 8463+ int InitProjectWeight() override; 8464+ void LstmUnidirectional(float *output, const float *weight_h, const float *state_bias, float *hidden_state, 8465+ float *cell_state, const float *weight_project, float *intermediate_states, float *buffer[], 8466+ bool is_backward) override; 8467+}; 8468+} // namespace mindspore::kernel 8469+ 8470+#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_LSTM_NON_MINDIR_FP32_H_ 8471diff --git a/mindspore/lite/src/litert/kernel/cpu/fp32_grad/custom_gather_d_grad_v2_fp32.cc b/mindspore/lite/src/litert/kernel/cpu/fp32_grad/custom_gather_d_grad_v2_fp32.cc 8472new file mode 100644 8473index 00000000..60d3f213 8474--- /dev/null 8475+++ b/mindspore/lite/src/litert/kernel/cpu/fp32_grad/custom_gather_d_grad_v2_fp32.cc 8476@@ -0,0 +1,147 @@ 8477+/** 8478+ * Copyright 2023 Huawei Technologies Co., Ltd 8479+ * 8480+ * Licensed under the Apache License, Version 2.0 (the "License"); 8481+ * you may not use this file except in compliance with the License. 8482+ * You may obtain a copy of the License at 8483+ * 8484+ * http://www.apache.org/licenses/LICENSE-2.0 8485+ * 8486+ * Unless required by applicable law or agreed to in writing, software 8487+ * distributed under the License is distributed on an "AS IS" BASIS, 8488+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 8489+ * See the License for the specific language governing permissions and 8490+ * limitations under the License. 8491+ */ 8492+#include "src/litert/kernel/cpu/fp32_grad/custom_gather_d_grad_v2_fp32.h" 8493+#include "src/litert//kernel_registry.h" 8494+#include "include/errorcode.h" 8495+#include "src/common/log_adapter.h" 8496+#include "nnacl/custom_gather_d_grad_v2_parameter.h" 8497+ 8498+using mindspore::lite::KernelRegistrar; 8499+using mindspore::lite::RET_ERROR; 8500+using mindspore::lite::RET_NOT_SUPPORT; 8501+using mindspore::lite::RET_OK; 8502+ 8503+namespace mindspore::kernel { 8504+namespace { 8505+constexpr size_t index_idx_{1}; 8506+constexpr size_t grad_idx_{2}; 8507+size_t get_element_num(const std::vector<int> &shape) { 8508+ return std::accumulate(shape.begin(), shape.end(), static_cast<std::size_t>(1), std::multiplies<int>()); 8509+} 8510+ 8511+void GatherDGradCopyTask(size_t cur, std::vector<size_t> *pos, float *input, int *index, const int &dim, float *output, 8512+ const std::vector<int> &output_shape, const std::vector<size_t> &out_cargo_size, 8513+ const std::vector<size_t> &input_cargo_size) { 8514+ for (int i = 0; i < output_shape[cur]; ++i) { 8515+ (*pos)[cur] = i; 8516+ if (cur == output_shape.size() - 1) { 8517+ int input_offset = 0; 8518+ int out_offset = 0; 8519+ // out offset 8520+ for (size_t j = 0; j < output_shape.size(); ++j) { 8521+ out_offset += (*pos)[j] * out_cargo_size[j]; 8522+ } 8523+ // input offset 8524+ int cur_index = (*pos)[dim]; 8525+ (*pos)[dim] = index[out_offset]; 8526+ for (size_t j = 0; j < output_shape.size(); ++j) { 8527+ input_offset += (*pos)[j] * input_cargo_size[j]; 8528+ } 8529+ // do copy 8530+ input[input_offset] += output[out_offset]; 8531+ (*pos)[dim] = cur_index; 8532+ } else { 8533+ // CopyTask 8534+ GatherDGradCopyTask(cur + 1, pos, input, index, dim, output, output_shape, out_cargo_size, input_cargo_size); 8535+ } 8536+ } 8537+} 8538+} // namespace 8539+ 8540+CustomGatherDGradV2CPUKernel::~CustomGatherDGradV2CPUKernel() {} 8541+ 8542+int CustomGatherDGradV2CPUKernel::Prepare() { 8543+ CHECK_LESS_RETURN(in_tensors_.size(), C3NUM); 8544+ CHECK_LESS_RETURN(out_tensors_.size(), C1NUM); 8545+ if (InitParamter() != RET_OK) { 8546+ MS_LOG(ERROR) << "Init Built-in CustomGatherGradV2 Parameter failed." << name_; 8547+ return RET_ERROR; 8548+ } 8549+ if (!InferShapeDone()) { 8550+ return RET_OK; 8551+ } 8552+ return ReSize(); 8553+} 8554+ 8555+int CustomGatherDGradV2CPUKernel::InitParamter() { 8556+ auto param = reinterpret_cast<CustomGatherGradV2Parameter *>(op_parameter_); 8557+ axis_ = param->dim; 8558+ return RET_OK; 8559+} 8560+ 8561+int CustomGatherDGradV2CPUKernel::ReSize() { 8562+ index_shape_ = in_tensors_[index_idx_]->shape(); 8563+ grad_shape_ = in_tensors_[grad_idx_]->shape(); 8564+ output_shape_ = out_tensors_[0]->shape(); 8565+ if (grad_shape_.size() != index_shape_.size() || output_shape_.size() != index_shape_.size()) { 8566+ MS_LOG(ERROR) << "For '" << name_ << "', the dimension of grad and output must be the equal to the " 8567+ << "dimension of index: " << index_shape_.size() 8568+ << ", but got the dimension of grad: " << grad_shape_.size() 8569+ << ", the dimension of output: " << output_shape_.size(); 8570+ return RET_ERROR; 8571+ } 8572+ 8573+ return RET_OK; 8574+} 8575+ 8576+int CustomGatherDGradV2CPUKernel::Run() { 8577+ auto *index = reinterpret_cast<int *>(in_tensors_[index_idx_]->data()); 8578+ auto *grad = reinterpret_cast<float *>(in_tensors_[grad_idx_]->data()); 8579+ auto out = reinterpret_cast<float *>(out_tensors_[0]->data()); 8580+ int output_rank = output_shape_.size(); 8581+ if (axis_ >= output_rank || axis_ < -output_rank) { 8582+ MS_LOG(ERROR) << "For '" << name_ << "', the value of 'dim' must be in [" << -output_rank << ", " << output_rank 8583+ << "), but got: " << axis_; 8584+ } 8585+ if (axis_ < 0) { 8586+ axis_ = axis_ + output_rank; 8587+ } 8588+ 8589+ // check index 8590+ size_t index_size = get_element_num(index_shape_); 8591+ int max_index = output_shape_[axis_]; 8592+ for (size_t i = 0; i < index_size; ++i) { 8593+ if (index[i] >= max_index || index[i] < -max_index) { 8594+ MS_LOG(ERROR) << "For '" << name_ << "', the value of 'index' must be in [" << -max_index << ", " << max_index 8595+ << "), but got: " << index[i]; 8596+ } 8597+ if (index[i] < 0) { 8598+ index[i] = max_index + index[i]; 8599+ } 8600+ } 8601+ auto out_size = get_element_num(output_shape_); 8602+ memset(out, 0, out_size * sizeof(float)); 8603+ 8604+ // out_cargo_size 8605+ std::vector<size_t> out_cargo_size = std::vector<size_t>(output_shape_.size(), 1); 8606+ for (int i = static_cast<int>(out_cargo_size.size()) - 2; i >= 0; --i) { 8607+ out_cargo_size[i] = output_shape_[i + 1] * out_cargo_size[i + 1]; 8608+ } 8609+ // grad_cargo_size 8610+ std::vector<size_t> grad_cargo_size = std::vector<size_t>(grad_shape_.size(), 1); 8611+ for (int i = static_cast<int>(grad_cargo_size.size()) - 2; i >= 0; --i) { 8612+ grad_cargo_size[i] = grad_shape_[i + 1] * grad_cargo_size[i + 1]; 8613+ } 8614+ 8615+ // copy task 8616+ std::vector<size_t> pos(index_shape_.size(), 0); 8617+ GatherDGradCopyTask(0, &pos, out, index, axis_, grad, index_shape_, grad_cargo_size, out_cargo_size); 8618+ return RET_OK; 8619+} 8620+ 8621+REG_KERNEL(kCPU, kNumberTypeFloat32, PrimType_Inner_CustomGatherDGradV2, 8622+ LiteKernelCreator<CustomGatherDGradV2CPUKernel>) 8623+} // namespace mindspore::kernel 8624diff --git a/mindspore/lite/src/litert/kernel/cpu/fp32_grad/custom_gather_d_grad_v2_fp32.h b/mindspore/lite/src/litert/kernel/cpu/fp32_grad/custom_gather_d_grad_v2_fp32.h 8625new file mode 100644 8626index 00000000..25666023 8627--- /dev/null 8628+++ b/mindspore/lite/src/litert/kernel/cpu/fp32_grad/custom_gather_d_grad_v2_fp32.h 8629@@ -0,0 +1,42 @@ 8630+/** 8631+ * Copyright 2023 Huawei Technologies Co., Ltd 8632+ * 8633+ * Licensed under the Apache License, Version 2.0 (the "License"); 8634+ * you may not use this file except in compliance with the License. 8635+ * You may obtain a copy of the License at 8636+ * 8637+ * http://www.apache.org/licenses/LICENSE-2.0 8638+ * 8639+ * Unless required by applicable law or agreed to in writing, software 8640+ * distributed under the License is distributed on an "AS IS" BASIS, 8641+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 8642+ * See the License for the specific language governing permissions and 8643+ * limitations under the License. 8644+ */ 8645+ 8646+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_CUSTOM_GATHER_D_GRAD_V2_H_ 8647+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_CUSTOM_GATHER_D_GRAD_V2_H_ 8648+#include <vector> 8649+#include "src/litert/lite_kernel.h" 8650+ 8651+namespace mindspore::kernel { 8652+class CustomGatherDGradV2CPUKernel : public LiteKernel { 8653+ public: 8654+ CustomGatherDGradV2CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, 8655+ const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) 8656+ : LiteKernel(parameter, inputs, outputs, ctx) {} 8657+ ~CustomGatherDGradV2CPUKernel() override; 8658+ int Prepare() override; 8659+ int ReSize() override; 8660+ int Run() override; 8661+ 8662+ private: 8663+ int InitParamter(); 8664+ 8665+ std::vector<int> index_shape_; 8666+ std::vector<int> grad_shape_; 8667+ std::vector<int> output_shape_; 8668+ int axis_{0}; 8669+}; 8670+} // namespace mindspore::kernel 8671+#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_CUSTOM_GATHER_D_GRAD_V2_H_ 8672diff --git a/mindspore/lite/src/train/graph_fusion.cc b/mindspore/lite/src/train/graph_fusion.cc 8673index 48c037b2..7982f818 100644 8674--- a/mindspore/lite/src/train/graph_fusion.cc 8675+++ b/mindspore/lite/src/train/graph_fusion.cc 8676@@ -25,6 +25,8 @@ 8677 #include "src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.h" 8678 #include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h" 8679 #include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h" 8680+#include "src/train/optimizer/fusion/matmul_add_fusion_pass.h" 8681+#include "src/train/optimizer/fusion/matmul_matmul_add_fusion_pass.h" 8682 8683 namespace mindspore { 8684 namespace lite { 8685@@ -52,7 +54,9 @@ STATUS GraphFusion::Run(schema::MetaGraphT *graph) { 8686 Optimizer fusion_optimizer; 8687 fusion_optimizer.AddPass(new (std::nothrow) ReshapeGatherReshapeFusionPass()); 8688 fusion_optimizer.AddPass(new (std::nothrow) MatMulBiasAddFusionPass()); 8689+ fusion_optimizer.AddPass(new (std::nothrow) MatMulAddFusionPass()); 8690 fusion_optimizer.AddPass(new (std::nothrow) MatMulActivationFusionPass()); 8691+ fusion_optimizer.AddPass(new (std::nothrow) MatMulMatMulAddFusionPass()); 8692 fusion_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); 8693 fusion_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); 8694 auto status = fusion_optimizer.Run(graph); 8695diff --git a/mindspore/lite/src/train/optimizer/fusion/matmul_add_fusion_pass.cc b/mindspore/lite/src/train/optimizer/fusion/matmul_add_fusion_pass.cc 8696new file mode 100644 8697index 00000000..34bed911 8698--- /dev/null 8699+++ b/mindspore/lite/src/train/optimizer/fusion/matmul_add_fusion_pass.cc 8700@@ -0,0 +1,127 @@ 8701+/** 8702+ * Copyright 2023 Huawei Technologies Co., Ltd 8703+ * 8704+ * Licensed under the Apache License, Version 2.0 (the "License"); 8705+ * you may not use this file except in compliance with the License. 8706+ * You may obtain a copy of the License at 8707+ * 8708+ * http://www.apache.org/licenses/LICENSE-2.0 8709+ * 8710+ * Unless required by applicable law or agreed to in writing, software 8711+ * distributed under the License is distributed on an "AS IS" BASIS, 8712+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 8713+ * See the License for the specific language governing permissions and 8714+ * limitations under the License. 8715+ */ 8716+#include "src/train/optimizer/fusion/matmul_add_fusion_pass.h" 8717+#include <string> 8718+#include <unordered_map> 8719+#include <vector> 8720+#include <memory> 8721+#include "schema/inner/model_generated.h" 8722+#include "tools/common/meta_graph_utils.h" 8723+namespace { 8724+constexpr int kNumAddMatchPathLen = 2; 8725+constexpr std::string_view MulName = "MATMUL"; 8726+constexpr std::string_view AddName = "ADD"; 8727+} // namespace 8728+namespace mindspore { 8729+namespace lite { 8730+namespace { 8731+int CalNewCnodeBias(const std::unique_ptr<mindspore::schema::TensorT> &add_weight_tensor, 8732+ const std::unique_ptr<mindspore::schema::TensorT> &matmul_bias_tensor) { 8733+ if (add_weight_tensor->dataType != kNumberTypeFloat32 || matmul_bias_tensor->dataType != kNumberTypeFloat32) { 8734+ MS_LOG(INFO) << "only support float32 data type"; 8735+ return RET_ERROR; 8736+ } 8737+ std::vector<int32_t> matmul_bias_shape = matmul_bias_tensor->dims; 8738+ std::vector<int32_t> add_weight_shape = add_weight_tensor->dims; 8739+ MS_CHECK_TRUE_RET(matmul_bias_shape == add_weight_shape, RET_ERROR); 8740+ auto add_weight_data = reinterpret_cast<float *>(add_weight_tensor->data.data()); 8741+ auto matmul_bias_data = reinterpret_cast<float *>(matmul_bias_tensor->data.data()); 8742+ int num = static_cast<int>(matmul_bias_tensor->data.size() / sizeof(float)); 8743+ for (int i = 0; i < num; ++i) { 8744+ matmul_bias_data[i] += add_weight_data[i]; 8745+ } 8746+ return RET_OK; 8747+} 8748+} // namespace 8749+STATUS MatMulAddFusionPass::Run(MetaGraphT *graph) { return FusionPass::Run(graph); } 8750+STATUS MatMulAddFusionPass::DefinePattern() { 8751+ auto mul_op = std::make_shared<PatternOp>(); 8752+ MS_CHECK_TRUE_RET(mul_op != nullptr, RET_NULL_PTR); 8753+ mul_op->id = MulName; 8754+ mul_op->types = {schema::PrimitiveType_MatMulFusion}; 8755+ auto add_op = std::make_shared<PatternOp>(); 8756+ MS_CHECK_TRUE_RET(add_op != nullptr, RET_NULL_PTR); 8757+ add_op->id = AddName; 8758+ add_op->types = {schema::PrimitiveType_AddFusion}; 8759+ add_op->left = mul_op; 8760+ std::unique_ptr<FusionPattern> fusion_pattern(new (std::nothrow) FusionPattern("MatMulAddFusion")); 8761+ if (fusion_pattern == nullptr) { 8762+ MS_LOG(ERROR) << "new fusion_pattern failed"; 8763+ return RET_ERROR; 8764+ } 8765+ fusion_pattern->AddPatternOp(mul_op); 8766+ fusion_pattern->AddPatternOp(add_op); 8767+ fusion_pattern->Finish(); 8768+ this->patterns.emplace_back(fusion_pattern.release()); 8769+ return RET_OK; 8770+} 8771+STATUS MatMulAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &pattern_name, 8772+ const std::unordered_map<std::string, std::shared_ptr<Path>> &matched_path) { 8773+ MS_CHECK_TRUE_RET(graph != nullptr, RET_NULL_PTR); 8774+ if (matched_path.size() != kNumAddMatchPathLen) { 8775+ MS_LOG(ERROR) << "MatMul-Add-Fusion should have two NodeIndex in matchedPair"; 8776+ return RET_PARAM_INVALID; 8777+ } 8778+ auto mul_path_iter = matched_path.find(std::string(MulName)); 8779+ MS_CHECK_TRUE_RET(mul_path_iter != matched_path.end(), RET_NO_CHANGE); 8780+ auto &mul_path = mul_path_iter->second; 8781+ MS_CHECK_TRUE_RET(mul_path != nullptr, RET_NULL_PTR); 8782+ auto add_path_iter = matched_path.find(std::string(AddName)); 8783+ MS_CHECK_TRUE_RET(add_path_iter != matched_path.end(), RET_NO_CHANGE); 8784+ auto &add_path = add_path_iter->second; 8785+ MS_CHECK_TRUE_RET(add_path != nullptr, RET_NULL_PTR); 8786+ auto mul_index = mul_path->nodeIdx; 8787+ auto add_index = add_path->nodeIdx; 8788+ auto &mul_node = graph->nodes.at(mul_index); 8789+ MS_CHECK_TRUE_RET(mul_node != nullptr, RET_NULL_PTR); 8790+ auto &add_node = graph->nodes.at(add_index); 8791+ MS_CHECK_TRUE_RET(add_node != nullptr, RET_NULL_PTR); 8792+ if (mul_node->quantType == schema::QuantType_QUANT_ALL || mul_node->quantType == schema::QuantType_QUANT_DYNAMIC || 8793+ add_node->quantType == schema::QuantType_QUANT_ALL || add_node->quantType == schema::QuantType_QUANT_DYNAMIC) { 8794+ MS_LOG(DEBUG) << "cannot fusion."; 8795+ return RET_NO_CHANGE; 8796+ } 8797+ MS_CHECK_TRUE_RET(mul_node->primitive != nullptr, RET_NULL_PTR); 8798+ auto matmul_type = mul_node->primitive->value.AsMatMulFusion(); 8799+ MS_CHECK_TRUE_RET(matmul_type->activation_type == ActivationType::ActivationType_NO_ACTIVATION, RET_NO_CHANGE); 8800+ auto add_param_shape = graph->allTensors.at(add_node->inputIndex.at(SECOND_INPUT))->dims; 8801+ MS_CHECK_TRUE_MSG(add_param_shape.size() == DIMENSION_1D, RET_NO_CHANGE, "only support bias with shape size of 1."); 8802+ if (mul_node->inputIndex.size() == C3NUM) { 8803+ auto &mul_bias_tensor = graph->allTensors.at(mul_node->inputIndex.at(THIRD_INPUT)); 8804+ if (mul_bias_tensor->data.data() == nullptr) { 8805+ MS_LOG(INFO) << mul_node->name << "'s bias is not const"; 8806+ return RET_NO_CHANGE; 8807+ } 8808+ auto &add_weight_tensor = graph->allTensors.at(add_node->inputIndex.at(SECOND_INPUT)); 8809+ if (CalNewCnodeBias(add_weight_tensor, mul_bias_tensor) != RET_OK) { 8810+ MS_LOG(INFO) << add_node->name << " failed to fusion with " << mul_node->name; 8811+ return RET_NO_CHANGE; 8812+ } 8813+ } 8814+ auto add_tensor_index = add_node->inputIndex.at(SECOND_INPUT); 8815+ if (mul_node->inputIndex.size() == C2NUM) { 8816+ mul_node->inputIndex.push_back(add_tensor_index); 8817+ } 8818+ mul_node->outputIndex = {add_node->outputIndex}; 8819+ // cannot delete node here, otherwise will destroy order in other pattern's node index 8820+ // make it an isolated node to be removed in IsolatedNodeRemovePass 8821+ add_node->inputIndex.clear(); 8822+ add_node->outputIndex.clear(); 8823+ return RET_OK; 8824+} 8825+MatMulAddFusionPass::~MatMulAddFusionPass() = default; 8826+} // namespace lite 8827+} // namespace mindspore 8828diff --git a/mindspore/lite/src/train/optimizer/fusion/matmul_add_fusion_pass.h b/mindspore/lite/src/train/optimizer/fusion/matmul_add_fusion_pass.h 8829new file mode 100644 8830index 00000000..8eb4ab2e 8831--- /dev/null 8832+++ b/mindspore/lite/src/train/optimizer/fusion/matmul_add_fusion_pass.h 8833@@ -0,0 +1,37 @@ 8834+/** 8835+ * Copyright 2023 Huawei Technologies Co., Ltd 8836+ * 8837+ * Licensed under the Apache License, Version 2.0 (the "License"); 8838+ * you may not use this file except in compliance with the License. 8839+ * You may obtain a copy of the License at 8840+ * 8841+ * http://www.apache.org/licenses/LICENSE-2.0 8842+ * 8843+ * Unless required by applicable law or agreed to in writing, software 8844+ * distributed under the License is distributed on an "AS IS" BASIS, 8845+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 8846+ * See the License for the specific language governing permissions and 8847+ * limitations under the License. 8848+ */ 8849+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_FUSION_MATMUL_ADD_FUSION_PASS_H_ 8850+#define MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_FUSION_MATMUL_ADD_FUSION_PASS_H_ 8851+#include <string> 8852+#include <unordered_map> 8853+#include <memory> 8854+#include <algorithm> 8855+#include <utility> 8856+#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h" 8857+namespace mindspore { 8858+namespace lite { 8859+class MatMulAddFusionPass : public FusionPass { 8860+ public: 8861+ MatMulAddFusionPass() = default; 8862+ ~MatMulAddFusionPass() override; 8863+ STATUS DefinePattern() override; 8864+ STATUS DoFusion(MetaGraphT *graph, const std::string &pattern_name, 8865+ const std::unordered_map<std::string, std::shared_ptr<Path>> &matched_path) override; 8866+ STATUS Run(MetaGraphT *graph) override; 8867+}; 8868+} // namespace lite 8869+} // namespace mindspore 8870+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_FUSION_MATMUL_ADD_FUSION_PASS_H_ 8871diff --git a/mindspore/lite/src/train/optimizer/fusion/matmul_matmul_add_fusion_pass.cc b/mindspore/lite/src/train/optimizer/fusion/matmul_matmul_add_fusion_pass.cc 8872new file mode 100644 8873index 00000000..d1a63c2d 8874--- /dev/null 8875+++ b/mindspore/lite/src/train/optimizer/fusion/matmul_matmul_add_fusion_pass.cc 8876@@ -0,0 +1,163 @@ 8877+/** 8878+ * Copyright 2023 Huawei Technologies Co., Ltd 8879+ * 8880+ * Licensed under the Apache License, Version 2.0 (the "License"); 8881+ * you may not use this file except in compliance with the License. 8882+ * You may obtain a copy of the License at 8883+ * 8884+ * http://www.apache.org/licenses/LICENSE-2.0 8885+ * 8886+ * Unless required by applicable law or agreed to in writing, software 8887+ * distributed under the License is distributed on an "AS IS" BASIS, 8888+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 8889+ * See the License for the specific language governing permissions and 8890+ * limitations under the License. 8891+ */ 8892+ 8893+#include "src/train/optimizer/fusion/matmul_matmul_add_fusion_pass.h" 8894+#include <string> 8895+#include <unordered_map> 8896+#include <vector> 8897+#include <memory> 8898+#include "schema/inner/model_generated.h" 8899+#include "tools/common/meta_graph_utils.h" 8900+#include "src/train/optimizer/common/fusion_utils.h" 8901+namespace { 8902+constexpr std::string_view kFirstMatMulName = "MATMUL1"; 8903+constexpr std::string_view kSecondMatMulName = "MATMUL2"; 8904+constexpr std::string_view kAddName = "ADD"; 8905+} // namespace 8906+namespace mindspore { 8907+namespace lite { 8908+/* 8909+ * The subgraph such as the following. 8910+ * any any 8911+ * / \ | 8912+ * matmul matmul matmul 8913+ * \ / ----> | 8914+ * add any 8915+ * | 8916+ * any 8917+ */ 8918+namespace { 8919+int CalNewMatMulNode(MetaGraphT *graph, const std::unique_ptr<mindspore::schema::CNodeT> &matmul_node1, 8920+ const std::unique_ptr<mindspore::schema::CNodeT> &matmul_node2) { 8921+ auto &matrix_b_1 = graph->allTensors.at(matmul_node1->inputIndex.at(opt::kInputIndexOne)); 8922+ auto &matrix_b_2 = graph->allTensors.at(matmul_node2->inputIndex.at(opt::kInputIndexOne)); 8923+ if (matrix_b_1->dims != matrix_b_2->dims) { 8924+ MS_LOG(INFO) << "currently, matmul fusion only support the same shape tensor"; 8925+ return RET_ERROR; 8926+ } 8927+ if (matrix_b_1->dataType != kNumberTypeFloat32 || matrix_b_2->dataType != kNumberTypeFloat32) { 8928+ MS_LOG(INFO) << "only support float32 data type"; 8929+ return RET_ERROR; 8930+ } 8931+ auto matrix_b_1_data = reinterpret_cast<float *>(matrix_b_1->data.data()); 8932+ auto matrix_b_2_data = reinterpret_cast<float *>(matrix_b_2->data.data()); 8933+ int num_b = static_cast<int>(matrix_b_1->data.size() / sizeof(float)); 8934+ for (int j = 0; j < num_b; ++j) { 8935+ matrix_b_1_data[j] += matrix_b_2_data[j]; 8936+ } 8937+ return RET_OK; 8938+} 8939+} // namespace 8940+STATUS MatMulMatMulAddFusionPass::DefinePattern() { 8941+ auto matmul_op1 = std::make_shared<PatternOp>(); 8942+ MS_CHECK_TRUE_RET(matmul_op1 != nullptr, RET_NULL_PTR); 8943+ matmul_op1->id = kFirstMatMulName; 8944+ matmul_op1->types = {schema::PrimitiveType_MatMulFusion}; 8945+ auto matmul_op2 = std::make_shared<PatternOp>(); 8946+ MS_CHECK_TRUE_RET(matmul_op2 != nullptr, RET_NULL_PTR); 8947+ matmul_op2->id = kSecondMatMulName; 8948+ matmul_op2->types = {schema::PrimitiveType_MatMulFusion}; 8949+ auto add_op = std::make_shared<PatternOp>(); 8950+ MS_CHECK_TRUE_RET(add_op != nullptr, RET_NULL_PTR); 8951+ add_op->id = kAddName; 8952+ add_op->types = {schema::PrimitiveType_AddFusion}; 8953+ add_op->left = matmul_op1; 8954+ add_op->right = matmul_op2; 8955+ auto fusion_pattern = std::make_unique<FusionPattern>("MatMulMatMulAddFusion"); 8956+ MS_CHECK_TRUE_MSG(fusion_pattern != nullptr, RET_NULL_PTR, "new fusion_pattern failed"); 8957+ fusion_pattern->AddPatternOp(matmul_op1); 8958+ fusion_pattern->AddPatternOp(matmul_op2); 8959+ fusion_pattern->AddPatternOp(add_op); 8960+ fusion_pattern->Finish(); 8961+ this->patterns.emplace_back(fusion_pattern.release()); 8962+ return RET_OK; 8963+} 8964+ 8965+STATUS MatMulMatMulAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &pattern_name, 8966+ const std::unordered_map<std::string, std::shared_ptr<Path>> &matched_path) { 8967+ MS_CHECK_TRUE_RET(graph != nullptr, RET_NULL_PTR); 8968+ if (matched_path.size() != opt::kMatchPathLenThree) { 8969+ MS_LOG(INFO) << "MatMul-MatMul-Add-Fusion should have three NodeIndex in matchedPair"; 8970+ return RET_PARAM_INVALID; 8971+ } 8972+ 8973+ size_t matmul_index1 = 0; 8974+ auto ret = opt::GetMatchNodeIndex(graph, matched_path, std::string(kFirstMatMulName), &matmul_index1); 8975+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "cannot get matmul_index1"); 8976+ auto &matmul_node1 = graph->nodes.at(matmul_index1); 8977+ MS_CHECK_TRUE_MSG(matmul_node1 != nullptr, RET_NULL_PTR, "matmul_node1 is nullptr"); 8978+ size_t matmul_index2 = 0; 8979+ ret = opt::GetMatchNodeIndex(graph, matched_path, std::string(kSecondMatMulName), &matmul_index2); 8980+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "cannot get matmul_index2"); 8981+ auto &matmul_node2 = graph->nodes.at(matmul_index2); 8982+ MS_CHECK_TRUE_MSG(matmul_node2 != nullptr, RET_NULL_PTR, "matmul_node2 is nullptr"); 8983+ MS_CHECK_TRUE_MSG(matmul_node1->inputIndex.size() > C1NUM && matmul_node2->inputIndex.size() > C1NUM, 8984+ RET_PARAM_INVALID, "matmul should have two input at least"); 8985+ if (matmul_node1->inputIndex.size() < matmul_node2->inputIndex.size()) { 8986+ matmul_node1.swap(matmul_node2); 8987+ } 8988+ size_t add_index = 0; 8989+ ret = opt::GetMatchNodeIndex(graph, matched_path, std::string(kAddName), &add_index); 8990+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "cannot get add_index"); 8991+ auto &add_node = graph->nodes.at(add_index); 8992+ MS_CHECK_TRUE_MSG(add_node != nullptr, RET_NULL_PTR, "add_node is nullptr"); 8993+ 8994+ if (matmul_node1->quantType == schema::QuantType_QUANT_ALL || 8995+ matmul_node1->quantType == schema::QuantType_QUANT_DYNAMIC || 8996+ matmul_node2->quantType == schema::QuantType_QUANT_ALL || 8997+ matmul_node2->quantType == schema::QuantType_QUANT_DYNAMIC || 8998+ add_node->quantType == schema::QuantType_QUANT_ALL || add_node->quantType == schema::QuantType_QUANT_DYNAMIC) { 8999+ MS_LOG(DEBUG) << "cannot fusion with quant node"; 9000+ return RET_NO_CHANGE; 9001+ } 9002+ MS_CHECK_TRUE_RET(matmul_node1->primitive != nullptr, RET_NULL_PTR); 9003+ auto matmul_type1 = matmul_node1->primitive->value.AsMatMulFusion()->activation_type; 9004+ MS_CHECK_TRUE_RET(matmul_node2->primitive != nullptr, RET_NULL_PTR); 9005+ auto matmul_type2 = matmul_node2->primitive->value.AsMatMulFusion()->activation_type; 9006+ MS_CHECK_TRUE_RET(add_node->primitive != nullptr, RET_NULL_PTR); 9007+ auto add_type = add_node->primitive->value.AsAddFusion()->activation_type; 9008+ MS_CHECK_TRUE_RET(matmul_type1 == ActivationType::ActivationType_NO_ACTIVATION && 9009+ matmul_type2 == ActivationType::ActivationType_NO_ACTIVATION && 9010+ add_type == ActivationType::ActivationType_NO_ACTIVATION, 9011+ RET_NO_CHANGE); 9012+ 9013+ if (matmul_node1->inputIndex.at(FIRST_INPUT) != matmul_node2->inputIndex.at(FIRST_INPUT)) { 9014+ MS_LOG(INFO) << "matmul should have the same first input"; 9015+ return RET_NO_CHANGE; 9016+ } 9017+ auto &matmul_left_b = graph->allTensors[matmul_node1->inputIndex.at(SECOND_INPUT)]; 9018+ auto &matmul_right_b = graph->allTensors[matmul_node2->inputIndex.at(SECOND_INPUT)]; 9019+ if (matmul_left_b->data.empty() || matmul_right_b->data.empty()) { 9020+ return RET_NO_CHANGE; 9021+ } 9022+ if (CalNewMatMulNode(graph, matmul_node1, matmul_node2) != RET_OK) { 9023+ MS_LOG(INFO) << "failed to fusion two matmul"; 9024+ return RET_NO_CHANGE; 9025+ } 9026+ 9027+ matmul_node1->outputIndex = {add_node->outputIndex}; 9028+ // cannot delete node here, otherwise will destroy order in other pattern's node index 9029+ // make it an isolated node to be removed in IsolatedNodeRemovePass 9030+ matmul_node2->inputIndex.clear(); 9031+ matmul_node2->outputIndex.clear(); 9032+ add_node->inputIndex.clear(); 9033+ add_node->outputIndex.clear(); 9034+ return RET_OK; 9035+} 9036+ 9037+MatMulMatMulAddFusionPass::~MatMulMatMulAddFusionPass() = default; 9038+} // namespace lite 9039+} // namespace mindspore 9040diff --git a/mindspore/lite/src/train/optimizer/fusion/matmul_matmul_add_fusion_pass.h b/mindspore/lite/src/train/optimizer/fusion/matmul_matmul_add_fusion_pass.h 9041new file mode 100644 9042index 00000000..9ee6d711 9043--- /dev/null 9044+++ b/mindspore/lite/src/train/optimizer/fusion/matmul_matmul_add_fusion_pass.h 9045@@ -0,0 +1,43 @@ 9046+/** 9047+ * Copyright 2023 Huawei Technologies Co., Ltd 9048+ * 9049+ * Licensed under the Apache License, Version 2.0 (the "License"); 9050+ * you may not use this file except in compliance with the License. 9051+ * You may obtain a copy of the License at 9052+ * 9053+ * http://www.apache.org/licenses/LICENSE-2.0 9054+ * 9055+ * Unless required by applicable law or agreed to in writing, software 9056+ * distributed under the License is distributed on an "AS IS" BASIS, 9057+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9058+ * See the License for the specific language governing permissions and 9059+ * limitations under the License. 9060+ */ 9061+ 9062+#ifndef MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_MATMUL_MATMUL_ADD_FUSION_PASS_H_ 9063+#define MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_MATMUL_MATMUL_ADD_FUSION_PASS_H_ 9064+ 9065+#include <string> 9066+#include <unordered_map> 9067+#include <memory> 9068+#include <algorithm> 9069+#include <utility> 9070+#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h" 9071+ 9072+namespace mindspore { 9073+namespace lite { 9074+class MatMulMatMulAddFusionPass : public FusionPass { 9075+ public: 9076+ MatMulMatMulAddFusionPass() = default; 9077+ 9078+ ~MatMulMatMulAddFusionPass() override; 9079+ 9080+ STATUS DefinePattern() override; 9081+ 9082+ STATUS DoFusion(MetaGraphT *graph, const std::string &pattern_name, 9083+ const std::unordered_map<std::string, std::shared_ptr<Path>> &matched_path) override; 9084+}; 9085+} // namespace lite 9086+} // namespace mindspore 9087+ 9088+#endif // MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_MATMUL_MATMUL_ADD_FUSION_PASS_H_ 9089diff --git a/mindspore/lite/src/train/train_export.cc b/mindspore/lite/src/train/train_export.cc 9090index 7534ed2f..5bace006 100644 9091--- a/mindspore/lite/src/train/train_export.cc 9092+++ b/mindspore/lite/src/train/train_export.cc 9093@@ -151,11 +151,18 @@ int TrainExport::QuantTensorData(schema::TensorT *dest_tensor, const lite::Tenso 9094 return RET_OK; 9095 } 9096 9097-std::unique_ptr<schema::TensorT> TrainExport::CreateTensor(const mindspore::lite::Tensor *tensor, 9098- schema::Tensor *scTensor, int preferred_dim, 9099- const int tensor_quant_type) { 9100+std::unique_ptr<schema::TensorT> TrainExport::CreateTensor( 9101+ const mindspore::lite::Tensor *tensor, const std::vector<mindspore::lite::Tensor *> const_folded_output, 9102+ schema::Tensor *scTensor, int preferred_dim, const int tensor_quant_type) { 9103 auto tensorT = std::make_unique<schema::TensorT>(); 9104- tensorT->nodeType = scTensor->nodeType(); 9105+ bool const_fold = false; 9106+ if (quant_type_ == QT_NONE && !const_folded_output.empty() && 9107+ std::find(const_folded_output.begin(), const_folded_output.end(), tensor) != const_folded_output.end()) { 9108+ tensorT->nodeType = NodeType_ValueNode; 9109+ const_fold = true; 9110+ } else { 9111+ tensorT->nodeType = scTensor->nodeType(); 9112+ } 9113 tensorT->dims = tensor->shape(); 9114 tensorT->format = static_cast<schema::Format>(tensor->format()); 9115 tensorT->name = tensor->tensor_name(); 9116@@ -163,7 +170,8 @@ std::unique_ptr<schema::TensorT> TrainExport::CreateTensor(const mindspore::lite 9117 tensorT->offset = 0; 9118 tensorT->dataType = tensor->data_type(); 9119 tensorT->enableHuffmanCode = false; 9120- if ((tensorT->nodeType == NodeType_ValueNode) && (scTensor->data() != nullptr) && (scTensor->data()->size() > 0)) { 9121+ if (((tensorT->nodeType == NodeType_ValueNode) && (scTensor->data() != nullptr) && (scTensor->data()->size() > 0)) || 9122+ const_fold) { 9123 if (NeedQuantization(tensor, tensor_quant_type)) { 9124 auto ret = QuantTensorData(tensorT.get(), tensor, preferred_dim); 9125 if (ret != RET_OK) { 9126@@ -392,6 +400,7 @@ int TrainExport::KeepGraphInputsInOrder(const Model *model) { 9127 return RET_OK; 9128 } 9129 int TrainExport::ExportTensor(const Model *model, const std::vector<mindspore::lite::Tensor *> &tensors, int offset, 9130+ const std::vector<mindspore::lite::Tensor *> const_folded_output, 9131 const std::vector<std::pair<size_t, tensor_info>> &map_index, 9132 const std::vector<std::string> &output_names, const std::set<size_t> &out_set) { 9133 std::vector<mindspore::lite::Tensor *> in_tensors; 9134@@ -401,6 +410,7 @@ int TrainExport::ExportTensor(const Model *model, const std::vector<mindspore::l 9135 mindspore::lite::Tensor *tensor = tensors.at(pid); 9136 in_tensors.push_back(tensor); 9137 } 9138+ std::map<std::string, uint32_t> ordered_output_names; 9139 for (auto index : map_index) { 9140 auto id = index.first; 9141 size_t pid = id - static_cast<size_t>(offset); 9142@@ -408,7 +418,8 @@ int TrainExport::ExportTensor(const Model *model, const std::vector<mindspore::l 9143 schema::Tensor *scTensor = model->graph_.all_tensors_.at(pid); 9144 auto preferred_dim = WeightDecoder::GetPreferredDim(in_tensors, index.second.op_parameter, index.second.input_index, 9145 tensor->shape(), model->graph_.version_); 9146- auto tensorT = CreateTensor(tensor, scTensor, preferred_dim, index.second.op_parameter->quant_type_); 9147+ auto tensorT = 9148+ CreateTensor(tensor, const_folded_output, scTensor, preferred_dim, index.second.op_parameter->quant_type_); 9149 if (tensorT == nullptr) { 9150 MS_LOG(ERROR) << "error in tensor creation"; 9151 return RET_ERROR; 9152@@ -423,21 +434,27 @@ int TrainExport::ExportTensor(const Model *model, const std::vector<mindspore::l 9153 } 9154 // find output tensor 9155 if (std::find(output_names.begin(), output_names.end(), tensor->tensor_name()) != output_names.end()) { 9156- meta_graph_->outputIndex.push_back(remap_[id]); 9157- if (!meta_graph_->subGraph.empty()) { 9158- meta_graph_->subGraph[0]->outputIndices.push_back(remap_[id]); 9159- } 9160+ ordered_output_names[tensor->tensor_name()] = remap_[id]; 9161 } 9162 meta_graph_->allTensors.emplace_back(std::move(tensorT)); 9163 if (!meta_graph_->subGraph.empty()) { 9164 meta_graph_->subGraph[0]->tensorIndices.push_back(meta_graph_->allTensors.size() - 1); 9165 } 9166 } 9167+ for (auto &output_name : output_names) { 9168+ if (ordered_output_names.find(output_name) != ordered_output_names.end()) { 9169+ meta_graph_->outputIndex.push_back(ordered_output_names[output_name]); 9170+ if (!meta_graph_->subGraph.empty()) { 9171+ meta_graph_->subGraph[0]->outputIndices.push_back(ordered_output_names[output_name]); 9172+ } 9173+ } 9174+ } 9175 return RET_OK; 9176 } 9177 9178 int TrainExport::ExportNet(const std::vector<mindspore::kernel::KernelExec *> &kernels, 9179 const std::vector<mindspore::lite::Tensor *> &tensors, 9180+ const std::vector<mindspore::lite::Tensor *> const_folded_output, 9181 const std::vector<std::string> &output_names, const Model *model, 9182 QuantizationType quant_type, const Model *bb_model) { 9183 std::vector<std::pair<size_t, tensor_info>> map_index; 9184@@ -498,7 +515,7 @@ int TrainExport::ExportNet(const std::vector<mindspore::kernel::KernelExec *> &k 9185 } 9186 } 9187 9188- auto status = ExportTensor(model, tensors, offset, map_index, output_names, out_set); 9189+ auto status = ExportTensor(model, tensors, offset, const_folded_output, map_index, output_names, out_set); 9190 if (status != RET_OK) { 9191 MS_LOG(ERROR) << "ExportTensor failed."; 9192 return RET_ERROR; 9193diff --git a/mindspore/lite/src/train/train_export.h b/mindspore/lite/src/train/train_export.h 9194index b44f6526..8428c9b9 100644 9195--- a/mindspore/lite/src/train/train_export.h 9196+++ b/mindspore/lite/src/train/train_export.h 9197@@ -47,8 +47,10 @@ class TrainExport { 9198 explicit TrainExport(Buffer *model_buffer) : model_buffer_(model_buffer) {} 9199 virtual ~TrainExport(); 9200 int ExportNet(const std::vector<mindspore::kernel::KernelExec *> &kernels, 9201- const std::vector<mindspore::lite::Tensor *> &tensors, const std::vector<std::string> &output_names, 9202- const Model *model, QuantizationType quant_type, const Model *bb_model = nullptr); 9203+ const std::vector<mindspore::lite::Tensor *> &tensors, 9204+ const std::vector<mindspore::lite::Tensor *> const_folded_output, 9205+ const std::vector<std::string> &output_names, const Model *model, QuantizationType quant_type, 9206+ const Model *bb_model = nullptr); 9207 int ExportInit(const std::string model_name, std::string version); 9208 int SaveToFile(); 9209 int SaveToBuffer(); 9210@@ -75,7 +77,9 @@ class TrainExport { 9211 int TopologicalSort(); 9212 void PrepareRemap(int offset); 9213 LiteGraph::Node *FindNode(const mindspore::kernel::KernelExec *kernel, const Model *model); 9214- std::unique_ptr<schema::TensorT> CreateTensor(const Tensor *tensor, schema::Tensor *scTensor, int preferred_dim, 9215+ std::unique_ptr<schema::TensorT> CreateTensor(const Tensor *tensor, 9216+ const std::vector<mindspore::lite::Tensor *> const_folded_output, 9217+ schema::Tensor *scTensor, int preferred_dim, 9218 const int tensor_quant_type); 9219 std::unique_ptr<schema::CNodeT> CreateCNode(const mindspore::kernel::KernelExec *kernel, 9220 std::vector<uint32_t> inputIndex, std::vector<uint32_t> outputIndex, 9221@@ -93,6 +97,7 @@ class TrainExport { 9222 size_t *target_index); 9223 int KeepGraphInputsInOrder(const Model *model); 9224 int ExportTensor(const Model *model, const std::vector<mindspore::lite::Tensor *> &tensors, int offset, 9225+ const std::vector<mindspore::lite::Tensor *> const_folded_output, 9226 const std::vector<std::pair<size_t, tensor_info>> &map_index, 9227 const std::vector<std::string> &output_names, const std::set<size_t> &out_set); 9228 virtual int QuantTensorData(schema::TensorT *dest_tensor, const mindspore::lite::Tensor *src_tensor, 9229diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc 9230index b581b389..c123cba8 100644 9231--- a/mindspore/lite/src/train/train_session.cc 9232+++ b/mindspore/lite/src/train/train_session.cc 9233@@ -399,6 +399,8 @@ int TrainSession::CompileTrainGraph(std::shared_ptr<Model> model) { 9234 MS_LOG(ERROR) << "failed to allocate space"; 9235 return RET_ERROR; 9236 } 9237+ // Prepare a list of kernels which are const folded 9238+ MS_CHECK_TRUE_MSG(CompileConstFoldedKernels() == RET_OK, RET_ERROR, "CompileConstFoldedKernels failed."); 9239 return RET_OK; 9240 } 9241 9242@@ -697,20 +699,30 @@ void TrainSession::CompileEvalOutputs() { 9243 } 9244 if (is_loss) continue; 9245 // insert if not already in 9246- if (eval_output_node_map_.find(in_kernel->name()) == eval_output_node_map_.end()) { 9247- auto *ms_tensor = in_kernel->out_tensors().at(0); 9248- if (ms_tensor != nullptr) { 9249- ms_tensor->set_init_ref_count(ms_tensor->init_ref_count() + 1); 9250- eval_output_node_map_[in_kernel->name()].emplace_back(ms_tensor); 9251- auto index = TSFindTensor(tensors_, ms_tensor); 9252- if (index != tensors_.size()) { 9253- if (!ms_tensor->tensor_name().empty()) { 9254- eval_output_tensor_map_.insert(std::make_pair(ms_tensor->tensor_name(), ms_tensor)); 9255- eval_output_tensor_names_.emplace_back(ms_tensor->tensor_name()); 9256- } else { 9257- eval_output_tensor_map_.insert(std::make_pair(std::to_string(index), ms_tensor)); 9258- eval_output_tensor_names_.emplace_back(std::to_string(index)); 9259- } 9260+ auto out_tensors = TSFindTensors(in_kernel, kernel); 9261+ if (eval_output_node_map_.find(in_kernel->name()) != eval_output_node_map_.end()) { 9262+ auto exist_out_tensors = eval_output_node_map_[in_kernel->name()]; 9263+ std::vector<Tensor *> all_out_tensors; 9264+ auto kernel_all_out_tensors = in_kernel->out_tensors(); 9265+ eval_output_node_map_[in_kernel->name()] = {}; 9266+ for (auto tensor : kernel_all_out_tensors) { 9267+ if (std::find(out_tensors.begin(), out_tensors.end(), tensor) != out_tensors.end() || 9268+ std::find(exist_out_tensors.begin(), exist_out_tensors.end(), tensor) != exist_out_tensors.end()) { 9269+ eval_output_node_map_[in_kernel->name()].emplace_back(tensor); 9270+ } 9271+ } 9272+ } else { 9273+ eval_output_node_map_[in_kernel->name()] = out_tensors; 9274+ } 9275+ for (auto out_tensor : out_tensors) { 9276+ auto index = TSFindTensor(tensors_, out_tensor); 9277+ if (index != tensors_.size()) { 9278+ if (!out_tensor->tensor_name().empty()) { 9279+ eval_output_tensor_map_.insert(std::make_pair(out_tensor->tensor_name(), out_tensor)); 9280+ eval_output_tensor_names_.emplace_back(out_tensor->tensor_name()); 9281+ } else { 9282+ eval_output_tensor_map_.insert(std::make_pair(std::to_string(index), out_tensor)); 9283+ eval_output_tensor_names_.emplace_back(std::to_string(index)); 9284 } 9285 } 9286 } 9287@@ -863,6 +875,35 @@ void TrainSession::CompileOptimizedKernels() { 9288 } 9289 } 9290 9291+int TrainSession::CompileConstFoldedKernels() { 9292+ const_output_tensors_.clear(); 9293+ for (auto kernel : this->inference_kernels_) { 9294+ bool is_input_const = true; 9295+ for (auto input : kernel->in_tensors()) { 9296+ if ((!input->IsConst() || input->IsGraphInput()) && 9297+ std::find(const_output_tensors_.begin(), const_output_tensors_.end(), input) == const_output_tensors_.end()) { 9298+ is_input_const = false; 9299+ } 9300+ if (!is_input_const) { 9301+ const_fold_kernels_.emplace_back(kernel); 9302+ break; 9303+ } 9304+ } 9305+ if (is_input_const) { 9306+ auto ret = kernel->Execute(); 9307+ if (RET_OK != ret) { 9308+ MS_LOG(ERROR) << "run kernel failed, name: " << kernel->name(); 9309+ return ret; 9310+ } 9311+ for (auto output : kernel->out_tensors()) { 9312+ const_output_tensors_.emplace_back(output); 9313+ output->set_category(Category::CONST_TENSOR); 9314+ } 9315+ } 9316+ } 9317+ return RET_OK; 9318+} 9319+ 9320 void TrainSession::CompileTrainableParams() { 9321 for (auto kernel : this->train_kernels_) { 9322 if (!IsOptimizer(kernel)) { 9323@@ -1214,9 +1255,10 @@ int TrainSession::ExportByDifferentType(DestType destination, ModelType model_ty 9324 TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to init export"); 9325 if (!output_tensor_name.empty() && model_type == MT_INFERENCE) { 9326 std::vector<kernel::KernelExec *> export_kernels = {}; 9327- status = FindExportKernels(&export_kernels, output_tensor_name, inference_kernels_); 9328+ status = FindExportKernels(&export_kernels, output_tensor_name, const_fold_kernels_); 9329 TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "FindExportKernels failed."); 9330- status = texport.ExportNet(export_kernels, tensors_, output_tensor_name, model_.get(), quant_type); 9331+ status = 9332+ texport.ExportNet(export_kernels, tensors_, const_output_tensors_, output_tensor_name, model_.get(), quant_type); 9333 } else { 9334 if (!output_tensor_name.empty() && model_type == MT_TRAIN) { 9335 MS_LOG(WARNING) << "Train model does not support to export selected output tensor, and all of the train kernels " 9336@@ -1234,9 +1276,15 @@ int TrainSession::ExportByDifferentType(DestType destination, ModelType model_ty 9337 } 9338 return status; 9339 } else { 9340- status = texport.ExportNet((model_type == MT_TRAIN) ? train_kernels_ : inference_kernels_, tensors_, 9341- (model_type == MT_TRAIN) ? train_output_tensor_names_ : eval_output_tensor_names_, 9342- model_.get(), quant_type); 9343+ if (quant_type == QT_NONE) { 9344+ status = texport.ExportNet( 9345+ (model_type == MT_TRAIN) ? train_kernels_ : const_fold_kernels_, tensors_, const_output_tensors_, 9346+ (model_type == MT_TRAIN) ? train_output_tensor_names_ : eval_output_tensor_names_, model_.get(), quant_type); 9347+ } else { 9348+ status = texport.ExportNet((model_type == MT_TRAIN) ? train_kernels_ : inference_kernels_, tensors_, {}, 9349+ (model_type == MT_TRAIN) ? train_output_tensor_names_ : eval_output_tensor_names_, 9350+ model_.get(), quant_type); 9351+ } 9352 } 9353 } 9354 TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to export Network."); 9355@@ -1322,14 +1370,13 @@ int TrainSession::ExportWeightsCollaborateWithMicro(const std::string &file_name 9356 MS_CHECK_FALSE_MSG(format != FT_FLATBUFFERS, RET_ERROR, "File name cannot be empty"); 9357 MS_CHECK_FALSE_MSG(model_type != mindspore::lite::MT_INFERENCE, RET_ERROR, 9358 "Currently, can only export inference-model's weights."); 9359- int status = Eval(); 9360- TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Eval failed"); 9361 9362 TrainExport texport(file_name); 9363- status = texport.ExportInit(model_.get()->graph_.name_, model_.get()->graph_.version_); 9364+ auto status = texport.ExportInit(model_.get()->graph_.name_, model_.get()->graph_.version_); 9365 TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to init export"); 9366 9367- status = texport.ExportNet(inference_kernels_, tensors_, eval_output_tensor_names_, model_.get(), QT_DEFAULT); 9368+ status = texport.ExportNet(const_fold_kernels_, tensors_, const_output_tensors_, eval_output_tensor_names_, 9369+ model_.get(), QT_NONE); 9370 TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to export Network."); 9371 status = texport.TrainModelDrop(); 9372 TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelDrop failed."); 9373diff --git a/mindspore/lite/src/train/train_session.h b/mindspore/lite/src/train/train_session.h 9374index 24f10065..0bd14b21 100644 9375--- a/mindspore/lite/src/train/train_session.h 9376+++ b/mindspore/lite/src/train/train_session.h 9377@@ -128,6 +128,7 @@ class TrainSession : virtual public lite::LiteSession { 9378 virtual int CompileInferenceKernels(); 9379 virtual void CompileOptimizedKernels(); 9380 virtual void CompileTrainableParams(); 9381+ virtual int CompileConstFoldedKernels(); 9382 virtual void CompileTrainOutputs(); 9383 virtual void CompileEvalOutputs(); 9384 virtual int InitCallBack(); 9385@@ -146,6 +147,8 @@ class TrainSession : virtual public lite::LiteSession { 9386 9387 std::vector<kernel::KernelExec *> inference_kernels_; 9388 std::vector<kernel::KernelExec *> train_kernels_; 9389+ std::vector<kernel::KernelExec *> const_fold_kernels_; 9390+ std::vector<lite::Tensor *> const_output_tensors_; 9391 TrainCfg cfg_; 9392 9393 private: 9394diff --git a/mindspore/lite/src/train/train_utils.cc b/mindspore/lite/src/train/train_utils.cc 9395index 32c4a502..cb7b669a 100644 9396--- a/mindspore/lite/src/train/train_utils.cc 9397+++ b/mindspore/lite/src/train/train_utils.cc 9398@@ -204,5 +204,20 @@ int ScaleTensor(Tensor *tensor, float scale) { 9399 MS_LOG(DEBUG) << "Scale tensor: " << tensor->tensor_name() << " " << scale; 9400 return tensor->Scale<float>(scale); 9401 } 9402+ 9403+std::vector<Tensor *> TSFindTensors(const kernel::KernelExec *pre_kernel, const kernel::KernelExec *post_kernel) { 9404+ MS_ASSERT(pre_kernel != nullptr); 9405+ MS_ASSERT(post_kernel != nullptr); 9406+ auto out_tensors = pre_kernel->out_tensors(); 9407+ auto in_tensors = post_kernel->in_tensors(); 9408+ std::vector<Tensor *> res; 9409+ for (auto tensor : out_tensors) { 9410+ if (std::find(in_tensors.begin(), in_tensors.end(), tensor) == in_tensors.end()) { 9411+ continue; 9412+ } 9413+ res.push_back(tensor); 9414+ } 9415+ return res; 9416+} 9417 } // namespace lite 9418 } // namespace mindspore 9419diff --git a/mindspore/lite/src/train/train_utils.h b/mindspore/lite/src/train/train_utils.h 9420index 5c85738f..9b2d62dc 100644 9421--- a/mindspore/lite/src/train/train_utils.h 9422+++ b/mindspore/lite/src/train/train_utils.h 9423@@ -36,6 +36,7 @@ float CalculateSparseClassification(lite::Tensor *input, lite::Tensor *output); 9424 float CalculateOneHotClassification(lite::Tensor *input, lite::Tensor *output); 9425 Tensor *CastTensor(Tensor *tensor, TypeId dst_data_type, bool support_fp16); 9426 int ScaleTensor(Tensor *tensor, float scale); 9427+std::vector<Tensor *> TSFindTensors(const kernel::KernelExec *pre_kernel, const kernel::KernelExec *post_kernel); 9428 } // namespace lite 9429 } // namespace mindspore 9430 #endif // MINDSPORE_LITE_SRC_TRAIN_TRAIN_UTILS_H_ 9431diff --git a/mindspore/lite/src/train/transfer_session.cc b/mindspore/lite/src/train/transfer_session.cc 9432index 48191b4f..b1cb7b3e 100644 9433--- a/mindspore/lite/src/train/transfer_session.cc 9434+++ b/mindspore/lite/src/train/transfer_session.cc 9435@@ -230,10 +230,10 @@ int TransferSession::ExportInner(DestType destination, ModelType model_type, Qua 9436 MS_LOG(ERROR) << "FindExportKernels failed."; 9437 return RET_ERROR; 9438 } 9439- status = texport.ExportNet(export_kernels, tensors_, out_put_tensor_name, model_.get(), quant_type, 9440+ status = texport.ExportNet(export_kernels, tensors_, {}, out_put_tensor_name, model_.get(), quant_type, 9441 backbone_session_->model_); 9442 } else { 9443- status = texport.ExportNet(inference_kernels_, tensors_, GetOutputTensorNames(), model_.get(), quant_type, 9444+ status = texport.ExportNet(inference_kernels_, tensors_, {}, GetOutputTensorNames(), model_.get(), quant_type, 9445 backbone_session_->model_); 9446 } 9447 if (status != RET_OK) { 9448diff --git a/mindspore/lite/tools/common/string_util.cc b/mindspore/lite/tools/common/string_util.cc 9449index 8d7076e5..13cddb3a 100644 9450--- a/mindspore/lite/tools/common/string_util.cc 9451+++ b/mindspore/lite/tools/common/string_util.cc 9452@@ -199,5 +199,9 @@ size_t Hex2ByteArray(const std::string &hex_str, unsigned char *byte_array, size 9453 } 9454 return byte_len; 9455 } 9456+ 9457+bool IsNumber(const std::string &item) { 9458+ return std::all_of(item.begin(), item.end(), [](char ch) { return ch >= '0' && ch <= '9'; }); 9459+} 9460 } // namespace lite 9461 } // namespace mindspore 9462diff --git a/mindspore/lite/tools/common/string_util.h b/mindspore/lite/tools/common/string_util.h 9463index 0fb9c0b2..95bdd742 100644 9464--- a/mindspore/lite/tools/common/string_util.h 9465+++ b/mindspore/lite/tools/common/string_util.h 9466@@ -45,6 +45,8 @@ bool ConvertBool(std::string str, bool *value); 9467 bool ConvertDoubleVector(const std::string &str, std::vector<double> *value); 9468 9469 size_t Hex2ByteArray(const std::string &hex_str, unsigned char *byte_array, size_t max_len); 9470+ 9471+bool IsNumber(const std::string &item); 9472 } // namespace lite 9473 } // namespace mindspore 9474 #endif // MINDSPORE_LITE_TOOLS_COMMON_STRING_UTIL_H_ 9475diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc 9476index c4f84163..b63912fa 100644 9477--- a/mindspore/lite/tools/converter/anf_transform.cc 9478+++ b/mindspore/lite/tools/converter/anf_transform.cc 9479@@ -135,6 +135,7 @@ 9480 #include "tools/common/string_util.h" 9481 #include "src/common/common.h" 9482 #include "tools/optimizer/graph/miniaturization_pass.h" 9483+#include "tools/optimizer/fusion/tile_matmul_fusion.h" 9484 9485 using std::string; 9486 namespace mindspore::lite { 9487@@ -317,7 +318,8 @@ std::vector<opt::PassPtr> InitFusions(const std::shared_ptr<ConverterPara> ¶ 9488 std::make_shared<opt::MulActivationFusion>(), 9489 std::make_shared<opt::AddActivationFusion>(), 9490 std::make_shared<opt::ExpandDimsReshapeFusion>(), 9491- std::make_shared<opt::SqueezeExpandDimsFusion>()}; 9492+ std::make_shared<opt::SqueezeExpandDimsFusion>(), 9493+ std::make_shared<opt::TileMatMulFusion>()}; 9494 if (param->optimize_transformer) { 9495 fusions.push_back(std::make_shared<opt::MultiHeadAttentionFusion>()); 9496 fusions.push_back(std::make_shared<opt::EncoderLayerFusion>()); 9497diff --git a/mindspore/lite/tools/converter/config_parser/config_file_parser.cc b/mindspore/lite/tools/converter/config_parser/config_file_parser.cc 9498index 2e7ca749..7b47fb8c 100644 9499--- a/mindspore/lite/tools/converter/config_parser/config_file_parser.cc 9500+++ b/mindspore/lite/tools/converter/config_parser/config_file_parser.cc 9501@@ -19,10 +19,10 @@ 9502 #include "include/errorcode.h" 9503 #include "src/common/log_adapter.h" 9504 #include "tools/converter/converter_context.h" 9505- 9506 #include "tools/common/string_util.h" 9507 #include "src/common/config_infos.h" 9508 #include "src/common/common.h" 9509+#include "nnacl/op_base.h" 9510 9511 namespace mindspore { 9512 namespace lite { 9513@@ -208,6 +208,75 @@ void SetDynParams(const std::shared_ptr<mindspore::ConverterPara> ¶m, 9514 } 9515 } 9516 9517+int ParseInputShapeTemplate(const std::string &shape_template, std::set<std::string> *dynamic_symbols) { 9518+ // the inputs_shape config is like: input1:[d0,d1,3];input2:[4,d0] 9519+ auto graph_inputs_shape_vec = SplitStringToVector(shape_template, ';'); 9520+ for (const auto &graph_input_shape : graph_inputs_shape_vec) { 9521+ auto graph_input_shape_info = SplitStringToVector(graph_input_shape, ':'); 9522+ MS_CHECK_TRUE_MSG(graph_input_shape_info.size() == kIndex2, RET_INPUT_PARAM_INVALID, "the inputs_shape is invalid"); 9523+ auto input_shape = graph_input_shape_info[1]; 9524+ if (input_shape[0] != '[' || input_shape[input_shape.size() - 1] != ']') { 9525+ MS_LOG(ERROR) << "the inputs_shape is invalid"; 9526+ return RET_INPUT_PARAM_INVALID; 9527+ } 9528+ input_shape = input_shape.substr(1, input_shape.size() - kIndex2); 9529+ auto input_shape_vec = SplitStringToVector(input_shape, ','); 9530+ for (const auto &shape : input_shape_vec) { 9531+ if (!IsNumber(shape)) { 9532+ dynamic_symbols->insert(shape); 9533+ } 9534+ } 9535+ } 9536+ return RET_OK; 9537+} 9538+ 9539+int ParseDynmiacDimTemplate(const std::string &dims_template, std::set<std::string> *dynamic_symbols, 9540+ MicroParamString *micro_param_string) { 9541+ // the dynamic_dim_params config is like: d0:[1,3~6];d1:[1~8] 9542+ auto dim_info_vec = SplitStringToVector(dims_template, ';'); 9543+ MS_CHECK_TRUE_MSG(dim_info_vec.size() <= kIndex2, RET_NOT_SUPPORT, "currently, only support to set two dynamic dims"); 9544+ for (const auto &dim_info : dim_info_vec) { 9545+ auto dim_vec = SplitStringToVector(dim_info, ':'); 9546+ MS_CHECK_TRUE_MSG(dim_vec.size() == kIndex2, RET_INPUT_PARAM_INVALID, "the dynamic_dim_params is invalid"); 9547+ std::string symbol = dim_vec[0]; 9548+ if (dynamic_symbols->find(symbol) == dynamic_symbols->end()) { 9549+ MS_LOG(ERROR) << symbol << "is invalid, because it's not set in the inputs_shape."; 9550+ return RET_INPUT_PARAM_INVALID; 9551+ } 9552+ std::string dim_range = dim_vec[1]; 9553+ if (dim_range[0] != '[' || dim_range[dim_range.size() - 1] != ']') { 9554+ MS_LOG(ERROR) << "the dynamic_dim_params is invalid"; 9555+ return RET_INPUT_PARAM_INVALID; 9556+ } 9557+ dim_range = dim_range.substr(1, dim_range.size() - kIndex2); 9558+ auto discrete_vec = SplitStringToVector(dim_range, ','); 9559+ for (const auto &dim : discrete_vec) { 9560+ auto continuous_dim = SplitStringToVector(dim, '~'); 9561+ MS_CHECK_TRUE_MSG(continuous_dim.size() == C1NUM || continuous_dim.size() == kIndex2, RET_INPUT_PARAM_INVALID, 9562+ "the dynamic_dim_params is invalid"); 9563+ if (continuous_dim.size() == C1NUM) { 9564+ if (!IsNumber(continuous_dim[0]) || std::stoi(continuous_dim[0]) <= 0) { 9565+ MS_LOG(ERROR) << "the dynamic_dim_params range value must be greater than 0"; 9566+ return RET_INPUT_PARAM_INVALID; 9567+ } 9568+ micro_param_string->dynamic_symbols_map[symbol] += continuous_dim[0] + ","; 9569+ continue; 9570+ } 9571+ if (!IsNumber(continuous_dim[0]) || std::stoi(continuous_dim[0]) <= 0 || !IsNumber(continuous_dim[1]) || 9572+ std::stoi(continuous_dim[1]) <= 0) { 9573+ MS_LOG(ERROR) << "the dynamic_dim_params range value must be greater than 0"; 9574+ return RET_INPUT_PARAM_INVALID; 9575+ } 9576+ auto start = std::stoi(continuous_dim[0]); 9577+ auto end = std::stoi(continuous_dim[1]); 9578+ for (auto i = start; i <= end; ++i) { 9579+ micro_param_string->dynamic_symbols_map[symbol] += std::to_string(i) + ","; 9580+ } 9581+ } 9582+ } 9583+ return RET_OK; 9584+} 9585+ 9586 void ConfigFileParser::SetParamByConfigfile(const std::shared_ptr<mindspore::ConverterPara> ¶m, 9587 const std::map<std::string, std::string> &ascend_map) { 9588 std::string ascend_string = ""; 9589@@ -377,8 +446,12 @@ int ConfigFileParser::ParseConfigParam(std::map<std::string, std::map<std::strin 9590 } 9591 9592 int ConfigFileParser::SetMapData(const std::map<std::string, std::string> &input_map, 9593- const std::map<std::string, std::string &> &parse_map, const std::string §ion) { 9594+ const std::map<std::string, std::string &> &parse_map, const std::string §ion, 9595+ const std::set<std::string> &dynamic_key) { 9596 for (const auto &map : input_map) { 9597+ if (dynamic_key.find(map.first) != dynamic_key.end()) { 9598+ continue; 9599+ } 9600 if (parse_map.find(map.first) == parse_map.end()) { 9601 MS_LOG(ERROR) << "INPUT ILLEGAL: `" << map.first << "` is not supported in " 9602 << "[" << section << "]"; 9603@@ -511,21 +584,34 @@ int ConfigFileParser::ParseAclOptionCfgString(const std::map<std::string, std::m 9604 } 9605 9606 int ConfigFileParser::ParseMicroParamString(const std::map<std::string, std::map<std::string, std::string>> &maps) { 9607- if (maps.find(kMicroParam) != maps.end()) { 9608- const auto &map = maps.at(kMicroParam); 9609- std::map<std::string, std::string &> parse_map{ 9610- {"target", micro_param_string_.target}, 9611- {"codegen_mode", micro_param_string_.codegen_mode}, 9612- {"debug_mode", micro_param_string_.debug_mode}, 9613- {"support_parallel", micro_param_string_.support_parallel}, 9614- {"enable_micro", micro_param_string_.enable_micro}, 9615- {"save_path", micro_param_string_.save_path}, 9616- {"project_name", micro_param_string_.project_name}, 9617- {"keep_original_weight", micro_param_string_.keep_original_weight}, 9618- {"changeable_weights_name", micro_param_string_.changeable_weights_name}}; 9619- return SetMapData(map, parse_map, kMicroParam); 9620+ if (maps.find(kMicroParam) == maps.end()) { 9621+ return RET_OK; 9622 } 9623- return RET_OK; 9624+ const auto &map = maps.at(kMicroParam); 9625+ const std::string graph_inputs_shape_template = "inputs_shape"; 9626+ std::set<std::string> dynamic_symbols; 9627+ if (map.find(graph_inputs_shape_template) != map.end()) { 9628+ const auto &shape_template = map.at(graph_inputs_shape_template); 9629+ ParseInputShapeTemplate(shape_template, &dynamic_symbols); 9630+ } 9631+ const std::string dynamic_dims = "dynamic_dim_params"; 9632+ if (!dynamic_symbols.empty() && map.find(dynamic_dims) != map.end()) { 9633+ const auto &dims_template = map.at(dynamic_dims); 9634+ ParseDynmiacDimTemplate(dims_template, &dynamic_symbols, µ_param_string_); 9635+ } 9636+ std::map<std::string, std::string &> parse_map{ 9637+ {"target", micro_param_string_.target}, 9638+ {"codegen_mode", micro_param_string_.codegen_mode}, 9639+ {"debug_mode", micro_param_string_.debug_mode}, 9640+ {"support_parallel", micro_param_string_.support_parallel}, 9641+ {"enable_micro", micro_param_string_.enable_micro}, 9642+ {"save_path", micro_param_string_.save_path}, 9643+ {"project_name", micro_param_string_.project_name}, 9644+ {"keep_original_weight", micro_param_string_.keep_original_weight}, 9645+ {"changeable_weights_name", micro_param_string_.changeable_weights_name}, 9646+ {"inputs_shape", micro_param_string_.inputs_shape}, 9647+ {"dynamic_dim_params", micro_param_string_.dynamic_dim_params}}; 9648+ return SetMapData(map, parse_map, kMicroParam); 9649 } 9650 9651 int ConfigFileParser::ParseWeightQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps) { 9652diff --git a/mindspore/lite/tools/converter/config_parser/config_file_parser.h b/mindspore/lite/tools/converter/config_parser/config_file_parser.h 9653index 6997bac8..163782b7 100644 9654--- a/mindspore/lite/tools/converter/config_parser/config_file_parser.h 9655+++ b/mindspore/lite/tools/converter/config_parser/config_file_parser.h 9656@@ -108,17 +108,20 @@ struct MicroParamString { 9657 std::string project_name; 9658 std::string keep_original_weight; 9659 std::string changeable_weights_name; 9660+ std::string inputs_shape; 9661+ std::string dynamic_dim_params; 9662+ std::map<std::string, std::string> dynamic_symbols_map; 9663 }; 9664 9665 struct ThirdPartyModelString { 9666 std::string input_dtypes; 9667 std::string input_shapes; 9668- std::string input_names; // optional, default: "" 9669+ std::string input_names; // optional, default: "" 9670 std::string input_formats; // optional, default: NHWC 9671 std::string output_dtypes; 9672 std::string output_shapes; 9673- std::string output_names; // optional, default: "" 9674- std::string output_formats; // optional, default: NHWC 9675+ std::string output_names; // optional, default: "" 9676+ std::string output_formats; // optional, default: NHWC 9677 std::string extended_parameters; // format: {key1:value1;ker2:value2} 9678 }; 9679 9680@@ -172,7 +175,8 @@ class ConfigFileParser { 9681 int ParseRegistryInfoString(const std::map<std::string, std::map<std::string, std::string>> &maps); 9682 int ParseAclOptionCfgString(const std::map<std::string, std::map<std::string, std::string>> &maps); 9683 int SetMapData(const std::map<std::string, std::string> &input_map, 9684- const std::map<std::string, std::string &> &parse_map, const std::string §ion); 9685+ const std::map<std::string, std::string &> &parse_map, const std::string §ion, 9686+ const std::set<std::string> &dynamic_key = {}); 9687 int ParseMicroParamString(const std::map<std::string, std::map<std::string, std::string>> &maps); 9688 int ParseThirdPartyParamString(const std::map<std::string, std::map<std::string, std::string>> §ions); 9689 int ParseCpuOptionCfgString(const std::map<std::string, std::map<std::string, std::string>> &maps); 9690diff --git a/mindspore/lite/tools/converter/config_parser/micro_param_parser.cc b/mindspore/lite/tools/converter/config_parser/micro_param_parser.cc 9691index c9998cc8..903f2863 100644 9692--- a/mindspore/lite/tools/converter/config_parser/micro_param_parser.cc 9693+++ b/mindspore/lite/tools/converter/config_parser/micro_param_parser.cc 9694@@ -19,6 +19,7 @@ 9695 #include "tools/common/string_util.h" 9696 #include "src/common/log_adapter.h" 9697 #include "src/common/log_util.h" 9698+#include "nnacl/op_base.h" 9699 9700 namespace mindspore { 9701 namespace lite { 9702@@ -115,6 +116,80 @@ STATUS MicroParamParser::ParseChangeableWeightsName(const std::string &changeabl 9703 return RET_OK; 9704 } 9705 9706+STATUS MicroParamParser::ParseGraphInputsShapeTemplate(const std::string &graph_inputs_shape_template, 9707+ const std::map<std::string, std::string> &dynamic_symbols_map, 9708+ micro::MicroParam *micro_param) { 9709+ MS_LOG(DEBUG) << "Micro record inputs shape: " << graph_inputs_shape_template; 9710+ if (!graph_inputs_shape_template.empty()) { 9711+ auto graph_inputs_shape_vec = SplitStringToVector(graph_inputs_shape_template, ';'); 9712+ std::map<std::string, std::vector<std::string>> graph_inputs_info; 9713+ std::vector<std::vector<std::string>> graph_inputs_shape; 9714+ std::vector<std::string> inputs_name; 9715+ for (const auto &graph_input_shape : graph_inputs_shape_vec) { 9716+ auto input_shape_info = SplitStringToVector(graph_input_shape, ':'); 9717+ std::string input_name = input_shape_info[0]; 9718+ std::string input_shape = input_shape_info[1].substr(1, input_shape_info[1].size() - C2NUM); 9719+ auto input_shape_vec = SplitStringToVector(input_shape, ','); 9720+ graph_inputs_info[input_name] = input_shape_vec; 9721+ graph_inputs_shape.push_back(input_shape_vec); 9722+ inputs_name.push_back(input_name); 9723+ } 9724+ micro_param->graph_inputs_origin_info = graph_inputs_info; 9725+ micro_param->inputs_shape_by_scenes.clear(); 9726+ std::map<std::string, std::vector<int>> symbols_to_num; 9727+ std::map<std::string, int> symbols_index; 9728+ std::vector<std::string> symbols; 9729+ std::vector<size_t> scene_num_by_symbol; 9730+ int index = 0; 9731+ size_t scene_num = 1; 9732+ for (const auto &item : dynamic_symbols_map) { 9733+ symbols_index[item.first] = index++; 9734+ symbols.push_back(item.first); 9735+ auto num_str_list = SplitStringToVector(item.second, ','); 9736+ for (const auto &num_str : num_str_list) { 9737+ symbols_to_num[item.first].push_back(std::stoi(num_str)); 9738+ } 9739+ if (symbols_to_num[item.first].empty()) { 9740+ MS_LOG(ERROR) << "Micro param invalid, dynamic symbol must have value."; 9741+ return RET_INPUT_PARAM_INVALID; 9742+ } 9743+ scene_num_by_symbol.push_back(symbols_to_num[item.first].size()); 9744+ scene_num *= symbols_to_num[item.first].size(); 9745+ } 9746+ micro_param->dynamic_symbols = symbols; 9747+ micro_param->dynamic_symbols_num = scene_num_by_symbol; 9748+ std::vector<size_t> post_multi(symbols.size(), 1); 9749+ for (int i = static_cast<int>(post_multi.size()) - 2; i >= 0; --i) { 9750+ post_multi[i] = post_multi[i + 1] * scene_num_by_symbol[i + 1]; 9751+ } 9752+ std::vector<int> real_num(symbols.size()); 9753+ for (size_t i = 0; i < scene_num; ++i) { 9754+ size_t remain = i; 9755+ for (size_t j = 0; j < symbols.size(); ++j) { 9756+ real_num[j] = remain / post_multi[j]; 9757+ remain %= post_multi[j]; 9758+ } 9759+ for (size_t j = 0; j < graph_inputs_shape.size(); ++j) { 9760+ const auto &input_template = graph_inputs_shape[j]; 9761+ std::vector<int> input_shape; 9762+ for (const auto &dim : input_template) { 9763+ if (IsNumber(dim)) { 9764+ input_shape.push_back(std::stoi(dim)); 9765+ continue; 9766+ } 9767+ if (symbols_index.find(dim) == symbols_index.end()) { 9768+ MS_LOG(ERROR) << "Dynamic symbol cannot find real num."; 9769+ return RET_INPUT_PARAM_INVALID; 9770+ } 9771+ input_shape.push_back(symbols_to_num[dim][real_num[symbols_index[dim]]]); 9772+ } 9773+ micro_param->inputs_shape_by_scenes[inputs_name[j]].push_back(input_shape); 9774+ } 9775+ } 9776+ } 9777+ return RET_OK; 9778+} 9779+ 9780 STATUS MicroParamParser::ParseMicroParam(const MicroParamString µ_param_string, micro::MicroParam *micro_param) { 9781 CHECK_NULL_RETURN(micro_param); 9782 if (ParseTarget(micro_param_string.target, micro_param) != RET_OK) { 9783@@ -145,9 +220,11 @@ STATUS MicroParamParser::ParseMicroParam(const MicroParamString µ_param_str 9784 MS_LOG(ERROR) << "Parse project name val failed: " << micro_param_string.project_name; 9785 return RET_INPUT_PARAM_INVALID; 9786 } 9787- if (ParseKeepOriginalWeight(micro_param_string.keep_original_weight, micro_param) != RET_OK) { 9788- MS_LOG(ERROR) << "Parse keep_original_weight failed, the val: " << micro_param_string.keep_original_weight; 9789- return RET_INPUT_PARAM_INVALID; 9790+ if (!micro_param_string.keep_original_weight.empty()) { 9791+ if (ParseKeepOriginalWeight(micro_param_string.keep_original_weight, micro_param) != RET_OK) { 9792+ MS_LOG(ERROR) << "Parse keep_original_weight val; " << micro_param_string.keep_original_weight; 9793+ return RET_INPUT_PARAM_INVALID; 9794+ } 9795 } 9796 if (!micro_param_string.changeable_weights_name.empty() && !micro_param->keep_original_weight) { 9797 MS_LOG(ERROR) << "When changeable_weights_name is set, the keep_original_weight must be true."; 9798@@ -157,6 +234,12 @@ STATUS MicroParamParser::ParseMicroParam(const MicroParamString µ_param_str 9799 MS_LOG(ERROR) << "Parse changeable_weights_name failed, the val: " << micro_param_string.changeable_weights_name; 9800 return RET_INPUT_PARAM_INVALID; 9801 } 9802+ if (ParseGraphInputsShapeTemplate(micro_param_string.inputs_shape, micro_param_string.dynamic_symbols_map, 9803+ micro_param) != RET_OK) { 9804+ MS_LOG(ERROR) << "Parse inputs_shape & dynamic_dim_params failed, the inputs_shape val: " 9805+ << micro_param_string.inputs_shape; 9806+ return RET_INPUT_PARAM_INVALID; 9807+ } 9808 return RET_OK; 9809 } 9810 } // namespace lite 9811diff --git a/mindspore/lite/tools/converter/config_parser/micro_param_parser.h b/mindspore/lite/tools/converter/config_parser/micro_param_parser.h 9812index b6efb4c7..eb95c571 100644 9813--- a/mindspore/lite/tools/converter/config_parser/micro_param_parser.h 9814+++ b/mindspore/lite/tools/converter/config_parser/micro_param_parser.h 9815@@ -37,6 +37,9 @@ class MicroParamParser { 9816 STATUS ParseProjName(const std::string &debug_mode, micro::MicroParam *micro_param); 9817 STATUS ParseKeepOriginalWeight(const std::string &keep_weight, micro::MicroParam *micro_param); 9818 STATUS ParseChangeableWeightsName(const std::string &changeable_weights_name, micro::MicroParam *micro_param); 9819+ STATUS ParseGraphInputsShapeTemplate(const std::string &graph_inputs_shape_template, 9820+ const std::map<std::string, std::string> &dynamic_symbols_map, 9821+ micro::MicroParam *micro_param); 9822 }; 9823 } // namespace lite 9824 } // namespace mindspore 9825diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc 9826index a61bd51c..4703e889 100644 9827--- a/mindspore/lite/tools/converter/converter.cc 9828+++ b/mindspore/lite/tools/converter/converter.cc 9829@@ -56,6 +56,7 @@ 9830 #include "src/common/file_utils.h" 9831 #include "ops/dynamic_shape.h" 9832 #include "tools/common/parse_config_utils.h" 9833+#include "src/common/file_utils.h" 9834 #include "tools/converter/converter_packed_node.h" 9835 #include "tools/converter/config_parser/cpu_option_param_parser.h" 9836 #include "tools/converter/export_model.h" 9837@@ -432,54 +433,34 @@ int ConverterImpl::InitConfigParam(const std::shared_ptr<ConverterPara> ¶m, 9838 MS_LOG(ERROR) << "Parse config param failed."; 9839 return ret; 9840 } 9841- ret = ParseParam(&config_parser, param, model_param_infos, maps); 9842- if (ret != RET_OK) { 9843- MS_LOG(ERROR) << "Parse param failed."; 9844- return ret; 9845- } 9846- return RET_OK; 9847-} 9848- 9849-int ConverterImpl::ParseParam(lite::ConfigFileParser *config_parser, const std::shared_ptr<ConverterPara> ¶m, 9850- const std::map<int, std::map<std::string, std::string>> *model_param_infos, 9851- const std::map<std::string, std::map<std::string, std::string>> maps) { 9852- param->config_infos = maps; 9853- auto ret = RET_OK; 9854 if (model_param_infos->empty()) { 9855- ret = 9856- lite::PreprocessParser::ParsePreprocess(config_parser->GetDataPreProcessString(), ¶m->dataPreProcessParam); 9857+ ret = lite::PreprocessParser::ParsePreprocess(config_parser.GetDataPreProcessString(), ¶m->dataPreProcessParam); 9858 if (ret != RET_OK) { 9859 MS_LOG(ERROR) << "Parse preprocess failed."; 9860 return ret; 9861 } 9862- ret = lite::QuantParamParser::ParseCommonQuant(config_parser->GetCommonQuantString(), ¶m->commonQuantParam); 9863+ ret = lite::QuantParamParser::ParseCommonQuant(config_parser.GetCommonQuantString(), ¶m->commonQuantParam); 9864 if (ret != RET_OK) { 9865 MS_LOG(ERROR) << "Parse common quant param failed."; 9866 return ret; 9867 } 9868- ret = lite::QuantParamParser::ParseFullQuant(config_parser->GetFullQuantString(), ¶m->fullQuantParam); 9869+ ret = lite::QuantParamParser::ParseFullQuant(config_parser.GetFullQuantString(), ¶m->fullQuantParam); 9870 if (ret != RET_OK) { 9871 MS_LOG(ERROR) << "Parse full quant param failed."; 9872 return ret; 9873 } 9874- ret = lite::QuantParamParser::ParseWeightQuant(config_parser->GetWeightQuantString(), ¶m->weightQuantParam); 9875- if (ret != RET_OK) { 9876- MS_LOG(ERROR) << "Parse full quant param failed."; 9877- return ret; 9878- } 9879- ret = lite::QuantParamParser::ParseMixedBitWeightQuant(config_parser->GetMixedBitWeightQuantString(), 9880+ ret = lite::QuantParamParser::ParseMixedBitWeightQuant(config_parser.GetMixedBitWeightQuantString(), 9881 ¶m->mixedBitWeightQuantParam); 9882 if (ret != RET_OK) { 9883 MS_LOG(ERROR) << "Parse mixed bit weight quant param failed."; 9884 return ret; 9885 } 9886- ret = lite::ThirdPartyParamParser::Parse(config_parser->GetThirdPartyModelString(), 9887- ¶m->thirdPartyModelParam); 9888+ ret = lite::ThirdPartyParamParser::Parse(config_parser.GetThirdPartyModelString(), ¶m->thirdPartyModelParam); 9889 if (ret != RET_OK) { 9890 MS_LOG(ERROR) << "Parse third party param failed."; 9891 return ret; 9892 } 9893- ret = InitExtendedIntegrationInfo(param, *config_parser); 9894+ ret = InitExtendedIntegrationInfo(param, config_parser); 9895 if (ret != RET_OK) { 9896 MS_LOG(ERROR) << "Parse extended integration info failed."; 9897 return ret; 9898@@ -490,7 +471,7 @@ int ConverterImpl::ParseParam(lite::ConfigFileParser *config_parser, const std:: 9899 param->aclModelOptionCfgParam.dump_model_name = 9900 dir_pos != std::string::npos ? output_file.substr(dir_pos + 1) : output_file; 9901 lite::AclOptionParamParser acl_param_parser; 9902- ret = acl_param_parser.ParseAclOptionCfg(config_parser->GetAclOptionCfgString(), ¶m->aclModelOptionCfgParam); 9903+ ret = acl_param_parser.ParseAclOptionCfg(config_parser.GetAclOptionCfgString(), ¶m->aclModelOptionCfgParam); 9904 if (ret != RET_OK) { 9905 MS_LOG(ERROR) << "Parse acl option param failed."; 9906 return ret; 9907@@ -498,14 +479,14 @@ int ConverterImpl::ParseParam(lite::ConfigFileParser *config_parser, const std:: 9908 // parse ascend_context in config file, the priority is higher 9909 if (maps.find("ascend_context") != maps.end()) { 9910 auto map = maps.at("ascend_context"); 9911- config_parser->SetParamByConfigfile(param, map); 9912+ config_parser.SetParamByConfigfile(param, map); 9913 } 9914 if (!param->config_file.empty()) { 9915 (void)CheckOfflineParallelConfig(param->config_file, ¶m->parallel_split_config); 9916 } 9917 9918 lite::CpuOptionParamParser cpu_param_parser; 9919- ret = cpu_param_parser.ParseCpuOptionCfg(config_parser->GetCpuOptionCfgString(), ¶m->cpuOptionCfgParam); 9920+ ret = cpu_param_parser.ParseCpuOptionCfg(config_parser.GetCpuOptionCfgString(), ¶m->cpuOptionCfgParam); 9921 if (ret != RET_OK) { 9922 MS_LOG(ERROR) << "Parse cpu option param failed."; 9923 return ret; 9924@@ -515,29 +496,29 @@ int ConverterImpl::ParseParam(lite::ConfigFileParser *config_parser, const std:: 9925 << "If there are multi models, only support micro_param and model_param, other configure can not take effect"; 9926 9927 lite::MicroParamParser micro_param_parser; 9928- ret = micro_param_parser.ParseMicroParam(config_parser->GetMicroParamString(), ¶m->microParam); 9929+ ret = micro_param_parser.ParseMicroParam(config_parser.GetMicroParamString(), ¶m->microParam); 9930 if (ret != RET_OK) { 9931 MS_LOG(ERROR) << "Parse micro param failed."; 9932 return ret; 9933 } 9934 ret = 9935- lite::QuantParamParser::ParseTransformQuant(config_parser->GetTransformQuantString(), ¶m->transformQuantParam); 9936+ lite::QuantParamParser::ParseTransformQuant(config_parser.GetTransformQuantString(), ¶m->transformQuantParam); 9937 if (ret != RET_OK) { 9938 MS_LOG(ERROR) << "Parse transform quant param failed."; 9939 return ret; 9940 } 9941- ret = lite::QuantParamParser::ParseAscendQuant(config_parser->GetAscendQuantString(), ¶m->ascendQuantParam); 9942+ ret = lite::QuantParamParser::ParseAscendQuant(config_parser.GetAscendQuantString(), ¶m->ascendQuantParam); 9943 if (ret != RET_OK) { 9944 MS_LOG(ERROR) << "Parse ascend quant param failed."; 9945 return ret; 9946 } 9947- ret = lite::QuantParamParser::ParseDynamicQuant(config_parser->GetDynamicQuantString(), ¶m->dynamicQuantParam); 9948+ ret = lite::QuantParamParser::ParseDynamicQuant(config_parser.GetDynamicQuantString(), ¶m->dynamicQuantParam); 9949 if (ret != RET_OK) { 9950 MS_LOG(ERROR) << "Parse dynamic quant param failed."; 9951 return ret; 9952 } 9953 lite::GraphKernelParamParser graph_kernel_parser; 9954- ret = graph_kernel_parser.ParseGraphKernelCfg(config_parser->GetGraphKernelString(), ¶m->graphKernelParam); 9955+ ret = graph_kernel_parser.ParseGraphKernelCfg(config_parser.GetGraphKernelString(), ¶m->graphKernelParam); 9956 if (ret != RET_OK) { 9957 MS_LOG(ERROR) << "Parse graph kernel param failed."; 9958 return ret; 9959@@ -708,9 +689,9 @@ int CheckFmkType(const std::shared_ptr<ConverterPara> ¶m) { 9960 if (param != nullptr) { 9961 return RET_OK; 9962 } 9963- std::set kValidFmkTypes = {FmkType::kFmkTypeTf, FmkType::kFmkTypeCaffe, FmkType::kFmkTypeOnnx, 9964- FmkType::kFmkTypeMs, FmkType::kFmkTypeTflite, FmkType::kFmkTypePytorch, 9965- FmkType::kFmkTypeMsLite, FmkType::kFmkTypeThirdParty}; 9966+ std::set kValidFmkTypes = {FmkType::kFmkTypeTf, FmkType::kFmkTypeCaffe, FmkType::kFmkTypeOnnx, 9967+ FmkType::kFmkTypeMs, FmkType::kFmkTypeTflite, FmkType::kFmkTypePytorch, 9968+ FmkType::kFmkTypeMsLite, FmkType::kFmkTypeThirdParty}; 9969 if (kValidFmkTypes.find(param->fmk_type) == kValidFmkTypes.end()) { 9970 MS_LOG(ERROR) << "INPUT ILLEGAL: fmk_type must be " 9971 "TF|CAFFE|ONNX|MS|TFLITE|PYTORCH|MSLITE|THIRDPARTY" 9972@@ -1010,7 +991,6 @@ int ConverterImpl::Convert(const std::shared_ptr<ConverterPara> ¶m, void **m 9973 model_index++; 9974 } 9975 } 9976- 9977 return RET_OK; 9978 } 9979 9980@@ -1045,7 +1025,6 @@ int ConverterImpl::HandleGraphCommon(const std::shared_ptr<ConverterPara> ¶m 9981 MS_LOG(ERROR) << "Save graph failed: " << ret << " " << GetErrorInfo(ret); 9982 return ret; 9983 } 9984- 9985 return RET_OK; 9986 } 9987 9988@@ -1067,8 +1046,8 @@ int ConverterImpl::ExecuteMicro(const schema::MetaGraphT *meta_graph, const std: 9989 } 9990 auto status = 9991 meta_graph != nullptr 9992- ? micro::Coder::MicroSourceCodeGeneration(*meta_graph, output_path, param->microParam, param->weight_fp16) 9993- : micro::Coder::MicroSourceCodeGeneration(param->model_file, output_path, param->microParam, param->weight_fp16); 9994+ ? micro::Coder::MicroSourceCodeGeneration(*meta_graph, output_path, ¶m->microParam, param->weight_fp16) 9995+ : micro::Coder::MicroSourceCodeGeneration(param->model_file, output_path, ¶m->microParam, param->weight_fp16); 9996 if (status != RET_OK) { 9997 MS_LOG(ERROR) << "Execute Micro failed."; 9998 } 9999@@ -1123,7 +1102,6 @@ int ConverterImpl::SaveGraph(FuncGraphPtr graph, const std::shared_ptr<Converter 10000 MS_LOG(ERROR) << "Save failed:" << status << " " << GetErrorInfo(status); 10001 return status; 10002 } 10003- 10004 return RET_OK; 10005 } 10006 10007diff --git a/mindspore/lite/tools/converter/import/mindspore_importer.cc b/mindspore/lite/tools/converter/import/mindspore_importer.cc 10008index 1d5afde4..aee0c854 100644 10009--- a/mindspore/lite/tools/converter/import/mindspore_importer.cc 10010+++ b/mindspore/lite/tools/converter/import/mindspore_importer.cc 10011@@ -39,6 +39,7 @@ 10012 #include "tools/optimizer/graph/redundant_op_remove_pass.h" 10013 #include "nnacl/op_base.h" 10014 #include "src/common/common.h" 10015+#include "tools/converter/import/to_custom_op_pass.h" 10016 10017 namespace mindspore::lite { 10018 namespace { 10019@@ -89,6 +90,13 @@ STATUS MindsporeImporter::Mindir2AnfAdjust(const FuncGraphPtr &func_graph, 10020 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); 10021 return RET_ERROR; 10022 } 10023+ auto to_custom_op_pass = std::make_shared<mindspore::opt::ToCustomOpPass>(); 10024+ MS_CHECK_TRUE_MSG(to_custom_op_pass != nullptr, RET_NULL_PTR, "to_custom_op_pass is nullptr."); 10025+ if (!to_custom_op_pass->Run(func_graph)) { 10026+ MS_LOG(ERROR) << "To custom op pass run failed!"; 10027+ ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); 10028+ return RET_ERROR; 10029+ } 10030 return RET_OK; 10031 } 10032 10033diff --git a/mindspore/lite/tools/converter/import/to_custom_op_pass.cc b/mindspore/lite/tools/converter/import/to_custom_op_pass.cc 10034new file mode 100644 10035index 00000000..55e524e6 10036--- /dev/null 10037+++ b/mindspore/lite/tools/converter/import/to_custom_op_pass.cc 10038@@ -0,0 +1,86 @@ 10039+/** 10040+ * Copyright 2023 Huawei Technologies Co., Ltd 10041+ * 10042+ * Licensed under the Apache License, Version 2.0 (the "License"); 10043+ * you may not use this file except in compliance with the License. 10044+ * You may obtain a copy of the License at 10045+ * 10046+ * http://www.apache.org/licenses/LICENSE-2.0 10047+ * 10048+ * Unless required by applicable law or agreed to in writing, software 10049+ * distributed under the License is distributed on an "AS IS" BASIS, 10050+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10051+ * See the License for the specific language governing permissions and 10052+ * limitations under the License. 10053+ */ 10054+ 10055+#include "tools/converter/import/to_custom_op_pass.h" 10056+#include "ops/grad/gather_d_grad_v2.h" 10057+#include "ops/masked_fill.h" 10058+#include "ops/custom.h" 10059+#include "ops/op_utils.h" 10060+#include "mindspore/ccsrc/include/common/utils/utils.h" 10061+#include "nnacl/custom_gather_d_grad_v2_parameter.h" 10062+ 10063+using mindspore::ops::kNameGatherDGradV2; 10064+using mindspore::ops::kNameMaskedFill; 10065+ 10066+namespace mindspore { 10067+namespace opt { 10068+bool ToCustomOpPass::Run(const FuncGraphPtr &graph) { 10069+ MS_ASSERT(graph != nullptr); 10070+ auto manager = graph->manager(); 10071+ MS_ASSERT(manager != nullptr); 10072+ auto node_list = TopoSort(graph->get_return()); 10073+ 10074+ for (auto &node : node_list) { 10075+ if (!utils::isa<CNodePtr>(node)) { 10076+ continue; 10077+ } 10078+ auto cnode = node->cast<CNodePtr>(); 10079+ MS_ASSERT(cnode != nullptr); 10080+ auto value_node = cnode->input(0); 10081+ auto prim = GetValueNode<PrimitivePtr>(value_node); 10082+ if (prim == nullptr) { 10083+ MS_LOG(DEBUG) << "this is a call cnode, which input[0] is fg."; 10084+ continue; 10085+ } 10086+ 10087+ auto func = ToCustomOpRegistry::GetInstance()->GetToCustomOpFunc(prim->name()); 10088+ if (func == nullptr) { 10089+ continue; 10090+ } 10091+ 10092+ auto ret = func(cnode); 10093+ if (ret != RET_OK) { 10094+ MS_LOG(ERROR) << "failed to convert normal cnode node to custom cnode"; 10095+ return false; 10096+ } 10097+ } 10098+ return true; 10099+} 10100+ 10101+int GatherDGradV2ToCustomOp(const CNodePtr &cnode) { 10102+ auto ori_prim = ops::GetOperator<ops::GatherDGradV2>(cnode->input(kAnfPrimitiveIndex)); 10103+ auto dim = ori_prim->get_dim(); 10104+ auto dim_str = std::to_string(dim); 10105+ std::map<std::string, std::vector<uint8_t>> attrs; 10106+ attrs["dim"] = std::vector<uint8_t>(dim_str.begin(), dim_str.end()); 10107+ auto custom_prim = std::make_shared<mindspore::ops::Custom>(); 10108+ custom_prim->set_type(kNameGatherDGradV2); 10109+ cnode->set_input(kAnfPrimitiveIndex, NewValueNode(custom_prim->GetPrim())); 10110+ custom_prim->set_attr(attrs); 10111+ return RET_OK; 10112+} 10113+ 10114+int MaskedFillToCustomOp(const CNodePtr &cnode) { 10115+ auto custom_prim = std::make_shared<mindspore::ops::Custom>(); 10116+ custom_prim->set_type(kNameMaskedFill); 10117+ cnode->set_input(kAnfPrimitiveIndex, NewValueNode(custom_prim->GetPrim())); 10118+ return RET_OK; 10119+} 10120+ 10121+REGISTER_TO_CUSTOM_OP(kNameGatherDGradV2, GatherDGradV2ToCustomOp); 10122+REGISTER_TO_CUSTOM_OP(kNameMaskedFill, MaskedFillToCustomOp); 10123+} // namespace opt 10124+} // namespace mindspore 10125diff --git a/mindspore/lite/tools/converter/import/to_custom_op_pass.h b/mindspore/lite/tools/converter/import/to_custom_op_pass.h 10126new file mode 100644 10127index 00000000..7108e48b 10128--- /dev/null 10129+++ b/mindspore/lite/tools/converter/import/to_custom_op_pass.h 10130@@ -0,0 +1,68 @@ 10131+/** 10132+ * Copyright 2023 Huawei Technologies Co., Ltd 10133+ * 10134+ * Licensed under the Apache License, Version 2.0 (the "License"); 10135+ * you may not use this file except in compliance with the License. 10136+ * You may obtain a copy of the License at 10137+ * 10138+ * http://www.apache.org/licenses/LICENSE-2.0 10139+ * 10140+ * Unless required by applicable law or agreed to in writing, software 10141+ * distributed under the License is distributed on an "AS IS" BASIS, 10142+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10143+ * See the License for the specific language governing permissions and 10144+ * limitations under the License. 10145+ */ 10146+ 10147+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_IMPORT_TO_CUSTOM_OP_PASS_H_ 10148+#define MINDSPORE_LITE_TOOLS_CONVERTER_IMPORT_TO_CUSTOM_OP_PASS_H_ 10149+#include <string> 10150+#include "backend/common/optimizer/pass.h" 10151+#include "tools/optimizer/common/gllo_utils.h" 10152+ 10153+namespace mindspore { 10154+namespace opt { 10155+ 10156+typedef int (*ToCustomOpFunc)(const CNodePtr &cnode); 10157+class ToCustomOpRegistry { 10158+ public: 10159+ static ToCustomOpRegistry *GetInstance() { 10160+ static ToCustomOpRegistry registry; 10161+ return ®istry; 10162+ } 10163+ 10164+ void InsertToCustomOpMap(const std::string &key, ToCustomOpFunc creator) { to_custom_op_funcs_[key] = creator; } 10165+ 10166+ ToCustomOpFunc GetToCustomOpFunc(const std::string &key) { 10167+ if (to_custom_op_funcs_.find(key) != to_custom_op_funcs_.end()) { 10168+ return to_custom_op_funcs_[key]; 10169+ } else { 10170+ MS_LOG(DEBUG) << "Unsupported primitive type : " << key; 10171+ return nullptr; 10172+ } 10173+ } 10174+ 10175+ protected: 10176+ std::map<std::string, ToCustomOpFunc> to_custom_op_funcs_; 10177+}; 10178+ 10179+class RegistryToCustomOp { 10180+ public: 10181+ RegistryToCustomOp(const std::string &key, ToCustomOpFunc creator) { 10182+ ToCustomOpRegistry::GetInstance()->InsertToCustomOpMap(key, creator); 10183+ } 10184+ virtual ~RegistryToCustomOp() = default; 10185+}; 10186+ 10187+#define REGISTER_TO_CUSTOM_OP(type, to_custom_op_func) \ 10188+ RegistryToCustomOp g_##type##_to_custom_op(type, to_custom_op_func); 10189+ 10190+class ToCustomOpPass : public Pass { 10191+ public: 10192+ ToCustomOpPass() : Pass("ToCustomOpPass") {} 10193+ ~ToCustomOpPass() = default; 10194+ bool Run(const FuncGraphPtr &graph) override; 10195+}; 10196+} // namespace opt 10197+} // namespace mindspore 10198+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_IMPORT_TO_CUSTOM_OP_PASS_H_ 10199diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/fusion_pass.cc 10200index 8ea838cf..a551196d 100644 10201--- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/fusion_pass.cc 10202+++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/fusion_pass.cc 10203@@ -287,7 +287,6 @@ bool FusionPass::MatchTree(const schema::MetaGraphT &graph, size_t nodeIdx, cons 10204 bool FusionPass::CheckMatchParams(const schema::MetaGraphT &graph, size_t nodeIdx, 10205 const std::shared_ptr<PatternOp> &target, const std::vector<size_t> &sinkIdes, 10206 const std::vector<size_t> &pathSinkIdes) { 10207- MS_ASSERT(target != nullptr); 10208 MS_ASSERT(nodeIdx < graph.nodes.size()); 10209 auto &scope = graph.nodes.at(nodeIdx); 10210 MS_CHECK_TRUE_MSG(scope != nullptr, false, "Node in graph is nullptr"); 10211diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc 10212index 371e93fb..ff99f1f4 100644 10213--- a/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc 10214+++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc 10215@@ -660,7 +660,9 @@ int InferShapePass::InitSearchTensor(const int64_t &subgraph_index, MetaGraphT * 10216 } 10217 auto &subgraph = graph->subGraph.at(subgraph_index); 10218 for (uint32_t i = 0; i < tensors_.size(); i++) { 10219- if (IsContain(subgraph->inputIndices, i) || !graph->allTensors.at(i)->data.empty()) { 10220+ if (IsContain(subgraph->inputIndices, i) || !graph->allTensors.at(i)->data.empty() || 10221+ (graph->allTensors.at(i)->nodeType == NodeType_ValueNode && graph->allTensors.at(i)->dims.size() == 1 && 10222+ graph->allTensors.at(i)->dims[0] == 0)) { 10223 tensors_[i].is_inferred_ = true; 10224 } 10225 } 10226diff --git a/mindspore/lite/tools/converter/micro/cmake/file_list.cmake b/mindspore/lite/tools/converter/micro/cmake/file_list.cmake 10227index c132460e..5dcf0bb7 100644 10228--- a/mindspore/lite/tools/converter/micro/cmake/file_list.cmake 10229+++ b/mindspore/lite/tools/converter/micro/cmake/file_list.cmake 10230@@ -4,6 +4,8 @@ set(CODER_SRC 10231 ${MICRO_DIR}/coder/context.cc 10232 ${MICRO_DIR}/coder/graph.cc 10233 ${MICRO_DIR}/coder/session.cc 10234+ ${MICRO_DIR}/coder/shape_info_container.cc 10235+ ${MICRO_DIR}/coder/dynamic_mem_manager.cc 10236 ${MICRO_DIR}/coder/utils/coder_utils.cc 10237 ${MICRO_DIR}/coder/utils/dir_utils.cc 10238 ${MICRO_DIR}/coder/utils/train_utils.cc 10239@@ -23,6 +25,7 @@ set(CODER_ALLOCATOR_SRC 10240 set(CODER_GENERATOR_SRC 10241 ${MICRO_DIR}/coder/generator/generator.cc 10242 ${MICRO_DIR}/coder/generator/inference/inference_generator.cc 10243+ ${MICRO_DIR}/coder/generator/component/allocator_component.cc 10244 ${MICRO_DIR}/coder/generator/component/common_component.cc 10245 ${MICRO_DIR}/coder/generator/component/weight_component.cc 10246 ${MICRO_DIR}/coder/generator/component/allocator_component.cc 10247@@ -66,6 +69,8 @@ set(CODER_OPCODERS_SRC 10248 ${MICRO_DIR}/coder/opcoders/base/stack_base_coder.cc 10249 ${MICRO_DIR}/coder/opcoders/base/unstack_base_coder.cc 10250 ${MICRO_DIR}/coder/opcoders/base/strided_slice_base_coder.cc 10251+ ${MICRO_DIR}/coder/opcoders/base/reshape_dynamic_base_coder.cc 10252+ ${MICRO_DIR}/coder/opcoders/base/strided_slice_dynamic_base_coder.cc 10253 #### cmsis int8 coder 10254 ${MICRO_DIR}/coder/opcoders/cmsis-nn/int8/add_int8_coder.cc 10255 ${MICRO_DIR}/coder/opcoders/cmsis-nn/int8/conv2d_base_coder.cc 10256@@ -81,23 +86,37 @@ set(CODER_OPCODERS_SRC 10257 ${MICRO_DIR}/coder/opcoders/nnacl/fp16/arithmetic_fp16_coder.cc 10258 ${MICRO_DIR}/coder/opcoders/nnacl/fp16/avg_pooling_fp16_coder.cc 10259 ${MICRO_DIR}/coder/opcoders/nnacl/fp16/concat_fp16_coder.cc 10260+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/conv2d_delegate_fp16_coder.cc 10261+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/conv_depthwise_3x3_fp16_coder.cc 10262+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/conv_depthwise_fp16_coder.cc 10263+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/conv_depthwise_sw_fp16_coder.cc 10264+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/convolution_1x1_fp16_coder.cc 10265+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/convolution_fp16_coder.cc 10266+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/convolution_winograd_fp16_coder.cc 10267 ${MICRO_DIR}/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.cc 10268 ${MICRO_DIR}/coder/opcoders/nnacl/fp16/deconv2d_fp16_coder.cc 10269+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/lstm_fp16_coder.cc 10270+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.cc 10271+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/matmul_fp16_coder.cc 10272 ${MICRO_DIR}/coder/opcoders/nnacl/fp16/resize_fp16_coder.cc 10273 ${MICRO_DIR}/coder/opcoders/nnacl/fp16/transpose_fp16_coder.cc 10274 ${MICRO_DIR}/coder/opcoders/nnacl/fp16/layernorm_fp16_coder.cc 10275 ${MICRO_DIR}/coder/opcoders/nnacl/fp16/reduce_fp16_coder.cc 10276 ${MICRO_DIR}/coder/opcoders/nnacl/fp16/resize_fp16_coder.cc 10277- ${MICRO_DIR}/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.cc 10278- ${MICRO_DIR}/coder/opcoders/nnacl/fp16/matmul_fp16_coder.cc 10279- ${MICRO_DIR}/coder/opcoders/nnacl/fp16/conv2d_delegate_fp16_coder.cc 10280- ${MICRO_DIR}/coder/opcoders/nnacl/fp16/conv_depthwise_fp16_coder.cc 10281- ${MICRO_DIR}/coder/opcoders/nnacl/fp16/convolution_fp16_coder.cc 10282- ${MICRO_DIR}/coder/opcoders/nnacl/fp16/convolution_winograd_fp16_coder.cc 10283- ${MICRO_DIR}/coder/opcoders/nnacl/fp16/convolution_1x1_fp16_coder.cc 10284- ${MICRO_DIR}/coder/opcoders/nnacl/fp16/conv_depthwise_3x3_fp16_coder.cc 10285- ${MICRO_DIR}/coder/opcoders/nnacl/fp16/conv_depthwise_sw_fp16_coder.cc 10286- ${MICRO_DIR}/coder/opcoders/nnacl/fp16/lstm_fp16_coder.cc 10287+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/activation_dynamic_fp16_coder.cc 10288+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.cc 10289+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_coder.cc 10290+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/lstm_mindir_dynamic_fp16_coder.cc 10291+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/transpose_dynamic_fp16_coder.cc 10292+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/concat_dynamic_fp16_coder.cc 10293+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/slice_dynamic_fp16_coder.cc 10294+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/softmax_dynamic_fp16_coder.cc 10295+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/arithmetic_dynamic_fp16_coder.cc 10296+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/scale_dynamic_fp16_coder.cc 10297+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/conv2d_delegate_dynamic_fp16_coder.cc 10298+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.cc 10299+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.cc 10300+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/pooling_dynamic_fp16_coder.cc 10301 #### nnacl fp32 coder 10302 ${MICRO_DIR}/coder/opcoders/nnacl/fp32/activation_fp32_coder.cc 10303 ${MICRO_DIR}/coder/opcoders/nnacl/fp32/addn_fp32_coder.cc 10304@@ -122,6 +141,7 @@ set(CODER_OPCODERS_SRC 10305 ${MICRO_DIR}/coder/opcoders/nnacl/fp32/lstm_fp32_coder.cc 10306 ${MICRO_DIR}/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.cc 10307 ${MICRO_DIR}/coder/opcoders/nnacl/fp32/matmul_fp32_coder.cc 10308+ ${MICRO_DIR}/coder/opcoders/nnacl/fp32/ones_like_fp32_coder.cc 10309 ${MICRO_DIR}/coder/opcoders/nnacl/fp32/pad_fp32_coder.cc 10310 ${MICRO_DIR}/coder/opcoders/nnacl/fp32/pooling_fp32_coder.cc 10311 ${MICRO_DIR}/coder/opcoders/nnacl/fp32/power_fp32_coder.cc 10312@@ -133,17 +153,14 @@ set(CODER_OPCODERS_SRC 10313 ${MICRO_DIR}/coder/opcoders/nnacl/fp32/transpose_fp32_coder.cc 10314 ${MICRO_DIR}/coder/opcoders/nnacl/fp32/splice_fp32_coder.cc 10315 ${MICRO_DIR}/coder/opcoders/nnacl/fp32/exp_fp32_coder.cc 10316+ ${MICRO_DIR}/coder/opcoders/nnacl/fp32/fill_fp32_coder.cc 10317 ${MICRO_DIR}/coder/opcoders/nnacl/fp32/deconv2d_fp32_coder.cc 10318 ${MICRO_DIR}/coder/opcoders/nnacl/fp32/prelu_fp32_coder.cc 10319 ${MICRO_DIR}/coder/opcoders/nnacl/fp32/layernorm_fp32_coder.cc 10320- ${MICRO_DIR}/coder/opcoders/nnacl/fp32/ones_like_fp32_coder.cc 10321- ${MICRO_DIR}/coder/opcoders/nnacl/fp32/fill_fp32_coder.cc 10322- #### nnacl fp32_grad coder 10323- ${MICRO_DIR}/coder/opcoders/nnacl/fp32_grad/activation_grad_coder.cc 10324- ${MICRO_DIR}/coder/opcoders/nnacl/fp32_grad/adam_coder.cc 10325- ${MICRO_DIR}/coder/opcoders/nnacl/fp32_grad/assign_coder.cc 10326- ${MICRO_DIR}/coder/opcoders/nnacl/fp32_grad/biasadd_grad_coder.cc 10327- ${MICRO_DIR}/coder/opcoders/nnacl/fp32_grad/softmax_cross_entropy_with_logits_coder.cc 10328+ ${MICRO_DIR}/coder/opcoders/nnacl/fp32/gather_dynamic_fp32_coder.cc 10329+ ${MICRO_DIR}/coder/opcoders/nnacl/fp32/activation_dynamic_fp32_coder.cc 10330+ ${MICRO_DIR}/coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.cc 10331+ ${MICRO_DIR}/coder/opcoders/nnacl/fp32/split_dynamic_fp32_coder.cc 10332 #### nnacl int8 coder 10333 ${MICRO_DIR}/coder/opcoders/nnacl/int8/activation_int8_coder.cc 10334 ${MICRO_DIR}/coder/opcoders/nnacl/int8/affine_int8_coder.cc 10335diff --git a/mindspore/lite/tools/converter/micro/coder/coder.cc b/mindspore/lite/tools/converter/micro/coder/coder.cc 10336index cc224ae5..a502500d 100644 10337--- a/mindspore/lite/tools/converter/micro/coder/coder.cc 10338+++ b/mindspore/lite/tools/converter/micro/coder/coder.cc 10339@@ -42,6 +42,34 @@ std::shared_ptr<CoderSession> CreateCoderSession() { 10340 } 10341 return session; 10342 } 10343+ 10344+int ParseMicroDynamicShape(const schema::MetaGraphT &graph, micro::MicroParam *micro_param) { 10345+ for (auto index : graph.inputIndex) { 10346+ auto input_name = graph.allTensors.at(index)->name; 10347+ if (micro_param->graph_inputs_origin_info.find(input_name) == micro_param->graph_inputs_origin_info.end() || 10348+ micro_param->inputs_shape_by_scenes.find(input_name) == micro_param->inputs_shape_by_scenes.end()) { 10349+ MS_LOG(ERROR) << "Micro param: dynamic inputs name is invalid"; 10350+ return RET_INPUT_PARAM_INVALID; 10351+ } 10352+ micro_param->graph_inputs_template.emplace_back(micro_param->graph_inputs_origin_info[input_name]); 10353+ micro_param->graph_inputs_shape_infos.emplace_back(micro_param->inputs_shape_by_scenes[input_name]); 10354+ } 10355+ return RET_OK; 10356+} 10357+ 10358+int ParseMicroDynamicShape(const Model &model, micro::MicroParam *micro_param) { 10359+ for (auto index : model.graph_.input_indices_) { 10360+ auto input_name = model.graph_.all_tensors_.at(index)->name()->str(); 10361+ if (micro_param->graph_inputs_origin_info.find(input_name) == micro_param->graph_inputs_origin_info.end() || 10362+ micro_param->inputs_shape_by_scenes.find(input_name) == micro_param->inputs_shape_by_scenes.end()) { 10363+ MS_LOG(ERROR) << "Micro param: dynamic inputs name is invalid"; 10364+ return RET_INPUT_PARAM_INVALID; 10365+ } 10366+ micro_param->graph_inputs_template.emplace_back(micro_param->graph_inputs_origin_info[input_name]); 10367+ micro_param->graph_inputs_shape_infos.emplace_back(micro_param->inputs_shape_by_scenes[input_name]); 10368+ } 10369+ return RET_OK; 10370+} 10371 } // namespace 10372 int Coder::Run(const void *model_buff, size_t size, const std::string &model_name, bool end_flag, bool enable_fp16) { 10373 session_ = CreateCoderSession(); 10374@@ -109,29 +137,37 @@ bool Coder::InitPath(const std::string &output_path) { 10375 return true; 10376 } 10377 10378-int Coder::MicroSourceCodeGeneration(const schema::MetaGraphT &graph, const std::string &output_path, 10379- const MicroParam ¶m, bool enable_fp16) { 10380+int Coder::MicroSourceCodeGeneration(const schema::MetaGraphT &graph, const std::string &output_path, MicroParam *param, 10381+ bool enable_fp16) { 10382 flatbuffers::FlatBufferBuilder builder(kFlatbuffersBuilderInitSize); 10383 auto offset = schema::MetaGraph::Pack(builder, &graph); 10384 builder.Finish(offset); 10385 schema::FinishMetaGraphBuffer(builder, offset); 10386 size_t size = builder.GetSize(); 10387- if (ExecuteMicroGeneration(builder.GetBufferPointer(), size, output_path, param, enable_fp16) != RET_OK) { 10388+ if (!param->dynamic_symbols.empty()) { 10389+ MS_CHECK_TRUE_MSG(ParseMicroDynamicShape(graph, param) == RET_OK, RET_ERROR, "ParseMicroDynamicShape failed."); 10390+ } 10391+ if (ExecuteMicroGeneration(builder.GetBufferPointer(), size, output_path, *param, enable_fp16) != RET_OK) { 10392 MS_LOG(ERROR) << "Execute Micro failed."; 10393 return RET_ERROR; 10394 } 10395 return RET_OK; 10396 } 10397 10398-int Coder::MicroSourceCodeGeneration(const std::string &model_file, const std::string &output_path, 10399- const MicroParam ¶m, bool enable_fp16) { 10400+int Coder::MicroSourceCodeGeneration(const std::string &model_file, const std::string &output_path, MicroParam *param, 10401+ bool enable_fp16) { 10402 size_t buffer_size; 10403 auto model_buf = lite::ReadFile(model_file.c_str(), &buffer_size); 10404 if (model_buf == nullptr) { 10405 MS_LOG(ERROR) << "Read model-file failed."; 10406 return RET_NULL_PTR; 10407 } 10408- auto ret = ExecuteMicroGeneration(model_buf, buffer_size, output_path, param, enable_fp16); 10409+ Model *model = lite::Model::Import(model_buf, buffer_size); 10410+ MS_CHECK_PTR(model); 10411+ if (!param->dynamic_symbols.empty()) { 10412+ MS_CHECK_TRUE_MSG(ParseMicroDynamicShape(*model, param) == RET_OK, RET_ERROR, "ParseMicroDynamicShape failed."); 10413+ } 10414+ auto ret = ExecuteMicroGeneration(model_buf, buffer_size, output_path, *param, enable_fp16); 10415 if (ret != RET_OK) { 10416 MS_LOG(ERROR) << "Execute Micro failed."; 10417 } 10418@@ -199,6 +235,10 @@ int Coder::Init(const MicroParam ¶m) const { 10419 DirectoryGenerator::GetInstance()->project_name()); 10420 config->set_keep_original_weight(param.keep_original_weight); 10421 config->set_changeable_weights_name(param.changeable_weights_name); 10422+ config->set_graph_inputs_shape_infos(param.graph_inputs_shape_infos); 10423+ config->set_dynamic_symbols(param.dynamic_symbols); 10424+ config->set_dynamic_symbols_num(param.dynamic_symbols_num); 10425+ config->set_user_graph_inputs_template(param.graph_inputs_template); 10426 10427 auto print_parameter = [](auto name, auto value) { 10428 MS_LOG(INFO) << std::setw(20) << std::left << name << "= " << value; 10429@@ -209,6 +249,7 @@ int Coder::Init(const MicroParam ¶m) const { 10430 print_parameter("codePath", config->code_path()); 10431 print_parameter("codeMode", config->code_mode()); 10432 print_parameter("debugMode", config->debug_mode()); 10433+ print_parameter("keepOriginalWeight", config->keep_original_weight()); 10434 return RET_OK; 10435 } 10436 } // namespace mindspore::lite::micro 10437diff --git a/mindspore/lite/tools/converter/micro/coder/coder.h b/mindspore/lite/tools/converter/micro/coder/coder.h 10438index c360f4c1..fad479aa 100644 10439--- a/mindspore/lite/tools/converter/micro/coder/coder.h 10440+++ b/mindspore/lite/tools/converter/micro/coder/coder.h 10441@@ -31,9 +31,9 @@ class Coder final { 10442 10443 ~Coder() = default; 10444 static int MicroSourceCodeGeneration(const schema::MetaGraphT &graph, const std::string &output_path, 10445- const MicroParam ¶m, bool enable_fp16); 10446- static int MicroSourceCodeGeneration(const std::string &model_file, const std::string &output_path, 10447- const MicroParam ¶m, bool enable_fp16); 10448+ MicroParam *param, bool enable_fp16); 10449+ static int MicroSourceCodeGeneration(const std::string &model_file, const std::string &output_path, MicroParam *param, 10450+ bool enable_fp16); 10451 10452 private: 10453 static int ExecuteMicroGeneration(const void *model_buf, size_t size, const std::string &output_path, 10454diff --git a/mindspore/lite/tools/converter/micro/coder/config.h b/mindspore/lite/tools/converter/micro/coder/config.h 10455index 9be56178..fb90a2fc 100644 10456--- a/mindspore/lite/tools/converter/micro/coder/config.h 10457+++ b/mindspore/lite/tools/converter/micro/coder/config.h 10458@@ -34,6 +34,12 @@ struct MicroParam { 10459 std::string project_name; 10460 bool is_last_model{false}; 10461 bool keep_original_weight{false}; 10462+ std::vector<std::vector<std::string>> graph_inputs_template; 10463+ std::map<std::string, std::vector<std::string>> graph_inputs_origin_info; 10464+ std::vector<std::string> dynamic_symbols; 10465+ std::vector<size_t> dynamic_symbols_num; 10466+ std::vector<std::vector<std::vector<int>>> graph_inputs_shape_infos; 10467+ std::map<std::string, std::vector<std::vector<int>>> inputs_shape_by_scenes; 10468 }; 10469 10470 class Configurator { 10471@@ -67,6 +73,29 @@ class Configurator { 10472 void set_changeable_weights_name(const std::string &weights_name) { changeable_weights_name_ = weights_name; } 10473 const std::string &changeable_weights_name() const { return changeable_weights_name_; } 10474 10475+ void set_dynamic_shape(bool dynamic_shape) { dynamic_shape_ = dynamic_shape; } 10476+ bool dynamic_shape() const { return dynamic_shape_; } 10477+ 10478+ void set_dynamic_symbols(const std::vector<std::string> &dynamic_symbols) { dynamic_symbols_ = dynamic_symbols; } 10479+ const std::vector<std::string> &dynamic_symbols() const { return dynamic_symbols_; } 10480+ 10481+ void set_dynamic_symbols_num(const std::vector<size_t> &dynamic_symbols_num) { 10482+ dynamic_symbols_num_ = dynamic_symbols_num; 10483+ } 10484+ const std::vector<size_t> &dynamic_symbols_num() const { return dynamic_symbols_num_; } 10485+ 10486+ void set_user_graph_inputs_template(const std::vector<std::vector<std::string>> &graph_inputs_template) { 10487+ user_graph_inputs_template_ = graph_inputs_template; 10488+ } 10489+ const std::vector<std::vector<std::string>> &user_graph_inputs_template() const { 10490+ return user_graph_inputs_template_; 10491+ } 10492+ 10493+ void set_graph_inputs_shape_infos(const std::vector<std::vector<std::vector<int>>> &graph_inputs_shape_infos) { 10494+ graph_inputs_shape_infos_ = graph_inputs_shape_infos; 10495+ } 10496+ const std::vector<std::vector<std::vector<int>>> &graph_inputs_shape_infos() { return graph_inputs_shape_infos_; } 10497+ 10498 private: 10499 Configurator() = default; 10500 ~Configurator() = default; 10501@@ -76,8 +105,13 @@ class Configurator { 10502 bool support_parallel_{false}; 10503 bool debug_mode_{false}; 10504 bool keep_original_weight_{false}; 10505+ bool dynamic_shape_{false}; 10506 std::string proj_dir_; 10507 std::string changeable_weights_name_; 10508+ std::vector<std::string> dynamic_symbols_; 10509+ std::vector<size_t> dynamic_symbols_num_; 10510+ std::vector<std::vector<std::vector<int>>> graph_inputs_shape_infos_; 10511+ std::vector<std::vector<std::string>> user_graph_inputs_template_; 10512 }; 10513 } // namespace mindspore::lite::micro 10514 #endif // MICRO_CODER_CONFIG_H 10515diff --git a/mindspore/lite/tools/converter/micro/coder/context.cc b/mindspore/lite/tools/converter/micro/coder/context.cc 10516index 251b282f..7e7f640e 100644 10517--- a/mindspore/lite/tools/converter/micro/coder/context.cc 10518+++ b/mindspore/lite/tools/converter/micro/coder/context.cc 10519@@ -50,4 +50,17 @@ std::vector<std::string> CoderContext::GetInitWeightSizeCode() const { 10520 } 10521 10522 void CoderContext::AppendInitWeightSizeCode(size_t w_buf_size) { weight_buffer_size_ += w_buf_size; } 10523+ 10524+const std::map<int, std::vector<int>> &CoderContext::shape_all_scenes() const { 10525+ return shape_info_container_->GetShapesWholeScenes(); 10526+} 10527+const std::map<const Tensor *, std::vector<std::string>> &CoderContext::shape_templates() { 10528+ return shape_info_container_->GetWholeTemplateShape(); 10529+} 10530+const std::map<int, std::vector<size_t>> &CoderContext::offset_all_scenes() { 10531+ return dynamic_mem_manager_->GetOffsetAllScenes(); 10532+} 10533+const std::vector<size_t> &CoderContext::buffer_sizes() const { return dynamic_mem_manager_->GetBufferSizes(); } 10534+const std::vector<size_t> &CoderContext::workspaces() const { return dynamic_mem_manager_->GetWorkSpaces(); } 10535+std::string CoderContext::tensor_addr(const Tensor *tensor) { return dynamic_mem_manager_->GetVarTensorAddr(tensor); } 10536 } // namespace mindspore::lite::micro 10537diff --git a/mindspore/lite/tools/converter/micro/coder/context.h b/mindspore/lite/tools/converter/micro/coder/context.h 10538index bad4ab40..b511eac1 100644 10539--- a/mindspore/lite/tools/converter/micro/coder/context.h 10540+++ b/mindspore/lite/tools/converter/micro/coder/context.h 10541@@ -25,6 +25,8 @@ 10542 #include <vector> 10543 #include <algorithm> 10544 #include "src/tensor.h" 10545+#include "tools/converter/micro/coder/shape_info_container.h" 10546+#include "tools/converter/micro/coder/dynamic_mem_manager.h" 10547 10548 namespace mindspore::lite::micro { 10549 class CoderContext { 10550@@ -146,6 +148,17 @@ class CoderContext { 10551 10552 bool end_flag() { return end_flag_; } 10553 10554+ void set_shape_info_container(ShapeInfoContainer *shape_info_container) { 10555+ shape_info_container_ = shape_info_container; 10556+ } 10557+ void set_dynamic_mem_manager(DynamicMemManager *dynamic_mem_manager) { dynamic_mem_manager_ = dynamic_mem_manager; } 10558+ const std::map<int, std::vector<int>> &shape_all_scenes() const; 10559+ const std::map<const Tensor *, std::vector<std::string>> &shape_templates(); 10560+ const std::map<int, std::vector<size_t>> &offset_all_scenes(); 10561+ const std::vector<size_t> &buffer_sizes() const; 10562+ const std::vector<size_t> &workspaces() const; 10563+ std::string tensor_addr(const Tensor *tensor); 10564+ 10565 private: 10566 std::string model_name_; 10567 std::vector<Tensor *> graph_inputs_; 10568@@ -195,6 +208,8 @@ class CoderContext { 10569 // operator C Lang files list, depended by the net.c. it will be add to CMakeLists.txt 10570 static std::set<std::string> c_files_; 10571 static size_t max_buffer_size_; 10572+ ShapeInfoContainer *shape_info_container_; 10573+ DynamicMemManager *dynamic_mem_manager_; 10574 }; 10575 } // namespace mindspore::lite::micro 10576 #endif // MINDSPORE_LITE_MICRO_CODER_CONTEXT_H_ 10577diff --git a/mindspore/lite/tools/converter/micro/coder/dynamic_mem_manager.cc b/mindspore/lite/tools/converter/micro/coder/dynamic_mem_manager.cc 10578new file mode 100644 10579index 00000000..976bd852 10580--- /dev/null 10581+++ b/mindspore/lite/tools/converter/micro/coder/dynamic_mem_manager.cc 10582@@ -0,0 +1,116 @@ 10583+/** 10584+ * Copyright 2023 Huawei Technologies Co., Ltd 10585+ * 10586+ * Licensed under the Apache License, Version 2.0 (the "License"); 10587+ * you may not use this file except in compliance with the License. 10588+ * You may obtain a copy of the License at 10589+ * 10590+ * http://www.apache.org/licenses/LICENSE-2.0 10591+ * 10592+ * Unless required by applicable law or agreed to in writing, software 10593+ * distributed under the License is distributed on an "AS IS" BASIS, 10594+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10595+ * See the License for the specific language governing permissions and 10596+ * limitations under the License. 10597+ */ 10598+ 10599+#include "coder/dynamic_mem_manager.h" 10600+#include <vector> 10601+#include "coder/allocator/memory_manager.h" 10602+#include "coder/generator/component/component.h" 10603+ 10604+namespace mindspore::lite::micro { 10605+int DynamicMemManager::AllocDynamicMem(const std::vector<std::unique_ptr<OperatorCoder>> &nodes, 10606+ const std::vector<Tensor *> &graph_inputs, 10607+ const std::vector<Tensor *> &graph_outputs, 10608+ const ShapeInfoContainer *shape_info_container) { 10609+ MS_CHECK_TRUE_MSG(shape_info_container, RET_NULL_PTR, "ShapeInfoContainer is a nullptr."); 10610+ for (size_t i = 0; i < graph_inputs.size(); ++i) { 10611+ graph_inputs_.insert(std::make_pair(graph_inputs.at(i), kInputPrefixName + std::to_string(i))); 10612+ } 10613+ auto var_tensor_shapes = shape_info_container->GetVarTensorInfos(); 10614+ MS_CHECK_TRUE_MSG(!var_tensor_shapes.empty(), RET_ERROR, "Cannot get var-tensor's shape-info"); 10615+ auto scene_num = var_tensor_shapes.begin()->second.size(); 10616+ for (const auto &item : var_tensor_shapes) { 10617+ MS_CHECK_TRUE_MSG(item.first, RET_NULL_PTR, "Find a nullptr in shape-infos"); 10618+ MS_CHECK_TRUE_MSG(item.second.size() == scene_num, RET_ERROR, "Shape-info is invalid."); 10619+ } 10620+ for (size_t i = 0; i < scene_num; ++i) { 10621+ for (const auto &item : var_tensor_shapes) { 10622+ item.first->ResetRefCount(); 10623+ item.first->set_shape(item.second[i]); 10624+ } 10625+ auto ret = AllocDynamicMemCore(nodes, graph_outputs, i); 10626+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "Alloc dynamic memory failed."); 10627+ } 10628+ return RET_OK; 10629+} 10630+ 10631+int DynamicMemManager::AllocDynamicMemCore(const std::vector<std::unique_ptr<OperatorCoder>> &nodes, 10632+ const std::vector<Tensor *> &graph_outputs, int scene_index) { 10633+ if (offsets_all_scenes_.find(scene_index) != offsets_all_scenes_.end()) { 10634+ MS_LOG(ERROR) << "Current scene has been processed."; 10635+ return RET_ERROR; 10636+ } 10637+ auto manager = std::make_unique<MemoryManager>(); 10638+ int ret = manager->AssignMemory(nodes, graph_outputs); 10639+ if (ret != RET_OK) { 10640+ MS_LOG(ERROR) << "assign memory failed"; 10641+ return RET_ERROR; 10642+ } 10643+ std::map<Tensor *, size_t> offsets = manager->variables_offset(); 10644+ if (offset_index_.empty()) { 10645+ int index = 0; 10646+ for (auto &item : offsets) { 10647+ offset_index_[item.first] = index++; 10648+ offsets_all_scenes_[scene_index].push_back(item.second); 10649+ } 10650+ } else { 10651+ MS_CHECK_TRUE_MSG(offsets.size() == offset_index_.size(), RET_ERROR, "Tensors num is not same."); 10652+ for (auto &item : offsets) { 10653+ MS_CHECK_TRUE_MSG(offset_index_.find(item.first) != offset_index_.end(), RET_ERROR, "Tensor cannot be found."); 10654+ offsets_all_scenes_[scene_index].push_back(item.second); 10655+ } 10656+ } 10657+ buffer_sizes_.push_back(manager->GetAllocatedSize()); 10658+ offsets_all_scenes_[scene_index].push_back(manager->GetAllocatedSize()); 10659+ return RET_OK; 10660+} 10661+ 10662+std::string DynamicMemManager::GetVarTensorAddr(const Tensor *tensor) const { 10663+ if (graph_inputs_.find(tensor) != graph_inputs_.end()) { 10664+ return graph_inputs_.at(tensor); 10665+ } 10666+ if (offset_index_.find(tensor) == offset_index_.end()) { 10667+ return ""; 10668+ } 10669+ if (kBufferPrefixName == nullptr || kOffsetPrefixName == nullptr) { 10670+ MS_LOG(ERROR) << "Buffer or Offset is a nullptr."; 10671+ return ""; 10672+ } 10673+ return std::string(kBufferPrefixName) + " + " + kOffsetPrefixName + "[" + std::to_string(offset_index_.at(tensor)) + 10674+ "]"; 10675+} 10676+ 10677+std::string DynamicMemManager::AllocWorkSpace(size_t size, int index) { 10678+ if (index < 0 || static_cast<size_t>(index) >= buffer_sizes_.size()) { 10679+ return ""; 10680+ } 10681+ if (static_cast<size_t>(index) + 1 >= workspaces_.size()) { 10682+ workspaces_.insert(workspaces_.end(), index + 1 - workspaces_.size(), 0); 10683+ } 10684+ if (workspaces_[index] < size) { 10685+ workspaces_[index] = size; 10686+ } 10687+ if (kBufferPrefixName == nullptr) { 10688+ MS_LOG(ERROR) << "Buffer is a nullptr."; 10689+ return ""; 10690+ } 10691+ if (kOffsetPrefixName == nullptr) { 10692+ MS_LOG(ERROR) << "Offset is a nullptr."; 10693+ return ""; 10694+ } 10695+ return "(" + std::string(kBufferPrefixName) + " + " + kOffsetPrefixName + "[" + 10696+ std::to_string(offsets_all_scenes_.begin()->second.size() - 1) + "])"; 10697+} 10698+} // namespace mindspore::lite::micro 10699diff --git a/mindspore/lite/tools/converter/micro/coder/dynamic_mem_manager.h b/mindspore/lite/tools/converter/micro/coder/dynamic_mem_manager.h 10700new file mode 100644 10701index 00000000..6db7cff5 10702--- /dev/null 10703+++ b/mindspore/lite/tools/converter/micro/coder/dynamic_mem_manager.h 10704@@ -0,0 +1,53 @@ 10705+/** 10706+ * Copyright 2023 Huawei Technologies Co., Ltd 10707+ * 10708+ * Licensed under the Apache License, Version 2.0 (the "License"); 10709+ * you may not use this file except in compliance with the License. 10710+ * You may obtain a copy of the License at 10711+ * 10712+ * http://www.apache.org/licenses/LICENSE-2.0 10713+ * 10714+ * Unless required by applicable law or agreed to in writing, software 10715+ * distributed under the License is distributed on an "AS IS" BASIS, 10716+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10717+ * See the License for the specific language governing permissions and 10718+ * limitations under the License. 10719+ */ 10720+ 10721+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_DYNAMIC_MEM_MANAGER_H_ 10722+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_DYNAMIC_MEM_MANAGER_H_ 10723+ 10724+#include <map> 10725+#include <vector> 10726+#include "src/tensor.h" 10727+#include "tools/converter/micro/coder/shape_info_container.h" 10728+ 10729+namespace mindspore::lite::micro { 10730+class OperatorCoder; 10731+class DynamicMemManager { 10732+ public: 10733+ DynamicMemManager() = default; 10734+ virtual ~DynamicMemManager() = default; 10735+ int AllocDynamicMem(const std::vector<std::unique_ptr<OperatorCoder>> &nodes, 10736+ const std::vector<Tensor *> &graph_inputs, const std::vector<Tensor *> &graph_outputs, 10737+ const ShapeInfoContainer *shape_info_container); 10738+ 10739+ std::string GetVarTensorAddr(const Tensor *tensor) const; 10740+ std::string AllocWorkSpace(size_t size, int index); 10741+ 10742+ const std::vector<size_t> &GetBufferSizes() const { return buffer_sizes_; } 10743+ const std::vector<size_t> &GetWorkSpaces() const { return workspaces_; } 10744+ const std::map<int, std::vector<size_t>> &GetOffsetAllScenes() { return offsets_all_scenes_; } 10745+ 10746+ private: 10747+ int AllocDynamicMemCore(const std::vector<std::unique_ptr<OperatorCoder>> &nodes, 10748+ const std::vector<Tensor *> &graph_outputs, int scene_index); 10749+ std::map<int, std::vector<size_t>> offsets_all_scenes_; 10750+ std::map<const Tensor *, int> offset_index_; 10751+ std::map<const Tensor *, std::string> graph_inputs_; 10752+ std::vector<size_t> buffer_sizes_; 10753+ std::vector<size_t> workspaces_; 10754+ int model_id_; 10755+}; 10756+} // namespace mindspore::lite::micro 10757+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_DYNAMIC_MEM_MANAGER_H_ 10758diff --git a/mindspore/lite/tools/converter/micro/coder/generator/component/cmake_component.cc b/mindspore/lite/tools/converter/micro/coder/generator/component/cmake_component.cc 10759index 643cf50b..831d4259 100644 10760--- a/mindspore/lite/tools/converter/micro/coder/generator/component/cmake_component.cc 10761+++ b/mindspore/lite/tools/converter/micro/coder/generator/component/cmake_component.cc 10762@@ -5,7 +5,7 @@ 10763 * you may not use this file except in compliance with the License. 10764 * You may obtain a copy of the License at 10765 * 10766- * http://www.apache.org/licenses/LICENSE-2.0 10767+ * http://www.apache.objrg/licenses/LICENSE-2.0 10768 * 10769 * Unless required by applicable law or agreed to in writing, software 10770 * distributed under the License is distributed on an "AS IS" BASIS, 10771@@ -29,32 +29,32 @@ void CodeCMakeNetLibrary(std::ofstream &ofs, const std::unique_ptr<CoderContext> 10772 } 10773 ofs << "set(OP_SRC\n"; 10774 for (const std::string &c_file : ctx->c_files()) { 10775- ofs << " " << c_file << ".o\n"; 10776+ ofs << " " << c_file << ".obj\n"; 10777 } 10778 for (int i = 0; i <= ctx->GetCurModelIndex(); ++i) { 10779- ofs << " weight" << i << ".c.o\n" 10780- << " net" << i << ".c.o\n" 10781- << " model" << i << ".c.o\n"; 10782+ ofs << " weight" << i << ".c.obj\n" 10783+ << " net" << i << ".c.obj\n" 10784+ << " model" << i << ".c.obj\n"; 10785 } 10786- ofs << " model.c.o\n" 10787- << " context.c.o\n" 10788- << " tensor.c.o\n"; 10789- if (config->target() != kCortex_M) { 10790- ofs << " allocator.c.o\n"; 10791+ ofs << " model.c.obj\n" 10792+ << " context.c.obj\n" 10793+ << " tensor.c.obj\n"; 10794+ if (config->target() != kCortex_M && !config->dynamic_shape()) { 10795+ ofs << " allocator.c.obj\n"; 10796 } 10797 if (config->debug_mode()) { 10798- ofs << " debug_utils.c.o\n"; 10799+ ofs << " debug_utils.c.obj\n"; 10800 } 10801 if (config->support_parallel()) { 10802- ofs << " micro_core_affinity.c.o\n" 10803- " micro_thread_pool.c.o\n"; 10804+ ofs << " micro_core_affinity.c.obj\n" 10805+ " micro_thread_pool.c.obj\n"; 10806 } 10807 ofs << ")\n"; 10808 std::set<std::string> kernel_cmake_asm_set_files = ctx->asm_files(); 10809 if (!kernel_cmake_asm_set_files.empty() && (config->target() == kARM32 || config->target() == kARM64)) { 10810 ofs << "set(ASSEMBLY_SRC\n"; 10811 for (const std::string &asm_file : kernel_cmake_asm_set_files) { 10812- ofs << " " << asm_file << ".o\n"; 10813+ ofs << " " << asm_file << ".obj\n"; 10814 } 10815 ofs << ")\n" 10816 << "set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C)\n" 10817diff --git a/mindspore/lite/tools/converter/micro/coder/generator/component/common_component.cc b/mindspore/lite/tools/converter/micro/coder/generator/component/common_component.cc 10818index 774e8353..62c2f668 100644 10819--- a/mindspore/lite/tools/converter/micro/coder/generator/component/common_component.cc 10820+++ b/mindspore/lite/tools/converter/micro/coder/generator/component/common_component.cc 10821@@ -16,6 +16,7 @@ 10822 10823 #include "coder/generator/component/common_component.h" 10824 #include <memory> 10825+#include "coder/generator/component/const_blocks/license.h" 10826 #include "coder/generator/component/component.h" 10827 #include "coder/utils/type_cast.h" 10828 #include "coder/utils/coder_utils.h" 10829@@ -23,36 +24,59 @@ 10830 #include "include/errorcode.h" 10831 #include "nnacl/op_base.h" 10832 #include "include/c_api/model_c.h" 10833+#include "tools/common/string_util.h" 10834 10835 namespace mindspore::lite::micro { 10836-const char handle_array_destroy_state[] = R"RAW( 10837-void MSTensorHandleArrayDestroy(MSTensorHandleArray inputs); 10838+const char model_runtime_init_source[] = R"RAW( 10839+typedef struct { 10840+ void *runtime_buffer; 10841+ OH_AI_TensorHandleArray inputs; 10842+ OH_AI_TensorHandleArray outputs; 10843+} MicroModel; 10844+OH_AI_ModelHandle OH_AI_ModelCreate() { 10845+ MicroModel *micro_model = (MicroModel *)malloc(sizeof(MicroModel)); 10846+ if (micro_model == NULL) { 10847+ return NULL; 10848+ } 10849+)RAW"; 10850+const char model_runtime_malloc_source[] = R"RAW( 10851+ int buffer_size = GetBufferSize(); 10852+ void *runtime_buffer = malloc(buffer_size); 10853+ if (runtime_buffer == NULL) { 10854+ return NULL; 10855+ } 10856+ micro_model->runtime_buffer = runtime_buffer; 10857+ int ret = SetBuffer(runtime_buffer); 10858+ if (ret != OH_AI_STATUS_SUCCESS) { 10859+ return NULL; 10860+ } 10861+ 10862 )RAW"; 10863 10864 const char handle_array_destroy[] = R"RAW( 10865-void MSTensorHandleArrayDestroy(MSTensorHandleArray inputs) { 10866- if (inputs.handle_list == NULL) { 10867- return; 10868- } 10869- for (size_t i = 0; i < inputs.handle_num; i++) { 10870- MicroTensor *micro_tensor = inputs.handle_list[i]; 10871- if (micro_tensor == NULL) { 10872- continue; 10873- } 10874- if (micro_tensor->data != NULL && micro_tensor->owned) { 10875- free(micro_tensor->data); 10876- micro_tensor->data = NULL; 10877- micro_tensor->owned = false; 10878- } 10879- if (micro_tensor->shape != NULL) { 10880- free(micro_tensor->shape); 10881- micro_tensor->shape = NULL; 10882- } 10883- free(micro_tensor); 10884- micro_tensor = NULL; 10885- } 10886- free(inputs.handle_list); 10887- inputs.handle_list = NULL; 10888+void OH_AI_TensorHandleArrayDestroy(OH_AI_TensorHandleArray inputs) { 10889+ if (inputs.handle_list == NULL) { 10890+ return; 10891+ } 10892+ for (size_t i = 0; i < inputs.handle_num; i++) { 10893+ MicroTensor *micro_tensor = inputs.handle_list[i]; 10894+ if (micro_tensor == NULL) { 10895+ continue; 10896+ } 10897+ if (micro_tensor->data != NULL && micro_tensor->owned) { 10898+ free(micro_tensor->data); 10899+ micro_tensor->data = NULL; 10900+ micro_tensor->owned = false; 10901+ } 10902+ if (micro_tensor->shape) { 10903+ free(micro_tensor->shape); 10904+ micro_tensor->shape = NULL; 10905+ } 10906+ free(micro_tensor); 10907+ micro_tensor = NULL; 10908+ } 10909+ free(inputs.handle_list); 10910+ inputs.handle_list = NULL; 10911 } 10912 10913 )RAW"; 10914@@ -62,7 +86,7 @@ const char cortex_set_workspace[] = R"RAW( 10915 if (micro_model == NULL) { 10916 return; 10917 } 10918- if (workspace_size < MSModelCalcWorkspaceSize(model)) { 10919+ if (workspace_size < OH_AI_ModelCalcWorkspaceSize(model)) { 10920 return; 10921 } 10922 if (micro_model->inputs.handle_num != GRAPH_INPUTS_SIZE) { 10923@@ -75,29 +99,29 @@ const char cortex_set_workspace[] = R"RAW( 10924 )RAW"; 10925 10926 const char micro_model_build_state[] = R"RAW( 10927-typedef MSStatus (*ModelBuild)(MSModelHandle model, const void *model_data, 10928+typedef OH_AI_Status (*ModelBuild)(OH_AI_ModelHandle model, const void *model_data, 10929 size_t data_size, 10930- const MSContextHandle model_context); 10931+ const OH_AI_ContextHandle model_context); 10932 )RAW"; 10933 10934 const char micro_model_build_implement[] = R"RAW( 10935-MSStatus MSModelBuild(MSModelHandle model, const void *model_data, 10936- size_t data_size, MSModelType model_type, 10937- const MSContextHandle model_context) { 10938- if (model_type != kMSModelTypeMindIR) { 10939- return kMSStatusLiteNotSupport; 10940+OH_AI_Status OH_AI_ModelBuild(OH_AI_ModelHandle model, const void *model_data, 10941+ size_t data_size, OH_AI_ModelType model_type, 10942+ const OH_AI_ContextHandle model_context) { 10943+ if (model_type != OH_AI_MODELTYPE_MINDIR) { 10944+ return OH_AI_STATUS_LITE_NOT_SUPPORT; 10945 } 10946 if (model == NULL) { 10947- return kMSStatusLiteParamInvalid; 10948+ return OH_AI_STATUS_LITE_PARAM_INVALID; 10949 } 10950 )RAW"; 10951 10952 const char micro_model_predict_state[] = R"RAW( 10953-typedef MSStatus (*ModelPredict)(MSModelHandle model, 10954- const MSTensorHandleArray inputs, 10955- MSTensorHandleArray *outputs, 10956- const MSKernelCallBackC before, 10957- const MSKernelCallBackC after); 10958+typedef OH_AI_Status (*ModelPredict)(OH_AI_ModelHandle model, 10959+ const OH_AI_TensorHandleArray inputs, 10960+ OH_AI_TensorHandleArray *outputs, 10961+ const OH_AI_KernelCallBack before, 10962+ const OH_AI_KernelCallBack after); 10963 )RAW"; 10964 10965 const char free_resource_state[] = R"RAW( 10966@@ -107,7 +131,7 @@ typedef void (*FreeResource)(); 10967 void CodeMSModelCalcWorkspaceSize(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, 10968 const Configurator &config) { 10969 if (config.target() == kCortex_M) { 10970- ofs << "size_t MSModelCalcWorkspaceSize(MSModelHandle model) {\n" 10971+ ofs << "size_t OH_AI_ModelCalcWorkspaceSize(OH_AI_ModelHandle model) {\n" 10972 << " MicroModel *micro_model = (MicroModel *)model;\n" 10973 << " if (micro_model == NULL) {\n" 10974 << " return 0;\n" 10975@@ -118,13 +142,13 @@ void CodeMSModelCalcWorkspaceSize(std::ofstream &ofs, const std::unique_ptr<Code 10976 << " return micro_model->calc_work_space(model);\n" 10977 << "}\n"; 10978 } else { 10979- ofs << "size_t MSModelCalcWorkspaceSize(MSModelHandle model) {\n return 0;\n}\n"; 10980+ ofs << "size_t OH_AI_ModelCalcWorkspaceSize(OH_AI_ModelHandle model) {\n return 0;\n}\n"; 10981 } 10982 ofs << "\n"; 10983 } 10984 10985 void CodeCortexCalcWorkspaceSize(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx) { 10986- ofs << "size_t MSModelCalcWorkspaceSize" << ctx->GetCurModelIndex() << "(MSModelHandle model) {\n" 10987+ ofs << "size_t OH_AI_ModelCalcWorkspaceSize" << ctx->GetCurModelIndex() << "(OH_AI_ModelHandle model) {\n" 10988 << "size_t shape_size = 0;\n"; 10989 std::vector<Tensor *> inputs = ctx->graph_inputs(); 10990 for (size_t i = 0; i < inputs.size(); ++i) { 10991@@ -141,7 +165,7 @@ void CodeCortexCalcWorkspaceSize(std::ofstream &ofs, const std::unique_ptr<Coder 10992 } 10993 10994 void CodeMSModelSetWorkspace(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, const Configurator &config) { 10995- ofs << "void MSModelSetWorkspace(MSModelHandle model, void *workspace, size_t workspace_size) {"; 10996+ ofs << "void OH_AI_ModelSetWorkspace(OH_AI_ModelHandle model, void *workspace, size_t workspace_size) {"; 10997 if (config.target() == kCortex_M) { 10998 ofs << " MicroModel *micro_model = (MicroModel *)model;\n" 10999 << " if (micro_model == NULL) {\n" 11000@@ -156,8 +180,8 @@ void CodeMSModelSetWorkspace(std::ofstream &ofs, const std::unique_ptr<CoderCont 11001 } 11002 11003 void CodeCortexSetWorkspace(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx) { 11004- ofs << "void MSModelSetWorkspace" << ctx->GetCurModelIndex() 11005- << "(MSModelHandle model, void *workspace, size_t workspace_size) {\n"; 11006+ ofs << "void OH_AI_ModelSetWorkspace" << ctx->GetCurModelIndex() 11007+ << "(OH_AI_ModelHandle model, void *workspace, size_t workspace_size) {\n"; 11008 ofs << cortex_set_workspace; 11009 ofs << " micro_model->runtime_buffer = workspace;\n" 11010 " int buffer_size = GetBufferSize" 11011@@ -173,12 +197,12 @@ void CodeCortexSetWorkspace(std::ofstream &ofs, const std::unique_ptr<CoderConte 11012 buffer_size += WEIGHT_BUF_SIZE; 11013 buffer_size = UP_ROUND(buffer_size,4); 11014 11015- micro_model->inputs.handle_list = (MSTensorHandle *)&buf[buffer_size]; 11016+ micro_model->inputs.handle_list = (OH_AI_TensorHandle *)&buf[buffer_size]; 11017 buffer_size += GRAPH_INPUTS_SIZE * sizeof(MicroTensor *); 11018 buffer_size = UP_ROUND(buffer_size,4); 11019 MicroTensor **input_tensors = (MicroTensor **)micro_model->inputs.handle_list; 11020 11021- micro_model->outputs.handle_list = (MSTensorHandle *)&buf[buffer_size]; 11022+ micro_model->outputs.handle_list = (OH_AI_TensorHandle *)&buf[buffer_size]; 11023 buffer_size += GRAPH_OUTPUTS_SIZE * sizeof(MicroTensor *); 11024 buffer_size = UP_ROUND(buffer_size,4); 11025 MicroTensor **output_tensors = (MicroTensor **)micro_model->outputs.handle_list; 11026@@ -215,7 +239,7 @@ void CodeCortexSetWorkspace(std::ofstream &ofs, const std::unique_ptr<CoderConte 11027 auto array_tostring = [&ofs](Tensor *tensor, const std::string &prefix, size_t index) { 11028 ofs << kAlignedString << prefix << "_tensors[" << index << "]->type = " << EnumNameMSDataType(tensor->data_type()) 11029 << ";\n"; 11030- ofs << kAlignedString << prefix << "_tensors[" << index << "]->format = kMSFormatNHWC;\n"; 11031+ ofs << kAlignedString << prefix << "_tensors[" << index << "]->format = OH_AI_FORMAT_NHWC;\n"; 11032 ofs << kAlignedString << prefix << "_tensors[" << index << "]->ndim = " << tensor->shape().size() << ";\n"; 11033 size_t shape_size = tensor->shape().size(); 11034 for (size_t i = 0; i < shape_size; i++) { 11035@@ -234,32 +258,31 @@ void CodeCortexSetWorkspace(std::ofstream &ofs, const std::unique_ptr<CoderConte 11036 ofs << "}\n"; 11037 } 11038 11039-void CodeMSTensorHandleArrayDestroyState(std::ofstream &ofs, const Configurator &config) { 11040- if (config.target() != kCortex_M) { 11041- ofs << handle_array_destroy_state; 11042- } 11043+void CodeMSModelCreateDefault(std::ofstream &ofs) { 11044+ ofs << "OH_AI_ModelHandle OH_AI_ModelCreate() { return model0; }\n"; 11045 } 11046 11047-void CodeMSModelCreateDefault(std::ofstream &ofs) { ofs << "MSModelHandle MSModelCreate() { return model0; }\n"; } 11048- 11049 void CodeMSModelCreate(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, const Configurator &config) { 11050 if (config.target() != kCortex_M) { 11051- ofs << "MSStatus MSModelCreate" << ctx->GetCurModelIndex() << "(MicroModel *micro_model) {"; 11052+ ofs << "OH_AI_Status OH_AI_ModelCreate" << ctx->GetCurModelIndex() << "(MicroModel *micro_model) {"; 11053 ofs << R"RAW( 11054 if (micro_model == NULL) { 11055- return kMSStatusLiteNullptr; 11056- } 11057- 11058- void *runtime_buffer = GlobalMemory(); 11059- if (runtime_buffer == NULL) { 11060- return kMSStatusLiteNullptr; 11061+ return OH_AI_STATUS_LITE_NULLPTR; 11062 } 11063- micro_model->runtime_buffer = runtime_buffer; 11064 )RAW"; 11065- ofs << " int ret = SetBuffer" << ctx->GetCurModelIndex() << "(((MemBlock *)runtime_buffer)->addr);\n" 11066- << " if (ret != kMSStatusSuccess) {\n" 11067- << " return kMSStatusLiteMemoryFailed;\n" 11068- << " }\n\n"; 11069+ if (!config.dynamic_shape()) { 11070+ ofs << "void *runtime_buffer = GlobalMemory();\n" 11071+ << "if (runtime_buffer == NULL) {\n" 11072+ << " return OH_AI_STATUS_LITE_NULLPTR;\n" 11073+ << " }\n" 11074+ << " micro_model->runtime_buffer = runtime_buffer;\n"; 11075+ ofs << " int ret = SetBuffer" << ctx->GetCurModelIndex() << "(((MemBlock *)runtime_buffer)->addr);\n" 11076+ << " if (ret != OH_AI_STATUS_SUCCESS) {\n" 11077+ << " return OH_AI_STATUS_LITE_MEMORY_FAILED;\n" 11078+ << " }\n\n"; 11079+ } else { 11080+ ofs << " micro_model->runtime_buffer = NULL;\n"; 11081+ } 11082 if (config.code_mode() == CodeMode::Inference) { 11083 ofs << " micro_model->train_mode = false;\n"; 11084 } else if (config.code_mode() == CodeMode::Train) { 11085@@ -269,7 +292,7 @@ void CodeMSModelCreate(std::ofstream &ofs, const std::unique_ptr<CoderContext> & 11086 ofs << kAlignedString << prefix << "_tensors[" << index << "] = malloc(sizeof(MicroTensor));\n"; 11087 ofs << kAlignedString << prefix << "_tensors[" << index << "]->type = " << EnumNameMSDataType(tensor->data_type()) 11088 << ";\n"; 11089- ofs << kAlignedString << prefix << "_tensors[" << index << "]->format = kMSFormatNHWC;\n"; 11090+ ofs << kAlignedString << prefix << "_tensors[" << index << "]->format = OH_AI_FORMAT_NHWC;\n"; 11091 ofs << kAlignedString << prefix << "_tensors[" << index << "]->ndim = " << tensor->shape().size() << ";\n"; 11092 size_t shape_size = tensor->shape().size(); 11093 ofs << kAlignedString << prefix << "_tensors[" << index << "]->shape = " 11094@@ -289,30 +312,30 @@ void CodeMSModelCreate(std::ofstream &ofs, const std::unique_ptr<CoderContext> & 11095 outputs = ctx->graph_train_outputs(); 11096 } 11097 size_t inputs_size = inputs.size(); 11098- ofs << " MSTensorHandleArray model_inputs;\n"; 11099+ ofs << " OH_AI_TensorHandleArray model_inputs;\n"; 11100 ofs << " model_inputs.handle_num = " << inputs_size << ";\n"; 11101 ofs << " MicroTensor **input_tensors = malloc(" << inputs_size << " * sizeof(MicroTensor *));\n"; 11102- ofs << " model_inputs.handle_list = (MSTensorHandle *)(input_tensors);\n"; 11103+ ofs << " model_inputs.handle_list = (OH_AI_TensorHandle *)(input_tensors);\n"; 11104 ofs << " micro_model->inputs = model_inputs;\n"; 11105 for (size_t i = 0; i < inputs_size; ++i) { 11106 Tensor *input = inputs[i]; 11107 array_tostring(input, "input", i); 11108 } 11109 size_t outputs_size = outputs.size(); 11110- ofs << " MSTensorHandleArray model_outputs;\n"; 11111+ ofs << " OH_AI_TensorHandleArray model_outputs;\n"; 11112 ofs << " model_outputs.handle_num = " << outputs_size << ";\n"; 11113 ofs << " MicroTensor **output_tensors = malloc(" << outputs_size << " * sizeof(MicroTensor *));\n"; 11114- ofs << " model_outputs.handle_list = (MSTensorHandle *)(output_tensors);\n"; 11115+ ofs << " model_outputs.handle_list = (OH_AI_TensorHandle *)(output_tensors);\n"; 11116 ofs << " micro_model->outputs = model_outputs;\n"; 11117 for (size_t i = 0; i < outputs_size; ++i) { 11118 Tensor *output = outputs[i]; 11119 array_tostring(output, "output", i); 11120 } 11121- ofs << " return kMSStatusSuccess;\n"; 11122+ ofs << " return OH_AI_STATUS_SUCCESS;\n"; 11123 } else { 11124- ofs << "MSStatus MSModelCreate" << ctx->GetCurModelIndex() << "(MicroModel *micro_model) {\n"; 11125+ ofs << "OH_AI_Status OH_AI_ModelCreate" << ctx->GetCurModelIndex() << "(MicroModel *micro_model) {\n"; 11126 ofs << " micro_model->train_mode = false;\n"; 11127- ofs << " return kMSStatusSuccess;\n"; 11128+ ofs << " return OH_AI_STATUS_SUCCESS;\n"; 11129 } 11130 ofs << "}\n\n"; 11131 } 11132@@ -324,20 +347,20 @@ void CodeMSModelBuildCommon(std::ofstream &ofs, const Configurator &config) { 11133 ofs << R"RAW( 11134 MicroModel *micro_model = (MicroModel *)model; 11135 if (micro_model == NULL) { 11136- return kMSStatusLiteNullptr; 11137+ return OH_AI_STATUS_LITE_NULLPTR; 11138 } 11139 if (micro_model->build == NULL) { 11140- return kMSStatusLiteNullptr; 11141+ return OH_AI_STATUS_LITE_NULLPTR; 11142 } 11143 )RAW"; 11144- if (config.target() != kCortex_M) { 11145+ if (config.target() != kCortex_M && !config.dynamic_shape()) { 11146 ofs << " IncRefCount();\n"; 11147 } 11148 ofs << R"RAW( 11149- MSStatus ret = 11150+ OH_AI_Status ret = 11151 micro_model->build(model, model_data, data_size, model_context); 11152- if (ret != kMSStatusSuccess) { 11153- MSModelDestroy(model); 11154+ if (ret != OH_AI_STATUS_SUCCESS) { 11155+ OH_AI_ModelDestroy(&model); 11156 } 11157 return ret; 11158 } 11159@@ -345,23 +368,23 @@ void CodeMSModelBuildCommon(std::ofstream &ofs, const Configurator &config) { 11160 } 11161 11162 void CodeMSModelBuild(std::ofstream &ofs, const int model_index, const size_t weight_size, const Configurator &config) { 11163- ofs << "MSStatus MSModelBuild" << model_index 11164- << "(MSModelHandle model, const void *model_data, size_t data_size,\n" 11165- " const MSContextHandle model_context) {\n" 11166+ ofs << "OH_AI_Status OH_AI_ModelBuild" << model_index 11167+ << "(OH_AI_ModelHandle model, const void *model_data, size_t data_size,\n" 11168+ " const OH_AI_ContextHandle model_context) {\n" 11169 " if (model == NULL) {\n" 11170- " return kMSStatusLiteParamInvalid;\n" 11171+ " return OH_AI_STATUS_LITE_PARAM_INVALID;\n" 11172 " }\n"; 11173 if (config.changeable_weights_name().empty()) { 11174 ofs << " if (data_size != " << weight_size 11175 << ") {\n" 11176- " return kMSStatusLiteInputParamInvalid;\n" 11177+ " return OH_AI_STATUS_LITE_INPUT_PARAM_INVALID;\n" 11178 " }\n"; 11179 } 11180 ofs << " MicroModel *micro_model = (MicroModel *)model;\n" 11181- " int ret = MSModelCreate" 11182+ " int ret = OH_AI_ModelCreate" 11183 << model_index 11184 << "(micro_model);\n" 11185- " if (ret != kMSStatusSuccess) {\n" 11186+ " if (ret != OH_AI_STATUS_SUCCESS) {\n" 11187 " return ret;\n" 11188 " }\n"; 11189 if (config.target() != kCortex_M) { 11190@@ -372,7 +395,7 @@ void CodeMSModelBuild(std::ofstream &ofs, const int model_index, const size_t we 11191 if (config.support_parallel()) { 11192 ofs << " MicroContext *micro_context = (MicroContext *)model_context;\n" 11193 " if (micro_context == NULL) {\n" 11194- " return kMSStatusLiteNullptr;" 11195+ " return OH_AI_STATUS_LITE_NULLPTR;" 11196 " }\n" 11197 " ret = CreateThreadPool(micro_context->thread_num_);\n" 11198 " if(ret != RET_OK) {\n" 11199@@ -384,35 +407,172 @@ void CodeMSModelBuild(std::ofstream &ofs, const int model_index, const size_t we 11200 ofs << "}\n"; 11201 } 11202 11203+void CodeMSModelResizeInit(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, const Configurator &config) { 11204+ auto &dynamic_symbols_num = config.dynamic_symbols_num(); 11205+ std::string array_index; 11206+ for (auto num : dynamic_symbols_num) { 11207+ array_index += "[" + std::to_string(num) + "]"; 11208+ } 11209+ auto shapes = ctx->shape_all_scenes(); 11210+ if (!shapes.empty()) { 11211+ auto num_of_each_scene = shapes.begin()->second.size(); 11212+ ofs << " static int shapes" << array_index << "[" + std::to_string(num_of_each_scene) + "] = {"; 11213+ for (auto &item : shapes) { 11214+ auto &shape_val = item.second; 11215+ for (size_t j = 0; j < shape_val.size(); ++j) { 11216+ ofs << shape_val[j] << ", "; 11217+ } 11218+ } 11219+ ofs << "};\n"; 11220+ } 11221+ auto offsets = ctx->offset_all_scenes(); 11222+ if (!offsets.empty()) { 11223+ auto num_of_each_scene = offsets.begin()->second.size(); 11224+ ofs << " static int offsets" << array_index << "[" + std::to_string(num_of_each_scene) + "] = {"; 11225+ for (auto &item : offsets) { 11226+ auto &offset_val = item.second; 11227+ for (size_t j = 0; j < offset_val.size(); ++j) { 11228+ ofs << offset_val[j] << ", "; 11229+ } 11230+ } 11231+ ofs << "};\n"; 11232+ } 11233+ ofs << " size_t buffer_sizes" << array_index << " = {"; 11234+ auto buffer_size = ctx->buffer_sizes(); 11235+ auto workspace = ctx->workspaces(); 11236+ if (buffer_size.size() != workspace.size()) { 11237+ return; 11238+ } 11239+ for (size_t i = 0; i < buffer_size.size(); i++) { 11240+ ofs << buffer_size[i] + workspace[i] << ", "; 11241+ } 11242+ ofs << "};\n"; 11243+} 11244+ 11245+void CodeMSModelResize(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, const Configurator &config) { 11246+ auto &shape_templates = ctx->shape_templates(); 11247+ ofs << "OH_AI_Status OH_AI_ModelResize" << ctx->GetCurModelIndex() 11248+ << "(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray inputs, OH_AI_ShapeInfo *shape_infos, size_t " 11249+ "shape_info_num) {\n" 11250+ " if (model == NULL) {\n" 11251+ " return OH_AI_STATUS_LITE_PARAM_INVALID;\n" 11252+ " }\n"; 11253+ if (!config.dynamic_shape()) { 11254+ ofs << " return OH_AI_STATUS_LITE_NOT_SUPPORT;\n"; 11255+ } else { 11256+ ofs << " MicroModel *micro_model = (MicroModel *)model;\n" 11257+ << " if (micro_model == NULL) {\n" 11258+ " return OH_AI_STATUS_LITE_NULLPTR;\n" 11259+ " }\n"; 11260+ CodeMSModelResizeInit(ofs, ctx, config); 11261+ std::map<std::string, std::vector<int>> symbol_to_indexes; 11262+ std::map<std::string, std::string> user_to_inner; 11263+ auto &user_graph_inputs_template = config.user_graph_inputs_template(); 11264+ for (size_t i = 0; i < ctx->graph_inputs().size(); ++i) { 11265+ auto cur_tensor = ctx->graph_inputs()[i]; 11266+ auto cur_shapes = shape_templates.at(cur_tensor); 11267+ for (size_t j = 0; j < cur_shapes.size(); ++j) { 11268+ if (IsNumber(cur_shapes.at(j))) { 11269+ continue; 11270+ } 11271+ ofs << " if (shape_infos[" << i << "].shape[" << j << "] <= 0) {\n" 11272+ << " return OH_AI_STATUS_LITE_PARAM_INVALID;\n" 11273+ << " }\n"; 11274+ ofs << " ((MicroTensor *)(inputs.handle_list[" << i << "]))->shape[" << j << "] = shape_infos[" << i 11275+ << "].shape[" << j << "];\n"; 11276+ if (symbol_to_indexes.find(cur_shapes.at(j)) != symbol_to_indexes.end()) { 11277+ continue; 11278+ } 11279+ symbol_to_indexes[cur_shapes.at(j)] = {static_cast<int>(i), static_cast<int>(j)}; 11280+ user_to_inner[user_graph_inputs_template[i][j]] = cur_shapes.at(j); 11281+ } 11282+ } 11283+ int index = 0; 11284+ std::map<std::string, std::string> inner_to_outer; 11285+ for (auto &item : symbol_to_indexes) { 11286+ ofs << " int dim" << index << " = shape_infos[" << item.second[0] << "].shape[" << item.second[1] << "];\n"; 11287+ inner_to_outer[item.first] = "dim" + std::to_string(index); 11288+ ++index; 11289+ } 11290+ std::string condition; 11291+ index = 0; 11292+ for (; index < static_cast<int>(symbol_to_indexes.size()) - 1; ++index) { 11293+ condition += "store" + std::to_string(ctx->GetCurModelIndex()) + "_" + std::to_string(index) + " == dim" + 11294+ std::to_string(index) + " && "; 11295+ } 11296+ condition += "store" + std::to_string(ctx->GetCurModelIndex()) + "_" + std::to_string(index) + " == dim" + 11297+ std::to_string(index); 11298+ ofs << " if (" << condition << ") {\n" 11299+ << " return OH_AI_STATUS_SUCCESS;\n" 11300+ << " }\n"; 11301+ for (size_t i = 0; i < symbol_to_indexes.size(); ++i) { 11302+ ofs << " store" + std::to_string(ctx->GetCurModelIndex()) + "_" << i << " = dim" << i << ";\n"; 11303+ } 11304+ ofs << " if (" << kBufferPrefixName << " != NULL) {\n"; 11305+ ofs << " free(" << kBufferPrefixName << ");\n"; 11306+ ofs << " }\n"; 11307+ std::string real_array_index; 11308+ auto &dynamic_symbols = config.dynamic_symbols(); 11309+ for (auto &symbol : dynamic_symbols) { 11310+ real_array_index += "[" + inner_to_outer[user_to_inner[symbol]] + " - 1]"; 11311+ } 11312+ ofs << " " << kBufferPrefixName << " = malloc(buffer_sizes" << real_array_index << ");\n"; 11313+ ofs << " micro_model->runtime_buffer = " << kBufferPrefixName << ";\n"; 11314+ ofs << " " << kShapePrefixName << " = &shapes" << real_array_index << "[0];\n"; 11315+ ofs << " " << kOffsetPrefixName << " = &offsets" << real_array_index << "[0];\n"; 11316+ ofs << " OH_AI_TensorHandleArray outputs = OH_AI_ModelGetOutputs(model);\n"; 11317+ for (size_t i = 0; i < ctx->graph_outputs().size(); ++i) { 11318+ ofs << " OH_AI_TensorSetData(outputs.handle_list[" << i << "], NULL);\n"; 11319+ auto cur_tensor = ctx->graph_outputs()[i]; 11320+ auto cur_shapes = shape_templates.at(cur_tensor); 11321+ for (size_t j = 0; j < cur_shapes.size(); ++j) { 11322+ if (IsNumber(cur_shapes.at(j))) { 11323+ continue; 11324+ } 11325+ ofs << " ((MicroTensor *)(outputs.handle_list[" << i << "]))->shape[" << j << "] = " << cur_shapes.at(j) 11326+ << ";\n"; 11327+ } 11328+ } 11329+ ofs << " return OH_AI_STATUS_SUCCESS;\n"; 11330+ } 11331+ ofs << "}\n"; 11332+} 11333+ 11334 void CodeMSModelDestory(std::ofstream &ofs, const Configurator *config) { 11335- if (config->target() != kCortex_M) { 11336+ if (config->code_mode() == CodeMode::Inference && config->target() != kCortex_M) { 11337 ofs << handle_array_destroy; 11338 } 11339- ofs << "void MSModelDestroy(MSModelHandle *model) {\n"; 11340+ ofs << "void OH_AI_ModelDestroy(OH_AI_ModelHandle *model) {\n"; 11341+ ofs << " if (*model) {\n" 11342+ " MicroModel *micro_model = (MicroModel *)*model;\n"; 11343 if (config->target() != kCortex_M) { 11344- ofs << " if (*model) {\n" 11345- " MicroModel *micro_model = (MicroModel *)*model;\n"; 11346- ofs << " if (micro_model->runtime_buffer) {\n" 11347- " micro_model->runtime_buffer = NULL;\n" 11348- " }\n"; 11349- ofs << " MSTensorHandleArrayDestroy(micro_model->inputs);\n" 11350- " MSTensorHandleArrayDestroy(micro_model->outputs);\n" 11351- " micro_model->inputs.handle_list = NULL;\n" 11352+ ofs << " if (micro_model->runtime_buffer) {\n"; 11353+ if (config->dynamic_shape()) { 11354+ ofs << " free(micro_model->runtime_buffer);\n"; 11355+ } else { 11356+ ofs << " micro_model->runtime_buffer = NULL;\n"; 11357+ } 11358+ ofs << " }\n"; 11359+ } 11360+ ofs << " OH_AI_TensorHandleArrayDestroy(micro_model->inputs);\n" 11361+ " OH_AI_TensorHandleArrayDestroy(micro_model->outputs);\n"; 11362+ if (config->code_mode() == CodeMode::Inference) { 11363+ ofs << " micro_model->inputs.handle_list = NULL;\n" 11364 " micro_model->outputs.handle_list = NULL;\n" 11365- " micro_model->free_resource();\n" 11366- " DecRefCount();\n" 11367- " }\n"; 11368- 11369- if (config->support_parallel()) { 11370- ofs << " ClearThreadPool();\n"; 11371+ " micro_model->free_resource();\n"; 11372+ if (!config->dynamic_shape()) { 11373+ ofs << " DecRefCount();\n"; 11374 } 11375+ ofs << " }\n"; 11376 } else { 11377- ofs << " if (*model) {\n" 11378- " MicroModel *micro_model = (MicroModel *)*model;\n"; 11379- ofs << " micro_model->runtime_buffer = NULL;\n" 11380+ ofs << " free(*model);\n" 11381 " *model = NULL;\n" 11382 " }\n"; 11383 } 11384+ 11385+ if (config->support_parallel()) { 11386+ ofs << " ClearThreadPool();\n"; 11387+ } 11388 ofs << "}\n"; 11389 } 11390 11391@@ -420,14 +580,14 @@ void CodeMSModelPredictState(std::ofstream &ofs) { ofs << micro_model_predict_st 11392 11393 void CodeMSModelPredictCommon(std::ofstream &ofs) { 11394 ofs << R"RAW( 11395-MSStatus MSModelPredict(MSModelHandle model, const MSTensorHandleArray inputs, MSTensorHandleArray *outputs, 11396- const MSKernelCallBackC before, const MSKernelCallBackC after) { 11397+OH_AI_Status OH_AI_ModelPredict(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray inputs, OH_AI_TensorHandleArray *outputs, 11398+ const OH_AI_KernelCallBack before, const OH_AI_KernelCallBack after) { 11399 MicroModel *micro_model = (MicroModel *)model; 11400 if (micro_model == NULL) { 11401- return kMSStatusLiteNullptr; 11402+ return OH_AI_STATUS_LITE_NULLPTR; 11403 } 11404 if (micro_model->predict == NULL) { 11405- return kMSStatusLiteNullptr; 11406+ return OH_AI_STATUS_LITE_NULLPTR; 11407 } 11408 return micro_model->predict(model, inputs, outputs, before, after); 11409 } 11410@@ -438,35 +598,35 @@ MSStatus MSModelPredict(MSModelHandle model, const MSTensorHandleArray inputs, M 11411 void CodeMSModelPredict(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, const Configurator &config) { 11412 auto inputs_num = ctx->graph_inputs().size(); 11413 auto outputs_num = ctx->graph_outputs().size(); 11414- ofs << "MSStatus MSModelPredict" << ctx->GetCurModelIndex() 11415- << "(MSModelHandle model, const MSTensorHandleArray inputs, MSTensorHandleArray *outputs,\n" 11416- << " const MSKernelCallBackC before, const MSKernelCallBackC after) {\n"; 11417+ ofs << "OH_AI_Status OH_AI_ModelPredict" << ctx->GetCurModelIndex() 11418+ << "(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray inputs, OH_AI_TensorHandleArray *outputs,\n" 11419+ << " const OH_AI_KernelCallBack before, const OH_AI_KernelCallBack after) {\n"; 11420 ofs << R"RAW( 11421 MicroModel *micro_model = (MicroModel *)model; 11422 if (micro_model == NULL) { 11423- return kMSStatusLiteNullptr; 11424+ return OH_AI_STATUS_LITE_NULLPTR; 11425 } 11426 if (micro_model->runtime_buffer == NULL) { 11427- return kMSStatusLiteMemoryFailed; 11428+ return OH_AI_STATUS_LITE_MEMORY_FAILED; 11429 } 11430 )RAW"; 11431 ofs << " if (inputs.handle_num != " << inputs_num << ") {\n"; 11432- ofs << " return kMSStatusLiteParamInvalid;\n"; 11433+ ofs << " return OH_AI_STATUS_LITE_PARAM_INVALID;\n"; 11434 ofs << " }\n"; 11435 ofs << " if (outputs->handle_num != " << outputs_num << ") {\n"; 11436- ofs << " return kMSStatusLiteParamInvalid;\n"; 11437+ ofs << " return OH_AI_STATUS_LITE_PARAM_INVALID;\n"; 11438 ofs << " }\n"; 11439- if (config.target() != kCortex_M) { 11440+ if (config.target() != kCortex_M && !config.dynamic_shape()) { 11441 ofs << " if (!LockBuffer(micro_model->runtime_buffer)) {\n" 11442 << " void *buffer = Malloc(GetBufferSize" << ctx->GetCurModelIndex() << "());\n" 11443 << " if (buffer == NULL) {\n" 11444- << " return kMSStatusLiteNullptr;\n" 11445+ << " return OH_AI_STATUS_LITE_NULLPTR;\n" 11446 << " }\n" 11447 << " if (micro_model->runtime_buffer != buffer) {\n" 11448 << " micro_model->runtime_buffer = buffer;\n" 11449 << " int ret = SetBuffer" << ctx->GetCurModelIndex() << "(((MemBlock *)buffer)->addr);\n" 11450- << " if (ret != kMSStatusSuccess) {\n" 11451- << " return kMSStatusLiteMemoryFailed;\n" 11452+ << " if (ret != OH_AI_STATUS_SUCCESS) {\n" 11453+ << " return OH_AI_STATUS_LITE_MEMORY_FAILED;\n" 11454 << " }\n" 11455 << " }\n" 11456 << " }\n"; 11457@@ -495,8 +655,7 @@ void CodeMSModelPredict(std::ofstream &ofs, const std::unique_ptr<CoderContext> 11458 ofs << " }\n"; 11459 ofs << " }\n"; 11460 ofs << "\n"; 11461- ofs << " void *outputs_data_array[" << outputs_num << "];\n"; 11462- ofs << " int expect_out_types[" << outputs_num << "] = {"; 11463+ ofs << " int cur_out_types[" << outputs_num << "] = {"; 11464 for (size_t i = 0; i < outputs_num; ++i) { 11465 ofs << ctx->graph_outputs().at(i)->data_type() << ", "; 11466 } 11467@@ -506,21 +665,18 @@ void CodeMSModelPredict(std::ofstream &ofs, const std::unique_ptr<CoderContext> 11468 ofs << "false, "; 11469 } 11470 ofs << "};\n"; 11471- ofs << " for (int i = 0; i < " << outputs_num << "; i++) {\n"; 11472- ofs << " outputs_data_array[i] = MSTensorGetMutableData(outputs->handle_list[i]);\n"; 11473- ofs << " }\n"; 11474- ofs << " CopyOutputsData" << ctx->GetCurModelIndex() 11475- << "(outputs, outputs_data_array, expect_out_types, out_type_changed);\n"; 11476- if (config.target() != kCortex_M) { 11477+ ofs << " OH_AI_Status ret = CopyOutputsData" << ctx->GetCurModelIndex() 11478+ << "(outputs, cur_out_types, out_type_changed);\n"; 11479+ if (config.target() != kCortex_M && !config.dynamic_shape()) { 11480 ofs << " UnLockBuffer(micro_model->runtime_buffer);\n"; 11481 } 11482- ofs << " return kMSStatusSuccess;\n"; 11483+ ofs << " return ret;\n"; 11484 ofs << "}\n"; 11485 } 11486 11487 void CodeCopyOutputsState(std::ofstream &ofs, const int model_index) { 11488- ofs << "int CopyOutputsData" << model_index 11489- << "(MSTensorHandleArray *outputs_ori, void **outputs, int *expect_types, bool *type_changed);\n\n"; 11490+ ofs << "OH_AI_Status CopyOutputsData" << model_index 11491+ << "(OH_AI_TensorHandleArray *outputs_ori, void **outputs, int *cur_out_types, bool *type_changed);\n\n"; 11492 } 11493 11494 void CodeCopyOutputsImplement(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx) { 11495@@ -528,56 +684,60 @@ void CodeCopyOutputsImplement(std::ofstream &ofs, const std::unique_ptr<CoderCon 11496 std::vector<Tensor *> outputs = ctx->graph_outputs(); 11497 size_t outputs_size = outputs.size(); 11498 11499- ofs << "int CopyOutputsData" << ctx->GetCurModelIndex() 11500- << "(MSTensorHandleArray *outputs_ori, void **outputs, int *expect_types, bool *type_changed) {\n" 11501- " if (outputs_ori == NULL || outputs == NULL) {\n" 11502- " return RET_ERROR;\n" 11503+ ofs << "OH_AI_Status CopyOutputsData" << ctx->GetCurModelIndex() 11504+ << "(OH_AI_TensorHandleArray *outputs_ori, int *cur_out_types, bool *type_changed) {\n" 11505+ " if (outputs_ori == NULL || cur_out_types == NULL || type_changed == NULL) {\n" 11506+ " return OH_AI_STATUS_LITE_NULLPTR;\n" 11507 " }\n"; 11508 ofs << " unsigned char *buffer[" << outputs_size << "] = {"; 11509 for (size_t i = 0; i < outputs_size; ++i) { 11510- ofs << tensor_map[outputs[i]] << ", "; 11511- } 11512- ofs << "};\n"; 11513- ofs << " size_t buffer_size[" << outputs_size << "] = {"; 11514- for (size_t i = 0; i < outputs_size; ++i) { 11515- Tensor *output = outputs[i]; 11516- MS_CHECK_PTR_IF_NULL(output); 11517- ofs << output->Size() << ", "; 11518+ auto out_str = ctx->tensor_addr(outputs[i]); 11519+ if (out_str.empty()) { 11520+ ofs << tensor_map[outputs[i]] << ", "; 11521+ } else { 11522+ ofs << out_str << ", "; 11523+ } 11524 } 11525 ofs << "};\n"; 11526 ofs << " for (int i = 0; i < " << outputs_size << "; i++) {\n" 11527 << " MicroTensor *micro_tensor = (MicroTensor *)outputs_ori->handle_list[i];\n" 11528- << " int cur_type = micro_tensor->type;\n" 11529- << " int expect_type = expect_types[i];\n"; 11530- ofs << " if (cur_type == expect_type) {\n" 11531- << " memcpy(outputs[i], buffer[i], buffer_size[i]);\n" 11532+ << " int expect_type = micro_tensor->type;\n" 11533+ << " int cur_type = cur_out_types[i];\n"; 11534+ ofs << " if (expect_type == cur_type) {\n" 11535+ << " micro_tensor->data = buffer[i];\n" 11536+ << " micro_tensor->owned = false;\n" 11537 << " continue;\n" 11538 << " }\n" 11539+ << "#ifdef ENABLE_FP16\n" 11540 << " int shape_size = micro_tensor->ndim;\n" 11541 << " int num = 1;\n" 11542- << " for (int i = 0; i < shape_size; ++i) {\n" 11543- << " num *= micro_tensor->shape[i];\n" 11544+ << " for (int j = 0; j < shape_size; ++j) {\n" 11545+ << " num *= micro_tensor->shape[j];\n" 11546 << " }\n"; 11547- ofs << " int type_trans_mode = TypeTransMode_MAX;\n" 11548- " if (expect_type == kMSDataTypeNumberTypeFloat16 && cur_type == kMSDataTypeNumberTypeFloat32) {\n" 11549- " type_trans_mode = TypeTransMode_FP32_TO_FP16;\n" 11550- " } else if (expect_type == kMSDataTypeNumberTypeFloat32 && cur_type == kMSDataTypeNumberTypeFloat16) {\n" 11551- " type_trans_mode = TypeTransMode_FP16_TO_FP32;\n" 11552- " }\n"; 11553+ ofs 11554+ << " int type_trans_mode = TypeTransMode_MAX;\n" 11555+ " if (expect_type == OH_AI_DATATYPE_NUMBERTYPE_FLOAT16 && cur_type == OH_AI_DATATYPE_NUMBERTYPE_FLOAT32) {\n" 11556+ " type_trans_mode = TypeTransMode_FP32_TO_FP16;\n" 11557+ " } else if (expect_type == OH_AI_DATATYPE_NUMBERTYPE_FLOAT32 && cur_type == " 11558+ "OH_AI_DATATYPE_NUMBERTYPE_FLOAT16) {\n" 11559+ " type_trans_mode = TypeTransMode_FP16_TO_FP32;\n" 11560+ " }\n"; 11561 ofs << " if (type_trans_mode == TypeTransMode_UNSUPPORT) {\n" 11562- << " return kMSStatusLiteNotSupport;\n" 11563+ << " return OH_AI_STATUS_LITE_NOT_SUPPORT;\n" 11564 << " }\n"; 11565- ofs << "#ifdef ENABLE_FP16\n" 11566- << " if (type_trans_mode == TypeTransMode_FP32_TO_FP16) {\n" 11567- << " Fp32CastToFp16((float *)(buffer[i]), (float16_t *)&outputs, num);\n" 11568+ ofs << " void *out_data = OH_AI_TensorGetMutableData(micro_tensor);\n"; 11569+ ofs << " if (type_trans_mode == TypeTransMode_FP32_TO_FP16) {\n" 11570+ << " Fp32CastToFp16((float *)(buffer[i]), (float16_t *)out_data, num);\n" 11571 << " type_changed[i] = true;\n" 11572 << " } else if (type_trans_mode == TypeTransMode_FP16_TO_FP32) {\n" 11573- << " Fp16CastToFp32((float16_t *)&outputs, (float *)(buffer[i]), num);\n" 11574+ << " Fp16CastToFp32((float16_t *)(buffer[i]), (float *)out_data, num);\n" 11575 << " type_changed[i] = true;\n" 11576 << " }\n" 11577+ << "#else\n" 11578+ << " return OH_AI_STATUS_LITE_NOT_SUPPORT;\n" 11579 << "#endif\n" 11580 << " }\n"; 11581- ofs << " return RET_OK;\n" 11582+ ofs << " return OH_AI_STATUS_SUCCESS;\n" 11583 "}\n\n"; 11584 } 11585 11586@@ -688,6 +848,16 @@ void CodeInitResourceImplement(std::ofstream &ofs, const std::unique_ptr<CoderCo 11587 "}\n"; 11588 } 11589 11590+void CodeResetImplement(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, const Configurator &config) { 11591+ ofs << "void Reset" << ctx->GetCurModelIndex() << "() {\n"; 11592+ auto &dynamic_symbols = config.dynamic_symbols(); 11593+ for (size_t i = 0; i < dynamic_symbols.size(); ++i) { 11594+ ofs << " store" << ctx->GetCurModelIndex() << "_" << i << " = -1;\n"; 11595+ } 11596+ ofs << " FreeResource" << ctx->GetCurModelIndex() << "();\n"; 11597+ ofs << "}\n"; 11598+} 11599+ 11600 void CodeFreeResourceState(std::ofstream &ofs) { ofs << free_resource_state; } 11601 11602 void CodeFreeResourceImplement(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, 11603diff --git a/mindspore/lite/tools/converter/micro/coder/generator/component/common_component.h b/mindspore/lite/tools/converter/micro/coder/generator/component/common_component.h 11604index 56209f05..6f0c7736 100644 11605--- a/mindspore/lite/tools/converter/micro/coder/generator/component/common_component.h 11606+++ b/mindspore/lite/tools/converter/micro/coder/generator/component/common_component.h 11607@@ -32,12 +32,13 @@ void CodeMSModelCalcWorkspaceSize(std::ofstream &ofs, const std::unique_ptr<Code 11608 void CodeCortexCalcWorkspaceSize(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx); 11609 void CodeMSModelSetWorkspace(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, const Configurator &config); 11610 void CodeCortexSetWorkspace(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx); 11611-void CodeMSTensorHandleArrayDestroyState(std::ofstream &ofs, const Configurator &config); 11612 void CodeMSModelCreateDefault(std::ofstream &ofs); 11613 void CodeMSModelCreate(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, const Configurator &config); 11614 void CodeMSModelBuildState(std::ofstream &ofs); 11615 void CodeMSModelBuildCommon(std::ofstream &ofs, const Configurator &config); 11616 void CodeMSModelBuild(std::ofstream &ofs, const int model_index, const size_t weight_size, const Configurator &config); 11617+void CodeMSModelResizeInit(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, const Configurator &config); 11618+void CodeMSModelResize(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, const Configurator &config); 11619 void CodeMSModelDestory(std::ofstream &ofs, const Configurator *config); 11620 void CodeMSModelPredictState(std::ofstream &ofs); 11621 void CodeMSModelPredictCommon(std::ofstream &ofs); 11622@@ -57,6 +58,7 @@ void CodeGraphQuantArgsImplement(std::ofstream &ofs, const std::unique_ptr<Coder 11623 void CodeManageResourceState(std::ofstream &ofs, const int model_index); 11624 void CodeInitResourceImplement(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx); 11625 11626+void CodeResetImplement(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, const Configurator &config); 11627 void CodeFreeResourceState(std::ofstream &ofs); 11628 void CodeFreeResourceImplement(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, 11629 const Configurator &config); 11630diff --git a/mindspore/lite/tools/converter/micro/coder/generator/component/component.cc b/mindspore/lite/tools/converter/micro/coder/generator/component/component.cc 11631index b2ed21be..0ee02e0c 100644 11632--- a/mindspore/lite/tools/converter/micro/coder/generator/component/component.cc 11633+++ b/mindspore/lite/tools/converter/micro/coder/generator/component/component.cc 11634@@ -24,6 +24,8 @@ const char *kOutputPrefixName = nullptr; 11635 const char *kWeightPrefixName = nullptr; 11636 const char *kBufferPrefixName = nullptr; 11637 const char *kBufferPrefixNameAdd = nullptr; 11638+const char *kOffsetPrefixName = nullptr; 11639+const char *kShapePrefixName = nullptr; 11640 11641 char *ModifyPrefixName(char *name, int model_index, const std::string &prefix) { 11642 if (name != nullptr) { 11643@@ -57,6 +59,8 @@ void FreeGlobalVariable() { 11644 Free(kWeightPrefixName); 11645 Free(kBufferPrefixName); 11646 Free(kBufferPrefixNameAdd); 11647+ Free(kOffsetPrefixName); 11648+ Free(kShapePrefixName) 11649 } 11650 11651 void InitGlobalVariable(int model_index) { 11652@@ -65,5 +69,7 @@ void InitGlobalVariable(int model_index) { 11653 kWeightPrefixName = ModifyPrefixName(const_cast<char *>(kWeightPrefixName), model_index, "_weight"); 11654 kBufferPrefixName = ModifyPrefixName(const_cast<char *>(kBufferPrefixName), model_index, "_buffer"); 11655 kBufferPrefixNameAdd = ModifyPrefixName(const_cast<char *>(kBufferPrefixNameAdd), model_index, "_buffer + "); 11656+ kOffsetPrefixName = ModifyPrefixName(const_cast<char *>(kOffsetPrefixName), model_index, "_offset"); 11657+ kShapePrefixName = ModifyPrefixName(const_cast<char *>(kShapePrefixName), model_index, "_shape"); 11658 } 11659 } // namespace mindspore::lite::micro 11660diff --git a/mindspore/lite/tools/converter/micro/coder/generator/component/component.h b/mindspore/lite/tools/converter/micro/coder/generator/component/component.h 11661index 0e943317..e084d692 100644 11662--- a/mindspore/lite/tools/converter/micro/coder/generator/component/component.h 11663+++ b/mindspore/lite/tools/converter/micro/coder/generator/component/component.h 11664@@ -16,7 +16,6 @@ 11665 11666 #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_GENERATOR_COMPONENT_COMPONENT_H_ 11667 #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_GENERATOR_COMPONENT_COMPONENT_H_ 11668-#include <string> 11669 11670 namespace mindspore::lite::micro { 11671 extern const char *kInputPrefixName; 11672@@ -26,6 +25,8 @@ constexpr auto kPackWeightOffsetName = "w_offset"; 11673 constexpr auto kPackWeightSizeName = "w_size"; 11674 extern const char *kBufferPrefixName; 11675 extern const char *kBufferPrefixNameAdd; 11676+extern const char *kOffsetPrefixName; 11677+extern const char *kShapePrefixName; 11678 void FreeGlobalVariable(); 11679 void InitGlobalVariable(int model_index); 11680 11681diff --git a/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/benchmark.cc b/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/benchmark.cc 11682index 91f2ca89..ad638276 100644 11683--- a/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/benchmark.cc 11684+++ b/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/benchmark.cc 11685@@ -53,7 +53,7 @@ const char benchmark_source[] = R"RAW(/** 11686 11687 void usage() { 11688 printf( 11689- "-- mindspore benchmark params usage:\n" 11690+ "-- mindspore benchmark paraOH_AI_ usage:\n" 11691 "args[0]: executable file\n" 11692 "args[1]: inputs binary file\n" 11693 "args[2]: model weight binary file\n" 11694@@ -67,38 +67,38 @@ void usage() { 11695 11696 uint64_t GetTimeUs() { 11697 const int USEC = 1000000; 11698- const int MSEC = 1000; 11699+ const int OH_AI_EC = 1000; 11700 struct timespec ts = {0, 0}; 11701 if (clock_gettime(CLOCK_MONOTONIC, &ts) != 0) { 11702 return 0; 11703 } 11704- uint64_t retval = (uint64_t)((ts.tv_sec * USEC) + (ts.tv_nsec / MSEC)); 11705+ uint64_t retval = (uint64_t)((ts.tv_sec * USEC) + (ts.tv_nsec / OH_AI_EC)); 11706 return retval; 11707 } 11708 11709-void PrintTensorHandle(MSTensorHandle tensor) { 11710- printf("name: %s, ", MSTensorGetName(tensor)); 11711- MSDataType data_type = MSTensorGetDataType(tensor); 11712+void PrintTensorHandle(OH_AI_TensorHandle tensor) { 11713+ printf("name: %s, ", OH_AI_TensorGetName(tensor)); 11714+ OH_AI_DataType data_type = OH_AI_TensorGetDataType(tensor); 11715 printf("DataType: %d, ", data_type); 11716- size_t element_num = (size_t)(MSTensorGetElementNum(tensor)); 11717+ size_t element_num = (size_t)(OH_AI_TensorGetElementNum(tensor)); 11718 printf("Elements: %zu, ", element_num); 11719 printf("Shape: ["); 11720 size_t shape_num = 0; 11721- const int64_t *dims = MSTensorGetShape(tensor, &shape_num); 11722+ const int64_t *dims = OH_AI_TensorGetShape(tensor, &shape_num); 11723 for (size_t i = 0; i < shape_num; i++) { 11724 printf("%d ", (int)dims[i]); 11725 } 11726 printf("], Data: \n"); 11727- void *data = MSTensorGetMutableData(tensor); 11728+ void *data = OH_AI_TensorGetMutableData(tensor); 11729 element_num = element_num > 10 ? 10 : element_num; 11730 switch (data_type) { 11731- case kMSDataTypeNumberTypeFloat32: { 11732+ case OH_AI_DATATYPE_NUMBERTYPE_FLOAT32: { 11733 for (size_t i = 0; i < element_num; i++) { 11734 printf("%.6f, ", ((float *)data)[i]); 11735 } 11736 printf("\n"); 11737 } break; 11738- case kMSDataTypeNumberTypeFloat16: 11739+ case OH_AI_DATATYPE_NUMBERTYPE_FLOAT16: 11740 #ifdef ENABLE_FP16 11741 { 11742 for (size_t i = 0; i < element_num; i++) { 11743@@ -107,25 +107,25 @@ void PrintTensorHandle(MSTensorHandle tensor) { 11744 printf("\n"); 11745 } break; 11746 #endif 11747- case kMSDataTypeNumberTypeInt16: { 11748+ case OH_AI_DATATYPE_NUMBERTYPE_INT16: { 11749 for (size_t i = 0; i < element_num; i++) { 11750 printf("%" PRId16, ((int16_t *)data)[i]); 11751 } 11752 printf("\n"); 11753 } break; 11754- case kMSDataTypeNumberTypeInt32: { 11755+ case OH_AI_DATATYPE_NUMBERTYPE_INT32: { 11756 for (size_t i = 0; i < element_num; i++) { 11757 printf("%" PRId32, ((int32_t *)data)[i]); 11758 } 11759 printf("\n"); 11760 } break; 11761- case kMSDataTypeNumberTypeInt8: { 11762+ case OH_AI_DATATYPE_NUMBERTYPE_INT8: { 11763 for (size_t i = 0; i < element_num; i++) { 11764 printf("%" PRIi8, ((int8_t *)data)[i]); 11765 } 11766 printf("\n"); 11767 } break; 11768- case kMSDataTypeNumberTypeUInt8: { 11769+ case OH_AI_DATATYPE_NUMBERTYPE_UINT8: { 11770 for (size_t i = 0; i < element_num; i++) { 11771 printf("%u", ((uint8_t *)data)[i]); 11772 } 11773@@ -141,31 +141,31 @@ int main(int argc, const char **argv) { 11774 if (argc < 2) { 11775 printf("input command is invalid\n"); 11776 usage(); 11777- return kMSStatusLiteError; 11778+ return OH_AI_STATUS_LITE_ERROR; 11779 } 11780 printf("=======run benchmark======\n"); 11781 11782- MSContextHandle ms_context_handle = MSContextCreate(); 11783+ OH_AI_ContextHandle ms_context_handle = OH_AI_ContextCreate(); 11784 if (argc >= 6) { 11785 int thread_num = atoi(argv[5]); 11786 if (thread_num < 1 || thread_num > kMaxThreadNum) { 11787 printf("Thread number error! It should be greater than 0 and less than 5\n"); 11788- return kMSStatusLiteParamInvalid; 11789+ return OH_AI_STATUS_LITE_PARAM_INVALID; 11790 } 11791- MSContextSetThreadNum(ms_context_handle, thread_num); 11792+ OH_AI_ContextSetThreadNum(ms_context_handle, thread_num); 11793 } 11794- printf("ThreadNum: %d.\n", MSContextGetThreadNum(ms_context_handle)); 11795+ printf("ThreadNum: %d.\n", OH_AI_ContextGetThreadNum(ms_context_handle)); 11796 11797 int bind_mode = kBindDefault; 11798 if (argc >= 7) { 11799 bind_mode = atoi(argv[6]); 11800 if (bind_mode < 0 || bind_mode > 2) { 11801 printf("Thread bind mode error! 0: No bind, 1: Bind hign cpu, 2: Bind mid cpu.\n"); 11802- return kMSStatusLiteParamInvalid; 11803+ return OH_AI_STATUS_LITE_PARAM_INVALID; 11804 } 11805 } 11806- MSContextSetThreadAffinityMode(ms_context_handle, bind_mode); 11807- printf("BindMode: %d.\n", MSContextGetThreadAffinityMode(ms_context_handle)); 11808+ OH_AI_ContextSetThreadAffinityMode(ms_context_handle, bind_mode); 11809+ printf("BindMode: %d.\n", OH_AI_ContextGetThreadAffinityMode(ms_context_handle)); 11810 11811 void *model_buffer = NULL; 11812 int model_size = 0; 11813@@ -174,14 +174,14 @@ int main(int argc, const char **argv) { 11814 model_buffer = ReadInputData(argv[2], &model_size); 11815 if (model_buffer == NULL) { 11816 printf("Read model file failed."); 11817- return kMSStatusLiteParamInvalid; 11818+ return OH_AI_STATUS_LITE_PARAM_INVALID; 11819 } 11820 } 11821- MSModelHandle model_handle = MSModelCreate(); 11822- int ret = MSModelBuild(model_handle, model_buffer, model_size, kMSModelTypeMindIR, ms_context_handle); 11823- MSContextDestroy(&ms_context_handle); 11824- if (ret != kMSStatusSuccess) { 11825- printf("MSModelBuildFromFile failed, ret: %d\n", ret); 11826+ OH_AI_ModelHandle model_handle = OH_AI_ModelCreate(); 11827+ int ret = OH_AI_ModelBuild(model_handle, model_buffer, model_size, OH_AI_MODELTYPE_MINDIR, ms_context_handle); 11828+ OH_AI_ContextDestroy(&ms_context_handle); 11829+ if (ret != OH_AI_STATUS_SUCCESS) { 11830+ printf("OH_AI_ModelBuild failed, ret: %d\n", ret); 11831 free(model_buffer); 11832 model_buffer = NULL; 11833 return ret; 11834@@ -191,33 +191,33 @@ int main(int argc, const char **argv) { 11835 model_buffer = NULL; 11836 } 11837 // set model inputs tensor data 11838- MSTensorHandleArray inputs_handle = MSModelGetInputs(model_handle); 11839+ OH_AI_TensorHandleArray inputs_handle = OH_AI_ModelGetInputs(model_handle); 11840 if (inputs_handle.handle_list == NULL) { 11841- printf("MSModelGetInputs failed, ret: %d", ret); 11842+ printf("OH_AI_ModelGetInputs failed, ret: %d", ret); 11843 return ret; 11844 } 11845 size_t inputs_num = inputs_handle.handle_num; 11846 void *inputs_binbuf[inputs_num]; 11847 int inputs_size[inputs_num]; 11848 for (size_t i = 0; i < inputs_num; ++i) { 11849- MSTensorHandle tensor = inputs_handle.handle_list[i]; 11850- inputs_size[i] = (int)MSTensorGetDataSize(tensor); 11851+ OH_AI_TensorHandle tensor = inputs_handle.handle_list[i]; 11852+ inputs_size[i] = (int)OH_AI_TensorGetDataSize(tensor); 11853 } 11854 ret = ReadInputsFile((char *)(argv[1]), inputs_binbuf, inputs_size, (int)inputs_num); 11855 if (ret != 0) { 11856- MSModelDestroy(&model_handle); 11857+ OH_AI_ModelDestroy(&model_handle); 11858 return ret; 11859 } 11860 for (size_t i = 0; i < inputs_num; ++i) { 11861- void *input_data = MSTensorGetMutableData(inputs_handle.handle_list[i]); 11862+ void *input_data = OH_AI_TensorGetMutableData(inputs_handle.handle_list[i]); 11863 memcpy(input_data, inputs_binbuf[i], inputs_size[i]); 11864 free(inputs_binbuf[i]); 11865 inputs_binbuf[i] = NULL; 11866 } 11867 11868- MSTensorHandleArray outputs_handle = MSModelGetOutputs(model_handle); 11869+ OH_AI_TensorHandleArray outputs_handle = OH_AI_ModelGetOutputs(model_handle); 11870 if (!outputs_handle.handle_list) { 11871- printf("MSModelGetOutputs failed, ret: %d", ret); 11872+ printf("OH_AI_ModelGetOutputs failed, ret: %d", ret); 11873 return ret; 11874 } 11875 11876@@ -226,15 +226,15 @@ int main(int argc, const char **argv) { 11877 warm_up_loop_count = atoi(argv[7]); 11878 if (warm_up_loop_count < 0) { 11879 printf("The warm up loop count error! Cannot be less than 0.\n"); 11880- return kMSStatusLiteParamInvalid; 11881+ return OH_AI_STATUS_LITE_PARAM_INVALID; 11882 } 11883 } 11884 printf("Running warm up loops..."); 11885 for (int i = 0; i < warm_up_loop_count; ++i) { 11886- ret = MSModelPredict(model_handle, inputs_handle, &outputs_handle, NULL, NULL); 11887- if (ret != kMSStatusSuccess) { 11888- MSModelDestroy(&model_handle); 11889- printf("MSModelPredict failed, ret: %d", ret); 11890+ ret = OH_AI_ModelPredict(model_handle, inputs_handle, &outputs_handle, NULL, NULL); 11891+ if (ret != OH_AI_STATUS_SUCCESS) { 11892+ OH_AI_ModelDestroy(&model_handle); 11893+ printf("OH_AI_ModelPredict failed, ret: %d", ret); 11894 return ret; 11895 } 11896 } 11897@@ -244,10 +244,10 @@ int main(int argc, const char **argv) { 11898 printf("\nloop count: %d\n", loop_count); 11899 uint64_t start_time = GetTimeUs(); 11900 for (int i = 0; i < loop_count; ++i) { 11901- ret = MSModelPredict(model_handle, inputs_handle, &outputs_handle, NULL, NULL); 11902- if (ret != kMSStatusSuccess) { 11903- MSModelDestroy(&model_handle); 11904- printf("MSModelPredict failed, ret: %d", ret); 11905+ ret = OH_AI_ModelPredict(model_handle, inputs_handle, &outputs_handle, NULL, NULL); 11906+ if (ret != OH_AI_STATUS_SUCCESS) { 11907+ OH_AI_ModelDestroy(&model_handle); 11908+ printf("OH_AI_ModelPredict failed, ret: %d", ret); 11909 return ret; 11910 } 11911 } 11912@@ -255,23 +255,23 @@ int main(int argc, const char **argv) { 11913 float total_time = (float)(end_time - start_time) / 1000.0f; 11914 printf("total time: %.5fms, per time: %.5fms\n", total_time, total_time / loop_count); 11915 } 11916- ret = MSModelPredict(model_handle, inputs_handle, &outputs_handle, NULL, NULL); 11917- if (ret != kMSStatusSuccess) { 11918- MSModelDestroy(&model_handle); 11919+ ret = OH_AI_ModelPredict(model_handle, inputs_handle, &outputs_handle, NULL, NULL); 11920+ if (ret != OH_AI_STATUS_SUCCESS) { 11921+ OH_AI_ModelDestroy(&model_handle); 11922 return ret; 11923 } 11924 printf("========run success=======\n"); 11925 printf("\noutputs: \n"); 11926 for (size_t i = 0; i < outputs_handle.handle_num; i++) { 11927- MSTensorHandle output = outputs_handle.handle_list[i]; 11928+ OH_AI_TensorHandle output = outputs_handle.handle_list[i]; 11929 PrintTensorHandle(output); 11930 } 11931 if (argc >= 5) { 11932 CalibTensor *calib_tensors; 11933 int calib_num = 0; 11934 ret = ReadCalibData(argv[4], &calib_tensors, &calib_num); 11935- if (ret != kMSStatusSuccess) { 11936- MSModelDestroy(&model_handle); 11937+ if (ret != OH_AI_STATUS_SUCCESS) { 11938+ OH_AI_ModelDestroy(&model_handle); 11939 return ret; 11940 } 11941 float cosine_distance_threshold = 0.9999; 11942@@ -279,15 +279,15 @@ int main(int argc, const char **argv) { 11943 cosine_distance_threshold = atof(argv[8]); 11944 } 11945 ret = CompareOutputs(outputs_handle, &calib_tensors, calib_num, cosine_distance_threshold); 11946- if (ret != kMSStatusSuccess) { 11947- MSModelDestroy(&model_handle); 11948+ if (ret != OH_AI_STATUS_SUCCESS) { 11949+ OH_AI_ModelDestroy(&model_handle); 11950 return ret; 11951 } 11952 FreeCalibTensors(&calib_tensors, calib_num); 11953 } 11954 printf("========run success=======\n"); 11955- MSModelDestroy(&model_handle); 11956- return kMSStatusSuccess; 11957+ OH_AI_ModelDestroy(&model_handle); 11958+ return OH_AI_STATUS_SUCCESS; 11959 } 11960 )RAW"; 11961 11962@@ -385,7 +385,7 @@ int benchmark() { 11963 return kMSStatusLiteError; 11964 } 11965 MSModelSetWorkspace(model_handle, g_WorkSpace, WORK_SPACE_SIZE); 11966- ret = MSModelBuild(model_handle, NULL, 0, kMSModelTypeMindIR, NULL); 11967+ ret = OH_AI_ModelBuild(model_handle, NULL, 0, kMSModelTypeMindIR, NULL); 11968 if (ret != kMSStatusSuccess) { 11969 printf("MSModelBuildFromFile failed, ret : %d.\n", ret); 11970 MSModelDestroy(&model_handle); 11971@@ -424,7 +424,7 @@ int benchmark() { 11972 } 11973 11974 printf("========Infer start=======\n"); 11975- ret = MSModelPredict(model_handle, inputs_handle, &outputs_handle, NULL, NULL); 11976+ ret = OH_AI_ModelPredict(model_handle, inputs_handle, &outputs_handle, NULL, NULL); 11977 if (ret != kMSStatusSuccess) { 11978 MSModelDestroy(&model_handle); 11979 return ret; 11980diff --git a/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/calib_output.cc b/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/calib_output.cc 11981index 71ca2287..66af9069 100644 11982--- a/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/calib_output.cc 11983+++ b/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/calib_output.cc 11984@@ -48,7 +48,7 @@ typedef struct CalibTensor { 11985 float *data_; 11986 } CalibTensor; 11987 int ReadCalibData(const char *calib_data_path, CalibTensor **calib_tensots, int *calib_num); 11988-int CompareOutputs(MSTensorHandleArray outputs, CalibTensor **calib_tensors, int calib_num, 11989+int CompareOutputs(OH_AI_TensorHandleArray outputs, CalibTensor **calib_tensors, int calib_num, 11990 float cosine_distance_threshold); 11991 void FreeCalibTensors(CalibTensor **calib_tensors, int calib_num); 11992 11993@@ -89,12 +89,12 @@ int ReadCalibData(const char *calib_data_path, CalibTensor **calib_tensor_pointe 11994 FILE *file = fopen(calib_data_path, "r"); 11995 if (!file) { 11996 printf("Unable open %s", calib_data_path); 11997- return kMSStatusLiteError; 11998+ return OH_AI_STATUS_LITE_ERROR; 11999 } 12000 CalibTensor *calib_tensors = (CalibTensor *)malloc(kMaxOutput * sizeof(CalibTensor)); 12001 if(calib_tensors == NULL) { 12002 printf("Malloc calib tensors failed."); 12003- return kMSStatusLiteError; 12004+ return OH_AI_STATUS_LITE_ERROR; 12005 } 12006 // read line by line 12007 char line[kMaxTensorSize]; 12008@@ -111,7 +111,7 @@ int ReadCalibData(const char *calib_data_path, CalibTensor **calib_tensor_pointe 12009 char* tensor_name = (char *)malloc(strlen(p)+1); 12010 if(tensor_name == NULL) { 12011 printf("Malloc tensor name failed."); 12012- return kMSStatusLiteError; 12013+ return OH_AI_STATUS_LITE_ERROR; 12014 } 12015 (void)strcpy(tensor_name, p); 12016 calib_tensors[*calib_num].tensor_name = tensor_name; 12017@@ -134,7 +134,7 @@ int ReadCalibData(const char *calib_data_path, CalibTensor **calib_tensor_pointe 12018 float *data = (float *)malloc(elements * sizeof(float)); 12019 if(data == NULL) { 12020 printf("Malloc tensor data failed."); 12021- return kMSStatusLiteError; 12022+ return OH_AI_STATUS_LITE_ERROR; 12023 } 12024 p = strtok(line, " "); 12025 int k = 0; 12026@@ -152,43 +152,43 @@ int ReadCalibData(const char *calib_data_path, CalibTensor **calib_tensor_pointe 12027 } 12028 *calib_tensor_pointers = calib_tensors; 12029 fclose(file); 12030- return kMSStatusSuccess; 12031+ return OH_AI_STATUS_SUCCESS; 12032 } 12033 12034-int CompareOutputs(MSTensorHandleArray outputs, CalibTensor **calib_tensors, int calib_num, 12035+int CompareOutputs(OH_AI_TensorHandleArray outputs, CalibTensor **calib_tensors, int calib_num, 12036 float cosine_distance_threshold) { 12037 if (outputs.handle_num != (size_t)calib_num) { 12038 printf("error, outputs and calibs size is mismatch\n"); 12039- return kMSStatusLiteError; 12040+ return OH_AI_STATUS_LITE_ERROR; 12041 } 12042 size_t outputs_num = outputs.handle_num; 12043 bool is_success = true; 12044 for (size_t i = 0; i < outputs_num; ++i) { 12045 MicroTensor *output = (MicroTensor *)outputs.handle_list[i]; 12046 if (!output || !output->data) { 12047- return kMSStatusLiteError; 12048+ return OH_AI_STATUS_LITE_ERROR; 12049 } 12050 CalibTensor *calib = calib_tensors[0]; 12051 if (!calib || !calib[i].data_) { 12052- return kMSStatusLiteError; 12053+ return OH_AI_STATUS_LITE_ERROR; 12054 } 12055 if (strcmp(output->name, calib[i].tensor_name) != 0) { 12056 printf("warning, output tensor name is not equal to calib\n"); 12057 } 12058- size_t elements = (size_t)MSTensorGetElementNum(output); 12059+ size_t elements = (size_t)OH_AI_TensorGetElementNum(output); 12060 if (elements != (size_t)calib[i].elemets_num_) { 12061 printf("error, output elements num is not equal to calib\n"); 12062- return kMSStatusLiteError; 12063+ return OH_AI_STATUS_LITE_ERROR; 12064 } 12065 float cosin = 0.f, dot = 0.f, normx = 0.f, normy = 0.f; 12066 switch (output->type) { 12067- case kMSDataTypeNumberTypeFloat32: { 12068+ case OH_AI_DATATYPE_NUMBERTYPE_FLOAT32: { 12069 float *float_output = (float *)output->data; 12070 for (size_t j = 0; j < elements; ++j) { 12071 if (isnan(float_output[j]) || isinf(float_output[j]) || isnan(calib[i].data_[j]) || 12072 isinf(calib[i].data_[j])) { 12073 printf("error, output data is nan or inf\n"); 12074- return kMSStatusLiteError; 12075+ return OH_AI_STATUS_LITE_ERROR; 12076 } 12077 dot += float_output[j] * calib[i].data_[j]; 12078 normx += float_output[j] * float_output[j]; 12079@@ -196,7 +196,7 @@ int CompareOutputs(MSTensorHandleArray outputs, CalibTensor **calib_tensors, int 12080 } 12081 break; 12082 } 12083- case kMSDataTypeNumberTypeInt8: { 12084+ case OH_AI_DATATYPE_NUMBERTYPE_INT8: { 12085 int8_t *int_output = (int8_t *)output->data; 12086 for (size_t j = 0; j < elements; ++j) { 12087 dot += (float) (int_output[j] * calib[i].data_[j]); 12088@@ -205,7 +205,7 @@ int CompareOutputs(MSTensorHandleArray outputs, CalibTensor **calib_tensors, int 12089 } 12090 break; 12091 } 12092- case kMSDataTypeNumberTypeUInt8: { 12093+ case OH_AI_DATATYPE_NUMBERTYPE_UINT8: { 12094 uint8_t *int_output = (uint8_t *)output->data; 12095 for (size_t j = 0; j < elements; ++j) { 12096 dot += (float) (int_output[j] * calib[i].data_[j]); 12097@@ -214,8 +214,8 @@ int CompareOutputs(MSTensorHandleArray outputs, CalibTensor **calib_tensors, int 12098 } 12099 break; 12100 } 12101- case kMSDataTypeNumberTypeInt32: 12102- case kMSDataTypeNumberTypeUInt32: { 12103+ case OH_AI_DATATYPE_NUMBERTYPE_INT32: 12104+ case OH_AI_DATATYPE_NUMBERTYPE_UINT32: { 12105 int32_t *int_output = (int32_t *)output->data; 12106 for (size_t j = 0; j < elements; ++j) { 12107 dot += (float) (int_output[j] * calib[i].data_[j]); 12108@@ -238,10 +238,10 @@ int CompareOutputs(MSTensorHandleArray outputs, CalibTensor **calib_tensors, int 12109 } 12110 if (!is_success) { 12111 printf("compare outputs failed.\n"); 12112- return kMSStatusLiteError; 12113+ return OH_AI_STATUS_LITE_ERROR; 12114 } 12115 printf("compare outputs success.\n"); 12116- return kMSStatusSuccess; 12117+ return OH_AI_STATUS_SUCCESS; 12118 } 12119 12120 void FreeCalibTensors(CalibTensor **calib_tensors_pointers, int calib_num) { 12121@@ -328,7 +328,7 @@ const char *calib_source_cortex = R"RAW(/** 12122 int LoadCalibInputs(MSTensorHandleArray *inputs, TensorArray *tensor_array) { 12123 if (inputs->handle_num != tensor_array->tensors_size_) { 12124 printf("error, inputs and calibs size is mismatch.\n"); 12125- return kMSStatusLiteError; 12126+ return OH_AI_STATUS_LITE_ERROR; 12127 } 12128 Tensor *calib_tensors = tensor_array->tensors_; 12129 if (calib_tensors == NULL) { 12130diff --git a/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/cmake_lists.cc b/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/cmake_lists.cc 12131index 79bfc485..f63e6f9e 100644 12132--- a/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/cmake_lists.cc 12133+++ b/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/cmake_lists.cc 12134@@ -127,9 +127,9 @@ if("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") 12135 set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=default") 12136 else() 12137 message(STATUS "build benchmark release version") 12138- set(CMAKE_C_FLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O3 -Wall -Werror -fstack-protector-strong -Wno-attributes \ 12139+ set(CMAKE_C_FLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O3 -Wall -fstack-protector-strong -Wno-attributes \ 12140 -Wno-deprecated-declarations -Wno-missing-braces ${CMAKE_C_FLAGS}") 12141- set(CMAKE_CXX_FLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O3 -Wall -Werror -fstack-protector-strong -Wno-attributes \ 12142+ set(CMAKE_CXX_FLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O3 -Wall -fstack-protector-strong -Wno-attributes \ 12143 -Wno-deprecated-declarations -Wno-missing-braces -Wno-overloaded-virtual ${CMAKE_CXX_FLAGS}") 12144 string(REPLACE "-g" "" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") 12145 string(REPLACE "-g" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") 12146@@ -211,9 +211,9 @@ if("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") 12147 set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=default") 12148 else() 12149 message(STATUS "build net library release version") 12150- set(CMAKE_C_FLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O3 -Wall -Werror -fstack-protector-strong -Wno-attributes \ 12151+ set(CMAKE_C_FLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O3 -Wall -fstack-protector-strong -Wno-attributes \ 12152 -Wno-deprecated-declarations -Wno-missing-braces ${CMAKE_C_FLAGS}") 12153- set(CMAKE_CXX_FLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O3 -Wall -Werror -fstack-protector-strong -Wno-attributes \ 12154+ set(CMAKE_CXX_FLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O3 -Wall -fstack-protector-strong -Wno-attributes \ 12155 -Wno-deprecated-declarations -Wno-missing-braces -Wno-overloaded-virtual ${CMAKE_CXX_FLAGS}") 12156 string(REPLACE "-g" "" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") 12157 string(REPLACE "-g" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") 12158@@ -241,11 +241,11 @@ function(create_library) 12159 endforeach() 12160 add_custom_command(TARGET net 12161 POST_BUILD 12162- COMMAND ar cr ${library_name} *.o 12163+ COMMAND ar cr ${library_name} *.obj 12164 COMMAND ranlib ${library_name} 12165 COMMAND echo "new static library ${library_name} size:" 12166 COMMAND ls -lh ${library_name} 12167- COMMAND rm -rf tmp && rm -rf *.o 12168+ COMMAND rm -rf tmp && rm -rf *.obj 12169 COMMENT "generate specified static library ${library_name}" 12170 ) 12171 endfunction(create_library) 12172diff --git a/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/load_input.cc b/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/load_input.cc 12173index 9a2aeaa7..669cd8c1 100644 12174--- a/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/load_input.cc 12175+++ b/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/load_input.cc 12176@@ -131,7 +131,7 @@ int ReadInputsFile(char *path, void **buffers, const int *inputs_size, int input 12177 while ((token = strtok_r(path, delim, &path))) { 12178 if (i >= inputs_num) { 12179 printf("inputs num is error, need: %d\n", inputs_num); 12180- return kMSStatusLiteParamInvalid; 12181+ return OH_AI_STATUS_LITE_PARAM_INVALID; 12182 } 12183 inputs_path[i] = token; 12184 printf("input %d: %s\n", i, inputs_path[i]); 12185@@ -144,7 +144,7 @@ int ReadInputsFile(char *path, void **buffers, const int *inputs_size, int input 12186 if (size != inputs_size[i] || buffers[i] == NULL) { 12187 printf("size mismatch, %s, input: %d, needed: %d\n", inputs_path[i], size, inputs_size[i]); 12188 free(buffers[i]); 12189- return kMSStatusLiteError; 12190+ return OH_AI_STATUS_LITE_ERROR; 12191 } 12192 } 12193 return 0; 12194diff --git a/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/mcontext.cc b/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/mcontext.cc 12195index 856de855..d662e3a8 100644 12196--- a/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/mcontext.cc 12197+++ b/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/mcontext.cc 12198@@ -73,24 +73,24 @@ const char context_source_cortex[] = R"RAW( 12199 #include <stdlib.h> 12200 #include <string.h> 12201 12202-MSContextHandle MSContextCreate() { 12203+OH_AI_ContextHandle OH_AI_ContextCreate() { 12204 return NULL; 12205 } 12206 12207-void MSContextDestroy(MSContextHandle *context) { 12208+void OH_AI_ContextDestroy(OH_AI_ContextHandle *context) { 12209 } 12210 12211-void MSContextSetThreadNum(MSContextHandle context, int32_t thread_num) { 12212+void OH_AI_ContextSetThreadNum(OH_AI_ContextHandle context, int32_t thread_num) { 12213 } 12214 12215-int32_t MSContextGetThreadNum(const MSContextHandle context) { 12216+int32_t OH_AI_ContextGetThreadNum(const OH_AI_ContextHandle context) { 12217 return 1; 12218 } 12219 12220-void MSContextSetThreadAffinityMode(MSContextHandle context, int mode) { 12221+void OH_AI_ContextSetThreadAffinityMode(OH_AI_ContextHandle context, int mode) { 12222 } 12223 12224-int MSContextGetThreadAffinityMode(const MSContextHandle context) { 12225+int OH_AI_ContextGetThreadAffinityMode(const OH_AI_ContextHandle context) { 12226 return 0; 12227 } 12228 )RAW"; 12229@@ -116,7 +116,7 @@ const char context_source_no_parallel[] = R"RAW( 12230 #include <stdlib.h> 12231 #include <string.h> 12232 12233-MSContextHandle MSContextCreate() { 12234+OH_AI_ContextHandle OH_AI_ContextCreate() { 12235 MicroContext *micro_context = (MicroContext *)malloc(sizeof(MicroContext)); 12236 if (micro_context == NULL) { 12237 return NULL; 12238@@ -129,7 +129,7 @@ MSContextHandle MSContextCreate() { 12239 return micro_context; 12240 } 12241 12242-void MSContextDestroy(MSContextHandle *context) { 12243+void OH_AI_ContextDestroy(OH_AI_ContextHandle *context) { 12244 MicroContext *micro_context = (MicroContext *)(*context); 12245 if (micro_context) { 12246 free(micro_context); 12247@@ -137,17 +137,17 @@ void MSContextDestroy(MSContextHandle *context) { 12248 } 12249 } 12250 12251-void MSContextSetThreadNum(MSContextHandle context, int32_t thread_num) { 12252+void OH_AI_ContextSetThreadNum(OH_AI_ContextHandle context, int32_t thread_num) { 12253 } 12254 12255-int32_t MSContextGetThreadNum(const MSContextHandle context) { 12256+int32_t OH_AI_ContextGetThreadNum(const OH_AI_ContextHandle context) { 12257 return 1; 12258 } 12259 12260-void MSContextSetThreadAffinityMode(MSContextHandle context, int mode) { 12261+void OH_AI_ContextSetThreadAffinityMode(OH_AI_ContextHandle context, int mode) { 12262 } 12263 12264-int MSContextGetThreadAffinityMode(const MSContextHandle context) { 12265+int OH_AI_ContextGetThreadAffinityMode(const OH_AI_ContextHandle context) { 12266 return 0; 12267 } 12268 )RAW"; 12269@@ -176,7 +176,7 @@ const char context_source[] = R"RAW( 12270 12271 #define MAX_THREAD_NUM 4 12272 12273-MSContextHandle MSContextCreate() { 12274+OH_AI_ContextHandle OH_AI_ContextCreate() { 12275 MicroContext *micro_context = (MicroContext *)malloc(sizeof(MicroContext)); 12276 if (micro_context == NULL) { 12277 return NULL; 12278@@ -189,7 +189,7 @@ MSContextHandle MSContextCreate() { 12279 return micro_context; 12280 } 12281 12282-void MSContextDestroy(MSContextHandle *context) { 12283+void OH_AI_ContextDestroy(OH_AI_ContextHandle *context) { 12284 MicroContext *micro_context = (MicroContext *)(*context); 12285 if (micro_context) { 12286 if (micro_context->affinity_core_list_) { 12287@@ -201,7 +201,7 @@ void MSContextDestroy(MSContextHandle *context) { 12288 } 12289 } 12290 12291-void MSContextSetThreadNum(MSContextHandle context, int32_t thread_num) { 12292+void OH_AI_ContextSetThreadNum(OH_AI_ContextHandle context, int32_t thread_num) { 12293 MicroContext *micro_context = (MicroContext *)context; 12294 if (micro_context) { 12295 int core_num = GetCpuCoreNum(); 12296@@ -214,7 +214,7 @@ void MSContextSetThreadNum(MSContextHandle context, int32_t thread_num) { 12297 } 12298 } 12299 12300-int32_t MSContextGetThreadNum(const MSContextHandle context) { 12301+int32_t OH_AI_ContextGetThreadNum(const OH_AI_ContextHandle context) { 12302 MicroContext *micro_context = (MicroContext *)context; 12303 if (micro_context) { 12304 return micro_context->thread_num_; 12305@@ -222,7 +222,7 @@ int32_t MSContextGetThreadNum(const MSContextHandle context) { 12306 return 0; 12307 } 12308 12309-void MSContextSetThreadAffinityMode(MSContextHandle context, int mode) { 12310+void OH_AI_ContextSetThreadAffinityMode(OH_AI_ContextHandle context, int mode) { 12311 MicroContext *micro_context = (MicroContext *)context; 12312 if (micro_context) { 12313 if (mode >= 0 && mode <= 2) { 12314@@ -233,7 +233,7 @@ void MSContextSetThreadAffinityMode(MSContextHandle context, int mode) { 12315 } 12316 } 12317 12318-int MSContextGetThreadAffinityMode(const MSContextHandle context) { 12319+int OH_AI_ContextGetThreadAffinityMode(const OH_AI_ContextHandle context) { 12320 MicroContext *micro_context = (MicroContext *)context; 12321 if (micro_context) { 12322 return micro_context->affinity_mode; 12323diff --git a/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/msession.cc b/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/msession.cc 12324index 44273071..5cbe4507 100644 12325--- a/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/msession.cc 12326+++ b/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/msession.cc 12327@@ -18,25 +18,25 @@ 12328 12329 namespace mindspore::lite::micro { 12330 const char model_runtime_other_source[] = R"RAW( 12331-MSTensorHandleArray MSModelGetInputs(const MSModelHandle model) { 12332+OH_AI_TensorHandleArray OH_AI_ModelGetInputs(const OH_AI_ModelHandle model) { 12333 MicroModel *micro_model = (MicroModel *)model; 12334 if (micro_model == NULL) { 12335- MSTensorHandleArray tmp = {0, NULL}; 12336+ OH_AI_TensorHandleArray tmp = {0, NULL}; 12337 return tmp; 12338 } 12339 return micro_model->inputs; 12340 } 12341 12342-MSTensorHandleArray MSModelGetOutputs(const MSModelHandle model) { 12343+OH_AI_TensorHandleArray OH_AI_ModelGetOutputs(const OH_AI_ModelHandle model) { 12344 MicroModel *micro_model = (MicroModel *)model; 12345 if (micro_model == NULL) { 12346- MSTensorHandleArray tmp = {0, NULL}; 12347+ OH_AI_TensorHandleArray tmp = {0, NULL}; 12348 return tmp; 12349 } 12350 return micro_model->outputs; 12351 } 12352 12353-MSTensorHandle MSModelGetInputByTensorName(const MSModelHandle model, const char *tensor_name) { 12354+OH_AI_TensorHandle OH_AI_ModelGetInputByTensorName(const OH_AI_ModelHandle model, const char *tensor_name) { 12355 MicroModel *micro_model = (MicroModel *)model; 12356 if (micro_model == NULL || micro_model->inputs.handle_list == NULL) { 12357 return NULL; 12358@@ -53,7 +53,7 @@ MSTensorHandle MSModelGetInputByTensorName(const MSModelHandle model, const char 12359 return NULL; 12360 } 12361 12362-MSTensorHandle MSModelGetOutputByTensorName(const MSModelHandle model, const char *tensor_name) { 12363+OH_AI_TensorHandle OH_AI_ModelGetOutputByTensorName(const OH_AI_ModelHandle model, const char *tensor_name) { 12364 MicroModel *micro_model = (MicroModel *)model; 12365 if (micro_model == NULL || micro_model->outputs.handle_list == NULL) { 12366 return NULL; 12367@@ -70,9 +70,16 @@ MSTensorHandle MSModelGetOutputByTensorName(const MSModelHandle model, const cha 12368 return NULL; 12369 } 12370 12371-MSStatus MSModelResize(MSModelHandle model, const MSTensorHandleArray inputs, MSShapeInfo *shape_infos, 12372+OH_AI_Status OH_AI_ModelResize(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray inputs, OH_AI_ShapeInfo *shape_infos, 12373 size_t shape_info_num) { 12374- return kMSStatusLiteNotSupport; 12375+ MicroModel *micro_model = (MicroModel *)model; 12376+ if (micro_model == NULL) { 12377+ return OH_AI_STATUS_LITE_NULLPTR; 12378+ } 12379+ if (micro_model->resize == NULL) { 12380+ return OH_AI_STATUS_LITE_NULLPTR; 12381+ } 12382+ return micro_model->resize(model, inputs, shape_infos, shape_info_num); 12383 } 12384 12385 )RAW"; 12386diff --git a/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/mtensor.cc b/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/mtensor.cc 12387index b125b31d..e4581829 100644 12388--- a/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/mtensor.cc 12389+++ b/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/mtensor.cc 12390@@ -46,8 +46,8 @@ const char tensor_header[] = R"RAW( 12391 #endif 12392 12393 typedef struct { 12394- enum MSDataType type; 12395- enum MSFormat format; 12396+ enum OH_AI_DataType type; 12397+ enum OH_AI_Format format; 12398 char *name; 12399 int ndim; 12400 int64_t *shape; 12401@@ -76,7 +76,7 @@ enum TypeTransMode { 12402 TypeTransMode_MAX = TypeTransMode_UNSUPPORT 12403 }; 12404 12405-void *TransformInput(MSTensorHandle tensor, int expect_type, bool *type_changed); 12406+void *TransformInput(OH_AI_TensorHandle tensor, int expect_type, bool *type_changed); 12407 12408 #ifdef ENABLE_FP16 12409 void Fp32CastToFp16(const float *input, float16_t *output, int number); 12410@@ -109,37 +109,37 @@ const char tensor_source[] = R"RAW( 12411 #include "string.h" 12412 #include "tensor.h" 12413 12414-size_t DataTypeSize(const MSDataType type) { 12415+size_t DataTypeSize(const OH_AI_DataType type) { 12416 switch (type) { 12417- case kMSDataTypeNumberTypeFloat64: 12418+ case OH_AI_DATATYPE_NUMBERTYPE_FLOAT64: 12419 return sizeof(double); 12420- case kMSDataTypeNumberTypeFloat32: 12421+ case OH_AI_DATATYPE_NUMBERTYPE_FLOAT32: 12422 return sizeof(float); 12423- case kMSDataTypeNumberTypeInt8: 12424+ case OH_AI_DATATYPE_NUMBERTYPE_INT8: 12425 return sizeof(int8_t); 12426- case kMSDataTypeNumberTypeUInt8: 12427+ case OH_AI_DATATYPE_NUMBERTYPE_UINT8: 12428 return sizeof(uint8_t); 12429- case kMSDataTypeNumberTypeFloat16: 12430- case kMSDataTypeNumberTypeInt16: 12431+ case OH_AI_DATATYPE_NUMBERTYPE_FLOAT16: 12432+ case OH_AI_DATATYPE_NUMBERTYPE_INT16: 12433 return sizeof(int16_t); 12434- case kMSDataTypeNumberTypeInt32: 12435+ case OH_AI_DATATYPE_NUMBERTYPE_INT32: 12436 return sizeof(int32_t); 12437- case kMSDataTypeNumberTypeInt64: 12438+ case OH_AI_DATATYPE_NUMBERTYPE_INT64: 12439 return sizeof(int64_t); 12440- case kMSDataTypeNumberTypeUInt16: 12441+ case OH_AI_DATATYPE_NUMBERTYPE_UINT16: 12442 return sizeof(uint16_t); 12443- case kMSDataTypeNumberTypeUInt32: 12444+ case OH_AI_DATATYPE_NUMBERTYPE_UINT32: 12445 return sizeof(uint32_t); 12446- case kMSDataTypeNumberTypeUInt64: 12447+ case OH_AI_DATATYPE_NUMBERTYPE_UINT64: 12448 return sizeof(uint64_t); 12449- case kMSDataTypeObjectTypeString: 12450+ case OH_AI_DATATYPE_OBJECTTYPE_STRING: 12451 return sizeof(char); 12452 default: 12453 return 0; 12454 } 12455 } 12456 12457-MSTensorHandle MSTensorCreate(const char *name, MSDataType type, const int64_t *shape, size_t shape_num, 12458+OH_AI_TensorHandle OH_AI_TensorCreate(const char *name, OH_AI_DataType type, const int64_t *shape, size_t shape_num, 12459 const void *data, size_t data_len) { 12460 size_t data_type_len = DataTypeSize(type); 12461 size_t acc_sum = 1; 12462@@ -160,16 +160,16 @@ MSTensorHandle MSTensorCreate(const char *name, MSDataType type, const int64_t * 12463 memcpy(micro_tensor->data, data, data_len); 12464 micro_tensor->shape = malloc(shape_num * sizeof(int64_t)); 12465 memcpy(micro_tensor->shape, shape, shape_num * sizeof(int64_t)); 12466- micro_tensor->format = kMSFormatNHWC; 12467+ micro_tensor->format = OH_AI_FORMAT_NHWC; 12468 return micro_tensor; 12469 } 12470 12471-void MSTensorDestroy(MSTensorHandle *tensor) { 12472+void OH_AI_TensorDestroy(OH_AI_TensorHandle *tensor) { 12473 MicroTensor* micro_tensor = (MicroTensor*)(*tensor); 12474 free(micro_tensor); 12475 } 12476 12477-void MSTensorSetName(MSTensorHandle tensor, const char *name) { 12478+void OH_AI_TensorSetName(OH_AI_TensorHandle tensor, const char *name) { 12479 MicroTensor* micro_tensor = (MicroTensor*)(tensor); 12480 if(micro_tensor->name != NULL) { 12481 free(micro_tensor->name); 12482@@ -179,10 +179,10 @@ void MSTensorSetName(MSTensorHandle tensor, const char *name) { 12483 memcpy(micro_tensor->name, name, len + 1); 12484 } 12485 12486-MSTensorHandle MSTensorClone(MSTensorHandle tensor) { 12487+OH_AI_TensorHandle OH_AI_TensorClone(OH_AI_TensorHandle tensor) { 12488 MicroTensor* micro_tensor = (MicroTensor*)(tensor); 12489 MicroTensor *clone_tensor = malloc( sizeof(MicroTensor)); 12490- size_t tensor_data_size = MSTensorGetDataSize(micro_tensor); 12491+ size_t tensor_data_size = OH_AI_TensorGetDataSize(micro_tensor); 12492 clone_tensor->data = malloc(tensor_data_size); 12493 clone_tensor->owned = true; 12494 memcpy(clone_tensor->data,micro_tensor->data,tensor_data_size); 12495@@ -195,26 +195,26 @@ MSTensorHandle MSTensorClone(MSTensorHandle tensor) { 12496 clone_tensor->shape = clone_shape; 12497 char* clone_name = malloc(strlen(micro_tensor->name)); 12498 strcpy(clone_name,micro_tensor->name); 12499- clone_tensor->format = kMSFormatNHWC; 12500+ clone_tensor->format = OH_AI_FORMAT_NHWC; 12501 return clone_tensor; 12502 } 12503 12504-const char *MSTensorGetName(const MSTensorHandle tensor) { 12505+const char *OH_AI_TensorGetName(const OH_AI_TensorHandle tensor) { 12506 MicroTensor* micro_tensor = (MicroTensor*)(tensor); 12507 return micro_tensor->name; 12508 } 12509 12510-void MSTensorSetDataType(MSTensorHandle tensor, MSDataType type) { 12511+void OH_AI_TensorSetDataType(OH_AI_TensorHandle tensor, OH_AI_DataType type) { 12512 MicroTensor* micro_tensor = (MicroTensor*)(tensor); 12513 micro_tensor->type = type; 12514 } 12515 12516-MSDataType MSTensorGetDataType(const MSTensorHandle tensor) { 12517+OH_AI_DataType OH_AI_TensorGetDataType(const OH_AI_TensorHandle tensor) { 12518 MicroTensor* micro_tensor = (MicroTensor*)(tensor); 12519 return micro_tensor->type; 12520 } 12521 12522-void MSTensorSetShape(MSTensorHandle tensor, const int64_t *shape, size_t shape_num) { 12523+void OH_AI_TensorSetShape(OH_AI_TensorHandle tensor, const int64_t *shape, size_t shape_num) { 12524 MicroTensor* micro_tensor = (MicroTensor*)(tensor); 12525 if(micro_tensor->shape != NULL) { 12526 free(micro_tensor->shape); 12527@@ -224,23 +224,23 @@ void MSTensorSetShape(MSTensorHandle tensor, const int64_t *shape, size_t shape_ 12528 memcpy(micro_tensor->shape, shape, shape_num * sizeof(int64_t)); 12529 } 12530 12531-const int64_t *MSTensorGetShape(const MSTensorHandle tensor, size_t *shape_num) { 12532+const int64_t *OH_AI_TensorGetShape(const OH_AI_TensorHandle tensor, size_t *shape_num) { 12533 MicroTensor* micro_tensor = (MicroTensor*)(tensor); 12534 *shape_num = micro_tensor->ndim; 12535 return micro_tensor->shape; 12536 } 12537 12538-void MSTensorSetFormat(MSTensorHandle tensor, MSFormat format) { 12539+void OH_AI_TensorSetFormat(OH_AI_TensorHandle tensor, OH_AI_Format format) { 12540 MicroTensor* micro_tensor = (MicroTensor*)(tensor); 12541 micro_tensor->format = format; 12542 } 12543 12544-MSFormat MSTensorGetFormat(const MSTensorHandle tensor) { 12545+OH_AI_Format OH_AI_TensorGetFormat(const OH_AI_TensorHandle tensor) { 12546 MicroTensor* micro_tensor = (MicroTensor*)(tensor); 12547 return micro_tensor->format; 12548 } 12549 12550-void MSTensorSetData(MSTensorHandle tensor, void *data) { 12551+void OH_AI_TensorSetData(OH_AI_TensorHandle tensor, void *data) { 12552 MicroTensor* micro_tensor = (MicroTensor*)(tensor); 12553 if (micro_tensor->data == data) { 12554 return; 12555@@ -254,23 +254,23 @@ void MSTensorSetData(MSTensorHandle tensor, void *data) { 12556 micro_tensor->data = data; 12557 } 12558 12559-const void *MSTensorGetData(const MSTensorHandle tensor) { 12560+const void *OH_AI_TensorGetData(const OH_AI_TensorHandle tensor) { 12561 MicroTensor* micro_tensor = (MicroTensor*)(tensor); 12562 return micro_tensor->data; 12563 } 12564 12565-void *MSTensorGetMutableData(const MSTensorHandle tensor) { 12566+void *OH_AI_TensorGetMutableData(const OH_AI_TensorHandle tensor) { 12567 MicroTensor* micro_tensor = (MicroTensor*)(tensor); 12568 if(micro_tensor->data) { 12569 return micro_tensor->data; 12570 } 12571- void* data = malloc(MSTensorGetDataSize(tensor)); 12572+ void* data = malloc(OH_AI_TensorGetDataSize(tensor)); 12573 micro_tensor->owned = true; 12574 micro_tensor->data = data; 12575 return data; 12576 } 12577 12578-int64_t MSTensorGetElementNum(const MSTensorHandle tensor) { 12579+int64_t OH_AI_TensorGetElementNum(const OH_AI_TensorHandle tensor) { 12580 MicroTensor* micro_tensor = (MicroTensor*)(tensor); 12581 int64_t acc_sum = 1; 12582 for(int i=0;i< micro_tensor->ndim;i++) { 12583@@ -279,10 +279,10 @@ int64_t MSTensorGetElementNum(const MSTensorHandle tensor) { 12584 return acc_sum; 12585 } 12586 12587-size_t MSTensorGetDataSize(const MSTensorHandle tensor) { 12588+size_t OH_AI_TensorGetDataSize(const OH_AI_TensorHandle tensor) { 12589 MicroTensor* micro_tensor = (MicroTensor*)(tensor); 12590 size_t data_type_size = DataTypeSize(micro_tensor->type); 12591- int64_t elements = MSTensorGetElementNum(tensor); 12592+ int64_t elements = OH_AI_TensorGetElementNum(tensor); 12593 return data_type_size * elements; 12594 } 12595 12596@@ -300,16 +300,16 @@ void Fp16CastToFp32(const float16_t *input, float *output, int number) { 12597 } 12598 #endif 12599 12600-void *TransformInput(MSTensorHandle tensor, int expect_type, bool *type_changed) { 12601+void *TransformInput(OH_AI_TensorHandle tensor, int expect_type, bool *type_changed) { 12602 MicroTensor* micro_tensor = (MicroTensor*)(tensor); 12603 int cur_type = micro_tensor->type; 12604 if (cur_type == expect_type) { 12605 return micro_tensor->data; 12606 } 12607 int type_trans_mode = TypeTransMode_MAX; 12608- if (expect_type == kMSDataTypeNumberTypeFloat16 && cur_type == kMSDataTypeNumberTypeFloat32) { 12609+ if (expect_type == OH_AI_DATATYPE_NUMBERTYPE_FLOAT16 && cur_type == OH_AI_DATATYPE_NUMBERTYPE_FLOAT32) { 12610 type_trans_mode = TypeTransMode_FP32_TO_FP16; 12611- } else if (expect_type == kMSDataTypeNumberTypeFloat32 && cur_type == kMSDataTypeNumberTypeFloat16) { 12612+ } else if (expect_type == OH_AI_DATATYPE_NUMBERTYPE_FLOAT32 && cur_type == OH_AI_DATATYPE_NUMBERTYPE_FLOAT16) { 12613 type_trans_mode = TypeTransMode_FP16_TO_FP32; 12614 } 12615 if (type_trans_mode == TypeTransMode_UNSUPPORT) { 12616diff --git a/mindspore/lite/tools/converter/micro/coder/generator/component/weight_component.cc b/mindspore/lite/tools/converter/micro/coder/generator/component/weight_component.cc 12617index ac958750..6a131b52 100644 12618--- a/mindspore/lite/tools/converter/micro/coder/generator/component/weight_component.cc 12619+++ b/mindspore/lite/tools/converter/micro/coder/generator/component/weight_component.cc 12620@@ -61,6 +61,8 @@ void CodeWeightFileHeader(std::ofstream &ofs, const std::unique_ptr<CoderContext 12621 << "#include <string.h>\n" 12622 << "extern unsigned char *" << ctx->buffer_name() << ";\n" 12623 << "extern uint8_t *" << ctx->weight_name() << ";\n" 12624+ << "extern int *" << kShapePrefixName << ";\n" 12625+ << "extern int *" << kOffsetPrefixName << ";\n" 12626 << "enum STATUS {\n" 12627 " RET_OK = 0,\n" 12628 " RET_ERROR = 1,\n" 12629diff --git a/mindspore/lite/tools/converter/micro/coder/generator/generator.cc b/mindspore/lite/tools/converter/micro/coder/generator/generator.cc 12630index dd66c333..23009e17 100644 12631--- a/mindspore/lite/tools/converter/micro/coder/generator/generator.cc 12632+++ b/mindspore/lite/tools/converter/micro/coder/generator/generator.cc 12633@@ -43,20 +43,28 @@ const char micro_model_define_source[] = R"RAW( 12634 typedef struct { 12635 void *runtime_buffer; 12636 bool train_mode; // true: train mode, false: eval mode 12637- MSTensorHandleArray inputs; 12638- MSTensorHandleArray outputs; 12639+ OH_AI_TensorHandleArray inputs; 12640+ OH_AI_TensorHandleArray outputs; 12641 ModelBuild build; 12642+ ModelResize resize; 12643 ModelSetWorkspace set_work_space; 12644 ModelCalcWorkspaceSize calc_work_space; 12645 FreeResource free_resource; 12646 )RAW"; 12647 12648 const char set_workspace_state[] = R"RAW( 12649-typedef void (*ModelSetWorkspace)(MSModelHandle model, void *workspace, size_t workspace_size); 12650+typedef void (*ModelSetWorkspace)(OH_AI_ModelHandle model, void *workspace, size_t workspace_size); 12651 )RAW"; 12652 12653 const char calc_workspace_state[] = R"RAW( 12654-typedef size_t (*ModelCalcWorkspaceSize)(MSModelHandle model); 12655+typedef size_t (*ModelCalcWorkspaceSize)(OH_AI_ModelHandle model); 12656+)RAW"; 12657+ 12658+const char model_resize[] = R"RAW( 12659+typedef OH_AI_Status (*ModelResize)(OH_AI_ModelHandle model, 12660+ const OH_AI_TensorHandleArray inputs, 12661+ OH_AI_ShapeInfo *shape_infos, 12662+ size_t shape_info_num); 12663 )RAW"; 12664 12665 int WriteContentToFile(const std::string &file, const std::string &content) { 12666@@ -311,6 +319,7 @@ int Generator::CodeCommonModelFile() { 12667 CodeFreeResourceState(hofs); 12668 hofs << set_workspace_state; 12669 hofs << calc_workspace_state; 12670+ hofs << model_resize; 12671 hofs << micro_model_define_source; 12672 if (config_->code_mode() == CodeMode::Inference) { 12673 hofs << " ModelPredict predict;\n"; 12674@@ -321,7 +330,7 @@ int Generator::CodeCommonModelFile() { 12675 } 12676 hofs << "} MicroModel;\n"; 12677 12678- hofs << "void MSTensorHandleArrayDestroy(MSTensorHandleArray inputs);\n"; 12679+ hofs << "void MSTensorHandleArrayDestroy(OH_AI_TensorHandleArray inputs);\n"; 12680 hofs << "#endif // MINDSPORE_LITE_MICRO_LIBRARY_SOURCE_MODEL_H_\n\n"; 12681 12682 // model source file 12683@@ -340,7 +349,7 @@ int Generator::CodeCommonModelFile() { 12684 if (config_->support_parallel()) { 12685 cofs << "#include \"" << kThreadWrapper << "\"\n"; 12686 } 12687- if (config_->target() != kCortex_M) { 12688+ if (config_->target() != kCortex_M && !config_->dynamic_shape()) { 12689 cofs << "#include \"src/allocator.h\"\n"; 12690 } 12691 CodeMSModelCalcWorkspaceSize(cofs, ctx_, *config_); 12692@@ -369,7 +378,7 @@ int Generator::CodeModelHandleHFile() { 12693 "#define MINDSPORE_LITE_MICRO_LIBRARY_INCLUDE_MODEL_HANDLE_H_\n\n" 12694 << "#include \"c_api/model_c.h\"\n\n"; 12695 for (int i = 0; i <= ctx_->GetCurModelIndex(); ++i) { 12696- ofs << "extern MSModelHandle model" << std::to_string(i) << "; // " << ctx_->model_name() << "\n"; 12697+ ofs << "extern OH_AI_ModelHandle model" << std::to_string(i) << "; // " << ctx_->model_name() << "\n"; 12698 } 12699 ofs << "\n#endif // MINDSPORE_LITE_MICRO_LIBRARY_INCLUDE_MODEL_HANDLE_H_\n"; 12700 return RET_OK; 12701@@ -386,7 +395,7 @@ int Generator::CodeMSModelImplement() { 12702 ofs << "#include \"c_api/model_c.h\"\n"; 12703 ofs << "#include \"src/model.h\"\n"; 12704 ofs << "#include \"src/model" << ctx_->GetCurModelIndex() << "/" << net_inc_hfile_ << "\"\n"; 12705- if (config_->target() != kCortex_M) { 12706+ if (config_->target() != kCortex_M && !config_->dynamic_shape()) { 12707 ofs << "#include \"src/allocator.h\"\n"; 12708 } 12709 if (config_->support_parallel()) { 12710@@ -399,33 +408,37 @@ int Generator::CodeMSModelImplement() { 12711 ofs << "#define GRAPH_OUTPUTS_SIZE " << ctx_->graph_outputs().size() << "\n"; 12712 ofs << "#define WEIGHT_BUF_SIZE " << ctx_->weight_buffer_size() << "\n"; 12713 } 12714- ofs << "MSStatus MSModelBuild" << ctx_->GetCurModelIndex() << "(MSModelHandle model, const void *model_data,\n" 12715- << " size_t data_size, const MSContextHandle model_context);\n"; 12716+ ofs << "OH_AI_Status OH_AI_ModelBuild" << ctx_->GetCurModelIndex() << "(OH_AI_ModelHandle model, const void *model_data,\n" 12717+ << " size_t data_size, const OH_AI_ContextHandle model_context);\n"; 12718+ ofs << "OH_AI_Status OH_AI_ModelResize" << ctx_->GetCurModelIndex() << "(OH_AI_ModelHandle model, \n" 12719+ << " const OH_AI_TensorHandleArray inputs, OH_AI_ShapeInfo *shape_infos, size_t shape_info_num);\n"; 12720 if (config_->code_mode() == CodeMode::Inference) { 12721- ofs << "MSStatus MSModelPredict" << ctx_->GetCurModelIndex() 12722- << "(MSModelHandle model, const MSTensorHandleArray inputs,\n" 12723- << " MSTensorHandleArray *output,\n" 12724- << " const MSKernelCallBackC before,\n" 12725- << " const MSKernelCallBackC after);\n"; 12726+ ofs << "OH_AI_Status OH_AI_ModelPredict" << ctx_->GetCurModelIndex() 12727+ << "(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray inputs,\n" 12728+ << " OH_AI_TensorHandleArray *output,\n" 12729+ << " const OH_AI_KernelCallBack before,\n" 12730+ << " const OH_AI_KernelCallBack after);\n"; 12731 } else { 12732- ofs << "MSStatus MSModelRunStep" << ctx_->GetCurModelIndex() 12733- << "(MSModelHandle model,\n" 12734- " const MSKernelCallBackC before,\n" 12735- " const MSKernelCallBackC after);\n"; 12736- ofs << "MSStatus MSModelSetTrainMode" << ctx_->GetCurModelIndex() << "(MSModelHandle model, bool train);\n"; 12737- ofs << "MSStatus MSModelExportWeight" << ctx_->GetCurModelIndex() 12738- << "(MSModelHandle model, const char *export_path);\n"; 12739- } 12740+ ofs << "OH_AI_Status MSModelRunStep" << ctx_->GetCurModelIndex() 12741+ << "(OH_AI_ModelHandle model,\n" 12742+ " const OH_AI_KernelCallBack before,\n" 12743+ " const OH_AI_KernelCallBack after);\n"; 12744+ ofs << "OH_AI_Status MSModelSetTrainMode" << ctx_->GetCurModelIndex() << "(OH_AI_ModelHandle model, bool train);\n"; 12745+ ofs << "OH_AI_Status MSModelExportWeight" << ctx_->GetCurModelIndex() 12746+ << "(OH_AI_ModelHandle model, const char *export_path);\n"; 12747+ } 12748+ ofs << "void Reset" << ctx_->GetCurModelIndex() << "();\n"; 12749 ofs << "void MSModelSetWorkspace" << ctx_->GetCurModelIndex() 12750- << "(MSModelHandle model, void *workspace, size_t workspace_size);\n"; 12751- ofs << "size_t MSModelCalcWorkspaceSize" << ctx_->GetCurModelIndex() << "(MSModelHandle model);\n"; 12752+ << "(OH_AI_ModelHandle model, void *workspace, size_t workspace_size);\n"; 12753+ ofs << "size_t MSModelCalcWorkspaceSize" << ctx_->GetCurModelIndex() << "(OH_AI_ModelHandle model);\n"; 12754 ofs << "static MicroModel gModel" << ctx_->GetCurModelIndex() << " = {.runtime_buffer = NULL,\n" 12755 << " .train_mode = false,\n" 12756 << " .inputs = {" << ctx_->graph_inputs().size() << ", NULL},\n" 12757 << " .outputs = {" << ctx_->graph_outputs().size() << ", NULL},\n" 12758- << " .build = MSModelBuild" << ctx_->GetCurModelIndex() << ",\n"; 12759+ << " .build = OH_AI_ModelBuild" << ctx_->GetCurModelIndex() << ",\n" 12760+ << " .resize = OH_AI_ModelResize" << ctx_->GetCurModelIndex() << ",\n"; 12761 if (config_->code_mode() == CodeMode::Inference) { 12762- ofs << " .predict = MSModelPredict" << ctx_->GetCurModelIndex() << ",\n"; 12763+ ofs << " .predict = OH_AI_ModelPredict" << ctx_->GetCurModelIndex() << ",\n"; 12764 } else { 12765 ofs << " .run_step = MSModelRunStep" << ctx_->GetCurModelIndex() << ",\n" 12766 << " .set_train_mode = MSModelSetTrainMode" << ctx_->GetCurModelIndex() << ",\n" 12767@@ -439,11 +452,16 @@ int Generator::CodeMSModelImplement() { 12768 ofs << " .set_work_space = NULL,\n" 12769 << " .calc_work_space = NULL,\n"; 12770 } 12771- ofs << " .free_resource = FreeResource" << ctx_->GetCurModelIndex() << "};\n"; 12772- ofs << "MSModelHandle model" << ctx_->GetCurModelIndex() << " = &gModel" << ctx_->GetCurModelIndex() << ";\n\n"; 12773- 12774+ ofs << " .free_resource = Reset" << ctx_->GetCurModelIndex() << "};\n"; 12775+ ofs << "OH_AI_ModelHandle model" << ctx_->GetCurModelIndex() << " = &gModel" << ctx_->GetCurModelIndex() << ";\n\n"; 12776+ auto &dynamic_symbols = config_->dynamic_symbols(); 12777+ for (size_t i = 0; i < dynamic_symbols.size(); ++i) { 12778+ ofs << "static int store" << ctx_->GetCurModelIndex() << "_" << i << " = -1;\n"; 12779+ } 12780+ CodeResetImplement(ofs, ctx_, *config_); 12781 CodeMSModelCreate(ofs, ctx_, *config_); 12782 CodeMSModelBuild(ofs, ctx_->GetCurModelIndex(), weight_size_, *config_); 12783+ CodeMSModelResize(ofs, ctx_, *config_); 12784 CodeCopyOutputsImplement(ofs, ctx_); 12785 if (config_->target() == kCortex_M) { 12786 CodeCortexCalcWorkspaceSize(ofs, ctx_); 12787@@ -483,6 +501,8 @@ int Generator::CodeWeightFile() { 12788 if (config_->target() != kCortex_M) { 12789 cofs << "unsigned char *" << ctx_->buffer_name() << " = 0; \n"; 12790 cofs << "unsigned char *" << ctx_->weight_name() << " = 0; \n"; 12791+ cofs << "int *" << kShapePrefixName << " = 0; \n"; 12792+ cofs << "int *" << kOffsetPrefixName << " = 0; \n"; 12793 std::string net_file = model_dir_ + "net" + std::to_string(ctx_->GetCurModelIndex()) + ".bin"; 12794 SaveDataToNet(ctx_, net_file, config_->keep_original_weight(), &weight_size_); 12795 } else { 12796@@ -598,8 +618,10 @@ int Generator::CreateCommonFiles() { 12797 MS_CHECK_RET_CODE(CodeStaticContent(), "code static content failed."); 12798 MS_CHECK_RET_CODE(CodeModelHandleHFile(), "code model_handle h file failed."); 12799 MS_CHECK_RET_CODE(CodeCommonModelFile(), "code common model file failed."); 12800+ if (!config_->dynamic_shape()) { 12801+ MS_CHECK_RET_CODE(CodeAllocatorFile(), "code allocator file failed."); 12802+ } 12803 MS_CHECK_RET_CODE(CodeRegKernelHFile(), "code registered kernel header file failed."); 12804- MS_CHECK_RET_CODE(CodeAllocatorFile(), "code allocator file failed."); 12805 MS_CHECK_RET_CODE(CodeSourceCMakeFile(), "code net cmake file failed."); 12806 return RET_OK; 12807 } 12808diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/base/reshape_dynamic_base_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/base/reshape_dynamic_base_coder.cc 12809new file mode 100644 12810index 00000000..108ba227 12811--- /dev/null 12812+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/base/reshape_dynamic_base_coder.cc 12813@@ -0,0 +1,116 @@ 12814+/** 12815+ * Copyright 2023 Huawei Technologies Co., Ltd 12816+ * 12817+ * Licensed under the Apache License, Version 2.0 (the "License"); 12818+ * you may not use this file except in compliance with the License. 12819+ * You may obtain a copy of the License at 12820+ * 12821+ * http://www.apache.org/licenses/LICENSE-2.0 12822+ * 12823+ * Unless required by applicable law or agreed to in writing, software 12824+ * distributed under the License is distributed on an "AS IS" BASIS, 12825+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12826+ * See the License for the specific language governing permissions and 12827+ * limitations under the License. 12828+ */ 12829+ 12830+#include "coder/opcoders/base/reshape_dynamic_base_coder.h" 12831+#include <string> 12832+#include "coder/opcoders/serializers/serializer.h" 12833+#include "include/errorcode.h" 12834+#include "tools/common/string_util.h" 12835+#include "coder/utils/coder_utils.h" 12836+ 12837+using mindspore::schema::PrimitiveType_ExpandDims; 12838+using mindspore::schema::PrimitiveType_Flatten; 12839+using mindspore::schema::PrimitiveType_FlattenGrad; 12840+using mindspore::schema::PrimitiveType_Reshape; 12841+using mindspore::schema::PrimitiveType_Squeeze; 12842+using mindspore::schema::PrimitiveType_Unsqueeze; 12843+ 12844+namespace mindspore::lite::micro { 12845+int ReshapeDynamicBaseCoder::Prepare(CoderContext *const context) { 12846+ if (input_tensors_.size() == C2NUM) { 12847+ MS_CHECK_TRUE_MSG(input_tensors_[SECOND_INPUT]->IsConst(), RET_NOT_SUPPORT, 12848+ "Currently, only support the first input of reshape is non-const when shape is dynamical."); 12849+ 12850+ MS_CHECK_TRUE_MSG(input_tensors_[SECOND_INPUT]->data_type() == kNumberTypeInt32 || 12851+ input_tensors_[SECOND_INPUT]->data_type() == kNumberTypeInt, 12852+ RET_ERROR, "The data-type of Reshape's second input must be int."); 12853+ } 12854+ return RET_OK; 12855+} 12856+ 12857+int ReshapeDynamicBaseCoder::DoCode(CoderContext *const context) { 12858+ Serializer coder; 12859+ 12860+ int data_item_size = static_cast<int>(lite::DataTypeSize(input_tensor_->data_type())); 12861+ auto in_shape = shape_info_container_->GetTemplateShape(input_tensor_); 12862+ int64_t const_part = 1; 12863+ std::string non_const_part; 12864+ for (const auto &item : in_shape) { 12865+ if (IsNumber(item)) { 12866+ const_part *= std::stoi(item); 12867+ } else { 12868+ if (!non_const_part.empty()) { 12869+ non_const_part += " * "; 12870+ } 12871+ non_const_part += item; 12872+ } 12873+ } 12874+ std::string size = std::to_string(const_part * data_item_size) + " * " + non_const_part; 12875+ std::string input_data = dynamic_mem_manager_->GetVarTensorAddr(input_tensor_); 12876+ MS_CHECK_TRUE_MSG(!input_data.empty(), RET_ERROR, "pointer is not allocated by the allocator"); 12877+ std::string output_data = dynamic_mem_manager_->GetVarTensorAddr(output_tensor_); 12878+ MS_CHECK_TRUE_MSG(!output_data.empty(), RET_ERROR, "pointer is not allocated by the allocator"); 12879+ coder.CodeFunction("memcpy", output_data, input_data, size); 12880+ 12881+ context->AppendCode(coder.str()); 12882+ return RET_OK; 12883+} 12884+ 12885+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_Reshape, 12886+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12887+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_Reshape, 12888+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12889+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_Reshape, 12890+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12891+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_Reshape, 12892+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12893+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_Flatten, 12894+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12895+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_Flatten, 12896+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12897+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeInt8, PrimitiveType_Flatten, 12898+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12899+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_ExpandDims, 12900+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12901+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_ExpandDims, 12902+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12903+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_ExpandDims, 12904+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12905+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_ExpandDims, 12906+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12907+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeInt8, PrimitiveType_ExpandDims, 12908+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12909+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_Squeeze, 12910+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12911+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_Squeeze, 12912+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12913+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_Squeeze, 12914+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12915+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_Squeeze, 12916+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12917+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeInt8, PrimitiveType_Squeeze, 12918+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12919+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_Unsqueeze, 12920+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12921+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_Unsqueeze, 12922+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12923+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_Unsqueeze, 12924+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12925+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_Unsqueeze, 12926+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12927+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeInt8, PrimitiveType_Unsqueeze, 12928+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12929+} // namespace mindspore::lite::micro 12930diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/base/reshape_dynamic_base_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/base/reshape_dynamic_base_coder.h 12931new file mode 100644 12932index 00000000..aaae22eb 12933--- /dev/null 12934+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/base/reshape_dynamic_base_coder.h 12935@@ -0,0 +1,38 @@ 12936+/** 12937+ * Copyright 2023 Huawei Technologies Co., Ltd 12938+ * 12939+ * Licensed under the Apache License, Version 2.0 (the "License"); 12940+ * you may not use this file except in compliance with the License. 12941+ * You may obtain a copy of the License at 12942+ * 12943+ * http://www.apache.org/licenses/LICENSE-2.0 12944+ * 12945+ * Unless required by applicable law or agreed to in writing, software 12946+ * distributed under the License is distributed on an "AS IS" BASIS, 12947+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12948+ * See the License for the specific language governing permissions and 12949+ * limitations under the License. 12950+ */ 12951+ 12952+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_BASE_RESHAPE_DYNAMIC_BASE_CODER_H_ 12953+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_BASE_RESHAPE_DYNAMIC_BASE_CODER_H_ 12954+ 12955+#include "tools/converter/micro/coder/opcoders/op_coder.h" 12956+#include "tools/converter/micro/coder/shape_info_container.h" 12957+#include "tools/converter/micro/coder/dynamic_mem_manager.h" 12958+ 12959+namespace mindspore::lite::micro { 12960+class ReshapeDynamicBaseCoder final : public OperatorCoder { 12961+ public: 12962+ ReshapeDynamicBaseCoder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 12963+ const LiteGraph::Node *node, size_t node_index, Target target) 12964+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 12965+ 12966+ ~ReshapeDynamicBaseCoder() override = default; 12967+ 12968+ int Prepare(CoderContext *const context) override; 12969+ 12970+ int DoCode(CoderContext *const context) override; 12971+}; 12972+} // namespace mindspore::lite::micro 12973+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_BASE_RESHAPE_DYNAMIC_BASE_CODER_H_ 12974diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/base/strided_slice_dynamic_base_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/base/strided_slice_dynamic_base_coder.cc 12975new file mode 100644 12976index 00000000..4b2b0abe 12977--- /dev/null 12978+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/base/strided_slice_dynamic_base_coder.cc 12979@@ -0,0 +1,115 @@ 12980+/** 12981+ * Copyright 2023 Huawei Technologies Co., Ltd 12982+ * 12983+ * Licensed under the Apache License, Version 2.0 (the "License"); 12984+ * you may not use this file except in compliance with the License. 12985+ * You may obtain a copy of the License at 12986+ * 12987+ * http://www.apache.org/licenses/LICENSE-2.0 12988+ * 12989+ * Unless required by applicable law or agreed to in writing, software 12990+ * distributed under the License is distributed on an "AS IS" BASIS, 12991+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12992+ * See the License for the specific language governing permissions and 12993+ * limitations under the License. 12994+ */ 12995+ 12996+#include "coder/opcoders/base/strided_slice_dynamic_base_coder.h" 12997+#include <cmath> 12998+#include <string> 12999+#include "mindspore/lite/src/common/log_util.h" 13000+#include "coder/opcoders/file_collector.h" 13001+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 13002+#include "coder/opcoders/parallel.h" 13003+#include "coder/utils/coder_utils.h" 13004+#include "tools/common/string_util.h" 13005+#include "base/float16.h" 13006+ 13007+using mindspore::schema::PrimitiveType_StridedSlice; 13008+ 13009+namespace mindspore::lite::micro { 13010+namespace { 13011+size_t GetInnerSize(TypeId type_id, size_t inner_elements) { 13012+ switch (type_id) { 13013+ case kNumberTypeInt8: 13014+ return inner_elements * sizeof(int8_t); 13015+ case kNumberTypeFloat32: 13016+ return inner_elements * sizeof(float); 13017+ case kNumberTypeInt32: 13018+ return inner_elements * sizeof(int32_t); 13019+ case kNumberTypeFloat16: 13020+ return inner_elements * sizeof(float16); 13021+ default: 13022+ MS_LOG(ERROR) << "Not supported data type: " << type_id; 13023+ return 0; 13024+ } 13025+} 13026+} // namespace 13027+ 13028+int StridedSliceDynamicBaseCoder::Prepare(CoderContext *context) { 13029+ CHECK_LESS_RETURN(input_tensors_.size(), C2NUM); 13030+ for (size_t i = 1; i < input_tensors_.size(); ++i) { 13031+ MS_CHECK_TRUE_MSG(input_tensors_[i]->IsConst(), RET_PARAM_INVALID, 13032+ "The " << i << " input of strided slice should be const."); 13033+ MS_CHECK_TRUE_MSG(input_tensors_[i]->data_type() == kNumberTypeInt32, RET_PARAM_INVALID, 13034+ "The " << i << " input tensor data type should be int32."); 13035+ } 13036+ CHECK_LESS_RETURN(output_tensors_.size(), C1NUM); 13037+ strided_slice_param_ = reinterpret_cast<StridedSliceParameter *>(parameter_); 13038+ CHECK_NULL_RETURN(strided_slice_param_); 13039+ auto begin_tensor = input_tensors_.at(1); 13040+ input_shape_ = shape_info_container_->GetTemplateShape(input_tensor_); 13041+ if (input_shape_.size() > DIMENSION_8D || begin_tensor->shape().size() > DIMENSION_8D) { 13042+ MS_LOG(ERROR) << "StridedSlice not support input rank or begin num exceeds " << DIMENSION_8D; 13043+ return RET_ERROR; 13044+ } 13045+ dynamic_param_.in_shape_ = "{"; 13046+ for (size_t i = 0; i < input_shape_.size(); ++i) { 13047+ dynamic_param_.in_shape_ += input_shape_[i] + ", "; 13048+ } 13049+ dynamic_param_.in_shape_ += "}"; 13050+ return RET_OK; 13051+} 13052+ 13053+int StridedSliceDynamicBaseCoder::DoCode(CoderContext *ctx) { 13054+ inner_size_ = GetInnerSize(input_tensor_->data_type(), inner_); 13055+ Collect(ctx, 13056+ { 13057+ "nnacl/fp32/strided_slice_fp32.h", 13058+ }, 13059+ { 13060+ "strided_slice_fp32.c", 13061+ }); 13062+ switch (input_tensor_->data_type()) { 13063+ case kNumberTypeInt8: 13064+ strided_slice_param_->data_type = ::kNumberTypeInt8; 13065+ break; 13066+ case kNumberTypeFloat32: 13067+ strided_slice_param_->data_type = ::kNumberTypeFloat32; 13068+ break; 13069+ case kNumberTypeInt32: 13070+ strided_slice_param_->data_type = ::kNumberTypeInt32; 13071+ break; 13072+ case kNumberTypeFloat16: 13073+ strided_slice_param_->data_type = ::kNumberTypeFloat16; 13074+ break; 13075+ default: 13076+ MS_LOG(ERROR) << "Not supported data type: " << input_tensor_->data_type(); 13077+ return RET_ERROR; 13078+ } 13079+ nnacl::NNaclFp32Serializer code; 13080+ code.CodeStruct("strided_slice_parameter", *strided_slice_param_, dynamic_param_); 13081+ std::string input_data = GetTensorAddr(input_tensor_, input_tensor_->IsConst(), dynamic_mem_manager_, allocator_); 13082+ std::string output_data = GetTensorAddr(output_tensor_, output_tensor_->IsConst(), dynamic_mem_manager_, allocator_); 13083+ code.CodeFunction("DoStridedSlice", input_data, output_data, "&strided_slice_parameter"); 13084+ ctx->AppendCode(code.str()); 13085+ return RET_OK; 13086+} 13087+ 13088+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_StridedSlice, 13089+ CPUOpCoderCreator<StridedSliceDynamicBaseCoder>) 13090+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeFloat16, PrimitiveType_StridedSlice, 13091+ CPUOpCoderCreator<StridedSliceDynamicBaseCoder>) 13092+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_StridedSlice, 13093+ CPUOpCoderCreator<StridedSliceDynamicBaseCoder>) 13094+} // namespace mindspore::lite::micro 13095diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/base/strided_slice_dynamic_base_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/base/strided_slice_dynamic_base_coder.h 13096new file mode 100644 13097index 00000000..d41cff4f 13098--- /dev/null 13099+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/base/strided_slice_dynamic_base_coder.h 13100@@ -0,0 +1,45 @@ 13101+/** 13102+ * Copyright 2023 Huawei Technologies Co., Ltd 13103+ * 13104+ * Licensed under the Apache License, Version 2.0 (the "License"); 13105+ * you may not use this file except in compliance with the License. 13106+ * You may obtain a copy of the License at 13107+ * 13108+ * http://www.apache.org/licenses/LICENSE-2.0 13109+ * 13110+ * Unless required by applicable law or agreed to in writing, software 13111+ * distributed under the License is distributed on an "AS IS" BASIS, 13112+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13113+ * See the License for the specific language governing permissions and 13114+ * limitations under the License. 13115+ */ 13116+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_BASE_STRIDED_SLICE_BASE_CODER_H_ 13117+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_BASE_STRIDED_SLICE_BASE_CODER_H_ 13118+#include <vector> 13119+#include "coder/opcoders/op_coder.h" 13120+#include "coder/opcoders/nnacl/dynamic_parameter/strided_slice_dynamic_parameter.h" 13121+#include "nnacl/strided_slice_parameter.h" 13122+ 13123+namespace mindspore::lite::micro { 13124+class StridedSliceDynamicBaseCoder final : public OperatorCoder { 13125+ public: 13126+ StridedSliceDynamicBaseCoder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 13127+ const LiteGraph::Node *node, size_t node_index, Target target) 13128+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 13129+ 13130+ ~StridedSliceDynamicBaseCoder() override = default; 13131+ 13132+ int Prepare(CoderContext *context) override; 13133+ 13134+ int DoCode(CoderContext *context) override; 13135+ 13136+ private: 13137+ StridedSliceParameter *strided_slice_param_{nullptr}; 13138+ StridedSliceDynamicParameter dynamic_param_; 13139+ size_t inner_{1}; 13140+ size_t inner_size_{1}; 13141+ std::vector<std::string> input_shape_; 13142+ std::vector<std::string> output_shape_; 13143+}; 13144+} // namespace mindspore::lite::micro 13145+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_BASE_STRIDED_SLICE_BASE_CODER_H_ 13146diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/arithmetic_dynamic_parameter.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/arithmetic_dynamic_parameter.h 13147new file mode 100644 13148index 00000000..1e9e4f8d 13149--- /dev/null 13150+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/arithmetic_dynamic_parameter.h 13151@@ -0,0 +1,43 @@ 13152+/** 13153+ * Copyright 2023 Huawei Technologies Co., Ltd 13154+ * 13155+ * Licensed under the Apache License, Version 2.0 (the "License"); 13156+ * you may not use this file except in compliance with the License. 13157+ * You may obtain a copy of the License at 13158+ * 13159+ * http://www.apache.org/licenses/LICENSE-2.0 13160+ * 13161+ * Unless required by applicable law or agreed to in writing, software 13162+ * distributed under the License is distributed on an "AS IS" BASIS, 13163+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13164+ * See the License for the specific language governing permissions and 13165+ * limitations under the License. 13166+ */ 13167+ 13168+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_ARITHMETIC_DYNAMIC_PARAMETER_H_ 13169+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_ARITHMETIC_DYNAMIC_PARAMETER_H_ 13170+#include <string> 13171+ 13172+typedef struct ArithmeticDynamicParameter { 13173+ std::string in_shape0_; 13174+ std::string in_elements_num0_; 13175+ std::string in_shape1_; 13176+ std::string in_elements_num1_; 13177+ 13178+ std::string out_shape_; 13179+ std::string out_elements_num_; 13180+ 13181+ std::string in_strides0_; 13182+ std::string in_strides1_; 13183+ std::string out_strides_; 13184+ 13185+ std::string multiples0_; 13186+ std::string multiples1_; 13187+} ArithmeticDynamicParameter; 13188+ 13189+typedef struct BroadcastDynamicShapeInfo { 13190+ std::string input_shape_; 13191+ std::string output_shape_; 13192+} BroadcastDynamicShapeInfo; 13193+ 13194+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_ARITHMETIC_DYNAMIC_PARAMETER_H_ 13195diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/conv_dynamic_parameter.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/conv_dynamic_parameter.h 13196new file mode 100644 13197index 00000000..a05ab848 13198--- /dev/null 13199+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/conv_dynamic_parameter.h 13200@@ -0,0 +1,26 @@ 13201+/** 13202+ * Copyright 2023 Huawei Technologies Co., Ltd 13203+ * 13204+ * Licensed under the Apache License, Version 2.0 (the "License"); 13205+ * you may not use this file except in compliance with the License. 13206+ * You may obtain a copy of the License at 13207+ * 13208+ * http://www.apache.org/licenses/LICENSE-2.0 13209+ * 13210+ * Unless required by applicable law or agreed to in writing, software 13211+ * distributed under the License is distributed on an "AS IS" BASIS, 13212+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13213+ * See the License for the specific language governing permissions and 13214+ * limitations under the License. 13215+ */ 13216+ 13217+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_CONV_DYNAMIC_PARAMETER_H_ 13218+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_CONV_DYNAMIC_PARAMETER_H_ 13219+#include <string> 13220+ 13221+typedef struct ConvDynamicParameter { 13222+ std::string input_batch_; 13223+ std::string output_batch_; 13224+} ConvDynamicParameter; 13225+ 13226+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_CONV_DYNAMIC_PARAMETER_H_ 13227\ No newline at end of file 13228diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/dynamic_lstm_parameter.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/dynamic_lstm_parameter.h 13229new file mode 100644 13230index 00000000..970a863a 13231--- /dev/null 13232+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/dynamic_lstm_parameter.h 13233@@ -0,0 +1,28 @@ 13234+/** 13235+ * Copyright 2023 Huawei Technologies Co., Ltd 13236+ * 13237+ * Licensed under the Apache License, Version 2.0 (the "License"); 13238+ * you may not use this file except in compliance with the License. 13239+ * You may obtain a copy of the License at 13240+ * 13241+ * http://www.apache.org/licenses/LICENSE-2.0 13242+ * 13243+ * Unless required by applicable law or agreed to in writing, software 13244+ * distributed under the License is distributed on an "AS IS" BASIS, 13245+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13246+ * See the License for the specific language governing permissions and 13247+ * limitations under the License. 13248+ */ 13249+ 13250+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_DYNAMIC_LSTM_PARAMETER_H_ 13251+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_DYNAMIC_LSTM_PARAMETER_H_ 13252+ 13253+typedef struct DynamicLstmParameter { 13254+ std::string seq_len_; 13255+ std::string batch_; 13256+ std::string input_row_align_; 13257+ std::string state_row_align_; 13258+ std::string output_step_; 13259+} DynamicLstmParameter; 13260+ 13261+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_DYNAMIC_LSTM_PARAMETER_H_ 13262diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/matmul_dynamic_parameter.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/matmul_dynamic_parameter.h 13263new file mode 100644 13264index 00000000..d99b0cf9 13265--- /dev/null 13266+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/matmul_dynamic_parameter.h 13267@@ -0,0 +1,25 @@ 13268+/** 13269+ * Copyright 2023 Huawei Technologies Co., Ltd 13270+ * 13271+ * Licensed under the Apache License, Version 2.0 (the "License"); 13272+ * you may not use this file except in compliance with the License. 13273+ * You may obtain a copy of the License at 13274+ * 13275+ * http://www.apache.org/licenses/LICENSE-2.0 13276+ * 13277+ * Unless required by applicable law or agreed to in writing, software 13278+ * distributed under the License is distributed on an "AS IS" BASIS, 13279+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13280+ * See the License for the specific language governing permissions and 13281+ * limitations under the License. 13282+ */ 13283+ 13284+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_MATMUL_DYNAMIC_PARAMETER_H_ 13285+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_MATMUL_DYNAMIC_PARAMETER_H_ 13286+ 13287+typedef struct MatmulDynamicParameter { 13288+ std::string row_; 13289+ std::string batch_; 13290+} MatmulDynamicParameter; 13291+ 13292+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_MATMUL_DYNAMIC_PARAMETER_H_ 13293diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/pooling_dynamic_parameter.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/pooling_dynamic_parameter.h 13294new file mode 100644 13295index 00000000..f2636e55 13296--- /dev/null 13297+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/pooling_dynamic_parameter.h 13298@@ -0,0 +1,33 @@ 13299+/** 13300+ * Copyright 2023 Huawei Technologies Co., Ltd 13301+ * 13302+ * Licensed under the Apache License, Version 2.0 (the "License"); 13303+ * you may not use this file except in compliance with the License. 13304+ * You may obtain a copy of the License at 13305+ * 13306+ * http://www.apache.org/licenses/LICENSE-2.0 13307+ * 13308+ * Unless required by applicable law or agreed to in writing, software 13309+ * distributed under the License is distributed on an "AS IS" BASIS, 13310+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13311+ * See the License for the specific language governing permissions and 13312+ * limitations under the License. 13313+ */ 13314+ 13315+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_POOLING_DYNAMIC_PARAMETER_H_ 13316+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_POOLING_DYNAMIC_PARAMETER_H_ 13317+#include <string> 13318+ 13319+typedef struct PoolingDynamicParameter { 13320+ int avg_mode_; 13321+ bool global_; 13322+ int window_w_; 13323+ int window_h_; 13324+ int stride_w_; 13325+ int stride_h_; 13326+ 13327+ std::string input_batch_; 13328+ std::string output_batch_; 13329+} PoolingDynamicParameter; 13330+ 13331+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_POOLING_DYNAMIC_PARAMETER_H_ 13332diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/scale_dynamic_parameter.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/scale_dynamic_parameter.h 13333new file mode 100644 13334index 00000000..e8728383 13335--- /dev/null 13336+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/scale_dynamic_parameter.h 13337@@ -0,0 +1,26 @@ 13338+/** 13339+ * Copyright 2023 Huawei Technologies Co., Ltd 13340+ * 13341+ * Licensed under the Apache License, Version 2.0 (the "License"); 13342+ * you may not use this file except in compliance with the License. 13343+ * You may obtain a copy of the License at 13344+ * 13345+ * http://www.apache.org/licenses/LICENSE-2.0 13346+ * 13347+ * Unless required by applicable law or agreed to in writing, software 13348+ * distributed under the License is distributed on an "AS IS" BASIS, 13349+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13350+ * See the License for the specific language governing permissions and 13351+ * limitations under the License. 13352+ */ 13353+ 13354+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_SCALE_DYNAMIC_PARAMETER_H_ 13355+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_SCALE_DYNAMIC_PARAMETER_H_ 13356+#include <string> 13357+ 13358+typedef struct ScaleDynamicParameter { 13359+ std::string outer_size_; 13360+ std::string axis_size_; 13361+ std::string inner_size_; 13362+} ScaleDynamicParameter; 13363+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_SCALE_DYNAMIC_PARAMETER_H_ 13364diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/slice_dynamic_parameter.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/slice_dynamic_parameter.h 13365new file mode 100644 13366index 00000000..f17993d4 13367--- /dev/null 13368+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/slice_dynamic_parameter.h 13369@@ -0,0 +1,27 @@ 13370+/** 13371+ * Copyright 2023 Huawei Technologies Co., Ltd 13372+ * 13373+ * Licensed under the Apache License, Version 2.0 (the "License"); 13374+ * you may not use this file except in compliance with the License. 13375+ * You may obtain a copy of the License at 13376+ * 13377+ * http://www.apache.org/licenses/LICENSE-2.0 13378+ * 13379+ * Unless required by applicable law or agreed to in writing, software 13380+ * distributed under the License is distributed on an "AS IS" BASIS, 13381+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13382+ * See the License for the specific language governing permissions and 13383+ * limitations under the License. 13384+ */ 13385+ 13386+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_SLICE_DYNAMIC_PARAMETER_H_ 13387+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_SLICE_DYNAMIC_PARAMETER_H_ 13388+#include <string> 13389+ 13390+typedef struct SliceDynamicParameter { 13391+ std::string shape_; 13392+ std::string size_; 13393+ std::string end_; 13394+} SliceDynamicParameter; 13395+ 13396+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_SLICE_DYNAMIC_PARAMETER_H_ 13397diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/softmax_dynamic_parameter.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/softmax_dynamic_parameter.h 13398new file mode 100644 13399index 00000000..92dfaf21 13400--- /dev/null 13401+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/softmax_dynamic_parameter.h 13402@@ -0,0 +1,26 @@ 13403+/** 13404+ * Copyright 2023 Huawei Technologies Co., Ltd 13405+ * 13406+ * Licensed under the Apache License, Version 2.0 (the "License"); 13407+ * you may not use this file except in compliance with the License. 13408+ * You may obtain a copy of the License at 13409+ * 13410+ * http://www.apache.org/licenses/LICENSE-2.0 13411+ * 13412+ * Unless required by applicable law or agreed to in writing, software 13413+ * distributed under the License is distributed on an "AS IS" BASIS, 13414+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13415+ * See the License for the specific language governing permissions and 13416+ * limitations under the License. 13417+ */ 13418+ 13419+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_SOFTMAX_DYNAMIC_PARAMETER_H_ 13420+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_SOFTMAX_DYNAMIC_PARAMETER_H_ 13421+#include <string> 13422+ 13423+typedef struct SoftmaxDynamicParameter { 13424+ std::string input_shape_; 13425+ std::string element_size_; 13426+} SoftmaxDynamicParameter; 13427+ 13428+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_SOFTMAX_DYNAMIC_PARAMETER_H_ 13429diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/split_dynamic_parameter.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/split_dynamic_parameter.h 13430new file mode 100644 13431index 00000000..b97097ad 13432--- /dev/null 13433+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/split_dynamic_parameter.h 13434@@ -0,0 +1,26 @@ 13435+/** 13436+ * Copyright 2023 Huawei Technologies Co., Ltd 13437+ * 13438+ * Licensed under the Apache License, Version 2.0 (the "License"); 13439+ * you may not use this file except in compliance with the License. 13440+ * You may obtain a copy of the License at 13441+ * 13442+ * http://www.apache.org/licenses/LICENSE-2.0 13443+ * 13444+ * Unless required by applicable law or agreed to in writing, software 13445+ * distributed under the License is distributed on an "AS IS" BASIS, 13446+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13447+ * See the License for the specific language governing permissions and 13448+ * limitations under the License. 13449+ */ 13450+ 13451+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_SPLIT_DYNAMIC_PARAMETER_H_ 13452+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_SPLIT_DYNAMIC_PARAMETER_H_ 13453+#include <string> 13454+ 13455+typedef struct SplitDynamicParameter { 13456+ std::string strides_; 13457+ std::string split_count_; 13458+} SplitDynamicParameter; 13459+ 13460+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_SPLIT_DYNAMIC_PARAMETER_H_ 13461diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/strided_slice_dynamic_parameter.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/strided_slice_dynamic_parameter.h 13462new file mode 100644 13463index 00000000..202ee7dd 13464--- /dev/null 13465+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/strided_slice_dynamic_parameter.h 13466@@ -0,0 +1,25 @@ 13467+/** 13468+ * Copyright 2023 Huawei Technologies Co., Ltd 13469+ * 13470+ * Licensed under the Apache License, Version 2.0 (the "License"); 13471+ * you may not use this file except in compliance with the License. 13472+ * You may obtain a copy of the License at 13473+ * 13474+ * http://www.apache.org/licenses/LICENSE-2.0 13475+ * 13476+ * Unless required by applicable law or agreed to in writing, software 13477+ * distributed under the License is distributed on an "AS IS" BASIS, 13478+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13479+ * See the License for the specific language governing permissions and 13480+ * limitations under the License. 13481+ */ 13482+ 13483+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_STRIDED_SLICE_DYNAMIC_PARAMETER_H_ 13484+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_STRIDED_SLICE_DYNAMIC_PARAMETER_H_ 13485+#include <string> 13486+ 13487+typedef struct StridedSliceDynamicParameter { 13488+ std::string in_shape_; 13489+} StridedSliceDynamicParameter; 13490+ 13491+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_STRIDED_SLICE_DYNAMIC_PARAMETER_H_ 13492diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/transpose_dynamic_parameter.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/transpose_dynamic_parameter.h 13493new file mode 100644 13494index 00000000..ed4f21f2 13495--- /dev/null 13496+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/transpose_dynamic_parameter.h 13497@@ -0,0 +1,28 @@ 13498+/** 13499+ * Copyright 2023 Huawei Technologies Co., Ltd 13500+ * 13501+ * Licensed under the Apache License, Version 2.0 (the "License"); 13502+ * you may not use this file except in compliance with the License. 13503+ * You may obtain a copy of the License at 13504+ * 13505+ * http://www.apache.org/licenses/LICENSE-2.0 13506+ * 13507+ * Unless required by applicable law or agreed to in writing, software 13508+ * distributed under the License is distributed on an "AS IS" BASIS, 13509+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13510+ * See the License for the specific language governing permissions and 13511+ * limitations under the License. 13512+ */ 13513+ 13514+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_TRANSPOSE_DYNAMIC_PARAMETER_H_ 13515+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_TRANSPOSE_DYNAMIC_PARAMETER_H_ 13516+#include <string> 13517+ 13518+typedef struct TransposeDynamicParameter { 13519+ // shape correlative 13520+ std::string strides_; 13521+ std::string out_strides_; 13522+ std::string data_num_; 13523+} TransposeDynamicParameter; 13524+ 13525+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_TRANSPOSE_DYNAMIC_PARAMETER_H_ 13526diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/activation_dynamic_fp16_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/activation_dynamic_fp16_coder.cc 13527new file mode 100644 13528index 00000000..86048179 13529--- /dev/null 13530+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/activation_dynamic_fp16_coder.cc 13531@@ -0,0 +1,93 @@ 13532+/** 13533+ * Copyright 2023 Huawei Technologies Co., Ltd 13534+ * 13535+ * Licensed under the Apache License, Version 2.0 (the "License"); 13536+ * you may not use this file except in compliance with the License. 13537+ * You may obtain a copy of the License at 13538+ * 13539+ * http://www.apache.org/licenses/LICENSE-2.0 13540+ * 13541+ * Unless required by applicable law or agreed to in writing, software 13542+ * distributed under the License is distributed on an "AS IS" BASIS, 13543+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13544+ * See the License for the specific language governing permissions and 13545+ * limitations under the License. 13546+ */ 13547+#include "coder/opcoders/nnacl/fp16/activation_dynamic_fp16_coder.h" 13548+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 13549+#include "coder/opcoders/file_collector.h" 13550+#include "coder/utils/coder_utils.h" 13551+#include "tools/common/string_util.h" 13552+ 13553+using mindspore::schema::PrimitiveType_Activation; 13554+ 13555+namespace mindspore::lite::micro::nnacl { 13556+int ActivationDynamicFP16Coder::Prepare(CoderContext *const context) { 13557+ MS_CHECK_TRUE_MSG(input_tensor_->data_type() == kNumberTypeFloat16, RET_INPUT_PARAM_INVALID, 13558+ "Input tensor data type is invalid."); 13559+ MS_CHECK_TRUE_MSG(output_tensor_->data_type() == kNumberTypeFloat16, RET_INPUT_PARAM_INVALID, 13560+ "Output tensor data type is invalid."); 13561+ return RET_OK; 13562+} 13563+ 13564+int ActivationDynamicFP16Coder::DoCode(CoderContext *const context) { 13565+ Collect(context, 13566+ { 13567+ "nnacl/fp16/activation_fp16.h", 13568+ }, 13569+ { 13570+ "activation_fp16.c", 13571+ }); 13572+ NNaclFp32Serializer code; 13573+ // attribute 13574+ auto *activation_parameter = reinterpret_cast<ActivationParameter *>(parameter_); 13575+ MS_CHECK_PTR(activation_parameter); 13576+ auto in_shape = shape_info_container_->GetTemplateShape(input_tensor_); 13577+ count_ = AccumulateShape(in_shape, 0, in_shape.size()); 13578+ input_data_ = dynamic_mem_manager_->GetVarTensorAddr(input_tensor_); 13579+ MS_CHECK_TRUE_MSG(!input_data_.empty(), RET_ERROR, "pointer is not allocated by the allocator"); 13580+ output_data_ = dynamic_mem_manager_->GetVarTensorAddr(output_tensor_); 13581+ MS_CHECK_TRUE_MSG(!output_data_.empty(), RET_ERROR, "pointer is not allocated by the allocator"); 13582+ input_data_ = "(float16_t *)(" + input_data_ + ")"; 13583+ output_data_ = "(float16_t *)(" + output_data_ + ")"; 13584+ 13585+ switch (activation_parameter->type_) { 13586+ case schema::ActivationType_RELU: 13587+ code.CodeFunction("ReluFp16", input_data_, output_data_, count_); 13588+ break; 13589+ case schema::ActivationType_RELU6: 13590+ code.CodeFunction("Relu6Fp16", input_data_, output_data_, count_); 13591+ break; 13592+ case schema::ActivationType_LEAKY_RELU: 13593+ code.CodeFunction("LReluFp16", input_data_, output_data_, count_, activation_parameter->alpha_); 13594+ break; 13595+ case schema::ActivationType_SIGMOID: 13596+ code.CodeFunction("SigmoidFp16", input_data_, output_data_, count_); 13597+ break; 13598+ case schema::ActivationType_TANH: 13599+ code.CodeFunction("TanhFp16", input_data_, output_data_, count_); 13600+ break; 13601+ case schema::ActivationType_HSWISH: 13602+ code.CodeFunction("HSwishFp16", input_data_, output_data_, count_); 13603+ break; 13604+ case schema::ActivationType_SWISH: 13605+ code.CodeFunction("SwishFp16", input_data_, output_data_, count_); 13606+ break; 13607+ case schema::ActivationType_HSIGMOID: 13608+ code.CodeFunction("HSigmoidFp16", input_data_, output_data_, count_); 13609+ break; 13610+ case schema::ActivationType_ELU: 13611+ code.CodeFunction("EluFp16", input_data_, output_data_, count_, activation_parameter->alpha_); 13612+ break; 13613+ default: 13614+ MS_LOG(ERROR) << "Activation type error"; 13615+ return RET_ERROR; 13616+ } 13617+ MS_LOG(DEBUG) << "ActivationFP16Code has been called"; 13618+ context->AppendCode(code.str()); 13619+ return lite::RET_OK; 13620+} 13621+ 13622+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_Activation, 13623+ CPUOpCoderCreator<ActivationDynamicFP16Coder>) 13624+} // namespace mindspore::lite::micro::nnacl 13625diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/activation_dynamic_fp16_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/activation_dynamic_fp16_coder.h 13626new file mode 100644 13627index 00000000..c881567f 13628--- /dev/null 13629+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/activation_dynamic_fp16_coder.h 13630@@ -0,0 +1,37 @@ 13631+/** 13632+ * Copyright 2023 Huawei Technologies Co., Ltd 13633+ * 13634+ * Licensed under the Apache License, Version 2.0 (the "License"); 13635+ * you may not use this file except in compliance with the License. 13636+ * You may obtain a copy of the License at 13637+ * 13638+ * http://www.apache.org/licenses/LICENSE-2.0 13639+ * 13640+ * Unless required by applicable law or agreed to in writing, software 13641+ * distributed under the License is distributed on an "AS IS" BASIS, 13642+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13643+ * See the License for the specific language governing permissions and 13644+ * limitations under the License. 13645+ */ 13646+ 13647+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_ACTIVATION_DYNAMIC_FP16_CODER_H_ 13648+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_ACTIVATION_DYNAMIC_FP16_CODER_H_ 13649+ 13650+#include <vector> 13651+#include "tools/converter/micro/coder/opcoders/nnacl/fp32/activation_dynamic_fp32_coder.h" 13652+ 13653+namespace mindspore::lite::micro::nnacl { 13654+class ActivationDynamicFP16Coder final : public ActivationDynamicFP32Coder { 13655+ public: 13656+ ActivationDynamicFP16Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 13657+ const LiteGraph::Node *node, size_t node_index, Target target) 13658+ : ActivationDynamicFP32Coder(in_tensors, out_tensors, node, node_index, target) {} 13659+ 13660+ ~ActivationDynamicFP16Coder() override = default; 13661+ 13662+ int Prepare(CoderContext *const context) override; 13663+ 13664+ int DoCode(CoderContext *const context) override; 13665+}; 13666+} // namespace mindspore::lite::micro::nnacl 13667+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_ACTIVATION_DYNAMIC_FP16_CODER_H_ 13668diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_dynamic_fp16_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_dynamic_fp16_coder.cc 13669new file mode 100644 13670index 00000000..7050b8b0 13671--- /dev/null 13672+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_dynamic_fp16_coder.cc 13673@@ -0,0 +1,369 @@ 13674+/** 13675+ * Copyright 2023 Huawei Technologies Co., Ltd 13676+ * 13677+ * Licensed under the Apache License, Version 2.0 (the "License"); 13678+ * you may not use this file except in compliance with the License. 13679+ * You may obtain a copy of the License at 13680+ * 13681+ * http://www.apache.org/licenses/LICENSE-2.0 13682+ * 13683+ * Unless required by applicable law or agreed to in writing, software 13684+ * distributed under the License is distributed on an "AS IS" BASIS, 13685+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13686+ * See the License for the specific language governing permissions and 13687+ * limitations under the License. 13688+ */ 13689+#include "coder/opcoders/nnacl/fp16/arithmetic_dynamic_fp16_coder.h" 13690+#include "coder/opcoders/file_collector.h" 13691+#include "coder/opcoders/parallel.h" 13692+#include "coder/log.h" 13693+#include "coder/utils/coder_utils.h" 13694+#include "tools/common/string_util.h" 13695+ 13696+namespace mindspore::lite::micro::nnacl { 13697+namespace { 13698+std::string wrap_void(const std::string &a) { return "(void *)(" + a + ")"; } 13699+} // namespace 13700+ 13701+void ArithmeticDynamicFP16Coder::InitFunTable() { 13702+ fun_table_ = { 13703+ {PrimitiveType_MulFusion, schema::ActivationType_RELU, "ElementMulReluFp16", "", "", "", ""}, 13704+ {PrimitiveType_MulFusion, schema::ActivationType_RELU6, "ElementMulRelu6Fp16", "", "", "", ""}, 13705+ {PrimitiveType_MulFusion, schema::ActivationType_NO_ACTIVATION, "ElementMulFp16", "", "", "", ""}, 13706+ {PrimitiveType_AddFusion, schema::ActivationType_RELU, "ElementAddReluFp16", "", "", "", ""}, 13707+ {PrimitiveType_AddFusion, schema::ActivationType_RELU6, "ElementAddRelu6Fp16", "", "", "", ""}, 13708+ {PrimitiveType_AddFusion, schema::ActivationType_NO_ACTIVATION, "ElementAddFp16", "", "", "", ""}, 13709+ {PrimitiveType_SubFusion, schema::ActivationType_RELU, "ElementSubReluFp16", "", "", "", ""}, 13710+ {PrimitiveType_SubFusion, schema::ActivationType_RELU6, "ElementSubRelu6Fp16", "", "", "", ""}, 13711+ {PrimitiveType_SubFusion, schema::ActivationType_NO_ACTIVATION, "ElementSubFp16", "", "", "", ""}, 13712+ {PrimitiveType_DivFusion, schema::ActivationType_RELU, "ElementDivReluFp16", "", "", "", ""}, 13713+ {PrimitiveType_DivFusion, schema::ActivationType_RELU6, "ElementDivRelu6Fp16", "", "", "", ""}, 13714+ {PrimitiveType_DivFusion, schema::ActivationType_NO_ACTIVATION, "ElementDivFp16", "", "", "", ""}, 13715+ {PrimitiveType_RealDiv, schema::ActivationType_RELU, "ElementDivReluFp16", "", "", "", ""}, 13716+ {PrimitiveType_RealDiv, schema::ActivationType_RELU6, "ElementDivRelu6Fp16", "", "", "", ""}, 13717+ {PrimitiveType_RealDiv, schema::ActivationType_NO_ACTIVATION, "ElementDivFp16", "", "", "", ""}, 13718+ {PrimitiveType_LogicalAnd, schema::ActivationType_NO_ACTIVATION, "ElementLogicalAndFp16", "", "", "", ""}, 13719+ {PrimitiveType_LogicalOr, schema::ActivationType_NO_ACTIVATION, "ElementLogicalOrFp16", "", "", "", ""}, 13720+ {PrimitiveType_Maximum, schema::ActivationType_NO_ACTIVATION, "ElementMaximumFp16", "", "", "", ""}, 13721+ {PrimitiveType_Minimum, schema::ActivationType_NO_ACTIVATION, "ElementMinimumFp16", "", "", "", ""}, 13722+ {PrimitiveType_FloorMod, schema::ActivationType_NO_ACTIVATION, "ElementFloorModFp16", "", "", "", ""}, 13723+ {PrimitiveType_FloorDiv, schema::ActivationType_NO_ACTIVATION, "ElementFloorDivFp16", "", "", "", ""}, 13724+ {PrimitiveType_SquaredDifference, schema::ActivationType_NO_ACTIVATION, "ElementSquaredDifferenceFp16", "", "", "", 13725+ ""}}; 13726+} 13727+ 13728+int ArithmeticDynamicFP16Coder::Prepare(CoderContext *const context) { 13729+ CHECK_LESS_RETURN(input_tensors_.size(), C2NUM); 13730+ CHECK_LESS_RETURN(output_tensors_.size(), 1); 13731+ for (size_t i = 0; i < input_tensors_.size(); ++i) { 13732+ MS_CHECK_TRUE_MSG(input_tensors_[i]->data_type() == kNumberTypeFloat16, RET_INPUT_PARAM_INVALID, 13733+ "Tensor data type is invalid"); 13734+ } 13735+ MS_CHECK_TRUE_MSG(output_tensor_->data_type() == kNumberTypeFloat16, RET_INPUT_PARAM_INVALID, 13736+ "Tensor data type is invalid"); 13737+ filter_tensor_ = input_tensors_.at(SECOND_INPUT); 13738+ MS_CHECK_PTR(filter_tensor_); 13739+ param_ = reinterpret_cast<ArithmeticParameter *>(parameter_); 13740+ MS_CHECK_PTR(param_); 13741+ auto primitive_type = param_->op_parameter_.type_; 13742+ if (primitive_type == schema::PrimitiveType_Eltwise) { 13743+ switch (param_->eltwise_mode_) { 13744+ case schema::EltwiseMode_PROD: 13745+ primitive_type = schema::PrimitiveType_MulFusion; 13746+ break; 13747+ case schema::EltwiseMode_SUM: 13748+ primitive_type = schema::PrimitiveType_AddFusion; 13749+ break; 13750+ case schema::EltwiseMode_MAXIMUM: 13751+ primitive_type = schema::PrimitiveType_Maximum; 13752+ break; 13753+ default: 13754+ MS_LOG(ERROR) << "Eltwise mode not support, mode:" << param_->eltwise_mode_; 13755+ return RET_ERROR; 13756+ } 13757+ } 13758+ InitRunFunction(primitive_type); 13759+ InitDynamicParams(); 13760+ ResetStatus(); 13761+ CalcMultiplesAndStrides(); 13762+ return RET_OK; 13763+} 13764+ 13765+int ArithmeticDynamicFP16Coder::DoCode(CoderContext *const context) { 13766+ input0_ptr_str_ = GetTensorAddr(input_tensor_, input_tensor_->IsConst(), dynamic_mem_manager_, allocator_); 13767+ input1_ptr_str_ = GetTensorAddr(filter_tensor_, filter_tensor_->IsConst(), dynamic_mem_manager_, allocator_); 13768+ output_ptr_str_ = GetTensorAddr(output_tensor_, output_tensor_->IsConst(), dynamic_mem_manager_, allocator_); 13769+ NNaclFp32Serializer code; 13770+ Collect(context, 13771+ { 13772+ "nnacl/fp16/arithmetic_fp16.h", 13773+ "nnacl/base/broadcast_to.h", 13774+ }, 13775+ { 13776+ "arithmetic_fp16.c", 13777+ "arithmetic_base.c", 13778+ "broadcast_to.c", 13779+ }); 13780+ 13781+ // all elements eltwise calculation 13782+ arithmetic_func_str_ = wrap_void(arithmetic_run_); 13783+ // run broadcast 13784+ auto in0_shape = shape_info_container_->GetTemplateShape(input_tensor_); 13785+ std::vector<std::string> in1_shape; 13786+ if (filter_tensor_->IsConst()) { 13787+ for (auto dim : filter_tensor_->shape()) { 13788+ in1_shape.emplace_back(std::to_string(dim)); 13789+ } 13790+ } else { 13791+ in1_shape = shape_info_container_->GetTemplateShape(filter_tensor_); 13792+ } 13793+ auto out_shape = shape_info_container_->GetTemplateShape(output_tensor_); 13794+ broadcast_info_.output_shape_size_ = static_cast<int>(out_shape_.size()); 13795+ if (in0_shape != out_shape) { 13796+ broadcast_info_.input_shape_size_ = static_cast<int>(in0_shape.size()); 13797+ dynamic_shape_info_.input_shape_ = dynamic_param_.in_shape0_; 13798+ dynamic_shape_info_.output_shape_ = dynamic_param_.out_shape_; 13799+ code.CodeStruct("in0_broadcast_info", broadcast_info_, dynamic_shape_info_); 13800+ code.CodeFunction("BroadcastToSize16", input0_ptr_str_, "&in0_broadcast_info", output_ptr_str_); 13801+ input0_ptr_str_ = output_ptr_str_; 13802+ } 13803+ if (in1_shape != out_shape) { 13804+ broadcast_info_.input_shape_size_ = static_cast<int>(in1_shape.size()); 13805+ dynamic_shape_info_.input_shape_ = dynamic_param_.in_shape1_; 13806+ dynamic_shape_info_.output_shape_ = dynamic_param_.out_shape_; 13807+ code.CodeStruct("in1_broadcast_info", broadcast_info_, dynamic_shape_info_); 13808+ auto temp = output_ptr_str_; 13809+ if (input0_ptr_str_ == output_ptr_str_) { 13810+ std::map<std::string, std::vector<int>> real_nums; 13811+ size_t scene_num = 0; 13812+ for (auto &dim_template : out_shape) { 13813+ auto dim_nums = shape_info_container_->GetRealNums(dim_template); 13814+ MS_CHECK_TRUE_MSG(!dim_nums.empty(), RET_ERROR, "Dynamic shape's num must be greater than 0."); 13815+ real_nums[dim_template] = dim_nums; 13816+ scene_num = std::max(scene_num, dim_nums.size()); 13817+ } 13818+ for (size_t i = 0; i < scene_num; ++i) { 13819+ int out_element_num = 1; 13820+ for (size_t j = 0; j < out_shape.size(); ++j) { 13821+ if (IsNumber(out_shape[j])) { 13822+ out_element_num *= std::stoi(out_shape[j]); 13823+ } else { 13824+ out_element_num *= real_nums[out_shape[j]][i % real_nums[out_shape[j]].size()]; 13825+ } 13826+ } 13827+ int workspace = out_element_num * DataTypeSize(kNumberTypeFloat16); 13828+ temp = dynamic_mem_manager_->AllocWorkSpace(workspace, i); 13829+ MS_CHECK_TRUE_MSG(!temp.empty(), RET_ERROR, "Arithmetic cannot alloc workspace."); 13830+ } 13831+ } 13832+ code.CodeFunction("BroadcastToSize16", input1_ptr_str_, "&in1_broadcast_info", temp); 13833+ input1_ptr_str_ = temp; 13834+ } 13835+ return ExecuteCode("(float16_t *)(" + input0_ptr_str_ + ")", "(float16_t *)(" + input1_ptr_str_ + ")", 13836+ "(float16_t *)(" + output_ptr_str_ + ")", dynamic_param_.out_elements_num_, context, &code); 13837+} 13838+ 13839+void ArithmeticDynamicFP16Coder::InitDynamicParams() { 13840+ auto in0_shape = shape_info_container_->GetTemplateShape(input_tensor_); 13841+ std::vector<std::string> in1_shape; 13842+ if (filter_tensor_->IsConst()) { 13843+ for (auto dim : filter_tensor_->shape()) { 13844+ in1_shape.emplace_back(std::to_string(dim)); 13845+ } 13846+ } else { 13847+ in1_shape = shape_info_container_->GetTemplateShape(filter_tensor_); 13848+ } 13849+ auto out_shape = shape_info_container_->GetTemplateShape(output_tensor_); 13850+ dynamic_param_.in_shape0_ = "{"; 13851+ dynamic_param_.in_shape1_ = "{"; 13852+ dynamic_param_.out_shape_ = "{"; 13853+ for (auto shape : in0_shape) { 13854+ dynamic_param_.in_shape0_ += shape + ", "; 13855+ } 13856+ for (auto shape : in1_shape) { 13857+ dynamic_param_.in_shape1_ += shape + ", "; 13858+ } 13859+ for (auto shape : out_shape) { 13860+ dynamic_param_.out_shape_ += shape + ", "; 13861+ } 13862+ dynamic_param_.in_shape0_ += "}"; 13863+ dynamic_param_.in_shape1_ += "}"; 13864+ dynamic_param_.out_shape_ += "}"; 13865+ dynamic_param_.in_elements_num0_ = AccumulateShape(in0_shape, 0, in0_shape.size()); 13866+ dynamic_param_.in_elements_num1_ = AccumulateShape(in1_shape, 0, in1_shape.size()); 13867+ dynamic_param_.out_elements_num_ = AccumulateShape(out_shape, 0, out_shape.size()); 13868+} 13869+ 13870+void ArithmeticDynamicFP16Coder::InitRunFunction(int primitive_type) { 13871+ InitFunTable(); 13872+ for (size_t i = 0; i < fun_table_.size(); i++) { 13873+ if (fun_table_[i].primitive_type_ == primitive_type && fun_table_[i].activation_type_ == param_->activation_type_) { 13874+ arithmetic_run_ = fun_table_[i].func_; 13875+ arithmetic_run_int_ = fun_table_[i].int_func_; 13876+ arithmetic_run_bool_ = fun_table_[i].bool_func_; 13877+ arithmetic_opt_run_ = fun_table_[i].opt_func_; 13878+ arithmetic_opt_run_int_ = fun_table_[i].opt_int_func_; 13879+ } 13880+ } 13881+ arithmetic_func_type_ = kArithmeticFuncFloat; 13882+} 13883+ 13884+void ArithmeticDynamicFP16Coder::ResetStatus() { 13885+ auto input_shape = shape_info_container_->GetTemplateShape(input_tensor_); 13886+ std::vector<std::string> filter_shape; 13887+ if (filter_tensor_->IsConst()) { 13888+ for (auto dim : filter_tensor_->shape()) { 13889+ filter_shape.emplace_back(std::to_string(dim)); 13890+ } 13891+ } else { 13892+ filter_shape = shape_info_container_->GetTemplateShape(filter_tensor_); 13893+ } 13894+ auto dim_num = input_shape.size() >= filter_shape.size() ? input_shape.size() : filter_shape.size(); 13895+ for (size_t i = 0; i < dim_num - input_shape.size(); ++i) { 13896+ in0_shape_.emplace_back("1"); 13897+ } 13898+ in0_shape_.insert(in0_shape_.end(), input_shape.begin(), input_shape.end()); 13899+ for (size_t i = 0; i < dim_num - filter_shape.size(); ++i) { 13900+ in1_shape_.emplace_back("1"); 13901+ } 13902+ in1_shape_.insert(in1_shape_.end(), filter_shape.begin(), filter_shape.end()); 13903+} 13904+ 13905+void ArithmeticDynamicFP16Coder::CalcMultiplesAndStrides() { 13906+ out_shape_ = shape_info_container_->GetTemplateShape(output_tensor_); 13907+ dynamic_param_.multiples0_ = "{"; 13908+ dynamic_param_.multiples1_ = "{"; 13909+ for (size_t i = 0; i < param_->ndim_; i++) { 13910+ if (in0_shape_[i] != "0") { 13911+ dynamic_param_.multiples0_ += out_shape_[i] + " / " + in0_shape_[i] + ", "; 13912+ } 13913+ if (in1_shape_[i] != "0") { 13914+ dynamic_param_.multiples1_ += out_shape_[i] + " / " + in1_shape_[i] + ", "; 13915+ } 13916+ } 13917+ dynamic_param_.multiples0_ += "}"; 13918+ dynamic_param_.multiples1_ += "}"; 13919+ 13920+ // cal strides 13921+ in0_strides_.resize(param_->ndim_); 13922+ in1_strides_.resize(param_->ndim_); 13923+ out_strides_.resize(param_->ndim_); 13924+ ComputeStrides(in0_shape_, in0_strides_); 13925+ ComputeStrides(in1_shape_, in1_strides_); 13926+ ComputeStrides(out_shape_, out_strides_); 13927+ dynamic_param_.in_strides0_ = "{"; 13928+ dynamic_param_.in_strides1_ = "{"; 13929+ dynamic_param_.out_strides_ = "{"; 13930+ for (size_t i = 0; i < param_->ndim_; ++i) { 13931+ dynamic_param_.in_strides0_ += in0_strides_[i] + ", "; 13932+ dynamic_param_.in_strides1_ += in1_strides_[i] + ", "; 13933+ dynamic_param_.out_strides_ += out_strides_[i] + ", "; 13934+ } 13935+ dynamic_param_.in_strides0_ += "}"; 13936+ dynamic_param_.in_strides1_ += "}"; 13937+ dynamic_param_.out_strides_ += "}"; 13938+} 13939+ 13940+void ArithmeticDynamicFP16Coder::ComputeStrides(const std::vector<std::string> &shape, 13941+ std::vector<std::string> &strides) { 13942+ std::string stride = "1"; 13943+ for (int i = param_->ndim_ - 1; i >= 0; i--) { 13944+ strides[i] = stride; 13945+ stride += "*=" + shape[i]; 13946+ } 13947+} 13948+ 13949+int ArithmeticDynamicFP16Coder::ExecuteCode(const std::string &input0, const std::string &input1, 13950+ const std::string &output, const std::string size, 13951+ CoderContext *const context, NNaclFp32Serializer *const code) { 13952+ if (arithmetic_func_str_.empty()) { 13953+ return RET_ERROR; 13954+ } 13955+ for (size_t i = 0; i < fun_table_.size(); i++) { 13956+ if (fun_table_[i].primitive_type_ == param_->op_parameter_.type_ && 13957+ fun_table_[i].activation_type_ == param_->activation_type_) { 13958+ code->CodeFunction(fun_table_[i].func_, input0, input1, output, size); 13959+ break; 13960+ } 13961+ } 13962+ context->AppendCode(code->str()); 13963+ return RET_OK; 13964+} 13965+ 13966+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_AddFusion, 13967+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 13968+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_MulFusion, 13969+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 13970+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_SubFusion, 13971+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 13972+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_DivFusion, 13973+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 13974+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_RealDiv, 13975+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 13976+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_LogicalAnd, 13977+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 13978+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_LogicalOr, 13979+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 13980+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_Maximum, 13981+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 13982+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_Minimum, 13983+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 13984+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_FloorDiv, 13985+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 13986+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_FloorMod, 13987+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 13988+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_SquaredDifference, 13989+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 13990+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_Equal, 13991+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 13992+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_NotEqual, 13993+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 13994+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_Less, 13995+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 13996+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_LessEqual, 13997+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 13998+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_Greater, 13999+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14000+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_GreaterEqual, 14001+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14002+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_Eltwise, 14003+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14004+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_AddFusion, 14005+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14006+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_MulFusion, 14007+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14008+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_SubFusion, 14009+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14010+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_DivFusion, 14011+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14012+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_RealDiv, 14013+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14014+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_LogicalAnd, 14015+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14016+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_LogicalOr, 14017+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14018+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_Maximum, 14019+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14020+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_Minimum, 14021+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14022+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_FloorDiv, 14023+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14024+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_FloorMod, 14025+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14026+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_SquaredDifference, 14027+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14028+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_Equal, 14029+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14030+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_NotEqual, 14031+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14032+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_Less, 14033+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14034+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_LessEqual, 14035+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14036+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_Greater, 14037+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14038+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_GreaterEqual, 14039+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14040+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_Eltwise, 14041+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14042+} // namespace mindspore::lite::micro::nnacl 14043diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_dynamic_fp16_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_dynamic_fp16_coder.h 14044new file mode 100644 14045index 00000000..87e43687 14046--- /dev/null 14047+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_dynamic_fp16_coder.h 14048@@ -0,0 +1,132 @@ 14049+/** 14050+ * Copyright 2023 Huawei Technologies Co., Ltd 14051+ * 14052+ * Licensed under the Apache License, Version 2.0 (the "License"); 14053+ * you may not use this file except in compliance with the License. 14054+ * You may obtain a copy of the License at 14055+ * 14056+ * http://www.apache.org/licenses/LICENSE-2.0 14057+ * 14058+ * Unless required by applicable law or agreed to in writing, software 14059+ * distributed under the License is distributed on an "AS IS" BASIS, 14060+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14061+ * See the License for the specific language governing permissions and 14062+ * limitations under the License. 14063+ */ 14064+ 14065+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_ARITHMETIC_DYNAMIC_FP16_CODER_H_ 14066+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_ARITHMETIC_DYNAMIC_FP16_CODER_H_ 14067+ 14068+#include <vector> 14069+#include <string> 14070+#include "coder/opcoders/op_coder.h" 14071+#include "nnacl/base/cast_base.h" 14072+#include "nnacl/arithmetic_parameter.h" 14073+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 14074+#include "coder/opcoders/nnacl/dynamic_parameter/arithmetic_dynamic_parameter.h" 14075+#include "nnacl/broadcast_to_parameter.h" 14076+ 14077+namespace mindspore::lite::micro::nnacl { 14078+using mindspore::schema::PrimitiveType_AddFusion; 14079+using mindspore::schema::PrimitiveType_DivFusion; 14080+using mindspore::schema::PrimitiveType_Eltwise; 14081+using mindspore::schema::PrimitiveType_Equal; 14082+using mindspore::schema::PrimitiveType_FloorDiv; 14083+using mindspore::schema::PrimitiveType_FloorMod; 14084+using mindspore::schema::PrimitiveType_Greater; 14085+using mindspore::schema::PrimitiveType_GreaterEqual; 14086+using mindspore::schema::PrimitiveType_Less; 14087+using mindspore::schema::PrimitiveType_LessEqual; 14088+using mindspore::schema::PrimitiveType_LogicalAnd; 14089+using mindspore::schema::PrimitiveType_LogicalOr; 14090+using mindspore::schema::PrimitiveType_Maximum; 14091+using mindspore::schema::PrimitiveType_Minimum; 14092+using mindspore::schema::PrimitiveType_Mod; 14093+using mindspore::schema::PrimitiveType_MulFusion; 14094+using mindspore::schema::PrimitiveType_NotEqual; 14095+using mindspore::schema::PrimitiveType_RealDiv; 14096+using mindspore::schema::PrimitiveType_SquaredDifference; 14097+using mindspore::schema::PrimitiveType_SubFusion; 14098+ 14099+class ArithmeticDynamicFP16Coder final : public OperatorCoder { 14100+ typedef struct { 14101+ int primitive_type_; 14102+ int activation_type_; 14103+ std::string func_; 14104+ std::string int_func_; 14105+ std::string bool_func_; 14106+ std::string opt_func_; 14107+ std::string opt_int_func_; 14108+ } ARITHMETIC_FUNC_INFO_FP16; 14109+ 14110+ // typedef struct MATRIC_INFO { 14111+ // bool is_const{false}; 14112+ // bool is_valid{false}; 14113+ // void *data{nullptr}; 14114+ // int64_t inner_size{1}; // the element num of once batch 14115+ // std::vector<int64_t> shape; 14116+ // std::vector<int64_t> batch_post_sum; 14117+ // void Reset() { 14118+ // is_valid = false; 14119+ // data = nullptr; 14120+ // inner_size = 1; 14121+ // shape.clear(); 14122+ // batch_post_sum.clear(); 14123+ // } 14124+ // } MATRIC_INFO; 14125+ 14126+ public: 14127+ ArithmeticDynamicFP16Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 14128+ const LiteGraph::Node *node, size_t node_index, Target target) 14129+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 14130+ 14131+ ~ArithmeticDynamicFP16Coder() override = default; 14132+ 14133+ int DoCode(CoderContext *const context) override; 14134+ 14135+ private: 14136+ int Prepare(CoderContext *const context) override; 14137+ 14138+ void InitFunTable(); 14139+ 14140+ void InitRunFunction(int primitive_type); 14141+ 14142+ void InitDynamicParams(); 14143+ 14144+ void ResetStatus(); 14145+ 14146+ void CalcMultiplesAndStrides(); 14147+ 14148+ void ComputeStrides(const std::vector<std::string> &shape, std::vector<std::string> &strides); 14149+ 14150+ int ExecuteCode(const std::string &input0, const std::string &input1, const std::string &output, 14151+ const std::string size, CoderContext *const context, NNaclFp32Serializer *const code); 14152+ 14153+ std::vector<ARITHMETIC_FUNC_INFO_FP16> fun_table_; 14154+ ArithmeticFuncType arithmetic_func_type_{kArithmeticFuncUnknow}; 14155+ ArithmeticParameter *param_{nullptr}; 14156+ ArithmeticDynamicParameter dynamic_param_; 14157+ BroadcastShapeInfo broadcast_info_; 14158+ BroadcastDynamicShapeInfo dynamic_shape_info_; 14159+ Tensor *filter_tensor_{nullptr}; 14160+ std::string input0_ptr_str_; 14161+ std::string input1_ptr_str_; 14162+ std::string output_ptr_str_; 14163+ std::string arithmetic_run_; 14164+ std::string arithmetic_run_int_; 14165+ std::string arithmetic_opt_run_; 14166+ std::string arithmetic_opt_run_int_; 14167+ std::string arithmetic_run_bool_; 14168+ std::string arithmetic_func_str_; 14169+ std::vector<std::string> in0_shape_; 14170+ std::vector<std::string> in1_shape_; 14171+ std::vector<std::string> out_shape_; 14172+ std::vector<std::string> in0_strides_; 14173+ std::vector<std::string> in1_strides_; 14174+ std::vector<std::string> out_strides_; 14175+ // MATRIC_INFO a_matric_; 14176+ // MATRIC_INFO b_matric_; 14177+ // MATRIC_INFO c_matric_; 14178+}; 14179+} // namespace mindspore::lite::micro::nnacl 14180+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_ARITHMETIC_DYNAMIC_FP16_CODER_H_ 14181diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_dynamic_fp16_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_dynamic_fp16_coder.cc 14182new file mode 100644 14183index 00000000..bf8bd06b 14184--- /dev/null 14185+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_dynamic_fp16_coder.cc 14186@@ -0,0 +1,92 @@ 14187+/** 14188+ * Copyright 2023 Huawei Technologies Co., Ltd 14189+ * 14190+ * Licensed under the Apache License, Version 2.0 (the "License"); 14191+ * you may not use this file except in compliance with the License. 14192+ * You may obtain a copy of the License at 14193+ * 14194+ * http://www.apache.org/licenses/LICENSE-2.0 14195+ * 14196+ * Unless required by applicable law or agreed to in writing, software 14197+ * distributed under the License is distributed on an "AS IS" BASIS, 14198+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14199+ * See the License for the specific language governing permissions and 14200+ * limitations under the License. 14201+ */ 14202+#include "coder/opcoders/nnacl/fp16/concat_dynamic_fp16_coder.h" 14203+#include <string> 14204+#include <vector> 14205+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 14206+#include "coder/opcoders/file_collector.h" 14207+#include "coder/opcoders/parallel.h" 14208+#include "coder/utils/coder_utils.h" 14209+ 14210+using mindspore::schema::PrimitiveType_Concat; 14211+ 14212+namespace mindspore::lite::micro::nnacl { 14213+int ConcatDynamicFP16Coder::Prepare(CoderContext *const context) { 14214+ for (size_t i = 0; i < input_tensors_.size(); ++i) { 14215+ MS_CHECK_TRUE_MSG(input_tensors_.at(i)->data_type() == kNumberTypeFloat16, RET_INPUT_PARAM_INVALID, 14216+ "input tensor data type is invalid."); 14217+ } 14218+ concat_param_ = reinterpret_cast<ConcatParameter *>(parameter_); 14219+ MS_CHECK_PTR(concat_param_); 14220+ auto input_shape = shape_info_container_->GetTemplateShape(input_tensor_); 14221+ axis_ = 14222+ concat_param_->axis_ >= 0 ? concat_param_->axis_ : static_cast<int>(input_shape.size()) + concat_param_->axis_; 14223+ return RET_OK; 14224+} 14225+ 14226+int ConcatDynamicFP16Coder::DoCode(CoderContext *const context) { 14227+ Collect(context, 14228+ { 14229+ "nnacl/base/concat_base.h", 14230+ }, 14231+ { 14232+ "concat_base.c", 14233+ }); 14234+ 14235+ size_t input_num = input_tensors_.size(); 14236+ 14237+ NNaclFp32Serializer code; 14238+ code << "\t\tvoid *inputs_addr[] = {"; 14239+ for (size_t i = 0; i < input_num; ++i) { 14240+ code << "(void *)(" 14241+ << GetTensorAddr(input_tensors_.at(i), input_tensors_.at(i)->IsConst(), dynamic_mem_manager_, allocator_) 14242+ << "), "; 14243+ } 14244+ code << "};\n"; 14245+ 14246+ size_t i; 14247+ for (i = 0; i < input_num; ++i) { 14248+ code << "\t\tint shape_" << i << "[] = {"; 14249+ auto in_shape = shape_info_container_->GetTemplateShape(input_tensors_.at(i)); 14250+ for (auto &shape : in_shape) { 14251+ code << shape << ", "; 14252+ } 14253+ code << "};\n"; 14254+ } 14255+ 14256+ auto out_shape = shape_info_container_->GetTemplateShape(output_tensor_); 14257+ code << "\t\tint shape_" << i << "[] = {"; 14258+ for (auto &shape : out_shape) { 14259+ code << shape << ", "; 14260+ } 14261+ code << "};\n"; 14262+ 14263+ code << "\t\tint *inputs_output_shape[] = {"; 14264+ for (i = 0; i <= input_num; ++i) { 14265+ code << "shape_" << i << ", "; 14266+ } 14267+ code << "};\n"; 14268+ std::string output_data = GetTensorAddr(output_tensor_, output_tensor_->IsConst(), dynamic_mem_manager_, allocator_); 14269+ code.CodeFunction("Concat", "inputs_addr", input_num, axis_, "inputs_output_shape", out_shape.size(), output_data, 0, 14270+ 1, sizeof(uint16_t)); 14271+ context->AppendCode(code.str()); 14272+ return RET_OK; 14273+} 14274+ 14275+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_Concat, CPUOpCoderCreator<ConcatDynamicFP16Coder>) 14276+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_Concat, CPUOpCoderCreator<ConcatDynamicFP16Coder>) 14277+ 14278+} // namespace mindspore::lite::micro::nnacl 14279diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_dynamic_fp16_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_dynamic_fp16_coder.h 14280new file mode 100644 14281index 00000000..bd1b7ff6 14282--- /dev/null 14283+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_dynamic_fp16_coder.h 14284@@ -0,0 +1,40 @@ 14285+/** 14286+ * Copyright 2023 Huawei Technologies Co., Ltd 14287+ * 14288+ * Licensed under the Apache License, Version 2.0 (the "License"); 14289+ * you may not use this file except in compliance with the License. 14290+ * You may obtain a copy of the License at 14291+ * 14292+ * http://www.apache.org/licenses/LICENSE-2.0 14293+ * 14294+ * Unless required by applicable law or agreed to in writing, software 14295+ * distributed under the License is distributed on an "AS IS" BASIS, 14296+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14297+ * See the License for the specific language governing permissions and 14298+ * limitations under the License. 14299+ */ 14300+ 14301+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_CONCAT_DYNAMIC_FP16_CODER_H_ 14302+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_CONCAT_DYNAMIC_FP16_CODER_H_ 14303+ 14304+#include <vector> 14305+#include "coder/opcoders/op_coder.h" 14306+#include "nnacl/concat_parameter.h" 14307+ 14308+namespace mindspore::lite::micro::nnacl { 14309+class ConcatDynamicFP16Coder final : public OperatorCoder { 14310+ public: 14311+ ConcatDynamicFP16Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 14312+ const LiteGraph::Node *node, size_t node_index, Target target) 14313+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 14314+ ~ConcatDynamicFP16Coder() override = default; 14315+ 14316+ int Prepare(CoderContext *const context) override; 14317+ int DoCode(CoderContext *const context) override; 14318+ 14319+ private: 14320+ int axis_{0}; 14321+ ConcatParameter *concat_param_{nullptr}; 14322+}; 14323+} // namespace mindspore::lite::micro::nnacl 14324+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_CONCAT_DYNAMIC_FP16_CODER_H_ 14325diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_dynamic_fp16_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_dynamic_fp16_coder.cc 14326new file mode 100644 14327index 00000000..2f4e42e7 14328--- /dev/null 14329+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_dynamic_fp16_coder.cc 14330@@ -0,0 +1,155 @@ 14331+/** 14332+ * Copyright 2023 Huawei Technologies Co., Ltd 14333+ * 14334+ * Licensed under the Apache License, Version 2.0 (the "License"); 14335+ * you may not use this file except in compliance with the License. 14336+ * You may obtain a copy of the License at 14337+ * 14338+ * http://www.apache.org/licenses/LICENSE-2.0 14339+ * 14340+ * Unless required by applicable law or agreed to in writing, software 14341+ * distributed under the License is distributed on an "AS IS" BASIS, 14342+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14343+ * See the License for the specific language governing permissions and 14344+ * limitations under the License. 14345+ */ 14346+ 14347+#include "coder/opcoders/nnacl/fp16/conv2d_delegate_dynamic_fp16_coder.h" 14348+#include "src/common/version_manager.h" 14349+#include "src/common/tensor_util.h" 14350+#include "src/common/ops/populate/populate_register.h" 14351+#include "nnacl/fp32/winograd_utils.h" 14352+#include "nnacl/base/conv_common_base.h" 14353+#include "nnacl/infer/conv2d_infer.h" 14354+#include "coder/shape_info_container.h" 14355+#include "coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.h" 14356+#include "coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.h" 14357+ 14358+using mindspore::schema::PrimitiveType_Conv2DFusion; 14359+namespace mindspore::lite::micro::nnacl { 14360+int ConvDelegateDynamicFP16Coder::Prepare(CoderContext *const context) { 14361+ for (size_t i = 0; i < input_tensors_.size(); ++i) { 14362+ MS_CHECK_TRUE_MSG(input_tensors_[i]->data_type() == kNumberTypeFloat16, RET_INPUT_PARAM_INVALID, 14363+ "Input tensor data type is invalid"); 14364+ } 14365+ for (size_t i = 0; i < output_tensors_.size(); ++i) { 14366+ MS_CHECK_TRUE_MSG(output_tensors_[i]->data_type() == kNumberTypeFloat16, RET_INPUT_PARAM_INVALID, 14367+ "Output tensor data type is invalid"); 14368+ } 14369+ // Update shape info of input and output 14370+ ConvDynamicParameter dynamic_param; 14371+ SetInputOutputShapeInfo(reinterpret_cast<ConvParameter *>(parameter_), dynamic_param, input_tensor_, output_tensor_); 14372+ if (conv_coder_ == nullptr) { 14373+ // need to select actual execute coder here 14374+ conv_coder_ = 14375+ CPUConvFP16DynamicCoderSelect(input_tensors_, output_tensors_, node_, node_index(), target_, schema_version_); 14376+ MS_CHECK_PTR(conv_coder_); 14377+ ConvParameter *op_parameter = static_cast<ConvParameter *>(malloc(sizeof(ConvParameter))); 14378+ if (op_parameter == nullptr) { 14379+ MS_LOG(ERROR) << "malloc ConvParameter failed."; 14380+ return RET_ERROR; 14381+ } 14382+ if (memcpy_s(op_parameter, sizeof(ConvParameter), parameter_, sizeof(ConvParameter)) != EOK) { 14383+ MS_LOG(ERROR) << "memcpy_s failed."; 14384+ free(op_parameter); 14385+ return RET_ERROR; 14386+ } 14387+ conv_coder_->set_type(GetPrimitiveType(node_->primitive_, schema_version_)); 14388+ conv_coder_->set_thread_num(thread_num_); 14389+ conv_coder_->set_parameter(reinterpret_cast<OpParameter *>(op_parameter)); 14390+ conv_coder_->set_shape_info_container(shape_info_container_); 14391+ conv_coder_->set_dynamic_mem_manager(dynamic_mem_manager_); 14392+ } 14393+ return conv_coder_->Prepare(context); 14394+} 14395+ 14396+int ConvDelegateDynamicFP16Coder::DoCode(CoderContext *const context) { return conv_coder_->DoCode(context); } 14397+ 14398+void ConvDelegateDynamicFP16Coder::SetInputOutputShapeInfo(ConvParameter *conv_param, 14399+ ConvDynamicParameter &dynamic_param, 14400+ const lite::Tensor *input, const lite::Tensor *output) { 14401+ dynamic_param.input_batch_ = shape_info_container_->GetTemplateShape(input_tensor_).at(0); 14402+ conv_param->input_h_ = input->Height(); 14403+ conv_param->input_w_ = input->Width(); 14404+ conv_param->input_channel_ = input->Channel(); 14405+ dynamic_param.output_batch_ = shape_info_container_->GetTemplateShape(output_tensor_).at(0); 14406+ conv_param->output_h_ = output->Height(); 14407+ conv_param->output_w_ = output->Width(); 14408+ conv_param->output_channel_ = output->Channel(); 14409+} 14410+ 14411+std::unique_ptr<OperatorCoder> CPUConvFP16DynamicCoderSelect(const std::vector<lite::Tensor *> &in_tensors, 14412+ const std::vector<lite::Tensor *> &out_tensors, 14413+ const LiteGraph::Node *node, size_t node_index, 14414+ Target target, int schema_version) { 14415+ const void *primitive = node->primitive_; 14416+ if (primitive == nullptr) { 14417+ return nullptr; 14418+ } 14419+ ParameterGen paramGen = PopulateRegistry::GetInstance()->GetParameterCreator( 14420+ GetPrimitiveType(node->primitive_, schema_version), schema_version); 14421+ MS_CHECK_PTR_RET_NULL(paramGen); 14422+ auto conv_param = reinterpret_cast<ConvParameter *>(paramGen(node->primitive_)); 14423+ MS_CHECK_PTR_RET_NULL(conv_param); 14424+ int kernel_h = conv_param->kernel_h_; 14425+ int kernel_w = conv_param->kernel_w_; 14426+ conv_param->input_h_ = in_tensors.at(kInputIndex)->Height(); 14427+ conv_param->input_w_ = in_tensors.at(kInputIndex)->Width(); 14428+ conv_param->input_channel_ = in_tensors.at(kInputIndex)->Channel(); 14429+ conv_param->output_h_ = out_tensors.at(kOutputIndex)->Height(); 14430+ conv_param->output_w_ = out_tensors.at(kOutputIndex)->Width(); 14431+ conv_param->output_channel_ = out_tensors.at(kOutputIndex)->Channel(); 14432+ conv_param->op_parameter_.thread_num_ = 1; 14433+ free(conv_param); 14434+ std::unique_ptr<OperatorCoder> coder; 14435+ if (kernel_h == 1 && kernel_w == 1) { 14436+ MS_LOG(DEBUG) << "create Convolution1x1DynamicFP16CPUKernel"; 14437+ coder = CPUOpCoderCreator<Convolution1x1DynamicFP16Coder>(in_tensors, out_tensors, node, node_index, target, 14438+ schema_version); 14439+ } else { 14440+ MS_LOG(DEBUG) << "create ConvolutionDynamicFP16Coder"; 14441+ coder = 14442+ CPUOpCoderCreator<ConvolutionDynamicFP16Coder>(in_tensors, out_tensors, node, node_index, target, schema_version); 14443+ } 14444+ return coder; 14445+} 14446+ 14447+std::unique_ptr<OperatorCoder> CreateConvDelegateFp16(const std::vector<lite::Tensor *> &in_tensors, 14448+ const std::vector<lite::Tensor *> &out_tensors, 14449+ const LiteGraph::Node *node, size_t node_index, Target target, 14450+ int schema_version) { 14451+ return CPUOpCoderCreator<ConvDelegateDynamicFP16Coder>(in_tensors, out_tensors, node, node_index, target, 14452+ schema_version); 14453+} 14454+ 14455+std::unique_ptr<OperatorCoder> CPUConv2DFusionDynamicFP16CoderCreator(const std::vector<lite::Tensor *> &in_tensors, 14456+ const std::vector<lite::Tensor *> &out_tensors, 14457+ const LiteGraph::Node *node, size_t node_index, 14458+ Target target, int schema_version) { 14459+ const void *primitive = node->primitive_; 14460+ if (primitive == nullptr) { 14461+ return nullptr; 14462+ } 14463+ ParameterGen param_gen = PopulateRegistry::GetInstance()->GetParameterCreator( 14464+ GetPrimitiveType(node->primitive_, schema_version), schema_version); 14465+ if (param_gen == nullptr) { 14466+ MS_LOG(ERROR) << "parameter generator is null"; 14467+ return nullptr; 14468+ } 14469+ auto conv_param = reinterpret_cast<ConvParameter *>(param_gen(node->primitive_)); 14470+ std::unique_ptr<OperatorCoder> coder; 14471+ if (conv_param->group_ == 1) { 14472+ coder = CreateConvDelegateFp16(in_tensors, out_tensors, node, node_index, target, schema_version); 14473+ } else { 14474+ // GroupConv 14475+ MS_LOG(ERROR) << "currently, only support conv_param->group_ == 1 in dynamic coder scene"; 14476+ return nullptr; 14477+ } 14478+ return coder; 14479+} 14480+ 14481+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_Conv2DFusion, 14482+ CPUConv2DFusionDynamicFP16CoderCreator) 14483+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_Conv2DFusion, 14484+ CPUConv2DFusionDynamicFP16CoderCreator) 14485+} // namespace mindspore::lite::micro::nnacl 14486diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_dynamic_fp16_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_dynamic_fp16_coder.h 14487new file mode 100644 14488index 00000000..c352c469 14489--- /dev/null 14490+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_dynamic_fp16_coder.h 14491@@ -0,0 +1,56 @@ 14492+/** 14493+ * Copyright 2023 Huawei Technologies Co., Ltd 14494+ * 14495+ * Licensed under the Apache License, Version 2.0 (the "License"); 14496+ * you may not use this file except in compliance with the License. 14497+ * You may obtain a copy of the License at 14498+ * 14499+ * http://www.apache.org/licenses/LICENSE-2.0 14500+ * 14501+ * Unless required by applicable law or agreed to in writing, software 14502+ * distributed under the License is distributed on an "AS IS" BASIS, 14503+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14504+ * See the License for the specific language governing permissions and 14505+ * limitations under the License. 14506+ */ 14507+ 14508+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_CONV2D_DELEGATE_DYNAMIC_FP16_CODER_H_ 14509+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_CONV2D_DELEGATE_DYNAMIC_FP16_CODER_H_ 14510+#include <vector> 14511+#include <memory> 14512+#include "coder/opcoders/op_coder.h" 14513+#include "coder/opcoders/nnacl/dynamic_parameter/conv_dynamic_parameter.h" 14514+#include "nnacl/conv_parameter.h" 14515+ 14516+namespace mindspore::lite::micro::nnacl { 14517+class ConvDelegateDynamicFP16Coder : public OperatorCoder { 14518+ public: 14519+ ConvDelegateDynamicFP16Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 14520+ const LiteGraph::Node *node, size_t node_index, Target target) 14521+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 14522+ 14523+ ~ConvDelegateDynamicFP16Coder() override = default; 14524+ int Prepare(CoderContext *const context) override; 14525+ int DoCode(CoderContext *const context) override; 14526+ 14527+ protected: 14528+ std::unique_ptr<OperatorCoder> conv_coder_ = nullptr; 14529+ ConvParameter *conv_param_{nullptr}; 14530+ ConvDynamicParameter dynamic_param_; 14531+ 14532+ private: 14533+ void SetInputOutputShapeInfo(ConvParameter *conv_param, ConvDynamicParameter &dynamic_param, 14534+ const lite::Tensor *input, const lite::Tensor *output); 14535+}; 14536+ 14537+std::unique_ptr<OperatorCoder> CPUConvFP16DynamicCoderSelect(const std::vector<lite::Tensor *> &in_tensors, 14538+ const std::vector<lite::Tensor *> &out_tensors, 14539+ const LiteGraph::Node *node, size_t node_index, 14540+ Target target, int schema_version); 14541+ 14542+std::unique_ptr<OperatorCoder> CPUConv2DFusionDynamicFP16CoderCreator(const std::vector<lite::Tensor *> &in_tensors, 14543+ const std::vector<lite::Tensor *> &out_tensors, 14544+ const LiteGraph::Node *node, size_t node_index, 14545+ Target target, int schema_version); 14546+} // namespace mindspore::lite::micro::nnacl 14547+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_CONV2D_DELEGATE_DYNAMIC_FP16_CODER_H_ 14548diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.cc 14549new file mode 100644 14550index 00000000..c682b2ed 14551--- /dev/null 14552+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.cc 14553@@ -0,0 +1,252 @@ 14554+/** 14555+ * Copyright 2023 Huawei Technologies Co., Ltd 14556+ * 14557+ * Licensed under the Apache License, Version 2.0 (the "License"); 14558+ * you may not use this file except in compliance with the License. 14559+ * You may obtain a copy of the License at 14560+ * 14561+ * http://www.apache.org/licenses/LICENSE-2.0 14562+ * 14563+ * Unless required by applicable law or agreed to in writing, software 14564+ * distributed under the License is distributed on an "AS IS" BASIS, 14565+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14566+ * See the License for the specific language governing permissions and 14567+ * limitations under the License. 14568+ */ 14569+ 14570+#include "coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.h" 14571+#include <string> 14572+#include <vector> 14573+#include "nnacl/fp32/winograd_utils.h" 14574+#include "coder/opcoders/file_collector.h" 14575+#include "coder/opcoders/parallel.h" 14576+#include "coder/utils/coder_utils.h" 14577+ 14578+namespace mindspore::lite::micro::nnacl { 14579+int Convolution1x1DynamicFP16Coder::Prepare(CoderContext *const context) { 14580+ CHECK_LESS_RETURN(input_tensors_.size(), C2NUM); 14581+ CHECK_LESS_RETURN(output_tensors_.size(), 1); 14582+ for (size_t i = 0; i < input_tensors_.size(); ++i) { 14583+ MS_CHECK_TRUE_MSG(input_tensors_[i]->data_type() == kNumberTypeFloat16, RET_INPUT_PARAM_INVALID, 14584+ "Tensor data type is invalid"); 14585+ } 14586+ for (size_t i = 0; i < output_tensors_.size(); ++i) { 14587+ MS_CHECK_TRUE_MSG(output_tensors_[i]->data_type() == kNumberTypeFloat16, RET_PARAM_INVALID, 14588+ "Tensor data type is invalid"); 14589+ } 14590+ if (target_ == kARM64) { 14591+ row_tile_ = (output_tensor_->format() == NC4HW4) ? C16NUM : C12NUM; 14592+ col_tile_ = (output_tensor_->format() == NC4HW4) ? C8NUM : C16NUM; 14593+ } 14594+ if (matmul_param_ == nullptr) { 14595+ matmul_param_ = new (std::nothrow) MatMulParameter(); 14596+ if (matmul_param_ == nullptr) { 14597+ MS_LOG(ERROR) << "Init matmul_param_ failed."; 14598+ return RET_ERROR; 14599+ } 14600+ } 14601+ conv_param_ = reinterpret_cast<ConvParameter *>(parameter_); 14602+ filter_tensor_ = input_tensors_.at(kWeightIndex); 14603+ MS_CHECK_PTR(filter_tensor_); 14604+ if (input_tensors_.size() == kInputSize2) { 14605+ bias_tensor_ = input_tensors_.at(kBiasIndex); 14606+ MS_CHECK_PTR(bias_tensor_); 14607+ } else { 14608+ MS_CHECK_TRUE(input_tensors_.size() == kInputSize1, "wrong input size"); 14609+ } 14610+ dynamic_param_.input_batch_ = shape_info_container_->GetTemplateShape(input_tensor_)[0]; 14611+ conv_param_->input_h_ = input_tensor_->Height(); 14612+ conv_param_->input_w_ = input_tensor_->Width(); 14613+ conv_param_->input_channel_ = input_tensor_->Channel(); 14614+ dynamic_param_.output_batch_ = shape_info_container_->GetTemplateShape(output_tensor_)[0]; 14615+ conv_param_->output_h_ = output_tensor_->Height(); 14616+ conv_param_->output_w_ = output_tensor_->Width(); 14617+ conv_param_->output_channel_ = output_tensor_->Channel(); 14618+ MS_CHECK_RET_CODE(InitWeightBias(context), "Init weight bias failed."); 14619+ MS_CHECK_RET_CODE(InitMatmulParam(), "Init matmul param failed."); 14620+ MS_CHECK_RET_CODE(InitTmpBuffer(context), "Init tmp buffer failed."); 14621+ return RET_OK; 14622+} 14623+ 14624+int Convolution1x1DynamicFP16Coder::DoCode(CoderContext *const context) { 14625+ CollectFilesForFunc(context); 14626+ NNaclFp32Serializer code; 14627+ MS_CHECK_RET_CODE(ComputeWorkspace(), "ComputeWorkspace failed."); 14628+ auto tmp_input_str = "(float16_t *)(" + allocator_->GetRuntimeAddr(static_cast<float16 *>(tmp_input_)) + ")"; 14629+ auto input_str = 14630+ "(float16_t *)(" + GetTensorAddr(input_tensor_, input_tensor_->IsConst(), dynamic_mem_manager_, allocator_) + ")"; 14631+ auto output_str = 14632+ "(float16_t *)(" + GetTensorAddr(output_tensor_, output_tensor_->IsConst(), dynamic_mem_manager_, allocator_) + ")"; 14633+ auto packed_weight_str = allocator_->GetRuntimeAddr(static_cast<float16 *>(packed_weight_)); 14634+ 14635+ code << " for (int batch_index = 0; batch_index < " << dynamic_param_.input_batch_ << "; batch_index++) {\n"; 14636+ output_ptr_ = output_str + " + batch_index * " + std::to_string(matmul_param_->row_ * matmul_param_->col_); 14637+ auto batch_in = input_str + " + batch_index * " + 14638+ std::to_string(conv_param_->input_h_ * conv_param_->input_w_ * conv_param_->input_channel_); 14639+ if (pre_trans_input_) { 14640+ code.CodeStruct("conv_parameter", *conv_param_, dynamic_param_); 14641+ code.CodeFunction("Conv1x1InputPack", batch_in, tmp_input_str, "&conv_parameter", DataTypeSize(data_type_)); 14642+ } else { 14643+ tmp_input_str = batch_in; 14644+ } 14645+ 14646+ if (output_tensor_->format() == NC4HW4) { 14647+ code.CodeFunction(target_ == kARM64 ? "RowMajor2Col16MajorFp16Opt" : "RowMajor2Col12MajorFp16Opt", tmp_input_str, 14648+ "(float16_t *)(" + pack_input_str_ + ")", matmul_param_->row_, matmul_param_->deep_); 14649+ } else { 14650+ code.CodeFunction("RowMajor2Col12MajorFp16Opt", tmp_input_str, "(float16_t *)(" + pack_input_str_ + ")", 14651+ matmul_param_->row_, matmul_param_->deep_); 14652+ } 14653+ 14654+ if (output_tensor_->format() == NC4HW4) { 14655+ code.CodeStruct("matmul_param", *matmul_param_); 14656+ code.CodeFunction("Conv1x1OutNc8hw8MultiThreadByWeightFp16", tmp_input_str, 14657+ "(float16_t *)(" + pack_input_str_ + ")", packed_weight_str, bias_data_, output_ptr_, 14658+ kDefaultTaskId, "&matmul_param"); 14659+ } else { 14660+ code.CodeFunction(target_ == kARM64 ? "MatMul12x16Fp16Opt" : "MatMul12x8A32Fp16", 14661+ "(float16_t *)(" + pack_input_str_ + ")", packed_weight_str, output_ptr_, bias_data_, 14662+ matmul_param_->act_type_, matmul_param_->deep_, matmul_param_->row_, matmul_param_->col_, 14663+ matmul_param_->col_, OutType_Nhwc); 14664+ } 14665+ code << " }\n"; 14666+ context->AppendCode(code.str()); 14667+ return RET_OK; 14668+} 14669+ 14670+Convolution1x1DynamicFP16Coder::~Convolution1x1DynamicFP16Coder() { 14671+ FreeTmpBuffer(); 14672+ if (matmul_param_ != nullptr) { 14673+ delete matmul_param_; 14674+ matmul_param_ = nullptr; 14675+ } 14676+ return; 14677+} 14678+ 14679+void Convolution1x1DynamicFP16Coder::FreeTmpBuffer() { 14680+ if (pre_trans_input_ && tmp_input_ != nullptr) { 14681+ free(tmp_input_); 14682+ tmp_input_ = nullptr; 14683+ } 14684+ return; 14685+} 14686+ 14687+int Convolution1x1DynamicFP16Coder::ComputeWorkspace() { 14688+ pack_input_size_ = matmul_param_->row_align_ * matmul_param_->deep_ * DataTypeSize(data_type_); 14689+ auto input_shape = shape_info_container_->GetTemplateShape(input_tensor_); 14690+ size_t scene_num = 0; 14691+ for (auto &dim_template : input_shape) { 14692+ auto dim_nums = shape_info_container_->GetRealNums(dim_template); 14693+ MS_CHECK_TRUE_MSG(!dim_nums.empty(), RET_ERROR, "Dynamic shape's num must be greater than 0."); 14694+ scene_num = std::max(scene_num, dim_nums.size()); 14695+ } 14696+ for (size_t i = 0; i < scene_num; ++i) { 14697+ pack_input_str_ = dynamic_mem_manager_->AllocWorkSpace(pack_input_size_, i); 14698+ MS_CHECK_TRUE_MSG(!pack_input_str_.empty(), RET_ERROR, "Convolution cannot alloc workspace."); 14699+ } 14700+ return RET_OK; 14701+} 14702+ 14703+int Convolution1x1DynamicFP16Coder::InitMatmulParam() { 14704+ matmul_param_->row_ = conv_param_->output_h_ * conv_param_->output_w_; 14705+ matmul_param_->col_ = conv_param_->output_channel_; 14706+ matmul_param_->deep_ = conv_param_->input_channel_; 14707+ matmul_param_->row_align_ = UP_ROUND(matmul_param_->row_, row_tile_); 14708+ matmul_param_->col_align_ = UP_ROUND(matmul_param_->col_, col_tile_); 14709+ matmul_param_->act_type_ = conv_param_->act_type_; 14710+ return RET_OK; 14711+} 14712+ 14713+int Convolution1x1DynamicFP16Coder::InitWeightBias(CoderContext *const context) { 14714+ auto input_channel = filter_tensor_->Channel(); 14715+ auto output_channel = filter_tensor_->Batch(); 14716+ MS_CHECK_TRUE_RET(input_channel > 0 && output_channel > 0, RET_ERROR); 14717+ pack_weight_size_ = input_channel * UP_ROUND(output_channel, col_tile_) * DataTypeSize(data_type_); 14718+ packed_weight_ = allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight); 14719+ MS_CHECK_PTR(packed_weight_); 14720+ 14721+ NNaclFp32Serializer init_code; 14722+ std::string ori_weight_addr = allocator_->GetRuntimeAddr(filter_tensor_); 14723+ size_t w_buf_size = 0; 14724+ w_buf_size += pack_weight_size_; 14725+ auto packed_weight_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(static_cast<float16 *>(packed_weight_)); 14726+ init_code.CodeBufferOffsetExpression(packed_weight_, context->weight_name(), context->weight_offset_name(), 14727+ context->weight_size_name(), pack_weight_size_); 14728+ if (target_ == kARM64 && output_tensor_->format() != NC4HW4) { 14729+ init_code.CodeFunction("RowMajor2Col16MajorFp16Opt", ori_weight_addr, packed_weight_str, output_channel, 14730+ input_channel); 14731+ } else { 14732+ init_code.CodeFunction("ColMajor2Row8MajorFp16", ori_weight_addr, packed_weight_str, input_channel, output_channel, 14733+ true); 14734+ } 14735+ bias_data_size_ = UP_ROUND(output_channel, col_tile_) * DataTypeSize(data_type_); 14736+ if (input_tensors_.size() == kInputSize2) { 14737+ bias_data_ = 14738+ allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight, bias_tensor_->tensor_name() + "_online_pack"); 14739+ MS_CHECK_PTR(bias_data_); 14740+ init_code.CodeBufferOffsetExpression(bias_data_, context->weight_name(), context->weight_offset_name(), 14741+ context->weight_size_name(), bias_data_size_); 14742+ w_buf_size += bias_data_size_; 14743+ auto bias_data_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(static_cast<float16 *>(bias_data_)); 14744+ std::string bias_tensor_str = allocator_->GetRuntimeAddr(bias_tensor_); 14745+ init_code.CodeFunction("memcpy", bias_data_str, bias_tensor_str, bias_tensor_->Size()); 14746+ } else { 14747+ bias_data_ = allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight, node_->name_ + "_bias_online_pack"); 14748+ MS_CHECK_PTR(bias_data_); 14749+ init_code.CodeFunction("memset", bias_data_, 0, bias_data_size_); 14750+ } 14751+ context->AppendInitWeightSizeCode(w_buf_size); 14752+ context->AppendInitCode(init_code.str()); 14753+ return RET_OK; 14754+} 14755+ 14756+int Convolution1x1DynamicFP16Coder::InitTmpBuffer(CoderContext *const context) { 14757+ NNaclFp32Serializer init_code; 14758+ pre_trans_input_ = (conv_param_->pad_u_ != 0 || conv_param_->pad_l_ != 0 || conv_param_->stride_h_ != 1 || 14759+ conv_param_->stride_w_ != 1); 14760+ size_t w_size = 0; 14761+ if (pre_trans_input_) { 14762+ tmp_input_size_ = matmul_param_->row_ * matmul_param_->deep_ * DataTypeSize(data_type_); 14763+ tmp_input_ = allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight); 14764+ MS_CHECK_PTR(tmp_input_); 14765+ w_size += tmp_input_size_; 14766+ auto tmp_input_str = allocator_->GetRuntimeAddr(static_cast<float16 *>(tmp_input_)); 14767+ init_code.CodeBufferOffsetExpression(tmp_input_, context->weight_name(), context->weight_offset_name(), 14768+ context->weight_size_name(), tmp_input_size_); 14769+ init_code.CodeFunction("memset", tmp_input_, 0, tmp_input_size_); 14770+ } 14771+ context->AppendInitWeightSizeCode(w_size); 14772+ context->AppendInitCode(init_code.str()); 14773+ return RET_OK; 14774+} 14775+ 14776+void Convolution1x1DynamicFP16Coder::CollectFilesForFunc(CoderContext *const context) { 14777+ if (target_ == kARM64) { 14778+ Collect(context, {}, {}, 14779+ { 14780+ "MatmulFp16.S", 14781+ "MatmulFp16Opt.S", 14782+ "Matmul12X16Fp16.S", 14783+ }); 14784+ } else { 14785+ Collect(context, {}, {}, 14786+ { 14787+ "Matmul12x8Fp16.S", 14788+ }); 14789+ } 14790+ Collect(context, 14791+ { 14792+ "nnacl/fp16/matmul_fp16.h", 14793+ "nnacl/conv_parameter.h", 14794+ "nnacl/op_base.h", 14795+ "nnacl/fp16/conv_fp16.h", 14796+ "nnacl/base/conv1x1_base.h", 14797+ }, 14798+ { 14799+ "common_func.c", 14800+ "matmul_fp16.c", 14801+ "conv_fp16.c", 14802+ "conv1x1_base.c", 14803+ }); 14804+} 14805+} // namespace mindspore::lite::micro::nnacl 14806diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.h 14807new file mode 100644 14808index 00000000..558eea53 14809--- /dev/null 14810+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.h 14811@@ -0,0 +1,68 @@ 14812+/** 14813+ * Copyright 2023 Huawei Technologies Co., Ltd 14814+ * 14815+ * Licensed under the Apache License, Version 2.0 (the "License"); 14816+ * you may not use this file except in compliance with the License. 14817+ * You may obtain a copy of the License at 14818+ * 14819+ * http://www.apache.org/licenses/LICENSE-2.0 14820+ * 14821+ * Unless required by applicable law or agreed to in writing, software 14822+ * distributed under the License is distributed on an "AS IS" BASIS, 14823+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14824+ * See the License for the specific language governing permissions and 14825+ * limitations under the License. 14826+ */ 14827+ 14828+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_CONVOLUTION_1X1_DYNAMIC_FP16_CODER_H_ 14829+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_CONVOLUTION_1X1_DYNAMIC_FP16_CODER_H_ 14830+ 14831+#include <vector> 14832+#include <string> 14833+#include "nnacl/conv_parameter.h" 14834+#include "nnacl/matmul_parameter.h" 14835+#include "coder/opcoders/op_coder.h" 14836+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 14837+#include "coder/opcoders/nnacl/dynamic_parameter/conv_dynamic_parameter.h" 14838+#include "base/float16.h" 14839+ 14840+namespace mindspore::lite::micro::nnacl { 14841+class Convolution1x1DynamicFP16Coder final : public OperatorCoder { 14842+ public: 14843+ Convolution1x1DynamicFP16Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 14844+ const LiteGraph::Node *node, size_t node_index, Target target) 14845+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 14846+ ~Convolution1x1DynamicFP16Coder() override; 14847+ 14848+ int Prepare(CoderContext *const context) override; 14849+ 14850+ int DoCode(CoderContext *const context) override; 14851+ 14852+ private: 14853+ void CollectFilesForFunc(CoderContext *const context); 14854+ int InitWeightBias(CoderContext *const context); 14855+ int InitMatmulParam(); 14856+ int InitTmpBuffer(CoderContext *const context); 14857+ void FreeTmpBuffer(); 14858+ int ComputeWorkspace(); 14859+ MatMulParameter *matmul_param_{nullptr}; 14860+ ConvParameter *conv_param_{nullptr}; 14861+ ConvDynamicParameter dynamic_param_; 14862+ Tensor *filter_tensor_{nullptr}; 14863+ Tensor *bias_tensor_{nullptr}; 14864+ int row_tile_{C12NUM}; 14865+ int col_tile_{C8NUM}; 14866+ void *packed_weight_{nullptr}; 14867+ void *bias_data_{nullptr}; 14868+ std::string pack_input_str_; 14869+ void *tmp_input_{nullptr}; 14870+ size_t pack_weight_size_{0}; 14871+ size_t bias_data_size_{0}; 14872+ size_t tmp_input_size_{0}; 14873+ size_t pack_input_size_{0}; 14874+ bool pre_trans_input_{false}; 14875+ std::string output_ptr_; 14876+ TypeId data_type_ = kNumberTypeFloat16; 14877+}; 14878+} // namespace mindspore::lite::micro::nnacl 14879+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_CONVOLUTION_1X1_DYNAMIC_FP16_CODER_H_ 14880diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.cc 14881new file mode 100644 14882index 00000000..c917b89a 14883--- /dev/null 14884+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.cc 14885@@ -0,0 +1,172 @@ 14886+/** 14887+ * Copyright 2023 Huawei Technologies Co., Ltd 14888+ * 14889+ * Licensed under the Apache License, Version 2.0 (the "License"); 14890+ * you may not use this file except in compliance with the License. 14891+ * You may obtain a copy of the License at 14892+ * 14893+ * http://www.apache.org/licenses/LICENSE-2.0 14894+ * 14895+ * Unless required by applicable law or agreed to in writing, software 14896+ * distributed under the License is distributed on an "AS IS" BASIS, 14897+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14898+ * See the License for the specific language governing permissions and 14899+ * limitations under the License. 14900+ */ 14901+ 14902+#include "coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.h" 14903+#include <string> 14904+#include <vector> 14905+#include "coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h" 14906+#include "nnacl/fp32/winograd_utils.h" 14907+#include "coder/opcoders/file_collector.h" 14908+#include "coder/log.h" 14909+#include "coder/opcoders/parallel.h" 14910+#include "coder/utils/coder_utils.h" 14911+#include "base/float16.h" 14912+ 14913+using mindspore::schema::PrimitiveType_Conv2DFusion; 14914+namespace mindspore::lite::micro::nnacl { 14915+int ConvolutionDynamicFP16Coder::Prepare(CoderContext *const context) { 14916+ CHECK_LESS_RETURN(input_tensors_.size(), C2NUM); 14917+ CHECK_LESS_RETURN(output_tensors_.size(), 1); 14918+ if (target_ == kARM64) { 14919+ row_tile_ = C16NUM; 14920+ } 14921+ conv_param_ = reinterpret_cast<ConvParameter *>(parameter_); 14922+ MS_CHECK_PTR(conv_param_); 14923+ dynamic_param_.input_batch_ = shape_info_container_->GetTemplateShape(input_tensor_)[0]; 14924+ conv_param_->input_h_ = input_tensor_->Height(); 14925+ conv_param_->input_w_ = input_tensor_->Width(); 14926+ conv_param_->input_channel_ = input_tensor_->Channel(); 14927+ dynamic_param_.output_batch_ = shape_info_container_->GetTemplateShape(output_tensor_)[0]; 14928+ conv_param_->output_h_ = output_tensor_->Height(); 14929+ conv_param_->output_w_ = output_tensor_->Width(); 14930+ conv_param_->output_channel_ = output_tensor_->Channel(); 14931+ conv_param_->thread_num_ = 1; 14932+ MS_CHECK_RET_CODE(InitWeightBias(context), "Init weight bias failed."); 14933+ MS_CHECK_RET_CODE(InitTmpBuffer(), "Init tmp buffer failed."); 14934+ return RET_OK; 14935+} 14936+ 14937+int ConvolutionDynamicFP16Coder::InitTmpBuffer() { 14938+ int uint_size = conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ * row_tile_ * 14939+ conv_param_->thread_num_; 14940+ packed_input_size_ = uint_size * DataTypeSize(data_type_); 14941+ auto input_shape = shape_info_container_->GetTemplateShape(input_tensor_); 14942+ size_t scene_num = 0; 14943+ for (auto &dim_template : input_shape) { 14944+ auto dim_nums = shape_info_container_->GetRealNums(dim_template); 14945+ MS_CHECK_TRUE_MSG(!dim_nums.empty(), RET_ERROR, "Dynamic shape's num must be greater than 0."); 14946+ scene_num = std::max(scene_num, dim_nums.size()); 14947+ } 14948+ for (size_t i = 0; i < scene_num; ++i) { 14949+ packed_input_str_ = dynamic_mem_manager_->AllocWorkSpace(packed_input_size_ * 2, i); 14950+ MS_CHECK_TRUE_MSG(!packed_input_str_.empty(), RET_ERROR, "Convolution cannot alloc workspace."); 14951+ } 14952+ col_major_input_str_ = packed_input_str_ + " + " + std::to_string(packed_input_size_); 14953+ return RET_OK; 14954+} 14955+ 14956+int ConvolutionDynamicFP16Coder::InitWeightBias(CoderContext *const context) { 14957+ filter_tensor_ = input_tensors_.at(kWeightIndex); 14958+ CHECK_NULL_RETURN(filter_tensor_); 14959+ auto shape = filter_tensor_->shape(); 14960+ if (std::find(shape.begin(), shape.end(), -1) != shape.end()) { 14961+ MS_LOG(WARNING) << "The shape of weight tensor is not ready, the weight and bias would be inited in runtime."; 14962+ return RET_OK; 14963+ } 14964+ int in_channel = filter_tensor_->Channel(); 14965+ int out_channel = filter_tensor_->Batch(); 14966+ MS_CHECK_TRUE_RET(in_channel > 0 && out_channel > 0, RET_ERROR); 14967+ conv_param_->input_channel_ = in_channel; 14968+ conv_param_->output_channel_ = out_channel; 14969+ int oc8 = UP_ROUND(out_channel, col_tile_); 14970+ int kernel_plane = filter_tensor_->Height() * filter_tensor_->Width(); 14971+ pack_weight_size_ = oc8 * in_channel * kernel_plane * DataTypeSize(data_type_); 14972+ // init weight 14973+ packed_weight_ = allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight); 14974+ MS_CHECK_PTR(packed_weight_); 14975+ NNaclFp32Serializer init_code; 14976+ std::string ori_weight_addr = allocator_->GetRuntimeAddr(filter_tensor_); 14977+ size_t w_buf_size = 0; 14978+ w_buf_size += pack_weight_size_; 14979+ auto packed_weight_str = allocator_->GetRuntimeAddr(static_cast<float16 *>(packed_weight_)); 14980+ init_code.CodeBufferOffsetExpression(packed_weight_, context->weight_name(), context->weight_offset_name(), 14981+ context->weight_size_name(), pack_weight_size_); 14982+ init_code.CodeFunction("RowMajor2Col8MajorFp16", ori_weight_addr, packed_weight_str, out_channel, 14983+ in_channel * kernel_plane, false); 14984+ if (input_tensors_.size() == C3NUM) { 14985+ bias_tensor_ = input_tensors_.at(kBiasIndex); 14986+ MS_CHECK_PTR(bias_tensor_); 14987+ bias_data_ = 14988+ allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight, bias_tensor_->tensor_name() + "_online_pack"); 14989+ MS_CHECK_PTR(bias_data_); 14990+ } else { 14991+ bias_data_ = allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight, node_->name_ + "_bias_online_pack"); 14992+ MS_CHECK_PTR(bias_data_); 14993+ } 14994+ auto bias_data_size = static_cast<size_t>(oc8 * DataTypeSize(data_type_)); 14995+ w_buf_size += bias_data_size; 14996+ init_code.CodeBufferOffsetExpression(bias_data_, context->weight_name(), context->weight_offset_name(), 14997+ context->weight_size_name(), bias_data_size); 14998+ bias_data_str_ = allocator_->GetRuntimeAddr(bias_data_); 14999+ if (input_tensors_.size() == C3NUM) { 15000+ auto origin_bias_str = allocator_->GetRuntimeAddr(bias_tensor_); 15001+ init_code.CodeFunction("memcpy", bias_data_str_, origin_bias_str, bias_tensor_->Size()); 15002+ } else { 15003+ init_code.CodeFunction("memset", bias_data_str_, 0, bias_data_size); 15004+ } 15005+ context->AppendInitWeightSizeCode(w_buf_size); 15006+ context->AppendInitCode(init_code.str()); 15007+ return RET_OK; 15008+} 15009+ 15010+void ConvolutionDynamicFP16Coder::CollectFilesForFunc(CoderContext *const context) { 15011+ Collect(context, {}, {}, 15012+ { 15013+ "MatmulFp16.S", 15014+ "MatmulFp16Opt.S", 15015+ "MatVecMulFp16.S", 15016+ "Matmul12X16Fp16.S", 15017+ }); 15018+ Collect(context, 15019+ { 15020+ "nnacl/fp16/matmul_fp16.h", 15021+ "nnacl/conv_parameter.h", 15022+ "nnacl/op_base.h", 15023+ "nnacl/fp16/conv_fp16.h", 15024+ }, 15025+ { 15026+ "common_func.c", 15027+ "matmul_fp16.c", 15028+ "pack_fp16.c", 15029+ "conv_fp16.c", 15030+ }); 15031+} 15032+ 15033+int ConvolutionDynamicFP16Coder::DoCode(CoderContext *const context) { 15034+ CollectFilesForFunc(context); 15035+ NNaclFp32Serializer code; 15036+ // call the op function 15037+ auto packed_weight_str = allocator_->GetRuntimeAddr(static_cast<float16 *>(packed_weight_)); 15038+ auto input_str = 15039+ "(float16_t *)(" + GetTensorAddr(input_tensor_, input_tensor_->IsConst(), dynamic_mem_manager_, allocator_) + ")"; 15040+ auto output_str = 15041+ "(float16_t *)(" + GetTensorAddr(output_tensor_, output_tensor_->IsConst(), dynamic_mem_manager_, allocator_) + ")"; 15042+ // code.CodeFunction("memset", packed_input_str_, "0", packed_input_size_); 15043+ // code.CodeFunction("memset", col_major_input_str_, "0", packed_input_size_); 15044+ code.CodeStruct("conv_parameter", *conv_param_, dynamic_param_); 15045+ packed_input_str_ = "(float16_t *)(" + packed_input_str_ + ")"; 15046+ col_major_input_str_ = "(float16_t *)(" + col_major_input_str_ + ")"; 15047+ if (output_tensor_->format() == NC4HW4) { 15048+ code.CodeFunction("ConvOutNc8hw8Fp16", input_str, packed_input_str_, packed_weight_str, bias_data_str_, 15049+ col_major_input_str_, output_str, kDefaultTaskId, "&conv_parameter"); 15050+ } else { 15051+ code.CodeFunction("ConvFp16", input_str, packed_input_str_, packed_weight_str, bias_data_str_, col_major_input_str_, 15052+ output_str, kDefaultTaskId, "&conv_parameter"); 15053+ } 15054+ context->AppendCode(code.str()); 15055+ return RET_OK; 15056+} 15057+} // namespace mindspore::lite::micro::nnacl 15058diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.h 15059new file mode 100644 15060index 00000000..29d70796 15061--- /dev/null 15062+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.h 15063@@ -0,0 +1,59 @@ 15064+/** 15065+ * Copyright 2023 Huawei Technologies Co., Ltd 15066+ * 15067+ * Licensed under the Apache License, Version 2.0 (the "License"); 15068+ * you may not use this file except in compliance with the License. 15069+ * You may obtain a copy of the License at 15070+ * 15071+ * http://www.apache.org/licenses/LICENSE-2.0 15072+ * 15073+ * Unless required by applicable law or agreed to in writing, software 15074+ * distributed under the License is distributed on an "AS IS" BASIS, 15075+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15076+ * See the License for the specific language governing permissions and 15077+ * limitations under the License. 15078+ */ 15079+ 15080+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_CONVOLUTION_DYNAMIC_FP16_CODER_H_ 15081+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_CONVOLUTION_DYNAMIC_FP16_CODER_H_ 15082+ 15083+#include <vector> 15084+#include <string> 15085+#include "nnacl/conv_parameter.h" 15086+#include "coder/opcoders/op_coder.h" 15087+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 15088+#include "coder/opcoders/nnacl/dynamic_parameter/conv_dynamic_parameter.h" 15089+ 15090+namespace mindspore::lite::micro::nnacl { 15091+class ConvolutionDynamicFP16Coder final : public OperatorCoder { 15092+ public: 15093+ ConvolutionDynamicFP16Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 15094+ const LiteGraph::Node *node, size_t node_index, Target target) 15095+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 15096+ 15097+ ~ConvolutionDynamicFP16Coder() override = default; 15098+ 15099+ int Prepare(CoderContext *const context) override; 15100+ int DoCode(CoderContext *const context) override; 15101+ 15102+ private: 15103+ void CollectFilesForFunc(CoderContext *const context); 15104+ int InitWeightBias(CoderContext *const context); 15105+ int InitTmpBuffer(); 15106+ ConvParameter *conv_param_{nullptr}; 15107+ ConvDynamicParameter dynamic_param_; 15108+ TypeId data_type_{kNumberTypeFloat16}; 15109+ int row_tile_{C12NUM}; 15110+ int col_tile_{C8NUM}; 15111+ Tensor *filter_tensor_{nullptr}; 15112+ Tensor *bias_tensor_{nullptr}; 15113+ size_t pack_weight_size_{0}; 15114+ size_t packed_input_size_{0}; 15115+ void *packed_weight_{nullptr}; 15116+ void *bias_data_{nullptr}; 15117+ std::string packed_input_str_; 15118+ std::string col_major_input_str_; 15119+ std::string bias_data_str_; 15120+}; 15121+} // namespace mindspore::lite::micro::nnacl 15122+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_CONVOLUTION_DYNAMIC_FP16_CODER_H_ 15123diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_mindir_dynamic_fp16_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_mindir_dynamic_fp16_coder.cc 15124new file mode 100644 15125index 00000000..8c4cc31b 15126--- /dev/null 15127+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_mindir_dynamic_fp16_coder.cc 15128@@ -0,0 +1,366 @@ 15129+/** 15130+ * Copyright 2023 Huawei Technologies Co., Ltd 15131+ * 15132+ * Licensed under the Apache License, Version 2.0 (the "License"); 15133+ * you may not use this file except in compliance with the License. 15134+ * You may obtain a copy of the License at 15135+ * 15136+ * http://www.apache.org/licenses/LICENSE-2.0 15137+ * 15138+ * Unless required by applicable law or agreed to in writing, software 15139+ * distributed under the License is distributed on an "AS IS" BASIS, 15140+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15141+ * See the License for the specific language governing permissions and 15142+ * limitations under the License. 15143+ */ 15144+ 15145+#include "coder/opcoders/nnacl/fp16/lstm_mindir_dynamic_fp16_coder.h" 15146+#include <cfloat> 15147+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 15148+#include "coder/opcoders/file_collector.h" 15149+#include "coder/utils/coder_utils.h" 15150+#include "tools/common/string_util.h" 15151+ 15152+using mindspore::schema::PrimitiveType_LSTM; 15153+ 15154+namespace mindspore::lite::micro::nnacl { 15155+namespace { 15156+constexpr size_t kMindirInputTensorNum = 4; 15157+} // namespace 15158+ 15159+int LstmMindirDynamicFP16Coder::Prepare(CoderContext *const context) { 15160+ CHECK_NULL_RETURN(context); 15161+ CHECK_NOT_EQUAL_RETURN(input_tensors_.size(), kMindirInputTensorNum); 15162+ for (auto in : input_tensors_) { 15163+ MS_CHECK_TRUE_MSG(in != nullptr, RET_INPUT_TENSOR_ERROR, "LstmMindirDynamicFP16Coder input is a nullptr."); 15164+ MS_CHECK_TRUE_MSG(in->data_type() == kNumberTypeFloat16, RET_INPUT_TENSOR_ERROR, 15165+ "LstmMindirDynamicFP16Coder input must be fp16."); 15166+ MS_CHECK_TRUE_MSG(in->shape().size() == C3NUM, RET_INPUT_TENSOR_ERROR, 15167+ "LstmMindirDynamicFP16Coder input must be 3D."); 15168+ } 15169+ MS_CHECK_TRUE_MSG(input_tensors_[FOURTH_INPUT]->IsConst(), RET_INPUT_TENSOR_ERROR, 15170+ "LstmMindirDynamicFP16Coder last three inputs must be all constant."); 15171+ lstm_param_ = reinterpret_cast<LstmParameter *>(parameter_); 15172+ return InitParam(); 15173+} 15174+ 15175+int LstmMindirDynamicFP16Coder::DoCode(CoderContext *const context) { 15176+ Collect(context, 15177+ { 15178+ "nnacl/lstm_parameter.h", 15179+ "nnacl/fp16/lstm_fp16.h", 15180+ }, 15181+ {"lstm_fp16.c", "activation_fp16.c", "arithmetic_fp16.c", "matmul_fp16.c", "pack_fp16.c"}, 15182+ {"MatmulBaseFp16Neon.S"}); 15183+ 15184+ auto ret = InitInputWeightBias(context); 15185+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "Lstm InitInputWeightBias failed."); 15186+ ret = InitStateWeightBias(context); 15187+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "Lstm InitStateWeightBias failed."); 15188+ ret = InitProjectWeight(context); 15189+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "Lstm InitProjectWeight failed."); 15190+ ret = ComputeWorkSpace(); 15191+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "Lstm ComputeWorkSpace failed."); 15192+ CreateBufferAddrStr(); 15193+ NNaclFp32Serializer code; 15194+ code << "float16_t *buffer[7] = {"; 15195+ for (const auto &buf : buffers_str_) { 15196+ code << "(float16_t *)(" << buf << "), "; 15197+ } 15198+ code << "};\n"; 15199+ 15200+ auto input1 = dynamic_mem_manager_->GetVarTensorAddr(input_tensors_[FIRST_INPUT]); 15201+ auto hidden_init = input_tensors_[SECOND_INPUT]->IsConst() 15202+ ? allocator_->GetRuntimeAddr(input_tensors_[SECOND_INPUT], true) 15203+ : dynamic_mem_manager_->GetVarTensorAddr(input_tensors_[SECOND_INPUT]); 15204+ auto cell_init = input_tensors_[THIRD_INPUT]->IsConst() 15205+ ? allocator_->GetRuntimeAddr(input_tensors_[THIRD_INPUT], true) 15206+ : dynamic_mem_manager_->GetVarTensorAddr(input_tensors_[THIRD_INPUT]); 15207+ auto output1 = dynamic_mem_manager_->GetVarTensorAddr(output_tensors_[FIRST_INPUT]); 15208+ auto hidden_output = dynamic_mem_manager_->GetVarTensorAddr(output_tensors_[SECOND_INPUT]); 15209+ auto cell_output = dynamic_mem_manager_->GetVarTensorAddr(output_tensors_[THIRD_INPUT]); 15210+ MS_CHECK_TRUE_MSG(!input1.empty() && !hidden_init.empty() && !cell_init.empty() && !output1.empty() && 15211+ !hidden_output.empty() && !cell_output.empty(), 15212+ RET_ERROR, "Lstm cannot get addr."); 15213+ code.CodeStruct("lstm_param", *lstm_param_, dynamic_lstm_param_); 15214+ auto input_shape2 = shape_info_container_->GetTemplateShape(input_tensors_[SECOND_INPUT]); 15215+ int64_t const_part = 1; 15216+ std::string non_const_part; 15217+ for (const auto &item : input_shape2) { 15218+ if (IsNumber(item)) { 15219+ const_part *= std::stoi(item); 15220+ } else { 15221+ if (!non_const_part.empty()) { 15222+ non_const_part += " * "; 15223+ } 15224+ non_const_part += item; 15225+ } 15226+ } 15227+ code.CodeFunction("memcpy", hidden_output, hidden_init, 15228+ non_const_part + " * " + std::to_string(const_part * DataTypeSize(kNumberTypeFloat16))); 15229+ auto input_shape3 = shape_info_container_->GetTemplateShape(input_tensors_[THIRD_INPUT]); 15230+ const_part = 1; 15231+ non_const_part = ""; 15232+ for (const auto &item : input_shape3) { 15233+ if (IsNumber(item)) { 15234+ const_part *= std::stoi(item); 15235+ } else { 15236+ if (!non_const_part.empty()) { 15237+ non_const_part += " * "; 15238+ } 15239+ non_const_part += item; 15240+ } 15241+ } 15242+ code.CodeFunction("memcpy", cell_output, cell_init, 15243+ non_const_part + " * " + std::to_string(const_part * DataTypeSize(kNumberTypeFloat16))); 15244+ auto weight_i_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(static_cast<float16 *>(weight_i_ptr_)); 15245+ auto weight_h_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(static_cast<float16 *>(weight_h_ptr_)); 15246+ auto weight_pro_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(static_cast<float16 *>(weight_project_ptr_)); 15247+ auto input_bias_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(static_cast<float16 *>(input_bias_)); 15248+ auto state_bias_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(static_cast<float16 *>(hh_bias_)); 15249+ auto pro_bias_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(static_cast<float16 *>(project_bias_)); 15250+ 15251+ code.CodeFunction("LstmFp16", "(float16_t *)(" + output1 + ")", "(float16_t *)(" + input1 + ")", weight_i_str, 15252+ weight_h_str, input_bias_str, state_bias_str, weight_pro_str, pro_bias_str, 15253+ "(float16_t *)(" + hidden_output + ")", "(float16_t *)(" + cell_output + ")", "buffer", 15254+ "&lstm_param"); 15255+ context->AppendCode(code.str()); 15256+ return RET_OK; 15257+} 15258+ 15259+int LstmMindirDynamicFP16Coder::InitParam() { 15260+ auto in_shape1 = shape_info_container_->GetTemplateShape(input_tensors_[FIRST_INPUT]); 15261+ MS_CHECK_TRUE_MSG(in_shape1.size() == C3NUM, RET_INPUT_TENSOR_ERROR, "LstmMindir first input's dim must be 3D."); 15262+ dynamic_lstm_param_.batch_ = in_shape1[1]; 15263+ dynamic_lstm_param_.seq_len_ = in_shape1[0]; 15264+ MS_CHECK_TRUE_MSG(IsNumber(in_shape1[C2NUM]), RET_NOT_SUPPORT, 15265+ "LstmMindir doesn't support input_size is dynamical in micro."); 15266+ lstm_param_->input_size_ = std::atoi(in_shape1[C2NUM].c_str()); 15267+ 15268+ auto h_init_shape = input_tensors_[SECOND_INPUT]->shape(); 15269+ auto c_init_shape = input_tensors_[THIRD_INPUT]->shape(); 15270+ lstm_param_->hidden_size_ = c_init_shape.back(); 15271+ lstm_param_->output_size_ = h_init_shape.back(); 15272+ 15273+ lstm_param_->output_step_ = lstm_param_->bidirectional_ ? C2NUM * lstm_param_->batch_ * lstm_param_->output_size_ 15274+ : lstm_param_->batch_ * lstm_param_->output_size_; 15275+ weight_segment_num_ = lstm_param_->bidirectional_ ? C8NUM : C4NUM; 15276+ dynamic_lstm_param_.input_row_align_ = 15277+ "(" + dynamic_lstm_param_.batch_ + " * " + dynamic_lstm_param_.seq_len_ + " + 3) / 4 * 4"; 15278+ lstm_param_->input_col_align_ = UP_ROUND(lstm_param_->hidden_size_, C4NUM); 15279+ 15280+ dynamic_lstm_param_.state_row_align_ = "(" + dynamic_lstm_param_.batch_ + " + 3) / 4 * 4"; 15281+ lstm_param_->state_col_align_ = UP_ROUND(lstm_param_->hidden_size_, C4NUM); 15282+ lstm_param_->proj_col_align_ = UP_ROUND(lstm_param_->project_size_, C4NUM); 15283+ dynamic_lstm_param_.output_step_ = 15284+ std::to_string((lstm_param_->bidirectional_ ? C2NUM : C1NUM) * lstm_param_->output_size_) + " * " + 15285+ dynamic_lstm_param_.batch_; 15286+ size_t scale = lstm_param_->bidirectional_ ? C2NUM : C1NUM; 15287+ hi_size_ = scale * C4NUM * lstm_param_->hidden_size_ * lstm_param_->input_size_; 15288+ hh_size_ = scale * C4NUM * lstm_param_->hidden_size_ * lstm_param_->output_size_; 15289+ hp_size_ = scale * lstm_param_->project_size_ * lstm_param_->hidden_size_; 15290+ bias_size_ = scale * C8NUM * lstm_param_->hidden_size_; 15291+ auto real_whole_size = input_tensors_[FOURTH_INPUT]->ElementsNum(); 15292+ gpu_state_ = (hi_size_ + hh_size_ + hp_size_ + bias_size_) == static_cast<size_t>(real_whole_size); 15293+ if (gpu_state_) { 15294+ MS_LOG(ERROR) << "LstmMindirDynamicFP16Coder doesn't suuport model which exported from GPU."; 15295+ return RET_NOT_SUPPORT; 15296+ } 15297+ if (hi_size_ + hh_size_ + hp_size_ == static_cast<size_t>(real_whole_size)) { 15298+ bias_size_ = 0; 15299+ return RET_OK; 15300+ } 15301+ bias_size_ /= C2NUM; 15302+ if ((hi_size_ + hh_size_ + hp_size_ + bias_size_) != static_cast<size_t>(real_whole_size)) { 15303+ MS_LOG(ERROR) << "Bias of LstmMindir exported from cpu only exist in hi-part."; 15304+ return RET_INPUT_TENSOR_ERROR; 15305+ } 15306+ return RET_OK; 15307+} 15308+ 15309+int LstmMindirDynamicFP16Coder::InitInputWeightBias(CoderContext *const context) { 15310+ NNaclFp32Serializer init_code; 15311+ 15312+ size_t weight_hi_size = 15313+ weight_segment_num_ * lstm_param_->input_col_align_ * lstm_param_->input_size_ * DataTypeSize(data_type_); 15314+ weight_i_ptr_ = allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight); 15315+ MS_CHECK_PTR(weight_i_ptr_); 15316+ 15317+ size_t w_buf_size = 0; 15318+ 15319+ init_code.CodeBufferOffsetExpression(weight_i_ptr_, context->weight_name(), context->weight_offset_name(), 15320+ context->weight_size_name(), weight_hi_size); 15321+ w_buf_size += weight_hi_size; 15322+ auto weight_i_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(input_tensors_[FOURTH_INPUT]); 15323+ MS_CHECK_TRUE_MSG(!weight_i_str.empty(), RET_INPUT_TENSOR_ERROR, "Lstm cannot get weight."); 15324+ auto packed_weight_i_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(reinterpret_cast<float16 *>(weight_i_ptr_)); 15325+ init_code << " int32_t order[4] = {0, 2, 3, 1};\n"; 15326+ init_code.CodeFunction("PackLstmWeightFp16", packed_weight_i_str, weight_i_str, weight_segment_num_, 15327+ lstm_param_->input_size_, lstm_param_->hidden_size_, lstm_param_->input_col_align_, "order"); 15328+ 15329+ auto bias_stride = hi_size_ + hh_size_ + hp_size_; 15330+ input_bias_ = allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight); 15331+ MS_CHECK_PTR(input_bias_); 15332+ size_t bias_i_size = weight_segment_num_ * lstm_param_->input_col_align_ * DataTypeSize(data_type_); 15333+ w_buf_size += bias_i_size; 15334+ init_code.CodeBufferOffsetExpression(input_bias_, context->weight_name(), context->weight_offset_name(), 15335+ context->weight_size_name(), bias_i_size); 15336+ auto input_bias_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(reinterpret_cast<float16 *>(input_bias_)); 15337+ init_code.CodeFunction("memset", input_bias_str, 0, bias_i_size); 15338+ if (bias_size_ != 0) { 15339+ init_code.CodeFunction("PackLstmBiasFp16", input_bias_str, weight_i_str + " + " + std::to_string(bias_stride), 15340+ weight_segment_num_, lstm_param_->hidden_size_, lstm_param_->input_col_align_, 15341+ lstm_param_->bidirectional_, "order"); 15342+ } 15343+ 15344+ context->AppendInitWeightSizeCode(w_buf_size); 15345+ context->AppendInitCode(init_code.str()); 15346+ return RET_OK; 15347+} 15348+ 15349+int LstmMindirDynamicFP16Coder::InitStateWeightBias(CoderContext *const context) { 15350+ NNaclFp32Serializer init_code; 15351+ 15352+ size_t weight_hh_size = 15353+ weight_segment_num_ * lstm_param_->state_col_align_ * lstm_param_->project_size_ * DataTypeSize(data_type_); 15354+ weight_h_ptr_ = allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight); 15355+ MS_CHECK_PTR(weight_h_ptr_); 15356+ 15357+ size_t w_buf_size = 0; 15358+ 15359+ init_code.CodeBufferOffsetExpression(weight_h_ptr_, context->weight_name(), context->weight_offset_name(), 15360+ context->weight_size_name(), weight_hh_size); 15361+ w_buf_size += weight_hh_size; 15362+ auto weight_hh_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(input_tensors_[FOURTH_INPUT]); 15363+ MS_CHECK_TRUE_MSG(!weight_hh_str.empty(), RET_INPUT_TENSOR_ERROR, "Lstm cannot get weight."); 15364+ auto packed_weight_hh_str = 15365+ MemoryAllocator::GetInstance()->GetRuntimeAddr(reinterpret_cast<float16 *>(weight_h_ptr_)); 15366+ init_code << " int32_t order[4] = {0, 2, 3, 1};\n"; 15367+ init_code.CodeFunction("PackLstmWeightFp16", packed_weight_hh_str, weight_hh_str + " + " + std::to_string(hi_size_), 15368+ weight_segment_num_, lstm_param_->project_size_, lstm_param_->hidden_size_, 15369+ lstm_param_->state_col_align_, "order"); 15370+ 15371+ hh_bias_ = allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight); 15372+ MS_CHECK_PTR(hh_bias_); 15373+ size_t bias_hh_size = weight_segment_num_ * lstm_param_->state_col_align_ * DataTypeSize(data_type_); 15374+ w_buf_size += bias_hh_size; 15375+ init_code.CodeBufferOffsetExpression(hh_bias_, context->weight_name(), context->weight_offset_name(), 15376+ context->weight_size_name(), bias_hh_size); 15377+ auto hh_bias_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(reinterpret_cast<float16 *>(hh_bias_)); 15378+ init_code.CodeFunction("memset", hh_bias_str, 0, bias_hh_size); 15379+ 15380+ context->AppendInitWeightSizeCode(w_buf_size); 15381+ context->AppendInitCode(init_code.str()); 15382+ return RET_OK; 15383+} 15384+ 15385+int LstmMindirDynamicFP16Coder::InitProjectWeight(CoderContext *const context) { 15386+ if (hp_size_ == 0) { 15387+ return RET_OK; 15388+ } 15389+ 15390+ NNaclFp32Serializer init_code; 15391+ size_t w_buf_size = 0; 15392+ int scale = lstm_param_->bidirectional_ ? C2NUM : C1NUM; 15393+ int col_align = UP_ROUND(lstm_param_->project_size_, C8NUM); 15394+ size_t weight_pro_size = scale * lstm_param_->hidden_size_ * col_align * DataTypeSize(data_type_); 15395+ weight_project_ptr_ = allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight); 15396+ MS_CHECK_PTR(weight_project_ptr_); 15397+ init_code.CodeBufferOffsetExpression(weight_project_ptr_, context->weight_name(), context->weight_offset_name(), 15398+ context->weight_size_name(), weight_pro_size); 15399+ w_buf_size += weight_pro_size; 15400+ auto weight_hp_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(input_tensors_[FOURTH_INPUT]); 15401+ MS_CHECK_TRUE_MSG(!weight_hp_str.empty(), RET_INPUT_TENSOR_ERROR, "Lstm cannot get weight."); 15402+ auto weight_pro_str = 15403+ MemoryAllocator::GetInstance()->GetRuntimeAddr(reinterpret_cast<float16 *>(weight_project_ptr_)); 15404+ init_code.CodeFunction("PackLstmWeightFp16", weight_pro_str, 15405+ weight_hp_str + " + " + std::to_string(hi_size_ + hh_size_), scale, lstm_param_->hidden_size_, 15406+ lstm_param_->project_size_, col_align, "NULL"); 15407+ 15408+ size_t bias_pro_size = col_align * DataTypeSize(data_type_); 15409+ project_bias_ = allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight); 15410+ MS_CHECK_PTR(project_bias_); 15411+ init_code.CodeBufferOffsetExpression(project_bias_, context->weight_name(), context->weight_offset_name(), 15412+ context->weight_size_name(), bias_pro_size); 15413+ w_buf_size += bias_pro_size; 15414+ auto bias_pro_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(reinterpret_cast<float16 *>(project_bias_)); 15415+ init_code.CodeFunction("memset", bias_pro_str, 0, bias_pro_size); 15416+ 15417+ context->AppendInitWeightSizeCode(w_buf_size); 15418+ context->AppendInitCode(init_code.str()); 15419+ return RET_OK; 15420+} 15421+ 15422+int LstmMindirDynamicFP16Coder::ComputeWorkSpace() { 15423+ auto in_shape1 = shape_info_container_->GetTemplateShape(input_tensors_[FIRST_INPUT]); 15424+ auto seq_lens = shape_info_container_->GetRealNums(in_shape1[0]); 15425+ MS_CHECK_TRUE_MSG(!seq_lens.empty(), RET_ERROR, "Lstm cannot get seq_len"); 15426+ auto batches = shape_info_container_->GetRealNums(in_shape1[1]); 15427+ MS_CHECK_TRUE_MSG(!batches.empty(), RET_ERROR, "Lstm cannot get batch"); 15428+ size_t scene_num = seq_lens.size() > batches.size() ? seq_lens.size() : batches.size(); 15429+ for (size_t i = 0; i < scene_num; ++i) { 15430+ int seq_len = seq_lens[i % seq_lens.size()]; 15431+ int batch = batches[i % batches.size()]; 15432+ size_t buffer1 = 15433+ seq_len * batch <= C3NUM ? 0 : seq_len * batch * lstm_param_->input_size_ * DataTypeSize(data_type_); 15434+ size_t buffer2 = C4NUM * seq_len * batch * lstm_param_->hidden_size_ * DataTypeSize(data_type_); 15435+ size_t buffer3 = batch <= C3NUM ? 0 : batch * lstm_param_->output_size_ * DataTypeSize(data_type_); 15436+ size_t buffer4 = C4NUM * batch * lstm_param_->hidden_size_ * DataTypeSize(data_type_); 15437+ size_t buffer5 = (lstm_param_->zoneout_cell_ >= -FLT_EPSILON && lstm_param_->zoneout_cell_ <= FLT_EPSILON) 15438+ ? 0 15439+ : batch * lstm_param_->hidden_size_ * DataTypeSize(data_type_); 15440+ size_t buffer6 = (lstm_param_->zoneout_hidden_ >= -FLT_EPSILON && lstm_param_->zoneout_hidden_ <= FLT_EPSILON) 15441+ ? 0 15442+ : batch * lstm_param_->output_size_ * DataTypeSize(data_type_); 15443+ size_t buffer7 = (batch <= C3NUM || lstm_param_->project_size_ == 0) 15444+ ? 0 15445+ : batch * lstm_param_->hidden_size_ * DataTypeSize(data_type_); 15446+ auto whole_size = buffer1 + buffer2 + buffer3 + buffer4 + buffer5 + buffer6 + buffer7; 15447+ buffers_start_ = dynamic_mem_manager_->AllocWorkSpace(whole_size, i); 15448+ MS_CHECK_TRUE_MSG(!buffers_start_.empty(), RET_ERROR, "Lstm cannot alloc workspace."); 15449+ } 15450+ 15451+ return RET_OK; 15452+} 15453+ 15454+void LstmMindirDynamicFP16Coder::CreateBufferAddrStr() { 15455+ auto in_shape1 = shape_info_container_->GetTemplateShape(input_tensors_[FIRST_INPUT]); 15456+ auto seq_len = in_shape1[0]; 15457+ auto batch = in_shape1[1]; 15458+ auto input_row_align = "(" + seq_len + " * " + batch + " + 3) / 4 * 4"; 15459+ auto state_row_align = "(" + batch + " + 3) / 4 * 4"; 15460+ buffers_str_.push_back("(" + seq_len + " * " + batch + " <= 3) ? NULL : " + buffers_start_); 15461+ auto offset = "((" + seq_len + " * " + batch + " <= 3) ? 0 : (" + seq_len + " * " + batch + ") * " + 15462+ std::to_string(lstm_param_->input_size_ * DataTypeSize(data_type_)) + ")"; 15463+ buffers_str_.push_back(buffers_start_ + " + " + offset); 15464+ offset = "(" + offset + " + " + seq_len + " * " + batch + " * " + 15465+ std::to_string(C4NUM * lstm_param_->hidden_size_ * DataTypeSize(data_type_)) + ")"; 15466+ buffers_str_.push_back(batch + " <= 3 ? NULL : (" + buffers_start_ + " + " + offset + ")"); 15467+ offset = "(" + offset + " + (" + batch + " <= 3 ? 0 : (" + batch + ") * " + 15468+ std::to_string(lstm_param_->output_size_ * DataTypeSize(data_type_)) + "))"; 15469+ buffers_str_.push_back(buffers_start_ + " + " + offset); 15470+ offset = "(" + offset + " + " + batch + " * " + 15471+ std::to_string(C4NUM * lstm_param_->hidden_size_ * DataTypeSize(data_type_)) + ")"; 15472+ if (lstm_param_->zoneout_cell_ < -FLT_EPSILON || lstm_param_->zoneout_cell_ > FLT_EPSILON) { 15473+ buffers_str_.push_back(buffers_start_ + " + " + offset); 15474+ offset = 15475+ "(" + offset + " + " + batch + " * " + std::to_string(lstm_param_->hidden_size_ * DataTypeSize(data_type_)) + ")"; 15476+ } else { 15477+ buffers_str_.emplace_back("NULL"); 15478+ } 15479+ if (lstm_param_->zoneout_hidden_ < -FLT_EPSILON && lstm_param_->zoneout_hidden_ > FLT_EPSILON) { 15480+ buffers_str_.push_back(buffers_start_ + " + " + offset); 15481+ offset = 15482+ "(" + offset + " + " + batch + " * " + std::to_string(lstm_param_->output_size_ * DataTypeSize(data_type_)) + ")"; 15483+ } else { 15484+ buffers_str_.emplace_back("NULL"); 15485+ } 15486+ if (lstm_param_->project_size_ == 0) { 15487+ buffers_str_.emplace_back("NULL"); 15488+ } else { 15489+ buffers_str_.emplace_back(batch + " <= 3 ? NULL : " + "(" + buffers_start_ + " + " + offset + ")"); 15490+ } 15491+} 15492+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_LSTM, 15493+ CPUOpCoderCreator<LstmMindirDynamicFP16Coder>) 15494+} // namespace mindspore::lite::micro::nnacl 15495diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_mindir_dynamic_fp16_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_mindir_dynamic_fp16_coder.h 15496new file mode 100644 15497index 00000000..1084fa82 15498--- /dev/null 15499+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_mindir_dynamic_fp16_coder.h 15500@@ -0,0 +1,66 @@ 15501+/** 15502+ * Copyright 2023 Huawei Technologies Co., Ltd 15503+ * 15504+ * Licensed under the Apache License, Version 2.0 (the "License"); 15505+ * you may not use this file except in compliance with the License. 15506+ * You may obtain a copy of the License at 15507+ * 15508+ * http://www.apache.org/licenses/LICENSE-2.0 15509+ * 15510+ * Unless required by applicable law or agreed to in writing, software 15511+ * distributed under the License is distributed on an "AS IS" BASIS, 15512+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15513+ * See the License for the specific language governing permissions and 15514+ * limitations under the License. 15515+ */ 15516+ 15517+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_LSTM_DYNAMIC_FP16_CODER_H 15518+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_LSTM_DYNAMIC_FP16_CODER_H 15519+ 15520+#include <vector> 15521+#include <string> 15522+#include "nnacl/lstm_parameter.h" 15523+#include "coder/opcoders/nnacl/dynamic_parameter/dynamic_lstm_parameter.h" 15524+#include "coder/opcoders/op_coder.h" 15525+ 15526+namespace mindspore::lite::micro::nnacl { 15527+ 15528+class LstmMindirDynamicFP16Coder : public OperatorCoder { 15529+ public: 15530+ LstmMindirDynamicFP16Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 15531+ const LiteGraph::Node *node, size_t node_index, Target target) 15532+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 15533+ 15534+ ~LstmMindirDynamicFP16Coder() override = default; 15535+ 15536+ int Prepare(CoderContext *const context) override; 15537+ int DoCode(CoderContext *const context) override; 15538+ 15539+ private: 15540+ int InitParam(); 15541+ int ComputeWorkSpace(); 15542+ void CreateBufferAddrStr(); 15543+ int InitInputWeightBias(CoderContext *const context); 15544+ int InitStateWeightBias(CoderContext *const context); 15545+ int InitProjectWeight(CoderContext *const context); 15546+ bool gpu_state_{false}; 15547+ TypeId data_type_{kNumberTypeFloat16}; 15548+ int weight_segment_num_{0}; 15549+ size_t hi_size_{0}; 15550+ size_t hh_size_{0}; 15551+ size_t hp_size_{0}; 15552+ size_t bias_size_{0}; 15553+ void *weight_i_ptr_{nullptr}; 15554+ void *weight_h_ptr_{nullptr}; 15555+ void *weight_project_ptr_{nullptr}; 15556+ void *input_bias_{nullptr}; 15557+ void *hh_bias_{nullptr}; 15558+ void *project_bias_{nullptr}; 15559+ LstmParameter *lstm_param_{nullptr}; 15560+ DynamicLstmParameter dynamic_lstm_param_; 15561+ std::string buffers_start_; 15562+ std::vector<std::string> buffers_str_; 15563+}; 15564+} // namespace mindspore::lite::micro::nnacl 15565+ 15566+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_LSTM_DYNAMIC_FP16_CODER_H 15567diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.cc 15568new file mode 100644 15569index 00000000..f6c56f86 15570--- /dev/null 15571+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.cc 15572@@ -0,0 +1,228 @@ 15573+/** 15574+ * Copyright 2023 Huawei Technologies Co., Ltd 15575+ * 15576+ * Licensed under the Apache License, Version 2.0 (the "License"); 15577+ * you may not use this file except in compliance with the License. 15578+ * You may obtain a copy of the License at 15579+ * 15580+ * http://www.apache.org/licenses/LICENSE-2.0 15581+ * 15582+ * Unless required by applicable law or agreed to in writing, software 15583+ * distributed under the License is distributed on an "AS IS" BASIS, 15584+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15585+ * See the License for the specific language governing permissions and 15586+ * limitations under the License. 15587+ */ 15588+ 15589+#include "tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.h" 15590+#include <string> 15591+#include <vector> 15592+#include "tools/converter/micro/coder/log.h" 15593+#include "tools/converter/micro/coder/opcoders/file_collector.h" 15594+#include "base/float16.h" 15595+#include "tools/common/string_util.h" 15596+#include "coder/utils/coder_utils.h" 15597+ 15598+using mindspore::schema::PrimitiveType_MatMulFusion; 15599+ 15600+namespace mindspore::lite::micro::nnacl { 15601+int MatMulDynamicFP16BaseCoder::Prepare(CoderContext *const context) { 15602+ row_tile_ = C1NUM; 15603+ col_tile_ = C4NUM; 15604+ auto ret = InitAShape(); 15605+ MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "init A-metrics' info failed"); 15606+ ret = InitBShape(); 15607+ MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "init B-metrics' info failed"); 15608+ params_->col_align_ = UP_ROUND(params_->col_, col_tile_); 15609+ return RET_OK; 15610+} 15611+ 15612+int MatMulDynamicFP16BaseCoder::DoCode(CoderContext *const context) { 15613+ CollectFilesForTarget(context); 15614+ auto ret = InitMatrixB(context); 15615+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "InitMatrixB failed."); 15616+ ret = InitBiasData(context); 15617+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "InitBiasData failed."); 15618+ 15619+ ret = ComputeWorkSpace(); 15620+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "Matmul alloc workspace failed."); 15621+ auto input_a_str = dynamic_mem_manager_->GetVarTensorAddr(input_tensor_); 15622+ MS_CHECK_TRUE_MSG(!input_a_str.empty(), RET_ERROR, "Matmul cannot get matrixA"); 15623+ auto output_str = dynamic_mem_manager_->GetVarTensorAddr(output_tensor_); 15624+ MS_CHECK_TRUE_MSG(!output_str.empty(), RET_ERROR, "Matmul cannot get output"); 15625+ NNaclFp32Serializer code; 15626+ if (params_->a_transpose_) { 15627+ code << " if (" << dynamic_params_.row_ << " == 1) {\n"; 15628+ code << " if (" << dynamic_params_.batch_ << " <= 3) {\n"; 15629+ code.CodeFunction("MatmulFp16OptV2", "(float16_t *)(" + input_a_str + ")", input_b_pack_str_, 15630+ "(float16_t *)(" + output_str + ")", bias_str_, params_->act_type_, params_->deep_, 15631+ dynamic_params_.batch_, params_->col_, params_->col_, OutType_Nhwc); 15632+ code << " } else {\n"; 15633+ code.CodeFunction("RowMajor2ColLadder12MajorFp16", "(float16_t *)(" + input_a_str + ")", 15634+ "(float16_t *)(" + buffer_start_ + ")", dynamic_params_.batch_, params_->deep_); 15635+ code.CodeFunction("MatmulFp16OptV2", "(float16_t *)(" + buffer_start_ + ")", input_b_pack_str_, 15636+ "(float16_t *)(" + output_str + ")", bias_str_, params_->act_type_, params_->deep_, 15637+ dynamic_params_.batch_, params_->col_, params_->col_, OutType_Nhwc); 15638+ code << " } else {\n"; 15639+ code << " int in_stride = " << dynamic_params_.row_ << " * " << params_->deep_ << ";\n"; 15640+ code << " int out_stride = " << dynamic_params_.row_ << " * " << params_->col_ << ";\n"; 15641+ code << " for (int i = 0; i < " << dynamic_params_.batch_ << "; ++i) {\n"; 15642+ code.CodeFunction("RowMajor2RowLadder12MajorFp16", "(float16_t *)(" + input_a_str + ")" + " + in_stride * i", 15643+ "(float16_t *)(" + buffer_start_ + ")", params_->deep_, dynamic_params_.row_); 15644+ code.CodeFunction("MatmulFp16OptV2", "(float16_t *)(" + buffer_start_ + ")", input_b_pack_str_, 15645+ "(float16_t *)(" + output_str + ")" + " + out_stride * i", bias_str_, params_->act_type_, 15646+ params_->deep_, dynamic_params_.row_, params_->col_, OutType_Nhwc); 15647+ code << " }\n"; 15648+ code << " }\n"; 15649+ } else { 15650+ code << " if (" << dynamic_params_.batch_ << " * " << dynamic_params_.row_ << " <= 3) {\n"; 15651+ code.CodeFunction("MatmulFp16OptV2", "(float16_t *)(" + input_a_str + ")", input_b_pack_str_, 15652+ "(float16_t *)(" + output_str + ")", bias_str_, params_->act_type_, params_->deep_, 15653+ dynamic_params_.batch_ + " * " + dynamic_params_.row_, params_->col_, params_->col_, 15654+ OutType_Nhwc); 15655+ code << " } else {\n"; 15656+ code.CodeFunction("RowMajor2ColLadder12MajorFp16", "(float16_t *)(" + input_a_str + ")", 15657+ "(float16_t *)(" + buffer_start_ + ")", dynamic_params_.batch_ + " * " + dynamic_params_.row_, 15658+ params_->deep_); 15659+ code.CodeFunction("MatmulFp16OptV2", "(float16_t *)(" + buffer_start_ + ")", input_b_pack_str_, 15660+ "(float16_t *)(" + output_str + ")", bias_str_, params_->act_type_, params_->deep_, 15661+ dynamic_params_.batch_ + " * " + dynamic_params_.row_, params_->col_, params_->col_, 15662+ OutType_Nhwc); 15663+ } 15664+ code << " }\n"; 15665+ context->AppendCode(code.str()); 15666+ return RET_OK; 15667+} 15668+ 15669+int MatMulDynamicFP16BaseCoder::InitMatrixB(CoderContext *const context) { 15670+ NNaclFp32Serializer init_code; 15671+ if (b_pack_ptr_ != nullptr) { 15672+ return RET_OK; 15673+ } 15674+ auto b_pack_ptr_size = static_cast<size_t>(params_->col_align_ * params_->deep_ * DataTypeSize(data_type_)); 15675+ b_pack_ptr_ = allocator_->GetSharedWeightAddr(filter_tensor_); 15676+ if (b_pack_ptr_ == nullptr) { 15677+ b_pack_ptr_ = allocator_->Malloc(data_type_, b_pack_ptr_size, kOnlinePackWeight, 15678+ filter_tensor_->tensor_name() + "_online_pack"); 15679+ allocator_->MarkSharedWeight(filter_tensor_, b_pack_ptr_); 15680+ } 15681+ MS_CHECK_PTR(b_pack_ptr_); 15682+ std::string input_b_str = allocator_->GetRuntimeAddr(filter_tensor_); 15683+ input_b_pack_str_ = allocator_->GetRuntimeAddr(static_cast<float16 *>(b_pack_ptr_)); 15684+ init_code.CodeBufferOffsetExpression(b_pack_ptr_, context->weight_name(), context->weight_offset_name(), 15685+ context->weight_size_name(), b_pack_ptr_size); 15686+ if (b_batch_ == C1NUM) { 15687+ if (params_->b_transpose_) { 15688+ init_code.CodeFunction("RowMajor2ColNMajorFp16", input_b_str, input_b_pack_str_, params_->col_, params_->deep_, 15689+ "false"); 15690+ } else { 15691+ init_code.CodeFunction("RowMajor2RowNMajorFp16", input_b_str, input_b_pack_str_, params_->deep_, params_->col_, 15692+ "false"); 15693+ } 15694+ } else { 15695+ init_code << " for (int i = 0; i < " << b_batch_ << "; i++) {\n" 15696+ << " float16_t *src = " << input_b_str << " + i * " << params_->deep_ * params_->col_ << ";\n" 15697+ << " float16_t *dst = " << input_b_pack_str_ << " + i * " << params_->deep_ * params_->col_align_ 15698+ << ";\n"; 15699+ if (params_->b_transpose_) { 15700+ init_code << " RowMajor2ColNMajorFp16(src, dst, " << params_->col_ << ", " << params_->deep_ << ", false);\n"; 15701+ } else { 15702+ init_code << " RowMajor2RowNMajorFp16(src, dst, " << params_->deep_ << ", " << params_->col_ << ", false);\n"; 15703+ } 15704+ init_code << " }\n"; 15705+ } 15706+ context->AppendInitWeightSizeCode(b_pack_ptr_size); 15707+ context->AppendInitCode(init_code.str()); 15708+ return RET_OK; 15709+} 15710+ 15711+int MatMulDynamicFP16BaseCoder::InitBiasData(CoderContext *const context) { 15712+ NNaclFp32Serializer init_code; 15713+ if (bias_ptr_ != nullptr) { 15714+ return RET_OK; 15715+ } 15716+ auto bias_pack_ptr_size = static_cast<size_t>(params_->col_align_ * DataTypeSize(data_type_)); 15717+ if (input_tensors_.size() == C3NUM) { 15718+ bias_ptr_ = 15719+ allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight, bias_tensor_->tensor_name() + "_online_pack"); 15720+ MS_CHECK_PTR(bias_ptr_); 15721+ } else { 15722+ bias_ptr_ = allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight, node_->name_ + "_bias_online_pack"); 15723+ MS_CHECK_PTR(bias_ptr_); 15724+ } 15725+ init_code.CodeBufferOffsetExpression(bias_ptr_, context->weight_name(), context->weight_offset_name(), 15726+ context->weight_size_name(), bias_pack_ptr_size); 15727+ bias_str_ = allocator_->GetRuntimeAddr(bias_ptr_); 15728+ if (input_tensors_.size() == DIMENSION_3D) { 15729+ auto origin_bias_str = allocator_->GetRuntimeAddr(bias_tensor_); 15730+ init_code.CodeFunction("memcpy", bias_str_, origin_bias_str, bias_tensor_->Size()); 15731+ } else { 15732+ init_code.CodeFunction("memset", bias_str_, 0, bias_pack_ptr_size); 15733+ } 15734+ context->AppendInitWeightSizeCode(bias_pack_ptr_size); 15735+ context->AppendInitCode(init_code.str()); 15736+ return RET_OK; 15737+} 15738+ 15739+int MatMulDynamicFP16BaseCoder::ComputeWorkSpace() { 15740+ auto a_shape = shape_info_container_->GetTemplateShape(input_tensor_); 15741+ std::map<std::string, std::vector<int>> real_nums; 15742+ size_t scene_num = 0; 15743+ for (auto &dim_template : a_shape) { 15744+ auto dim_nums = shape_info_container_->GetRealNums(dim_template); 15745+ MS_CHECK_TRUE_MSG(!dim_nums.empty(), RET_ERROR, "Dynamic shape's num must be greater than 0."); 15746+ real_nums[dim_template] = dim_nums; 15747+ scene_num = std::max(scene_num, dim_nums.size()); 15748+ } 15749+ for (size_t i = 0; i < scene_num; ++i) { 15750+ std::vector<int> real_shape(a_shape.size()); 15751+ for (size_t j = 0; j < a_shape.size(); ++j) { 15752+ if (IsNumber(a_shape[j])) { 15753+ real_shape[j] = std::stoi(a_shape[j]); 15754+ } else { 15755+ real_shape[j] = real_nums[a_shape[j]][i % real_nums[a_shape[j]].size()]; 15756+ } 15757+ } 15758+ int a_batch = 1; 15759+ for (size_t j = 0; j < a_shape.size() - C2NUM; ++j) { 15760+ MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch, real_shape[j], RET_ERROR); 15761+ a_batch *= real_shape[j]; 15762+ } 15763+ int row = params_->a_transpose_ ? real_shape.back() : real_shape[real_shape.size() - C2NUM]; 15764+ int deep = params_->a_transpose_ ? real_shape[real_shape.size() - C2NUM] : real_shape.back(); 15765+ MS_CHECK_TRUE_MSG(deep == params_->deep_, RET_INPUT_TENSOR_ERROR, 15766+ "Matmul's matrixA doesn't match matrixB, becase their deeps are not same."); 15767+ int workspace = 0; 15768+ if (params_->a_transpose_) { 15769+ workspace = (row == 1 ? (a_batch <= C3NUM ? 0 : UP_ROUND(a_batch, row_tile_)) : UP_ROUND(row, row_tile_)) * deep; 15770+ } else { 15771+ workspace = (a_batch * row <= C3NUM ? 0 : UP_ROUND(a_batch * row, row_tile_)) * deep; 15772+ } 15773+ buffer_start_ = dynamic_mem_manager_->AllocWorkSpace(workspace, i); 15774+ MS_CHECK_TRUE_MSG(!buffer_start_.empty(), RET_ERROR, "Matmul cannot alloc workspace."); 15775+ } 15776+ return RET_OK; 15777+} 15778+ 15779+int MatMulDynamicFP16BaseCoder::CollectFilesForTarget(CoderContext *const context) { 15780+ Collect(context, 15781+ { 15782+ "nnacl/fp16/pack_fp16.h", 15783+ "nnacl/fp16/matmul_fp16.h", 15784+ }, 15785+ { 15786+ "pack_fp16.c", 15787+ "matmul_fp16.c", 15788+ }); 15789+ if (target_ == kARM32) { 15790+ Collect(context, {}, {}, 15791+ { 15792+ "Matmul12x8Fp16.S", 15793+ "MatVecMulFp16.S", 15794+ }); 15795+ } else if (target_ == kARM64) { 15796+ Collect(context, {}, {}, {"MatmulFp16OptV2.S"}); 15797+ } 15798+ return RET_OK; 15799+} 15800+} // namespace mindspore::lite::micro::nnacl 15801diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.h 15802new file mode 100644 15803index 00000000..f73cfff7 15804--- /dev/null 15805+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.h 15806@@ -0,0 +1,73 @@ 15807+/** 15808+ * Copyright 2023 Huawei Technologies Co., Ltd 15809+ * 15810+ * Licensed under the Apache License, Version 2.0 (the "License"); 15811+ * you may not use this file except in compliance with the License. 15812+ * You may obtain a copy of the License at 15813+ * 15814+ * http://www.apache.org/licenses/LICENSE-2.0 15815+ * 15816+ * Unless required by applicable law or agreed to in writing, software 15817+ * distributed under the License is distributed on an "AS IS" BASIS, 15818+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15819+ * See the License for the specific language governing permissions and 15820+ * limitations under the License. 15821+ */ 15822+ 15823+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_MATMUL_FP16_BASE_CODER_H_ 15824+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_MATMUL_FP16_BASE_CODER_H_ 15825+ 15826+#include <vector> 15827+#include <string> 15828+#include "tools/converter/micro/coder/opcoders/op_coder.h" 15829+#include "tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 15830+#include "nnacl/matmul_parameter.h" 15831+#include "tools/converter/micro/coder/shape_info_container.h" 15832+#include "tools/converter/micro/coder/dynamic_mem_manager.h" 15833+#include "base/float16.h" 15834+#include "coder/opcoders/nnacl/dynamic_parameter/matmul_dynamic_parameter.h" 15835+ 15836+namespace mindspore::lite::micro::nnacl { 15837+class MatMulDynamicFP16BaseCoder : public OperatorCoder { 15838+ public: 15839+ MatMulDynamicFP16BaseCoder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 15840+ const LiteGraph::Node *node, size_t node_index, Target target) 15841+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 15842+ 15843+ ~MatMulDynamicFP16BaseCoder() override = default; 15844+ 15845+ int Prepare(CoderContext *const context) override; 15846+ 15847+ int DoCode(CoderContext *const context) override; 15848+ 15849+ private: 15850+ int InitBiasData(CoderContext *const context); 15851+ int InitMatrixB(CoderContext *const context); 15852+ int CollectFilesForTarget(CoderContext *const context); 15853+ int ComputeWorkSpace(); 15854+ 15855+ protected: 15856+ virtual int InitAShape() = 0; 15857+ virtual int InitBShape() = 0; 15858+ 15859+ protected: 15860+ Tensor *filter_tensor_{nullptr}; 15861+ Tensor *bias_tensor_{nullptr}; 15862+ MatMulParameter *params_{nullptr}; 15863+ MatmulDynamicParameter dynamic_params_; 15864+ void *a_pack_ptr_ = nullptr; 15865+ void *b_pack_ptr_ = nullptr; 15866+ void *bias_ptr_{nullptr}; 15867+ int col_tile_{0}; 15868+ int row_tile_{0}; 15869+ size_t a_pack_ptr_size_{0}; 15870+ TypeId data_type_{kNumberTypeFloat16}; 15871+ int a_batch_; 15872+ int b_batch_; 15873+ std::string buffer_start_; 15874+ std::string bias_str_; 15875+ std::string input_a_pack_str_; 15876+ std::string input_b_pack_str_; 15877+}; 15878+} // namespace mindspore::lite::micro::nnacl 15879+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_MATMUL_FP16_BASE_CODER_H_ 15880diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_coder.cc 15881new file mode 100644 15882index 00000000..24cf7120 15883--- /dev/null 15884+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_coder.cc 15885@@ -0,0 +1,100 @@ 15886+/** 15887+ * Copyright 2023 Huawei Technologies Co., Ltd 15888+ * 15889+ * Licensed under the Apache License, Version 2.0 (the "License"); 15890+ * you may not use this file except in compliance with the License. 15891+ * You may obtain a copy of the License at 15892+ * 15893+ * http://www.apache.org/licenses/LICENSE-2.0 15894+ * 15895+ * Unless required by applicable law or agreed to in writing, software 15896+ * distributed under the License is distributed on an "AS IS" BASIS, 15897+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15898+ * See the License for the specific language governing permissions and 15899+ * limitations under the License. 15900+ */ 15901+ 15902+#include "coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_coder.h" 15903+#include <vector> 15904+#include "coder/log.h" 15905+#include "coder/opcoders/file_collector.h" 15906+#include "tools/common/string_util.h" 15907+#include "coder/utils/coder_utils.h" 15908+ 15909+using mindspore::schema::PrimitiveType_MatMulFusion; 15910+ 15911+namespace mindspore::lite::micro::nnacl { 15912+int MatMulDynamicFP16Coder::InitAShape() { 15913+ auto a_shape = shape_info_container_->GetTemplateShape(input_tensor_); 15914+ auto a_shape_size = a_shape.size(); 15915+ MS_CHECK_TRUE_MSG(a_shape_size >= DIMENSION_2D, RET_NOT_SUPPORT, "Matmul's a_shape_size must be not less than two."); 15916+ int64_t const_part = 1; 15917+ std::string non_const_part; 15918+ for (size_t i = 0; i < a_shape_size - C2NUM; ++i) { 15919+ if (IsNumber(a_shape[i])) { 15920+ const_part *= std::atoi(a_shape[i].c_str()); 15921+ } else { 15922+ if (!non_const_part.empty()) { 15923+ non_const_part += " * "; 15924+ } 15925+ non_const_part += a_shape[i]; 15926+ } 15927+ } 15928+ dynamic_params_.batch_ = non_const_part + " * " + std::to_string(const_part); 15929+ dynamic_params_.row_ = params_->a_transpose_ ? a_shape[a_shape.size() - C1NUM] : a_shape[a_shape.size() - C2NUM]; 15930+ return RET_OK; 15931+} 15932+ 15933+int MatMulDynamicFP16Coder::InitBShape() { 15934+ std::vector<int> b_shape = filter_tensor_->shape(); 15935+ MS_CHECK_TRUE_MSG(b_shape.size() >= DIMENSION_2D, RET_NOT_SUPPORT, 15936+ "Matmul's b_shape_size must be not less than two."); 15937+ int batch = 1; 15938+ for (size_t i = 0; i < b_shape.size() - DIMENSION_2D; ++i) { 15939+ batch *= b_shape[i]; 15940+ } 15941+ if (batch != 1) { 15942+ MS_LOG(ERROR) << "Currently, Matmul only support matrixB's batch is 1."; 15943+ } 15944+ b_batch_ = batch; 15945+ params_->col_ = params_->b_transpose_ ? b_shape[b_shape.size() - C2NUM] : b_shape[b_shape.size() - C1NUM]; 15946+ params_->col_8_ = UP_ROUND(params_->col_, C8NUM); 15947+ params_->deep_ = params_->b_transpose_ ? b_shape[b_shape.size() - C1NUM] : b_shape[b_shape.size() - C2NUM]; 15948+ return RET_OK; 15949+} 15950+ 15951+int MatMulDynamicFP16Coder::Prepare(CoderContext *const context) { 15952+ for (size_t i = 0; i < input_tensors_.size(); ++i) { 15953+ MS_CHECK_TRUE_MSG(input_tensors_[i]->data_type() == kNumberTypeFloat16, RET_INPUT_PARAM_INVALID, 15954+ "Input tensor data type is invalid."); 15955+ } 15956+ MS_CHECK_TRUE_MSG(output_tensor_->data_type() == kNumberTypeFloat16, RET_INPUT_PARAM_INVALID, 15957+ "Input tensor data type is invalid."); 15958+ MS_CHECK_TRUE_MSG(input_tensors_.size() == C2NUM || input_tensors_.size() == C3NUM, RET_INPUT_PARAM_INVALID, 15959+ "MatMul's input-num must be 2 or 3."); 15960+ MS_CHECK_TRUE_MSG(input_tensors_[SECOND_INPUT]->IsConst(), RET_NOT_SUPPORT, 15961+ "Currently, only support the first input of matmul is non-const when shape is dynamical."); 15962+ if (input_tensors_.size() == C3NUM) { 15963+ MS_CHECK_TRUE_MSG(input_tensors_[THIRD_INPUT]->IsConst(), RET_NOT_SUPPORT, 15964+ "Currently, only support the first input of matmul is non-const when shape is dynamical."); 15965+ } 15966+ params_ = reinterpret_cast<MatMulParameter *>(parameter_); 15967+ filter_tensor_ = input_tensors_.at(kWeightIndex); 15968+ MS_CHECK_PTR(filter_tensor_); 15969+ if (input_tensors_.size() == kInputSize2) { 15970+ bias_tensor_ = input_tensors_.at(kBiasIndex); 15971+ MS_CHECK_PTR(bias_tensor_); 15972+ MS_CHECK_PTR(bias_tensor_->data()); 15973+ } 15974+ params_->a_const_ = (input_tensor_->data() != nullptr); 15975+ params_->b_const_ = (filter_tensor_->data() != nullptr); 15976+ MS_CHECK_RET_CODE(MatMulDynamicFP16BaseCoder::Prepare(context), "MatMulDynamicFP16Coder prepare failed"); 15977+ return RET_OK; 15978+} 15979+ 15980+int MatMulDynamicFP16Coder::DoCode(CoderContext *const context) { return MatMulDynamicFP16BaseCoder::DoCode(context); } 15981+ 15982+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_MatMulFusion, 15983+ CPUOpCoderCreator<MatMulDynamicFP16Coder>) 15984+// REG_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_MatMulFusion, CPUOpCoderCreator<MatMulDynamicFP16Coder>) 15985+} // namespace mindspore::lite::micro::nnacl 15986diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_coder.h 15987new file mode 100644 15988index 00000000..1a16798c 15989--- /dev/null 15990+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_coder.h 15991@@ -0,0 +1,44 @@ 15992+/** 15993+ * Copyright 2023 Huawei Technologies Co., Ltd 15994+ * 15995+ * Licensed under the Apache License, Version 2.0 (the "License"); 15996+ * you may not use this file except in compliance with the License. 15997+ * You may obtain a copy of the License at 15998+ * 15999+ * http://www.apache.org/licenses/LICENSE-2.0 16000+ * 16001+ * Unless required by applicable law or agreed to in writing, software 16002+ * distributed under the License is distributed on an "AS IS" BASIS, 16003+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16004+ * See the License for the specific language governing permissions and 16005+ * limitations under the License. 16006+ */ 16007+ 16008+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_MATMUL_DYNAMIC_FP16_CODER_H_ 16009+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_MATMUL_DYNAMIC_FP16_CODER_H_ 16010+ 16011+#include <vector> 16012+#include "tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.h" 16013+#include "nnacl/matmul_parameter.h" 16014+#include "tools/converter/micro/coder/shape_info_container.h" 16015+#include "tools/converter/micro/coder/dynamic_mem_manager.h" 16016+ 16017+namespace mindspore::lite::micro::nnacl { 16018+class MatMulDynamicFP16Coder final : public MatMulDynamicFP16BaseCoder { 16019+ public: 16020+ MatMulDynamicFP16Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 16021+ const LiteGraph::Node *node, size_t node_index, Target target) 16022+ : MatMulDynamicFP16BaseCoder(in_tensors, out_tensors, node, node_index, target) {} 16023+ 16024+ ~MatMulDynamicFP16Coder() override = default; 16025+ 16026+ int Prepare(CoderContext *const context) override; 16027+ 16028+ int DoCode(CoderContext *const context) override; 16029+ 16030+ private: 16031+ int InitAShape() override; 16032+ int InitBShape() override; 16033+}; 16034+} // namespace mindspore::lite::micro::nnacl 16035+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_MATMUL_DYNAMIC_FP16_CODER_H_ 16036diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.cc 16037index 67f633fe..415e912d 100644 16038--- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.cc 16039+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.cc 16040@@ -102,14 +102,15 @@ std::string MatMulFP16BaseCoder::InitMatrixA(NNaclFp32Serializer *const code, NN 16041 if (a_batch_ == 1) { 16042 if (params_.a_transpose_) { 16043 if (target_ == kARM64) { 16044- pack_code.CodeFunction("RowMajor2RowNMajorFp16", input_a_str, input_a_pack_str, params_.deep_, params_.row_); 16045+ pack_code.CodeFunction("RowMajor2RowNMajorFp16", input_a_str, input_a_pack_str, params_.deep_, params_.row_, 16046+ "false"); 16047 } else { 16048 pack_code.CodeFunction("RowMajor2Row12MajorFp16", input_a_str, input_a_pack_str, params_.deep_, params_.row_, 16049 false); 16050 } 16051 } else { 16052 if (target_ == kARM64) { 16053- pack_code.CodeFunction("RowMajor2ColNMajorFp16", input_a_str, input_a_pack_str, params_.row_, params_.deep_); 16054+ pack_code.CodeFunction("RowMajor2ColNMajorFp16", input_a_str, input_a_pack_str, params_.row_, params_.deep_, false); 16055 } else { 16056 pack_code.CodeFunction("RowMajor2Col12MajorFp16", input_a_str, input_a_pack_str, params_.row_, params_.deep_, 16057 false); 16058@@ -122,13 +123,13 @@ std::string MatMulFP16BaseCoder::InitMatrixA(NNaclFp32Serializer *const code, NN 16059 << ";\n"; 16060 if (params_.a_transpose_) { 16061 if (target_ == kARM64) { 16062- pack_code << " RowMajor2RowNMajorFp16(src, dst, " << params_.deep_ << ", " << params_.row_ << ");\n"; 16063+ pack_code << " RowMajor2RowNMajorFp16(src, dst, " << params_.deep_ << ", " << params_.row_ << ", false);\n"; 16064 } else { 16065 pack_code << " RowMajor2Row12MajorFp16(src, dst, " << params_.deep_ << ", " << params_.row_ << ", false);\n"; 16066 } 16067 } else { 16068 if (target_ == kARM64) { 16069- pack_code << " RowMajor2ColNMajorFp16(src, dst, " << params_.row_ << ", " << params_.deep_ << ");\n"; 16070+ pack_code << " RowMajor2ColNMajorFp16(src, dst, " << params_.row_ << ", " << params_.deep_ << ", false);\n"; 16071 } else { 16072 pack_code << " RowMajor2Col12MajorFp16(src, dst, " << params_.row_ << ", " << params_.deep_ << ", false);\n"; 16073 } 16074diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_dynamic_fp16_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_dynamic_fp16_coder.cc 16075new file mode 100644 16076index 00000000..c565f5b2 16077--- /dev/null 16078+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_dynamic_fp16_coder.cc 16079@@ -0,0 +1,89 @@ 16080+/** 16081+ * Copyright 2023 Huawei Technologies Co., Ltd 16082+ * 16083+ * Licensed under the Apache License, Version 2.0 (the "License"); 16084+ * you may not use this file except in compliance with the License. 16085+ * You may obtain a copy of the License at 16086+ * 16087+ * http://www.apache.org/licenses/LICENSE-2.0 16088+ * 16089+ * Unless required by applicable law or agreed to in writing, software 16090+ * distributed under the License is distributed on an "AS IS" BASIS, 16091+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16092+ * See the License for the specific language governing permissions and 16093+ * limitations under the License. 16094+ */ 16095+#include "coder/opcoders/nnacl/fp16/pooling_dynamic_fp16_coder.h" 16096+#include <cfloat> 16097+#include <string> 16098+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 16099+#include "coder/log.h" 16100+#include "coder/opcoders/parallel.h" 16101+#include "coder/opcoders/file_collector.h" 16102+#include "coder/utils/coder_utils.h" 16103+ 16104+using mindspore::schema::PrimitiveType_AvgPoolFusion; 16105+using mindspore::schema::PrimitiveType_MaxPoolFusion; 16106+ 16107+namespace mindspore::lite::micro::nnacl { 16108+int PoolingDynamicFP16Coder::Prepare(CoderContext *const context) { 16109+ if (input_tensor_->data_type() != kNumberTypeFloat16 || output_tensor_->data_type() != kNumberTypeFloat16) { 16110+ MS_LOG(ERROR) << "Tensor data type is invalid"; 16111+ return lite::RET_INPUT_PARAM_INVALID; 16112+ } 16113+ param_ = reinterpret_cast<PoolingParameter *>(parameter_); 16114+ MS_CHECK_PTR(param_); 16115+ dynamic_param_.input_batch_ = shape_info_container_->GetTemplateShape(input_tensor_)[0]; 16116+ compute_.input_channel_ = input_tensor_->Channel(); 16117+ compute_.input_h_ = input_tensor_->Height(); 16118+ compute_.input_w_ = input_tensor_->Width(); 16119+ dynamic_param_.output_batch_ = shape_info_container_->GetTemplateShape(output_tensor_)[0]; 16120+ compute_.output_channel_ = output_tensor_->Channel(); 16121+ compute_.output_h_ = output_tensor_->Height(); 16122+ compute_.output_w_ = output_tensor_->Width(); 16123+ if (param_->global_) { 16124+ param_->window_h_ = compute_.input_h_; 16125+ param_->window_w_ = compute_.input_w_; 16126+ } 16127+ return RET_OK; 16128+} 16129+ 16130+int PoolingDynamicFP16Coder::DoCode(CoderContext *const context) { 16131+ Collect(context, 16132+ { 16133+ "nnacl/fp16/pooling_fp16.h", 16134+ }, 16135+ { 16136+ "pooling_fp16.c", 16137+ }); 16138+ NNaclFp32Serializer code; 16139+ code.CodeStruct("pooling_parameter", *param_); 16140+ code.CodeStruct("pooling_compute", compute_, dynamic_param_); 16141+ 16142+ auto input_data = 16143+ "(float16_t *)(" + GetTensorAddr(input_tensor_, input_tensor_->IsConst(), dynamic_mem_manager_, allocator_) + ")"; 16144+ auto output_data = 16145+ "(float16_t *)(" + GetTensorAddr(output_tensor_, output_tensor_->IsConst(), dynamic_mem_manager_, allocator_) + ")"; 16146+ if (param_->pool_mode_ == PoolMode_MaxPool) { 16147+ code.CodeFunction("MaxPoolingFp16", input_data, output_data, "&pooling_parameter", "&pooling_compute", 16148+ kDefaultTaskId, param_->op_parameter_.thread_num_); 16149+ } else if (param_->pool_mode_ == PoolMode_AvgPool) { 16150+ code.CodeFunction("AvgPoolingFp16", input_data, output_data, "&pooling_parameter", "&pooling_compute", 16151+ kDefaultTaskId, param_->op_parameter_.thread_num_); 16152+ } else { 16153+ MS_LOG(ERROR) << "Unsupported pooling mode."; 16154+ return lite::RET_ERROR; 16155+ } 16156+ context->AppendCode(code.str()); 16157+ return lite::RET_OK; 16158+} 16159+ 16160+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_AvgPoolFusion, 16161+ CPUOpCoderCreator<PoolingDynamicFP16Coder>) 16162+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_AvgPoolFusion, 16163+ CPUOpCoderCreator<PoolingDynamicFP16Coder>) 16164+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_MaxPoolFusion, 16165+ CPUOpCoderCreator<PoolingDynamicFP16Coder>) 16166+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_MaxPoolFusion, 16167+ CPUOpCoderCreator<PoolingDynamicFP16Coder>) 16168+} // namespace mindspore::lite::micro::nnacl 16169diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_dynamic_fp16_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_dynamic_fp16_coder.h 16170new file mode 100644 16171index 00000000..7b138b61 16172--- /dev/null 16173+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_dynamic_fp16_coder.h 16174@@ -0,0 +1,44 @@ 16175+/** 16176+ * Copyright 2023 Huawei Technologies Co., Ltd 16177+ * 16178+ * Licensed under the Apache License, Version 2.0 (the "License"); 16179+ * you may not use this file except in compliance with the License. 16180+ * You may obtain a copy of the License at 16181+ * 16182+ * http://www.apache.org/licenses/LICENSE-2.0 16183+ * 16184+ * Unless required by applicable law or agreed to in writing, software 16185+ * distributed under the License is distributed on an "AS IS" BASIS, 16186+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16187+ * See the License for the specific language governing permissions and 16188+ * limitations under the License. 16189+ */ 16190+ 16191+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_POOLING_DYNAMIC_FP16_CODER_H_ 16192+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_POOLING_DYNAMIC_FP16_CODER_H_ 16193+ 16194+#include <vector> 16195+#include "coder/opcoders/op_coder.h" 16196+#include "coder/opcoders/nnacl/dynamic_parameter/pooling_dynamic_parameter.h" 16197+#include "nnacl/pooling_parameter.h" 16198+#include "nnacl/kernel/pooling.h" 16199+ 16200+namespace mindspore::lite::micro::nnacl { 16201+class PoolingDynamicFP16Coder final : public OperatorCoder { 16202+ public: 16203+ PoolingDynamicFP16Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 16204+ const LiteGraph::Node *node, size_t node_index, Target target) 16205+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 16206+ ~PoolingDynamicFP16Coder() override = default; 16207+ 16208+ int Prepare(CoderContext *const context) override; 16209+ 16210+ int DoCode(CoderContext *const context) override; 16211+ 16212+ private: 16213+ PoolingParameter *param_{nullptr}; 16214+ PoolingComputeParam compute_; 16215+ PoolingDynamicParameter dynamic_param_; 16216+}; 16217+} // namespace mindspore::lite::micro::nnacl 16218+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_POOLING_DYNAMIC_FP16_CODER_H_ 16219diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_dynamic_fp16_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_dynamic_fp16_coder.cc 16220new file mode 100644 16221index 00000000..733cf49d 16222--- /dev/null 16223+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_dynamic_fp16_coder.cc 16224@@ -0,0 +1,128 @@ 16225+/** 16226+ * Copyright 2023 Huawei Technologies Co., Ltd 16227+ * 16228+ * Licensed under the Apache License, Version 2.0 (the "License"); 16229+ * you may not use this file except in compliance with the License. 16230+ * You may obtain a copy of the License at 16231+ * 16232+ * http://www.apache.org/licenses/LICENSE-2.0 16233+ * 16234+ * Unless required by applicable law or agreed to in writing, software 16235+ * distributed under the License is distributed on an "AS IS" BASIS, 16236+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16237+ * See the License for the specific language governing permissions and 16238+ * limitations under the License. 16239+ */ 16240+#include "coder/opcoders/nnacl/fp16/scale_dynamic_fp16_coder.h" 16241+#include <string> 16242+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 16243+#include "coder/opcoders/file_collector.h" 16244+#include "coder/opcoders/parallel.h" 16245+#include "coder/utils/coder_utils.h" 16246+ 16247+using mindspore::schema::PrimitiveType_ScaleFusion; 16248+ 16249+namespace mindspore::lite::micro::nnacl { 16250+int ScaleDynamicFP16Coder::Prepare(CoderContext *const context) { 16251+ for (size_t i = 0; i < input_tensors_.size(); ++i) { 16252+ MS_CHECK_TRUE_MSG(input_tensors_[i]->data_type() == kNumberTypeFloat16, RET_INPUT_PARAM_INVALID, 16253+ "Input tensor data type should be fp16, now is " << input_tensors_[i]->data_type()); 16254+ } 16255+ MS_CHECK_TRUE_MSG(output_tensor_->data_type() == kNumberTypeFloat16, RET_INPUT_PARAM_INVALID, 16256+ "Output tensor data type should be fp16, now is " << output_tensor_->data_type()); 16257+ 16258+ scale_param_ = reinterpret_cast<ScaleParameter *>(parameter_); 16259+ MS_CHECK_PTR(scale_param_); 16260+ scale_struct_.base_.param_ = parameter_; 16261+ if (input_tensors_.size() < DIMENSION_2D || input_tensors_.size() > DIMENSION_3D) { 16262+ MS_LOG(ERROR) << "inputs to Scale operator should be 2 or 3, but " << input_tensors_.size() << " is given."; 16263+ return RET_ERROR; 16264+ } 16265+ scale_tensor_ = input_tensors_.at(kWeightIndex); 16266+ MS_CHECK_PTR(scale_tensor_); 16267+ MS_CHECK_RET_CODE(CalculateParameter(), "Scale fp16 CalculateParameter failed."); 16268+ return RET_OK; 16269+} 16270+ 16271+int ScaleDynamicFP16Coder::DoCode(CoderContext *const context) { 16272+ // init struct ScaleParameters 16273+ Collect(context, 16274+ { 16275+ "nnacl/kernel/scale.h", 16276+ "nnacl/fp16/scale_fp16.h", 16277+ }, 16278+ { 16279+ "scale_fp16.c", 16280+ }); 16281+ 16282+ NNaclFp32Serializer code; 16283+ code.CodeStruct("scale_struct", scale_struct_, dynamic_param_); 16284+ 16285+ auto scale = GetTensorAddr(scale_tensor_, scale_tensor_->IsConst(), dynamic_mem_manager_, allocator_); 16286+ std::string offset{"NULL"}; 16287+ if (input_tensors_.size() == DIMENSION_3D) { 16288+ auto offset_tensor = input_tensors_.at(kBiasIndex); 16289+ offset = GetTensorAddr(offset_tensor, offset_tensor->IsConst(), dynamic_mem_manager_, allocator_); 16290+ } 16291+ std::string input_str = 16292+ "(float16_t *)(" + GetTensorAddr(input_tensor_, input_tensor_->IsConst(), dynamic_mem_manager_, allocator_) + ")"; 16293+ std::string output_str = 16294+ "(float16_t *)(" + GetTensorAddr(output_tensor_, output_tensor_->IsConst(), dynamic_mem_manager_, allocator_) + ")"; 16295+ switch (scale_param_->activation_type_) { 16296+ case schema::ActivationType_RELU6: 16297+ code.CodeFunction("DoScaleRelu6Fp16", input_str, output_str, scale, offset, kDefaultTaskId, "&scale_struct"); 16298+ break; 16299+ case schema::ActivationType_RELU: 16300+ code.CodeFunction("Fp16DoScaleRelu", input_str, output_str, scale, offset, kDefaultTaskId, "&scale_struct"); 16301+ break; 16302+ case schema::ActivationType_NO_ACTIVATION: 16303+ code.CodeFunction("DoScaleFp16", input_str, output_str, scale, offset, kDefaultTaskId, "&scale_struct"); 16304+ break; 16305+ default: 16306+ MS_LOG(ERROR) << "Scale does not support activation type " << scale_param_->activation_type_; 16307+ return RET_ERROR; 16308+ } 16309+ context->AppendCode(code.str()); 16310+ return RET_OK; 16311+} 16312+ 16313+int ScaleDynamicFP16Coder::CalculateParameter() { 16314+ auto in_shape = shape_info_container_->GetTemplateShape(input_tensor_); 16315+ std::vector<std::string> scale_shape; 16316+ if (scale_tensor_->IsConst()) { 16317+ for (auto dim : scale_tensor_->shape()) { 16318+ scale_shape.emplace_back(std::to_string(dim)); 16319+ } 16320+ } else { 16321+ scale_shape = shape_info_container_->GetTemplateShape(scale_tensor_); 16322+ } 16323+ if (scale_param_->axis_ < 0) { 16324+ scale_struct_.axis_ = scale_param_->axis_ + in_shape.size(); 16325+ } 16326+ if (scale_shape.size() + scale_struct_.axis_ > in_shape.size()) { 16327+ MS_LOG(ERROR) << "Scale tensor shape is incorrect."; 16328+ return RET_ERROR; 16329+ } 16330+ dynamic_param_.outer_size_ = AccumulateShape(in_shape, 0, scale_struct_.axis_); 16331+ if (scale_tensor_->IsConst() && scale_tensor_->shape().size() == 1) { 16332+ dynamic_param_.axis_size_ = in_shape.at(scale_struct_.axis_); 16333+ } else { 16334+ dynamic_param_.axis_size_ = "{"; 16335+ for (size_t i = 0; i < scale_shape.size(); i++) { 16336+ if (in_shape.at(i + scale_struct_.axis_) != scale_shape.at(i)) { 16337+ MS_LOG(ERROR) << "Scale tensor shape is incorrect."; 16338+ return RET_ERROR; 16339+ } 16340+ dynamic_param_.axis_size_ += in_shape.at(i + scale_struct_.axis_) + ", "; 16341+ } 16342+ dynamic_param_.axis_size_ += "}"; 16343+ } 16344+ dynamic_param_.inner_size_ = AccumulateShape(in_shape, scale_struct_.axis_ + scale_shape.size(), in_shape.size()); 16345+ return RET_OK; 16346+} 16347+ 16348+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_ScaleFusion, 16349+ CPUOpCoderCreator<ScaleDynamicFP16Coder>) 16350+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_ScaleFusion, 16351+ CPUOpCoderCreator<ScaleDynamicFP16Coder>) 16352+} // namespace mindspore::lite::micro::nnacl 16353diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_dynamic_fp16_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_dynamic_fp16_coder.h 16354new file mode 100644 16355index 00000000..02ec35ba 16356--- /dev/null 16357+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_dynamic_fp16_coder.h 16358@@ -0,0 +1,46 @@ 16359+/** 16360+ * Copyright 2023 Huawei Technologies Co., Ltd 16361+ * 16362+ * Licensed under the Apache License, Version 2.0 (the "License"); 16363+ * you may not use this file except in compliance with the License. 16364+ * You may obtain a copy of the License at 16365+ * 16366+ * http://www.apache.org/licenses/LICENSE-2.0 16367+ * 16368+ * Unless required by applicable law or agreed to in writing, software 16369+ * distributed under the License is distributed on an "AS IS" BASIS, 16370+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16371+ * See the License for the specific language governing permissions and 16372+ * limitations under the License. 16373+ */ 16374+ 16375+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_SCALE_DYNAMIC_FP16_CODER_H_ 16376+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_SCALE_DYNAMIC_FP16_CODER_H_ 16377+ 16378+#include <vector> 16379+#include "coder/opcoders/op_coder.h" 16380+#include "coder/opcoders/nnacl/dynamic_parameter/scale_dynamic_parameter.h" 16381+#include "nnacl/kernel/scale.h" 16382+#include "nnacl/scale_parameter.h" 16383+ 16384+namespace mindspore::lite::micro::nnacl { 16385+class ScaleDynamicFP16Coder final : public OperatorCoder { 16386+ public: 16387+ ScaleDynamicFP16Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 16388+ const LiteGraph::Node *node, size_t node_index, Target target) 16389+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 16390+ ~ScaleDynamicFP16Coder() override = default; 16391+ 16392+ int Prepare(CoderContext *const context) override; 16393+ 16394+ int DoCode(CoderContext *const context) override; 16395+ 16396+ private: 16397+ int CalculateParameter(); 16398+ ScaleParameter *scale_param_{nullptr}; 16399+ ScaleStruct scale_struct_; 16400+ ScaleDynamicParameter dynamic_param_; 16401+ Tensor *scale_tensor_{nullptr}; 16402+}; 16403+} // namespace mindspore::lite::micro::nnacl 16404+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_SCALE_DYNAMIC_FP16_CODER_H_ 16405diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_dynamic_fp16_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_dynamic_fp16_coder.cc 16406new file mode 100644 16407index 00000000..1c6969b2 16408--- /dev/null 16409+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_dynamic_fp16_coder.cc 16410@@ -0,0 +1,160 @@ 16411+/** 16412+ * Copyright 2023 Huawei Technologies Co., Ltd 16413+ * 16414+ * Licensed under the Apache License, Version 2.0 (the "License"); 16415+ * you may not use this file except in compliance with the License. 16416+ * You may obtain a copy of the License at 16417+ * 16418+ * http://www.apache.org/licenses/LICENSE-2.0 16419+ * 16420+ * Unless required by applicable law or agreed to in writing, software 16421+ * distributed under the License is distributed on an "AS IS" BASIS, 16422+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16423+ * See the License for the specific language governing permissions and 16424+ * limitations under the License. 16425+ */ 16426+ 16427+#include "coder/opcoders/nnacl/fp16/slice_dynamic_fp16_coder.h" 16428+#include "coder/opcoders/file_collector.h" 16429+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 16430+#include "coder/utils/coder_utils.h" 16431+ 16432+using mindspore::schema::PrimitiveType_SliceFusion; 16433+ 16434+namespace mindspore::lite::micro::nnacl { 16435+int SliceDynamicFP16Coder::Prepare(CoderContext *const context) { 16436+ CHECK_LESS_RETURN(input_tensors_.size(), C3NUM); 16437+ CHECK_LESS_RETURN(output_tensors_.size(), 1); 16438+ CHECK_NULL_RETURN(input_tensors_[FIRST_INPUT]); 16439+ CHECK_NULL_RETURN(input_tensors_[SECOND_INPUT]); 16440+ CHECK_NULL_RETURN(input_tensors_[THIRD_INPUT]); 16441+ CHECK_NULL_RETURN(output_tensor_); 16442+ param_ = reinterpret_cast<SliceParameter *>(parameter_); 16443+ CHECK_NULL_RETURN(param_); 16444+ MS_CHECK_TRUE_MSG(input_tensors_[SECOND_INPUT]->IsConst() && input_tensors_[THIRD_INPUT]->IsConst(), RET_NOT_SUPPORT, 16445+ "The second and third input of slice is non-const."); 16446+ MS_CHECK_TRUE_MSG(input_tensors_[SECOND_INPUT]->data_type() == kNumberTypeInt32 && 16447+ input_tensors_[THIRD_INPUT]->data_type() == kNumberTypeInt32, 16448+ RET_INPUT_PARAM_INVALID, "second or third input tensor data type need to be int32."); 16449+ if (input_tensor_->data_type() != kNumberTypeFloat16 || output_tensor_->data_type() != kNumberTypeFloat16) { 16450+ MS_LOG(ERROR) << "Tensor data type is invalid"; 16451+ return lite::RET_INPUT_PARAM_INVALID; 16452+ } 16453+ return Init(); 16454+} 16455+ 16456+int SliceDynamicFP16Coder::DoCode(CoderContext *const context) { 16457+ Collect(context, 16458+ { 16459+ "nnacl/base/slice_base.h", 16460+ }, 16461+ { 16462+ "slice_base.c", 16463+ }); 16464+ NNaclFp32Serializer code; 16465+ code.CodeStruct("slice_param", *param_, dynamic_param_); 16466+ std::string input_data = GetTensorAddr(input_tensor_, input_tensor_->IsConst(), dynamic_mem_manager_, allocator_); 16467+ std::string output_data = GetTensorAddr(output_tensor_, output_tensor_->IsConst(), dynamic_mem_manager_, allocator_); 16468+ if (!support_parallel_) { 16469+ code.CodeFunction("DoSliceNoParallel", input_data, output_data, "&slice_param", 16470+ DataTypeSize(input_tensor_->data_type())); 16471+ } 16472+ context->AppendCode(code.str()); 16473+ return NNACL_OK; 16474+} 16475+ 16476+int SliceDynamicFP16Coder::Init() { 16477+ auto begin_tensor = input_tensors_[SECOND_INPUT]; 16478+ auto size_tensor = input_tensors_[THIRD_INPUT]; 16479+ data_shape_ = shape_info_container_->GetTemplateShape(input_tensor_); 16480+ MS_CHECK_TRUE_MSG(data_shape_.size() == static_cast<size_t>(begin_tensor->ElementsNum()), RET_ERROR, 16481+ "The begin tensor is invalid."); 16482+ MS_CHECK_TRUE_MSG(data_shape_.size() == static_cast<size_t>(size_tensor->ElementsNum()), RET_ERROR, 16483+ "The size tensor is invalid."); 16484+ auto begin = reinterpret_cast<int32_t *>(begin_tensor->data()); 16485+ CHECK_NULL_RETURN(begin); 16486+ auto size = reinterpret_cast<int32_t *>(size_tensor->data()); 16487+ CHECK_NULL_RETURN(size); 16488+ param_->param_length_ = static_cast<int>(data_shape_.size()); 16489+ if (param_->param_length_ > DIMENSION_8D) { 16490+ MS_LOG(ERROR) << "input dimension num should <= " << DIMENSION_8D; 16491+ return RET_ERROR; 16492+ } 16493+ dynamic_param_.shape_ = "{"; 16494+ dynamic_param_.size_ = "{"; 16495+ dynamic_param_.end_ = "{"; 16496+ for (int i = 0; i < param_->param_length_; ++i) { 16497+ dynamic_param_.shape_ += data_shape_[i] + ", "; 16498+ param_->begin_[i] = begin[i]; 16499+ if (size[i] < 0) { 16500+ std::string cur_size = data_shape_[i] + " - " + std::to_string(begin[i]); 16501+ slice_size_.emplace_back(cur_size); 16502+ dynamic_param_.size_ += cur_size + ", "; 16503+ } else { 16504+ slice_size_.emplace_back(std::to_string(size[i])); 16505+ dynamic_param_.size_ += std::to_string(size[i]) + ", "; 16506+ } 16507+ std::string cur_end = std::to_string(param_->begin_[i]) + " + " + slice_size_[i]; 16508+ end_.emplace_back(cur_end); 16509+ dynamic_param_.end_ += cur_end + ", "; 16510+ } 16511+ dynamic_param_.shape_ += "}"; 16512+ dynamic_param_.size_ += "}"; 16513+ dynamic_param_.end_ += "}"; 16514+ if (param_->param_length_ < DIMENSION_8D) { 16515+ PadSliceParameterTo8D(); 16516+ } 16517+ return RET_OK; 16518+} 16519+ 16520+void SliceDynamicFP16Coder::PadSliceParameterTo8D() { 16521+ std::vector<int32_t> begin(DIMENSION_8D, 0); 16522+ std::vector<std::string> end(DIMENSION_8D, ""); 16523+ std::vector<std::string> slice_size(DIMENSION_8D, ""); 16524+ std::vector<std::string> data_shape(DIMENSION_8D, ""); 16525+ for (int32_t i = 0; i < param_->param_length_; ++i) { 16526+ begin[i] = param_->begin_[i]; 16527+ end[i] = end_[i]; 16528+ slice_size[i] = 16529+ slice_size_[i] + " < 0 ? " + data_shape[i] + " - " + std::to_string(begin[i]) + " : " + slice_size_[i]; 16530+ data_shape[i] = data_shape_[i]; 16531+ } 16532+ data_shape_.resize(DIMENSION_8D); 16533+ slice_size_.resize(DIMENSION_8D); 16534+ end_.resize(DIMENSION_8D); 16535+ int32_t real_index = param_->param_length_ - 1; 16536+ for (int32_t i = DIMENSION_8D - 1; i >= 0; --i) { 16537+ if (real_index >= 0) { 16538+ param_->begin_[i] = begin[real_index]; 16539+ end_[i] = end[real_index]; 16540+ slice_size_[i] = slice_size[real_index]; 16541+ data_shape_[i] = data_shape[real_index--]; 16542+ } else { 16543+ param_->begin_[i] = 0; 16544+ end_[i] = "1"; 16545+ slice_size_[i] = "1"; 16546+ data_shape_[i] = "1"; 16547+ } 16548+ } 16549+ param_->param_length_ = DIMENSION_8D; 16550+ dynamic_param_.shape_.clear(); 16551+ dynamic_param_.size_.clear(); 16552+ dynamic_param_.end_.clear(); 16553+ dynamic_param_.shape_ = "{"; 16554+ dynamic_param_.size_ = "{"; 16555+ dynamic_param_.end_ = "{"; 16556+ for (int i = 0; i < DIMENSION_8D; ++i) { 16557+ dynamic_param_.end_ += end_[i] + ", "; 16558+ dynamic_param_.size_ += slice_size_[i] + ", "; 16559+ dynamic_param_.shape_ += data_shape_[i] + ", "; 16560+ } 16561+ dynamic_param_.shape_ += "}"; 16562+ dynamic_param_.size_ += "}"; 16563+ dynamic_param_.end_ += "}"; 16564+} 16565+ 16566+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_SliceFusion, 16567+ CPUOpCoderCreator<SliceDynamicFP16Coder>) 16568+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_SliceFusion, 16569+ CPUOpCoderCreator<SliceDynamicFP16Coder>) 16570+}; // namespace mindspore::lite::micro::nnacl 16571diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_dynamic_fp16_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_dynamic_fp16_coder.h 16572new file mode 100644 16573index 00000000..21b1b27b 16574--- /dev/null 16575+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_dynamic_fp16_coder.h 16576@@ -0,0 +1,51 @@ 16577+/** 16578+ * Copyright 2023 Huawei Technologies Co., Ltd 16579+ * 16580+ * Licensed under the Apache License, Version 2.0 (the "License"); 16581+ * you may not use this file except in compliance with the License. 16582+ * You may obtain a copy of the License at 16583+ * 16584+ * http://www.apache.org/licenses/LICENSE-2.0 16585+ * 16586+ * Unless required by applicable law or agreed to in writing, software 16587+ * distributed under the License is distributed on an "AS IS" BASIS, 16588+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16589+ * See the License for the specific language governing permissions and 16590+ * limitations under the License. 16591+ */ 16592+ 16593+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_SLICE_DYNAMIC_FP16_CODER_H_ 16594+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_SLICE_DYNAMIC_FP16_CODER_H_ 16595+ 16596+#include <vector> 16597+#include "mindspore/lite/tools/converter/micro/coder/opcoders/op_coder.h" 16598+#include "coder/opcoders/nnacl/dynamic_parameter/slice_dynamic_parameter.h" 16599+#include "nnacl/slice_parameter.h" 16600+#include "nnacl/op_base.h" 16601+ 16602+namespace mindspore::lite::micro::nnacl { 16603+class SliceDynamicFP16Coder final : public OperatorCoder { 16604+ public: 16605+ SliceDynamicFP16Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 16606+ const LiteGraph::Node *node, size_t node_index, Target target) 16607+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 16608+ 16609+ ~SliceDynamicFP16Coder() override = default; 16610+ 16611+ int Prepare(CoderContext *const context) override; 16612+ 16613+ int DoCode(CoderContext *const context) override; 16614+ 16615+ protected: 16616+ int Init(); 16617+ void PadSliceParameterTo8D(); 16618+ SliceParameter *param_{nullptr}; 16619+ SliceDynamicParameter dynamic_param_; 16620+ std::vector<std::string> in_shapes_; 16621+ std::vector<std::string> out_shapes_; 16622+ std::vector<std::string> data_shape_; 16623+ std::vector<std::string> slice_size_; 16624+ std::vector<std::string> end_; 16625+}; 16626+}; // namespace mindspore::lite::micro::nnacl 16627+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_SLICE_DYNAMIC_FP16_CODER_H_ 16628diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_dynamic_fp16_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_dynamic_fp16_coder.cc 16629new file mode 100644 16630index 00000000..1bd09fb5 16631--- /dev/null 16632+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_dynamic_fp16_coder.cc 16633@@ -0,0 +1,137 @@ 16634+/** 16635+ * Copyright 2023 Huawei Technologies Co., Ltd 16636+ * 16637+ * Licensed under the Apache License, Version 2.0 (the "License"); 16638+ * you may not use this file except in compliance with the License. 16639+ * You may obtain a copy of the License at 16640+ * 16641+ * http://www.apache.org/licenses/LICENSE-2.0 16642+ * 16643+ * Unless required by applicable law or agreed to in writing, software 16644+ * distributed under the License is distributed on an "AS IS" BASIS, 16645+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16646+ * See the License for the specific language governing permissions and 16647+ * limitations under the License. 16648+ */ 16649+ 16650+#include "coder/opcoders/nnacl/fp16/softmax_dynamic_fp16_coder.h" 16651+#include <string> 16652+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 16653+#include "schema/inner/ops_generated.h" 16654+#include "coder/opcoders/file_collector.h" 16655+#include "coder/utils/coder_utils.h" 16656+#include "tools/common/string_util.h" 16657+#include "base/float16.h" 16658+ 16659+using mindspore::schema::PrimitiveType_LogSoftmax; 16660+using mindspore::schema::PrimitiveType_Softmax; 16661+ 16662+namespace mindspore::lite::micro::nnacl { 16663+int SoftmaxDynamicFP16Coder::Prepare(CoderContext *const context) { 16664+ for (size_t i = 0; i < input_tensors_.size(); ++i) { 16665+ MS_CHECK_TRUE_MSG(input_tensors_[i]->data_type() == kNumberTypeFloat16, RET_INPUT_PARAM_INVALID, 16666+ "Input tensor data type is invalid"); 16667+ } 16668+ for (size_t i = 0; i < output_tensors_.size(); ++i) { 16669+ MS_CHECK_TRUE_MSG(output_tensors_[i]->data_type() == kNumberTypeFloat16, RET_INPUT_PARAM_INVALID, 16670+ "Output tensor data type is invalid"); 16671+ } 16672+ auto ret = Init(); 16673+ MS_CHECK_RET_CODE(ret, "Init failed!"); 16674+ return RET_OK; 16675+} 16676+ 16677+int SoftmaxDynamicFP16Coder::DoCode(CoderContext *const context) { 16678+ Collect(context, 16679+ { 16680+ "nnacl/fp16/softmax_fp16.h", 16681+ "nnacl/fp16/log_softmax_fp16.h", 16682+ }, 16683+ { 16684+ "softmax_fp16.c", 16685+ "log_softmax_fp16.c", 16686+ "exp_fp16.c", 16687+ }); 16688+ 16689+ auto ret = ComputeWorkSpace(); 16690+ MS_CHECK_RET_CODE(ret, "ComputeWorkSpace failed!"); 16691+ NNaclFp32Serializer code; 16692+ sum_data_str_ = "(float16_t *)(" + buffer_start_ + ")"; 16693+ auto primitive_type = param_->op_parameter_.type_; 16694+ std::string input_data = 16695+ "(float16_t *)(" + GetTensorAddr(input_tensor_, input_tensor_->IsConst(), dynamic_mem_manager_, allocator_) + ")"; 16696+ std::string output_data = 16697+ "(float16_t *)(" + GetTensorAddr(output_tensor_, output_tensor_->IsConst(), dynamic_mem_manager_, allocator_) + ")"; 16698+ code << " int input_shape[" << input_shape_.size() << "] = " << dynamic_param_.input_shape_ << ";\n"; 16699+ if (primitive_type == schema::PrimitiveType_Softmax) { 16700+ code.CodeFunction("SoftmaxFp16", input_data, output_data, sum_data_str_, softmax_struct_.axis_, 16701+ softmax_struct_.n_dim_, "&input_shape"); 16702+ } else { 16703+ code.CodeFunction("LogSoftmaxFp16", input_data, output_data, sum_data_str_, "&input_shape", softmax_struct_.n_dim_, 16704+ softmax_struct_.axis_); 16705+ } 16706+ context->AppendCode(code.str()); 16707+ return RET_OK; 16708+} 16709+ 16710+int SoftmaxDynamicFP16Coder::Init() { 16711+ param_ = reinterpret_cast<SoftmaxParameter *>(parameter_); 16712+ MS_CHECK_PTR(param_); 16713+ softmax_struct_.base_.param_ = parameter_; 16714+ input_shape_ = shape_info_container_->GetTemplateShape(input_tensor_); 16715+ size_t in_dims = input_shape_.size(); 16716+ softmax_struct_.n_dim_ = in_dims; 16717+ softmax_struct_.axis_ = param_->axis_ < 0 ? param_->axis_ + softmax_struct_.n_dim_ : param_->axis_; 16718+ dynamic_param_.element_size_ = AccumulateShape(input_shape_, 0, input_shape_.size()); 16719+ dynamic_param_.input_shape_ = "{"; 16720+ for (size_t i = 0; i < input_shape_.size(); ++i) { 16721+ dynamic_param_.input_shape_ += input_shape_[i] + ", "; 16722+ } 16723+ dynamic_param_.input_shape_ += "}"; 16724+ return RET_OK; 16725+} 16726+ 16727+int SoftmaxDynamicFP16Coder::ComputeWorkSpace() { 16728+ std::map<std::string, std::vector<int>> real_nums; 16729+ size_t scene_num = 0; 16730+ for (auto &dim_template : input_shape_) { 16731+ auto dim_nums = shape_info_container_->GetRealNums(dim_template); 16732+ MS_CHECK_TRUE_MSG(!dim_nums.empty(), RET_ERROR, "Dynamic shape's num must be greater than 0."); 16733+ real_nums[dim_template] = dim_nums; 16734+ scene_num = std::max(scene_num, dim_nums.size()); 16735+ } 16736+ for (size_t i = 0; i < scene_num; ++i) { 16737+ std::vector<int> real_shape(input_shape_.size()); 16738+ for (size_t j = 0; j < input_shape_.size(); ++j) { 16739+ if (IsNumber(input_shape_[j])) { 16740+ real_shape[j] = std::stoi(input_shape_[j]); 16741+ } else { 16742+ real_shape[j] = real_nums[input_shape_[j]][i % real_nums[input_shape_[j]].size()]; 16743+ } 16744+ } 16745+ int out_plane_size = 1; 16746+ for (int j = 0; j < softmax_struct_.axis_; ++j) { 16747+ MS_CHECK_INT_MUL_NOT_OVERFLOW(out_plane_size, real_shape[j], RET_ERROR); 16748+ out_plane_size *= real_shape[j]; 16749+ } 16750+ int in_plane_size = 1; 16751+ for (int j = softmax_struct_.axis_ + 1; j < softmax_struct_.n_dim_; ++j) { 16752+ MS_CHECK_INT_MUL_NOT_OVERFLOW(in_plane_size, real_shape[j], RET_ERROR); 16753+ in_plane_size *= real_shape[j]; 16754+ } 16755+ int workspace = out_plane_size * in_plane_size * sizeof(float16); 16756+ buffer_start_ = dynamic_mem_manager_->AllocWorkSpace(workspace, i); 16757+ MS_CHECK_TRUE_MSG(!buffer_start_.empty(), RET_ERROR, "Softmax cannot alloc workspace."); 16758+ } 16759+ return RET_OK; 16760+} 16761+ 16762+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_Softmax, 16763+ CPUOpCoderCreator<SoftmaxDynamicFP16Coder>) 16764+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_Softmax, 16765+ CPUOpCoderCreator<SoftmaxDynamicFP16Coder>) 16766+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_LogSoftmax, 16767+ CPUOpCoderCreator<SoftmaxDynamicFP16Coder>) 16768+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_LogSoftmax, 16769+ CPUOpCoderCreator<SoftmaxDynamicFP16Coder>) 16770+} // namespace mindspore::lite::micro::nnacl 16771diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_dynamic_fp16_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_dynamic_fp16_coder.h 16772new file mode 100644 16773index 00000000..913f5ad4 16774--- /dev/null 16775+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_dynamic_fp16_coder.h 16776@@ -0,0 +1,50 @@ 16777+/** 16778+ * Copyright 2023 Huawei Technologies Co., Ltd 16779+ * 16780+ * Licensed under the Apache License, Version 2.0 (the "License"); 16781+ * you may not use this file except in compliance with the License. 16782+ * You may obtain a copy of the License at 16783+ * 16784+ * http://www.apache.org/licenses/LICENSE-2.0 16785+ * 16786+ * Unless required by applicable law or agreed to in writing, software 16787+ * distributed under the License is distributed on an "AS IS" BASIS, 16788+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16789+ * See the License for the specific language governing permissions and 16790+ * limitations under the License. 16791+ */ 16792+ 16793+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_SOFTMAX_DYNAMIC_FP16_CODER_H_ 16794+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_SOFTMAX_DYNAMIC_FP16_CODER_H_ 16795+ 16796+#include <vector> 16797+#include <string> 16798+#include "coder/opcoders/op_coder.h" 16799+#include "coder/opcoders/nnacl/dynamic_parameter/softmax_dynamic_parameter.h" 16800+#include "nnacl/softmax_parameter.h" 16801+#include "nnacl/kernel/softmax.h" 16802+ 16803+namespace mindspore::lite::micro::nnacl { 16804+class SoftmaxDynamicFP16Coder final : public OperatorCoder { 16805+ public: 16806+ SoftmaxDynamicFP16Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 16807+ const LiteGraph::Node *node, size_t node_index, Target target) 16808+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 16809+ ~SoftmaxDynamicFP16Coder() override = default; 16810+ 16811+ int Prepare(CoderContext *const context) override; 16812+ 16813+ int DoCode(CoderContext *const context) override; 16814+ 16815+ private: 16816+ int Init(); 16817+ int ComputeWorkSpace(); 16818+ SoftmaxParameter *param_{nullptr}; 16819+ SoftmaxStruct softmax_struct_; 16820+ SoftmaxDynamicParameter dynamic_param_; 16821+ std::vector<std::string> input_shape_; 16822+ std::string buffer_start_; 16823+ std::string sum_data_str_; 16824+}; 16825+} // namespace mindspore::lite::micro::nnacl 16826+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_SOFTMAX_DYNAMIC_FP16_CODER_H_ 16827diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_dynamic_fp16_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_dynamic_fp16_coder.cc 16828new file mode 100644 16829index 00000000..59c8d8b8 16830--- /dev/null 16831+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_dynamic_fp16_coder.cc 16832@@ -0,0 +1,76 @@ 16833+/** 16834+ * Copyright 2023 Huawei Technologies Co., Ltd 16835+ * 16836+ * Licensed under the Apache License, Version 2.0 (the "License"); 16837+ * you may not use this file except in compliance with the License. 16838+ * You may obtain a copy of the License at 16839+ * 16840+ * http://www.apache.org/licenses/LICENSE-2.0 16841+ * 16842+ * Unless required by applicable law or agreed to in writing, software 16843+ * distributed under the License is distributed on an "AS IS" BASIS, 16844+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16845+ * See the License for the specific language governing permissions and 16846+ * limitations under the License. 16847+ */ 16848+ 16849+#include "coder/opcoders/nnacl/fp16/transpose_dynamic_fp16_coder.h" 16850+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 16851+#include "coder/opcoders/file_collector.h" 16852+#include "coder/opcoders/parallel.h" 16853+#include "coder/utils/coder_utils.h" 16854+ 16855+using mindspore::schema::PrimitiveType_Transpose; 16856+namespace mindspore::lite::micro::nnacl { 16857+int TransposeDynamicFp16Coder::Prepare(CoderContext *const context) { 16858+ MS_CHECK_TRUE_MSG(input_tensor_->data_type() == kNumberTypeFloat16, RET_INPUT_PARAM_INVALID, 16859+ "Input tensor data type is invalid."); 16860+ MS_CHECK_TRUE_MSG(input_tensors_[SECOND_INPUT]->data_type() == kNumberTypeInt32, RET_INPUT_PARAM_INVALID, 16861+ "Perm tensor data type is invalid."); 16862+ MS_CHECK_TRUE_MSG( 16863+ output_tensor_->data_type() == kNumberTypeInt32 || output_tensor_->data_type() == kNumberTypeFloat16, 16864+ RET_INPUT_PARAM_INVALID, "Output tensor data type is invalid."); 16865+ MS_CHECK_TRUE_MSG(input_tensors_[SECOND_INPUT]->IsConst(), RET_NOT_SUPPORT, 16866+ "The second input of transpose is non-const."); 16867+ thread_num_ = 1; 16868+ MS_CHECK_RET_CODE(Init(), "init failed"); 16869+ return RET_OK; 16870+} 16871+ 16872+int TransposeDynamicFp16Coder::DoCode(CoderContext *const context) { 16873+ Collect(context, 16874+ { 16875+ "nnacl/transpose_parameter.h", 16876+ "nnacl/errorcode.h", 16877+ "nnacl/fp16/transpose_fp16.h", 16878+ }, 16879+ { 16880+ "transpose_fp16.c", 16881+ }); 16882+ 16883+ NNaclFp32Serializer code; 16884+ dims_ = static_cast<int>(out_shapes_.size()); 16885+ code << "const int32_t output_shape[" << dims_ << "] = {"; 16886+ for (size_t i = 0; i < out_shapes_.size(); ++i) { 16887+ code << out_shapes_[i] << ", "; 16888+ } 16889+ code << "};\n"; 16890+ code.CodeStruct("trans_param", *param_, dynamic_param_); 16891+ auto input_str = dynamic_mem_manager_->GetVarTensorAddr(input_tensor_); 16892+ auto output_str = dynamic_mem_manager_->GetVarTensorAddr(output_tensor_); 16893+ if (param_->num_axes_ > DIMENSION_6D) { 16894+ code.CodeFunction("TransposeDimsFp16", input_str, output_str, "output_shape", "trans_param.perm_", 16895+ "trans_param.strides_", "trans_param.out_strides_", "trans_param.num_axes_", kDefaultTaskId, 16896+ kDefaultThreadNum); 16897+ } else { 16898+ code.CodeFunction("DoTransposeFp16", input_str, output_str, "output_shape", "trans_param.perm_", 16899+ "trans_param.strides_", "trans_param.out_strides_", "trans_param.data_num_", 16900+ "trans_param.num_axes_"); 16901+ } 16902+ context->AppendCode(code.str()); 16903+ return RET_OK; 16904+} 16905+ 16906+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_Transpose, 16907+ CPUOpCoderCreator<TransposeDynamicFp16Coder>) 16908+} // namespace mindspore::lite::micro::nnacl 16909diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_dynamic_fp16_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_dynamic_fp16_coder.h 16910new file mode 100644 16911index 00000000..e008a794 16912--- /dev/null 16913+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_dynamic_fp16_coder.h 16914@@ -0,0 +1,37 @@ 16915+/** 16916+ * Copyright 2023 Huawei Technologies Co., Ltd 16917+ * 16918+ * Licensed under the Apache License, Version 2.0 (the "License"); 16919+ * you may not use this file except in compliance with the License. 16920+ * You may obtain a copy of the License at 16921+ * 16922+ * http://www.apache.org/licenses/LICENSE-2.0 16923+ * 16924+ * Unless required by applicable law or agreed to in writing, software 16925+ * distributed under the License is distributed on an "AS IS" BASIS, 16926+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16927+ * See the License for the specific language governing permissions and 16928+ * limitations under the License. 16929+ */ 16930+ 16931+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_TRANSPOSE_DYNAMIC_FP16_CODER_H_ 16932+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_TRANSPOSE_DYNAMIC_FP16_CODER_H_ 16933+#include <vector> 16934+#include <string> 16935+#include "coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.h" 16936+ 16937+namespace mindspore::lite::micro::nnacl { 16938+class TransposeDynamicFp16Coder : public TransposeDynamicFp32Coder { 16939+ public: 16940+ TransposeDynamicFp16Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 16941+ const LiteGraph::Node *node, size_t node_index, Target target) 16942+ : TransposeDynamicFp32Coder(in_tensors, out_tensors, node, node_index, target) {} 16943+ 16944+ ~TransposeDynamicFp16Coder() override = default; 16945+ 16946+ int Prepare(CoderContext *const context) override; 16947+ 16948+ int DoCode(CoderContext *const context) override; 16949+}; 16950+} // namespace mindspore::lite::micro::nnacl 16951+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_TRANSPOSE_DYNAMIC_FP16_CODER_H_ 16952diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/activation_dynamic_fp32_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/activation_dynamic_fp32_coder.cc 16953new file mode 100644 16954index 00000000..1dd33bbd 16955--- /dev/null 16956+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/activation_dynamic_fp32_coder.cc 16957@@ -0,0 +1,112 @@ 16958+/** 16959+ * Copyright 2023 Huawei Technologies Co., Ltd 16960+ * 16961+ * Licensed under the Apache License, Version 2.0 (the "License"); 16962+ * you may not use this file except in compliance with the License. 16963+ * You may obtain a copy of the License at 16964+ * 16965+ * http://www.apache.org/licenses/LICENSE-2.0 16966+ * 16967+ * Unless required by applicable law or agreed to in writing, software 16968+ * distributed under the License is distributed on an "AS IS" BASIS, 16969+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16970+ * See the License for the specific language governing permissions and 16971+ * limitations under the License. 16972+ */ 16973+#include "coder/opcoders/nnacl/fp32/activation_dynamic_fp32_coder.h" 16974+#include <string> 16975+#include "nnacl/fp32/activation_fp32.h" 16976+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 16977+#include "coder/opcoders/file_collector.h" 16978+#include "coder/opcoders/parallel.h" 16979+#include "tools/common/string_util.h" 16980+#include "coder/utils/coder_utils.h" 16981+ 16982+using mindspore::schema::PrimitiveType_Activation; 16983+ 16984+namespace mindspore::lite::micro::nnacl { 16985+int ActivationDynamicFP32Coder::Preprocess() { 16986+ // attribute 16987+ auto in_shape = shape_info_container_->GetTemplateShape(input_tensor_); 16988+ int64_t const_part = 1; 16989+ std::string non_const_part; 16990+ for (const auto &item : in_shape) { 16991+ if (IsNumber(item)) { 16992+ const_part *= std::atoi(item.c_str()); 16993+ } else { 16994+ if (!non_const_part.empty()) { 16995+ non_const_part += " * "; 16996+ } 16997+ non_const_part += item; 16998+ } 16999+ } 17000+ count_ = std::to_string(const_part) + " * " + non_const_part; 17001+ input_data_ = dynamic_mem_manager_->GetVarTensorAddr(input_tensor_); 17002+ MS_CHECK_TRUE_MSG(!input_data_.empty(), RET_ERROR, "pointer is not allocated by the allocator"); 17003+ output_data_ = dynamic_mem_manager_->GetVarTensorAddr(output_tensor_); 17004+ MS_CHECK_TRUE_MSG(!output_data_.empty(), RET_ERROR, "pointer is not allocated by the allocator"); 17005+ return RET_OK; 17006+} 17007+ 17008+int ActivationDynamicFP32Coder::DoCode(CoderContext *const context) { 17009+ Collect(context, 17010+ { 17011+ "wrapper/fp32/activation_fp32_wrapper.h", 17012+ "nnacl/fp32/activation_fp32.h", 17013+ }, 17014+ { 17015+ "activation_fp32_wrapper.c", 17016+ "activation_fp32.c", 17017+ }); 17018+ NNaclFp32Serializer code; 17019+ auto *activation_parameter = reinterpret_cast<ActivationParameter *>(parameter_); 17020+ int ret = Preprocess(); 17021+ MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "Preprocess failed"); 17022+ 17023+ switch (activation_parameter->type_) { 17024+ case schema::ActivationType_RELU: 17025+ code.CodeFunction("Fp32Relu", input_data_, count_, output_data_); 17026+ break; 17027+ case schema::ActivationType_RELU6: 17028+ code.CodeFunction("Fp32Relu6", input_data_, count_, output_data_); 17029+ break; 17030+ case schema::ActivationType_LEAKY_RELU: 17031+ code.CodeFunction("LRelu", input_data_, count_, output_data_, activation_parameter->alpha_); 17032+ break; 17033+ case schema::ActivationType_SIGMOID: 17034+ if (!support_parallel_) { 17035+ code.CodeFunction("Sigmoid", input_data_, count_, output_data_); 17036+ } else { 17037+ code.CodeStruct("activation_param", *activation_parameter); 17038+ code.CodeBaseStruct("ActivationFp32Args", kRunArgs, input_data_, count_, output_data_, 0.0f, 17039+ "&activation_param"); 17040+ code.CodeFunction(kParallelLaunch, "DoSigmoid", kRunArgsAddr, "activation_param.op_parameter_.thread_num_"); 17041+ } 17042+ break; 17043+ case schema::ActivationType_TANH: 17044+ code.CodeFunction("Tanh", input_data_, count_, output_data_); 17045+ break; 17046+ case schema::ActivationType_HSWISH: 17047+ code.CodeFunction("HSwish", input_data_, count_, output_data_); 17048+ break; 17049+ case schema::ActivationType_SWISH: 17050+ code.CodeFunction("Swish", input_data_, count_, output_data_); 17051+ break; 17052+ case schema::ActivationType_HSIGMOID: 17053+ code.CodeFunction("HSigmoid", input_data_, count_, output_data_); 17054+ break; 17055+ case schema::ActivationType_ELU: 17056+ code.CodeFunction("Elu", input_data_, count_, output_data_, activation_parameter->alpha_); 17057+ break; 17058+ default: 17059+ MS_LOG(ERROR) << "Activation type error"; 17060+ return RET_ERROR; 17061+ } 17062+ MS_LOG(DEBUG) << "ActivationFP32Code has been called"; 17063+ context->AppendCode(code.str()); 17064+ return lite::RET_OK; 17065+} 17066+ 17067+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_Activation, 17068+ CPUOpCoderCreator<ActivationDynamicFP32Coder>) 17069+} // namespace mindspore::lite::micro::nnacl 17070diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/activation_dynamic_fp32_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/activation_dynamic_fp32_coder.h 17071new file mode 100644 17072index 00000000..1560afbb 17073--- /dev/null 17074+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/activation_dynamic_fp32_coder.h 17075@@ -0,0 +1,46 @@ 17076+/** 17077+ * Copyright 2023 Huawei Technologies Co., Ltd 17078+ * 17079+ * Licensed under the Apache License, Version 2.0 (the "License"); 17080+ * you may not use this file except in compliance with the License. 17081+ * You may obtain a copy of the License at 17082+ * 17083+ * http://www.apache.org/licenses/LICENSE-2.0 17084+ * 17085+ * Unless required by applicable law or agreed to in writing, software 17086+ * distributed under the License is distributed on an "AS IS" BASIS, 17087+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17088+ * See the License for the specific language governing permissions and 17089+ * limitations under the License. 17090+ */ 17091+ 17092+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_ACTIVATION_DYNAMIC_FP32_CODER_H_ 17093+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_ACTIVATION_DYNAMIC_FP32_CODER_H_ 17094+ 17095+#include <string> 17096+#include <vector> 17097+#include "tools/converter/micro/coder/opcoders/op_coder.h" 17098+#include "tools/converter/micro/coder/shape_info_container.h" 17099+#include "tools/converter/micro/coder/dynamic_mem_manager.h" 17100+ 17101+namespace mindspore::lite::micro::nnacl { 17102+class ActivationDynamicFP32Coder : public OperatorCoder { 17103+ public: 17104+ ActivationDynamicFP32Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 17105+ const LiteGraph::Node *node, size_t node_index, Target target) 17106+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 17107+ 17108+ ~ActivationDynamicFP32Coder() override = default; 17109+ 17110+ int Prepare(CoderContext *const context) override { return RET_OK; } 17111+ 17112+ int DoCode(CoderContext *const context) override; 17113+ 17114+ protected: 17115+ int Preprocess(); 17116+ std::string count_; 17117+ std::string input_data_; 17118+ std::string output_data_; 17119+}; 17120+} // namespace mindspore::lite::micro::nnacl 17121+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_ACTIVATION_DYNAMIC_FP32_CODER_H_ 17122diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.cc 17123index c15d3101..1b827283 100644 17124--- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.cc 17125+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.cc 17126@@ -266,7 +266,6 @@ void ConvolutionWinogradFP32Coder::CollectFilesForFunc(CoderContext *const conte 17127 } else if (target_ == kARM64) { 17128 Collect(context, {}, {}, 17129 { 17130- "BigMatmulFp32Opt.S", 17131 "MatmulFp32.S", 17132 "MatmulFp32Opt.S", 17133 "PreSum4x16Int8Peroc.S", 17134diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_dynamic_fp32_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_dynamic_fp32_coder.cc 17135new file mode 100644 17136index 00000000..57d7a5dd 17137--- /dev/null 17138+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_dynamic_fp32_coder.cc 17139@@ -0,0 +1,106 @@ 17140+/** 17141+ * Copyright 2021-2022 Huawei Technologies Co., Ltd 17142+ * 17143+ * Licensed under the Apache License, Version 2.0 (the "License"); 17144+ * you may not use this file except in compliance with the License. 17145+ * You may obtain a copy of the License at 17146+ * 17147+ * http://www.apache.org/licenses/LICENSE-2.0 17148+ * 17149+ * Unless required by applicable law or agreed to in writing, software 17150+ * distributed under the License is distributed on an "AS IS" BASIS, 17151+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17152+ * See the License for the specific language governing permissions and 17153+ * limitations under the License. 17154+ */ 17155+ 17156+#include "coder/opcoders/nnacl/fp32/gather_dynamic_fp32_coder.h" 17157+#include <string> 17158+#include "nnacl/gather_parameter.h" 17159+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 17160+#include "coder/opcoders/file_collector.h" 17161+#include "coder/utils/coder_utils.h" 17162+#include "tools/common/string_util.h" 17163+ 17164+using mindspore::schema::PrimitiveType_Gather; 17165+ 17166+namespace mindspore::lite::micro::nnacl { 17167+int GatherDynamicFP32Coder::Prepare(CoderContext *const context) { 17168+ MS_CHECK_TRUE_MSG(input_tensors_.size() == C3NUM, RET_ERROR, "Gather's input-num must be 3."); 17169+ MS_CHECK_TRUE_MSG(input_tensors_[FIRST_INPUT]->IsConst() && input_tensors_[THIRD_INPUT]->IsConst(), RET_NOT_SUPPORT, 17170+ "Currently, only support the second input of gather is non-const when shape is dynamical."); 17171+ MS_CHECK_TRUE_MSG(input_tensors_[THIRD_INPUT]->data_type() == kNumberTypeInt32 || 17172+ input_tensors_[THIRD_INPUT]->data_type() == kNumberTypeInt, 17173+ RET_ERROR, "The data-type of Gather's third input must be int."); 17174+ auto axis = input_tensors_[THIRD_INPUT]->data(); 17175+ MS_CHECK_TRUE_MSG(axis != nullptr, RET_NULL_PTR, "Gather has no axis."); 17176+ axis_ = *(static_cast<int *>(axis)); 17177+ auto in_shape0 = input_tensors_[FIRST_INPUT]->shape(); 17178+ axis_ = axis_ >= 0 ? axis_ : axis_ + static_cast<int>(in_shape0.size()); 17179+ MS_CHECK_TRUE_MSG(axis_ >= 0 && axis_ < static_cast<int>(in_shape0.size()), RET_INPUT_TENSOR_ERROR, 17180+ "Gather's axis is out of range."); 17181+ return RET_OK; 17182+} 17183+ 17184+int GatherDynamicFP32Coder::DoCode(CoderContext *const context) { 17185+ Collect(context, 17186+ { 17187+ "nnacl/base/gather_base.h", 17188+ }, 17189+ { 17190+ "gather_base.c", 17191+ }); 17192+ auto in_shape0 = input_tensors_[FIRST_INPUT]->shape(); 17193+ auto data_item_size = static_cast<int>(lite::DataTypeSize(input_tensors_[FIRST_INPUT]->data_type())); 17194+ int64_t out_size = 1; 17195+ for (size_t i = 0; i < static_cast<size_t>(axis_); ++i) { 17196+ out_size *= in_shape0[i]; 17197+ } 17198+ int64_t byte_inner_size = data_item_size; 17199+ for (size_t i = axis_ + 1; i < in_shape0.size(); ++i) { 17200+ byte_inner_size *= in_shape0[i]; 17201+ } 17202+ int64_t limit = in_shape0[axis_]; 17203+ auto in_shape1 = shape_info_container_->GetTemplateShape(input_tensors_[SECOND_INPUT]); 17204+ int64_t const_part = 1; 17205+ std::string non_const_part; 17206+ for (const auto &item : in_shape1) { 17207+ if (IsNumber(item)) { 17208+ const_part *= std::stoi(item); 17209+ } else { 17210+ if (!non_const_part.empty()) { 17211+ non_const_part += " * "; 17212+ } 17213+ non_const_part += item; 17214+ } 17215+ } 17216+ std::string byte_out_stride_str = std::to_string(const_part * byte_inner_size); 17217+ std::string index_num_str = std::to_string(const_part); 17218+ if (!non_const_part.empty()) { 17219+ byte_out_stride_str += " * " + non_const_part; 17220+ index_num_str += " * " + non_const_part; 17221+ } 17222+ std::string input0_data = MemoryAllocator::GetInstance()->GetRuntimeAddr(input_tensors_[FIRST_INPUT], true); 17223+ MS_CHECK_TRUE_MSG(!input0_data.empty(), RET_ERROR, "pointer is not allocated by the allocator"); 17224+ std::string input1_data = dynamic_mem_manager_->GetVarTensorAddr(input_tensors_[SECOND_INPUT]); 17225+ MS_CHECK_TRUE_MSG(!input1_data.empty(), RET_ERROR, "pointer is not allocated by the allocator"); 17226+ std::string output_data = dynamic_mem_manager_->GetVarTensorAddr(output_tensors_[FIRST_INPUT]); 17227+ MS_CHECK_TRUE_MSG(!output_data.empty(), RET_ERROR, "pointer is not allocated by the allocator"); 17228+ NNaclFp32Serializer code; 17229+ code << "\t\tconst int8_t *int8_in = (const int8_t *)(" << input0_data << ");\n"; 17230+ code << "\t\tconst int *index_data = (const int *)(" << input1_data << ");\n"; 17231+ code << "\t\tint8_t *int8_out = (int8_t *)(" << output_data << ");\n"; 17232+ // call the op function 17233+ code.CodeFunction("Gather", "int8_in", out_size, byte_inner_size, limit, "index_data", index_num_str, "int8_out", 17234+ byte_out_stride_str); 17235+ context->AppendCode(code.str()); 17236+ return RET_OK; 17237+} 17238+ 17239+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_Gather, 17240+ CPUOpCoderCreator<GatherDynamicFP32Coder>) 17241+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_Gather, 17242+ CPUOpCoderCreator<GatherDynamicFP32Coder>) 17243+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_Gather, CPUOpCoderCreator<GatherDynamicFP32Coder>) 17244+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_Gather, CPUOpCoderCreator<GatherDynamicFP32Coder>) 17245+} // namespace mindspore::lite::micro::nnacl 17246diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_dynamic_fp32_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_dynamic_fp32_coder.h 17247new file mode 100644 17248index 00000000..9e58e1fa 17249--- /dev/null 17250+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_dynamic_fp32_coder.h 17251@@ -0,0 +1,42 @@ 17252+/** 17253+ * Copyright 2021 Huawei Technologies Co., Ltd 17254+ * 17255+ * Licensed under the Apache License, Version 2.0 (the "License"); 17256+ * you may not use this file except in compliance with the License. 17257+ * You may obtain a copy of the License at 17258+ * 17259+ * http://www.apache.org/licenses/LICENSE-2.0 17260+ * 17261+ * Unless required by applicable law or agreed to in writing, software 17262+ * distributed under the License is distributed on an "AS IS" BASIS, 17263+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17264+ * See the License for the specific language governing permissions and 17265+ * limitations under the License. 17266+ */ 17267+ 17268+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_GATHER_DYNAMIC_FP32_CODER_H_ 17269+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_GATHER_DYNAMIC_FP32_CODER_H_ 17270+ 17271+#include <string> 17272+#include <vector> 17273+#include "coder/opcoders/op_coder.h" 17274+#include "nnacl/base/tile_base.h" 17275+ 17276+namespace mindspore::lite::micro::nnacl { 17277+class GatherDynamicFP32Coder final : public OperatorCoder { 17278+ public: 17279+ GatherDynamicFP32Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 17280+ const LiteGraph::Node *node, size_t node_index, Target target) 17281+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 17282+ 17283+ ~GatherDynamicFP32Coder() override = default; 17284+ 17285+ int Prepare(CoderContext *const context) override; 17286+ 17287+ int DoCode(CoderContext *const context) override; 17288+ 17289+ private: 17290+ int axis_{0}; 17291+}; 17292+} // namespace mindspore::lite::micro::nnacl 17293+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_GATHER_DYNAMIC_FP32_CODER_H_ 17294diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_dynamic_fp32_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_dynamic_fp32_coder.cc 17295new file mode 100644 17296index 00000000..4ec7f317 17297--- /dev/null 17298+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_dynamic_fp32_coder.cc 17299@@ -0,0 +1,94 @@ 17300+/** 17301+ * Copyright 2022 Huawei Technologies Co., Ltd 17302+ * 17303+ * Licensed under the Apache License, Version 2.0 (the "License"); 17304+ * you may not use this file except in compliance with the License. 17305+ * You may obtain a copy of the License at 17306+ * 17307+ * http://www.apache.org/licenses/LICENSE-2.0 17308+ * 17309+ * Unless required by applicable law or agreed to in writing, software 17310+ * distributed under the License is distributed on an "AS IS" BASIS, 17311+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17312+ * See the License for the specific language governing permissions and 17313+ * limitations under the License. 17314+ */ 17315+#include "coder/opcoders/nnacl/fp32/split_dynamic_fp32_coder.h" 17316+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 17317+#include "coder/opcoders/file_collector.h" 17318+#include "coder/opcoders/parallel.h" 17319+#include "coder/utils/coder_utils.h" 17320+#include "nnacl/op_base.h" 17321+ 17322+using mindspore::schema::PrimitiveType_Split; 17323+ 17324+namespace mindspore::lite::micro::nnacl { 17325+int SplitDynamicFP32Coder::Prepare(CoderContext *const context) { 17326+ auto input_shape = shape_info_container_->GetTemplateShape(input_tensor_); 17327+ int in_shape_size = static_cast<int>(input_shape.size()); 17328+ CHECK_LESS_RETURN(in_shape_size, 1); 17329+ CHECK_LESS_RETURN(SPLIT_STRIDES_SIZE - 1, in_shape_size); 17330+ param_ = reinterpret_cast<SplitParameter *>(parameter_); 17331+ CHECK_NULL_RETURN(param_); 17332+ 17333+ auto split_dim = param_->split_dim_; 17334+ param_->split_dim_ = split_dim >= 0 ? split_dim : in_shape_size + split_dim; 17335+ std::vector<std::string> strides(in_shape_size); 17336+ strides[in_shape_size - 1] = "1"; 17337+ for (int i = static_cast<int>(in_shape_size) - C2NUM; i >= 0; i--) { 17338+ strides[i] = strides[i + 1] + " * " + input_shape[i + 1]; 17339+ } 17340+ dynamic_param_.strides_ = "{"; 17341+ for (int i = 0; i < in_shape_size; ++i) { 17342+ dynamic_param_.strides_ += strides[i] + ", "; 17343+ } 17344+ dynamic_param_.strides_ += "}"; 17345+ CHECK_LESS_RETURN(in_shape_size, param_->split_dim_ + 1); 17346+ if (input_shape.at(param_->split_dim_) == "0") { 17347+ MS_LOG(ERROR) << "input_shape[" << param_->split_dim_ << "] must not be zero!"; 17348+ return RET_ERROR; 17349+ } 17350+ CHECK_LESS_RETURN(SPLIT_STRIDES_SIZE, param_->split_dim_ + 1); 17351+ if (strides[param_->split_dim_] == "0") { 17352+ MS_LOG(ERROR) << "strides[" << param_->split_dim_ << "] must not be zero!"; 17353+ return RET_ERROR; 17354+ } 17355+ dynamic_param_.split_count_ = strides[0] + " * " + input_shape[0] + " / (" + input_shape.at(param_->split_dim_) + 17356+ " * " + strides[param_->split_dim_] + ")"; 17357+ param_->n_dims_ = static_cast<int>(input_shape.size()); 17358+ CHECK_LESS_RETURN(param_->num_split_, 1); 17359+ MS_CHECK_TRUE_MSG(param_->split_sizes_[0] != 0 && param_->split_sizes_[param_->num_split_ - 1] != -1, 17360+ lite::RET_PARAM_INVALID, "Currently, split not support split_size 0 or -1"); 17361+ return RET_OK; 17362+} 17363+ 17364+int SplitDynamicFP32Coder::DoCode(CoderContext *const context) { 17365+ Collect(context, {"nnacl/base/split_base.h"}, {"split_base.c"}); 17366+ NNaclFp32Serializer code; 17367+ code << " void *output_ptrs[" << output_tensors_.size() << "] = {"; 17368+ for (int i = 0; i < param_->num_split_; i++) { 17369+ code << GetTensorAddr(output_tensors_.at(i), output_tensors_.at(i)->IsConst(), dynamic_mem_manager_, allocator_) 17370+ << ", "; 17371+ } 17372+ code << "};\n"; 17373+ auto input_shape = shape_info_container_->GetTemplateShape(input_tensor_); 17374+ code << " int input_dim[" << input_shape.size() << "] = {"; 17375+ for (auto &dim : input_shape) { 17376+ code << dim << ", "; 17377+ } 17378+ code << "};\n"; 17379+ std::string input_data = GetTensorAddr(input_tensor_, input_tensor_->IsConst(), dynamic_mem_manager_, allocator_); 17380+ std::string num_unit = dynamic_param_.split_count_ + " * " + std::to_string(param_->num_split_); 17381+ code.CodeStruct("split_param", *param_, dynamic_param_); 17382+ code.CodeFunction("DoSplit", input_data, "output_ptrs", "input_dim", "0", num_unit, "&split_param", 17383+ lite::DataTypeSize(input_tensor_->data_type())); 17384+ context->AppendCode(code.str()); 17385+ return RET_OK; 17386+} 17387+ 17388+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_Split, 17389+ CPUOpCoderCreator<SplitDynamicFP32Coder>) 17390+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_Split, CPUOpCoderCreator<SplitDynamicFP32Coder>) 17391+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeFloat16, PrimitiveType_Split, 17392+ CPUOpCoderCreator<SplitDynamicFP32Coder>) 17393+} // namespace mindspore::lite::micro::nnacl 17394diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_dynamic_fp32_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_dynamic_fp32_coder.h 17395new file mode 100644 17396index 00000000..e3e64cb3 17397--- /dev/null 17398+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_dynamic_fp32_coder.h 17399@@ -0,0 +1,42 @@ 17400+/** 17401+ * Copyright 2023 Huawei Technologies Co., Ltd 17402+ * 17403+ * Licensed under the Apache License, Version 2.0 (the "License"); 17404+ * you may not use this file except in compliance with the License. 17405+ * You may obtain a copy of the License at 17406+ * 17407+ * http://www.apache.org/licenses/LICENSE-2.0 17408+ * 17409+ * Unless required by applicable law or agreed to in writing, software 17410+ * distributed under the License is distributed on an "AS IS" BASIS, 17411+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17412+ * See the License for the specific language governing permissions and 17413+ * limitations under the License. 17414+ */ 17415+ 17416+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_SPLIT_DYNAMIC_FP32_CODER_H_ 17417+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_SPLIT_DYNAMIC_FP32_CODER_H_ 17418+ 17419+#include <vector> 17420+#include "coder/opcoders/op_coder.h" 17421+#include "coder/opcoders/nnacl/dynamic_parameter/split_dynamic_parameter.h" 17422+#include "nnacl/split_parameter.h" 17423+ 17424+namespace mindspore::lite::micro::nnacl { 17425+class SplitDynamicFP32Coder : public OperatorCoder { 17426+ public: 17427+ SplitDynamicFP32Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 17428+ const LiteGraph::Node *node, size_t node_index, Target target) 17429+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 17430+ ~SplitDynamicFP32Coder() override = default; 17431+ 17432+ int Prepare(CoderContext *const context) override; 17433+ 17434+ int DoCode(CoderContext *const context) override; 17435+ 17436+ protected: 17437+ SplitParameter *param_{nullptr}; 17438+ SplitDynamicParameter dynamic_param_; 17439+}; 17440+} // namespace mindspore::lite::micro::nnacl 17441+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_SPLIT_DYNAMIC_FP32_CODER_H_ 17442diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.cc 17443new file mode 100644 17444index 00000000..7fb160d5 17445--- /dev/null 17446+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.cc 17447@@ -0,0 +1,171 @@ 17448+/** 17449+ * Copyright 2023 Huawei Technologies Co., Ltd 17450+ * 17451+ * Licensed under the Apache License, Version 2.0 (the "License"); 17452+ * you may not use this file except in compliance with the License. 17453+ * You may obtain a copy of the License at 17454+ * 17455+ * http://www.apache.org/licenses/LICENSE-2.0 17456+ * 17457+ * Unless required by applicable law or agreed to in writing, software 17458+ * distributed under the License is distributed on an "AS IS" BASIS, 17459+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17460+ * See the License for the specific language governing permissions and 17461+ * limitations under the License. 17462+ */ 17463+ 17464+#include "coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.h" 17465+#include <vector> 17466+#include <unordered_set> 17467+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 17468+#include "coder/opcoders/file_collector.h" 17469+#include "coder/opcoders/parallel.h" 17470+#include "coder/utils/coder_utils.h" 17471+ 17472+using mindspore::schema::PrimitiveType_Transpose; 17473+namespace mindspore::lite::micro::nnacl { 17474+int TransposeDynamicFp32Coder::Prepare(CoderContext *const context) { 17475+ MS_CHECK_TRUE_MSG(input_tensor_->data_type() == kNumberTypeInt32 || input_tensor_->data_type() == kNumberTypeFloat32, 17476+ RET_INPUT_PARAM_INVALID, "Input tensor data type is invalid."); 17477+ MS_CHECK_TRUE_MSG(input_tensors_[SECOND_INPUT]->data_type() == kNumberTypeInt32, RET_INPUT_PARAM_INVALID, 17478+ "Perm tensor data type is invalid."); 17479+ MS_CHECK_TRUE_MSG( 17480+ output_tensor_->data_type() == kNumberTypeInt32 || output_tensor_->data_type() == kNumberTypeFloat32, 17481+ RET_INPUT_PARAM_INVALID, "Output tensor data type is invalid."); 17482+ MS_CHECK_TRUE_MSG(input_tensors_[SECOND_INPUT]->IsConst(), RET_NOT_SUPPORT, 17483+ "The second input of transpose is non-const."); 17484+ thread_num_ = 1; 17485+ MS_CHECK_RET_CODE(Init(), "init failed"); 17486+ return RET_OK; 17487+} 17488+ 17489+int TransposeDynamicFp32Coder::DoCode(CoderContext *const context) { 17490+ Collect(context, 17491+ { 17492+ "nnacl/transpose_parameter.h", 17493+ "nnacl/errorcode.h", 17494+ "nnacl/fp32/transpose_fp32.h", 17495+ }, 17496+ { 17497+ "transpose_fp32.c", 17498+ }); 17499+ 17500+ NNaclFp32Serializer code; 17501+ dims_ = static_cast<int>(out_shapes_.size()); 17502+ code << "const int32_t output_shape[" << dims_ << "] = {"; 17503+ for (size_t i = 0; i < out_shapes_.size(); ++i) { 17504+ code << out_shapes_[i] << ", "; 17505+ } 17506+ code << "};\n"; 17507+ code.CodeStruct("trans_param", *param_, dynamic_param_); 17508+ auto input_str = dynamic_mem_manager_->GetVarTensorAddr(input_tensor_); 17509+ auto output_str = dynamic_mem_manager_->GetVarTensorAddr(output_tensor_); 17510+ if (param_->num_axes_ > DIMENSION_6D) { 17511+ code.CodeFunction("TransposeDimsFp32", input_str, output_str, "output_shape", "trans_param.perm_", 17512+ "trans_param.strides_", "trans_param.out_strides_", "trans_param.num_axes_", kDefaultTaskId, 17513+ kDefaultThreadNum); 17514+ } else { 17515+ code.CodeFunction("DoTransposeFp32", input_str, output_str, "output_shape", "trans_param.perm_", 17516+ "trans_param.strides_", "trans_param.out_strides_", "trans_param.data_num_", 17517+ "trans_param.num_axes_"); 17518+ } 17519+ context->AppendCode(code.str()); 17520+ return RET_OK; 17521+} 17522+ 17523+int TransposeDynamicFp32Coder::Init() { 17524+ param_ = reinterpret_cast<TransposeParameter *>(parameter_); 17525+ MS_CHECK_PTR(param_); 17526+ param_->num_axes_ = 0; 17527+ if (input_tensors_.size() == C2NUM) { 17528+ param_->num_axes_ = input_tensors_[SECOND_INPUT]->ElementsNum(); 17529+ } 17530+ if (input_tensor_->shape().size() != static_cast<size_t>(param_->num_axes_)) { 17531+ return RET_OK; 17532+ } 17533+ // get perm data 17534+ auto ret = ResetStatus(); 17535+ if (ret != RET_OK) { 17536+ MS_LOG(ERROR) << "Do transpose reset failed."; 17537+ return ret; 17538+ } 17539+ 17540+ ret = ComputeOfflineInfo(); 17541+ if (ret != RET_OK) { 17542+ MS_LOG(ERROR) << "Do compute transpose offline info failed."; 17543+ return ret; 17544+ } 17545+ return RET_OK; 17546+} 17547+ 17548+int TransposeDynamicFp32Coder::ResetStatus() { 17549+ auto in_shape = shape_info_container_->GetTemplateShape(input_tensor_); 17550+ if (in_shape.size() > MAX_TRANSPOSE_DIM_SIZE) { 17551+ MS_LOG(ERROR) << "input shape out of range."; 17552+ return RET_ERROR; 17553+ } 17554+ int trans_nd[MAX_TRANSPOSE_DIM_SIZE] = {0, 2, 1}; 17555+ int *perm_data{nullptr}; 17556+ if (in_shape.size() != static_cast<size_t>(param_->num_axes_)) { 17557+ perm_data = trans_nd; 17558+ if (in_shape.size() == C3NUM && param_->num_axes_ == C4NUM) { 17559+ param_->num_axes_ = C3NUM; 17560+ } 17561+ if (param_->num_axes_ == 0) { 17562+ for (int i = 0; i < static_cast<int>(in_shape.size()); ++i) { 17563+ trans_nd[i] = static_cast<int>(in_shape.size()) - 1 - i; 17564+ } 17565+ param_->num_axes_ = static_cast<int>(in_shape.size()); 17566+ } 17567+ } else { 17568+ if (input_tensors_.size() != C2NUM) { 17569+ MS_LOG(ERROR) << "input tensors size is not equal to 2."; 17570+ return RET_ERROR; 17571+ } 17572+ auto perm_tensor = input_tensors_.at(SECOND_INPUT); 17573+ perm_data = reinterpret_cast<int *>(perm_tensor->data()); 17574+ MSLITE_CHECK_PTR(perm_data); 17575+ std::vector<int> perm(perm_data, perm_data + input_tensors_[SECOND_INPUT]->ElementsNum()); 17576+ if (perm.size() != std::unordered_set<int>(perm.cbegin(), perm.cend()).size()) { 17577+ MS_LOG(ERROR) << "Invalid perm, the same element exits in perm."; 17578+ return RET_ERROR; 17579+ } 17580+ } 17581+ MS_CHECK_TRUE_MSG(param_->num_axes_ <= MAX_TRANSPOSE_DIM_SIZE, RET_ERROR, "transpose perm is invalid."); 17582+ for (int i = 0; i < param_->num_axes_; ++i) { 17583+ param_->perm_[i] = perm_data[i]; 17584+ } 17585+ return RET_OK; 17586+} 17587+ 17588+int TransposeDynamicFp32Coder::ComputeOfflineInfo() { 17589+ in_shapes_ = shape_info_container_->GetTemplateShape(input_tensor_); 17590+ out_shapes_ = shape_info_container_->GetTemplateShape(output_tensor_); 17591+ const int ori_stride = 1; 17592+ dynamic_param_.strides_ = std::to_string(ori_stride) + ", "; 17593+ dynamic_param_.out_strides_ = std::to_string(ori_stride) + ", "; 17594+ dynamic_param_.data_num_ = AccumulateShape(in_shapes_, 0, in_shapes_.size()); 17595+ std::vector<std::string> strides(param_->num_axes_); 17596+ std::vector<std::string> out_strides(param_->num_axes_); 17597+ strides[param_->num_axes_ - 1] = "1"; 17598+ out_strides[param_->num_axes_ - 1] = "1"; 17599+ for (int i = param_->num_axes_ - C2NUM; i >= 0; --i) { 17600+ strides[i] = in_shapes_[i + 1] + " * " + strides[i + 1]; 17601+ out_strides[i] = out_shapes_[i + 1] + " * " + out_strides[i + 1]; 17602+ } 17603+ dynamic_param_.strides_ = "{"; 17604+ dynamic_param_.out_strides_ = "{"; 17605+ for (int i = 0; i < param_->num_axes_; ++i) { 17606+ dynamic_param_.strides_ += strides[i] + ", "; 17607+ dynamic_param_.out_strides_ += out_strides[i] + ", "; 17608+ } 17609+ dynamic_param_.strides_ += "}"; 17610+ dynamic_param_.out_strides_ += "}"; 17611+ return RET_OK; 17612+} 17613+ 17614+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat32, PrimitiveType_Transpose, 17615+ CPUOpCoderCreator<TransposeDynamicFp32Coder>) 17616+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeInt32, PrimitiveType_Transpose, 17617+ CPUOpCoderCreator<TransposeDynamicFp32Coder>) 17618+} // namespace mindspore::lite::micro::nnacl 17619diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.h 17620new file mode 100644 17621index 00000000..9230b8e3 17622--- /dev/null 17623+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.h 17624@@ -0,0 +1,49 @@ 17625+/** 17626+ * Copyright 2023 Huawei Technologies Co., Ltd 17627+ * 17628+ * Licensed under the Apache License, Version 2.0 (the "License"); 17629+ * you may not use this file except in compliance with the License. 17630+ * You may obtain a copy of the License at 17631+ * 17632+ * http://www.apache.org/licenses/LICENSE-2.0 17633+ * 17634+ * Unless required by applicable law or agreed to in writing, software 17635+ * distributed under the License is distributed on an "AS IS" BASIS, 17636+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17637+ * See the License for the specific language governing permissions and 17638+ * limitations under the License. 17639+ */ 17640+ 17641+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_TRANSPOSE_DYNAMIC_FP32_CODER_H_ 17642+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_TRANSPOSE_DYNAMIC_FP32_CODER_H_ 17643+#include <vector> 17644+#include <string> 17645+#include "coder/opcoders/op_coder.h" 17646+#include "nnacl/transpose_parameter.h" 17647+#include "coder/opcoders/nnacl/dynamic_parameter/transpose_dynamic_parameter.h" 17648+ 17649+namespace mindspore::lite::micro::nnacl { 17650+class TransposeDynamicFp32Coder : public OperatorCoder { 17651+ public: 17652+ TransposeDynamicFp32Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 17653+ const LiteGraph::Node *node, size_t node_index, Target target) 17654+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 17655+ 17656+ ~TransposeDynamicFp32Coder() override = default; 17657+ 17658+ int Prepare(CoderContext *const context) override; 17659+ 17660+ int DoCode(CoderContext *const context) override; 17661+ 17662+ protected: 17663+ int Init(); 17664+ int ResetStatus(); 17665+ int ComputeOfflineInfo(); 17666+ TransposeParameter *param_{nullptr}; 17667+ TransposeDynamicParameter dynamic_param_; 17668+ int dims_{0}; 17669+ std::vector<std::string> in_shapes_; 17670+ std::vector<std::string> out_shapes_; 17671+}; 17672+} // namespace mindspore::lite::micro::nnacl 17673+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_TRANSPOSE_DYNAMIC_FP32_CODER_H_ 17674diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder.h 17675index dffaf14b..fa59e483 100644 17676--- a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder.h 17677+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder.h 17678@@ -28,6 +28,8 @@ 17679 #include "securec/include/securec.h" 17680 #include "tools/converter/micro/coder/opcoders/op_coder_register.h" 17681 #include "tools/converter/micro/coder/log.h" 17682+#include "tools/converter/micro/coder/shape_info_container.h" 17683+#include "tools/converter/micro/coder/dynamic_mem_manager.h" 17684 17685 namespace mindspore::lite::micro { 17686 constexpr int kPrecision = 19; 17687@@ -71,6 +73,8 @@ class OperatorCoder { 17688 17689 void set_parameter(OpParameter *parameter); 17690 17691+ OpParameter *get_parameter() const { return parameter_; } 17692+ 17693 const LiteGraph::Node *node() const { return this->node_; } 17694 17695 void AddInitialParameters(Tensor *parameter) { initial_parameters_.push_back(parameter); } 17696@@ -88,6 +92,12 @@ class OperatorCoder { 17697 17698 void set_thread_num(int thread_num); 17699 17700+ void set_shape_info_container(ShapeInfoContainer *shape_info_container) { 17701+ shape_info_container_ = shape_info_container; 17702+ } 17703+ 17704+ void set_dynamic_mem_manager(DynamicMemManager *dynamic_mem_manager) { dynamic_mem_manager_ = dynamic_mem_manager; } 17705+ 17706 protected: 17707 std::vector<Tensor *> input_tensors_; 17708 std::vector<Tensor *> output_tensors_; 17709@@ -103,6 +113,8 @@ class OperatorCoder { 17710 bool support_parallel_{false}; 17711 int thread_num_{1}; 17712 int schema_version_ = lite::SCHEMA_VERSION::SCHEMA_CUR; 17713+ ShapeInfoContainer *shape_info_container_{nullptr}; 17714+ DynamicMemManager *dynamic_mem_manager_{nullptr}; 17715 17716 private: 17717 size_t node_index_{0}; 17718diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.cc 17719index 45b2e37f..e2d70c12 100644 17720--- a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.cc 17721+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.cc 17722@@ -35,7 +35,7 @@ std::unique_ptr<OperatorCoder> OpCoderBuilder::build(int schema_version) { 17723 } 17724 coder_key = CoderKey(target_, data_type_, schema::PrimitiveType_Custom, custom_type->str()); 17725 } 17726- CoderCreatorFunc creator_func = OpCoderFactory::GetInstance()->FindOpCoder(coder_key); 17727+ CoderCreatorFunc creator_func = OpCoderFactory::GetInstance()->FindOpCoder(coder_key, dynamic_); 17728 if (creator_func == nullptr) { 17729 MS_LOG(ERROR) << "caught unsupported layer: " << node_->name_; 17730 return nullptr; 17731@@ -125,5 +125,10 @@ OpCoderBuilder &OpCoderBuilder::is_builtin_custom(bool builtin_custom) { 17732 return *this; 17733 } 17734 17735+OpCoderBuilder &OpCoderBuilder::is_dynamic(bool dynamic) { 17736+ dynamic_ = dynamic; 17737+ return *this; 17738+} 17739+ 17740 void OpCoderBuilder::Reset() {} 17741 } // namespace mindspore::lite::micro 17742diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.h 17743index d85f1c32..bdd815ef 100644 17744--- a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.h 17745+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.h 17746@@ -48,6 +48,8 @@ class OpCoderBuilder { 17747 17748 OpCoderBuilder &is_builtin_custom(bool builtin_custom); 17749 17750+ OpCoderBuilder &is_dynamic(bool dynamic); 17751+ 17752 void Reset(); 17753 17754 private: 17755@@ -74,6 +76,8 @@ class OpCoderBuilder { 17756 bool support_parallel_{false}; 17757 17758 bool builtin_custom_{false}; 17759+ 17760+ bool dynamic_{false}; 17761 }; 17762 } // namespace mindspore::lite::micro 17763 #endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_OP_CODER_BUILDER_H_ 17764diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.cc 17765index cf26d51d..1dac9c73 100644 17766--- a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.cc 17767+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.cc 17768@@ -37,33 +37,38 @@ OpCoderFactory *OpCoderFactory::GetInstance() { 17769 } 17770 17771 int OpCoderFactory::RegistOpCoder(Target target, TypeId data_type, schema::PrimitiveType operator_type, 17772- const std::string &builtin_custom_type, const CoderCreatorFunc &creator_func) { 17773+ const std::string &builtin_custom_type, const CoderCreatorFunc &creator_func, 17774+ bool dynamic) { 17775+ auto &op_sets = dynamic ? dynamic_opcoder_sets_ : static_opcoder_sets_; 17776 // check key 17777 CoderKey key(target, data_type, operator_type, builtin_custom_type); 17778 // insert pair to registry 17779- if (this->opcoder_sets_.find(key) != this->opcoder_sets_.end()) { 17780+ if (op_sets.find(key) != op_sets.end()) { 17781 MS_LOG(ERROR) << "coder already exist: " << key.ToString(); 17782 return RET_ERROR; 17783 } 17784- this->opcoder_sets_.insert(std::pair<CoderKey, CoderCreatorFunc>(key, creator_func)); 17785+ op_sets.insert(std::pair<CoderKey, CoderCreatorFunc>(key, creator_func)); 17786 return RET_OK; 17787 } 17788 17789-CoderCreatorFunc OpCoderFactory::FindOpCoder(const CoderKey &key) { 17790- auto iterator = this->opcoder_sets_.find(key); 17791- if (iterator != this->opcoder_sets_.end()) { 17792+CoderCreatorFunc OpCoderFactory::FindOpCoder(const CoderKey &key, bool dynamic) { 17793+ const auto &op_sets = dynamic ? dynamic_opcoder_sets_ : static_opcoder_sets_; 17794+ auto iterator = op_sets.find(key); 17795+ if (iterator != op_sets.end()) { 17796 return iterator->second; 17797 } 17798 // matching kAllTargets 17799- iterator = this->opcoder_sets_.find(key.AllKey()); 17800- if (iterator != this->opcoder_sets_.end()) { 17801+ iterator = op_sets.find(key.AllKey()); 17802+ if (iterator != op_sets.end()) { 17803 return iterator->second; 17804 } 17805 return nullptr; 17806 } 17807 17808 OpCoderRegister::OpCoderRegister(Target target, TypeId data_type, schema::PrimitiveType operator_type, 17809- const std::string &builtin_custom_type, const CoderCreatorFunc &creatorFunc) { 17810- OpCoderFactory::GetInstance()->RegistOpCoder(target, data_type, operator_type, builtin_custom_type, creatorFunc); 17811+ const std::string &builtin_custom_type, const CoderCreatorFunc &creatorFunc, 17812+ bool dynamic) { 17813+ OpCoderFactory::GetInstance()->RegistOpCoder(target, data_type, operator_type, builtin_custom_type, creatorFunc, 17814+ dynamic); 17815 } 17816 } // namespace mindspore::lite::micro 17817diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.h b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.h 17818index 30c8a64d..b616e287 100644 17819--- a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.h 17820+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.h 17821@@ -65,15 +65,19 @@ class OpCoderFactory { 17822 static OpCoderFactory *GetInstance(); 17823 17824 int RegistOpCoder(Target target, TypeId data_type, schema::PrimitiveType operator_type, 17825- const std::string &builtin_custom_type, const CoderCreatorFunc &creator_func); 17826+ const std::string &builtin_custom_type, const CoderCreatorFunc &creator_func, bool dynamic); 17827 17828- CoderCreatorFunc FindOpCoder(const CoderKey &key); 17829+ CoderCreatorFunc FindOpCoder(const CoderKey &key, bool dynamic = false); 17830 17831- ~OpCoderFactory() { opcoder_sets_.clear(); } 17832+ ~OpCoderFactory() { 17833+ static_opcoder_sets_.clear(); 17834+ dynamic_opcoder_sets_.clear(); 17835+ } 17836 17837 private: 17838 // target || data type || primitive type 17839- std::map<CoderKey, CoderCreatorFunc> opcoder_sets_; 17840+ std::map<CoderKey, CoderCreatorFunc> static_opcoder_sets_; 17841+ std::map<CoderKey, CoderCreatorFunc> dynamic_opcoder_sets_; 17842 }; 17843 17844 class OpCoderRegister { 17845@@ -81,16 +85,20 @@ class OpCoderRegister { 17846 OpCoderRegister() = delete; 17847 17848 OpCoderRegister(Target target, TypeId data_type, schema::PrimitiveType operator_type, 17849- const std::string &builtin_custom_type, const CoderCreatorFunc &creator_func); 17850+ const std::string &builtin_custom_type, const CoderCreatorFunc &creator_func, bool dynamic = false); 17851 17852 ~OpCoderRegister() = default; 17853 }; 17854-#define REG_OPERATOR_CODER(target, data_type, operator_type, creator_func) \ 17855- static OpCoderRegister g_##target##data_type##operator_type##Creator(target, data_type, operator_type, "", \ 17856- creator_func); 17857+#define REG_OPERATOR_CODER(target, data_type, operator_type, creator_func) \ 17858+ static OpCoderRegister g_##target##data_type##operator_type##StaticCreator(target, data_type, operator_type, "", \ 17859+ creator_func); 17860 17861 #define REG_BUILIN_CUSTOM_CODER(target, data_type, custom_type, creator_func) \ 17862 static OpCoderRegister g_##target##data_type##operator_type##Creator( \ 17863 target, data_type, schema::PrimitiveType_Custom, custom_type, creator_func); 17864+ 17865+#define REG_DYNAMIC_OPERATOR_CODER(target, data_type, operator_type, creator_func) \ 17866+ static OpCoderRegister g_##target##data_type##operator_type##DynamicCreator(target, data_type, operator_type, "", \ 17867+ creator_func, true); 17868 } // namespace mindspore::lite::micro 17869 #endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_OP_CODER_REGISTER_H_ 17870diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.cc 17871index a3743b48..920f2723 100644 17872--- a/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.cc 17873+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.cc 17874@@ -38,6 +38,15 @@ void NNaclFp32Serializer::CodeStruct(const std::string &name, const PoolingCompu 17875 pooling_compute.maxf); 17876 } 17877 17878+void NNaclFp32Serializer::CodeStruct(const std::string &name, const PoolingComputeParam &pooling_compute, 17879+ const PoolingDynamicParameter &dynamic_pooling_param) { 17880+ CodeBaseStruct<false>("PoolingComputeParam", name, pooling_compute.input_w_, pooling_compute.input_h_, 17881+ dynamic_pooling_param.input_batch_, pooling_compute.input_channel_, pooling_compute.output_w_, 17882+ pooling_compute.output_h_, dynamic_pooling_param.output_batch_, pooling_compute.output_channel_, 17883+ pooling_compute.window_w_, pooling_compute.window_h_, pooling_compute.minf, 17884+ pooling_compute.maxf); 17885+} 17886+ 17887 void NNaclFp32Serializer::CodeStruct(const std::string &name, const BatchNormParameter &batch_norm_parameter) { 17888 CodeBaseStruct("BatchNormParameter", name, batch_norm_parameter.op_parameter_, batch_norm_parameter.epsilon_, 17889 batch_norm_parameter.momentum_, batch_norm_parameter.unit_, batch_norm_parameter.units_, 17890@@ -85,6 +94,29 @@ void NNaclFp32Serializer::CodeStruct(const std::string &name, const ConvParamete 17891 conv_parameter.output_padding_w_, conv_parameter.output_padding_h_); 17892 } 17893 17894+void NNaclFp32Serializer::CodeStruct(const std::string &name, const ConvParameter &conv_parameter, 17895+ const ConvDynamicParameter &dynamic_conv_param) { 17896+ CodeBaseStruct<false>( 17897+ "ConvParameter", name, conv_parameter.op_parameter_, "{0}", conv_parameter.kernel_h_, conv_parameter.kernel_w_, 17898+ conv_parameter.stride_h_, conv_parameter.stride_w_, conv_parameter.dilation_h_, conv_parameter.dilation_w_, 17899+ conv_parameter.pad_u_, conv_parameter.pad_d_, conv_parameter.pad_l_, conv_parameter.pad_r_, conv_parameter.group_, 17900+ conv_parameter.tile_num_, dynamic_conv_param.input_batch_, conv_parameter.input_h_, conv_parameter.input_w_, 17901+ conv_parameter.input_channel_, dynamic_conv_param.output_batch_, conv_parameter.output_h_, conv_parameter.output_w_, 17902+ conv_parameter.output_channel_, conv_parameter.thread_num_, conv_parameter.input_unit_, conv_parameter.output_unit_, 17903+ conv_parameter.pad_mode_, conv_parameter.act_type_, conv_parameter.channel_multiplie_, 17904+ conv_parameter.output_padding_w_, conv_parameter.output_padding_h_); 17905+} 17906+ 17907+void NNaclFp32Serializer::CodeStruct(const std::string &name, const MatMulParameter &mat_mul_parameter) { 17908+ CodeBaseStruct<false>( 17909+ "MatMulParameter", name, mat_mul_parameter.op_parameter_, mat_mul_parameter.has_bias_, mat_mul_parameter.use_axis_, 17910+ mat_mul_parameter.a_transpose_, mat_mul_parameter.b_transpose_, mat_mul_parameter.act_type_, mat_mul_parameter.row_, 17911+ mat_mul_parameter.col_, mat_mul_parameter.row_4_, mat_mul_parameter.row_16_, mat_mul_parameter.row_align_, 17912+ mat_mul_parameter.col_8_, mat_mul_parameter.col_align_, mat_mul_parameter.deep_, mat_mul_parameter.deep_4_, 17913+ mat_mul_parameter.deep_16_, mat_mul_parameter.deep_align_, mat_mul_parameter.batch, mat_mul_parameter.a_const_, 17914+ mat_mul_parameter.b_const_, mat_mul_parameter.axis_, mat_mul_parameter.matmul_type_); 17915+} 17916+ 17917 void NNaclFp32Serializer::CodeStruct(const std::string &name, const MicroMatmulParameter µ_matmul_parameter) { 17918 CodeBaseStruct<false>("MicroMatmulParameter", name, micro_matmul_parameter.act_type_, 17919 micro_matmul_parameter.thread_num_, micro_matmul_parameter.row_, micro_matmul_parameter.col_, 17920@@ -102,18 +134,41 @@ void NNaclFp32Serializer::CodeStruct(const std::string &name, const ScaleStruct 17921 scale_struct.outer_size_, scale_struct.inner_size_); 17922 } 17923 17924+void NNaclFp32Serializer::CodeStruct(const std::string &name, const ScaleStruct &scale_struct, 17925+ const ScaleDynamicParameter &dynamic_scale_param) { 17926+ CodeBaseStruct<false>("ScaleStruct", name, "{}", scale_struct.axis_, scale_struct.data_type_, 17927+ dynamic_scale_param.axis_size_, dynamic_scale_param.outer_size_, 17928+ dynamic_scale_param.inner_size_); 17929+} 17930+ 17931 void NNaclFp32Serializer::CodeStruct(const std::string &name, const SliceParameter &slice_parameter) { 17932 CodeBaseStruct("SliceParameter", name, slice_parameter.op_parameter_, ToString(slice_parameter.shape_), 17933 ToString(slice_parameter.begin_), ToString(slice_parameter.end_), ToString(slice_parameter.size_), 17934 "{0}", slice_parameter.param_length_); 17935 } 17936 17937+void NNaclFp32Serializer::CodeStruct(const std::string &name, const SliceParameter &slice_parameter, 17938+ const SliceDynamicParameter &dynamic_slice_param) { 17939+ CodeBaseStruct<false>("SliceParameter", name, slice_parameter.op_parameter_, dynamic_slice_param.shape_, 17940+ ToString(slice_parameter.begin_), dynamic_slice_param.end_, dynamic_slice_param.size_, "{0}", 17941+ slice_parameter.param_length_); 17942+} 17943+ 17944 void NNaclFp32Serializer::CodeStruct(const std::string &name, const SplitParameter &split_parameter) { 17945 CodeBaseStruct("SplitParameter", name, split_parameter.op_parameter_, split_parameter.num_split_, "split_sizes", 17946 split_parameter.split_dim_, ToString(split_parameter.strides_), "{0}", split_parameter.n_dims_, 17947 split_parameter.split_count_); 17948 } 17949 17950+void NNaclFp32Serializer::CodeStruct(const std::string &name, const SplitParameter &split_parameter, 17951+ const SplitDynamicParameter &dynamic_split_param) { 17952+ CodeArray("split_sizes", split_parameter.split_sizes_, split_parameter.num_split_, false); 17953+ CodeBaseStruct<false>("SplitParameter", name, split_parameter.op_parameter_, split_parameter.num_split_, nullptr, 17954+ split_parameter.split_dim_, dynamic_split_param.strides_, "{0}", split_parameter.n_dims_, 17955+ dynamic_split_param.split_count_); 17956+ code << " " << name << ".split_sizes_ = split_sizes;\n"; 17957+} 17958+ 17959 void NNaclFp32Serializer::CodeStruct(const std::string &name, const TileParameter &tile_parameter) { 17960 CodeBaseStruct("TileParameter", name, tile_parameter.op_parameter_, ToString(tile_parameter.multiples_), 17961 ToString(tile_parameter.in_shape_), ToString(tile_parameter.out_shape_), 17962@@ -127,12 +182,32 @@ void NNaclFp32Serializer::CodeStruct(const std::string &name, const TransposePar 17963 ToString(transpose_parameter.out_strides_), transpose_parameter.num_axes_, transpose_parameter.data_num_); 17964 } 17965 17966+void NNaclFp32Serializer::CodeStruct(const std::string &name, const TransposeParameter &transpose_param, 17967+ const TransposeDynamicParameter &dynamic_transpose_param) { 17968+ CodeBaseStruct<false>("TransposeParameter", name, transpose_param.op_parameter_, ToString(transpose_param.perm_), 17969+ transpose_param.perm_size_, transpose_param.conjugate_, dynamic_transpose_param.strides_, 17970+ dynamic_transpose_param.out_strides_, transpose_param.num_axes_, 17971+ dynamic_transpose_param.data_num_); 17972+} 17973+ 17974 void NNaclFp32Serializer::CodeStruct(const std::string &name, const LstmParameter &lstm_parameter) { 17975 CodeBaseStruct("LstmParameter", name, lstm_parameter.op_parameter_, lstm_parameter.input_size_, 17976- lstm_parameter.hidden_size_, lstm_parameter.project_size_, lstm_parameter.seq_len_, 17977- lstm_parameter.batch_, lstm_parameter.output_step_, lstm_parameter.bidirectional_, 17978- lstm_parameter.zoneout_cell_, lstm_parameter.zoneout_hidden_, lstm_parameter.input_row_align_, 17979- lstm_parameter.input_col_align_, lstm_parameter.state_row_align_, lstm_parameter.state_col_align_); 17980+ lstm_parameter.hidden_size_, lstm_parameter.project_size_, lstm_parameter.output_size_, 17981+ lstm_parameter.seq_len_, lstm_parameter.batch_, lstm_parameter.output_step_, 17982+ lstm_parameter.bidirectional_, lstm_parameter.zoneout_cell_, lstm_parameter.zoneout_hidden_, 17983+ lstm_parameter.input_row_align_, lstm_parameter.input_col_align_, lstm_parameter.state_row_align_, 17984+ lstm_parameter.state_col_align_, lstm_parameter.proj_col_align_, lstm_parameter.has_bias_); 17985+} 17986+ 17987+void NNaclFp32Serializer::CodeStruct(const std::string &name, const LstmParameter &lstm_parameter, 17988+ const DynamicLstmParameter &dynamic_lstm_param) { 17989+ CodeBaseStruct("LstmParameter", name, lstm_parameter.op_parameter_, lstm_parameter.input_size_, 17990+ lstm_parameter.hidden_size_, lstm_parameter.project_size_, lstm_parameter.output_size_, 17991+ dynamic_lstm_param.seq_len_, dynamic_lstm_param.batch_, dynamic_lstm_param.output_step_, 17992+ lstm_parameter.bidirectional_, lstm_parameter.zoneout_cell_, lstm_parameter.zoneout_hidden_, 17993+ dynamic_lstm_param.input_row_align_, lstm_parameter.input_col_align_, 17994+ dynamic_lstm_param.state_row_align_, lstm_parameter.state_col_align_, lstm_parameter.proj_col_align_, 17995+ lstm_parameter.has_bias_); 17996 } 17997 17998 void NNaclFp32Serializer::CodeStruct(const std::string &name, const DeQuantArg &de_quant_arg) { 17999@@ -165,6 +240,17 @@ void NNaclFp32Serializer::CodeStruct(const std::string &name, const StridedSlice 18000 strided_slice_parameter.newAxisMask_, strided_slice_parameter.shrinkAxisMask_); 18001 } 18002 18003+void NNaclFp32Serializer::CodeStruct(const std::string &name, const StridedSliceParameter &strided_slice_parameter, 18004+ const StridedSliceDynamicParameter &dynamic_strided_slice_param) { 18005+ CodeBaseStruct<false>("StridedSliceParameter", name, strided_slice_parameter.op_parameter_, 18006+ ToString(strided_slice_parameter.begins_), ToString(strided_slice_parameter.ends_), 18007+ ToString(strided_slice_parameter.strides_), strided_slice_parameter.isScale, 18008+ strided_slice_parameter.in_shape_length_, dynamic_strided_slice_param.in_shape_, 18009+ strided_slice_parameter.num_axes_, strided_slice_parameter.data_type, 18010+ strided_slice_parameter.begins_mask_, strided_slice_parameter.ellipsisMask_, 18011+ strided_slice_parameter.newAxisMask_, strided_slice_parameter.shrinkAxisMask_); 18012+} 18013+ 18014 void NNaclFp32Serializer::CodeStruct(const std::string &name, const ArithmeticWrapperInfo &arithmetic_wrapper_info) { 18015 CodeBaseStruct("ArithmeticWrapperInfo", name, arithmetic_wrapper_info.offset0_, arithmetic_wrapper_info.stride0_, 18016 arithmetic_wrapper_info.offset1_, arithmetic_wrapper_info.stride1_, 18017@@ -207,6 +293,12 @@ void NNaclFp32Serializer::CodeStruct(const std::string &name, const BroadcastSha 18018 ToString(param.output_shape_), param.output_shape_size_); 18019 } 18020 18021+void NNaclFp32Serializer::CodeStruct(const std::string &name, const BroadcastShapeInfo &op_param, 18022+ const BroadcastDynamicShapeInfo &dynamic_param) { 18023+ CodeBaseStruct<false>("BroadcastShapeInfo", name, dynamic_param.input_shape_, op_param.input_shape_size_, 18024+ dynamic_param.output_shape_, op_param.output_shape_size_); 18025+} 18026+ 18027 void NNaclFp32Serializer::CodeStruct(const std::string &name, const CustomGruParameter &op_param) { 18028 CodeBaseStruct<false>("CustomGruParameter", name, op_param.op_parameter_, op_param.num_step, op_param.batch_size, 18029 op_param.input_size, op_param.hidden_size); 18030diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h b/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h 18031index d1435dea..2b1536c6 100644 18032--- a/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h 18033+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h 18034@@ -53,6 +53,15 @@ 18035 #include "nnacl/kernel/pooling.h" 18036 #include "nnacl/kernel/layer_norm.h" 18037 #include "nnacl/kernel/fill.h" 18038+#include "coder/opcoders/nnacl/dynamic_parameter/dynamic_lstm_parameter.h" 18039+#include "coder/opcoders/nnacl/dynamic_parameter/transpose_dynamic_parameter.h" 18040+#include "coder/opcoders/nnacl/dynamic_parameter/slice_dynamic_parameter.h" 18041+#include "coder/opcoders/nnacl/dynamic_parameter/split_dynamic_parameter.h" 18042+#include "coder/opcoders/nnacl/dynamic_parameter/strided_slice_dynamic_parameter.h" 18043+#include "coder/opcoders/nnacl/dynamic_parameter/scale_dynamic_parameter.h" 18044+#include "coder/opcoders/nnacl/dynamic_parameter/conv_dynamic_parameter.h" 18045+#include "coder/opcoders/nnacl/dynamic_parameter/arithmetic_dynamic_parameter.h" 18046+#include "coder/opcoders/nnacl/dynamic_parameter/pooling_dynamic_parameter.h" 18047 18048 namespace mindspore::lite::micro::nnacl { 18049 class NNaclFp32Serializer : public Serializer { 18050@@ -66,6 +75,7 @@ class NNaclFp32Serializer : public Serializer { 18051 void CodeStruct(const std::string &name, const InstanceNormParameter ¶m); 18052 void CodeStruct(const std::string &name, const ArithmeticParameter &arithmetic_parameter); 18053 void CodeStruct(const std::string &name, const ConvParameter &conv_parameter); 18054+ void CodeStruct(const std::string &name, const MatMulParameter &mat_mul_parameter); 18055 void CodeStruct(const std::string &name, const MicroMatmulParameter µ_matmul_parameter); 18056 void CodeStruct(const std::string &name, const LstmParameter &lstm_parameter); 18057 void CodeStruct(const std::string &name, const ScaleStruct &scale_struct); 18058@@ -89,6 +99,24 @@ class NNaclFp32Serializer : public Serializer { 18059 void CodeStruct(const std::string &name, const SlidingWindowParam ¶m); 18060 void CodeStruct(const std::string &name, const UnstackParameter ¶m); 18061 void CodeStruct(const std::string &name, const FillStruct ¶m); 18062+ void CodeStruct(const std::string &name, const TransposeParameter &transpose_param, 18063+ const TransposeDynamicParameter &dynamic_transpose_param); 18064+ void CodeStruct(const std::string &name, const SplitParameter &split_parameter, 18065+ const SplitDynamicParameter &dynamic_split_param); 18066+ void CodeStruct(const std::string &name, const BroadcastShapeInfo ¶m, 18067+ const BroadcastDynamicShapeInfo &dynamic_param); 18068+ void CodeStruct(const std::string &name, const LstmParameter &lstm_param, 18069+ const DynamicLstmParameter &dynamic_lstm_param); 18070+ void CodeStruct(const std::string &name, const SliceParameter &slice_parameter, 18071+ const SliceDynamicParameter &dynamic_slice_param); 18072+ void CodeStruct(const std::string &name, const StridedSliceParameter &strided_slice_parameter, 18073+ const StridedSliceDynamicParameter &dynamic_strided_slice_param); 18074+ void CodeStruct(const std::string &name, const ScaleStruct &scale_struct, 18075+ const ScaleDynamicParameter &dynamic_scale_param); 18076+ void CodeStruct(const std::string &name, const ConvParameter &conv_parameter, 18077+ const ConvDynamicParameter &dynamic_conv_param); 18078+ void CodeStruct(const std::string &name, const PoolingComputeParam &pooling_compute, 18079+ const PoolingDynamicParameter &dynamic_pooling_param); 18080 void CodeStruct(const std::string &name, const int *list, int size); 18081 void CodeArrayStruct(const std::string &name, TensorC *tensorC, std::vector<Tensor *> tensor); 18082 18083diff --git a/mindspore/lite/tools/converter/micro/coder/session.cc b/mindspore/lite/tools/converter/micro/coder/session.cc 18084index 55df7a22..374f662d 100644 18085--- a/mindspore/lite/tools/converter/micro/coder/session.cc 18086+++ b/mindspore/lite/tools/converter/micro/coder/session.cc 18087@@ -75,7 +75,10 @@ int CoderSession::PassArgsToContext(const std::string &model_name) { 18088 context_->set_total_buffer_size(final_total_size); 18089 context_->set_graph_inputs(coder_graph_->input_tensors()); 18090 context_->set_graph_outputs(coder_graph_->output_tensors()); 18091- if (Configurator::GetInstance()->debug_mode()) { 18092+ context_->set_shape_info_container(&shape_info_container_); 18093+ context_->set_dynamic_mem_manager(&dynamic_mem_manager_); 18094+ Configurator *config = Configurator::GetInstance(); 18095+ if (config->debug_mode()) { 18096 std::vector<std::string> blocks; 18097 blocks = AddDumpDataInfo(context_->code_blocks(), op_coders_); 18098 if (blocks.size() == 0) { 18099@@ -100,7 +103,16 @@ int CoderSession::Preprocess() { 18100 Configurator::GetInstance()->changeable_weights_name()); 18101 MS_CHECK_RET_CODE(ret, "assign memory failed"); 18102 18103- // prepare, init model parameters 18104+ if (dynamic_) { 18105+ auto config = Configurator::GetInstance(); 18106+ MS_CHECK_TRUE_MSG(config != nullptr, RET_NULL_PTR, "Config is a nullptr."); 18107+ ret = shape_info_container_.Init(op_coders_, graph_inputs_shape_infos_); 18108+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "Init ShapeInfoContainer failed."); 18109+ auto outputs = coder_graph_->output_tensors(); 18110+ ret = dynamic_mem_manager_.AllocDynamicMem(op_coders_, inputs, outputs, &shape_info_container_); 18111+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "DynamicMemManager AllocDynamicMem failed."); 18112+ } 18113+ // 2. prepare, init model parameters 18114 for (const auto &op_coder : op_coders_) { 18115 MS_CHECK_PTR(op_coder); 18116 MS_LOG(DEBUG) << "prepare: " << op_coder->name(); 18117@@ -133,7 +145,7 @@ int CoderSession::Run(const std::string &model_name) { 18118 ret = PassArgsToContext(model_name); 18119 MS_CHECK_RET_CODE(ret, "PassArgsToContext failed"); 18120 MS_LOG(INFO) << "run opcoders success"; 18121- return RET_OK; 18122+ return ret; 18123 } 18124 18125 int CoderSession::GenerateCode() { 18126@@ -161,6 +173,9 @@ int CoderSession::Init(const void *content, int size, const int model_index, boo 18127 context_ = std::make_unique<CoderContext>(model_index); 18128 context_->set_end_flag(end_flag); 18129 enable_fp16_ = enable_fp16; 18130+ Configurator *config = Configurator::GetInstance(); 18131+ MS_CHECK_TRUE_MSG(config != nullptr, RET_NULL_PTR, "Config is a nullptr."); 18132+ dynamic_ = !config->graph_inputs_shape_infos().empty(); 18133 MS_LOG(INFO) << "CoderSession::Init done"; 18134 return RET_OK; 18135 } 18136@@ -227,6 +242,7 @@ int CoderSession::InitTensorsRef() { 18137 } 18138 } 18139 tensor->set_ref_count(refcount); 18140+ tensor->set_init_ref_count(refcount); 18141 } 18142 return RET_OK; 18143 } 18144@@ -325,6 +341,7 @@ int CoderSession::CreateOpCoders() { 18145 .input_indices(input_indices) 18146 .output_indices(output_indices) 18147 .is_builtin_custom(is_built_in_custom_op) 18148+ .is_dynamic(dynamic_) 18149 .build(schema_version_); 18150 if (op_coder == nullptr) { 18151 coder_graph_->DumpUnSupportLayer(code_target); 18152@@ -348,6 +365,20 @@ int CoderSession::CompileGraph() { 18153 MS_CHECK_RET_CODE(InitCodeGraph(), "InitGraphInOutTensors failed"); 18154 MS_CHECK_RET_CODE(CreateOpCoders(), "CreateOpCoders failed!"); 18155 MS_CHECK_RET_CODE(InitTensorsRef(), "InitTensorsRefcount failed!"); 18156+ if (dynamic_) { 18157+ Configurator::GetInstance()->set_dynamic_shape(true); 18158+ std::vector<lite::Tensor *> inputs = coder_graph_->input_tensors(); 18159+ auto &graph_inputs_shape_infos = Configurator::GetInstance()->graph_inputs_shape_infos(); 18160+ MS_CHECK_TRUE_MSG(inputs.size() == graph_inputs_shape_infos.size(), RET_ERROR, 18161+ "Config graph_inputs_shape's num cannot match."); 18162+ for (size_t i = 0; i < inputs.size(); ++i) { 18163+ graph_inputs_shape_infos_[inputs[i]] = graph_inputs_shape_infos[i]; 18164+ } 18165+ } 18166+ for (auto &op_coder : op_coders_) { 18167+ op_coder->set_shape_info_container(&shape_info_container_); 18168+ op_coder->set_dynamic_mem_manager(&dynamic_mem_manager_); 18169+ } 18170 return RET_OK; 18171 } 18172 CoderSession::~CoderSession() { allocator_->Free(); } 18173diff --git a/mindspore/lite/tools/converter/micro/coder/session.h b/mindspore/lite/tools/converter/micro/coder/session.h 18174index 98a8d008..452e3245 100644 18175--- a/mindspore/lite/tools/converter/micro/coder/session.h 18176+++ b/mindspore/lite/tools/converter/micro/coder/session.h 18177@@ -65,6 +65,10 @@ class CoderSession { 18178 private: 18179 int schema_version_ = SCHEMA_VERSION::SCHEMA_CUR; 18180 bool enable_fp16_{false}; 18181+ bool dynamic_{false}; 18182+ DynamicMemManager dynamic_mem_manager_; 18183+ ShapeInfoContainer shape_info_container_; 18184+ std::map<Tensor *, std::vector<std::vector<int>>> graph_inputs_shape_infos_; 18185 }; 18186 } // namespace mindspore::lite::micro 18187 #endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_SESSION_H_ 18188diff --git a/mindspore/lite/tools/converter/micro/coder/shape_info_container.cc b/mindspore/lite/tools/converter/micro/coder/shape_info_container.cc 18189new file mode 100644 18190index 00000000..c914be6c 18191--- /dev/null 18192+++ b/mindspore/lite/tools/converter/micro/coder/shape_info_container.cc 18193@@ -0,0 +1,131 @@ 18194+/** 18195+ * Copyright 2023 Huawei Technologies Co., Ltd 18196+ * 18197+ * Licensed under the Apache License, Version 2.0 (the "License"); 18198+ * you may not use this file except in compliance with the License. 18199+ * You may obtain a copy of the License at 18200+ * 18201+ * http://www.apache.org/licenses/LICENSE-2.0 18202+ * 18203+ * Unless required by applicable law or agreed to in writing, software 18204+ * distributed under the License is distributed on an "AS IS" BASIS, 18205+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18206+ * See the License for the specific language governing permissions and 18207+ * limitations under the License. 18208+ */ 18209+ 18210+#include "coder/shape_info_container.h" 18211+#include "src/litert/infer_manager.h" 18212+#include "coder/opcoders/op_coder.h" 18213+#include "coder/utils/coder_utils.h" 18214+#include "tools/common/string_util.h" 18215+ 18216+namespace mindspore::lite::micro { 18217+int ShapeInfoContainer::Init(const std::vector<std::unique_ptr<OperatorCoder>> &nodes_coder, 18218+ const std::map<Tensor *, std::vector<std::vector<int>>> &graph_inputs) { 18219+ MS_CHECK_TRUE_MSG(!graph_inputs.empty(), RET_ERROR, "Cannot get graph_inputs's shape-info"); 18220+ auto scene_num = graph_inputs.begin()->second.size(); 18221+ for (const auto &item : graph_inputs) { 18222+ MS_CHECK_TRUE_MSG(item.first, RET_NULL_PTR, "Find a nullptr in graph_inputs"); 18223+ MS_CHECK_TRUE_MSG(item.second.size() == scene_num, RET_ERROR, "Graph inputs are invalid."); 18224+ } 18225+ var_tensor_shapes_.insert(graph_inputs.begin(), graph_inputs.end()); 18226+ for (size_t i = 0; i < scene_num; ++i) { 18227+ for (const auto &item : graph_inputs) { 18228+ item.first->set_shape(item.second[i]); 18229+ } 18230+ for (const auto &node_coder : nodes_coder) { 18231+ auto in_tensors = node_coder->input_tensors(); 18232+ auto out_tensors = node_coder->output_tensors(); 18233+ auto op_param = node_coder->get_parameter(); 18234+ MS_CHECK_TRUE_MSG(op_param, RET_NULL_PTR, "NodeCoder's op_param is a nullptr."); 18235+ auto node = node_coder->node(); 18236+ MS_CHECK_TRUE_MSG(node, RET_NULL_PTR, "NodeCoder's node is a nullptr."); 18237+ auto prim = node->primitive_; 18238+ auto ret = DoInferShape(in_tensors, out_tensors, op_param, prim); 18239+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "ShapeInfoContainer Init failed."); 18240+ } 18241+ } 18242+ auto ret = DetermineShapeVarInfos(); 18243+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "DetermineShapeVarInfos failed."); 18244+ return RET_OK; 18245+} 18246+ 18247+int ShapeInfoContainer::DoInferShape(const std::vector<Tensor *> &in_tensors, std::vector<Tensor *> &out_tensors, 18248+ OpParameter *op_param, const void *primitive) { 18249+ auto ret = KernelInferShape(in_tensors, out_tensors, primitive, {}, lite::SCHEMA_CUR); 18250+ if (ret == lite::RET_NOT_SUPPORT) { 18251+ ret = KernelInferShape(in_tensors, out_tensors, op_param); 18252+ } 18253+ if (ret != RET_OK) { 18254+ MS_LOG(ERROR) << "Infer shape failed."; 18255+ return ret; 18256+ } 18257+ for (const auto out_tensor : out_tensors) { 18258+ var_tensor_shapes_[out_tensor].push_back(out_tensor->shape()); 18259+ } 18260+ return RET_OK; 18261+} 18262+ 18263+int ShapeInfoContainer::DetermineShapeVarInfos() { 18264+ MS_CHECK_TRUE_MSG(kShapePrefixName, RET_NULL_PTR, "kShapePrefixName is a nullptr."); 18265+ int index = 0; 18266+ for (const auto &item : var_tensor_shapes_) { 18267+ auto &tensor = item.first; 18268+ auto &shapes = item.second; 18269+ MS_CHECK_TRUE_MSG(!shapes.empty(), RET_ERROR, "Cannot get some tensor's shape."); 18270+ auto shape = shapes.front(); 18271+ auto dims = shape.size(); 18272+ auto is_same_dim = 18273+ std::all_of(shapes.begin(), shapes.end(), [dims](const std::vector<int> &item) { return item.size() == dims; }); 18274+ MS_CHECK_TRUE_MSG(is_same_dim, RET_ERROR, "Tensor's shape-dims-num are not same."); 18275+ std::vector<std::string> shape_symbols; 18276+ for (size_t i = 0; i < dims; ++i) { 18277+ int dim = shape[i]; 18278+ std::vector<int> real_nums; 18279+ auto is_same_pos = 18280+ std::all_of(shapes.begin(), shapes.end(), [dim, i](const std::vector<int> &item) { return item[i] == dim; }); 18281+ if (is_same_pos) { 18282+ shape_symbols.push_back(std::to_string(dim)); 18283+ continue; 18284+ } 18285+ (void)std::transform(shapes.begin(), shapes.end(), std::back_inserter(real_nums), 18286+ [i](const std::vector<int> &item) { return item[i]; }); 18287+ std::string shape_symbol; 18288+ for (const auto &shape_to_num : shape_to_nums_) { 18289+ if (shape_to_num.second == real_nums) { 18290+ shape_symbol = shape_to_num.first; 18291+ break; 18292+ } 18293+ } 18294+ if (shape_symbol.empty()) { 18295+ for (size_t scene_index = 0; scene_index < real_nums.size(); ++scene_index) { 18296+ shapes_whole_scenes_[scene_index].push_back(real_nums[scene_index]); 18297+ } 18298+ shape_symbol = std::string(kShapePrefixName) + "[" + std::to_string(index++) + "]"; 18299+ shape_to_nums_[shape_symbol] = real_nums; 18300+ } 18301+ shape_symbols.push_back(shape_symbol); 18302+ } 18303+ shape_templates_[tensor] = shape_symbols; 18304+ } 18305+ return RET_OK; 18306+} 18307+ 18308+std::vector<std::string> ShapeInfoContainer::GetTemplateShape(const Tensor *tensor) const { 18309+ if (shape_templates_.find(tensor) == shape_templates_.end()) { 18310+ return {}; 18311+ } 18312+ return shape_templates_.at(tensor); 18313+} 18314+ 18315+std::vector<int> ShapeInfoContainer::GetRealNums(const std::string &shape_var) const { 18316+ if (IsNumber(shape_var)) { 18317+ return {std::stoi(shape_var)}; 18318+ } 18319+ if (shape_to_nums_.find(shape_var) == shape_to_nums_.end()) { 18320+ return {}; 18321+ } 18322+ return shape_to_nums_.at(shape_var); 18323+} 18324+} // namespace mindspore::lite::micro 18325diff --git a/mindspore/lite/tools/converter/micro/coder/shape_info_container.h b/mindspore/lite/tools/converter/micro/coder/shape_info_container.h 18326new file mode 100644 18327index 00000000..9268b249 18328--- /dev/null 18329+++ b/mindspore/lite/tools/converter/micro/coder/shape_info_container.h 18330@@ -0,0 +1,59 @@ 18331+/** 18332+ * Copyright 2023 Huawei Technologies Co., Ltd 18333+ * 18334+ * Licensed under the Apache License, Version 2.0 (the "License"); 18335+ * you may not use this file except in compliance with the License. 18336+ * You may obtain a copy of the License at 18337+ * 18338+ * http://www.apache.org/licenses/LICENSE-2.0 18339+ * 18340+ * Unless required by applicable law or agreed to in writing, software 18341+ * distributed under the License is distributed on an "AS IS" BASIS, 18342+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18343+ * See the License for the specific language governing permissions and 18344+ * limitations under the License. 18345+ */ 18346+ 18347+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_SHAPE_INFO_CONTAINER_H_ 18348+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_SHAPE_INFO_CONTAINER_H_ 18349+ 18350+#include <vector> 18351+#include <string> 18352+#include <map> 18353+#include "tools/converter/micro/coder/config.h" 18354+#include "include/model.h" 18355+#include "src/tensor.h" 18356+#include "nnacl/op_base.h" 18357+ 18358+namespace mindspore::lite::micro { 18359+class OperatorCoder; 18360+class ShapeInfoContainer { 18361+ public: 18362+ ShapeInfoContainer() = default; 18363+ ~ShapeInfoContainer() = default; 18364+ 18365+ int Init(const std::vector<std::unique_ptr<OperatorCoder>> &nodes_coder, 18366+ const std::map<Tensor *, std::vector<std::vector<int>>> &graph_inputs); 18367+ 18368+ const std::map<Tensor *, std::vector<std::vector<int>>> &GetVarTensorInfos() const { return var_tensor_shapes_; } 18369+ 18370+ std::vector<std::string> GetTemplateShape(const Tensor *tensor) const; 18371+ 18372+ const std::map<const Tensor *, std::vector<std::string>> &GetWholeTemplateShape() { return shape_templates_; } 18373+ 18374+ std::vector<int> GetRealNums(const std::string &shape_var) const; 18375+ 18376+ const std::map<int, std::vector<int>> &GetShapesWholeScenes() const { return shapes_whole_scenes_; } 18377+ 18378+ private: 18379+ int DoInferShape(const std::vector<Tensor *> &in_tensors, std::vector<Tensor *> &out_tensors, OpParameter *op_param, 18380+ const void *primitive); 18381+ int DetermineShapeVarInfos(); 18382+ std::map<Tensor *, std::vector<std::vector<int>>> var_tensor_shapes_; 18383+ std::map<const Tensor *, std::vector<std::string>> shape_templates_; 18384+ std::map<std::string, std::vector<int>> shape_to_nums_; 18385+ std::map<int, std::vector<int>> shapes_whole_scenes_; 18386+ Model *model_{nullptr}; 18387+}; 18388+} // namespace mindspore::lite::micro 18389+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_SHAPE_INFO_CONTAINER_H_ 18390diff --git a/mindspore/lite/tools/converter/micro/coder/utils/coder_utils.cc b/mindspore/lite/tools/converter/micro/coder/utils/coder_utils.cc 18391index c86a967d..a4c15c83 100644 18392--- a/mindspore/lite/tools/converter/micro/coder/utils/coder_utils.cc 18393+++ b/mindspore/lite/tools/converter/micro/coder/utils/coder_utils.cc 18394@@ -1,5 +1,5 @@ 18395 /** 18396- * Copyright 2021-2022 Huawei Technologies Co., Ltd 18397+ * Copyright 2021 Huawei Technologies Co., Ltd 18398 * 18399 * Licensed under the Apache License, Version 2.0 (the "License"); 18400 * you may not use this file except in compliance with the License. 18401@@ -22,6 +22,7 @@ 18402 #include "tools/converter/micro/coder/log.h" 18403 #include "tools/converter/micro/coder/utils/type_cast.h" 18404 #include "tools/converter/micro/coder/allocator/allocator.h" 18405+#include "tools/common/string_util.h" 18406 18407 namespace mindspore::lite::micro { 18408 bool CheckConstantTensor(const Tensor *const tensor) { 18409@@ -145,4 +146,36 @@ std::vector<std::string> SplitString(std::string str, const std::string &pattern 18410 } 18411 return results; 18412 } 18413+ 18414+std::string AccumulateShape(const std::vector<std::string> &shape_template, size_t start_index, size_t end_index) { 18415+ int64_t const_part = 1; 18416+ std::string non_const_part; 18417+ for (size_t i = start_index; i < end_index; ++i) { 18418+ auto item = shape_template[i]; 18419+ if (IsNumber(item)) { 18420+ const_part *= std::stoi(item); 18421+ } else { 18422+ if (!non_const_part.empty()) { 18423+ non_const_part += " * "; 18424+ } 18425+ non_const_part += item; 18426+ } 18427+ } 18428+ std::string accumulate_shape = std::to_string(const_part); 18429+ if (!non_const_part.empty()) { 18430+ accumulate_shape += " * " + non_const_part; 18431+ } 18432+ return accumulate_shape; 18433+} 18434+ 18435+std::string GetTensorAddr(lite::Tensor *tensor, bool is_const, DynamicMemManager *dynamic_mem_manager, 18436+ MemoryAllocator *allocator) { 18437+ if (is_const) { 18438+ return allocator->GetRuntimeAddr(tensor, true); 18439+ } 18440+ if (dynamic_mem_manager == nullptr) { 18441+ return allocator->GetRuntimeAddr(tensor); 18442+ } 18443+ return dynamic_mem_manager->GetVarTensorAddr(tensor); 18444+} 18445 } // namespace mindspore::lite::micro 18446diff --git a/mindspore/lite/tools/converter/micro/coder/utils/coder_utils.h b/mindspore/lite/tools/converter/micro/coder/utils/coder_utils.h 18447index eabae70e..70a973cb 100644 18448--- a/mindspore/lite/tools/converter/micro/coder/utils/coder_utils.h 18449+++ b/mindspore/lite/tools/converter/micro/coder/utils/coder_utils.h 18450@@ -41,5 +41,10 @@ std::string ArrayToString(std::vector<T> array) { 18451 std::for_each(array.begin(), array.end(), [&result](const T &t) { result += std::to_string(t) + ", "; }); 18452 return "{" + result + "}"; 18453 } 18454+ 18455+std::string AccumulateShape(const std::vector<std::string> &shape_template, size_t start_index, size_t end_index); 18456+ 18457+std::string GetTensorAddr(lite::Tensor *tensor, bool is_const, DynamicMemManager *dynamic_mem_manager, 18458+ MemoryAllocator *allocator); 18459 } // namespace mindspore::lite::micro 18460 #endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_UTILS_CODER_UTILS_H_ 18461diff --git a/mindspore/lite/tools/converter/micro/coder/utils/type_cast.cc b/mindspore/lite/tools/converter/micro/coder/utils/type_cast.cc 18462index 61b22bae..1d3c02a0 100644 18463--- a/mindspore/lite/tools/converter/micro/coder/utils/type_cast.cc 18464+++ b/mindspore/lite/tools/converter/micro/coder/utils/type_cast.cc 18465@@ -54,32 +54,30 @@ std::string EnumNameDataType(TypeId type) { 18466 std::string EnumNameMSDataType(TypeId type) { 18467 switch (type) { 18468 case kNumberTypeInt: 18469- return "kMSDataTypeNumberTypeInt32"; 18470+ case kNumberTypeInt32: 18471+ return "OH_AI_DATATYPE_NUMBERTYPE_INT32"; 18472 case kNumberTypeInt8: 18473- return "kMSDataTypeNumberTypeInt8"; 18474+ return "OH_AI_DATATYPE_NUMBERTYPE_INT8"; 18475 case kNumberTypeInt16: 18476- return "kMSDataTypeNumberTypeInt16"; 18477- case kNumberTypeInt32: 18478- return "kMSDataTypeNumberTypeInt32"; 18479+ return "OH_AI_DATATYPE_NUMBERTYPE_INT16"; 18480 case kNumberTypeInt64: 18481- return "kMSDataTypeNumberTypeUInt64"; 18482+ return "OH_AI_DATATYPE_NUMBERTYPE_INT64"; 18483 case kNumberTypeUInt: 18484- return "kMSDataTypeNumberTypeUInt32"; 18485+ case kNumberTypeUInt32: 18486+ return "OH_AI_DATATYPE_NUMBERTYPE_UINT32"; 18487 case kNumberTypeUInt8: 18488- return "kMSDataTypeNumberTypeUInt8"; 18489+ return "OH_AI_DATATYPE_NUMBERTYPE_UINT8"; 18490 case kNumberTypeUInt16: 18491- return "kMSDataTypeNumberTypeUInt16"; 18492- case kNumberTypeUInt32: 18493- return "kMSDataTypeNumberTypeUInt32"; 18494+ return "OH_AI_DATATYPE_NUMBERTYPE_UINT16"; 18495 case kNumberTypeFloat: 18496 case kNumberTypeFloat32: 18497- return "kMSDataTypeNumberTypeFloat32"; 18498+ return "OH_AI_DATATYPE_NUMBERTYPE_FLOAT32"; 18499 case kNumberTypeFloat16: 18500- return "kMSDataTypeNumberTypeFloat16"; 18501+ return "OH_AI_DATATYPE_NUMBERTYPE_FLOAT16"; 18502 case kNumberTypeFloat64: 18503- return "kMSDataTypeNumberTypeFloat64"; 18504+ return "OH_AI_DATATYPE_NUMBERTYPE_FLOAT64"; 18505 case kTypeUnknown: 18506- return "kMSDataTypeUnknown"; 18507+ return "OH_AI_DATATYPE_UNKNOWN"; 18508 default: 18509 return "unsupported"; 18510 } 18511diff --git a/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.cc b/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.cc 18512index 652db4af..a82feb07 100644 18513--- a/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.cc 18514+++ b/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.cc 18515@@ -62,7 +62,7 @@ STATUS ThirdPartyModelParser::InitConfig(const std::string &config_file) { 18516 MS_LOG(ERROR) << "Missing config file in converting third party model"; 18517 return RET_ERROR; 18518 } 18519- auto ret = config_parser.ParseConfigFile(config_file); 18520+ auto ret = config_parser.ParseConfigFile(config_file, nullptr); 18521 if (ret != RET_OK) { 18522 MS_LOG(ERROR) << "Get third party model section from config file failed"; 18523 return RET_ERROR; 18524diff --git a/mindspore/lite/tools/optimizer/fusion/tile_matmul_fusion.cc b/mindspore/lite/tools/optimizer/fusion/tile_matmul_fusion.cc 18525new file mode 100644 18526index 00000000..4caef237 18527--- /dev/null 18528+++ b/mindspore/lite/tools/optimizer/fusion/tile_matmul_fusion.cc 18529@@ -0,0 +1,120 @@ 18530+/** 18531+ * Copyright 2023 Huawei Technologies Co., Ltd 18532+ * 18533+ * Licensed under the Apache License, Version 2.0 (the "License"); 18534+ * you may not use this file except in compliance with the License. 18535+ * You may obtain a copy of the License at 18536+ * 18537+ * http://www.apache.org/licenses/LICENSE-2.0 18538+ * 18539+ * Unless required by applicable law or agreed to in writing, software 18540+ * distributed under the License is distributed on an "AS IS" BASIS, 18541+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18542+ * See the License for the specific language governing permissions and 18543+ * limitations under the License. 18544+ */ 18545+ 18546+#define USE_DEPRECATED_API 18547+#include "tools/optimizer/fusion/tile_matmul_fusion.h" 18548+#include <memory> 18549+#include "tools/optimizer/common/gllo_utils.h" 18550+#include "nnacl/op_base.h" 18551+#include "tools/lite_exporter/fetch_content.h" 18552+#include "ops/op_utils.h" 18553+#include "ops/lite_ops.h" 18554+#include "ops/fusion/tile_fusion.h" 18555+#include "ops/fusion/mat_mul_fusion.h" 18556+ 18557+namespace mindspore { 18558+namespace opt { 18559+bool TileMatMulFusion::CheckCanFuse(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const { 18560+ auto tile_cnode = node->cast<CNodePtr>(); 18561+ MS_CHECK_TRUE_RET(tile_cnode != nullptr, false); 18562+ auto tile_primc = ops::GetOperator<ops::TileFusion>(tile_cnode->input(0)); 18563+ MS_CHECK_TRUE_RET(tile_primc != nullptr, false); 18564+ auto tile_prim_c = tile_primc->GetPrim(); 18565+ MS_CHECK_TRUE_RET(tile_prim_c != nullptr, false); 18566+ if (IsQuantParameterNode(tile_prim_c)) { 18567+ MS_LOG(INFO) << tile_primc->name() << " is quant node"; 18568+ return false; 18569+ } 18570+ auto manager = func_graph->manager(); 18571+ MS_CHECK_TRUE_RET(manager != nullptr, false); 18572+ auto node_users = manager->node_users()[tile_cnode]; 18573+ for (auto &node_user : node_users) { 18574+ auto post_node = node_user.first; 18575+ auto post_node_index = node_user.second; 18576+ if (!utils::isa<CNode>(post_node) || !CheckPrimitiveType(post_node, prim::kPrimMatMulFusion) || 18577+ post_node_index != C2NUM) { 18578+ MS_LOG(INFO) << "The post node of tile must be matmul's matirxB."; 18579+ return false; 18580+ } 18581+ auto matmul_primc = ops::GetOperator<ops::MatMulFusion>(GetInputs(post_node).at(0)); 18582+ MS_CHECK_TRUE_RET(matmul_primc != nullptr, false); 18583+ auto matmul_prim_c = matmul_primc->GetPrim(); 18584+ MS_CHECK_TRUE_RET(matmul_prim_c != nullptr, false); 18585+ if (IsQuantParameterNode(matmul_prim_c)) { 18586+ MS_LOG(INFO) << matmul_prim_c->name() << " is quant node"; 18587+ return false; 18588+ } 18589+ } 18590+ 18591+ lite::DataInfo data_info; 18592+ auto status = lite::FetchConstData(tile_cnode, C2NUM, converter::kFmkTypeMs, &data_info, false); 18593+ MS_CHECK_TRUE_MSG(status == RET_OK, false, "Fetch tile_cnode third input's const data failed."); 18594+ if ((data_info.data_type_ != kNumberTypeInt32 && data_info.data_type_ != kNumberTypeInt) || 18595+ data_info.data_.size() / sizeof(int) < DIMENSION_2D) { 18596+ MS_LOG(INFO) << "Tile index data is invalid."; 18597+ return false; 18598+ } 18599+ auto data = reinterpret_cast<int *>(data_info.data_.data()); 18600+ int dim = static_cast<int>(data_info.data_.size() / sizeof(int)); 18601+ for (int i = dim - C1NUM; i > dim - C3NUM; --i) { 18602+ if (data[i] != C1NUM) { 18603+ return false; 18604+ } 18605+ } 18606+ lite::DataInfo weights_info; 18607+ auto left_pre_node = tile_cnode->input(C1NUM); 18608+ if (left_pre_node->isa<Parameter>() || left_pre_node->isa<ValueNode>()) { 18609+ status = lite::FetchConstData(tile_cnode, C1NUM, converter::kFmkTypeMs, &weights_info, false); 18610+ } else { 18611+ status = lite::FetchDataFromCNode(tile_cnode, C1NUM, &weights_info); 18612+ } 18613+ MS_CHECK_TRUE_RET(status == RET_OK, false); 18614+ MS_CHECK_TRUE_MSG(weights_info.shape_.size() == static_cast<size_t>(dim), false, 18615+ "Tile_cnode second input's shape size is invalid."); 18616+ for (int i = 0; i < dim - C2NUM; i++) { 18617+ if (data[i] != C1NUM && weights_info.shape_[i] != C1NUM) { 18618+ return false; 18619+ } 18620+ } 18621+ return true; 18622+} 18623+ 18624+bool TileMatMulFusion::Run(const FuncGraphPtr &func_graph) { 18625+ MS_CHECK_TRUE_RET(func_graph != nullptr, false); 18626+ auto node_list = TopoSort(func_graph->get_return()); 18627+ for (auto &node : node_list) { 18628+ MS_CHECK_TRUE_RET(node != nullptr, false); 18629+ if (!utils::isa<CNode>(node)) { 18630+ continue; 18631+ } 18632+ if (!CheckPrimitiveType(node, prim::kPrimTileFusion)) { 18633+ continue; 18634+ } 18635+ if (!CheckCanFuse(func_graph, node)) { 18636+ continue; 18637+ } 18638+ auto tile_cnode = node->cast<CNodePtr>(); 18639+ MS_CHECK_TRUE_RET(tile_cnode != nullptr, false); 18640+ auto left_pre_node = tile_cnode->input(SECOND_INPUT); 18641+ auto manage = func_graph->manager(); 18642+ MS_CHECK_TRUE_RET(manage != nullptr, false); 18643+ auto success = manage->Replace(tile_cnode, left_pre_node); 18644+ MS_CHECK_TRUE_MSG(success, false, "Replace old node failed."); 18645+ } 18646+ return true; 18647+} 18648+} // namespace opt 18649+} // namespace mindspore 18650diff --git a/mindspore/lite/tools/optimizer/fusion/tile_matmul_fusion.h b/mindspore/lite/tools/optimizer/fusion/tile_matmul_fusion.h 18651new file mode 100644 18652index 00000000..280dc265 18653--- /dev/null 18654+++ b/mindspore/lite/tools/optimizer/fusion/tile_matmul_fusion.h 18655@@ -0,0 +1,37 @@ 18656+/** 18657+ * Copyright 2023 Huawei Technologies Co., Ltd 18658+ * 18659+ * Licensed under the Apache License, Version 2.0 (the "License"); 18660+ * you may not use this file except in compliance with the License. 18661+ * You may obtain a copy of the License at 18662+ * 18663+ * http://www.apache.org/licenses/LICENSE-2.0 18664+ * 18665+ * Unless required by applicable law or agreed to in writing, software 18666+ * distributed under the License is distributed on an "AS IS" BASIS, 18667+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18668+ * See the License for the specific language governing permissions and 18669+ * limitations under the License. 18670+ */ 18671+ 18672+#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TILE_MATMUL_FUSION_H_ 18673+#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TILE_MATMUL_FUSION_H_ 18674+ 18675+#include <string> 18676+#include "tools/optimizer/common/multiple_pattern_process_pass.h" 18677+#include "utils/check_convert_utils.h" 18678+ 18679+namespace mindspore { 18680+namespace opt { 18681+class TileMatMulFusion : public Pass { 18682+ public: 18683+ TileMatMulFusion() : Pass("TileMatMulFusion") {} 18684+ ~TileMatMulFusion() override = default; 18685+ bool Run(const FuncGraphPtr &func_graph) override; 18686+ 18687+ private: 18688+ bool CheckCanFuse(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const; 18689+}; 18690+} // namespace opt 18691+} // namespace mindspore 18692+#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TILE_MATMUL_FUSION_H_ 18693diff --git a/mindspore/python/mindspore/ops/operations/_grad_ops.py b/mindspore/python/mindspore/ops/operations/_grad_ops.py 18694index 59c9c883..5714b832 100644 18695--- a/mindspore/python/mindspore/ops/operations/_grad_ops.py 18696+++ b/mindspore/python/mindspore/ops/operations/_grad_ops.py 18697@@ -1521,7 +1521,7 @@ class LSTMGrad(Primitive): 18698 """Computes the data and weight gradients of LSTM.""" 18699 18700 @prim_attr_register 18701- def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): 18702+ def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout, proj_size=0): 18703 self.input_size = validator.check_positive_int(input_size, 'input_size', self.name) 18704 self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name) 18705 self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name) 18706@@ -1529,12 +1529,53 @@ class LSTMGrad(Primitive): 18707 self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name) 18708 self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) 18709 self.dropout = validator.check_float_range(dropout, 0, 1, validator.INC_BOTH, 'dropout', self.name) 18710+ self.proj_size = validator.check_int_range(proj_size, 0, hidden_size, Rel.INC_LEFT, 18711+ 'proj_size', self.name) 18712+ 18713 18714 if bidirectional: 18715 self.num_directions = 2 18716 else: 18717 self.num_directions = 1 18718 18719+ def infer_shape(self, x_shape, hx_shape, cx_shape, w_shape, y_shape, hy_shape, cy_shape, dy_shape, dhy_shape, 18720+ dcy_shape, reserve_shape): 18721+ # dhy and dcy should be same shape 18722+ validator.check_equal_int(len(dhy_shape), 3, "h_shape", self.name) 18723+ validator.check_equal_int(len(dhy_shape), len(dcy_shape), "h_shape", self.name) 18724+ if self.proj_size == 0: 18725+ validator.check_equal_int(dhy_shape[0], dcy_shape[0], "h_shape[0]", self.name) 18726+ validator.check_equal_int(dhy_shape[1], dcy_shape[1], "h_shape[1]", self.name) 18727+ validator.check_equal_int(dhy_shape[2], dcy_shape[2], "h_shape[2]", self.name) 18728+ 18729+ real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size 18730+ validator.check_int(dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, "h_shape[0]", self.name) 18731+ validator.check_equal_int(dhy_shape[2], real_hidden_size, "h_shape[2]", self.name) 18732+ 18733+ validator.check_equal_int(len(dy_shape), 3, "dy_shape", self.name) 18734+ validator.check_equal_int(dy_shape[1], dhy_shape[1], "dy[1]", self.name) 18735+ validator.check_int(dy_shape[2], real_hidden_size * self.num_directions, Rel.EQ, "dy[2]", self.name) 18736+ 18737+ dx_shape = (y_shape[0], y_shape[1], self.input_size) 18738+ dhx_shape = dhy_shape 18739+ dcx_shape = dcy_shape 18740+ weight_size = 0 18741+ gate_size = 4 * self.hidden_size 18742+ for layer in range(self.num_layers): 18743+ for _ in range(self.num_directions): 18744+ input_layer_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions 18745+ weight_size += gate_size * input_layer_size 18746+ weight_size += gate_size * real_hidden_size 18747+ if self.proj_size > 0: 18748+ weight_size += self.proj_size * self.hidden_size 18749+ if self.has_bias: 18750+ weight_size += gate_size 18751+ 18752+ return (dx_shape, dhx_shape, dcx_shape, (weight_size, 1, 1)) 18753+ 18754+ def infer_dtype(self, x_dtype, hx_dtype, cx_dtype, w_dtype, y_dtype, hy_dtype, cy_dtype, dy_dtype, dhy_dtype, 18755+ dcy_dtype, reserve_dtype): 18756+ return (dy_dtype, dy_dtype, dy_dtype, hx_dtype) 18757 18758 class DynamicRNNGrad(Primitive): 18759 """Computes the input gradients of DynamicRNN.""" 18760diff --git a/mindspore/python/mindspore/ops/operations/nn_ops.py b/mindspore/python/mindspore/ops/operations/nn_ops.py 18761index 3a0eb3d6..8ae747be 100644 18762--- a/mindspore/python/mindspore/ops/operations/nn_ops.py 18763+++ b/mindspore/python/mindspore/ops/operations/nn_ops.py 18764@@ -4356,7 +4356,7 @@ class LSTM(Primitive): 18765 """ 18766 18767 @prim_attr_register 18768- def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): 18769+ def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout, proj_size=0): 18770 """Initialize LSTM.""" 18771 self.input_size = validator.check_positive_int(input_size, "input_size", self.name) 18772 self.hidden_size = validator.check_positive_int(hidden_size, "hidden_size", self.name) 18773@@ -4365,12 +4365,40 @@ class LSTM(Primitive): 18774 self.bidirectional = validator.check_value_type("bidirectional", bidirectional, (bool,), self.name) 18775 self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) 18776 self.dropout = validator.check_float_range(dropout, 0, 1, validator.INC_BOTH, 'dropout', self.name) 18777+ self.proj_size = validator.check_int_range(proj_size, 0, hidden_size, validator.INC_LEFT, 18778+ 'proj_size', self.name) 18779 18780 if bidirectional: 18781 self.num_directions = 2 18782 else: 18783 self.num_directions = 1 18784 18785+ def infer_shape(self, x_shape, h_shape, c_shape, w_shape): 18786+ validator.check_equal_int(len(x_shape), 3, "x rank", self.name) 18787+ validator.check_equal_int(x_shape[2], self.input_size, "x[2]", self.name) 18788+ 18789+ # h and c should be same shape 18790+ validator.check_equal_int(len(h_shape), 3, "h rank", self.name) 18791+ if self.proj_size == 0: 18792+ validator.check("h_shape", h_shape, "c_shape", c_shape, Rel.EQ, self.name) 18793+ 18794+ real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size 18795+ validator.check_int(h_shape[0], self.num_layers * self.num_directions, Rel.EQ, "h[0]", self.name) 18796+ validator.check_equal_int(h_shape[1], x_shape[1], "h[1]", self.name) 18797+ validator.check_int(h_shape[2], real_hidden_size, Rel.EQ, "h[2]", self.name) 18798+ 18799+ y_shape = (x_shape[0], x_shape[1], real_hidden_size * self.num_directions) 18800+ 18801+ # set arbitrary shape for reserved space 18802+ reserved_shape = (1, 1) 18803+ state_shape = (1, 1) 18804+ return y_shape, h_shape, c_shape, reserved_shape, state_shape 18805+ 18806+ def infer_dtype(self, x_dtype, h_dtype, c_dtype, w_dtype): 18807+ args = {'x': x_dtype, 'h': h_dtype, 'c': c_dtype, 'w': w_dtype} 18808+ validator.check_tensors_dtypes_same_and_valid(args, (mstype.float32, mstype.float16), self.name) 18809+ return x_dtype, x_dtype, x_dtype, x_dtype, x_dtype 18810+ 18811 18812 class SigmoidCrossEntropyWithLogits(Primitive): 18813 r""" 18814