1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "coder/opcoders/nnacl/dequant/de_quant.h"
18 #include <string>
19 #include <vector>
20 #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h"
21
22 namespace mindspore::lite::micro::nnacl {
23 constexpr int kPerTensor = 1;
24 constexpr size_t kPerBatch = 3;
25
set_de_quant_buffer_str(const std::string & dequant_buffer_str)26 void Dequant::set_de_quant_buffer_str(const std::string &dequant_buffer_str) {
27 de_quant_buffer_str_ = "(float *)(" + dequant_buffer_str + ")";
28 }
29
DequantRecordWorkspcae(size_t curr_workspace)30 void Dequant::DequantRecordWorkspcae(size_t curr_workspace) {
31 de_quant_max_workspace_ = de_quant_max_workspace_ > curr_workspace ? de_quant_max_workspace_ : curr_workspace;
32 }
33
CheckDequantFlag(const Tensor * weight_tensor)34 bool Dequant::CheckDequantFlag(const Tensor *weight_tensor) {
35 if (weight_tensor == nullptr) {
36 return false;
37 }
38 return !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited &&
39 weight_tensor->data() != nullptr;
40 }
41
DeQuantFunctionPerChannel(const Tensor * quant_tensor,const std::vector<DeQuantArg> & de_quant_args,const std::string & de_quant_arg_base_str,NNaclFp32Serializer * const de_quant_code)42 void Dequant::DeQuantFunctionPerChannel(const Tensor *quant_tensor, const std::vector<DeQuantArg> &de_quant_args,
43 const std::string &de_quant_arg_base_str,
44 NNaclFp32Serializer *const de_quant_code) {
45 int quant_arg_dims = static_cast<int>(quant_tensor->quant_params().size());
46 int de_quant_nums = quant_tensor->ElementsNum();
47 for (int i = 0; i < quant_arg_dims; ++i) {
48 auto de_quant_arg = de_quant_args.at(i);
49 std::string de_quant_arg_str = de_quant_arg_base_str + std::to_string(i);
50 de_quant_code->CodeStruct(de_quant_arg_str, de_quant_arg);
51 }
52 std::string de_quant_args_name = "de_quant_args";
53 *de_quant_code << "const DeQuantArg *" << de_quant_args_name << "[" << quant_arg_dims << "] = {\n";
54 for (int i = 0; i < quant_arg_dims - 1; ++i) {
55 *de_quant_code << "&" << de_quant_arg_base_str << std::to_string(i) << ", ";
56 }
57 *de_quant_code << "&" << de_quant_arg_base_str << std::to_string(quant_arg_dims - 1);
58 *de_quant_code << "};\n";
59 size_t per_batch_size = quant_tensor->shape().at(0);
60 std::string quant_tensor_addr_str = "(int8_t *)(" + quant_tensor_addr_ + ")";
61 de_quant_code->CodeFunction("DequantDataPerChannel", quant_tensor_addr_str, de_quant_args_name, de_quant_nums,
62 per_batch_size, de_quant_buffer_str_);
63 }
64
DeQuantFunction(const Tensor * quant_tensor,const std::vector<DeQuantArg> & de_quant_args,const std::string & de_quant_arg_base_str,NNaclFp32Serializer * const de_quant_code)65 void Dequant::DeQuantFunction(const Tensor *quant_tensor, const std::vector<DeQuantArg> &de_quant_args,
66 const std::string &de_quant_arg_base_str, NNaclFp32Serializer *const de_quant_code) {
67 int quant_arg_dims = static_cast<int>(quant_tensor->quant_params().size());
68 int de_quant_nums = quant_tensor->ElementsNum();
69 for (int i = 0; i < quant_arg_dims; ++i) {
70 auto de_quant_arg = de_quant_args.at(i);
71 std::string de_quant_arg_str = de_quant_arg_base_str + std::to_string(i);
72 de_quant_code->CodeStruct(de_quant_arg_str, de_quant_arg);
73 }
74 std::string de_quant_args_name = "de_quant_args";
75 *de_quant_code << "const DeQuantArg *" << de_quant_args_name << "[" << quant_arg_dims << "] = {\n";
76 for (int i = 0; i < quant_arg_dims - 1; ++i) {
77 *de_quant_code << "&" << de_quant_arg_base_str << std::to_string(i) << ", ";
78 }
79 *de_quant_code << "&" << de_quant_arg_base_str << std::to_string(quant_arg_dims - 1);
80 *de_quant_code << "};\n";
81 int32_t channels = quant_tensor->Batch();
82 std::string quant_tensor_addr_str = "(int8_t *)(" + quant_tensor_addr_ + ")";
83 de_quant_code->CodeFunction("DequantData", quant_tensor_addr_str, de_quant_args_name, de_quant_nums, channels,
84 de_quant_buffer_str_);
85 }
86
DeQuantFunctionPerTensor(const Tensor * quant_tensor,const std::vector<DeQuantArg> & de_quant_args,const std::string & de_quant_arg_base_str,NNaclFp32Serializer * const de_quant_code)87 void Dequant::DeQuantFunctionPerTensor(const Tensor *quant_tensor, const std::vector<DeQuantArg> &de_quant_args,
88 const std::string &de_quant_arg_base_str,
89 NNaclFp32Serializer *const de_quant_code) {
90 size_t de_quant_nums = quant_tensor->ElementsNum();
91 auto de_quant_arg = de_quant_args.at(0);
92 std::string de_quant_arg_str = de_quant_arg_base_str + std::to_string(0);
93 de_quant_code->CodeStruct(de_quant_arg_str, de_quant_arg);
94 std::string de_quant_args_name = "de_quant_args";
95 *de_quant_code << "const DeQuantArg *" << de_quant_args_name << "[" << 1 << "] = {\n";
96 *de_quant_code << "&" << de_quant_arg_base_str << std::to_string(0);
97 *de_quant_code << "};\n";
98 std::string quant_tensor_addr_str = "(int8_t *)(" + quant_tensor_addr_ + ")";
99 de_quant_code->CodeFunction("DequantDataPerTensor", quant_tensor_addr_str, de_quant_args_name, de_quant_nums,
100 de_quant_buffer_str_);
101 }
102
GetMicroDeQuantFunction(const Tensor * quant_tensor,const std::string & quant_tensor_addr)103 std::string Dequant::GetMicroDeQuantFunction(const Tensor *quant_tensor, const std::string &quant_tensor_addr) {
104 std::string de_quant_block;
105 if (quant_tensor == nullptr || de_quant_buffer_str_.empty()) {
106 return de_quant_block;
107 }
108 quant_tensor_addr_ = quant_tensor_addr;
109 size_t de_quant_nums = quant_tensor->ElementsNum();
110 size_t quant_arg_dims = quant_tensor->quant_params().size();
111 DequantRecordWorkspcae(static_cast<size_t>(de_quant_nums * sizeof(float)));
112 NNaclFp32Serializer de_quant_code;
113 de_quant_code << "{\n";
114 size_t quant_tensor_dims = quant_tensor->shape().size();
115 std::vector<DeQuantArg> de_quant_args;
116 std::string de_quant_arg_base_str = "de_quant_arg_";
117 for (size_t i = 0; i < quant_arg_dims; ++i) {
118 auto curr_quant_param = quant_tensor->quant_params().at(i);
119 DeQuantArg de_quant_arg = {
120 .scale = static_cast<float>(curr_quant_param.scale),
121 .zeroPoint = curr_quant_param.zeroPoint,
122 .var_corr = curr_quant_param.var_corr,
123 .mean_corr = curr_quant_param.mean_corr,
124 // this clusters is meaningless which will be supported in future
125 .clusters = {},
126 .clusters_nums = static_cast<int>(curr_quant_param.clusters.size()),
127 .bitNum = quant_tensor->quant_params().at(i).bitNum,
128 };
129 de_quant_args.emplace_back(de_quant_arg);
130 }
131 de_quant_code.CodeFunction("memset", de_quant_buffer_str_, 0, de_quant_nums * sizeof(float));
132 if (quant_tensor_dims == kPerBatch && quant_arg_dims == static_cast<size_t>(quant_tensor->shape().at(0))) {
133 DeQuantFunctionPerChannel(quant_tensor, de_quant_args, de_quant_arg_base_str, &de_quant_code);
134 } else if (quant_arg_dims != kPerTensor) {
135 DeQuantFunction(quant_tensor, de_quant_args, de_quant_arg_base_str, &de_quant_code);
136 } else {
137 DeQuantFunctionPerTensor(quant_tensor, de_quant_args, de_quant_arg_base_str, &de_quant_code);
138 }
139 de_quant_code << "}\n";
140 de_quant_block = de_quant_code.str();
141 return de_quant_block;
142 }
143 } // namespace mindspore::lite::micro::nnacl
144