From 98e86575c68939fecc1bc3c80be9fca0e080b7fa Mon Sep 17 00:00:00 2001
From: z00574805
Date: Wed, 24 May 2023 11:04:25 +0800
Subject: [PATCH 1/5] xiaoyi-0001
---
include/api/callback/callback.h | 22 +-
include/api/callback/ckpt_saver.h | 4 +-
include/api/callback/loss_monitor.h | 4 +-
include/api/callback/lr_scheduler.h | 10 +-
include/api/callback/time_monitor.h | 4 +-
include/api/callback/train_accuracy.h | 10 +-
include/api/cfg.h | 25 ++-
include/api/metrics/accuracy.h | 4 +-
include/api/metrics/metrics.h | 7 +-
include/api/model_parallel_runner.h | 4 +-
include/api/net.h | 23 +-
include/api/serialization.h | 20 +-
include/api/types.h | 6 +-
mindspore/ccsrc/cxx_api/serialization.cc | 3 +-
.../device/cpu/kernel/nnacl/fp32/div_fp32.c | 4 -
.../nnacl/fp32_grad/binary_cross_entropy.c | 45 ++--
.../nnacl/fp32_grad/binary_cross_entropy.h | 2 +-
.../fp32_grad/binary_cross_entropy_grad.c | 34 ++-
.../fp32_grad/binary_cross_entropy_grad.h | 2 +-
.../nnacl/infer/binary_cross_entropy_infer.c | 4 +-
.../cpu/kernel/nnacl/infer/common_infer.c | 1 +
mindspore/lite/include/model.h | 7 +-
.../include/registry/opencl_runtime_wrapper.h | 4 +-
.../java/src/main/native/train_config.cpp | 6 +-
mindspore/lite/src/CMakeLists.txt | 22 +-
.../binary_cross_entropy_grad_populate.cc | 45 ----
.../populate/binary_cross_entropy_populate.cc | 45 ----
mindspore/lite/src/common/prim_util.h | 7 +-
mindspore/lite/src/common/tensor_util.h | 2 +-
mindspore/lite/src/extendrt/CMakeLists.txt | 4 +
.../src/extendrt/cxx_api/serialization.cc | 3 +-
.../lite/src/runtime/cxx_api/converters.h | 4 +-
.../src/runtime/cxx_api/model/model_impl.cc | 3 +
.../src/runtime/cxx_api/model/model_impl.h | 6 +-
.../lite/src/runtime/cxx_api/serialization.cc | 31 ++-
.../src/runtime/cxx_api/train/converters.cc | 6 +-
mindspore/lite/src/runtime/infer_manager.h | 10 +-
mindspore/lite/src/runtime/inner_context.h | 2 +-
.../runtime/kernel/cpu/base/argminmax_base.cc | 1 -
.../kernel/cpu/base/arithmetic_base.cc | 1 +
.../kernel/cpu/base/group_convolution_base.cc | 16 +-
.../cpu/base/group_convolution_creator.cc | 14 +-
.../runtime/kernel/cpu/base/reshape_base.cc | 1 +
.../runtime/kernel/cpu/base/strided_slice.cc | 2 +
.../kernel/cpu/fp16/fused_batchnorm_fp16.cc | 8 +-
.../kernel/cpu/fp32/fused_batchnorm_fp32.cc | 16 +-
.../runtime/kernel/cpu/fp32/oneslike_fp32.cc | 52 +++++
.../runtime/kernel/cpu/fp32/oneslike_fp32.h | 46 ++++
.../cpu/fp32_grad/binary_cross_entropy.cc | 120 +++++++++++
.../cpu/fp32_grad/binary_cross_entropy.h | 42 ++++
.../fp32_grad/binary_cross_entropy_grad.cc | 105 +++++++++
.../cpu/fp32_grad/binary_cross_entropy_grad.h | 41 ++++
.../runtime/kernel/gpu/opencl/CMakeLists.txt | 11 +
.../src/runtime/kernel/opencl/CMakeLists.txt | 8 -
mindspore/lite/src/runtime/kernel_exec_util.h | 2 +-
mindspore/lite/src/runtime/kernel_registry.h | 4 +-
mindspore/lite/src/runtime/lite_kernel.h | 4 +-
mindspore/lite/src/runtime/lite_model.h | 4 +-
mindspore/lite/src/runtime/lite_session.cc | 28 ++-
mindspore/lite/src/runtime/lite_session.h | 10 +-
mindspore/lite/src/runtime/weight_decoder.h | 4 +-
mindspore/lite/src/tensor.h | 2 +-
mindspore/lite/src/tensorlist.h | 2 +-
mindspore/lite/src/train/graph_fusion.cc | 4 +
.../train/optimizer/common/fusion_utils.cc | 37 ++++
.../src/train/optimizer/common/fusion_utils.h | 50 +++++
.../fusion/matmul_activation_fusion_pass.cc | 93 ++++++++
.../fusion/matmul_activation_fusion_pass.h | 42 ++++
.../reshape_gather_reshape_fusion_pass.cc | 148 +++++++++++++
.../reshape_gather_reshape_fusion_pass.h | 42 ++++
mindspore/lite/src/train/train_export.cc | 32 +++
mindspore/lite/src/train/train_export.h | 4 +
.../src/train/train_populate_parameter.cc | 49 +++--
mindspore/lite/src/train/train_session.cc | 80 ++++---
mindspore/lite/src/train/train_session.h | 14 ++
mindspore/lite/src/train/transfer_session.cc | 41 +++-
mindspore/lite/src/train/transfer_session.h | 5 +
.../lite/tools/benchmark_train/net_train.cc | 203 ++++++++++++++++--
.../lite/tools/benchmark_train/net_train.h | 23 +-
.../lite/tools/converter/anf_transform.cc | 6 +-
mindspore/lite/tools/converter/converter.cc | 14 +-
.../tools/converter/graphdef_transform.cc | 4 +
.../legacy_optimizer/graph/CMakeLists.txt | 1 +
.../legacy_optimizer/graph/node_name_pass.cc | 96 +++++++++
.../legacy_optimizer/graph/node_name_pass.h | 35 +++
.../converter/parser/onnx/onnx_node_parser.cc | 10 +
.../lite/tools/lite_exporter/anf_exporter.cc | 14 +-
.../fusion/expanddims_reshape_fusion.cc | 73 +++++++
.../fusion/expanddims_reshape_fusion.h | 40 ++++
.../fusion/squeeze_expanddims_fusion.cc | 117 ++++++++++
.../fusion/squeeze_expanddims_fusion.h | 40 ++++
91 files changed, 1955 insertions(+), 351 deletions(-)
delete mode 100644 mindspore/lite/src/common/ops/populate/binary_cross_entropy_grad_populate.cc
delete mode 100644 mindspore/lite/src/common/ops/populate/binary_cross_entropy_populate.cc
create mode 100644 mindspore/lite/src/runtime/kernel/cpu/fp32/oneslike_fp32.cc
create mode 100644 mindspore/lite/src/runtime/kernel/cpu/fp32/oneslike_fp32.h
create mode 100644 mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy.cc
create mode 100644 mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy.h
create mode 100644 mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy_grad.cc
create mode 100644 mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy_grad.h
create mode 100644 mindspore/lite/src/runtime/kernel/gpu/opencl/CMakeLists.txt
delete mode 100644 mindspore/lite/src/runtime/kernel/opencl/CMakeLists.txt
create mode 100644 mindspore/lite/src/train/optimizer/common/fusion_utils.cc
create mode 100644 mindspore/lite/src/train/optimizer/common/fusion_utils.h
create mode 100644 mindspore/lite/src/train/optimizer/fusion/matmul_activation_fusion_pass.cc
create mode 100644 mindspore/lite/src/train/optimizer/fusion/matmul_activation_fusion_pass.h
create mode 100644 mindspore/lite/src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.cc
create mode 100644 mindspore/lite/src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.h
create mode 100644 mindspore/lite/tools/converter/legacy_optimizer/graph/node_name_pass.cc
create mode 100644 mindspore/lite/tools/converter/legacy_optimizer/graph/node_name_pass.h
create mode 100644 mindspore/lite/tools/optimizer/fusion/expanddims_reshape_fusion.cc
create mode 100644 mindspore/lite/tools/optimizer/fusion/expanddims_reshape_fusion.h
create mode 100644 mindspore/lite/tools/optimizer/fusion/squeeze_expanddims_fusion.cc
create mode 100644 mindspore/lite/tools/optimizer/fusion/squeeze_expanddims_fusion.h
diff --git a/include/api/callback/callback.h b/include/api/callback/callback.h
index 3332f819..60f30b80 100644
--- a/include/api/callback/callback.h
+++ b/include/api/callback/callback.h
@@ -1,5 +1,5 @@
/**
- * Copyright 2021 Huawei Technologies Co., Ltd
+ * Copyright 2021-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -23,6 +23,7 @@
#include
#include "include/api/data_type.h"
#include "include/api/dual_abi_helper.h"
+#include "include/api/types.h"
namespace mindspore {
class Model;
@@ -31,24 +32,19 @@ class CallbackImpl;
using GraphPoint = std::pair;
-struct TrainCallBackData {
- TrainCallBackData(bool train_mode, int epoch, int step, Model *model): train_mode_(train_mode), epoch_(epoch),
- step_(step), model_(model) {}
+struct MS_API TrainCallBackData {
+ TrainCallBackData(bool train_mode, int epoch, int step, Model *model)
+ : train_mode_(train_mode), epoch_(epoch), step_(step), model_(model) {}
bool train_mode_; /**< training mode of LiteSession object */
unsigned int epoch_; /**< the current training epoch (starts at 0) */
unsigned int step_ = 0; /**< the current step within the epoch */
- Model *model_; /**< pointer to the Model object */
+ Model *model_; /**< pointer to the Model object */
};
-enum CallbackRetValue : uint32_t {
- kContinue = 0,
- kStopTraining = 1,
- kExit = 2,
- kUnknownRetValue = 0xFFFFFFFF
-};
+enum CallbackRetValue : uint32_t { kContinue = 0, kStopTraining = 1, kExit = 2, kUnknownRetValue = 0xFFFFFFFF };
-class TrainCallBack {
+class MS_API TrainCallBack {
public:
virtual ~TrainCallBack() = default;
@@ -90,7 +86,7 @@ class TrainCallBack {
protected:
friend class Model;
friend class ModelImpl;
- CallbackImpl* callback_impl_ = nullptr;
+ CallbackImpl *callback_impl_ = nullptr;
};
} // namespace mindspore
diff --git a/include/api/callback/ckpt_saver.h b/include/api/callback/ckpt_saver.h
index e673c624..d9ff2f69 100644
--- a/include/api/callback/ckpt_saver.h
+++ b/include/api/callback/ckpt_saver.h
@@ -1,5 +1,5 @@
/**
- * Copyright 2021 Huawei Technologies Co., Ltd
+ * Copyright 2021-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -25,7 +25,7 @@
namespace mindspore {
-class CkptSaver: public TrainCallBack {
+class MS_API CkptSaver : public TrainCallBack {
public:
inline CkptSaver(int save_every_n, const std::string &filename_prefix);
virtual ~CkptSaver();
diff --git a/include/api/callback/loss_monitor.h b/include/api/callback/loss_monitor.h
index 9e0a8247..7efd0ca7 100644
--- a/include/api/callback/loss_monitor.h
+++ b/include/api/callback/loss_monitor.h
@@ -1,5 +1,5 @@
/**
- * Copyright 2021 Huawei Technologies Co., Ltd
+ * Copyright 2021-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -23,7 +23,7 @@
namespace mindspore {
-class LossMonitor: public TrainCallBack {
+class MS_API LossMonitor: public TrainCallBack {
public:
explicit LossMonitor(int print_every_n_steps = INT_MAX);
virtual ~LossMonitor();
diff --git a/include/api/callback/lr_scheduler.h b/include/api/callback/lr_scheduler.h
index 2eddc66b..11fd7124 100644
--- a/include/api/callback/lr_scheduler.h
+++ b/include/api/callback/lr_scheduler.h
@@ -1,5 +1,5 @@
/**
- * Copyright 2021 Huawei Technologies Co., Ltd
+ * Copyright 2021-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -30,18 +30,18 @@ constexpr int UPDATE_LR = 1;
using LR_Lambda = std::function;
/// \brief Multiply the LR by a factor of gamma every epoch
-int MultiplicativeLRLambda(float *lr, int epoch, void *multiplication);
+MS_API int MultiplicativeLRLambda(float *lr, int epoch, void *multiplication);
/// \brief Multiply the LR by a factor of gamma every step_size
-int StepLRLambda(float *lr, int epoch, void *step_size);
-struct StepLRLambda {
+MS_API int StepLRLambda(float *lr, int epoch, void *step_size);
+struct MS_API StepLRLambda {
StepLRLambda(int step, float g) : step_size(step), gamma(g) {}
int step_size; // period of LR decay
float gamma; // LR decay factor
};
-class LRScheduler: public TrainCallBack {
+class MS_API LRScheduler : public TrainCallBack {
public:
explicit LRScheduler(LR_Lambda lambda_func, void *lr_cb_data = nullptr, int step = 1);
virtual ~LRScheduler();
diff --git a/include/api/callback/time_monitor.h b/include/api/callback/time_monitor.h
index 7e857849..45e48248 100644
--- a/include/api/callback/time_monitor.h
+++ b/include/api/callback/time_monitor.h
@@ -1,5 +1,5 @@
/**
- * Copyright 2021 Huawei Technologies Co., Ltd
+ * Copyright 2021-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -24,7 +24,7 @@
namespace mindspore {
-class TimeMonitor: public TrainCallBack {
+class MS_API TimeMonitor : public TrainCallBack {
public:
virtual ~TimeMonitor() = default;
void EpochBegin(const TrainCallBackData &cb_data) override;
diff --git a/include/api/callback/train_accuracy.h b/include/api/callback/train_accuracy.h
index 5838dfd9..16774dd7 100644
--- a/include/api/callback/train_accuracy.h
+++ b/include/api/callback/train_accuracy.h
@@ -1,5 +1,5 @@
/**
- * Copyright 2021 Huawei Technologies Co., Ltd
+ * Copyright 2021-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -26,12 +26,10 @@
namespace mindspore {
-class TrainAccuracy: public TrainCallBack {
+class MS_API TrainAccuracy : public TrainCallBack {
public:
- explicit TrainAccuracy(int print_every_n = INT_MAX,
- int accuracy_metrics = METRICS_CLASSIFICATION,
- const std::vector &input_indexes = {1},
- const std::vector &output_indexes = {0});
+ explicit TrainAccuracy(int print_every_n = INT_MAX, int accuracy_metrics = METRICS_CLASSIFICATION,
+ const std::vector &input_indexes = {1}, const std::vector &output_indexes = {0});
virtual ~TrainAccuracy();
const std::vector &GetAccuracyPoints();
};
diff --git a/include/api/cfg.h b/include/api/cfg.h
index db915cac..8dc37bb4 100644
--- a/include/api/cfg.h
+++ b/include/api/cfg.h
@@ -1,5 +1,5 @@
/**
- * Copyright 2022 Huawei Technologies Co., Ltd
+ * Copyright 2022-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -26,7 +26,7 @@
namespace mindspore {
constexpr int iter_th = 1000;
-class MixPrecisionCfg {
+class MS_API MixPrecisionCfg {
public:
MixPrecisionCfg() {
this->dynamic_loss_scale_ = false;
@@ -49,7 +49,7 @@ class MixPrecisionCfg {
bool is_raw_mix_precision_ = false; /**< Is mix precision model export from mindspore */
};
-class TrainCfg {
+class MS_API TrainCfg {
public:
TrainCfg() = default;
TrainCfg(const TrainCfg &rhs) {
@@ -59,11 +59,24 @@ class TrainCfg {
}
~TrainCfg() = default;
+ /// \brief obtain part of the name that identify a loss kernel.
+ ///
+ /// \return loss_name.
+ inline std::vector GetLossName() const;
+ /// \brief Set part of the name that identify a loss kernel.
+ ///
+ /// \param[in] loss_name define part of the name that identify a loss kernel.
+ inline void SetLossName(const std::vector &loss_name);
+
OptimizationLevel optimization_level_ = kO0;
- std::vector loss_name_ = {
- "loss_fct", "_loss_fn", "SigmoidCrossEntropy"}; /**< Set part of the name that identify a loss kernel */
- MixPrecisionCfg mix_precision_cfg_; /**< Mix precision configuration */
+ MixPrecisionCfg mix_precision_cfg_; /**< Mix precision configuration */
bool accumulate_gradients_ = false;
+
+ private:
+ std::vector> loss_name_ = VectorStringToChar({"loss_fct", "_loss_fn", "SigmoidCrossEntropy"});
};
+
+std::vector TrainCfg::GetLossName() const { return VectorCharToString(loss_name_); }
+void TrainCfg::SetLossName(const std::vector &loss_name) { loss_name_ = VectorStringToChar(loss_name); }
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_CFG_H
diff --git a/include/api/metrics/accuracy.h b/include/api/metrics/accuracy.h
index 1d1732f3..4aefc3b5 100644
--- a/include/api/metrics/accuracy.h
+++ b/include/api/metrics/accuracy.h
@@ -1,5 +1,5 @@
/**
- * Copyright 2021 Huawei Technologies Co., Ltd
+ * Copyright 2021-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -23,7 +23,7 @@ namespace mindspore {
constexpr int METRICS_CLASSIFICATION = 0;
constexpr int METRICS_MULTILABEL = 1;
-class AccuracyMetrics : public Metrics {
+class MS_API AccuracyMetrics : public Metrics {
public:
explicit AccuracyMetrics(int accuracy_metrics = METRICS_CLASSIFICATION, const std::vector &input_indexes = {1},
const std::vector &output_indexes = {0});
diff --git a/include/api/metrics/metrics.h b/include/api/metrics/metrics.h
index 7154332f..36eb4ed1 100644
--- a/include/api/metrics/metrics.h
+++ b/include/api/metrics/metrics.h
@@ -1,5 +1,5 @@
/**
- * Copyright 2021 Huawei Technologies Co., Ltd
+ * Copyright 2021-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -24,16 +24,17 @@ class MetricsImpl;
class ModelImpl;
class MSTensor;
-class Metrics {
+class MS_API Metrics {
public:
virtual ~Metrics() = default;
virtual void Clear() {}
virtual float Eval() { return 0.0; }
virtual void Update(std::vector inputs, std::vector outputs) {}
+
protected:
friend class Model;
friend class ModelImpl;
- MetricsImpl* metrics_impl_;
+ MetricsImpl *metrics_impl_;
};
} // namespace mindspore
diff --git a/include/api/model_parallel_runner.h b/include/api/model_parallel_runner.h
index 159f4cea..360405b9 100644
--- a/include/api/model_parallel_runner.h
+++ b/include/api/model_parallel_runner.h
@@ -1,5 +1,5 @@
/**
- * Copyright 2022 Huawei Technologies Co., Ltd
+ * Copyright 2022-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -25,7 +25,7 @@
namespace mindspore {
/// \brief The RunnerConfig class is used to store environment variables during execution
/// management.
-class RunnerConfig {
+class MS_API RunnerConfig {
public:
struct Data;
RunnerConfig();
diff --git a/include/api/net.h b/include/api/net.h
index c7a3a9b0..61990ae0 100644
--- a/include/api/net.h
+++ b/include/api/net.h
@@ -1,5 +1,5 @@
/**
- * Copyright 2022 Huawei Technologies Co., Ltd
+ * Copyright 2022-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -36,14 +36,14 @@ class NodeSet;
class Graph;
class NetData;
-class NetBase {
+class MS_API NetBase {
public:
NetBase() = default;
virtual std::vector operator()(const std::vector &inputs) = 0;
virtual uint32_t type() = 0;
};
-class Node : public NetBase {
+class MS_API Node : public NetBase {
public:
Node();
virtual ~Node();
@@ -65,7 +65,7 @@ class Node : public NetBase {
std::shared_ptr impl_ = nullptr;
};
-class Net : public NetBase, public std::enable_shared_from_this {
+class MS_API Net : public NetBase, public std::enable_shared_from_this {
public:
Net();
virtual ~Net();
@@ -116,12 +116,12 @@ class Net : public NetBase, public std::enable_shared_from_this {
std::shared_ptr impl_;
};
-class SoftMaxCrossEntropyCfg {
+class MS_API SoftMaxCrossEntropyCfg {
public:
std::string reduction = "mean"; /**< Specifies reduction mode. The optional values are "none", "mean", "sum" */
};
-class AdamConfig {
+class MS_API AdamConfig {
public:
float learning_rate_ = 1e-3;
float beta1_ = 0.9;
@@ -131,11 +131,12 @@ class AdamConfig {
};
namespace NN {
-Net *NetWithLoss(Net *net, Node *loss);
-Graph *GraphWithLoss(Graph *g, Node *loss);
-Node *Adam(std::shared_ptr learn, const AdamConfig &cfg);
-Node *SoftmaxCrossEntropy(const SoftMaxCrossEntropyCfg &cfg);
-std::unique_ptr Input(std::vector dims, DataType data_type = DataType::kNumberTypeFloat32, int fmt = NHWC);
+MS_API Net *NetWithLoss(Net *net, Node *loss);
+MS_API Graph *GraphWithLoss(Graph *g, Node *loss);
+MS_API Node *Adam(std::shared_ptr learn, const AdamConfig &cfg);
+MS_API Node *SoftmaxCrossEntropy(const SoftMaxCrossEntropyCfg &cfg);
+MS_API std::unique_ptr Input(std::vector dims, DataType data_type = DataType::kNumberTypeFloat32,
+ int fmt = NHWC);
}; // namespace NN
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_NET_H
diff --git a/include/api/serialization.h b/include/api/serialization.h
index 2dc9d028..1a0c1f57 100644
--- a/include/api/serialization.h
+++ b/include/api/serialization.h
@@ -79,10 +79,16 @@ class MS_API Serialization {
///
/// \param[in] model The model data.
/// \param[in] model_type The model file type.
- /// \param[out] model_data The model parameter data.
+ /// \param[out] model_data The model buffer.
+ /// \param[in] quantization_type The quantification type.
+ /// \param[in] export_inference_only Whether to export a reasoning only model.
+ /// \param[in] output_tensor_name The set the name of the output tensor of the exported reasoning model, default as
+ /// empty, and export the complete reasoning model.
///
/// \return Status.
- static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data);
+ inline static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data,
+ QuantizationType quantization_type = kNoQuant, bool export_inference_only = true,
+ const std::vector &output_tensor_name = {});
/// \brief Export training model from file.
///
@@ -110,6 +116,9 @@ class MS_API Serialization {
static Status ExportModel(const Model &model, ModelType model_type, const std::vector &model_file,
QuantizationType quantization_type, bool export_inference_only,
const std::vector> &output_tensor_name);
+ static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data,
+ QuantizationType quantization_type, bool export_inference_only,
+ const std::vector> &output_tensor_name);
};
Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
@@ -134,5 +143,12 @@ Status Serialization::ExportModel(const Model &model, ModelType model_type, cons
VectorStringToChar(output_tensor_name));
}
+Status Serialization::ExportModel(const Model &model, ModelType model_type, Buffer *model_data,
+ QuantizationType quantization_type, bool export_inference_only,
+ const std::vector &output_tensor_name) {
+ return ExportModel(model, model_type, model_data, quantization_type, export_inference_only,
+ VectorStringToChar(output_tensor_name));
+}
+
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H
diff --git a/include/api/types.h b/include/api/types.h
index 6cf04523..377b5db0 100644
--- a/include/api/types.h
+++ b/include/api/types.h
@@ -1,5 +1,5 @@
/**
- * Copyright 2020 Huawei Technologies Co., Ltd
+ * Copyright 2020-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -350,7 +350,7 @@ std::string MSTensor::Name() const { return CharToString(CharName()); }
void MSTensor::SetTensorName(const std::string &name) { return SetTensorName(StringToChar(name)); }
-using Key = struct Key {
+using Key = struct MS_API Key {
const size_t max_key_len = 32;
size_t len = 0;
unsigned char key[32] = {0};
@@ -371,7 +371,7 @@ struct MSCallBackParam {
using MSKernelCallBack = std::function &inputs, const std::vector &outputs,
const MSCallBackParam &opInfo)>;
-std::vector CharVersion();
+MS_API std::vector CharVersion();
inline std::string Version() { return CharToString(CharVersion()); }
} // namespace mindspore
diff --git a/mindspore/ccsrc/cxx_api/serialization.cc b/mindspore/ccsrc/cxx_api/serialization.cc
index 1ea95935..bf33e85d 100644
--- a/mindspore/ccsrc/cxx_api/serialization.cc
+++ b/mindspore/ccsrc/cxx_api/serialization.cc
@@ -334,7 +334,8 @@ Status Serialization::SetParameters(const std::map &, Model
return kMEFailed;
}
-Status Serialization::ExportModel(const Model &, ModelType, Buffer *) {
+Status Serialization::ExportModel(const Model &, ModelType, Buffer *, QuantizationType, bool,
+ const std::vector> & /* output_tensor_name */) {
MS_LOG(ERROR) << "Unsupported feature.";
return kMEFailed;
}
diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/div_fp32.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/div_fp32.c
index f6fa5994..60a27df1 100644
--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/div_fp32.c
+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/div_fp32.c
@@ -29,10 +29,6 @@ int ElementOptDiv(const float *in0, const float *in1, float *out, int size, cons
out[index] = in0[0] / in1[index];
}
} else {
- if (in1[0] == 0) {
- return NNACL_ERRCODE_DIVISOR_ZERO;
- }
-
SIMD_RUN_NO_SCALAR(ElementOptDivNum1, index, in0, in1, out, size);
for (; index < size; index++) {
out[index] = in0[index] / in1[0];
diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy.c
index 2db54161..cf2f867c 100644
--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy.c
+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy.c
@@ -18,28 +18,44 @@
#include "nnacl/fp32_grad/binary_cross_entropy.h"
static void BinaryCrossEntropyLossKernel(const int input_size, const int reduction, const float *input_x,
- const float *input_y, const float *weight, float *loss, float *tmp_loss) {
+ const float *input_y, const float *weight, float *loss, float *tmp_loss,
+ bool weight_defined) {
const float epsilon = 1e-12;
- if (reduction == 0) {
- for (int i = 0; i < input_size; i++) {
- float value =
- -weight[i] * (input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon));
- loss[i] = value;
+
+ if (reduction == Reduction_None) {
+ if (weight_defined) {
+ for (int i = 0; i < input_size; i++) {
+ float value =
+ -weight[i] * (input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon));
+ loss[i] = value;
+ }
+ } else {
+ for (int i = 0; i < input_size; i++) {
+ float value = -(input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon));
+ loss[i] = value;
+ }
}
} else {
- for (int i = 0; i < input_size; i++) {
- float value =
- -weight[i] * (input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon));
- tmp_loss[i] = value;
+ if (weight_defined) {
+ for (int i = 0; i < input_size; i++) {
+ float value =
+ -weight[i] * (input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon));
+ tmp_loss[i] = value;
+ }
+ } else {
+ for (int i = 0; i < input_size; i++) {
+ float value = -(input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon));
+ tmp_loss[i] = value;
+ }
}
}
}
void BinaryCrossEntropy(const int input_size, const int reduction, const float *input_x, const float *input_y,
- const float *weight, float *loss, float *tmp_loss) {
+ const float *weight, float *loss, float *tmp_loss, bool weight_defined) {
loss[0] = 0.0f;
- BinaryCrossEntropyLossKernel(input_size, reduction, input_x, input_y, weight, loss, tmp_loss);
- if (reduction != 0) {
+ BinaryCrossEntropyLossKernel(input_size, reduction, input_x, input_y, weight, loss, tmp_loss, weight_defined);
+ if (reduction != Reduction_None) {
if (input_size % 2 == 1) {
tmp_loss[0] += tmp_loss[input_size - 1];
}
@@ -47,13 +63,12 @@ void BinaryCrossEntropy(const int input_size, const int reduction, const float *
for (int i = 0; i < stride; i++) {
tmp_loss[i] += tmp_loss[i + stride];
}
-
if (stride > 2 && stride % 2 == 1) {
tmp_loss[0] += tmp_loss[stride - 1];
}
}
loss[0] += tmp_loss[0];
- if (reduction == 1) {
+ if (reduction == Reduction_Mean) {
loss[0] /= input_size;
}
}
diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy.h
index 6ba6422d..abf6e63b 100644
--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy.h
+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy.h
@@ -28,7 +28,7 @@ extern "C" {
#endif
void BinaryCrossEntropy(const int input_size, const int reduction, const float *input_x, const float *input_y,
- const float *weight, float *loss, float *tmp_loss);
+ const float *weight, float *loss, float *tmp_loss, bool weight_defined);
#ifdef __cplusplus
}
diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy_grad.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy_grad.c
index 95e28c8c..12d20356 100644
--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy_grad.c
+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy_grad.c
@@ -19,23 +19,37 @@
#define MAX(a, b) ((a) > (b) ? (a) : (b))
int BinaryCrossEntropyGrad(const int input_size, const int reduction, const float *input_x, const float *input_y,
- const float *weight, const float *dloss, float *dx) {
+ const float *weight, const float *dloss, float *dx, bool weight_defined) {
const float epsilon = 1e-12f;
- if (reduction == 0) {
- for (int i = 0; i < input_size; i++) {
- float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon);
- float value = weight[i] * (input_x[i] - input_y[i]) / denominator;
- dx[i] = value * dloss[i];
+ if (reduction == Reduction_None) {
+ if (weight_defined) {
+ for (int i = 0; i < input_size; i++) {
+ float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon);
+ float value = weight[i] * (input_x[i] - input_y[i]) / denominator;
+ dx[i] = value * dloss[i];
+ }
+ } else {
+ for (int i = 0; i < input_size; i++) {
+ float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon);
+ float value = (input_x[i] - input_y[i]) / denominator;
+ dx[i] = value * dloss[i];
+ }
}
} else {
float dloss1 = dloss[0];
- if (reduction == 1) {
+ if (reduction == Reduction_Mean) {
dloss1 = dloss[0] / input_size;
}
for (int i = 0; i < input_size; i++) {
- float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon);
- float value = weight[i] * (input_x[i] - input_y[i]) / denominator;
- dx[i] = value * dloss1;
+ if (weight_defined) {
+ float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon);
+ float value = weight[i] * (input_x[i] - input_y[i]) / denominator;
+ dx[i] = value * dloss1;
+ } else {
+ float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon);
+ float value = (input_x[i] - input_y[i]) / denominator;
+ dx[i] = value * dloss1;
+ }
}
}
return 0;
diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy_grad.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy_grad.h
index f3506f4f..3033fa98 100644
--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy_grad.h
+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/binary_cross_entropy_grad.h
@@ -28,7 +28,7 @@ extern "C" {
#endif
int BinaryCrossEntropyGrad(const int input_size, const int reduction, const float *input_x, const float *input_y,
- const float *weight, const float *dloss, float *dx);
+ const float *weight, const float *dloss, float *dx, bool weight_defined);
#ifdef __cplusplus
}
diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/binary_cross_entropy_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/binary_cross_entropy_infer.c
index e280ad2e..22e207ac 100644
--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/binary_cross_entropy_infer.c
+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/binary_cross_entropy_infer.c
@@ -27,8 +27,8 @@ int BinaryCrossEntropyInferShape(const TensorC *const *inputs, size_t inputs_siz
TensorC *out = outputs[0];
SetDataTypeFormat(out, x);
BinaryCrossEntropyParameter *param = (BinaryCrossEntropyParameter *)parameter;
- int reduction = param->reduction;
- if (reduction == 1 || reduction == 2) {
+ ReductionType reduction = (ReductionType)(param->reduction);
+ if (reduction == Reduction_Mean || reduction == Reduction_Sum) {
out->shape_size_ = 1;
out->shape_[0] = 1;
} else {
diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/common_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/common_infer.c
index 3073385f..875c3bc0 100644
--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/common_infer.c
+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/common_infer.c
@@ -237,6 +237,7 @@ REG_INFER(LRN, PrimType_LRN, CommonInferShapeWithNHWC)
REG_INFER(L2Normalize, PrimType_L2NormalizeFusion, CommonInferShape)
REG_INFER(Neg, PrimType_Neg, CommonInferShape)
REG_INFER(NegGrad, PrimType_NegGrad, CommonGradInferShape)
+REG_INFER(OnesLike, PrimType_OnesLike, CommonInferShape)
REG_INFER(PowerGrad, PrimType_PowerGrad, CommonGradInferShape)
REG_INFER(PReLU, PrimType_PReLUFusion, CommonInferShape)
REG_INFER(Reciprocal, PrimType_Reciprocal, CommonInferShape)
diff --git a/mindspore/lite/include/model.h b/mindspore/lite/include/model.h
index a54904c8..a635e745 100644
--- a/mindspore/lite/include/model.h
+++ b/mindspore/lite/include/model.h
@@ -1,5 +1,5 @@
/**
- * Copyright 2020 Huawei Technologies Co., Ltd
+ * Copyright 2020-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -19,6 +19,7 @@
#include
#include
#include
+#include "include/api/types.h"
namespace mindspore {
namespace schema {
@@ -30,7 +31,7 @@ typedef enum { ModelType_MSLite, ModelType_MindIR } LiteModelType;
// LiteGraph can be considered as a light weight and subset of FuncGraph, it can not support the advanced expression of
// FuncGraph, e.g., non-tail recursive.
-struct LiteGraph {
+struct MS_API LiteGraph {
struct Node {
std::string name_;
std::string op_type_;
@@ -66,7 +67,7 @@ struct LiteGraph {
std::string ToString() const;
};
-struct Model {
+struct MS_API Model {
LiteGraph graph_;
char *buf = nullptr;
size_t buf_size_ = 0;
diff --git a/mindspore/lite/include/registry/opencl_runtime_wrapper.h b/mindspore/lite/include/registry/opencl_runtime_wrapper.h
index 5b95a8a3..b55554e4 100644
--- a/mindspore/lite/include/registry/opencl_runtime_wrapper.h
+++ b/mindspore/lite/include/registry/opencl_runtime_wrapper.h
@@ -1,5 +1,5 @@
/**
- * Copyright 2021 Huawei Technologies Co., Ltd
+ * Copyright 2021-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -30,7 +30,7 @@
#include "include/api/dual_abi_helper.h"
namespace mindspore::registry::opencl {
-class OpenCLRuntimeWrapper {
+class MS_API OpenCLRuntimeWrapper {
public:
OpenCLRuntimeWrapper() = default;
~OpenCLRuntimeWrapper() = default;
diff --git a/mindspore/lite/java/src/main/native/train_config.cpp b/mindspore/lite/java/src/main/native/train_config.cpp
index 4177e96b..d2452acf 100644
--- a/mindspore/lite/java/src/main/native/train_config.cpp
+++ b/mindspore/lite/java/src/main/native/train_config.cpp
@@ -1,5 +1,5 @@
/**
- * Copyright 2021 Huawei Technologies Co., Ltd
+ * Copyright 2021-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -50,7 +50,9 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_config_TrainCfg_createTrai
return (jlong) nullptr;
}
if (loss_name != nullptr) {
- traincfg_ptr->loss_name_.emplace_back(env->GetStringUTFChars(loss_name, JNI_FALSE));
+ std::vector traincfg_loss_name = traincfg_ptr->GetLossName();
+ traincfg_loss_name.emplace_back(env->GetStringUTFChars(loss_name, JNI_FALSE));
+ traincfg_ptr->SetLossName(traincfg_loss_name);
}
traincfg_ptr->optimization_level_ = ol;
traincfg_ptr->accumulate_gradients_ = accmulateGrads;
diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt
index 16ae2e63..48e0fe7c 100644
--- a/mindspore/lite/src/CMakeLists.txt
+++ b/mindspore/lite/src/CMakeLists.txt
@@ -50,6 +50,11 @@ if(APPLE OR PLATFORM_ARM32 OR PLATFORM_ARM64)
-fdata-sections -ffast-math -fno-rtti -fno-exceptions -Wno-shorten-64-to-32 \
-fno-aligned-allocation -DTARGET_OS_OSX")
endif()
+ if("${CMAKE_BUILD_TYPE}" STREQUAL "Release" AND NOT MSLITE_ENABLE_TESTCASES)
+ set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fvisibility-inlines-hidden -fvisibility=hidden")
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility-inlines-hidden -fvisibility=hidden")
+ set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--gc-sections")
+ endif()
elseif(NOT MSVC)
if("${CMAKE_BUILD_TYPE}" STREQUAL "Release")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fomit-frame-pointer -fstrict-aliasing -ffunction-sections \
@@ -312,16 +317,6 @@ set(LITE_SRC
${CMAKE_CURRENT_SOURCE_DIR}/runtime/weight_decoder.cc
)
-if(MSLITE_GPU_BACKEND STREQUAL opencl)
- file(GLOB_RECURSE OPENCL_RUNTIME_SRC
- ${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/gpu/opencl/*.cc
- )
- set(LITE_SRC
- ${LITE_SRC}
- ${OPENCL_RUNTIME_SRC}
- )
-endif()
-
if(MSLITE_GPU_BACKEND STREQUAL cuda)
file(GLOB CUDA_RUNTIME_SRC
${CMAKE_CURRENT_SOURCE_DIR}/runtime/gpu/*.cc
@@ -384,6 +379,9 @@ set(TRAIN_SRC
${CMAKE_CURRENT_SOURCE_DIR}/train/classification_train_accuracy_monitor.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/train_export.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/opt_allocator.cc
+ ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/common/fusion_utils.cc
+ ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/fusion/matmul_activation_fusion_pass.cc
+ ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/common/storage.cc
${TOOLS_DIR}/converter/optimizer.cc
${TOOLS_DIR}/converter/legacy_optimizer/fusion/fusion_pass.cc
@@ -393,7 +391,7 @@ set(TRAIN_SRC
${TOOLS_DIR}/converter/legacy_optimizer/graph/dropout_node_remove_pass.cc
${TOOLS_DIR}/converter/legacy_optimizer/graph/isolated_node_remove_pass.cc
${TOOLS_DIR}/converter/legacy_optimizer/graph/subgraph_node_pass.cc
- )
+ train/optimizer/fusion/reshape_gather_reshape_fusion_pass.cc)
if(MSLITE_ENABLE_MINDRT)
add_subdirectory(${CORE_DIR}/mindrt mindspore_mindrt)
@@ -527,7 +525,7 @@ else()
endif()
if(MSLITE_GPU_BACKEND STREQUAL opencl)
- add_subdirectory(runtime/kernel/opencl)
+ add_subdirectory(runtime/kernel/gpu/opencl)
target_link_libraries(mindspore-lite opencl_kernel_mid)
target_link_libraries(mindspore-lite_static opencl_kernel_mid)
endif()
diff --git a/mindspore/lite/src/common/ops/populate/binary_cross_entropy_grad_populate.cc b/mindspore/lite/src/common/ops/populate/binary_cross_entropy_grad_populate.cc
deleted file mode 100644
index 5da193bc..00000000
--- a/mindspore/lite/src/common/ops/populate/binary_cross_entropy_grad_populate.cc
+++ /dev/null
@@ -1,45 +0,0 @@
-/**
- * Copyright 2019-2021 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-#include "src/common/ops/populate/populate_register.h"
-#include "nnacl/fp32_grad/binary_cross_entropy_grad.h"
-using mindspore::schema::PrimitiveType_BinaryCrossEntropyGrad;
-
-namespace mindspore {
-namespace lite {
-OpParameter *PopulateBinaryCrossEntropyGradParameter(const void *prim) {
- auto *primitive = static_cast(prim);
- MS_ASSERT(primitive != nullptr);
- auto value = primitive->value_as_BinaryCrossEntropyGrad();
- if (value == nullptr) {
- MS_LOG(ERROR) << "param is nullptr";
- return nullptr;
- }
-
- auto *param = reinterpret_cast(malloc(sizeof(BinaryCrossEntropyGradParameter)));
- if (param == nullptr) {
- MS_LOG(ERROR) << "malloc BinaryCrossEntropyGrad Parameter failed.";
- return nullptr;
- }
- memset(param, 0, sizeof(BinaryCrossEntropyGradParameter));
-
- param->op_parameter_.type_ = primitive->value_type();
- param->reduction = value->reduction();
- return reinterpret_cast(param);
-}
-
-REG_POPULATE(PrimitiveType_BinaryCrossEntropyGrad, PopulateBinaryCrossEntropyGradParameter, SCHEMA_CUR);
-} // namespace lite
-} // namespace mindspore
diff --git a/mindspore/lite/src/common/ops/populate/binary_cross_entropy_populate.cc b/mindspore/lite/src/common/ops/populate/binary_cross_entropy_populate.cc
deleted file mode 100644
index 10060d3f..00000000
--- a/mindspore/lite/src/common/ops/populate/binary_cross_entropy_populate.cc
+++ /dev/null
@@ -1,45 +0,0 @@
-/**
- * Copyright 2019-2021 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-#include "src/common/ops/populate/populate_register.h"
-#include "nnacl/fp32_grad/binary_cross_entropy.h"
-using mindspore::schema::PrimitiveType_BinaryCrossEntropy;
-
-namespace mindspore {
-namespace lite {
-OpParameter *PopulateBinaryCrossEntropyParameter(const void *prim) {
- auto primitive = static_cast(prim);
- MS_ASSERT(primitive != nullptr);
- auto value = primitive->value_as_BinaryCrossEntropy();
- if (value == nullptr) {
- MS_LOG(ERROR) << "value is nullptr";
- return nullptr;
- }
-
- auto *param = reinterpret_cast(malloc(sizeof(BinaryCrossEntropyParameter)));
- if (param == nullptr) {
- MS_LOG(ERROR) << "malloc BinaryCrossEntropy Parameter failed.";
- return nullptr;
- }
- memset(param, 0, sizeof(BinaryCrossEntropyParameter));
-
- param->op_parameter_.type_ = primitive->value_type();
- param->reduction = value->reduction();
- return reinterpret_cast(param);
-}
-
-REG_POPULATE(PrimitiveType_BinaryCrossEntropy, PopulateBinaryCrossEntropyParameter, SCHEMA_CUR);
-} // namespace lite
-} // namespace mindspore
diff --git a/mindspore/lite/src/common/prim_util.h b/mindspore/lite/src/common/prim_util.h
index 2714b6df..38733f82 100644
--- a/mindspore/lite/src/common/prim_util.h
+++ b/mindspore/lite/src/common/prim_util.h
@@ -1,5 +1,5 @@
/**
- * Copyright 2020 Huawei Technologies Co., Ltd
+ * Copyright 2020-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -16,13 +16,14 @@
#ifndef MINDSPORE_LITE_SRC_COMMON_PRIM_UTIL_H_
#define MINDSPORE_LITE_SRC_COMMON_PRIM_UTIL_H_
+#include "include/api/types.h"
namespace mindspore {
namespace lite {
-int GetPrimitiveType(const void *primitive, int schema_version);
+MS_API int GetPrimitiveType(const void *primitive, int schema_version);
const char *GetPrimitiveTypeName(const void *primitive, int schema_version);
const char *PrimitiveCurVersionTypeName(int type);
-int GenPrimVersionKey(int primitive_type, int schema_version);
+MS_API int GenPrimVersionKey(int primitive_type, int schema_version);
bool IsPartialNode(const void *primitive, int schema_version);
bool IsCallNode(const void *primitive, int schema_version);
bool IsSwitchNode(const void *primitive, int schema_version);
diff --git a/mindspore/lite/src/common/tensor_util.h b/mindspore/lite/src/common/tensor_util.h
index 6e8ac3af..caced545 100644
--- a/mindspore/lite/src/common/tensor_util.h
+++ b/mindspore/lite/src/common/tensor_util.h
@@ -41,7 +41,7 @@ int GenerateInTensorC(const std::vector &inputs, std::vector allocator = nullptr);
int GenerateOutTensorC(const OpParameter *const parameter, const std::vector &outputs,
std::vector *out_tensor_c, std::shared_ptr allocator = nullptr);
-int CheckTensorsInvalid(const std::vector &tensors);
+MS_API int CheckTensorsInvalid(const std::vector &tensors);
int CheckGraphInputShapes(const std::vector &inputs,
const std::unordered_map> &input_shape_map);
std::vector LiteTensorsToMSTensors(const std::vector &lite_tensors);
diff --git a/mindspore/lite/src/extendrt/CMakeLists.txt b/mindspore/lite/src/extendrt/CMakeLists.txt
index 4f43b01f..70d734be 100644
--- a/mindspore/lite/src/extendrt/CMakeLists.txt
+++ b/mindspore/lite/src/extendrt/CMakeLists.txt
@@ -1,3 +1,7 @@
+string(REPLACE "-fvisibility-inlines-hidden" "" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
+string(REPLACE "-fvisibility=hidden" "" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
+string(REPLACE "-fvisibility-inlines-hidden" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
+string(REPLACE "-fvisibility=hidden" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
set(MODEL_LOADER_FRAMEWORK_SRC
${MODEL_LOADER_FRAMEWORK_SRC}
${CMAKE_CURRENT_SOURCE_DIR}/mindir_loader/model_loader.cc
diff --git a/mindspore/lite/src/extendrt/cxx_api/serialization.cc b/mindspore/lite/src/extendrt/cxx_api/serialization.cc
index 344cfca7..c1e3d065 100644
--- a/mindspore/lite/src/extendrt/cxx_api/serialization.cc
+++ b/mindspore/lite/src/extendrt/cxx_api/serialization.cc
@@ -332,7 +332,8 @@ Status Serialization::SetParameters(const std::map &, Model
return kMEFailed;
}
-Status Serialization::ExportModel(const Model &, ModelType, Buffer *) {
+Status Serialization::ExportModel(const Model &, ModelType, Buffer *, QuantizationType, bool,
+ const std::vector> & /* output_tensor_name */) {
MS_LOG(ERROR) << "Unsupported feature.";
return kMEFailed;
}
diff --git a/mindspore/lite/src/runtime/cxx_api/converters.h b/mindspore/lite/src/runtime/cxx_api/converters.h
index bd7daabb..45ed6a5b 100644
--- a/mindspore/lite/src/runtime/cxx_api/converters.h
+++ b/mindspore/lite/src/runtime/cxx_api/converters.h
@@ -1,5 +1,5 @@
/**
- * Copyright 2021 Huawei Technologies Co., Ltd
+ * Copyright 2021-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -27,7 +27,7 @@
#include "src/runtime/c_api/context_c.h"
namespace mindspore {
-class ContextUtils {
+class MS_API ContextUtils {
public:
static lite::InnerContext *Convert(Context *context);
static lite::InnerContext *Convert(const ContextC *context_c);
diff --git a/mindspore/lite/src/runtime/cxx_api/model/model_impl.cc b/mindspore/lite/src/runtime/cxx_api/model/model_impl.cc
index e7a0e272..f5f275e4 100644
--- a/mindspore/lite/src/runtime/cxx_api/model/model_impl.cc
+++ b/mindspore/lite/src/runtime/cxx_api/model/model_impl.cc
@@ -717,6 +717,9 @@ Status ModelImpl::UpdateWeights(const std::vector &new_weights) {
inner_weights[i] = lite_impl->lite_tensor();
}
auto ret = session_->UpdateWeights(inner_weights);
+ if (ret != kSuccess) {
+ MS_LOG(ERROR) << "UpdateWeights failed, and the origin weights may have been changed.";
+ }
return static_cast(ret);
}
diff --git a/mindspore/lite/src/runtime/cxx_api/model/model_impl.h b/mindspore/lite/src/runtime/cxx_api/model/model_impl.h
index 3d359f14..5c572883 100644
--- a/mindspore/lite/src/runtime/cxx_api/model/model_impl.h
+++ b/mindspore/lite/src/runtime/cxx_api/model/model_impl.h
@@ -1,5 +1,5 @@
/**
- * Copyright 2021 Huawei Technologies Co., Ltd
+ * Copyright 2021-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -47,10 +47,10 @@ namespace mindspore {
typedef std::shared_ptr(CreateTrainSessionProto)(std::shared_ptr graph_data,
std::shared_ptr cfg,
lite::InnerContext *context);
-CreateTrainSessionProto *CreateTrainSessionCallbackHolder(CreateTrainSessionProto *proto = nullptr);
+MS_API CreateTrainSessionProto *CreateTrainSessionCallbackHolder(CreateTrainSessionProto *proto = nullptr);
using ExpressionLoader = std::function;
-ExpressionLoader CreateExpressionLoader(ExpressionLoader loader = nullptr);
+MS_API ExpressionLoader CreateExpressionLoader(ExpressionLoader loader = nullptr);
namespace session {
class Metrics;
diff --git a/mindspore/lite/src/runtime/cxx_api/serialization.cc b/mindspore/lite/src/runtime/cxx_api/serialization.cc
index 3db32826..8405f4b2 100644
--- a/mindspore/lite/src/runtime/cxx_api/serialization.cc
+++ b/mindspore/lite/src/runtime/cxx_api/serialization.cc
@@ -157,9 +157,34 @@ Status Serialization::SetParameters(const std::map ¶met
return kMEFailed;
}
-Status Serialization::ExportModel(const Model &model, ModelType model_type, Buffer *model_data) {
- MS_LOG(ERROR) << "Unsupported feature.";
- return kMEFailed;
+Status Serialization::ExportModel(const Model &model, ModelType model_type, Buffer *model_data,
+ QuantizationType quantization_type, bool export_inference_only,
+ const std::vector> &output_tensor_name) {
+ if (model.impl_ == nullptr) {
+ MS_LOG(ERROR) << "Model implement is null.";
+ return kLiteUninitializedObj;
+ }
+ if (!model.impl_->IsTrainModel()) {
+ MS_LOG(ERROR) << "Model is not TrainModel.";
+ return kLiteError;
+ }
+ if (model_data == nullptr) {
+ MS_LOG(ERROR) << "model_data is nullptr.";
+ return kLiteParamInvalid;
+ }
+ if (model_type != kMindIR && model_type != kMindIR_Lite) {
+ MS_LOG(ERROR) << "Unsupported Export Format " << model_type;
+ return kLiteParamInvalid;
+ }
+ if (model.impl_->session_ == nullptr) {
+ MS_LOG(ERROR) << "Model session is nullptr.";
+ return kLiteError;
+ }
+ auto ret = model.impl_->session_->Export(model_data, export_inference_only ? lite::MT_INFERENCE : lite::MT_TRAIN,
+ A2L_ConvertQT(quantization_type), lite::FT_FLATBUFFERS,
+ VectorCharToString(output_tensor_name));
+
+ return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError;
}
Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::vector &model_file,
diff --git a/mindspore/lite/src/runtime/cxx_api/train/converters.cc b/mindspore/lite/src/runtime/cxx_api/train/converters.cc
index 694259b3..b0801804 100644
--- a/mindspore/lite/src/runtime/cxx_api/train/converters.cc
+++ b/mindspore/lite/src/runtime/cxx_api/train/converters.cc
@@ -1,5 +1,5 @@
/**
- * Copyright 2021 Huawei Technologies Co., Ltd
+ * Copyright 2021-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -25,8 +25,8 @@ Status A2L_ConvertConfig(const TrainCfg *a_train_cfg, lite::TrainCfg *l_train_cf
return kLiteNullptr;
}
- l_train_cfg->loss_name_.clear();
- l_train_cfg->loss_name_.assign(a_train_cfg->loss_name_.begin(), a_train_cfg->loss_name_.end());
+ std::vector a_loss_name = a_train_cfg->GetLossName();
+ l_train_cfg->loss_name_.assign(a_loss_name.begin(), a_loss_name.end());
l_train_cfg->mix_precision_cfg_.dynamic_loss_scale_ = a_train_cfg->mix_precision_cfg_.loss_scale_;
l_train_cfg->mix_precision_cfg_.loss_scale_ = a_train_cfg->mix_precision_cfg_.loss_scale_;
l_train_cfg->mix_precision_cfg_.keep_batchnorm_fp32_ = (a_train_cfg->optimization_level_ != kO3);
diff --git a/mindspore/lite/src/runtime/infer_manager.h b/mindspore/lite/src/runtime/infer_manager.h
index 31da532e..a851b7d2 100644
--- a/mindspore/lite/src/runtime/infer_manager.h
+++ b/mindspore/lite/src/runtime/infer_manager.h
@@ -31,11 +31,11 @@
#include "include/api/allocator.h"
namespace mindspore::lite {
-int KernelInferShape(const std::vector &tensors_in, const std::vector &outputs,
- OpParameter *parameter, std::shared_ptr allocator = nullptr);
-int KernelInferShape(const std::vector &inputs, const std::vector &outputs,
- const void *primitive, std::set &&providers, int schema_version,
- const kernel::Kernel *kernel = nullptr);
+MS_API int KernelInferShape(const std::vector &tensors_in, const std::vector &outputs,
+ OpParameter *parameter, std::shared_ptr allocator = nullptr);
+MS_API int KernelInferShape(const std::vector &inputs, const std::vector &outputs,
+ const void *primitive, std::set &&providers, int schema_version,
+ const kernel::Kernel *kernel = nullptr);
class InferManager {
public:
static InferManager *GetInstance() {
diff --git a/mindspore/lite/src/runtime/inner_context.h b/mindspore/lite/src/runtime/inner_context.h
index adbeacbf..ff58995f 100644
--- a/mindspore/lite/src/runtime/inner_context.h
+++ b/mindspore/lite/src/runtime/inner_context.h
@@ -35,7 +35,7 @@ namespace mindspore::lite {
#ifdef ENABLE_MINDRT
constexpr int kDefaultParallelNum = 2;
#endif
-struct InnerContext : public Context {
+struct MS_API InnerContext : public Context {
public:
InnerContext() { InitDeviceFp16(); }
diff --git a/mindspore/lite/src/runtime/kernel/cpu/base/argminmax_base.cc b/mindspore/lite/src/runtime/kernel/cpu/base/argminmax_base.cc
index 843fc0c9..5b94867b 100644
--- a/mindspore/lite/src/runtime/kernel/cpu/base/argminmax_base.cc
+++ b/mindspore/lite/src/runtime/kernel/cpu/base/argminmax_base.cc
@@ -54,7 +54,6 @@ int ArgMinMaxCPUKernel::ReSize() {
ComputeStrides(in_shape.data(), arg_param_->in_strides_, in_shape.size());
CHECK_NULL_RETURN(out_tensors_.at(0));
auto out_shape = out_tensors_.at(0)->shape();
- CHECK_NULL_RETURN(out_shape.data());
ComputeStrides(out_shape.data(), arg_param_->out_strides_, out_shape.size());
return RET_OK;
}
diff --git a/mindspore/lite/src/runtime/kernel/cpu/base/arithmetic_base.cc b/mindspore/lite/src/runtime/kernel/cpu/base/arithmetic_base.cc
index 68f5cce3..14c97bf8 100644
--- a/mindspore/lite/src/runtime/kernel/cpu/base/arithmetic_base.cc
+++ b/mindspore/lite/src/runtime/kernel/cpu/base/arithmetic_base.cc
@@ -285,6 +285,7 @@ void ArithmeticBaseCPUKernel::ComputeOfflineInfo() {
c_matric_.batch_post_sum[i] = c_matric_.shape[i] * c_matric_.batch_post_sum[i + 1];
}
}
+ scalar_opt_ = false;
if (a_matric_.inner_size == 1) {
param_->in_elements_num0_ = 1;
scalar_opt_ = true;
diff --git a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_base.cc b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_base.cc
index aa50a916..b5370ddd 100644
--- a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_base.cc
+++ b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_base.cc
@@ -50,8 +50,6 @@ int GroupConvolutionBaseCPUKernel::ReSize() {
if (group_num_ == 0) {
return RET_ERROR;
}
- conv_param_->input_channel_ /= group_num_;
- conv_param_->output_channel_ /= group_num_;
return RET_OK;
}
@@ -96,7 +94,8 @@ int GroupConvolutionBaseCPUKernel::PreProcess() {
// in
auto in_tensor = in_tensors_.front();
CHECK_NULL_RETURN(in_tensor);
- in_shape = {in_tensor->Batch(), in_tensor->Height(), in_tensor->Width(), conv_param_->input_channel_};
+ in_shape = {in_tensor->Batch(), in_tensor->Height(), in_tensor->Width(),
+ conv_param_->input_channel_ / group_num_};
auto sub_kernel_in_tensor = group_convs_.at(i)->in_tensors().front();
CHECK_NULL_RETURN(sub_kernel_in_tensor);
sub_kernel_in_tensor->set_shape(in_shape);
@@ -108,7 +107,8 @@ int GroupConvolutionBaseCPUKernel::PreProcess() {
// out
auto out_tensor = out_tensors_.front();
CHECK_NULL_RETURN(out_tensor);
- out_shape = {out_tensor->Batch(), out_tensor->Height(), out_tensor->Width(), conv_param_->output_channel_};
+ out_shape = {out_tensor->Batch(), out_tensor->Height(), out_tensor->Width(),
+ conv_param_->output_channel_ / group_num_};
auto sub_kernel_out_tensors = group_convs_.at(i)->out_tensors();
for (auto tensor : sub_kernel_out_tensors) {
CHECK_NULL_RETURN(tensor);
@@ -148,8 +148,8 @@ int GroupConvolutionBaseCPUKernel::InitGroupParam() {
MS_LOG(ERROR) << "get in_plane_ from in_tensor failed.";
return RET_ERROR;
}
- sub_in_channel_ = conv_param_->input_channel_;
- ori_in_channel_ = sub_in_channel_ * group_num_;
+ sub_in_channel_ = conv_param_->input_channel_ / group_num_;
+ ori_in_channel_ = conv_param_->input_channel_;
in_thread_num_ = MSMIN(MSMAX(1, ctx_->thread_num_), in_plane_);
auto out_tensor = out_tensors_.front();
@@ -159,8 +159,8 @@ int GroupConvolutionBaseCPUKernel::InitGroupParam() {
MS_LOG(ERROR) << "get out_plane_ from out_tensor failed.";
return RET_ERROR;
}
- sub_out_channel_ = conv_param_->output_channel_;
- ori_out_channel_ = sub_out_channel_ * group_num_;
+ sub_out_channel_ = conv_param_->output_channel_ / group_num_;
+ ori_out_channel_ = conv_param_->output_channel_;
out_thread_num_ = MSMIN(MSMAX(1, ctx_->thread_num_), out_plane_);
return RET_OK;
}
diff --git a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.cc b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.cc
index f2a29bfd..fc78a887 100644
--- a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.cc
+++ b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.cc
@@ -96,15 +96,10 @@ lite::Tensor *CreateVarTensor(const TensorInfo &tensor_info, bool inferred) {
tensor->set_data_type(tensor_info.data_type_);
tensor->set_format(tensor_info.format_);
tensor->set_category(tensor_info.tensor_type_);
- if (tensor_info.is_in_) {
- tensor->set_shape(tensor_info.shape_);
- }
+ tensor->set_shape(tensor_info.shape_);
if (inferred) {
// set shape of out tensor
- if (!tensor_info.is_in_) {
- tensor->set_shape(tensor_info.shape_);
- }
return TensorMalloc(tensor);
}
return tensor;
@@ -185,13 +180,16 @@ void GroupConvCreator::SetShapeOfTensors() {
/* set shape */
set_filter_shape({new_out_channel, conv_param_->kernel_h_, conv_param_->kernel_w_, new_in_channel});
set_bias_shape({new_out_channel});
+ conv_param_->input_channel_ = new_in_channel;
+ conv_param_->output_channel_ = new_out_channel;
if (infered_) {
- conv_param_->input_channel_ = new_in_channel;
- conv_param_->output_channel_ = new_out_channel;
set_input_shape({origin_inputs_.front()->Batch(), origin_inputs_.front()->Height(), origin_inputs_.front()->Width(),
new_in_channel});
set_output_shape({origin_inputs_.front()->Batch(), origin_outputs_.front()->Height(),
origin_outputs_.front()->Width(), new_out_channel});
+ } else {
+ set_input_shape({-1});
+ set_output_shape({-1});
}
}
diff --git a/mindspore/lite/src/runtime/kernel/cpu/base/reshape_base.cc b/mindspore/lite/src/runtime/kernel/cpu/base/reshape_base.cc
index 58f953b8..89af7aae 100644
--- a/mindspore/lite/src/runtime/kernel/cpu/base/reshape_base.cc
+++ b/mindspore/lite/src/runtime/kernel/cpu/base/reshape_base.cc
@@ -105,6 +105,7 @@ REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_ExpandDims, LiteKernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ExpandDims, LiteKernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_ExpandDims, LiteKernelCreator)
+REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_ExpandDims, LiteKernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Squeeze, LiteKernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Squeeze, LiteKernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Squeeze, LiteKernelCreator)
diff --git a/mindspore/lite/src/runtime/kernel/cpu/base/strided_slice.cc b/mindspore/lite/src/runtime/kernel/cpu/base/strided_slice.cc
index 5db44a0a..ec2080ef 100644
--- a/mindspore/lite/src/runtime/kernel/cpu/base/strided_slice.cc
+++ b/mindspore/lite/src/runtime/kernel/cpu/base/strided_slice.cc
@@ -56,6 +56,8 @@ void StridedSliceCPUKernel::InitFastRunParam() {
for (size_t i = static_cast(split_axis_ + 1); i < in_shape.size(); i++) {
inner_ *= in_shape[i];
}
+ parallel_on_split_axis_ = false;
+ parallel_on_outer_ = false;
outer_ == 1 ? (parallel_on_split_axis_ = true) : (parallel_on_outer_ = true);
if (UpdateThreadNumPass(TC_TYPE(PrimitiveType_StridedSlice, parallel_on_outer_), 1, 1,
diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp16/fused_batchnorm_fp16.cc b/mindspore/lite/src/runtime/kernel/cpu/fp16/fused_batchnorm_fp16.cc
index 84b5a1a4..8ab4969f 100644
--- a/mindspore/lite/src/runtime/kernel/cpu/fp16/fused_batchnorm_fp16.cc
+++ b/mindspore/lite/src/runtime/kernel/cpu/fp16/fused_batchnorm_fp16.cc
@@ -23,6 +23,7 @@
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
+using mindspore::lite::RET_NO_CHANGE;
using mindspore::schema::PrimitiveType_FusedBatchNorm;
namespace mindspore::kernel {
@@ -41,8 +42,11 @@ constexpr static int kOutCurrentVarIdx = 4;
int FusedBatchnormFp16CPUKernel::Batchnorm2Scale(const void *scale_data, const void *bias_data, const void *mean_data,
const void *var_data, float eps, int kernel_num) {
auto ret = InitScaleParam();
- if (ret != RET_OK) {
- MS_LOG(ERROR) << "Init scale parameter when converting fused_batchnorm to scale.";
+ if (ret == RET_NO_CHANGE) {
+ MS_LOG(INFO) << "Unsupported to convert fused batch norm to scale.";
+ return RET_NO_CHANGE;
+ } else if (ret != RET_OK) {
+ MS_LOG(ERROR) << "Init scale param failed.";
return RET_ERROR;
}
diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32/fused_batchnorm_fp32.cc b/mindspore/lite/src/runtime/kernel/cpu/fp32/fused_batchnorm_fp32.cc
index 7243e3b0..afed28ae 100644
--- a/mindspore/lite/src/runtime/kernel/cpu/fp32/fused_batchnorm_fp32.cc
+++ b/mindspore/lite/src/runtime/kernel/cpu/fp32/fused_batchnorm_fp32.cc
@@ -19,6 +19,7 @@
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
+using mindspore::lite::RET_NO_CHANGE;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_FusedBatchNorm;
@@ -65,7 +66,7 @@ int FusedBatchnormCPUKernel::InitScaleParam() {
scale_param_->axis_ = kNHWC_C;
auto in_shape = in_tensors_[0]->shape();
- CHECK_LESS_RETURN(in_shape.size(), DIMENSION_5D);
+ MS_CHECK_TRUE_RET(in_shape.size() == DIMENSION_4D, RET_NO_CHANGE);
scale_param_->outer_size_ = 1;
for (auto i = 0; i < scale_param_->axis_; i++) {
scale_param_->outer_size_ *= in_shape[i];
@@ -80,8 +81,11 @@ int FusedBatchnormCPUKernel::InitScaleParam() {
int FusedBatchnormCPUKernel::Batchnorm2Scale(const void *scale_data, const void *bias_data, const void *mean_data,
const void *var_data, float eps, int kernel_num) {
auto ret = InitScaleParam();
- if (ret != RET_OK) {
- MS_LOG(ERROR) << "Init scale parameter when converting fused_batchnorm to scale.";
+ if (ret == RET_NO_CHANGE) {
+ MS_LOG(INFO) << "Unsupported to convert fused batch norm to scale.";
+ return RET_NO_CHANGE;
+ } else if (ret != RET_OK) {
+ MS_LOG(ERROR) << "Init scale param failed.";
return RET_ERROR;
}
@@ -131,6 +135,10 @@ int FusedBatchnormCPUKernel::InitConstTensor() {
return RET_OK;
} else {
FreeScaleAndOffset();
+ if (ret != RET_NO_CHANGE) {
+ MS_LOG(ERROR) << "convert batch norm to scale failed.";
+ return RET_ERROR;
+ }
}
}
@@ -188,7 +196,7 @@ int FusedBatchnormCPUKernel::Run() {
trained_ = true; // trained at least once
} else {
- if (out_tensors_.size() >= DIMENSION_5D) {
+ if (op_parameter_->is_train_session_ && out_tensors_.size() >= DIMENSION_5D) {
(void)memcpy(out_tensors_.at(SECOND_INPUT)->data(), scale_, out_tensors_.at(SECOND_INPUT)->Size());
(void)memcpy(out_tensors_.at(THIRD_INPUT)->data(), offset_, out_tensors_.at(THIRD_INPUT)->Size());
(void)memcpy(out_tensors_.at(FOURTH_INPUT)->data(), mean_, out_tensors_.at(FOURTH_INPUT)->Size());
diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32/oneslike_fp32.cc b/mindspore/lite/src/runtime/kernel/cpu/fp32/oneslike_fp32.cc
new file mode 100644
index 00000000..627f19f0
--- /dev/null
+++ b/mindspore/lite/src/runtime/kernel/cpu/fp32/oneslike_fp32.cc
@@ -0,0 +1,52 @@
+/**
+ * Copyright 2022 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "src/runtime/kernel/cpu/fp32/oneslike_fp32.h"
+#include "schema/model_generated.h"
+#include "nnacl/base/zeroslike_base.h"
+#include "src/runtime/kernel_registry.h"
+#include "include/errorcode.h"
+
+using mindspore::kernel::KERNEL_ARCH;
+using mindspore::lite::KernelRegistrar;
+using mindspore::lite::RET_ERROR;
+using mindspore::lite::RET_OK;
+using mindspore::schema::PrimitiveType_OnesLike;
+
+namespace mindspore::kernel {
+int OnesLikeCPUKernel::Prepare() {
+ CHECK_LESS_RETURN(in_tensors_.size(), 1);
+ CHECK_LESS_RETURN(out_tensors_.size(), 1);
+ return RET_OK;
+}
+
+int OnesLikeCPUKernel::Run() {
+ auto output = out_tensors_[0];
+ CHECK_NULL_RETURN(output);
+ if (output->data_type() == kNumberTypeInt32) {
+ ApproximateOnesLike(static_cast(output->data()), output->ElementsNum());
+ } else if (output->data_type() == kNumberTypeFloat32) {
+ ApproximateOnesLike(static_cast(output->data()), output->ElementsNum());
+ }
+ return RET_OK;
+}
+
+REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_OnesLike, LiteKernelCreator)
+REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_OnesLike, LiteKernelCreator)
+#ifdef ENABLE_FP16
+REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_OnesLike, LiteKernelCreator)
+#endif
+} // namespace mindspore::kernel
\ No newline at end of file
diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32/oneslike_fp32.h b/mindspore/lite/src/runtime/kernel/cpu/fp32/oneslike_fp32.h
new file mode 100644
index 00000000..fdca97cb
--- /dev/null
+++ b/mindspore/lite/src/runtime/kernel/cpu/fp32/oneslike_fp32.h
@@ -0,0 +1,46 @@
+/**
+ * Copyright 2022 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_ONESLike_FP32_H_
+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_ONESLike_FP32_H_
+
+#include
+#include "src/runtime/lite_kernel.h"
+
+namespace mindspore::kernel {
+class OnesLikeCPUKernel : public LiteKernel {
+ public:
+ OnesLikeCPUKernel(OpParameter *parameter, const std::vector &inputs,
+ const std::vector &outputs, const lite::InnerContext *ctx)
+ : LiteKernel(parameter, inputs, outputs, ctx) {}
+
+ ~OnesLikeCPUKernel() = default;
+
+ int Prepare() override;
+ int ReSize() override { return lite::RET_OK; }
+ int Run() override;
+
+ private:
+ template
+ void ApproximateOnesLike(T *output, int data_size) {
+ for (int i = 0; i < data_size; ++i) {
+ output[i] = 1;
+ }
+ return;
+ }
+};
+} // namespace mindspore::kernel
+
+#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_ONESLike_FP32_H_
\ No newline at end of file
diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy.cc b/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy.cc
new file mode 100644
index 00000000..a24976f8
--- /dev/null
+++ b/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy.cc
@@ -0,0 +1,120 @@
+/**
+ * Copyright 2022 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy.h"
+#include "src/runtime/kernel_registry.h"
+#include "include/errorcode.h"
+#include "nnacl/fp32_grad/binary_cross_entropy.h"
+
+using mindspore::lite::KernelRegistrar;
+using mindspore::lite::RET_ERROR;
+using mindspore::lite::RET_OK;
+using mindspore::schema::PrimitiveType_BinaryCrossEntropy;
+
+namespace mindspore::kernel {
+BinaryCrossEntropyCPUKernel::~BinaryCrossEntropyCPUKernel() {
+ if (tmp_loss_ != nullptr) {
+ free(tmp_loss_);
+ tmp_loss_ = nullptr;
+ }
+}
+
+int BinaryCrossEntropyCPUKernel::ReSize() {
+ CHECK_LESS_RETURN(in_tensors_.size(), C2NUM);
+ CHECK_LESS_RETURN(out_tensors_.size(), 1);
+ CHECK_NULL_RETURN(in_tensors_.at(0));
+ CHECK_NULL_RETURN(in_tensors_.at(1));
+ if (in_tensors_.size() == C3NUM) {
+ weight_defined_ = true;
+ CHECK_NULL_RETURN(in_tensors_.at(C2NUM));
+ }
+ CHECK_NULL_RETURN(out_tensors_.at(0));
+ CHECK_NULL_RETURN(op_parameter_);
+
+ auto param_ = reinterpret_cast(op_parameter_);
+ CHECK_NULL_RETURN(param_);
+ if (tmp_loss_ != nullptr) {
+ free(tmp_loss_);
+ tmp_loss_ = nullptr;
+ }
+ size_t input_size = in_tensors_.at(0)->ElementsNum();
+ tmp_loss_ = reinterpret_cast(malloc(input_size * sizeof(float)));
+ if (tmp_loss_ == nullptr) {
+ MS_LOG(ERROR) << "malloc tmp_loss_ for BinaryCrossEntropy op failed";
+ return RET_ERROR;
+ }
+
+ return RET_OK;
+}
+
+int BinaryCrossEntropyCPUKernel::DoExecute(int task_id) {
+ auto logits = reinterpret_cast(in_tensors_.at(0)->MutableData());
+ CHECK_NULL_RETURN(logits);
+ auto labels = reinterpret_cast(in_tensors_.at(1)->MutableData());
+ CHECK_NULL_RETURN(labels);
+ auto *out = reinterpret_cast(out_tensors_.at(0)->MutableData());
+ CHECK_NULL_RETURN(out);
+
+ auto param_ = reinterpret_cast(op_parameter_);
+ int reduction = param_->reduction;
+ size_t input_size = in_tensors_.at(0)->ElementsNum();
+ if (weight_defined_) {
+ weight_ = reinterpret_cast(in_tensors_.at(C2NUM)->MutableData());
+ CHECK_NULL_RETURN(weight_);
+ }
+ BinaryCrossEntropy(input_size, reduction, logits, labels, weight_, out, tmp_loss_, weight_defined_);
+ return RET_OK;
+}
+
+int BinaryCrossEntropyRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
+ CHECK_NULL_RETURN(cdata);
+ auto bin_crs_ent_kernel = reinterpret_cast(cdata);
+ auto error_code = bin_crs_ent_kernel->DoExecute(task_id);
+ if (error_code != RET_OK) {
+ MS_LOG(ERROR) << "BinaryCrossEntropy error task_id[" << task_id << "] error_code[" << error_code << "]";
+ return RET_ERROR;
+ }
+ return RET_OK;
+}
+
+int BinaryCrossEntropyCPUKernel::Run() {
+ int error_code = ParallelLaunch(this->ms_context_, BinaryCrossEntropyRun, this, 1);
+ if (error_code != RET_OK) {
+ MS_LOG(ERROR) << "SigmoidCrossEntropyWithLogits function error error_code[" << error_code << "]";
+ return RET_ERROR;
+ }
+ return RET_OK;
+}
+
+int BinaryCrossEntropyCPUKernel::Prepare() { return ReSize(); }
+
+kernel::LiteKernel *CpuBinaryCrossEntropyFp32KernelCreator(const std::vector &inputs,
+ const std::vector &outputs,
+ OpParameter *opParameter, const lite::Context *ctx,
+ const kernel::KernelKey &desc) {
+ MS_ASSERT(opParameter != nullptr);
+ MS_ASSERT(desc.type == schema::PrimitiveType_BinaryCrossEntropy);
+ auto *kernel = new (std::nothrow)
+ BinaryCrossEntropyCPUKernel(opParameter, inputs, outputs, static_cast(ctx));
+ if (kernel == nullptr) {
+ MS_LOG(ERROR) << "new SigmoidCrossEntropyWithLogits failed";
+ return nullptr;
+ }
+ return kernel;
+}
+
+REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BinaryCrossEntropy, CpuBinaryCrossEntropyFp32KernelCreator)
+} // namespace mindspore::kernel
diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy.h b/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy.h
new file mode 100644
index 00000000..39a7181e
--- /dev/null
+++ b/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy.h
@@ -0,0 +1,42 @@
+/**
+ * Copyright 2022 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_BINARY_CROSS_ENTROPY_H_
+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_BINARY_CROSS_ENTROPY_H_
+
+#include
+#include "src/runtime/lite_kernel.h"
+
+namespace mindspore::kernel {
+class BinaryCrossEntropyCPUKernel : public LiteKernel {
+ public:
+ explicit BinaryCrossEntropyCPUKernel(OpParameter *parameter, const std::vector &inputs,
+ const std::vector &outputs, const lite::InnerContext *ctx)
+ : LiteKernel(parameter, inputs, outputs, ctx) {}
+ ~BinaryCrossEntropyCPUKernel() override;
+ int Prepare() override;
+ int ReSize() override;
+ int Run() override;
+ int DoExecute(int task_id);
+
+ protected:
+ float *tmp_loss_ = nullptr;
+ bool weight_defined_{false};
+ float *weight_ = nullptr;
+};
+} // namespace mindspore::kernel
+
+#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_BINARY_CROSS_ENTROPY_H_
diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy_grad.cc b/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy_grad.cc
new file mode 100644
index 00000000..abac8fd1
--- /dev/null
+++ b/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy_grad.cc
@@ -0,0 +1,105 @@
+/**
+ * Copyright 2022 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy_grad.h"
+#include "src/runtime/kernel_registry.h"
+#include "include/errorcode.h"
+#include "nnacl/fp32_grad/binary_cross_entropy_grad.h"
+
+using mindspore::lite::KernelRegistrar;
+using mindspore::lite::RET_ERROR;
+using mindspore::lite::RET_OK;
+using mindspore::schema::PrimitiveType_BinaryCrossEntropyGrad;
+
+namespace mindspore::kernel {
+int BinaryCrossEntropyGradCPUKernel::ReSize() {
+ CHECK_LESS_RETURN(in_tensors_.size(), C3NUM);
+ CHECK_LESS_RETURN(out_tensors_.size(), C1NUM);
+ CHECK_NULL_RETURN(in_tensors_.at(C0NUM));
+ CHECK_NULL_RETURN(in_tensors_.at(C1NUM));
+ CHECK_NULL_RETURN(in_tensors_.at(C2NUM));
+ if (in_tensors_.size() == C4NUM) {
+ weight_defined_ = true;
+ CHECK_NULL_RETURN(in_tensors_.at(C3NUM));
+ }
+ CHECK_NULL_RETURN(out_tensors_.at(0));
+ CHECK_NULL_RETURN(op_parameter_);
+ auto param_ = reinterpret_cast(op_parameter_);
+ CHECK_NULL_RETURN(param_);
+
+ return RET_OK;
+}
+
+int BinaryCrossEntropyGradCPUKernel::DoExecute(int task_id) {
+ auto input_x = reinterpret_cast(in_tensors_.at(C0NUM)->MutableData());
+ CHECK_NULL_RETURN(input_x);
+ auto input_y = reinterpret_cast(in_tensors_.at(C1NUM)->MutableData());
+ CHECK_NULL_RETURN(input_y);
+ auto dloss = reinterpret_cast(in_tensors_.at(C2NUM)->MutableData());
+ CHECK_NULL_RETURN(dloss);
+ if (weight_defined_) {
+ weight_ = reinterpret_cast(in_tensors_.at(C3NUM)->MutableData());
+ CHECK_NULL_RETURN(weight_);
+ }
+ auto *out = reinterpret_cast(out_tensors_.at(C0NUM)->MutableData());
+ CHECK_NULL_RETURN(out);
+
+ auto param_ = reinterpret_cast(op_parameter_);
+ int reduction = param_->reduction;
+ size_t input_size = in_tensors_.at(0)->ElementsNum();
+ BinaryCrossEntropyGrad(input_size, reduction, input_x, input_y, weight_, dloss, out, weight_defined_);
+ return RET_OK;
+}
+
+int BinaryCrossEntropyGradRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
+ CHECK_NULL_RETURN(cdata);
+ auto bin_crs_ent_kernel = reinterpret_cast(cdata);
+ auto error_code = bin_crs_ent_kernel->DoExecute(task_id);
+ if (error_code != RET_OK) {
+ MS_LOG(ERROR) << "BinaryCrossEntropyGrad error task_id[" << task_id << "] error_code[" << error_code << "]";
+ return RET_ERROR;
+ }
+ return RET_OK;
+}
+
+int BinaryCrossEntropyGradCPUKernel::Run() {
+ int error_code = ParallelLaunch(this->ms_context_, BinaryCrossEntropyGradRun, this, 1);
+ if (error_code != RET_OK) {
+ MS_LOG(ERROR) << "BinaryCrossEntropyGrad function error error_code[" << error_code << "]";
+ return RET_ERROR;
+ }
+ return RET_OK;
+}
+
+int BinaryCrossEntropyGradCPUKernel::Prepare() { return ReSize(); }
+
+kernel::LiteKernel *CpuBinaryCrossEntropyGradFp32KernelCreator(const std::vector &inputs,
+ const std::vector &outputs,
+ OpParameter *opParameter, const lite::Context *ctx,
+ const kernel::KernelKey &desc) {
+ MS_ASSERT(opParameter != nullptr);
+ MS_ASSERT(desc.type == schema::PrimitiveType_BinaryCrossEntropyGrad);
+ auto *kernel = new (std::nothrow)
+ BinaryCrossEntropyGradCPUKernel(opParameter, inputs, outputs, static_cast(ctx));
+ if (kernel == nullptr) {
+ MS_LOG(ERROR) << "new BinaryCrossEntropyGrad failed";
+ return nullptr;
+ }
+ return kernel;
+}
+
+REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BinaryCrossEntropyGrad, CpuBinaryCrossEntropyGradFp32KernelCreator)
+} // namespace mindspore::kernel
diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy_grad.h b/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy_grad.h
new file mode 100644
index 00000000..d289eb65
--- /dev/null
+++ b/mindspore/lite/src/runtime/kernel/cpu/fp32_grad/binary_cross_entropy_grad.h
@@ -0,0 +1,41 @@
+/**
+ * Copyright 2022 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_BINARY_CROSS_ENTROPY_GRAD_H_
+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_BINARY_CROSS_ENTROPY_GRAD_H_
+
+#include
+#include "src/runtime/lite_kernel.h"
+
+namespace mindspore::kernel {
+class BinaryCrossEntropyGradCPUKernel : public LiteKernel {
+ public:
+ explicit BinaryCrossEntropyGradCPUKernel(OpParameter *parameter, const std::vector &inputs,
+ const std::vector &outputs, const lite::InnerContext *ctx)
+ : LiteKernel(parameter, inputs, outputs, ctx) {}
+ ~BinaryCrossEntropyGradCPUKernel() override {}
+ int Prepare() override;
+ int ReSize() override;
+ int Run() override;
+ int DoExecute(int task_id);
+
+ protected:
+ bool weight_defined_{false};
+ float *weight_ = nullptr;
+};
+} // namespace mindspore::kernel
+
+#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_BINARY_CROSS_ENTROPY_GRAD_H_
diff --git a/mindspore/lite/src/runtime/kernel/gpu/opencl/CMakeLists.txt b/mindspore/lite/src/runtime/kernel/gpu/opencl/CMakeLists.txt
new file mode 100644
index 00000000..3b42ed7a
--- /dev/null
+++ b/mindspore/lite/src/runtime/kernel/gpu/opencl/CMakeLists.txt
@@ -0,0 +1,11 @@
+string(REPLACE "-fvisibility-inlines-hidden" "" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
+string(REPLACE "-fvisibility=hidden" "" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
+string(REPLACE "-fvisibility-inlines-hidden" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
+string(REPLACE "-fvisibility=hidden" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
+if(MSLITE_GPU_BACKEND STREQUAL opencl)
+ file(GLOB_RECURSE OPENCL_KERNEL_SRC
+ ${CMAKE_CURRENT_SOURCE_DIR}/*.cc
+ ${CMAKE_CURRENT_SOURCE_DIR}/../../opencl/*.cc)
+ add_library(opencl_kernel_mid OBJECT ${OPENCL_KERNEL_SRC})
+ add_dependencies(opencl_kernel_mid fbs_src)
+endif()
diff --git a/mindspore/lite/src/runtime/kernel/opencl/CMakeLists.txt b/mindspore/lite/src/runtime/kernel/opencl/CMakeLists.txt
deleted file mode 100644
index cad0f8f7..00000000
--- a/mindspore/lite/src/runtime/kernel/opencl/CMakeLists.txt
+++ /dev/null
@@ -1,8 +0,0 @@
-if(MSLITE_GPU_BACKEND STREQUAL opencl)
- file(GLOB_RECURSE OPENCL_KERNEL_SRC
- ${CMAKE_CURRENT_SOURCE_DIR}/*.cc
- ${CMAKE_CURRENT_SOURCE_DIR}/kernel/*.cc
- ${CMAKE_CURRENT_SOURCE_DIR}/kernel/int8/*.cc)
- add_library(opencl_kernel_mid OBJECT ${OPENCL_KERNEL_SRC})
- add_dependencies(opencl_kernel_mid fbs_src)
-endif()
diff --git a/mindspore/lite/src/runtime/kernel_exec_util.h b/mindspore/lite/src/runtime/kernel_exec_util.h
index e45c185b..9ce5267e 100644
--- a/mindspore/lite/src/runtime/kernel_exec_util.h
+++ b/mindspore/lite/src/runtime/kernel_exec_util.h
@@ -24,7 +24,7 @@
namespace mindspore::kernel {
-class KernelExecUtil {
+class MS_API KernelExecUtil {
public:
static std::vector SubgraphInputNodes(const std::vector &kernels);
static std::vector SubgraphOutputNodes(const std::vector &kernels);
diff --git a/mindspore/lite/src/runtime/kernel_registry.h b/mindspore/lite/src/runtime/kernel_registry.h
index 853d863a..f563d82d 100644
--- a/mindspore/lite/src/runtime/kernel_registry.h
+++ b/mindspore/lite/src/runtime/kernel_registry.h
@@ -1,5 +1,5 @@
/**
- * Copyright 2020 Huawei Technologies Co., Ltd
+ * Copyright 2020-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -31,7 +31,7 @@ using mindspore::schema::PrimitiveType_MAX;
using mindspore::schema::PrimitiveType_MIN;
namespace mindspore::lite {
-class KernelRegistry {
+class MS_API KernelRegistry {
public:
KernelRegistry() = default;
virtual ~KernelRegistry();
diff --git a/mindspore/lite/src/runtime/lite_kernel.h b/mindspore/lite/src/runtime/lite_kernel.h
index ce829320..a27f77d8 100644
--- a/mindspore/lite/src/runtime/lite_kernel.h
+++ b/mindspore/lite/src/runtime/lite_kernel.h
@@ -1,5 +1,5 @@
/**
- * Copyright 2021-2022 Huawei Technologies Co., Ltd
+ * Copyright 2021-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -37,7 +37,7 @@
using mindspore::infer::Abstractkernel;
namespace mindspore::kernel {
-class LiteKernel : public Abstractkernel {
+class MS_API LiteKernel : public Abstractkernel {
public:
LiteKernel() = default;
diff --git a/mindspore/lite/src/runtime/lite_model.h b/mindspore/lite/src/runtime/lite_model.h
index f6c7ebc4..af62cb91 100644
--- a/mindspore/lite/src/runtime/lite_model.h
+++ b/mindspore/lite/src/runtime/lite_model.h
@@ -1,5 +1,5 @@
/**
- * Copyright 2020 Huawei Technologies Co., Ltd
+ * Copyright 2020-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -36,7 +36,7 @@
namespace mindspore {
namespace lite {
-class LiteModel : public Model {
+class MS_API LiteModel : public Model {
public:
explicit LiteModel(std::string model_path = "") : model_path_(std::move(model_path)) {}
diff --git a/mindspore/lite/src/runtime/lite_session.cc b/mindspore/lite/src/runtime/lite_session.cc
index b8808e21..dffb39e7 100644
--- a/mindspore/lite/src/runtime/lite_session.cc
+++ b/mindspore/lite/src/runtime/lite_session.cc
@@ -504,7 +504,7 @@ void LiteSession::FreePackOpWeight(const std::vector &kern
auto inputs = kernel->in_tensors();
for (auto *tensor : inputs) {
MS_ASSERT(tensor != nullptr);
- if (!tensor->IsConst()) {
+ if (!tensor->IsConst() || tensor->ref_count() >= 1) {
continue;
}
tensor->FreeData();
@@ -512,6 +512,29 @@ void LiteSession::FreePackOpWeight(const std::vector &kern
}
}
+void LiteSession::MarkSharedWeight(const std::vector &kernels) {
+ // For reducing runtime RAM
+ // free pack-op weight because pack-op will not access origin weight in runtime
+ for (auto *kernel : kernels) {
+ MS_ASSERT(kernel != nullptr);
+ if (kernel->subgraph_type() == kernel::kNotSubGraph) {
+ if (IsPackedOp(static_cast(kernel->type()))) {
+ continue;
+ }
+ } else {
+ auto subgraph = reinterpret_cast(kernel);
+ MarkSharedWeight(subgraph->nodes());
+ }
+ auto inputs = kernel->in_tensors();
+ for (auto *tensor : inputs) {
+ MS_ASSERT(tensor != nullptr);
+ if (tensor->IsConst()) {
+ tensor->IncRefCount();
+ }
+ }
+ }
+}
+
int LiteSession::CompileGraph(Model *model) {
auto ret = PreCheck(model);
if (ret != RET_OK) {
@@ -572,7 +595,7 @@ int LiteSession::CompileGraph(Model *model) {
is_running_.store(false);
return ret;
}
-
+ MarkSharedWeight(kernels_);
FreePackOpWeight(kernels_);
ret = RuntimeAllocatorInit();
@@ -1727,6 +1750,7 @@ int lite::LiteSession::LoadModelAndCompileByBuf(const char *model_buf, mindspore
delete model;
return RET_ERROR;
}
+ model->Free();
set_model(model);
return RET_OK;
}
diff --git a/mindspore/lite/src/runtime/lite_session.h b/mindspore/lite/src/runtime/lite_session.h
index 255e90b5..2fdb1eb7 100644
--- a/mindspore/lite/src/runtime/lite_session.h
+++ b/mindspore/lite/src/runtime/lite_session.h
@@ -1,5 +1,5 @@
/**
- * Copyright 2020 Huawei Technologies Co., Ltd
+ * Copyright 2020-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -39,7 +39,7 @@
namespace mindspore {
namespace lite {
-class LiteSession {
+class MS_API LiteSession {
public:
LiteSession();
virtual ~LiteSession();
@@ -101,6 +101,11 @@ class LiteSession {
std::vector out_put_tensor_name = {}) {
return mindspore::lite::RET_ERROR;
}
+ virtual int Export(Buffer *model_buffer, lite::ModelType model_type = lite::MT_TRAIN,
+ lite::QuantizationType quant_type = lite::QT_DEFAULT, lite::FormatType = lite::FT_FLATBUFFERS,
+ std::vector out_put_tensor_name = {}) {
+ return mindspore::lite::RET_ERROR;
+ }
virtual int UpdateWeights(std::vector new_weights) { return mindspore::lite::RET_ERROR; }
virtual std::vector GetFeatureMaps() const {
std::vector features;
@@ -142,6 +147,7 @@ class LiteSession {
const std::unordered_map &isolate_input_map = std::unordered_map());
static void FreePackOpWeight(const std::vector &kernels);
std::string ParseWeightPath();
+ static void MarkSharedWeight(const std::vector &kernels);
private:
int PreCheck(Model *model);
diff --git a/mindspore/lite/src/runtime/weight_decoder.h b/mindspore/lite/src/runtime/weight_decoder.h
index 7c9e514c..006b4895 100644
--- a/mindspore/lite/src/runtime/weight_decoder.h
+++ b/mindspore/lite/src/runtime/weight_decoder.h
@@ -1,5 +1,5 @@
/**
- * Copyright 2020-2022 Huawei Technologies Co., Ltd
+ * Copyright 2020-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -39,7 +39,7 @@ static constexpr int kBitNum32 = 32;
namespace mindspore::lite {
-class WeightDecoder {
+class MS_API WeightDecoder {
public:
static int DequantNode(const OpParameter *op_parameter, const std::vector &in_tensors, TypeId dst_data_type,
const std::string &model_version, bool float_mode);
diff --git a/mindspore/lite/src/tensor.h b/mindspore/lite/src/tensor.h
index f30fe090..178d4754 100644
--- a/mindspore/lite/src/tensor.h
+++ b/mindspore/lite/src/tensor.h
@@ -53,7 +53,7 @@ struct LiteQuantParam {
double max{255.0};
};
-class Tensor {
+class MS_API Tensor {
public:
Tensor() = default;
diff --git a/mindspore/lite/src/tensorlist.h b/mindspore/lite/src/tensorlist.h
index 2e6f8d79..39058057 100644
--- a/mindspore/lite/src/tensorlist.h
+++ b/mindspore/lite/src/tensorlist.h
@@ -55,7 +55,7 @@ namespace mindspore::lite {
*
* See the code for other constructors.
*/
-class TensorList : public Tensor {
+class MS_API TensorList : public Tensor {
public:
TensorList() = default;
diff --git a/mindspore/lite/src/train/graph_fusion.cc b/mindspore/lite/src/train/graph_fusion.cc
index 0e582e40..1af44e45 100644
--- a/mindspore/lite/src/train/graph_fusion.cc
+++ b/mindspore/lite/src/train/graph_fusion.cc
@@ -22,6 +22,8 @@
#include "tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h"
#include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h"
#include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h"
+#include "src/train/optimizer/fusion/matmul_activation_fusion_pass.h"
+#include "src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.h"
namespace mindspore {
namespace lite {
@@ -41,7 +43,9 @@ STATUS GraphFusion::Run(schema::MetaGraphT *graph) {
}
auto old_nodes = GetGraphNodes(*graph);
Optimizer fusion_optimizer;
+ fusion_optimizer.AddPass(new (std::nothrow) ReshapeGatherReshapeFusionPass());
fusion_optimizer.AddPass(new (std::nothrow) MatMulBiasAddFusionPass());
+ fusion_optimizer.AddPass(new (std::nothrow) MatMulActivationFusionPass());
fusion_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
fusion_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
auto status = fusion_optimizer.Run(graph);
diff --git a/mindspore/lite/src/train/optimizer/common/fusion_utils.cc b/mindspore/lite/src/train/optimizer/common/fusion_utils.cc
new file mode 100644
index 00000000..3edb1a1b
--- /dev/null
+++ b/mindspore/lite/src/train/optimizer/common/fusion_utils.cc
@@ -0,0 +1,37 @@
+/**
+ * Copyright 2022 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include
+#include
+#include
+#include
+#include "src/common/log_util.h"
+#include "src/train/optimizer/common/fusion_utils.h"
+
+namespace mindspore {
+namespace opt {
+STATUS GetMatchNodeIndex(schema::MetaGraphT *graph,
+ const std::unordered_map> &matched_path,
+ const std::string &node_name, size_t *node_index) {
+ auto node_path_iter = matched_path.find(node_name);
+ MS_CHECK_TRUE_MSG(node_path_iter != matched_path.end(), RET_ERROR, "cannot find node_path");
+ const auto &node_path = node_path_iter->second;
+ MS_CHECK_TRUE_MSG(node_path != nullptr, RET_NULL_PTR, "node_path is empty");
+ *node_index = node_path->nodeIdx;
+ MS_CHECK_TRUE_MSG(*node_index < graph->nodes.size(), RET_ERROR, "node_index is out of range");
+ return RET_OK;
+}
+} // namespace opt
+} // namespace mindspore
diff --git a/mindspore/lite/src/train/optimizer/common/fusion_utils.h b/mindspore/lite/src/train/optimizer/common/fusion_utils.h
new file mode 100644
index 00000000..7f80cd49
--- /dev/null
+++ b/mindspore/lite/src/train/optimizer/common/fusion_utils.h
@@ -0,0 +1,50 @@
+/**
+ * Copyright 2022 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_COMMON_FUSION_UTILS_H_
+#define MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_COMMON_FUSION_UTILS_H_
+
+#include
+#include
+#include
+#include
+#include "src/common/utils.h"
+#include "schema/inner/model_generated.h"
+#include "tools/converter/legacy_optimizer/fusion/fusion_pattern.h"
+
+using mindspore::lite::RET_ERROR;
+using mindspore::lite::RET_NULL_PTR;
+using mindspore::lite::RET_OK;
+using mindspore::lite::STATUS;
+namespace mindspore {
+namespace opt {
+inline constexpr int kInputIndexZero = 0;
+inline constexpr int kInputIndexOne = 1;
+inline constexpr int kInputIndexTwo = 2;
+inline constexpr int kOutputIndexZero = 0;
+inline constexpr int kOutputIndexOne = 1;
+inline constexpr size_t kInputSizeTwo = 2;
+inline constexpr size_t kInputSizeThree = 3;
+inline constexpr size_t kOutputSizeOne = 1;
+inline constexpr size_t kMatchPathLenTwo = 2;
+inline constexpr size_t kMatchPathLenThree = 3;
+
+STATUS GetMatchNodeIndex(schema::MetaGraphT *graph,
+ const std::unordered_map> &matched_path,
+ const std::string &node_name, size_t *node_index);
+} // namespace opt
+} // namespace mindspore
+#endif // MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_COMMON_FUSION_UTILS_H_
diff --git a/mindspore/lite/src/train/optimizer/fusion/matmul_activation_fusion_pass.cc b/mindspore/lite/src/train/optimizer/fusion/matmul_activation_fusion_pass.cc
new file mode 100644
index 00000000..b809f2c9
--- /dev/null
+++ b/mindspore/lite/src/train/optimizer/fusion/matmul_activation_fusion_pass.cc
@@ -0,0 +1,93 @@
+/**
+ * Copyright 2022 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "src/train/optimizer/fusion/matmul_activation_fusion_pass.h"
+#include
+#include
+#include
+#include
+#include "schema/inner/model_generated.h"
+#include "tools/common/meta_graph_utils.h"
+#include "src/train/optimizer/common/fusion_utils.h"
+namespace {
+constexpr std::string_view MatMulName = "MATMUL";
+constexpr std::string_view ActName = "ACTIVATION";
+} // namespace
+namespace mindspore {
+namespace lite {
+STATUS MatMulActivationFusionPass::DefinePattern() {
+ auto matmul_op = std::make_shared();
+ MS_CHECK_TRUE_RET(matmul_op != nullptr, RET_NULL_PTR);
+ matmul_op->id = MatMulName;
+ matmul_op->types = {schema::PrimitiveType_MatMulFusion};
+ auto act_op = std::make_shared();
+ MS_CHECK_TRUE_RET(act_op != nullptr, RET_NULL_PTR);
+ act_op->id = ActName;
+ act_op->types = {schema::PrimitiveType_Activation};
+ act_op->left = matmul_op;
+ auto fusion_pattern = std::make_unique("MatMulActivationFusion");
+ MS_CHECK_TRUE_MSG(fusion_pattern != nullptr, RET_NULL_PTR, "new fusion_pattern failed");
+ fusion_pattern->AddPatternOp(matmul_op);
+ fusion_pattern->AddPatternOp(act_op);
+ fusion_pattern->Finish();
+ this->patterns.emplace_back(fusion_pattern.release());
+ return RET_OK;
+}
+
+STATUS MatMulActivationFusionPass::DoFusion(
+ MetaGraphT *graph, const std::string &pattern_name,
+ const std::unordered_map> &matched_path) {
+ MS_CHECK_TRUE_RET(graph != nullptr, RET_NULL_PTR);
+ if (matched_path.size() != opt::kMatchPathLenTwo) {
+ MS_LOG(ERROR) << "MatMul-Activation-Fusion should have two NodeIndex in matchedPair";
+ return RET_PARAM_INVALID;
+ }
+
+ size_t matmul_index = 0;
+ auto ret = opt::GetMatchNodeIndex(graph, matched_path, std::string(MatMulName), &matmul_index);
+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "cannot get matmul_index");
+ auto &matmul_node = graph->nodes.at(matmul_index);
+ MS_CHECK_TRUE_MSG(matmul_node != nullptr, RET_NULL_PTR, "matmul_node is nullptr");
+ size_t act_index = 0;
+ ret = opt::GetMatchNodeIndex(graph, matched_path, std::string(ActName), &act_index);
+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "cannot get act_index");
+ auto &act_node = graph->nodes.at(act_index);
+ MS_CHECK_TRUE_MSG(act_node != nullptr, RET_NULL_PTR, "act_node is nullptr");
+
+ if (matmul_node->quantType == schema::QuantType_QUANT_ALL ||
+ matmul_node->quantType == schema::QuantType_QUANT_DYNAMIC) {
+ MS_LOG(DEBUG) << "cannot fusion.";
+ return RET_NO_CHANGE;
+ }
+ MS_CHECK_TRUE_RET(matmul_node->primitive != nullptr, RET_NULL_PTR);
+ auto matmul_type = matmul_node->primitive->value.AsMatMulFusion();
+ MS_CHECK_TRUE_RET(matmul_type->activation_type == ActivationType::ActivationType_NO_ACTIVATION, RET_NO_CHANGE);
+ MS_CHECK_TRUE_RET(act_node->primitive != nullptr, RET_NULL_PTR);
+ auto act_type = act_node->primitive->value.AsActivation()->activation_type;
+ MS_CHECK_TRUE_RET(act_type == ActivationType::ActivationType_RELU || act_type == ActivationType::ActivationType_RELU6,
+ RET_NO_CHANGE);
+ matmul_type->activation_type = act_type;
+ matmul_node->outputIndex = {act_node->outputIndex};
+ // cannot delete node here, otherwise will destroy order in other pattern's node index
+ // make it an isolated node to be removed in IsolatedNodeRemovePass
+ act_node->inputIndex.clear();
+ act_node->outputIndex.clear();
+ return RET_OK;
+}
+
+MatMulActivationFusionPass::~MatMulActivationFusionPass() = default;
+} // namespace lite
+} // namespace mindspore
diff --git a/mindspore/lite/src/train/optimizer/fusion/matmul_activation_fusion_pass.h b/mindspore/lite/src/train/optimizer/fusion/matmul_activation_fusion_pass.h
new file mode 100644
index 00000000..57891eb3
--- /dev/null
+++ b/mindspore/lite/src/train/optimizer/fusion/matmul_activation_fusion_pass.h
@@ -0,0 +1,42 @@
+/**
+ * Copyright 2022 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_MATMUL_ACTIVATION_FUSION_PASS_H_
+#define MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_MATMUL_ACTIVATION_FUSION_PASS_H_
+
+#include
+#include
+#include
+#include
+#include
+#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h"
+
+namespace mindspore {
+namespace lite {
+class MatMulActivationFusionPass : public FusionPass {
+ public:
+ MatMulActivationFusionPass() = default;
+
+ ~MatMulActivationFusionPass() override;
+
+ STATUS DefinePattern() override;
+
+ STATUS DoFusion(MetaGraphT *graph, const std::string &pattern_name,
+ const std::unordered_map> &matched_path) override;
+};
+} // namespace lite
+} // namespace mindspore
+#endif // MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_MATMUL_ACTIVATION_FUSION_PASS_H_
diff --git a/mindspore/lite/src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.cc b/mindspore/lite/src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.cc
new file mode 100644
index 00000000..7fb8d1f4
--- /dev/null
+++ b/mindspore/lite/src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.cc
@@ -0,0 +1,148 @@
+/**
+ * Copyright 2022 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.h"
+#include
+#include
+#include
+#include
+#include "schema/inner/model_generated.h"
+#include "tools/common/meta_graph_utils.h"
+#include "src/train/optimizer/common/fusion_utils.h"
+
+namespace {
+constexpr std::string_view Reshape1Name = "RESHAPE1";
+constexpr std::string_view Reshape2Name = "RESHAPE2";
+constexpr std::string_view GatherName = "GATHER";
+} // namespace
+namespace mindspore {
+namespace lite {
+/*
+ * The subgraph such as the following.
+ * any
+ * / |
+ * reshape |
+ * \ |
+ * gather
+ * / |
+ * reshape |
+ * \ |
+ * any
+ */
+STATUS ReshapeGatherReshapeFusionPass::DefinePattern() {
+ auto reshape_op1 = std::make_shared();
+ MS_CHECK_TRUE_RET(reshape_op1 != nullptr, RET_NULL_PTR);
+ reshape_op1->id = Reshape1Name;
+ reshape_op1->types = {schema::PrimitiveType_Reshape};
+
+ auto gather_op = std::make_shared();
+ MS_CHECK_TRUE_RET(gather_op != nullptr, RET_NULL_PTR);
+ gather_op->id = GatherName;
+ gather_op->types = {schema::PrimitiveType_Gather};
+ gather_op->left = reshape_op1;
+
+ auto reshape_op2 = std::make_shared();
+ MS_CHECK_TRUE_RET(reshape_op2 != nullptr, RET_NULL_PTR);
+ reshape_op2->id = Reshape2Name;
+ reshape_op2->types = {schema::PrimitiveType_Reshape};
+ reshape_op2->left = gather_op;
+
+ auto fusion_pattern = std::make_unique("ReshapeGatherReshapeFusion");
+ MS_CHECK_TRUE_MSG(fusion_pattern != nullptr, RET_NULL_PTR, "new fusion_pattern failed");
+ fusion_pattern->AddPatternOp(reshape_op1);
+ fusion_pattern->AddPatternOp(gather_op);
+ fusion_pattern->AddPatternOp(reshape_op2);
+ fusion_pattern->Finish();
+ this->patterns.emplace_back(fusion_pattern.release());
+ return RET_OK;
+}
+
+STATUS ReshapeGatherReshapeFusionPass::DoFusion(
+ MetaGraphT *graph, const std::string &pattern_name,
+ const std::unordered_map> &matched_path) {
+ MS_CHECK_TRUE_RET(graph != nullptr, RET_NULL_PTR);
+ if (matched_path.size() != opt::kMatchPathLenThree) {
+ MS_LOG(ERROR) << "Reshape-Gather-Reshape-Fusion should have three NodeIndex in matchedPair";
+ return RET_PARAM_INVALID;
+ }
+
+ size_t reshape1_index = 0;
+ STATUS ret = opt::GetMatchNodeIndex(graph, matched_path, std::string(Reshape1Name), &reshape1_index);
+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "cannot get reshape1_index");
+ auto &reshape1_node = graph->nodes.at(reshape1_index);
+ MS_CHECK_TRUE_MSG(reshape1_node != nullptr, RET_NULL_PTR, "reshape1_node is nullptr");
+ size_t gather_index = 0;
+ ret = opt::GetMatchNodeIndex(graph, matched_path, std::string(GatherName), &gather_index);
+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "cannot get gather_index");
+ auto &gather_node = graph->nodes.at(gather_index);
+ MS_CHECK_TRUE_MSG(gather_node != nullptr, RET_NULL_PTR, "gather_node is nullptr");
+ size_t reshape2_index = 0;
+ ret = opt::GetMatchNodeIndex(graph, matched_path, std::string(Reshape2Name), &reshape2_index);
+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "cannot get reshape2_index");
+ auto &reshape2_node = graph->nodes.at(reshape2_index);
+ MS_CHECK_TRUE_MSG(reshape2_node != nullptr, RET_NULL_PTR, "reshape2_node is nullptr");
+
+ if (reshape1_node->inputIndex.size() != opt::kInputSizeTwo ||
+ reshape1_node->outputIndex.size() != opt::kOutputSizeOne ||
+ reshape1_node->quantType == schema::QuantType_QUANT_ALL ||
+ reshape1_node->quantType == schema::QuantType_QUANT_DYNAMIC ||
+ reshape2_node->inputIndex.size() != opt::kInputSizeTwo ||
+ reshape2_node->outputIndex.size() != opt::kOutputSizeOne ||
+ reshape2_node->quantType == schema::QuantType_QUANT_ALL ||
+ reshape2_node->quantType == schema::QuantType_QUANT_DYNAMIC ||
+ gather_node->quantType == schema::QuantType_QUANT_ALL ||
+ gather_node->quantType == schema::QuantType_QUANT_DYNAMIC) {
+ MS_LOG(ERROR) << "reshape_node cannot fusion";
+ return RET_NO_CHANGE;
+ }
+
+ auto old_shape = graph->allTensors.at(reshape2_node->outputIndex.at(opt::kOutputIndexZero))->dims;
+ auto gather_shape0 = graph->allTensors.at(gather_node->inputIndex.at(opt::kInputIndexZero))->dims;
+ auto gather_shape1 = graph->allTensors.at(reshape1_node->inputIndex.at(opt::kInputIndexZero))->dims;
+ if (old_shape.empty() || gather_shape0.empty() || gather_shape1.empty()) {
+ return RET_NO_CHANGE;
+ }
+ int gather_axis;
+ auto data = graph->allTensors.at(gather_node->inputIndex.at(opt::kInputIndexTwo))->data;
+ if (data.empty()) {
+ gather_axis = 0;
+ } else {
+ memcpy(&gather_axis, &data[0], data.size());
+ }
+ if (gather_axis < 0) {
+ gather_axis += gather_shape1.size();
+ }
+ gather_shape0.erase(gather_shape0.begin() + gather_axis);
+ (void)gather_shape0.insert(gather_shape0.begin() + gather_axis, gather_shape1.begin(), gather_shape1.end());
+ if (gather_shape0 != old_shape) {
+ return RET_NO_CHANGE;
+ }
+
+ gather_node->inputIndex.at(opt::kInputIndexOne) = reshape1_node->inputIndex.at(opt::kInputIndexZero);
+ gather_node->outputIndex.at(opt::kOutputIndexZero) = reshape2_node->outputIndex.at(opt::kOutputIndexZero);
+
+ // cannot delete node here, otherwise will destroy order in other pattern's node index
+ // make it an isolated node to be removed in IsolatedNodeRemovePass
+ reshape1_node->inputIndex.clear();
+ reshape1_node->outputIndex.clear();
+ reshape2_node->inputIndex.clear();
+ reshape2_node->outputIndex.clear();
+ return RET_OK;
+}
+
+ReshapeGatherReshapeFusionPass::~ReshapeGatherReshapeFusionPass() = default;
+} // namespace lite
+} // namespace mindspore
diff --git a/mindspore/lite/src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.h b/mindspore/lite/src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.h
new file mode 100644
index 00000000..ef184a3c
--- /dev/null
+++ b/mindspore/lite/src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.h
@@ -0,0 +1,42 @@
+/**
+ * Copyright 2022 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_RESHAPE_GATHER_RESHAPE_FUSION_PASS_H_
+#define MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_RESHAPE_GATHER_RESHAPE_FUSION_PASS_H_
+
+#include
+#include
+#include
+#include
+#include
+#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h"
+
+namespace mindspore {
+namespace lite {
+class ReshapeGatherReshapeFusionPass : public FusionPass {
+ public:
+ ReshapeGatherReshapeFusionPass() = default;
+
+ ~ReshapeGatherReshapeFusionPass() override;
+
+ STATUS DefinePattern() override;
+
+ STATUS DoFusion(MetaGraphT *graph, const std::string &pattern_name,
+ const std::unordered_map> &matched_path) override;
+};
+} // namespace lite
+} // namespace mindspore
+#endif // MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_RESHAPE_GATHER_RESHAPE_FUSION_PASS_H_
diff --git a/mindspore/lite/src/train/train_export.cc b/mindspore/lite/src/train/train_export.cc
index a9990963..7e504c4e 100644
--- a/mindspore/lite/src/train/train_export.cc
+++ b/mindspore/lite/src/train/train_export.cc
@@ -612,8 +612,40 @@ int TrainExport::SaveModel(lite::Model *model, const std::string &file_name) {
return status;
}
+int TrainExport::SaveModel(lite::Model *model, Buffer *model_buffer) {
+ MS_CHECK_FALSE_MSG(model == nullptr, RET_ERROR, "model cannot be empty.");
+ MS_CHECK_FALSE_MSG(model_buffer == nullptr, RET_ERROR, "model_buffer cannot be empty.");
+ auto *liteModel = reinterpret_cast(model);
+ auto size = liteModel->buf_size_;
+ model_buffer->ResizeData(size);
+
+ size_t out_size = model_buffer->DataSize();
+ int status = mindspore::lite::Model::Export(model, static_cast(model_buffer->MutableData()), &out_size);
+ if (out_size != size) {
+ MS_LOG(ERROR) << "model_buffer resize failed.";
+ return RET_ERROR;
+ }
+
+ return status;
+}
+
int TrainExport::SaveToFile() { return Storage::Save(*meta_graph_, file_name_); }
+int TrainExport::SaveToBuffer() {
+ constexpr size_t kFbBuilderInitSize = 1024;
+ flatbuffers::FlatBufferBuilder builder(kFbBuilderInitSize);
+ auto offset = schema::MetaGraph::Pack(builder, meta_graph_);
+ builder.Finish(offset);
+ schema::FinishMetaGraphBuffer(builder, offset);
+ size_t size = builder.GetSize();
+ auto content = builder.GetBufferPointer();
+ MS_CHECK_FALSE_MSG(content == nullptr, RET_ERROR, "context cannot be empty.");
+ MS_CHECK_FALSE_MSG(model_buffer_ == nullptr, RET_ERROR, "context cannot be empty.");
+ model_buffer_->SetData(content, size);
+ return RET_OK;
+}
+
+
bool TrainExport::IsInputTensor(const schema::TensorT &t) {
int total_dims = std::accumulate(t.dims.begin(), t.dims.end(), 1, std::multiplies());
return ((t.data.size() == 0) && (total_dims != 0));
diff --git a/mindspore/lite/src/train/train_export.h b/mindspore/lite/src/train/train_export.h
index 7727109b..8e802021 100644
--- a/mindspore/lite/src/train/train_export.h
+++ b/mindspore/lite/src/train/train_export.h
@@ -44,24 +44,28 @@ struct tensor_info {
class TrainExport {
public:
explicit TrainExport(const std::string file_name) : file_name_(file_name) {}
+ explicit TrainExport(Buffer *model_buffer) : model_buffer_(model_buffer) {}
virtual ~TrainExport();
int ExportNet(const std::vector &kernels,
const std::vector &tensors, const std::vector &output_names,
const Model *model, QuantizationType quant_type, const Model *bb_model = nullptr);
int ExportInit(const std::string model_name, std::string version);
int SaveToFile();
+ int SaveToBuffer();
void set_connect(const std::unordered_map &map) { connect_ = map; }
int LoadModel(void *buf, size_t buf_size);
int AddTransformNode();
int TrainModelFusion();
int TrainModelDrop();
int SaveModel(lite::Model *model, const std::string &file_name);
+ int SaveModel(lite::Model *model, Buffer *model_buffer);
protected:
virtual std::vector CreateData(const mindspore::lite::Tensor *tensor);
private:
std::string file_name_;
+ Buffer *model_buffer_ = nullptr;
schema::MetaGraphT *meta_graph_ = nullptr;
std::vector out_idx_;
std::map remap_;
diff --git a/mindspore/lite/src/train/train_populate_parameter.cc b/mindspore/lite/src/train/train_populate_parameter.cc
index bda5d0a5..9874a30d 100644
--- a/mindspore/lite/src/train/train_populate_parameter.cc
+++ b/mindspore/lite/src/train/train_populate_parameter.cc
@@ -31,6 +31,8 @@
#include "nnacl/fp32_grad/smooth_l1_loss.h"
#include "nnacl/fp32_grad/resize_grad_parameter.h"
#include "nnacl/fp32_grad/lstm_grad_fp32.h"
+#include "nnacl/fp32_grad/binary_cross_entropy.h"
+#include "nnacl/fp32_grad/binary_cross_entropy_grad.h"
using mindspore::lite::Registry;
@@ -88,29 +90,44 @@ OpParameter *PopulateApplyMomentumParameter(const void *prim) {
}
OpParameter *PopulateBCEParameter(const void *prim) {
- int32_t *reduction = reinterpret_cast(malloc(sizeof(int32_t)));
- if (reduction == nullptr) {
- MS_LOG(ERROR) << "malloc reduction failed.";
- return nullptr;
- }
auto primitive = static_cast(prim);
+ MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_BinaryCrossEntropy();
- MS_ASSERT(value != nullptr);
- *reduction = value->reduction();
- return reinterpret_cast(reduction);
+ if (value == nullptr) {
+ MS_LOG(ERROR) << "value is nullptr";
+ return nullptr;
+ }
+
+ auto *param = reinterpret_cast(malloc(sizeof(BinaryCrossEntropyParameter)));
+ if (param == nullptr) {
+ MS_LOG(ERROR) << "malloc BinaryCrossEntropy Parameter failed.";
+ return nullptr;
+ }
+ memset(param, 0, sizeof(BinaryCrossEntropyParameter));
+
+ param->op_parameter_.type_ = primitive->value_type();
+ param->reduction = value->reduction();
+ return reinterpret_cast(param);
}
OpParameter *PopulateBCEGradParameter(const void *prim) {
- int32_t *reduction = reinterpret_cast(malloc(sizeof(int32_t)));
- if (reduction == nullptr) {
- MS_LOG(ERROR) << "malloc reduction failed.";
+ auto *primitive = static_cast(prim);
+ MS_ASSERT(primitive != nullptr);
+ auto value = primitive->value_as_BinaryCrossEntropyGrad();
+ if (value == nullptr) {
+ MS_LOG(ERROR) << "param is nullptr";
return nullptr;
}
- auto primitive = static_cast(prim);
- auto value = primitive->value_as_BinaryCrossEntropyGrad();
- MS_ASSERT(value != nullptr);
- *reduction = value->reduction();
- return reinterpret_cast(reduction);
+ auto *param = reinterpret_cast(malloc(sizeof(BinaryCrossEntropyGradParameter)));
+ if (param == nullptr) {
+ MS_LOG(ERROR) << "malloc BinaryCrossEntropyGrad Parameter failed.";
+ return nullptr;
+ }
+ memset(param, 0, sizeof(BinaryCrossEntropyGradParameter));
+
+ param->op_parameter_.type_ = primitive->value_type();
+ param->reduction = value->reduction();
+ return reinterpret_cast(param);
}
OpParameter *PopulateAdamParameter(const void *prim) {
diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc
index cd1af1b6..b40ff8c2 100644
--- a/mindspore/lite/src/train/train_session.cc
+++ b/mindspore/lite/src/train/train_session.cc
@@ -206,10 +206,11 @@ static int ReshapeWeightTensor(Tensor *orig_tensor, lite::Tensor *new_tensor) {
}
}
- orig_tensor->FreeData();
- orig_tensor->set_data(nullptr);
- orig_tensor->set_shape(new_tensor->shape());
-
+ if (orig_tensor->shape() != new_tensor->shape()) {
+ orig_tensor->FreeData();
+ orig_tensor->set_data(nullptr);
+ orig_tensor->set_shape(new_tensor->shape());
+ }
uint8_t *dst_data = reinterpret_cast(orig_tensor->MutableData());
if (dst_data == nullptr) {
MS_LOG(ERROR) << "Allocation of Data Failed";
@@ -228,6 +229,9 @@ int TrainSession::UpdateWeights(std::vector modify_tensors) {
return RET_PARAM_INVALID;
}
if (modify->tensor_name() == tensor->tensor_name()) {
+ if (tensor->Size() != modify->Size()) {
+ model_buff_changed_ = true;
+ }
auto ret = ReshapeWeightTensor(tensor, modify);
num_of_found_tensors++;
if (ret != RET_OK) {
@@ -243,6 +247,7 @@ int TrainSession::UpdateWeights(std::vector modify_tensors) {
}
auto ret = ReSizeKernels(kernels_);
if (ret != RET_OK) {
+ model_buff_changed_ = false;
MS_LOG(ERROR) << "Resize kernels fail!";
return ret;
}
@@ -1154,9 +1159,17 @@ int TrainSession::FindExportKernels(std::vector *export_ke
return RET_OK;
}
-int TrainSession::Export(const std::string &file_name, ModelType model_type, QuantizationType quant_type,
- FormatType format, std::vector out_put_tensor_name) {
- MS_CHECK_FALSE_MSG(file_name.empty(), RET_ERROR, "File name cannot be empty");
+template
+int TrainSession::ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type,
+ FormatType format, std::vector out_put_tensor_name) {
+ if constexpr (std::is_same_v) {
+ MS_CHECK_FALSE_MSG(destination.empty(), RET_ERROR, "File name cannot be empty");
+ } else if constexpr (std::is_same_v) {
+ MS_CHECK_FALSE_MSG(destination == nullptr, RET_ERROR, "model buffer cannot be nullptr");
+ } else {
+ MS_LOG(ERROR) << "Unsupported destination.";
+ return RET_ERROR;
+ }
MS_CHECK_FALSE_MSG(model_type > mindspore::lite::MT_INFERENCE || model_type < mindspore::lite::MT_TRAIN, RET_ERROR,
"Export model type parameter error");
MS_CHECK_FALSE_MSG(quant_type < mindspore::lite::QT_DEFAULT || quant_type > mindspore::lite::QT_WEIGHT, RET_ERROR,
@@ -1165,27 +1178,21 @@ int TrainSession::Export(const std::string &file_name, ModelType model_type, Qua
bool orig_train_state = IsTrain();
Eval();
- TrainExport texport(file_name);
+ TrainExport texport(destination);
int status = texport.ExportInit(model_.get()->graph_.name_, model_.get()->graph_.version_);
- if (status != RET_OK) {
- MS_LOG(ERROR) << "cannot init export";
- return status;
- }
+ TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to init export");
if (!out_put_tensor_name.empty() && model_type == MT_INFERENCE) {
std::vector export_kernels = {};
status = FindExportKernels(&export_kernels, out_put_tensor_name, inference_kernels_);
- if (status != RET_OK) {
- MS_LOG(ERROR) << "FindExportKernels failed.";
- return RET_ERROR;
- }
+ TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "FindExportKernels failed.");
status = texport.ExportNet(export_kernels, tensors_, out_put_tensor_name, model_.get(), quant_type);
} else {
- if ((quant_type == QT_NONE) && (model_type == MT_TRAIN) &&
+ if ((!model_buff_changed_) && (quant_type == QT_NONE) && (model_type == MT_TRAIN) &&
std::all_of(model_->graph_.all_nodes_.begin(), model_->graph_.all_nodes_.end(), [](const LiteGraph::Node *n) {
return n->quant_type_ == schema::QuantType::QuantType_QUANT_NONE;
})) {
- status = texport.SaveModel(model_.get(), file_name);
+ status = texport.SaveModel(model_.get(), destination);
if (orig_train_state) Train();
return status;
} else {
@@ -1194,35 +1201,42 @@ int TrainSession::Export(const std::string &file_name, ModelType model_type, Qua
model_.get(), quant_type);
}
}
+ TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to export Network.");
- if (status != RET_OK) {
- MS_LOG(ERROR) << "cannot export Network";
- return status;
- }
if (model_type == MT_INFERENCE) {
status = texport.TrainModelDrop();
- if (status != RET_OK) {
- MS_LOG(ERROR) << "TrainModelDrop failed.";
- return status;
- }
+ TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelDrop failed.");
status = texport.TrainModelFusion();
+ TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelFusion failed.");
+ }
+ if constexpr (std::is_same_v) {
+ status = texport.SaveToFile();
if (status != RET_OK) {
- MS_LOG(ERROR) << "TrainModelFusion failed.";
+ MS_LOG(ERROR) << "failed to save to " << destination;
return status;
}
- }
- status = texport.SaveToFile();
- if (status != RET_OK) {
- MS_LOG(ERROR) << "failed to save to " << file_name;
- return status;
+ } else {
+ status = texport.SaveToBuffer();
+ TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "fail to save to model buffer.");
}
if (orig_train_state) Train();
return status;
}
+
+int TrainSession::Export(const std::string &file_name, ModelType model_type, QuantizationType quant_type,
+ FormatType format, std::vector out_put_tensor_name) {
+ return ExportInner(file_name, model_type, quant_type, format, out_put_tensor_name);
+}
+
+int TrainSession::Export(Buffer *model_buffer, ModelType model_type, QuantizationType quant_type, FormatType format,
+ std::vector out_put_tensor_name) {
+ return ExportInner(model_buffer, model_type, quant_type, format, out_put_tensor_name);
+}
+
std::vector TrainSession::GetFeatureMaps() const {
std::vector features;
for (auto cur_tensor : this->tensors_) {
- if (cur_tensor->IsConst() && cur_tensor->data_type() == kNumberTypeFloat32) {
+ if (cur_tensor->category() ==lite::Category::CONST_TENSOR && cur_tensor->data_type() == kNumberTypeFloat32) {
features.push_back(cur_tensor);
}
}
diff --git a/mindspore/lite/src/train/train_session.h b/mindspore/lite/src/train/train_session.h
index 0a0ce640..5acff82a 100644
--- a/mindspore/lite/src/train/train_session.h
+++ b/mindspore/lite/src/train/train_session.h
@@ -36,6 +36,14 @@
+-------------------------------+
*/
+#define TRAIN_SESSION_CHECK_FALSE_MSG(value, errcode, msg) \
+ do { \
+ if ((value)) { \
+ MS_LOG(ERROR) << #msg; \
+ return errcode; \
+ } \
+ } while (0)
+
namespace mindspore {
namespace lite {
using CreatorOp = std::tuple;
@@ -96,6 +104,8 @@ class TrainSession : virtual public lite::LiteSession {
}
int Export(const std::string &fb_name, ModelType model_type, QuantizationType quant_type, FormatType,
std::vector out_put_tensor_name = {}) override;
+ int Export(Buffer *model_buffer, ModelType model_type, QuantizationType quant_type, FormatType,
+ std::vector out_put_tensor_name = {}) override;
std::vector GetFeatureMaps() const override;
@@ -165,6 +175,9 @@ class TrainSession : virtual public lite::LiteSession {
const std::unordered_map &offset_map,
std::unordered_map *ref_count, uint32_t input_idx);
+ template
+ int ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type, FormatType,
+ std::vector out_put_tensor_name = {});
std::map restored_origin_tensors_;
int virtual_batch_idx_ = 0;
int virtual_batch_multiplier_ = 0;
@@ -172,6 +185,7 @@ class TrainSession : virtual public lite::LiteSession {
void *workspace_ = nullptr;
SchedCallBack sched_mix_precision_callback_;
bool train_mode_ = false;
+ bool model_buff_changed_ = false;
void *tensors_data_ = nullptr;
size_t tensors_data_size_ = 0;
std::shared_ptr allocator_;
diff --git a/mindspore/lite/src/train/transfer_session.cc b/mindspore/lite/src/train/transfer_session.cc
index b54f348e..031c4a6b 100644
--- a/mindspore/lite/src/train/transfer_session.cc
+++ b/mindspore/lite/src/train/transfer_session.cc
@@ -183,15 +183,24 @@ std::unordered_map TransferSession::ConnectionMap() {
return map;
}
-int TransferSession::Export(const std::string &filename, ModelType model_type, QuantizationType quant_type,
- FormatType format, std::vector out_put_tensor_name) {
+template
+int TransferSession::ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type,
+ FormatType format, std::vector out_put_tensor_name) {
+ if constexpr (std::is_same_v) {
+ MS_CHECK_FALSE_MSG(destination.empty(), RET_ERROR, "File name cannot be empty");
+ } else if constexpr (std::is_same_v) {
+ MS_CHECK_FALSE_MSG(destination == nullptr, RET_ERROR, "model buffer cannot be nullptr");
+ } else {
+ MS_LOG(ERROR) << "Unsupported destination.";
+ return RET_ERROR;
+ }
if (format != FT_FLATBUFFERS) {
MS_LOG(ERROR) << "Currently only flatbuffer format is supported";
return RET_ERROR;
}
if (model_type == MT_TRAIN) {
- return TrainSession::Export(filename, model_type, quant_type, format);
+ return TrainSession::Export(destination, model_type, quant_type, format);
}
bool orig_train_state = IsTrain();
@@ -199,7 +208,7 @@ int TransferSession::Export(const std::string &filename, ModelType model_type, Q
MS_LOG(ERROR) << "eval failed.";
return RET_ERROR;
}
- TrainExport texport(filename);
+ TrainExport texport(destination);
int status = texport.LoadModel(lite_model_, size_backbone_);
if (status != RET_OK) {
MS_LOG(ERROR) << "cannot init export";
@@ -231,10 +240,15 @@ int TransferSession::Export(const std::string &filename, ModelType model_type, Q
MS_LOG(ERROR) << "cannot serialize head";
return status;
}
- status = texport.SaveToFile();
- if (status != RET_OK) {
- MS_LOG(ERROR) << "failed to save to " << filename;
- return status;
+ if constexpr (std::is_same_v) {
+ status = texport.SaveToFile();
+ if (status != RET_OK) {
+ MS_LOG(ERROR) << "failed to save to " << destination;
+ return status;
+ }
+ } else {
+ status = texport.SaveToBuffer();
+ MS_CHECK_FALSE_MSG(status != RET_OK, status, "fail to save to model buffer.");
}
if (orig_train_state) {
auto ret = Train();
@@ -246,6 +260,17 @@ int TransferSession::Export(const std::string &filename, ModelType model_type, Q
return status;
}
+int TransferSession::Export(const std::string &filename, ModelType model_type, QuantizationType quant_type,
+ FormatType format, std::vector out_put_tensor_name) {
+ return ExportInner(filename, model_type, quant_type, format, out_put_tensor_name);
+}
+
+int TransferSession::Export(Buffer *model_buffer, ModelType model_type, QuantizationType quant_type, FormatType format,
+ std::vector out_put_tensor_name) {
+ return ExportInner(model_buffer, model_type, quant_type, format, out_put_tensor_name);
+}
+
+
lite::LiteSession *CreateTransferSessionInt(const char *model_buf_backbone, size_t size_backbone,
const char *model_buf_head, size_t size_head, const lite::Context *context,
bool train_mode, const lite::TrainCfg *cfg) {
diff --git a/mindspore/lite/src/train/transfer_session.h b/mindspore/lite/src/train/transfer_session.h
index 48a38b8b..6cd06c60 100644
--- a/mindspore/lite/src/train/transfer_session.h
+++ b/mindspore/lite/src/train/transfer_session.h
@@ -63,6 +63,8 @@ class TransferSession : public lite::TrainSession {
int CompileTransferGraph();
int Export(const std::string &fb_name, ModelType model_type, QuantizationType quant_type, FormatType,
std::vector out_put_tensor_name = {}) override;
+ int Export(Buffer *model_buffer, ModelType model_type, QuantizationType quant_type, FormatType,
+ std::vector out_put_tensor_name = {}) override;
protected:
LiteSession *backbone_session_ = nullptr;
@@ -72,6 +74,9 @@ class TransferSession : public lite::TrainSession {
bool is_valid_ = false;
private:
+ template
+ int ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type, FormatType,
+ std::vector out_put_tensor_name = {});
bool CompileFormatTransform(lite::Tensor *out, lite::Tensor *in, int *mask, size_t mask_len);
std::unordered_map ConnectionMap();
bool nchw2nhwc_ = false;
diff --git a/mindspore/lite/tools/benchmark_train/net_train.cc b/mindspore/lite/tools/benchmark_train/net_train.cc
index cad52545..20d2d298 100644
--- a/mindspore/lite/tools/benchmark_train/net_train.cc
+++ b/mindspore/lite/tools/benchmark_train/net_train.cc
@@ -1,5 +1,5 @@
/**
- * Copyright 2020 Huawei Technologies Co., Ltd
+ * Copyright 2020-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -42,6 +42,21 @@ constexpr int kField4 = 4;
constexpr int kFieldsToPrint = 5;
constexpr int kPrintOffset = 4;
static const int kTHOUSAND = 1000;
+constexpr int kDumpInputsAndOutputs = 0;
+constexpr int kDumpOutputs = 2;
+
+const std::unordered_map