• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/converter/quantizer/huffman_encode.h"
18 #include "src/weight_decoder.h"
19 #include "tools/converter/quantizer/quantize_util.h"
20 
21 namespace mindspore {
22 namespace lite {
DoHuffmanEncode(const tensor::TensorPtr & weight,const PrimitivePtr & primitive,void * quant_datas,const size_t & bit_num)23 STATUS HuffmanEncode::DoHuffmanEncode(const tensor::TensorPtr &weight, const PrimitivePtr &primitive, void *quant_datas,
24                                       const size_t &bit_num) {
25   MS_ASSERT(weight != nullptr);
26   MS_ASSERT(primitive != nullptr);
27   if (quant_datas == nullptr) {
28     MS_LOG(ERROR) << "quant data is nullptr";
29     return RET_ERROR;
30   }
31   auto *raw_datas = static_cast<int8_t *>(quant_datas);
32   size_t elem_count = weight->DataSize();
33   int packed_size = elem_count * bit_num;
34 
35   HuffmanPriorityQueue pq;
36   auto status = GetHuffmanPriorityQueue(raw_datas, elem_count, &pq);
37   if (status != RET_OK) {
38     MS_LOG(ERROR) << "GetHuffmanPriorityQueue failed";
39     return status;
40   }
41   status = BuildHuffmanTree(&pq);
42   if (status != RET_OK) {
43     MS_LOG(ERROR) << "BuildHuffmanTree failed";
44     return status;
45   }
46   status = DoHuffmanCompress(raw_datas, elem_count);
47   if (status != RET_OK) {
48     MS_LOG(ERROR) << "DoHuffmanCompress failed";
49     return status;
50   }
51   int ch_size = huffman_encoded_str_.length();
52   if (ch_size < packed_size) {
53     if (ch_size != weight->data().nbytes()) {
54       MS_LOG(ERROR) << "Data size of weight is error.";
55       return RET_ERROR;
56     }
57     if (memcpy_s(weight->data_c(), weight->data().nbytes(), huffman_encoded_str_.c_str(), ch_size) != EOK) {
58       MS_LOG(ERROR) << "memcpy_s failed.";
59       return RET_MEMORY_FAILED;
60     }
61     auto quant_param_holder = quant::GetCNodeQuantHolder(primitive);
62     MS_ASSERT(quant_param_holder != nullptr);
63     quant_param_holder->set_enable_huffman_code(true);
64   }
65   huffman_encoded_str_.clear();
66   huffman_table_.clear();
67   return RET_SUCCESS;
68 }
69 
GetHuffmanPriorityQueue(const int8_t * data,const size_t data_size,HuffmanPriorityQueue * pq)70 STATUS HuffmanEncode::GetHuffmanPriorityQueue(const int8_t *data, const size_t data_size, HuffmanPriorityQueue *pq) {
71   MS_ASSERT(data != nullptr);
72   std::map<int8_t, size_t> freq_map;
73   for (size_t i = 0; i < data_size; i++) {
74     freq_map[data[i]]++;
75   }
76   for (auto &kv : freq_map) {
77     if (kv.second == 0) {
78       continue;
79     }
80     auto node = new (std::nothrow) HuffmanNode();
81     if (node == nullptr) {
82       MS_LOG(ERROR) << "new HuffmanNode failed.";
83       return RET_MEMORY_FAILED;
84     }
85     this->huffman_nodes_.push_back(node);
86     node->key = kv.first;
87     node->freq = kv.second;
88     node->code = "";
89     node->left = nullptr;
90     node->right = nullptr;
91     node->parent = nullptr;
92     pq->push(node);
93   }
94 
95   // insert pseudo-EOF
96   auto node = new (std::nothrow) HuffmanNode();
97   if (node == nullptr) {
98     MS_LOG(ERROR) << "new HuffmanNode failed.";
99     return RET_MEMORY_FAILED;
100   }
101   this->huffman_nodes_.push_back(node);
102   node->key = PSEUDO_EOF;
103   node->freq = 1;
104   node->code = "";
105   node->left = nullptr;
106   node->right = nullptr;
107   node->parent = nullptr;
108 
109   pq->push(node);
110   return RET_OK;
111 }
112 
GenerateHuffmanTable(const HuffmanNodePtr node,bool is_left_node)113 void HuffmanEncode::GenerateHuffmanTable(const HuffmanNodePtr node, bool is_left_node) {
114   MS_ASSERT(node != nullptr);
115   if (is_left_node) {
116     node->code = node->parent->code + "0";
117   } else {
118     node->code = node->parent->code + "1";
119   }
120 
121   if (node->left == nullptr && node->right == nullptr) {
122     huffman_table_[node->key] = node->code;
123   } else {
124     if (node->left != nullptr) {
125       GenerateHuffmanTable(node->left, true);
126     }
127     if (node->right != nullptr) {
128       GenerateHuffmanTable(node->right, false);
129     }
130   }
131 }
132 
BuildHuffmanTree(HuffmanPriorityQueue * pq)133 STATUS HuffmanEncode::BuildHuffmanTree(HuffmanPriorityQueue *pq) {
134   MS_ASSERT(pq != nullptr);
135   HuffmanNodePtr root = nullptr;
136   while (!pq->empty()) {
137     HuffmanNodePtr first = pq->top();
138     pq->pop();
139     if (pq->empty()) {
140       root = first;
141       break;
142     }
143     HuffmanNodePtr second = pq->top();
144     pq->pop();
145     auto new_node = new (std::nothrow) HuffmanNode();
146     if (new_node == nullptr) {
147       MS_LOG(ERROR) << "new HuffmanNode failed.";
148       return RET_MEMORY_FAILED;
149     }
150     this->huffman_nodes_.push_back(new_node);
151     new_node->freq = first->freq + second->freq;
152     new_node->left = first;
153     new_node->right = second;
154     first->parent = new_node;
155     second->parent = new_node;
156     pq->push(new_node);
157   }
158 
159   if (root == nullptr) {
160     MS_LOG(ERROR) << "huffman tree root node is nullptr.";
161     return RET_ERROR;
162   }
163 
164   if (root->left != nullptr) {
165     GenerateHuffmanTable(root->left, true);
166   }
167   if (root->right != nullptr) GenerateHuffmanTable(root->right, false);
168 
169   return RET_OK;
170 }
171 
DoHuffmanCompress(const int8_t * input_datas,const size_t data_size)172 STATUS HuffmanEncode::DoHuffmanCompress(const int8_t *input_datas, const size_t data_size) {
173   MS_ASSERT(input_datas != nullptr);
174   unsigned char out_c;
175   string code_str;
176   std::map<int, string>::iterator iter;
177   std::vector<std::string> encode_str = {"", "", ""};
178 
179   huffman_encoded_str_.clear();
180   for (iter = huffman_table_.begin(); iter != huffman_table_.end(); ++iter) {
181     encode_str[0] += std::to_string(iter->first) + " ";
182     encode_str[1] += iter->second + " ";
183   }
184 
185   for (size_t i = 0; i < data_size; i++) {
186     auto raw_num = input_datas[i];
187     iter = huffman_table_.find(raw_num);
188     if (iter != huffman_table_.end()) {
189       code_str += iter->second;
190     } else {
191       MS_LOG(ERROR) << "Can't find the huffman code " << raw_num;
192       return RET_ERROR;
193     }
194   }
195   iter = huffman_table_.find(PSEUDO_EOF);
196   if (iter != huffman_table_.end()) {
197     code_str += iter->second;
198   } else {
199     MS_LOG(ERROR) << "Can't find the huffman code pseudo-EOF";
200     return RET_ERROR;
201   }
202   out_c = 0;
203   for (size_t i = 0; i < code_str.length(); i++) {
204     auto tmp_c = code_str[i] == '0' ? 0 : 1;
205     out_c += tmp_c << ((quant::kMaxBit - 1) - (i % quant::kMaxBit));
206     if ((i + 1) % quant::kMaxBit == 0 || i == code_str.length() - 1) {
207       encode_str[2] += out_c;
208       out_c = 0;
209     }
210   }
211   huffman_encoded_str_ = encode_str[0] + "#" + encode_str[1] + "#" + encode_str[2];
212   return RET_OK;
213 }
214 
~HuffmanEncode()215 HuffmanEncode::~HuffmanEncode() {
216   for (auto &node : this->huffman_nodes_) {
217     delete node;
218   }
219   this->huffman_nodes_.clear();
220 }
221 }  // namespace lite
222 }  // namespace mindspore
223