• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2022 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <atomic>
16 #include <cstdint>
17 #include <fstream>
18 #include <limits>
19 #include <map>
20 #include <memory>
21 #include <mutex>
22 #include <numeric>
23 #include <queue>
24 #include <set>
25 #include <string>
26 #include <thread>
27 #include <vector>
28 
29 #include <openssl/sha.h>
30 
31 #include "absl/memory/memory.h"
32 #include "absl/strings/ascii.h"
33 #include "absl/strings/str_cat.h"
34 #include "absl/strings/str_join.h"
35 #include "absl/strings/str_replace.h"
36 #include "absl/strings/str_split.h"
37 #include "absl/types/optional.h"
38 #include "absl/types/variant.h"
39 
40 #include "src/core/ext/transport/chttp2/transport/huffsyms.h"
41 #include "src/core/util/env.h"
42 #include "src/core/util/match.h"
43 
44 ///////////////////////////////////////////////////////////////////////////////
45 // SHA256 hash handling
46 // We need strong uniqueness checks of some very long strings - so we hash
47 // them with SHA256 and compare.
48 struct Hash {
49   uint8_t bytes[SHA256_DIGEST_LENGTH];
operator ==Hash50   bool operator==(const Hash& other) const {
51     return memcmp(bytes, other.bytes, SHA256_DIGEST_LENGTH) == 0;
52   }
operator <Hash53   bool operator<(const Hash& other) const {
54     return memcmp(bytes, other.bytes, SHA256_DIGEST_LENGTH) < 0;
55   }
ToStringHash56   std::string ToString() const {
57     std::string result;
58     for (int i = 0; i < SHA256_DIGEST_LENGTH; i++) {
59       absl::StrAppend(&result, absl::Hex(bytes[i], absl::kZeroPad2));
60     }
61     return result;
62   }
63 };
64 
65 // Given a vector of ints (T), return a Hash object with the sha256
66 template <typename T>
HashVec(absl::string_view type,const std::vector<T> & v)67 Hash HashVec(absl::string_view type, const std::vector<T>& v) {
68   Hash h;
69   std::string text = absl::StrCat(type, ":", absl::StrJoin(v, ","));
70   SHA256(reinterpret_cast<const uint8_t*>(text.data()), text.size(), h.bytes);
71   return h;
72 }
73 
74 ///////////////////////////////////////////////////////////////////////////////
75 // BitQueue
76 // A utility that treats a sequence of bits like a queue
77 class BitQueue {
78  public:
BitQueue(unsigned mask,int len)79   BitQueue(unsigned mask, int len) : mask_(mask), len_(len) {}
BitQueue()80   BitQueue() : BitQueue(0, 0) {}
81 
82   // Return the most significant bit (the front of the queue)
Front() const83   int Front() const { return (mask_ >> (len_ - 1)) & 1; }
84   // Pop one bit off the queue
Pop()85   void Pop() {
86     mask_ &= ~(1 << (len_ - 1));
87     len_--;
88   }
Empty() const89   bool Empty() const { return len_ == 0; }
length() const90   int length() const { return len_; }
mask() const91   unsigned mask() const { return mask_; }
92 
93   // Text representation of the queue
ToString() const94   std::string ToString() const {
95     return absl::StrCat(absl::Hex(mask_), "/", len_);
96   }
97 
98   // Comparisons so that we can use BitQueue as a key in a std::map
operator <(const BitQueue & other) const99   bool operator<(const BitQueue& other) const {
100     return std::tie(mask_, len_) < std::tie(other.mask_, other.len_);
101   }
102 
103  private:
104   // The bits
105   unsigned mask_;
106   // How many bits have we
107   int len_;
108 };
109 
110 ///////////////////////////////////////////////////////////////////////////////
111 // Symbol sets for the huffman tree
112 
113 // A Sym is one symbol in the tree, and the bits that we need to read to decode
114 // that symbol. As we progress through decoding we remove bits from the symbol,
115 // but also condense the number of symbols we're considering.
116 struct Sym {
117   BitQueue bits;
118   int symbol;
119 
operator <Sym120   bool operator<(const Sym& other) const {
121     return std::tie(bits, symbol) < std::tie(other.bits, other.symbol);
122   }
123 };
124 
125 // A SymSet is all the symbols we're considering at some time
126 using SymSet = std::vector<Sym>;
127 
128 // Debug utility to turn a SymSet into a string
SymSetString(const SymSet & syms)129 std::string SymSetString(const SymSet& syms) {
130   std::vector<std::string> parts;
131   for (const Sym& sym : syms) {
132     parts.push_back(absl::StrCat(sym.symbol, ":", sym.bits.ToString()));
133   }
134   return absl::StrJoin(parts, ",");
135 }
136 
137 // Initial SymSet - all the symbols [0..256] with their bits initialized from
138 // the http2 static huffman tree.
AllSyms()139 SymSet AllSyms() {
140   SymSet syms;
141   for (int i = 0; i < GRPC_CHTTP2_NUM_HUFFSYMS; i++) {
142     Sym sym;
143     sym.bits =
144         BitQueue(grpc_chttp2_huffsyms[i].bits, grpc_chttp2_huffsyms[i].length);
145     sym.symbol = i;
146     syms.push_back(sym);
147   }
148   return syms;
149 }
150 
151 // What would we do after reading a set of bits?
152 struct ReadActions {
153   // Emit these symbols
154   std::vector<int> emit;
155   // Number of bits that were consumed by the read
156   int consumed;
157   // Remaining SymSet that we need to consider on the next read action
158   SymSet remaining;
159 };
160 
161 // Given a SymSet \a pending, read through the bits in \a index and determine
162 // what actions the decoder should take.
163 // allow_multiple controls the behavior should we get to the last bit in pending
164 // and hence know which symbol to emit, but we still have bits in index.
165 // We could either start decoding the next symbol (allow_multiple == true), or
166 // we could stop (allow_multiple == false).
167 // If allow_multiple is true we tend to emit more per read op, but generate
168 // bigger tables.
ActionsFor(BitQueue index,SymSet pending,bool allow_multiple)169 ReadActions ActionsFor(BitQueue index, SymSet pending, bool allow_multiple) {
170   std::vector<int> emit;
171   int len_start = index.length();
172   int len_consume = len_start;
173 
174   // We read one bit in index at a time, so whilst we have bits...
175   while (!index.Empty()) {
176     SymSet next_pending;
177     // For each symbol in the pending set
178     for (auto sym : pending) {
179       // If the first bit doesn't match, then that symbol is not part of our
180       // remaining set.
181       if (sym.bits.Front() != index.Front()) continue;
182       sym.bits.Pop();
183       next_pending.push_back(sym);
184     }
185     switch (next_pending.size()) {
186       case 0:
187         // There should be no bit patterns that are undecodable.
188         abort();
189       case 1:
190         // If we have one symbol left, we need to have decoded all of it.
191         if (!next_pending[0].bits.Empty()) abort();
192         // Emit that symbol
193         emit.push_back(next_pending[0].symbol);
194         // Track how many bits we've read.
195         len_consume = index.length() - 1;
196         // If we allow multiple, reprime pending and continue, otherwise stop.
197         if (!allow_multiple) goto done;
198         pending = AllSyms();
199         break;
200       default:
201         pending = std::move(next_pending);
202         break;
203     }
204     // Finished with this bit, continue with next
205     index.Pop();
206   }
207 done:
208   return ReadActions{std::move(emit), len_start - len_consume, pending};
209 }
210 
211 ///////////////////////////////////////////////////////////////////////////////
212 // MatchCase
213 // A variant that helps us bunch together related ReadActions
214 
215 // A Matched in a MatchCase indicates that we need to emit some number of
216 // symbols
217 struct Matched {
218   // number of symbols to emit
219   int emits;
220 
operator <Matched221   bool operator<(const Matched& other) const { return emits < other.emits; }
222 };
223 
224 // Unmatched says we didn't emit anything and we need to keep decoding
225 struct Unmatched {
226   SymSet syms;
227 
operator <Unmatched228   bool operator<(const Unmatched& other) const { return syms < other.syms; }
229 };
230 
231 // Emit end of stream
232 struct End {
operator <End233   bool operator<(End) const { return false; }
234 };
235 
236 using MatchCase = absl::variant<Matched, Unmatched, End>;
237 
238 ///////////////////////////////////////////////////////////////////////////////
239 // Text & numeric helper functions
240 
241 // Given a vector of lines, indent those lines by some number of indents
242 // (2 spaces) and return that.
IndentLines(std::vector<std::string> lines,int n=1)243 std::vector<std::string> IndentLines(std::vector<std::string> lines,
244                                      int n = 1) {
245   std::string indent(2 * n, ' ');
246   for (auto& line : lines) {
247     line = absl::StrCat(indent, line);
248   }
249   return lines;
250 }
251 
252 // Given a snake_case_name return a PascalCaseName
ToPascalCase(const std::string & in)253 std::string ToPascalCase(const std::string& in) {
254   std::string out;
255   bool next_upper = true;
256   for (char c : in) {
257     if (c == '_') {
258       next_upper = true;
259     } else {
260       if (next_upper) {
261         out.push_back(toupper(c));
262         next_upper = false;
263       } else {
264         out.push_back(c);
265       }
266     }
267   }
268   return out;
269 }
270 
271 // Return a uint type for some number of bits (16 -> uint16_t, 32 -> uint32_t)
Uint(int bits)272 std::string Uint(int bits) { return absl::StrCat("uint", bits, "_t"); }
273 
274 // Given a maximum value, how many bits to store it in a uint
TypeBitsForMax(int max)275 int TypeBitsForMax(int max) {
276   if (max <= 255) {
277     return 8;
278   } else if (max <= 65535) {
279     return 16;
280   } else {
281     return 32;
282   }
283 }
284 
285 // Combine Uint & TypeBitsForMax to make for more concise code
TypeForMax(int max)286 std::string TypeForMax(int max) { return Uint(TypeBitsForMax(max)); }
287 
288 // How many bits are needed to encode a value
BitsForMaxValue(int x)289 int BitsForMaxValue(int x) {
290   int n = 0;
291   while (x >= (1 << n)) n++;
292   return n;
293 }
294 
295 ///////////////////////////////////////////////////////////////////////////////
296 // Codegen framework
297 // Some helpers so we don't need to generate all the code linearly, which helps
298 // organize this a little more nicely.
299 
300 // An Item is our primitive for code generation, it can generate some lines
301 // that it would like to emit - those lines are fed to a parent item that might
302 // generate more lines or mutate the ones we return, and so on until codegen
303 // is complete.
304 class Item {
305  public:
306   virtual ~Item() = default;
307   virtual std::vector<std::string> ToLines() const = 0;
ToString() const308   std::string ToString() const {
309     return absl::StrCat(absl::StrJoin(ToLines(), "\n"), "\n");
310   }
311 };
312 using ItemPtr = std::unique_ptr<Item>;
313 
314 // An item that emits one line (the one given as an argument!)
315 class String : public Item {
316  public:
String(std::string s)317   explicit String(std::string s) : s_(std::move(s)) {}
ToLines() const318   std::vector<std::string> ToLines() const override { return {s_}; }
319 
320  private:
321   std::string s_;
322 };
323 
324 // An item that returns a fixed copyright notice and autogenerated note text.
325 class Prelude final : public Item {
326  public:
Prelude(absl::string_view comment_prefix,int copyright_year)327   explicit Prelude(absl::string_view comment_prefix, int copyright_year)
328       : comment_prefix_(comment_prefix), copyright_year_(copyright_year) {}
ToLines() const329   std::vector<std::string> ToLines() const override {
330     auto line = [this](absl::string_view text) {
331       return absl::StrCat(comment_prefix_, " ", text);
332     };
333     return {
334         line(absl::StrCat("Copyright ", copyright_year_, " gRPC authors.")),
335         line(""),
336         line("Licensed under the Apache License, Version 2.0 (the "
337              "\"License\");"),
338         line(
339             "you may not use this file except in compliance with the License."),
340         line("You may obtain a copy of the License at"),
341         line(""),
342         line("    http://www.apache.org/licenses/LICENSE-2.0"),
343         line(""),
344         line("Unless required by applicable law or agreed to in writing, "
345              "software"),
346         line("distributed under the License is distributed on an \"AS IS\" "
347              "BASIS,"),
348         line("WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or "
349              "implied."),
350         line("See the License for the specific language governing permissions "
351              "and"),
352         line("limitations under the License."),
353         "",
354         line("This file is autogenerated: see "
355              "tools/codegen/core/gen_huffman_decompressor.cc"),
356         ""};
357   }
358 
359  private:
360   absl::string_view comment_prefix_;
361   int copyright_year_;
362 };
363 
364 class Switch;
365 
366 // A Sink is an Item that we can add more Items to.
367 // At codegen time it calls each of its children in turn and concatenates
368 // their results together.
369 class Sink : public Item {
370  public:
ToLines() const371   std::vector<std::string> ToLines() const override {
372     std::vector<std::string> lines;
373     for (const auto& item : children_) {
374       for (const auto& line : item->ToLines()) {
375         lines.push_back(line);
376       }
377     }
378     return lines;
379   }
380 
381   // Add one string to our output.
Add(std::string s)382   void Add(std::string s) {
383     children_.push_back(std::make_unique<String>(std::move(s)));
384   }
385 
386   // Add an item of type T to our output (constructing it with args).
387   template <typename T, typename... Args>
Add(Args &&...args)388   T* Add(Args&&... args) {
389     auto v = std::make_unique<T>(std::forward<Args>(args)...);
390     auto* r = v.get();
391     children_.push_back(std::move(v));
392     return r;
393   }
394 
395  private:
396   std::vector<ItemPtr> children_;
397 };
398 
399 // A sink that indents its lines by one indent (2 spaces)
400 class Indent : public Sink {
401  public:
ToLines() const402   std::vector<std::string> ToLines() const override {
403     return IndentLines(Sink::ToLines());
404   }
405 };
406 
407 // A Sink that wraps its lines in a while block
408 class While : public Sink {
409  public:
While(std::string cond)410   explicit While(std::string cond) : cond_(std::move(cond)) {}
ToLines() const411   std::vector<std::string> ToLines() const override {
412     std::vector<std::string> lines;
413     lines.push_back(absl::StrCat("while (", cond_, ") {"));
414     for (const auto& line : IndentLines(Sink::ToLines())) {
415       lines.push_back(line);
416     }
417     lines.push_back("}");
418     return lines;
419   }
420 
421  private:
422   std::string cond_;
423 };
424 
425 // A switch statement.
426 // Cases can be modified by calling the Case member.
427 // Identical cases are collapsed into 'case X: case Y:' type blocks.
428 class Switch : public Item {
429  public:
430   struct Default {
operator <Switch::Default431     bool operator<(const Default&) const { return false; }
operator ==Switch::Default432     bool operator==(const Default&) const { return true; }
433   };
434   using CaseLabel = absl::variant<int, std::string, Default>;
435   // \a cond is the condition to place at the head of the switch statement.
436   // eg. "switch (cond) {".
Switch(std::string cond)437   explicit Switch(std::string cond) : cond_(std::move(cond)) {}
ToLines() const438   std::vector<std::string> ToLines() const override {
439     std::map<std::string, std::vector<CaseLabel>> reverse_map;
440     for (const auto& kv : cases_) {
441       reverse_map[kv.second.ToString()].push_back(kv.first);
442     }
443     std::vector<std::pair<std::string, std::vector<CaseLabel>>>
444         sorted_reverse_map;
445     sorted_reverse_map.reserve(reverse_map.size());
446     for (auto& kv : reverse_map) {
447       sorted_reverse_map.push_back(kv);
448     }
449     for (auto& e : sorted_reverse_map) {
450       std::sort(e.second.begin(), e.second.end());
451     }
452     std::sort(sorted_reverse_map.begin(), sorted_reverse_map.end(),
453               [](const auto& a, const auto& b) { return a.second < b.second; });
454     std::vector<std::string> lines;
455     lines.push_back(absl::StrCat("switch (", cond_, ") {"));
456     for (const auto& kv : sorted_reverse_map) {
457       for (const auto& cond : kv.second) {
458         lines.push_back(absl::StrCat(
459             "  ",
460             grpc_core::Match(
461                 cond, [](Default) -> std::string { return "default"; },
462                 [](int i) { return absl::StrCat("case ", i); },
463                 [](const std::string& s) { return absl::StrCat("case ", s); }),
464             ":"));
465       }
466       lines.back().append(" {");
467       for (const auto& case_line :
468            IndentLines(cases_.find(kv.second[0])->second.ToLines(), 2)) {
469         lines.push_back(case_line);
470       }
471       lines.push_back("  }");
472     }
473     lines.push_back("}");
474     return lines;
475   }
476 
Case(CaseLabel cond)477   Sink* Case(CaseLabel cond) { return &cases_[cond]; }
478 
479  private:
480   std::string cond_;
481   std::map<CaseLabel, Sink> cases_;
482 };
483 
484 ///////////////////////////////////////////////////////////////////////////////
485 // BuildCtx declaration
486 // Shared state for one code gen attempt
487 
488 class TableBuilder;
489 class FunMaker;
490 
491 class BuildCtx {
492  public:
BuildCtx(std::vector<int> max_bits_for_depth,Sink * global_fns,Sink * global_decls,Sink * global_values,FunMaker * fun_maker)493   BuildCtx(std::vector<int> max_bits_for_depth, Sink* global_fns,
494            Sink* global_decls, Sink* global_values, FunMaker* fun_maker)
495       : max_bits_for_depth_(std::move(max_bits_for_depth)),
496         global_fns_(global_fns),
497         global_decls_(global_decls),
498         global_values_(global_values),
499         fun_maker_(fun_maker) {}
500 
501   void AddStep(SymSet start_syms, int num_bits, bool is_top, bool refill,
502                int depth, Sink* out);
503   void AddMatchBody(TableBuilder* table_builder, std::string index,
504                     std::string ofs, const MatchCase& match_case, bool refill,
505                     int depth, Sink* out);
506   void AddDone(SymSet start_syms, int num_bits, bool all_ones_so_far,
507                Sink* out);
508 
NewId()509   int NewId() { return next_id_++; }
MaxBitsForTop() const510   int MaxBitsForTop() const { return max_bits_for_depth_[0]; }
511 
PreviousNameForArtifact(std::string proposed_name,Hash hash)512   absl::optional<std::string> PreviousNameForArtifact(std::string proposed_name,
513                                                       Hash hash) {
514     auto it = arrays_.find(hash);
515     if (it == arrays_.end()) {
516       arrays_.emplace(hash, proposed_name);
517       return absl::nullopt;
518     }
519     return it->second;
520   }
521 
global_fns() const522   Sink* global_fns() const { return global_fns_; }
global_decls() const523   Sink* global_decls() const { return global_decls_; }
global_values() const524   Sink* global_values() const { return global_values_; }
525 
526  private:
527   void AddDoneCase(size_t n, size_t n_bits, bool all_ones_so_far, SymSet syms,
528                    std::vector<uint8_t> emit, TableBuilder* table_builder,
529                    std::map<absl::optional<int>, int>* cases);
530 
531   const std::vector<int> max_bits_for_depth_;
532   std::map<Hash, std::string> arrays_;
533   int next_id_ = 1;
534   Sink* const global_fns_;
535   Sink* const global_decls_;
536   Sink* const global_values_;
537   FunMaker* const fun_maker_;
538 };
539 
540 ///////////////////////////////////////////////////////////////////////////////
541 // TableBuilder
542 // All our magic for building decode tables.
543 // We have three kinds of tables to generate:
544 // 1. op tables that translate a bit sequence to which decode case we should
545 //    execute (and arguments to it), and
546 // 2. emit tables that translate an index given by the op table and tell us
547 //    which symbols to emit
548 // Op table format
549 // Our opcodes contain an offset into an emit table, a number of bits consumed
550 // and an operation. The consumed bits are how many of the presented to us bits
551 // we actually took. The operation tells whether to emit some symbols (and how
552 // many) or to keep decoding.
553 // Optimization 1:
554 // op tables are essentially dense maps of bits -> opcode, and it turns out
555 // that *many* of the opcodes repeat across index bits for some of our tables
556 // so for those we split the table into two levels: first level indexes into
557 // a child table, and the child table contains the deduped opcodes.
558 // Optimization 2:
559 // Emit tables are a bit list of uint8_ts, and are indexed into by the op
560 // table (with an offset and length) - since many symbols get repeated, we try
561 // to overlay the symbols in the emit table to reduce the size.
562 // Optimization 3:
563 // We shard the table into some number of slices and use the top bits of the
564 // incoming lookup to select the shard. This tends to allow us to use smaller
565 // types to represent the table, saving on footprint.
566 
567 class TableBuilder {
568  public:
TableBuilder(BuildCtx * ctx)569   explicit TableBuilder(BuildCtx* ctx) : ctx_(ctx), id_(ctx->NewId()) {}
570 
571   // Append one case to the table
Add(int match_case,std::vector<uint8_t> emit,int consumed_bits)572   void Add(int match_case, std::vector<uint8_t> emit, int consumed_bits) {
573     elems_.push_back({match_case, std::move(emit), consumed_bits});
574     max_consumed_bits_ = std::max(max_consumed_bits_, consumed_bits);
575     max_match_case_ = std::max(max_match_case_, match_case);
576   }
577 
578   // Build the table
Build() const579   void Build() const {
580     Choose()->Build(this, BitsForMaxValue(elems_.size() - 1));
581   }
582 
583   // Generate a call to the accessor function for the emit table
EmitAccessor(std::string index,std::string offset)584   std::string EmitAccessor(std::string index, std::string offset) {
585     return absl::StrCat("GetEmit", id_, "(", index, ", ", offset, ")");
586   }
587 
588   // Generate a call to the accessor function for the op table
OpAccessor(std::string index)589   std::string OpAccessor(std::string index) {
590     return absl::StrCat("GetOp", id_, "(", index, ")");
591   }
592 
ConsumeBits() const593   int ConsumeBits() const { return BitsForMaxValue(max_consumed_bits_); }
MatchBits() const594   int MatchBits() const { return BitsForMaxValue(max_match_case_); }
595 
596  private:
597   // One element in the op table.
598   struct Elem {
599     int match_case;
600     std::vector<uint8_t> emit;
601     int consumed_bits;
602   };
603 
604   // A nested slice is one slice of a table using two level lookup
605   // - i.e. we look at an outer table to get an index into the inner table,
606   //   and then fetch the result from there.
607   struct NestedSlice {
608     std::vector<uint8_t> emit;
609     std::vector<uint64_t> inner;
610     std::vector<int> outer;
611 
612     // Various sizes return number of bits to be generated
613 
InnerSizeTableBuilder::NestedSlice614     size_t InnerSize() const {
615       return inner.size() *
616              TypeBitsForMax(*std::max_element(inner.begin(), inner.end()));
617     }
618 
OuterSizeTableBuilder::NestedSlice619     size_t OuterSize() const {
620       return outer.size() *
621              TypeBitsForMax(*std::max_element(outer.begin(), outer.end()));
622     }
623 
EmitSizeTableBuilder::NestedSlice624     size_t EmitSize() const { return emit.size() * 8; }
625   };
626 
627   // A slice is one part of a larger table.
628   struct Slice {
629     std::vector<uint8_t> emit;
630     std::vector<uint64_t> ops;
631 
632     // Various sizes return number of bits to be generated
633 
OpsSizeTableBuilder::Slice634     size_t OpsSize() const {
635       return ops.size() *
636              TypeBitsForMax(*std::max_element(ops.begin(), ops.end()));
637     }
638 
EmitSizeTableBuilder::Slice639     size_t EmitSize() const { return emit.size() * 8; }
640 
641     // Given a vector of symbols to emit, return the offset into the emit table
642     // that they're at (adding them to the emit table if necessary).
OffsetOfTableBuilder::Slice643     int OffsetOf(const std::vector<uint8_t>& x) {
644       if (x.empty()) return 0;
645       auto r = std::search(emit.begin(), emit.end(), x.begin(), x.end());
646       if (r == emit.end()) {
647         // look for a partial match @ end
648         for (size_t check_len = x.size() - 1; check_len > 0; check_len--) {
649           if (emit.size() < check_len) continue;
650           bool matches = true;
651           for (size_t i = 0; matches && i < check_len; i++) {
652             if (emit[emit.size() - check_len + i] != x[i]) matches = false;
653           }
654           if (matches) {
655             int offset = emit.size() - check_len;
656             for (size_t i = check_len; i < x.size(); i++) {
657               emit.push_back(x[i]);
658             }
659             for (size_t i = 0; i < x.size(); i++) {
660               if (emit[offset + i] != x[i]) {
661                 abort();
662               }
663             }
664             return offset;
665           }
666         }
667         // add new
668         int result = emit.size();
669         for (auto v : x) emit.push_back(v);
670         return result;
671       }
672       return r - emit.begin();
673     }
674 
675     // Convert this slice to a nested slice.
MakeNestedSliceTableBuilder::Slice676     NestedSlice MakeNestedSlice() const {
677       NestedSlice result;
678       result.emit = emit;
679       std::map<uint64_t, int> op_to_inner;
680       for (auto v : ops) {
681         auto it = op_to_inner.find(v);
682         if (it == op_to_inner.end()) {
683           it = op_to_inner.emplace(v, op_to_inner.size()).first;
684           result.inner.push_back(v);
685         }
686         result.outer.push_back(it->second);
687       }
688       return result;
689     }
690   };
691 
692   // An EncodeOption is a potential way of encoding a table.
693   struct EncodeOption {
694     // Overall size (in bits) of the table encoding
695     virtual size_t Size() const = 0;
696     // Generate the code
697     virtual void Build(const TableBuilder* builder, int op_bits) const = 0;
~EncodeOptionTableBuilder::EncodeOption698     virtual ~EncodeOption() {}
699   };
700 
701   // NestedTable is a table that uses two level lookup for each slice
702   struct NestedTable : public EncodeOption {
703     std::vector<NestedSlice> slices;
704     int slice_bits;
SizeTableBuilder::NestedTable705     size_t Size() const override {
706       size_t sum = 0;
707       std::vector<Hash> h_emit;
708       std::vector<Hash> h_inner;
709       std::vector<Hash> h_outer;
710       for (size_t i = 0; i < slices.size(); i++) {
711         h_emit.push_back(HashVec("uint8_t", slices[i].emit));
712         h_inner.push_back(HashVec(TypeForMax(MaxInner()), slices[i].inner));
713         h_outer.push_back(HashVec(TypeForMax(MaxOuter()), slices[i].outer));
714       }
715       std::set<Hash> seen;
716       for (size_t i = 0; i < slices.size(); i++) {
717         // Try to account for deduplication in the size calculation.
718         if (seen.count(h_emit[i]) == 0) sum += slices[i].EmitSize();
719         if (seen.count(h_outer[i]) == 0) sum += slices[i].OuterSize();
720         if (seen.count(h_inner[i]) == 0) sum += slices[i].OuterSize();
721         seen.insert(h_emit[i]);
722         seen.insert(h_outer[i]);
723         seen.insert(h_inner[i]);
724       }
725       if (slice_bits != 0) sum += 3 * 64 * slices.size();
726       return sum;
727     }
BuildTableBuilder::NestedTable728     void Build(const TableBuilder* builder, int op_bits) const override {
729       Sink* const global_fns = builder->ctx_->global_fns();
730       Sink* const global_decls = builder->ctx_->global_decls();
731       Sink* const global_values = builder->ctx_->global_values();
732       const int id = builder->id_;
733       std::vector<std::string> lines;
734       const uint64_t max_inner = MaxInner();
735       const uint64_t max_outer = MaxOuter();
736       std::vector<std::unique_ptr<Array>> emit_names;
737       std::vector<std::unique_ptr<Array>> inner_names;
738       std::vector<std::unique_ptr<Array>> outer_names;
739       for (size_t i = 0; i < slices.size(); i++) {
740         emit_names.push_back(builder->GenArray(
741             slice_bits != 0, absl::StrCat("table", id, "_", i, "_emit"),
742             "uint8_t", slices[i].emit, true, global_decls, global_values));
743         inner_names.push_back(builder->GenArray(
744             slice_bits != 0, absl::StrCat("table", id, "_", i, "_inner"),
745             TypeForMax(max_inner), slices[i].inner, true, global_decls,
746             global_values));
747         outer_names.push_back(builder->GenArray(
748             slice_bits != 0, absl::StrCat("table", id, "_", i, "_outer"),
749             TypeForMax(max_outer), slices[i].outer, false, global_decls,
750             global_values));
751       }
752       if (slice_bits == 0) {
753         global_fns->Add(absl::StrCat(
754             "static inline uint64_t GetOp", id, "(size_t i) { return ",
755             inner_names[0]->Index(outer_names[0]->Index("i")), "; }"));
756         global_fns->Add(absl::StrCat("static inline uint64_t GetEmit", id,
757                                      "(size_t, size_t emit) { return ",
758                                      emit_names[0]->Index("emit"), "; }"));
759       } else {
760         GenCompound(id, emit_names, "emit", "uint8_t", global_decls,
761                     global_values);
762         GenCompound(id, inner_names, "inner", TypeForMax(max_inner),
763                     global_decls, global_values);
764         GenCompound(id, outer_names, "outer", TypeForMax(max_outer),
765                     global_decls, global_values);
766         global_fns->Add(absl::StrCat(
767             "static inline uint64_t GetOp", id, "(size_t i) { return table", id,
768             "_inner_[i >> ", op_bits - slice_bits, "][table", id,
769             "_outer_[i >> ", op_bits - slice_bits, "][i & 0x",
770             absl::Hex((1 << (op_bits - slice_bits)) - 1), "]]; }"));
771         global_fns->Add(absl::StrCat("static inline uint64_t GetEmit", id,
772                                      "(size_t i, size_t emit) { return table",
773                                      id, "_emit_[i >> ", op_bits - slice_bits,
774                                      "][emit]; }"));
775       }
776     }
MaxInnerTableBuilder::NestedTable777     uint64_t MaxInner() const {
778       if (max_inner == 0) {
779         for (size_t i = 0; i < slices.size(); i++) {
780           max_inner =
781               std::max(max_inner, *std::max_element(slices[i].inner.begin(),
782                                                     slices[i].inner.end()));
783         }
784       }
785       return max_inner;
786     }
MaxOuterTableBuilder::NestedTable787     int MaxOuter() const {
788       if (max_outer == 0) {
789         for (size_t i = 0; i < slices.size(); i++) {
790           max_outer =
791               std::max(max_outer, *std::max_element(slices[i].outer.begin(),
792                                                     slices[i].outer.end()));
793         }
794       }
795       return max_outer;
796     }
797     mutable uint64_t max_inner = 0;
798     mutable int max_outer = 0;
799   };
800 
801   // Encoding that uses single level lookup for each slice.
802   struct Table : public EncodeOption {
803     std::vector<Slice> slices;
804     int slice_bits;
SizeTableBuilder::Table805     size_t Size() const override {
806       size_t sum = 0;
807       std::vector<Hash> h_emit;
808       std::vector<Hash> h_ops;
809       for (size_t i = 0; i < slices.size(); i++) {
810         h_emit.push_back(HashVec("uint8_t", slices[i].emit));
811         h_ops.push_back(HashVec(TypeForMax(MaxOp()), slices[i].ops));
812       }
813       std::set<Hash> seen;
814       for (size_t i = 0; i < slices.size(); i++) {
815         if (seen.count(h_emit[i]) == 0) sum += slices[i].EmitSize();
816         if (seen.count(h_ops[i]) == 0) sum += slices[i].OpsSize();
817         seen.insert(h_emit[i]);
818         seen.insert(h_ops[i]);
819       }
820       return sum + 3 * 64 * slices.size();
821     }
BuildTableBuilder::Table822     void Build(const TableBuilder* builder, int op_bits) const override {
823       Sink* const global_fns = builder->ctx_->global_fns();
824       Sink* const global_decls = builder->ctx_->global_decls();
825       Sink* const global_values = builder->ctx_->global_values();
826       uint64_t max_op = MaxOp();
827       const int id = builder->id_;
828       std::vector<std::unique_ptr<Array>> emit_names;
829       std::vector<std::unique_ptr<Array>> ops_names;
830       for (size_t i = 0; i < slices.size(); i++) {
831         emit_names.push_back(builder->GenArray(
832             slice_bits != 0, absl::StrCat("table", id, "_", i, "_emit"),
833             "uint8_t", slices[i].emit, true, global_decls, global_values));
834         ops_names.push_back(builder->GenArray(
835             slice_bits != 0, absl::StrCat("table", id, "_", i, "_ops"),
836             TypeForMax(max_op), slices[i].ops, true, global_decls,
837             global_values));
838       }
839       if (slice_bits == 0) {
840         global_fns->Add(absl::StrCat("static inline uint64_t GetOp", id,
841                                      "(size_t i) { return ",
842                                      ops_names[0]->Index("i"), "; }"));
843         global_fns->Add(absl::StrCat("static inline uint64_t GetEmit", id,
844                                      "(size_t, size_t emit) { return ",
845                                      emit_names[0]->Index("emit"), "; }"));
846       } else {
847         GenCompound(id, emit_names, "emit", "uint8_t", global_decls,
848                     global_values);
849         GenCompound(id, ops_names, "ops", TypeForMax(max_op), global_decls,
850                     global_values);
851         global_fns->Add(absl::StrCat(
852             "static inline uint64_t GetOp", id, "(size_t i) { return table", id,
853             "_ops_[i >> ", op_bits - slice_bits, "][i & 0x",
854             absl::Hex((1 << (op_bits - slice_bits)) - 1), "]; }"));
855         global_fns->Add(absl::StrCat("static inline uint64_t GetEmit", id,
856                                      "(size_t i, size_t emit) { return table",
857                                      id, "_emit_[i >> ", op_bits - slice_bits,
858                                      "][emit]; }"));
859       }
860     }
MaxOpTableBuilder::Table861     uint64_t MaxOp() const {
862       if (max_op == 0) {
863         for (size_t i = 0; i < slices.size(); i++) {
864           max_op = std::max(max_op, *std::max_element(slices[i].ops.begin(),
865                                                       slices[i].ops.end()));
866         }
867       }
868       return max_op;
869     }
870     mutable uint64_t max_op = 0;
871     // Convert to a two-level lookup
MakeNestedTableTableBuilder::Table872     std::unique_ptr<NestedTable> MakeNestedTable() {
873       std::unique_ptr<NestedTable> result(new NestedTable);
874       result->slice_bits = slice_bits;
875       for (const auto& slice : slices) {
876         result->slices.push_back(slice.MakeNestedSlice());
877       }
878       return result;
879     }
880   };
881 
882   // Given a number of slices (2**slice_bits), generate a table that uses a
883   // single level lookup for each slice based on our input.
MakeTable(size_t slice_bits) const884   std::unique_ptr<Table> MakeTable(size_t slice_bits) const {
885     std::unique_ptr<Table> table = std::make_unique<Table>();
886     int slices = 1 << slice_bits;
887     table->slices.resize(slices);
888     table->slice_bits = slice_bits;
889     const int pack_consume_bits = ConsumeBits();
890     const int pack_match_bits = MatchBits();
891     for (int i = 0; i < slices; i++) {
892       auto& slice = table->slices[i];
893       for (size_t j = 0; j < elems_.size() / slices; j++) {
894         const auto& elem = elems_[i * elems_.size() / slices + j];
895         slice.ops.push_back(elem.consumed_bits |
896                             (elem.match_case << pack_consume_bits) |
897                             (slice.OffsetOf(elem.emit)
898                              << (pack_consume_bits + pack_match_bits)));
899       }
900     }
901     return table;
902   }
903 
904   class Array {
905    public:
906     virtual ~Array() = default;
907     virtual std::string Index(absl::string_view value) = 0;
908     virtual std::string ArrayName() = 0;
909     virtual int Cost() = 0;
910   };
911 
912   class NamedArray : public Array {
913    public:
NamedArray(std::string name)914     explicit NamedArray(std::string name) : name_(std::move(name)) {}
Index(absl::string_view value)915     std::string Index(absl::string_view value) override {
916       return absl::StrCat(name_, "[", value, "]");
917     }
ArrayName()918     std::string ArrayName() override { return name_; }
Cost()919     int Cost() override { abort(); }
920 
921    private:
922     std::string name_;
923   };
924 
925   class IdentityArray : public Array {
926    public:
Index(absl::string_view value)927     std::string Index(absl::string_view value) override {
928       return std::string(value);
929     }
ArrayName()930     std::string ArrayName() override { abort(); }
Cost()931     int Cost() override { return 0; }
932   };
933 
934   class ConstantArray : public Array {
935    public:
ConstantArray(std::string value)936     explicit ConstantArray(std::string value) : value_(std::move(value)) {}
Index(absl::string_view index)937     std::string Index(absl::string_view index) override {
938       return absl::StrCat("((void)", index, ", ", value_, ")");
939     }
ArrayName()940     std::string ArrayName() override { abort(); }
Cost()941     int Cost() override { return 0; }
942 
943    private:
944     std::string value_;
945   };
946 
947   class OffsetArray : public Array {
948    public:
OffsetArray(int offset)949     explicit OffsetArray(int offset) : offset_(offset) {}
Index(absl::string_view value)950     std::string Index(absl::string_view value) override {
951       return absl::StrCat(value, " + ", offset_);
952     }
ArrayName()953     std::string ArrayName() override { abort(); }
Cost()954     int Cost() override { return 10; }
955 
956    private:
957     int offset_;
958   };
959 
960   class LinearDivideArray : public Array {
961    public:
LinearDivideArray(int offset,int divisor)962     LinearDivideArray(int offset, int divisor)
963         : offset_(offset), divisor_(divisor) {}
Index(absl::string_view value)964     std::string Index(absl::string_view value) override {
965       return absl::StrCat(value, "/", divisor_, " + ", offset_);
966     }
ArrayName()967     std::string ArrayName() override { abort(); }
Cost()968     int Cost() override { return 20 + (offset_ != 0 ? 10 : 0); }
969 
970    private:
971     int offset_;
972     int divisor_;
973   };
974 
975   class TwoElemArray : public Array {
976    public:
TwoElemArray(std::string value0,std::string value1)977     TwoElemArray(std::string value0, std::string value1)
978         : value0_(std::move(value0)), value1_(std::move(value1)) {}
Index(absl::string_view value)979     std::string Index(absl::string_view value) override {
980       return absl::StrCat(value, " ? ", value1_, " : ", value0_);
981     }
ArrayName()982     std::string ArrayName() override { abort(); }
Cost()983     int Cost() override { return 40; }
984 
985    private:
986     std::string value0_;
987     std::string value1_;
988   };
989 
990   class Composite2Array : public Array {
991    public:
Composite2Array(std::unique_ptr<Array> a,std::unique_ptr<Array> b,int split)992     Composite2Array(std::unique_ptr<Array> a, std::unique_ptr<Array> b,
993                     int split)
994         : a_(std::move(a)), b_(std::move(b)), split_(split) {}
Index(absl::string_view value)995     std::string Index(absl::string_view value) override {
996       return absl::StrCat(
997           "(", value, " < ", split_, " ? (", a_->Index(value), ") : (",
998           b_->Index(absl::StrCat("(", value, "-", split_, ")")), "))");
999     }
ArrayName()1000     std::string ArrayName() override { abort(); }
Cost()1001     int Cost() override { return 40 + a_->Cost() + b_->Cost(); }
1002 
1003    private:
1004     std::unique_ptr<Array> a_;
1005     std::unique_ptr<Array> b_;
1006     int split_;
1007   };
1008 
1009   // Helper to generate a compound table (an array of arrays)
GenCompound(int id,const std::vector<std::unique_ptr<Array>> & arrays,std::string ext,std::string type,Sink * global_decls,Sink * global_values)1010   static void GenCompound(int id,
1011                           const std::vector<std::unique_ptr<Array>>& arrays,
1012                           std::string ext, std::string type, Sink* global_decls,
1013                           Sink* global_values) {
1014     global_decls->Add(absl::StrCat("static const ", type, "* const table", id,
1015                                    "_", ext, "_[", arrays.size(), "];"));
1016     global_values->Add(absl::StrCat("const ", type,
1017                                     "* const HuffDecoderCommon::table", id, "_",
1018                                     ext, "_[", arrays.size(), "] = {"));
1019     for (const std::unique_ptr<Array>& array : arrays) {
1020       global_values->Add(absl::StrCat("  ", array->ArrayName(), ","));
1021     }
1022     global_values->Add("};");
1023   }
1024 
1025   // Try to create a simple function equivalent to a mapping implied by a set of
1026   // values.
1027   static const int kMaxArrayToFunctionRecursions = 1;
1028   template <typename T>
ArrayToFunction(const std::vector<T> & values,int recurse=kMaxArrayToFunctionRecursions)1029   static std::unique_ptr<Array> ArrayToFunction(
1030       const std::vector<T>& values,
1031       int recurse = kMaxArrayToFunctionRecursions) {
1032     std::unique_ptr<Array> best = nullptr;
1033     auto note_solution = [&best](std::unique_ptr<Array> a) {
1034       if (best != nullptr && best->Cost() <= a->Cost()) return;
1035       best = std::move(a);
1036     };
1037     // constant => k,k,k,k,...
1038     bool is_constant = true;
1039     for (size_t i = 1; i < values.size(); i++) {
1040       if (values[i] != values[0]) {
1041         is_constant = false;
1042         break;
1043       }
1044     }
1045     if (is_constant) {
1046       note_solution(std::make_unique<ConstantArray>(absl::StrCat(values[0])));
1047     }
1048     // identity => 0,1,2,3,...
1049     bool is_identity = true;
1050     for (size_t i = 0; i < values.size(); i++) {
1051       if (static_cast<size_t>(values[i]) != i) {
1052         is_identity = false;
1053         break;
1054       }
1055     }
1056     if (is_identity) {
1057       note_solution(std::make_unique<IdentityArray>());
1058     }
1059     // offset => k,k+1,k+2,k+3,...
1060     bool is_offset = true;
1061     for (size_t i = 1; i < values.size(); i++) {
1062       if (static_cast<size_t>(values[i] - values[0]) != i) {
1063         is_offset = false;
1064         break;
1065       }
1066     }
1067     if (is_offset) {
1068       note_solution(std::make_unique<OffsetArray>(values[0]));
1069     }
1070     // offset => k,k,k+1,k+1,...
1071     for (size_t d = 2; d < 32; d++) {
1072       bool is_linear = true;
1073       for (size_t i = 1; i < values.size(); i++) {
1074         if (static_cast<size_t>(values[i] - values[0]) != (i / d)) {
1075           is_linear = false;
1076           break;
1077         }
1078       }
1079       if (is_linear) {
1080         note_solution(std::make_unique<LinearDivideArray>(values[0], d));
1081       }
1082     }
1083     // Two items can be resolved with a conditional
1084     if (values.size() == 2) {
1085       note_solution(std::make_unique<TwoElemArray>(absl::StrCat(values[0]),
1086                                                    absl::StrCat(values[1])));
1087     }
1088     if ((recurse > 0 && values.size() >= 6) ||
1089         (recurse == kMaxArrayToFunctionRecursions)) {
1090       for (size_t i = 1; i < values.size() - 1; i++) {
1091         std::vector<T> left(values.begin(), values.begin() + i);
1092         std::vector<T> right(values.begin() + i, values.end());
1093         std::unique_ptr<Array> left_array = ArrayToFunction(left, recurse - 1);
1094         std::unique_ptr<Array> right_array =
1095             ArrayToFunction(right, recurse - 1);
1096         if (left_array && right_array) {
1097           note_solution(std::make_unique<Composite2Array>(
1098               std::move(left_array), std::move(right_array), i));
1099         }
1100       }
1101     }
1102     return best;
1103   }
1104 
1105   // Helper to generate an array of values
1106   template <typename T>
GenArray(bool force_array,std::string name,std::string type,const std::vector<T> & values,bool hex,Sink * global_decls,Sink * global_values) const1107   std::unique_ptr<Array> GenArray(bool force_array, std::string name,
1108                                   std::string type,
1109                                   const std::vector<T>& values, bool hex,
1110                                   Sink* global_decls,
1111                                   Sink* global_values) const {
1112     if (values.empty()) return std::make_unique<NamedArray>("nullptr");
1113     if (!force_array) {
1114       auto fn = ArrayToFunction(values);
1115       if (fn != nullptr) return fn;
1116     }
1117     auto previous_name =
1118         ctx_->PreviousNameForArtifact(name, HashVec(type, values));
1119     if (previous_name.has_value()) {
1120       return std::make_unique<NamedArray>(absl::StrCat(*previous_name, "_"));
1121     }
1122     std::vector<std::string> elems;
1123     elems.reserve(values.size());
1124     for (const auto& elem : values) {
1125       if (hex) {
1126         if (type == "uint8_t") {
1127           elems.push_back(absl::StrCat("0x", absl::Hex(elem, absl::kZeroPad2)));
1128         } else if (type == "uint16_t") {
1129           elems.push_back(absl::StrCat("0x", absl::Hex(elem, absl::kZeroPad4)));
1130         } else {
1131           elems.push_back(absl::StrCat("0x", absl::Hex(elem, absl::kZeroPad8)));
1132         }
1133       } else {
1134         elems.push_back(absl::StrCat(elem));
1135       }
1136     }
1137     std::string data = absl::StrJoin(elems, ", ");
1138     global_decls->Add(absl::StrCat("static const ", type, " ", name, "_[",
1139                                    values.size(), "];"));
1140     global_values->Add(absl::StrCat("const ", type, " HuffDecoderCommon::",
1141                                     name, "_[", values.size(), "] = {"));
1142     global_values->Add(absl::StrCat("  ", data));
1143     global_values->Add("};");
1144     return std::make_unique<NamedArray>(absl::StrCat(name, "_"));
1145   }
1146 
1147   // Choose an encoding for this set of tables.
1148   // We try all available values for slice count and choose the one that gives
1149   // the smallest footprint.
Choose() const1150   std::unique_ptr<EncodeOption> Choose() const {
1151     std::unique_ptr<EncodeOption> chosen;
1152     size_t best_size = std::numeric_limits<size_t>::max();
1153     for (size_t slice_bits = 0; (1 << slice_bits) < elems_.size();
1154          slice_bits++) {
1155       auto raw = MakeTable(slice_bits);
1156       size_t raw_size = raw->Size();
1157       auto nested = raw->MakeNestedTable();
1158       size_t nested_size = nested->Size();
1159       if (raw_size < best_size) {
1160         chosen = std::move(raw);
1161         best_size = raw_size;
1162       }
1163       if (nested_size < best_size) {
1164         chosen = std::move(nested);
1165         best_size = nested_size;
1166       }
1167     }
1168     return chosen;
1169   }
1170 
1171   BuildCtx* const ctx_;
1172   std::vector<Elem> elems_;
1173   int max_consumed_bits_ = 0;
1174   int max_match_case_ = 0;
1175   const int id_;
1176 };
1177 
1178 ///////////////////////////////////////////////////////////////////////////////
1179 // FunMaker
1180 // Handles generating the code for various functions.
1181 
1182 class FunMaker {
1183  public:
FunMaker(Sink * sink)1184   explicit FunMaker(Sink* sink) : sink_(sink) {}
1185 
1186   // Generate a refill function - that ensures the incoming bitmask has enough
1187   // bits for the next step.
RefillTo(int n)1188   std::string RefillTo(int n) {
1189     if (have_refills_.count(n) == 0) {
1190       have_refills_.insert(n);
1191       auto fn = NewFun(absl::StrCat("RefillTo", n), "bool");
1192       auto s = fn->Add<Switch>("buffer_len_");
1193       for (int i = 0; i < n; i++) {
1194         auto c = s->Case(i);
1195         const int bytes_needed = (n - i + 7) / 8;
1196         const int bytes_allowed = (64 - i) / 8;
1197         c->Add(absl::StrCat("return ", ReadBytes(bytes_needed, bytes_allowed),
1198                             ";"));
1199       }
1200       fn->Add("return true;");
1201     }
1202     return absl::StrCat("RefillTo", n, "()");
1203   }
1204 
1205   // At callsite, generate a call to a new function with base name
1206   // base_name (new functions get a suffix of how many instances of base_name
1207   // there have been).
1208   // Return a sink to fill in the body of the new function.
CallNewFun(std::string base_name,Sink * callsite)1209   Sink* CallNewFun(std::string base_name, Sink* callsite) {
1210     std::string name = absl::StrCat(base_name, have_funs_[base_name]++);
1211     callsite->Add(absl::StrCat(name, "();"));
1212     return NewFun(name, "void");
1213   }
1214 
FillFromInput(int bytes_needed)1215   std::string FillFromInput(int bytes_needed) {
1216     auto fn_name = absl::StrCat("Fill", bytes_needed);
1217     if (have_fill_from_input_.count(bytes_needed) == 0) {
1218       have_fill_from_input_.insert(bytes_needed);
1219       auto fn = NewFun(fn_name, "void");
1220       std::string new_value;
1221       if (bytes_needed == 8) {
1222         new_value = "0";
1223       } else {
1224         new_value = absl::StrCat("(buffer_ << ", 8 * bytes_needed, ")");
1225       }
1226       for (int i = 0; i < bytes_needed; i++) {
1227         absl::StrAppend(&new_value, "| (static_cast<uint64_t>(begin_[", i,
1228                         "]) << ", 8 * (bytes_needed - i - 1), ")");
1229       }
1230       fn->Add(absl::StrCat("buffer_ = ", new_value, ";"));
1231       fn->Add(absl::StrCat("begin_ += ", bytes_needed, ";"));
1232       fn->Add(absl::StrCat("buffer_len_ += ", 8 * bytes_needed, ";"));
1233     }
1234     return fn_name;
1235   }
1236 
1237  private:
NewFun(std::string name,std::string returns)1238   Sink* NewFun(std::string name, std::string returns) {
1239     sink_->Add(absl::StrCat(returns, " ", name, "() {"));
1240     auto fn = sink_->Add<Indent>();
1241     sink_->Add("}");
1242     return fn;
1243   }
1244 
1245   // Bring in some number of bytes from the input stream to our current read
1246   // bits.
ReadBytes(int bytes_needed,int bytes_allowed)1247   std::string ReadBytes(int bytes_needed, int bytes_allowed) {
1248     auto fn_name =
1249         absl::StrCat("Read", bytes_needed, "to", bytes_allowed, "Bytes");
1250     if (have_reads_.count(std::make_pair(bytes_needed, bytes_allowed)) == 0) {
1251       have_reads_.insert(std::make_pair(bytes_needed, bytes_allowed));
1252       auto fn = NewFun(fn_name, "bool");
1253       auto s = fn->Add<Switch>("end_ - begin_");
1254       for (int i = 0; i <= bytes_allowed; i++) {
1255         auto c = i == bytes_allowed ? s->Case(Switch::Default{}) : s->Case(i);
1256         if (i < bytes_needed) {
1257           c->Add(absl::StrCat("return false;"));
1258         } else {
1259           c->Add(absl::StrCat(FillFromInput(i), "();"));
1260           c->Add("return true;");
1261         }
1262       }
1263     }
1264     return absl::StrCat(fn_name, "()");
1265   }
1266 
1267   std::set<int> have_refills_;
1268   std::set<std::pair<int, int>> have_reads_;
1269   std::set<int> have_fill_from_input_;
1270   std::map<std::string, int> have_funs_;
1271   Sink* sink_;
1272 };
1273 
1274 ///////////////////////////////////////////////////////////////////////////////
1275 // BuildCtx implementation
1276 
AddDone(SymSet start_syms,int num_bits,bool all_ones_so_far,Sink * out)1277 void BuildCtx::AddDone(SymSet start_syms, int num_bits, bool all_ones_so_far,
1278                        Sink* out) {
1279   out->Add("done_ = true;");
1280   if (num_bits == 1) {
1281     if (!all_ones_so_far) out->Add("ok_ = false;");
1282     return;
1283   }
1284   if (num_bits > 7) {
1285     auto consume_rest = out->Add<Switch>("end_ - begin_");
1286     for (int i = 1; i < (num_bits + 7) / 8; i++) {
1287       auto c = consume_rest->Case(i);
1288       c->Add(absl::StrCat(fun_maker_->FillFromInput(i), "();"));
1289       c->Add("break;");
1290     }
1291   }
1292   // we must have 0 < buffer_len_ < num_bits
1293   auto s = out->Add<Switch>("buffer_len_");
1294   auto c0 = s->Case("0");
1295   if (!all_ones_so_far) c0->Add("ok_ = false;");
1296   c0->Add("return;");
1297   for (int i = 1; i < num_bits; i++) {
1298     auto c = s->Case(i);
1299     SymSet maybe;
1300     for (auto sym : start_syms) {
1301       if (sym.bits.length() > i) continue;
1302       maybe.push_back(sym);
1303     }
1304     if (maybe.empty()) {
1305       if (all_ones_so_far) {
1306         c->Add("ok_ = (buffer_ & ((1<<buffer_len_)-1)) == (1<<buffer_len_)-1;");
1307       } else {
1308         c->Add("ok_ = false;");
1309       }
1310       c->Add("return;");
1311       continue;
1312     }
1313     TableBuilder table_builder(this);
1314     std::map<absl::optional<int>, int> cases;
1315     for (size_t n = 0; n < (1 << i); n++) {
1316       AddDoneCase(n, i, all_ones_so_far, maybe, {}, &table_builder, &cases);
1317     }
1318     table_builder.Build();
1319     c->Add(absl::StrCat("const auto index = buffer_ & ", (1 << i) - 1, ";"));
1320     c->Add(absl::StrCat("const auto op = ", table_builder.OpAccessor("index"),
1321                         ";"));
1322     if (table_builder.ConsumeBits() != 0) {
1323       fprintf(stderr, "consume bits = %d\n", table_builder.ConsumeBits());
1324       abort();
1325     }
1326     auto s_fin = c->Add<Switch>(
1327         absl::StrCat("op & ", (1 << table_builder.MatchBits()) - 1));
1328     for (auto& kv : cases) {
1329       if (kv.first.has_value()) {
1330         if (*kv.first == 0) continue;
1331         auto emit_ok = s_fin->Case(kv.second);
1332         for (int i = 0; i < *kv.first; i++) {
1333           emit_ok->Add(absl::StrCat(
1334               "sink_(",
1335               table_builder.EmitAccessor(
1336                   "index", absl::StrCat("(op >> ", table_builder.MatchBits(),
1337                                         ") + ", i)),
1338               ");"));
1339         }
1340         emit_ok->Add("break;");
1341       } else {
1342         auto fail = s_fin->Case(kv.second);
1343         fail->Add("ok_ = false;");
1344         fail->Add("break;");
1345       }
1346     }
1347     c->Add("return;");
1348   }
1349 }
1350 
AddDoneCase(size_t n,size_t n_bits,bool all_ones_so_far,SymSet syms,std::vector<uint8_t> emit,TableBuilder * table_builder,std::map<absl::optional<int>,int> * cases)1351 void BuildCtx::AddDoneCase(size_t n, size_t n_bits, bool all_ones_so_far,
1352                            SymSet syms, std::vector<uint8_t> emit,
1353                            TableBuilder* table_builder,
1354                            std::map<absl::optional<int>, int>* cases) {
1355   auto add_case = [cases](absl::optional<int> which) {
1356     auto it = cases->find(which);
1357     if (it == cases->end()) {
1358       it = cases->emplace(which, cases->size()).first;
1359     }
1360     return it->second;
1361   };
1362   if (all_ones_so_far && n == (1 << n_bits) - 1) {
1363     table_builder->Add(add_case(emit.size()), emit, 0);
1364     return;
1365   }
1366   for (auto sym : syms) {
1367     if ((n >> (n_bits - sym.bits.length())) == sym.bits.mask()) {
1368       emit.push_back(sym.symbol);
1369       int bits_left = n_bits - sym.bits.length();
1370       if (bits_left == 0) {
1371         table_builder->Add(add_case(emit.size()), emit, 0);
1372         return;
1373       }
1374       SymSet next_syms;
1375       for (auto sym : AllSyms()) {
1376         if (sym.bits.length() > bits_left) continue;
1377         next_syms.push_back(sym);
1378       }
1379       AddDoneCase(n & ((1 << bits_left) - 1), n_bits - sym.bits.length(), true,
1380                   std::move(next_syms), std::move(emit), table_builder, cases);
1381       return;
1382     }
1383   }
1384   table_builder->Add(add_case(absl::nullopt), {}, 0);
1385 }
1386 
AddStep(SymSet start_syms,int num_bits,bool is_top,bool refill,int depth,Sink * out)1387 void BuildCtx::AddStep(SymSet start_syms, int num_bits, bool is_top,
1388                        bool refill, int depth, Sink* out) {
1389   TableBuilder table_builder(this);
1390   if (refill) {
1391     out->Add(absl::StrCat("if (!", fun_maker_->RefillTo(num_bits), ") {"));
1392     auto ifblk = out->Add<Indent>();
1393     if (!is_top) {
1394       Sym some = start_syms[0];
1395       auto sym = grpc_chttp2_huffsyms[some.symbol];
1396       int consumed_len = (sym.length - some.bits.length());
1397       uint32_t consumed_mask = sym.bits >> some.bits.length();
1398       bool all_ones_so_far = consumed_mask == ((1 << consumed_len) - 1);
1399       AddDone(start_syms, num_bits, all_ones_so_far,
1400               fun_maker_->CallNewFun("Done", ifblk));
1401       ifblk->Add("return;");
1402     } else {
1403       AddDone(start_syms, num_bits, true,
1404               fun_maker_->CallNewFun("Done", ifblk));
1405       ifblk->Add("break;");
1406     }
1407     out->Add("}");
1408   }
1409   out->Add(absl::StrCat("const auto index = (buffer_ >> (buffer_len_ - ",
1410                         num_bits, ")) & 0x", absl::Hex((1 << num_bits) - 1),
1411                         ";"));
1412   std::map<MatchCase, int> match_cases;
1413   for (int i = 0; i < (1 << num_bits); i++) {
1414     auto actions = ActionsFor(BitQueue(i, num_bits), start_syms, is_top);
1415     auto add_case = [&match_cases](MatchCase match_case) {
1416       if (match_cases.find(match_case) == match_cases.end()) {
1417         match_cases[match_case] = match_cases.size();
1418       }
1419       return match_cases[match_case];
1420     };
1421     if (actions.emit.size() == 1 && actions.emit[0] == 256) {
1422       table_builder.Add(add_case(End{}), {}, actions.consumed);
1423     } else if (actions.consumed == 0) {
1424       table_builder.Add(add_case(Unmatched{std::move(actions.remaining)}), {},
1425                         num_bits);
1426     } else {
1427       std::vector<uint8_t> emit;
1428       for (auto sym : actions.emit) emit.push_back(sym);
1429       table_builder.Add(
1430           add_case(Matched{static_cast<int>(actions.emit.size())}),
1431           std::move(emit), actions.consumed);
1432     }
1433   }
1434   table_builder.Build();
1435   out->Add(
1436       absl::StrCat("const auto op = ", table_builder.OpAccessor("index"), ";"));
1437   out->Add(absl::StrCat("const int consumed = op & ",
1438                         (1 << table_builder.ConsumeBits()) - 1, ";"));
1439   out->Add("buffer_len_ -= consumed;");
1440   out->Add(absl::StrCat("const auto emit_ofs = op >> ",
1441                         table_builder.ConsumeBits() + table_builder.MatchBits(),
1442                         ";"));
1443   if (match_cases.size() == 1) {
1444     AddMatchBody(&table_builder, "index", "emit_ofs",
1445                  match_cases.begin()->first, refill, depth, out);
1446   } else {
1447     auto s = out->Add<Switch>(
1448         absl::StrCat("(op >> ", table_builder.ConsumeBits(), ") & ",
1449                      (1 << table_builder.MatchBits()) - 1));
1450     for (auto kv : match_cases) {
1451       auto c = s->Case(kv.second);
1452       AddMatchBody(&table_builder, "index", "emit_ofs", kv.first, refill, depth,
1453                    c);
1454       c->Add("break;");
1455     }
1456   }
1457 }
1458 
AddMatchBody(TableBuilder * table_builder,std::string index,std::string ofs,const MatchCase & match_case,bool refill,int depth,Sink * out)1459 void BuildCtx::AddMatchBody(TableBuilder* table_builder, std::string index,
1460                             std::string ofs, const MatchCase& match_case,
1461                             bool refill, int depth, Sink* out) {
1462   if (absl::holds_alternative<End>(match_case)) {
1463     out->Add("begin_ = end_;");
1464     out->Add("buffer_len_ = 0;");
1465     return;
1466   }
1467   if (auto* p = absl::get_if<Unmatched>(&match_case)) {
1468     if (refill) {
1469       int max_bits = 0;
1470       for (auto sym : p->syms) max_bits = std::max(max_bits, sym.bits.length());
1471       AddStep(p->syms,
1472               static_cast<size_t>(depth + 1) >= max_bits_for_depth_.size()
1473                   ? max_bits
1474                   : std::min(max_bits, max_bits_for_depth_[depth + 1]),
1475               false, true, depth + 1,
1476               fun_maker_->CallNewFun("DecodeStep", out));
1477     }
1478     return;
1479   }
1480   const auto& matched = absl::get<Matched>(match_case);
1481   for (int i = 0; i < matched.emits; i++) {
1482     out->Add(absl::StrCat(
1483         "sink_(",
1484         table_builder->EmitAccessor(index, absl::StrCat(ofs, " + ", i)), ");"));
1485   }
1486 }
1487 
1488 ///////////////////////////////////////////////////////////////////////////////
1489 // Driver code
1490 
1491 // Generated header and source code
1492 struct FileSet {
1493   std::string header;
1494   std::string source;
1495   const std::string base_name;
1496   std::vector<std::string> all_ns;
1497 
FileSetFileSet1498   explicit FileSet(std::string base_name) : base_name(base_name) {}
1499   void AddFrontMatter(int copyright_year);
1500   void AddBuild(std::vector<int> max_bits_for_depth, bool selected_version);
1501   void AddTailMatter();
1502 };
1503 
AddFrontMatter(int copyright_year)1504 void FileSet::AddFrontMatter(int copyright_year) {
1505   std::string guard = absl::StrCat(
1506       "GRPC_",
1507       absl::AsciiStrToUpper(absl::StrReplaceAll(base_name, {{"/", "_"}})),
1508       "_H");
1509   auto hdr = std::make_unique<Sink>();
1510   auto src = std::make_unique<Sink>();
1511   hdr->Add<Prelude>("//", copyright_year);
1512   src->Add<Prelude>("//", copyright_year);
1513   hdr->Add(absl::StrCat("#ifndef ", guard));
1514   hdr->Add(absl::StrCat("#define ", guard));
1515   header += hdr->ToString();
1516   source += src->ToString();
1517 }
1518 
AddTailMatter()1519 void FileSet::AddTailMatter() {
1520   auto hdr = std::make_unique<Sink>();
1521   auto src = std::make_unique<Sink>();
1522   hdr->Add("#endif");
1523   header += hdr->ToString();
1524   source += src->ToString();
1525 }
1526 
1527 // Given max_bits_for_depth = {n1,n2,n3,...}
1528 // Build a decoder that first considers n1 bits, then n2, then n3, ...
AddBuild(std::vector<int> max_bits_for_depth,bool selected_version)1529 void FileSet::AddBuild(std::vector<int> max_bits_for_depth,
1530                        bool selected_version) {
1531   auto hdr = std::make_unique<Sink>();
1532   auto src = std::make_unique<Sink>();
1533   src->Add(absl::StrCat("#include \"", base_name, ".h\""));
1534   hdr->Add("#include <cstddef>");
1535   hdr->Add("#include <grpc/support/port_platform.h>");
1536   src->Add("#include <grpc/support/port_platform.h>");
1537   hdr->Add("#include <cstdint>");
1538   hdr->Add("namespace grpc_core {");
1539   src->Add("namespace grpc_core {");
1540   std::string ns;
1541   if (!selected_version) {
1542     ns = absl::StrCat("geometry_", absl::StrJoin(max_bits_for_depth, "_"));
1543     hdr->Add(absl::StrCat("namespace ", ns, " {"));
1544     src->Add(absl::StrCat("namespace ", ns, " {"));
1545   }
1546   hdr->Add("class HuffDecoderCommon {");
1547   hdr->Add(" protected:");
1548   auto global_fns = hdr->Add<Indent>();
1549   hdr->Add(" private:");
1550   auto global_decls = hdr->Add<Indent>();
1551   hdr->Add("};");
1552   hdr->Add(
1553       "template<typename F> class HuffDecoder : public HuffDecoderCommon {");
1554   hdr->Add(" public:");
1555   auto pub = hdr->Add<Indent>();
1556   hdr->Add(" private:");
1557   auto prv = hdr->Add<Indent>();
1558   FunMaker fun_maker(prv->Add<Sink>());
1559   hdr->Add("};");
1560   if (!ns.empty()) {
1561     hdr->Add("}  // namespace geometry");
1562   }
1563   hdr->Add("}  // namespace grpc_core");
1564   auto global_values = src->Add<Indent>();
1565   if (!ns.empty()) {
1566     src->Add("}  // namespace geometry");
1567   }
1568   src->Add("}  // namespace grpc_core");
1569   BuildCtx ctx(std::move(max_bits_for_depth), global_fns, global_decls,
1570                global_values, &fun_maker);
1571   // constructor
1572   pub->Add(
1573       "HuffDecoder(F sink, const uint8_t* begin, const uint8_t* end) : "
1574       "sink_(sink), begin_(begin), end_(end) {}");
1575   // members
1576   prv->Add("F sink_;");
1577   prv->Add("const uint8_t* begin_;");
1578   prv->Add("const uint8_t* const end_;");
1579   prv->Add("uint64_t buffer_ = 0;");
1580   prv->Add("int buffer_len_ = 0;");
1581   prv->Add("bool ok_ = true;");
1582   prv->Add("bool done_ = false;");
1583   // main fn
1584   pub->Add("bool Run() {");
1585   auto body = pub->Add<Indent>();
1586   body->Add("while (!done_) {");
1587   ctx.AddStep(AllSyms(), ctx.MaxBitsForTop(), true, true, 0,
1588               body->Add<Indent>());
1589   body->Add("}");
1590   body->Add("return ok_;");
1591   pub->Add("}");
1592   header += hdr->ToString();
1593   source += src->ToString();
1594   all_ns.push_back(std::move(ns));
1595 }
1596 
1597 // Generate all permutations of max_bits_for_depth for the Build function,
1598 // with a minimum step size of 5 bits (needed for http2 I think) and a
1599 // configurable maximum step size.
1600 class PermutationBuilder {
1601  public:
PermutationBuilder(int max_depth)1602   explicit PermutationBuilder(int max_depth) : max_depth_(max_depth) {}
Run()1603   std::vector<std::vector<int>> Run() {
1604     Step({});
1605     return std::move(perms_);
1606   }
1607 
1608  private:
Step(std::vector<int> so_far)1609   void Step(std::vector<int> so_far) {
1610     // Restrict first step to 7 bits - smaller is known to generate simply
1611     // terrible code.
1612     const int min_step = so_far.empty() ? 7 : 5;
1613     int sum_so_far = std::accumulate(so_far.begin(), so_far.end(), 0);
1614     if (so_far.size() > max_depth_ ||
1615         (so_far.size() == max_depth_ && sum_so_far != 30)) {
1616       return;
1617     }
1618     if (sum_so_far + 5 > 30) {
1619       perms_.emplace_back(std::move(so_far));
1620       return;
1621     }
1622     for (int i = min_step; i <= std::min(30 - sum_so_far, 16); i++) {
1623       auto p = so_far;
1624       p.push_back(i);
1625       Step(std::move(p));
1626     }
1627   }
1628 
1629   const size_t max_depth_;
1630   std::vector<std::vector<int>> perms_;
1631 };
1632 
1633 // Split after c
SplitAfter(absl::string_view input,char c)1634 std::string SplitAfter(absl::string_view input, char c) {
1635   return std::vector<std::string>(absl::StrSplit(input, c)).back();
1636 }
SplitBefore(absl::string_view input,char c)1637 std::string SplitBefore(absl::string_view input, char c) {
1638   return std::vector<std::string>(absl::StrSplit(input, c)).front();
1639 }
1640 
1641 // Does what it says.
WriteFile(std::string filename,std::string content)1642 void WriteFile(std::string filename, std::string content) {
1643   auto out = grpc_core::GetEnv("GEN_OUT");
1644   if (out.has_value()) {
1645     filename = absl::StrCat(*out, "/", filename);
1646   }
1647   std::ofstream ofs(filename);
1648   ofs << content;
1649   if (ofs.bad()) {
1650     fprintf(stderr, "Failed to write %s\n", filename.c_str());
1651     abort();
1652   }
1653 }
1654 
GenMicrobenchmarks()1655 void GenMicrobenchmarks() {
1656   std::queue<std::thread> threads;
1657   // Generate all permutations of max_bits_for_depth for the Build function.
1658   // Then generate all variations of the code.
1659   static constexpr int kNumShards = 100;
1660   std::unique_ptr<FileSet> results[kNumShards];
1661   std::mutex results_mutexes[kNumShards];
1662   for (int i = 0; i < kNumShards; i++) {
1663     results[i] = std::make_unique<FileSet>(
1664         absl::StrCat("test/cpp/microbenchmarks/huffman_geometries/shard_", i));
1665     results[i]->AddFrontMatter(2024);
1666   }
1667   int r = 0;
1668   for (const auto& perm : PermutationBuilder(3).Run()) {
1669     int shard = r++ % kNumShards;
1670     threads.emplace(
1671         [perm, fileset = results[shard].get(), mu = &results_mutexes[shard]] {
1672           std::lock_guard<std::mutex> lock(*mu);
1673           fileset->AddBuild(perm, false);
1674         });
1675   }
1676   while (!threads.empty()) {
1677     threads.front().join();
1678     threads.pop();
1679   }
1680   auto index_hdr = std::make_unique<Sink>();
1681   index_hdr->Add<Prelude>("//", 2023);
1682   index_hdr->Add(
1683       "#ifndef GRPC_TEST_CPP_MICROBENCHMARKS_HUFFMAN_GEOMETRIES_INDEX_H");
1684   index_hdr->Add(
1685       "#define GRPC_TEST_CPP_MICROBENCHMARKS_HUFFMAN_GEOMETRIES_INDEX_H");
1686   auto index_includes = index_hdr->Add<Sink>();
1687   index_hdr->Add("#define DECL_HUFFMAN_VARIANTS() \\");
1688   auto index_decls = index_hdr->Add<Sink>();
1689   index_hdr->Add("  DECL_BENCHMARK(grpc_core::HuffDecoder, Selected)");
1690   index_hdr->Add(
1691       "#endif  // GRPC_TEST_CPP_MICROBENCHMARKS_HUFFMAN_GEOMETRIES_INDEX_H");
1692 
1693   for (auto& r : results) {
1694     r->AddTailMatter();
1695     index_includes->Add(absl::StrCat("#include \"", r->base_name, ".h\""));
1696     for (const auto& ns : r->all_ns) {
1697       index_decls->Add(absl::StrCat("  DECL_BENCHMARK(grpc_core::", ns,
1698                                     "::HuffDecoder, ", ns, "); \\"));
1699     }
1700     WriteFile(r->base_name + ".h", r->header);
1701     WriteFile(r->base_name + ".cc", r->source);
1702   }
1703   WriteFile("test/cpp/microbenchmarks/huffman_geometries/index.h",
1704             index_hdr->ToString());
1705 }
1706 
GenSelected()1707 void GenSelected() {
1708   FileSet selected("src/core/ext/transport/chttp2/transport/decode_huff");
1709   selected.AddFrontMatter(2023);
1710   selected.AddBuild(std::vector<int>({15, 7, 8}), true);
1711   selected.AddTailMatter();
1712   WriteFile(selected.base_name + ".h", selected.header);
1713   WriteFile(selected.base_name + ".cc", selected.source);
1714 }
1715 
main(int argc,char ** argv)1716 int main(int argc, char** argv) {
1717   if (argc < 2) {
1718     fprintf(stderr, "No generators specified\n");
1719     return 1;
1720   }
1721   std::map<std::string, std::function<void()>> generators = {
1722       {"microbenchmarks", GenMicrobenchmarks}, {"selected", GenSelected}};
1723   for (int i = 1; i < argc; i++) {
1724     auto it = generators.find(argv[i]);
1725     if (it == generators.end()) {
1726       fprintf(stderr, "Unknown generator: %s\n", argv[i]);
1727       return 1;
1728     }
1729     it->second();
1730   }
1731   return 0;
1732 }
1733