• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "coder/opcoders/nnacl/dequant/de_quant.h"
18 #include <string>
19 #include <vector>
20 #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h"
21 
22 namespace mindspore::lite::micro::nnacl {
23 constexpr int kPerTensor = 1;
24 constexpr size_t kPerBatch = 3;
25 
set_de_quant_buffer_str(const std::string & dequant_buffer_str)26 void Dequant::set_de_quant_buffer_str(const std::string &dequant_buffer_str) {
27   de_quant_buffer_str_ = "(float *)(" + dequant_buffer_str + ")";
28 }
29 
DequantRecordWorkspcae(size_t curr_workspace)30 void Dequant::DequantRecordWorkspcae(size_t curr_workspace) {
31   de_quant_max_workspace_ = de_quant_max_workspace_ > curr_workspace ? de_quant_max_workspace_ : curr_workspace;
32 }
33 
CheckDequantFlag(const Tensor * weight_tensor)34 bool Dequant::CheckDequantFlag(const Tensor *weight_tensor) {
35   if (weight_tensor == nullptr) {
36     return false;
37   }
38   return !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited &&
39          weight_tensor->data() != nullptr;
40 }
41 
DeQuantFunctionPerChannel(const Tensor * quant_tensor,const std::vector<DeQuantArg> & de_quant_args,const std::string & de_quant_arg_base_str,NNaclFp32Serializer * const de_quant_code)42 void Dequant::DeQuantFunctionPerChannel(const Tensor *quant_tensor, const std::vector<DeQuantArg> &de_quant_args,
43                                         const std::string &de_quant_arg_base_str,
44                                         NNaclFp32Serializer *const de_quant_code) {
45   int quant_arg_dims = static_cast<int>(quant_tensor->quant_params().size());
46   int de_quant_nums = quant_tensor->ElementsNum();
47   for (int i = 0; i < quant_arg_dims; ++i) {
48     auto de_quant_arg = de_quant_args.at(i);
49     std::string de_quant_arg_str = de_quant_arg_base_str + std::to_string(i);
50     de_quant_code->CodeStruct(de_quant_arg_str, de_quant_arg);
51   }
52   std::string de_quant_args_name = "de_quant_args";
53   *de_quant_code << "const DeQuantArg *" << de_quant_args_name << "[" << quant_arg_dims << "] = {\n";
54   for (int i = 0; i < quant_arg_dims - 1; ++i) {
55     *de_quant_code << "&" << de_quant_arg_base_str << std::to_string(i) << ", ";
56   }
57   *de_quant_code << "&" << de_quant_arg_base_str << std::to_string(quant_arg_dims - 1);
58   *de_quant_code << "};\n";
59   size_t per_batch_size = quant_tensor->shape().at(0);
60   std::string quant_tensor_addr_str = "(int8_t *)(" + quant_tensor_addr_ + ")";
61   de_quant_code->CodeFunction("DequantDataPerChannel", quant_tensor_addr_str, de_quant_args_name, de_quant_nums,
62                               per_batch_size, de_quant_buffer_str_);
63 }
64 
DeQuantFunction(const Tensor * quant_tensor,const std::vector<DeQuantArg> & de_quant_args,const std::string & de_quant_arg_base_str,NNaclFp32Serializer * const de_quant_code)65 void Dequant::DeQuantFunction(const Tensor *quant_tensor, const std::vector<DeQuantArg> &de_quant_args,
66                               const std::string &de_quant_arg_base_str, NNaclFp32Serializer *const de_quant_code) {
67   int quant_arg_dims = static_cast<int>(quant_tensor->quant_params().size());
68   int de_quant_nums = quant_tensor->ElementsNum();
69   for (int i = 0; i < quant_arg_dims; ++i) {
70     auto de_quant_arg = de_quant_args.at(i);
71     std::string de_quant_arg_str = de_quant_arg_base_str + std::to_string(i);
72     de_quant_code->CodeStruct(de_quant_arg_str, de_quant_arg);
73   }
74   std::string de_quant_args_name = "de_quant_args";
75   *de_quant_code << "const DeQuantArg *" << de_quant_args_name << "[" << quant_arg_dims << "] = {\n";
76   for (int i = 0; i < quant_arg_dims - 1; ++i) {
77     *de_quant_code << "&" << de_quant_arg_base_str << std::to_string(i) << ", ";
78   }
79   *de_quant_code << "&" << de_quant_arg_base_str << std::to_string(quant_arg_dims - 1);
80   *de_quant_code << "};\n";
81   int32_t channels = quant_tensor->Batch();
82   std::string quant_tensor_addr_str = "(int8_t *)(" + quant_tensor_addr_ + ")";
83   de_quant_code->CodeFunction("DequantData", quant_tensor_addr_str, de_quant_args_name, de_quant_nums, channels,
84                               de_quant_buffer_str_);
85 }
86 
DeQuantFunctionPerTensor(const Tensor * quant_tensor,const std::vector<DeQuantArg> & de_quant_args,const std::string & de_quant_arg_base_str,NNaclFp32Serializer * const de_quant_code)87 void Dequant::DeQuantFunctionPerTensor(const Tensor *quant_tensor, const std::vector<DeQuantArg> &de_quant_args,
88                                        const std::string &de_quant_arg_base_str,
89                                        NNaclFp32Serializer *const de_quant_code) {
90   size_t de_quant_nums = quant_tensor->ElementsNum();
91   auto de_quant_arg = de_quant_args.at(0);
92   std::string de_quant_arg_str = de_quant_arg_base_str + std::to_string(0);
93   de_quant_code->CodeStruct(de_quant_arg_str, de_quant_arg);
94   std::string de_quant_args_name = "de_quant_args";
95   *de_quant_code << "const DeQuantArg *" << de_quant_args_name << "[" << 1 << "] = {\n";
96   *de_quant_code << "&" << de_quant_arg_base_str << std::to_string(0);
97   *de_quant_code << "};\n";
98   std::string quant_tensor_addr_str = "(int8_t *)(" + quant_tensor_addr_ + ")";
99   de_quant_code->CodeFunction("DequantDataPerTensor", quant_tensor_addr_str, de_quant_args_name, de_quant_nums,
100                               de_quant_buffer_str_);
101 }
102 
GetMicroDeQuantFunction(const Tensor * quant_tensor,const std::string & quant_tensor_addr)103 std::string Dequant::GetMicroDeQuantFunction(const Tensor *quant_tensor, const std::string &quant_tensor_addr) {
104   std::string de_quant_block;
105   if (quant_tensor == nullptr || de_quant_buffer_str_.empty()) {
106     return de_quant_block;
107   }
108   quant_tensor_addr_ = quant_tensor_addr;
109   size_t de_quant_nums = quant_tensor->ElementsNum();
110   size_t quant_arg_dims = quant_tensor->quant_params().size();
111   DequantRecordWorkspcae(static_cast<size_t>(de_quant_nums * sizeof(float)));
112   NNaclFp32Serializer de_quant_code;
113   de_quant_code << "{\n";
114   size_t quant_tensor_dims = quant_tensor->shape().size();
115   std::vector<DeQuantArg> de_quant_args;
116   std::string de_quant_arg_base_str = "de_quant_arg_";
117   for (size_t i = 0; i < quant_arg_dims; ++i) {
118     auto curr_quant_param = quant_tensor->quant_params().at(i);
119     DeQuantArg de_quant_arg = {
120       .scale = static_cast<float>(curr_quant_param.scale),
121       .zeroPoint = curr_quant_param.zeroPoint,
122       .var_corr = curr_quant_param.var_corr,
123       .mean_corr = curr_quant_param.mean_corr,
124       // this clusters is meaningless which will be supported in future
125       .clusters = {},
126       .clusters_nums = static_cast<int>(curr_quant_param.clusters.size()),
127       .bitNum = quant_tensor->quant_params().at(i).bitNum,
128     };
129     de_quant_args.emplace_back(de_quant_arg);
130   }
131   de_quant_code.CodeFunction("memset", de_quant_buffer_str_, 0, de_quant_nums * sizeof(float));
132   if (quant_tensor_dims == kPerBatch && quant_arg_dims == static_cast<size_t>(quant_tensor->shape().at(0))) {
133     DeQuantFunctionPerChannel(quant_tensor, de_quant_args, de_quant_arg_base_str, &de_quant_code);
134   } else if (quant_arg_dims != kPerTensor) {
135     DeQuantFunction(quant_tensor, de_quant_args, de_quant_arg_base_str, &de_quant_code);
136   } else {
137     DeQuantFunctionPerTensor(quant_tensor, de_quant_args, de_quant_arg_base_str, &de_quant_code);
138   }
139   de_quant_code << "}\n";
140   de_quant_block = de_quant_code.str();
141   return de_quant_block;
142 }
143 }  // namespace mindspore::lite::micro::nnacl
144