1diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn 2index a1e7908e..7bbc3782 100644 3--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn 4+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn 5@@ -714,7 +714,17 @@ arm32_assembly_sources = [ 6 "assembly/arm32/WinogradTransRight.S", 7 ] 8 9-fp16_assembly_sources = [ 10+arm32_fp16_assembly_sources = [ 11+ "assembly/arm82_aarch32_fp16/Float16Tofloat32.S", 12+ "assembly/arm82_aarch32_fp16/Float32ToFloat16.S", 13+ "assembly/arm82_aarch32_fp16/Matmul12x8Fp16.S", 14+ "assembly/arm82_aarch32_fp16/MatVecMulFp16.S", 15+ "assembly/arm82_aarch32_fp16/TiledC4MatmulFp16.S", 16+ "assembly/arm82_aarch32_fp16/WinogradTransLeft.S", 17+ "assembly/arm82_aarch32_fp16/WinogradTransRight.S", 18+] 19+ 20+arm64_fp16_assembly_sources = [ 21 "assembly/fp16/CalculateMinMaxFp16Count8.S", 22 "assembly/fp16/ConvDwFp16Border.S", 23 "assembly/fp16/ConvDwFp16Center.S", 24@@ -839,11 +849,13 @@ nnacl_sources += infer_control_sources 25 26 # source files on arm32 27 arm_only_sources = arm32_assembly_sources 28+#arm_only_sources += arm32_fp16_assembly_sources 29+not_needed(arm32_fp16_assembly_sources) 30 31 # source files on arm64 32 arm64_only_sources = fp16_kernel_sources 33 arm64_only_sources += fp16_grad_sources 34-arm64_only_sources += fp16_assembly_sources 35+arm64_only_sources += arm64_fp16_assembly_sources 36 arm64_only_sources += arm64_assembly_sources 37 arm64_only_sources += optimizing_assembly_sources 38 arm64_only_sources += arm64_fp32_kernel_sources 39diff --git a/mindspore/lite/BUILD.gn b/mindspore/lite/BUILD.gn 40index 6d83e6f9..9d9c299f 100644 41--- a/mindspore/lite/BUILD.gn 42+++ b/mindspore/lite/BUILD.gn 43@@ -74,10 +74,12 @@ import("//build/ohos.gni") 44 ohos_group("mindspore") { 45 deps = [ 46 ":mindspore_lib", 47+ ":mindspore_train_lib", 48 "mindir:mindir_lib", 49 ] 50 } 51 52+# Inference library 53 cxx_api_sources = [ 54 "src/litert/cxx_api/cell.cc", 55 "src/litert/cxx_api/context.cc", 56@@ -429,7 +431,6 @@ ohos_shared_library("mindspore_lib") { 57 SUPPORT_NNRT = true 58 if (SUPPORT_NNRT) { 59 sources += [ 60- # "mindir/src/mindir_nnrt_lite_graph.cc", 61 "src/litert/delegate/nnrt/checker/primitive_check.cc", 62 "src/litert/delegate/nnrt/nnrt_delegate.cc", 63 "src/litert/delegate/nnrt/nnrt_model_kernel.cc", 64@@ -444,8 +445,9 @@ ohos_shared_library("mindspore_lib") { 65 external_deps += [ "neural_network_runtime:nnrt_target" ] 66 deps += [ "mindir:mindir_lib" ] 67 defines += [ "SUPPORT_NNRT" ] 68- defines += [ "MSLITE_ENABLE_EXPERIMENTAL_KERNEL" ] 69 } 70+ defines += [ "MSLITE_ENABLE_EXPERIMENTAL_KERNEL" ] 71+ defines += [ "SUPPORT_TRAIN" ] 72 cflags_cc = [ 73 "-Wno-ignored-qualifiers", 74 "-Wunused-private-field", 75@@ -458,6 +460,224 @@ ohos_shared_library("mindspore_lib") { 76 subsystem_name = "thirdparty" 77 } 78 79+# Train library 80+expression_cxx_api_sources = [ 81+ "src/litert/cxx_api/expression/net.cc", 82+ "src/litert/cxx_api/expression/net_impl.cc", 83+ "src/litert/cxx_api/expression/node_impl.cc", 84+] 85+ 86+expression_op_sources = [ 87+ "src/expression/ops/activation.cc", 88+ "src/expression/ops/adam.cc", 89+ "src/expression/ops/addn.cc", 90+ "src/expression/ops/arithmetic.cc", 91+ "src/expression/ops/arithmetic_self.cc", 92+ "src/expression/ops/assign.cc", 93+ "src/expression/ops/batchnorm.cc", 94+ "src/expression/ops/biasadd.cc", 95+ "src/expression/ops/conv.cc", 96+ "src/expression/ops/dense.cc", 97+ "src/expression/ops/depend.cc", 98+ "src/expression/ops/dropout.cc", 99+ "src/expression/ops/flatten.cc", 100+ "src/expression/ops/pooling.cc", 101+ "src/expression/ops/reduce.cc", 102+ "src/expression/ops/reshape.cc", 103+ "src/expression/ops/softmax.cc", 104+ "src/expression/ops/softmaxCE.cc", 105+ "src/expression/ops/tile.cc", 106+ "src/expression/ops/transpose.cc", 107+] 108+ 109+all_expression_sources = [ 110+ "src/expression/export.cc", 111+ "src/expression/expr.cc", 112+ "src/expression/import.cc", 113+ "src/expression/net.cc", 114+ "src/expression/node.cc", 115+ "src/expression/ops.cc", 116+ "src/expression/ops_utils.cc", 117+ "src/expression/param.cc", 118+ "src/expression/sequential.cc", 119+] 120+ 121+all_expression_sources += expression_cxx_api_sources 122+all_expression_sources += expression_op_sources 123+ 124+all_train_sources = [ 125+ # ${API_TRAIN_SRC} is empty. 126+ # ${TRAIN_SRC_WITH_MD} is empty. 127+ "src/common/quant_utils.cc", 128+ "src/litert/cxx_api/metrics/accuracy.cc", 129+ "src/litert/cxx_api/train/model_build.cc", 130+ "src/litert/cxx_api/train/model_build_impl.cc", 131+ "src/litert/cxx_api/train/converters.cc", 132+ "src/litert/cxx_api/train/train_support.cc", 133+ "src/train/train_populate_parameter.cc", 134+ "src/train/train_session.cc", 135+ "src/train/graph_fusion.cc", 136+ "src/train/graph_dropout.cc", 137+ "src/train/transfer_session.cc", 138+ "src/train/train_utils.cc", 139+ "src/train/loss_monitor.cc", 140+ "src/train/lr_scheduler.cc", 141+ "src/train/accuracy_metrics.cc", 142+# "src/train/accuracy_monitor.cc", # depends on minddata header, not compiled 143+ "src/train/classification_train_accuracy_monitor.cc", 144+ "src/train/train_export.cc", 145+ "src/train/opt_allocator.cc", 146+ "src/train/optimizer/common/fusion_utils.cc", 147+ "src/train/optimizer/fusion/matmul_activation_fusion_pass.cc", 148+ "src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.cc", 149+ "src/train/optimizer/fusion/gru_fusion_pass.cc", 150+ "src/common/storage.cc", 151+ "tools/converter/optimizer.cc", 152+ "tools/converter/legacy_optimizer/fusion/fusion_pass.cc", 153+ "tools/converter/legacy_optimizer/fusion/fusion_pattern.cc", 154+ "tools/common/meta_graph_utils.cc", 155+ "tools/common/statistic_utils.cc", 156+ "tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.cc", 157+ "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.cc", 158+ "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.cc", 159+ "tools/converter/legacy_optimizer/graph/subgraph_node_pass.cc", 160+] 161+ 162+all_train_sources += all_expression_sources 163+ 164+fp16_train_kernel_sources = [ 165+ "src/litert/kernel/cpu/fp16_grad/activation_fp16_grad.cc", 166+ "src/litert/kernel/cpu/fp16_grad/arithmetic_fp16_grad.cc", 167+ "src/litert/kernel/cpu/fp16_grad/arithmetic_fp16_self_grad.cc", 168+ "src/litert/kernel/cpu/fp16_grad/bias_fp16_grad.cc", 169+ "src/litert/kernel/cpu/fp16_grad/bn_fp16_grad.cc", 170+ "src/litert/kernel/cpu/fp16_grad/convolution_fp16_grad_filter.cc", 171+ "src/litert/kernel/cpu/fp16_grad/convolution_fp16_grad_input.cc", 172+ "src/litert/kernel/cpu/fp16_grad/dropout_fp16_grad.cc", 173+ "src/litert/kernel/cpu/fp16_grad/layernorm_fp16_grad.cc", 174+ "src/litert/kernel/cpu/fp16_grad/neg_fp16_grad.cc", 175+ "src/litert/kernel/cpu/fp16_grad/pooling_fp16_grad.cc", 176+ "src/litert/kernel/cpu/fp16_grad/resize_fp16_grad.cc", 177+ "src/litert/kernel/cpu/fp16_grad/strided_slice_fp16_grad.cc", 178+ "src/litert/kernel/cpu/fp16_grad/unsorted_segment_sum_fp16.cc", 179+] 180+ 181+fp32_train_kernel_sources = [ 182+ "src/litert/kernel/cpu/fp32_grad/activation_grad.cc", 183+ "src/litert/kernel/cpu/fp32_grad/adam.cc", 184+ "src/litert/kernel/cpu/fp32_grad/adam_weight_decay.cc", 185+ "src/litert/kernel/cpu/fp32_grad/apply_momentum.cc", 186+ "src/litert/kernel/cpu/fp32_grad/arithmetic_grad.cc", 187+ "src/litert/kernel/cpu/fp32_grad/arithmetic_self_grad.cc", 188+ "src/litert/kernel/cpu/fp32_grad/assign.cc", 189+ "src/litert/kernel/cpu/fp32_grad/bias_grad.cc", 190+ "src/litert/kernel/cpu/fp32_grad/bn_grad.cc", 191+ "src/litert/kernel/cpu/fp32_grad/convolution.cc", 192+ "src/litert/kernel/cpu/fp32_grad/convolution_grad_filter.cc", 193+ "src/litert/kernel/cpu/fp32_grad/convolution_grad_input.cc", 194+ "src/litert/kernel/cpu/fp32_grad/deconvolution_grad_filter.cc", 195+ "src/litert/kernel/cpu/fp32_grad/dropout.cc", 196+ "src/litert/kernel/cpu/fp32_grad/dropout_grad.cc", 197+ "src/litert/kernel/cpu/fp32_grad/layernorm_grad.cc", 198+ "src/litert/kernel/cpu/fp32_grad/lstm_grad_data_fp32.cc", 199+ "src/litert/kernel/cpu/fp32_grad/lstm_grad_fp32.cc", 200+ "src/litert/kernel/cpu/fp32_grad/lstm_grad_weight_fp32.cc", 201+ "src/litert/kernel/cpu/fp32_grad/neg_grad.cc", 202+ "src/litert/kernel/cpu/fp32_grad/nllloss_grad.cc", 203+ "src/litert/kernel/cpu/fp32_grad/pooling_grad.cc", 204+ "src/litert/kernel/cpu/fp32_grad/power_grad.cc", 205+ "src/litert/kernel/cpu/fp32_grad/resize_grad.cc", 206+ "src/litert/kernel/cpu/fp32_grad/sgd.cc", 207+ "src/litert/kernel/cpu/fp32_grad/sigmoid_cross_entropy_with_logits.cc", 208+ "src/litert/kernel/cpu/fp32_grad/sigmoid_cross_entropy_with_logits_grad.cc", 209+ "src/litert/kernel/cpu/fp32_grad/smooth_l1_loss.cc", 210+ "src/litert/kernel/cpu/fp32_grad/smooth_l1_loss_grad.cc", 211+ "src/litert/kernel/cpu/fp32_grad/softmax_cross_entropy_with_logits.cc", 212+ "src/litert/kernel/cpu/fp32_grad/softmax_grad.cc", 213+ "src/litert/kernel/cpu/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc", 214+ "src/litert/kernel/cpu/fp32_grad/strided_slice_grad.cc", 215+ "src/litert/kernel/cpu/fp32_grad/unsorted_segment_sum.cc", 216+ "src/litert/kernel/cpu/fp32_grad/binary_cross_entropy.cc", 217+ "src/litert/kernel/cpu/fp32_grad/binary_cross_entropy_grad.cc", 218+] 219+ 220+#all_train_sources += fp16_train_kernel_sources 221+not_needed(fp16_train_kernel_sources) 222+all_train_sources += fp32_train_kernel_sources 223+ 224+ohos_shared_library("mindspore_train_lib") { 225+ deps = [ 226+ ":mindspore_lib", 227+ ] 228+ 229+ sources = all_train_sources 230+ 231+ include_dirs = [ 232+ "//base/hiviewdfx/hilog/interfaces/native/innerkits/include", 233+ "//third_party/flatbuffers/include", 234+ "./", 235+ "../", 236+ "../../", 237+ "../core", 238+ "src", 239+ "src/c_api/", 240+ "../ccsrc/plugin/device/cpu/kernel/", 241+ "../core/mindrt/src/", 242+ "../core/mindrt/include/", 243+ "../../third_party/", 244+ "./schema/", 245+ "../ccsrc/", 246+ ] 247+ 248+ defines = [ 249+ "ENABLE_MINDRT", 250+ "MS_COMPILE_OHOS", 251+ "PRIMITIVE_WRITEABLE", 252+ "VERSION_STR=\"2.1.0\"", 253+ ] 254+ 255+ if (target_cpu == "arm") { 256+ defines += [ 257+ "ENABLE_ARM", 258+ "ENABLE_ARM32", 259+ "ENABLE_NEON", 260+ ] 261+ } else if (target_cpu == "arm64") { 262+ defines += [ 263+ "ENABLE_ARM", 264+ "ENABLE_ARM64", 265+ "ENABLE_NEON", 266+ "ENABLE_FP16", 267+ "USE_OPENCL_WRAPPER", 268+ "MS_OPENCL_PROFILE=false", 269+ "CL_TARGET_OPENCL_VERSION=200", 270+ "CL_HPP_TARGET_OPENCL_VERSION=120", 271+ "CL_HPP_MINIMUM_OPENCL_VERSION=120", 272+ ] 273+ } 274+ configs = [ 275+ ":mindspore_api", 276+ ":disable_android", 277+ ":train_kernel_option", 278+ ] 279+ 280+ remove_configs = [ "//build/config/compiler:no_rtti" ] 281+ external_deps = [ "hilog:libhilog" ] 282+ output_name = "libmindspore-lite-train" 283+ output_extension = "so" 284+ defines += [ "SUPPORT_TRAIN" ] 285+ cflags_cc = [ 286+ "-Wno-ignored-qualifiers", 287+ "-Wunused-private-field", 288+ "-Wno-unused-private-field", 289+ "-Wno-inconsistent-missing-override", 290+ "-Wno-macro-redefined", 291+ "-Wno-constant-conversion", 292+ ] 293+ part_name = "mindspore" 294+} 295+ 296+# Build configurations 297 config("opencl_option") { 298 cflags_cc = [ "-Wno-missing-braces" ] 299 } 300@@ -482,3 +702,7 @@ config("disable_android") { 301 config("secure_option") { 302 cflags = [ "-fstack-protector-all" ] 303 } 304+ 305+config("train_kernel_option") { 306+ cflags_cc = [ "-fno-finite-math-only" ] 307+} 308diff --git a/mindspore/lite/include/registry/opencl_runtime_wrapper.h b/mindspore/lite/include/registry/opencl_runtime_wrapper.h 309index fb647d37..b55554e4 100644 310--- a/mindspore/lite/include/registry/opencl_runtime_wrapper.h 311+++ b/mindspore/lite/include/registry/opencl_runtime_wrapper.h 312@@ -1,5 +1,5 @@ 313 /** 314- * Copyright 2021 Huawei Technologies Co., Ltd 315+ * Copyright 2021-2023 Huawei Technologies Co., Ltd 316 * 317 * Licensed under the Apache License, Version 2.0 (the "License"); 318 * you may not use this file except in compliance with the License. 319diff --git a/mindspore/lite/src/litert/kernel/cpu/BUILD.gn b/mindspore/lite/src/litert/kernel/cpu/BUILD.gn 320index b34e0427..48308425 100644 321--- a/mindspore/lite/src/litert/kernel/cpu/BUILD.gn 322+++ b/mindspore/lite/src/litert/kernel/cpu/BUILD.gn 323@@ -112,6 +112,7 @@ cpu_kernel_sources = [ 324 "fp32/uniform_real_fp32.cc", 325 "fp32/unstack_fp32.cc", 326 "fp32/where_fp32.cc", 327+ "fp32/oneslike_fp32.cc", 328 "fp32/online_fusion/cast_gather_reduce_fp32.cc", 329 "fp32/online_fusion/reduce_concat_fp32.cc", 330 "fp32/online_fusion/split_reduce_concat_fp32.cc", 331diff --git a/mindspore/lite/src/litert/kernel/cpu/fp32/oneslike_fp32.cc b/mindspore/lite/src/litert/kernel/cpu/fp32/oneslike_fp32.cc 332new file mode 100644 333index 00000000..b4c3bf7e 334--- /dev/null 335+++ b/mindspore/lite/src/litert/kernel/cpu/fp32/oneslike_fp32.cc 336@@ -0,0 +1,51 @@ 337+/** 338+ * Copyright 2022 Huawei Technologies Co., Ltd 339+ * 340+ * Licensed under the Apache License, Version 2.0 (the "License"); 341+ * you may not use this file except in compliance with the License. 342+ * You may obtain a copy of the License at 343+ * 344+ * http://www.apache.org/licenses/LICENSE-2.0 345+ * 346+ * Unless required by applicable law or agreed to in writing, software 347+ * distributed under the License is distributed on an "AS IS" BASIS, 348+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 349+ * See the License for the specific language governing permissions and 350+ * limitations under the License. 351+ */ 352+ 353+#include "src/litert/kernel/cpu/fp32/oneslike_fp32.h" 354+#include "schema/model_generated.h" 355+#include "src/litert/kernel_registry.h" 356+#include "include/errorcode.h" 357+ 358+using mindspore::kernel::KERNEL_ARCH; 359+using mindspore::lite::KernelRegistrar; 360+using mindspore::lite::RET_ERROR; 361+using mindspore::lite::RET_OK; 362+using mindspore::schema::PrimitiveType_OnesLike; 363+ 364+namespace mindspore::kernel { 365+int OnesLikeCPUKernel::Prepare() { 366+ CHECK_LESS_RETURN(in_tensors_.size(), 1); 367+ CHECK_LESS_RETURN(out_tensors_.size(), 1); 368+ return RET_OK; 369+} 370+ 371+int OnesLikeCPUKernel::Run() { 372+ auto output = out_tensors_[0]; 373+ CHECK_NULL_RETURN(output); 374+ if (output->data_type() == kNumberTypeInt32) { 375+ ApproximateOnesLike(static_cast<int *>(output->data()), output->ElementsNum()); 376+ } else if (output->data_type() == kNumberTypeFloat32) { 377+ ApproximateOnesLike(static_cast<float *>(output->data()), output->ElementsNum()); 378+ } 379+ return RET_OK; 380+} 381+ 382+REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_OnesLike, LiteKernelCreator<OnesLikeCPUKernel>) 383+REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_OnesLike, LiteKernelCreator<OnesLikeCPUKernel>) 384+#ifdef ENABLE_FP16 385+REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_OnesLike, LiteKernelCreator<OnesLikeCPUKernel>) 386+#endif 387+} // namespace mindspore::kernel 388\ No newline at end of file 389diff --git a/mindspore/lite/src/litert/kernel/cpu/fp32/oneslike_fp32.h b/mindspore/lite/src/litert/kernel/cpu/fp32/oneslike_fp32.h 390new file mode 100644 391index 00000000..f90aebed 392--- /dev/null 393+++ b/mindspore/lite/src/litert/kernel/cpu/fp32/oneslike_fp32.h 394@@ -0,0 +1,46 @@ 395+/** 396+ * Copyright 2022 Huawei Technologies Co., Ltd 397+ * 398+ * Licensed under the Apache License, Version 2.0 (the "License"); 399+ * you may not use this file except in compliance with the License. 400+ * You may obtain a copy of the License at 401+ * 402+ * http://www.apache.org/licenses/LICENSE-2.0 403+ * 404+ * Unless required by applicable law or agreed to in writing, software 405+ * distributed under the License is distributed on an "AS IS" BASIS, 406+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 407+ * See the License for the specific language governing permissions and 408+ * limitations under the License. 409+ */ 410+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_ONESLike_FP32_H_ 411+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_ONESLike_FP32_H_ 412+ 413+#include <vector> 414+#include "src/litert/lite_kernel.h" 415+ 416+namespace mindspore::kernel { 417+class OnesLikeCPUKernel : public LiteKernel { 418+ public: 419+ OnesLikeCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, 420+ const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) 421+ : LiteKernel(parameter, inputs, outputs, ctx) {} 422+ 423+ ~OnesLikeCPUKernel() = default; 424+ 425+ int Prepare() override; 426+ int ReSize() override { return lite::RET_OK; } 427+ int Run() override; 428+ 429+ private: 430+ template <typename T> 431+ void ApproximateOnesLike(T *output, int data_size) { 432+ for (int i = 0; i < data_size; ++i) { 433+ output[i] = 1; 434+ } 435+ return; 436+ } 437+}; 438+} // namespace mindspore::kernel 439+ 440+#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_ONESLike_FP32_H_ 441\ No newline at end of file 442diff --git a/mindspore/lite/src/litert/lite_model.h b/mindspore/lite/src/litert/lite_model.h 443index 2b5422fa..635b529a 100644 444--- a/mindspore/lite/src/litert/lite_model.h 445+++ b/mindspore/lite/src/litert/lite_model.h 446@@ -1,5 +1,5 @@ 447 /** 448- * Copyright 2020 Huawei Technologies Co., Ltd 449+ * Copyright 2020-2023 Huawei Technologies Co., Ltd 450 * 451 * Licensed under the Apache License, Version 2.0 (the "License"); 452 * you may not use this file except in compliance with the License. 453diff --git a/mindspore/lite/src/litert/lite_session.cc b/mindspore/lite/src/litert/lite_session.cc 454index ded4d761..8f54879e 100644 455--- a/mindspore/lite/src/litert/lite_session.cc 456+++ b/mindspore/lite/src/litert/lite_session.cc 457@@ -2022,6 +2022,7 @@ int lite::LiteSession::LoadModelAndCompileByPath(const std::string &model_path, 458 delete model; 459 return RET_ERROR; 460 } 461+ model->Free(); 462 set_model(model); 463 return RET_OK; 464 } 465diff --git a/mindspore/lite/src/litert/weight_decoder.h b/mindspore/lite/src/litert/weight_decoder.h 466index 9afaca55..9fbcefde 100644 467--- a/mindspore/lite/src/litert/weight_decoder.h 468+++ b/mindspore/lite/src/litert/weight_decoder.h 469@@ -1,5 +1,5 @@ 470 /** 471- * Copyright 2020-2022 Huawei Technologies Co., Ltd 472+ * Copyright 2020-2023 Huawei Technologies Co., Ltd 473 * 474 * Licensed under the Apache License, Version 2.0 (the "License"); 475 * you may not use this file except in compliance with the License. 476diff --git a/mindspore/lite/src/tensor.h b/mindspore/lite/src/tensor.h 477index 838892cf..f2eb4d1a 100644 478--- a/mindspore/lite/src/tensor.h 479+++ b/mindspore/lite/src/tensor.h 480@@ -69,7 +69,7 @@ enum CompressType { 481 kFSEInfer = 6 482 }; 483 484-class Tensor { 485+class MS_API Tensor { 486 public: 487 Tensor() { tensor_c_ = {false, kTypeUnknown, NHWC, VarTensor, nullptr, 0}; } 488 489diff --git a/mindspore/lite/src/tensorlist.h b/mindspore/lite/src/tensorlist.h 490index bdfdda02..6925cc95 100644 491--- a/mindspore/lite/src/tensorlist.h 492+++ b/mindspore/lite/src/tensorlist.h 493@@ -56,7 +56,7 @@ namespace mindspore::lite { 494 * 495 * See the code for other constructors. 496 */ 497-class TensorList : public Tensor { 498+class MS_API TensorList : public Tensor { 499 public: 500 TensorList() { tensor_list_c_ = {false, kObjectTypeTensorType, DEFAULT_FORMAT, 0, kTypeUnknown, -1, nullptr, 0, 0}; } 501 502diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc 503index ef3c71f3..b581b389 100644 504--- a/mindspore/lite/src/train/train_session.cc 505+++ b/mindspore/lite/src/train/train_session.cc 506@@ -248,8 +248,8 @@ static int ReshapeWeightTensor(Tensor *orig_tensor, lite::Tensor *new_tensor) { 507 508 int TrainSession::UpdateWeights(std::vector<lite::Tensor *> modify_tensors) { 509 unsigned int num_of_found_tensors = 0; 510- for (auto tensor : tensors_) { 511- for (auto modify : modify_tensors) { 512+ for (auto modify : modify_tensors) { 513+ for (auto tensor : tensors_) { 514 if (modify == nullptr) { 515 MS_LOG(ERROR) << "Tensor is nullptr"; 516 return RET_PARAM_INVALID; 517diff --git a/mindspore/lite/tools/benchmark_train/net_train.h b/mindspore/lite/tools/benchmark_train/net_train.h 518index 43853e99..67e58a04 100644 519--- a/mindspore/lite/tools/benchmark_train/net_train.h 520+++ b/mindspore/lite/tools/benchmark_train/net_train.h 521@@ -1,5 +1,5 @@ 522 /** 523- * Copyright 2020 Huawei Technologies Co., Ltd 524+ * Copyright 2020-2023 Huawei Technologies Co., Ltd 525 * 526 * Licensed under the Apache License, Version 2.0 (the "License"); 527 * you may not use this file except in compliance with the License. 528diff --git a/mindspore/lite/tools/converter/converter_metagraph.cc b/mindspore/lite/tools/converter/converter_metagraph.cc 529index 6ffff71c..46a66128 100644 530--- a/mindspore/lite/tools/converter/converter_metagraph.cc 531+++ b/mindspore/lite/tools/converter/converter_metagraph.cc 532@@ -104,12 +104,14 @@ schema::MetaGraphT *ConverterToMetaGraph::Build(const std::shared_ptr<ConverterP 533 return nullptr; 534 } 535 536- // output name will be modified by Transform 537- status = UpdateMetaGraphOutputName(meta_graph, output_tensor_name); 538- if (status != RET_OK) { 539- MS_LOG(ERROR) << "UpdateGraphOutputName failed."; 540- delete meta_graph; 541- return nullptr; 542+ if (!param->train_model) { 543+ // output name will be modified by Transform 544+ status = UpdateMetaGraphOutputName(meta_graph, output_tensor_name); 545+ if (status != RET_OK) { 546+ MS_LOG(ERROR) << "UpdateGraphOutputName failed."; 547+ delete meta_graph; 548+ return nullptr; 549+ } 550 } 551 552 return meta_graph; 553diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc 554index d571b532..90b744e5 100644 555--- a/mindspore/lite/tools/converter/graphdef_transform.cc 556+++ b/mindspore/lite/tools/converter/graphdef_transform.cc 557@@ -26,6 +26,7 @@ 558 #include "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h" 559 #include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h" 560 #include "tools/converter/legacy_optimizer/graph/tensor_name_pass.h" 561+#include "tools/converter/legacy_optimizer/graph/node_name_pass.h" 562 #include "tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.h" 563 #include "tools/converter/legacy_optimizer/graph/convert_fp32_to_fp16_pass.h" 564 #include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h" 565@@ -136,6 +137,9 @@ int GraphDefTransform::Transform(const std::shared_ptr<ConverterPara> ¶m) { 566 Optimizer forming_model_optimizer; 567 forming_model_optimizer.AddPass(new (std::nothrow) InferShapePass(param->fmk_type)); 568 forming_model_optimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass(param)); 569+ if (param->train_model) { 570+ forming_model_optimizer.AddPass(new (std::nothrow) NodeNamePass()); 571+ } 572 forming_model_optimizer.AddPass(new (std::nothrow) TensorNamePass()); 573 forming_model_optimizer.AddPass(new (std::nothrow) ConvertFP32ToFP16Pass(param->weight_fp16)); 574 status = forming_model_optimizer.Run(graph_defT_); 575diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt b/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt 576index 9b16f4f8..30bccbde 100755 577--- a/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt 578+++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt 579@@ -9,6 +9,7 @@ file(GLOB GRAPH_PASS 580 ${CMAKE_CURRENT_SOURCE_DIR}/convert_fp32_to_fp16_pass.cc 581 ${CMAKE_CURRENT_SOURCE_DIR}/set_unused_quant_param_to_default_pass.cc 582 ${CMAKE_CURRENT_SOURCE_DIR}/tensor_name_pass.cc 583+ ${CMAKE_CURRENT_SOURCE_DIR}/node_name_pass.cc 584 ${CMAKE_CURRENT_SOURCE_DIR}/subgraph_node_pass.cc 585 ${CMAKE_CURRENT_SOURCE_DIR}/subgraph_tensor_pass.cc 586 ${CMAKE_CURRENT_SOURCE_DIR}/const_node_reorder_pass.cc 587diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/node_name_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/node_name_pass.cc 588new file mode 100644 589index 00000000..712927b0 590--- /dev/null 591+++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/node_name_pass.cc 592@@ -0,0 +1,96 @@ 593+/** 594+ * Copyright 2022 Huawei Technologies Co., Ltd 595+ * 596+ * Licensed under the Apache License, Version 2.0 (the "License"); 597+ * you may not use this file except in compliance with the License. 598+ * You may obtain a copy of the License at 599+ * 600+ * http://www.apache.org/licenses/LICENSE-2.0 601+ * 602+ * Unless required by applicable law or agreed to in writing, software 603+ * distributed under the License is distributed on an "AS IS" BASIS, 604+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 605+ * See the License for the specific language governing permissions and 606+ * limitations under the License. 607+ */ 608+ 609+#include "tools/converter/legacy_optimizer/graph/node_name_pass.h" 610+#include <string> 611+#include <vector> 612+#include "tools/converter/converter_context.h" 613+ 614+namespace mindspore::lite { 615+std::string CutShortName(const std::string &fullname, const std::string &delimiter) { 616+ size_t end_pos = fullname.find_last_of(delimiter); 617+ std::string name = ""; 618+ if (end_pos != std::string::npos) { 619+ name = fullname.substr(end_pos + 1); 620+ } 621+ if ((fullname.find("op") != std::string::npos) && (name.find("op") == std::string::npos) && 622+ (end_pos != std::string::npos)) { 623+ size_t pos = fullname.rfind(delimiter, end_pos - 1); 624+ if (pos != std::string::npos) { 625+ name.insert(0, fullname.substr(pos + 1, end_pos - pos)); 626+ } else { 627+ name.insert(0, fullname.substr(0, end_pos + 1)); 628+ } 629+ } 630+ 631+ const std::vector<std::string> loss_names = {"loss_fct", "_loss_fn", "SigmoidCrossEntropy"}; 632+ for (auto &s : loss_names) { 633+ if (fullname.find(s) != std::string::npos) { 634+ name.insert(0, s + "/"); 635+ break; 636+ } 637+ } 638+ 639+ if (fullname.find("Gradients") != std::string::npos) { 640+ size_t pos = fullname.find(delimiter); 641+ if (pos != std::string::npos) { 642+ name.insert(0, fullname.substr(0, pos + 1)); 643+ } 644+ } 645+ return name; 646+} 647+ 648+STATUS NodeNamePass::Run(schema::MetaGraphT *graph) { 649+ if (graph == nullptr) { 650+ MS_LOG(ERROR) << "graph is nullptr"; 651+ return RET_NULL_PTR; 652+ } 653+ 654+ std::string delimiter = "/"; 655+ for (auto &node : graph->nodes) { 656+ if (node == nullptr || node->primitive == nullptr) { 657+ MS_LOG(ERROR) << "node or node->primitive is nullptr"; 658+ return RET_NULL_PTR; 659+ } 660+ std::string node_name = CutShortName(node->name, delimiter); 661+ node->name = node_name != "" ? node_name : node->name; 662+ 663+ for (int i = 0; i < static_cast<int>(node->inputIndex.size()); i++) { 664+ auto tensor_id = node->inputIndex.at(i); 665+ auto &tensor = graph->allTensors.at(tensor_id); 666+ if (tensor->name.empty()) { 667+ MS_LOG(DEBUG) << "input tensor (id = " << tensor_id << ") name is null"; 668+ tensor->name = node->name + "/input-" + std::to_string(i); 669+ } else { 670+ std::string in_tensor_name = CutShortName(tensor->name, delimiter); 671+ tensor->name = in_tensor_name != "" ? in_tensor_name : tensor->name; 672+ } 673+ } 674+ 675+ for (int i = 0; i < static_cast<int>(node->outputIndex.size()); i++) { 676+ auto tensor_id = node->outputIndex.at(i); 677+ auto &tensor = graph->allTensors.at(tensor_id); 678+ if (tensor->name.empty()) { 679+ tensor->name = node->name + "/output-" + std::to_string(i); 680+ } else { 681+ std::string out_tensor_name = CutShortName(tensor->name, delimiter); 682+ tensor->name = out_tensor_name != "" ? out_tensor_name : tensor->name; 683+ } 684+ } 685+ } 686+ return RET_OK; 687+} 688+} // namespace mindspore::lite 689\ No newline at end of file 690diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/node_name_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/node_name_pass.h 691new file mode 100644 692index 00000000..4e58e5c7 693--- /dev/null 694+++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/node_name_pass.h 695@@ -0,0 +1,35 @@ 696+/** 697+ * Copyright 2022 Huawei Technologies Co., Ltd 698+ * 699+ * Licensed under the Apache License, Version 2.0 (the "License"); 700+ * you may not use this file except in compliance with the License. 701+ * You may obtain a copy of the License at 702+ * 703+ * http://www.apache.org/licenses/LICENSE-2.0 704+ * 705+ * Unless required by applicable law or agreed to in writing, software 706+ * distributed under the License is distributed on an "AS IS" BASIS, 707+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 708+ * See the License for the specific language governing permissions and 709+ * limitations under the License. 710+ */ 711+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAPH_NODE_NAME_PASS_H_ 712+#define MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAPH_NODE_NAME_PASS_H_ 713+ 714+#include <memory> 715+#include "tools/converter/optimizer.h" 716+#include "tools/common/graph_util.h" 717+ 718+namespace mindspore { 719+namespace lite { 720+class NodeNamePass : public GraphPass { 721+ public: 722+ NodeNamePass() {} 723+ 724+ ~NodeNamePass() override = default; 725+ 726+ STATUS Run(schema::MetaGraphT *graph) override; 727+}; 728+} // namespace lite 729+} // namespace mindspore 730+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAPH_NODE_NAME_PASS_H_ 731