1From 98e86575c68939fecc1bc3c80be9fca0e080b7fa Mon Sep 17 00:00:00 2001 2From: z00574805 <z00574805@notesmail.huawei.com/> 3Date: Wed, 24 May 2023 11:04:25 +0800 4Subject: [PATCH 1/5] xiaoyi-0001 5 6--- 7 include/api/callback/callback.h | 22 +- 8 include/api/callback/ckpt_saver.h | 4 +- 9 include/api/callback/loss_monitor.h | 4 +- 10 include/api/callback/lr_scheduler.h | 10 +- 11 include/api/callback/time_monitor.h | 4 +- 12 include/api/callback/train_accuracy.h | 10 +- 13 include/api/cfg.h | 25 ++- 14 include/api/metrics/accuracy.h | 4 +- 15 include/api/metrics/metrics.h | 7 +- 16 include/api/model_parallel_runner.h | 4 +- 17 include/api/net.h | 23 +- 18 include/api/serialization.h | 20 +- 19 include/api/types.h | 6 +- 20 mindspore/ccsrc/cxx_api/serialization.cc | 3 +- 21 .../device/cpu/kernel/nnacl/fp32/div_fp32.c | 4 - 22 .../nnacl/fp32_grad/binary_cross_entropy.c | 45 ++-- 23 .../nnacl/fp32_grad/binary_cross_entropy.h | 2 +- 24 .../fp32_grad/binary_cross_entropy_grad.c | 34 ++- 25 .../fp32_grad/binary_cross_entropy_grad.h | 2 +- 26 .../nnacl/infer/binary_cross_entropy_infer.c | 4 +- 27 .../cpu/kernel/nnacl/infer/common_infer.c | 1 + 28 mindspore/lite/include/model.h | 7 +- 29 .../include/registry/opencl_runtime_wrapper.h | 4 +- 30 .../java/src/main/native/train_config.cpp | 6 +- 31 mindspore/lite/src/CMakeLists.txt | 22 +- 32 .../binary_cross_entropy_grad_populate.cc | 45 ---- 33 .../populate/binary_cross_entropy_populate.cc | 45 ---- 34 mindspore/lite/src/common/prim_util.h | 7 +- 35 mindspore/lite/src/common/tensor_util.h | 2 +- 36 mindspore/lite/src/extendrt/CMakeLists.txt | 4 + 37 .../src/extendrt/cxx_api/serialization.cc | 3 +- 38 .../lite/src/runtime/cxx_api/converters.h | 4 +- 39 .../src/runtime/cxx_api/model/model_impl.cc | 3 + 40 .../src/runtime/cxx_api/model/model_impl.h | 6 +- 41 .../lite/src/runtime/cxx_api/serialization.cc | 31 ++- 42 .../src/runtime/cxx_api/train/converters.cc | 6 +- 43 mindspore/lite/src/runtime/infer_manager.h | 10 +- 44 mindspore/lite/src/runtime/inner_context.h | 2 +- 45 .../runtime/kernel/cpu/base/argminmax_base.cc | 1 - 46 .../kernel/cpu/base/arithmetic_base.cc | 1 + 47 .../kernel/cpu/base/group_convolution_base.cc | 16 +- 48 .../cpu/base/group_convolution_creator.cc | 14 +- 49 .../runtime/kernel/cpu/base/reshape_base.cc | 1 + 50 .../runtime/kernel/cpu/base/strided_slice.cc | 2 + 51 .../kernel/cpu/fp16/fused_batchnorm_fp16.cc | 8 +- 52 .../kernel/cpu/fp32/fused_batchnorm_fp32.cc | 16 +- 53 .../runtime/kernel/cpu/fp32/oneslike_fp32.cc | 52 +++++ 54 .../runtime/kernel/cpu/fp32/oneslike_fp32.h | 46 ++++ 55 .../cpu/fp32_grad/binary_cross_entropy.cc | 120 +++++++++++ 56 .../cpu/fp32_grad/binary_cross_entropy.h | 42 ++++ 57 .../fp32_grad/binary_cross_entropy_grad.cc | 105 +++++++++ 58 .../cpu/fp32_grad/binary_cross_entropy_grad.h | 41 ++++ 59 .../runtime/kernel/gpu/opencl/CMakeLists.txt | 11 + 60 .../src/runtime/kernel/opencl/CMakeLists.txt | 8 - 61 mindspore/lite/src/runtime/kernel_exec_util.h | 2 +- 62 mindspore/lite/src/runtime/kernel_registry.h | 4 +- 63 mindspore/lite/src/runtime/lite_kernel.h | 4 +- 64 mindspore/lite/src/runtime/lite_model.h | 4 +- 65 mindspore/lite/src/runtime/lite_session.cc | 28 ++- 66 mindspore/lite/src/runtime/lite_session.h | 10 +- 67 mindspore/lite/src/runtime/weight_decoder.h | 4 +- 68 mindspore/lite/src/tensor.h | 2 +- 69 mindspore/lite/src/tensorlist.h | 2 +- 70 mindspore/lite/src/train/graph_fusion.cc | 4 + 71 .../train/optimizer/common/fusion_utils.cc | 37 ++++ 72 .../src/train/optimizer/common/fusion_utils.h | 50 +++++ 73 .../fusion/matmul_activation_fusion_pass.cc | 93 ++++++++ 74 .../fusion/matmul_activation_fusion_pass.h | 42 ++++ 75 .../reshape_gather_reshape_fusion_pass.cc | 148 +++++++++++++ 76 .../reshape_gather_reshape_fusion_pass.h | 42 ++++ 77 mindspore/lite/src/train/train_export.cc | 32 +++ 78 mindspore/lite/src/train/train_export.h | 4 + 79 .../src/train/train_populate_parameter.cc | 49 +++-- 80 mindspore/lite/src/train/train_session.cc | 80 ++++--- 81 mindspore/lite/src/train/train_session.h | 14 ++ 82 mindspore/lite/src/train/transfer_session.cc | 41 +++- 83 mindspore/lite/src/train/transfer_session.h | 5 + 84 .../lite/tools/benchmark_train/net_train.cc | 203 ++++++++++++++++-- 85 .../lite/tools/benchmark_train/net_train.h | 23 +- 86 .../lite/tools/converter/anf_transform.cc | 6 +- 87 mindspore/lite/tools/converter/converter.cc | 14 +- 88 .../tools/converter/graphdef_transform.cc | 4 + 89 .../legacy_optimizer/graph/CMakeLists.txt | 1 + 90 .../legacy_optimizer/graph/node_name_pass.cc | 96 +++++++++ 91 .../legacy_optimizer/graph/node_name_pass.h | 35 +++ 92 .../converter/parser/onnx/onnx_node_parser.cc | 10 + 93 .../lite/tools/lite_exporter/anf_exporter.cc | 14 +- 94 .../fusion/expanddims_reshape_fusion.cc | 73 +++++++ 95 .../fusion/expanddims_reshape_fusion.h | 40 ++++ 96 .../fusion/squeeze_expanddims_fusion.cc | 117 ++++++++++ 97 .../fusion/squeeze_expanddims_fusion.h | 40 ++++ 98 91 files changed, 1955 insertions(+), 351 deletions(-) 99 delete mode 100644 mindspore/lite/src/common/ops/populate/binary_cross_entropy_grad_populate.cc 100 delete mode 100644 mindspore/lite/src/common/ops/populate/binary_cross_entropy_populate.cc 101 create mode 100644 mindspore/lite/src/runtime/kernel/cpu/fp32/oneslike_fp32.cc 102 create mode 100644 mindspore/lite/src/runtime/kernel/cpu/fp32/oneslike_fp32.h 103 create mode 100644 mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy.cc 104 create mode 100644 mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy.h 105 create mode 100644 mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy_grad.cc 106 create mode 100644 mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy_grad.h 107 create mode 100644 mindspore/lite/src/runtime/kernel/gpu/opencl/CMakeLists.txt 108 delete mode 100644 mindspore/lite/src/runtime/kernel/opencl/CMakeLists.txt 109 create mode 100644 mindspore/lite/src/train/optimizer/common/fusion_utils.cc 110 create mode 100644 mindspore/lite/src/train/optimizer/common/fusion_utils.h 111 create mode 100644 mindspore/lite/src/train/optimizer/fusion/matmul_activation_fusion_pass.cc 112 create mode 100644 mindspore/lite/src/train/optimizer/fusion/matmul_activation_fusion_pass.h 113 create mode 100644 mindspore/lite/src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.cc 114 create mode 100644 mindspore/lite/src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.h 115 create mode 100644 mindspore/lite/tools/converter/legacy_optimizer/graph/node_name_pass.cc 116 create mode 100644 mindspore/lite/tools/converter/legacy_optimizer/graph/node_name_pass.h 117 create mode 100644 mindspore/lite/tools/optimizer/fusion/expanddims_reshape_fusion.cc 118 create mode 100644 mindspore/lite/tools/optimizer/fusion/expanddims_reshape_fusion.h 119 create mode 100644 mindspore/lite/tools/optimizer/fusion/squeeze_expanddims_fusion.cc 120 create mode 100644 mindspore/lite/tools/optimizer/fusion/squeeze_expanddims_fusion.h 121 122diff --git a/include/api/callback/callback.h b/include/api/callback/callback.h 123index 3332f819..60f30b80 100644 124--- a/include/api/callback/callback.h 125+++ b/include/api/callback/callback.h 126@@ -1,5 +1,5 @@ 127 /** 128- * Copyright 2021 Huawei Technologies Co., Ltd 129+ * Copyright 2021-2023 Huawei Technologies Co., Ltd 130 * 131 * Licensed under the Apache License, Version 2.0 (the "License"); 132 * you may not use this file except in compliance with the License. 133@@ -23,6 +23,7 @@ 134 #include <utility> 135 #include "include/api/data_type.h" 136 #include "include/api/dual_abi_helper.h" 137+#include "include/api/types.h" 138 139 namespace mindspore { 140 class Model; 141@@ -31,24 +32,19 @@ class CallbackImpl; 142 143 using GraphPoint = std::pair<int, float>; 144 145-struct TrainCallBackData { 146- TrainCallBackData(bool train_mode, int epoch, int step, Model *model): train_mode_(train_mode), epoch_(epoch), 147- step_(step), model_(model) {} 148+struct MS_API TrainCallBackData { 149+ TrainCallBackData(bool train_mode, int epoch, int step, Model *model) 150+ : train_mode_(train_mode), epoch_(epoch), step_(step), model_(model) {} 151 152 bool train_mode_; /**< training mode of LiteSession object */ 153 unsigned int epoch_; /**< the current training epoch (starts at 0) */ 154 unsigned int step_ = 0; /**< the current step within the epoch */ 155- Model *model_; /**< pointer to the Model object */ 156+ Model *model_; /**< pointer to the Model object */ 157 }; 158 159-enum CallbackRetValue : uint32_t { 160- kContinue = 0, 161- kStopTraining = 1, 162- kExit = 2, 163- kUnknownRetValue = 0xFFFFFFFF 164-}; 165+enum CallbackRetValue : uint32_t { kContinue = 0, kStopTraining = 1, kExit = 2, kUnknownRetValue = 0xFFFFFFFF }; 166 167-class TrainCallBack { 168+class MS_API TrainCallBack { 169 public: 170 virtual ~TrainCallBack() = default; 171 172@@ -90,7 +86,7 @@ class TrainCallBack { 173 protected: 174 friend class Model; 175 friend class ModelImpl; 176- CallbackImpl* callback_impl_ = nullptr; 177+ CallbackImpl *callback_impl_ = nullptr; 178 }; 179 180 } // namespace mindspore 181diff --git a/include/api/callback/ckpt_saver.h b/include/api/callback/ckpt_saver.h 182index e673c624..d9ff2f69 100644 183--- a/include/api/callback/ckpt_saver.h 184+++ b/include/api/callback/ckpt_saver.h 185@@ -1,5 +1,5 @@ 186 /** 187- * Copyright 2021 Huawei Technologies Co., Ltd 188+ * Copyright 2021-2023 Huawei Technologies Co., Ltd 189 * 190 * Licensed under the Apache License, Version 2.0 (the "License"); 191 * you may not use this file except in compliance with the License. 192@@ -25,7 +25,7 @@ 193 194 namespace mindspore { 195 196-class CkptSaver: public TrainCallBack { 197+class MS_API CkptSaver : public TrainCallBack { 198 public: 199 inline CkptSaver(int save_every_n, const std::string &filename_prefix); 200 virtual ~CkptSaver(); 201diff --git a/include/api/callback/loss_monitor.h b/include/api/callback/loss_monitor.h 202index 9e0a8247..7efd0ca7 100644 203--- a/include/api/callback/loss_monitor.h 204+++ b/include/api/callback/loss_monitor.h 205@@ -1,5 +1,5 @@ 206 /** 207- * Copyright 2021 Huawei Technologies Co., Ltd 208+ * Copyright 2021-2023 Huawei Technologies Co., Ltd 209 * 210 * Licensed under the Apache License, Version 2.0 (the "License"); 211 * you may not use this file except in compliance with the License. 212@@ -23,7 +23,7 @@ 213 214 namespace mindspore { 215 216-class LossMonitor: public TrainCallBack { 217+class MS_API LossMonitor: public TrainCallBack { 218 public: 219 explicit LossMonitor(int print_every_n_steps = INT_MAX); 220 virtual ~LossMonitor(); 221diff --git a/include/api/callback/lr_scheduler.h b/include/api/callback/lr_scheduler.h 222index 2eddc66b..11fd7124 100644 223--- a/include/api/callback/lr_scheduler.h 224+++ b/include/api/callback/lr_scheduler.h 225@@ -1,5 +1,5 @@ 226 /** 227- * Copyright 2021 Huawei Technologies Co., Ltd 228+ * Copyright 2021-2023 Huawei Technologies Co., Ltd 229 * 230 * Licensed under the Apache License, Version 2.0 (the "License"); 231 * you may not use this file except in compliance with the License. 232@@ -30,18 +30,18 @@ constexpr int UPDATE_LR = 1; 233 using LR_Lambda = std::function<int(float *lr, int epoch, void *cb_data)>; 234 235 /// \brief Multiply the LR by a factor of gamma every epoch 236-int MultiplicativeLRLambda(float *lr, int epoch, void *multiplication); 237+MS_API int MultiplicativeLRLambda(float *lr, int epoch, void *multiplication); 238 239 /// \brief Multiply the LR by a factor of gamma every step_size 240-int StepLRLambda(float *lr, int epoch, void *step_size); 241-struct StepLRLambda { 242+MS_API int StepLRLambda(float *lr, int epoch, void *step_size); 243+struct MS_API StepLRLambda { 244 StepLRLambda(int step, float g) : step_size(step), gamma(g) {} 245 246 int step_size; // period of LR decay 247 float gamma; // LR decay factor 248 }; 249 250-class LRScheduler: public TrainCallBack { 251+class MS_API LRScheduler : public TrainCallBack { 252 public: 253 explicit LRScheduler(LR_Lambda lambda_func, void *lr_cb_data = nullptr, int step = 1); 254 virtual ~LRScheduler(); 255diff --git a/include/api/callback/time_monitor.h b/include/api/callback/time_monitor.h 256index 7e857849..45e48248 100644 257--- a/include/api/callback/time_monitor.h 258+++ b/include/api/callback/time_monitor.h 259@@ -1,5 +1,5 @@ 260 /** 261- * Copyright 2021 Huawei Technologies Co., Ltd 262+ * Copyright 2021-2023 Huawei Technologies Co., Ltd 263 * 264 * Licensed under the Apache License, Version 2.0 (the "License"); 265 * you may not use this file except in compliance with the License. 266@@ -24,7 +24,7 @@ 267 268 namespace mindspore { 269 270-class TimeMonitor: public TrainCallBack { 271+class MS_API TimeMonitor : public TrainCallBack { 272 public: 273 virtual ~TimeMonitor() = default; 274 void EpochBegin(const TrainCallBackData &cb_data) override; 275diff --git a/include/api/callback/train_accuracy.h b/include/api/callback/train_accuracy.h 276index 5838dfd9..16774dd7 100644 277--- a/include/api/callback/train_accuracy.h 278+++ b/include/api/callback/train_accuracy.h 279@@ -1,5 +1,5 @@ 280 /** 281- * Copyright 2021 Huawei Technologies Co., Ltd 282+ * Copyright 2021-2023 Huawei Technologies Co., Ltd 283 * 284 * Licensed under the Apache License, Version 2.0 (the "License"); 285 * you may not use this file except in compliance with the License. 286@@ -26,12 +26,10 @@ 287 288 namespace mindspore { 289 290-class TrainAccuracy: public TrainCallBack { 291+class MS_API TrainAccuracy : public TrainCallBack { 292 public: 293- explicit TrainAccuracy(int print_every_n = INT_MAX, 294- int accuracy_metrics = METRICS_CLASSIFICATION, 295- const std::vector<int> &input_indexes = {1}, 296- const std::vector<int> &output_indexes = {0}); 297+ explicit TrainAccuracy(int print_every_n = INT_MAX, int accuracy_metrics = METRICS_CLASSIFICATION, 298+ const std::vector<int> &input_indexes = {1}, const std::vector<int> &output_indexes = {0}); 299 virtual ~TrainAccuracy(); 300 const std::vector<GraphPoint> &GetAccuracyPoints(); 301 }; 302diff --git a/include/api/cfg.h b/include/api/cfg.h 303index db915cac..8dc37bb4 100644 304--- a/include/api/cfg.h 305+++ b/include/api/cfg.h 306@@ -1,5 +1,5 @@ 307 /** 308- * Copyright 2022 Huawei Technologies Co., Ltd 309+ * Copyright 2022-2023 Huawei Technologies Co., Ltd 310 * 311 * Licensed under the Apache License, Version 2.0 (the "License"); 312 * you may not use this file except in compliance with the License. 313@@ -26,7 +26,7 @@ 314 315 namespace mindspore { 316 constexpr int iter_th = 1000; 317-class MixPrecisionCfg { 318+class MS_API MixPrecisionCfg { 319 public: 320 MixPrecisionCfg() { 321 this->dynamic_loss_scale_ = false; 322@@ -49,7 +49,7 @@ class MixPrecisionCfg { 323 bool is_raw_mix_precision_ = false; /**< Is mix precision model export from mindspore */ 324 }; 325 326-class TrainCfg { 327+class MS_API TrainCfg { 328 public: 329 TrainCfg() = default; 330 TrainCfg(const TrainCfg &rhs) { 331@@ -59,11 +59,24 @@ class TrainCfg { 332 } 333 ~TrainCfg() = default; 334 335+ /// \brief obtain part of the name that identify a loss kernel. 336+ /// 337+ /// \return loss_name. 338+ inline std::vector<std::string> GetLossName() const; 339+ /// \brief Set part of the name that identify a loss kernel. 340+ /// 341+ /// \param[in] loss_name define part of the name that identify a loss kernel. 342+ inline void SetLossName(const std::vector<std::string> &loss_name); 343+ 344 OptimizationLevel optimization_level_ = kO0; 345- std::vector<std::string> loss_name_ = { 346- "loss_fct", "_loss_fn", "SigmoidCrossEntropy"}; /**< Set part of the name that identify a loss kernel */ 347- MixPrecisionCfg mix_precision_cfg_; /**< Mix precision configuration */ 348+ MixPrecisionCfg mix_precision_cfg_; /**< Mix precision configuration */ 349 bool accumulate_gradients_ = false; 350+ 351+ private: 352+ std::vector<std::vector<char>> loss_name_ = VectorStringToChar({"loss_fct", "_loss_fn", "SigmoidCrossEntropy"}); 353 }; 354+ 355+std::vector<std::string> TrainCfg::GetLossName() const { return VectorCharToString(loss_name_); } 356+void TrainCfg::SetLossName(const std::vector<std::string> &loss_name) { loss_name_ = VectorStringToChar(loss_name); } 357 } // namespace mindspore 358 #endif // MINDSPORE_INCLUDE_API_CFG_H 359diff --git a/include/api/metrics/accuracy.h b/include/api/metrics/accuracy.h 360index 1d1732f3..4aefc3b5 100644 361--- a/include/api/metrics/accuracy.h 362+++ b/include/api/metrics/accuracy.h 363@@ -1,5 +1,5 @@ 364 /** 365- * Copyright 2021 Huawei Technologies Co., Ltd 366+ * Copyright 2021-2023 Huawei Technologies Co., Ltd 367 * 368 * Licensed under the Apache License, Version 2.0 (the "License"); 369 * you may not use this file except in compliance with the License. 370@@ -23,7 +23,7 @@ namespace mindspore { 371 constexpr int METRICS_CLASSIFICATION = 0; 372 constexpr int METRICS_MULTILABEL = 1; 373 374-class AccuracyMetrics : public Metrics { 375+class MS_API AccuracyMetrics : public Metrics { 376 public: 377 explicit AccuracyMetrics(int accuracy_metrics = METRICS_CLASSIFICATION, const std::vector<int> &input_indexes = {1}, 378 const std::vector<int> &output_indexes = {0}); 379diff --git a/include/api/metrics/metrics.h b/include/api/metrics/metrics.h 380index 7154332f..36eb4ed1 100644 381--- a/include/api/metrics/metrics.h 382+++ b/include/api/metrics/metrics.h 383@@ -1,5 +1,5 @@ 384 /** 385- * Copyright 2021 Huawei Technologies Co., Ltd 386+ * Copyright 2021-2023 Huawei Technologies Co., Ltd 387 * 388 * Licensed under the Apache License, Version 2.0 (the "License"); 389 * you may not use this file except in compliance with the License. 390@@ -24,16 +24,17 @@ class MetricsImpl; 391 class ModelImpl; 392 class MSTensor; 393 394-class Metrics { 395+class MS_API Metrics { 396 public: 397 virtual ~Metrics() = default; 398 virtual void Clear() {} 399 virtual float Eval() { return 0.0; } 400 virtual void Update(std::vector<MSTensor *> inputs, std::vector<MSTensor *> outputs) {} 401+ 402 protected: 403 friend class Model; 404 friend class ModelImpl; 405- MetricsImpl* metrics_impl_; 406+ MetricsImpl *metrics_impl_; 407 }; 408 409 } // namespace mindspore 410diff --git a/include/api/model_parallel_runner.h b/include/api/model_parallel_runner.h 411index 159f4cea..360405b9 100644 412--- a/include/api/model_parallel_runner.h 413+++ b/include/api/model_parallel_runner.h 414@@ -1,5 +1,5 @@ 415 /** 416- * Copyright 2022 Huawei Technologies Co., Ltd 417+ * Copyright 2022-2023 Huawei Technologies Co., Ltd 418 * 419 * Licensed under the Apache License, Version 2.0 (the "License"); 420 * you may not use this file except in compliance with the License. 421@@ -25,7 +25,7 @@ 422 namespace mindspore { 423 /// \brief The RunnerConfig class is used to store environment variables during execution 424 /// management. 425-class RunnerConfig { 426+class MS_API RunnerConfig { 427 public: 428 struct Data; 429 RunnerConfig(); 430diff --git a/include/api/net.h b/include/api/net.h 431index c7a3a9b0..61990ae0 100644 432--- a/include/api/net.h 433+++ b/include/api/net.h 434@@ -1,5 +1,5 @@ 435 /** 436- * Copyright 2022 Huawei Technologies Co., Ltd 437+ * Copyright 2022-2023 Huawei Technologies Co., Ltd 438 * 439 * Licensed under the Apache License, Version 2.0 (the "License"); 440 * you may not use this file except in compliance with the License. 441@@ -36,14 +36,14 @@ class NodeSet; 442 class Graph; 443 class NetData; 444 445-class NetBase { 446+class MS_API NetBase { 447 public: 448 NetBase() = default; 449 virtual std::vector<Expr *> operator()(const std::vector<Expr *> &inputs) = 0; 450 virtual uint32_t type() = 0; 451 }; 452 453-class Node : public NetBase { 454+class MS_API Node : public NetBase { 455 public: 456 Node(); 457 virtual ~Node(); 458@@ -65,7 +65,7 @@ class Node : public NetBase { 459 std::shared_ptr<NodeImpl> impl_ = nullptr; 460 }; 461 462-class Net : public NetBase, public std::enable_shared_from_this<Net> { 463+class MS_API Net : public NetBase, public std::enable_shared_from_this<Net> { 464 public: 465 Net(); 466 virtual ~Net(); 467@@ -116,12 +116,12 @@ class Net : public NetBase, public std::enable_shared_from_this<Net> { 468 std::shared_ptr<NetImpl> impl_; 469 }; 470 471-class SoftMaxCrossEntropyCfg { 472+class MS_API SoftMaxCrossEntropyCfg { 473 public: 474 std::string reduction = "mean"; /**< Specifies reduction mode. The optional values are "none", "mean", "sum" */ 475 }; 476 477-class AdamConfig { 478+class MS_API AdamConfig { 479 public: 480 float learning_rate_ = 1e-3; 481 float beta1_ = 0.9; 482@@ -131,11 +131,12 @@ class AdamConfig { 483 }; 484 485 namespace NN { 486-Net *NetWithLoss(Net *net, Node *loss); 487-Graph *GraphWithLoss(Graph *g, Node *loss); 488-Node *Adam(std::shared_ptr<NodeSet> learn, const AdamConfig &cfg); 489-Node *SoftmaxCrossEntropy(const SoftMaxCrossEntropyCfg &cfg); 490-std::unique_ptr<Node> Input(std::vector<int> dims, DataType data_type = DataType::kNumberTypeFloat32, int fmt = NHWC); 491+MS_API Net *NetWithLoss(Net *net, Node *loss); 492+MS_API Graph *GraphWithLoss(Graph *g, Node *loss); 493+MS_API Node *Adam(std::shared_ptr<NodeSet> learn, const AdamConfig &cfg); 494+MS_API Node *SoftmaxCrossEntropy(const SoftMaxCrossEntropyCfg &cfg); 495+MS_API std::unique_ptr<Node> Input(std::vector<int> dims, DataType data_type = DataType::kNumberTypeFloat32, 496+ int fmt = NHWC); 497 }; // namespace NN 498 } // namespace mindspore 499 #endif // MINDSPORE_INCLUDE_API_NET_H 500diff --git a/include/api/serialization.h b/include/api/serialization.h 501index 2dc9d028..1a0c1f57 100644 502--- a/include/api/serialization.h 503+++ b/include/api/serialization.h 504@@ -79,10 +79,16 @@ class MS_API Serialization { 505 /// 506 /// \param[in] model The model data. 507 /// \param[in] model_type The model file type. 508- /// \param[out] model_data The model parameter data. 509+ /// \param[out] model_data The model buffer. 510+ /// \param[in] quantization_type The quantification type. 511+ /// \param[in] export_inference_only Whether to export a reasoning only model. 512+ /// \param[in] output_tensor_name The set the name of the output tensor of the exported reasoning model, default as 513+ /// empty, and export the complete reasoning model. 514 /// 515 /// \return Status. 516- static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data); 517+ inline static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data, 518+ QuantizationType quantization_type = kNoQuant, bool export_inference_only = true, 519+ const std::vector<std::string> &output_tensor_name = {}); 520 521 /// \brief Export training model from file. 522 /// 523@@ -110,6 +116,9 @@ class MS_API Serialization { 524 static Status ExportModel(const Model &model, ModelType model_type, const std::vector<char> &model_file, 525 QuantizationType quantization_type, bool export_inference_only, 526 const std::vector<std::vector<char>> &output_tensor_name); 527+ static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data, 528+ QuantizationType quantization_type, bool export_inference_only, 529+ const std::vector<std::vector<char>> &output_tensor_name); 530 }; 531 532 Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph, 533@@ -134,5 +143,12 @@ Status Serialization::ExportModel(const Model &model, ModelType model_type, cons 534 VectorStringToChar(output_tensor_name)); 535 } 536 537+Status Serialization::ExportModel(const Model &model, ModelType model_type, Buffer *model_data, 538+ QuantizationType quantization_type, bool export_inference_only, 539+ const std::vector<std::string> &output_tensor_name) { 540+ return ExportModel(model, model_type, model_data, quantization_type, export_inference_only, 541+ VectorStringToChar(output_tensor_name)); 542+} 543+ 544 } // namespace mindspore 545 #endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H 546diff --git a/include/api/types.h b/include/api/types.h 547index 6cf04523..377b5db0 100644 548--- a/include/api/types.h 549+++ b/include/api/types.h 550@@ -1,5 +1,5 @@ 551 /** 552- * Copyright 2020 Huawei Technologies Co., Ltd 553+ * Copyright 2020-2023 Huawei Technologies Co., Ltd 554 * 555 * Licensed under the Apache License, Version 2.0 (the "License"); 556 * you may not use this file except in compliance with the License. 557@@ -350,7 +350,7 @@ std::string MSTensor::Name() const { return CharToString(CharName()); } 558 559 void MSTensor::SetTensorName(const std::string &name) { return SetTensorName(StringToChar(name)); } 560 561-using Key = struct Key { 562+using Key = struct MS_API Key { 563 const size_t max_key_len = 32; 564 size_t len = 0; 565 unsigned char key[32] = {0}; 566@@ -371,7 +371,7 @@ struct MSCallBackParam { 567 using MSKernelCallBack = std::function<bool(const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs, 568 const MSCallBackParam &opInfo)>; 569 570-std::vector<char> CharVersion(); 571+MS_API std::vector<char> CharVersion(); 572 inline std::string Version() { return CharToString(CharVersion()); } 573 574 } // namespace mindspore 575diff --git a/mindspore/ccsrc/cxx_api/serialization.cc b/mindspore/ccsrc/cxx_api/serialization.cc 576index 1ea95935..bf33e85d 100644 577--- a/mindspore/ccsrc/cxx_api/serialization.cc 578+++ b/mindspore/ccsrc/cxx_api/serialization.cc 579@@ -334,7 +334,8 @@ Status Serialization::SetParameters(const std::map<std::string, Buffer> &, Model 580 return kMEFailed; 581 } 582 583-Status Serialization::ExportModel(const Model &, ModelType, Buffer *) { 584+Status Serialization::ExportModel(const Model &, ModelType, Buffer *, QuantizationType, bool, 585+ const std::vector<std::vector<char>> & /* output_tensor_name */) { 586 MS_LOG(ERROR) << "Unsupported feature."; 587 return kMEFailed; 588 } 589diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/div_fp32.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/div_fp32.c 590index f6fa5994..60a27df1 100644 591--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/div_fp32.c 592+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/div_fp32.c 593@@ -29,10 +29,6 @@ int ElementOptDiv(const float *in0, const float *in1, float *out, int size, cons 594 out[index] = in0[0] / in1[index]; 595 } 596 } else { 597- if (in1[0] == 0) { 598- return NNACL_ERRCODE_DIVISOR_ZERO; 599- } 600- 601 SIMD_RUN_NO_SCALAR(ElementOptDivNum1, index, in0, in1, out, size); 602 for (; index < size; index++) { 603 out[index] = in0[index] / in1[0]; 604diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy.c 605index 2db54161..cf2f867c 100644 606--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy.c 607+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy.c 608@@ -18,28 +18,44 @@ 609 #include "nnacl/fp32_grad/binary_cross_entropy.h" 610 611 static void BinaryCrossEntropyLossKernel(const int input_size, const int reduction, const float *input_x, 612- const float *input_y, const float *weight, float *loss, float *tmp_loss) { 613+ const float *input_y, const float *weight, float *loss, float *tmp_loss, 614+ bool weight_defined) { 615 const float epsilon = 1e-12; 616- if (reduction == 0) { 617- for (int i = 0; i < input_size; i++) { 618- float value = 619- -weight[i] * (input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon)); 620- loss[i] = value; 621+ 622+ if (reduction == Reduction_None) { 623+ if (weight_defined) { 624+ for (int i = 0; i < input_size; i++) { 625+ float value = 626+ -weight[i] * (input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon)); 627+ loss[i] = value; 628+ } 629+ } else { 630+ for (int i = 0; i < input_size; i++) { 631+ float value = -(input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon)); 632+ loss[i] = value; 633+ } 634 } 635 } else { 636- for (int i = 0; i < input_size; i++) { 637- float value = 638- -weight[i] * (input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon)); 639- tmp_loss[i] = value; 640+ if (weight_defined) { 641+ for (int i = 0; i < input_size; i++) { 642+ float value = 643+ -weight[i] * (input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon)); 644+ tmp_loss[i] = value; 645+ } 646+ } else { 647+ for (int i = 0; i < input_size; i++) { 648+ float value = -(input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon)); 649+ tmp_loss[i] = value; 650+ } 651 } 652 } 653 } 654 655 void BinaryCrossEntropy(const int input_size, const int reduction, const float *input_x, const float *input_y, 656- const float *weight, float *loss, float *tmp_loss) { 657+ const float *weight, float *loss, float *tmp_loss, bool weight_defined) { 658 loss[0] = 0.0f; 659- BinaryCrossEntropyLossKernel(input_size, reduction, input_x, input_y, weight, loss, tmp_loss); 660- if (reduction != 0) { 661+ BinaryCrossEntropyLossKernel(input_size, reduction, input_x, input_y, weight, loss, tmp_loss, weight_defined); 662+ if (reduction != Reduction_None) { 663 if (input_size % 2 == 1) { 664 tmp_loss[0] += tmp_loss[input_size - 1]; 665 } 666@@ -47,13 +63,12 @@ void BinaryCrossEntropy(const int input_size, const int reduction, const float * 667 for (int i = 0; i < stride; i++) { 668 tmp_loss[i] += tmp_loss[i + stride]; 669 } 670- 671 if (stride > 2 && stride % 2 == 1) { 672 tmp_loss[0] += tmp_loss[stride - 1]; 673 } 674 } 675 loss[0] += tmp_loss[0]; 676- if (reduction == 1) { 677+ if (reduction == Reduction_Mean) { 678 loss[0] /= input_size; 679 } 680 } 681diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy.h 682index 6ba6422d..abf6e63b 100644 683--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy.h 684+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy.h 685@@ -28,7 +28,7 @@ extern "C" { 686 #endif 687 688 void BinaryCrossEntropy(const int input_size, const int reduction, const float *input_x, const float *input_y, 689- const float *weight, float *loss, float *tmp_loss); 690+ const float *weight, float *loss, float *tmp_loss, bool weight_defined); 691 692 #ifdef __cplusplus 693 } 694diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy_grad.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy_grad.c 695index 95e28c8c..12d20356 100644 696--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy_grad.c 697+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy_grad.c 698@@ -19,23 +19,37 @@ 699 #define MAX(a, b) ((a) > (b) ? (a) : (b)) 700 701 int BinaryCrossEntropyGrad(const int input_size, const int reduction, const float *input_x, const float *input_y, 702- const float *weight, const float *dloss, float *dx) { 703+ const float *weight, const float *dloss, float *dx, bool weight_defined) { 704 const float epsilon = 1e-12f; 705- if (reduction == 0) { 706- for (int i = 0; i < input_size; i++) { 707- float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon); 708- float value = weight[i] * (input_x[i] - input_y[i]) / denominator; 709- dx[i] = value * dloss[i]; 710+ if (reduction == Reduction_None) { 711+ if (weight_defined) { 712+ for (int i = 0; i < input_size; i++) { 713+ float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon); 714+ float value = weight[i] * (input_x[i] - input_y[i]) / denominator; 715+ dx[i] = value * dloss[i]; 716+ } 717+ } else { 718+ for (int i = 0; i < input_size; i++) { 719+ float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon); 720+ float value = (input_x[i] - input_y[i]) / denominator; 721+ dx[i] = value * dloss[i]; 722+ } 723 } 724 } else { 725 float dloss1 = dloss[0]; 726- if (reduction == 1) { 727+ if (reduction == Reduction_Mean) { 728 dloss1 = dloss[0] / input_size; 729 } 730 for (int i = 0; i < input_size; i++) { 731- float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon); 732- float value = weight[i] * (input_x[i] - input_y[i]) / denominator; 733- dx[i] = value * dloss1; 734+ if (weight_defined) { 735+ float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon); 736+ float value = weight[i] * (input_x[i] - input_y[i]) / denominator; 737+ dx[i] = value * dloss1; 738+ } else { 739+ float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon); 740+ float value = (input_x[i] - input_y[i]) / denominator; 741+ dx[i] = value * dloss1; 742+ } 743 } 744 } 745 return 0; 746diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy_grad.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy_grad.h 747index f3506f4f..3033fa98 100644 748--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy_grad.h 749+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy_grad.h 750@@ -28,7 +28,7 @@ extern "C" { 751 #endif 752 753 int BinaryCrossEntropyGrad(const int input_size, const int reduction, const float *input_x, const float *input_y, 754- const float *weight, const float *dloss, float *dx); 755+ const float *weight, const float *dloss, float *dx, bool weight_defined); 756 757 #ifdef __cplusplus 758 } 759diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/binary_cross_entropy_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/binary_cross_entropy_infer.c 760index e280ad2e..22e207ac 100644 761--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/binary_cross_entropy_infer.c 762+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/binary_cross_entropy_infer.c 763@@ -27,8 +27,8 @@ int BinaryCrossEntropyInferShape(const TensorC *const *inputs, size_t inputs_siz 764 TensorC *out = outputs[0]; 765 SetDataTypeFormat(out, x); 766 BinaryCrossEntropyParameter *param = (BinaryCrossEntropyParameter *)parameter; 767- int reduction = param->reduction; 768- if (reduction == 1 || reduction == 2) { 769+ ReductionType reduction = (ReductionType)(param->reduction); 770+ if (reduction == Reduction_Mean || reduction == Reduction_Sum) { 771 out->shape_size_ = 1; 772 out->shape_[0] = 1; 773 } else { 774diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/common_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/common_infer.c 775index 3073385f..875c3bc0 100644 776--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/common_infer.c 777+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/common_infer.c 778@@ -237,6 +237,7 @@ REG_INFER(LRN, PrimType_LRN, CommonInferShapeWithNHWC) 779 REG_INFER(L2Normalize, PrimType_L2NormalizeFusion, CommonInferShape) 780 REG_INFER(Neg, PrimType_Neg, CommonInferShape) 781 REG_INFER(NegGrad, PrimType_NegGrad, CommonGradInferShape) 782+REG_INFER(OnesLike, PrimType_OnesLike, CommonInferShape) 783 REG_INFER(PowerGrad, PrimType_PowerGrad, CommonGradInferShape) 784 REG_INFER(PReLU, PrimType_PReLUFusion, CommonInferShape) 785 REG_INFER(Reciprocal, PrimType_Reciprocal, CommonInferShape) 786diff --git a/mindspore/lite/include/model.h b/mindspore/lite/include/model.h 787index a54904c8..a635e745 100644 788--- a/mindspore/lite/include/model.h 789+++ b/mindspore/lite/include/model.h 790@@ -1,5 +1,5 @@ 791 /** 792- * Copyright 2020 Huawei Technologies Co., Ltd 793+ * Copyright 2020-2023 Huawei Technologies Co., Ltd 794 * 795 * Licensed under the Apache License, Version 2.0 (the "License"); 796 * you may not use this file except in compliance with the License. 797@@ -19,6 +19,7 @@ 798 #include <memory> 799 #include <string> 800 #include <vector> 801+#include "include/api/types.h" 802 803 namespace mindspore { 804 namespace schema { 805@@ -30,7 +31,7 @@ typedef enum { ModelType_MSLite, ModelType_MindIR } LiteModelType; 806 807 // LiteGraph can be considered as a light weight and subset of FuncGraph, it can not support the advanced expression of 808 // FuncGraph, e.g., non-tail recursive. 809-struct LiteGraph { 810+struct MS_API LiteGraph { 811 struct Node { 812 std::string name_; 813 std::string op_type_; 814@@ -66,7 +67,7 @@ struct LiteGraph { 815 std::string ToString() const; 816 }; 817 818-struct Model { 819+struct MS_API Model { 820 LiteGraph graph_; 821 char *buf = nullptr; 822 size_t buf_size_ = 0; 823diff --git a/mindspore/lite/include/registry/opencl_runtime_wrapper.h b/mindspore/lite/include/registry/opencl_runtime_wrapper.h 824index 5b95a8a3..b55554e4 100644 825--- a/mindspore/lite/include/registry/opencl_runtime_wrapper.h 826+++ b/mindspore/lite/include/registry/opencl_runtime_wrapper.h 827@@ -1,5 +1,5 @@ 828 /** 829- * Copyright 2021 Huawei Technologies Co., Ltd 830+ * Copyright 2021-2023 Huawei Technologies Co., Ltd 831 * 832 * Licensed under the Apache License, Version 2.0 (the "License"); 833 * you may not use this file except in compliance with the License. 834@@ -30,7 +30,7 @@ 835 #include "include/api/dual_abi_helper.h" 836 837 namespace mindspore::registry::opencl { 838-class OpenCLRuntimeWrapper { 839+class MS_API OpenCLRuntimeWrapper { 840 public: 841 OpenCLRuntimeWrapper() = default; 842 ~OpenCLRuntimeWrapper() = default; 843diff --git a/mindspore/lite/java/src/main/native/train_config.cpp b/mindspore/lite/java/src/main/native/train_config.cpp 844index 4177e96b..d2452acf 100644 845--- a/mindspore/lite/java/src/main/native/train_config.cpp 846+++ b/mindspore/lite/java/src/main/native/train_config.cpp 847@@ -1,5 +1,5 @@ 848 /** 849- * Copyright 2021 Huawei Technologies Co., Ltd 850+ * Copyright 2021-2023 Huawei Technologies Co., Ltd 851 * 852 * Licensed under the Apache License, Version 2.0 (the "License"); 853 * you may not use this file except in compliance with the License. 854@@ -50,7 +50,9 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_config_TrainCfg_createTrai 855 return (jlong) nullptr; 856 } 857 if (loss_name != nullptr) { 858- traincfg_ptr->loss_name_.emplace_back(env->GetStringUTFChars(loss_name, JNI_FALSE)); 859+ std::vector<std::string> traincfg_loss_name = traincfg_ptr->GetLossName(); 860+ traincfg_loss_name.emplace_back(env->GetStringUTFChars(loss_name, JNI_FALSE)); 861+ traincfg_ptr->SetLossName(traincfg_loss_name); 862 } 863 traincfg_ptr->optimization_level_ = ol; 864 traincfg_ptr->accumulate_gradients_ = accmulateGrads; 865diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt 866index 16ae2e63..48e0fe7c 100644 867--- a/mindspore/lite/src/CMakeLists.txt 868+++ b/mindspore/lite/src/CMakeLists.txt 869@@ -50,6 +50,11 @@ if(APPLE OR PLATFORM_ARM32 OR PLATFORM_ARM64) 870 -fdata-sections -ffast-math -fno-rtti -fno-exceptions -Wno-shorten-64-to-32 \ 871 -fno-aligned-allocation -DTARGET_OS_OSX") 872 endif() 873+ if("${CMAKE_BUILD_TYPE}" STREQUAL "Release" AND NOT MSLITE_ENABLE_TESTCASES) 874+ set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fvisibility-inlines-hidden -fvisibility=hidden") 875+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility-inlines-hidden -fvisibility=hidden") 876+ set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--gc-sections") 877+ endif() 878 elseif(NOT MSVC) 879 if("${CMAKE_BUILD_TYPE}" STREQUAL "Release") 880 set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fomit-frame-pointer -fstrict-aliasing -ffunction-sections \ 881@@ -312,16 +317,6 @@ set(LITE_SRC 882 ${CMAKE_CURRENT_SOURCE_DIR}/runtime/weight_decoder.cc 883 ) 884 885-if(MSLITE_GPU_BACKEND STREQUAL opencl) 886- file(GLOB_RECURSE OPENCL_RUNTIME_SRC 887- ${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/gpu/opencl/*.cc 888- ) 889- set(LITE_SRC 890- ${LITE_SRC} 891- ${OPENCL_RUNTIME_SRC} 892- ) 893-endif() 894- 895 if(MSLITE_GPU_BACKEND STREQUAL cuda) 896 file(GLOB CUDA_RUNTIME_SRC 897 ${CMAKE_CURRENT_SOURCE_DIR}/runtime/gpu/*.cc 898@@ -384,6 +379,9 @@ set(TRAIN_SRC 899 ${CMAKE_CURRENT_SOURCE_DIR}/train/classification_train_accuracy_monitor.cc 900 ${CMAKE_CURRENT_SOURCE_DIR}/train/train_export.cc 901 ${CMAKE_CURRENT_SOURCE_DIR}/train/opt_allocator.cc 902+ ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/common/fusion_utils.cc 903+ ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/fusion/matmul_activation_fusion_pass.cc 904+ ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.cc 905 ${CMAKE_CURRENT_SOURCE_DIR}/common/storage.cc 906 ${TOOLS_DIR}/converter/optimizer.cc 907 ${TOOLS_DIR}/converter/legacy_optimizer/fusion/fusion_pass.cc 908@@ -393,7 +391,7 @@ set(TRAIN_SRC 909 ${TOOLS_DIR}/converter/legacy_optimizer/graph/dropout_node_remove_pass.cc 910 ${TOOLS_DIR}/converter/legacy_optimizer/graph/isolated_node_remove_pass.cc 911 ${TOOLS_DIR}/converter/legacy_optimizer/graph/subgraph_node_pass.cc 912- ) 913+ train/optimizer/fusion/reshape_gather_reshape_fusion_pass.cc) 914 915 if(MSLITE_ENABLE_MINDRT) 916 add_subdirectory(${CORE_DIR}/mindrt mindspore_mindrt) 917@@ -527,7 +525,7 @@ else() 918 endif() 919 920 if(MSLITE_GPU_BACKEND STREQUAL opencl) 921- add_subdirectory(runtime/kernel/opencl) 922+ add_subdirectory(runtime/kernel/gpu/opencl) 923 target_link_libraries(mindspore-lite opencl_kernel_mid) 924 target_link_libraries(mindspore-lite_static opencl_kernel_mid) 925 endif() 926diff --git a/mindspore/lite/src/common/ops/populate/binary_cross_entropy_grad_populate.cc b/mindspore/lite/src/common/ops/populate/binary_cross_entropy_grad_populate.cc 927deleted file mode 100644 928index 5da193bc..00000000 929--- a/mindspore/lite/src/common/ops/populate/binary_cross_entropy_grad_populate.cc 930+++ /dev/null 931@@ -1,45 +0,0 @@ 932-/** 933- * Copyright 2019-2021 Huawei Technologies Co., Ltd 934- * 935- * Licensed under the Apache License, Version 2.0 (the "License"); 936- * you may not use this file except in compliance with the License. 937- * You may obtain a copy of the License at 938- * 939- * http://www.apache.org/licenses/LICENSE-2.0 940- * 941- * Unless required by applicable law or agreed to in writing, software 942- * distributed under the License is distributed on an "AS IS" BASIS, 943- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 944- * See the License for the specific language governing permissions and 945- * limitations under the License. 946- */ 947-#include "src/common/ops/populate/populate_register.h" 948-#include "nnacl/fp32_grad/binary_cross_entropy_grad.h" 949-using mindspore::schema::PrimitiveType_BinaryCrossEntropyGrad; 950- 951-namespace mindspore { 952-namespace lite { 953-OpParameter *PopulateBinaryCrossEntropyGradParameter(const void *prim) { 954- auto *primitive = static_cast<const schema::Primitive *>(prim); 955- MS_ASSERT(primitive != nullptr); 956- auto value = primitive->value_as_BinaryCrossEntropyGrad(); 957- if (value == nullptr) { 958- MS_LOG(ERROR) << "param is nullptr"; 959- return nullptr; 960- } 961- 962- auto *param = reinterpret_cast<BinaryCrossEntropyGradParameter *>(malloc(sizeof(BinaryCrossEntropyGradParameter))); 963- if (param == nullptr) { 964- MS_LOG(ERROR) << "malloc BinaryCrossEntropyGrad Parameter failed."; 965- return nullptr; 966- } 967- memset(param, 0, sizeof(BinaryCrossEntropyGradParameter)); 968- 969- param->op_parameter_.type_ = primitive->value_type(); 970- param->reduction = value->reduction(); 971- return reinterpret_cast<OpParameter *>(param); 972-} 973- 974-REG_POPULATE(PrimitiveType_BinaryCrossEntropyGrad, PopulateBinaryCrossEntropyGradParameter, SCHEMA_CUR); 975-} // namespace lite 976-} // namespace mindspore 977diff --git a/mindspore/lite/src/common/ops/populate/binary_cross_entropy_populate.cc b/mindspore/lite/src/common/ops/populate/binary_cross_entropy_populate.cc 978deleted file mode 100644 979index 10060d3f..00000000 980--- a/mindspore/lite/src/common/ops/populate/binary_cross_entropy_populate.cc 981+++ /dev/null 982@@ -1,45 +0,0 @@ 983-/** 984- * Copyright 2019-2021 Huawei Technologies Co., Ltd 985- * 986- * Licensed under the Apache License, Version 2.0 (the "License"); 987- * you may not use this file except in compliance with the License. 988- * You may obtain a copy of the License at 989- * 990- * http://www.apache.org/licenses/LICENSE-2.0 991- * 992- * Unless required by applicable law or agreed to in writing, software 993- * distributed under the License is distributed on an "AS IS" BASIS, 994- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 995- * See the License for the specific language governing permissions and 996- * limitations under the License. 997- */ 998-#include "src/common/ops/populate/populate_register.h" 999-#include "nnacl/fp32_grad/binary_cross_entropy.h" 1000-using mindspore::schema::PrimitiveType_BinaryCrossEntropy; 1001- 1002-namespace mindspore { 1003-namespace lite { 1004-OpParameter *PopulateBinaryCrossEntropyParameter(const void *prim) { 1005- auto primitive = static_cast<const schema::Primitive *>(prim); 1006- MS_ASSERT(primitive != nullptr); 1007- auto value = primitive->value_as_BinaryCrossEntropy(); 1008- if (value == nullptr) { 1009- MS_LOG(ERROR) << "value is nullptr"; 1010- return nullptr; 1011- } 1012- 1013- auto *param = reinterpret_cast<BinaryCrossEntropyParameter *>(malloc(sizeof(BinaryCrossEntropyParameter))); 1014- if (param == nullptr) { 1015- MS_LOG(ERROR) << "malloc BinaryCrossEntropy Parameter failed."; 1016- return nullptr; 1017- } 1018- memset(param, 0, sizeof(BinaryCrossEntropyParameter)); 1019- 1020- param->op_parameter_.type_ = primitive->value_type(); 1021- param->reduction = value->reduction(); 1022- return reinterpret_cast<OpParameter *>(param); 1023-} 1024- 1025-REG_POPULATE(PrimitiveType_BinaryCrossEntropy, PopulateBinaryCrossEntropyParameter, SCHEMA_CUR); 1026-} // namespace lite 1027-} // namespace mindspore 1028diff --git a/mindspore/lite/src/common/prim_util.h b/mindspore/lite/src/common/prim_util.h 1029index 2714b6df..38733f82 100644 1030--- a/mindspore/lite/src/common/prim_util.h 1031+++ b/mindspore/lite/src/common/prim_util.h 1032@@ -1,5 +1,5 @@ 1033 /** 1034- * Copyright 2020 Huawei Technologies Co., Ltd 1035+ * Copyright 2020-2023 Huawei Technologies Co., Ltd 1036 * 1037 * Licensed under the Apache License, Version 2.0 (the "License"); 1038 * you may not use this file except in compliance with the License. 1039@@ -16,13 +16,14 @@ 1040 1041 #ifndef MINDSPORE_LITE_SRC_COMMON_PRIM_UTIL_H_ 1042 #define MINDSPORE_LITE_SRC_COMMON_PRIM_UTIL_H_ 1043+#include "include/api/types.h" 1044 1045 namespace mindspore { 1046 namespace lite { 1047-int GetPrimitiveType(const void *primitive, int schema_version); 1048+MS_API int GetPrimitiveType(const void *primitive, int schema_version); 1049 const char *GetPrimitiveTypeName(const void *primitive, int schema_version); 1050 const char *PrimitiveCurVersionTypeName(int type); 1051-int GenPrimVersionKey(int primitive_type, int schema_version); 1052+MS_API int GenPrimVersionKey(int primitive_type, int schema_version); 1053 bool IsPartialNode(const void *primitive, int schema_version); 1054 bool IsCallNode(const void *primitive, int schema_version); 1055 bool IsSwitchNode(const void *primitive, int schema_version); 1056diff --git a/mindspore/lite/src/common/tensor_util.h b/mindspore/lite/src/common/tensor_util.h 1057index 6e8ac3af..caced545 100644 1058--- a/mindspore/lite/src/common/tensor_util.h 1059+++ b/mindspore/lite/src/common/tensor_util.h 1060@@ -41,7 +41,7 @@ int GenerateInTensorC(const std::vector<lite::Tensor *> &inputs, std::vector<Ten 1061 std::shared_ptr<Allocator> allocator = nullptr); 1062 int GenerateOutTensorC(const OpParameter *const parameter, const std::vector<lite::Tensor *> &outputs, 1063 std::vector<TensorC *> *out_tensor_c, std::shared_ptr<Allocator> allocator = nullptr); 1064-int CheckTensorsInvalid(const std::vector<Tensor *> &tensors); 1065+MS_API int CheckTensorsInvalid(const std::vector<Tensor *> &tensors); 1066 int CheckGraphInputShapes(const std::vector<Tensor *> &inputs, 1067 const std::unordered_map<Tensor *, std::vector<int>> &input_shape_map); 1068 std::vector<mindspore::MSTensor> LiteTensorsToMSTensors(const std::vector<lite::Tensor *> &lite_tensors); 1069diff --git a/mindspore/lite/src/extendrt/CMakeLists.txt b/mindspore/lite/src/extendrt/CMakeLists.txt 1070index 4f43b01f..70d734be 100644 1071--- a/mindspore/lite/src/extendrt/CMakeLists.txt 1072+++ b/mindspore/lite/src/extendrt/CMakeLists.txt 1073@@ -1,3 +1,7 @@ 1074+string(REPLACE "-fvisibility-inlines-hidden" "" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") 1075+string(REPLACE "-fvisibility=hidden" "" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") 1076+string(REPLACE "-fvisibility-inlines-hidden" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") 1077+string(REPLACE "-fvisibility=hidden" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") 1078 set(MODEL_LOADER_FRAMEWORK_SRC 1079 ${MODEL_LOADER_FRAMEWORK_SRC} 1080 ${CMAKE_CURRENT_SOURCE_DIR}/mindir_loader/model_loader.cc 1081diff --git a/mindspore/lite/src/extendrt/cxx_api/serialization.cc b/mindspore/lite/src/extendrt/cxx_api/serialization.cc 1082index 344cfca7..c1e3d065 100644 1083--- a/mindspore/lite/src/extendrt/cxx_api/serialization.cc 1084+++ b/mindspore/lite/src/extendrt/cxx_api/serialization.cc 1085@@ -332,7 +332,8 @@ Status Serialization::SetParameters(const std::map<std::string, Buffer> &, Model 1086 return kMEFailed; 1087 } 1088 1089-Status Serialization::ExportModel(const Model &, ModelType, Buffer *) { 1090+Status Serialization::ExportModel(const Model &, ModelType, Buffer *, QuantizationType, bool, 1091+ const std::vector<std::vector<char>> & /* output_tensor_name */) { 1092 MS_LOG(ERROR) << "Unsupported feature."; 1093 return kMEFailed; 1094 } 1095diff --git a/mindspore/lite/src/runtime/cxx_api/converters.h b/mindspore/lite/src/runtime/cxx_api/converters.h 1096index bd7daabb..45ed6a5b 100644 1097--- a/mindspore/lite/src/runtime/cxx_api/converters.h 1098+++ b/mindspore/lite/src/runtime/cxx_api/converters.h 1099@@ -1,5 +1,5 @@ 1100 /** 1101- * Copyright 2021 Huawei Technologies Co., Ltd 1102+ * Copyright 2021-2023 Huawei Technologies Co., Ltd 1103 * 1104 * Licensed under the Apache License, Version 2.0 (the "License"); 1105 * you may not use this file except in compliance with the License. 1106@@ -27,7 +27,7 @@ 1107 #include "src/runtime/c_api/context_c.h" 1108 1109 namespace mindspore { 1110-class ContextUtils { 1111+class MS_API ContextUtils { 1112 public: 1113 static lite::InnerContext *Convert(Context *context); 1114 static lite::InnerContext *Convert(const ContextC *context_c); 1115diff --git a/mindspore/lite/src/runtime/cxx_api/model/model_impl.cc b/mindspore/lite/src/runtime/cxx_api/model/model_impl.cc 1116index e7a0e272..f5f275e4 100644 1117--- a/mindspore/lite/src/runtime/cxx_api/model/model_impl.cc 1118+++ b/mindspore/lite/src/runtime/cxx_api/model/model_impl.cc 1119@@ -717,6 +717,9 @@ Status ModelImpl::UpdateWeights(const std::vector<MSTensor> &new_weights) { 1120 inner_weights[i] = lite_impl->lite_tensor(); 1121 } 1122 auto ret = session_->UpdateWeights(inner_weights); 1123+ if (ret != kSuccess) { 1124+ MS_LOG(ERROR) << "UpdateWeights failed, and the origin weights may have been changed."; 1125+ } 1126 return static_cast<StatusCode>(ret); 1127 } 1128 1129diff --git a/mindspore/lite/src/runtime/cxx_api/model/model_impl.h b/mindspore/lite/src/runtime/cxx_api/model/model_impl.h 1130index 3d359f14..5c572883 100644 1131--- a/mindspore/lite/src/runtime/cxx_api/model/model_impl.h 1132+++ b/mindspore/lite/src/runtime/cxx_api/model/model_impl.h 1133@@ -1,5 +1,5 @@ 1134 /** 1135- * Copyright 2021 Huawei Technologies Co., Ltd 1136+ * Copyright 2021-2023 Huawei Technologies Co., Ltd 1137 * 1138 * Licensed under the Apache License, Version 2.0 (the "License"); 1139 * you may not use this file except in compliance with the License. 1140@@ -47,10 +47,10 @@ namespace mindspore { 1141 typedef std::shared_ptr<lite::LiteSession>(CreateTrainSessionProto)(std::shared_ptr<Graph::GraphData> graph_data, 1142 std::shared_ptr<TrainCfg> cfg, 1143 lite::InnerContext *context); 1144-CreateTrainSessionProto *CreateTrainSessionCallbackHolder(CreateTrainSessionProto *proto = nullptr); 1145+MS_API CreateTrainSessionProto *CreateTrainSessionCallbackHolder(CreateTrainSessionProto *proto = nullptr); 1146 1147 using ExpressionLoader = std::function<Status(const char *, Graph *)>; 1148-ExpressionLoader CreateExpressionLoader(ExpressionLoader loader = nullptr); 1149+MS_API ExpressionLoader CreateExpressionLoader(ExpressionLoader loader = nullptr); 1150 1151 namespace session { 1152 class Metrics; 1153diff --git a/mindspore/lite/src/runtime/cxx_api/serialization.cc b/mindspore/lite/src/runtime/cxx_api/serialization.cc 1154index 3db32826..8405f4b2 100644 1155--- a/mindspore/lite/src/runtime/cxx_api/serialization.cc 1156+++ b/mindspore/lite/src/runtime/cxx_api/serialization.cc 1157@@ -157,9 +157,34 @@ Status Serialization::SetParameters(const std::map<std::string, Buffer> ¶met 1158 return kMEFailed; 1159 } 1160 1161-Status Serialization::ExportModel(const Model &model, ModelType model_type, Buffer *model_data) { 1162- MS_LOG(ERROR) << "Unsupported feature."; 1163- return kMEFailed; 1164+Status Serialization::ExportModel(const Model &model, ModelType model_type, Buffer *model_data, 1165+ QuantizationType quantization_type, bool export_inference_only, 1166+ const std::vector<std::vector<char>> &output_tensor_name) { 1167+ if (model.impl_ == nullptr) { 1168+ MS_LOG(ERROR) << "Model implement is null."; 1169+ return kLiteUninitializedObj; 1170+ } 1171+ if (!model.impl_->IsTrainModel()) { 1172+ MS_LOG(ERROR) << "Model is not TrainModel."; 1173+ return kLiteError; 1174+ } 1175+ if (model_data == nullptr) { 1176+ MS_LOG(ERROR) << "model_data is nullptr."; 1177+ return kLiteParamInvalid; 1178+ } 1179+ if (model_type != kMindIR && model_type != kMindIR_Lite) { 1180+ MS_LOG(ERROR) << "Unsupported Export Format " << model_type; 1181+ return kLiteParamInvalid; 1182+ } 1183+ if (model.impl_->session_ == nullptr) { 1184+ MS_LOG(ERROR) << "Model session is nullptr."; 1185+ return kLiteError; 1186+ } 1187+ auto ret = model.impl_->session_->Export(model_data, export_inference_only ? lite::MT_INFERENCE : lite::MT_TRAIN, 1188+ A2L_ConvertQT(quantization_type), lite::FT_FLATBUFFERS, 1189+ VectorCharToString(output_tensor_name)); 1190+ 1191+ return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError; 1192 } 1193 1194 Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::vector<char> &model_file, 1195diff --git a/mindspore/lite/src/runtime/cxx_api/train/converters.cc b/mindspore/lite/src/runtime/cxx_api/train/converters.cc 1196index 694259b3..b0801804 100644 1197--- a/mindspore/lite/src/runtime/cxx_api/train/converters.cc 1198+++ b/mindspore/lite/src/runtime/cxx_api/train/converters.cc 1199@@ -1,5 +1,5 @@ 1200 /** 1201- * Copyright 2021 Huawei Technologies Co., Ltd 1202+ * Copyright 2021-2023 Huawei Technologies Co., Ltd 1203 * 1204 * Licensed under the Apache License, Version 2.0 (the "License"); 1205 * you may not use this file except in compliance with the License. 1206@@ -25,8 +25,8 @@ Status A2L_ConvertConfig(const TrainCfg *a_train_cfg, lite::TrainCfg *l_train_cf 1207 return kLiteNullptr; 1208 } 1209 1210- l_train_cfg->loss_name_.clear(); 1211- l_train_cfg->loss_name_.assign(a_train_cfg->loss_name_.begin(), a_train_cfg->loss_name_.end()); 1212+ std::vector<std::string> a_loss_name = a_train_cfg->GetLossName(); 1213+ l_train_cfg->loss_name_.assign(a_loss_name.begin(), a_loss_name.end()); 1214 l_train_cfg->mix_precision_cfg_.dynamic_loss_scale_ = a_train_cfg->mix_precision_cfg_.loss_scale_; 1215 l_train_cfg->mix_precision_cfg_.loss_scale_ = a_train_cfg->mix_precision_cfg_.loss_scale_; 1216 l_train_cfg->mix_precision_cfg_.keep_batchnorm_fp32_ = (a_train_cfg->optimization_level_ != kO3); 1217diff --git a/mindspore/lite/src/runtime/infer_manager.h b/mindspore/lite/src/runtime/infer_manager.h 1218index 31da532e..a851b7d2 100644 1219--- a/mindspore/lite/src/runtime/infer_manager.h 1220+++ b/mindspore/lite/src/runtime/infer_manager.h 1221@@ -31,11 +31,11 @@ 1222 #include "include/api/allocator.h" 1223 1224 namespace mindspore::lite { 1225-int KernelInferShape(const std::vector<lite::Tensor *> &tensors_in, const std::vector<lite::Tensor *> &outputs, 1226- OpParameter *parameter, std::shared_ptr<Allocator> allocator = nullptr); 1227-int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, 1228- const void *primitive, std::set<std::string> &&providers, int schema_version, 1229- const kernel::Kernel *kernel = nullptr); 1230+MS_API int KernelInferShape(const std::vector<lite::Tensor *> &tensors_in, const std::vector<lite::Tensor *> &outputs, 1231+ OpParameter *parameter, std::shared_ptr<Allocator> allocator = nullptr); 1232+MS_API int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, 1233+ const void *primitive, std::set<std::string> &&providers, int schema_version, 1234+ const kernel::Kernel *kernel = nullptr); 1235 class InferManager { 1236 public: 1237 static InferManager *GetInstance() { 1238diff --git a/mindspore/lite/src/runtime/inner_context.h b/mindspore/lite/src/runtime/inner_context.h 1239index adbeacbf..ff58995f 100644 1240--- a/mindspore/lite/src/runtime/inner_context.h 1241+++ b/mindspore/lite/src/runtime/inner_context.h 1242@@ -35,7 +35,7 @@ namespace mindspore::lite { 1243 #ifdef ENABLE_MINDRT 1244 constexpr int kDefaultParallelNum = 2; 1245 #endif 1246-struct InnerContext : public Context { 1247+struct MS_API InnerContext : public Context { 1248 public: 1249 InnerContext() { InitDeviceFp16(); } 1250 1251diff --git a/mindspore/lite/src/runtime/kernel/cpu/base/argminmax_base.cc b/mindspore/lite/src/runtime/kernel/cpu/base/argminmax_base.cc 1252index 843fc0c9..5b94867b 100644 1253--- a/mindspore/lite/src/runtime/kernel/cpu/base/argminmax_base.cc 1254+++ b/mindspore/lite/src/runtime/kernel/cpu/base/argminmax_base.cc 1255@@ -54,7 +54,6 @@ int ArgMinMaxCPUKernel::ReSize() { 1256 ComputeStrides(in_shape.data(), arg_param_->in_strides_, in_shape.size()); 1257 CHECK_NULL_RETURN(out_tensors_.at(0)); 1258 auto out_shape = out_tensors_.at(0)->shape(); 1259- CHECK_NULL_RETURN(out_shape.data()); 1260 ComputeStrides(out_shape.data(), arg_param_->out_strides_, out_shape.size()); 1261 return RET_OK; 1262 } 1263diff --git a/mindspore/lite/src/runtime/kernel/cpu/base/arithmetic_base.cc b/mindspore/lite/src/runtime/kernel/cpu/base/arithmetic_base.cc 1264index 68f5cce3..14c97bf8 100644 1265--- a/mindspore/lite/src/runtime/kernel/cpu/base/arithmetic_base.cc 1266+++ b/mindspore/lite/src/runtime/kernel/cpu/base/arithmetic_base.cc 1267@@ -285,6 +285,7 @@ void ArithmeticBaseCPUKernel::ComputeOfflineInfo() { 1268 c_matric_.batch_post_sum[i] = c_matric_.shape[i] * c_matric_.batch_post_sum[i + 1]; 1269 } 1270 } 1271+ scalar_opt_ = false; 1272 if (a_matric_.inner_size == 1) { 1273 param_->in_elements_num0_ = 1; 1274 scalar_opt_ = true; 1275diff --git a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_base.cc b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_base.cc 1276index aa50a916..b5370ddd 100644 1277--- a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_base.cc 1278+++ b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_base.cc 1279@@ -50,8 +50,6 @@ int GroupConvolutionBaseCPUKernel::ReSize() { 1280 if (group_num_ == 0) { 1281 return RET_ERROR; 1282 } 1283- conv_param_->input_channel_ /= group_num_; 1284- conv_param_->output_channel_ /= group_num_; 1285 return RET_OK; 1286 } 1287 1288@@ -96,7 +94,8 @@ int GroupConvolutionBaseCPUKernel::PreProcess() { 1289 // in 1290 auto in_tensor = in_tensors_.front(); 1291 CHECK_NULL_RETURN(in_tensor); 1292- in_shape = {in_tensor->Batch(), in_tensor->Height(), in_tensor->Width(), conv_param_->input_channel_}; 1293+ in_shape = {in_tensor->Batch(), in_tensor->Height(), in_tensor->Width(), 1294+ conv_param_->input_channel_ / group_num_}; 1295 auto sub_kernel_in_tensor = group_convs_.at(i)->in_tensors().front(); 1296 CHECK_NULL_RETURN(sub_kernel_in_tensor); 1297 sub_kernel_in_tensor->set_shape(in_shape); 1298@@ -108,7 +107,8 @@ int GroupConvolutionBaseCPUKernel::PreProcess() { 1299 // out 1300 auto out_tensor = out_tensors_.front(); 1301 CHECK_NULL_RETURN(out_tensor); 1302- out_shape = {out_tensor->Batch(), out_tensor->Height(), out_tensor->Width(), conv_param_->output_channel_}; 1303+ out_shape = {out_tensor->Batch(), out_tensor->Height(), out_tensor->Width(), 1304+ conv_param_->output_channel_ / group_num_}; 1305 auto sub_kernel_out_tensors = group_convs_.at(i)->out_tensors(); 1306 for (auto tensor : sub_kernel_out_tensors) { 1307 CHECK_NULL_RETURN(tensor); 1308@@ -148,8 +148,8 @@ int GroupConvolutionBaseCPUKernel::InitGroupParam() { 1309 MS_LOG(ERROR) << "get in_plane_ from in_tensor failed."; 1310 return RET_ERROR; 1311 } 1312- sub_in_channel_ = conv_param_->input_channel_; 1313- ori_in_channel_ = sub_in_channel_ * group_num_; 1314+ sub_in_channel_ = conv_param_->input_channel_ / group_num_; 1315+ ori_in_channel_ = conv_param_->input_channel_; 1316 in_thread_num_ = MSMIN(MSMAX(1, ctx_->thread_num_), in_plane_); 1317 1318 auto out_tensor = out_tensors_.front(); 1319@@ -159,8 +159,8 @@ int GroupConvolutionBaseCPUKernel::InitGroupParam() { 1320 MS_LOG(ERROR) << "get out_plane_ from out_tensor failed."; 1321 return RET_ERROR; 1322 } 1323- sub_out_channel_ = conv_param_->output_channel_; 1324- ori_out_channel_ = sub_out_channel_ * group_num_; 1325+ sub_out_channel_ = conv_param_->output_channel_ / group_num_; 1326+ ori_out_channel_ = conv_param_->output_channel_; 1327 out_thread_num_ = MSMIN(MSMAX(1, ctx_->thread_num_), out_plane_); 1328 return RET_OK; 1329 } 1330diff --git a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.cc b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.cc 1331index f2a29bfd..fc78a887 100644 1332--- a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.cc 1333+++ b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.cc 1334@@ -96,15 +96,10 @@ lite::Tensor *CreateVarTensor(const TensorInfo &tensor_info, bool inferred) { 1335 tensor->set_data_type(tensor_info.data_type_); 1336 tensor->set_format(tensor_info.format_); 1337 tensor->set_category(tensor_info.tensor_type_); 1338- if (tensor_info.is_in_) { 1339- tensor->set_shape(tensor_info.shape_); 1340- } 1341+ tensor->set_shape(tensor_info.shape_); 1342 1343 if (inferred) { 1344 // set shape of out tensor 1345- if (!tensor_info.is_in_) { 1346- tensor->set_shape(tensor_info.shape_); 1347- } 1348 return TensorMalloc(tensor); 1349 } 1350 return tensor; 1351@@ -185,13 +180,16 @@ void GroupConvCreator::SetShapeOfTensors() { 1352 /* set shape */ 1353 set_filter_shape({new_out_channel, conv_param_->kernel_h_, conv_param_->kernel_w_, new_in_channel}); 1354 set_bias_shape({new_out_channel}); 1355+ conv_param_->input_channel_ = new_in_channel; 1356+ conv_param_->output_channel_ = new_out_channel; 1357 if (infered_) { 1358- conv_param_->input_channel_ = new_in_channel; 1359- conv_param_->output_channel_ = new_out_channel; 1360 set_input_shape({origin_inputs_.front()->Batch(), origin_inputs_.front()->Height(), origin_inputs_.front()->Width(), 1361 new_in_channel}); 1362 set_output_shape({origin_inputs_.front()->Batch(), origin_outputs_.front()->Height(), 1363 origin_outputs_.front()->Width(), new_out_channel}); 1364+ } else { 1365+ set_input_shape({-1}); 1366+ set_output_shape({-1}); 1367 } 1368 } 1369 1370diff --git a/mindspore/lite/src/runtime/kernel/cpu/base/reshape_base.cc b/mindspore/lite/src/runtime/kernel/cpu/base/reshape_base.cc 1371index 58f953b8..89af7aae 100644 1372--- a/mindspore/lite/src/runtime/kernel/cpu/base/reshape_base.cc 1373+++ b/mindspore/lite/src/runtime/kernel/cpu/base/reshape_base.cc 1374@@ -105,6 +105,7 @@ REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_ExpandDims, LiteKernelCreator<R 1375 REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_ExpandDims, LiteKernelCreator<ReshapeBaseCPUKernel>) 1376 REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ExpandDims, LiteKernelCreator<ReshapeBaseCPUKernel>) 1377 REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_ExpandDims, LiteKernelCreator<ReshapeBaseCPUKernel>) 1378+REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_ExpandDims, LiteKernelCreator<ReshapeBaseCPUKernel>) 1379 REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Squeeze, LiteKernelCreator<ReshapeBaseCPUKernel>) 1380 REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Squeeze, LiteKernelCreator<ReshapeBaseCPUKernel>) 1381 REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Squeeze, LiteKernelCreator<ReshapeBaseCPUKernel>) 1382diff --git a/mindspore/lite/src/runtime/kernel/cpu/base/strided_slice.cc b/mindspore/lite/src/runtime/kernel/cpu/base/strided_slice.cc 1383index 5db44a0a..ec2080ef 100644 1384--- a/mindspore/lite/src/runtime/kernel/cpu/base/strided_slice.cc 1385+++ b/mindspore/lite/src/runtime/kernel/cpu/base/strided_slice.cc 1386@@ -56,6 +56,8 @@ void StridedSliceCPUKernel::InitFastRunParam() { 1387 for (size_t i = static_cast<size_t>(split_axis_ + 1); i < in_shape.size(); i++) { 1388 inner_ *= in_shape[i]; 1389 } 1390+ parallel_on_split_axis_ = false; 1391+ parallel_on_outer_ = false; 1392 outer_ == 1 ? (parallel_on_split_axis_ = true) : (parallel_on_outer_ = true); 1393 1394 if (UpdateThreadNumPass(TC_TYPE(PrimitiveType_StridedSlice, parallel_on_outer_), 1, 1, 1395diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp16/fused_batchnorm_fp16.cc b/mindspore/lite/src/runtime/kernel/cpu/fp16/fused_batchnorm_fp16.cc 1396index 84b5a1a4..8ab4969f 100644 1397--- a/mindspore/lite/src/runtime/kernel/cpu/fp16/fused_batchnorm_fp16.cc 1398+++ b/mindspore/lite/src/runtime/kernel/cpu/fp16/fused_batchnorm_fp16.cc 1399@@ -23,6 +23,7 @@ 1400 using mindspore::lite::KernelRegistrar; 1401 using mindspore::lite::RET_ERROR; 1402 using mindspore::lite::RET_OK; 1403+using mindspore::lite::RET_NO_CHANGE; 1404 using mindspore::schema::PrimitiveType_FusedBatchNorm; 1405 1406 namespace mindspore::kernel { 1407@@ -41,8 +42,11 @@ constexpr static int kOutCurrentVarIdx = 4; 1408 int FusedBatchnormFp16CPUKernel::Batchnorm2Scale(const void *scale_data, const void *bias_data, const void *mean_data, 1409 const void *var_data, float eps, int kernel_num) { 1410 auto ret = InitScaleParam(); 1411- if (ret != RET_OK) { 1412- MS_LOG(ERROR) << "Init scale parameter when converting fused_batchnorm to scale."; 1413+ if (ret == RET_NO_CHANGE) { 1414+ MS_LOG(INFO) << "Unsupported to convert fused batch norm to scale."; 1415+ return RET_NO_CHANGE; 1416+ } else if (ret != RET_OK) { 1417+ MS_LOG(ERROR) << "Init scale param failed."; 1418 return RET_ERROR; 1419 } 1420 1421diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32/fused_batchnorm_fp32.cc b/mindspore/lite/src/runtime/kernel/cpu/fp32/fused_batchnorm_fp32.cc 1422index 7243e3b0..afed28ae 100644 1423--- a/mindspore/lite/src/runtime/kernel/cpu/fp32/fused_batchnorm_fp32.cc 1424+++ b/mindspore/lite/src/runtime/kernel/cpu/fp32/fused_batchnorm_fp32.cc 1425@@ -19,6 +19,7 @@ 1426 1427 using mindspore::lite::KernelRegistrar; 1428 using mindspore::lite::RET_ERROR; 1429+using mindspore::lite::RET_NO_CHANGE; 1430 using mindspore::lite::RET_OK; 1431 using mindspore::schema::PrimitiveType_FusedBatchNorm; 1432 1433@@ -65,7 +66,7 @@ int FusedBatchnormCPUKernel::InitScaleParam() { 1434 1435 scale_param_->axis_ = kNHWC_C; 1436 auto in_shape = in_tensors_[0]->shape(); 1437- CHECK_LESS_RETURN(in_shape.size(), DIMENSION_5D); 1438+ MS_CHECK_TRUE_RET(in_shape.size() == DIMENSION_4D, RET_NO_CHANGE); 1439 scale_param_->outer_size_ = 1; 1440 for (auto i = 0; i < scale_param_->axis_; i++) { 1441 scale_param_->outer_size_ *= in_shape[i]; 1442@@ -80,8 +81,11 @@ int FusedBatchnormCPUKernel::InitScaleParam() { 1443 int FusedBatchnormCPUKernel::Batchnorm2Scale(const void *scale_data, const void *bias_data, const void *mean_data, 1444 const void *var_data, float eps, int kernel_num) { 1445 auto ret = InitScaleParam(); 1446- if (ret != RET_OK) { 1447- MS_LOG(ERROR) << "Init scale parameter when converting fused_batchnorm to scale."; 1448+ if (ret == RET_NO_CHANGE) { 1449+ MS_LOG(INFO) << "Unsupported to convert fused batch norm to scale."; 1450+ return RET_NO_CHANGE; 1451+ } else if (ret != RET_OK) { 1452+ MS_LOG(ERROR) << "Init scale param failed."; 1453 return RET_ERROR; 1454 } 1455 1456@@ -131,6 +135,10 @@ int FusedBatchnormCPUKernel::InitConstTensor() { 1457 return RET_OK; 1458 } else { 1459 FreeScaleAndOffset(); 1460+ if (ret != RET_NO_CHANGE) { 1461+ MS_LOG(ERROR) << "convert batch norm to scale failed."; 1462+ return RET_ERROR; 1463+ } 1464 } 1465 } 1466 1467@@ -188,7 +196,7 @@ int FusedBatchnormCPUKernel::Run() { 1468 1469 trained_ = true; // trained at least once 1470 } else { 1471- if (out_tensors_.size() >= DIMENSION_5D) { 1472+ if (op_parameter_->is_train_session_ && out_tensors_.size() >= DIMENSION_5D) { 1473 (void)memcpy(out_tensors_.at(SECOND_INPUT)->data(), scale_, out_tensors_.at(SECOND_INPUT)->Size()); 1474 (void)memcpy(out_tensors_.at(THIRD_INPUT)->data(), offset_, out_tensors_.at(THIRD_INPUT)->Size()); 1475 (void)memcpy(out_tensors_.at(FOURTH_INPUT)->data(), mean_, out_tensors_.at(FOURTH_INPUT)->Size()); 1476diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32/oneslike_fp32.cc b/mindspore/lite/src/runtime/kernel/cpu/fp32/oneslike_fp32.cc 1477new file mode 100644 1478index 00000000..627f19f0 1479--- /dev/null 1480+++ b/mindspore/lite/src/runtime/kernel/cpu/fp32/oneslike_fp32.cc 1481@@ -0,0 +1,52 @@ 1482+/** 1483+ * Copyright 2022 Huawei Technologies Co., Ltd 1484+ * 1485+ * Licensed under the Apache License, Version 2.0 (the "License"); 1486+ * you may not use this file except in compliance with the License. 1487+ * You may obtain a copy of the License at 1488+ * 1489+ * http://www.apache.org/licenses/LICENSE-2.0 1490+ * 1491+ * Unless required by applicable law or agreed to in writing, software 1492+ * distributed under the License is distributed on an "AS IS" BASIS, 1493+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1494+ * See the License for the specific language governing permissions and 1495+ * limitations under the License. 1496+ */ 1497+ 1498+#include "src/runtime/kernel/cpu/fp32/oneslike_fp32.h" 1499+#include "schema/model_generated.h" 1500+#include "nnacl/base/zeroslike_base.h" 1501+#include "src/runtime/kernel_registry.h" 1502+#include "include/errorcode.h" 1503+ 1504+using mindspore::kernel::KERNEL_ARCH; 1505+using mindspore::lite::KernelRegistrar; 1506+using mindspore::lite::RET_ERROR; 1507+using mindspore::lite::RET_OK; 1508+using mindspore::schema::PrimitiveType_OnesLike; 1509+ 1510+namespace mindspore::kernel { 1511+int OnesLikeCPUKernel::Prepare() { 1512+ CHECK_LESS_RETURN(in_tensors_.size(), 1); 1513+ CHECK_LESS_RETURN(out_tensors_.size(), 1); 1514+ return RET_OK; 1515+} 1516+ 1517+int OnesLikeCPUKernel::Run() { 1518+ auto output = out_tensors_[0]; 1519+ CHECK_NULL_RETURN(output); 1520+ if (output->data_type() == kNumberTypeInt32) { 1521+ ApproximateOnesLike(static_cast<int *>(output->data()), output->ElementsNum()); 1522+ } else if (output->data_type() == kNumberTypeFloat32) { 1523+ ApproximateOnesLike(static_cast<float *>(output->data()), output->ElementsNum()); 1524+ } 1525+ return RET_OK; 1526+} 1527+ 1528+REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_OnesLike, LiteKernelCreator<OnesLikeCPUKernel>) 1529+REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_OnesLike, LiteKernelCreator<OnesLikeCPUKernel>) 1530+#ifdef ENABLE_FP16 1531+REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_OnesLike, LiteKernelCreator<OnesLikeCPUKernel>) 1532+#endif 1533+} // namespace mindspore::kernel 1534\ No newline at end of file 1535diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32/oneslike_fp32.h b/mindspore/lite/src/runtime/kernel/cpu/fp32/oneslike_fp32.h 1536new file mode 100644 1537index 00000000..fdca97cb 1538--- /dev/null 1539+++ b/mindspore/lite/src/runtime/kernel/cpu/fp32/oneslike_fp32.h 1540@@ -0,0 +1,46 @@ 1541+/** 1542+ * Copyright 2022 Huawei Technologies Co., Ltd 1543+ * 1544+ * Licensed under the Apache License, Version 2.0 (the "License"); 1545+ * you may not use this file except in compliance with the License. 1546+ * You may obtain a copy of the License at 1547+ * 1548+ * http://www.apache.org/licenses/LICENSE-2.0 1549+ * 1550+ * Unless required by applicable law or agreed to in writing, software 1551+ * distributed under the License is distributed on an "AS IS" BASIS, 1552+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1553+ * See the License for the specific language governing permissions and 1554+ * limitations under the License. 1555+ */ 1556+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_ONESLike_FP32_H_ 1557+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_ONESLike_FP32_H_ 1558+ 1559+#include <vector> 1560+#include "src/runtime/lite_kernel.h" 1561+ 1562+namespace mindspore::kernel { 1563+class OnesLikeCPUKernel : public LiteKernel { 1564+ public: 1565+ OnesLikeCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, 1566+ const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) 1567+ : LiteKernel(parameter, inputs, outputs, ctx) {} 1568+ 1569+ ~OnesLikeCPUKernel() = default; 1570+ 1571+ int Prepare() override; 1572+ int ReSize() override { return lite::RET_OK; } 1573+ int Run() override; 1574+ 1575+ private: 1576+ template <typename T> 1577+ void ApproximateOnesLike(T *output, int data_size) { 1578+ for (int i = 0; i < data_size; ++i) { 1579+ output[i] = 1; 1580+ } 1581+ return; 1582+ } 1583+}; 1584+} // namespace mindspore::kernel 1585+ 1586+#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_ONESLike_FP32_H_ 1587\ No newline at end of file 1588diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy.cc b/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy.cc 1589new file mode 100644 1590index 00000000..a24976f8 1591--- /dev/null 1592+++ b/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy.cc 1593@@ -0,0 +1,120 @@ 1594+/** 1595+ * Copyright 2022 Huawei Technologies Co., Ltd 1596+ * 1597+ * Licensed under the Apache License, Version 2.0 (the "License"); 1598+ * you may not use this file except in compliance with the License. 1599+ * You may obtain a copy of the License at 1600+ * 1601+ * http://www.apache.org/licenses/LICENSE-2.0 1602+ * 1603+ * Unless required by applicable law or agreed to in writing, software 1604+ * distributed under the License is distributed on an "AS IS" BASIS, 1605+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1606+ * See the License for the specific language governing permissions and 1607+ * limitations under the License. 1608+ */ 1609+ 1610+#include "src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy.h" 1611+#include "src/runtime/kernel_registry.h" 1612+#include "include/errorcode.h" 1613+#include "nnacl/fp32_grad/binary_cross_entropy.h" 1614+ 1615+using mindspore::lite::KernelRegistrar; 1616+using mindspore::lite::RET_ERROR; 1617+using mindspore::lite::RET_OK; 1618+using mindspore::schema::PrimitiveType_BinaryCrossEntropy; 1619+ 1620+namespace mindspore::kernel { 1621+BinaryCrossEntropyCPUKernel::~BinaryCrossEntropyCPUKernel() { 1622+ if (tmp_loss_ != nullptr) { 1623+ free(tmp_loss_); 1624+ tmp_loss_ = nullptr; 1625+ } 1626+} 1627+ 1628+int BinaryCrossEntropyCPUKernel::ReSize() { 1629+ CHECK_LESS_RETURN(in_tensors_.size(), C2NUM); 1630+ CHECK_LESS_RETURN(out_tensors_.size(), 1); 1631+ CHECK_NULL_RETURN(in_tensors_.at(0)); 1632+ CHECK_NULL_RETURN(in_tensors_.at(1)); 1633+ if (in_tensors_.size() == C3NUM) { 1634+ weight_defined_ = true; 1635+ CHECK_NULL_RETURN(in_tensors_.at(C2NUM)); 1636+ } 1637+ CHECK_NULL_RETURN(out_tensors_.at(0)); 1638+ CHECK_NULL_RETURN(op_parameter_); 1639+ 1640+ auto param_ = reinterpret_cast<BinaryCrossEntropyParameter *>(op_parameter_); 1641+ CHECK_NULL_RETURN(param_); 1642+ if (tmp_loss_ != nullptr) { 1643+ free(tmp_loss_); 1644+ tmp_loss_ = nullptr; 1645+ } 1646+ size_t input_size = in_tensors_.at(0)->ElementsNum(); 1647+ tmp_loss_ = reinterpret_cast<float *>(malloc(input_size * sizeof(float))); 1648+ if (tmp_loss_ == nullptr) { 1649+ MS_LOG(ERROR) << "malloc tmp_loss_ for BinaryCrossEntropy op failed"; 1650+ return RET_ERROR; 1651+ } 1652+ 1653+ return RET_OK; 1654+} 1655+ 1656+int BinaryCrossEntropyCPUKernel::DoExecute(int task_id) { 1657+ auto logits = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); 1658+ CHECK_NULL_RETURN(logits); 1659+ auto labels = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData()); 1660+ CHECK_NULL_RETURN(labels); 1661+ auto *out = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); 1662+ CHECK_NULL_RETURN(out); 1663+ 1664+ auto param_ = reinterpret_cast<BinaryCrossEntropyParameter *>(op_parameter_); 1665+ int reduction = param_->reduction; 1666+ size_t input_size = in_tensors_.at(0)->ElementsNum(); 1667+ if (weight_defined_) { 1668+ weight_ = reinterpret_cast<float *>(in_tensors_.at(C2NUM)->MutableData()); 1669+ CHECK_NULL_RETURN(weight_); 1670+ } 1671+ BinaryCrossEntropy(input_size, reduction, logits, labels, weight_, out, tmp_loss_, weight_defined_); 1672+ return RET_OK; 1673+} 1674+ 1675+int BinaryCrossEntropyRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) { 1676+ CHECK_NULL_RETURN(cdata); 1677+ auto bin_crs_ent_kernel = reinterpret_cast<BinaryCrossEntropyCPUKernel *>(cdata); 1678+ auto error_code = bin_crs_ent_kernel->DoExecute(task_id); 1679+ if (error_code != RET_OK) { 1680+ MS_LOG(ERROR) << "BinaryCrossEntropy error task_id[" << task_id << "] error_code[" << error_code << "]"; 1681+ return RET_ERROR; 1682+ } 1683+ return RET_OK; 1684+} 1685+ 1686+int BinaryCrossEntropyCPUKernel::Run() { 1687+ int error_code = ParallelLaunch(this->ms_context_, BinaryCrossEntropyRun, this, 1); 1688+ if (error_code != RET_OK) { 1689+ MS_LOG(ERROR) << "SigmoidCrossEntropyWithLogits function error error_code[" << error_code << "]"; 1690+ return RET_ERROR; 1691+ } 1692+ return RET_OK; 1693+} 1694+ 1695+int BinaryCrossEntropyCPUKernel::Prepare() { return ReSize(); } 1696+ 1697+kernel::LiteKernel *CpuBinaryCrossEntropyFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, 1698+ const std::vector<lite::Tensor *> &outputs, 1699+ OpParameter *opParameter, const lite::Context *ctx, 1700+ const kernel::KernelKey &desc) { 1701+ MS_ASSERT(opParameter != nullptr); 1702+ MS_ASSERT(desc.type == schema::PrimitiveType_BinaryCrossEntropy); 1703+ auto *kernel = new (std::nothrow) 1704+ BinaryCrossEntropyCPUKernel(opParameter, inputs, outputs, static_cast<const lite::InnerContext *>(ctx)); 1705+ if (kernel == nullptr) { 1706+ MS_LOG(ERROR) << "new SigmoidCrossEntropyWithLogits failed"; 1707+ return nullptr; 1708+ } 1709+ return kernel; 1710+} 1711+ 1712+REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BinaryCrossEntropy, CpuBinaryCrossEntropyFp32KernelCreator) 1713+} // namespace mindspore::kernel 1714diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy.h b/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy.h 1715new file mode 100644 1716index 00000000..39a7181e 1717--- /dev/null 1718+++ b/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy.h 1719@@ -0,0 +1,42 @@ 1720+/** 1721+ * Copyright 2022 Huawei Technologies Co., Ltd 1722+ * 1723+ * Licensed under the Apache License, Version 2.0 (the "License"); 1724+ * you may not use this file except in compliance with the License. 1725+ * You may obtain a copy of the License at 1726+ * 1727+ * http://www.apache.org/licenses/LICENSE-2.0 1728+ * 1729+ * Unless required by applicable law or agreed to in writing, software 1730+ * distributed under the License is distributed on an "AS IS" BASIS, 1731+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1732+ * See the License for the specific language governing permissions and 1733+ * limitations under the License. 1734+ */ 1735+ 1736+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_BINARY_CROSS_ENTROPY_H_ 1737+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_BINARY_CROSS_ENTROPY_H_ 1738+ 1739+#include <vector> 1740+#include "src/runtime/lite_kernel.h" 1741+ 1742+namespace mindspore::kernel { 1743+class BinaryCrossEntropyCPUKernel : public LiteKernel { 1744+ public: 1745+ explicit BinaryCrossEntropyCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, 1746+ const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) 1747+ : LiteKernel(parameter, inputs, outputs, ctx) {} 1748+ ~BinaryCrossEntropyCPUKernel() override; 1749+ int Prepare() override; 1750+ int ReSize() override; 1751+ int Run() override; 1752+ int DoExecute(int task_id); 1753+ 1754+ protected: 1755+ float *tmp_loss_ = nullptr; 1756+ bool weight_defined_{false}; 1757+ float *weight_ = nullptr; 1758+}; 1759+} // namespace mindspore::kernel 1760+ 1761+#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_BINARY_CROSS_ENTROPY_H_ 1762diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy_grad.cc b/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy_grad.cc 1763new file mode 100644 1764index 00000000..abac8fd1 1765--- /dev/null 1766+++ b/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy_grad.cc 1767@@ -0,0 +1,105 @@ 1768+/** 1769+ * Copyright 2022 Huawei Technologies Co., Ltd 1770+ * 1771+ * Licensed under the Apache License, Version 2.0 (the "License"); 1772+ * you may not use this file except in compliance with the License. 1773+ * You may obtain a copy of the License at 1774+ * 1775+ * http://www.apache.org/licenses/LICENSE-2.0 1776+ * 1777+ * Unless required by applicable law or agreed to in writing, software 1778+ * distributed under the License is distributed on an "AS IS" BASIS, 1779+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1780+ * See the License for the specific language governing permissions and 1781+ * limitations under the License. 1782+ */ 1783+ 1784+#include "src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy_grad.h" 1785+#include "src/runtime/kernel_registry.h" 1786+#include "include/errorcode.h" 1787+#include "nnacl/fp32_grad/binary_cross_entropy_grad.h" 1788+ 1789+using mindspore::lite::KernelRegistrar; 1790+using mindspore::lite::RET_ERROR; 1791+using mindspore::lite::RET_OK; 1792+using mindspore::schema::PrimitiveType_BinaryCrossEntropyGrad; 1793+ 1794+namespace mindspore::kernel { 1795+int BinaryCrossEntropyGradCPUKernel::ReSize() { 1796+ CHECK_LESS_RETURN(in_tensors_.size(), C3NUM); 1797+ CHECK_LESS_RETURN(out_tensors_.size(), C1NUM); 1798+ CHECK_NULL_RETURN(in_tensors_.at(C0NUM)); 1799+ CHECK_NULL_RETURN(in_tensors_.at(C1NUM)); 1800+ CHECK_NULL_RETURN(in_tensors_.at(C2NUM)); 1801+ if (in_tensors_.size() == C4NUM) { 1802+ weight_defined_ = true; 1803+ CHECK_NULL_RETURN(in_tensors_.at(C3NUM)); 1804+ } 1805+ CHECK_NULL_RETURN(out_tensors_.at(0)); 1806+ CHECK_NULL_RETURN(op_parameter_); 1807+ auto param_ = reinterpret_cast<BinaryCrossEntropyGradParameter *>(op_parameter_); 1808+ CHECK_NULL_RETURN(param_); 1809+ 1810+ return RET_OK; 1811+} 1812+ 1813+int BinaryCrossEntropyGradCPUKernel::DoExecute(int task_id) { 1814+ auto input_x = reinterpret_cast<float *>(in_tensors_.at(C0NUM)->MutableData()); 1815+ CHECK_NULL_RETURN(input_x); 1816+ auto input_y = reinterpret_cast<float *>(in_tensors_.at(C1NUM)->MutableData()); 1817+ CHECK_NULL_RETURN(input_y); 1818+ auto dloss = reinterpret_cast<float *>(in_tensors_.at(C2NUM)->MutableData()); 1819+ CHECK_NULL_RETURN(dloss); 1820+ if (weight_defined_) { 1821+ weight_ = reinterpret_cast<float *>(in_tensors_.at(C3NUM)->MutableData()); 1822+ CHECK_NULL_RETURN(weight_); 1823+ } 1824+ auto *out = reinterpret_cast<float *>(out_tensors_.at(C0NUM)->MutableData()); 1825+ CHECK_NULL_RETURN(out); 1826+ 1827+ auto param_ = reinterpret_cast<BinaryCrossEntropyGradParameter *>(op_parameter_); 1828+ int reduction = param_->reduction; 1829+ size_t input_size = in_tensors_.at(0)->ElementsNum(); 1830+ BinaryCrossEntropyGrad(input_size, reduction, input_x, input_y, weight_, dloss, out, weight_defined_); 1831+ return RET_OK; 1832+} 1833+ 1834+int BinaryCrossEntropyGradRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) { 1835+ CHECK_NULL_RETURN(cdata); 1836+ auto bin_crs_ent_kernel = reinterpret_cast<BinaryCrossEntropyGradCPUKernel *>(cdata); 1837+ auto error_code = bin_crs_ent_kernel->DoExecute(task_id); 1838+ if (error_code != RET_OK) { 1839+ MS_LOG(ERROR) << "BinaryCrossEntropyGrad error task_id[" << task_id << "] error_code[" << error_code << "]"; 1840+ return RET_ERROR; 1841+ } 1842+ return RET_OK; 1843+} 1844+ 1845+int BinaryCrossEntropyGradCPUKernel::Run() { 1846+ int error_code = ParallelLaunch(this->ms_context_, BinaryCrossEntropyGradRun, this, 1); 1847+ if (error_code != RET_OK) { 1848+ MS_LOG(ERROR) << "BinaryCrossEntropyGrad function error error_code[" << error_code << "]"; 1849+ return RET_ERROR; 1850+ } 1851+ return RET_OK; 1852+} 1853+ 1854+int BinaryCrossEntropyGradCPUKernel::Prepare() { return ReSize(); } 1855+ 1856+kernel::LiteKernel *CpuBinaryCrossEntropyGradFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, 1857+ const std::vector<lite::Tensor *> &outputs, 1858+ OpParameter *opParameter, const lite::Context *ctx, 1859+ const kernel::KernelKey &desc) { 1860+ MS_ASSERT(opParameter != nullptr); 1861+ MS_ASSERT(desc.type == schema::PrimitiveType_BinaryCrossEntropyGrad); 1862+ auto *kernel = new (std::nothrow) 1863+ BinaryCrossEntropyGradCPUKernel(opParameter, inputs, outputs, static_cast<const lite::InnerContext *>(ctx)); 1864+ if (kernel == nullptr) { 1865+ MS_LOG(ERROR) << "new BinaryCrossEntropyGrad failed"; 1866+ return nullptr; 1867+ } 1868+ return kernel; 1869+} 1870+ 1871+REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BinaryCrossEntropyGrad, CpuBinaryCrossEntropyGradFp32KernelCreator) 1872+} // namespace mindspore::kernel 1873diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy_grad.h b/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy_grad.h 1874new file mode 100644 1875index 00000000..d289eb65 1876--- /dev/null 1877+++ b/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy_grad.h 1878@@ -0,0 +1,41 @@ 1879+/** 1880+ * Copyright 2022 Huawei Technologies Co., Ltd 1881+ * 1882+ * Licensed under the Apache License, Version 2.0 (the "License"); 1883+ * you may not use this file except in compliance with the License. 1884+ * You may obtain a copy of the License at 1885+ * 1886+ * http://www.apache.org/licenses/LICENSE-2.0 1887+ * 1888+ * Unless required by applicable law or agreed to in writing, software 1889+ * distributed under the License is distributed on an "AS IS" BASIS, 1890+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1891+ * See the License for the specific language governing permissions and 1892+ * limitations under the License. 1893+ */ 1894+ 1895+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_BINARY_CROSS_ENTROPY_GRAD_H_ 1896+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_BINARY_CROSS_ENTROPY_GRAD_H_ 1897+ 1898+#include <vector> 1899+#include "src/runtime/lite_kernel.h" 1900+ 1901+namespace mindspore::kernel { 1902+class BinaryCrossEntropyGradCPUKernel : public LiteKernel { 1903+ public: 1904+ explicit BinaryCrossEntropyGradCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, 1905+ const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) 1906+ : LiteKernel(parameter, inputs, outputs, ctx) {} 1907+ ~BinaryCrossEntropyGradCPUKernel() override {} 1908+ int Prepare() override; 1909+ int ReSize() override; 1910+ int Run() override; 1911+ int DoExecute(int task_id); 1912+ 1913+ protected: 1914+ bool weight_defined_{false}; 1915+ float *weight_ = nullptr; 1916+}; 1917+} // namespace mindspore::kernel 1918+ 1919+#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_BINARY_CROSS_ENTROPY_GRAD_H_ 1920diff --git a/mindspore/lite/src/runtime/kernel/gpu/opencl/CMakeLists.txt b/mindspore/lite/src/runtime/kernel/gpu/opencl/CMakeLists.txt 1921new file mode 100644 1922index 00000000..3b42ed7a 1923--- /dev/null 1924+++ b/mindspore/lite/src/runtime/kernel/gpu/opencl/CMakeLists.txt 1925@@ -0,0 +1,11 @@ 1926+string(REPLACE "-fvisibility-inlines-hidden" "" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") 1927+string(REPLACE "-fvisibility=hidden" "" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") 1928+string(REPLACE "-fvisibility-inlines-hidden" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") 1929+string(REPLACE "-fvisibility=hidden" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") 1930+if(MSLITE_GPU_BACKEND STREQUAL opencl) 1931+ file(GLOB_RECURSE OPENCL_KERNEL_SRC 1932+ ${CMAKE_CURRENT_SOURCE_DIR}/*.cc 1933+ ${CMAKE_CURRENT_SOURCE_DIR}/../../opencl/*.cc) 1934+ add_library(opencl_kernel_mid OBJECT ${OPENCL_KERNEL_SRC}) 1935+ add_dependencies(opencl_kernel_mid fbs_src) 1936+endif() 1937diff --git a/mindspore/lite/src/runtime/kernel/opencl/CMakeLists.txt b/mindspore/lite/src/runtime/kernel/opencl/CMakeLists.txt 1938deleted file mode 100644 1939index cad0f8f7..00000000 1940--- a/mindspore/lite/src/runtime/kernel/opencl/CMakeLists.txt 1941+++ /dev/null 1942@@ -1,8 +0,0 @@ 1943-if(MSLITE_GPU_BACKEND STREQUAL opencl) 1944- file(GLOB_RECURSE OPENCL_KERNEL_SRC 1945- ${CMAKE_CURRENT_SOURCE_DIR}/*.cc 1946- ${CMAKE_CURRENT_SOURCE_DIR}/kernel/*.cc 1947- ${CMAKE_CURRENT_SOURCE_DIR}/kernel/int8/*.cc) 1948- add_library(opencl_kernel_mid OBJECT ${OPENCL_KERNEL_SRC}) 1949- add_dependencies(opencl_kernel_mid fbs_src) 1950-endif() 1951diff --git a/mindspore/lite/src/runtime/kernel_exec_util.h b/mindspore/lite/src/runtime/kernel_exec_util.h 1952index e45c185b..9ce5267e 100644 1953--- a/mindspore/lite/src/runtime/kernel_exec_util.h 1954+++ b/mindspore/lite/src/runtime/kernel_exec_util.h 1955@@ -24,7 +24,7 @@ 1956 1957 namespace mindspore::kernel { 1958 1959-class KernelExecUtil { 1960+class MS_API KernelExecUtil { 1961 public: 1962 static std::vector<KernelExec *> SubgraphInputNodes(const std::vector<KernelExec *> &kernels); 1963 static std::vector<KernelExec *> SubgraphOutputNodes(const std::vector<KernelExec *> &kernels); 1964diff --git a/mindspore/lite/src/runtime/kernel_registry.h b/mindspore/lite/src/runtime/kernel_registry.h 1965index 853d863a..f563d82d 100644 1966--- a/mindspore/lite/src/runtime/kernel_registry.h 1967+++ b/mindspore/lite/src/runtime/kernel_registry.h 1968@@ -1,5 +1,5 @@ 1969 /** 1970- * Copyright 2020 Huawei Technologies Co., Ltd 1971+ * Copyright 2020-2023 Huawei Technologies Co., Ltd 1972 * 1973 * Licensed under the Apache License, Version 2.0 (the "License"); 1974 * you may not use this file except in compliance with the License. 1975@@ -31,7 +31,7 @@ using mindspore::schema::PrimitiveType_MAX; 1976 using mindspore::schema::PrimitiveType_MIN; 1977 1978 namespace mindspore::lite { 1979-class KernelRegistry { 1980+class MS_API KernelRegistry { 1981 public: 1982 KernelRegistry() = default; 1983 virtual ~KernelRegistry(); 1984diff --git a/mindspore/lite/src/runtime/lite_kernel.h b/mindspore/lite/src/runtime/lite_kernel.h 1985index ce829320..a27f77d8 100644 1986--- a/mindspore/lite/src/runtime/lite_kernel.h 1987+++ b/mindspore/lite/src/runtime/lite_kernel.h 1988@@ -1,5 +1,5 @@ 1989 /** 1990- * Copyright 2021-2022 Huawei Technologies Co., Ltd 1991+ * Copyright 2021-2023 Huawei Technologies Co., Ltd 1992 * 1993 * Licensed under the Apache License, Version 2.0 (the "License"); 1994 * you may not use this file except in compliance with the License. 1995@@ -37,7 +37,7 @@ 1996 using mindspore::infer::Abstractkernel; 1997 1998 namespace mindspore::kernel { 1999-class LiteKernel : public Abstractkernel { 2000+class MS_API LiteKernel : public Abstractkernel { 2001 public: 2002 LiteKernel() = default; 2003 2004diff --git a/mindspore/lite/src/runtime/lite_model.h b/mindspore/lite/src/runtime/lite_model.h 2005index f6c7ebc4..af62cb91 100644 2006--- a/mindspore/lite/src/runtime/lite_model.h 2007+++ b/mindspore/lite/src/runtime/lite_model.h 2008@@ -1,5 +1,5 @@ 2009 /** 2010- * Copyright 2020 Huawei Technologies Co., Ltd 2011+ * Copyright 2020-2023 Huawei Technologies Co., Ltd 2012 * 2013 * Licensed under the Apache License, Version 2.0 (the "License"); 2014 * you may not use this file except in compliance with the License. 2015@@ -36,7 +36,7 @@ 2016 2017 namespace mindspore { 2018 namespace lite { 2019-class LiteModel : public Model { 2020+class MS_API LiteModel : public Model { 2021 public: 2022 explicit LiteModel(std::string model_path = "") : model_path_(std::move(model_path)) {} 2023 2024diff --git a/mindspore/lite/src/runtime/lite_session.cc b/mindspore/lite/src/runtime/lite_session.cc 2025index b8808e21..dffb39e7 100644 2026--- a/mindspore/lite/src/runtime/lite_session.cc 2027+++ b/mindspore/lite/src/runtime/lite_session.cc 2028@@ -504,7 +504,7 @@ void LiteSession::FreePackOpWeight(const std::vector<kernel::KernelExec *> &kern 2029 auto inputs = kernel->in_tensors(); 2030 for (auto *tensor : inputs) { 2031 MS_ASSERT(tensor != nullptr); 2032- if (!tensor->IsConst()) { 2033+ if (!tensor->IsConst() || tensor->ref_count() >= 1) { 2034 continue; 2035 } 2036 tensor->FreeData(); 2037@@ -512,6 +512,29 @@ void LiteSession::FreePackOpWeight(const std::vector<kernel::KernelExec *> &kern 2038 } 2039 } 2040 2041+void LiteSession::MarkSharedWeight(const std::vector<kernel::KernelExec *> &kernels) { 2042+ // For reducing runtime RAM 2043+ // free pack-op weight because pack-op will not access origin weight in runtime 2044+ for (auto *kernel : kernels) { 2045+ MS_ASSERT(kernel != nullptr); 2046+ if (kernel->subgraph_type() == kernel::kNotSubGraph) { 2047+ if (IsPackedOp(static_cast<int>(kernel->type()))) { 2048+ continue; 2049+ } 2050+ } else { 2051+ auto subgraph = reinterpret_cast<kernel::SubGraphKernel *>(kernel); 2052+ MarkSharedWeight(subgraph->nodes()); 2053+ } 2054+ auto inputs = kernel->in_tensors(); 2055+ for (auto *tensor : inputs) { 2056+ MS_ASSERT(tensor != nullptr); 2057+ if (tensor->IsConst()) { 2058+ tensor->IncRefCount(); 2059+ } 2060+ } 2061+ } 2062+} 2063+ 2064 int LiteSession::CompileGraph(Model *model) { 2065 auto ret = PreCheck(model); 2066 if (ret != RET_OK) { 2067@@ -572,7 +595,7 @@ int LiteSession::CompileGraph(Model *model) { 2068 is_running_.store(false); 2069 return ret; 2070 } 2071- 2072+ MarkSharedWeight(kernels_); 2073 FreePackOpWeight(kernels_); 2074 2075 ret = RuntimeAllocatorInit(); 2076@@ -1727,6 +1750,7 @@ int lite::LiteSession::LoadModelAndCompileByBuf(const char *model_buf, mindspore 2077 delete model; 2078 return RET_ERROR; 2079 } 2080+ model->Free(); 2081 set_model(model); 2082 return RET_OK; 2083 } 2084diff --git a/mindspore/lite/src/runtime/lite_session.h b/mindspore/lite/src/runtime/lite_session.h 2085index 255e90b5..2fdb1eb7 100644 2086--- a/mindspore/lite/src/runtime/lite_session.h 2087+++ b/mindspore/lite/src/runtime/lite_session.h 2088@@ -1,5 +1,5 @@ 2089 /** 2090- * Copyright 2020 Huawei Technologies Co., Ltd 2091+ * Copyright 2020-2023 Huawei Technologies Co., Ltd 2092 * 2093 * Licensed under the Apache License, Version 2.0 (the "License"); 2094 * you may not use this file except in compliance with the License. 2095@@ -39,7 +39,7 @@ 2096 2097 namespace mindspore { 2098 namespace lite { 2099-class LiteSession { 2100+class MS_API LiteSession { 2101 public: 2102 LiteSession(); 2103 virtual ~LiteSession(); 2104@@ -101,6 +101,11 @@ class LiteSession { 2105 std::vector<std::string> out_put_tensor_name = {}) { 2106 return mindspore::lite::RET_ERROR; 2107 } 2108+ virtual int Export(Buffer *model_buffer, lite::ModelType model_type = lite::MT_TRAIN, 2109+ lite::QuantizationType quant_type = lite::QT_DEFAULT, lite::FormatType = lite::FT_FLATBUFFERS, 2110+ std::vector<std::string> out_put_tensor_name = {}) { 2111+ return mindspore::lite::RET_ERROR; 2112+ } 2113 virtual int UpdateWeights(std::vector<lite::Tensor *> new_weights) { return mindspore::lite::RET_ERROR; } 2114 virtual std::vector<lite::Tensor *> GetFeatureMaps() const { 2115 std::vector<lite::Tensor *> features; 2116@@ -142,6 +147,7 @@ class LiteSession { 2117 const std::unordered_map<Tensor *, Tensor *> &isolate_input_map = std::unordered_map<Tensor *, Tensor *>()); 2118 static void FreePackOpWeight(const std::vector<kernel::KernelExec *> &kernels); 2119 std::string ParseWeightPath(); 2120+ static void MarkSharedWeight(const std::vector<kernel::KernelExec *> &kernels); 2121 2122 private: 2123 int PreCheck(Model *model); 2124diff --git a/mindspore/lite/src/runtime/weight_decoder.h b/mindspore/lite/src/runtime/weight_decoder.h 2125index 7c9e514c..006b4895 100644 2126--- a/mindspore/lite/src/runtime/weight_decoder.h 2127+++ b/mindspore/lite/src/runtime/weight_decoder.h 2128@@ -1,5 +1,5 @@ 2129 /** 2130- * Copyright 2020-2022 Huawei Technologies Co., Ltd 2131+ * Copyright 2020-2023 Huawei Technologies Co., Ltd 2132 * 2133 * Licensed under the Apache License, Version 2.0 (the "License"); 2134 * you may not use this file except in compliance with the License. 2135@@ -39,7 +39,7 @@ static constexpr int kBitNum32 = 32; 2136 2137 namespace mindspore::lite { 2138 2139-class WeightDecoder { 2140+class MS_API WeightDecoder { 2141 public: 2142 static int DequantNode(const OpParameter *op_parameter, const std::vector<Tensor *> &in_tensors, TypeId dst_data_type, 2143 const std::string &model_version, bool float_mode); 2144diff --git a/mindspore/lite/src/tensor.h b/mindspore/lite/src/tensor.h 2145index f30fe090..178d4754 100644 2146--- a/mindspore/lite/src/tensor.h 2147+++ b/mindspore/lite/src/tensor.h 2148@@ -53,7 +53,7 @@ struct LiteQuantParam { 2149 double max{255.0}; 2150 }; 2151 2152-class Tensor { 2153+class MS_API Tensor { 2154 public: 2155 Tensor() = default; 2156 2157diff --git a/mindspore/lite/src/tensorlist.h b/mindspore/lite/src/tensorlist.h 2158index 2e6f8d79..39058057 100644 2159--- a/mindspore/lite/src/tensorlist.h 2160+++ b/mindspore/lite/src/tensorlist.h 2161@@ -55,7 +55,7 @@ namespace mindspore::lite { 2162 * 2163 * See the code for other constructors. 2164 */ 2165-class TensorList : public Tensor { 2166+class MS_API TensorList : public Tensor { 2167 public: 2168 TensorList() = default; 2169 2170diff --git a/mindspore/lite/src/train/graph_fusion.cc b/mindspore/lite/src/train/graph_fusion.cc 2171index 0e582e40..1af44e45 100644 2172--- a/mindspore/lite/src/train/graph_fusion.cc 2173+++ b/mindspore/lite/src/train/graph_fusion.cc 2174@@ -22,6 +22,8 @@ 2175 #include "tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h" 2176 #include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h" 2177 #include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h" 2178+#include "src/train/optimizer/fusion/matmul_activation_fusion_pass.h" 2179+#include "src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.h" 2180 2181 namespace mindspore { 2182 namespace lite { 2183@@ -41,7 +43,9 @@ STATUS GraphFusion::Run(schema::MetaGraphT *graph) { 2184 } 2185 auto old_nodes = GetGraphNodes(*graph); 2186 Optimizer fusion_optimizer; 2187+ fusion_optimizer.AddPass(new (std::nothrow) ReshapeGatherReshapeFusionPass()); 2188 fusion_optimizer.AddPass(new (std::nothrow) MatMulBiasAddFusionPass()); 2189+ fusion_optimizer.AddPass(new (std::nothrow) MatMulActivationFusionPass()); 2190 fusion_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); 2191 fusion_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); 2192 auto status = fusion_optimizer.Run(graph); 2193diff --git a/mindspore/lite/src/train/optimizer/common/fusion_utils.cc b/mindspore/lite/src/train/optimizer/common/fusion_utils.cc 2194new file mode 100644 2195index 00000000..3edb1a1b 2196--- /dev/null 2197+++ b/mindspore/lite/src/train/optimizer/common/fusion_utils.cc 2198@@ -0,0 +1,37 @@ 2199+/** 2200+ * Copyright 2022 Huawei Technologies Co., Ltd 2201+ * 2202+ * Licensed under the Apache License, Version 2.0 (the "License"); 2203+ * you may not use this file except in compliance with the License. 2204+ * You may obtain a copy of the License at 2205+ * 2206+ * http://www.apache.org/licenses/LICENSE-2.0 2207+ * 2208+ * Unless required by applicable law or agreed to in writing, software 2209+ * distributed under the License is distributed on an "AS IS" BASIS, 2210+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 2211+ * See the License for the specific language governing permissions and 2212+ * limitations under the License. 2213+ */ 2214+#include <vector> 2215+#include <unordered_map> 2216+#include <string> 2217+#include <memory> 2218+#include "src/common/log_util.h" 2219+#include "src/train/optimizer/common/fusion_utils.h" 2220+ 2221+namespace mindspore { 2222+namespace opt { 2223+STATUS GetMatchNodeIndex(schema::MetaGraphT *graph, 2224+ const std::unordered_map<std::string, std::shared_ptr<lite::Path>> &matched_path, 2225+ const std::string &node_name, size_t *node_index) { 2226+ auto node_path_iter = matched_path.find(node_name); 2227+ MS_CHECK_TRUE_MSG(node_path_iter != matched_path.end(), RET_ERROR, "cannot find node_path"); 2228+ const auto &node_path = node_path_iter->second; 2229+ MS_CHECK_TRUE_MSG(node_path != nullptr, RET_NULL_PTR, "node_path is empty"); 2230+ *node_index = node_path->nodeIdx; 2231+ MS_CHECK_TRUE_MSG(*node_index < graph->nodes.size(), RET_ERROR, "node_index is out of range"); 2232+ return RET_OK; 2233+} 2234+} // namespace opt 2235+} // namespace mindspore 2236diff --git a/mindspore/lite/src/train/optimizer/common/fusion_utils.h b/mindspore/lite/src/train/optimizer/common/fusion_utils.h 2237new file mode 100644 2238index 00000000..7f80cd49 2239--- /dev/null 2240+++ b/mindspore/lite/src/train/optimizer/common/fusion_utils.h 2241@@ -0,0 +1,50 @@ 2242+/** 2243+ * Copyright 2022 Huawei Technologies Co., Ltd 2244+ * 2245+ * Licensed under the Apache License, Version 2.0 (the "License"); 2246+ * you may not use this file except in compliance with the License. 2247+ * You may obtain a copy of the License at 2248+ * 2249+ * http://www.apache.org/licenses/LICENSE-2.0 2250+ * 2251+ * Unless required by applicable law or agreed to in writing, software 2252+ * distributed under the License is distributed on an "AS IS" BASIS, 2253+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 2254+ * See the License for the specific language governing permissions and 2255+ * limitations under the License. 2256+ */ 2257+ 2258+#ifndef MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_COMMON_FUSION_UTILS_H_ 2259+#define MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_COMMON_FUSION_UTILS_H_ 2260+ 2261+#include <vector> 2262+#include <unordered_map> 2263+#include <string> 2264+#include <memory> 2265+#include "src/common/utils.h" 2266+#include "schema/inner/model_generated.h" 2267+#include "tools/converter/legacy_optimizer/fusion/fusion_pattern.h" 2268+ 2269+using mindspore::lite::RET_ERROR; 2270+using mindspore::lite::RET_NULL_PTR; 2271+using mindspore::lite::RET_OK; 2272+using mindspore::lite::STATUS; 2273+namespace mindspore { 2274+namespace opt { 2275+inline constexpr int kInputIndexZero = 0; 2276+inline constexpr int kInputIndexOne = 1; 2277+inline constexpr int kInputIndexTwo = 2; 2278+inline constexpr int kOutputIndexZero = 0; 2279+inline constexpr int kOutputIndexOne = 1; 2280+inline constexpr size_t kInputSizeTwo = 2; 2281+inline constexpr size_t kInputSizeThree = 3; 2282+inline constexpr size_t kOutputSizeOne = 1; 2283+inline constexpr size_t kMatchPathLenTwo = 2; 2284+inline constexpr size_t kMatchPathLenThree = 3; 2285+ 2286+STATUS GetMatchNodeIndex(schema::MetaGraphT *graph, 2287+ const std::unordered_map<std::string, std::shared_ptr<lite::Path>> &matched_path, 2288+ const std::string &node_name, size_t *node_index); 2289+} // namespace opt 2290+} // namespace mindspore 2291+#endif // MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_COMMON_FUSION_UTILS_H_ 2292diff --git a/mindspore/lite/src/train/optimizer/fusion/matmul_activation_fusion_pass.cc b/mindspore/lite/src/train/optimizer/fusion/matmul_activation_fusion_pass.cc 2293new file mode 100644 2294index 00000000..b809f2c9 2295--- /dev/null 2296+++ b/mindspore/lite/src/train/optimizer/fusion/matmul_activation_fusion_pass.cc 2297@@ -0,0 +1,93 @@ 2298+/** 2299+ * Copyright 2022 Huawei Technologies Co., Ltd 2300+ * 2301+ * Licensed under the Apache License, Version 2.0 (the "License"); 2302+ * you may not use this file except in compliance with the License. 2303+ * You may obtain a copy of the License at 2304+ * 2305+ * http://www.apache.org/licenses/LICENSE-2.0 2306+ * 2307+ * Unless required by applicable law or agreed to in writing, software 2308+ * distributed under the License is distributed on an "AS IS" BASIS, 2309+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 2310+ * See the License for the specific language governing permissions and 2311+ * limitations under the License. 2312+ */ 2313+ 2314+#include "src/train/optimizer/fusion/matmul_activation_fusion_pass.h" 2315+#include <string> 2316+#include <unordered_map> 2317+#include <vector> 2318+#include <memory> 2319+#include "schema/inner/model_generated.h" 2320+#include "tools/common/meta_graph_utils.h" 2321+#include "src/train/optimizer/common/fusion_utils.h" 2322+namespace { 2323+constexpr std::string_view MatMulName = "MATMUL"; 2324+constexpr std::string_view ActName = "ACTIVATION"; 2325+} // namespace 2326+namespace mindspore { 2327+namespace lite { 2328+STATUS MatMulActivationFusionPass::DefinePattern() { 2329+ auto matmul_op = std::make_shared<PatternOp>(); 2330+ MS_CHECK_TRUE_RET(matmul_op != nullptr, RET_NULL_PTR); 2331+ matmul_op->id = MatMulName; 2332+ matmul_op->types = {schema::PrimitiveType_MatMulFusion}; 2333+ auto act_op = std::make_shared<PatternOp>(); 2334+ MS_CHECK_TRUE_RET(act_op != nullptr, RET_NULL_PTR); 2335+ act_op->id = ActName; 2336+ act_op->types = {schema::PrimitiveType_Activation}; 2337+ act_op->left = matmul_op; 2338+ auto fusion_pattern = std::make_unique<FusionPattern>("MatMulActivationFusion"); 2339+ MS_CHECK_TRUE_MSG(fusion_pattern != nullptr, RET_NULL_PTR, "new fusion_pattern failed"); 2340+ fusion_pattern->AddPatternOp(matmul_op); 2341+ fusion_pattern->AddPatternOp(act_op); 2342+ fusion_pattern->Finish(); 2343+ this->patterns.emplace_back(fusion_pattern.release()); 2344+ return RET_OK; 2345+} 2346+ 2347+STATUS MatMulActivationFusionPass::DoFusion( 2348+ MetaGraphT *graph, const std::string &pattern_name, 2349+ const std::unordered_map<std::string, std::shared_ptr<Path>> &matched_path) { 2350+ MS_CHECK_TRUE_RET(graph != nullptr, RET_NULL_PTR); 2351+ if (matched_path.size() != opt::kMatchPathLenTwo) { 2352+ MS_LOG(ERROR) << "MatMul-Activation-Fusion should have two NodeIndex in matchedPair"; 2353+ return RET_PARAM_INVALID; 2354+ } 2355+ 2356+ size_t matmul_index = 0; 2357+ auto ret = opt::GetMatchNodeIndex(graph, matched_path, std::string(MatMulName), &matmul_index); 2358+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "cannot get matmul_index"); 2359+ auto &matmul_node = graph->nodes.at(matmul_index); 2360+ MS_CHECK_TRUE_MSG(matmul_node != nullptr, RET_NULL_PTR, "matmul_node is nullptr"); 2361+ size_t act_index = 0; 2362+ ret = opt::GetMatchNodeIndex(graph, matched_path, std::string(ActName), &act_index); 2363+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "cannot get act_index"); 2364+ auto &act_node = graph->nodes.at(act_index); 2365+ MS_CHECK_TRUE_MSG(act_node != nullptr, RET_NULL_PTR, "act_node is nullptr"); 2366+ 2367+ if (matmul_node->quantType == schema::QuantType_QUANT_ALL || 2368+ matmul_node->quantType == schema::QuantType_QUANT_DYNAMIC) { 2369+ MS_LOG(DEBUG) << "cannot fusion."; 2370+ return RET_NO_CHANGE; 2371+ } 2372+ MS_CHECK_TRUE_RET(matmul_node->primitive != nullptr, RET_NULL_PTR); 2373+ auto matmul_type = matmul_node->primitive->value.AsMatMulFusion(); 2374+ MS_CHECK_TRUE_RET(matmul_type->activation_type == ActivationType::ActivationType_NO_ACTIVATION, RET_NO_CHANGE); 2375+ MS_CHECK_TRUE_RET(act_node->primitive != nullptr, RET_NULL_PTR); 2376+ auto act_type = act_node->primitive->value.AsActivation()->activation_type; 2377+ MS_CHECK_TRUE_RET(act_type == ActivationType::ActivationType_RELU || act_type == ActivationType::ActivationType_RELU6, 2378+ RET_NO_CHANGE); 2379+ matmul_type->activation_type = act_type; 2380+ matmul_node->outputIndex = {act_node->outputIndex}; 2381+ // cannot delete node here, otherwise will destroy order in other pattern's node index 2382+ // make it an isolated node to be removed in IsolatedNodeRemovePass 2383+ act_node->inputIndex.clear(); 2384+ act_node->outputIndex.clear(); 2385+ return RET_OK; 2386+} 2387+ 2388+MatMulActivationFusionPass::~MatMulActivationFusionPass() = default; 2389+} // namespace lite 2390+} // namespace mindspore 2391diff --git a/mindspore/lite/src/train/optimizer/fusion/matmul_activation_fusion_pass.h b/mindspore/lite/src/train/optimizer/fusion/matmul_activation_fusion_pass.h 2392new file mode 100644 2393index 00000000..57891eb3 2394--- /dev/null 2395+++ b/mindspore/lite/src/train/optimizer/fusion/matmul_activation_fusion_pass.h 2396@@ -0,0 +1,42 @@ 2397+/** 2398+ * Copyright 2022 Huawei Technologies Co., Ltd 2399+ * 2400+ * Licensed under the Apache License, Version 2.0 (the "License"); 2401+ * you may not use this file except in compliance with the License. 2402+ * You may obtain a copy of the License at 2403+ * 2404+ * http://www.apache.org/licenses/LICENSE-2.0 2405+ * 2406+ * Unless required by applicable law or agreed to in writing, software 2407+ * distributed under the License is distributed on an "AS IS" BASIS, 2408+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 2409+ * See the License for the specific language governing permissions and 2410+ * limitations under the License. 2411+ */ 2412+ 2413+#ifndef MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_MATMUL_ACTIVATION_FUSION_PASS_H_ 2414+#define MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_MATMUL_ACTIVATION_FUSION_PASS_H_ 2415+ 2416+#include <string> 2417+#include <unordered_map> 2418+#include <memory> 2419+#include <algorithm> 2420+#include <utility> 2421+#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h" 2422+ 2423+namespace mindspore { 2424+namespace lite { 2425+class MatMulActivationFusionPass : public FusionPass { 2426+ public: 2427+ MatMulActivationFusionPass() = default; 2428+ 2429+ ~MatMulActivationFusionPass() override; 2430+ 2431+ STATUS DefinePattern() override; 2432+ 2433+ STATUS DoFusion(MetaGraphT *graph, const std::string &pattern_name, 2434+ const std::unordered_map<std::string, std::shared_ptr<Path>> &matched_path) override; 2435+}; 2436+} // namespace lite 2437+} // namespace mindspore 2438+#endif // MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_MATMUL_ACTIVATION_FUSION_PASS_H_ 2439diff --git a/mindspore/lite/src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.cc b/mindspore/lite/src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.cc 2440new file mode 100644 2441index 00000000..7fb8d1f4 2442--- /dev/null 2443+++ b/mindspore/lite/src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.cc 2444@@ -0,0 +1,148 @@ 2445+/** 2446+ * Copyright 2022 Huawei Technologies Co., Ltd 2447+ * 2448+ * Licensed under the Apache License, Version 2.0 (the "License"); 2449+ * you may not use this file except in compliance with the License. 2450+ * You may obtain a copy of the License at 2451+ * 2452+ * http://www.apache.org/licenses/LICENSE-2.0 2453+ * 2454+ * Unless required by applicable law or agreed to in writing, software 2455+ * distributed under the License is distributed on an "AS IS" BASIS, 2456+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 2457+ * See the License for the specific language governing permissions and 2458+ * limitations under the License. 2459+ */ 2460+ 2461+#include "src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.h" 2462+#include <string> 2463+#include <unordered_map> 2464+#include <vector> 2465+#include <memory> 2466+#include "schema/inner/model_generated.h" 2467+#include "tools/common/meta_graph_utils.h" 2468+#include "src/train/optimizer/common/fusion_utils.h" 2469+ 2470+namespace { 2471+constexpr std::string_view Reshape1Name = "RESHAPE1"; 2472+constexpr std::string_view Reshape2Name = "RESHAPE2"; 2473+constexpr std::string_view GatherName = "GATHER"; 2474+} // namespace 2475+namespace mindspore { 2476+namespace lite { 2477+/* 2478+ * The subgraph such as the following. 2479+ * any 2480+ * / | 2481+ * reshape | 2482+ * \ | 2483+ * gather 2484+ * / | 2485+ * reshape | 2486+ * \ | 2487+ * any 2488+ */ 2489+STATUS ReshapeGatherReshapeFusionPass::DefinePattern() { 2490+ auto reshape_op1 = std::make_shared<PatternOp>(); 2491+ MS_CHECK_TRUE_RET(reshape_op1 != nullptr, RET_NULL_PTR); 2492+ reshape_op1->id = Reshape1Name; 2493+ reshape_op1->types = {schema::PrimitiveType_Reshape}; 2494+ 2495+ auto gather_op = std::make_shared<PatternOp>(); 2496+ MS_CHECK_TRUE_RET(gather_op != nullptr, RET_NULL_PTR); 2497+ gather_op->id = GatherName; 2498+ gather_op->types = {schema::PrimitiveType_Gather}; 2499+ gather_op->left = reshape_op1; 2500+ 2501+ auto reshape_op2 = std::make_shared<PatternOp>(); 2502+ MS_CHECK_TRUE_RET(reshape_op2 != nullptr, RET_NULL_PTR); 2503+ reshape_op2->id = Reshape2Name; 2504+ reshape_op2->types = {schema::PrimitiveType_Reshape}; 2505+ reshape_op2->left = gather_op; 2506+ 2507+ auto fusion_pattern = std::make_unique<FusionPattern>("ReshapeGatherReshapeFusion"); 2508+ MS_CHECK_TRUE_MSG(fusion_pattern != nullptr, RET_NULL_PTR, "new fusion_pattern failed"); 2509+ fusion_pattern->AddPatternOp(reshape_op1); 2510+ fusion_pattern->AddPatternOp(gather_op); 2511+ fusion_pattern->AddPatternOp(reshape_op2); 2512+ fusion_pattern->Finish(); 2513+ this->patterns.emplace_back(fusion_pattern.release()); 2514+ return RET_OK; 2515+} 2516+ 2517+STATUS ReshapeGatherReshapeFusionPass::DoFusion( 2518+ MetaGraphT *graph, const std::string &pattern_name, 2519+ const std::unordered_map<std::string, std::shared_ptr<Path>> &matched_path) { 2520+ MS_CHECK_TRUE_RET(graph != nullptr, RET_NULL_PTR); 2521+ if (matched_path.size() != opt::kMatchPathLenThree) { 2522+ MS_LOG(ERROR) << "Reshape-Gather-Reshape-Fusion should have three NodeIndex in matchedPair"; 2523+ return RET_PARAM_INVALID; 2524+ } 2525+ 2526+ size_t reshape1_index = 0; 2527+ STATUS ret = opt::GetMatchNodeIndex(graph, matched_path, std::string(Reshape1Name), &reshape1_index); 2528+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "cannot get reshape1_index"); 2529+ auto &reshape1_node = graph->nodes.at(reshape1_index); 2530+ MS_CHECK_TRUE_MSG(reshape1_node != nullptr, RET_NULL_PTR, "reshape1_node is nullptr"); 2531+ size_t gather_index = 0; 2532+ ret = opt::GetMatchNodeIndex(graph, matched_path, std::string(GatherName), &gather_index); 2533+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "cannot get gather_index"); 2534+ auto &gather_node = graph->nodes.at(gather_index); 2535+ MS_CHECK_TRUE_MSG(gather_node != nullptr, RET_NULL_PTR, "gather_node is nullptr"); 2536+ size_t reshape2_index = 0; 2537+ ret = opt::GetMatchNodeIndex(graph, matched_path, std::string(Reshape2Name), &reshape2_index); 2538+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "cannot get reshape2_index"); 2539+ auto &reshape2_node = graph->nodes.at(reshape2_index); 2540+ MS_CHECK_TRUE_MSG(reshape2_node != nullptr, RET_NULL_PTR, "reshape2_node is nullptr"); 2541+ 2542+ if (reshape1_node->inputIndex.size() != opt::kInputSizeTwo || 2543+ reshape1_node->outputIndex.size() != opt::kOutputSizeOne || 2544+ reshape1_node->quantType == schema::QuantType_QUANT_ALL || 2545+ reshape1_node->quantType == schema::QuantType_QUANT_DYNAMIC || 2546+ reshape2_node->inputIndex.size() != opt::kInputSizeTwo || 2547+ reshape2_node->outputIndex.size() != opt::kOutputSizeOne || 2548+ reshape2_node->quantType == schema::QuantType_QUANT_ALL || 2549+ reshape2_node->quantType == schema::QuantType_QUANT_DYNAMIC || 2550+ gather_node->quantType == schema::QuantType_QUANT_ALL || 2551+ gather_node->quantType == schema::QuantType_QUANT_DYNAMIC) { 2552+ MS_LOG(ERROR) << "reshape_node cannot fusion"; 2553+ return RET_NO_CHANGE; 2554+ } 2555+ 2556+ auto old_shape = graph->allTensors.at(reshape2_node->outputIndex.at(opt::kOutputIndexZero))->dims; 2557+ auto gather_shape0 = graph->allTensors.at(gather_node->inputIndex.at(opt::kInputIndexZero))->dims; 2558+ auto gather_shape1 = graph->allTensors.at(reshape1_node->inputIndex.at(opt::kInputIndexZero))->dims; 2559+ if (old_shape.empty() || gather_shape0.empty() || gather_shape1.empty()) { 2560+ return RET_NO_CHANGE; 2561+ } 2562+ int gather_axis; 2563+ auto data = graph->allTensors.at(gather_node->inputIndex.at(opt::kInputIndexTwo))->data; 2564+ if (data.empty()) { 2565+ gather_axis = 0; 2566+ } else { 2567+ memcpy(&gather_axis, &data[0], data.size()); 2568+ } 2569+ if (gather_axis < 0) { 2570+ gather_axis += gather_shape1.size(); 2571+ } 2572+ gather_shape0.erase(gather_shape0.begin() + gather_axis); 2573+ (void)gather_shape0.insert(gather_shape0.begin() + gather_axis, gather_shape1.begin(), gather_shape1.end()); 2574+ if (gather_shape0 != old_shape) { 2575+ return RET_NO_CHANGE; 2576+ } 2577+ 2578+ gather_node->inputIndex.at(opt::kInputIndexOne) = reshape1_node->inputIndex.at(opt::kInputIndexZero); 2579+ gather_node->outputIndex.at(opt::kOutputIndexZero) = reshape2_node->outputIndex.at(opt::kOutputIndexZero); 2580+ 2581+ // cannot delete node here, otherwise will destroy order in other pattern's node index 2582+ // make it an isolated node to be removed in IsolatedNodeRemovePass 2583+ reshape1_node->inputIndex.clear(); 2584+ reshape1_node->outputIndex.clear(); 2585+ reshape2_node->inputIndex.clear(); 2586+ reshape2_node->outputIndex.clear(); 2587+ return RET_OK; 2588+} 2589+ 2590+ReshapeGatherReshapeFusionPass::~ReshapeGatherReshapeFusionPass() = default; 2591+} // namespace lite 2592+} // namespace mindspore 2593diff --git a/mindspore/lite/src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.h b/mindspore/lite/src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.h 2594new file mode 100644 2595index 00000000..ef184a3c 2596--- /dev/null 2597+++ b/mindspore/lite/src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.h 2598@@ -0,0 +1,42 @@ 2599+/** 2600+ * Copyright 2022 Huawei Technologies Co., Ltd 2601+ * 2602+ * Licensed under the Apache License, Version 2.0 (the "License"); 2603+ * you may not use this file except in compliance with the License. 2604+ * You may obtain a copy of the License at 2605+ * 2606+ * http://www.apache.org/licenses/LICENSE-2.0 2607+ * 2608+ * Unless required by applicable law or agreed to in writing, software 2609+ * distributed under the License is distributed on an "AS IS" BASIS, 2610+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 2611+ * See the License for the specific language governing permissions and 2612+ * limitations under the License. 2613+ */ 2614+ 2615+#ifndef MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_RESHAPE_GATHER_RESHAPE_FUSION_PASS_H_ 2616+#define MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_RESHAPE_GATHER_RESHAPE_FUSION_PASS_H_ 2617+ 2618+#include <string> 2619+#include <unordered_map> 2620+#include <memory> 2621+#include <algorithm> 2622+#include <utility> 2623+#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h" 2624+ 2625+namespace mindspore { 2626+namespace lite { 2627+class ReshapeGatherReshapeFusionPass : public FusionPass { 2628+ public: 2629+ ReshapeGatherReshapeFusionPass() = default; 2630+ 2631+ ~ReshapeGatherReshapeFusionPass() override; 2632+ 2633+ STATUS DefinePattern() override; 2634+ 2635+ STATUS DoFusion(MetaGraphT *graph, const std::string &pattern_name, 2636+ const std::unordered_map<std::string, std::shared_ptr<Path>> &matched_path) override; 2637+}; 2638+} // namespace lite 2639+} // namespace mindspore 2640+#endif // MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_RESHAPE_GATHER_RESHAPE_FUSION_PASS_H_ 2641diff --git a/mindspore/lite/src/train/train_export.cc b/mindspore/lite/src/train/train_export.cc 2642index a9990963..7e504c4e 100644 2643--- a/mindspore/lite/src/train/train_export.cc 2644+++ b/mindspore/lite/src/train/train_export.cc 2645@@ -612,8 +612,40 @@ int TrainExport::SaveModel(lite::Model *model, const std::string &file_name) { 2646 return status; 2647 } 2648 2649+int TrainExport::SaveModel(lite::Model *model, Buffer *model_buffer) { 2650+ MS_CHECK_FALSE_MSG(model == nullptr, RET_ERROR, "model cannot be empty."); 2651+ MS_CHECK_FALSE_MSG(model_buffer == nullptr, RET_ERROR, "model_buffer cannot be empty."); 2652+ auto *liteModel = reinterpret_cast<LiteModel *>(model); 2653+ auto size = liteModel->buf_size_; 2654+ model_buffer->ResizeData(size); 2655+ 2656+ size_t out_size = model_buffer->DataSize(); 2657+ int status = mindspore::lite::Model::Export(model, static_cast<char *>(model_buffer->MutableData()), &out_size); 2658+ if (out_size != size) { 2659+ MS_LOG(ERROR) << "model_buffer resize failed."; 2660+ return RET_ERROR; 2661+ } 2662+ 2663+ return status; 2664+} 2665+ 2666 int TrainExport::SaveToFile() { return Storage::Save(*meta_graph_, file_name_); } 2667 2668+int TrainExport::SaveToBuffer() { 2669+ constexpr size_t kFbBuilderInitSize = 1024; 2670+ flatbuffers::FlatBufferBuilder builder(kFbBuilderInitSize); 2671+ auto offset = schema::MetaGraph::Pack(builder, meta_graph_); 2672+ builder.Finish(offset); 2673+ schema::FinishMetaGraphBuffer(builder, offset); 2674+ size_t size = builder.GetSize(); 2675+ auto content = builder.GetBufferPointer(); 2676+ MS_CHECK_FALSE_MSG(content == nullptr, RET_ERROR, "context cannot be empty."); 2677+ MS_CHECK_FALSE_MSG(model_buffer_ == nullptr, RET_ERROR, "context cannot be empty."); 2678+ model_buffer_->SetData(content, size); 2679+ return RET_OK; 2680+} 2681+ 2682+ 2683 bool TrainExport::IsInputTensor(const schema::TensorT &t) { 2684 int total_dims = std::accumulate(t.dims.begin(), t.dims.end(), 1, std::multiplies<int>()); 2685 return ((t.data.size() == 0) && (total_dims != 0)); 2686diff --git a/mindspore/lite/src/train/train_export.h b/mindspore/lite/src/train/train_export.h 2687index 7727109b..8e802021 100644 2688--- a/mindspore/lite/src/train/train_export.h 2689+++ b/mindspore/lite/src/train/train_export.h 2690@@ -44,24 +44,28 @@ struct tensor_info { 2691 class TrainExport { 2692 public: 2693 explicit TrainExport(const std::string file_name) : file_name_(file_name) {} 2694+ explicit TrainExport(Buffer *model_buffer) : model_buffer_(model_buffer) {} 2695 virtual ~TrainExport(); 2696 int ExportNet(const std::vector<mindspore::kernel::KernelExec *> &kernels, 2697 const std::vector<mindspore::lite::Tensor *> &tensors, const std::vector<std::string> &output_names, 2698 const Model *model, QuantizationType quant_type, const Model *bb_model = nullptr); 2699 int ExportInit(const std::string model_name, std::string version); 2700 int SaveToFile(); 2701+ int SaveToBuffer(); 2702 void set_connect(const std::unordered_map<size_t, size_t> &map) { connect_ = map; } 2703 int LoadModel(void *buf, size_t buf_size); 2704 int AddTransformNode(); 2705 int TrainModelFusion(); 2706 int TrainModelDrop(); 2707 int SaveModel(lite::Model *model, const std::string &file_name); 2708+ int SaveModel(lite::Model *model, Buffer *model_buffer); 2709 2710 protected: 2711 virtual std::vector<uint8_t> CreateData(const mindspore::lite::Tensor *tensor); 2712 2713 private: 2714 std::string file_name_; 2715+ Buffer *model_buffer_ = nullptr; 2716 schema::MetaGraphT *meta_graph_ = nullptr; 2717 std::vector<size_t> out_idx_; 2718 std::map<size_t, size_t> remap_; 2719diff --git a/mindspore/lite/src/train/train_populate_parameter.cc b/mindspore/lite/src/train/train_populate_parameter.cc 2720index bda5d0a5..9874a30d 100644 2721--- a/mindspore/lite/src/train/train_populate_parameter.cc 2722+++ b/mindspore/lite/src/train/train_populate_parameter.cc 2723@@ -31,6 +31,8 @@ 2724 #include "nnacl/fp32_grad/smooth_l1_loss.h" 2725 #include "nnacl/fp32_grad/resize_grad_parameter.h" 2726 #include "nnacl/fp32_grad/lstm_grad_fp32.h" 2727+#include "nnacl/fp32_grad/binary_cross_entropy.h" 2728+#include "nnacl/fp32_grad/binary_cross_entropy_grad.h" 2729 2730 using mindspore::lite::Registry; 2731 2732@@ -88,29 +90,44 @@ OpParameter *PopulateApplyMomentumParameter(const void *prim) { 2733 } 2734 2735 OpParameter *PopulateBCEParameter(const void *prim) { 2736- int32_t *reduction = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t))); 2737- if (reduction == nullptr) { 2738- MS_LOG(ERROR) << "malloc reduction failed."; 2739- return nullptr; 2740- } 2741 auto primitive = static_cast<const schema::Primitive *>(prim); 2742+ MS_ASSERT(primitive != nullptr); 2743 auto value = primitive->value_as_BinaryCrossEntropy(); 2744- MS_ASSERT(value != nullptr); 2745- *reduction = value->reduction(); 2746- return reinterpret_cast<OpParameter *>(reduction); 2747+ if (value == nullptr) { 2748+ MS_LOG(ERROR) << "value is nullptr"; 2749+ return nullptr; 2750+ } 2751+ 2752+ auto *param = reinterpret_cast<BinaryCrossEntropyParameter *>(malloc(sizeof(BinaryCrossEntropyParameter))); 2753+ if (param == nullptr) { 2754+ MS_LOG(ERROR) << "malloc BinaryCrossEntropy Parameter failed."; 2755+ return nullptr; 2756+ } 2757+ memset(param, 0, sizeof(BinaryCrossEntropyParameter)); 2758+ 2759+ param->op_parameter_.type_ = primitive->value_type(); 2760+ param->reduction = value->reduction(); 2761+ return reinterpret_cast<OpParameter *>(param); 2762 } 2763 2764 OpParameter *PopulateBCEGradParameter(const void *prim) { 2765- int32_t *reduction = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t))); 2766- if (reduction == nullptr) { 2767- MS_LOG(ERROR) << "malloc reduction failed."; 2768+ auto *primitive = static_cast<const schema::Primitive *>(prim); 2769+ MS_ASSERT(primitive != nullptr); 2770+ auto value = primitive->value_as_BinaryCrossEntropyGrad(); 2771+ if (value == nullptr) { 2772+ MS_LOG(ERROR) << "param is nullptr"; 2773 return nullptr; 2774 } 2775- auto primitive = static_cast<const schema::Primitive *>(prim); 2776- auto value = primitive->value_as_BinaryCrossEntropyGrad(); 2777- MS_ASSERT(value != nullptr); 2778- *reduction = value->reduction(); 2779- return reinterpret_cast<OpParameter *>(reduction); 2780+ auto *param = reinterpret_cast<BinaryCrossEntropyGradParameter *>(malloc(sizeof(BinaryCrossEntropyGradParameter))); 2781+ if (param == nullptr) { 2782+ MS_LOG(ERROR) << "malloc BinaryCrossEntropyGrad Parameter failed."; 2783+ return nullptr; 2784+ } 2785+ memset(param, 0, sizeof(BinaryCrossEntropyGradParameter)); 2786+ 2787+ param->op_parameter_.type_ = primitive->value_type(); 2788+ param->reduction = value->reduction(); 2789+ return reinterpret_cast<OpParameter *>(param); 2790 } 2791 2792 OpParameter *PopulateAdamParameter(const void *prim) { 2793diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc 2794index cd1af1b6..b40ff8c2 100644 2795--- a/mindspore/lite/src/train/train_session.cc 2796+++ b/mindspore/lite/src/train/train_session.cc 2797@@ -206,10 +206,11 @@ static int ReshapeWeightTensor(Tensor *orig_tensor, lite::Tensor *new_tensor) { 2798 } 2799 } 2800 2801- orig_tensor->FreeData(); 2802- orig_tensor->set_data(nullptr); 2803- orig_tensor->set_shape(new_tensor->shape()); 2804- 2805+ if (orig_tensor->shape() != new_tensor->shape()) { 2806+ orig_tensor->FreeData(); 2807+ orig_tensor->set_data(nullptr); 2808+ orig_tensor->set_shape(new_tensor->shape()); 2809+ } 2810 uint8_t *dst_data = reinterpret_cast<uint8_t *>(orig_tensor->MutableData()); 2811 if (dst_data == nullptr) { 2812 MS_LOG(ERROR) << "Allocation of Data Failed"; 2813@@ -228,6 +229,9 @@ int TrainSession::UpdateWeights(std::vector<lite::Tensor *> modify_tensors) { 2814 return RET_PARAM_INVALID; 2815 } 2816 if (modify->tensor_name() == tensor->tensor_name()) { 2817+ if (tensor->Size() != modify->Size()) { 2818+ model_buff_changed_ = true; 2819+ } 2820 auto ret = ReshapeWeightTensor(tensor, modify); 2821 num_of_found_tensors++; 2822 if (ret != RET_OK) { 2823@@ -243,6 +247,7 @@ int TrainSession::UpdateWeights(std::vector<lite::Tensor *> modify_tensors) { 2824 } 2825 auto ret = ReSizeKernels(kernels_); 2826 if (ret != RET_OK) { 2827+ model_buff_changed_ = false; 2828 MS_LOG(ERROR) << "Resize kernels fail!"; 2829 return ret; 2830 } 2831@@ -1154,9 +1159,17 @@ int TrainSession::FindExportKernels(std::vector<kernel::KernelExec *> *export_ke 2832 return RET_OK; 2833 } 2834 2835-int TrainSession::Export(const std::string &file_name, ModelType model_type, QuantizationType quant_type, 2836- FormatType format, std::vector<std::string> out_put_tensor_name) { 2837- MS_CHECK_FALSE_MSG(file_name.empty(), RET_ERROR, "File name cannot be empty"); 2838+template <typename DestType> 2839+int TrainSession::ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type, 2840+ FormatType format, std::vector<std::string> out_put_tensor_name) { 2841+ if constexpr (std::is_same_v<DestType, const std::string &>) { 2842+ MS_CHECK_FALSE_MSG(destination.empty(), RET_ERROR, "File name cannot be empty"); 2843+ } else if constexpr (std::is_same_v<DestType, Buffer *>) { 2844+ MS_CHECK_FALSE_MSG(destination == nullptr, RET_ERROR, "model buffer cannot be nullptr"); 2845+ } else { 2846+ MS_LOG(ERROR) << "Unsupported destination."; 2847+ return RET_ERROR; 2848+ } 2849 MS_CHECK_FALSE_MSG(model_type > mindspore::lite::MT_INFERENCE || model_type < mindspore::lite::MT_TRAIN, RET_ERROR, 2850 "Export model type parameter error"); 2851 MS_CHECK_FALSE_MSG(quant_type < mindspore::lite::QT_DEFAULT || quant_type > mindspore::lite::QT_WEIGHT, RET_ERROR, 2852@@ -1165,27 +1178,21 @@ int TrainSession::Export(const std::string &file_name, ModelType model_type, Qua 2853 2854 bool orig_train_state = IsTrain(); 2855 Eval(); 2856- TrainExport texport(file_name); 2857+ TrainExport texport(destination); 2858 int status = texport.ExportInit(model_.get()->graph_.name_, model_.get()->graph_.version_); 2859- if (status != RET_OK) { 2860- MS_LOG(ERROR) << "cannot init export"; 2861- return status; 2862- } 2863+ TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to init export"); 2864 2865 if (!out_put_tensor_name.empty() && model_type == MT_INFERENCE) { 2866 std::vector<kernel::KernelExec *> export_kernels = {}; 2867 status = FindExportKernels(&export_kernels, out_put_tensor_name, inference_kernels_); 2868- if (status != RET_OK) { 2869- MS_LOG(ERROR) << "FindExportKernels failed."; 2870- return RET_ERROR; 2871- } 2872+ TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "FindExportKernels failed."); 2873 status = texport.ExportNet(export_kernels, tensors_, out_put_tensor_name, model_.get(), quant_type); 2874 } else { 2875- if ((quant_type == QT_NONE) && (model_type == MT_TRAIN) && 2876+ if ((!model_buff_changed_) && (quant_type == QT_NONE) && (model_type == MT_TRAIN) && 2877 std::all_of(model_->graph_.all_nodes_.begin(), model_->graph_.all_nodes_.end(), [](const LiteGraph::Node *n) { 2878 return n->quant_type_ == schema::QuantType::QuantType_QUANT_NONE; 2879 })) { 2880- status = texport.SaveModel(model_.get(), file_name); 2881+ status = texport.SaveModel(model_.get(), destination); 2882 if (orig_train_state) Train(); 2883 return status; 2884 } else { 2885@@ -1194,35 +1201,42 @@ int TrainSession::Export(const std::string &file_name, ModelType model_type, Qua 2886 model_.get(), quant_type); 2887 } 2888 } 2889+ TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to export Network."); 2890 2891- if (status != RET_OK) { 2892- MS_LOG(ERROR) << "cannot export Network"; 2893- return status; 2894- } 2895 if (model_type == MT_INFERENCE) { 2896 status = texport.TrainModelDrop(); 2897- if (status != RET_OK) { 2898- MS_LOG(ERROR) << "TrainModelDrop failed."; 2899- return status; 2900- } 2901+ TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelDrop failed."); 2902 status = texport.TrainModelFusion(); 2903+ TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelFusion failed."); 2904+ } 2905+ if constexpr (std::is_same_v<DestType, const std::string &>) { 2906+ status = texport.SaveToFile(); 2907 if (status != RET_OK) { 2908- MS_LOG(ERROR) << "TrainModelFusion failed."; 2909+ MS_LOG(ERROR) << "failed to save to " << destination; 2910 return status; 2911 } 2912- } 2913- status = texport.SaveToFile(); 2914- if (status != RET_OK) { 2915- MS_LOG(ERROR) << "failed to save to " << file_name; 2916- return status; 2917+ } else { 2918+ status = texport.SaveToBuffer(); 2919+ TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "fail to save to model buffer."); 2920 } 2921 if (orig_train_state) Train(); 2922 return status; 2923 } 2924+ 2925+int TrainSession::Export(const std::string &file_name, ModelType model_type, QuantizationType quant_type, 2926+ FormatType format, std::vector<std::string> out_put_tensor_name) { 2927+ return ExportInner<const std::string &>(file_name, model_type, quant_type, format, out_put_tensor_name); 2928+} 2929+ 2930+int TrainSession::Export(Buffer *model_buffer, ModelType model_type, QuantizationType quant_type, FormatType format, 2931+ std::vector<std::string> out_put_tensor_name) { 2932+ return ExportInner<Buffer *>(model_buffer, model_type, quant_type, format, out_put_tensor_name); 2933+} 2934+ 2935 std::vector<lite::Tensor *> TrainSession::GetFeatureMaps() const { 2936 std::vector<lite::Tensor *> features; 2937 for (auto cur_tensor : this->tensors_) { 2938- if (cur_tensor->IsConst() && cur_tensor->data_type() == kNumberTypeFloat32) { 2939+ if (cur_tensor->category() ==lite::Category::CONST_TENSOR && cur_tensor->data_type() == kNumberTypeFloat32) { 2940 features.push_back(cur_tensor); 2941 } 2942 } 2943diff --git a/mindspore/lite/src/train/train_session.h b/mindspore/lite/src/train/train_session.h 2944index 0a0ce640..5acff82a 100644 2945--- a/mindspore/lite/src/train/train_session.h 2946+++ b/mindspore/lite/src/train/train_session.h 2947@@ -36,6 +36,14 @@ 2948 +-------------------------------+ 2949 */ 2950 2951+#define TRAIN_SESSION_CHECK_FALSE_MSG(value, errcode, msg) \ 2952+ do { \ 2953+ if ((value)) { \ 2954+ MS_LOG(ERROR) << #msg; \ 2955+ return errcode; \ 2956+ } \ 2957+ } while (0) 2958+ 2959 namespace mindspore { 2960 namespace lite { 2961 using CreatorOp = std::tuple<mindspore::kernel::KernelKey, mindspore::kernel::KernelCreator>; 2962@@ -96,6 +104,8 @@ class TrainSession : virtual public lite::LiteSession { 2963 } 2964 int Export(const std::string &fb_name, ModelType model_type, QuantizationType quant_type, FormatType, 2965 std::vector<std::string> out_put_tensor_name = {}) override; 2966+ int Export(Buffer *model_buffer, ModelType model_type, QuantizationType quant_type, FormatType, 2967+ std::vector<std::string> out_put_tensor_name = {}) override; 2968 2969 std::vector<lite::Tensor *> GetFeatureMaps() const override; 2970 2971@@ -165,6 +175,9 @@ class TrainSession : virtual public lite::LiteSession { 2972 const std::unordered_map<lite::Tensor *, size_t> &offset_map, 2973 std::unordered_map<lite::Tensor *, int> *ref_count, uint32_t input_idx); 2974 2975+ template <typename DestType> 2976+ int ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type, FormatType, 2977+ std::vector<std::string> out_put_tensor_name = {}); 2978 std::map<Tensor *, Tensor *> restored_origin_tensors_; 2979 int virtual_batch_idx_ = 0; 2980 int virtual_batch_multiplier_ = 0; 2981@@ -172,6 +185,7 @@ class TrainSession : virtual public lite::LiteSession { 2982 void *workspace_ = nullptr; 2983 SchedCallBack sched_mix_precision_callback_; 2984 bool train_mode_ = false; 2985+ bool model_buff_changed_ = false; 2986 void *tensors_data_ = nullptr; 2987 size_t tensors_data_size_ = 0; 2988 std::shared_ptr<Allocator> allocator_; 2989diff --git a/mindspore/lite/src/train/transfer_session.cc b/mindspore/lite/src/train/transfer_session.cc 2990index b54f348e..031c4a6b 100644 2991--- a/mindspore/lite/src/train/transfer_session.cc 2992+++ b/mindspore/lite/src/train/transfer_session.cc 2993@@ -183,15 +183,24 @@ std::unordered_map<size_t, size_t> TransferSession::ConnectionMap() { 2994 return map; 2995 } 2996 2997-int TransferSession::Export(const std::string &filename, ModelType model_type, QuantizationType quant_type, 2998- FormatType format, std::vector<std::string> out_put_tensor_name) { 2999+template <typename DestType> 3000+int TransferSession::ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type, 3001+ FormatType format, std::vector<std::string> out_put_tensor_name) { 3002+ if constexpr (std::is_same_v<DestType, const std::string &>) { 3003+ MS_CHECK_FALSE_MSG(destination.empty(), RET_ERROR, "File name cannot be empty"); 3004+ } else if constexpr (std::is_same_v<DestType, Buffer *>) { 3005+ MS_CHECK_FALSE_MSG(destination == nullptr, RET_ERROR, "model buffer cannot be nullptr"); 3006+ } else { 3007+ MS_LOG(ERROR) << "Unsupported destination."; 3008+ return RET_ERROR; 3009+ } 3010 if (format != FT_FLATBUFFERS) { 3011 MS_LOG(ERROR) << "Currently only flatbuffer format is supported"; 3012 return RET_ERROR; 3013 } 3014 3015 if (model_type == MT_TRAIN) { 3016- return TrainSession::Export(filename, model_type, quant_type, format); 3017+ return TrainSession::Export(destination, model_type, quant_type, format); 3018 } 3019 3020 bool orig_train_state = IsTrain(); 3021@@ -199,7 +208,7 @@ int TransferSession::Export(const std::string &filename, ModelType model_type, Q 3022 MS_LOG(ERROR) << "eval failed."; 3023 return RET_ERROR; 3024 } 3025- TrainExport texport(filename); 3026+ TrainExport texport(destination); 3027 int status = texport.LoadModel(lite_model_, size_backbone_); 3028 if (status != RET_OK) { 3029 MS_LOG(ERROR) << "cannot init export"; 3030@@ -231,10 +240,15 @@ int TransferSession::Export(const std::string &filename, ModelType model_type, Q 3031 MS_LOG(ERROR) << "cannot serialize head"; 3032 return status; 3033 } 3034- status = texport.SaveToFile(); 3035- if (status != RET_OK) { 3036- MS_LOG(ERROR) << "failed to save to " << filename; 3037- return status; 3038+ if constexpr (std::is_same_v<DestType, const std::string &>) { 3039+ status = texport.SaveToFile(); 3040+ if (status != RET_OK) { 3041+ MS_LOG(ERROR) << "failed to save to " << destination; 3042+ return status; 3043+ } 3044+ } else { 3045+ status = texport.SaveToBuffer(); 3046+ MS_CHECK_FALSE_MSG(status != RET_OK, status, "fail to save to model buffer."); 3047 } 3048 if (orig_train_state) { 3049 auto ret = Train(); 3050@@ -246,6 +260,17 @@ int TransferSession::Export(const std::string &filename, ModelType model_type, Q 3051 return status; 3052 } 3053 3054+int TransferSession::Export(const std::string &filename, ModelType model_type, QuantizationType quant_type, 3055+ FormatType format, std::vector<std::string> out_put_tensor_name) { 3056+ return ExportInner<const std::string &>(filename, model_type, quant_type, format, out_put_tensor_name); 3057+} 3058+ 3059+int TransferSession::Export(Buffer *model_buffer, ModelType model_type, QuantizationType quant_type, FormatType format, 3060+ std::vector<std::string> out_put_tensor_name) { 3061+ return ExportInner<Buffer *>(model_buffer, model_type, quant_type, format, out_put_tensor_name); 3062+} 3063+ 3064+ 3065 lite::LiteSession *CreateTransferSessionInt(const char *model_buf_backbone, size_t size_backbone, 3066 const char *model_buf_head, size_t size_head, const lite::Context *context, 3067 bool train_mode, const lite::TrainCfg *cfg) { 3068diff --git a/mindspore/lite/src/train/transfer_session.h b/mindspore/lite/src/train/transfer_session.h 3069index 48a38b8b..6cd06c60 100644 3070--- a/mindspore/lite/src/train/transfer_session.h 3071+++ b/mindspore/lite/src/train/transfer_session.h 3072@@ -63,6 +63,8 @@ class TransferSession : public lite::TrainSession { 3073 int CompileTransferGraph(); 3074 int Export(const std::string &fb_name, ModelType model_type, QuantizationType quant_type, FormatType, 3075 std::vector<std::string> out_put_tensor_name = {}) override; 3076+ int Export(Buffer *model_buffer, ModelType model_type, QuantizationType quant_type, FormatType, 3077+ std::vector<std::string> out_put_tensor_name = {}) override; 3078 3079 protected: 3080 LiteSession *backbone_session_ = nullptr; 3081@@ -72,6 +74,9 @@ class TransferSession : public lite::TrainSession { 3082 bool is_valid_ = false; 3083 3084 private: 3085+ template <typename DestType> 3086+ int ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type, FormatType, 3087+ std::vector<std::string> out_put_tensor_name = {}); 3088 bool CompileFormatTransform(lite::Tensor *out, lite::Tensor *in, int *mask, size_t mask_len); 3089 std::unordered_map<size_t, size_t> ConnectionMap(); 3090 bool nchw2nhwc_ = false; 3091diff --git a/mindspore/lite/tools/benchmark_train/net_train.cc b/mindspore/lite/tools/benchmark_train/net_train.cc 3092index cad52545..20d2d298 100644 3093--- a/mindspore/lite/tools/benchmark_train/net_train.cc 3094+++ b/mindspore/lite/tools/benchmark_train/net_train.cc 3095@@ -1,5 +1,5 @@ 3096 /** 3097- * Copyright 2020 Huawei Technologies Co., Ltd 3098+ * Copyright 2020-2023 Huawei Technologies Co., Ltd 3099 * 3100 * Licensed under the Apache License, Version 2.0 (the "License"); 3101 * you may not use this file except in compliance with the License. 3102@@ -42,6 +42,21 @@ constexpr int kField4 = 4; 3103 constexpr int kFieldsToPrint = 5; 3104 constexpr int kPrintOffset = 4; 3105 static const int kTHOUSAND = 1000; 3106+constexpr int kDumpInputsAndOutputs = 0; 3107+constexpr int kDumpOutputs = 2; 3108+ 3109+const std::unordered_map<int, std::string> kTypeIdMap{ 3110+ {kNumberTypeFloat16, "Float16"}, {kNumberTypeFloat, "Float32"}, {kNumberTypeFloat32, "Float32"}, 3111+ {kNumberTypeInt8, "Int8"}, {kNumberTypeInt16, "Int16"}, {kNumberTypeInt, "Int32"}, 3112+ {kNumberTypeInt32, "Int32"}, {kNumberTypeUInt8, "UInt8"}, {kNumberTypeUInt16, "UInt16"}, 3113+ {kNumberTypeUInt, "UInt32"}, {kNumberTypeUInt32, "UInt32"}, {kObjectTypeString, "String"}, 3114+ {kNumberTypeBool, "Bool"}, {kObjectTypeTensorType, "Tensor"}}; 3115+ 3116+const std::unordered_map<mindspore::Format, std::string> kTensorFormatMap{ 3117+ {mindspore::NCHW, "NCHW"}, {mindspore::NHWC, "NHWC"}, {mindspore::NHWC4, "NHWC4"}, {mindspore::HWKC, "HWKC"}, 3118+ {mindspore::HWCK, "HWCK"}, {mindspore::KCHW, "KCHW"}, {mindspore::CKHW, "CKHW"}, {mindspore::KHWC, "KHWC"}, 3119+ {mindspore::CHWK, "CHWK"}, {mindspore::HW, "HW"}, {mindspore::HW4, "HW4"}, {mindspore::NC, "NC"}, 3120+ {mindspore::NC4, "NC4"}, {mindspore::NC4HW4, "NC4HW4"}, {mindspore::NCDHW, "NCDHW"}}; 3121 3122 std::function<int(NetTrainFlags *)> NetTrain::nr_cb_ = nullptr; 3123 3124@@ -249,9 +264,7 @@ int NetTrain::MarkPerformance() { 3125 3126 for (int i = 0; i < flags_->epochs_; i++) { 3127 auto start = GetTimeUs(); 3128- auto status = flags_->time_profiling_ 3129- ? ms_model_.Predict(ms_inputs_for_api_, &outputs, before_call_back_, after_call_back_) 3130- : ms_model_.Predict(ms_inputs_for_api_, &outputs); 3131+ auto status = ms_model_.RunStep(before_call_back_, after_call_back_); 3132 if (status != mindspore::kSuccess) { 3133 MS_LOG(ERROR) << "Inference error " << status; 3134 std::cerr << "Inference error " << status; 3135@@ -300,7 +313,7 @@ int NetTrain::MarkAccuracy(bool enforce_accuracy) { 3136 } 3137 } 3138 std::vector<MSTensor> outputs; 3139- auto status = ms_model_.Predict(ms_inputs_for_api_, &outputs); 3140+ auto status = ms_model_.RunStep(before_call_back_, after_call_back_); 3141 if (status != mindspore::kSuccess) { 3142 MS_LOG(ERROR) << "Inference error " << status; 3143 std::cerr << "Inference error " << status << std::endl; 3144@@ -405,17 +418,22 @@ void NetTrain::InitTrainCfg(const std::shared_ptr<TrainCfg> &train_cfg) { 3145 if (flags_->loss_name_.empty()) { 3146 return; 3147 } 3148- train_cfg->loss_name_.clear(); 3149+ std::vector<std::string> empty_loss_name; 3150+ train_cfg->SetLossName(empty_loss_name); // clear train_cfg's loss_name 3151 std::string delimiter = ","; 3152 size_t pos = 0; 3153 std::string token; 3154 while ((pos = flags_->loss_name_.find(delimiter)) != std::string::npos) { 3155 token = flags_->loss_name_.substr(0, pos); 3156 flags_->loss_name_.erase(0, pos + delimiter.length()); // change to delim without deletion 3157- train_cfg->loss_name_.emplace_back(token); 3158+ std::vector<std::string> train_cfg_loss_name = train_cfg->GetLossName(); 3159+ train_cfg_loss_name.emplace_back(token); 3160+ train_cfg->SetLossName(train_cfg_loss_name); 3161 } 3162 if (!(flags_->loss_name_.empty())) { 3163- train_cfg->loss_name_.emplace_back(flags_->loss_name_); 3164+ std::vector<std::string> train_cfg_loss_name = train_cfg->GetLossName(); 3165+ train_cfg_loss_name.emplace_back(flags_->loss_name_); 3166+ train_cfg->SetLossName(train_cfg_loss_name); 3167 } 3168 } 3169 3170@@ -635,7 +653,79 @@ void NetTrain::CheckSum(MSTensor *tensor, const std::string &node_type, int id, 3171 } 3172 } 3173 3174-int NetTrain::InitCallbackParameter() { 3175+std::string GenerateOutputFileName(mindspore::MSTensor *tensor, const std::string &op_name, 3176+ const std::string &file_type, const size_t &idx) { 3177+ std::string file_name = op_name; 3178+ auto pos = file_name.find_first_of('/'); 3179+ while (pos != std::string::npos) { 3180+ file_name.replace(pos, 1, "."); 3181+ pos = file_name.find_first_of('/'); 3182+ } 3183+ file_name += "_" + file_type + "_" + std::to_string(idx) + "_shape_"; 3184+ for (const auto &dim : tensor->Shape()) { 3185+ file_name += std::to_string(dim) + "_"; 3186+ } 3187+ if (kTypeIdMap.find(static_cast<int>(tensor->DataType())) != kTypeIdMap.end()) { 3188+ file_name += kTypeIdMap.at(static_cast<int>(tensor->DataType())); 3189+ } 3190+ auto tensor_format = tensor->format(); 3191+ if (kTensorFormatMap.find(tensor_format) != kTensorFormatMap.end()) { 3192+ file_name += "_" + kTensorFormatMap.at(tensor_format) + ".bin"; 3193+ } 3194+ 3195+ file_name += ".bin"; 3196+ return file_name; 3197+} 3198+ 3199+int NetTrain::InitDumpTensorDataCallbackParameter() { 3200+ // before callback 3201+ before_call_back_ = [&](const std::vector<mindspore::MSTensor> &before_inputs, 3202+ const std::vector<mindspore::MSTensor> &before_outputs, const MSCallBackParam &call_param) { 3203+ auto dump_mode = dump_cfg_json_[dump::kSettings][dump::kMode].get<int>(); 3204+ auto input_output_mode = dump_cfg_json_[dump::kSettings][dump::kInputOutput].get<int>(); 3205+ auto kernels = dump_cfg_json_[dump::kSettings][dump::kKernels].get<std::vector<std::string>>(); 3206+ if (dump_mode == 0 || std::find(kernels.begin(), kernels.end(), call_param.node_name) != kernels.end()) { 3207+ if (input_output_mode == 0 || input_output_mode == 1) { 3208+ for (size_t i = 0; i < before_inputs.size(); i++) { 3209+ auto ms_tensor = before_inputs.at(i); 3210+ auto file_name = GenerateOutputFileName(&ms_tensor, call_param.node_name, "input", i); 3211+ auto abs_file_path = dump_file_output_dir_ + "/" + file_name; 3212+ if (WriteToBin(abs_file_path, ms_tensor.MutableData(), ms_tensor.DataSize()) != RET_OK) { // save to file 3213+ MS_LOG(ERROR) << "write tensor data to file failed."; 3214+ return false; 3215+ } 3216+ } 3217+ } 3218+ } 3219+ return true; 3220+ }; 3221+ 3222+ // after callback 3223+ after_call_back_ = [&](const std::vector<mindspore::MSTensor> &after_inputs, 3224+ const std::vector<mindspore::MSTensor> &after_outputs, const MSCallBackParam &call_param) { 3225+ auto dump_mode = dump_cfg_json_[dump::kSettings][dump::kMode].get<int>(); 3226+ auto input_output_mode = dump_cfg_json_[dump::kSettings][dump::kInputOutput].get<int>(); 3227+ auto kernels = dump_cfg_json_[dump::kSettings][dump::kKernels].get<std::vector<std::string>>(); 3228+ if (dump_mode == kDumpInputsAndOutputs || 3229+ std::find(kernels.begin(), kernels.end(), call_param.node_name) != kernels.end()) { 3230+ if (input_output_mode == kDumpInputsAndOutputs || input_output_mode == kDumpOutputs) { 3231+ for (size_t i = 0; i < after_outputs.size(); i++) { 3232+ auto ms_tensor = after_outputs.at(i); 3233+ auto file_name = GenerateOutputFileName(&ms_tensor, call_param.node_name, "output", i); 3234+ auto abs_file_path = dump_file_output_dir_ + "/" + file_name; 3235+ if (WriteToBin(abs_file_path, ms_tensor.MutableData(), ms_tensor.DataSize()) != RET_OK) { // save to file 3236+ MS_LOG(ERROR) << "write tensor data to file failed."; 3237+ return false; 3238+ } 3239+ } 3240+ } 3241+ } 3242+ return true; 3243+ }; 3244+ return RET_OK; 3245+} 3246+ 3247+int NetTrain::InitTimeProfilingCallbackParameter() { 3248 // before callback 3249 before_call_back_ = [&](const std::vector<mindspore::MSTensor> &before_inputs, 3250 const std::vector<mindspore::MSTensor> &before_outputs, 3251@@ -696,6 +786,16 @@ int NetTrain::InitCallbackParameter() { 3252 return RET_OK; 3253 } 3254 3255+int NetTrain::InitCallbackParameter() { 3256+ int ret = RET_OK; 3257+ if (flags_->dump_tensor_data_) { 3258+ ret = InitDumpTensorDataCallbackParameter(); 3259+ } else if (flags_->time_profiling_) { 3260+ ret = InitTimeProfilingCallbackParameter(); 3261+ } 3262+ return ret; 3263+} 3264+ 3265 void NetTrainFlags::InitResizeDimsList() { 3266 std::string content = this->resize_dims_in_; 3267 std::vector<int> shape; 3268@@ -761,14 +861,25 @@ int NetTrain::Init() { 3269 return 1; 3270 } 3271 3272- if (flags_->time_profiling_) { 3273- auto status = InitCallbackParameter(); 3274- if (status != RET_OK) { 3275- MS_LOG(ERROR) << "Init callback Parameter failed."; 3276- std::cerr << "Init callback Parameter failed." << std::endl; 3277+ // get dump data output path 3278+ auto dump_cfg_path = std::getenv(dump::kConfigPath); 3279+ if (dump_cfg_path != nullptr) { 3280+ flags_->dump_tensor_data_ = true; 3281+ if (InitDumpConfigFromJson(dump_cfg_path) != RET_OK) { 3282+ MS_LOG(ERROR) << "parse dump config file failed."; 3283 return RET_ERROR; 3284 } 3285+ } else { 3286+ MS_LOG(INFO) << "No MINDSPORE_DUMP_CONFIG in env, don't need to dump data"; 3287+ } 3288+ 3289+ auto status = InitCallbackParameter(); 3290+ if (status != RET_OK) { 3291+ MS_LOG(ERROR) << "Init callback Parameter failed."; 3292+ std::cerr << "Init callback Parameter failed." << std::endl; 3293+ return RET_ERROR; 3294 } 3295+ 3296 flags_->InitResizeDimsList(); 3297 if (!flags_->resize_dims_.empty() && !flags_->input_data_list_.empty() && 3298 flags_->resize_dims_.size() != flags_->input_data_list_.size()) { 3299@@ -779,6 +890,70 @@ int NetTrain::Init() { 3300 return RET_OK; 3301 } 3302 3303+int NetTrain::InitDumpConfigFromJson(char *path) { 3304+ auto real_path = RealPath(path); 3305+ std::ifstream ifs(real_path); 3306+ if (!ifs.good()) { 3307+ MS_LOG(ERROR) << "file: " << real_path << " is not exist"; 3308+ return RET_ERROR; 3309+ } 3310+ if (!ifs.is_open()) { 3311+ MS_LOG(ERROR) << "file: " << real_path << " open failed"; 3312+ return RET_ERROR; 3313+ } 3314+ 3315+ try { 3316+ dump_cfg_json_ = nlohmann::json::parse(ifs); 3317+ } catch (const nlohmann::json::parse_error &error) { 3318+ MS_LOG(ERROR) << "parse json file failed, please check your file."; 3319+ return RET_ERROR; 3320+ } 3321+ if (dump_cfg_json_[dump::kSettings] == nullptr) { 3322+ MS_LOG(ERROR) << "\"common_dump_settings\" is required."; 3323+ return RET_ERROR; 3324+ } 3325+ if (dump_cfg_json_[dump::kSettings][dump::kMode] == nullptr) { 3326+ MS_LOG(ERROR) << "\"dump_mode\" is required."; 3327+ return RET_ERROR; 3328+ } 3329+ if (dump_cfg_json_[dump::kSettings][dump::kPath] == nullptr) { 3330+ MS_LOG(ERROR) << "\"path\" is required."; 3331+ return RET_ERROR; 3332+ } 3333+ if (dump_cfg_json_[dump::kSettings][dump::kNetName] == nullptr) { 3334+ dump_cfg_json_[dump::kSettings][dump::kNetName] = "default"; 3335+ } 3336+ if (dump_cfg_json_[dump::kSettings][dump::kInputOutput] == nullptr) { 3337+ dump_cfg_json_[dump::kSettings][dump::kInputOutput] = 0; 3338+ } 3339+ if (dump_cfg_json_[dump::kSettings][dump::kKernels] != nullptr && 3340+ !dump_cfg_json_[dump::kSettings][dump::kKernels].empty()) { 3341+ if (dump_cfg_json_[dump::kSettings][dump::kMode] == 0) { 3342+ MS_LOG(ERROR) << R"("dump_mode" should be 1 when "kernels" isn't empty.)"; 3343+ return RET_ERROR; 3344+ } 3345+ } 3346+ 3347+ auto abs_path = dump_cfg_json_[dump::kSettings][dump::kPath].get<std::string>(); 3348+ auto net_name = dump_cfg_json_[dump::kSettings][dump::kNetName].get<std::string>(); 3349+ if (abs_path.back() == '\\' || abs_path.back() == '/') { 3350+ dump_file_output_dir_ = abs_path + net_name; 3351+ } else { 3352+#ifdef _WIN32 3353+ dump_file_output_dir_ = abs_path + "\\" + net_name; 3354+#else 3355+ dump_file_output_dir_ = abs_path + "/" + net_name; 3356+#endif 3357+ } 3358+ 3359+ auto status = CreateOutputDir(&dump_file_output_dir_); 3360+ if (status != RET_OK) { 3361+ MS_LOG(ERROR) << "create data output directory failed."; 3362+ return RET_ERROR; 3363+ } 3364+ return RET_OK; 3365+} 3366+ 3367 namespace { 3368 constexpr int kNumToPrint = 5; 3369 } 3370diff --git a/mindspore/lite/tools/benchmark_train/net_train.h b/mindspore/lite/tools/benchmark_train/net_train.h 3371index c67e952f..ea0aaacd 100644 3372--- a/mindspore/lite/tools/benchmark_train/net_train.h 3373+++ b/mindspore/lite/tools/benchmark_train/net_train.h 3374@@ -1,5 +1,5 @@ 3375 /** 3376- * Copyright 2020 Huawei Technologies Co., Ltd 3377+ * Copyright 2020-2023 Huawei Technologies Co., Ltd 3378 * 3379 * Licensed under the Apache License, Version 2.0 (the "License"); 3380 * you may not use this file except in compliance with the License. 3381@@ -30,6 +30,7 @@ 3382 #include <cfloat> 3383 #include <utility> 3384 #include <algorithm> 3385+#include <nlohmann/json.hpp> 3386 #include "include/api/model.h" 3387 #include "include/api/types.h" 3388 #include "include/api/context.h" 3389@@ -54,6 +55,18 @@ enum MS_API DataType { kImage = 0, kBinary = 1 }; 3390 3391 constexpr float relativeTolerance = 1e-5; 3392 constexpr float absoluteTolerance = 1e-8; 3393+extern const std::unordered_map<int, std::string> kTypeIdMap; 3394+extern const std::unordered_map<mindspore::Format, std::string> kTensorFormatMap; 3395+ 3396+namespace dump { 3397+constexpr auto kConfigPath = "MINDSPORE_DUMP_CONFIG"; 3398+constexpr auto kSettings = "common_dump_settings"; 3399+constexpr auto kMode = "dump_mode"; 3400+constexpr auto kPath = "path"; 3401+constexpr auto kNetName = "net_name"; 3402+constexpr auto kInputOutput = "input_output"; 3403+constexpr auto kKernels = "kernels"; 3404+} // namespace dump 3405 3406 template <typename T> 3407 float TensorSum(const void *data, int size) { 3408@@ -122,6 +135,7 @@ class MS_API NetTrainFlags : public virtual FlagParser { 3409 std::string loss_name_ = ""; 3410 std::string inference_file_ = ""; 3411 bool unified_api_ = false; 3412+ bool dump_tensor_data_ = false; 3413 }; 3414 3415 class MS_API NetTrain { 3416@@ -193,6 +207,7 @@ class MS_API NetTrain { 3417 } 3418 return meanError; 3419 } 3420+ int InitDumpConfigFromJson(char *path); 3421 3422 private: 3423 // call GenerateInputData or ReadInputFile to init inputTensors 3424@@ -219,6 +234,10 @@ class MS_API NetTrain { 3425 const std::shared_ptr<TrainCfg> &train_cfg, int epochs); 3426 int InitCallbackParameter(); 3427 3428+ int InitDumpTensorDataCallbackParameter(); 3429+ 3430+ int InitTimeProfilingCallbackParameter(); 3431+ 3432 int PrintResult(const std::vector<std::string> &title, const std::map<std::string, std::pair<int, float>> &result); 3433 3434 template <typename T> 3435@@ -280,6 +299,8 @@ class MS_API NetTrain { 3436 3437 mindspore::MSKernelCallBack before_call_back_{nullptr}; 3438 mindspore::MSKernelCallBack after_call_back_{nullptr}; 3439+ nlohmann::json dump_cfg_json_; 3440+ std::string dump_file_output_dir_; 3441 }; 3442 3443 int MS_API RunNetTrain(int argc, const char **argv); 3444diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc 3445index 57c818b1..03cac4c0 100644 3446--- a/mindspore/lite/tools/converter/anf_transform.cc 3447+++ b/mindspore/lite/tools/converter/anf_transform.cc 3448@@ -95,6 +95,8 @@ 3449 #include "tools/optimizer/fusion/groupnorm_fusion.h" 3450 #include "tools/optimizer/fusion/mul_reduce_fusion.h" 3451 #include "tools/converter/import/cast_op_adjust.h" 3452+#include "tools/optimizer/fusion/expanddims_reshape_fusion.h" 3453+#include "tools/optimizer/fusion/squeeze_expanddims_fusion.h" 3454 3455 using std::string; 3456 namespace mindspore::lite { 3457@@ -226,7 +228,9 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const std::shared 3458 std::make_shared<opt::FullConnectedFusion>(), 3459 std::make_shared<opt::FullconnectedAddFusion>(), 3460 std::make_shared<opt::TensorDotFusion>(), 3461- std::make_shared<opt::MatMulActivationFusion>(param)}; 3462+ std::make_shared<opt::MatMulActivationFusion>(param), 3463+ std::make_shared<opt::ExpandDimsReshapeFusion>(), 3464+ std::make_shared<opt::SqueezeExpandDimsFusion>()}; 3465 for (size_t index = 0; index < fusions.size(); index++) { 3466 auto pass_ptr = fusions.at(index); 3467 auto pass_name = pass_ptr->name(); 3468diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc 3469index 449c6ef9..eaa18d6b 100644 3470--- a/mindspore/lite/tools/converter/converter.cc 3471+++ b/mindspore/lite/tools/converter/converter.cc 3472@@ -236,12 +236,14 @@ schema::MetaGraphT *ConverterImpl::TransferFuncGraph(const std::shared_ptr<Conve 3473 return nullptr; 3474 } 3475 3476- status = UpdateGraphOutputName(meta_graph); 3477- if (status != RET_OK) { 3478- MS_LOG(ERROR) << "UpdateGraphOutputName failed."; 3479- ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); 3480- delete meta_graph; 3481- return nullptr; 3482+ if (!param->train_model) { 3483+ status = UpdateGraphOutputName(meta_graph); 3484+ if (status != RET_OK) { 3485+ MS_LOG(ERROR) << "UpdateGraphOutputName failed."; 3486+ ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); 3487+ delete meta_graph; 3488+ return nullptr; 3489+ } 3490 } 3491 3492 return meta_graph; 3493diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc 3494index 09f7366e..955e346d 100644 3495--- a/mindspore/lite/tools/converter/graphdef_transform.cc 3496+++ b/mindspore/lite/tools/converter/graphdef_transform.cc 3497@@ -27,6 +27,7 @@ 3498 #include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h" 3499 #include "tools/converter/legacy_optimizer/graph/tensor_quant_pass.h" 3500 #include "tools/converter/legacy_optimizer/graph/tensor_name_pass.h" 3501+#include "tools/converter/legacy_optimizer/graph/node_name_pass.h" 3502 #include "tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h" 3503 #include "tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.h" 3504 #include "tools/converter/legacy_optimizer/graph/convert_fp32_to_fp16_pass.h" 3505@@ -169,6 +170,9 @@ int GraphDefTransform::Transform(const std::shared_ptr<ConverterPara> ¶m) { 3506 Optimizer forming_model_optimizer; 3507 forming_model_optimizer.AddPass(new (std::nothrow) InferShapePass(param->fmk_type)); 3508 forming_model_optimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass(param)); 3509+ if (param->train_model) { 3510+ forming_model_optimizer.AddPass(new (std::nothrow) NodeNamePass()); 3511+ } 3512 forming_model_optimizer.AddPass(new (std::nothrow) TensorNamePass()); 3513 forming_model_optimizer.AddPass(new (std::nothrow) ConvertFP32ToFP16Pass(param->weight_fp16)); 3514 status = forming_model_optimizer.Run(graph_defT_); 3515diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt b/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt 3516index f3e245f4..c274b7db 100644 3517--- a/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt 3518+++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt 3519@@ -9,6 +9,7 @@ file(GLOB GRAPH_PASS 3520 ${CMAKE_CURRENT_SOURCE_DIR}/convert_fp32_to_fp16_pass.cc 3521 ${CMAKE_CURRENT_SOURCE_DIR}/set_unused_quant_param_to_default_pass.cc 3522 ${CMAKE_CURRENT_SOURCE_DIR}/tensor_name_pass.cc 3523+ ${CMAKE_CURRENT_SOURCE_DIR}/node_name_pass.cc 3524 ${CMAKE_CURRENT_SOURCE_DIR}/subgraph_node_pass.cc 3525 ${CMAKE_CURRENT_SOURCE_DIR}/subgraph_tensor_pass.cc 3526 ) 3527diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/node_name_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/node_name_pass.cc 3528new file mode 100644 3529index 00000000..712927b0 3530--- /dev/null 3531+++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/node_name_pass.cc 3532@@ -0,0 +1,96 @@ 3533+/** 3534+ * Copyright 2022 Huawei Technologies Co., Ltd 3535+ * 3536+ * Licensed under the Apache License, Version 2.0 (the "License"); 3537+ * you may not use this file except in compliance with the License. 3538+ * You may obtain a copy of the License at 3539+ * 3540+ * http://www.apache.org/licenses/LICENSE-2.0 3541+ * 3542+ * Unless required by applicable law or agreed to in writing, software 3543+ * distributed under the License is distributed on an "AS IS" BASIS, 3544+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 3545+ * See the License for the specific language governing permissions and 3546+ * limitations under the License. 3547+ */ 3548+ 3549+#include "tools/converter/legacy_optimizer/graph/node_name_pass.h" 3550+#include <string> 3551+#include <vector> 3552+#include "tools/converter/converter_context.h" 3553+ 3554+namespace mindspore::lite { 3555+std::string CutShortName(const std::string &fullname, const std::string &delimiter) { 3556+ size_t end_pos = fullname.find_last_of(delimiter); 3557+ std::string name = ""; 3558+ if (end_pos != std::string::npos) { 3559+ name = fullname.substr(end_pos + 1); 3560+ } 3561+ if ((fullname.find("op") != std::string::npos) && (name.find("op") == std::string::npos) && 3562+ (end_pos != std::string::npos)) { 3563+ size_t pos = fullname.rfind(delimiter, end_pos - 1); 3564+ if (pos != std::string::npos) { 3565+ name.insert(0, fullname.substr(pos + 1, end_pos - pos)); 3566+ } else { 3567+ name.insert(0, fullname.substr(0, end_pos + 1)); 3568+ } 3569+ } 3570+ 3571+ const std::vector<std::string> loss_names = {"loss_fct", "_loss_fn", "SigmoidCrossEntropy"}; 3572+ for (auto &s : loss_names) { 3573+ if (fullname.find(s) != std::string::npos) { 3574+ name.insert(0, s + "/"); 3575+ break; 3576+ } 3577+ } 3578+ 3579+ if (fullname.find("Gradients") != std::string::npos) { 3580+ size_t pos = fullname.find(delimiter); 3581+ if (pos != std::string::npos) { 3582+ name.insert(0, fullname.substr(0, pos + 1)); 3583+ } 3584+ } 3585+ return name; 3586+} 3587+ 3588+STATUS NodeNamePass::Run(schema::MetaGraphT *graph) { 3589+ if (graph == nullptr) { 3590+ MS_LOG(ERROR) << "graph is nullptr"; 3591+ return RET_NULL_PTR; 3592+ } 3593+ 3594+ std::string delimiter = "/"; 3595+ for (auto &node : graph->nodes) { 3596+ if (node == nullptr || node->primitive == nullptr) { 3597+ MS_LOG(ERROR) << "node or node->primitive is nullptr"; 3598+ return RET_NULL_PTR; 3599+ } 3600+ std::string node_name = CutShortName(node->name, delimiter); 3601+ node->name = node_name != "" ? node_name : node->name; 3602+ 3603+ for (int i = 0; i < static_cast<int>(node->inputIndex.size()); i++) { 3604+ auto tensor_id = node->inputIndex.at(i); 3605+ auto &tensor = graph->allTensors.at(tensor_id); 3606+ if (tensor->name.empty()) { 3607+ MS_LOG(DEBUG) << "input tensor (id = " << tensor_id << ") name is null"; 3608+ tensor->name = node->name + "/input-" + std::to_string(i); 3609+ } else { 3610+ std::string in_tensor_name = CutShortName(tensor->name, delimiter); 3611+ tensor->name = in_tensor_name != "" ? in_tensor_name : tensor->name; 3612+ } 3613+ } 3614+ 3615+ for (int i = 0; i < static_cast<int>(node->outputIndex.size()); i++) { 3616+ auto tensor_id = node->outputIndex.at(i); 3617+ auto &tensor = graph->allTensors.at(tensor_id); 3618+ if (tensor->name.empty()) { 3619+ tensor->name = node->name + "/output-" + std::to_string(i); 3620+ } else { 3621+ std::string out_tensor_name = CutShortName(tensor->name, delimiter); 3622+ tensor->name = out_tensor_name != "" ? out_tensor_name : tensor->name; 3623+ } 3624+ } 3625+ } 3626+ return RET_OK; 3627+} 3628+} // namespace mindspore::lite 3629\ No newline at end of file 3630diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/node_name_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/node_name_pass.h 3631new file mode 100644 3632index 00000000..4e58e5c7 3633--- /dev/null 3634+++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/node_name_pass.h 3635@@ -0,0 +1,35 @@ 3636+/** 3637+ * Copyright 2022 Huawei Technologies Co., Ltd 3638+ * 3639+ * Licensed under the Apache License, Version 2.0 (the "License"); 3640+ * you may not use this file except in compliance with the License. 3641+ * You may obtain a copy of the License at 3642+ * 3643+ * http://www.apache.org/licenses/LICENSE-2.0 3644+ * 3645+ * Unless required by applicable law or agreed to in writing, software 3646+ * distributed under the License is distributed on an "AS IS" BASIS, 3647+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 3648+ * See the License for the specific language governing permissions and 3649+ * limitations under the License. 3650+ */ 3651+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAPH_NODE_NAME_PASS_H_ 3652+#define MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAPH_NODE_NAME_PASS_H_ 3653+ 3654+#include <memory> 3655+#include "tools/converter/optimizer.h" 3656+#include "tools/common/graph_util.h" 3657+ 3658+namespace mindspore { 3659+namespace lite { 3660+class NodeNamePass : public GraphPass { 3661+ public: 3662+ NodeNamePass() {} 3663+ 3664+ ~NodeNamePass() override = default; 3665+ 3666+ STATUS Run(schema::MetaGraphT *graph) override; 3667+}; 3668+} // namespace lite 3669+} // namespace mindspore 3670+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAPH_NODE_NAME_PASS_H_ 3671diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc 3672index 04e856a8..edd6a538 100644 3673--- a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc 3674+++ b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc 3675@@ -316,6 +316,16 @@ const void *OnnxNodeParser::GetOnnxRawData(const onnx::TensorProto &onnx_const_t 3676 *data_size = data_count * sizeof(uint8_t); 3677 onnx_data = onnx_const_tensor.raw_data().data(); 3678 break; 3679+ case kNumberTypeFloat16: 3680+ if (INT_MUL_OVERFLOW_THRESHOLD(data_count, sizeof(uint16_t), SIZE_MAX)) { 3681+ MS_LOG(ERROR) << "data_size overflow"; 3682+ return nullptr; 3683+ } 3684+ *data_size = data_count * sizeof(uint16_t); 3685+ if (!onnx_const_tensor.raw_data().empty()) { 3686+ onnx_data = onnx_const_tensor.raw_data().data(); 3687+ } 3688+ break; 3689 default: 3690 MS_LOG(ERROR) << "unsupported data type " << data_type; 3691 return nullptr; 3692diff --git a/mindspore/lite/tools/lite_exporter/anf_exporter.cc b/mindspore/lite/tools/lite_exporter/anf_exporter.cc 3693index e4063edc..6c4d2544 100644 3694--- a/mindspore/lite/tools/lite_exporter/anf_exporter.cc 3695+++ b/mindspore/lite/tools/lite_exporter/anf_exporter.cc 3696@@ -798,6 +798,17 @@ int AnfExporter::ConvertInputParameter(const CNodePtr &cnode, size_t index, cons 3697 int AnfExporter::ConvertInputValueNode(const CNodePtr &cnode, size_t index, const PrimitivePtr &primitive, 3698 const std::unique_ptr<schema::MetaGraphT> &meta_graphT, 3699 schema::CNodeT *op_node) { 3700+ MS_CHECK_TRUE_MSG(cnode != nullptr, RET_NULL_PTR, "cnode is nullptr"); 3701+ MS_CHECK_TRUE_MSG(primitive != nullptr, RET_NULL_PTR, "primitive is nullptr"); 3702+ MS_CHECK_TRUE_MSG(meta_graphT != nullptr, RET_NULL_PTR, "meta_graphT is nullptr"); 3703+ MS_CHECK_TRUE_MSG(op_node != nullptr, RET_NULL_PTR, "op_node is nullptr"); 3704+ auto value_node = cnode->input(index)->cast<ValueNodePtr>(); 3705+ MS_ASSERT(value_node != nullptr); 3706+ auto key = std::make_pair(value_node, 0); 3707+ if (node_id_map_.find(key) != node_id_map_.end()) { 3708+ op_node->inputIndex.emplace_back(node_id_map_[key]); 3709+ return RET_OK; 3710+ } 3711 DataInfo data_info; 3712 auto status = 3713 FetchDataFromValueNode(cnode, index, converter::FmkType(meta_graphT->fmkType), train_flag_, &data_info, true); 3714@@ -810,13 +821,12 @@ int AnfExporter::ConvertInputValueNode(const CNodePtr &cnode, size_t index, cons 3715 } 3716 auto schema_tensor = std::make_unique<schema::TensorT>(); 3717 MS_CHECK_TRUE_MSG(schema_tensor != nullptr, RET_ERROR, "schema is nullptr"); 3718- schema_tensor->name = cnode->input(index)->fullname_with_scope(); 3719+ schema_tensor->name = value_node->fullname_with_scope(); 3720 schema_tensor->format = static_cast<schema::Format>(data_info.format_); 3721 schema_tensor->dataType = data_info.data_type_; 3722 schema_tensor->dims = data_info.shape_; 3723 schema_tensor->data = data_info.data_; 3724 3725- auto key = std::make_pair(cnode->input(index), 0); 3726 node_id_map_[key] = meta_graphT->allTensors.size(); 3727 op_node->inputIndex.emplace_back(meta_graphT->allTensors.size()); 3728 meta_graphT->allTensors.emplace_back(std::move(schema_tensor)); 3729diff --git a/mindspore/lite/tools/optimizer/fusion/expanddims_reshape_fusion.cc b/mindspore/lite/tools/optimizer/fusion/expanddims_reshape_fusion.cc 3730new file mode 100644 3731index 00000000..fb047193 3732--- /dev/null 3733+++ b/mindspore/lite/tools/optimizer/fusion/expanddims_reshape_fusion.cc 3734@@ -0,0 +1,73 @@ 3735+/** 3736+ * Copyright 2022 Huawei Technologies Co., Ltd 3737+ * 3738+ * Licensed under the Apache License, Version 2.0 (the "License"); 3739+ * you may not use this file except in compliance with the License. 3740+ * You may obtain a copy of the License at 3741+ * 3742+ * http://www.apache.org/licenses/LICENSE-2.0 3743+ * 3744+ * Unless required by applicable law or agreed to in writing, software 3745+ * distributed under the License is distributed on an "AS IS" BASIS, 3746+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 3747+ * See the License for the specific language governing permissions and 3748+ * limitations under the License. 3749+ */ 3750+ 3751+#define USE_DEPRECATED_API 3752+#include "tools/optimizer/fusion/expanddims_reshape_fusion.h" 3753+#include "tools/lite_exporter/fetch_content.h" 3754+#include "ops/op_utils.h" 3755+#include "ops/reshape.h" 3756+#include "tools/optimizer/common/gllo_utils.h" 3757+#include "nnacl/op_base.h" 3758+#include "include/registry/converter_context.h" 3759+ 3760+namespace mindspore::opt { 3761+const BaseRef ExpandDimsReshapeFusion::DefinePattern() const { 3762+ auto is_reshape = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>); 3763+ MS_CHECK_TRUE_RET(is_reshape != nullptr, {}); 3764+ auto reshape_shape = std::make_shared<Var>(); 3765+ MS_CHECK_TRUE_RET(reshape_shape != nullptr, {}); 3766+ auto is_expanddims = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimExpandDims>); 3767+ MS_CHECK_TRUE_RET(is_expanddims != nullptr, {}); 3768+ return VectorRef({is_reshape, is_expanddims, reshape_shape}); 3769+} 3770+ 3771+bool ExpandDimsReshapeFusion::CheckCanFuse(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const { 3772+ auto reshape_cnode = node->cast<CNodePtr>(); 3773+ MS_CHECK_TRUE_RET(reshape_cnode != nullptr, false); 3774+ 3775+ MS_CHECK_TRUE_RET(reshape_cnode->input(SECOND_INPUT) != nullptr, false); 3776+ auto expanddims_cnode = reshape_cnode->input(SECOND_INPUT)->cast<CNodePtr>(); 3777+ MS_CHECK_TRUE_RET(expanddims_cnode != nullptr, false); 3778+ if (IsMultiOutputTensors(func_graph, expanddims_cnode)) { 3779+ return false; 3780+ } 3781+ auto expanddims_primc = GetValueNode<PrimitiveCPtr>(expanddims_cnode->input(0)); 3782+ MS_CHECK_TRUE_RET(expanddims_primc != nullptr, false); 3783+ if (IsQuantParameterNode(expanddims_primc)) { 3784+ MS_LOG(INFO) << expanddims_primc->name() << " is quant node"; 3785+ return false; 3786+ } 3787+ return true; 3788+} 3789+ 3790+const AnfNodePtr ExpandDimsReshapeFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, 3791+ const EquivPtr &equiv) const { 3792+ if (func_graph == nullptr || node == nullptr || equiv == nullptr) { 3793+ return nullptr; 3794+ } 3795+ 3796+ if (!CheckCanFuse(func_graph, node)) { 3797+ return nullptr; 3798+ } 3799+ 3800+ auto reshape_cnode = node->cast<CNodePtr>(); 3801+ auto expanddims_cnode = reshape_cnode->input(SECOND_INPUT)->cast<CNodePtr>(); 3802+ auto manage = Manage(func_graph); 3803+ MS_CHECK_TRUE_RET(manage != nullptr, nullptr); 3804+ manage->SetEdge(reshape_cnode, C1NUM, expanddims_cnode->input(SECOND_INPUT)); 3805+ return reshape_cnode; 3806+} 3807+} // namespace mindspore::opt 3808diff --git a/mindspore/lite/tools/optimizer/fusion/expanddims_reshape_fusion.h b/mindspore/lite/tools/optimizer/fusion/expanddims_reshape_fusion.h 3809new file mode 100644 3810index 00000000..09475591 3811--- /dev/null 3812+++ b/mindspore/lite/tools/optimizer/fusion/expanddims_reshape_fusion.h 3813@@ -0,0 +1,40 @@ 3814+/** 3815+ * Copyright 2022 Huawei Technologies Co., Ltd 3816+ * 3817+ * Licensed under the Apache License, Version 2.0 (the "License"); 3818+ * you may not use this file except in compliance with the License. 3819+ * You may obtain a copy of the License at 3820+ * 3821+ * http://www.apache.org/licenses/LICENSE-2.0 3822+ * 3823+ * Unless required by applicable law or agreed to in writing, software 3824+ * distributed under the License is distributed on an "AS IS" BASIS, 3825+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 3826+ * See the License for the specific language governing permissions and 3827+ * limitations under the License. 3828+ */ 3829+ 3830+#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_EXPANDDIMS_RESHAPE_FUSION_H_ 3831+#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_EXPANDDIMS_RESHAPE_FUSION_H_ 3832+ 3833+#include <string> 3834+#include <memory> 3835+#include "utils/check_convert_utils.h" 3836+#include "backend/common/optimizer/optimizer.h" 3837+ 3838+namespace mindspore { 3839+namespace opt { 3840+class ExpandDimsReshapeFusion : public PatternProcessPass { 3841+ public: 3842+ explicit ExpandDimsReshapeFusion(bool multigraph = true, const std::string &name = "ExpandDimsReshapeFusion") 3843+ : PatternProcessPass(name, multigraph) {} 3844+ ~ExpandDimsReshapeFusion() override = default; 3845+ const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; 3846+ const BaseRef DefinePattern() const override; 3847+ 3848+ private: 3849+ bool CheckCanFuse(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const; 3850+}; 3851+} // namespace opt 3852+} // namespace mindspore 3853+#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_EXPANDDIMS_RESHAPE_FUSION_H_ 3854diff --git a/mindspore/lite/tools/optimizer/fusion/squeeze_expanddims_fusion.cc b/mindspore/lite/tools/optimizer/fusion/squeeze_expanddims_fusion.cc 3855new file mode 100644 3856index 00000000..daa8ac64 3857--- /dev/null 3858+++ b/mindspore/lite/tools/optimizer/fusion/squeeze_expanddims_fusion.cc 3859@@ -0,0 +1,117 @@ 3860+/** 3861+ * Copyright 2022 Huawei Technologies Co., Ltd 3862+ * 3863+ * Licensed under the Apache License, Version 2.0 (the "License"); 3864+ * you may not use this file except in compliance with the License. 3865+ * You may obtain a copy of the License at 3866+ * 3867+ * http://www.apache.org/licenses/LICENSE-2.0 3868+ * 3869+ * Unless required by applicable law or agreed to in writing, software 3870+ * distributed under the License is distributed on an "AS IS" BASIS, 3871+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 3872+ * See the License for the specific language governing permissions and 3873+ * limitations under the License. 3874+ */ 3875+ 3876+#define USE_DEPRECATED_API 3877+#include "tools/optimizer/fusion/squeeze_expanddims_fusion.h" 3878+#include <vector> 3879+#include "tools/lite_exporter/fetch_content.h" 3880+#include "ops/op_utils.h" 3881+#include "ops/squeeze.h" 3882+#include "tools/optimizer/common/gllo_utils.h" 3883+#include "nnacl/op_base.h" 3884+#include "include/registry/converter_context.h" 3885+ 3886+namespace mindspore::opt { 3887+const BaseRef SqueezeExpandDimsFusion::DefinePattern() const { 3888+ auto is_expanddims = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimExpandDims>); 3889+ MS_CHECK_TRUE_RET(is_expanddims != nullptr, {}); 3890+ auto ex_shape = std::make_shared<Var>(); 3891+ MS_CHECK_TRUE_RET(ex_shape != nullptr, {}); 3892+ auto is_squeeze = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSqueeze>); 3893+ MS_CHECK_TRUE_RET(is_squeeze != nullptr, {}); 3894+ return VectorRef({is_expanddims, is_squeeze, ex_shape}); 3895+} 3896+ 3897+bool SqueezeExpandDimsFusion::CheckCanFuse(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const { 3898+ auto expanddims_cnode = node->cast<CNodePtr>(); 3899+ MS_CHECK_TRUE_RET(expanddims_cnode != nullptr, false); 3900+ MS_CHECK_TRUE_RET(expanddims_cnode->input(SECOND_INPUT) != nullptr, false); 3901+ auto squeeze_cnode = expanddims_cnode->input(SECOND_INPUT)->cast<CNodePtr>(); 3902+ MS_CHECK_TRUE_RET(squeeze_cnode != nullptr, false); 3903+ if (IsMultiOutputTensors(func_graph, squeeze_cnode)) { 3904+ return false; 3905+ } 3906+ auto squeeze_primitive = GetValueNode<PrimitiveCPtr>(squeeze_cnode->input(0)); 3907+ MS_CHECK_TRUE_RET(squeeze_primitive != nullptr, false); 3908+ MS_CHECK_TRUE_RET(!IsQuantParameterNode(squeeze_primitive), false); 3909+ 3910+ MS_CHECK_TRUE_RET(expanddims_cnode->input(THIRD_INPUT) != nullptr, false); 3911+ lite::DataInfo data_info; 3912+ if (lite::FetchConstData(expanddims_cnode, THIRD_INPUT, converter::kFmkTypeMs, &data_info, false) != lite::RET_OK) { 3913+ return false; 3914+ } 3915+ if ((data_info.data_type_ != kNumberTypeInt && data_info.data_type_ != kNumberTypeInt32) || 3916+ data_info.data_.size() != C4NUM) { 3917+ return false; 3918+ } 3919+ auto expanddims_axis = *reinterpret_cast<int *>(data_info.data_.data()); 3920+ 3921+ auto squeeze_prim = api::MakeShared<mindspore::ops::Squeeze>(squeeze_primitive); 3922+ MS_CHECK_TRUE_RET(squeeze_prim != nullptr, false); 3923+ auto squeeze_axises = squeeze_prim->get_axis(); 3924+ MS_CHECK_TRUE_RET(squeeze_axises.size() < DIMENSION_2D, false); 3925+ int64_t squeeze_axis; 3926+ if (squeeze_axises.empty()) { 3927+ squeeze_axis = INT64_MIN; 3928+ } else { 3929+ squeeze_axis = squeeze_axises.at(C0NUM); 3930+ } 3931+ if (squeeze_axis == expanddims_axis) { 3932+ return true; 3933+ } else { 3934+ // squeeze_axis or expanddims_axis is less than zero 3935+ MS_CHECK_TRUE_RET(squeeze_cnode->input(SECOND_INPUT) != nullptr, false); 3936+ auto squeeze_abt = squeeze_cnode->input(SECOND_INPUT)->abstract(); 3937+ MS_CHECK_TRUE_RET(squeeze_abt != nullptr, false); 3938+ std::vector<int64_t> squeeze_shape; 3939+ if (FetchShapeFromAbstract(squeeze_abt, &squeeze_shape) != lite::RET_OK) { 3940+ return false; 3941+ } 3942+ 3943+ auto expanddims_prim = GetCNodePrimitive(expanddims_cnode); 3944+ MS_CHECK_TRUE_RET(expanddims_prim != nullptr, false); 3945+ auto is_inferred = 3946+ expanddims_prim->GetAttr(kInferDone) != nullptr && GetValue<bool>(expanddims_prim->GetAttr(kInferDone)); 3947+ MS_CHECK_TRUE_RET(is_inferred, false); 3948+ auto expanddims_abt = expanddims_cnode->abstract(); 3949+ MS_CHECK_TRUE_RET(expanddims_abt != nullptr, false); 3950+ std::vector<int64_t> expanddims_shape; 3951+ if (FetchShapeFromAbstract(expanddims_abt, &expanddims_shape) != lite::RET_OK) { 3952+ return false; 3953+ } 3954+ MS_CHECK_TRUE_RET(squeeze_shape == expanddims_shape, false); 3955+ } 3956+ return true; 3957+} 3958+ 3959+const AnfNodePtr SqueezeExpandDimsFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, 3960+ const EquivPtr &equiv) const { 3961+ if (func_graph == nullptr || node == nullptr || equiv == nullptr) { 3962+ return nullptr; 3963+ } 3964+ 3965+ if (!CheckCanFuse(func_graph, node)) { 3966+ return nullptr; 3967+ } 3968+ 3969+ auto expanddims_cnode = node->cast<CNodePtr>(); 3970+ auto squeeze_cnode = expanddims_cnode->input(SECOND_INPUT)->cast<CNodePtr>(); 3971+ auto manage = Manage(func_graph); 3972+ MS_CHECK_TRUE_RET(manage != nullptr, nullptr); 3973+ manage->Replace(expanddims_cnode, squeeze_cnode->input(SECOND_INPUT)); 3974+ return nullptr; 3975+} 3976+} // namespace mindspore::opt 3977diff --git a/mindspore/lite/tools/optimizer/fusion/squeeze_expanddims_fusion.h b/mindspore/lite/tools/optimizer/fusion/squeeze_expanddims_fusion.h 3978new file mode 100644 3979index 00000000..1173cb0a 3980--- /dev/null 3981+++ b/mindspore/lite/tools/optimizer/fusion/squeeze_expanddims_fusion.h 3982@@ -0,0 +1,40 @@ 3983+/** 3984+ * Copyright 2022 Huawei Technologies Co., Ltd 3985+ * 3986+ * Licensed under the Apache License, Version 2.0 (the "License"); 3987+ * you may not use this file except in compliance with the License. 3988+ * You may obtain a copy of the License at 3989+ * 3990+ * http://www.apache.org/licenses/LICENSE-2.0 3991+ * 3992+ * Unless required by applicable law or agreed to in writing, software 3993+ * distributed under the License is distributed on an "AS IS" BASIS, 3994+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 3995+ * See the License for the specific language governing permissions and 3996+ * limitations under the License. 3997+ */ 3998+ 3999+#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_SQUEEZE_EXPANDDIMS_FUSION_H_ 4000+#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_SQUEEZE_EXPANDDIMS_FUSION_H_ 4001+ 4002+#include <string> 4003+#include <memory> 4004+#include "utils/check_convert_utils.h" 4005+#include "backend/common/optimizer/optimizer.h" 4006+ 4007+namespace mindspore { 4008+namespace opt { 4009+class SqueezeExpandDimsFusion : public PatternProcessPass { 4010+ public: 4011+ explicit SqueezeExpandDimsFusion(bool multigraph = true, const std::string &name = "SqueezeExpandDimsFusion") 4012+ : PatternProcessPass(name, multigraph) {} 4013+ ~SqueezeExpandDimsFusion() override = default; 4014+ const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; 4015+ const BaseRef DefinePattern() const override; 4016+ 4017+ private: 4018+ bool CheckCanFuse(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const; 4019+}; 4020+} // namespace opt 4021+} // namespace mindspore 4022+#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_SQUEEZE_EXPANDDIMS_FUSION_H_ 4023-- 40242.17.1 4025 4026