1 // Copyright 2022 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <atomic>
16 #include <cstdint>
17 #include <fstream>
18 #include <limits>
19 #include <map>
20 #include <memory>
21 #include <mutex>
22 #include <numeric>
23 #include <queue>
24 #include <set>
25 #include <string>
26 #include <thread>
27 #include <vector>
28
29 #include <openssl/sha.h>
30
31 #include "absl/memory/memory.h"
32 #include "absl/strings/ascii.h"
33 #include "absl/strings/str_cat.h"
34 #include "absl/strings/str_join.h"
35 #include "absl/strings/str_replace.h"
36 #include "absl/strings/str_split.h"
37 #include "absl/types/optional.h"
38 #include "absl/types/variant.h"
39
40 #include "src/core/ext/transport/chttp2/transport/huffsyms.h"
41 #include "src/core/util/env.h"
42 #include "src/core/util/match.h"
43
44 ///////////////////////////////////////////////////////////////////////////////
45 // SHA256 hash handling
46 // We need strong uniqueness checks of some very long strings - so we hash
47 // them with SHA256 and compare.
48 struct Hash {
49 uint8_t bytes[SHA256_DIGEST_LENGTH];
operator ==Hash50 bool operator==(const Hash& other) const {
51 return memcmp(bytes, other.bytes, SHA256_DIGEST_LENGTH) == 0;
52 }
operator <Hash53 bool operator<(const Hash& other) const {
54 return memcmp(bytes, other.bytes, SHA256_DIGEST_LENGTH) < 0;
55 }
ToStringHash56 std::string ToString() const {
57 std::string result;
58 for (int i = 0; i < SHA256_DIGEST_LENGTH; i++) {
59 absl::StrAppend(&result, absl::Hex(bytes[i], absl::kZeroPad2));
60 }
61 return result;
62 }
63 };
64
65 // Given a vector of ints (T), return a Hash object with the sha256
66 template <typename T>
HashVec(absl::string_view type,const std::vector<T> & v)67 Hash HashVec(absl::string_view type, const std::vector<T>& v) {
68 Hash h;
69 std::string text = absl::StrCat(type, ":", absl::StrJoin(v, ","));
70 SHA256(reinterpret_cast<const uint8_t*>(text.data()), text.size(), h.bytes);
71 return h;
72 }
73
74 ///////////////////////////////////////////////////////////////////////////////
75 // BitQueue
76 // A utility that treats a sequence of bits like a queue
77 class BitQueue {
78 public:
BitQueue(unsigned mask,int len)79 BitQueue(unsigned mask, int len) : mask_(mask), len_(len) {}
BitQueue()80 BitQueue() : BitQueue(0, 0) {}
81
82 // Return the most significant bit (the front of the queue)
Front() const83 int Front() const { return (mask_ >> (len_ - 1)) & 1; }
84 // Pop one bit off the queue
Pop()85 void Pop() {
86 mask_ &= ~(1 << (len_ - 1));
87 len_--;
88 }
Empty() const89 bool Empty() const { return len_ == 0; }
length() const90 int length() const { return len_; }
mask() const91 unsigned mask() const { return mask_; }
92
93 // Text representation of the queue
ToString() const94 std::string ToString() const {
95 return absl::StrCat(absl::Hex(mask_), "/", len_);
96 }
97
98 // Comparisons so that we can use BitQueue as a key in a std::map
operator <(const BitQueue & other) const99 bool operator<(const BitQueue& other) const {
100 return std::tie(mask_, len_) < std::tie(other.mask_, other.len_);
101 }
102
103 private:
104 // The bits
105 unsigned mask_;
106 // How many bits have we
107 int len_;
108 };
109
110 ///////////////////////////////////////////////////////////////////////////////
111 // Symbol sets for the huffman tree
112
113 // A Sym is one symbol in the tree, and the bits that we need to read to decode
114 // that symbol. As we progress through decoding we remove bits from the symbol,
115 // but also condense the number of symbols we're considering.
116 struct Sym {
117 BitQueue bits;
118 int symbol;
119
operator <Sym120 bool operator<(const Sym& other) const {
121 return std::tie(bits, symbol) < std::tie(other.bits, other.symbol);
122 }
123 };
124
125 // A SymSet is all the symbols we're considering at some time
126 using SymSet = std::vector<Sym>;
127
128 // Debug utility to turn a SymSet into a string
SymSetString(const SymSet & syms)129 std::string SymSetString(const SymSet& syms) {
130 std::vector<std::string> parts;
131 for (const Sym& sym : syms) {
132 parts.push_back(absl::StrCat(sym.symbol, ":", sym.bits.ToString()));
133 }
134 return absl::StrJoin(parts, ",");
135 }
136
137 // Initial SymSet - all the symbols [0..256] with their bits initialized from
138 // the http2 static huffman tree.
AllSyms()139 SymSet AllSyms() {
140 SymSet syms;
141 for (int i = 0; i < GRPC_CHTTP2_NUM_HUFFSYMS; i++) {
142 Sym sym;
143 sym.bits =
144 BitQueue(grpc_chttp2_huffsyms[i].bits, grpc_chttp2_huffsyms[i].length);
145 sym.symbol = i;
146 syms.push_back(sym);
147 }
148 return syms;
149 }
150
151 // What would we do after reading a set of bits?
152 struct ReadActions {
153 // Emit these symbols
154 std::vector<int> emit;
155 // Number of bits that were consumed by the read
156 int consumed;
157 // Remaining SymSet that we need to consider on the next read action
158 SymSet remaining;
159 };
160
161 // Given a SymSet \a pending, read through the bits in \a index and determine
162 // what actions the decoder should take.
163 // allow_multiple controls the behavior should we get to the last bit in pending
164 // and hence know which symbol to emit, but we still have bits in index.
165 // We could either start decoding the next symbol (allow_multiple == true), or
166 // we could stop (allow_multiple == false).
167 // If allow_multiple is true we tend to emit more per read op, but generate
168 // bigger tables.
ActionsFor(BitQueue index,SymSet pending,bool allow_multiple)169 ReadActions ActionsFor(BitQueue index, SymSet pending, bool allow_multiple) {
170 std::vector<int> emit;
171 int len_start = index.length();
172 int len_consume = len_start;
173
174 // We read one bit in index at a time, so whilst we have bits...
175 while (!index.Empty()) {
176 SymSet next_pending;
177 // For each symbol in the pending set
178 for (auto sym : pending) {
179 // If the first bit doesn't match, then that symbol is not part of our
180 // remaining set.
181 if (sym.bits.Front() != index.Front()) continue;
182 sym.bits.Pop();
183 next_pending.push_back(sym);
184 }
185 switch (next_pending.size()) {
186 case 0:
187 // There should be no bit patterns that are undecodable.
188 abort();
189 case 1:
190 // If we have one symbol left, we need to have decoded all of it.
191 if (!next_pending[0].bits.Empty()) abort();
192 // Emit that symbol
193 emit.push_back(next_pending[0].symbol);
194 // Track how many bits we've read.
195 len_consume = index.length() - 1;
196 // If we allow multiple, reprime pending and continue, otherwise stop.
197 if (!allow_multiple) goto done;
198 pending = AllSyms();
199 break;
200 default:
201 pending = std::move(next_pending);
202 break;
203 }
204 // Finished with this bit, continue with next
205 index.Pop();
206 }
207 done:
208 return ReadActions{std::move(emit), len_start - len_consume, pending};
209 }
210
211 ///////////////////////////////////////////////////////////////////////////////
212 // MatchCase
213 // A variant that helps us bunch together related ReadActions
214
215 // A Matched in a MatchCase indicates that we need to emit some number of
216 // symbols
217 struct Matched {
218 // number of symbols to emit
219 int emits;
220
operator <Matched221 bool operator<(const Matched& other) const { return emits < other.emits; }
222 };
223
224 // Unmatched says we didn't emit anything and we need to keep decoding
225 struct Unmatched {
226 SymSet syms;
227
operator <Unmatched228 bool operator<(const Unmatched& other) const { return syms < other.syms; }
229 };
230
231 // Emit end of stream
232 struct End {
operator <End233 bool operator<(End) const { return false; }
234 };
235
236 using MatchCase = absl::variant<Matched, Unmatched, End>;
237
238 ///////////////////////////////////////////////////////////////////////////////
239 // Text & numeric helper functions
240
241 // Given a vector of lines, indent those lines by some number of indents
242 // (2 spaces) and return that.
IndentLines(std::vector<std::string> lines,int n=1)243 std::vector<std::string> IndentLines(std::vector<std::string> lines,
244 int n = 1) {
245 std::string indent(2 * n, ' ');
246 for (auto& line : lines) {
247 line = absl::StrCat(indent, line);
248 }
249 return lines;
250 }
251
252 // Given a snake_case_name return a PascalCaseName
ToPascalCase(const std::string & in)253 std::string ToPascalCase(const std::string& in) {
254 std::string out;
255 bool next_upper = true;
256 for (char c : in) {
257 if (c == '_') {
258 next_upper = true;
259 } else {
260 if (next_upper) {
261 out.push_back(toupper(c));
262 next_upper = false;
263 } else {
264 out.push_back(c);
265 }
266 }
267 }
268 return out;
269 }
270
271 // Return a uint type for some number of bits (16 -> uint16_t, 32 -> uint32_t)
Uint(int bits)272 std::string Uint(int bits) { return absl::StrCat("uint", bits, "_t"); }
273
274 // Given a maximum value, how many bits to store it in a uint
TypeBitsForMax(int max)275 int TypeBitsForMax(int max) {
276 if (max <= 255) {
277 return 8;
278 } else if (max <= 65535) {
279 return 16;
280 } else {
281 return 32;
282 }
283 }
284
285 // Combine Uint & TypeBitsForMax to make for more concise code
TypeForMax(int max)286 std::string TypeForMax(int max) { return Uint(TypeBitsForMax(max)); }
287
288 // How many bits are needed to encode a value
BitsForMaxValue(int x)289 int BitsForMaxValue(int x) {
290 int n = 0;
291 while (x >= (1 << n)) n++;
292 return n;
293 }
294
295 ///////////////////////////////////////////////////////////////////////////////
296 // Codegen framework
297 // Some helpers so we don't need to generate all the code linearly, which helps
298 // organize this a little more nicely.
299
300 // An Item is our primitive for code generation, it can generate some lines
301 // that it would like to emit - those lines are fed to a parent item that might
302 // generate more lines or mutate the ones we return, and so on until codegen
303 // is complete.
304 class Item {
305 public:
306 virtual ~Item() = default;
307 virtual std::vector<std::string> ToLines() const = 0;
ToString() const308 std::string ToString() const {
309 return absl::StrCat(absl::StrJoin(ToLines(), "\n"), "\n");
310 }
311 };
312 using ItemPtr = std::unique_ptr<Item>;
313
314 // An item that emits one line (the one given as an argument!)
315 class String : public Item {
316 public:
String(std::string s)317 explicit String(std::string s) : s_(std::move(s)) {}
ToLines() const318 std::vector<std::string> ToLines() const override { return {s_}; }
319
320 private:
321 std::string s_;
322 };
323
324 // An item that returns a fixed copyright notice and autogenerated note text.
325 class Prelude final : public Item {
326 public:
Prelude(absl::string_view comment_prefix,int copyright_year)327 explicit Prelude(absl::string_view comment_prefix, int copyright_year)
328 : comment_prefix_(comment_prefix), copyright_year_(copyright_year) {}
ToLines() const329 std::vector<std::string> ToLines() const override {
330 auto line = [this](absl::string_view text) {
331 return absl::StrCat(comment_prefix_, " ", text);
332 };
333 return {
334 line(absl::StrCat("Copyright ", copyright_year_, " gRPC authors.")),
335 line(""),
336 line("Licensed under the Apache License, Version 2.0 (the "
337 "\"License\");"),
338 line(
339 "you may not use this file except in compliance with the License."),
340 line("You may obtain a copy of the License at"),
341 line(""),
342 line(" http://www.apache.org/licenses/LICENSE-2.0"),
343 line(""),
344 line("Unless required by applicable law or agreed to in writing, "
345 "software"),
346 line("distributed under the License is distributed on an \"AS IS\" "
347 "BASIS,"),
348 line("WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or "
349 "implied."),
350 line("See the License for the specific language governing permissions "
351 "and"),
352 line("limitations under the License."),
353 "",
354 line("This file is autogenerated: see "
355 "tools/codegen/core/gen_huffman_decompressor.cc"),
356 ""};
357 }
358
359 private:
360 absl::string_view comment_prefix_;
361 int copyright_year_;
362 };
363
364 class Switch;
365
366 // A Sink is an Item that we can add more Items to.
367 // At codegen time it calls each of its children in turn and concatenates
368 // their results together.
369 class Sink : public Item {
370 public:
ToLines() const371 std::vector<std::string> ToLines() const override {
372 std::vector<std::string> lines;
373 for (const auto& item : children_) {
374 for (const auto& line : item->ToLines()) {
375 lines.push_back(line);
376 }
377 }
378 return lines;
379 }
380
381 // Add one string to our output.
Add(std::string s)382 void Add(std::string s) {
383 children_.push_back(std::make_unique<String>(std::move(s)));
384 }
385
386 // Add an item of type T to our output (constructing it with args).
387 template <typename T, typename... Args>
Add(Args &&...args)388 T* Add(Args&&... args) {
389 auto v = std::make_unique<T>(std::forward<Args>(args)...);
390 auto* r = v.get();
391 children_.push_back(std::move(v));
392 return r;
393 }
394
395 private:
396 std::vector<ItemPtr> children_;
397 };
398
399 // A sink that indents its lines by one indent (2 spaces)
400 class Indent : public Sink {
401 public:
ToLines() const402 std::vector<std::string> ToLines() const override {
403 return IndentLines(Sink::ToLines());
404 }
405 };
406
407 // A Sink that wraps its lines in a while block
408 class While : public Sink {
409 public:
While(std::string cond)410 explicit While(std::string cond) : cond_(std::move(cond)) {}
ToLines() const411 std::vector<std::string> ToLines() const override {
412 std::vector<std::string> lines;
413 lines.push_back(absl::StrCat("while (", cond_, ") {"));
414 for (const auto& line : IndentLines(Sink::ToLines())) {
415 lines.push_back(line);
416 }
417 lines.push_back("}");
418 return lines;
419 }
420
421 private:
422 std::string cond_;
423 };
424
425 // A switch statement.
426 // Cases can be modified by calling the Case member.
427 // Identical cases are collapsed into 'case X: case Y:' type blocks.
428 class Switch : public Item {
429 public:
430 struct Default {
operator <Switch::Default431 bool operator<(const Default&) const { return false; }
operator ==Switch::Default432 bool operator==(const Default&) const { return true; }
433 };
434 using CaseLabel = absl::variant<int, std::string, Default>;
435 // \a cond is the condition to place at the head of the switch statement.
436 // eg. "switch (cond) {".
Switch(std::string cond)437 explicit Switch(std::string cond) : cond_(std::move(cond)) {}
ToLines() const438 std::vector<std::string> ToLines() const override {
439 std::map<std::string, std::vector<CaseLabel>> reverse_map;
440 for (const auto& kv : cases_) {
441 reverse_map[kv.second.ToString()].push_back(kv.first);
442 }
443 std::vector<std::pair<std::string, std::vector<CaseLabel>>>
444 sorted_reverse_map;
445 sorted_reverse_map.reserve(reverse_map.size());
446 for (auto& kv : reverse_map) {
447 sorted_reverse_map.push_back(kv);
448 }
449 for (auto& e : sorted_reverse_map) {
450 std::sort(e.second.begin(), e.second.end());
451 }
452 std::sort(sorted_reverse_map.begin(), sorted_reverse_map.end(),
453 [](const auto& a, const auto& b) { return a.second < b.second; });
454 std::vector<std::string> lines;
455 lines.push_back(absl::StrCat("switch (", cond_, ") {"));
456 for (const auto& kv : sorted_reverse_map) {
457 for (const auto& cond : kv.second) {
458 lines.push_back(absl::StrCat(
459 " ",
460 grpc_core::Match(
461 cond, [](Default) -> std::string { return "default"; },
462 [](int i) { return absl::StrCat("case ", i); },
463 [](const std::string& s) { return absl::StrCat("case ", s); }),
464 ":"));
465 }
466 lines.back().append(" {");
467 for (const auto& case_line :
468 IndentLines(cases_.find(kv.second[0])->second.ToLines(), 2)) {
469 lines.push_back(case_line);
470 }
471 lines.push_back(" }");
472 }
473 lines.push_back("}");
474 return lines;
475 }
476
Case(CaseLabel cond)477 Sink* Case(CaseLabel cond) { return &cases_[cond]; }
478
479 private:
480 std::string cond_;
481 std::map<CaseLabel, Sink> cases_;
482 };
483
484 ///////////////////////////////////////////////////////////////////////////////
485 // BuildCtx declaration
486 // Shared state for one code gen attempt
487
488 class TableBuilder;
489 class FunMaker;
490
491 class BuildCtx {
492 public:
BuildCtx(std::vector<int> max_bits_for_depth,Sink * global_fns,Sink * global_decls,Sink * global_values,FunMaker * fun_maker)493 BuildCtx(std::vector<int> max_bits_for_depth, Sink* global_fns,
494 Sink* global_decls, Sink* global_values, FunMaker* fun_maker)
495 : max_bits_for_depth_(std::move(max_bits_for_depth)),
496 global_fns_(global_fns),
497 global_decls_(global_decls),
498 global_values_(global_values),
499 fun_maker_(fun_maker) {}
500
501 void AddStep(SymSet start_syms, int num_bits, bool is_top, bool refill,
502 int depth, Sink* out);
503 void AddMatchBody(TableBuilder* table_builder, std::string index,
504 std::string ofs, const MatchCase& match_case, bool refill,
505 int depth, Sink* out);
506 void AddDone(SymSet start_syms, int num_bits, bool all_ones_so_far,
507 Sink* out);
508
NewId()509 int NewId() { return next_id_++; }
MaxBitsForTop() const510 int MaxBitsForTop() const { return max_bits_for_depth_[0]; }
511
PreviousNameForArtifact(std::string proposed_name,Hash hash)512 absl::optional<std::string> PreviousNameForArtifact(std::string proposed_name,
513 Hash hash) {
514 auto it = arrays_.find(hash);
515 if (it == arrays_.end()) {
516 arrays_.emplace(hash, proposed_name);
517 return absl::nullopt;
518 }
519 return it->second;
520 }
521
global_fns() const522 Sink* global_fns() const { return global_fns_; }
global_decls() const523 Sink* global_decls() const { return global_decls_; }
global_values() const524 Sink* global_values() const { return global_values_; }
525
526 private:
527 void AddDoneCase(size_t n, size_t n_bits, bool all_ones_so_far, SymSet syms,
528 std::vector<uint8_t> emit, TableBuilder* table_builder,
529 std::map<absl::optional<int>, int>* cases);
530
531 const std::vector<int> max_bits_for_depth_;
532 std::map<Hash, std::string> arrays_;
533 int next_id_ = 1;
534 Sink* const global_fns_;
535 Sink* const global_decls_;
536 Sink* const global_values_;
537 FunMaker* const fun_maker_;
538 };
539
540 ///////////////////////////////////////////////////////////////////////////////
541 // TableBuilder
542 // All our magic for building decode tables.
543 // We have three kinds of tables to generate:
544 // 1. op tables that translate a bit sequence to which decode case we should
545 // execute (and arguments to it), and
546 // 2. emit tables that translate an index given by the op table and tell us
547 // which symbols to emit
548 // Op table format
549 // Our opcodes contain an offset into an emit table, a number of bits consumed
550 // and an operation. The consumed bits are how many of the presented to us bits
551 // we actually took. The operation tells whether to emit some symbols (and how
552 // many) or to keep decoding.
553 // Optimization 1:
554 // op tables are essentially dense maps of bits -> opcode, and it turns out
555 // that *many* of the opcodes repeat across index bits for some of our tables
556 // so for those we split the table into two levels: first level indexes into
557 // a child table, and the child table contains the deduped opcodes.
558 // Optimization 2:
559 // Emit tables are a bit list of uint8_ts, and are indexed into by the op
560 // table (with an offset and length) - since many symbols get repeated, we try
561 // to overlay the symbols in the emit table to reduce the size.
562 // Optimization 3:
563 // We shard the table into some number of slices and use the top bits of the
564 // incoming lookup to select the shard. This tends to allow us to use smaller
565 // types to represent the table, saving on footprint.
566
567 class TableBuilder {
568 public:
TableBuilder(BuildCtx * ctx)569 explicit TableBuilder(BuildCtx* ctx) : ctx_(ctx), id_(ctx->NewId()) {}
570
571 // Append one case to the table
Add(int match_case,std::vector<uint8_t> emit,int consumed_bits)572 void Add(int match_case, std::vector<uint8_t> emit, int consumed_bits) {
573 elems_.push_back({match_case, std::move(emit), consumed_bits});
574 max_consumed_bits_ = std::max(max_consumed_bits_, consumed_bits);
575 max_match_case_ = std::max(max_match_case_, match_case);
576 }
577
578 // Build the table
Build() const579 void Build() const {
580 Choose()->Build(this, BitsForMaxValue(elems_.size() - 1));
581 }
582
583 // Generate a call to the accessor function for the emit table
EmitAccessor(std::string index,std::string offset)584 std::string EmitAccessor(std::string index, std::string offset) {
585 return absl::StrCat("GetEmit", id_, "(", index, ", ", offset, ")");
586 }
587
588 // Generate a call to the accessor function for the op table
OpAccessor(std::string index)589 std::string OpAccessor(std::string index) {
590 return absl::StrCat("GetOp", id_, "(", index, ")");
591 }
592
ConsumeBits() const593 int ConsumeBits() const { return BitsForMaxValue(max_consumed_bits_); }
MatchBits() const594 int MatchBits() const { return BitsForMaxValue(max_match_case_); }
595
596 private:
597 // One element in the op table.
598 struct Elem {
599 int match_case;
600 std::vector<uint8_t> emit;
601 int consumed_bits;
602 };
603
604 // A nested slice is one slice of a table using two level lookup
605 // - i.e. we look at an outer table to get an index into the inner table,
606 // and then fetch the result from there.
607 struct NestedSlice {
608 std::vector<uint8_t> emit;
609 std::vector<uint64_t> inner;
610 std::vector<int> outer;
611
612 // Various sizes return number of bits to be generated
613
InnerSizeTableBuilder::NestedSlice614 size_t InnerSize() const {
615 return inner.size() *
616 TypeBitsForMax(*std::max_element(inner.begin(), inner.end()));
617 }
618
OuterSizeTableBuilder::NestedSlice619 size_t OuterSize() const {
620 return outer.size() *
621 TypeBitsForMax(*std::max_element(outer.begin(), outer.end()));
622 }
623
EmitSizeTableBuilder::NestedSlice624 size_t EmitSize() const { return emit.size() * 8; }
625 };
626
627 // A slice is one part of a larger table.
628 struct Slice {
629 std::vector<uint8_t> emit;
630 std::vector<uint64_t> ops;
631
632 // Various sizes return number of bits to be generated
633
OpsSizeTableBuilder::Slice634 size_t OpsSize() const {
635 return ops.size() *
636 TypeBitsForMax(*std::max_element(ops.begin(), ops.end()));
637 }
638
EmitSizeTableBuilder::Slice639 size_t EmitSize() const { return emit.size() * 8; }
640
641 // Given a vector of symbols to emit, return the offset into the emit table
642 // that they're at (adding them to the emit table if necessary).
OffsetOfTableBuilder::Slice643 int OffsetOf(const std::vector<uint8_t>& x) {
644 if (x.empty()) return 0;
645 auto r = std::search(emit.begin(), emit.end(), x.begin(), x.end());
646 if (r == emit.end()) {
647 // look for a partial match @ end
648 for (size_t check_len = x.size() - 1; check_len > 0; check_len--) {
649 if (emit.size() < check_len) continue;
650 bool matches = true;
651 for (size_t i = 0; matches && i < check_len; i++) {
652 if (emit[emit.size() - check_len + i] != x[i]) matches = false;
653 }
654 if (matches) {
655 int offset = emit.size() - check_len;
656 for (size_t i = check_len; i < x.size(); i++) {
657 emit.push_back(x[i]);
658 }
659 for (size_t i = 0; i < x.size(); i++) {
660 if (emit[offset + i] != x[i]) {
661 abort();
662 }
663 }
664 return offset;
665 }
666 }
667 // add new
668 int result = emit.size();
669 for (auto v : x) emit.push_back(v);
670 return result;
671 }
672 return r - emit.begin();
673 }
674
675 // Convert this slice to a nested slice.
MakeNestedSliceTableBuilder::Slice676 NestedSlice MakeNestedSlice() const {
677 NestedSlice result;
678 result.emit = emit;
679 std::map<uint64_t, int> op_to_inner;
680 for (auto v : ops) {
681 auto it = op_to_inner.find(v);
682 if (it == op_to_inner.end()) {
683 it = op_to_inner.emplace(v, op_to_inner.size()).first;
684 result.inner.push_back(v);
685 }
686 result.outer.push_back(it->second);
687 }
688 return result;
689 }
690 };
691
692 // An EncodeOption is a potential way of encoding a table.
693 struct EncodeOption {
694 // Overall size (in bits) of the table encoding
695 virtual size_t Size() const = 0;
696 // Generate the code
697 virtual void Build(const TableBuilder* builder, int op_bits) const = 0;
~EncodeOptionTableBuilder::EncodeOption698 virtual ~EncodeOption() {}
699 };
700
701 // NestedTable is a table that uses two level lookup for each slice
702 struct NestedTable : public EncodeOption {
703 std::vector<NestedSlice> slices;
704 int slice_bits;
SizeTableBuilder::NestedTable705 size_t Size() const override {
706 size_t sum = 0;
707 std::vector<Hash> h_emit;
708 std::vector<Hash> h_inner;
709 std::vector<Hash> h_outer;
710 for (size_t i = 0; i < slices.size(); i++) {
711 h_emit.push_back(HashVec("uint8_t", slices[i].emit));
712 h_inner.push_back(HashVec(TypeForMax(MaxInner()), slices[i].inner));
713 h_outer.push_back(HashVec(TypeForMax(MaxOuter()), slices[i].outer));
714 }
715 std::set<Hash> seen;
716 for (size_t i = 0; i < slices.size(); i++) {
717 // Try to account for deduplication in the size calculation.
718 if (seen.count(h_emit[i]) == 0) sum += slices[i].EmitSize();
719 if (seen.count(h_outer[i]) == 0) sum += slices[i].OuterSize();
720 if (seen.count(h_inner[i]) == 0) sum += slices[i].OuterSize();
721 seen.insert(h_emit[i]);
722 seen.insert(h_outer[i]);
723 seen.insert(h_inner[i]);
724 }
725 if (slice_bits != 0) sum += 3 * 64 * slices.size();
726 return sum;
727 }
BuildTableBuilder::NestedTable728 void Build(const TableBuilder* builder, int op_bits) const override {
729 Sink* const global_fns = builder->ctx_->global_fns();
730 Sink* const global_decls = builder->ctx_->global_decls();
731 Sink* const global_values = builder->ctx_->global_values();
732 const int id = builder->id_;
733 std::vector<std::string> lines;
734 const uint64_t max_inner = MaxInner();
735 const uint64_t max_outer = MaxOuter();
736 std::vector<std::unique_ptr<Array>> emit_names;
737 std::vector<std::unique_ptr<Array>> inner_names;
738 std::vector<std::unique_ptr<Array>> outer_names;
739 for (size_t i = 0; i < slices.size(); i++) {
740 emit_names.push_back(builder->GenArray(
741 slice_bits != 0, absl::StrCat("table", id, "_", i, "_emit"),
742 "uint8_t", slices[i].emit, true, global_decls, global_values));
743 inner_names.push_back(builder->GenArray(
744 slice_bits != 0, absl::StrCat("table", id, "_", i, "_inner"),
745 TypeForMax(max_inner), slices[i].inner, true, global_decls,
746 global_values));
747 outer_names.push_back(builder->GenArray(
748 slice_bits != 0, absl::StrCat("table", id, "_", i, "_outer"),
749 TypeForMax(max_outer), slices[i].outer, false, global_decls,
750 global_values));
751 }
752 if (slice_bits == 0) {
753 global_fns->Add(absl::StrCat(
754 "static inline uint64_t GetOp", id, "(size_t i) { return ",
755 inner_names[0]->Index(outer_names[0]->Index("i")), "; }"));
756 global_fns->Add(absl::StrCat("static inline uint64_t GetEmit", id,
757 "(size_t, size_t emit) { return ",
758 emit_names[0]->Index("emit"), "; }"));
759 } else {
760 GenCompound(id, emit_names, "emit", "uint8_t", global_decls,
761 global_values);
762 GenCompound(id, inner_names, "inner", TypeForMax(max_inner),
763 global_decls, global_values);
764 GenCompound(id, outer_names, "outer", TypeForMax(max_outer),
765 global_decls, global_values);
766 global_fns->Add(absl::StrCat(
767 "static inline uint64_t GetOp", id, "(size_t i) { return table", id,
768 "_inner_[i >> ", op_bits - slice_bits, "][table", id,
769 "_outer_[i >> ", op_bits - slice_bits, "][i & 0x",
770 absl::Hex((1 << (op_bits - slice_bits)) - 1), "]]; }"));
771 global_fns->Add(absl::StrCat("static inline uint64_t GetEmit", id,
772 "(size_t i, size_t emit) { return table",
773 id, "_emit_[i >> ", op_bits - slice_bits,
774 "][emit]; }"));
775 }
776 }
MaxInnerTableBuilder::NestedTable777 uint64_t MaxInner() const {
778 if (max_inner == 0) {
779 for (size_t i = 0; i < slices.size(); i++) {
780 max_inner =
781 std::max(max_inner, *std::max_element(slices[i].inner.begin(),
782 slices[i].inner.end()));
783 }
784 }
785 return max_inner;
786 }
MaxOuterTableBuilder::NestedTable787 int MaxOuter() const {
788 if (max_outer == 0) {
789 for (size_t i = 0; i < slices.size(); i++) {
790 max_outer =
791 std::max(max_outer, *std::max_element(slices[i].outer.begin(),
792 slices[i].outer.end()));
793 }
794 }
795 return max_outer;
796 }
797 mutable uint64_t max_inner = 0;
798 mutable int max_outer = 0;
799 };
800
801 // Encoding that uses single level lookup for each slice.
802 struct Table : public EncodeOption {
803 std::vector<Slice> slices;
804 int slice_bits;
SizeTableBuilder::Table805 size_t Size() const override {
806 size_t sum = 0;
807 std::vector<Hash> h_emit;
808 std::vector<Hash> h_ops;
809 for (size_t i = 0; i < slices.size(); i++) {
810 h_emit.push_back(HashVec("uint8_t", slices[i].emit));
811 h_ops.push_back(HashVec(TypeForMax(MaxOp()), slices[i].ops));
812 }
813 std::set<Hash> seen;
814 for (size_t i = 0; i < slices.size(); i++) {
815 if (seen.count(h_emit[i]) == 0) sum += slices[i].EmitSize();
816 if (seen.count(h_ops[i]) == 0) sum += slices[i].OpsSize();
817 seen.insert(h_emit[i]);
818 seen.insert(h_ops[i]);
819 }
820 return sum + 3 * 64 * slices.size();
821 }
BuildTableBuilder::Table822 void Build(const TableBuilder* builder, int op_bits) const override {
823 Sink* const global_fns = builder->ctx_->global_fns();
824 Sink* const global_decls = builder->ctx_->global_decls();
825 Sink* const global_values = builder->ctx_->global_values();
826 uint64_t max_op = MaxOp();
827 const int id = builder->id_;
828 std::vector<std::unique_ptr<Array>> emit_names;
829 std::vector<std::unique_ptr<Array>> ops_names;
830 for (size_t i = 0; i < slices.size(); i++) {
831 emit_names.push_back(builder->GenArray(
832 slice_bits != 0, absl::StrCat("table", id, "_", i, "_emit"),
833 "uint8_t", slices[i].emit, true, global_decls, global_values));
834 ops_names.push_back(builder->GenArray(
835 slice_bits != 0, absl::StrCat("table", id, "_", i, "_ops"),
836 TypeForMax(max_op), slices[i].ops, true, global_decls,
837 global_values));
838 }
839 if (slice_bits == 0) {
840 global_fns->Add(absl::StrCat("static inline uint64_t GetOp", id,
841 "(size_t i) { return ",
842 ops_names[0]->Index("i"), "; }"));
843 global_fns->Add(absl::StrCat("static inline uint64_t GetEmit", id,
844 "(size_t, size_t emit) { return ",
845 emit_names[0]->Index("emit"), "; }"));
846 } else {
847 GenCompound(id, emit_names, "emit", "uint8_t", global_decls,
848 global_values);
849 GenCompound(id, ops_names, "ops", TypeForMax(max_op), global_decls,
850 global_values);
851 global_fns->Add(absl::StrCat(
852 "static inline uint64_t GetOp", id, "(size_t i) { return table", id,
853 "_ops_[i >> ", op_bits - slice_bits, "][i & 0x",
854 absl::Hex((1 << (op_bits - slice_bits)) - 1), "]; }"));
855 global_fns->Add(absl::StrCat("static inline uint64_t GetEmit", id,
856 "(size_t i, size_t emit) { return table",
857 id, "_emit_[i >> ", op_bits - slice_bits,
858 "][emit]; }"));
859 }
860 }
MaxOpTableBuilder::Table861 uint64_t MaxOp() const {
862 if (max_op == 0) {
863 for (size_t i = 0; i < slices.size(); i++) {
864 max_op = std::max(max_op, *std::max_element(slices[i].ops.begin(),
865 slices[i].ops.end()));
866 }
867 }
868 return max_op;
869 }
870 mutable uint64_t max_op = 0;
871 // Convert to a two-level lookup
MakeNestedTableTableBuilder::Table872 std::unique_ptr<NestedTable> MakeNestedTable() {
873 std::unique_ptr<NestedTable> result(new NestedTable);
874 result->slice_bits = slice_bits;
875 for (const auto& slice : slices) {
876 result->slices.push_back(slice.MakeNestedSlice());
877 }
878 return result;
879 }
880 };
881
882 // Given a number of slices (2**slice_bits), generate a table that uses a
883 // single level lookup for each slice based on our input.
MakeTable(size_t slice_bits) const884 std::unique_ptr<Table> MakeTable(size_t slice_bits) const {
885 std::unique_ptr<Table> table = std::make_unique<Table>();
886 int slices = 1 << slice_bits;
887 table->slices.resize(slices);
888 table->slice_bits = slice_bits;
889 const int pack_consume_bits = ConsumeBits();
890 const int pack_match_bits = MatchBits();
891 for (int i = 0; i < slices; i++) {
892 auto& slice = table->slices[i];
893 for (size_t j = 0; j < elems_.size() / slices; j++) {
894 const auto& elem = elems_[i * elems_.size() / slices + j];
895 slice.ops.push_back(elem.consumed_bits |
896 (elem.match_case << pack_consume_bits) |
897 (slice.OffsetOf(elem.emit)
898 << (pack_consume_bits + pack_match_bits)));
899 }
900 }
901 return table;
902 }
903
904 class Array {
905 public:
906 virtual ~Array() = default;
907 virtual std::string Index(absl::string_view value) = 0;
908 virtual std::string ArrayName() = 0;
909 virtual int Cost() = 0;
910 };
911
912 class NamedArray : public Array {
913 public:
NamedArray(std::string name)914 explicit NamedArray(std::string name) : name_(std::move(name)) {}
Index(absl::string_view value)915 std::string Index(absl::string_view value) override {
916 return absl::StrCat(name_, "[", value, "]");
917 }
ArrayName()918 std::string ArrayName() override { return name_; }
Cost()919 int Cost() override { abort(); }
920
921 private:
922 std::string name_;
923 };
924
925 class IdentityArray : public Array {
926 public:
Index(absl::string_view value)927 std::string Index(absl::string_view value) override {
928 return std::string(value);
929 }
ArrayName()930 std::string ArrayName() override { abort(); }
Cost()931 int Cost() override { return 0; }
932 };
933
934 class ConstantArray : public Array {
935 public:
ConstantArray(std::string value)936 explicit ConstantArray(std::string value) : value_(std::move(value)) {}
Index(absl::string_view index)937 std::string Index(absl::string_view index) override {
938 return absl::StrCat("((void)", index, ", ", value_, ")");
939 }
ArrayName()940 std::string ArrayName() override { abort(); }
Cost()941 int Cost() override { return 0; }
942
943 private:
944 std::string value_;
945 };
946
947 class OffsetArray : public Array {
948 public:
OffsetArray(int offset)949 explicit OffsetArray(int offset) : offset_(offset) {}
Index(absl::string_view value)950 std::string Index(absl::string_view value) override {
951 return absl::StrCat(value, " + ", offset_);
952 }
ArrayName()953 std::string ArrayName() override { abort(); }
Cost()954 int Cost() override { return 10; }
955
956 private:
957 int offset_;
958 };
959
960 class LinearDivideArray : public Array {
961 public:
LinearDivideArray(int offset,int divisor)962 LinearDivideArray(int offset, int divisor)
963 : offset_(offset), divisor_(divisor) {}
Index(absl::string_view value)964 std::string Index(absl::string_view value) override {
965 return absl::StrCat(value, "/", divisor_, " + ", offset_);
966 }
ArrayName()967 std::string ArrayName() override { abort(); }
Cost()968 int Cost() override { return 20 + (offset_ != 0 ? 10 : 0); }
969
970 private:
971 int offset_;
972 int divisor_;
973 };
974
975 class TwoElemArray : public Array {
976 public:
TwoElemArray(std::string value0,std::string value1)977 TwoElemArray(std::string value0, std::string value1)
978 : value0_(std::move(value0)), value1_(std::move(value1)) {}
Index(absl::string_view value)979 std::string Index(absl::string_view value) override {
980 return absl::StrCat(value, " ? ", value1_, " : ", value0_);
981 }
ArrayName()982 std::string ArrayName() override { abort(); }
Cost()983 int Cost() override { return 40; }
984
985 private:
986 std::string value0_;
987 std::string value1_;
988 };
989
990 class Composite2Array : public Array {
991 public:
Composite2Array(std::unique_ptr<Array> a,std::unique_ptr<Array> b,int split)992 Composite2Array(std::unique_ptr<Array> a, std::unique_ptr<Array> b,
993 int split)
994 : a_(std::move(a)), b_(std::move(b)), split_(split) {}
Index(absl::string_view value)995 std::string Index(absl::string_view value) override {
996 return absl::StrCat(
997 "(", value, " < ", split_, " ? (", a_->Index(value), ") : (",
998 b_->Index(absl::StrCat("(", value, "-", split_, ")")), "))");
999 }
ArrayName()1000 std::string ArrayName() override { abort(); }
Cost()1001 int Cost() override { return 40 + a_->Cost() + b_->Cost(); }
1002
1003 private:
1004 std::unique_ptr<Array> a_;
1005 std::unique_ptr<Array> b_;
1006 int split_;
1007 };
1008
1009 // Helper to generate a compound table (an array of arrays)
GenCompound(int id,const std::vector<std::unique_ptr<Array>> & arrays,std::string ext,std::string type,Sink * global_decls,Sink * global_values)1010 static void GenCompound(int id,
1011 const std::vector<std::unique_ptr<Array>>& arrays,
1012 std::string ext, std::string type, Sink* global_decls,
1013 Sink* global_values) {
1014 global_decls->Add(absl::StrCat("static const ", type, "* const table", id,
1015 "_", ext, "_[", arrays.size(), "];"));
1016 global_values->Add(absl::StrCat("const ", type,
1017 "* const HuffDecoderCommon::table", id, "_",
1018 ext, "_[", arrays.size(), "] = {"));
1019 for (const std::unique_ptr<Array>& array : arrays) {
1020 global_values->Add(absl::StrCat(" ", array->ArrayName(), ","));
1021 }
1022 global_values->Add("};");
1023 }
1024
1025 // Try to create a simple function equivalent to a mapping implied by a set of
1026 // values.
1027 static const int kMaxArrayToFunctionRecursions = 1;
1028 template <typename T>
ArrayToFunction(const std::vector<T> & values,int recurse=kMaxArrayToFunctionRecursions)1029 static std::unique_ptr<Array> ArrayToFunction(
1030 const std::vector<T>& values,
1031 int recurse = kMaxArrayToFunctionRecursions) {
1032 std::unique_ptr<Array> best = nullptr;
1033 auto note_solution = [&best](std::unique_ptr<Array> a) {
1034 if (best != nullptr && best->Cost() <= a->Cost()) return;
1035 best = std::move(a);
1036 };
1037 // constant => k,k,k,k,...
1038 bool is_constant = true;
1039 for (size_t i = 1; i < values.size(); i++) {
1040 if (values[i] != values[0]) {
1041 is_constant = false;
1042 break;
1043 }
1044 }
1045 if (is_constant) {
1046 note_solution(std::make_unique<ConstantArray>(absl::StrCat(values[0])));
1047 }
1048 // identity => 0,1,2,3,...
1049 bool is_identity = true;
1050 for (size_t i = 0; i < values.size(); i++) {
1051 if (static_cast<size_t>(values[i]) != i) {
1052 is_identity = false;
1053 break;
1054 }
1055 }
1056 if (is_identity) {
1057 note_solution(std::make_unique<IdentityArray>());
1058 }
1059 // offset => k,k+1,k+2,k+3,...
1060 bool is_offset = true;
1061 for (size_t i = 1; i < values.size(); i++) {
1062 if (static_cast<size_t>(values[i] - values[0]) != i) {
1063 is_offset = false;
1064 break;
1065 }
1066 }
1067 if (is_offset) {
1068 note_solution(std::make_unique<OffsetArray>(values[0]));
1069 }
1070 // offset => k,k,k+1,k+1,...
1071 for (size_t d = 2; d < 32; d++) {
1072 bool is_linear = true;
1073 for (size_t i = 1; i < values.size(); i++) {
1074 if (static_cast<size_t>(values[i] - values[0]) != (i / d)) {
1075 is_linear = false;
1076 break;
1077 }
1078 }
1079 if (is_linear) {
1080 note_solution(std::make_unique<LinearDivideArray>(values[0], d));
1081 }
1082 }
1083 // Two items can be resolved with a conditional
1084 if (values.size() == 2) {
1085 note_solution(std::make_unique<TwoElemArray>(absl::StrCat(values[0]),
1086 absl::StrCat(values[1])));
1087 }
1088 if ((recurse > 0 && values.size() >= 6) ||
1089 (recurse == kMaxArrayToFunctionRecursions)) {
1090 for (size_t i = 1; i < values.size() - 1; i++) {
1091 std::vector<T> left(values.begin(), values.begin() + i);
1092 std::vector<T> right(values.begin() + i, values.end());
1093 std::unique_ptr<Array> left_array = ArrayToFunction(left, recurse - 1);
1094 std::unique_ptr<Array> right_array =
1095 ArrayToFunction(right, recurse - 1);
1096 if (left_array && right_array) {
1097 note_solution(std::make_unique<Composite2Array>(
1098 std::move(left_array), std::move(right_array), i));
1099 }
1100 }
1101 }
1102 return best;
1103 }
1104
1105 // Helper to generate an array of values
1106 template <typename T>
GenArray(bool force_array,std::string name,std::string type,const std::vector<T> & values,bool hex,Sink * global_decls,Sink * global_values) const1107 std::unique_ptr<Array> GenArray(bool force_array, std::string name,
1108 std::string type,
1109 const std::vector<T>& values, bool hex,
1110 Sink* global_decls,
1111 Sink* global_values) const {
1112 if (values.empty()) return std::make_unique<NamedArray>("nullptr");
1113 if (!force_array) {
1114 auto fn = ArrayToFunction(values);
1115 if (fn != nullptr) return fn;
1116 }
1117 auto previous_name =
1118 ctx_->PreviousNameForArtifact(name, HashVec(type, values));
1119 if (previous_name.has_value()) {
1120 return std::make_unique<NamedArray>(absl::StrCat(*previous_name, "_"));
1121 }
1122 std::vector<std::string> elems;
1123 elems.reserve(values.size());
1124 for (const auto& elem : values) {
1125 if (hex) {
1126 if (type == "uint8_t") {
1127 elems.push_back(absl::StrCat("0x", absl::Hex(elem, absl::kZeroPad2)));
1128 } else if (type == "uint16_t") {
1129 elems.push_back(absl::StrCat("0x", absl::Hex(elem, absl::kZeroPad4)));
1130 } else {
1131 elems.push_back(absl::StrCat("0x", absl::Hex(elem, absl::kZeroPad8)));
1132 }
1133 } else {
1134 elems.push_back(absl::StrCat(elem));
1135 }
1136 }
1137 std::string data = absl::StrJoin(elems, ", ");
1138 global_decls->Add(absl::StrCat("static const ", type, " ", name, "_[",
1139 values.size(), "];"));
1140 global_values->Add(absl::StrCat("const ", type, " HuffDecoderCommon::",
1141 name, "_[", values.size(), "] = {"));
1142 global_values->Add(absl::StrCat(" ", data));
1143 global_values->Add("};");
1144 return std::make_unique<NamedArray>(absl::StrCat(name, "_"));
1145 }
1146
1147 // Choose an encoding for this set of tables.
1148 // We try all available values for slice count and choose the one that gives
1149 // the smallest footprint.
Choose() const1150 std::unique_ptr<EncodeOption> Choose() const {
1151 std::unique_ptr<EncodeOption> chosen;
1152 size_t best_size = std::numeric_limits<size_t>::max();
1153 for (size_t slice_bits = 0; (1 << slice_bits) < elems_.size();
1154 slice_bits++) {
1155 auto raw = MakeTable(slice_bits);
1156 size_t raw_size = raw->Size();
1157 auto nested = raw->MakeNestedTable();
1158 size_t nested_size = nested->Size();
1159 if (raw_size < best_size) {
1160 chosen = std::move(raw);
1161 best_size = raw_size;
1162 }
1163 if (nested_size < best_size) {
1164 chosen = std::move(nested);
1165 best_size = nested_size;
1166 }
1167 }
1168 return chosen;
1169 }
1170
1171 BuildCtx* const ctx_;
1172 std::vector<Elem> elems_;
1173 int max_consumed_bits_ = 0;
1174 int max_match_case_ = 0;
1175 const int id_;
1176 };
1177
1178 ///////////////////////////////////////////////////////////////////////////////
1179 // FunMaker
1180 // Handles generating the code for various functions.
1181
1182 class FunMaker {
1183 public:
FunMaker(Sink * sink)1184 explicit FunMaker(Sink* sink) : sink_(sink) {}
1185
1186 // Generate a refill function - that ensures the incoming bitmask has enough
1187 // bits for the next step.
RefillTo(int n)1188 std::string RefillTo(int n) {
1189 if (have_refills_.count(n) == 0) {
1190 have_refills_.insert(n);
1191 auto fn = NewFun(absl::StrCat("RefillTo", n), "bool");
1192 auto s = fn->Add<Switch>("buffer_len_");
1193 for (int i = 0; i < n; i++) {
1194 auto c = s->Case(i);
1195 const int bytes_needed = (n - i + 7) / 8;
1196 const int bytes_allowed = (64 - i) / 8;
1197 c->Add(absl::StrCat("return ", ReadBytes(bytes_needed, bytes_allowed),
1198 ";"));
1199 }
1200 fn->Add("return true;");
1201 }
1202 return absl::StrCat("RefillTo", n, "()");
1203 }
1204
1205 // At callsite, generate a call to a new function with base name
1206 // base_name (new functions get a suffix of how many instances of base_name
1207 // there have been).
1208 // Return a sink to fill in the body of the new function.
CallNewFun(std::string base_name,Sink * callsite)1209 Sink* CallNewFun(std::string base_name, Sink* callsite) {
1210 std::string name = absl::StrCat(base_name, have_funs_[base_name]++);
1211 callsite->Add(absl::StrCat(name, "();"));
1212 return NewFun(name, "void");
1213 }
1214
FillFromInput(int bytes_needed)1215 std::string FillFromInput(int bytes_needed) {
1216 auto fn_name = absl::StrCat("Fill", bytes_needed);
1217 if (have_fill_from_input_.count(bytes_needed) == 0) {
1218 have_fill_from_input_.insert(bytes_needed);
1219 auto fn = NewFun(fn_name, "void");
1220 std::string new_value;
1221 if (bytes_needed == 8) {
1222 new_value = "0";
1223 } else {
1224 new_value = absl::StrCat("(buffer_ << ", 8 * bytes_needed, ")");
1225 }
1226 for (int i = 0; i < bytes_needed; i++) {
1227 absl::StrAppend(&new_value, "| (static_cast<uint64_t>(begin_[", i,
1228 "]) << ", 8 * (bytes_needed - i - 1), ")");
1229 }
1230 fn->Add(absl::StrCat("buffer_ = ", new_value, ";"));
1231 fn->Add(absl::StrCat("begin_ += ", bytes_needed, ";"));
1232 fn->Add(absl::StrCat("buffer_len_ += ", 8 * bytes_needed, ";"));
1233 }
1234 return fn_name;
1235 }
1236
1237 private:
NewFun(std::string name,std::string returns)1238 Sink* NewFun(std::string name, std::string returns) {
1239 sink_->Add(absl::StrCat(returns, " ", name, "() {"));
1240 auto fn = sink_->Add<Indent>();
1241 sink_->Add("}");
1242 return fn;
1243 }
1244
1245 // Bring in some number of bytes from the input stream to our current read
1246 // bits.
ReadBytes(int bytes_needed,int bytes_allowed)1247 std::string ReadBytes(int bytes_needed, int bytes_allowed) {
1248 auto fn_name =
1249 absl::StrCat("Read", bytes_needed, "to", bytes_allowed, "Bytes");
1250 if (have_reads_.count(std::make_pair(bytes_needed, bytes_allowed)) == 0) {
1251 have_reads_.insert(std::make_pair(bytes_needed, bytes_allowed));
1252 auto fn = NewFun(fn_name, "bool");
1253 auto s = fn->Add<Switch>("end_ - begin_");
1254 for (int i = 0; i <= bytes_allowed; i++) {
1255 auto c = i == bytes_allowed ? s->Case(Switch::Default{}) : s->Case(i);
1256 if (i < bytes_needed) {
1257 c->Add(absl::StrCat("return false;"));
1258 } else {
1259 c->Add(absl::StrCat(FillFromInput(i), "();"));
1260 c->Add("return true;");
1261 }
1262 }
1263 }
1264 return absl::StrCat(fn_name, "()");
1265 }
1266
1267 std::set<int> have_refills_;
1268 std::set<std::pair<int, int>> have_reads_;
1269 std::set<int> have_fill_from_input_;
1270 std::map<std::string, int> have_funs_;
1271 Sink* sink_;
1272 };
1273
1274 ///////////////////////////////////////////////////////////////////////////////
1275 // BuildCtx implementation
1276
AddDone(SymSet start_syms,int num_bits,bool all_ones_so_far,Sink * out)1277 void BuildCtx::AddDone(SymSet start_syms, int num_bits, bool all_ones_so_far,
1278 Sink* out) {
1279 out->Add("done_ = true;");
1280 if (num_bits == 1) {
1281 if (!all_ones_so_far) out->Add("ok_ = false;");
1282 return;
1283 }
1284 if (num_bits > 7) {
1285 auto consume_rest = out->Add<Switch>("end_ - begin_");
1286 for (int i = 1; i < (num_bits + 7) / 8; i++) {
1287 auto c = consume_rest->Case(i);
1288 c->Add(absl::StrCat(fun_maker_->FillFromInput(i), "();"));
1289 c->Add("break;");
1290 }
1291 }
1292 // we must have 0 < buffer_len_ < num_bits
1293 auto s = out->Add<Switch>("buffer_len_");
1294 auto c0 = s->Case("0");
1295 if (!all_ones_so_far) c0->Add("ok_ = false;");
1296 c0->Add("return;");
1297 for (int i = 1; i < num_bits; i++) {
1298 auto c = s->Case(i);
1299 SymSet maybe;
1300 for (auto sym : start_syms) {
1301 if (sym.bits.length() > i) continue;
1302 maybe.push_back(sym);
1303 }
1304 if (maybe.empty()) {
1305 if (all_ones_so_far) {
1306 c->Add("ok_ = (buffer_ & ((1<<buffer_len_)-1)) == (1<<buffer_len_)-1;");
1307 } else {
1308 c->Add("ok_ = false;");
1309 }
1310 c->Add("return;");
1311 continue;
1312 }
1313 TableBuilder table_builder(this);
1314 std::map<absl::optional<int>, int> cases;
1315 for (size_t n = 0; n < (1 << i); n++) {
1316 AddDoneCase(n, i, all_ones_so_far, maybe, {}, &table_builder, &cases);
1317 }
1318 table_builder.Build();
1319 c->Add(absl::StrCat("const auto index = buffer_ & ", (1 << i) - 1, ";"));
1320 c->Add(absl::StrCat("const auto op = ", table_builder.OpAccessor("index"),
1321 ";"));
1322 if (table_builder.ConsumeBits() != 0) {
1323 fprintf(stderr, "consume bits = %d\n", table_builder.ConsumeBits());
1324 abort();
1325 }
1326 auto s_fin = c->Add<Switch>(
1327 absl::StrCat("op & ", (1 << table_builder.MatchBits()) - 1));
1328 for (auto& kv : cases) {
1329 if (kv.first.has_value()) {
1330 if (*kv.first == 0) continue;
1331 auto emit_ok = s_fin->Case(kv.second);
1332 for (int i = 0; i < *kv.first; i++) {
1333 emit_ok->Add(absl::StrCat(
1334 "sink_(",
1335 table_builder.EmitAccessor(
1336 "index", absl::StrCat("(op >> ", table_builder.MatchBits(),
1337 ") + ", i)),
1338 ");"));
1339 }
1340 emit_ok->Add("break;");
1341 } else {
1342 auto fail = s_fin->Case(kv.second);
1343 fail->Add("ok_ = false;");
1344 fail->Add("break;");
1345 }
1346 }
1347 c->Add("return;");
1348 }
1349 }
1350
AddDoneCase(size_t n,size_t n_bits,bool all_ones_so_far,SymSet syms,std::vector<uint8_t> emit,TableBuilder * table_builder,std::map<absl::optional<int>,int> * cases)1351 void BuildCtx::AddDoneCase(size_t n, size_t n_bits, bool all_ones_so_far,
1352 SymSet syms, std::vector<uint8_t> emit,
1353 TableBuilder* table_builder,
1354 std::map<absl::optional<int>, int>* cases) {
1355 auto add_case = [cases](absl::optional<int> which) {
1356 auto it = cases->find(which);
1357 if (it == cases->end()) {
1358 it = cases->emplace(which, cases->size()).first;
1359 }
1360 return it->second;
1361 };
1362 if (all_ones_so_far && n == (1 << n_bits) - 1) {
1363 table_builder->Add(add_case(emit.size()), emit, 0);
1364 return;
1365 }
1366 for (auto sym : syms) {
1367 if ((n >> (n_bits - sym.bits.length())) == sym.bits.mask()) {
1368 emit.push_back(sym.symbol);
1369 int bits_left = n_bits - sym.bits.length();
1370 if (bits_left == 0) {
1371 table_builder->Add(add_case(emit.size()), emit, 0);
1372 return;
1373 }
1374 SymSet next_syms;
1375 for (auto sym : AllSyms()) {
1376 if (sym.bits.length() > bits_left) continue;
1377 next_syms.push_back(sym);
1378 }
1379 AddDoneCase(n & ((1 << bits_left) - 1), n_bits - sym.bits.length(), true,
1380 std::move(next_syms), std::move(emit), table_builder, cases);
1381 return;
1382 }
1383 }
1384 table_builder->Add(add_case(absl::nullopt), {}, 0);
1385 }
1386
AddStep(SymSet start_syms,int num_bits,bool is_top,bool refill,int depth,Sink * out)1387 void BuildCtx::AddStep(SymSet start_syms, int num_bits, bool is_top,
1388 bool refill, int depth, Sink* out) {
1389 TableBuilder table_builder(this);
1390 if (refill) {
1391 out->Add(absl::StrCat("if (!", fun_maker_->RefillTo(num_bits), ") {"));
1392 auto ifblk = out->Add<Indent>();
1393 if (!is_top) {
1394 Sym some = start_syms[0];
1395 auto sym = grpc_chttp2_huffsyms[some.symbol];
1396 int consumed_len = (sym.length - some.bits.length());
1397 uint32_t consumed_mask = sym.bits >> some.bits.length();
1398 bool all_ones_so_far = consumed_mask == ((1 << consumed_len) - 1);
1399 AddDone(start_syms, num_bits, all_ones_so_far,
1400 fun_maker_->CallNewFun("Done", ifblk));
1401 ifblk->Add("return;");
1402 } else {
1403 AddDone(start_syms, num_bits, true,
1404 fun_maker_->CallNewFun("Done", ifblk));
1405 ifblk->Add("break;");
1406 }
1407 out->Add("}");
1408 }
1409 out->Add(absl::StrCat("const auto index = (buffer_ >> (buffer_len_ - ",
1410 num_bits, ")) & 0x", absl::Hex((1 << num_bits) - 1),
1411 ";"));
1412 std::map<MatchCase, int> match_cases;
1413 for (int i = 0; i < (1 << num_bits); i++) {
1414 auto actions = ActionsFor(BitQueue(i, num_bits), start_syms, is_top);
1415 auto add_case = [&match_cases](MatchCase match_case) {
1416 if (match_cases.find(match_case) == match_cases.end()) {
1417 match_cases[match_case] = match_cases.size();
1418 }
1419 return match_cases[match_case];
1420 };
1421 if (actions.emit.size() == 1 && actions.emit[0] == 256) {
1422 table_builder.Add(add_case(End{}), {}, actions.consumed);
1423 } else if (actions.consumed == 0) {
1424 table_builder.Add(add_case(Unmatched{std::move(actions.remaining)}), {},
1425 num_bits);
1426 } else {
1427 std::vector<uint8_t> emit;
1428 for (auto sym : actions.emit) emit.push_back(sym);
1429 table_builder.Add(
1430 add_case(Matched{static_cast<int>(actions.emit.size())}),
1431 std::move(emit), actions.consumed);
1432 }
1433 }
1434 table_builder.Build();
1435 out->Add(
1436 absl::StrCat("const auto op = ", table_builder.OpAccessor("index"), ";"));
1437 out->Add(absl::StrCat("const int consumed = op & ",
1438 (1 << table_builder.ConsumeBits()) - 1, ";"));
1439 out->Add("buffer_len_ -= consumed;");
1440 out->Add(absl::StrCat("const auto emit_ofs = op >> ",
1441 table_builder.ConsumeBits() + table_builder.MatchBits(),
1442 ";"));
1443 if (match_cases.size() == 1) {
1444 AddMatchBody(&table_builder, "index", "emit_ofs",
1445 match_cases.begin()->first, refill, depth, out);
1446 } else {
1447 auto s = out->Add<Switch>(
1448 absl::StrCat("(op >> ", table_builder.ConsumeBits(), ") & ",
1449 (1 << table_builder.MatchBits()) - 1));
1450 for (auto kv : match_cases) {
1451 auto c = s->Case(kv.second);
1452 AddMatchBody(&table_builder, "index", "emit_ofs", kv.first, refill, depth,
1453 c);
1454 c->Add("break;");
1455 }
1456 }
1457 }
1458
AddMatchBody(TableBuilder * table_builder,std::string index,std::string ofs,const MatchCase & match_case,bool refill,int depth,Sink * out)1459 void BuildCtx::AddMatchBody(TableBuilder* table_builder, std::string index,
1460 std::string ofs, const MatchCase& match_case,
1461 bool refill, int depth, Sink* out) {
1462 if (absl::holds_alternative<End>(match_case)) {
1463 out->Add("begin_ = end_;");
1464 out->Add("buffer_len_ = 0;");
1465 return;
1466 }
1467 if (auto* p = absl::get_if<Unmatched>(&match_case)) {
1468 if (refill) {
1469 int max_bits = 0;
1470 for (auto sym : p->syms) max_bits = std::max(max_bits, sym.bits.length());
1471 AddStep(p->syms,
1472 static_cast<size_t>(depth + 1) >= max_bits_for_depth_.size()
1473 ? max_bits
1474 : std::min(max_bits, max_bits_for_depth_[depth + 1]),
1475 false, true, depth + 1,
1476 fun_maker_->CallNewFun("DecodeStep", out));
1477 }
1478 return;
1479 }
1480 const auto& matched = absl::get<Matched>(match_case);
1481 for (int i = 0; i < matched.emits; i++) {
1482 out->Add(absl::StrCat(
1483 "sink_(",
1484 table_builder->EmitAccessor(index, absl::StrCat(ofs, " + ", i)), ");"));
1485 }
1486 }
1487
1488 ///////////////////////////////////////////////////////////////////////////////
1489 // Driver code
1490
1491 // Generated header and source code
1492 struct FileSet {
1493 std::string header;
1494 std::string source;
1495 const std::string base_name;
1496 std::vector<std::string> all_ns;
1497
FileSetFileSet1498 explicit FileSet(std::string base_name) : base_name(base_name) {}
1499 void AddFrontMatter(int copyright_year);
1500 void AddBuild(std::vector<int> max_bits_for_depth, bool selected_version);
1501 void AddTailMatter();
1502 };
1503
AddFrontMatter(int copyright_year)1504 void FileSet::AddFrontMatter(int copyright_year) {
1505 std::string guard = absl::StrCat(
1506 "GRPC_",
1507 absl::AsciiStrToUpper(absl::StrReplaceAll(base_name, {{"/", "_"}})),
1508 "_H");
1509 auto hdr = std::make_unique<Sink>();
1510 auto src = std::make_unique<Sink>();
1511 hdr->Add<Prelude>("//", copyright_year);
1512 src->Add<Prelude>("//", copyright_year);
1513 hdr->Add(absl::StrCat("#ifndef ", guard));
1514 hdr->Add(absl::StrCat("#define ", guard));
1515 header += hdr->ToString();
1516 source += src->ToString();
1517 }
1518
AddTailMatter()1519 void FileSet::AddTailMatter() {
1520 auto hdr = std::make_unique<Sink>();
1521 auto src = std::make_unique<Sink>();
1522 hdr->Add("#endif");
1523 header += hdr->ToString();
1524 source += src->ToString();
1525 }
1526
1527 // Given max_bits_for_depth = {n1,n2,n3,...}
1528 // Build a decoder that first considers n1 bits, then n2, then n3, ...
AddBuild(std::vector<int> max_bits_for_depth,bool selected_version)1529 void FileSet::AddBuild(std::vector<int> max_bits_for_depth,
1530 bool selected_version) {
1531 auto hdr = std::make_unique<Sink>();
1532 auto src = std::make_unique<Sink>();
1533 src->Add(absl::StrCat("#include \"", base_name, ".h\""));
1534 hdr->Add("#include <cstddef>");
1535 hdr->Add("#include <grpc/support/port_platform.h>");
1536 src->Add("#include <grpc/support/port_platform.h>");
1537 hdr->Add("#include <cstdint>");
1538 hdr->Add("namespace grpc_core {");
1539 src->Add("namespace grpc_core {");
1540 std::string ns;
1541 if (!selected_version) {
1542 ns = absl::StrCat("geometry_", absl::StrJoin(max_bits_for_depth, "_"));
1543 hdr->Add(absl::StrCat("namespace ", ns, " {"));
1544 src->Add(absl::StrCat("namespace ", ns, " {"));
1545 }
1546 hdr->Add("class HuffDecoderCommon {");
1547 hdr->Add(" protected:");
1548 auto global_fns = hdr->Add<Indent>();
1549 hdr->Add(" private:");
1550 auto global_decls = hdr->Add<Indent>();
1551 hdr->Add("};");
1552 hdr->Add(
1553 "template<typename F> class HuffDecoder : public HuffDecoderCommon {");
1554 hdr->Add(" public:");
1555 auto pub = hdr->Add<Indent>();
1556 hdr->Add(" private:");
1557 auto prv = hdr->Add<Indent>();
1558 FunMaker fun_maker(prv->Add<Sink>());
1559 hdr->Add("};");
1560 if (!ns.empty()) {
1561 hdr->Add("} // namespace geometry");
1562 }
1563 hdr->Add("} // namespace grpc_core");
1564 auto global_values = src->Add<Indent>();
1565 if (!ns.empty()) {
1566 src->Add("} // namespace geometry");
1567 }
1568 src->Add("} // namespace grpc_core");
1569 BuildCtx ctx(std::move(max_bits_for_depth), global_fns, global_decls,
1570 global_values, &fun_maker);
1571 // constructor
1572 pub->Add(
1573 "HuffDecoder(F sink, const uint8_t* begin, const uint8_t* end) : "
1574 "sink_(sink), begin_(begin), end_(end) {}");
1575 // members
1576 prv->Add("F sink_;");
1577 prv->Add("const uint8_t* begin_;");
1578 prv->Add("const uint8_t* const end_;");
1579 prv->Add("uint64_t buffer_ = 0;");
1580 prv->Add("int buffer_len_ = 0;");
1581 prv->Add("bool ok_ = true;");
1582 prv->Add("bool done_ = false;");
1583 // main fn
1584 pub->Add("bool Run() {");
1585 auto body = pub->Add<Indent>();
1586 body->Add("while (!done_) {");
1587 ctx.AddStep(AllSyms(), ctx.MaxBitsForTop(), true, true, 0,
1588 body->Add<Indent>());
1589 body->Add("}");
1590 body->Add("return ok_;");
1591 pub->Add("}");
1592 header += hdr->ToString();
1593 source += src->ToString();
1594 all_ns.push_back(std::move(ns));
1595 }
1596
1597 // Generate all permutations of max_bits_for_depth for the Build function,
1598 // with a minimum step size of 5 bits (needed for http2 I think) and a
1599 // configurable maximum step size.
1600 class PermutationBuilder {
1601 public:
PermutationBuilder(int max_depth)1602 explicit PermutationBuilder(int max_depth) : max_depth_(max_depth) {}
Run()1603 std::vector<std::vector<int>> Run() {
1604 Step({});
1605 return std::move(perms_);
1606 }
1607
1608 private:
Step(std::vector<int> so_far)1609 void Step(std::vector<int> so_far) {
1610 // Restrict first step to 7 bits - smaller is known to generate simply
1611 // terrible code.
1612 const int min_step = so_far.empty() ? 7 : 5;
1613 int sum_so_far = std::accumulate(so_far.begin(), so_far.end(), 0);
1614 if (so_far.size() > max_depth_ ||
1615 (so_far.size() == max_depth_ && sum_so_far != 30)) {
1616 return;
1617 }
1618 if (sum_so_far + 5 > 30) {
1619 perms_.emplace_back(std::move(so_far));
1620 return;
1621 }
1622 for (int i = min_step; i <= std::min(30 - sum_so_far, 16); i++) {
1623 auto p = so_far;
1624 p.push_back(i);
1625 Step(std::move(p));
1626 }
1627 }
1628
1629 const size_t max_depth_;
1630 std::vector<std::vector<int>> perms_;
1631 };
1632
1633 // Split after c
SplitAfter(absl::string_view input,char c)1634 std::string SplitAfter(absl::string_view input, char c) {
1635 return std::vector<std::string>(absl::StrSplit(input, c)).back();
1636 }
SplitBefore(absl::string_view input,char c)1637 std::string SplitBefore(absl::string_view input, char c) {
1638 return std::vector<std::string>(absl::StrSplit(input, c)).front();
1639 }
1640
1641 // Does what it says.
WriteFile(std::string filename,std::string content)1642 void WriteFile(std::string filename, std::string content) {
1643 auto out = grpc_core::GetEnv("GEN_OUT");
1644 if (out.has_value()) {
1645 filename = absl::StrCat(*out, "/", filename);
1646 }
1647 std::ofstream ofs(filename);
1648 ofs << content;
1649 if (ofs.bad()) {
1650 fprintf(stderr, "Failed to write %s\n", filename.c_str());
1651 abort();
1652 }
1653 }
1654
GenMicrobenchmarks()1655 void GenMicrobenchmarks() {
1656 std::queue<std::thread> threads;
1657 // Generate all permutations of max_bits_for_depth for the Build function.
1658 // Then generate all variations of the code.
1659 static constexpr int kNumShards = 100;
1660 std::unique_ptr<FileSet> results[kNumShards];
1661 std::mutex results_mutexes[kNumShards];
1662 for (int i = 0; i < kNumShards; i++) {
1663 results[i] = std::make_unique<FileSet>(
1664 absl::StrCat("test/cpp/microbenchmarks/huffman_geometries/shard_", i));
1665 results[i]->AddFrontMatter(2024);
1666 }
1667 int r = 0;
1668 for (const auto& perm : PermutationBuilder(3).Run()) {
1669 int shard = r++ % kNumShards;
1670 threads.emplace(
1671 [perm, fileset = results[shard].get(), mu = &results_mutexes[shard]] {
1672 std::lock_guard<std::mutex> lock(*mu);
1673 fileset->AddBuild(perm, false);
1674 });
1675 }
1676 while (!threads.empty()) {
1677 threads.front().join();
1678 threads.pop();
1679 }
1680 auto index_hdr = std::make_unique<Sink>();
1681 index_hdr->Add<Prelude>("//", 2023);
1682 index_hdr->Add(
1683 "#ifndef GRPC_TEST_CPP_MICROBENCHMARKS_HUFFMAN_GEOMETRIES_INDEX_H");
1684 index_hdr->Add(
1685 "#define GRPC_TEST_CPP_MICROBENCHMARKS_HUFFMAN_GEOMETRIES_INDEX_H");
1686 auto index_includes = index_hdr->Add<Sink>();
1687 index_hdr->Add("#define DECL_HUFFMAN_VARIANTS() \\");
1688 auto index_decls = index_hdr->Add<Sink>();
1689 index_hdr->Add(" DECL_BENCHMARK(grpc_core::HuffDecoder, Selected)");
1690 index_hdr->Add(
1691 "#endif // GRPC_TEST_CPP_MICROBENCHMARKS_HUFFMAN_GEOMETRIES_INDEX_H");
1692
1693 for (auto& r : results) {
1694 r->AddTailMatter();
1695 index_includes->Add(absl::StrCat("#include \"", r->base_name, ".h\""));
1696 for (const auto& ns : r->all_ns) {
1697 index_decls->Add(absl::StrCat(" DECL_BENCHMARK(grpc_core::", ns,
1698 "::HuffDecoder, ", ns, "); \\"));
1699 }
1700 WriteFile(r->base_name + ".h", r->header);
1701 WriteFile(r->base_name + ".cc", r->source);
1702 }
1703 WriteFile("test/cpp/microbenchmarks/huffman_geometries/index.h",
1704 index_hdr->ToString());
1705 }
1706
GenSelected()1707 void GenSelected() {
1708 FileSet selected("src/core/ext/transport/chttp2/transport/decode_huff");
1709 selected.AddFrontMatter(2023);
1710 selected.AddBuild(std::vector<int>({15, 7, 8}), true);
1711 selected.AddTailMatter();
1712 WriteFile(selected.base_name + ".h", selected.header);
1713 WriteFile(selected.base_name + ".cc", selected.source);
1714 }
1715
main(int argc,char ** argv)1716 int main(int argc, char** argv) {
1717 if (argc < 2) {
1718 fprintf(stderr, "No generators specified\n");
1719 return 1;
1720 }
1721 std::map<std::string, std::function<void()>> generators = {
1722 {"microbenchmarks", GenMicrobenchmarks}, {"selected", GenSelected}};
1723 for (int i = 1; i < argc; i++) {
1724 auto it = generators.find(argv[i]);
1725 if (it == generators.end()) {
1726 fprintf(stderr, "Unknown generator: %s\n", argv[i]);
1727 return 1;
1728 }
1729 it->second();
1730 }
1731 return 0;
1732 }
1733