1 // Copyright (c) 2017 Google Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Contains utils for reading, writing and debug printing bit streams. 16 17 #ifndef LIBSPIRV_UTIL_HUFFMAN_CODEC_H_ 18 #define LIBSPIRV_UTIL_HUFFMAN_CODEC_H_ 19 20 #include <algorithm> 21 #include <cassert> 22 #include <functional> 23 #include <queue> 24 #include <iomanip> 25 #include <map> 26 #include <memory> 27 #include <ostream> 28 #include <sstream> 29 #include <stack> 30 #include <tuple> 31 #include <unordered_map> 32 #include <vector> 33 34 namespace spvutils { 35 36 // Used to generate and apply a Huffman coding scheme. 37 // |Val| is the type of variable being encoded (for example a string or a 38 // literal). 39 template <class Val> 40 class HuffmanCodec { 41 struct Node; 42 43 public: 44 // Creates Huffman codec from a histogramm. 45 // Histogramm counts must not be zero. HuffmanCodec(const std::map<Val,uint32_t> & hist)46 explicit HuffmanCodec(const std::map<Val, uint32_t>& hist) { 47 if (hist.empty()) return; 48 49 // Heuristic estimate. 50 all_nodes_.reserve(3 * hist.size()); 51 52 // The queue is sorted in ascending order by weight (or by node id if 53 // weights are equal). 54 std::vector<Node*> queue_vector; 55 queue_vector.reserve(hist.size()); 56 std::priority_queue<Node*, std::vector<Node*>, 57 std::function<bool(const Node*, const Node*)>> 58 queue(LeftIsBigger, std::move(queue_vector)); 59 60 // Put all leaves in the queue. 61 for (const auto& pair : hist) { 62 Node* node = CreateNode(); 63 node->val = pair.first; 64 node->weight = pair.second; 65 assert(node->weight); 66 queue.push(node); 67 } 68 69 // Form the tree by combining two subtrees with the least weight, 70 // and pushing the root of the new tree in the queue. 71 while (true) { 72 // We push a node at the end of each iteration, so the queue is never 73 // supposed to be empty at this point, unless there are no leaves, but 74 // that case was already handled. 75 assert(!queue.empty()); 76 Node* right = queue.top(); 77 queue.pop(); 78 79 // If the queue is empty at this point, then the last node is 80 // the root of the complete Huffman tree. 81 if (queue.empty()) { 82 root_ = right; 83 break; 84 } 85 86 Node* left = queue.top(); 87 queue.pop(); 88 89 // Combine left and right into a new tree and push it into the queue. 90 Node* parent = CreateNode(); 91 parent->weight = right->weight + left->weight; 92 parent->left = left; 93 parent->right = right; 94 queue.push(parent); 95 } 96 97 // Traverse the tree and form encoding table. 98 CreateEncodingTable(); 99 } 100 101 // Prints the Huffman tree in the following format: 102 // w------w------'x' 103 // w------'y' 104 // Where w stands for the weight of the node. 105 // Right tree branches appear above left branches. Taking the right path 106 // adds 1 to the code, taking the left adds 0. PrintTree(std::ostream & out)107 void PrintTree(std::ostream& out) { 108 PrintTreeInternal(out, root_, 0); 109 } 110 111 // Traverses the tree and prints the Huffman table: value, code 112 // and optionally node weight for every leaf. 113 void PrintTable(std::ostream& out, bool print_weights = true) { 114 std::queue<std::pair<Node*, std::string>> queue; 115 queue.emplace(root_, ""); 116 117 while (!queue.empty()) { 118 const Node* node = queue.front().first; 119 const std::string code = queue.front().second; 120 queue.pop(); 121 if (!node->right && !node->left) { 122 out << node->val; 123 if (print_weights) 124 out << " " << node->weight; 125 out << " " << code << std::endl; 126 } else { 127 if (node->left) 128 queue.emplace(node->left, code + "0"); 129 130 if (node->right) 131 queue.emplace(node->right, code + "1"); 132 } 133 } 134 } 135 136 // Returns the Huffman table. The table was built at at construction time, 137 // this function just returns a const reference. 138 const std::unordered_map<Val, std::pair<uint64_t, size_t>>& GetEncodingTable()139 GetEncodingTable() const { 140 return encoding_table_; 141 } 142 143 // Encodes |val| and stores its Huffman code in the lower |num_bits| of 144 // |bits|. Returns false of |val| is not in the Huffman table. Encode(const Val & val,uint64_t * bits,size_t * num_bits)145 bool Encode(const Val& val, uint64_t* bits, size_t* num_bits) { 146 auto it = encoding_table_.find(val); 147 if (it == encoding_table_.end()) 148 return false; 149 *bits = it->second.first; 150 *num_bits = it->second.second; 151 return true; 152 } 153 154 // Reads bits one-by-one using callback |read_bit| until a match is found. 155 // Matching value is stored in |val|. Returns false if |read_bit| terminates 156 // before a code was mathced. 157 // |read_bit| has type bool func(bool* bit). When called, the next bit is 158 // stored in |bit|. |read_bit| returns false if the stream terminates 159 // prematurely. DecodeFromStream(const std::function<bool (bool *)> & read_bit,Val * val)160 bool DecodeFromStream(const std::function<bool(bool*)>& read_bit, Val* val) { 161 Node* node = root_; 162 while (true) { 163 assert(node); 164 165 if (node->left == nullptr && node->right == nullptr) { 166 *val = node->val; 167 return true; 168 } 169 170 bool go_right; 171 if (!read_bit(&go_right)) 172 return false; 173 174 if (go_right) 175 node = node->right; 176 else 177 node = node->left; 178 } 179 180 assert (0); 181 return false; 182 } 183 184 private: 185 // Huffman tree node. 186 struct Node { 187 Val val = Val(); 188 uint32_t weight = 0; 189 // Ids are issued sequentially starting from 1. Ids are used as an ordering 190 // tie-breaker, to make sure that the ordering (and resulting coding scheme) 191 // are consistent accross multiple platforms. 192 uint32_t id = 0; 193 Node* left = nullptr; 194 Node* right = nullptr; 195 }; 196 197 // Returns true if |left| has bigger weight than |right|. Node ids are 198 // used as tie-breaker. LeftIsBigger(const Node * left,const Node * right)199 static bool LeftIsBigger(const Node* left, const Node* right) { 200 if (left->weight == right->weight) { 201 assert (left->id != right->id); 202 return left->id > right->id; 203 } 204 return left->weight > right->weight; 205 } 206 207 // Prints subtree (helper function used by PrintTree). PrintTreeInternal(std::ostream & out,Node * node,size_t depth)208 static void PrintTreeInternal(std::ostream& out, Node* node, size_t depth) { 209 if (!node) 210 return; 211 212 const size_t kTextFieldWidth = 7; 213 214 if (!node->right && !node->left) { 215 out << node->val << std::endl; 216 } else { 217 if (node->right) { 218 std::stringstream label; 219 label << std::setfill('-') << std::left << std::setw(kTextFieldWidth) 220 << node->right->weight; 221 out << label.str(); 222 PrintTreeInternal(out, node->right, depth + 1); 223 } 224 225 if (node->left) { 226 out << std::string(depth * kTextFieldWidth, ' '); 227 std::stringstream label; 228 label << std::setfill('-') << std::left << std::setw(kTextFieldWidth) 229 << node->left->weight; 230 out << label.str(); 231 PrintTreeInternal(out, node->left, depth + 1); 232 } 233 } 234 } 235 236 // Traverses the Huffman tree and saves paths to the leaves as bit 237 // sequences to encoding_table_. CreateEncodingTable()238 void CreateEncodingTable() { 239 struct Context { 240 Context(Node* in_node, uint64_t in_bits, size_t in_depth) 241 : node(in_node), bits(in_bits), depth(in_depth) {} 242 Node* node; 243 // Huffman tree depth cannot exceed 64 as histogramm counts are expected 244 // to be positive and limited by numeric_limits<uint32_t>::max(). 245 // For practical applications tree depth would be much smaller than 64. 246 uint64_t bits; 247 size_t depth; 248 }; 249 250 std::queue<Context> queue; 251 queue.emplace(root_, 0, 0); 252 253 while (!queue.empty()) { 254 const Context& context = queue.front(); 255 const Node* node = context.node; 256 const uint64_t bits = context.bits; 257 const size_t depth = context.depth; 258 queue.pop(); 259 260 if (!node->right && !node->left) { 261 auto insertion_result = encoding_table_.emplace( 262 node->val, std::pair<uint64_t, size_t>(bits, depth)); 263 assert(insertion_result.second); 264 (void)insertion_result; 265 } else { 266 if (node->left) 267 queue.emplace(node->left, bits, depth + 1); 268 269 if (node->right) 270 queue.emplace(node->right, bits | (1ULL << depth), depth + 1); 271 } 272 } 273 } 274 275 // Creates new Huffman tree node and stores it in the deleter array. CreateNode()276 Node* CreateNode() { 277 all_nodes_.emplace_back(new Node()); 278 all_nodes_.back()->id = next_node_id_++; 279 return all_nodes_.back().get(); 280 } 281 282 // Huffman tree root. 283 Node* root_ = nullptr; 284 285 // Huffman tree deleter. 286 std::vector<std::unique_ptr<Node>> all_nodes_; 287 288 // Encoding table value -> {bits, num_bits}. 289 // Huffman codes are expected to never exceed 64 bit length (this is in fact 290 // impossible if frequencies are stored as uint32_t). 291 std::unordered_map<Val, std::pair<uint64_t, size_t>> encoding_table_; 292 293 // Next node id issued by CreateNode(); 294 uint32_t next_node_id_ = 1; 295 }; 296 297 } // namespace spvutils 298 299 #endif // LIBSPIRV_UTIL_HUFFMAN_CODEC_H_ 300