1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/litert/huffman_decode.h"
18 #include <queue>
19
20 namespace mindspore {
21 namespace lite {
DoHuffmanDecode(const std::string & input_str,void * decoded_data,size_t data_len)22 STATUS HuffmanDecode::DoHuffmanDecode(const std::string &input_str, void *decoded_data, size_t data_len) {
23 if (decoded_data == nullptr) {
24 MS_LOG(ERROR) << "decoded_data is nullptr.";
25 return RET_ERROR;
26 }
27
28 int status;
29 std::string huffman_decoded_str;
30 auto key_pos = input_str.find_first_of('#');
31 auto code_pos = input_str.find_first_of('#', key_pos + 1);
32 if (key_pos == std::string::npos || code_pos == std::string::npos) {
33 MS_LOG(ERROR) << "not found '#' in input_str";
34 return RET_ERROR;
35 }
36 if (key_pos + 1 > input_str.size() || code_pos + 1 > input_str.size()) {
37 MS_LOG(ERROR) << "pos extend input_str size.";
38 return RET_ERROR;
39 }
40 auto key = input_str.substr(0, key_pos);
41 auto code = input_str.substr(key_pos + 1, code_pos - key_pos - 1);
42 auto encoded_data = input_str.substr(code_pos + 1);
43
44 auto root = new (std::nothrow) HuffmanNode();
45 if (root == nullptr) {
46 MS_LOG(ERROR) << "new HuffmanNode failed.";
47 return RET_MEMORY_FAILED;
48 }
49 root->left = nullptr;
50 root->right = nullptr;
51 root->parent = nullptr;
52
53 status = RebuildHuffmanTree(key, code, root);
54 if (status != RET_OK) {
55 MS_LOG(ERROR) << "Rebuild huffman tree failed.";
56 delete root;
57 return status;
58 }
59
60 status = DoHuffmanDecompress(root, encoded_data, &huffman_decoded_str);
61 if (status != RET_OK) {
62 MS_LOG(ERROR) << "DoHuffmanDecompress failed.";
63 delete root;
64 return status;
65 }
66
67 size_t len = huffman_decoded_str.length();
68 if (data_len >= len) {
69 memcpy(decoded_data, huffman_decoded_str.c_str(), len);
70 } else {
71 FreeHuffmanNodeTree(root);
72 return RET_ERROR;
73 }
74 FreeHuffmanNodeTree(root);
75 return RET_OK;
76 }
77
RebuildHuffmanTree(std::string keys,std::string codes,const HuffmanNodePtr & root)78 STATUS HuffmanDecode::RebuildHuffmanTree(std::string keys, std::string codes, const HuffmanNodePtr &root) {
79 CHECK_NULL_RETURN(root);
80 HuffmanNodePtr cur_node;
81 HuffmanNodePtr tmp_node;
82 HuffmanNodePtr new_node;
83
84 auto huffman_keys = Str2Vec(std::move(keys));
85 auto huffman_codes = Str2Vec(std::move(codes));
86
87 for (size_t i = 0; i < huffman_codes.size(); ++i) {
88 auto key = stoi(huffman_keys[i]);
89 auto code = huffman_codes[i];
90 auto code_len = code.length();
91 cur_node = root;
92 for (size_t j = 0; j < code_len; ++j) {
93 if (code[j] == '0') {
94 tmp_node = cur_node->left;
95 } else if (code[j] == '1') {
96 tmp_node = cur_node->right;
97 } else {
98 MS_LOG(ERROR) << "find huffman code is not 0 or 1";
99 return RET_ERROR;
100 }
101
102 if (tmp_node == nullptr) {
103 new_node = new (std::nothrow) HuffmanNode();
104 if (new_node == nullptr) {
105 MS_LOG(ERROR) << "new HuffmanNode failed.";
106 return RET_MEMORY_FAILED;
107 }
108 new_node->left = nullptr;
109 new_node->right = nullptr;
110 new_node->parent = cur_node;
111
112 if (j == code_len - 1) {
113 new_node->key = key;
114 new_node->code = code;
115 }
116
117 if (code[j] == '0') {
118 cur_node->left = new_node;
119 } else {
120 cur_node->right = new_node;
121 }
122
123 tmp_node = new_node;
124 } else if (j == code_len - 1) {
125 MS_LOG(ERROR) << "the huffman code is incomplete.";
126 return RET_ERROR;
127 } else if (tmp_node->left == nullptr && tmp_node->right == nullptr) {
128 MS_LOG(ERROR) << "the huffman code is incomplete";
129 return RET_ERROR;
130 }
131 cur_node = tmp_node;
132 }
133 }
134 return RET_OK;
135 }
136
DoHuffmanDecompress(HuffmanNodePtr root,std::string encoded_data,std::string * decoded_str)137 STATUS HuffmanDecode::DoHuffmanDecompress(HuffmanNodePtr root, std::string encoded_data, std::string *decoded_str) {
138 CHECK_NULL_RETURN(decoded_str);
139 CHECK_NULL_RETURN(root);
140 HuffmanNodePtr cur_node = root;
141 bool pseudo_eof = false;
142 size_t pos = 0;
143 unsigned char flag;
144
145 decoded_str->clear();
146 while (pos < encoded_data.length()) {
147 auto u_char = static_cast<unsigned char>(encoded_data[pos]);
148 flag = 0x80;
149 for (size_t i = 0; i < 8; ++i) { // traverse the 8 bit num, to find the leaf node
150 if (u_char & flag) {
151 cur_node = cur_node->right;
152 } else {
153 cur_node = cur_node->left;
154 }
155 CHECK_NULL_RETURN(cur_node);
156 if (cur_node->left == nullptr && cur_node->right == nullptr) {
157 auto key = cur_node->key;
158 if (key == PSEUDO_EOF) {
159 pseudo_eof = true;
160 break;
161 } else {
162 *decoded_str += static_cast<char>(cur_node->key);
163 cur_node = root;
164 }
165 }
166 flag = flag >> 1;
167 }
168 pos++;
169 if (pseudo_eof) {
170 break;
171 }
172 }
173 return RET_OK;
174 }
175
FreeHuffmanNodeTree(HuffmanNodePtr root)176 void HuffmanDecode::FreeHuffmanNodeTree(HuffmanNodePtr root) {
177 if (root == nullptr) {
178 return;
179 }
180 std::queue<HuffmanNodePtr> node_queue;
181 node_queue.push(root);
182 while (!node_queue.empty()) {
183 auto cur_node = node_queue.front();
184 node_queue.pop();
185 if (cur_node->left != nullptr) {
186 node_queue.push(cur_node->left);
187 }
188 if (cur_node->right != nullptr) {
189 node_queue.push(cur_node->right);
190 }
191 delete (cur_node);
192 }
193 }
194 } // namespace lite
195 } // namespace mindspore
196