• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package protowire parses and formats the raw wire encoding.
6// See https://developers.google.com/protocol-buffers/docs/encoding.
7//
8// For marshaling and unmarshaling entire protobuf messages,
9// use the "google.golang.org/protobuf/proto" package instead.
10package protowire
11
12import (
13	"io"
14	"math"
15	"math/bits"
16
17	"google.golang.org/protobuf/internal/errors"
18)
19
20// Number represents the field number.
21type Number int32
22
23const (
24	MinValidNumber        Number = 1
25	FirstReservedNumber   Number = 19000
26	LastReservedNumber    Number = 19999
27	MaxValidNumber        Number = 1<<29 - 1
28	DefaultRecursionLimit        = 10000
29)
30
31// IsValid reports whether the field number is semantically valid.
32//
33// Note that while numbers within the reserved range are semantically invalid,
34// they are syntactically valid in the wire format.
35// Implementations may treat records with reserved field numbers as unknown.
36func (n Number) IsValid() bool {
37	return MinValidNumber <= n && n < FirstReservedNumber || LastReservedNumber < n && n <= MaxValidNumber
38}
39
40// Type represents the wire type.
41type Type int8
42
43const (
44	VarintType     Type = 0
45	Fixed32Type    Type = 5
46	Fixed64Type    Type = 1
47	BytesType      Type = 2
48	StartGroupType Type = 3
49	EndGroupType   Type = 4
50)
51
52const (
53	_ = -iota
54	errCodeTruncated
55	errCodeFieldNumber
56	errCodeOverflow
57	errCodeReserved
58	errCodeEndGroup
59	errCodeRecursionDepth
60)
61
62var (
63	errFieldNumber = errors.New("invalid field number")
64	errOverflow    = errors.New("variable length integer overflow")
65	errReserved    = errors.New("cannot parse reserved wire type")
66	errEndGroup    = errors.New("mismatching end group marker")
67	errParse       = errors.New("parse error")
68)
69
70// ParseError converts an error code into an error value.
71// This returns nil if n is a non-negative number.
72func ParseError(n int) error {
73	if n >= 0 {
74		return nil
75	}
76	switch n {
77	case errCodeTruncated:
78		return io.ErrUnexpectedEOF
79	case errCodeFieldNumber:
80		return errFieldNumber
81	case errCodeOverflow:
82		return errOverflow
83	case errCodeReserved:
84		return errReserved
85	case errCodeEndGroup:
86		return errEndGroup
87	default:
88		return errParse
89	}
90}
91
92// ConsumeField parses an entire field record (both tag and value) and returns
93// the field number, the wire type, and the total length.
94// This returns a negative length upon an error (see ParseError).
95//
96// The total length includes the tag header and the end group marker (if the
97// field is a group).
98func ConsumeField(b []byte) (Number, Type, int) {
99	num, typ, n := ConsumeTag(b)
100	if n < 0 {
101		return 0, 0, n // forward error code
102	}
103	m := ConsumeFieldValue(num, typ, b[n:])
104	if m < 0 {
105		return 0, 0, m // forward error code
106	}
107	return num, typ, n + m
108}
109
110// ConsumeFieldValue parses a field value and returns its length.
111// This assumes that the field Number and wire Type have already been parsed.
112// This returns a negative length upon an error (see ParseError).
113//
114// When parsing a group, the length includes the end group marker and
115// the end group is verified to match the starting field number.
116func ConsumeFieldValue(num Number, typ Type, b []byte) (n int) {
117	return consumeFieldValueD(num, typ, b, DefaultRecursionLimit)
118}
119
120func consumeFieldValueD(num Number, typ Type, b []byte, depth int) (n int) {
121	switch typ {
122	case VarintType:
123		_, n = ConsumeVarint(b)
124		return n
125	case Fixed32Type:
126		_, n = ConsumeFixed32(b)
127		return n
128	case Fixed64Type:
129		_, n = ConsumeFixed64(b)
130		return n
131	case BytesType:
132		_, n = ConsumeBytes(b)
133		return n
134	case StartGroupType:
135		if depth < 0 {
136			return errCodeRecursionDepth
137		}
138		n0 := len(b)
139		for {
140			num2, typ2, n := ConsumeTag(b)
141			if n < 0 {
142				return n // forward error code
143			}
144			b = b[n:]
145			if typ2 == EndGroupType {
146				if num != num2 {
147					return errCodeEndGroup
148				}
149				return n0 - len(b)
150			}
151
152			n = consumeFieldValueD(num2, typ2, b, depth-1)
153			if n < 0 {
154				return n // forward error code
155			}
156			b = b[n:]
157		}
158	case EndGroupType:
159		return errCodeEndGroup
160	default:
161		return errCodeReserved
162	}
163}
164
165// AppendTag encodes num and typ as a varint-encoded tag and appends it to b.
166func AppendTag(b []byte, num Number, typ Type) []byte {
167	return AppendVarint(b, EncodeTag(num, typ))
168}
169
170// ConsumeTag parses b as a varint-encoded tag, reporting its length.
171// This returns a negative length upon an error (see ParseError).
172func ConsumeTag(b []byte) (Number, Type, int) {
173	v, n := ConsumeVarint(b)
174	if n < 0 {
175		return 0, 0, n // forward error code
176	}
177	num, typ := DecodeTag(v)
178	if num < MinValidNumber {
179		return 0, 0, errCodeFieldNumber
180	}
181	return num, typ, n
182}
183
184func SizeTag(num Number) int {
185	return SizeVarint(EncodeTag(num, 0)) // wire type has no effect on size
186}
187
188// AppendVarint appends v to b as a varint-encoded uint64.
189func AppendVarint(b []byte, v uint64) []byte {
190	switch {
191	case v < 1<<7:
192		b = append(b, byte(v))
193	case v < 1<<14:
194		b = append(b,
195			byte((v>>0)&0x7f|0x80),
196			byte(v>>7))
197	case v < 1<<21:
198		b = append(b,
199			byte((v>>0)&0x7f|0x80),
200			byte((v>>7)&0x7f|0x80),
201			byte(v>>14))
202	case v < 1<<28:
203		b = append(b,
204			byte((v>>0)&0x7f|0x80),
205			byte((v>>7)&0x7f|0x80),
206			byte((v>>14)&0x7f|0x80),
207			byte(v>>21))
208	case v < 1<<35:
209		b = append(b,
210			byte((v>>0)&0x7f|0x80),
211			byte((v>>7)&0x7f|0x80),
212			byte((v>>14)&0x7f|0x80),
213			byte((v>>21)&0x7f|0x80),
214			byte(v>>28))
215	case v < 1<<42:
216		b = append(b,
217			byte((v>>0)&0x7f|0x80),
218			byte((v>>7)&0x7f|0x80),
219			byte((v>>14)&0x7f|0x80),
220			byte((v>>21)&0x7f|0x80),
221			byte((v>>28)&0x7f|0x80),
222			byte(v>>35))
223	case v < 1<<49:
224		b = append(b,
225			byte((v>>0)&0x7f|0x80),
226			byte((v>>7)&0x7f|0x80),
227			byte((v>>14)&0x7f|0x80),
228			byte((v>>21)&0x7f|0x80),
229			byte((v>>28)&0x7f|0x80),
230			byte((v>>35)&0x7f|0x80),
231			byte(v>>42))
232	case v < 1<<56:
233		b = append(b,
234			byte((v>>0)&0x7f|0x80),
235			byte((v>>7)&0x7f|0x80),
236			byte((v>>14)&0x7f|0x80),
237			byte((v>>21)&0x7f|0x80),
238			byte((v>>28)&0x7f|0x80),
239			byte((v>>35)&0x7f|0x80),
240			byte((v>>42)&0x7f|0x80),
241			byte(v>>49))
242	case v < 1<<63:
243		b = append(b,
244			byte((v>>0)&0x7f|0x80),
245			byte((v>>7)&0x7f|0x80),
246			byte((v>>14)&0x7f|0x80),
247			byte((v>>21)&0x7f|0x80),
248			byte((v>>28)&0x7f|0x80),
249			byte((v>>35)&0x7f|0x80),
250			byte((v>>42)&0x7f|0x80),
251			byte((v>>49)&0x7f|0x80),
252			byte(v>>56))
253	default:
254		b = append(b,
255			byte((v>>0)&0x7f|0x80),
256			byte((v>>7)&0x7f|0x80),
257			byte((v>>14)&0x7f|0x80),
258			byte((v>>21)&0x7f|0x80),
259			byte((v>>28)&0x7f|0x80),
260			byte((v>>35)&0x7f|0x80),
261			byte((v>>42)&0x7f|0x80),
262			byte((v>>49)&0x7f|0x80),
263			byte((v>>56)&0x7f|0x80),
264			1)
265	}
266	return b
267}
268
269// ConsumeVarint parses b as a varint-encoded uint64, reporting its length.
270// This returns a negative length upon an error (see ParseError).
271func ConsumeVarint(b []byte) (v uint64, n int) {
272	var y uint64
273	if len(b) <= 0 {
274		return 0, errCodeTruncated
275	}
276	v = uint64(b[0])
277	if v < 0x80 {
278		return v, 1
279	}
280	v -= 0x80
281
282	if len(b) <= 1 {
283		return 0, errCodeTruncated
284	}
285	y = uint64(b[1])
286	v += y << 7
287	if y < 0x80 {
288		return v, 2
289	}
290	v -= 0x80 << 7
291
292	if len(b) <= 2 {
293		return 0, errCodeTruncated
294	}
295	y = uint64(b[2])
296	v += y << 14
297	if y < 0x80 {
298		return v, 3
299	}
300	v -= 0x80 << 14
301
302	if len(b) <= 3 {
303		return 0, errCodeTruncated
304	}
305	y = uint64(b[3])
306	v += y << 21
307	if y < 0x80 {
308		return v, 4
309	}
310	v -= 0x80 << 21
311
312	if len(b) <= 4 {
313		return 0, errCodeTruncated
314	}
315	y = uint64(b[4])
316	v += y << 28
317	if y < 0x80 {
318		return v, 5
319	}
320	v -= 0x80 << 28
321
322	if len(b) <= 5 {
323		return 0, errCodeTruncated
324	}
325	y = uint64(b[5])
326	v += y << 35
327	if y < 0x80 {
328		return v, 6
329	}
330	v -= 0x80 << 35
331
332	if len(b) <= 6 {
333		return 0, errCodeTruncated
334	}
335	y = uint64(b[6])
336	v += y << 42
337	if y < 0x80 {
338		return v, 7
339	}
340	v -= 0x80 << 42
341
342	if len(b) <= 7 {
343		return 0, errCodeTruncated
344	}
345	y = uint64(b[7])
346	v += y << 49
347	if y < 0x80 {
348		return v, 8
349	}
350	v -= 0x80 << 49
351
352	if len(b) <= 8 {
353		return 0, errCodeTruncated
354	}
355	y = uint64(b[8])
356	v += y << 56
357	if y < 0x80 {
358		return v, 9
359	}
360	v -= 0x80 << 56
361
362	if len(b) <= 9 {
363		return 0, errCodeTruncated
364	}
365	y = uint64(b[9])
366	v += y << 63
367	if y < 2 {
368		return v, 10
369	}
370	return 0, errCodeOverflow
371}
372
373// SizeVarint returns the encoded size of a varint.
374// The size is guaranteed to be within 1 and 10, inclusive.
375func SizeVarint(v uint64) int {
376	// This computes 1 + (bits.Len64(v)-1)/7.
377	// 9/64 is a good enough approximation of 1/7
378	return int(9*uint32(bits.Len64(v))+64) / 64
379}
380
381// AppendFixed32 appends v to b as a little-endian uint32.
382func AppendFixed32(b []byte, v uint32) []byte {
383	return append(b,
384		byte(v>>0),
385		byte(v>>8),
386		byte(v>>16),
387		byte(v>>24))
388}
389
390// ConsumeFixed32 parses b as a little-endian uint32, reporting its length.
391// This returns a negative length upon an error (see ParseError).
392func ConsumeFixed32(b []byte) (v uint32, n int) {
393	if len(b) < 4 {
394		return 0, errCodeTruncated
395	}
396	v = uint32(b[0])<<0 | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
397	return v, 4
398}
399
400// SizeFixed32 returns the encoded size of a fixed32; which is always 4.
401func SizeFixed32() int {
402	return 4
403}
404
405// AppendFixed64 appends v to b as a little-endian uint64.
406func AppendFixed64(b []byte, v uint64) []byte {
407	return append(b,
408		byte(v>>0),
409		byte(v>>8),
410		byte(v>>16),
411		byte(v>>24),
412		byte(v>>32),
413		byte(v>>40),
414		byte(v>>48),
415		byte(v>>56))
416}
417
418// ConsumeFixed64 parses b as a little-endian uint64, reporting its length.
419// This returns a negative length upon an error (see ParseError).
420func ConsumeFixed64(b []byte) (v uint64, n int) {
421	if len(b) < 8 {
422		return 0, errCodeTruncated
423	}
424	v = uint64(b[0])<<0 | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 | uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56
425	return v, 8
426}
427
428// SizeFixed64 returns the encoded size of a fixed64; which is always 8.
429func SizeFixed64() int {
430	return 8
431}
432
433// AppendBytes appends v to b as a length-prefixed bytes value.
434func AppendBytes(b []byte, v []byte) []byte {
435	return append(AppendVarint(b, uint64(len(v))), v...)
436}
437
438// ConsumeBytes parses b as a length-prefixed bytes value, reporting its length.
439// This returns a negative length upon an error (see ParseError).
440func ConsumeBytes(b []byte) (v []byte, n int) {
441	m, n := ConsumeVarint(b)
442	if n < 0 {
443		return nil, n // forward error code
444	}
445	if m > uint64(len(b[n:])) {
446		return nil, errCodeTruncated
447	}
448	return b[n:][:m], n + int(m)
449}
450
451// SizeBytes returns the encoded size of a length-prefixed bytes value,
452// given only the length.
453func SizeBytes(n int) int {
454	return SizeVarint(uint64(n)) + n
455}
456
457// AppendString appends v to b as a length-prefixed bytes value.
458func AppendString(b []byte, v string) []byte {
459	return append(AppendVarint(b, uint64(len(v))), v...)
460}
461
462// ConsumeString parses b as a length-prefixed bytes value, reporting its length.
463// This returns a negative length upon an error (see ParseError).
464func ConsumeString(b []byte) (v string, n int) {
465	bb, n := ConsumeBytes(b)
466	return string(bb), n
467}
468
469// AppendGroup appends v to b as group value, with a trailing end group marker.
470// The value v must not contain the end marker.
471func AppendGroup(b []byte, num Number, v []byte) []byte {
472	return AppendVarint(append(b, v...), EncodeTag(num, EndGroupType))
473}
474
475// ConsumeGroup parses b as a group value until the trailing end group marker,
476// and verifies that the end marker matches the provided num. The value v
477// does not contain the end marker, while the length does contain the end marker.
478// This returns a negative length upon an error (see ParseError).
479func ConsumeGroup(num Number, b []byte) (v []byte, n int) {
480	n = ConsumeFieldValue(num, StartGroupType, b)
481	if n < 0 {
482		return nil, n // forward error code
483	}
484	b = b[:n]
485
486	// Truncate off end group marker, but need to handle denormalized varints.
487	// Assuming end marker is never 0 (which is always the case since
488	// EndGroupType is non-zero), we can truncate all trailing bytes where the
489	// lower 7 bits are all zero (implying that the varint is denormalized).
490	for len(b) > 0 && b[len(b)-1]&0x7f == 0 {
491		b = b[:len(b)-1]
492	}
493	b = b[:len(b)-SizeTag(num)]
494	return b, n
495}
496
497// SizeGroup returns the encoded size of a group, given only the length.
498func SizeGroup(num Number, n int) int {
499	return n + SizeTag(num)
500}
501
502// DecodeTag decodes the field Number and wire Type from its unified form.
503// The Number is -1 if the decoded field number overflows int32.
504// Other than overflow, this does not check for field number validity.
505func DecodeTag(x uint64) (Number, Type) {
506	// NOTE: MessageSet allows for larger field numbers than normal.
507	if x>>3 > uint64(math.MaxInt32) {
508		return -1, 0
509	}
510	return Number(x >> 3), Type(x & 7)
511}
512
513// EncodeTag encodes the field Number and wire Type into its unified form.
514func EncodeTag(num Number, typ Type) uint64 {
515	return uint64(num)<<3 | uint64(typ&7)
516}
517
518// DecodeZigZag decodes a zig-zag-encoded uint64 as an int64.
519//	Input:  {…,  5,  3,  1,  0,  2,  4,  6, …}
520//	Output: {…, -3, -2, -1,  0, +1, +2, +3, …}
521func DecodeZigZag(x uint64) int64 {
522	return int64(x>>1) ^ int64(x)<<63>>63
523}
524
525// EncodeZigZag encodes an int64 as a zig-zag-encoded uint64.
526//	Input:  {…, -3, -2, -1,  0, +1, +2, +3, …}
527//	Output: {…,  5,  3,  1,  0,  2,  4,  6, …}
528func EncodeZigZag(x int64) uint64 {
529	return uint64(x<<1) ^ uint64(x>>63)
530}
531
532// DecodeBool decodes a uint64 as a bool.
533//	Input:  {    0,    1,    2, …}
534//	Output: {false, true, true, …}
535func DecodeBool(x uint64) bool {
536	return x != 0
537}
538
539// EncodeBool encodes a bool as a uint64.
540//	Input:  {false, true}
541//	Output: {    0,    1}
542func EncodeBool(x bool) uint64 {
543	if x {
544		return 1
545	}
546	return 0
547}
548