• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // Contains utils for reading, writing and debug printing bit streams.
16 
17 #ifndef LIBSPIRV_UTIL_HUFFMAN_CODEC_H_
18 #define LIBSPIRV_UTIL_HUFFMAN_CODEC_H_
19 
20 #include <algorithm>
21 #include <cassert>
22 #include <functional>
23 #include <queue>
24 #include <iomanip>
25 #include <map>
26 #include <memory>
27 #include <ostream>
28 #include <sstream>
29 #include <stack>
30 #include <tuple>
31 #include <unordered_map>
32 #include <vector>
33 
34 namespace spvutils {
35 
36 // Used to generate and apply a Huffman coding scheme.
37 // |Val| is the type of variable being encoded (for example a string or a
38 // literal).
39 template <class Val>
40 class HuffmanCodec {
41   struct Node;
42 
43  public:
44   // Creates Huffman codec from a histogramm.
45   // Histogramm counts must not be zero.
HuffmanCodec(const std::map<Val,uint32_t> & hist)46   explicit HuffmanCodec(const std::map<Val, uint32_t>& hist) {
47     if (hist.empty()) return;
48 
49     // Heuristic estimate.
50     all_nodes_.reserve(3 * hist.size());
51 
52     // The queue is sorted in ascending order by weight (or by node id if
53     // weights are equal).
54     std::vector<Node*> queue_vector;
55     queue_vector.reserve(hist.size());
56     std::priority_queue<Node*, std::vector<Node*>,
57         std::function<bool(const Node*, const Node*)>>
58             queue(LeftIsBigger, std::move(queue_vector));
59 
60     // Put all leaves in the queue.
61     for (const auto& pair : hist) {
62       Node* node = CreateNode();
63       node->val = pair.first;
64       node->weight = pair.second;
65       assert(node->weight);
66       queue.push(node);
67     }
68 
69     // Form the tree by combining two subtrees with the least weight,
70     // and pushing the root of the new tree in the queue.
71     while (true) {
72       // We push a node at the end of each iteration, so the queue is never
73       // supposed to be empty at this point, unless there are no leaves, but
74       // that case was already handled.
75       assert(!queue.empty());
76       Node* right = queue.top();
77       queue.pop();
78 
79       // If the queue is empty at this point, then the last node is
80       // the root of the complete Huffman tree.
81       if (queue.empty()) {
82         root_ = right;
83         break;
84       }
85 
86       Node* left = queue.top();
87       queue.pop();
88 
89       // Combine left and right into a new tree and push it into the queue.
90       Node* parent = CreateNode();
91       parent->weight = right->weight + left->weight;
92       parent->left = left;
93       parent->right = right;
94       queue.push(parent);
95     }
96 
97     // Traverse the tree and form encoding table.
98     CreateEncodingTable();
99   }
100 
101   // Prints the Huffman tree in the following format:
102   // w------w------'x'
103   //        w------'y'
104   // Where w stands for the weight of the node.
105   // Right tree branches appear above left branches. Taking the right path
106   // adds 1 to the code, taking the left adds 0.
PrintTree(std::ostream & out)107   void PrintTree(std::ostream& out) {
108     PrintTreeInternal(out, root_, 0);
109   }
110 
111   // Traverses the tree and prints the Huffman table: value, code
112   // and optionally node weight for every leaf.
113   void PrintTable(std::ostream& out, bool print_weights = true) {
114     std::queue<std::pair<Node*, std::string>> queue;
115     queue.emplace(root_, "");
116 
117     while (!queue.empty()) {
118       const Node* node = queue.front().first;
119       const std::string code = queue.front().second;
120       queue.pop();
121       if (!node->right && !node->left) {
122         out << node->val;
123         if (print_weights)
124             out << " " << node->weight;
125         out << " " << code << std::endl;
126       } else {
127         if (node->left)
128           queue.emplace(node->left, code + "0");
129 
130         if (node->right)
131           queue.emplace(node->right, code + "1");
132       }
133     }
134   }
135 
136   // Returns the Huffman table. The table was built at at construction time,
137   // this function just returns a const reference.
138   const std::unordered_map<Val, std::pair<uint64_t, size_t>>&
GetEncodingTable()139       GetEncodingTable() const {
140     return encoding_table_;
141   }
142 
143   // Encodes |val| and stores its Huffman code in the lower |num_bits| of
144   // |bits|. Returns false of |val| is not in the Huffman table.
Encode(const Val & val,uint64_t * bits,size_t * num_bits)145   bool Encode(const Val& val, uint64_t* bits, size_t* num_bits) {
146     auto it = encoding_table_.find(val);
147     if (it == encoding_table_.end())
148       return false;
149     *bits = it->second.first;
150     *num_bits = it->second.second;
151     return true;
152   }
153 
154   // Reads bits one-by-one using callback |read_bit| until a match is found.
155   // Matching value is stored in |val|. Returns false if |read_bit| terminates
156   // before a code was mathced.
157   // |read_bit| has type bool func(bool* bit). When called, the next bit is
158   // stored in |bit|. |read_bit| returns false if the stream terminates
159   // prematurely.
DecodeFromStream(const std::function<bool (bool *)> & read_bit,Val * val)160   bool DecodeFromStream(const std::function<bool(bool*)>& read_bit, Val* val) {
161     Node* node = root_;
162     while (true) {
163       assert(node);
164 
165       if (node->left == nullptr && node->right == nullptr) {
166         *val = node->val;
167         return true;
168       }
169 
170       bool go_right;
171       if (!read_bit(&go_right))
172         return false;
173 
174       if (go_right)
175         node = node->right;
176       else
177         node = node->left;
178     }
179 
180     assert (0);
181     return false;
182   }
183 
184  private:
185   // Huffman tree node.
186   struct Node {
187     Val val = Val();
188     uint32_t weight = 0;
189     // Ids are issued sequentially starting from 1. Ids are used as an ordering
190     // tie-breaker, to make sure that the ordering (and resulting coding scheme)
191     // are consistent accross multiple platforms.
192     uint32_t id = 0;
193     Node* left = nullptr;
194     Node* right = nullptr;
195   };
196 
197   // Returns true if |left| has bigger weight than |right|. Node ids are
198   // used as tie-breaker.
LeftIsBigger(const Node * left,const Node * right)199   static bool LeftIsBigger(const Node* left, const Node* right) {
200     if (left->weight == right->weight) {
201       assert (left->id != right->id);
202       return left->id > right->id;
203     }
204     return left->weight > right->weight;
205   }
206 
207   // Prints subtree (helper function used by PrintTree).
PrintTreeInternal(std::ostream & out,Node * node,size_t depth)208   static void PrintTreeInternal(std::ostream& out, Node* node, size_t depth) {
209     if (!node)
210       return;
211 
212     const size_t kTextFieldWidth = 7;
213 
214     if (!node->right && !node->left) {
215       out << node->val << std::endl;
216     } else {
217       if (node->right) {
218         std::stringstream label;
219         label << std::setfill('-') << std::left << std::setw(kTextFieldWidth)
220               << node->right->weight;
221         out << label.str();
222         PrintTreeInternal(out, node->right, depth + 1);
223       }
224 
225       if (node->left) {
226         out << std::string(depth * kTextFieldWidth, ' ');
227         std::stringstream label;
228         label << std::setfill('-') << std::left << std::setw(kTextFieldWidth)
229               << node->left->weight;
230         out << label.str();
231         PrintTreeInternal(out, node->left, depth + 1);
232       }
233     }
234   }
235 
236   // Traverses the Huffman tree and saves paths to the leaves as bit
237   // sequences to encoding_table_.
CreateEncodingTable()238   void CreateEncodingTable() {
239     struct Context {
240       Context(Node* in_node, uint64_t in_bits, size_t in_depth)
241           :  node(in_node), bits(in_bits), depth(in_depth) {}
242       Node* node;
243       // Huffman tree depth cannot exceed 64 as histogramm counts are expected
244       // to be positive and limited by numeric_limits<uint32_t>::max().
245       // For practical applications tree depth would be much smaller than 64.
246       uint64_t bits;
247       size_t depth;
248     };
249 
250     std::queue<Context> queue;
251     queue.emplace(root_, 0, 0);
252 
253     while (!queue.empty()) {
254       const Context& context = queue.front();
255       const Node* node = context.node;
256       const uint64_t bits = context.bits;
257       const size_t depth = context.depth;
258       queue.pop();
259 
260       if (!node->right && !node->left) {
261         auto insertion_result = encoding_table_.emplace(
262             node->val, std::pair<uint64_t, size_t>(bits, depth));
263         assert(insertion_result.second);
264         (void)insertion_result;
265       } else {
266         if (node->left)
267           queue.emplace(node->left, bits, depth + 1);
268 
269         if (node->right)
270           queue.emplace(node->right, bits | (1ULL << depth), depth + 1);
271       }
272     }
273   }
274 
275   // Creates new Huffman tree node and stores it in the deleter array.
CreateNode()276   Node* CreateNode() {
277     all_nodes_.emplace_back(new Node());
278     all_nodes_.back()->id = next_node_id_++;
279     return all_nodes_.back().get();
280   }
281 
282   // Huffman tree root.
283   Node* root_ = nullptr;
284 
285   // Huffman tree deleter.
286   std::vector<std::unique_ptr<Node>> all_nodes_;
287 
288   // Encoding table value -> {bits, num_bits}.
289   // Huffman codes are expected to never exceed 64 bit length (this is in fact
290   // impossible if frequencies are stored as uint32_t).
291   std::unordered_map<Val, std::pair<uint64_t, size_t>> encoding_table_;
292 
293   // Next node id issued by CreateNode();
294   uint32_t next_node_id_ = 1;
295 };
296 
297 }  // namespace spvutils
298 
299 #endif  // LIBSPIRV_UTIL_HUFFMAN_CODEC_H_
300