From f0daa7ef13e1741f8bcd1dfad7517a4a8ae4a209 Mon Sep 17 00:00:00 2001 From: xuanyue Date: Thu, 21 Mar 2024 19:38:34 +0800 Subject: [PATCH] DynamicQuant strategy opyimization --- .../kernel/nnacl/dynamic_quant_parameter.h | 7 +- mindspore/core/ops/dynamic_quant.cc | 12 + mindspore/core/ops/dynamic_quant.h | 10 + mindspore/core/ops/op_name.h | 1 + mindspore/lite/schema/inner/ops_generated.h | 53 +++- mindspore/lite/schema/ops.fbs | 1 + mindspore/lite/schema/ops_generated.h | 34 +- mindspore/lite/src/common/ops/ops_def.cc | 1 + .../ops/populate/dynamic_quant_populate.cc | 24 +- .../litert/kernel/cpu/int8/dynamic_quant.cc | 299 +++++++++++------- .../litert/kernel/cpu/int8/dynamic_quant.h | 59 ++-- .../cpu/int8/matmul_dynamic_base_int8.cc | 43 ++- .../cpu/int8/matmul_dynamic_base_int8.h | 7 +- .../quantizer/insert_quant_node_manager.cc | 27 +- .../quantizer/insert_quant_node_manager.h | 5 +- 15 files changed, 395 insertions(+), 188 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/dynamic_quant_parameter.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/dynamic_quant_parameter.h index aaabe041..1fc166cb 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/dynamic_quant_parameter.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/dynamic_quant_parameter.h @@ -21,10 +21,9 @@ typedef struct DynamicQuantParameter { OpParameter op_parameter_; bool symmetric_; - int64_t dst_type_; - bool activation_perchannel_; - int64_t prefer_axis_; - bool transpose_; + int dst_type_; + int axis_num_; + int prefer_axes_[MAX_SHAPE_SIZE]; } DynamicQuantParameter; #endif // NNACL_DYNAMIC_QUANT_PARAMETER_H_ diff --git a/mindspore/core/ops/dynamic_quant.cc b/mindspore/core/ops/dynamic_quant.cc index 63ea0be5..1949f809 100644 --- a/mindspore/core/ops/dynamic_quant.cc +++ b/mindspore/core/ops/dynamic_quant.cc @@ -48,6 +48,18 @@ bool DynamicQuant::get_transpose() const { auto value_ptr = this->GetAttr(kTrans); return GetValue(value_ptr); } + +void DynamicQuant::set_prefer_axes(const std::vector &prefer_axes) { + (void)AddAttr(kPreferAxes, api::MakeValue(prefer_axes)); +} + +std::vector DynamicQuant::get_prefer_axes() const { + auto value_ptr = GetAttr(kPreferAxes); + auto tmp = GetValue>(value_ptr); + std::vector res(tmp.begin(), tmp.end()); + return res; +} + void DynamicQuant::Init(const bool symmetric, const int64_t dst_type) { this->set_symmetric(symmetric); this->set_dst_type(dst_type); diff --git a/mindspore/core/ops/dynamic_quant.h b/mindspore/core/ops/dynamic_quant.h index 4cb446c3..963dfb37 100644 --- a/mindspore/core/ops/dynamic_quant.h +++ b/mindspore/core/ops/dynamic_quant.h @@ -91,6 +91,16 @@ class MIND_API DynamicQuant : public BaseOperator { /// /// \return Whether transpose matrix. bool get_transpose() const; + + /// \brief Method to set prefer_axis attribute. + /// + /// \param[in] prefer_axis Define the preferred axis. + void set_prefer_axes(const std::vector &prefer_axes); + + /// \brief Method to get prefer_axis attribute. + /// + /// \return the preferred axis. + std::vector get_prefer_axes() const; }; MIND_API abstract::AbstractBasePtr DynamicQuantInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args); diff --git a/mindspore/core/ops/op_name.h b/mindspore/core/ops/op_name.h index ad9066e7..1282e6ea 100644 --- a/mindspore/core/ops/op_name.h +++ b/mindspore/core/ops/op_name.h @@ -410,6 +410,7 @@ constexpr auto KCurrChunkIndex = "curr_chunk_index"; constexpr auto KCurrBitCount = "curr_bit_count"; constexpr auto KTableLog = "table_log"; constexpr auto kIgnoreIndex = "ignore_index"; +constexpr auto kPreferAxes = "prefer_axes"; constexpr size_t kInputIndex0 = 0; constexpr size_t kInputIndex1 = 1; diff --git a/mindspore/lite/schema/inner/ops_generated.h b/mindspore/lite/schema/inner/ops_generated.h index 6c861aa5..b595f4b2 100644 --- a/mindspore/lite/schema/inner/ops_generated.h +++ b/mindspore/lite/schema/inner/ops_generated.h @@ -19790,6 +19790,7 @@ struct DynamicQuantT : public flatbuffers::NativeTable { bool activation_channel = false; int64_t prefer_axis = 0; bool transpose = false; + std::vector prefer_axes{}; }; struct DynamicQuant FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -19803,7 +19804,8 @@ struct DynamicQuant FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_DST_TYPE = 6, VT_ACTIVATION_CHANNEL = 8, VT_PREFER_AXIS = 10, - VT_TRANSPOSE = 12 + VT_TRANSPOSE = 12, + VT_PREFER_AXES = 14 }; bool symmetric() const { return GetField(VT_SYMMETRIC, 0) != 0; @@ -19835,6 +19837,12 @@ struct DynamicQuant FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { bool mutate_transpose(bool _transpose) { return SetField(VT_TRANSPOSE, static_cast(_transpose), 0); } + const flatbuffers::Vector *prefer_axes() const { + return GetPointer *>(VT_PREFER_AXES); + } + flatbuffers::Vector *mutable_prefer_axes() { + return GetPointer *>(VT_PREFER_AXES); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_SYMMETRIC) && @@ -19842,6 +19850,8 @@ struct DynamicQuant FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyField(verifier, VT_ACTIVATION_CHANNEL) && VerifyField(verifier, VT_PREFER_AXIS) && VerifyField(verifier, VT_TRANSPOSE) && + VerifyOffset(verifier, VT_PREFER_AXES) && + verifier.VerifyVector(prefer_axes()) && verifier.EndTable(); } DynamicQuantT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -19868,6 +19878,9 @@ struct DynamicQuantBuilder { void add_transpose(bool transpose) { fbb_.AddElement(DynamicQuant::VT_TRANSPOSE, static_cast(transpose), 0); } + void add_prefer_axes(flatbuffers::Offset> prefer_axes) { + fbb_.AddOffset(DynamicQuant::VT_PREFER_AXES, prefer_axes); + } explicit DynamicQuantBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -19885,16 +19898,37 @@ inline flatbuffers::Offset CreateDynamicQuant( int64_t dst_type = 32LL, bool activation_channel = false, int64_t prefer_axis = 0, - bool transpose = false) { + bool transpose = false, + flatbuffers::Offset> prefer_axes = 0) { DynamicQuantBuilder builder_(_fbb); builder_.add_prefer_axis(prefer_axis); builder_.add_dst_type(dst_type); + builder_.add_prefer_axes(prefer_axes); builder_.add_transpose(transpose); builder_.add_activation_channel(activation_channel); builder_.add_symmetric(symmetric); return builder_.Finish(); } +inline flatbuffers::Offset CreateDynamicQuantDirect( + flatbuffers::FlatBufferBuilder &_fbb, + bool symmetric = false, + int64_t dst_type = 32LL, + bool activation_channel = false, + int64_t prefer_axis = 0, + bool transpose = false, + const std::vector *prefer_axes = nullptr) { + auto prefer_axes__ = prefer_axes ? _fbb.CreateVector(*prefer_axes) : 0; + return mindspore::schema::CreateDynamicQuant( + _fbb, + symmetric, + dst_type, + activation_channel, + prefer_axis, + transpose, + prefer_axes__); +} + flatbuffers::Offset CreateDynamicQuant(flatbuffers::FlatBufferBuilder &_fbb, const DynamicQuantT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct LSTMGradDataT : public flatbuffers::NativeTable { @@ -26903,6 +26937,7 @@ inline void DynamicQuant::UnPackTo(DynamicQuantT *_o, const flatbuffers::resolve { auto _e = activation_channel(); _o->activation_channel = _e; } { auto _e = prefer_axis(); _o->prefer_axis = _e; } { auto _e = transpose(); _o->transpose = _e; } + { auto _e = prefer_axes(); if (_e) { _o->prefer_axes.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->prefer_axes[_i] = _e->Get(_i); } } } } inline flatbuffers::Offset DynamicQuant::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DynamicQuantT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -26918,13 +26953,15 @@ inline flatbuffers::Offset CreateDynamicQuant(flatbuffers::FlatBuf auto _activation_channel = _o->activation_channel; auto _prefer_axis = _o->prefer_axis; auto _transpose = _o->transpose; + auto _prefer_axes = _o->prefer_axes.size() ? _fbb.CreateVector(_o->prefer_axes) : 0; return mindspore::schema::CreateDynamicQuant( _fbb, _symmetric, _dst_type, _activation_channel, _prefer_axis, - _transpose); + _transpose, + _prefer_axes); } inline LSTMGradDataT *LSTMGradData::UnPack(const flatbuffers::resolver_function_t *_resolver) const { @@ -33509,10 +33546,11 @@ inline const flatbuffers::TypeTable *LSTMTypeTable() { { flatbuffers::ET_LONG, 0, -1 }, { flatbuffers::ET_FLOAT, 0, -1 }, { flatbuffers::ET_FLOAT, 0, -1 }, - { flatbuffers::ET_FLOAT, 0, -1 } + { flatbuffers::ET_FLOAT, 0, -1 }, + { flatbuffers::ET_LONG, 0, -1 } }; static const flatbuffers::TypeTable tt = { - flatbuffers::ST_TABLE, 9, type_codes, nullptr, nullptr, nullptr, nullptr + flatbuffers::ST_TABLE, 10, type_codes, nullptr, nullptr, nullptr, nullptr }; return &tt; } @@ -34744,10 +34782,11 @@ inline const flatbuffers::TypeTable *DynamicQuantTypeTable() { { flatbuffers::ET_LONG, 0, -1 }, { flatbuffers::ET_BOOL, 0, -1 }, { flatbuffers::ET_LONG, 0, -1 }, - { flatbuffers::ET_BOOL, 0, -1 } + { flatbuffers::ET_BOOL, 0, -1 }, + { flatbuffers::ET_INT, 1, -1 } }; static const flatbuffers::TypeTable tt = { - flatbuffers::ST_TABLE, 5, type_codes, nullptr, nullptr, nullptr, nullptr + flatbuffers::ST_TABLE, 6, type_codes, nullptr, nullptr, nullptr, nullptr }; return &tt; } diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 920c0d31..153a21d0 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -1250,6 +1250,7 @@ table DynamicQuant { activation_channel: bool = false; prefer_axis: long = 0; transpose: bool = false; + prefer_axes: [int]; } table LSTMGradData { diff --git a/mindspore/lite/schema/ops_generated.h b/mindspore/lite/schema/ops_generated.h index 8d387e9d..d2d89bff 100644 --- a/mindspore/lite/schema/ops_generated.h +++ b/mindspore/lite/schema/ops_generated.h @@ -13118,7 +13118,8 @@ struct DynamicQuant FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_DST_TYPE = 6, VT_ACTIVATION_CHANNEL = 8, VT_PREFER_AXIS = 10, - VT_TRANSPOSE = 12 + VT_TRANSPOSE = 12, + VT_PREFER_AXES = 14 }; bool symmetric() const { return GetField(VT_SYMMETRIC, 0) != 0; @@ -13135,6 +13136,9 @@ struct DynamicQuant FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { bool transpose() const { return GetField(VT_TRANSPOSE, 0) != 0; } + const flatbuffers::Vector *prefer_axes() const { + return GetPointer *>(VT_PREFER_AXES); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_SYMMETRIC) && @@ -13142,6 +13146,8 @@ struct DynamicQuant FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyField(verifier, VT_ACTIVATION_CHANNEL) && VerifyField(verifier, VT_PREFER_AXIS) && VerifyField(verifier, VT_TRANSPOSE) && + VerifyOffset(verifier, VT_PREFER_AXES) && + verifier.VerifyVector(prefer_axes()) && verifier.EndTable(); } }; @@ -13165,6 +13171,9 @@ struct DynamicQuantBuilder { void add_transpose(bool transpose) { fbb_.AddElement(DynamicQuant::VT_TRANSPOSE, static_cast(transpose), 0); } + void add_prefer_axes(flatbuffers::Offset> prefer_axes) { + fbb_.AddOffset(DynamicQuant::VT_PREFER_AXES, prefer_axes); + } explicit DynamicQuantBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -13182,16 +13191,37 @@ inline flatbuffers::Offset CreateDynamicQuant( int64_t dst_type = 32LL, bool activation_channel = false, int64_t prefer_axis = 0, - bool transpose = false) { + bool transpose = false, + flatbuffers::Offset> prefer_axes = 0) { DynamicQuantBuilder builder_(_fbb); builder_.add_prefer_axis(prefer_axis); builder_.add_dst_type(dst_type); + builder_.add_prefer_axes(prefer_axes); builder_.add_transpose(transpose); builder_.add_activation_channel(activation_channel); builder_.add_symmetric(symmetric); return builder_.Finish(); } +inline flatbuffers::Offset CreateDynamicQuantDirect( + flatbuffers::FlatBufferBuilder &_fbb, + bool symmetric = false, + int64_t dst_type = 32LL, + bool activation_channel = false, + int64_t prefer_axis = 0, + bool transpose = false, + const std::vector *prefer_axes = nullptr) { + auto prefer_axes__ = prefer_axes ? _fbb.CreateVector(*prefer_axes) : 0; + return mindspore::schema::CreateDynamicQuant( + _fbb, + symmetric, + dst_type, + activation_channel, + prefer_axis, + transpose, + prefer_axes__); +} + struct LSTMGradData FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef LSTMGradDataBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { diff --git a/mindspore/lite/src/common/ops/ops_def.cc b/mindspore/lite/src/common/ops/ops_def.cc index baa2497a..1e973362 100644 --- a/mindspore/lite/src/common/ops/ops_def.cc +++ b/mindspore/lite/src/common/ops/ops_def.cc @@ -1254,6 +1254,7 @@ OP_ATTR_WITH_VALUE(dst_type, long, 32) OP_ATTR_WITH_VALUE(activation_channel, bool, false) OP_ATTR_WITH_VALUE(prefer_axis, long, 0) OP_ATTR_WITH_VALUE(transpose, bool, false) +OP_ATTR(prefer_axes, [int]) OP_SCHEMA_DEF_END(DynamicQuant) OP_SCHEMA_DEF(LSTMGradData) diff --git a/mindspore/lite/src/common/ops/populate/dynamic_quant_populate.cc b/mindspore/lite/src/common/ops/populate/dynamic_quant_populate.cc index 3566f082..8e393320 100644 --- a/mindspore/lite/src/common/ops/populate/dynamic_quant_populate.cc +++ b/mindspore/lite/src/common/ops/populate/dynamic_quant_populate.cc @@ -36,11 +36,27 @@ OpParameter *PopulateDynamicQuantParameter(const void *prim) { memset(param, 0, sizeof(DynamicQuantParameter)); param->op_parameter_.type_ = primitive->value_type(); - param->dst_type_ = value->dst_type(); + param->dst_type_ = static_cast(value->dst_type()); param->symmetric_ = value->symmetric(); - param->activation_perchannel_ = value->activation_channel(); - param->prefer_axis_ = value->prefer_axis(); - param->transpose_ = value->transpose(); + auto prefer_axes = value->prefer_axes(); + if (prefer_axes != nullptr) { + param->axis_num_ = static_cast(prefer_axes->size()); + if (param->axis_num_ > MAX_SHAPE_SIZE) { + MS_LOG(ERROR) << "Dynamic quant's prefer_axes's number is more than 8."; + free(param); + return nullptr; + } + for (int i = 0; i < param->axis_num_; ++i) { + param->prefer_axes_[i] = prefer_axes->Get(i); + } + return reinterpret_cast(param); + } + auto activation_channel = value->activation_channel(); + if (!activation_channel) { + return reinterpret_cast(param); + } + param->axis_num_ = 1; + param->prefer_axes_[0] = static_cast(value->prefer_axis()); return reinterpret_cast(param); } REG_POPULATE(PrimitiveType_DynamicQuant, PopulateDynamicQuantParameter, SCHEMA_CUR); diff --git a/mindspore/lite/src/litert/kernel/cpu/int8/dynamic_quant.cc b/mindspore/lite/src/litert/kernel/cpu/int8/dynamic_quant.cc index e9404ef2..acc43c97 100644 --- a/mindspore/lite/src/litert/kernel/cpu/int8/dynamic_quant.cc +++ b/mindspore/lite/src/litert/kernel/cpu/int8/dynamic_quant.cc @@ -14,14 +14,16 @@ * limitations under the License. */ #include "src/litert/kernel/cpu/int8/dynamic_quant.h" +#include #include #include #include "src/litert/kernel_registry.h" #include "schema/model_generated.h" #include "include/errorcode.h" -#include "nnacl/dynamic_quant_parameter.h" #include "nnacl/int8/dynamic_quant_int8.h" #include "nnacl/int8/quant_dtype_cast_int8.h" +#include "nnacl/fp32/transpose_fp32.h" +#include "nnacl/int8/transpose_int8.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; @@ -44,19 +46,10 @@ int DynamicQuantCPUKernel::Prepare() { CHECK_NULL_RETURN(in_tensor); auto out_tensor = out_tensors_.front(); CHECK_NULL_RETURN(out_tensor); - auto param = reinterpret_cast(op_parameter_); - CHECK_NULL_RETURN(param); - src_dtype_ = in_tensor->data_type(); - dst_dtype_ = param->dst_type_; - symmetric_ = param->symmetric_; - activation_perchannel_ = param->activation_perchannel_; - prefer_axis_ = param->prefer_axis_; - transpose_ = param->transpose_; - if (out_tensor->data_type() != dst_dtype_) { - MS_LOG(ERROR) << "param data type and tensor data type do not match."; - return RET_ERROR; - } - + param_ = reinterpret_cast(op_parameter_); + CHECK_NULL_RETURN(param_); + MS_CHECK_TRUE_MSG(param_->dst_type_ == out_tensor->data_type(), lite::RET_ERROR, + "param data type and tensor data type do not match."); if (!InferShapeDone()) { return RET_OK; } @@ -65,71 +58,86 @@ int DynamicQuantCPUKernel::Prepare() { int DynamicQuantCPUKernel::ReSize() { auto in_tensor = in_tensors_.front(); - num_unit_ = static_cast(in_tensor->ElementsNum()); - if (num_unit_ < kMinNums) { - thread_n_num_ = 1; + auto ele_num = static_cast(in_tensor->ElementsNum()); + auto shape = in_tensor->shape(); + int segment_num = 1; + if (param_->axis_num_ == 0) { + segment_num = MSMIN(kBucketNums, ele_num / kMinNums); } else { - thread_n_num_ = MSMIN(thread_num_, num_unit_); - // Limit for 8 thread - thread_n_num_ = MSMIN(thread_n_num_, kBucketNums); + std::set prefer_axes; + for (int i = 0; i < param_->axis_num_; ++i) { + int axis = param_->prefer_axes_[i] < 0 ? param_->prefer_axes_[i] + static_cast(shape.size()) + : param_->prefer_axes_[i]; + MS_CHECK_TRUE_MSG(axis >= 0 && axis < static_cast(shape.size()), lite::RET_ERROR, + "The prefer axis is out of range."); + if (prefer_axes.find(axis) != prefer_axes.end()) { + continue; + } + segment_num *= shape[axis]; + (void)prefer_axes.insert(axis); + } + pre_perm_.resize(shape.size()); + post_perm_.resize(shape.size()); + int pre_point0 = 0; + int pre_point1 = param_->axis_num_; + for (int i = 0; i < static_cast(shape.size()); ++i) { + if (prefer_axes.find(i) != prefer_axes.end()) { + pre_perm_[pre_point0] = i; + post_perm_[i] = pre_point0; + ++pre_point0; + } else { + pre_perm_[pre_point1] = i; + post_perm_[i] = pre_point1; + ++pre_point1; + } + } } - - int min_max_array_size = 0; - if (activation_perchannel_) { - auto dims = in_tensor->shape(); - prefer_axis_ = (prefer_axis_ < 0) ? prefer_axis_ + dims.size() : prefer_axis_; - channel_num_ = dims[prefer_axis_]; - MS_CHECK_GT(channel_num_, 0, RET_ERROR); - scale_ = reinterpret_cast(malloc(channel_num_ * sizeof(float))); - MS_CHECK_TRUE_MSG(scale_ != nullptr, RET_ERROR, "Malloc scale_ failed."); - zero_point_ = reinterpret_cast(malloc(channel_num_ * sizeof(int32_t))); - MS_CHECK_TRUE_MSG(zero_point_ != nullptr, RET_ERROR, "Malloc zero_point_ failed."); - size_t last_axis = dims.size() - 1; - row_length_ = dims[last_axis]; - channel_length_ = num_unit_ / channel_num_; - thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_); - if (!transpose_ && channel_length_ > thread_n_stride_) { - thread_n_num_ = 1; + need_transpose_ = false; + for (size_t i = 0; i < pre_perm_.size(); ++i) { + if (pre_perm_[i] != static_cast(i)) { + need_transpose_ = true; } - min_max_array_size = channel_num_; - } else { - min_max_array_size = kBucketNums; } - real_min_ = reinterpret_cast(malloc(min_max_array_size * sizeof(float))); - real_max_ = reinterpret_cast(malloc(min_max_array_size * sizeof(float))); - if (real_min_ == nullptr || real_max_ == nullptr) { - return RET_NULL_PTR; + if (segment_num <= 0) { + segment_num = 1; } - for (int i = 0; i < min_max_array_size; ++i) { + real_min_.resize(segment_num); + real_max_.resize(segment_num); + scale_.resize(segment_num); + zero_point_.resize(segment_num); + for (int i = 0; i < segment_num; ++i) { real_min_[i] = FLT_MAX; real_max_[i] = -FLT_MAX; } - MS_CHECK_GT(thread_n_num_, 0, RET_ERROR); - thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_); + thread_num_ = MSMIN(segment_num, op_parameter_->thread_num_); + unit_num_ = UP_DIV(ele_num, segment_num); + unit_segment_num_ = UP_DIV(segment_num, thread_num_); return RET_OK; } int DynamicQuantCPUKernel::CalculateMinMax(int task_id) { - int num_unit_thread = MSMIN(thread_n_stride_, num_unit_ - task_id * thread_n_stride_); - if (num_unit_thread <= 0) { - return RET_OK; - } - int thread_offset = task_id * thread_n_stride_; - float *data = float32_ptr_ + thread_offset; - if (activation_perchannel_) { - if (transpose_) { - MS_LOG(INFO) << "attribute transpose is true."; - CalculateChannelColMinMax(data, num_unit_thread, real_min_, real_max_, row_length_); - } else { - int channel_offset = task_id * thread_n_stride_ / channel_length_; - float *real_min = real_min_ + channel_offset; - float *real_max = real_max_ + channel_offset; - CalculateChannelRowMinMax(data, num_unit_thread, real_min, real_max, row_length_); + int task_unit = unit_segment_num_ * unit_num_; + int offset = task_id * task_unit; + int ele_num = static_cast(in_tensors_.front()->ElementsNum()); + int remain = ele_num - offset; + if (task_unit <= remain) { + for (int i = 0; i < unit_segment_num_; ++i) { + CalculateMinMaxFp32(float32_ptr_ + offset + i * unit_num_, unit_num_, &real_min_[task_id * unit_segment_num_ + i], + &real_max_[task_id * unit_segment_num_ + i]); } } else { - float *real_min = real_min_ + task_id; - float *real_max = real_max_ + task_id; - CalculateMinMaxFp32(data, num_unit_thread, real_min, real_max); + int segment_num = remain / unit_num_; + int remain_ele_num = remain - segment_num * unit_num_; + for (int i = 0; i < segment_num; ++i) { + CalculateMinMaxFp32(float32_ptr_ + offset + i * unit_num_, unit_num_, &real_min_[task_id * unit_segment_num_ + i], + &real_max_[task_id * unit_segment_num_ + i]); + } + if (remain_ele_num == 0) { + return RET_OK; + } + CalculateMinMaxFp32(float32_ptr_ + offset + segment_num * unit_num_, remain_ele_num, + &real_min_[task_id * unit_segment_num_ + segment_num], + &real_max_[task_id * unit_segment_num_ + segment_num]); } return RET_OK; } @@ -148,7 +156,7 @@ int CalculateMinMaxRun(void *cdata, int task_id, float, float) { void DynamicQuantCPUKernel::CalculatePerlayerScaleZp() { float real_min = FLT_MAX; float real_max = -FLT_MAX; - for (int i = 0; i < kBucketNums; i++) { + for (size_t i = 0; i < real_min_.size(); ++i) { real_min = (real_min_[i] < real_min) ? real_min_[i] : real_min; real_max = (real_max_[i] > real_max) ? real_max_[i] : real_max; } @@ -158,7 +166,7 @@ void DynamicQuantCPUKernel::CalculatePerlayerScaleZp() { int zp = 0; constexpr int kQSymmetricRange = 255; constexpr int kQAsymmetricRange = 254; - if (!symmetric_) { + if (!param_->symmetric_) { auto range = real_max - real_min; if (range <= 0) { range = kDefaultRange; @@ -175,12 +183,11 @@ void DynamicQuantCPUKernel::CalculatePerlayerScaleZp() { quant_parm.bitNum = k8Bit; quant_parm.inited = true; this->out_tensors_.front()->set_quant_params({quant_parm}); - return; } void DynamicQuantCPUKernel::CalculatePerChannelScaleZp() { std::vector quant_params; - for (int i = 0; i < channel_num_; ++i) { + for (size_t i = 0; i < real_min_.size(); ++i) { float real_min = real_min_[i]; float real_max = real_max_[i]; @@ -189,7 +196,7 @@ void DynamicQuantCPUKernel::CalculatePerChannelScaleZp() { int zp = 0; constexpr int kQSymmetricRange = 255; constexpr int kQAsymmetricRange = 254; - if (!symmetric_) { + if (!param_->symmetric_) { auto range = real_max - real_min; if (range <= 0) { range = kDefaultRange; @@ -208,40 +215,34 @@ void DynamicQuantCPUKernel::CalculatePerChannelScaleZp() { quant_params.push_back(quant_parm); } this->out_tensors_.front()->set_quant_params(quant_params); - return; } + int DynamicQuantCPUKernel::QuantData(int task_id) { - int num_unit_thread = MSMIN(thread_n_stride_, num_unit_ - task_id * thread_n_stride_); - MS_CHECK_GT(num_unit_thread, 0, RET_ERROR); - TypeId data_type = out_tensors_.front()->data_type(); - if (data_type != TypeId::kNumberTypeInt8) { - MS_LOG(ERROR) << "Data type not supported:" << data_type; - return RET_PARAM_INVALID; - } - int thread_offset = task_id * thread_n_stride_; - int ret; - if (activation_perchannel_) { - MS_CHECK_EQ(out_tensors_.front()->quant_params().size(), static_cast(channel_num_), RET_ERROR); - for (int i = 0; i < channel_num_; i++) { - auto quant_arg = out_tensors_.front()->quant_params().at(i); - scale_[i] = quant_arg.scale; - zero_point_[i] = quant_arg.zeroPoint; - } - if (transpose_) { - ret = DoChannelColFp32ToInt8(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, scale_, zero_point_, - num_unit_thread, row_length_, (int32_t)INT8_MIN, (int32_t)INT8_MAX); - } else { - ret = DoChannelRowFp32ToInt8(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, scale_, zero_point_, - num_unit_thread, row_length_, (int32_t)INT8_MIN, (int32_t)INT8_MAX); - } - } else { + int task_unit = unit_segment_num_ * unit_num_; + int offset = task_id * task_unit; + int ele_num = static_cast(in_tensors_.front()->ElementsNum()); + int remain = ele_num - offset; + task_unit = MSMIN(task_unit, remain); + if (param_->axis_num_ == 0) { // per-tensor auto quant_arg = out_tensors_.front()->quant_params().front(); - ret = DoQuantizeFp32ToInt8(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale, - quant_arg.zeroPoint, num_unit_thread, (int32_t)INT8_MIN, (int32_t)INT8_MAX); + auto ret = DoQuantizeFp32ToInt8(float32_ptr_ + offset, int8_ptr_ + offset, quant_arg.scale, quant_arg.zeroPoint, + task_unit, (int32_t)INT8_MIN, (int32_t)INT8_MAX); + if (ret != RET_OK) { + MS_LOG(ERROR) << "QuantDTypeCast error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; } - if (ret != RET_OK) { - MS_LOG(ERROR) << "QuantDTypeCast error task_id[" << task_id << "] error_code[" << ret << "]"; - return RET_ERROR; + int segment_num = task_unit / unit_num_; + for (int i = 0; i < segment_num; ++i) { + auto quant_arg = out_tensors_.front()->quant_params()[task_id * unit_segment_num_ + i]; + auto ret = + DoQuantizeFp32ToInt8(float32_ptr_ + offset + i * unit_num_, int8_ptr_ + offset + i * unit_num_, quant_arg.scale, + quant_arg.zeroPoint, unit_num_, (int32_t)INT8_MIN, (int32_t)INT8_MAX); + if (ret != RET_OK) { + MS_LOG(ERROR) << "QuantDTypeCast error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } } return RET_OK; } @@ -257,26 +258,110 @@ int QuantDataRun(void *cdata, int task_id, float, float) { return RET_OK; } +int DynamicQuantCPUKernel::MallocTmpBuffer() { + auto in_size = in_tensors_.front()->Size(); + auto out_size = out_tensors_.front()->Size(); + if (ms_context_ != nullptr && ms_context_->allocator != nullptr) { + int8_ptr_ = static_cast(ms_context_->allocator->Malloc(in_size + out_size)); + } else { + int8_ptr_ = static_cast(malloc(in_size + out_size)); + } + MS_CHECK_TRUE_MSG(int8_ptr_ != nullptr, lite::RET_NULL_PTR, "DynamicQuant malloc tmp buffer failed."); + float32_ptr_ = reinterpret_cast(int8_ptr_ + out_size); + return lite::RET_OK; +} + +void DynamicQuantCPUKernel::FreeTmpBuffer() { + if (need_transpose_) { + if (int8_ptr_ != nullptr) { + if (ms_context_ != nullptr && ms_context_->allocator != nullptr) { + ms_context_->allocator->Free(int8_ptr_); + } else { + free(int8_ptr_); + } + } + } + int8_ptr_ = nullptr; + float32_ptr_ = nullptr; +} + int DynamicQuantCPUKernel::Run() { - int8_ptr_ = reinterpret_cast(out_tensors_[0]->data()); - float32_ptr_ = reinterpret_cast(in_tensors_[0]->data()); - CHECK_NULL_RETURN(int8_ptr_); - CHECK_NULL_RETURN(float32_ptr_); - auto ret = ParallelLaunch(this->ms_context_, CalculateMinMaxRun, this, thread_n_num_); + std::vector transpose_shape; + if (need_transpose_) { + auto shape = in_tensors_.front()->shape(); + transpose_shape.resize(shape.size()); + for (size_t i = 0; i < shape.size(); ++i) { + transpose_shape[i] = shape[pre_perm_[i]]; + } + if (MallocTmpBuffer() != lite::RET_OK) { + MS_LOG(ERROR) << "DynamicQuant MallocTmpBuffer failed."; + return lite::RET_NULL_PTR; + } + std::vector strides(shape.size(), 1); + std::vector out_strides(shape.size(), 1); + for (int i = static_cast(shape.size()) - C2NUM; i >= 0; i--) { + strides[i] = shape[i + 1] * strides[i + 1]; + out_strides[i] = transpose_shape[i + 1] * out_strides[i + 1]; + } + if (shape.size() <= C6NUM) { + (void)DoTransposeFp32(in_tensors_.front()->data(), float32_ptr_, transpose_shape.data(), pre_perm_.data(), + strides.data(), out_strides.data(), in_tensors_.front()->Size(), shape.size()); + } else { + TransposeDimsFp32(in_tensors_.front()->data(), float32_ptr_, transpose_shape.data(), pre_perm_.data(), + strides.data(), out_strides.data(), shape.size(), 0, 1); + } + } else { + int8_ptr_ = reinterpret_cast(out_tensors_[0]->data()); + float32_ptr_ = reinterpret_cast(in_tensors_[0]->data()); + } + if (int8_ptr_ == nullptr || float32_ptr_ == nullptr) { + FreeTmpBuffer(); + MS_LOG(ERROR) << "DynamicQuant's original data exists nullptr."; + return lite::RET_NULL_PTR; + } + auto ret = ParallelLaunch(this->ms_context_, CalculateMinMaxRun, this, thread_num_); if (ret != RET_OK) { + FreeTmpBuffer(); MS_LOG(ERROR) << "Run error error_code[" << ret << "]"; return RET_ERROR; } - if (activation_perchannel_) { + if (param_->axis_num_ != 0) { CalculatePerChannelScaleZp(); } else { CalculatePerlayerScaleZp(); } - ret = ParallelLaunch(this->ms_context_, QuantDataRun, this, thread_n_num_); + ret = ParallelLaunch(this->ms_context_, QuantDataRun, this, thread_num_); if (ret != RET_OK) { + FreeTmpBuffer(); MS_LOG(ERROR) << "Run error error_code[" << ret << "]"; return RET_ERROR; } + if (need_transpose_) { + auto out_shape = out_tensors_.front()->shape(); + TransposeParameter trans_parameter; + (void)memset(&trans_parameter, 0, sizeof(TransposeParameter)); + trans_parameter.op_parameter_.thread_num_ = 1; + trans_parameter.num_axes_ = static_cast(out_shape.size()); + trans_parameter.data_num_ = out_tensors_[0]->ElementsNum(); + trans_parameter.perm_size_ = post_perm_.size(); + int last_index = static_cast(out_shape.size()) - 1; + trans_parameter.perm_[last_index] = post_perm_[last_index]; + trans_parameter.strides_[last_index] = 1; + trans_parameter.out_strides_[last_index] = 1; + for (int i = last_index - 1; i >= 0; i--) { + trans_parameter.perm_[i] = post_perm_[i]; + trans_parameter.strides_[i] = transpose_shape[i + 1] * trans_parameter.strides_[i + 1]; + trans_parameter.out_strides_[i] = out_shape[i + 1] * trans_parameter.out_strides_[i + 1]; + } + if (out_shape.size() <= C6NUM) { + (void)DoTransposeInt8(int8_ptr_, reinterpret_cast(out_tensors_[0]->data()), out_shape.data(), + &trans_parameter); + } else { + TransposeDimsInt8(int8_ptr_, reinterpret_cast(out_tensors_[0]->data()), out_shape.data(), + &trans_parameter, 0, 1); + } + } + FreeTmpBuffer(); return RET_OK; } diff --git a/mindspore/lite/src/litert/kernel/cpu/int8/dynamic_quant.h b/mindspore/lite/src/litert/kernel/cpu/int8/dynamic_quant.h index ca84f088..023f1fab 100644 --- a/mindspore/lite/src/litert/kernel/cpu/int8/dynamic_quant.h +++ b/mindspore/lite/src/litert/kernel/cpu/int8/dynamic_quant.h @@ -21,31 +21,15 @@ #include #include #include "src/litert/lite_kernel.h" +#include "nnacl/dynamic_quant_parameter.h" namespace mindspore::kernel { class DynamicQuantCPUKernel : public LiteKernel { public: DynamicQuantCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx) - : LiteKernel(parameter, inputs, outputs, ctx), thread_num_(ctx->thread_num_) {} - ~DynamicQuantCPUKernel() override { - if (real_min_ != nullptr) { - free(real_min_); - real_min_ = nullptr; - } - if (real_max_ != nullptr) { - free(real_max_); - real_max_ = nullptr; - } - if (scale_ != nullptr) { - free(scale_); - scale_ = nullptr; - } - if (zero_point_ != nullptr) { - free(zero_point_); - zero_point_ = nullptr; - } - }; + : LiteKernel(parameter, inputs, outputs, ctx) {} + ~DynamicQuantCPUKernel() override = default; int Prepare() override; int ReSize() override; @@ -57,28 +41,21 @@ class DynamicQuantCPUKernel : public LiteKernel { private: void CalculatePerlayerScaleZp(); void CalculatePerChannelScaleZp(); - - private: - int thread_num_; - int thread_n_num_{0}; - int thread_n_stride_{0}; - int num_unit_{0}; - int8_t *int8_ptr_ = nullptr; - float *float32_ptr_ = nullptr; - float *real_min_ = nullptr; - float *real_max_ = nullptr; - float *scale_ = nullptr; - int32_t *zero_point_ = nullptr; - - int32_t src_dtype_{0}; - int32_t dst_dtype_{0}; - bool symmetric_ = false; - bool activation_perchannel_ = false; - bool transpose_ = false; - int32_t prefer_axis_{-1}; - int32_t channel_num_{0}; - int32_t channel_length_{0}; - int32_t row_length_{0}; + int MallocTmpBuffer(); + void FreeTmpBuffer(); + + DynamicQuantParameter *param_{nullptr}; + std::vector real_min_; + std::vector real_max_; + std::vector scale_; + std::vector zero_point_; + std::vector pre_perm_; + std::vector post_perm_; + int8_t *int8_ptr_{nullptr}; + float *float32_ptr_{nullptr}; + int unit_num_{0}; + int unit_segment_num_{0}; + bool need_transpose_{false}; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.cc b/mindspore/lite/src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.cc index adae37aa..bab1f730 100644 --- a/mindspore/lite/src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.cc +++ b/mindspore/lite/src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.cc @@ -54,12 +54,12 @@ void MatmulDynamicBaseInt8CPUKernel::FreeQuantParam() { } int MatmulDynamicBaseInt8CPUKernel::MallocQuantParam() { - quant_param_ = reinterpret_cast(malloc(sizeof(MatmulQuantParameter))); + quant_param_ = reinterpret_cast(malloc(sizeof(MatmulDynamicQuantParameter))); if (quant_param_ == nullptr) { MS_LOG(ERROR) << "Malloc MatmulDynamicQuantParameter for Matmul int8 op failed!"; return RET_ERROR; } - memset(quant_param_, 0, sizeof(MatmulQuantParameter)); + (void)memset(quant_param_, 0, sizeof(MatmulDynamicQuantParameter)); return RET_OK; } @@ -80,9 +80,16 @@ int MatmulDynamicBaseInt8CPUKernel::InitFilterQuantParam() { MS_LOG(ERROR) << weight_tensor->tensor_name() << " dims < 2."; return RET_ERROR; } - int col = param_->b_transpose_ ? w_shape[w_shape.size() - kSize2] : w_shape[w_shape.size() - kSize1]; filter_per_channel_ = (weight_quant_params.size() > 1); - auto channel_num = filter_per_channel_ ? col : 1; + filter_per_batch_channel_ = false; + int channel_num = 1; + if (filter_per_channel_) { + channel_num = param_->col_; + if (weight_quant_params.size() > static_cast(channel_num)) { + filter_per_batch_channel_ = true; + channel_num = in_tensors_.at(kWeightIndex)->ElementsNum() / param_->deep_; + } + } if (static_cast(weight_quant_params.size()) != channel_num) { MS_LOG(ERROR) << weight_tensor->tensor_name() << " quant params size:" << weight_quant_params.size() << " != channel_num:" << channel_num; @@ -90,10 +97,10 @@ int MatmulDynamicBaseInt8CPUKernel::InitFilterQuantParam() { } quant_param_->filter_scale_ = reinterpret_cast(malloc(channel_num * sizeof(float))); CHECK_NULL_RETURN(quant_param_->filter_scale_); - memset(quant_param_->filter_scale_, 0, sizeof(channel_num)); + (void)memset(quant_param_->filter_scale_, 0, sizeof(channel_num)); quant_param_->filter_zp_ = reinterpret_cast(malloc(channel_num * sizeof(int32_t))); CHECK_NULL_RETURN(quant_param_->filter_zp_); - memset(quant_param_->filter_zp_, 0, sizeof(channel_num)); + (void)memset(quant_param_->filter_zp_, 0, sizeof(channel_num)); for (int i = 0; i < channel_num; i++) { quant_param_->filter_scale_[i] = static_cast(weight_quant_params[i].scale); @@ -143,7 +150,15 @@ int MatmulDynamicBaseInt8CPUKernel::InitInputQuantParam(std::vector *scal return RET_ERROR; } input_per_channel_ = (in_quant_params.size() > 1); - auto channel_num = input_per_channel_ ? param_->row_ : 1; + input_per_batch_channel_ = false; + int channel_num = 1; + if (input_per_channel_) { + channel_num = param_->row_; + if (in_quant_params.size() > static_cast(channel_num)) { + input_per_batch_channel_ = true; + channel_num = in_tensors_.at(kInputIndex)->ElementsNum() / param_->deep_; + } + } if (static_cast(in_quant_params.size()) != channel_num) { MS_LOG(ERROR) << in_tensors_.at(kInputIndex)->tensor_name() << " quant params size:" << in_quant_params.size() << " != channel_num:" << channel_num; @@ -199,7 +214,7 @@ int MatmulDynamicBaseInt8CPUKernel::InitMatrixABuffer() { return lite::RET_NULL_PTR; } input_sums_ = reinterpret_cast(pack_a_ptr_ + pack_a_size); - memset(pack_a_ptr_, 0, pack_a_size + sum_a_size); + (void)memset(pack_a_ptr_, 0, pack_a_size + sum_a_size); return RET_OK; } @@ -240,8 +255,8 @@ int MatmulDynamicBaseInt8CPUKernel::InitMatrixBBuffer() { FreeTmpBuffer(); return RET_ERROR; } - memset(pack_b_ptr_, 0, b_batch_ * param_->col_align_ * param_->deep_align_ * sizeof(int8_t)); - memset(weight_sums_, 0, b_batch_ * param_->col_align_ * sizeof(int)); + (void)memset(pack_b_ptr_, 0, b_batch_ * param_->col_align_ * param_->deep_align_ * sizeof(int8_t)); + (void)memset(weight_sums_, 0, b_batch_ * param_->col_align_ * sizeof(int)); return RET_OK; } @@ -258,7 +273,7 @@ int MatmulDynamicBaseInt8CPUKernel::CopyBias() { FreeTmpBuffer(); return RET_MEMORY_FAILED; } - memcpy(bias_ptr_, bias_tensor->data(), bias_tensor->Size()); + (void)memcpy(bias_ptr_, bias_tensor->data(), bias_tensor->Size()); } else { bias_ptr_ = nullptr; } @@ -352,6 +367,8 @@ int MatmulDynamicBaseInt8CPUKernel::ReSize() { int MatmulDynamicBaseInt8CPUKernel::InitBroadcastParams(const std::vector &a_shape_const, const std::vector &b_shape_const, MatMulParameter *params, std::vector *a_offsets, std::vector *b_offsets) { + CHECK_NULL_RETURN(a_offsets); + CHECK_NULL_RETURN(b_offsets); std::vector a_shape = a_shape_const; if (a_shape.size() < kNCHWDimNumber) { size_t add_nums = kNCHWDimNumber - a_shape.size(); @@ -370,8 +387,8 @@ int MatmulDynamicBaseInt8CPUKernel::InitBroadcastParams(const std::vector & int batch_sizes[MAX_SHAPE_SIZE] = {0}; int a_batch_sizes[MAX_SHAPE_SIZE] = {0}; int b_batch_sizes[MAX_SHAPE_SIZE] = {0}; - for (int i = a_shape.size() - kCHWDimNumber; i >= 0; --i) { - if (static_cast(a_shape.size() - kCHWDimNumber) == i) { + for (int i = static_cast(a_shape.size()) - kCHWDimNumber; i >= 0; --i) { + if (static_cast(a_shape.size()) - kCHWDimNumber == i) { batch_sizes[i] = std::max(a_shape[i], b_shape[i]); a_batch_sizes[i] = a_shape[i]; b_batch_sizes[i] = b_shape[i]; diff --git a/mindspore/lite/src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.h b/mindspore/lite/src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.h index 3fc20d80..858affc8 100644 --- a/mindspore/lite/src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.h +++ b/mindspore/lite/src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.h @@ -58,6 +58,8 @@ class MatmulDynamicBaseInt8CPUKernel : public LiteKernel { int b_batch_ = 1; std::vector a_offset_; std::vector b_offset_; + int a_quant_offset_ = 0; + int b_quant_offset_ = 0; typedef void (*PackFunc)(const int8_t *src, int8_t *dst, int row, int col); virtual void InitParameter() = 0; int TransferA(); @@ -69,14 +71,15 @@ class MatmulDynamicBaseInt8CPUKernel : public LiteKernel { int InitMatrixABuffer(); void FreeMatrixABuffer(); - protected: MatMulParameter *param_ = nullptr; MatmulDynamicQuantParameter *quant_param_ = nullptr; int8_t *pack_a_ptr_ = nullptr; int8_t *pack_b_ptr_ = nullptr; bool input_per_channel_ = false; - bool filter_per_channel_ = true; + bool input_per_batch_channel_ = false; + bool filter_per_channel_ = false; + bool filter_per_batch_channel_ = false; int8_t *batch_input_ptr_ = nullptr; int8_t *batch_weight_ptr_ = nullptr; int8_t *batch_a_ptr_ = nullptr; diff --git a/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.cc b/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.cc index 721a1a8c..03113eaa 100644 --- a/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.cc +++ b/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.cc @@ -102,7 +102,7 @@ int InsertQuantNodeManager::InsertDynamicQuantWithIndex(const FuncGraphPtr &grap bool symmetric = activation_channel ? true : false; primitive->set_symmetric(symmetric); primitive->set_activation_channel(activation_channel); - if (activation_channel && SetPreferAxis(cnode, index, primitive) != RET_OK) { + if (activation_channel && SetPreferAxes(cnode, index, primitive) != RET_OK) { MS_LOG(ERROR) << "Set prefer axis failed, " << cnode->fullname_with_scope(); return RET_ERROR; } @@ -127,18 +127,25 @@ int InsertQuantNodeManager::InsertDynamicQuantWithIndex(const FuncGraphPtr &grap return RET_OK; } -int InsertQuantNodeManager::SetPreferAxis(const CNodePtr &cnode, size_t index, +int InsertQuantNodeManager::SetPreferAxes(const CNodePtr &cnode, size_t index, const std::shared_ptr &dynamic_primitive) { auto primitive = GetValueNode(cnode->input(0)); if (primitive->name() == ops::kNameMatMulFusion || primitive->name() == ops::kNameMatMul) { auto matmul_prim = api::MakeShared(primitive); CHECK_NULL_RETURN(matmul_prim); + auto shape = opt::GetAnfNodeOutputShape(cnode->input(index), 0); + std::vector prefer_axes; + for (int i = 0; i < static_cast(shape.size()) - C2NUM; ++i) { + prefer_axes.push_back(i); + } // For MatMul A if (index == kInputIndex + kPrimOffset) { if (matmul_prim->GetAttr(ops::kTransposeA) != nullptr && matmul_prim->get_transpose_a()) { + prefer_axes.push_back(kLastFisrtIndex); dynamic_primitive->set_prefer_axis(kLastFisrtIndex); dynamic_primitive->set_transpose(true); } else { + prefer_axes.push_back(kLastSecondIndex); dynamic_primitive->set_prefer_axis(kLastSecondIndex); dynamic_primitive->set_transpose(false); } @@ -146,13 +153,16 @@ int InsertQuantNodeManager::SetPreferAxis(const CNodePtr &cnode, size_t index, // For MatMul B if (index == kWeightIndex + kPrimOffset) { if (matmul_prim->GetAttr(ops::kTransposeB) != nullptr && matmul_prim->get_transpose_b()) { + prefer_axes.push_back(kLastSecondIndex); dynamic_primitive->set_prefer_axis(kLastSecondIndex); dynamic_primitive->set_transpose(true); } else { + prefer_axes.push_back(kLastFisrtIndex); dynamic_primitive->set_prefer_axis(kLastFisrtIndex); dynamic_primitive->set_transpose(false); } } + dynamic_primitive->set_prefer_axes(prefer_axes); } else { MS_LOG(WARNING) << "cnode don't need prefer axis, cnode name: " << cnode->fullname_with_scope(); } @@ -167,13 +177,17 @@ int InsertQuantNodeManager::NewDynamicQuantNode(const FuncGraphPtr &graph, const return RET_ERROR; } auto input = cnode->input(kInputIndex + kPrimOffset); + auto weight = cnode->input(kWeightIndex + kPrimOffset); + if (activation_channel && (input->isa() || IsGraphInput(input)) && + (weight->isa() || IsGraphInput(weight))) { + return RET_NOT_SUPPORT; + } if (input->isa() || IsGraphInput(input)) { auto ret = InsertDynamicQuantWithIndex(graph, cnode, kInputIndex + kPrimOffset, activation_channel); if (ret != RET_OK) { MS_LOG(ERROR) << "Insert dynamic quant with index failed."; } } - auto weight = cnode->input(kWeightIndex + kPrimOffset); if (weight->isa() || IsGraphInput(weight)) { auto ret = InsertDynamicQuantWithIndex(graph, cnode, kWeightIndex + kPrimOffset, activation_channel); if (ret != RET_OK) { @@ -218,6 +232,9 @@ int InsertQuantNodeManager::InsertDynamicQuantNode(const FuncGraphPtr &graph, continue; } ret = NewDynamicQuantNode(graph, cnode, activation_channel); + if (ret == RET_NOT_SUPPORT) { + continue; + } if (ret != RET_OK) { MS_LOG(ERROR) << "node:" << op_name << " new dynamic quant node failed."; return ret; @@ -684,7 +701,7 @@ int InsertQuantNodeManager::InsertQuantDtypeCastFlyNode(const FuncGraphPtr &func int InsertQuantNodeManager::CalculateScaleZPNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index, ParameterPtr *scales_node, ParameterPtr *zps_node, - TypeId src_dtype, TypeId dst_dtype, int axis) { + TypeId dst_dtype, int axis) { CHECK_NULL_RETURN(scales_node); CHECK_NULL_RETURN(zps_node); auto input_node = cnode->input(input_index); @@ -785,7 +802,7 @@ int InsertQuantNodeManager::InsertAscendAntiQuantNode(const FuncGraphPtr &func_g CHECK_NULL_RETURN(cast_cnode); ParameterPtr scales_node; ParameterPtr zps_node; - auto ret = CalculateScaleZPNode(func_graph, cnode, input_index, &scales_node, &zps_node, src_dtype, dst_dtype, axis); + auto ret = CalculateScaleZPNode(func_graph, cnode, input_index, &scales_node, &zps_node, dst_dtype, axis); if (ret != RET_OK) { MS_LOG(ERROR) << "Fail to Remove node: " << input_node->fullname_with_scope() << " quant param"; return RET_ERROR; diff --git a/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.h b/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.h index a46e8c68..6f328485 100644 --- a/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.h +++ b/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.h @@ -75,13 +75,12 @@ class InsertQuantNodeManager { int MarkDynamicQuantize(const CNodePtr &cnode); int CalculateScaleZPNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index, - ParameterPtr *scales_node, ParameterPtr *zps_node, TypeId src_dtype, TypeId dst_dtype, - int axis); + ParameterPtr *scales_node, ParameterPtr *zps_node, TypeId dst_dtype, int axis); int InsertDynamicQuantWithIndex(const FuncGraphPtr &graph, const CNodePtr &cnode, size_t index, bool activation_channel = true); - int SetPreferAxis(const CNodePtr &cnode, size_t index, const std::shared_ptr &dynamic_primitive); + int SetPreferAxes(const CNodePtr &cnode, size_t index, const std::shared_ptr &dynamic_primitive); int SetCastNodeAbstract(const CNodePtr &cnode, const AnfNodePtr &input_node, const CNodePtr &cast_cnode); -- 2.25.1