1diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn 2index a1e7908e..7bbc3782 100644 3--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn 4+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn 5@@ -714,7 +714,17 @@ arm32_assembly_sources = [ 6 "assembly/arm32/WinogradTransRight.S", 7 ] 8 9-fp16_assembly_sources = [ 10+arm32_fp16_assembly_sources = [ 11+ "assembly/arm82_aarch32_fp16/Float16Tofloat32.S", 12+ "assembly/arm82_aarch32_fp16/Float32ToFloat16.S", 13+ "assembly/arm82_aarch32_fp16/Matmul12x8Fp16.S", 14+ "assembly/arm82_aarch32_fp16/MatVecMulFp16.S", 15+ "assembly/arm82_aarch32_fp16/TiledC4MatmulFp16.S", 16+ "assembly/arm82_aarch32_fp16/WinogradTransLeft.S", 17+ "assembly/arm82_aarch32_fp16/WinogradTransRight.S", 18+] 19+ 20+arm64_fp16_assembly_sources = [ 21 "assembly/fp16/CalculateMinMaxFp16Count8.S", 22 "assembly/fp16/ConvDwFp16Border.S", 23 "assembly/fp16/ConvDwFp16Center.S", 24@@ -839,11 +849,13 @@ nnacl_sources += infer_control_sources 25 26 # source files on arm32 27 arm_only_sources = arm32_assembly_sources 28+#arm_only_sources += arm32_fp16_assembly_sources 29+not_needed(arm32_fp16_assembly_sources) 30 31 # source files on arm64 32 arm64_only_sources = fp16_kernel_sources 33 arm64_only_sources += fp16_grad_sources 34-arm64_only_sources += fp16_assembly_sources 35+arm64_only_sources += arm64_fp16_assembly_sources 36 arm64_only_sources += arm64_assembly_sources 37 arm64_only_sources += optimizing_assembly_sources 38 arm64_only_sources += arm64_fp32_kernel_sources 39diff --git a/mindspore/lite/BUILD.gn b/mindspore/lite/BUILD.gn 40index 6d83e6f9..9d9c299f 100644 41--- a/mindspore/lite/BUILD.gn 42+++ b/mindspore/lite/BUILD.gn 43@@ -74,10 +74,12 @@ import("//build/ohos.gni") 44 ohos_group("mindspore") { 45 deps = [ 46 ":mindspore_lib", 47+ ":mindspore_train_lib", 48 "mindir:mindir_lib", 49 ] 50 } 51 52+# Inference library 53 cxx_api_sources = [ 54 "src/litert/cxx_api/cell.cc", 55 "src/litert/cxx_api/context.cc", 56@@ -429,7 +431,6 @@ ohos_shared_library("mindspore_lib") { 57 SUPPORT_NNRT = true 58 if (SUPPORT_NNRT) { 59 sources += [ 60- # "mindir/src/mindir_nnrt_lite_graph.cc", 61 "src/litert/delegate/nnrt/checker/primitive_check.cc", 62 "src/litert/delegate/nnrt/nnrt_delegate.cc", 63 "src/litert/delegate/nnrt/nnrt_model_kernel.cc", 64@@ -444,8 +445,9 @@ ohos_shared_library("mindspore_lib") { 65 external_deps += [ "neural_network_runtime:nnrt_target" ] 66 deps += [ "mindir:mindir_lib" ] 67 defines += [ "SUPPORT_NNRT" ] 68- defines += [ "MSLITE_ENABLE_EXPERIMENTAL_KERNEL" ] 69 } 70+ defines += [ "MSLITE_ENABLE_EXPERIMENTAL_KERNEL" ] 71+ defines += [ "SUPPORT_TRAIN" ] 72 cflags_cc = [ 73 "-Wno-ignored-qualifiers", 74 "-Wunused-private-field", 75@@ -458,6 +460,225 @@ ohos_shared_library("mindspore_lib") { 76 subsystem_name = "thirdparty" 77 } 78 79+# Train library 80+expression_cxx_api_sources = [ 81+ "src/litert/cxx_api/expression/net.cc", 82+ "src/litert/cxx_api/expression/net_impl.cc", 83+ "src/litert/cxx_api/expression/node_impl.cc", 84+] 85+ 86+expression_op_sources = [ 87+ "src/expression/ops/activation.cc", 88+ "src/expression/ops/adam.cc", 89+ "src/expression/ops/addn.cc", 90+ "src/expression/ops/arithmetic.cc", 91+ "src/expression/ops/arithmetic_self.cc", 92+ "src/expression/ops/assign.cc", 93+ "src/expression/ops/batchnorm.cc", 94+ "src/expression/ops/biasadd.cc", 95+ "src/expression/ops/conv.cc", 96+ "src/expression/ops/dense.cc", 97+ "src/expression/ops/depend.cc", 98+ "src/expression/ops/dropout.cc", 99+ "src/expression/ops/flatten.cc", 100+ "src/expression/ops/pooling.cc", 101+ "src/expression/ops/reduce.cc", 102+ "src/expression/ops/reshape.cc", 103+ "src/expression/ops/softmax.cc", 104+ "src/expression/ops/softmaxCE.cc", 105+ "src/expression/ops/tile.cc", 106+ "src/expression/ops/transpose.cc", 107+] 108+ 109+all_expression_sources = [ 110+ "src/expression/export.cc", 111+ "src/expression/expr.cc", 112+ "src/expression/import.cc", 113+ "src/expression/net.cc", 114+ "src/expression/node.cc", 115+ "src/expression/ops.cc", 116+ "src/expression/ops_utils.cc", 117+ "src/expression/param.cc", 118+ "src/expression/sequential.cc", 119+] 120+ 121+all_expression_sources += expression_cxx_api_sources 122+all_expression_sources += expression_op_sources 123+ 124+all_train_sources = [ 125+ # ${API_TRAIN_SRC} is empty. 126+ # ${TRAIN_SRC_WITH_MD} is empty. 127+ "src/common/quant_utils.cc", 128+ "src/litert/cxx_api/metrics/accuracy.cc", 129+ "src/litert/cxx_api/train/model_build.cc", 130+ "src/litert/cxx_api/train/model_build_impl.cc", 131+ "src/litert/cxx_api/train/converters.cc", 132+ "src/litert/cxx_api/train/train_support.cc", 133+ "src/train/train_populate_parameter.cc", 134+ "src/train/train_session.cc", 135+ "src/train/graph_fusion.cc", 136+ "src/train/graph_dropout.cc", 137+ "src/train/transfer_session.cc", 138+ "src/train/train_utils.cc", 139+ "src/train/loss_monitor.cc", 140+ "src/train/lr_scheduler.cc", 141+ "src/train/accuracy_metrics.cc", 142+# "src/train/accuracy_monitor.cc", # depends on minddata header, not compiled 143+ "src/train/classification_train_accuracy_monitor.cc", 144+ "src/train/train_export.cc", 145+ "src/train/opt_allocator.cc", 146+ "src/train/optimizer/common/fusion_utils.cc", 147+ "src/train/optimizer/fusion/matmul_activation_fusion_pass.cc", 148+ "src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.cc", 149+ "src/train/optimizer/fusion/gru_fusion_pass.cc", 150+ "src/common/storage.cc", 151+ "tools/converter/optimizer.cc", 152+ "tools/converter/legacy_optimizer/fusion/fusion_pass.cc", 153+ "tools/converter/legacy_optimizer/fusion/fusion_pattern.cc", 154+ "tools/common/meta_graph_utils.cc", 155+ "tools/common/statistic_utils.cc", 156+ "tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.cc", 157+ "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.cc", 158+ "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.cc", 159+ "tools/converter/legacy_optimizer/graph/subgraph_node_pass.cc", 160+] 161+ 162+all_train_sources += all_expression_sources 163+ 164+fp16_train_kernel_sources = [ 165+ "src/litert/kernel/cpu/fp16_grad/activation_fp16_grad.cc", 166+ "src/litert/kernel/cpu/fp16_grad/arithmetic_fp16_grad.cc", 167+ "src/litert/kernel/cpu/fp16_grad/arithmetic_fp16_self_grad.cc", 168+ "src/litert/kernel/cpu/fp16_grad/bias_fp16_grad.cc", 169+ "src/litert/kernel/cpu/fp16_grad/bn_fp16_grad.cc", 170+ "src/litert/kernel/cpu/fp16_grad/convolution_fp16_grad_filter.cc", 171+ "src/litert/kernel/cpu/fp16_grad/convolution_fp16_grad_input.cc", 172+ "src/litert/kernel/cpu/fp16_grad/dropout_fp16_grad.cc", 173+ "src/litert/kernel/cpu/fp16_grad/layernorm_fp16_grad.cc", 174+ "src/litert/kernel/cpu/fp16_grad/neg_fp16_grad.cc", 175+ "src/litert/kernel/cpu/fp16_grad/pooling_fp16_grad.cc", 176+ "src/litert/kernel/cpu/fp16_grad/resize_fp16_grad.cc", 177+ "src/litert/kernel/cpu/fp16_grad/strided_slice_fp16_grad.cc", 178+ "src/litert/kernel/cpu/fp16_grad/unsorted_segment_sum_fp16.cc", 179+] 180+ 181+fp32_train_kernel_sources = [ 182+ "src/litert/kernel/cpu/fp32_grad/activation_grad.cc", 183+ "src/litert/kernel/cpu/fp32_grad/adam.cc", 184+ "src/litert/kernel/cpu/fp32_grad/adam_weight_decay.cc", 185+ "src/litert/kernel/cpu/fp32_grad/apply_momentum.cc", 186+ "src/litert/kernel/cpu/fp32_grad/arithmetic_grad.cc", 187+ "src/litert/kernel/cpu/fp32_grad/arithmetic_self_grad.cc", 188+ "src/litert/kernel/cpu/fp32_grad/assign.cc", 189+ "src/litert/kernel/cpu/fp32_grad/bias_grad.cc", 190+ "src/litert/kernel/cpu/fp32_grad/bn_grad.cc", 191+ "src/litert/kernel/cpu/fp32_grad/convolution.cc", 192+ "src/litert/kernel/cpu/fp32_grad/convolution_grad_filter.cc", 193+ "src/litert/kernel/cpu/fp32_grad/convolution_grad_input.cc", 194+ "src/litert/kernel/cpu/fp32_grad/deconvolution_grad_filter.cc", 195+ "src/litert/kernel/cpu/fp32_grad/dropout.cc", 196+ "src/litert/kernel/cpu/fp32_grad/dropout_grad.cc", 197+ "src/litert/kernel/cpu/fp32_grad/layernorm_grad.cc", 198+ "src/litert/kernel/cpu/fp32_grad/lstm_grad_data_fp32.cc", 199+ "src/litert/kernel/cpu/fp32_grad/lstm_grad_fp32.cc", 200+ "src/litert/kernel/cpu/fp32_grad/lstm_grad_weight_fp32.cc", 201+ "src/litert/kernel/cpu/fp32_grad/neg_grad.cc", 202+ "src/litert/kernel/cpu/fp32_grad/nllloss_grad.cc", 203+ "src/litert/kernel/cpu/fp32_grad/pooling_grad.cc", 204+ "src/litert/kernel/cpu/fp32_grad/power_grad.cc", 205+ "src/litert/kernel/cpu/fp32_grad/resize_grad.cc", 206+ "src/litert/kernel/cpu/fp32_grad/sgd.cc", 207+ "src/litert/kernel/cpu/fp32_grad/sigmoid_cross_entropy_with_logits.cc", 208+ "src/litert/kernel/cpu/fp32_grad/sigmoid_cross_entropy_with_logits_grad.cc", 209+ "src/litert/kernel/cpu/fp32_grad/smooth_l1_loss.cc", 210+ "src/litert/kernel/cpu/fp32_grad/smooth_l1_loss_grad.cc", 211+ "src/litert/kernel/cpu/fp32_grad/softmax_cross_entropy_with_logits.cc", 212+ "src/litert/kernel/cpu/fp32_grad/softmax_grad.cc", 213+ "src/litert/kernel/cpu/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc", 214+ "src/litert/kernel/cpu/fp32_grad/strided_slice_grad.cc", 215+ "src/litert/kernel/cpu/fp32_grad/unsorted_segment_sum.cc", 216+ "src/litert/kernel/cpu/fp32_grad/binary_cross_entropy.cc", 217+ "src/litert/kernel/cpu/fp32_grad/binary_cross_entropy_grad.cc", 218+] 219+ 220+#all_train_sources += fp16_train_kernel_sources 221+not_needed(fp16_train_kernel_sources) 222+all_train_sources += fp32_train_kernel_sources 223+ 224+ohos_shared_library("mindspore_train_lib") { 225+ deps = [ 226+ ":mindspore_lib", 227+ ] 228+ 229+ sources = all_train_sources 230+ 231+ include_dirs = [ 232+ "//base/hiviewdfx/hilog/interfaces/native/innerkits/include", 233+ "//third_party/flatbuffers/include", 234+ "./", 235+ "../", 236+ "../../", 237+ "../core", 238+ "src", 239+ "src/c_api/", 240+ "../ccsrc/plugin/device/cpu/kernel/", 241+ "../core/mindrt/src/", 242+ "../core/mindrt/include/", 243+ "../../third_party/", 244+ "./schema/", 245+ "../ccsrc/", 246+ ] 247+ 248+ defines = [ 249+ "ENABLE_MINDRT", 250+ "MS_COMPILE_OHOS", 251+ "PRIMITIVE_WRITEABLE", 252+ "VERSION_STR=\"2.1.0\"", 253+ ] 254+ 255+ if (target_cpu == "arm") { 256+ defines += [ 257+ "ENABLE_ARM", 258+ "ENABLE_ARM32", 259+ "ENABLE_NEON", 260+ ] 261+ } else if (target_cpu == "arm64") { 262+ defines += [ 263+ "ENABLE_ARM", 264+ "ENABLE_ARM64", 265+ "ENABLE_NEON", 266+ "ENABLE_FP16", 267+ "USE_OPENCL_WRAPPER", 268+ "MS_OPENCL_PROFILE=false", 269+ "CL_TARGET_OPENCL_VERSION=200", 270+ "CL_HPP_TARGET_OPENCL_VERSION=120", 271+ "CL_HPP_MINIMUM_OPENCL_VERSION=120", 272+ ] 273+ } 274+ configs = [ 275+ ":mindspore_api", 276+ ":disable_android", 277+ ":train_kernel_option", 278+ ] 279+ 280+ remove_configs = [ "//build/config/compiler:no_rtti" ] 281+ external_deps = [ "hilog:libhilog" ] 282+ innerapi_tags = [ "platformsdk" ] 283+ output_name = "libmindspore-lite-train" 284+ output_extension = "so" 285+ defines += [ "SUPPORT_TRAIN" ] 286+ cflags_cc = [ 287+ "-Wno-ignored-qualifiers", 288+ "-Wunused-private-field", 289+ "-Wno-unused-private-field", 290+ "-Wno-inconsistent-missing-override", 291+ "-Wno-macro-redefined", 292+ "-Wno-constant-conversion", 293+ ] 294+ part_name = "mindspore" 295+} 296+ 297+# Build configurations 298 config("opencl_option") { 299 cflags_cc = [ "-Wno-missing-braces" ] 300 } 301@@ -482,3 +702,7 @@ config("disable_android") { 302 config("secure_option") { 303 cflags = [ "-fstack-protector-all" ] 304 } 305+ 306+config("train_kernel_option") { 307+ cflags_cc = [ "-fno-finite-math-only" ] 308+} 309diff --git a/mindspore/lite/include/registry/opencl_runtime_wrapper.h b/mindspore/lite/include/registry/opencl_runtime_wrapper.h 310index fb647d37..b55554e4 100644 311--- a/mindspore/lite/include/registry/opencl_runtime_wrapper.h 312+++ b/mindspore/lite/include/registry/opencl_runtime_wrapper.h 313@@ -1,5 +1,5 @@ 314 /** 315- * Copyright 2021 Huawei Technologies Co., Ltd 316+ * Copyright 2021-2023 Huawei Technologies Co., Ltd 317 * 318 * Licensed under the Apache License, Version 2.0 (the "License"); 319 * you may not use this file except in compliance with the License. 320diff --git a/mindspore/lite/src/litert/kernel/cpu/BUILD.gn b/mindspore/lite/src/litert/kernel/cpu/BUILD.gn 321index b34e0427..48308425 100644 322--- a/mindspore/lite/src/litert/kernel/cpu/BUILD.gn 323+++ b/mindspore/lite/src/litert/kernel/cpu/BUILD.gn 324@@ -112,6 +112,7 @@ cpu_kernel_sources = [ 325 "fp32/uniform_real_fp32.cc", 326 "fp32/unstack_fp32.cc", 327 "fp32/where_fp32.cc", 328+ "fp32/oneslike_fp32.cc", 329 "fp32/online_fusion/cast_gather_reduce_fp32.cc", 330 "fp32/online_fusion/reduce_concat_fp32.cc", 331 "fp32/online_fusion/split_reduce_concat_fp32.cc", 332diff --git a/mindspore/lite/src/litert/kernel/cpu/fp32/oneslike_fp32.cc b/mindspore/lite/src/litert/kernel/cpu/fp32/oneslike_fp32.cc 333new file mode 100644 334index 00000000..b4c3bf7e 335--- /dev/null 336+++ b/mindspore/lite/src/litert/kernel/cpu/fp32/oneslike_fp32.cc 337@@ -0,0 +1,51 @@ 338+/** 339+ * Copyright 2022 Huawei Technologies Co., Ltd 340+ * 341+ * Licensed under the Apache License, Version 2.0 (the "License"); 342+ * you may not use this file except in compliance with the License. 343+ * You may obtain a copy of the License at 344+ * 345+ * http://www.apache.org/licenses/LICENSE-2.0 346+ * 347+ * Unless required by applicable law or agreed to in writing, software 348+ * distributed under the License is distributed on an "AS IS" BASIS, 349+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 350+ * See the License for the specific language governing permissions and 351+ * limitations under the License. 352+ */ 353+ 354+#include "src/litert/kernel/cpu/fp32/oneslike_fp32.h" 355+#include "schema/model_generated.h" 356+#include "src/litert/kernel_registry.h" 357+#include "include/errorcode.h" 358+ 359+using mindspore::kernel::KERNEL_ARCH; 360+using mindspore::lite::KernelRegistrar; 361+using mindspore::lite::RET_ERROR; 362+using mindspore::lite::RET_OK; 363+using mindspore::schema::PrimitiveType_OnesLike; 364+ 365+namespace mindspore::kernel { 366+int OnesLikeCPUKernel::Prepare() { 367+ CHECK_LESS_RETURN(in_tensors_.size(), 1); 368+ CHECK_LESS_RETURN(out_tensors_.size(), 1); 369+ return RET_OK; 370+} 371+ 372+int OnesLikeCPUKernel::Run() { 373+ auto output = out_tensors_[0]; 374+ CHECK_NULL_RETURN(output); 375+ if (output->data_type() == kNumberTypeInt32) { 376+ ApproximateOnesLike(static_cast<int *>(output->data()), output->ElementsNum()); 377+ } else if (output->data_type() == kNumberTypeFloat32) { 378+ ApproximateOnesLike(static_cast<float *>(output->data()), output->ElementsNum()); 379+ } 380+ return RET_OK; 381+} 382+ 383+REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_OnesLike, LiteKernelCreator<OnesLikeCPUKernel>) 384+REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_OnesLike, LiteKernelCreator<OnesLikeCPUKernel>) 385+#ifdef ENABLE_FP16 386+REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_OnesLike, LiteKernelCreator<OnesLikeCPUKernel>) 387+#endif 388+} // namespace mindspore::kernel 389\ No newline at end of file 390diff --git a/mindspore/lite/src/litert/kernel/cpu/fp32/oneslike_fp32.h b/mindspore/lite/src/litert/kernel/cpu/fp32/oneslike_fp32.h 391new file mode 100644 392index 00000000..f90aebed 393--- /dev/null 394+++ b/mindspore/lite/src/litert/kernel/cpu/fp32/oneslike_fp32.h 395@@ -0,0 +1,46 @@ 396+/** 397+ * Copyright 2022 Huawei Technologies Co., Ltd 398+ * 399+ * Licensed under the Apache License, Version 2.0 (the "License"); 400+ * you may not use this file except in compliance with the License. 401+ * You may obtain a copy of the License at 402+ * 403+ * http://www.apache.org/licenses/LICENSE-2.0 404+ * 405+ * Unless required by applicable law or agreed to in writing, software 406+ * distributed under the License is distributed on an "AS IS" BASIS, 407+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 408+ * See the License for the specific language governing permissions and 409+ * limitations under the License. 410+ */ 411+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_ONESLike_FP32_H_ 412+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_ONESLike_FP32_H_ 413+ 414+#include <vector> 415+#include "src/litert/lite_kernel.h" 416+ 417+namespace mindspore::kernel { 418+class OnesLikeCPUKernel : public LiteKernel { 419+ public: 420+ OnesLikeCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, 421+ const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) 422+ : LiteKernel(parameter, inputs, outputs, ctx) {} 423+ 424+ ~OnesLikeCPUKernel() = default; 425+ 426+ int Prepare() override; 427+ int ReSize() override { return lite::RET_OK; } 428+ int Run() override; 429+ 430+ private: 431+ template <typename T> 432+ void ApproximateOnesLike(T *output, int data_size) { 433+ for (int i = 0; i < data_size; ++i) { 434+ output[i] = 1; 435+ } 436+ return; 437+ } 438+}; 439+} // namespace mindspore::kernel 440+ 441+#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_ONESLike_FP32_H_ 442\ No newline at end of file 443diff --git a/mindspore/lite/src/litert/lite_model.h b/mindspore/lite/src/litert/lite_model.h 444index 2b5422fa..635b529a 100644 445--- a/mindspore/lite/src/litert/lite_model.h 446+++ b/mindspore/lite/src/litert/lite_model.h 447@@ -1,5 +1,5 @@ 448 /** 449- * Copyright 2020 Huawei Technologies Co., Ltd 450+ * Copyright 2020-2023 Huawei Technologies Co., Ltd 451 * 452 * Licensed under the Apache License, Version 2.0 (the "License"); 453 * you may not use this file except in compliance with the License. 454diff --git a/mindspore/lite/src/litert/lite_session.cc b/mindspore/lite/src/litert/lite_session.cc 455index ded4d761..8f54879e 100644 456--- a/mindspore/lite/src/litert/lite_session.cc 457+++ b/mindspore/lite/src/litert/lite_session.cc 458@@ -2022,6 +2022,7 @@ int lite::LiteSession::LoadModelAndCompileByPath(const std::string &model_path, 459 delete model; 460 return RET_ERROR; 461 } 462+ model->Free(); 463 set_model(model); 464 return RET_OK; 465 } 466diff --git a/mindspore/lite/src/litert/weight_decoder.h b/mindspore/lite/src/litert/weight_decoder.h 467index 9afaca55..9fbcefde 100644 468--- a/mindspore/lite/src/litert/weight_decoder.h 469+++ b/mindspore/lite/src/litert/weight_decoder.h 470@@ -1,5 +1,5 @@ 471 /** 472- * Copyright 2020-2022 Huawei Technologies Co., Ltd 473+ * Copyright 2020-2023 Huawei Technologies Co., Ltd 474 * 475 * Licensed under the Apache License, Version 2.0 (the "License"); 476 * you may not use this file except in compliance with the License. 477diff --git a/mindspore/lite/src/tensor.h b/mindspore/lite/src/tensor.h 478index 838892cf..f2eb4d1a 100644 479--- a/mindspore/lite/src/tensor.h 480+++ b/mindspore/lite/src/tensor.h 481@@ -69,7 +69,7 @@ enum CompressType { 482 kFSEInfer = 6 483 }; 484 485-class Tensor { 486+class MS_API Tensor { 487 public: 488 Tensor() { tensor_c_ = {false, kTypeUnknown, NHWC, VarTensor, nullptr, 0}; } 489 490diff --git a/mindspore/lite/src/tensorlist.h b/mindspore/lite/src/tensorlist.h 491index bdfdda02..6925cc95 100644 492--- a/mindspore/lite/src/tensorlist.h 493+++ b/mindspore/lite/src/tensorlist.h 494@@ -56,7 +56,7 @@ namespace mindspore::lite { 495 * 496 * See the code for other constructors. 497 */ 498-class TensorList : public Tensor { 499+class MS_API TensorList : public Tensor { 500 public: 501 TensorList() { tensor_list_c_ = {false, kObjectTypeTensorType, DEFAULT_FORMAT, 0, kTypeUnknown, -1, nullptr, 0, 0}; } 502 503diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc 504index ef3c71f3..b581b389 100644 505--- a/mindspore/lite/src/train/train_session.cc 506+++ b/mindspore/lite/src/train/train_session.cc 507@@ -248,8 +248,8 @@ static int ReshapeWeightTensor(Tensor *orig_tensor, lite::Tensor *new_tensor) { 508 509 int TrainSession::UpdateWeights(std::vector<lite::Tensor *> modify_tensors) { 510 unsigned int num_of_found_tensors = 0; 511- for (auto tensor : tensors_) { 512- for (auto modify : modify_tensors) { 513+ for (auto modify : modify_tensors) { 514+ for (auto tensor : tensors_) { 515 if (modify == nullptr) { 516 MS_LOG(ERROR) << "Tensor is nullptr"; 517 return RET_PARAM_INVALID; 518diff --git a/mindspore/lite/tools/benchmark_train/net_train.h b/mindspore/lite/tools/benchmark_train/net_train.h 519index 43853e99..67e58a04 100644 520--- a/mindspore/lite/tools/benchmark_train/net_train.h 521+++ b/mindspore/lite/tools/benchmark_train/net_train.h 522@@ -1,5 +1,5 @@ 523 /** 524- * Copyright 2020 Huawei Technologies Co., Ltd 525+ * Copyright 2020-2023 Huawei Technologies Co., Ltd 526 * 527 * Licensed under the Apache License, Version 2.0 (the "License"); 528 * you may not use this file except in compliance with the License. 529diff --git a/mindspore/lite/tools/converter/converter_metagraph.cc b/mindspore/lite/tools/converter/converter_metagraph.cc 530index 6ffff71c..46a66128 100644 531--- a/mindspore/lite/tools/converter/converter_metagraph.cc 532+++ b/mindspore/lite/tools/converter/converter_metagraph.cc 533@@ -104,12 +104,14 @@ schema::MetaGraphT *ConverterToMetaGraph::Build(const std::shared_ptr<ConverterP 534 return nullptr; 535 } 536 537- // output name will be modified by Transform 538- status = UpdateMetaGraphOutputName(meta_graph, output_tensor_name); 539- if (status != RET_OK) { 540- MS_LOG(ERROR) << "UpdateGraphOutputName failed."; 541- delete meta_graph; 542- return nullptr; 543+ if (!param->train_model) { 544+ // output name will be modified by Transform 545+ status = UpdateMetaGraphOutputName(meta_graph, output_tensor_name); 546+ if (status != RET_OK) { 547+ MS_LOG(ERROR) << "UpdateGraphOutputName failed."; 548+ delete meta_graph; 549+ return nullptr; 550+ } 551 } 552 553 return meta_graph; 554diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc 555index d571b532..90b744e5 100644 556--- a/mindspore/lite/tools/converter/graphdef_transform.cc 557+++ b/mindspore/lite/tools/converter/graphdef_transform.cc 558@@ -26,6 +26,7 @@ 559 #include "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h" 560 #include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h" 561 #include "tools/converter/legacy_optimizer/graph/tensor_name_pass.h" 562+#include "tools/converter/legacy_optimizer/graph/node_name_pass.h" 563 #include "tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.h" 564 #include "tools/converter/legacy_optimizer/graph/convert_fp32_to_fp16_pass.h" 565 #include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h" 566@@ -136,6 +137,9 @@ int GraphDefTransform::Transform(const std::shared_ptr<ConverterPara> ¶m) { 567 Optimizer forming_model_optimizer; 568 forming_model_optimizer.AddPass(new (std::nothrow) InferShapePass(param->fmk_type)); 569 forming_model_optimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass(param)); 570+ if (param->train_model) { 571+ forming_model_optimizer.AddPass(new (std::nothrow) NodeNamePass()); 572+ } 573 forming_model_optimizer.AddPass(new (std::nothrow) TensorNamePass()); 574 forming_model_optimizer.AddPass(new (std::nothrow) ConvertFP32ToFP16Pass(param->weight_fp16)); 575 status = forming_model_optimizer.Run(graph_defT_); 576diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt b/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt 577index 9b16f4f8..30bccbde 100755 578--- a/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt 579+++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt 580@@ -9,6 +9,7 @@ file(GLOB GRAPH_PASS 581 ${CMAKE_CURRENT_SOURCE_DIR}/convert_fp32_to_fp16_pass.cc 582 ${CMAKE_CURRENT_SOURCE_DIR}/set_unused_quant_param_to_default_pass.cc 583 ${CMAKE_CURRENT_SOURCE_DIR}/tensor_name_pass.cc 584+ ${CMAKE_CURRENT_SOURCE_DIR}/node_name_pass.cc 585 ${CMAKE_CURRENT_SOURCE_DIR}/subgraph_node_pass.cc 586 ${CMAKE_CURRENT_SOURCE_DIR}/subgraph_tensor_pass.cc 587 ${CMAKE_CURRENT_SOURCE_DIR}/const_node_reorder_pass.cc 588diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/node_name_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/node_name_pass.cc 589new file mode 100644 590index 00000000..712927b0 591--- /dev/null 592+++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/node_name_pass.cc 593@@ -0,0 +1,96 @@ 594+/** 595+ * Copyright 2022 Huawei Technologies Co., Ltd 596+ * 597+ * Licensed under the Apache License, Version 2.0 (the "License"); 598+ * you may not use this file except in compliance with the License. 599+ * You may obtain a copy of the License at 600+ * 601+ * http://www.apache.org/licenses/LICENSE-2.0 602+ * 603+ * Unless required by applicable law or agreed to in writing, software 604+ * distributed under the License is distributed on an "AS IS" BASIS, 605+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 606+ * See the License for the specific language governing permissions and 607+ * limitations under the License. 608+ */ 609+ 610+#include "tools/converter/legacy_optimizer/graph/node_name_pass.h" 611+#include <string> 612+#include <vector> 613+#include "tools/converter/converter_context.h" 614+ 615+namespace mindspore::lite { 616+std::string CutShortName(const std::string &fullname, const std::string &delimiter) { 617+ size_t end_pos = fullname.find_last_of(delimiter); 618+ std::string name = ""; 619+ if (end_pos != std::string::npos) { 620+ name = fullname.substr(end_pos + 1); 621+ } 622+ if ((fullname.find("op") != std::string::npos) && (name.find("op") == std::string::npos) && 623+ (end_pos != std::string::npos)) { 624+ size_t pos = fullname.rfind(delimiter, end_pos - 1); 625+ if (pos != std::string::npos) { 626+ name.insert(0, fullname.substr(pos + 1, end_pos - pos)); 627+ } else { 628+ name.insert(0, fullname.substr(0, end_pos + 1)); 629+ } 630+ } 631+ 632+ const std::vector<std::string> loss_names = {"loss_fct", "_loss_fn", "SigmoidCrossEntropy"}; 633+ for (auto &s : loss_names) { 634+ if (fullname.find(s) != std::string::npos) { 635+ name.insert(0, s + "/"); 636+ break; 637+ } 638+ } 639+ 640+ if (fullname.find("Gradients") != std::string::npos) { 641+ size_t pos = fullname.find(delimiter); 642+ if (pos != std::string::npos) { 643+ name.insert(0, fullname.substr(0, pos + 1)); 644+ } 645+ } 646+ return name; 647+} 648+ 649+STATUS NodeNamePass::Run(schema::MetaGraphT *graph) { 650+ if (graph == nullptr) { 651+ MS_LOG(ERROR) << "graph is nullptr"; 652+ return RET_NULL_PTR; 653+ } 654+ 655+ std::string delimiter = "/"; 656+ for (auto &node : graph->nodes) { 657+ if (node == nullptr || node->primitive == nullptr) { 658+ MS_LOG(ERROR) << "node or node->primitive is nullptr"; 659+ return RET_NULL_PTR; 660+ } 661+ std::string node_name = CutShortName(node->name, delimiter); 662+ node->name = node_name != "" ? node_name : node->name; 663+ 664+ for (int i = 0; i < static_cast<int>(node->inputIndex.size()); i++) { 665+ auto tensor_id = node->inputIndex.at(i); 666+ auto &tensor = graph->allTensors.at(tensor_id); 667+ if (tensor->name.empty()) { 668+ MS_LOG(DEBUG) << "input tensor (id = " << tensor_id << ") name is null"; 669+ tensor->name = node->name + "/input-" + std::to_string(i); 670+ } else { 671+ std::string in_tensor_name = CutShortName(tensor->name, delimiter); 672+ tensor->name = in_tensor_name != "" ? in_tensor_name : tensor->name; 673+ } 674+ } 675+ 676+ for (int i = 0; i < static_cast<int>(node->outputIndex.size()); i++) { 677+ auto tensor_id = node->outputIndex.at(i); 678+ auto &tensor = graph->allTensors.at(tensor_id); 679+ if (tensor->name.empty()) { 680+ tensor->name = node->name + "/output-" + std::to_string(i); 681+ } else { 682+ std::string out_tensor_name = CutShortName(tensor->name, delimiter); 683+ tensor->name = out_tensor_name != "" ? out_tensor_name : tensor->name; 684+ } 685+ } 686+ } 687+ return RET_OK; 688+} 689+} // namespace mindspore::lite 690\ No newline at end of file 691diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/node_name_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/node_name_pass.h 692new file mode 100644 693index 00000000..4e58e5c7 694--- /dev/null 695+++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/node_name_pass.h 696@@ -0,0 +1,35 @@ 697+/** 698+ * Copyright 2022 Huawei Technologies Co., Ltd 699+ * 700+ * Licensed under the Apache License, Version 2.0 (the "License"); 701+ * you may not use this file except in compliance with the License. 702+ * You may obtain a copy of the License at 703+ * 704+ * http://www.apache.org/licenses/LICENSE-2.0 705+ * 706+ * Unless required by applicable law or agreed to in writing, software 707+ * distributed under the License is distributed on an "AS IS" BASIS, 708+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 709+ * See the License for the specific language governing permissions and 710+ * limitations under the License. 711+ */ 712+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAPH_NODE_NAME_PASS_H_ 713+#define MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAPH_NODE_NAME_PASS_H_ 714+ 715+#include <memory> 716+#include "tools/converter/optimizer.h" 717+#include "tools/common/graph_util.h" 718+ 719+namespace mindspore { 720+namespace lite { 721+class NodeNamePass : public GraphPass { 722+ public: 723+ NodeNamePass() {} 724+ 725+ ~NodeNamePass() override = default; 726+ 727+ STATUS Run(schema::MetaGraphT *graph) override; 728+}; 729+} // namespace lite 730+} // namespace mindspore 731+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAPH_NODE_NAME_PASS_H_ 732