1// Copyright 2014 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package hpack 6 7import ( 8 "bytes" 9 "errors" 10 "io" 11 "sync" 12) 13 14var bufPool = sync.Pool{ 15 New: func() interface{} { return new(bytes.Buffer) }, 16} 17 18// HuffmanDecode decodes the string in v and writes the expanded 19// result to w, returning the number of bytes written to w and the 20// Write call's return value. At most one Write call is made. 21func HuffmanDecode(w io.Writer, v []byte) (int, error) { 22 buf := bufPool.Get().(*bytes.Buffer) 23 buf.Reset() 24 defer bufPool.Put(buf) 25 if err := huffmanDecode(buf, 0, v); err != nil { 26 return 0, err 27 } 28 return w.Write(buf.Bytes()) 29} 30 31// HuffmanDecodeToString decodes the string in v. 32func HuffmanDecodeToString(v []byte) (string, error) { 33 buf := bufPool.Get().(*bytes.Buffer) 34 buf.Reset() 35 defer bufPool.Put(buf) 36 if err := huffmanDecode(buf, 0, v); err != nil { 37 return "", err 38 } 39 return buf.String(), nil 40} 41 42// ErrInvalidHuffman is returned for errors found decoding 43// Huffman-encoded strings. 44var ErrInvalidHuffman = errors.New("hpack: invalid Huffman-encoded data") 45 46// huffmanDecode decodes v to buf. 47// If maxLen is greater than 0, attempts to write more to buf than 48// maxLen bytes will return ErrStringLength. 49func huffmanDecode(buf *bytes.Buffer, maxLen int, v []byte) error { 50 n := rootHuffmanNode 51 // cur is the bit buffer that has not been fed into n. 52 // cbits is the number of low order bits in cur that are valid. 53 // sbits is the number of bits of the symbol prefix being decoded. 54 cur, cbits, sbits := uint(0), uint8(0), uint8(0) 55 for _, b := range v { 56 cur = cur<<8 | uint(b) 57 cbits += 8 58 sbits += 8 59 for cbits >= 8 { 60 idx := byte(cur >> (cbits - 8)) 61 n = n.children[idx] 62 if n == nil { 63 return ErrInvalidHuffman 64 } 65 if n.children == nil { 66 if maxLen != 0 && buf.Len() == maxLen { 67 return ErrStringLength 68 } 69 buf.WriteByte(n.sym) 70 cbits -= n.codeLen 71 n = rootHuffmanNode 72 sbits = cbits 73 } else { 74 cbits -= 8 75 } 76 } 77 } 78 for cbits > 0 { 79 n = n.children[byte(cur<<(8-cbits))] 80 if n == nil { 81 return ErrInvalidHuffman 82 } 83 if n.children != nil || n.codeLen > cbits { 84 break 85 } 86 if maxLen != 0 && buf.Len() == maxLen { 87 return ErrStringLength 88 } 89 buf.WriteByte(n.sym) 90 cbits -= n.codeLen 91 n = rootHuffmanNode 92 sbits = cbits 93 } 94 if sbits > 7 { 95 // Either there was an incomplete symbol, or overlong padding. 96 // Both are decoding errors per RFC 7541 section 5.2. 97 return ErrInvalidHuffman 98 } 99 if mask := uint(1<<cbits - 1); cur&mask != mask { 100 // Trailing bits must be a prefix of EOS per RFC 7541 section 5.2. 101 return ErrInvalidHuffman 102 } 103 104 return nil 105} 106 107type node struct { 108 // children is non-nil for internal nodes 109 children []*node 110 111 // The following are only valid if children is nil: 112 codeLen uint8 // number of bits that led to the output of sym 113 sym byte // output symbol 114} 115 116func newInternalNode() *node { 117 return &node{children: make([]*node, 256)} 118} 119 120var rootHuffmanNode = newInternalNode() 121 122func init() { 123 if len(huffmanCodes) != 256 { 124 panic("unexpected size") 125 } 126 for i, code := range huffmanCodes { 127 addDecoderNode(byte(i), code, huffmanCodeLen[i]) 128 } 129} 130 131func addDecoderNode(sym byte, code uint32, codeLen uint8) { 132 cur := rootHuffmanNode 133 for codeLen > 8 { 134 codeLen -= 8 135 i := uint8(code >> codeLen) 136 if cur.children[i] == nil { 137 cur.children[i] = newInternalNode() 138 } 139 cur = cur.children[i] 140 } 141 shift := 8 - codeLen 142 start, end := int(uint8(code<<shift)), int(1<<shift) 143 for i := start; i < start+end; i++ { 144 cur.children[i] = &node{sym: sym, codeLen: codeLen} 145 } 146} 147 148// AppendHuffmanString appends s, as encoded in Huffman codes, to dst 149// and returns the extended buffer. 150func AppendHuffmanString(dst []byte, s string) []byte { 151 rembits := uint8(8) 152 153 for i := 0; i < len(s); i++ { 154 if rembits == 8 { 155 dst = append(dst, 0) 156 } 157 dst, rembits = appendByteToHuffmanCode(dst, rembits, s[i]) 158 } 159 160 if rembits < 8 { 161 // special EOS symbol 162 code := uint32(0x3fffffff) 163 nbits := uint8(30) 164 165 t := uint8(code >> (nbits - rembits)) 166 dst[len(dst)-1] |= t 167 } 168 169 return dst 170} 171 172// HuffmanEncodeLength returns the number of bytes required to encode 173// s in Huffman codes. The result is round up to byte boundary. 174func HuffmanEncodeLength(s string) uint64 { 175 n := uint64(0) 176 for i := 0; i < len(s); i++ { 177 n += uint64(huffmanCodeLen[s[i]]) 178 } 179 return (n + 7) / 8 180} 181 182// appendByteToHuffmanCode appends Huffman code for c to dst and 183// returns the extended buffer and the remaining bits in the last 184// element. The appending is not byte aligned and the remaining bits 185// in the last element of dst is given in rembits. 186func appendByteToHuffmanCode(dst []byte, rembits uint8, c byte) ([]byte, uint8) { 187 code := huffmanCodes[c] 188 nbits := huffmanCodeLen[c] 189 190 for { 191 if rembits > nbits { 192 t := uint8(code << (rembits - nbits)) 193 dst[len(dst)-1] |= t 194 rembits -= nbits 195 break 196 } 197 198 t := uint8(code >> (nbits - rembits)) 199 dst[len(dst)-1] |= t 200 201 nbits -= rembits 202 rembits = 8 203 204 if nbits == 0 { 205 break 206 } 207 208 dst = append(dst, 0) 209 } 210 211 return dst, rembits 212} 213