1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/converter/quantizer/huffman_encode.h"
18 #include "src/weight_decoder.h"
19 #include "tools/converter/quantizer/quantize_util.h"
20
21 namespace mindspore {
22 namespace lite {
DoHuffmanEncode(const tensor::TensorPtr & weight,const PrimitivePtr & primitive,void * quant_datas,const size_t & bit_num)23 STATUS HuffmanEncode::DoHuffmanEncode(const tensor::TensorPtr &weight, const PrimitivePtr &primitive, void *quant_datas,
24 const size_t &bit_num) {
25 MS_ASSERT(weight != nullptr);
26 MS_ASSERT(primitive != nullptr);
27 if (quant_datas == nullptr) {
28 MS_LOG(ERROR) << "quant data is nullptr";
29 return RET_ERROR;
30 }
31 auto *raw_datas = static_cast<int8_t *>(quant_datas);
32 size_t elem_count = weight->DataSize();
33 int packed_size = elem_count * bit_num;
34
35 HuffmanPriorityQueue pq;
36 auto status = GetHuffmanPriorityQueue(raw_datas, elem_count, &pq);
37 if (status != RET_OK) {
38 MS_LOG(ERROR) << "GetHuffmanPriorityQueue failed";
39 return status;
40 }
41 status = BuildHuffmanTree(&pq);
42 if (status != RET_OK) {
43 MS_LOG(ERROR) << "BuildHuffmanTree failed";
44 return status;
45 }
46 status = DoHuffmanCompress(raw_datas, elem_count);
47 if (status != RET_OK) {
48 MS_LOG(ERROR) << "DoHuffmanCompress failed";
49 return status;
50 }
51 int ch_size = huffman_encoded_str_.length();
52 if (ch_size < packed_size) {
53 if (ch_size != weight->data().nbytes()) {
54 MS_LOG(ERROR) << "Data size of weight is error.";
55 return RET_ERROR;
56 }
57 if (memcpy_s(weight->data_c(), weight->data().nbytes(), huffman_encoded_str_.c_str(), ch_size) != EOK) {
58 MS_LOG(ERROR) << "memcpy_s failed.";
59 return RET_MEMORY_FAILED;
60 }
61 auto quant_param_holder = quant::GetCNodeQuantHolder(primitive);
62 MS_ASSERT(quant_param_holder != nullptr);
63 quant_param_holder->set_enable_huffman_code(true);
64 }
65 huffman_encoded_str_.clear();
66 huffman_table_.clear();
67 return RET_SUCCESS;
68 }
69
GetHuffmanPriorityQueue(const int8_t * data,const size_t data_size,HuffmanPriorityQueue * pq)70 STATUS HuffmanEncode::GetHuffmanPriorityQueue(const int8_t *data, const size_t data_size, HuffmanPriorityQueue *pq) {
71 MS_ASSERT(data != nullptr);
72 std::map<int8_t, size_t> freq_map;
73 for (size_t i = 0; i < data_size; i++) {
74 freq_map[data[i]]++;
75 }
76 for (auto &kv : freq_map) {
77 if (kv.second == 0) {
78 continue;
79 }
80 auto node = new (std::nothrow) HuffmanNode();
81 if (node == nullptr) {
82 MS_LOG(ERROR) << "new HuffmanNode failed.";
83 return RET_MEMORY_FAILED;
84 }
85 this->huffman_nodes_.push_back(node);
86 node->key = kv.first;
87 node->freq = kv.second;
88 node->code = "";
89 node->left = nullptr;
90 node->right = nullptr;
91 node->parent = nullptr;
92 pq->push(node);
93 }
94
95 // insert pseudo-EOF
96 auto node = new (std::nothrow) HuffmanNode();
97 if (node == nullptr) {
98 MS_LOG(ERROR) << "new HuffmanNode failed.";
99 return RET_MEMORY_FAILED;
100 }
101 this->huffman_nodes_.push_back(node);
102 node->key = PSEUDO_EOF;
103 node->freq = 1;
104 node->code = "";
105 node->left = nullptr;
106 node->right = nullptr;
107 node->parent = nullptr;
108
109 pq->push(node);
110 return RET_OK;
111 }
112
GenerateHuffmanTable(const HuffmanNodePtr node,bool is_left_node)113 void HuffmanEncode::GenerateHuffmanTable(const HuffmanNodePtr node, bool is_left_node) {
114 MS_ASSERT(node != nullptr);
115 if (is_left_node) {
116 node->code = node->parent->code + "0";
117 } else {
118 node->code = node->parent->code + "1";
119 }
120
121 if (node->left == nullptr && node->right == nullptr) {
122 huffman_table_[node->key] = node->code;
123 } else {
124 if (node->left != nullptr) {
125 GenerateHuffmanTable(node->left, true);
126 }
127 if (node->right != nullptr) {
128 GenerateHuffmanTable(node->right, false);
129 }
130 }
131 }
132
BuildHuffmanTree(HuffmanPriorityQueue * pq)133 STATUS HuffmanEncode::BuildHuffmanTree(HuffmanPriorityQueue *pq) {
134 MS_ASSERT(pq != nullptr);
135 HuffmanNodePtr root = nullptr;
136 while (!pq->empty()) {
137 HuffmanNodePtr first = pq->top();
138 pq->pop();
139 if (pq->empty()) {
140 root = first;
141 break;
142 }
143 HuffmanNodePtr second = pq->top();
144 pq->pop();
145 auto new_node = new (std::nothrow) HuffmanNode();
146 if (new_node == nullptr) {
147 MS_LOG(ERROR) << "new HuffmanNode failed.";
148 return RET_MEMORY_FAILED;
149 }
150 this->huffman_nodes_.push_back(new_node);
151 new_node->freq = first->freq + second->freq;
152 new_node->left = first;
153 new_node->right = second;
154 first->parent = new_node;
155 second->parent = new_node;
156 pq->push(new_node);
157 }
158
159 if (root == nullptr) {
160 MS_LOG(ERROR) << "huffman tree root node is nullptr.";
161 return RET_ERROR;
162 }
163
164 if (root->left != nullptr) {
165 GenerateHuffmanTable(root->left, true);
166 }
167 if (root->right != nullptr) GenerateHuffmanTable(root->right, false);
168
169 return RET_OK;
170 }
171
DoHuffmanCompress(const int8_t * input_datas,const size_t data_size)172 STATUS HuffmanEncode::DoHuffmanCompress(const int8_t *input_datas, const size_t data_size) {
173 MS_ASSERT(input_datas != nullptr);
174 unsigned char out_c;
175 string code_str;
176 std::map<int, string>::iterator iter;
177 std::vector<std::string> encode_str = {"", "", ""};
178
179 huffman_encoded_str_.clear();
180 for (iter = huffman_table_.begin(); iter != huffman_table_.end(); ++iter) {
181 encode_str[0] += std::to_string(iter->first) + " ";
182 encode_str[1] += iter->second + " ";
183 }
184
185 for (size_t i = 0; i < data_size; i++) {
186 auto raw_num = input_datas[i];
187 iter = huffman_table_.find(raw_num);
188 if (iter != huffman_table_.end()) {
189 code_str += iter->second;
190 } else {
191 MS_LOG(ERROR) << "Can't find the huffman code " << raw_num;
192 return RET_ERROR;
193 }
194 }
195 iter = huffman_table_.find(PSEUDO_EOF);
196 if (iter != huffman_table_.end()) {
197 code_str += iter->second;
198 } else {
199 MS_LOG(ERROR) << "Can't find the huffman code pseudo-EOF";
200 return RET_ERROR;
201 }
202 out_c = 0;
203 for (size_t i = 0; i < code_str.length(); i++) {
204 auto tmp_c = code_str[i] == '0' ? 0 : 1;
205 out_c += tmp_c << ((quant::kMaxBit - 1) - (i % quant::kMaxBit));
206 if ((i + 1) % quant::kMaxBit == 0 || i == code_str.length() - 1) {
207 encode_str[2] += out_c;
208 out_c = 0;
209 }
210 }
211 huffman_encoded_str_ = encode_str[0] + "#" + encode_str[1] + "#" + encode_str[2];
212 return RET_OK;
213 }
214
~HuffmanEncode()215 HuffmanEncode::~HuffmanEncode() {
216 for (auto &node : this->huffman_nodes_) {
217 delete node;
218 }
219 this->huffman_nodes_.clear();
220 }
221 } // namespace lite
222 } // namespace mindspore
223