• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1From 98e86575c68939fecc1bc3c80be9fca0e080b7fa Mon Sep 17 00:00:00 2001
2From: z00574805 <z00574805@notesmail.huawei.com/>
3Date: Wed, 24 May 2023 11:04:25 +0800
4Subject: [PATCH 1/5] xiaoyi-0001
5
6---
7 include/api/callback/callback.h               |  22 +-
8 include/api/callback/ckpt_saver.h             |   4 +-
9 include/api/callback/loss_monitor.h           |   4 +-
10 include/api/callback/lr_scheduler.h           |  10 +-
11 include/api/callback/time_monitor.h           |   4 +-
12 include/api/callback/train_accuracy.h         |  10 +-
13 include/api/cfg.h                             |  25 ++-
14 include/api/metrics/accuracy.h                |   4 +-
15 include/api/metrics/metrics.h                 |   7 +-
16 include/api/model_parallel_runner.h           |   4 +-
17 include/api/net.h                             |  23 +-
18 include/api/serialization.h                   |  20 +-
19 include/api/types.h                           |   6 +-
20 mindspore/ccsrc/cxx_api/serialization.cc      |   3 +-
21 .../device/cpu/kernel/nnacl/fp32/div_fp32.c   |   4 -
22 .../nnacl/fp32_grad/binary_cross_entropy.c    |  45 ++--
23 .../nnacl/fp32_grad/binary_cross_entropy.h    |   2 +-
24 .../fp32_grad/binary_cross_entropy_grad.c     |  34 ++-
25 .../fp32_grad/binary_cross_entropy_grad.h     |   2 +-
26 .../nnacl/infer/binary_cross_entropy_infer.c  |   4 +-
27 .../cpu/kernel/nnacl/infer/common_infer.c     |   1 +
28 mindspore/lite/include/model.h                |   7 +-
29 .../include/registry/opencl_runtime_wrapper.h |   4 +-
30 .../java/src/main/native/train_config.cpp     |   6 +-
31 mindspore/lite/src/CMakeLists.txt             |  22 +-
32 .../binary_cross_entropy_grad_populate.cc     |  45 ----
33 .../populate/binary_cross_entropy_populate.cc |  45 ----
34 mindspore/lite/src/common/prim_util.h         |   7 +-
35 mindspore/lite/src/common/tensor_util.h       |   2 +-
36 mindspore/lite/src/extendrt/CMakeLists.txt    |   4 +
37 .../src/extendrt/cxx_api/serialization.cc     |   3 +-
38 .../lite/src/runtime/cxx_api/converters.h     |   4 +-
39 .../src/runtime/cxx_api/model/model_impl.cc   |   3 +
40 .../src/runtime/cxx_api/model/model_impl.h    |   6 +-
41 .../lite/src/runtime/cxx_api/serialization.cc |  31 ++-
42 .../src/runtime/cxx_api/train/converters.cc   |   6 +-
43 mindspore/lite/src/runtime/infer_manager.h    |  10 +-
44 mindspore/lite/src/runtime/inner_context.h    |   2 +-
45 .../runtime/kernel/cpu/base/argminmax_base.cc |   1 -
46 .../kernel/cpu/base/arithmetic_base.cc        |   1 +
47 .../kernel/cpu/base/group_convolution_base.cc |  16 +-
48 .../cpu/base/group_convolution_creator.cc     |  14 +-
49 .../runtime/kernel/cpu/base/reshape_base.cc   |   1 +
50 .../runtime/kernel/cpu/base/strided_slice.cc  |   2 +
51 .../kernel/cpu/fp16/fused_batchnorm_fp16.cc   |   8 +-
52 .../kernel/cpu/fp32/fused_batchnorm_fp32.cc   |  16 +-
53 .../runtime/kernel/cpu/fp32/oneslike_fp32.cc  |  52 +++++
54 .../runtime/kernel/cpu/fp32/oneslike_fp32.h   |  46 ++++
55 .../cpu/fp32_grad/binary_cross_entropy.cc     | 120 +++++++++++
56 .../cpu/fp32_grad/binary_cross_entropy.h      |  42 ++++
57 .../fp32_grad/binary_cross_entropy_grad.cc    | 105 +++++++++
58 .../cpu/fp32_grad/binary_cross_entropy_grad.h |  41 ++++
59 .../runtime/kernel/gpu/opencl/CMakeLists.txt  |  11 +
60 .../src/runtime/kernel/opencl/CMakeLists.txt  |   8 -
61 mindspore/lite/src/runtime/kernel_exec_util.h |   2 +-
62 mindspore/lite/src/runtime/kernel_registry.h  |   4 +-
63 mindspore/lite/src/runtime/lite_kernel.h      |   4 +-
64 mindspore/lite/src/runtime/lite_model.h       |   4 +-
65 mindspore/lite/src/runtime/lite_session.cc    |  28 ++-
66 mindspore/lite/src/runtime/lite_session.h     |  10 +-
67 mindspore/lite/src/runtime/weight_decoder.h   |   4 +-
68 mindspore/lite/src/tensor.h                   |   2 +-
69 mindspore/lite/src/tensorlist.h               |   2 +-
70 mindspore/lite/src/train/graph_fusion.cc      |   4 +
71 .../train/optimizer/common/fusion_utils.cc    |  37 ++++
72 .../src/train/optimizer/common/fusion_utils.h |  50 +++++
73 .../fusion/matmul_activation_fusion_pass.cc   |  93 ++++++++
74 .../fusion/matmul_activation_fusion_pass.h    |  42 ++++
75 .../reshape_gather_reshape_fusion_pass.cc     | 148 +++++++++++++
76 .../reshape_gather_reshape_fusion_pass.h      |  42 ++++
77 mindspore/lite/src/train/train_export.cc      |  32 +++
78 mindspore/lite/src/train/train_export.h       |   4 +
79 .../src/train/train_populate_parameter.cc     |  49 +++--
80 mindspore/lite/src/train/train_session.cc     |  80 ++++---
81 mindspore/lite/src/train/train_session.h      |  14 ++
82 mindspore/lite/src/train/transfer_session.cc  |  41 +++-
83 mindspore/lite/src/train/transfer_session.h   |   5 +
84 .../lite/tools/benchmark_train/net_train.cc   | 203 ++++++++++++++++--
85 .../lite/tools/benchmark_train/net_train.h    |  23 +-
86 .../lite/tools/converter/anf_transform.cc     |   6 +-
87 mindspore/lite/tools/converter/converter.cc   |  14 +-
88 .../tools/converter/graphdef_transform.cc     |   4 +
89 .../legacy_optimizer/graph/CMakeLists.txt     |   1 +
90 .../legacy_optimizer/graph/node_name_pass.cc  |  96 +++++++++
91 .../legacy_optimizer/graph/node_name_pass.h   |  35 +++
92 .../converter/parser/onnx/onnx_node_parser.cc |  10 +
93 .../lite/tools/lite_exporter/anf_exporter.cc  |  14 +-
94 .../fusion/expanddims_reshape_fusion.cc       |  73 +++++++
95 .../fusion/expanddims_reshape_fusion.h        |  40 ++++
96 .../fusion/squeeze_expanddims_fusion.cc       | 117 ++++++++++
97 .../fusion/squeeze_expanddims_fusion.h        |  40 ++++
98 91 files changed, 1955 insertions(+), 351 deletions(-)
99 delete mode 100644 mindspore/lite/src/common/ops/populate/binary_cross_entropy_grad_populate.cc
100 delete mode 100644 mindspore/lite/src/common/ops/populate/binary_cross_entropy_populate.cc
101 create mode 100644 mindspore/lite/src/runtime/kernel/cpu/fp32/oneslike_fp32.cc
102 create mode 100644 mindspore/lite/src/runtime/kernel/cpu/fp32/oneslike_fp32.h
103 create mode 100644 mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy.cc
104 create mode 100644 mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy.h
105 create mode 100644 mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy_grad.cc
106 create mode 100644 mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy_grad.h
107 create mode 100644 mindspore/lite/src/runtime/kernel/gpu/opencl/CMakeLists.txt
108 delete mode 100644 mindspore/lite/src/runtime/kernel/opencl/CMakeLists.txt
109 create mode 100644 mindspore/lite/src/train/optimizer/common/fusion_utils.cc
110 create mode 100644 mindspore/lite/src/train/optimizer/common/fusion_utils.h
111 create mode 100644 mindspore/lite/src/train/optimizer/fusion/matmul_activation_fusion_pass.cc
112 create mode 100644 mindspore/lite/src/train/optimizer/fusion/matmul_activation_fusion_pass.h
113 create mode 100644 mindspore/lite/src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.cc
114 create mode 100644 mindspore/lite/src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.h
115 create mode 100644 mindspore/lite/tools/converter/legacy_optimizer/graph/node_name_pass.cc
116 create mode 100644 mindspore/lite/tools/converter/legacy_optimizer/graph/node_name_pass.h
117 create mode 100644 mindspore/lite/tools/optimizer/fusion/expanddims_reshape_fusion.cc
118 create mode 100644 mindspore/lite/tools/optimizer/fusion/expanddims_reshape_fusion.h
119 create mode 100644 mindspore/lite/tools/optimizer/fusion/squeeze_expanddims_fusion.cc
120 create mode 100644 mindspore/lite/tools/optimizer/fusion/squeeze_expanddims_fusion.h
121
122diff --git a/include/api/callback/callback.h b/include/api/callback/callback.h
123index 3332f819..60f30b80 100644
124--- a/include/api/callback/callback.h
125+++ b/include/api/callback/callback.h
126@@ -1,5 +1,5 @@
127 /**
128- * Copyright 2021 Huawei Technologies Co., Ltd
129+ * Copyright 2021-2023 Huawei Technologies Co., Ltd
130  *
131  * Licensed under the Apache License, Version 2.0 (the "License");
132  * you may not use this file except in compliance with the License.
133@@ -23,6 +23,7 @@
134 #include <utility>
135 #include "include/api/data_type.h"
136 #include "include/api/dual_abi_helper.h"
137+#include "include/api/types.h"
138
139 namespace mindspore {
140 class Model;
141@@ -31,24 +32,19 @@ class CallbackImpl;
142
143 using GraphPoint = std::pair<int, float>;
144
145-struct TrainCallBackData {
146-  TrainCallBackData(bool train_mode, int epoch, int step, Model *model): train_mode_(train_mode), epoch_(epoch),
147-                    step_(step), model_(model) {}
148+struct MS_API TrainCallBackData {
149+  TrainCallBackData(bool train_mode, int epoch, int step, Model *model)
150+      : train_mode_(train_mode), epoch_(epoch), step_(step), model_(model) {}
151
152   bool train_mode_;       /**< training mode of LiteSession object */
153   unsigned int epoch_;    /**< the current training epoch (starts at 0) */
154   unsigned int step_ = 0; /**< the current step within the epoch */
155-  Model *model_;  /**< pointer to the Model object */
156+  Model *model_;          /**< pointer to the Model object */
157 };
158
159-enum CallbackRetValue : uint32_t {
160-  kContinue = 0,
161-  kStopTraining = 1,
162-  kExit = 2,
163-  kUnknownRetValue = 0xFFFFFFFF
164-};
165+enum CallbackRetValue : uint32_t { kContinue = 0, kStopTraining = 1, kExit = 2, kUnknownRetValue = 0xFFFFFFFF };
166
167-class TrainCallBack {
168+class MS_API TrainCallBack {
169  public:
170   virtual ~TrainCallBack() = default;
171
172@@ -90,7 +86,7 @@ class TrainCallBack {
173  protected:
174   friend class Model;
175   friend class ModelImpl;
176-  CallbackImpl* callback_impl_ = nullptr;
177+  CallbackImpl *callback_impl_ = nullptr;
178 };
179
180 }  // namespace mindspore
181diff --git a/include/api/callback/ckpt_saver.h b/include/api/callback/ckpt_saver.h
182index e673c624..d9ff2f69 100644
183--- a/include/api/callback/ckpt_saver.h
184+++ b/include/api/callback/ckpt_saver.h
185@@ -1,5 +1,5 @@
186 /**
187- * Copyright 2021 Huawei Technologies Co., Ltd
188+ * Copyright 2021-2023 Huawei Technologies Co., Ltd
189  *
190  * Licensed under the Apache License, Version 2.0 (the "License");
191  * you may not use this file except in compliance with the License.
192@@ -25,7 +25,7 @@
193
194 namespace mindspore {
195
196-class CkptSaver: public TrainCallBack {
197+class MS_API CkptSaver : public TrainCallBack {
198  public:
199   inline CkptSaver(int save_every_n, const std::string &filename_prefix);
200   virtual ~CkptSaver();
201diff --git a/include/api/callback/loss_monitor.h b/include/api/callback/loss_monitor.h
202index 9e0a8247..7efd0ca7 100644
203--- a/include/api/callback/loss_monitor.h
204+++ b/include/api/callback/loss_monitor.h
205@@ -1,5 +1,5 @@
206 /**
207- * Copyright 2021 Huawei Technologies Co., Ltd
208+ * Copyright 2021-2023 Huawei Technologies Co., Ltd
209  *
210  * Licensed under the Apache License, Version 2.0 (the "License");
211  * you may not use this file except in compliance with the License.
212@@ -23,7 +23,7 @@
213
214 namespace mindspore {
215
216-class LossMonitor: public TrainCallBack {
217+class MS_API LossMonitor: public TrainCallBack {
218  public:
219   explicit LossMonitor(int print_every_n_steps = INT_MAX);
220   virtual ~LossMonitor();
221diff --git a/include/api/callback/lr_scheduler.h b/include/api/callback/lr_scheduler.h
222index 2eddc66b..11fd7124 100644
223--- a/include/api/callback/lr_scheduler.h
224+++ b/include/api/callback/lr_scheduler.h
225@@ -1,5 +1,5 @@
226 /**
227- * Copyright 2021 Huawei Technologies Co., Ltd
228+ * Copyright 2021-2023 Huawei Technologies Co., Ltd
229  *
230  * Licensed under the Apache License, Version 2.0 (the "License");
231  * you may not use this file except in compliance with the License.
232@@ -30,18 +30,18 @@ constexpr int UPDATE_LR = 1;
233 using LR_Lambda = std::function<int(float *lr, int epoch, void *cb_data)>;
234
235 /// \brief Multiply the LR by a factor of gamma every epoch
236-int MultiplicativeLRLambda(float *lr, int epoch, void *multiplication);
237+MS_API int MultiplicativeLRLambda(float *lr, int epoch, void *multiplication);
238
239 /// \brief Multiply the LR by a factor of gamma every step_size
240-int StepLRLambda(float *lr, int epoch, void *step_size);
241-struct StepLRLambda {
242+MS_API int StepLRLambda(float *lr, int epoch, void *step_size);
243+struct MS_API StepLRLambda {
244   StepLRLambda(int step, float g) : step_size(step), gamma(g) {}
245
246   int step_size;  // period of LR decay
247   float gamma;    // LR decay factor
248 };
249
250-class LRScheduler: public TrainCallBack {
251+class MS_API LRScheduler : public TrainCallBack {
252  public:
253   explicit LRScheduler(LR_Lambda lambda_func, void *lr_cb_data = nullptr, int step = 1);
254   virtual ~LRScheduler();
255diff --git a/include/api/callback/time_monitor.h b/include/api/callback/time_monitor.h
256index 7e857849..45e48248 100644
257--- a/include/api/callback/time_monitor.h
258+++ b/include/api/callback/time_monitor.h
259@@ -1,5 +1,5 @@
260 /**
261- * Copyright 2021 Huawei Technologies Co., Ltd
262+ * Copyright 2021-2023 Huawei Technologies Co., Ltd
263  *
264  * Licensed under the Apache License, Version 2.0 (the "License");
265  * you may not use this file except in compliance with the License.
266@@ -24,7 +24,7 @@
267
268 namespace mindspore {
269
270-class TimeMonitor: public TrainCallBack {
271+class MS_API TimeMonitor : public TrainCallBack {
272  public:
273   virtual ~TimeMonitor() = default;
274   void EpochBegin(const TrainCallBackData &cb_data) override;
275diff --git a/include/api/callback/train_accuracy.h b/include/api/callback/train_accuracy.h
276index 5838dfd9..16774dd7 100644
277--- a/include/api/callback/train_accuracy.h
278+++ b/include/api/callback/train_accuracy.h
279@@ -1,5 +1,5 @@
280 /**
281- * Copyright 2021 Huawei Technologies Co., Ltd
282+ * Copyright 2021-2023 Huawei Technologies Co., Ltd
283  *
284  * Licensed under the Apache License, Version 2.0 (the "License");
285  * you may not use this file except in compliance with the License.
286@@ -26,12 +26,10 @@
287
288 namespace mindspore {
289
290-class TrainAccuracy: public TrainCallBack {
291+class MS_API TrainAccuracy : public TrainCallBack {
292  public:
293-  explicit TrainAccuracy(int print_every_n = INT_MAX,
294-                         int accuracy_metrics = METRICS_CLASSIFICATION,
295-                         const std::vector<int> &input_indexes = {1},
296-                         const std::vector<int> &output_indexes = {0});
297+  explicit TrainAccuracy(int print_every_n = INT_MAX, int accuracy_metrics = METRICS_CLASSIFICATION,
298+                         const std::vector<int> &input_indexes = {1}, const std::vector<int> &output_indexes = {0});
299   virtual ~TrainAccuracy();
300   const std::vector<GraphPoint> &GetAccuracyPoints();
301 };
302diff --git a/include/api/cfg.h b/include/api/cfg.h
303index db915cac..8dc37bb4 100644
304--- a/include/api/cfg.h
305+++ b/include/api/cfg.h
306@@ -1,5 +1,5 @@
307 /**
308- * Copyright 2022 Huawei Technologies Co., Ltd
309+ * Copyright 2022-2023 Huawei Technologies Co., Ltd
310  *
311  * Licensed under the Apache License, Version 2.0 (the "License");
312  * you may not use this file except in compliance with the License.
313@@ -26,7 +26,7 @@
314
315 namespace mindspore {
316 constexpr int iter_th = 1000;
317-class MixPrecisionCfg {
318+class MS_API MixPrecisionCfg {
319  public:
320   MixPrecisionCfg() {
321     this->dynamic_loss_scale_ = false;
322@@ -49,7 +49,7 @@ class MixPrecisionCfg {
323   bool is_raw_mix_precision_ = false; /**< Is mix precision model export from mindspore  */
324 };
325
326-class TrainCfg {
327+class MS_API TrainCfg {
328  public:
329   TrainCfg() = default;
330   TrainCfg(const TrainCfg &rhs) {
331@@ -59,11 +59,24 @@ class TrainCfg {
332   }
333   ~TrainCfg() = default;
334
335+  /// \brief obtain part of the name that identify a loss kernel.
336+  ///
337+  /// \return loss_name.
338+  inline std::vector<std::string> GetLossName() const;
339+  /// \brief Set part of the name that identify a loss kernel.
340+  ///
341+  /// \param[in] loss_name define part of the name that identify a loss kernel.
342+  inline void SetLossName(const std::vector<std::string> &loss_name);
343+
344   OptimizationLevel optimization_level_ = kO0;
345-  std::vector<std::string> loss_name_ = {
346-    "loss_fct", "_loss_fn", "SigmoidCrossEntropy"}; /**< Set part of the name that identify a loss kernel */
347-  MixPrecisionCfg mix_precision_cfg_;               /**< Mix precision configuration */
348+  MixPrecisionCfg mix_precision_cfg_; /**< Mix precision configuration */
349   bool accumulate_gradients_ = false;
350+
351+ private:
352+  std::vector<std::vector<char>> loss_name_ = VectorStringToChar({"loss_fct", "_loss_fn", "SigmoidCrossEntropy"});
353 };
354+
355+std::vector<std::string> TrainCfg::GetLossName() const { return VectorCharToString(loss_name_); }
356+void TrainCfg::SetLossName(const std::vector<std::string> &loss_name) { loss_name_ = VectorStringToChar(loss_name); }
357 }  // namespace mindspore
358 #endif  // MINDSPORE_INCLUDE_API_CFG_H
359diff --git a/include/api/metrics/accuracy.h b/include/api/metrics/accuracy.h
360index 1d1732f3..4aefc3b5 100644
361--- a/include/api/metrics/accuracy.h
362+++ b/include/api/metrics/accuracy.h
363@@ -1,5 +1,5 @@
364 /**
365- * Copyright 2021 Huawei Technologies Co., Ltd
366+ * Copyright 2021-2023 Huawei Technologies Co., Ltd
367  *
368  * Licensed under the Apache License, Version 2.0 (the "License");
369  * you may not use this file except in compliance with the License.
370@@ -23,7 +23,7 @@ namespace mindspore {
371 constexpr int METRICS_CLASSIFICATION = 0;
372 constexpr int METRICS_MULTILABEL = 1;
373
374-class AccuracyMetrics : public Metrics {
375+class MS_API AccuracyMetrics : public Metrics {
376  public:
377   explicit AccuracyMetrics(int accuracy_metrics = METRICS_CLASSIFICATION, const std::vector<int> &input_indexes = {1},
378                            const std::vector<int> &output_indexes = {0});
379diff --git a/include/api/metrics/metrics.h b/include/api/metrics/metrics.h
380index 7154332f..36eb4ed1 100644
381--- a/include/api/metrics/metrics.h
382+++ b/include/api/metrics/metrics.h
383@@ -1,5 +1,5 @@
384 /**
385- * Copyright 2021 Huawei Technologies Co., Ltd
386+ * Copyright 2021-2023 Huawei Technologies Co., Ltd
387  *
388  * Licensed under the Apache License, Version 2.0 (the "License");
389  * you may not use this file except in compliance with the License.
390@@ -24,16 +24,17 @@ class MetricsImpl;
391 class ModelImpl;
392 class MSTensor;
393
394-class Metrics {
395+class MS_API Metrics {
396  public:
397   virtual ~Metrics() = default;
398   virtual void Clear() {}
399   virtual float Eval() { return 0.0; }
400   virtual void Update(std::vector<MSTensor *> inputs, std::vector<MSTensor *> outputs) {}
401+
402  protected:
403   friend class Model;
404   friend class ModelImpl;
405-  MetricsImpl* metrics_impl_;
406+  MetricsImpl *metrics_impl_;
407 };
408
409 }  // namespace mindspore
410diff --git a/include/api/model_parallel_runner.h b/include/api/model_parallel_runner.h
411index 159f4cea..360405b9 100644
412--- a/include/api/model_parallel_runner.h
413+++ b/include/api/model_parallel_runner.h
414@@ -1,5 +1,5 @@
415 /**
416- * Copyright 2022 Huawei Technologies Co., Ltd
417+ * Copyright 2022-2023 Huawei Technologies Co., Ltd
418  *
419  * Licensed under the Apache License, Version 2.0 (the "License");
420  * you may not use this file except in compliance with the License.
421@@ -25,7 +25,7 @@
422 namespace mindspore {
423 /// \brief The RunnerConfig class  is used to store environment variables during execution
424 /// management.
425-class RunnerConfig {
426+class MS_API RunnerConfig {
427  public:
428   struct Data;
429   RunnerConfig();
430diff --git a/include/api/net.h b/include/api/net.h
431index c7a3a9b0..61990ae0 100644
432--- a/include/api/net.h
433+++ b/include/api/net.h
434@@ -1,5 +1,5 @@
435 /**
436- * Copyright 2022 Huawei Technologies Co., Ltd
437+ * Copyright 2022-2023 Huawei Technologies Co., Ltd
438  *
439  * Licensed under the Apache License, Version 2.0 (the "License");
440  * you may not use this file except in compliance with the License.
441@@ -36,14 +36,14 @@ class NodeSet;
442 class Graph;
443 class NetData;
444
445-class NetBase {
446+class MS_API NetBase {
447  public:
448   NetBase() = default;
449   virtual std::vector<Expr *> operator()(const std::vector<Expr *> &inputs) = 0;
450   virtual uint32_t type() = 0;
451 };
452
453-class Node : public NetBase {
454+class MS_API Node : public NetBase {
455  public:
456   Node();
457   virtual ~Node();
458@@ -65,7 +65,7 @@ class Node : public NetBase {
459   std::shared_ptr<NodeImpl> impl_ = nullptr;
460 };
461
462-class Net : public NetBase, public std::enable_shared_from_this<Net> {
463+class MS_API Net : public NetBase, public std::enable_shared_from_this<Net> {
464  public:
465   Net();
466   virtual ~Net();
467@@ -116,12 +116,12 @@ class Net : public NetBase, public std::enable_shared_from_this<Net> {
468   std::shared_ptr<NetImpl> impl_;
469 };
470
471-class SoftMaxCrossEntropyCfg {
472+class MS_API SoftMaxCrossEntropyCfg {
473  public:
474   std::string reduction = "mean"; /**<  Specifies reduction mode. The optional values are "none", "mean", "sum" */
475 };
476
477-class AdamConfig {
478+class MS_API AdamConfig {
479  public:
480   float learning_rate_ = 1e-3;
481   float beta1_ = 0.9;
482@@ -131,11 +131,12 @@ class AdamConfig {
483 };
484
485 namespace NN {
486-Net *NetWithLoss(Net *net, Node *loss);
487-Graph *GraphWithLoss(Graph *g, Node *loss);
488-Node *Adam(std::shared_ptr<NodeSet> learn, const AdamConfig &cfg);
489-Node *SoftmaxCrossEntropy(const SoftMaxCrossEntropyCfg &cfg);
490-std::unique_ptr<Node> Input(std::vector<int> dims, DataType data_type = DataType::kNumberTypeFloat32, int fmt = NHWC);
491+MS_API Net *NetWithLoss(Net *net, Node *loss);
492+MS_API Graph *GraphWithLoss(Graph *g, Node *loss);
493+MS_API Node *Adam(std::shared_ptr<NodeSet> learn, const AdamConfig &cfg);
494+MS_API Node *SoftmaxCrossEntropy(const SoftMaxCrossEntropyCfg &cfg);
495+MS_API std::unique_ptr<Node> Input(std::vector<int> dims, DataType data_type = DataType::kNumberTypeFloat32,
496+                                   int fmt = NHWC);
497 };  // namespace NN
498 }  // namespace mindspore
499 #endif  // MINDSPORE_INCLUDE_API_NET_H
500diff --git a/include/api/serialization.h b/include/api/serialization.h
501index 2dc9d028..1a0c1f57 100644
502--- a/include/api/serialization.h
503+++ b/include/api/serialization.h
504@@ -79,10 +79,16 @@ class MS_API Serialization {
505   ///
506   /// \param[in] model The model data.
507   /// \param[in] model_type The model file type.
508-  /// \param[out] model_data The model parameter data.
509+  /// \param[out] model_data The model buffer.
510+  /// \param[in] quantization_type The quantification type.
511+  /// \param[in] export_inference_only Whether to export a reasoning only model.
512+  /// \param[in] output_tensor_name The set the name of the output tensor of the exported reasoning model, default as
513+  /// empty, and export the complete reasoning model.
514   ///
515   /// \return Status.
516-  static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data);
517+  inline static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data,
518+                                   QuantizationType quantization_type = kNoQuant, bool export_inference_only = true,
519+                                   const std::vector<std::string> &output_tensor_name = {});
520
521   /// \brief Export training model from file.
522   ///
523@@ -110,6 +116,9 @@ class MS_API Serialization {
524   static Status ExportModel(const Model &model, ModelType model_type, const std::vector<char> &model_file,
525                             QuantizationType quantization_type, bool export_inference_only,
526                             const std::vector<std::vector<char>> &output_tensor_name);
527+  static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data,
528+                            QuantizationType quantization_type, bool export_inference_only,
529+                            const std::vector<std::vector<char>> &output_tensor_name);
530 };
531
532 Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
533@@ -134,5 +143,12 @@ Status Serialization::ExportModel(const Model &model, ModelType model_type, cons
534                      VectorStringToChar(output_tensor_name));
535 }
536
537+Status Serialization::ExportModel(const Model &model, ModelType model_type, Buffer *model_data,
538+                                  QuantizationType quantization_type, bool export_inference_only,
539+                                  const std::vector<std::string> &output_tensor_name) {
540+  return ExportModel(model, model_type, model_data, quantization_type, export_inference_only,
541+                     VectorStringToChar(output_tensor_name));
542+}
543+
544 }  // namespace mindspore
545 #endif  // MINDSPORE_INCLUDE_API_SERIALIZATION_H
546diff --git a/include/api/types.h b/include/api/types.h
547index 6cf04523..377b5db0 100644
548--- a/include/api/types.h
549+++ b/include/api/types.h
550@@ -1,5 +1,5 @@
551 /**
552- * Copyright 2020 Huawei Technologies Co., Ltd
553+ * Copyright 2020-2023 Huawei Technologies Co., Ltd
554  *
555  * Licensed under the Apache License, Version 2.0 (the "License");
556  * you may not use this file except in compliance with the License.
557@@ -350,7 +350,7 @@ std::string MSTensor::Name() const { return CharToString(CharName()); }
558
559 void MSTensor::SetTensorName(const std::string &name) { return SetTensorName(StringToChar(name)); }
560
561-using Key = struct Key {
562+using Key = struct MS_API Key {
563   const size_t max_key_len = 32;
564   size_t len = 0;
565   unsigned char key[32] = {0};
566@@ -371,7 +371,7 @@ struct MSCallBackParam {
567 using MSKernelCallBack = std::function<bool(const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs,
568                                             const MSCallBackParam &opInfo)>;
569
570-std::vector<char> CharVersion();
571+MS_API std::vector<char> CharVersion();
572 inline std::string Version() { return CharToString(CharVersion()); }
573
574 }  // namespace mindspore
575diff --git a/mindspore/ccsrc/cxx_api/serialization.cc b/mindspore/ccsrc/cxx_api/serialization.cc
576index 1ea95935..bf33e85d 100644
577--- a/mindspore/ccsrc/cxx_api/serialization.cc
578+++ b/mindspore/ccsrc/cxx_api/serialization.cc
579@@ -334,7 +334,8 @@ Status Serialization::SetParameters(const std::map<std::string, Buffer> &, Model
580   return kMEFailed;
581 }
582
583-Status Serialization::ExportModel(const Model &, ModelType, Buffer *) {
584+Status Serialization::ExportModel(const Model &, ModelType, Buffer *, QuantizationType, bool,
585+                                  const std::vector<std::vector<char>> & /* output_tensor_name */) {
586   MS_LOG(ERROR) << "Unsupported feature.";
587   return kMEFailed;
588 }
589diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/div_fp32.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/div_fp32.c
590index f6fa5994..60a27df1 100644
591--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/div_fp32.c
592+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/div_fp32.c
593@@ -29,10 +29,6 @@ int ElementOptDiv(const float *in0, const float *in1, float *out, int size, cons
594       out[index] = in0[0] / in1[index];
595     }
596   } else {
597-    if (in1[0] == 0) {
598-      return NNACL_ERRCODE_DIVISOR_ZERO;
599-    }
600-
601     SIMD_RUN_NO_SCALAR(ElementOptDivNum1, index, in0, in1, out, size);
602     for (; index < size; index++) {
603       out[index] = in0[index] / in1[0];
604diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy.c
605index 2db54161..cf2f867c 100644
606--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy.c
607+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy.c
608@@ -18,28 +18,44 @@
609 #include "nnacl/fp32_grad/binary_cross_entropy.h"
610
611 static void BinaryCrossEntropyLossKernel(const int input_size, const int reduction, const float *input_x,
612-                                         const float *input_y, const float *weight, float *loss, float *tmp_loss) {
613+                                         const float *input_y, const float *weight, float *loss, float *tmp_loss,
614+                                         bool weight_defined) {
615   const float epsilon = 1e-12;
616-  if (reduction == 0) {
617-    for (int i = 0; i < input_size; i++) {
618-      float value =
619-        -weight[i] * (input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon));
620-      loss[i] = value;
621+
622+  if (reduction == Reduction_None) {
623+    if (weight_defined) {
624+      for (int i = 0; i < input_size; i++) {
625+        float value =
626+          -weight[i] * (input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon));
627+        loss[i] = value;
628+      }
629+    } else {
630+      for (int i = 0; i < input_size; i++) {
631+        float value = -(input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon));
632+        loss[i] = value;
633+      }
634     }
635   } else {
636-    for (int i = 0; i < input_size; i++) {
637-      float value =
638-        -weight[i] * (input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon));
639-      tmp_loss[i] = value;
640+    if (weight_defined) {
641+      for (int i = 0; i < input_size; i++) {
642+        float value =
643+          -weight[i] * (input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon));
644+        tmp_loss[i] = value;
645+      }
646+    } else {
647+      for (int i = 0; i < input_size; i++) {
648+        float value = -(input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon));
649+        tmp_loss[i] = value;
650+      }
651     }
652   }
653 }
654
655 void BinaryCrossEntropy(const int input_size, const int reduction, const float *input_x, const float *input_y,
656-                        const float *weight, float *loss, float *tmp_loss) {
657+                        const float *weight, float *loss, float *tmp_loss, bool weight_defined) {
658   loss[0] = 0.0f;
659-  BinaryCrossEntropyLossKernel(input_size, reduction, input_x, input_y, weight, loss, tmp_loss);
660-  if (reduction != 0) {
661+  BinaryCrossEntropyLossKernel(input_size, reduction, input_x, input_y, weight, loss, tmp_loss, weight_defined);
662+  if (reduction != Reduction_None) {
663     if (input_size % 2 == 1) {
664       tmp_loss[0] += tmp_loss[input_size - 1];
665     }
666@@ -47,13 +63,12 @@ void BinaryCrossEntropy(const int input_size, const int reduction, const float *
667       for (int i = 0; i < stride; i++) {
668         tmp_loss[i] += tmp_loss[i + stride];
669       }
670-
671       if (stride > 2 && stride % 2 == 1) {
672         tmp_loss[0] += tmp_loss[stride - 1];
673       }
674     }
675     loss[0] += tmp_loss[0];
676-    if (reduction == 1) {
677+    if (reduction == Reduction_Mean) {
678       loss[0] /= input_size;
679     }
680   }
681diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy.h
682index 6ba6422d..abf6e63b 100644
683--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy.h
684+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy.h
685@@ -28,7 +28,7 @@ extern "C" {
686 #endif
687
688 void BinaryCrossEntropy(const int input_size, const int reduction, const float *input_x, const float *input_y,
689-                        const float *weight, float *loss, float *tmp_loss);
690+                        const float *weight, float *loss, float *tmp_loss, bool weight_defined);
691
692 #ifdef __cplusplus
693 }
694diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy_grad.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy_grad.c
695index 95e28c8c..12d20356 100644
696--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy_grad.c
697+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy_grad.c
698@@ -19,23 +19,37 @@
699 #define MAX(a, b) ((a) > (b) ? (a) : (b))
700
701 int BinaryCrossEntropyGrad(const int input_size, const int reduction, const float *input_x, const float *input_y,
702-                           const float *weight, const float *dloss, float *dx) {
703+                           const float *weight, const float *dloss, float *dx, bool weight_defined) {
704   const float epsilon = 1e-12f;
705-  if (reduction == 0) {
706-    for (int i = 0; i < input_size; i++) {
707-      float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon);
708-      float value = weight[i] * (input_x[i] - input_y[i]) / denominator;
709-      dx[i] = value * dloss[i];
710+  if (reduction == Reduction_None) {
711+    if (weight_defined) {
712+      for (int i = 0; i < input_size; i++) {
713+        float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon);
714+        float value = weight[i] * (input_x[i] - input_y[i]) / denominator;
715+        dx[i] = value * dloss[i];
716+      }
717+    } else {
718+      for (int i = 0; i < input_size; i++) {
719+        float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon);
720+        float value = (input_x[i] - input_y[i]) / denominator;
721+        dx[i] = value * dloss[i];
722+      }
723     }
724   } else {
725     float dloss1 = dloss[0];
726-    if (reduction == 1) {
727+    if (reduction == Reduction_Mean) {
728       dloss1 = dloss[0] / input_size;
729     }
730     for (int i = 0; i < input_size; i++) {
731-      float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon);
732-      float value = weight[i] * (input_x[i] - input_y[i]) / denominator;
733-      dx[i] = value * dloss1;
734+      if (weight_defined) {
735+        float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon);
736+        float value = weight[i] * (input_x[i] - input_y[i]) / denominator;
737+        dx[i] = value * dloss1;
738+      } else {
739+        float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon);
740+        float value = (input_x[i] - input_y[i]) / denominator;
741+        dx[i] = value * dloss1;
742+      }
743     }
744   }
745   return 0;
746diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy_grad.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy_grad.h
747index f3506f4f..3033fa98 100644
748--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy_grad.h
749+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy_grad.h
750@@ -28,7 +28,7 @@ extern "C" {
751 #endif
752
753 int BinaryCrossEntropyGrad(const int input_size, const int reduction, const float *input_x, const float *input_y,
754-                           const float *weight, const float *dloss, float *dx);
755+                           const float *weight, const float *dloss, float *dx, bool weight_defined);
756
757 #ifdef __cplusplus
758 }
759diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/binary_cross_entropy_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/binary_cross_entropy_infer.c
760index e280ad2e..22e207ac 100644
761--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/binary_cross_entropy_infer.c
762+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/binary_cross_entropy_infer.c
763@@ -27,8 +27,8 @@ int BinaryCrossEntropyInferShape(const TensorC *const *inputs, size_t inputs_siz
764   TensorC *out = outputs[0];
765   SetDataTypeFormat(out, x);
766   BinaryCrossEntropyParameter *param = (BinaryCrossEntropyParameter *)parameter;
767-  int reduction = param->reduction;
768-  if (reduction == 1 || reduction == 2) {
769+  ReductionType reduction = (ReductionType)(param->reduction);
770+  if (reduction == Reduction_Mean || reduction == Reduction_Sum) {
771     out->shape_size_ = 1;
772     out->shape_[0] = 1;
773   } else {
774diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/common_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/common_infer.c
775index 3073385f..875c3bc0 100644
776--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/common_infer.c
777+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/common_infer.c
778@@ -237,6 +237,7 @@ REG_INFER(LRN, PrimType_LRN, CommonInferShapeWithNHWC)
779 REG_INFER(L2Normalize, PrimType_L2NormalizeFusion, CommonInferShape)
780 REG_INFER(Neg, PrimType_Neg, CommonInferShape)
781 REG_INFER(NegGrad, PrimType_NegGrad, CommonGradInferShape)
782+REG_INFER(OnesLike, PrimType_OnesLike, CommonInferShape)
783 REG_INFER(PowerGrad, PrimType_PowerGrad, CommonGradInferShape)
784 REG_INFER(PReLU, PrimType_PReLUFusion, CommonInferShape)
785 REG_INFER(Reciprocal, PrimType_Reciprocal, CommonInferShape)
786diff --git a/mindspore/lite/include/model.h b/mindspore/lite/include/model.h
787index a54904c8..a635e745 100644
788--- a/mindspore/lite/include/model.h
789+++ b/mindspore/lite/include/model.h
790@@ -1,5 +1,5 @@
791 /**
792- * Copyright 2020 Huawei Technologies Co., Ltd
793+ * Copyright 2020-2023 Huawei Technologies Co., Ltd
794  *
795  * Licensed under the Apache License, Version 2.0 (the "License");
796  * you may not use this file except in compliance with the License.
797@@ -19,6 +19,7 @@
798 #include <memory>
799 #include <string>
800 #include <vector>
801+#include "include/api/types.h"
802
803 namespace mindspore {
804 namespace schema {
805@@ -30,7 +31,7 @@ typedef enum { ModelType_MSLite, ModelType_MindIR } LiteModelType;
806
807 // LiteGraph can be considered as a light weight and subset of FuncGraph, it can not support the advanced expression of
808 // FuncGraph, e.g., non-tail recursive.
809-struct LiteGraph {
810+struct MS_API LiteGraph {
811   struct Node {
812     std::string name_;
813     std::string op_type_;
814@@ -66,7 +67,7 @@ struct LiteGraph {
815   std::string ToString() const;
816 };
817
818-struct Model {
819+struct MS_API Model {
820   LiteGraph graph_;
821   char *buf = nullptr;
822   size_t buf_size_ = 0;
823diff --git a/mindspore/lite/include/registry/opencl_runtime_wrapper.h b/mindspore/lite/include/registry/opencl_runtime_wrapper.h
824index 5b95a8a3..b55554e4 100644
825--- a/mindspore/lite/include/registry/opencl_runtime_wrapper.h
826+++ b/mindspore/lite/include/registry/opencl_runtime_wrapper.h
827@@ -1,5 +1,5 @@
828 /**
829- * Copyright 2021 Huawei Technologies Co., Ltd
830+ * Copyright 2021-2023 Huawei Technologies Co., Ltd
831  *
832  * Licensed under the Apache License, Version 2.0 (the "License");
833  * you may not use this file except in compliance with the License.
834@@ -30,7 +30,7 @@
835 #include "include/api/dual_abi_helper.h"
836
837 namespace mindspore::registry::opencl {
838-class OpenCLRuntimeWrapper {
839+class MS_API OpenCLRuntimeWrapper {
840  public:
841   OpenCLRuntimeWrapper() = default;
842   ~OpenCLRuntimeWrapper() = default;
843diff --git a/mindspore/lite/java/src/main/native/train_config.cpp b/mindspore/lite/java/src/main/native/train_config.cpp
844index 4177e96b..d2452acf 100644
845--- a/mindspore/lite/java/src/main/native/train_config.cpp
846+++ b/mindspore/lite/java/src/main/native/train_config.cpp
847@@ -1,5 +1,5 @@
848 /**
849- * Copyright 2021 Huawei Technologies Co., Ltd
850+ * Copyright 2021-2023 Huawei Technologies Co., Ltd
851  *
852  * Licensed under the Apache License, Version 2.0 (the "License");
853  * you may not use this file except in compliance with the License.
854@@ -50,7 +50,9 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_config_TrainCfg_createTrai
855     return (jlong) nullptr;
856   }
857   if (loss_name != nullptr) {
858-    traincfg_ptr->loss_name_.emplace_back(env->GetStringUTFChars(loss_name, JNI_FALSE));
859+    std::vector<std::string> traincfg_loss_name = traincfg_ptr->GetLossName();
860+    traincfg_loss_name.emplace_back(env->GetStringUTFChars(loss_name, JNI_FALSE));
861+    traincfg_ptr->SetLossName(traincfg_loss_name);
862   }
863   traincfg_ptr->optimization_level_ = ol;
864   traincfg_ptr->accumulate_gradients_ = accmulateGrads;
865diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt
866index 16ae2e63..48e0fe7c 100644
867--- a/mindspore/lite/src/CMakeLists.txt
868+++ b/mindspore/lite/src/CMakeLists.txt
869@@ -50,6 +50,11 @@ if(APPLE OR PLATFORM_ARM32 OR PLATFORM_ARM64)
870         -fdata-sections -ffast-math -fno-rtti -fno-exceptions -Wno-shorten-64-to-32 \
871         -fno-aligned-allocation -DTARGET_OS_OSX")
872     endif()
873+    if("${CMAKE_BUILD_TYPE}" STREQUAL "Release" AND NOT MSLITE_ENABLE_TESTCASES)
874+        set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fvisibility-inlines-hidden -fvisibility=hidden")
875+        set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility-inlines-hidden -fvisibility=hidden")
876+        set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--gc-sections")
877+    endif()
878 elseif(NOT MSVC)
879     if("${CMAKE_BUILD_TYPE}" STREQUAL "Release")
880         set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}  -fomit-frame-pointer -fstrict-aliasing -ffunction-sections \
881@@ -312,16 +317,6 @@ set(LITE_SRC
882     ${CMAKE_CURRENT_SOURCE_DIR}/runtime/weight_decoder.cc
883     )
884
885-if(MSLITE_GPU_BACKEND STREQUAL opencl)
886-    file(GLOB_RECURSE OPENCL_RUNTIME_SRC
887-            ${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/gpu/opencl/*.cc
888-            )
889-    set(LITE_SRC
890-            ${LITE_SRC}
891-            ${OPENCL_RUNTIME_SRC}
892-            )
893-endif()
894-
895 if(MSLITE_GPU_BACKEND STREQUAL cuda)
896     file(GLOB CUDA_RUNTIME_SRC
897             ${CMAKE_CURRENT_SOURCE_DIR}/runtime/gpu/*.cc
898@@ -384,6 +379,9 @@ set(TRAIN_SRC
899         ${CMAKE_CURRENT_SOURCE_DIR}/train/classification_train_accuracy_monitor.cc
900         ${CMAKE_CURRENT_SOURCE_DIR}/train/train_export.cc
901         ${CMAKE_CURRENT_SOURCE_DIR}/train/opt_allocator.cc
902+        ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/common/fusion_utils.cc
903+        ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/fusion/matmul_activation_fusion_pass.cc
904+        ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.cc
905         ${CMAKE_CURRENT_SOURCE_DIR}/common/storage.cc
906         ${TOOLS_DIR}/converter/optimizer.cc
907         ${TOOLS_DIR}/converter/legacy_optimizer/fusion/fusion_pass.cc
908@@ -393,7 +391,7 @@ set(TRAIN_SRC
909         ${TOOLS_DIR}/converter/legacy_optimizer/graph/dropout_node_remove_pass.cc
910         ${TOOLS_DIR}/converter/legacy_optimizer/graph/isolated_node_remove_pass.cc
911         ${TOOLS_DIR}/converter/legacy_optimizer/graph/subgraph_node_pass.cc
912-        )
913+        train/optimizer/fusion/reshape_gather_reshape_fusion_pass.cc)
914
915 if(MSLITE_ENABLE_MINDRT)
916     add_subdirectory(${CORE_DIR}/mindrt mindspore_mindrt)
917@@ -527,7 +525,7 @@ else()
918 endif()
919
920 if(MSLITE_GPU_BACKEND STREQUAL opencl)
921-    add_subdirectory(runtime/kernel/opencl)
922+    add_subdirectory(runtime/kernel/gpu/opencl)
923     target_link_libraries(mindspore-lite opencl_kernel_mid)
924     target_link_libraries(mindspore-lite_static opencl_kernel_mid)
925 endif()
926diff --git a/mindspore/lite/src/common/ops/populate/binary_cross_entropy_grad_populate.cc b/mindspore/lite/src/common/ops/populate/binary_cross_entropy_grad_populate.cc
927deleted file mode 100644
928index 5da193bc..00000000
929--- a/mindspore/lite/src/common/ops/populate/binary_cross_entropy_grad_populate.cc
930+++ /dev/null
931@@ -1,45 +0,0 @@
932-/**
933- * Copyright 2019-2021 Huawei Technologies Co., Ltd
934- *
935- * Licensed under the Apache License, Version 2.0 (the "License");
936- * you may not use this file except in compliance with the License.
937- * You may obtain a copy of the License at
938- *
939- * http://www.apache.org/licenses/LICENSE-2.0
940- *
941- * Unless required by applicable law or agreed to in writing, software
942- * distributed under the License is distributed on an "AS IS" BASIS,
943- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
944- * See the License for the specific language governing permissions and
945- * limitations under the License.
946- */
947-#include "src/common/ops/populate/populate_register.h"
948-#include "nnacl/fp32_grad/binary_cross_entropy_grad.h"
949-using mindspore::schema::PrimitiveType_BinaryCrossEntropyGrad;
950-
951-namespace mindspore {
952-namespace lite {
953-OpParameter *PopulateBinaryCrossEntropyGradParameter(const void *prim) {
954-  auto *primitive = static_cast<const schema::Primitive *>(prim);
955-  MS_ASSERT(primitive != nullptr);
956-  auto value = primitive->value_as_BinaryCrossEntropyGrad();
957-  if (value == nullptr) {
958-    MS_LOG(ERROR) << "param is nullptr";
959-    return nullptr;
960-  }
961-
962-  auto *param = reinterpret_cast<BinaryCrossEntropyGradParameter *>(malloc(sizeof(BinaryCrossEntropyGradParameter)));
963-  if (param == nullptr) {
964-    MS_LOG(ERROR) << "malloc BinaryCrossEntropyGrad Parameter failed.";
965-    return nullptr;
966-  }
967-  memset(param, 0, sizeof(BinaryCrossEntropyGradParameter));
968-
969-  param->op_parameter_.type_ = primitive->value_type();
970-  param->reduction = value->reduction();
971-  return reinterpret_cast<OpParameter *>(param);
972-}
973-
974-REG_POPULATE(PrimitiveType_BinaryCrossEntropyGrad, PopulateBinaryCrossEntropyGradParameter, SCHEMA_CUR);
975-}  // namespace lite
976-}  // namespace mindspore
977diff --git a/mindspore/lite/src/common/ops/populate/binary_cross_entropy_populate.cc b/mindspore/lite/src/common/ops/populate/binary_cross_entropy_populate.cc
978deleted file mode 100644
979index 10060d3f..00000000
980--- a/mindspore/lite/src/common/ops/populate/binary_cross_entropy_populate.cc
981+++ /dev/null
982@@ -1,45 +0,0 @@
983-/**
984- * Copyright 2019-2021 Huawei Technologies Co., Ltd
985- *
986- * Licensed under the Apache License, Version 2.0 (the "License");
987- * you may not use this file except in compliance with the License.
988- * You may obtain a copy of the License at
989- *
990- * http://www.apache.org/licenses/LICENSE-2.0
991- *
992- * Unless required by applicable law or agreed to in writing, software
993- * distributed under the License is distributed on an "AS IS" BASIS,
994- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
995- * See the License for the specific language governing permissions and
996- * limitations under the License.
997- */
998-#include "src/common/ops/populate/populate_register.h"
999-#include "nnacl/fp32_grad/binary_cross_entropy.h"
1000-using mindspore::schema::PrimitiveType_BinaryCrossEntropy;
1001-
1002-namespace mindspore {
1003-namespace lite {
1004-OpParameter *PopulateBinaryCrossEntropyParameter(const void *prim) {
1005-  auto primitive = static_cast<const schema::Primitive *>(prim);
1006-  MS_ASSERT(primitive != nullptr);
1007-  auto value = primitive->value_as_BinaryCrossEntropy();
1008-  if (value == nullptr) {
1009-    MS_LOG(ERROR) << "value is nullptr";
1010-    return nullptr;
1011-  }
1012-
1013-  auto *param = reinterpret_cast<BinaryCrossEntropyParameter *>(malloc(sizeof(BinaryCrossEntropyParameter)));
1014-  if (param == nullptr) {
1015-    MS_LOG(ERROR) << "malloc BinaryCrossEntropy Parameter failed.";
1016-    return nullptr;
1017-  }
1018-  memset(param, 0, sizeof(BinaryCrossEntropyParameter));
1019-
1020-  param->op_parameter_.type_ = primitive->value_type();
1021-  param->reduction = value->reduction();
1022-  return reinterpret_cast<OpParameter *>(param);
1023-}
1024-
1025-REG_POPULATE(PrimitiveType_BinaryCrossEntropy, PopulateBinaryCrossEntropyParameter, SCHEMA_CUR);
1026-}  // namespace lite
1027-}  // namespace mindspore
1028diff --git a/mindspore/lite/src/common/prim_util.h b/mindspore/lite/src/common/prim_util.h
1029index 2714b6df..38733f82 100644
1030--- a/mindspore/lite/src/common/prim_util.h
1031+++ b/mindspore/lite/src/common/prim_util.h
1032@@ -1,5 +1,5 @@
1033 /**
1034- * Copyright 2020 Huawei Technologies Co., Ltd
1035+ * Copyright 2020-2023 Huawei Technologies Co., Ltd
1036  *
1037  * Licensed under the Apache License, Version 2.0 (the "License");
1038  * you may not use this file except in compliance with the License.
1039@@ -16,13 +16,14 @@
1040
1041 #ifndef MINDSPORE_LITE_SRC_COMMON_PRIM_UTIL_H_
1042 #define MINDSPORE_LITE_SRC_COMMON_PRIM_UTIL_H_
1043+#include "include/api/types.h"
1044
1045 namespace mindspore {
1046 namespace lite {
1047-int GetPrimitiveType(const void *primitive, int schema_version);
1048+MS_API int GetPrimitiveType(const void *primitive, int schema_version);
1049 const char *GetPrimitiveTypeName(const void *primitive, int schema_version);
1050 const char *PrimitiveCurVersionTypeName(int type);
1051-int GenPrimVersionKey(int primitive_type, int schema_version);
1052+MS_API int GenPrimVersionKey(int primitive_type, int schema_version);
1053 bool IsPartialNode(const void *primitive, int schema_version);
1054 bool IsCallNode(const void *primitive, int schema_version);
1055 bool IsSwitchNode(const void *primitive, int schema_version);
1056diff --git a/mindspore/lite/src/common/tensor_util.h b/mindspore/lite/src/common/tensor_util.h
1057index 6e8ac3af..caced545 100644
1058--- a/mindspore/lite/src/common/tensor_util.h
1059+++ b/mindspore/lite/src/common/tensor_util.h
1060@@ -41,7 +41,7 @@ int GenerateInTensorC(const std::vector<lite::Tensor *> &inputs, std::vector<Ten
1061                       std::shared_ptr<Allocator> allocator = nullptr);
1062 int GenerateOutTensorC(const OpParameter *const parameter, const std::vector<lite::Tensor *> &outputs,
1063                        std::vector<TensorC *> *out_tensor_c, std::shared_ptr<Allocator> allocator = nullptr);
1064-int CheckTensorsInvalid(const std::vector<Tensor *> &tensors);
1065+MS_API int CheckTensorsInvalid(const std::vector<Tensor *> &tensors);
1066 int CheckGraphInputShapes(const std::vector<Tensor *> &inputs,
1067                           const std::unordered_map<Tensor *, std::vector<int>> &input_shape_map);
1068 std::vector<mindspore::MSTensor> LiteTensorsToMSTensors(const std::vector<lite::Tensor *> &lite_tensors);
1069diff --git a/mindspore/lite/src/extendrt/CMakeLists.txt b/mindspore/lite/src/extendrt/CMakeLists.txt
1070index 4f43b01f..70d734be 100644
1071--- a/mindspore/lite/src/extendrt/CMakeLists.txt
1072+++ b/mindspore/lite/src/extendrt/CMakeLists.txt
1073@@ -1,3 +1,7 @@
1074+string(REPLACE "-fvisibility-inlines-hidden" "" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
1075+string(REPLACE "-fvisibility=hidden" "" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
1076+string(REPLACE "-fvisibility-inlines-hidden" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
1077+string(REPLACE "-fvisibility=hidden" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
1078 set(MODEL_LOADER_FRAMEWORK_SRC
1079         ${MODEL_LOADER_FRAMEWORK_SRC}
1080         ${CMAKE_CURRENT_SOURCE_DIR}/mindir_loader/model_loader.cc
1081diff --git a/mindspore/lite/src/extendrt/cxx_api/serialization.cc b/mindspore/lite/src/extendrt/cxx_api/serialization.cc
1082index 344cfca7..c1e3d065 100644
1083--- a/mindspore/lite/src/extendrt/cxx_api/serialization.cc
1084+++ b/mindspore/lite/src/extendrt/cxx_api/serialization.cc
1085@@ -332,7 +332,8 @@ Status Serialization::SetParameters(const std::map<std::string, Buffer> &, Model
1086   return kMEFailed;
1087 }
1088
1089-Status Serialization::ExportModel(const Model &, ModelType, Buffer *) {
1090+Status Serialization::ExportModel(const Model &, ModelType, Buffer *, QuantizationType, bool,
1091+                                  const std::vector<std::vector<char>> & /* output_tensor_name */) {
1092   MS_LOG(ERROR) << "Unsupported feature.";
1093   return kMEFailed;
1094 }
1095diff --git a/mindspore/lite/src/runtime/cxx_api/converters.h b/mindspore/lite/src/runtime/cxx_api/converters.h
1096index bd7daabb..45ed6a5b 100644
1097--- a/mindspore/lite/src/runtime/cxx_api/converters.h
1098+++ b/mindspore/lite/src/runtime/cxx_api/converters.h
1099@@ -1,5 +1,5 @@
1100 /**
1101- * Copyright 2021 Huawei Technologies Co., Ltd
1102+ * Copyright 2021-2023 Huawei Technologies Co., Ltd
1103  *
1104  * Licensed under the Apache License, Version 2.0 (the "License");
1105  * you may not use this file except in compliance with the License.
1106@@ -27,7 +27,7 @@
1107 #include "src/runtime/c_api/context_c.h"
1108
1109 namespace mindspore {
1110-class ContextUtils {
1111+class MS_API ContextUtils {
1112  public:
1113   static lite::InnerContext *Convert(Context *context);
1114   static lite::InnerContext *Convert(const ContextC *context_c);
1115diff --git a/mindspore/lite/src/runtime/cxx_api/model/model_impl.cc b/mindspore/lite/src/runtime/cxx_api/model/model_impl.cc
1116index e7a0e272..f5f275e4 100644
1117--- a/mindspore/lite/src/runtime/cxx_api/model/model_impl.cc
1118+++ b/mindspore/lite/src/runtime/cxx_api/model/model_impl.cc
1119@@ -717,6 +717,9 @@ Status ModelImpl::UpdateWeights(const std::vector<MSTensor> &new_weights) {
1120     inner_weights[i] = lite_impl->lite_tensor();
1121   }
1122   auto ret = session_->UpdateWeights(inner_weights);
1123+  if (ret != kSuccess) {
1124+    MS_LOG(ERROR) << "UpdateWeights failed, and the origin weights may have been changed.";
1125+  }
1126   return static_cast<StatusCode>(ret);
1127 }
1128
1129diff --git a/mindspore/lite/src/runtime/cxx_api/model/model_impl.h b/mindspore/lite/src/runtime/cxx_api/model/model_impl.h
1130index 3d359f14..5c572883 100644
1131--- a/mindspore/lite/src/runtime/cxx_api/model/model_impl.h
1132+++ b/mindspore/lite/src/runtime/cxx_api/model/model_impl.h
1133@@ -1,5 +1,5 @@
1134 /**
1135- * Copyright 2021 Huawei Technologies Co., Ltd
1136+ * Copyright 2021-2023 Huawei Technologies Co., Ltd
1137  *
1138  * Licensed under the Apache License, Version 2.0 (the "License");
1139  * you may not use this file except in compliance with the License.
1140@@ -47,10 +47,10 @@ namespace mindspore {
1141 typedef std::shared_ptr<lite::LiteSession>(CreateTrainSessionProto)(std::shared_ptr<Graph::GraphData> graph_data,
1142                                                                     std::shared_ptr<TrainCfg> cfg,
1143                                                                     lite::InnerContext *context);
1144-CreateTrainSessionProto *CreateTrainSessionCallbackHolder(CreateTrainSessionProto *proto = nullptr);
1145+MS_API CreateTrainSessionProto *CreateTrainSessionCallbackHolder(CreateTrainSessionProto *proto = nullptr);
1146
1147 using ExpressionLoader = std::function<Status(const char *, Graph *)>;
1148-ExpressionLoader CreateExpressionLoader(ExpressionLoader loader = nullptr);
1149+MS_API ExpressionLoader CreateExpressionLoader(ExpressionLoader loader = nullptr);
1150
1151 namespace session {
1152 class Metrics;
1153diff --git a/mindspore/lite/src/runtime/cxx_api/serialization.cc b/mindspore/lite/src/runtime/cxx_api/serialization.cc
1154index 3db32826..8405f4b2 100644
1155--- a/mindspore/lite/src/runtime/cxx_api/serialization.cc
1156+++ b/mindspore/lite/src/runtime/cxx_api/serialization.cc
1157@@ -157,9 +157,34 @@ Status Serialization::SetParameters(const std::map<std::string, Buffer> &paramet
1158   return kMEFailed;
1159 }
1160
1161-Status Serialization::ExportModel(const Model &model, ModelType model_type, Buffer *model_data) {
1162-  MS_LOG(ERROR) << "Unsupported feature.";
1163-  return kMEFailed;
1164+Status Serialization::ExportModel(const Model &model, ModelType model_type, Buffer *model_data,
1165+                                  QuantizationType quantization_type, bool export_inference_only,
1166+                                  const std::vector<std::vector<char>> &output_tensor_name) {
1167+  if (model.impl_ == nullptr) {
1168+    MS_LOG(ERROR) << "Model implement is null.";
1169+    return kLiteUninitializedObj;
1170+  }
1171+  if (!model.impl_->IsTrainModel()) {
1172+    MS_LOG(ERROR) << "Model is not TrainModel.";
1173+    return kLiteError;
1174+  }
1175+  if (model_data == nullptr) {
1176+    MS_LOG(ERROR) << "model_data is nullptr.";
1177+    return kLiteParamInvalid;
1178+  }
1179+  if (model_type != kMindIR && model_type != kMindIR_Lite) {
1180+    MS_LOG(ERROR) << "Unsupported Export Format " << model_type;
1181+    return kLiteParamInvalid;
1182+  }
1183+  if (model.impl_->session_ == nullptr) {
1184+    MS_LOG(ERROR) << "Model session is nullptr.";
1185+    return kLiteError;
1186+  }
1187+  auto ret = model.impl_->session_->Export(model_data, export_inference_only ? lite::MT_INFERENCE : lite::MT_TRAIN,
1188+                                           A2L_ConvertQT(quantization_type), lite::FT_FLATBUFFERS,
1189+                                           VectorCharToString(output_tensor_name));
1190+
1191+  return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError;
1192 }
1193
1194 Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::vector<char> &model_file,
1195diff --git a/mindspore/lite/src/runtime/cxx_api/train/converters.cc b/mindspore/lite/src/runtime/cxx_api/train/converters.cc
1196index 694259b3..b0801804 100644
1197--- a/mindspore/lite/src/runtime/cxx_api/train/converters.cc
1198+++ b/mindspore/lite/src/runtime/cxx_api/train/converters.cc
1199@@ -1,5 +1,5 @@
1200 /**
1201- * Copyright 2021 Huawei Technologies Co., Ltd
1202+ * Copyright 2021-2023 Huawei Technologies Co., Ltd
1203  *
1204  * Licensed under the Apache License, Version 2.0 (the "License");
1205  * you may not use this file except in compliance with the License.
1206@@ -25,8 +25,8 @@ Status A2L_ConvertConfig(const TrainCfg *a_train_cfg, lite::TrainCfg *l_train_cf
1207     return kLiteNullptr;
1208   }
1209
1210-  l_train_cfg->loss_name_.clear();
1211-  l_train_cfg->loss_name_.assign(a_train_cfg->loss_name_.begin(), a_train_cfg->loss_name_.end());
1212+  std::vector<std::string> a_loss_name = a_train_cfg->GetLossName();
1213+  l_train_cfg->loss_name_.assign(a_loss_name.begin(), a_loss_name.end());
1214   l_train_cfg->mix_precision_cfg_.dynamic_loss_scale_ = a_train_cfg->mix_precision_cfg_.loss_scale_;
1215   l_train_cfg->mix_precision_cfg_.loss_scale_ = a_train_cfg->mix_precision_cfg_.loss_scale_;
1216   l_train_cfg->mix_precision_cfg_.keep_batchnorm_fp32_ = (a_train_cfg->optimization_level_ != kO3);
1217diff --git a/mindspore/lite/src/runtime/infer_manager.h b/mindspore/lite/src/runtime/infer_manager.h
1218index 31da532e..a851b7d2 100644
1219--- a/mindspore/lite/src/runtime/infer_manager.h
1220+++ b/mindspore/lite/src/runtime/infer_manager.h
1221@@ -31,11 +31,11 @@
1222 #include "include/api/allocator.h"
1223
1224 namespace mindspore::lite {
1225-int KernelInferShape(const std::vector<lite::Tensor *> &tensors_in, const std::vector<lite::Tensor *> &outputs,
1226-                     OpParameter *parameter, std::shared_ptr<Allocator> allocator = nullptr);
1227-int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
1228-                     const void *primitive, std::set<std::string> &&providers, int schema_version,
1229-                     const kernel::Kernel *kernel = nullptr);
1230+MS_API int KernelInferShape(const std::vector<lite::Tensor *> &tensors_in, const std::vector<lite::Tensor *> &outputs,
1231+                            OpParameter *parameter, std::shared_ptr<Allocator> allocator = nullptr);
1232+MS_API int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
1233+                            const void *primitive, std::set<std::string> &&providers, int schema_version,
1234+                            const kernel::Kernel *kernel = nullptr);
1235 class InferManager {
1236  public:
1237   static InferManager *GetInstance() {
1238diff --git a/mindspore/lite/src/runtime/inner_context.h b/mindspore/lite/src/runtime/inner_context.h
1239index adbeacbf..ff58995f 100644
1240--- a/mindspore/lite/src/runtime/inner_context.h
1241+++ b/mindspore/lite/src/runtime/inner_context.h
1242@@ -35,7 +35,7 @@ namespace mindspore::lite {
1243 #ifdef ENABLE_MINDRT
1244 constexpr int kDefaultParallelNum = 2;
1245 #endif
1246-struct InnerContext : public Context {
1247+struct MS_API InnerContext : public Context {
1248  public:
1249   InnerContext() { InitDeviceFp16(); }
1250
1251diff --git a/mindspore/lite/src/runtime/kernel/cpu/base/argminmax_base.cc b/mindspore/lite/src/runtime/kernel/cpu/base/argminmax_base.cc
1252index 843fc0c9..5b94867b 100644
1253--- a/mindspore/lite/src/runtime/kernel/cpu/base/argminmax_base.cc
1254+++ b/mindspore/lite/src/runtime/kernel/cpu/base/argminmax_base.cc
1255@@ -54,7 +54,6 @@ int ArgMinMaxCPUKernel::ReSize() {
1256   ComputeStrides(in_shape.data(), arg_param_->in_strides_, in_shape.size());
1257   CHECK_NULL_RETURN(out_tensors_.at(0));
1258   auto out_shape = out_tensors_.at(0)->shape();
1259-  CHECK_NULL_RETURN(out_shape.data());
1260   ComputeStrides(out_shape.data(), arg_param_->out_strides_, out_shape.size());
1261   return RET_OK;
1262 }
1263diff --git a/mindspore/lite/src/runtime/kernel/cpu/base/arithmetic_base.cc b/mindspore/lite/src/runtime/kernel/cpu/base/arithmetic_base.cc
1264index 68f5cce3..14c97bf8 100644
1265--- a/mindspore/lite/src/runtime/kernel/cpu/base/arithmetic_base.cc
1266+++ b/mindspore/lite/src/runtime/kernel/cpu/base/arithmetic_base.cc
1267@@ -285,6 +285,7 @@ void ArithmeticBaseCPUKernel::ComputeOfflineInfo() {
1268       c_matric_.batch_post_sum[i] = c_matric_.shape[i] * c_matric_.batch_post_sum[i + 1];
1269     }
1270   }
1271+  scalar_opt_ = false;
1272   if (a_matric_.inner_size == 1) {
1273     param_->in_elements_num0_ = 1;
1274     scalar_opt_ = true;
1275diff --git a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_base.cc b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_base.cc
1276index aa50a916..b5370ddd 100644
1277--- a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_base.cc
1278+++ b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_base.cc
1279@@ -50,8 +50,6 @@ int GroupConvolutionBaseCPUKernel::ReSize() {
1280   if (group_num_ == 0) {
1281     return RET_ERROR;
1282   }
1283-  conv_param_->input_channel_ /= group_num_;
1284-  conv_param_->output_channel_ /= group_num_;
1285   return RET_OK;
1286 }
1287
1288@@ -96,7 +94,8 @@ int GroupConvolutionBaseCPUKernel::PreProcess() {
1289       // in
1290       auto in_tensor = in_tensors_.front();
1291       CHECK_NULL_RETURN(in_tensor);
1292-      in_shape = {in_tensor->Batch(), in_tensor->Height(), in_tensor->Width(), conv_param_->input_channel_};
1293+      in_shape = {in_tensor->Batch(), in_tensor->Height(), in_tensor->Width(),
1294+                  conv_param_->input_channel_ / group_num_};
1295       auto sub_kernel_in_tensor = group_convs_.at(i)->in_tensors().front();
1296       CHECK_NULL_RETURN(sub_kernel_in_tensor);
1297       sub_kernel_in_tensor->set_shape(in_shape);
1298@@ -108,7 +107,8 @@ int GroupConvolutionBaseCPUKernel::PreProcess() {
1299       // out
1300       auto out_tensor = out_tensors_.front();
1301       CHECK_NULL_RETURN(out_tensor);
1302-      out_shape = {out_tensor->Batch(), out_tensor->Height(), out_tensor->Width(), conv_param_->output_channel_};
1303+      out_shape = {out_tensor->Batch(), out_tensor->Height(), out_tensor->Width(),
1304+                   conv_param_->output_channel_ / group_num_};
1305       auto sub_kernel_out_tensors = group_convs_.at(i)->out_tensors();
1306       for (auto tensor : sub_kernel_out_tensors) {
1307         CHECK_NULL_RETURN(tensor);
1308@@ -148,8 +148,8 @@ int GroupConvolutionBaseCPUKernel::InitGroupParam() {
1309     MS_LOG(ERROR) << "get in_plane_ from in_tensor failed.";
1310     return RET_ERROR;
1311   }
1312-  sub_in_channel_ = conv_param_->input_channel_;
1313-  ori_in_channel_ = sub_in_channel_ * group_num_;
1314+  sub_in_channel_ = conv_param_->input_channel_ / group_num_;
1315+  ori_in_channel_ = conv_param_->input_channel_;
1316   in_thread_num_ = MSMIN(MSMAX(1, ctx_->thread_num_), in_plane_);
1317
1318   auto out_tensor = out_tensors_.front();
1319@@ -159,8 +159,8 @@ int GroupConvolutionBaseCPUKernel::InitGroupParam() {
1320     MS_LOG(ERROR) << "get out_plane_ from out_tensor failed.";
1321     return RET_ERROR;
1322   }
1323-  sub_out_channel_ = conv_param_->output_channel_;
1324-  ori_out_channel_ = sub_out_channel_ * group_num_;
1325+  sub_out_channel_ = conv_param_->output_channel_ / group_num_;
1326+  ori_out_channel_ = conv_param_->output_channel_;
1327   out_thread_num_ = MSMIN(MSMAX(1, ctx_->thread_num_), out_plane_);
1328   return RET_OK;
1329 }
1330diff --git a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.cc b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.cc
1331index f2a29bfd..fc78a887 100644
1332--- a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.cc
1333+++ b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.cc
1334@@ -96,15 +96,10 @@ lite::Tensor *CreateVarTensor(const TensorInfo &tensor_info, bool inferred) {
1335   tensor->set_data_type(tensor_info.data_type_);
1336   tensor->set_format(tensor_info.format_);
1337   tensor->set_category(tensor_info.tensor_type_);
1338-  if (tensor_info.is_in_) {
1339-    tensor->set_shape(tensor_info.shape_);
1340-  }
1341+  tensor->set_shape(tensor_info.shape_);
1342
1343   if (inferred) {
1344     // set shape of out tensor
1345-    if (!tensor_info.is_in_) {
1346-      tensor->set_shape(tensor_info.shape_);
1347-    }
1348     return TensorMalloc(tensor);
1349   }
1350   return tensor;
1351@@ -185,13 +180,16 @@ void GroupConvCreator::SetShapeOfTensors() {
1352   /* set shape */
1353   set_filter_shape({new_out_channel, conv_param_->kernel_h_, conv_param_->kernel_w_, new_in_channel});
1354   set_bias_shape({new_out_channel});
1355+  conv_param_->input_channel_ = new_in_channel;
1356+  conv_param_->output_channel_ = new_out_channel;
1357   if (infered_) {
1358-    conv_param_->input_channel_ = new_in_channel;
1359-    conv_param_->output_channel_ = new_out_channel;
1360     set_input_shape({origin_inputs_.front()->Batch(), origin_inputs_.front()->Height(), origin_inputs_.front()->Width(),
1361                      new_in_channel});
1362     set_output_shape({origin_inputs_.front()->Batch(), origin_outputs_.front()->Height(),
1363                       origin_outputs_.front()->Width(), new_out_channel});
1364+  } else {
1365+    set_input_shape({-1});
1366+    set_output_shape({-1});
1367   }
1368 }
1369
1370diff --git a/mindspore/lite/src/runtime/kernel/cpu/base/reshape_base.cc b/mindspore/lite/src/runtime/kernel/cpu/base/reshape_base.cc
1371index 58f953b8..89af7aae 100644
1372--- a/mindspore/lite/src/runtime/kernel/cpu/base/reshape_base.cc
1373+++ b/mindspore/lite/src/runtime/kernel/cpu/base/reshape_base.cc
1374@@ -105,6 +105,7 @@ REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_ExpandDims, LiteKernelCreator<R
1375 REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_ExpandDims, LiteKernelCreator<ReshapeBaseCPUKernel>)
1376 REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ExpandDims, LiteKernelCreator<ReshapeBaseCPUKernel>)
1377 REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_ExpandDims, LiteKernelCreator<ReshapeBaseCPUKernel>)
1378+REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_ExpandDims, LiteKernelCreator<ReshapeBaseCPUKernel>)
1379 REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Squeeze, LiteKernelCreator<ReshapeBaseCPUKernel>)
1380 REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Squeeze, LiteKernelCreator<ReshapeBaseCPUKernel>)
1381 REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Squeeze, LiteKernelCreator<ReshapeBaseCPUKernel>)
1382diff --git a/mindspore/lite/src/runtime/kernel/cpu/base/strided_slice.cc b/mindspore/lite/src/runtime/kernel/cpu/base/strided_slice.cc
1383index 5db44a0a..ec2080ef 100644
1384--- a/mindspore/lite/src/runtime/kernel/cpu/base/strided_slice.cc
1385+++ b/mindspore/lite/src/runtime/kernel/cpu/base/strided_slice.cc
1386@@ -56,6 +56,8 @@ void StridedSliceCPUKernel::InitFastRunParam() {
1387   for (size_t i = static_cast<size_t>(split_axis_ + 1); i < in_shape.size(); i++) {
1388     inner_ *= in_shape[i];
1389   }
1390+  parallel_on_split_axis_ = false;
1391+  parallel_on_outer_ = false;
1392   outer_ == 1 ? (parallel_on_split_axis_ = true) : (parallel_on_outer_ = true);
1393
1394   if (UpdateThreadNumPass(TC_TYPE(PrimitiveType_StridedSlice, parallel_on_outer_), 1, 1,
1395diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp16/fused_batchnorm_fp16.cc b/mindspore/lite/src/runtime/kernel/cpu/fp16/fused_batchnorm_fp16.cc
1396index 84b5a1a4..8ab4969f 100644
1397--- a/mindspore/lite/src/runtime/kernel/cpu/fp16/fused_batchnorm_fp16.cc
1398+++ b/mindspore/lite/src/runtime/kernel/cpu/fp16/fused_batchnorm_fp16.cc
1399@@ -23,6 +23,7 @@
1400 using mindspore::lite::KernelRegistrar;
1401 using mindspore::lite::RET_ERROR;
1402 using mindspore::lite::RET_OK;
1403+using mindspore::lite::RET_NO_CHANGE;
1404 using mindspore::schema::PrimitiveType_FusedBatchNorm;
1405
1406 namespace mindspore::kernel {
1407@@ -41,8 +42,11 @@ constexpr static int kOutCurrentVarIdx = 4;
1408 int FusedBatchnormFp16CPUKernel::Batchnorm2Scale(const void *scale_data, const void *bias_data, const void *mean_data,
1409                                                  const void *var_data, float eps, int kernel_num) {
1410   auto ret = InitScaleParam();
1411-  if (ret != RET_OK) {
1412-    MS_LOG(ERROR) << "Init scale parameter when converting fused_batchnorm to scale.";
1413+  if (ret == RET_NO_CHANGE) {
1414+    MS_LOG(INFO) << "Unsupported to convert fused batch norm to scale.";
1415+    return RET_NO_CHANGE;
1416+  } else if (ret != RET_OK) {
1417+    MS_LOG(ERROR) << "Init scale param failed.";
1418     return RET_ERROR;
1419   }
1420
1421diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32/fused_batchnorm_fp32.cc b/mindspore/lite/src/runtime/kernel/cpu/fp32/fused_batchnorm_fp32.cc
1422index 7243e3b0..afed28ae 100644
1423--- a/mindspore/lite/src/runtime/kernel/cpu/fp32/fused_batchnorm_fp32.cc
1424+++ b/mindspore/lite/src/runtime/kernel/cpu/fp32/fused_batchnorm_fp32.cc
1425@@ -19,6 +19,7 @@
1426
1427 using mindspore::lite::KernelRegistrar;
1428 using mindspore::lite::RET_ERROR;
1429+using mindspore::lite::RET_NO_CHANGE;
1430 using mindspore::lite::RET_OK;
1431 using mindspore::schema::PrimitiveType_FusedBatchNorm;
1432
1433@@ -65,7 +66,7 @@ int FusedBatchnormCPUKernel::InitScaleParam() {
1434
1435   scale_param_->axis_ = kNHWC_C;
1436   auto in_shape = in_tensors_[0]->shape();
1437-  CHECK_LESS_RETURN(in_shape.size(), DIMENSION_5D);
1438+  MS_CHECK_TRUE_RET(in_shape.size() == DIMENSION_4D, RET_NO_CHANGE);
1439   scale_param_->outer_size_ = 1;
1440   for (auto i = 0; i < scale_param_->axis_; i++) {
1441     scale_param_->outer_size_ *= in_shape[i];
1442@@ -80,8 +81,11 @@ int FusedBatchnormCPUKernel::InitScaleParam() {
1443 int FusedBatchnormCPUKernel::Batchnorm2Scale(const void *scale_data, const void *bias_data, const void *mean_data,
1444                                              const void *var_data, float eps, int kernel_num) {
1445   auto ret = InitScaleParam();
1446-  if (ret != RET_OK) {
1447-    MS_LOG(ERROR) << "Init scale parameter when converting fused_batchnorm to scale.";
1448+  if (ret == RET_NO_CHANGE) {
1449+    MS_LOG(INFO) << "Unsupported to convert fused batch norm to scale.";
1450+    return RET_NO_CHANGE;
1451+  } else if (ret != RET_OK) {
1452+    MS_LOG(ERROR) << "Init scale param failed.";
1453     return RET_ERROR;
1454   }
1455
1456@@ -131,6 +135,10 @@ int FusedBatchnormCPUKernel::InitConstTensor() {
1457       return RET_OK;
1458     } else {
1459       FreeScaleAndOffset();
1460+      if (ret != RET_NO_CHANGE) {
1461+        MS_LOG(ERROR) << "convert batch norm to scale failed.";
1462+        return RET_ERROR;
1463+      }
1464     }
1465   }
1466
1467@@ -188,7 +196,7 @@ int FusedBatchnormCPUKernel::Run() {
1468
1469     trained_ = true;  // trained at least once
1470   } else {
1471-    if (out_tensors_.size() >= DIMENSION_5D) {
1472+    if (op_parameter_->is_train_session_ && out_tensors_.size() >= DIMENSION_5D) {
1473       (void)memcpy(out_tensors_.at(SECOND_INPUT)->data(), scale_, out_tensors_.at(SECOND_INPUT)->Size());
1474       (void)memcpy(out_tensors_.at(THIRD_INPUT)->data(), offset_, out_tensors_.at(THIRD_INPUT)->Size());
1475       (void)memcpy(out_tensors_.at(FOURTH_INPUT)->data(), mean_, out_tensors_.at(FOURTH_INPUT)->Size());
1476diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32/oneslike_fp32.cc b/mindspore/lite/src/runtime/kernel/cpu/fp32/oneslike_fp32.cc
1477new file mode 100644
1478index 00000000..627f19f0
1479--- /dev/null
1480+++ b/mindspore/lite/src/runtime/kernel/cpu/fp32/oneslike_fp32.cc
1481@@ -0,0 +1,52 @@
1482+/**
1483+ * Copyright 2022 Huawei Technologies Co., Ltd
1484+ *
1485+ * Licensed under the Apache License, Version 2.0 (the "License");
1486+ * you may not use this file except in compliance with the License.
1487+ * You may obtain a copy of the License at
1488+ *
1489+ * http://www.apache.org/licenses/LICENSE-2.0
1490+ *
1491+ * Unless required by applicable law or agreed to in writing, software
1492+ * distributed under the License is distributed on an "AS IS" BASIS,
1493+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1494+ * See the License for the specific language governing permissions and
1495+ * limitations under the License.
1496+ */
1497+
1498+#include "src/runtime/kernel/cpu/fp32/oneslike_fp32.h"
1499+#include "schema/model_generated.h"
1500+#include "nnacl/base/zeroslike_base.h"
1501+#include "src/runtime/kernel_registry.h"
1502+#include "include/errorcode.h"
1503+
1504+using mindspore::kernel::KERNEL_ARCH;
1505+using mindspore::lite::KernelRegistrar;
1506+using mindspore::lite::RET_ERROR;
1507+using mindspore::lite::RET_OK;
1508+using mindspore::schema::PrimitiveType_OnesLike;
1509+
1510+namespace mindspore::kernel {
1511+int OnesLikeCPUKernel::Prepare() {
1512+  CHECK_LESS_RETURN(in_tensors_.size(), 1);
1513+  CHECK_LESS_RETURN(out_tensors_.size(), 1);
1514+  return RET_OK;
1515+}
1516+
1517+int OnesLikeCPUKernel::Run() {
1518+  auto output = out_tensors_[0];
1519+  CHECK_NULL_RETURN(output);
1520+  if (output->data_type() == kNumberTypeInt32) {
1521+    ApproximateOnesLike(static_cast<int *>(output->data()), output->ElementsNum());
1522+  } else if (output->data_type() == kNumberTypeFloat32) {
1523+    ApproximateOnesLike(static_cast<float *>(output->data()), output->ElementsNum());
1524+  }
1525+  return RET_OK;
1526+}
1527+
1528+REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_OnesLike, LiteKernelCreator<OnesLikeCPUKernel>)
1529+REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_OnesLike, LiteKernelCreator<OnesLikeCPUKernel>)
1530+#ifdef ENABLE_FP16
1531+REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_OnesLike, LiteKernelCreator<OnesLikeCPUKernel>)
1532+#endif
1533+}  // namespace mindspore::kernel
1534\ No newline at end of file
1535diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32/oneslike_fp32.h b/mindspore/lite/src/runtime/kernel/cpu/fp32/oneslike_fp32.h
1536new file mode 100644
1537index 00000000..fdca97cb
1538--- /dev/null
1539+++ b/mindspore/lite/src/runtime/kernel/cpu/fp32/oneslike_fp32.h
1540@@ -0,0 +1,46 @@
1541+/**
1542+ * Copyright 2022 Huawei Technologies Co., Ltd
1543+ *
1544+ * Licensed under the Apache License, Version 2.0 (the "License");
1545+ * you may not use this file except in compliance with the License.
1546+ * You may obtain a copy of the License at
1547+ *
1548+ * http://www.apache.org/licenses/LICENSE-2.0
1549+ *
1550+ * Unless required by applicable law or agreed to in writing, software
1551+ * distributed under the License is distributed on an "AS IS" BASIS,
1552+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1553+ * See the License for the specific language governing permissions and
1554+ * limitations under the License.
1555+ */
1556+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_ONESLike_FP32_H_
1557+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_ONESLike_FP32_H_
1558+
1559+#include <vector>
1560+#include "src/runtime/lite_kernel.h"
1561+
1562+namespace mindspore::kernel {
1563+class OnesLikeCPUKernel : public LiteKernel {
1564+ public:
1565+  OnesLikeCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
1566+                    const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
1567+      : LiteKernel(parameter, inputs, outputs, ctx) {}
1568+
1569+  ~OnesLikeCPUKernel() = default;
1570+
1571+  int Prepare() override;
1572+  int ReSize() override { return lite::RET_OK; }
1573+  int Run() override;
1574+
1575+ private:
1576+  template <typename T>
1577+  void ApproximateOnesLike(T *output, int data_size) {
1578+    for (int i = 0; i < data_size; ++i) {
1579+      output[i] = 1;
1580+    }
1581+    return;
1582+  }
1583+};
1584+}  // namespace mindspore::kernel
1585+
1586+#endif  // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_ONESLike_FP32_H_
1587\ No newline at end of file
1588diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy.cc b/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy.cc
1589new file mode 100644
1590index 00000000..a24976f8
1591--- /dev/null
1592+++ b/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy.cc
1593@@ -0,0 +1,120 @@
1594+/**
1595+ * Copyright 2022 Huawei Technologies Co., Ltd
1596+ *
1597+ * Licensed under the Apache License, Version 2.0 (the "License");
1598+ * you may not use this file except in compliance with the License.
1599+ * You may obtain a copy of the License at
1600+ *
1601+ * http://www.apache.org/licenses/LICENSE-2.0
1602+ *
1603+ * Unless required by applicable law or agreed to in writing, software
1604+ * distributed under the License is distributed on an "AS IS" BASIS,
1605+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1606+ * See the License for the specific language governing permissions and
1607+ * limitations under the License.
1608+ */
1609+
1610+#include "src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy.h"
1611+#include "src/runtime/kernel_registry.h"
1612+#include "include/errorcode.h"
1613+#include "nnacl/fp32_grad/binary_cross_entropy.h"
1614+
1615+using mindspore::lite::KernelRegistrar;
1616+using mindspore::lite::RET_ERROR;
1617+using mindspore::lite::RET_OK;
1618+using mindspore::schema::PrimitiveType_BinaryCrossEntropy;
1619+
1620+namespace mindspore::kernel {
1621+BinaryCrossEntropyCPUKernel::~BinaryCrossEntropyCPUKernel() {
1622+  if (tmp_loss_ != nullptr) {
1623+    free(tmp_loss_);
1624+    tmp_loss_ = nullptr;
1625+  }
1626+}
1627+
1628+int BinaryCrossEntropyCPUKernel::ReSize() {
1629+  CHECK_LESS_RETURN(in_tensors_.size(), C2NUM);
1630+  CHECK_LESS_RETURN(out_tensors_.size(), 1);
1631+  CHECK_NULL_RETURN(in_tensors_.at(0));
1632+  CHECK_NULL_RETURN(in_tensors_.at(1));
1633+  if (in_tensors_.size() == C3NUM) {
1634+    weight_defined_ = true;
1635+    CHECK_NULL_RETURN(in_tensors_.at(C2NUM));
1636+  }
1637+  CHECK_NULL_RETURN(out_tensors_.at(0));
1638+  CHECK_NULL_RETURN(op_parameter_);
1639+
1640+  auto param_ = reinterpret_cast<BinaryCrossEntropyParameter *>(op_parameter_);
1641+  CHECK_NULL_RETURN(param_);
1642+  if (tmp_loss_ != nullptr) {
1643+    free(tmp_loss_);
1644+    tmp_loss_ = nullptr;
1645+  }
1646+  size_t input_size = in_tensors_.at(0)->ElementsNum();
1647+  tmp_loss_ = reinterpret_cast<float *>(malloc(input_size * sizeof(float)));
1648+  if (tmp_loss_ == nullptr) {
1649+    MS_LOG(ERROR) << "malloc tmp_loss_ for BinaryCrossEntropy op failed";
1650+    return RET_ERROR;
1651+  }
1652+
1653+  return RET_OK;
1654+}
1655+
1656+int BinaryCrossEntropyCPUKernel::DoExecute(int task_id) {
1657+  auto logits = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
1658+  CHECK_NULL_RETURN(logits);
1659+  auto labels = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
1660+  CHECK_NULL_RETURN(labels);
1661+  auto *out = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
1662+  CHECK_NULL_RETURN(out);
1663+
1664+  auto param_ = reinterpret_cast<BinaryCrossEntropyParameter *>(op_parameter_);
1665+  int reduction = param_->reduction;
1666+  size_t input_size = in_tensors_.at(0)->ElementsNum();
1667+  if (weight_defined_) {
1668+    weight_ = reinterpret_cast<float *>(in_tensors_.at(C2NUM)->MutableData());
1669+    CHECK_NULL_RETURN(weight_);
1670+  }
1671+  BinaryCrossEntropy(input_size, reduction, logits, labels, weight_, out, tmp_loss_, weight_defined_);
1672+  return RET_OK;
1673+}
1674+
1675+int BinaryCrossEntropyRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
1676+  CHECK_NULL_RETURN(cdata);
1677+  auto bin_crs_ent_kernel = reinterpret_cast<BinaryCrossEntropyCPUKernel *>(cdata);
1678+  auto error_code = bin_crs_ent_kernel->DoExecute(task_id);
1679+  if (error_code != RET_OK) {
1680+    MS_LOG(ERROR) << "BinaryCrossEntropy error task_id[" << task_id << "] error_code[" << error_code << "]";
1681+    return RET_ERROR;
1682+  }
1683+  return RET_OK;
1684+}
1685+
1686+int BinaryCrossEntropyCPUKernel::Run() {
1687+  int error_code = ParallelLaunch(this->ms_context_, BinaryCrossEntropyRun, this, 1);
1688+  if (error_code != RET_OK) {
1689+    MS_LOG(ERROR) << "SigmoidCrossEntropyWithLogits function error error_code[" << error_code << "]";
1690+    return RET_ERROR;
1691+  }
1692+  return RET_OK;
1693+}
1694+
1695+int BinaryCrossEntropyCPUKernel::Prepare() { return ReSize(); }
1696+
1697+kernel::LiteKernel *CpuBinaryCrossEntropyFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
1698+                                                           const std::vector<lite::Tensor *> &outputs,
1699+                                                           OpParameter *opParameter, const lite::Context *ctx,
1700+                                                           const kernel::KernelKey &desc) {
1701+  MS_ASSERT(opParameter != nullptr);
1702+  MS_ASSERT(desc.type == schema::PrimitiveType_BinaryCrossEntropy);
1703+  auto *kernel = new (std::nothrow)
1704+    BinaryCrossEntropyCPUKernel(opParameter, inputs, outputs, static_cast<const lite::InnerContext *>(ctx));
1705+  if (kernel == nullptr) {
1706+    MS_LOG(ERROR) << "new SigmoidCrossEntropyWithLogits failed";
1707+    return nullptr;
1708+  }
1709+  return kernel;
1710+}
1711+
1712+REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BinaryCrossEntropy, CpuBinaryCrossEntropyFp32KernelCreator)
1713+}  // namespace mindspore::kernel
1714diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy.h b/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy.h
1715new file mode 100644
1716index 00000000..39a7181e
1717--- /dev/null
1718+++ b/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy.h
1719@@ -0,0 +1,42 @@
1720+/**
1721+ * Copyright 2022 Huawei Technologies Co., Ltd
1722+ *
1723+ * Licensed under the Apache License, Version 2.0 (the "License");
1724+ * you may not use this file except in compliance with the License.
1725+ * You may obtain a copy of the License at
1726+ *
1727+ * http://www.apache.org/licenses/LICENSE-2.0
1728+ *
1729+ * Unless required by applicable law or agreed to in writing, software
1730+ * distributed under the License is distributed on an "AS IS" BASIS,
1731+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1732+ * See the License for the specific language governing permissions and
1733+ * limitations under the License.
1734+ */
1735+
1736+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_BINARY_CROSS_ENTROPY_H_
1737+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_BINARY_CROSS_ENTROPY_H_
1738+
1739+#include <vector>
1740+#include "src/runtime/lite_kernel.h"
1741+
1742+namespace mindspore::kernel {
1743+class BinaryCrossEntropyCPUKernel : public LiteKernel {
1744+ public:
1745+  explicit BinaryCrossEntropyCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
1746+                                       const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
1747+      : LiteKernel(parameter, inputs, outputs, ctx) {}
1748+  ~BinaryCrossEntropyCPUKernel() override;
1749+  int Prepare() override;
1750+  int ReSize() override;
1751+  int Run() override;
1752+  int DoExecute(int task_id);
1753+
1754+ protected:
1755+  float *tmp_loss_ = nullptr;
1756+  bool weight_defined_{false};
1757+  float *weight_ = nullptr;
1758+};
1759+}  // namespace mindspore::kernel
1760+
1761+#endif  // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_BINARY_CROSS_ENTROPY_H_
1762diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy_grad.cc b/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy_grad.cc
1763new file mode 100644
1764index 00000000..abac8fd1
1765--- /dev/null
1766+++ b/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy_grad.cc
1767@@ -0,0 +1,105 @@
1768+/**
1769+ * Copyright 2022 Huawei Technologies Co., Ltd
1770+ *
1771+ * Licensed under the Apache License, Version 2.0 (the "License");
1772+ * you may not use this file except in compliance with the License.
1773+ * You may obtain a copy of the License at
1774+ *
1775+ * http://www.apache.org/licenses/LICENSE-2.0
1776+ *
1777+ * Unless required by applicable law or agreed to in writing, software
1778+ * distributed under the License is distributed on an "AS IS" BASIS,
1779+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1780+ * See the License for the specific language governing permissions and
1781+ * limitations under the License.
1782+ */
1783+
1784+#include "src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy_grad.h"
1785+#include "src/runtime/kernel_registry.h"
1786+#include "include/errorcode.h"
1787+#include "nnacl/fp32_grad/binary_cross_entropy_grad.h"
1788+
1789+using mindspore::lite::KernelRegistrar;
1790+using mindspore::lite::RET_ERROR;
1791+using mindspore::lite::RET_OK;
1792+using mindspore::schema::PrimitiveType_BinaryCrossEntropyGrad;
1793+
1794+namespace mindspore::kernel {
1795+int BinaryCrossEntropyGradCPUKernel::ReSize() {
1796+  CHECK_LESS_RETURN(in_tensors_.size(), C3NUM);
1797+  CHECK_LESS_RETURN(out_tensors_.size(), C1NUM);
1798+  CHECK_NULL_RETURN(in_tensors_.at(C0NUM));
1799+  CHECK_NULL_RETURN(in_tensors_.at(C1NUM));
1800+  CHECK_NULL_RETURN(in_tensors_.at(C2NUM));
1801+  if (in_tensors_.size() == C4NUM) {
1802+    weight_defined_ = true;
1803+    CHECK_NULL_RETURN(in_tensors_.at(C3NUM));
1804+  }
1805+  CHECK_NULL_RETURN(out_tensors_.at(0));
1806+  CHECK_NULL_RETURN(op_parameter_);
1807+  auto param_ = reinterpret_cast<BinaryCrossEntropyGradParameter *>(op_parameter_);
1808+  CHECK_NULL_RETURN(param_);
1809+
1810+  return RET_OK;
1811+}
1812+
1813+int BinaryCrossEntropyGradCPUKernel::DoExecute(int task_id) {
1814+  auto input_x = reinterpret_cast<float *>(in_tensors_.at(C0NUM)->MutableData());
1815+  CHECK_NULL_RETURN(input_x);
1816+  auto input_y = reinterpret_cast<float *>(in_tensors_.at(C1NUM)->MutableData());
1817+  CHECK_NULL_RETURN(input_y);
1818+  auto dloss = reinterpret_cast<float *>(in_tensors_.at(C2NUM)->MutableData());
1819+  CHECK_NULL_RETURN(dloss);
1820+  if (weight_defined_) {
1821+    weight_ = reinterpret_cast<float *>(in_tensors_.at(C3NUM)->MutableData());
1822+    CHECK_NULL_RETURN(weight_);
1823+  }
1824+  auto *out = reinterpret_cast<float *>(out_tensors_.at(C0NUM)->MutableData());
1825+  CHECK_NULL_RETURN(out);
1826+
1827+  auto param_ = reinterpret_cast<BinaryCrossEntropyGradParameter *>(op_parameter_);
1828+  int reduction = param_->reduction;
1829+  size_t input_size = in_tensors_.at(0)->ElementsNum();
1830+  BinaryCrossEntropyGrad(input_size, reduction, input_x, input_y, weight_, dloss, out, weight_defined_);
1831+  return RET_OK;
1832+}
1833+
1834+int BinaryCrossEntropyGradRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
1835+  CHECK_NULL_RETURN(cdata);
1836+  auto bin_crs_ent_kernel = reinterpret_cast<BinaryCrossEntropyGradCPUKernel *>(cdata);
1837+  auto error_code = bin_crs_ent_kernel->DoExecute(task_id);
1838+  if (error_code != RET_OK) {
1839+    MS_LOG(ERROR) << "BinaryCrossEntropyGrad error task_id[" << task_id << "] error_code[" << error_code << "]";
1840+    return RET_ERROR;
1841+  }
1842+  return RET_OK;
1843+}
1844+
1845+int BinaryCrossEntropyGradCPUKernel::Run() {
1846+  int error_code = ParallelLaunch(this->ms_context_, BinaryCrossEntropyGradRun, this, 1);
1847+  if (error_code != RET_OK) {
1848+    MS_LOG(ERROR) << "BinaryCrossEntropyGrad function error error_code[" << error_code << "]";
1849+    return RET_ERROR;
1850+  }
1851+  return RET_OK;
1852+}
1853+
1854+int BinaryCrossEntropyGradCPUKernel::Prepare() { return ReSize(); }
1855+
1856+kernel::LiteKernel *CpuBinaryCrossEntropyGradFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
1857+                                                               const std::vector<lite::Tensor *> &outputs,
1858+                                                               OpParameter *opParameter, const lite::Context *ctx,
1859+                                                               const kernel::KernelKey &desc) {
1860+  MS_ASSERT(opParameter != nullptr);
1861+  MS_ASSERT(desc.type == schema::PrimitiveType_BinaryCrossEntropyGrad);
1862+  auto *kernel = new (std::nothrow)
1863+    BinaryCrossEntropyGradCPUKernel(opParameter, inputs, outputs, static_cast<const lite::InnerContext *>(ctx));
1864+  if (kernel == nullptr) {
1865+    MS_LOG(ERROR) << "new BinaryCrossEntropyGrad failed";
1866+    return nullptr;
1867+  }
1868+  return kernel;
1869+}
1870+
1871+REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BinaryCrossEntropyGrad, CpuBinaryCrossEntropyGradFp32KernelCreator)
1872+}  // namespace mindspore::kernel
1873diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy_grad.h b/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy_grad.h
1874new file mode 100644
1875index 00000000..d289eb65
1876--- /dev/null
1877+++ b/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy_grad.h
1878@@ -0,0 +1,41 @@
1879+/**
1880+ * Copyright 2022 Huawei Technologies Co., Ltd
1881+ *
1882+ * Licensed under the Apache License, Version 2.0 (the "License");
1883+ * you may not use this file except in compliance with the License.
1884+ * You may obtain a copy of the License at
1885+ *
1886+ * http://www.apache.org/licenses/LICENSE-2.0
1887+ *
1888+ * Unless required by applicable law or agreed to in writing, software
1889+ * distributed under the License is distributed on an "AS IS" BASIS,
1890+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1891+ * See the License for the specific language governing permissions and
1892+ * limitations under the License.
1893+ */
1894+
1895+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_BINARY_CROSS_ENTROPY_GRAD_H_
1896+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_BINARY_CROSS_ENTROPY_GRAD_H_
1897+
1898+#include <vector>
1899+#include "src/runtime/lite_kernel.h"
1900+
1901+namespace mindspore::kernel {
1902+class BinaryCrossEntropyGradCPUKernel : public LiteKernel {
1903+ public:
1904+  explicit BinaryCrossEntropyGradCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
1905+                                           const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
1906+      : LiteKernel(parameter, inputs, outputs, ctx) {}
1907+  ~BinaryCrossEntropyGradCPUKernel() override {}
1908+  int Prepare() override;
1909+  int ReSize() override;
1910+  int Run() override;
1911+  int DoExecute(int task_id);
1912+
1913+ protected:
1914+  bool weight_defined_{false};
1915+  float *weight_ = nullptr;
1916+};
1917+}  // namespace mindspore::kernel
1918+
1919+#endif  // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_BINARY_CROSS_ENTROPY_GRAD_H_
1920diff --git a/mindspore/lite/src/runtime/kernel/gpu/opencl/CMakeLists.txt b/mindspore/lite/src/runtime/kernel/gpu/opencl/CMakeLists.txt
1921new file mode 100644
1922index 00000000..3b42ed7a
1923--- /dev/null
1924+++ b/mindspore/lite/src/runtime/kernel/gpu/opencl/CMakeLists.txt
1925@@ -0,0 +1,11 @@
1926+string(REPLACE "-fvisibility-inlines-hidden" "" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
1927+string(REPLACE "-fvisibility=hidden" "" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
1928+string(REPLACE "-fvisibility-inlines-hidden" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
1929+string(REPLACE "-fvisibility=hidden" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
1930+if(MSLITE_GPU_BACKEND STREQUAL opencl)
1931+    file(GLOB_RECURSE OPENCL_KERNEL_SRC
1932+        ${CMAKE_CURRENT_SOURCE_DIR}/*.cc
1933+        ${CMAKE_CURRENT_SOURCE_DIR}/../../opencl/*.cc)
1934+    add_library(opencl_kernel_mid OBJECT ${OPENCL_KERNEL_SRC})
1935+    add_dependencies(opencl_kernel_mid fbs_src)
1936+endif()
1937diff --git a/mindspore/lite/src/runtime/kernel/opencl/CMakeLists.txt b/mindspore/lite/src/runtime/kernel/opencl/CMakeLists.txt
1938deleted file mode 100644
1939index cad0f8f7..00000000
1940--- a/mindspore/lite/src/runtime/kernel/opencl/CMakeLists.txt
1941+++ /dev/null
1942@@ -1,8 +0,0 @@
1943-if(MSLITE_GPU_BACKEND STREQUAL opencl)
1944-    file(GLOB_RECURSE OPENCL_KERNEL_SRC
1945-        ${CMAKE_CURRENT_SOURCE_DIR}/*.cc
1946-        ${CMAKE_CURRENT_SOURCE_DIR}/kernel/*.cc
1947-        ${CMAKE_CURRENT_SOURCE_DIR}/kernel/int8/*.cc)
1948-    add_library(opencl_kernel_mid OBJECT ${OPENCL_KERNEL_SRC})
1949-    add_dependencies(opencl_kernel_mid fbs_src)
1950-endif()
1951diff --git a/mindspore/lite/src/runtime/kernel_exec_util.h b/mindspore/lite/src/runtime/kernel_exec_util.h
1952index e45c185b..9ce5267e 100644
1953--- a/mindspore/lite/src/runtime/kernel_exec_util.h
1954+++ b/mindspore/lite/src/runtime/kernel_exec_util.h
1955@@ -24,7 +24,7 @@
1956
1957 namespace mindspore::kernel {
1958
1959-class KernelExecUtil {
1960+class MS_API KernelExecUtil {
1961  public:
1962   static std::vector<KernelExec *> SubgraphInputNodes(const std::vector<KernelExec *> &kernels);
1963   static std::vector<KernelExec *> SubgraphOutputNodes(const std::vector<KernelExec *> &kernels);
1964diff --git a/mindspore/lite/src/runtime/kernel_registry.h b/mindspore/lite/src/runtime/kernel_registry.h
1965index 853d863a..f563d82d 100644
1966--- a/mindspore/lite/src/runtime/kernel_registry.h
1967+++ b/mindspore/lite/src/runtime/kernel_registry.h
1968@@ -1,5 +1,5 @@
1969 /**
1970- * Copyright 2020 Huawei Technologies Co., Ltd
1971+ * Copyright 2020-2023 Huawei Technologies Co., Ltd
1972  *
1973  * Licensed under the Apache License, Version 2.0 (the "License");
1974  * you may not use this file except in compliance with the License.
1975@@ -31,7 +31,7 @@ using mindspore::schema::PrimitiveType_MAX;
1976 using mindspore::schema::PrimitiveType_MIN;
1977
1978 namespace mindspore::lite {
1979-class KernelRegistry {
1980+class MS_API KernelRegistry {
1981  public:
1982   KernelRegistry() = default;
1983   virtual ~KernelRegistry();
1984diff --git a/mindspore/lite/src/runtime/lite_kernel.h b/mindspore/lite/src/runtime/lite_kernel.h
1985index ce829320..a27f77d8 100644
1986--- a/mindspore/lite/src/runtime/lite_kernel.h
1987+++ b/mindspore/lite/src/runtime/lite_kernel.h
1988@@ -1,5 +1,5 @@
1989 /**
1990- * Copyright 2021-2022 Huawei Technologies Co., Ltd
1991+ * Copyright 2021-2023 Huawei Technologies Co., Ltd
1992  *
1993  * Licensed under the Apache License, Version 2.0 (the "License");
1994  * you may not use this file except in compliance with the License.
1995@@ -37,7 +37,7 @@
1996 using mindspore::infer::Abstractkernel;
1997
1998 namespace mindspore::kernel {
1999-class LiteKernel : public Abstractkernel {
2000+class MS_API LiteKernel : public Abstractkernel {
2001  public:
2002   LiteKernel() = default;
2003
2004diff --git a/mindspore/lite/src/runtime/lite_model.h b/mindspore/lite/src/runtime/lite_model.h
2005index f6c7ebc4..af62cb91 100644
2006--- a/mindspore/lite/src/runtime/lite_model.h
2007+++ b/mindspore/lite/src/runtime/lite_model.h
2008@@ -1,5 +1,5 @@
2009 /**
2010- * Copyright 2020 Huawei Technologies Co., Ltd
2011+ * Copyright 2020-2023 Huawei Technologies Co., Ltd
2012  *
2013  * Licensed under the Apache License, Version 2.0 (the "License");
2014  * you may not use this file except in compliance with the License.
2015@@ -36,7 +36,7 @@
2016
2017 namespace mindspore {
2018 namespace lite {
2019-class LiteModel : public Model {
2020+class MS_API LiteModel : public Model {
2021  public:
2022   explicit LiteModel(std::string model_path = "") : model_path_(std::move(model_path)) {}
2023
2024diff --git a/mindspore/lite/src/runtime/lite_session.cc b/mindspore/lite/src/runtime/lite_session.cc
2025index b8808e21..dffb39e7 100644
2026--- a/mindspore/lite/src/runtime/lite_session.cc
2027+++ b/mindspore/lite/src/runtime/lite_session.cc
2028@@ -504,7 +504,7 @@ void LiteSession::FreePackOpWeight(const std::vector<kernel::KernelExec *> &kern
2029     auto inputs = kernel->in_tensors();
2030     for (auto *tensor : inputs) {
2031       MS_ASSERT(tensor != nullptr);
2032-      if (!tensor->IsConst()) {
2033+      if (!tensor->IsConst() || tensor->ref_count() >= 1) {
2034         continue;
2035       }
2036       tensor->FreeData();
2037@@ -512,6 +512,29 @@ void LiteSession::FreePackOpWeight(const std::vector<kernel::KernelExec *> &kern
2038   }
2039 }
2040
2041+void LiteSession::MarkSharedWeight(const std::vector<kernel::KernelExec *> &kernels) {
2042+  // For reducing runtime RAM
2043+  // free pack-op weight because pack-op will not access origin weight in runtime
2044+  for (auto *kernel : kernels) {
2045+    MS_ASSERT(kernel != nullptr);
2046+    if (kernel->subgraph_type() == kernel::kNotSubGraph) {
2047+      if (IsPackedOp(static_cast<int>(kernel->type()))) {
2048+        continue;
2049+      }
2050+    } else {
2051+      auto subgraph = reinterpret_cast<kernel::SubGraphKernel *>(kernel);
2052+      MarkSharedWeight(subgraph->nodes());
2053+    }
2054+    auto inputs = kernel->in_tensors();
2055+    for (auto *tensor : inputs) {
2056+      MS_ASSERT(tensor != nullptr);
2057+      if (tensor->IsConst()) {
2058+        tensor->IncRefCount();
2059+      }
2060+    }
2061+  }
2062+}
2063+
2064 int LiteSession::CompileGraph(Model *model) {
2065   auto ret = PreCheck(model);
2066   if (ret != RET_OK) {
2067@@ -572,7 +595,7 @@ int LiteSession::CompileGraph(Model *model) {
2068     is_running_.store(false);
2069     return ret;
2070   }
2071-
2072+  MarkSharedWeight(kernels_);
2073   FreePackOpWeight(kernels_);
2074
2075   ret = RuntimeAllocatorInit();
2076@@ -1727,6 +1750,7 @@ int lite::LiteSession::LoadModelAndCompileByBuf(const char *model_buf, mindspore
2077     delete model;
2078     return RET_ERROR;
2079   }
2080+  model->Free();
2081   set_model(model);
2082   return RET_OK;
2083 }
2084diff --git a/mindspore/lite/src/runtime/lite_session.h b/mindspore/lite/src/runtime/lite_session.h
2085index 255e90b5..2fdb1eb7 100644
2086--- a/mindspore/lite/src/runtime/lite_session.h
2087+++ b/mindspore/lite/src/runtime/lite_session.h
2088@@ -1,5 +1,5 @@
2089 /**
2090- * Copyright 2020 Huawei Technologies Co., Ltd
2091+ * Copyright 2020-2023 Huawei Technologies Co., Ltd
2092  *
2093  * Licensed under the Apache License, Version 2.0 (the "License");
2094  * you may not use this file except in compliance with the License.
2095@@ -39,7 +39,7 @@
2096
2097 namespace mindspore {
2098 namespace lite {
2099-class LiteSession {
2100+class MS_API LiteSession {
2101  public:
2102   LiteSession();
2103   virtual ~LiteSession();
2104@@ -101,6 +101,11 @@ class LiteSession {
2105                      std::vector<std::string> out_put_tensor_name = {}) {
2106     return mindspore::lite::RET_ERROR;
2107   }
2108+  virtual int Export(Buffer *model_buffer, lite::ModelType model_type = lite::MT_TRAIN,
2109+                     lite::QuantizationType quant_type = lite::QT_DEFAULT, lite::FormatType = lite::FT_FLATBUFFERS,
2110+                     std::vector<std::string> out_put_tensor_name = {}) {
2111+    return mindspore::lite::RET_ERROR;
2112+  }
2113   virtual int UpdateWeights(std::vector<lite::Tensor *> new_weights) { return mindspore::lite::RET_ERROR; }
2114   virtual std::vector<lite::Tensor *> GetFeatureMaps() const {
2115     std::vector<lite::Tensor *> features;
2116@@ -142,6 +147,7 @@ class LiteSession {
2117     const std::unordered_map<Tensor *, Tensor *> &isolate_input_map = std::unordered_map<Tensor *, Tensor *>());
2118   static void FreePackOpWeight(const std::vector<kernel::KernelExec *> &kernels);
2119   std::string ParseWeightPath();
2120+  static void MarkSharedWeight(const std::vector<kernel::KernelExec *> &kernels);
2121
2122  private:
2123   int PreCheck(Model *model);
2124diff --git a/mindspore/lite/src/runtime/weight_decoder.h b/mindspore/lite/src/runtime/weight_decoder.h
2125index 7c9e514c..006b4895 100644
2126--- a/mindspore/lite/src/runtime/weight_decoder.h
2127+++ b/mindspore/lite/src/runtime/weight_decoder.h
2128@@ -1,5 +1,5 @@
2129 /**
2130- * Copyright 2020-2022 Huawei Technologies Co., Ltd
2131+ * Copyright 2020-2023 Huawei Technologies Co., Ltd
2132  *
2133  * Licensed under the Apache License, Version 2.0 (the "License");
2134  * you may not use this file except in compliance with the License.
2135@@ -39,7 +39,7 @@ static constexpr int kBitNum32 = 32;
2136
2137 namespace mindspore::lite {
2138
2139-class WeightDecoder {
2140+class MS_API WeightDecoder {
2141  public:
2142   static int DequantNode(const OpParameter *op_parameter, const std::vector<Tensor *> &in_tensors, TypeId dst_data_type,
2143                          const std::string &model_version, bool float_mode);
2144diff --git a/mindspore/lite/src/tensor.h b/mindspore/lite/src/tensor.h
2145index f30fe090..178d4754 100644
2146--- a/mindspore/lite/src/tensor.h
2147+++ b/mindspore/lite/src/tensor.h
2148@@ -53,7 +53,7 @@ struct LiteQuantParam {
2149   double max{255.0};
2150 };
2151
2152-class Tensor {
2153+class MS_API Tensor {
2154  public:
2155   Tensor() = default;
2156
2157diff --git a/mindspore/lite/src/tensorlist.h b/mindspore/lite/src/tensorlist.h
2158index 2e6f8d79..39058057 100644
2159--- a/mindspore/lite/src/tensorlist.h
2160+++ b/mindspore/lite/src/tensorlist.h
2161@@ -55,7 +55,7 @@ namespace mindspore::lite {
2162  *
2163  *  See the code for other constructors.
2164  */
2165-class TensorList : public Tensor {
2166+class MS_API TensorList : public Tensor {
2167  public:
2168   TensorList() = default;
2169
2170diff --git a/mindspore/lite/src/train/graph_fusion.cc b/mindspore/lite/src/train/graph_fusion.cc
2171index 0e582e40..1af44e45 100644
2172--- a/mindspore/lite/src/train/graph_fusion.cc
2173+++ b/mindspore/lite/src/train/graph_fusion.cc
2174@@ -22,6 +22,8 @@
2175 #include "tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h"
2176 #include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h"
2177 #include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h"
2178+#include "src/train/optimizer/fusion/matmul_activation_fusion_pass.h"
2179+#include "src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.h"
2180
2181 namespace mindspore {
2182 namespace lite {
2183@@ -41,7 +43,9 @@ STATUS GraphFusion::Run(schema::MetaGraphT *graph) {
2184   }
2185   auto old_nodes = GetGraphNodes(*graph);
2186   Optimizer fusion_optimizer;
2187+  fusion_optimizer.AddPass(new (std::nothrow) ReshapeGatherReshapeFusionPass());
2188   fusion_optimizer.AddPass(new (std::nothrow) MatMulBiasAddFusionPass());
2189+  fusion_optimizer.AddPass(new (std::nothrow) MatMulActivationFusionPass());
2190   fusion_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
2191   fusion_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
2192   auto status = fusion_optimizer.Run(graph);
2193diff --git a/mindspore/lite/src/train/optimizer/common/fusion_utils.cc b/mindspore/lite/src/train/optimizer/common/fusion_utils.cc
2194new file mode 100644
2195index 00000000..3edb1a1b
2196--- /dev/null
2197+++ b/mindspore/lite/src/train/optimizer/common/fusion_utils.cc
2198@@ -0,0 +1,37 @@
2199+/**
2200+ * Copyright 2022 Huawei Technologies Co., Ltd
2201+ *
2202+ * Licensed under the Apache License, Version 2.0 (the "License");
2203+ * you may not use this file except in compliance with the License.
2204+ * You may obtain a copy of the License at
2205+ *
2206+ * http://www.apache.org/licenses/LICENSE-2.0
2207+ *
2208+ * Unless required by applicable law or agreed to in writing, software
2209+ * distributed under the License is distributed on an "AS IS" BASIS,
2210+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2211+ * See the License for the specific language governing permissions and
2212+ * limitations under the License.
2213+ */
2214+#include <vector>
2215+#include <unordered_map>
2216+#include <string>
2217+#include <memory>
2218+#include "src/common/log_util.h"
2219+#include "src/train/optimizer/common/fusion_utils.h"
2220+
2221+namespace mindspore {
2222+namespace opt {
2223+STATUS GetMatchNodeIndex(schema::MetaGraphT *graph,
2224+                         const std::unordered_map<std::string, std::shared_ptr<lite::Path>> &matched_path,
2225+                         const std::string &node_name, size_t *node_index) {
2226+  auto node_path_iter = matched_path.find(node_name);
2227+  MS_CHECK_TRUE_MSG(node_path_iter != matched_path.end(), RET_ERROR, "cannot find node_path");
2228+  const auto &node_path = node_path_iter->second;
2229+  MS_CHECK_TRUE_MSG(node_path != nullptr, RET_NULL_PTR, "node_path is empty");
2230+  *node_index = node_path->nodeIdx;
2231+  MS_CHECK_TRUE_MSG(*node_index < graph->nodes.size(), RET_ERROR, "node_index is out of range");
2232+  return RET_OK;
2233+}
2234+}  // namespace opt
2235+}  // namespace mindspore
2236diff --git a/mindspore/lite/src/train/optimizer/common/fusion_utils.h b/mindspore/lite/src/train/optimizer/common/fusion_utils.h
2237new file mode 100644
2238index 00000000..7f80cd49
2239--- /dev/null
2240+++ b/mindspore/lite/src/train/optimizer/common/fusion_utils.h
2241@@ -0,0 +1,50 @@
2242+/**
2243+ * Copyright 2022 Huawei Technologies Co., Ltd
2244+ *
2245+ * Licensed under the Apache License, Version 2.0 (the "License");
2246+ * you may not use this file except in compliance with the License.
2247+ * You may obtain a copy of the License at
2248+ *
2249+ * http://www.apache.org/licenses/LICENSE-2.0
2250+ *
2251+ * Unless required by applicable law or agreed to in writing, software
2252+ * distributed under the License is distributed on an "AS IS" BASIS,
2253+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2254+ * See the License for the specific language governing permissions and
2255+ * limitations under the License.
2256+ */
2257+
2258+#ifndef MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_COMMON_FUSION_UTILS_H_
2259+#define MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_COMMON_FUSION_UTILS_H_
2260+
2261+#include <vector>
2262+#include <unordered_map>
2263+#include <string>
2264+#include <memory>
2265+#include "src/common/utils.h"
2266+#include "schema/inner/model_generated.h"
2267+#include "tools/converter/legacy_optimizer/fusion/fusion_pattern.h"
2268+
2269+using mindspore::lite::RET_ERROR;
2270+using mindspore::lite::RET_NULL_PTR;
2271+using mindspore::lite::RET_OK;
2272+using mindspore::lite::STATUS;
2273+namespace mindspore {
2274+namespace opt {
2275+inline constexpr int kInputIndexZero = 0;
2276+inline constexpr int kInputIndexOne = 1;
2277+inline constexpr int kInputIndexTwo = 2;
2278+inline constexpr int kOutputIndexZero = 0;
2279+inline constexpr int kOutputIndexOne = 1;
2280+inline constexpr size_t kInputSizeTwo = 2;
2281+inline constexpr size_t kInputSizeThree = 3;
2282+inline constexpr size_t kOutputSizeOne = 1;
2283+inline constexpr size_t kMatchPathLenTwo = 2;
2284+inline constexpr size_t kMatchPathLenThree = 3;
2285+
2286+STATUS GetMatchNodeIndex(schema::MetaGraphT *graph,
2287+                         const std::unordered_map<std::string, std::shared_ptr<lite::Path>> &matched_path,
2288+                         const std::string &node_name, size_t *node_index);
2289+}  // namespace opt
2290+}  // namespace mindspore
2291+#endif  // MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_COMMON_FUSION_UTILS_H_
2292diff --git a/mindspore/lite/src/train/optimizer/fusion/matmul_activation_fusion_pass.cc b/mindspore/lite/src/train/optimizer/fusion/matmul_activation_fusion_pass.cc
2293new file mode 100644
2294index 00000000..b809f2c9
2295--- /dev/null
2296+++ b/mindspore/lite/src/train/optimizer/fusion/matmul_activation_fusion_pass.cc
2297@@ -0,0 +1,93 @@
2298+/**
2299+ * Copyright 2022 Huawei Technologies Co., Ltd
2300+ *
2301+ * Licensed under the Apache License, Version 2.0 (the "License");
2302+ * you may not use this file except in compliance with the License.
2303+ * You may obtain a copy of the License at
2304+ *
2305+ * http://www.apache.org/licenses/LICENSE-2.0
2306+ *
2307+ * Unless required by applicable law or agreed to in writing, software
2308+ * distributed under the License is distributed on an "AS IS" BASIS,
2309+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2310+ * See the License for the specific language governing permissions and
2311+ * limitations under the License.
2312+ */
2313+
2314+#include "src/train/optimizer/fusion/matmul_activation_fusion_pass.h"
2315+#include <string>
2316+#include <unordered_map>
2317+#include <vector>
2318+#include <memory>
2319+#include "schema/inner/model_generated.h"
2320+#include "tools/common/meta_graph_utils.h"
2321+#include "src/train/optimizer/common/fusion_utils.h"
2322+namespace {
2323+constexpr std::string_view MatMulName = "MATMUL";
2324+constexpr std::string_view ActName = "ACTIVATION";
2325+}  // namespace
2326+namespace mindspore {
2327+namespace lite {
2328+STATUS MatMulActivationFusionPass::DefinePattern() {
2329+  auto matmul_op = std::make_shared<PatternOp>();
2330+  MS_CHECK_TRUE_RET(matmul_op != nullptr, RET_NULL_PTR);
2331+  matmul_op->id = MatMulName;
2332+  matmul_op->types = {schema::PrimitiveType_MatMulFusion};
2333+  auto act_op = std::make_shared<PatternOp>();
2334+  MS_CHECK_TRUE_RET(act_op != nullptr, RET_NULL_PTR);
2335+  act_op->id = ActName;
2336+  act_op->types = {schema::PrimitiveType_Activation};
2337+  act_op->left = matmul_op;
2338+  auto fusion_pattern = std::make_unique<FusionPattern>("MatMulActivationFusion");
2339+  MS_CHECK_TRUE_MSG(fusion_pattern != nullptr, RET_NULL_PTR, "new fusion_pattern failed");
2340+  fusion_pattern->AddPatternOp(matmul_op);
2341+  fusion_pattern->AddPatternOp(act_op);
2342+  fusion_pattern->Finish();
2343+  this->patterns.emplace_back(fusion_pattern.release());
2344+  return RET_OK;
2345+}
2346+
2347+STATUS MatMulActivationFusionPass::DoFusion(
2348+  MetaGraphT *graph, const std::string &pattern_name,
2349+  const std::unordered_map<std::string, std::shared_ptr<Path>> &matched_path) {
2350+  MS_CHECK_TRUE_RET(graph != nullptr, RET_NULL_PTR);
2351+  if (matched_path.size() != opt::kMatchPathLenTwo) {
2352+    MS_LOG(ERROR) << "MatMul-Activation-Fusion should have two NodeIndex in matchedPair";
2353+    return RET_PARAM_INVALID;
2354+  }
2355+
2356+  size_t matmul_index = 0;
2357+  auto ret = opt::GetMatchNodeIndex(graph, matched_path, std::string(MatMulName), &matmul_index);
2358+  MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "cannot get matmul_index");
2359+  auto &matmul_node = graph->nodes.at(matmul_index);
2360+  MS_CHECK_TRUE_MSG(matmul_node != nullptr, RET_NULL_PTR, "matmul_node is nullptr");
2361+  size_t act_index = 0;
2362+  ret = opt::GetMatchNodeIndex(graph, matched_path, std::string(ActName), &act_index);
2363+  MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "cannot get act_index");
2364+  auto &act_node = graph->nodes.at(act_index);
2365+  MS_CHECK_TRUE_MSG(act_node != nullptr, RET_NULL_PTR, "act_node is nullptr");
2366+
2367+  if (matmul_node->quantType == schema::QuantType_QUANT_ALL ||
2368+      matmul_node->quantType == schema::QuantType_QUANT_DYNAMIC) {
2369+    MS_LOG(DEBUG) << "cannot fusion.";
2370+    return RET_NO_CHANGE;
2371+  }
2372+  MS_CHECK_TRUE_RET(matmul_node->primitive != nullptr, RET_NULL_PTR);
2373+  auto matmul_type = matmul_node->primitive->value.AsMatMulFusion();
2374+  MS_CHECK_TRUE_RET(matmul_type->activation_type == ActivationType::ActivationType_NO_ACTIVATION, RET_NO_CHANGE);
2375+  MS_CHECK_TRUE_RET(act_node->primitive != nullptr, RET_NULL_PTR);
2376+  auto act_type = act_node->primitive->value.AsActivation()->activation_type;
2377+  MS_CHECK_TRUE_RET(act_type == ActivationType::ActivationType_RELU || act_type == ActivationType::ActivationType_RELU6,
2378+                    RET_NO_CHANGE);
2379+  matmul_type->activation_type = act_type;
2380+  matmul_node->outputIndex = {act_node->outputIndex};
2381+  // cannot delete node here, otherwise will destroy order in other pattern's node index
2382+  // make it an isolated node to be removed in IsolatedNodeRemovePass
2383+  act_node->inputIndex.clear();
2384+  act_node->outputIndex.clear();
2385+  return RET_OK;
2386+}
2387+
2388+MatMulActivationFusionPass::~MatMulActivationFusionPass() = default;
2389+}  // namespace lite
2390+}  // namespace mindspore
2391diff --git a/mindspore/lite/src/train/optimizer/fusion/matmul_activation_fusion_pass.h b/mindspore/lite/src/train/optimizer/fusion/matmul_activation_fusion_pass.h
2392new file mode 100644
2393index 00000000..57891eb3
2394--- /dev/null
2395+++ b/mindspore/lite/src/train/optimizer/fusion/matmul_activation_fusion_pass.h
2396@@ -0,0 +1,42 @@
2397+/**
2398+ * Copyright 2022 Huawei Technologies Co., Ltd
2399+ *
2400+ * Licensed under the Apache License, Version 2.0 (the "License");
2401+ * you may not use this file except in compliance with the License.
2402+ * You may obtain a copy of the License at
2403+ *
2404+ * http://www.apache.org/licenses/LICENSE-2.0
2405+ *
2406+ * Unless required by applicable law or agreed to in writing, software
2407+ * distributed under the License is distributed on an "AS IS" BASIS,
2408+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2409+ * See the License for the specific language governing permissions and
2410+ * limitations under the License.
2411+ */
2412+
2413+#ifndef MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_MATMUL_ACTIVATION_FUSION_PASS_H_
2414+#define MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_MATMUL_ACTIVATION_FUSION_PASS_H_
2415+
2416+#include <string>
2417+#include <unordered_map>
2418+#include <memory>
2419+#include <algorithm>
2420+#include <utility>
2421+#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h"
2422+
2423+namespace mindspore {
2424+namespace lite {
2425+class MatMulActivationFusionPass : public FusionPass {
2426+ public:
2427+  MatMulActivationFusionPass() = default;
2428+
2429+  ~MatMulActivationFusionPass() override;
2430+
2431+  STATUS DefinePattern() override;
2432+
2433+  STATUS DoFusion(MetaGraphT *graph, const std::string &pattern_name,
2434+                  const std::unordered_map<std::string, std::shared_ptr<Path>> &matched_path) override;
2435+};
2436+}  // namespace lite
2437+}  // namespace mindspore
2438+#endif  // MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_MATMUL_ACTIVATION_FUSION_PASS_H_
2439diff --git a/mindspore/lite/src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.cc b/mindspore/lite/src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.cc
2440new file mode 100644
2441index 00000000..7fb8d1f4
2442--- /dev/null
2443+++ b/mindspore/lite/src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.cc
2444@@ -0,0 +1,148 @@
2445+/**
2446+ * Copyright 2022 Huawei Technologies Co., Ltd
2447+ *
2448+ * Licensed under the Apache License, Version 2.0 (the "License");
2449+ * you may not use this file except in compliance with the License.
2450+ * You may obtain a copy of the License at
2451+ *
2452+ * http://www.apache.org/licenses/LICENSE-2.0
2453+ *
2454+ * Unless required by applicable law or agreed to in writing, software
2455+ * distributed under the License is distributed on an "AS IS" BASIS,
2456+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2457+ * See the License for the specific language governing permissions and
2458+ * limitations under the License.
2459+ */
2460+
2461+#include "src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.h"
2462+#include <string>
2463+#include <unordered_map>
2464+#include <vector>
2465+#include <memory>
2466+#include "schema/inner/model_generated.h"
2467+#include "tools/common/meta_graph_utils.h"
2468+#include "src/train/optimizer/common/fusion_utils.h"
2469+
2470+namespace {
2471+constexpr std::string_view Reshape1Name = "RESHAPE1";
2472+constexpr std::string_view Reshape2Name = "RESHAPE2";
2473+constexpr std::string_view GatherName = "GATHER";
2474+}  // namespace
2475+namespace mindspore {
2476+namespace lite {
2477+/*
2478+ * The subgraph such as the following.
2479+ *            any
2480+ *           /  |
2481+ *     reshape  |
2482+ *           \  |
2483+ *           gather
2484+ *           /  |
2485+ *     reshape  |
2486+ *           \  |
2487+ *            any
2488+ */
2489+STATUS ReshapeGatherReshapeFusionPass::DefinePattern() {
2490+  auto reshape_op1 = std::make_shared<PatternOp>();
2491+  MS_CHECK_TRUE_RET(reshape_op1 != nullptr, RET_NULL_PTR);
2492+  reshape_op1->id = Reshape1Name;
2493+  reshape_op1->types = {schema::PrimitiveType_Reshape};
2494+
2495+  auto gather_op = std::make_shared<PatternOp>();
2496+  MS_CHECK_TRUE_RET(gather_op != nullptr, RET_NULL_PTR);
2497+  gather_op->id = GatherName;
2498+  gather_op->types = {schema::PrimitiveType_Gather};
2499+  gather_op->left = reshape_op1;
2500+
2501+  auto reshape_op2 = std::make_shared<PatternOp>();
2502+  MS_CHECK_TRUE_RET(reshape_op2 != nullptr, RET_NULL_PTR);
2503+  reshape_op2->id = Reshape2Name;
2504+  reshape_op2->types = {schema::PrimitiveType_Reshape};
2505+  reshape_op2->left = gather_op;
2506+
2507+  auto fusion_pattern = std::make_unique<FusionPattern>("ReshapeGatherReshapeFusion");
2508+  MS_CHECK_TRUE_MSG(fusion_pattern != nullptr, RET_NULL_PTR, "new fusion_pattern failed");
2509+  fusion_pattern->AddPatternOp(reshape_op1);
2510+  fusion_pattern->AddPatternOp(gather_op);
2511+  fusion_pattern->AddPatternOp(reshape_op2);
2512+  fusion_pattern->Finish();
2513+  this->patterns.emplace_back(fusion_pattern.release());
2514+  return RET_OK;
2515+}
2516+
2517+STATUS ReshapeGatherReshapeFusionPass::DoFusion(
2518+  MetaGraphT *graph, const std::string &pattern_name,
2519+  const std::unordered_map<std::string, std::shared_ptr<Path>> &matched_path) {
2520+  MS_CHECK_TRUE_RET(graph != nullptr, RET_NULL_PTR);
2521+  if (matched_path.size() != opt::kMatchPathLenThree) {
2522+    MS_LOG(ERROR) << "Reshape-Gather-Reshape-Fusion should have three NodeIndex in matchedPair";
2523+    return RET_PARAM_INVALID;
2524+  }
2525+
2526+  size_t reshape1_index = 0;
2527+  STATUS ret = opt::GetMatchNodeIndex(graph, matched_path, std::string(Reshape1Name), &reshape1_index);
2528+  MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "cannot get reshape1_index");
2529+  auto &reshape1_node = graph->nodes.at(reshape1_index);
2530+  MS_CHECK_TRUE_MSG(reshape1_node != nullptr, RET_NULL_PTR, "reshape1_node is nullptr");
2531+  size_t gather_index = 0;
2532+  ret = opt::GetMatchNodeIndex(graph, matched_path, std::string(GatherName), &gather_index);
2533+  MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "cannot get gather_index");
2534+  auto &gather_node = graph->nodes.at(gather_index);
2535+  MS_CHECK_TRUE_MSG(gather_node != nullptr, RET_NULL_PTR, "gather_node is nullptr");
2536+  size_t reshape2_index = 0;
2537+  ret = opt::GetMatchNodeIndex(graph, matched_path, std::string(Reshape2Name), &reshape2_index);
2538+  MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "cannot get reshape2_index");
2539+  auto &reshape2_node = graph->nodes.at(reshape2_index);
2540+  MS_CHECK_TRUE_MSG(reshape2_node != nullptr, RET_NULL_PTR, "reshape2_node is nullptr");
2541+
2542+  if (reshape1_node->inputIndex.size() != opt::kInputSizeTwo ||
2543+      reshape1_node->outputIndex.size() != opt::kOutputSizeOne ||
2544+      reshape1_node->quantType == schema::QuantType_QUANT_ALL ||
2545+      reshape1_node->quantType == schema::QuantType_QUANT_DYNAMIC ||
2546+      reshape2_node->inputIndex.size() != opt::kInputSizeTwo ||
2547+      reshape2_node->outputIndex.size() != opt::kOutputSizeOne ||
2548+      reshape2_node->quantType == schema::QuantType_QUANT_ALL ||
2549+      reshape2_node->quantType == schema::QuantType_QUANT_DYNAMIC ||
2550+      gather_node->quantType == schema::QuantType_QUANT_ALL ||
2551+      gather_node->quantType == schema::QuantType_QUANT_DYNAMIC) {
2552+    MS_LOG(ERROR) << "reshape_node cannot fusion";
2553+    return RET_NO_CHANGE;
2554+  }
2555+
2556+  auto old_shape = graph->allTensors.at(reshape2_node->outputIndex.at(opt::kOutputIndexZero))->dims;
2557+  auto gather_shape0 = graph->allTensors.at(gather_node->inputIndex.at(opt::kInputIndexZero))->dims;
2558+  auto gather_shape1 = graph->allTensors.at(reshape1_node->inputIndex.at(opt::kInputIndexZero))->dims;
2559+  if (old_shape.empty() || gather_shape0.empty() || gather_shape1.empty()) {
2560+    return RET_NO_CHANGE;
2561+  }
2562+  int gather_axis;
2563+  auto data = graph->allTensors.at(gather_node->inputIndex.at(opt::kInputIndexTwo))->data;
2564+  if (data.empty()) {
2565+    gather_axis = 0;
2566+  } else {
2567+    memcpy(&gather_axis, &data[0], data.size());
2568+  }
2569+  if (gather_axis < 0) {
2570+    gather_axis += gather_shape1.size();
2571+  }
2572+  gather_shape0.erase(gather_shape0.begin() + gather_axis);
2573+  (void)gather_shape0.insert(gather_shape0.begin() + gather_axis, gather_shape1.begin(), gather_shape1.end());
2574+  if (gather_shape0 != old_shape) {
2575+    return RET_NO_CHANGE;
2576+  }
2577+
2578+  gather_node->inputIndex.at(opt::kInputIndexOne) = reshape1_node->inputIndex.at(opt::kInputIndexZero);
2579+  gather_node->outputIndex.at(opt::kOutputIndexZero) = reshape2_node->outputIndex.at(opt::kOutputIndexZero);
2580+
2581+  // cannot delete node here, otherwise will destroy order in other pattern's node index
2582+  // make it an isolated node to be removed in IsolatedNodeRemovePass
2583+  reshape1_node->inputIndex.clear();
2584+  reshape1_node->outputIndex.clear();
2585+  reshape2_node->inputIndex.clear();
2586+  reshape2_node->outputIndex.clear();
2587+  return RET_OK;
2588+}
2589+
2590+ReshapeGatherReshapeFusionPass::~ReshapeGatherReshapeFusionPass() = default;
2591+}  // namespace lite
2592+}  // namespace mindspore
2593diff --git a/mindspore/lite/src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.h b/mindspore/lite/src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.h
2594new file mode 100644
2595index 00000000..ef184a3c
2596--- /dev/null
2597+++ b/mindspore/lite/src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.h
2598@@ -0,0 +1,42 @@
2599+/**
2600+ * Copyright 2022 Huawei Technologies Co., Ltd
2601+ *
2602+ * Licensed under the Apache License, Version 2.0 (the "License");
2603+ * you may not use this file except in compliance with the License.
2604+ * You may obtain a copy of the License at
2605+ *
2606+ * http://www.apache.org/licenses/LICENSE-2.0
2607+ *
2608+ * Unless required by applicable law or agreed to in writing, software
2609+ * distributed under the License is distributed on an "AS IS" BASIS,
2610+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2611+ * See the License for the specific language governing permissions and
2612+ * limitations under the License.
2613+ */
2614+
2615+#ifndef MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_RESHAPE_GATHER_RESHAPE_FUSION_PASS_H_
2616+#define MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_RESHAPE_GATHER_RESHAPE_FUSION_PASS_H_
2617+
2618+#include <string>
2619+#include <unordered_map>
2620+#include <memory>
2621+#include <algorithm>
2622+#include <utility>
2623+#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h"
2624+
2625+namespace mindspore {
2626+namespace lite {
2627+class ReshapeGatherReshapeFusionPass : public FusionPass {
2628+ public:
2629+  ReshapeGatherReshapeFusionPass() = default;
2630+
2631+  ~ReshapeGatherReshapeFusionPass() override;
2632+
2633+  STATUS DefinePattern() override;
2634+
2635+  STATUS DoFusion(MetaGraphT *graph, const std::string &pattern_name,
2636+                  const std::unordered_map<std::string, std::shared_ptr<Path>> &matched_path) override;
2637+};
2638+}  // namespace lite
2639+}  // namespace mindspore
2640+#endif  // MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_RESHAPE_GATHER_RESHAPE_FUSION_PASS_H_
2641diff --git a/mindspore/lite/src/train/train_export.cc b/mindspore/lite/src/train/train_export.cc
2642index a9990963..7e504c4e 100644
2643--- a/mindspore/lite/src/train/train_export.cc
2644+++ b/mindspore/lite/src/train/train_export.cc
2645@@ -612,8 +612,40 @@ int TrainExport::SaveModel(lite::Model *model, const std::string &file_name) {
2646   return status;
2647 }
2648
2649+int TrainExport::SaveModel(lite::Model *model, Buffer *model_buffer) {
2650+  MS_CHECK_FALSE_MSG(model == nullptr, RET_ERROR, "model cannot be empty.");
2651+  MS_CHECK_FALSE_MSG(model_buffer == nullptr, RET_ERROR, "model_buffer cannot be empty.");
2652+  auto *liteModel = reinterpret_cast<LiteModel *>(model);
2653+  auto size = liteModel->buf_size_;
2654+  model_buffer->ResizeData(size);
2655+
2656+  size_t out_size = model_buffer->DataSize();
2657+  int status = mindspore::lite::Model::Export(model, static_cast<char *>(model_buffer->MutableData()), &out_size);
2658+  if (out_size != size) {
2659+    MS_LOG(ERROR) << "model_buffer resize failed.";
2660+    return RET_ERROR;
2661+  }
2662+
2663+  return status;
2664+}
2665+
2666 int TrainExport::SaveToFile() { return Storage::Save(*meta_graph_, file_name_); }
2667
2668+int TrainExport::SaveToBuffer() {
2669+  constexpr size_t kFbBuilderInitSize = 1024;
2670+  flatbuffers::FlatBufferBuilder builder(kFbBuilderInitSize);
2671+  auto offset = schema::MetaGraph::Pack(builder, meta_graph_);
2672+  builder.Finish(offset);
2673+  schema::FinishMetaGraphBuffer(builder, offset);
2674+  size_t size = builder.GetSize();
2675+  auto content = builder.GetBufferPointer();
2676+  MS_CHECK_FALSE_MSG(content == nullptr, RET_ERROR, "context cannot be empty.");
2677+  MS_CHECK_FALSE_MSG(model_buffer_ == nullptr, RET_ERROR, "context cannot be empty.");
2678+  model_buffer_->SetData(content, size);
2679+  return RET_OK;
2680+}
2681+
2682+
2683 bool TrainExport::IsInputTensor(const schema::TensorT &t) {
2684   int total_dims = std::accumulate(t.dims.begin(), t.dims.end(), 1, std::multiplies<int>());
2685   return ((t.data.size() == 0) && (total_dims != 0));
2686diff --git a/mindspore/lite/src/train/train_export.h b/mindspore/lite/src/train/train_export.h
2687index 7727109b..8e802021 100644
2688--- a/mindspore/lite/src/train/train_export.h
2689+++ b/mindspore/lite/src/train/train_export.h
2690@@ -44,24 +44,28 @@ struct tensor_info {
2691 class TrainExport {
2692  public:
2693   explicit TrainExport(const std::string file_name) : file_name_(file_name) {}
2694+  explicit TrainExport(Buffer *model_buffer) : model_buffer_(model_buffer) {}
2695   virtual ~TrainExport();
2696   int ExportNet(const std::vector<mindspore::kernel::KernelExec *> &kernels,
2697                 const std::vector<mindspore::lite::Tensor *> &tensors, const std::vector<std::string> &output_names,
2698                 const Model *model, QuantizationType quant_type, const Model *bb_model = nullptr);
2699   int ExportInit(const std::string model_name, std::string version);
2700   int SaveToFile();
2701+  int SaveToBuffer();
2702   void set_connect(const std::unordered_map<size_t, size_t> &map) { connect_ = map; }
2703   int LoadModel(void *buf, size_t buf_size);
2704   int AddTransformNode();
2705   int TrainModelFusion();
2706   int TrainModelDrop();
2707   int SaveModel(lite::Model *model, const std::string &file_name);
2708+  int SaveModel(lite::Model *model, Buffer *model_buffer);
2709
2710  protected:
2711   virtual std::vector<uint8_t> CreateData(const mindspore::lite::Tensor *tensor);
2712
2713  private:
2714   std::string file_name_;
2715+  Buffer *model_buffer_ = nullptr;
2716   schema::MetaGraphT *meta_graph_ = nullptr;
2717   std::vector<size_t> out_idx_;
2718   std::map<size_t, size_t> remap_;
2719diff --git a/mindspore/lite/src/train/train_populate_parameter.cc b/mindspore/lite/src/train/train_populate_parameter.cc
2720index bda5d0a5..9874a30d 100644
2721--- a/mindspore/lite/src/train/train_populate_parameter.cc
2722+++ b/mindspore/lite/src/train/train_populate_parameter.cc
2723@@ -31,6 +31,8 @@
2724 #include "nnacl/fp32_grad/smooth_l1_loss.h"
2725 #include "nnacl/fp32_grad/resize_grad_parameter.h"
2726 #include "nnacl/fp32_grad/lstm_grad_fp32.h"
2727+#include "nnacl/fp32_grad/binary_cross_entropy.h"
2728+#include "nnacl/fp32_grad/binary_cross_entropy_grad.h"
2729
2730 using mindspore::lite::Registry;
2731
2732@@ -88,29 +90,44 @@ OpParameter *PopulateApplyMomentumParameter(const void *prim) {
2733 }
2734
2735 OpParameter *PopulateBCEParameter(const void *prim) {
2736-  int32_t *reduction = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t)));
2737-  if (reduction == nullptr) {
2738-    MS_LOG(ERROR) << "malloc reduction failed.";
2739-    return nullptr;
2740-  }
2741   auto primitive = static_cast<const schema::Primitive *>(prim);
2742+  MS_ASSERT(primitive != nullptr);
2743   auto value = primitive->value_as_BinaryCrossEntropy();
2744-  MS_ASSERT(value != nullptr);
2745-  *reduction = value->reduction();
2746-  return reinterpret_cast<OpParameter *>(reduction);
2747+  if (value == nullptr) {
2748+    MS_LOG(ERROR) << "value is nullptr";
2749+    return nullptr;
2750+  }
2751+
2752+  auto *param = reinterpret_cast<BinaryCrossEntropyParameter *>(malloc(sizeof(BinaryCrossEntropyParameter)));
2753+  if (param == nullptr) {
2754+    MS_LOG(ERROR) << "malloc BinaryCrossEntropy Parameter failed.";
2755+    return nullptr;
2756+  }
2757+  memset(param, 0, sizeof(BinaryCrossEntropyParameter));
2758+
2759+  param->op_parameter_.type_ = primitive->value_type();
2760+  param->reduction = value->reduction();
2761+  return reinterpret_cast<OpParameter *>(param);
2762 }
2763
2764 OpParameter *PopulateBCEGradParameter(const void *prim) {
2765-  int32_t *reduction = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t)));
2766-  if (reduction == nullptr) {
2767-    MS_LOG(ERROR) << "malloc reduction failed.";
2768+  auto *primitive = static_cast<const schema::Primitive *>(prim);
2769+  MS_ASSERT(primitive != nullptr);
2770+  auto value = primitive->value_as_BinaryCrossEntropyGrad();
2771+  if (value == nullptr) {
2772+    MS_LOG(ERROR) << "param is nullptr";
2773     return nullptr;
2774   }
2775-  auto primitive = static_cast<const schema::Primitive *>(prim);
2776-  auto value = primitive->value_as_BinaryCrossEntropyGrad();
2777-  MS_ASSERT(value != nullptr);
2778-  *reduction = value->reduction();
2779-  return reinterpret_cast<OpParameter *>(reduction);
2780+  auto *param = reinterpret_cast<BinaryCrossEntropyGradParameter *>(malloc(sizeof(BinaryCrossEntropyGradParameter)));
2781+  if (param == nullptr) {
2782+    MS_LOG(ERROR) << "malloc BinaryCrossEntropyGrad Parameter failed.";
2783+    return nullptr;
2784+  }
2785+  memset(param, 0, sizeof(BinaryCrossEntropyGradParameter));
2786+
2787+  param->op_parameter_.type_ = primitive->value_type();
2788+  param->reduction = value->reduction();
2789+  return reinterpret_cast<OpParameter *>(param);
2790 }
2791
2792 OpParameter *PopulateAdamParameter(const void *prim) {
2793diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc
2794index cd1af1b6..b40ff8c2 100644
2795--- a/mindspore/lite/src/train/train_session.cc
2796+++ b/mindspore/lite/src/train/train_session.cc
2797@@ -206,10 +206,11 @@ static int ReshapeWeightTensor(Tensor *orig_tensor, lite::Tensor *new_tensor) {
2798     }
2799   }
2800
2801-  orig_tensor->FreeData();
2802-  orig_tensor->set_data(nullptr);
2803-  orig_tensor->set_shape(new_tensor->shape());
2804-
2805+  if (orig_tensor->shape() != new_tensor->shape()) {
2806+    orig_tensor->FreeData();
2807+    orig_tensor->set_data(nullptr);
2808+    orig_tensor->set_shape(new_tensor->shape());
2809+  }
2810   uint8_t *dst_data = reinterpret_cast<uint8_t *>(orig_tensor->MutableData());
2811   if (dst_data == nullptr) {
2812     MS_LOG(ERROR) << "Allocation of Data Failed";
2813@@ -228,6 +229,9 @@ int TrainSession::UpdateWeights(std::vector<lite::Tensor *> modify_tensors) {
2814         return RET_PARAM_INVALID;
2815       }
2816       if (modify->tensor_name() == tensor->tensor_name()) {
2817+        if (tensor->Size() != modify->Size()) {
2818+          model_buff_changed_ = true;
2819+        }
2820         auto ret = ReshapeWeightTensor(tensor, modify);
2821         num_of_found_tensors++;
2822         if (ret != RET_OK) {
2823@@ -243,6 +247,7 @@ int TrainSession::UpdateWeights(std::vector<lite::Tensor *> modify_tensors) {
2824   }
2825   auto ret = ReSizeKernels(kernels_);
2826   if (ret != RET_OK) {
2827+    model_buff_changed_ = false;
2828     MS_LOG(ERROR) << "Resize kernels fail!";
2829     return ret;
2830   }
2831@@ -1154,9 +1159,17 @@ int TrainSession::FindExportKernels(std::vector<kernel::KernelExec *> *export_ke
2832   return RET_OK;
2833 }
2834
2835-int TrainSession::Export(const std::string &file_name, ModelType model_type, QuantizationType quant_type,
2836-                         FormatType format, std::vector<std::string> out_put_tensor_name) {
2837-  MS_CHECK_FALSE_MSG(file_name.empty(), RET_ERROR, "File name cannot be empty");
2838+template <typename DestType>
2839+int TrainSession::ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type,
2840+                              FormatType format, std::vector<std::string> out_put_tensor_name) {
2841+  if constexpr (std::is_same_v<DestType, const std::string &>) {
2842+    MS_CHECK_FALSE_MSG(destination.empty(), RET_ERROR, "File name cannot be empty");
2843+  } else if constexpr (std::is_same_v<DestType, Buffer *>) {
2844+    MS_CHECK_FALSE_MSG(destination == nullptr, RET_ERROR, "model buffer cannot be nullptr");
2845+  } else {
2846+    MS_LOG(ERROR) << "Unsupported destination.";
2847+    return RET_ERROR;
2848+  }
2849   MS_CHECK_FALSE_MSG(model_type > mindspore::lite::MT_INFERENCE || model_type < mindspore::lite::MT_TRAIN, RET_ERROR,
2850                      "Export model type parameter error");
2851   MS_CHECK_FALSE_MSG(quant_type < mindspore::lite::QT_DEFAULT || quant_type > mindspore::lite::QT_WEIGHT, RET_ERROR,
2852@@ -1165,27 +1178,21 @@ int TrainSession::Export(const std::string &file_name, ModelType model_type, Qua
2853
2854   bool orig_train_state = IsTrain();
2855   Eval();
2856-  TrainExport texport(file_name);
2857+  TrainExport texport(destination);
2858   int status = texport.ExportInit(model_.get()->graph_.name_, model_.get()->graph_.version_);
2859-  if (status != RET_OK) {
2860-    MS_LOG(ERROR) << "cannot init export";
2861-    return status;
2862-  }
2863+  TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to init export");
2864
2865   if (!out_put_tensor_name.empty() && model_type == MT_INFERENCE) {
2866     std::vector<kernel::KernelExec *> export_kernels = {};
2867     status = FindExportKernels(&export_kernels, out_put_tensor_name, inference_kernels_);
2868-    if (status != RET_OK) {
2869-      MS_LOG(ERROR) << "FindExportKernels failed.";
2870-      return RET_ERROR;
2871-    }
2872+    TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "FindExportKernels failed.");
2873     status = texport.ExportNet(export_kernels, tensors_, out_put_tensor_name, model_.get(), quant_type);
2874   } else {
2875-    if ((quant_type == QT_NONE) && (model_type == MT_TRAIN) &&
2876+    if ((!model_buff_changed_) && (quant_type == QT_NONE) && (model_type == MT_TRAIN) &&
2877         std::all_of(model_->graph_.all_nodes_.begin(), model_->graph_.all_nodes_.end(), [](const LiteGraph::Node *n) {
2878           return n->quant_type_ == schema::QuantType::QuantType_QUANT_NONE;
2879         })) {
2880-      status = texport.SaveModel(model_.get(), file_name);
2881+      status = texport.SaveModel(model_.get(), destination);
2882       if (orig_train_state) Train();
2883       return status;
2884     } else {
2885@@ -1194,35 +1201,42 @@ int TrainSession::Export(const std::string &file_name, ModelType model_type, Qua
2886                                  model_.get(), quant_type);
2887     }
2888   }
2889+  TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to export Network.");
2890
2891-  if (status != RET_OK) {
2892-    MS_LOG(ERROR) << "cannot export Network";
2893-    return status;
2894-  }
2895   if (model_type == MT_INFERENCE) {
2896     status = texport.TrainModelDrop();
2897-    if (status != RET_OK) {
2898-      MS_LOG(ERROR) << "TrainModelDrop failed.";
2899-      return status;
2900-    }
2901+    TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelDrop failed.");
2902     status = texport.TrainModelFusion();
2903+    TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelFusion failed.");
2904+  }
2905+  if constexpr (std::is_same_v<DestType, const std::string &>) {
2906+    status = texport.SaveToFile();
2907     if (status != RET_OK) {
2908-      MS_LOG(ERROR) << "TrainModelFusion failed.";
2909+      MS_LOG(ERROR) << "failed to save to " << destination;
2910       return status;
2911     }
2912-  }
2913-  status = texport.SaveToFile();
2914-  if (status != RET_OK) {
2915-    MS_LOG(ERROR) << "failed to save to " << file_name;
2916-    return status;
2917+  } else {
2918+    status = texport.SaveToBuffer();
2919+    TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "fail to save to model buffer.");
2920   }
2921   if (orig_train_state) Train();
2922   return status;
2923 }
2924+
2925+int TrainSession::Export(const std::string &file_name, ModelType model_type, QuantizationType quant_type,
2926+                         FormatType format, std::vector<std::string> out_put_tensor_name) {
2927+  return ExportInner<const std::string &>(file_name, model_type, quant_type, format, out_put_tensor_name);
2928+}
2929+
2930+int TrainSession::Export(Buffer *model_buffer, ModelType model_type, QuantizationType quant_type, FormatType format,
2931+                         std::vector<std::string> out_put_tensor_name) {
2932+  return ExportInner<Buffer *>(model_buffer, model_type, quant_type, format, out_put_tensor_name);
2933+}
2934+
2935 std::vector<lite::Tensor *> TrainSession::GetFeatureMaps() const {
2936   std::vector<lite::Tensor *> features;
2937   for (auto cur_tensor : this->tensors_) {
2938-    if (cur_tensor->IsConst() && cur_tensor->data_type() == kNumberTypeFloat32) {
2939+    if (cur_tensor->category() ==lite::Category::CONST_TENSOR && cur_tensor->data_type() == kNumberTypeFloat32) {
2940       features.push_back(cur_tensor);
2941     }
2942   }
2943diff --git a/mindspore/lite/src/train/train_session.h b/mindspore/lite/src/train/train_session.h
2944index 0a0ce640..5acff82a 100644
2945--- a/mindspore/lite/src/train/train_session.h
2946+++ b/mindspore/lite/src/train/train_session.h
2947@@ -36,6 +36,14 @@
2948   +-------------------------------+
2949 */
2950
2951+#define TRAIN_SESSION_CHECK_FALSE_MSG(value, errcode, msg) \
2952+  do {                                                     \
2953+    if ((value)) {                                         \
2954+      MS_LOG(ERROR) << #msg;                               \
2955+      return errcode;                                      \
2956+    }                                                      \
2957+  } while (0)
2958+
2959 namespace mindspore {
2960 namespace lite {
2961 using CreatorOp = std::tuple<mindspore::kernel::KernelKey, mindspore::kernel::KernelCreator>;
2962@@ -96,6 +104,8 @@ class TrainSession : virtual public lite::LiteSession {
2963   }
2964   int Export(const std::string &fb_name, ModelType model_type, QuantizationType quant_type, FormatType,
2965              std::vector<std::string> out_put_tensor_name = {}) override;
2966+  int Export(Buffer *model_buffer, ModelType model_type, QuantizationType quant_type, FormatType,
2967+             std::vector<std::string> out_put_tensor_name = {}) override;
2968
2969   std::vector<lite::Tensor *> GetFeatureMaps() const override;
2970
2971@@ -165,6 +175,9 @@ class TrainSession : virtual public lite::LiteSession {
2972                                 const std::unordered_map<lite::Tensor *, size_t> &offset_map,
2973                                 std::unordered_map<lite::Tensor *, int> *ref_count, uint32_t input_idx);
2974
2975+  template <typename DestType>
2976+  int ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type, FormatType,
2977+                  std::vector<std::string> out_put_tensor_name = {});
2978   std::map<Tensor *, Tensor *> restored_origin_tensors_;
2979   int virtual_batch_idx_ = 0;
2980   int virtual_batch_multiplier_ = 0;
2981@@ -172,6 +185,7 @@ class TrainSession : virtual public lite::LiteSession {
2982   void *workspace_ = nullptr;
2983   SchedCallBack sched_mix_precision_callback_;
2984   bool train_mode_ = false;
2985+  bool model_buff_changed_ = false;
2986   void *tensors_data_ = nullptr;
2987   size_t tensors_data_size_ = 0;
2988   std::shared_ptr<Allocator> allocator_;
2989diff --git a/mindspore/lite/src/train/transfer_session.cc b/mindspore/lite/src/train/transfer_session.cc
2990index b54f348e..031c4a6b 100644
2991--- a/mindspore/lite/src/train/transfer_session.cc
2992+++ b/mindspore/lite/src/train/transfer_session.cc
2993@@ -183,15 +183,24 @@ std::unordered_map<size_t, size_t> TransferSession::ConnectionMap() {
2994   return map;
2995 }
2996
2997-int TransferSession::Export(const std::string &filename, ModelType model_type, QuantizationType quant_type,
2998-                            FormatType format, std::vector<std::string> out_put_tensor_name) {
2999+template <typename DestType>
3000+int TransferSession::ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type,
3001+                                 FormatType format, std::vector<std::string> out_put_tensor_name) {
3002+  if constexpr (std::is_same_v<DestType, const std::string &>) {
3003+    MS_CHECK_FALSE_MSG(destination.empty(), RET_ERROR, "File name cannot be empty");
3004+  } else if constexpr (std::is_same_v<DestType, Buffer *>) {
3005+    MS_CHECK_FALSE_MSG(destination == nullptr, RET_ERROR, "model buffer cannot be nullptr");
3006+  } else {
3007+    MS_LOG(ERROR) << "Unsupported destination.";
3008+    return RET_ERROR;
3009+  }
3010   if (format != FT_FLATBUFFERS) {
3011     MS_LOG(ERROR) << "Currently only flatbuffer format is supported";
3012     return RET_ERROR;
3013   }
3014
3015   if (model_type == MT_TRAIN) {
3016-    return TrainSession::Export(filename, model_type, quant_type, format);
3017+    return TrainSession::Export(destination, model_type, quant_type, format);
3018   }
3019
3020   bool orig_train_state = IsTrain();
3021@@ -199,7 +208,7 @@ int TransferSession::Export(const std::string &filename, ModelType model_type, Q
3022     MS_LOG(ERROR) << "eval failed.";
3023     return RET_ERROR;
3024   }
3025-  TrainExport texport(filename);
3026+  TrainExport texport(destination);
3027   int status = texport.LoadModel(lite_model_, size_backbone_);
3028   if (status != RET_OK) {
3029     MS_LOG(ERROR) << "cannot init export";
3030@@ -231,10 +240,15 @@ int TransferSession::Export(const std::string &filename, ModelType model_type, Q
3031     MS_LOG(ERROR) << "cannot serialize head";
3032     return status;
3033   }
3034-  status = texport.SaveToFile();
3035-  if (status != RET_OK) {
3036-    MS_LOG(ERROR) << "failed to save to " << filename;
3037-    return status;
3038+  if constexpr (std::is_same_v<DestType, const std::string &>) {
3039+    status = texport.SaveToFile();
3040+    if (status != RET_OK) {
3041+      MS_LOG(ERROR) << "failed to save to " << destination;
3042+      return status;
3043+    }
3044+  } else {
3045+    status = texport.SaveToBuffer();
3046+    MS_CHECK_FALSE_MSG(status != RET_OK, status, "fail to save to model buffer.");
3047   }
3048   if (orig_train_state) {
3049     auto ret = Train();
3050@@ -246,6 +260,17 @@ int TransferSession::Export(const std::string &filename, ModelType model_type, Q
3051   return status;
3052 }
3053
3054+int TransferSession::Export(const std::string &filename, ModelType model_type, QuantizationType quant_type,
3055+                            FormatType format, std::vector<std::string> out_put_tensor_name) {
3056+  return ExportInner<const std::string &>(filename, model_type, quant_type, format, out_put_tensor_name);
3057+}
3058+
3059+int TransferSession::Export(Buffer *model_buffer, ModelType model_type, QuantizationType quant_type, FormatType format,
3060+                            std::vector<std::string> out_put_tensor_name) {
3061+  return ExportInner<Buffer *>(model_buffer, model_type, quant_type, format, out_put_tensor_name);
3062+}
3063+
3064+
3065 lite::LiteSession *CreateTransferSessionInt(const char *model_buf_backbone, size_t size_backbone,
3066                                             const char *model_buf_head, size_t size_head, const lite::Context *context,
3067                                             bool train_mode, const lite::TrainCfg *cfg) {
3068diff --git a/mindspore/lite/src/train/transfer_session.h b/mindspore/lite/src/train/transfer_session.h
3069index 48a38b8b..6cd06c60 100644
3070--- a/mindspore/lite/src/train/transfer_session.h
3071+++ b/mindspore/lite/src/train/transfer_session.h
3072@@ -63,6 +63,8 @@ class TransferSession : public lite::TrainSession {
3073   int CompileTransferGraph();
3074   int Export(const std::string &fb_name, ModelType model_type, QuantizationType quant_type, FormatType,
3075              std::vector<std::string> out_put_tensor_name = {}) override;
3076+  int Export(Buffer *model_buffer, ModelType model_type, QuantizationType quant_type, FormatType,
3077+             std::vector<std::string> out_put_tensor_name = {}) override;
3078
3079  protected:
3080   LiteSession *backbone_session_ = nullptr;
3081@@ -72,6 +74,9 @@ class TransferSession : public lite::TrainSession {
3082   bool is_valid_ = false;
3083
3084  private:
3085+  template <typename DestType>
3086+  int ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type, FormatType,
3087+                  std::vector<std::string> out_put_tensor_name = {});
3088   bool CompileFormatTransform(lite::Tensor *out, lite::Tensor *in, int *mask, size_t mask_len);
3089   std::unordered_map<size_t, size_t> ConnectionMap();
3090   bool nchw2nhwc_ = false;
3091diff --git a/mindspore/lite/tools/benchmark_train/net_train.cc b/mindspore/lite/tools/benchmark_train/net_train.cc
3092index cad52545..20d2d298 100644
3093--- a/mindspore/lite/tools/benchmark_train/net_train.cc
3094+++ b/mindspore/lite/tools/benchmark_train/net_train.cc
3095@@ -1,5 +1,5 @@
3096 /**
3097- * Copyright 2020 Huawei Technologies Co., Ltd
3098+ * Copyright 2020-2023 Huawei Technologies Co., Ltd
3099  *
3100  * Licensed under the Apache License, Version 2.0 (the "License");
3101  * you may not use this file except in compliance with the License.
3102@@ -42,6 +42,21 @@ constexpr int kField4 = 4;
3103 constexpr int kFieldsToPrint = 5;
3104 constexpr int kPrintOffset = 4;
3105 static const int kTHOUSAND = 1000;
3106+constexpr int kDumpInputsAndOutputs = 0;
3107+constexpr int kDumpOutputs = 2;
3108+
3109+const std::unordered_map<int, std::string> kTypeIdMap{
3110+  {kNumberTypeFloat16, "Float16"}, {kNumberTypeFloat, "Float32"},    {kNumberTypeFloat32, "Float32"},
3111+  {kNumberTypeInt8, "Int8"},       {kNumberTypeInt16, "Int16"},      {kNumberTypeInt, "Int32"},
3112+  {kNumberTypeInt32, "Int32"},     {kNumberTypeUInt8, "UInt8"},      {kNumberTypeUInt16, "UInt16"},
3113+  {kNumberTypeUInt, "UInt32"},     {kNumberTypeUInt32, "UInt32"},    {kObjectTypeString, "String"},
3114+  {kNumberTypeBool, "Bool"},       {kObjectTypeTensorType, "Tensor"}};
3115+
3116+const std::unordered_map<mindspore::Format, std::string> kTensorFormatMap{
3117+  {mindspore::NCHW, "NCHW"}, {mindspore::NHWC, "NHWC"},     {mindspore::NHWC4, "NHWC4"}, {mindspore::HWKC, "HWKC"},
3118+  {mindspore::HWCK, "HWCK"}, {mindspore::KCHW, "KCHW"},     {mindspore::CKHW, "CKHW"},   {mindspore::KHWC, "KHWC"},
3119+  {mindspore::CHWK, "CHWK"}, {mindspore::HW, "HW"},         {mindspore::HW4, "HW4"},     {mindspore::NC, "NC"},
3120+  {mindspore::NC4, "NC4"},   {mindspore::NC4HW4, "NC4HW4"}, {mindspore::NCDHW, "NCDHW"}};
3121
3122 std::function<int(NetTrainFlags *)> NetTrain::nr_cb_ = nullptr;
3123
3124@@ -249,9 +264,7 @@ int NetTrain::MarkPerformance() {
3125
3126   for (int i = 0; i < flags_->epochs_; i++) {
3127     auto start = GetTimeUs();
3128-    auto status = flags_->time_profiling_
3129-                    ? ms_model_.Predict(ms_inputs_for_api_, &outputs, before_call_back_, after_call_back_)
3130-                    : ms_model_.Predict(ms_inputs_for_api_, &outputs);
3131+    auto status = ms_model_.RunStep(before_call_back_, after_call_back_);
3132     if (status != mindspore::kSuccess) {
3133       MS_LOG(ERROR) << "Inference error " << status;
3134       std::cerr << "Inference error " << status;
3135@@ -300,7 +313,7 @@ int NetTrain::MarkAccuracy(bool enforce_accuracy) {
3136     }
3137   }
3138   std::vector<MSTensor> outputs;
3139-  auto status = ms_model_.Predict(ms_inputs_for_api_, &outputs);
3140+  auto status = ms_model_.RunStep(before_call_back_, after_call_back_);
3141   if (status != mindspore::kSuccess) {
3142     MS_LOG(ERROR) << "Inference error " << status;
3143     std::cerr << "Inference error " << status << std::endl;
3144@@ -405,17 +418,22 @@ void NetTrain::InitTrainCfg(const std::shared_ptr<TrainCfg> &train_cfg) {
3145   if (flags_->loss_name_.empty()) {
3146     return;
3147   }
3148-  train_cfg->loss_name_.clear();
3149+  std::vector<std::string> empty_loss_name;
3150+  train_cfg->SetLossName(empty_loss_name);  // clear train_cfg's loss_name
3151   std::string delimiter = ",";
3152   size_t pos = 0;
3153   std::string token;
3154   while ((pos = flags_->loss_name_.find(delimiter)) != std::string::npos) {
3155     token = flags_->loss_name_.substr(0, pos);
3156     flags_->loss_name_.erase(0, pos + delimiter.length());  // change to delim without deletion
3157-    train_cfg->loss_name_.emplace_back(token);
3158+    std::vector<std::string> train_cfg_loss_name = train_cfg->GetLossName();
3159+    train_cfg_loss_name.emplace_back(token);
3160+    train_cfg->SetLossName(train_cfg_loss_name);
3161   }
3162   if (!(flags_->loss_name_.empty())) {
3163-    train_cfg->loss_name_.emplace_back(flags_->loss_name_);
3164+    std::vector<std::string> train_cfg_loss_name = train_cfg->GetLossName();
3165+    train_cfg_loss_name.emplace_back(flags_->loss_name_);
3166+    train_cfg->SetLossName(train_cfg_loss_name);
3167   }
3168 }
3169
3170@@ -635,7 +653,79 @@ void NetTrain::CheckSum(MSTensor *tensor, const std::string &node_type, int id,
3171   }
3172 }
3173
3174-int NetTrain::InitCallbackParameter() {
3175+std::string GenerateOutputFileName(mindspore::MSTensor *tensor, const std::string &op_name,
3176+                                   const std::string &file_type, const size_t &idx) {
3177+  std::string file_name = op_name;
3178+  auto pos = file_name.find_first_of('/');
3179+  while (pos != std::string::npos) {
3180+    file_name.replace(pos, 1, ".");
3181+    pos = file_name.find_first_of('/');
3182+  }
3183+  file_name += "_" + file_type + "_" + std::to_string(idx) + "_shape_";
3184+  for (const auto &dim : tensor->Shape()) {
3185+    file_name += std::to_string(dim) + "_";
3186+  }
3187+  if (kTypeIdMap.find(static_cast<int>(tensor->DataType())) != kTypeIdMap.end()) {
3188+    file_name += kTypeIdMap.at(static_cast<int>(tensor->DataType()));
3189+  }
3190+  auto tensor_format = tensor->format();
3191+  if (kTensorFormatMap.find(tensor_format) != kTensorFormatMap.end()) {
3192+    file_name += "_" + kTensorFormatMap.at(tensor_format) + ".bin";
3193+  }
3194+
3195+  file_name += ".bin";
3196+  return file_name;
3197+}
3198+
3199+int NetTrain::InitDumpTensorDataCallbackParameter() {
3200+  // before callback
3201+  before_call_back_ = [&](const std::vector<mindspore::MSTensor> &before_inputs,
3202+                          const std::vector<mindspore::MSTensor> &before_outputs, const MSCallBackParam &call_param) {
3203+    auto dump_mode = dump_cfg_json_[dump::kSettings][dump::kMode].get<int>();
3204+    auto input_output_mode = dump_cfg_json_[dump::kSettings][dump::kInputOutput].get<int>();
3205+    auto kernels = dump_cfg_json_[dump::kSettings][dump::kKernels].get<std::vector<std::string>>();
3206+    if (dump_mode == 0 || std::find(kernels.begin(), kernels.end(), call_param.node_name) != kernels.end()) {
3207+      if (input_output_mode == 0 || input_output_mode == 1) {
3208+        for (size_t i = 0; i < before_inputs.size(); i++) {
3209+          auto ms_tensor = before_inputs.at(i);
3210+          auto file_name = GenerateOutputFileName(&ms_tensor, call_param.node_name, "input", i);
3211+          auto abs_file_path = dump_file_output_dir_ + "/" + file_name;
3212+          if (WriteToBin(abs_file_path, ms_tensor.MutableData(), ms_tensor.DataSize()) != RET_OK) {  // save to file
3213+            MS_LOG(ERROR) << "write tensor data to file failed.";
3214+            return false;
3215+          }
3216+        }
3217+      }
3218+    }
3219+    return true;
3220+  };
3221+
3222+  // after callback
3223+  after_call_back_ = [&](const std::vector<mindspore::MSTensor> &after_inputs,
3224+                         const std::vector<mindspore::MSTensor> &after_outputs, const MSCallBackParam &call_param) {
3225+    auto dump_mode = dump_cfg_json_[dump::kSettings][dump::kMode].get<int>();
3226+    auto input_output_mode = dump_cfg_json_[dump::kSettings][dump::kInputOutput].get<int>();
3227+    auto kernels = dump_cfg_json_[dump::kSettings][dump::kKernels].get<std::vector<std::string>>();
3228+    if (dump_mode == kDumpInputsAndOutputs ||
3229+        std::find(kernels.begin(), kernels.end(), call_param.node_name) != kernels.end()) {
3230+      if (input_output_mode == kDumpInputsAndOutputs || input_output_mode == kDumpOutputs) {
3231+        for (size_t i = 0; i < after_outputs.size(); i++) {
3232+          auto ms_tensor = after_outputs.at(i);
3233+          auto file_name = GenerateOutputFileName(&ms_tensor, call_param.node_name, "output", i);
3234+          auto abs_file_path = dump_file_output_dir_ + "/" + file_name;
3235+          if (WriteToBin(abs_file_path, ms_tensor.MutableData(), ms_tensor.DataSize()) != RET_OK) {  // save to file
3236+            MS_LOG(ERROR) << "write tensor data to file failed.";
3237+            return false;
3238+          }
3239+        }
3240+      }
3241+    }
3242+    return true;
3243+  };
3244+  return RET_OK;
3245+}
3246+
3247+int NetTrain::InitTimeProfilingCallbackParameter() {
3248   // before callback
3249   before_call_back_ = [&](const std::vector<mindspore::MSTensor> &before_inputs,
3250                           const std::vector<mindspore::MSTensor> &before_outputs,
3251@@ -696,6 +786,16 @@ int NetTrain::InitCallbackParameter() {
3252   return RET_OK;
3253 }
3254
3255+int NetTrain::InitCallbackParameter() {
3256+  int ret = RET_OK;
3257+  if (flags_->dump_tensor_data_) {
3258+    ret = InitDumpTensorDataCallbackParameter();
3259+  } else if (flags_->time_profiling_) {
3260+    ret = InitTimeProfilingCallbackParameter();
3261+  }
3262+  return ret;
3263+}
3264+
3265 void NetTrainFlags::InitResizeDimsList() {
3266   std::string content = this->resize_dims_in_;
3267   std::vector<int> shape;
3268@@ -761,14 +861,25 @@ int NetTrain::Init() {
3269     return 1;
3270   }
3271
3272-  if (flags_->time_profiling_) {
3273-    auto status = InitCallbackParameter();
3274-    if (status != RET_OK) {
3275-      MS_LOG(ERROR) << "Init callback Parameter failed.";
3276-      std::cerr << "Init callback Parameter failed." << std::endl;
3277+  // get dump data output path
3278+  auto dump_cfg_path = std::getenv(dump::kConfigPath);
3279+  if (dump_cfg_path != nullptr) {
3280+    flags_->dump_tensor_data_ = true;
3281+    if (InitDumpConfigFromJson(dump_cfg_path) != RET_OK) {
3282+      MS_LOG(ERROR) << "parse dump config file failed.";
3283       return RET_ERROR;
3284     }
3285+  } else {
3286+    MS_LOG(INFO) << "No MINDSPORE_DUMP_CONFIG in env, don't need to dump data";
3287+  }
3288+
3289+  auto status = InitCallbackParameter();
3290+  if (status != RET_OK) {
3291+    MS_LOG(ERROR) << "Init callback Parameter failed.";
3292+    std::cerr << "Init callback Parameter failed." << std::endl;
3293+    return RET_ERROR;
3294   }
3295+
3296   flags_->InitResizeDimsList();
3297   if (!flags_->resize_dims_.empty() && !flags_->input_data_list_.empty() &&
3298       flags_->resize_dims_.size() != flags_->input_data_list_.size()) {
3299@@ -779,6 +890,70 @@ int NetTrain::Init() {
3300   return RET_OK;
3301 }
3302
3303+int NetTrain::InitDumpConfigFromJson(char *path) {
3304+  auto real_path = RealPath(path);
3305+  std::ifstream ifs(real_path);
3306+  if (!ifs.good()) {
3307+    MS_LOG(ERROR) << "file: " << real_path << " is not exist";
3308+    return RET_ERROR;
3309+  }
3310+  if (!ifs.is_open()) {
3311+    MS_LOG(ERROR) << "file: " << real_path << " open failed";
3312+    return RET_ERROR;
3313+  }
3314+
3315+  try {
3316+    dump_cfg_json_ = nlohmann::json::parse(ifs);
3317+  } catch (const nlohmann::json::parse_error &error) {
3318+    MS_LOG(ERROR) << "parse json file failed, please check your file.";
3319+    return RET_ERROR;
3320+  }
3321+  if (dump_cfg_json_[dump::kSettings] == nullptr) {
3322+    MS_LOG(ERROR) << "\"common_dump_settings\" is required.";
3323+    return RET_ERROR;
3324+  }
3325+  if (dump_cfg_json_[dump::kSettings][dump::kMode] == nullptr) {
3326+    MS_LOG(ERROR) << "\"dump_mode\" is required.";
3327+    return RET_ERROR;
3328+  }
3329+  if (dump_cfg_json_[dump::kSettings][dump::kPath] == nullptr) {
3330+    MS_LOG(ERROR) << "\"path\" is required.";
3331+    return RET_ERROR;
3332+  }
3333+  if (dump_cfg_json_[dump::kSettings][dump::kNetName] == nullptr) {
3334+    dump_cfg_json_[dump::kSettings][dump::kNetName] = "default";
3335+  }
3336+  if (dump_cfg_json_[dump::kSettings][dump::kInputOutput] == nullptr) {
3337+    dump_cfg_json_[dump::kSettings][dump::kInputOutput] = 0;
3338+  }
3339+  if (dump_cfg_json_[dump::kSettings][dump::kKernels] != nullptr &&
3340+      !dump_cfg_json_[dump::kSettings][dump::kKernels].empty()) {
3341+    if (dump_cfg_json_[dump::kSettings][dump::kMode] == 0) {
3342+      MS_LOG(ERROR) << R"("dump_mode" should be 1 when "kernels" isn't empty.)";
3343+      return RET_ERROR;
3344+    }
3345+  }
3346+
3347+  auto abs_path = dump_cfg_json_[dump::kSettings][dump::kPath].get<std::string>();
3348+  auto net_name = dump_cfg_json_[dump::kSettings][dump::kNetName].get<std::string>();
3349+  if (abs_path.back() == '\\' || abs_path.back() == '/') {
3350+    dump_file_output_dir_ = abs_path + net_name;
3351+  } else {
3352+#ifdef _WIN32
3353+    dump_file_output_dir_ = abs_path + "\\" + net_name;
3354+#else
3355+    dump_file_output_dir_ = abs_path + "/" + net_name;
3356+#endif
3357+  }
3358+
3359+  auto status = CreateOutputDir(&dump_file_output_dir_);
3360+  if (status != RET_OK) {
3361+    MS_LOG(ERROR) << "create data output directory failed.";
3362+    return RET_ERROR;
3363+  }
3364+  return RET_OK;
3365+}
3366+
3367 namespace {
3368 constexpr int kNumToPrint = 5;
3369 }
3370diff --git a/mindspore/lite/tools/benchmark_train/net_train.h b/mindspore/lite/tools/benchmark_train/net_train.h
3371index c67e952f..ea0aaacd 100644
3372--- a/mindspore/lite/tools/benchmark_train/net_train.h
3373+++ b/mindspore/lite/tools/benchmark_train/net_train.h
3374@@ -1,5 +1,5 @@
3375 /**
3376- * Copyright 2020 Huawei Technologies Co., Ltd
3377+ * Copyright 2020-2023 Huawei Technologies Co., Ltd
3378  *
3379  * Licensed under the Apache License, Version 2.0 (the "License");
3380  * you may not use this file except in compliance with the License.
3381@@ -30,6 +30,7 @@
3382 #include <cfloat>
3383 #include <utility>
3384 #include <algorithm>
3385+#include <nlohmann/json.hpp>
3386 #include "include/api/model.h"
3387 #include "include/api/types.h"
3388 #include "include/api/context.h"
3389@@ -54,6 +55,18 @@ enum MS_API DataType { kImage = 0, kBinary = 1 };
3390
3391 constexpr float relativeTolerance = 1e-5;
3392 constexpr float absoluteTolerance = 1e-8;
3393+extern const std::unordered_map<int, std::string> kTypeIdMap;
3394+extern const std::unordered_map<mindspore::Format, std::string> kTensorFormatMap;
3395+
3396+namespace dump {
3397+constexpr auto kConfigPath = "MINDSPORE_DUMP_CONFIG";
3398+constexpr auto kSettings = "common_dump_settings";
3399+constexpr auto kMode = "dump_mode";
3400+constexpr auto kPath = "path";
3401+constexpr auto kNetName = "net_name";
3402+constexpr auto kInputOutput = "input_output";
3403+constexpr auto kKernels = "kernels";
3404+}  // namespace dump
3405
3406 template <typename T>
3407 float TensorSum(const void *data, int size) {
3408@@ -122,6 +135,7 @@ class MS_API NetTrainFlags : public virtual FlagParser {
3409   std::string loss_name_ = "";
3410   std::string inference_file_ = "";
3411   bool unified_api_ = false;
3412+  bool dump_tensor_data_ = false;
3413 };
3414
3415 class MS_API NetTrain {
3416@@ -193,6 +207,7 @@ class MS_API NetTrain {
3417     }
3418     return meanError;
3419   }
3420+  int InitDumpConfigFromJson(char *path);
3421
3422  private:
3423   // call GenerateInputData or ReadInputFile to init inputTensors
3424@@ -219,6 +234,10 @@ class MS_API NetTrain {
3425                                   const std::shared_ptr<TrainCfg> &train_cfg, int epochs);
3426   int InitCallbackParameter();
3427
3428+  int InitDumpTensorDataCallbackParameter();
3429+
3430+  int InitTimeProfilingCallbackParameter();
3431+
3432   int PrintResult(const std::vector<std::string> &title, const std::map<std::string, std::pair<int, float>> &result);
3433
3434   template <typename T>
3435@@ -280,6 +299,8 @@ class MS_API NetTrain {
3436
3437   mindspore::MSKernelCallBack before_call_back_{nullptr};
3438   mindspore::MSKernelCallBack after_call_back_{nullptr};
3439+  nlohmann::json dump_cfg_json_;
3440+  std::string dump_file_output_dir_;
3441 };
3442
3443 int MS_API RunNetTrain(int argc, const char **argv);
3444diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc
3445index 57c818b1..03cac4c0 100644
3446--- a/mindspore/lite/tools/converter/anf_transform.cc
3447+++ b/mindspore/lite/tools/converter/anf_transform.cc
3448@@ -95,6 +95,8 @@
3449 #include "tools/optimizer/fusion/groupnorm_fusion.h"
3450 #include "tools/optimizer/fusion/mul_reduce_fusion.h"
3451 #include "tools/converter/import/cast_op_adjust.h"
3452+#include "tools/optimizer/fusion/expanddims_reshape_fusion.h"
3453+#include "tools/optimizer/fusion/squeeze_expanddims_fusion.h"
3454
3455 using std::string;
3456 namespace mindspore::lite {
3457@@ -226,7 +228,9 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const std::shared
3458                                     std::make_shared<opt::FullConnectedFusion>(),
3459                                     std::make_shared<opt::FullconnectedAddFusion>(),
3460                                     std::make_shared<opt::TensorDotFusion>(),
3461-                                    std::make_shared<opt::MatMulActivationFusion>(param)};
3462+                                    std::make_shared<opt::MatMulActivationFusion>(param),
3463+                                    std::make_shared<opt::ExpandDimsReshapeFusion>(),
3464+                                    std::make_shared<opt::SqueezeExpandDimsFusion>()};
3465   for (size_t index = 0; index < fusions.size(); index++) {
3466     auto pass_ptr = fusions.at(index);
3467     auto pass_name = pass_ptr->name();
3468diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc
3469index 449c6ef9..eaa18d6b 100644
3470--- a/mindspore/lite/tools/converter/converter.cc
3471+++ b/mindspore/lite/tools/converter/converter.cc
3472@@ -236,12 +236,14 @@ schema::MetaGraphT *ConverterImpl::TransferFuncGraph(const std::shared_ptr<Conve
3473     return nullptr;
3474   }
3475
3476-  status = UpdateGraphOutputName(meta_graph);
3477-  if (status != RET_OK) {
3478-    MS_LOG(ERROR) << "UpdateGraphOutputName failed.";
3479-    ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
3480-    delete meta_graph;
3481-    return nullptr;
3482+  if (!param->train_model) {
3483+    status = UpdateGraphOutputName(meta_graph);
3484+    if (status != RET_OK) {
3485+      MS_LOG(ERROR) << "UpdateGraphOutputName failed.";
3486+      ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
3487+      delete meta_graph;
3488+      return nullptr;
3489+    }
3490   }
3491
3492   return meta_graph;
3493diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc
3494index 09f7366e..955e346d 100644
3495--- a/mindspore/lite/tools/converter/graphdef_transform.cc
3496+++ b/mindspore/lite/tools/converter/graphdef_transform.cc
3497@@ -27,6 +27,7 @@
3498 #include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h"
3499 #include "tools/converter/legacy_optimizer/graph/tensor_quant_pass.h"
3500 #include "tools/converter/legacy_optimizer/graph/tensor_name_pass.h"
3501+#include "tools/converter/legacy_optimizer/graph/node_name_pass.h"
3502 #include "tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h"
3503 #include "tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.h"
3504 #include "tools/converter/legacy_optimizer/graph/convert_fp32_to_fp16_pass.h"
3505@@ -169,6 +170,9 @@ int GraphDefTransform::Transform(const std::shared_ptr<ConverterPara> &param) {
3506     Optimizer forming_model_optimizer;
3507     forming_model_optimizer.AddPass(new (std::nothrow) InferShapePass(param->fmk_type));
3508     forming_model_optimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass(param));
3509+    if (param->train_model) {
3510+      forming_model_optimizer.AddPass(new (std::nothrow) NodeNamePass());
3511+    }
3512     forming_model_optimizer.AddPass(new (std::nothrow) TensorNamePass());
3513     forming_model_optimizer.AddPass(new (std::nothrow) ConvertFP32ToFP16Pass(param->weight_fp16));
3514     status = forming_model_optimizer.Run(graph_defT_);
3515diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt b/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt
3516index f3e245f4..c274b7db 100644
3517--- a/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt
3518+++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt
3519@@ -9,6 +9,7 @@ file(GLOB GRAPH_PASS
3520         ${CMAKE_CURRENT_SOURCE_DIR}/convert_fp32_to_fp16_pass.cc
3521         ${CMAKE_CURRENT_SOURCE_DIR}/set_unused_quant_param_to_default_pass.cc
3522         ${CMAKE_CURRENT_SOURCE_DIR}/tensor_name_pass.cc
3523+        ${CMAKE_CURRENT_SOURCE_DIR}/node_name_pass.cc
3524         ${CMAKE_CURRENT_SOURCE_DIR}/subgraph_node_pass.cc
3525         ${CMAKE_CURRENT_SOURCE_DIR}/subgraph_tensor_pass.cc
3526         )
3527diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/node_name_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/node_name_pass.cc
3528new file mode 100644
3529index 00000000..712927b0
3530--- /dev/null
3531+++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/node_name_pass.cc
3532@@ -0,0 +1,96 @@
3533+/**
3534+ * Copyright 2022 Huawei Technologies Co., Ltd
3535+ *
3536+ * Licensed under the Apache License, Version 2.0 (the "License");
3537+ * you may not use this file except in compliance with the License.
3538+ * You may obtain a copy of the License at
3539+ *
3540+ * http://www.apache.org/licenses/LICENSE-2.0
3541+ *
3542+ * Unless required by applicable law or agreed to in writing, software
3543+ * distributed under the License is distributed on an "AS IS" BASIS,
3544+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3545+ * See the License for the specific language governing permissions and
3546+ * limitations under the License.
3547+ */
3548+
3549+#include "tools/converter/legacy_optimizer/graph/node_name_pass.h"
3550+#include <string>
3551+#include <vector>
3552+#include "tools/converter/converter_context.h"
3553+
3554+namespace mindspore::lite {
3555+std::string CutShortName(const std::string &fullname, const std::string &delimiter) {
3556+  size_t end_pos = fullname.find_last_of(delimiter);
3557+  std::string name = "";
3558+  if (end_pos != std::string::npos) {
3559+    name = fullname.substr(end_pos + 1);
3560+  }
3561+  if ((fullname.find("op") != std::string::npos) && (name.find("op") == std::string::npos) &&
3562+      (end_pos != std::string::npos)) {
3563+    size_t pos = fullname.rfind(delimiter, end_pos - 1);
3564+    if (pos != std::string::npos) {
3565+      name.insert(0, fullname.substr(pos + 1, end_pos - pos));
3566+    } else {
3567+      name.insert(0, fullname.substr(0, end_pos + 1));
3568+    }
3569+  }
3570+
3571+  const std::vector<std::string> loss_names = {"loss_fct", "_loss_fn", "SigmoidCrossEntropy"};
3572+  for (auto &s : loss_names) {
3573+    if (fullname.find(s) != std::string::npos) {
3574+      name.insert(0, s + "/");
3575+      break;
3576+    }
3577+  }
3578+
3579+  if (fullname.find("Gradients") != std::string::npos) {
3580+    size_t pos = fullname.find(delimiter);
3581+    if (pos != std::string::npos) {
3582+      name.insert(0, fullname.substr(0, pos + 1));
3583+    }
3584+  }
3585+  return name;
3586+}
3587+
3588+STATUS NodeNamePass::Run(schema::MetaGraphT *graph) {
3589+  if (graph == nullptr) {
3590+    MS_LOG(ERROR) << "graph is nullptr";
3591+    return RET_NULL_PTR;
3592+  }
3593+
3594+  std::string delimiter = "/";
3595+  for (auto &node : graph->nodes) {
3596+    if (node == nullptr || node->primitive == nullptr) {
3597+      MS_LOG(ERROR) << "node or node->primitive is nullptr";
3598+      return RET_NULL_PTR;
3599+    }
3600+    std::string node_name = CutShortName(node->name, delimiter);
3601+    node->name = node_name != "" ? node_name : node->name;
3602+
3603+    for (int i = 0; i < static_cast<int>(node->inputIndex.size()); i++) {
3604+      auto tensor_id = node->inputIndex.at(i);
3605+      auto &tensor = graph->allTensors.at(tensor_id);
3606+      if (tensor->name.empty()) {
3607+        MS_LOG(DEBUG) << "input tensor (id = " << tensor_id << ") name is null";
3608+        tensor->name = node->name + "/input-" + std::to_string(i);
3609+      } else {
3610+        std::string in_tensor_name = CutShortName(tensor->name, delimiter);
3611+        tensor->name = in_tensor_name != "" ? in_tensor_name : tensor->name;
3612+      }
3613+    }
3614+
3615+    for (int i = 0; i < static_cast<int>(node->outputIndex.size()); i++) {
3616+      auto tensor_id = node->outputIndex.at(i);
3617+      auto &tensor = graph->allTensors.at(tensor_id);
3618+      if (tensor->name.empty()) {
3619+        tensor->name = node->name + "/output-" + std::to_string(i);
3620+      } else {
3621+        std::string out_tensor_name = CutShortName(tensor->name, delimiter);
3622+        tensor->name = out_tensor_name != "" ? out_tensor_name : tensor->name;
3623+      }
3624+    }
3625+  }
3626+  return RET_OK;
3627+}
3628+}  // namespace mindspore::lite
3629\ No newline at end of file
3630diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/node_name_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/node_name_pass.h
3631new file mode 100644
3632index 00000000..4e58e5c7
3633--- /dev/null
3634+++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/node_name_pass.h
3635@@ -0,0 +1,35 @@
3636+/**
3637+ * Copyright 2022 Huawei Technologies Co., Ltd
3638+ *
3639+ * Licensed under the Apache License, Version 2.0 (the "License");
3640+ * you may not use this file except in compliance with the License.
3641+ * You may obtain a copy of the License at
3642+ *
3643+ * http://www.apache.org/licenses/LICENSE-2.0
3644+ *
3645+ * Unless required by applicable law or agreed to in writing, software
3646+ * distributed under the License is distributed on an "AS IS" BASIS,
3647+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3648+ * See the License for the specific language governing permissions and
3649+ * limitations under the License.
3650+ */
3651+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAPH_NODE_NAME_PASS_H_
3652+#define MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAPH_NODE_NAME_PASS_H_
3653+
3654+#include <memory>
3655+#include "tools/converter/optimizer.h"
3656+#include "tools/common/graph_util.h"
3657+
3658+namespace mindspore {
3659+namespace lite {
3660+class NodeNamePass : public GraphPass {
3661+ public:
3662+  NodeNamePass() {}
3663+
3664+  ~NodeNamePass() override = default;
3665+
3666+  STATUS Run(schema::MetaGraphT *graph) override;
3667+};
3668+}  // namespace lite
3669+}  // namespace mindspore
3670+#endif  // MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAPH_NODE_NAME_PASS_H_
3671diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc
3672index 04e856a8..edd6a538 100644
3673--- a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc
3674+++ b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc
3675@@ -316,6 +316,16 @@ const void *OnnxNodeParser::GetOnnxRawData(const onnx::TensorProto &onnx_const_t
3676       *data_size = data_count * sizeof(uint8_t);
3677       onnx_data = onnx_const_tensor.raw_data().data();
3678       break;
3679+    case kNumberTypeFloat16:
3680+      if (INT_MUL_OVERFLOW_THRESHOLD(data_count, sizeof(uint16_t), SIZE_MAX)) {
3681+        MS_LOG(ERROR) << "data_size overflow";
3682+        return nullptr;
3683+      }
3684+      *data_size = data_count * sizeof(uint16_t);
3685+      if (!onnx_const_tensor.raw_data().empty()) {
3686+        onnx_data = onnx_const_tensor.raw_data().data();
3687+      }
3688+      break;
3689     default:
3690       MS_LOG(ERROR) << "unsupported data type " << data_type;
3691       return nullptr;
3692diff --git a/mindspore/lite/tools/lite_exporter/anf_exporter.cc b/mindspore/lite/tools/lite_exporter/anf_exporter.cc
3693index e4063edc..6c4d2544 100644
3694--- a/mindspore/lite/tools/lite_exporter/anf_exporter.cc
3695+++ b/mindspore/lite/tools/lite_exporter/anf_exporter.cc
3696@@ -798,6 +798,17 @@ int AnfExporter::ConvertInputParameter(const CNodePtr &cnode, size_t index, cons
3697 int AnfExporter::ConvertInputValueNode(const CNodePtr &cnode, size_t index, const PrimitivePtr &primitive,
3698                                        const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
3699                                        schema::CNodeT *op_node) {
3700+  MS_CHECK_TRUE_MSG(cnode != nullptr, RET_NULL_PTR, "cnode is nullptr");
3701+  MS_CHECK_TRUE_MSG(primitive != nullptr, RET_NULL_PTR, "primitive is nullptr");
3702+  MS_CHECK_TRUE_MSG(meta_graphT != nullptr, RET_NULL_PTR, "meta_graphT is nullptr");
3703+  MS_CHECK_TRUE_MSG(op_node != nullptr, RET_NULL_PTR, "op_node is nullptr");
3704+  auto value_node = cnode->input(index)->cast<ValueNodePtr>();
3705+  MS_ASSERT(value_node != nullptr);
3706+  auto key = std::make_pair(value_node, 0);
3707+  if (node_id_map_.find(key) != node_id_map_.end()) {
3708+    op_node->inputIndex.emplace_back(node_id_map_[key]);
3709+    return RET_OK;
3710+  }
3711   DataInfo data_info;
3712   auto status =
3713     FetchDataFromValueNode(cnode, index, converter::FmkType(meta_graphT->fmkType), train_flag_, &data_info, true);
3714@@ -810,13 +821,12 @@ int AnfExporter::ConvertInputValueNode(const CNodePtr &cnode, size_t index, cons
3715   }
3716   auto schema_tensor = std::make_unique<schema::TensorT>();
3717   MS_CHECK_TRUE_MSG(schema_tensor != nullptr, RET_ERROR, "schema is nullptr");
3718-  schema_tensor->name = cnode->input(index)->fullname_with_scope();
3719+  schema_tensor->name = value_node->fullname_with_scope();
3720   schema_tensor->format = static_cast<schema::Format>(data_info.format_);
3721   schema_tensor->dataType = data_info.data_type_;
3722   schema_tensor->dims = data_info.shape_;
3723   schema_tensor->data = data_info.data_;
3724
3725-  auto key = std::make_pair(cnode->input(index), 0);
3726   node_id_map_[key] = meta_graphT->allTensors.size();
3727   op_node->inputIndex.emplace_back(meta_graphT->allTensors.size());
3728   meta_graphT->allTensors.emplace_back(std::move(schema_tensor));
3729diff --git a/mindspore/lite/tools/optimizer/fusion/expanddims_reshape_fusion.cc b/mindspore/lite/tools/optimizer/fusion/expanddims_reshape_fusion.cc
3730new file mode 100644
3731index 00000000..fb047193
3732--- /dev/null
3733+++ b/mindspore/lite/tools/optimizer/fusion/expanddims_reshape_fusion.cc
3734@@ -0,0 +1,73 @@
3735+/**
3736+ * Copyright 2022 Huawei Technologies Co., Ltd
3737+ *
3738+ * Licensed under the Apache License, Version 2.0 (the "License");
3739+ * you may not use this file except in compliance with the License.
3740+ * You may obtain a copy of the License at
3741+ *
3742+ * http://www.apache.org/licenses/LICENSE-2.0
3743+ *
3744+ * Unless required by applicable law or agreed to in writing, software
3745+ * distributed under the License is distributed on an "AS IS" BASIS,
3746+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3747+ * See the License for the specific language governing permissions and
3748+ * limitations under the License.
3749+ */
3750+
3751+#define USE_DEPRECATED_API
3752+#include "tools/optimizer/fusion/expanddims_reshape_fusion.h"
3753+#include "tools/lite_exporter/fetch_content.h"
3754+#include "ops/op_utils.h"
3755+#include "ops/reshape.h"
3756+#include "tools/optimizer/common/gllo_utils.h"
3757+#include "nnacl/op_base.h"
3758+#include "include/registry/converter_context.h"
3759+
3760+namespace mindspore::opt {
3761+const BaseRef ExpandDimsReshapeFusion::DefinePattern() const {
3762+  auto is_reshape = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
3763+  MS_CHECK_TRUE_RET(is_reshape != nullptr, {});
3764+  auto reshape_shape = std::make_shared<Var>();
3765+  MS_CHECK_TRUE_RET(reshape_shape != nullptr, {});
3766+  auto is_expanddims = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimExpandDims>);
3767+  MS_CHECK_TRUE_RET(is_expanddims != nullptr, {});
3768+  return VectorRef({is_reshape, is_expanddims, reshape_shape});
3769+}
3770+
3771+bool ExpandDimsReshapeFusion::CheckCanFuse(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const {
3772+  auto reshape_cnode = node->cast<CNodePtr>();
3773+  MS_CHECK_TRUE_RET(reshape_cnode != nullptr, false);
3774+
3775+  MS_CHECK_TRUE_RET(reshape_cnode->input(SECOND_INPUT) != nullptr, false);
3776+  auto expanddims_cnode = reshape_cnode->input(SECOND_INPUT)->cast<CNodePtr>();
3777+  MS_CHECK_TRUE_RET(expanddims_cnode != nullptr, false);
3778+  if (IsMultiOutputTensors(func_graph, expanddims_cnode)) {
3779+    return false;
3780+  }
3781+  auto expanddims_primc = GetValueNode<PrimitiveCPtr>(expanddims_cnode->input(0));
3782+  MS_CHECK_TRUE_RET(expanddims_primc != nullptr, false);
3783+  if (IsQuantParameterNode(expanddims_primc)) {
3784+    MS_LOG(INFO) << expanddims_primc->name() << " is quant node";
3785+    return false;
3786+  }
3787+  return true;
3788+}
3789+
3790+const AnfNodePtr ExpandDimsReshapeFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
3791+                                                  const EquivPtr &equiv) const {
3792+  if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
3793+    return nullptr;
3794+  }
3795+
3796+  if (!CheckCanFuse(func_graph, node)) {
3797+    return nullptr;
3798+  }
3799+
3800+  auto reshape_cnode = node->cast<CNodePtr>();
3801+  auto expanddims_cnode = reshape_cnode->input(SECOND_INPUT)->cast<CNodePtr>();
3802+  auto manage = Manage(func_graph);
3803+  MS_CHECK_TRUE_RET(manage != nullptr, nullptr);
3804+  manage->SetEdge(reshape_cnode, C1NUM, expanddims_cnode->input(SECOND_INPUT));
3805+  return reshape_cnode;
3806+}
3807+}  // namespace mindspore::opt
3808diff --git a/mindspore/lite/tools/optimizer/fusion/expanddims_reshape_fusion.h b/mindspore/lite/tools/optimizer/fusion/expanddims_reshape_fusion.h
3809new file mode 100644
3810index 00000000..09475591
3811--- /dev/null
3812+++ b/mindspore/lite/tools/optimizer/fusion/expanddims_reshape_fusion.h
3813@@ -0,0 +1,40 @@
3814+/**
3815+ * Copyright 2022 Huawei Technologies Co., Ltd
3816+ *
3817+ * Licensed under the Apache License, Version 2.0 (the "License");
3818+ * you may not use this file except in compliance with the License.
3819+ * You may obtain a copy of the License at
3820+ *
3821+ * http://www.apache.org/licenses/LICENSE-2.0
3822+ *
3823+ * Unless required by applicable law or agreed to in writing, software
3824+ * distributed under the License is distributed on an "AS IS" BASIS,
3825+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3826+ * See the License for the specific language governing permissions and
3827+ * limitations under the License.
3828+ */
3829+
3830+#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_EXPANDDIMS_RESHAPE_FUSION_H_
3831+#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_EXPANDDIMS_RESHAPE_FUSION_H_
3832+
3833+#include <string>
3834+#include <memory>
3835+#include "utils/check_convert_utils.h"
3836+#include "backend/common/optimizer/optimizer.h"
3837+
3838+namespace mindspore {
3839+namespace opt {
3840+class ExpandDimsReshapeFusion : public PatternProcessPass {
3841+ public:
3842+  explicit ExpandDimsReshapeFusion(bool multigraph = true, const std::string &name = "ExpandDimsReshapeFusion")
3843+      : PatternProcessPass(name, multigraph) {}
3844+  ~ExpandDimsReshapeFusion() override = default;
3845+  const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
3846+  const BaseRef DefinePattern() const override;
3847+
3848+ private:
3849+  bool CheckCanFuse(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const;
3850+};
3851+}  // namespace opt
3852+}  // namespace mindspore
3853+#endif  // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_EXPANDDIMS_RESHAPE_FUSION_H_
3854diff --git a/mindspore/lite/tools/optimizer/fusion/squeeze_expanddims_fusion.cc b/mindspore/lite/tools/optimizer/fusion/squeeze_expanddims_fusion.cc
3855new file mode 100644
3856index 00000000..daa8ac64
3857--- /dev/null
3858+++ b/mindspore/lite/tools/optimizer/fusion/squeeze_expanddims_fusion.cc
3859@@ -0,0 +1,117 @@
3860+/**
3861+ * Copyright 2022 Huawei Technologies Co., Ltd
3862+ *
3863+ * Licensed under the Apache License, Version 2.0 (the "License");
3864+ * you may not use this file except in compliance with the License.
3865+ * You may obtain a copy of the License at
3866+ *
3867+ * http://www.apache.org/licenses/LICENSE-2.0
3868+ *
3869+ * Unless required by applicable law or agreed to in writing, software
3870+ * distributed under the License is distributed on an "AS IS" BASIS,
3871+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3872+ * See the License for the specific language governing permissions and
3873+ * limitations under the License.
3874+ */
3875+
3876+#define USE_DEPRECATED_API
3877+#include "tools/optimizer/fusion/squeeze_expanddims_fusion.h"
3878+#include <vector>
3879+#include "tools/lite_exporter/fetch_content.h"
3880+#include "ops/op_utils.h"
3881+#include "ops/squeeze.h"
3882+#include "tools/optimizer/common/gllo_utils.h"
3883+#include "nnacl/op_base.h"
3884+#include "include/registry/converter_context.h"
3885+
3886+namespace mindspore::opt {
3887+const BaseRef SqueezeExpandDimsFusion::DefinePattern() const {
3888+  auto is_expanddims = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimExpandDims>);
3889+  MS_CHECK_TRUE_RET(is_expanddims != nullptr, {});
3890+  auto ex_shape = std::make_shared<Var>();
3891+  MS_CHECK_TRUE_RET(ex_shape != nullptr, {});
3892+  auto is_squeeze = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSqueeze>);
3893+  MS_CHECK_TRUE_RET(is_squeeze != nullptr, {});
3894+  return VectorRef({is_expanddims, is_squeeze, ex_shape});
3895+}
3896+
3897+bool SqueezeExpandDimsFusion::CheckCanFuse(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const {
3898+  auto expanddims_cnode = node->cast<CNodePtr>();
3899+  MS_CHECK_TRUE_RET(expanddims_cnode != nullptr, false);
3900+  MS_CHECK_TRUE_RET(expanddims_cnode->input(SECOND_INPUT) != nullptr, false);
3901+  auto squeeze_cnode = expanddims_cnode->input(SECOND_INPUT)->cast<CNodePtr>();
3902+  MS_CHECK_TRUE_RET(squeeze_cnode != nullptr, false);
3903+  if (IsMultiOutputTensors(func_graph, squeeze_cnode)) {
3904+    return false;
3905+  }
3906+  auto squeeze_primitive = GetValueNode<PrimitiveCPtr>(squeeze_cnode->input(0));
3907+  MS_CHECK_TRUE_RET(squeeze_primitive != nullptr, false);
3908+  MS_CHECK_TRUE_RET(!IsQuantParameterNode(squeeze_primitive), false);
3909+
3910+  MS_CHECK_TRUE_RET(expanddims_cnode->input(THIRD_INPUT) != nullptr, false);
3911+  lite::DataInfo data_info;
3912+  if (lite::FetchConstData(expanddims_cnode, THIRD_INPUT, converter::kFmkTypeMs, &data_info, false) != lite::RET_OK) {
3913+    return false;
3914+  }
3915+  if ((data_info.data_type_ != kNumberTypeInt && data_info.data_type_ != kNumberTypeInt32) ||
3916+      data_info.data_.size() != C4NUM) {
3917+    return false;
3918+  }
3919+  auto expanddims_axis = *reinterpret_cast<int *>(data_info.data_.data());
3920+
3921+  auto squeeze_prim = api::MakeShared<mindspore::ops::Squeeze>(squeeze_primitive);
3922+  MS_CHECK_TRUE_RET(squeeze_prim != nullptr, false);
3923+  auto squeeze_axises = squeeze_prim->get_axis();
3924+  MS_CHECK_TRUE_RET(squeeze_axises.size() < DIMENSION_2D, false);
3925+  int64_t squeeze_axis;
3926+  if (squeeze_axises.empty()) {
3927+    squeeze_axis = INT64_MIN;
3928+  } else {
3929+    squeeze_axis = squeeze_axises.at(C0NUM);
3930+  }
3931+  if (squeeze_axis == expanddims_axis) {
3932+    return true;
3933+  } else {
3934+    // squeeze_axis or expanddims_axis is less than zero
3935+    MS_CHECK_TRUE_RET(squeeze_cnode->input(SECOND_INPUT) != nullptr, false);
3936+    auto squeeze_abt = squeeze_cnode->input(SECOND_INPUT)->abstract();
3937+    MS_CHECK_TRUE_RET(squeeze_abt != nullptr, false);
3938+    std::vector<int64_t> squeeze_shape;
3939+    if (FetchShapeFromAbstract(squeeze_abt, &squeeze_shape) != lite::RET_OK) {
3940+      return false;
3941+    }
3942+
3943+    auto expanddims_prim = GetCNodePrimitive(expanddims_cnode);
3944+    MS_CHECK_TRUE_RET(expanddims_prim != nullptr, false);
3945+    auto is_inferred =
3946+      expanddims_prim->GetAttr(kInferDone) != nullptr && GetValue<bool>(expanddims_prim->GetAttr(kInferDone));
3947+    MS_CHECK_TRUE_RET(is_inferred, false);
3948+    auto expanddims_abt = expanddims_cnode->abstract();
3949+    MS_CHECK_TRUE_RET(expanddims_abt != nullptr, false);
3950+    std::vector<int64_t> expanddims_shape;
3951+    if (FetchShapeFromAbstract(expanddims_abt, &expanddims_shape) != lite::RET_OK) {
3952+      return false;
3953+    }
3954+    MS_CHECK_TRUE_RET(squeeze_shape == expanddims_shape, false);
3955+  }
3956+  return true;
3957+}
3958+
3959+const AnfNodePtr SqueezeExpandDimsFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
3960+                                                  const EquivPtr &equiv) const {
3961+  if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
3962+    return nullptr;
3963+  }
3964+
3965+  if (!CheckCanFuse(func_graph, node)) {
3966+    return nullptr;
3967+  }
3968+
3969+  auto expanddims_cnode = node->cast<CNodePtr>();
3970+  auto squeeze_cnode = expanddims_cnode->input(SECOND_INPUT)->cast<CNodePtr>();
3971+  auto manage = Manage(func_graph);
3972+  MS_CHECK_TRUE_RET(manage != nullptr, nullptr);
3973+  manage->Replace(expanddims_cnode, squeeze_cnode->input(SECOND_INPUT));
3974+  return nullptr;
3975+}
3976+}  // namespace mindspore::opt
3977diff --git a/mindspore/lite/tools/optimizer/fusion/squeeze_expanddims_fusion.h b/mindspore/lite/tools/optimizer/fusion/squeeze_expanddims_fusion.h
3978new file mode 100644
3979index 00000000..1173cb0a
3980--- /dev/null
3981+++ b/mindspore/lite/tools/optimizer/fusion/squeeze_expanddims_fusion.h
3982@@ -0,0 +1,40 @@
3983+/**
3984+ * Copyright 2022 Huawei Technologies Co., Ltd
3985+ *
3986+ * Licensed under the Apache License, Version 2.0 (the "License");
3987+ * you may not use this file except in compliance with the License.
3988+ * You may obtain a copy of the License at
3989+ *
3990+ * http://www.apache.org/licenses/LICENSE-2.0
3991+ *
3992+ * Unless required by applicable law or agreed to in writing, software
3993+ * distributed under the License is distributed on an "AS IS" BASIS,
3994+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3995+ * See the License for the specific language governing permissions and
3996+ * limitations under the License.
3997+ */
3998+
3999+#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_SQUEEZE_EXPANDDIMS_FUSION_H_
4000+#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_SQUEEZE_EXPANDDIMS_FUSION_H_
4001+
4002+#include <string>
4003+#include <memory>
4004+#include "utils/check_convert_utils.h"
4005+#include "backend/common/optimizer/optimizer.h"
4006+
4007+namespace mindspore {
4008+namespace opt {
4009+class SqueezeExpandDimsFusion : public PatternProcessPass {
4010+ public:
4011+  explicit SqueezeExpandDimsFusion(bool multigraph = true, const std::string &name = "SqueezeExpandDimsFusion")
4012+      : PatternProcessPass(name, multigraph) {}
4013+  ~SqueezeExpandDimsFusion() override = default;
4014+  const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
4015+  const BaseRef DefinePattern() const override;
4016+
4017+ private:
4018+  bool CheckCanFuse(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const;
4019+};
4020+}  // namespace opt
4021+}  // namespace mindspore
4022+#endif  // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_SQUEEZE_EXPANDDIMS_FUSION_H_
4023--
40242.17.1
4025
4026