• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package protowire parses and formats the raw wire encoding.
6// See https://protobuf.dev/programming-guides/encoding.
7//
8// For marshaling and unmarshaling entire protobuf messages,
9// use the "google.golang.org/protobuf/proto" package instead.
10package protowire
11
12import (
13	"io"
14	"math"
15	"math/bits"
16
17	"google.golang.org/protobuf/internal/errors"
18)
19
20// Number represents the field number.
21type Number int32
22
23const (
24	MinValidNumber        Number = 1
25	FirstReservedNumber   Number = 19000
26	LastReservedNumber    Number = 19999
27	MaxValidNumber        Number = 1<<29 - 1
28	DefaultRecursionLimit        = 10000
29)
30
31// IsValid reports whether the field number is semantically valid.
32func (n Number) IsValid() bool {
33	return MinValidNumber <= n && n <= MaxValidNumber
34}
35
36// Type represents the wire type.
37type Type int8
38
39const (
40	VarintType     Type = 0
41	Fixed32Type    Type = 5
42	Fixed64Type    Type = 1
43	BytesType      Type = 2
44	StartGroupType Type = 3
45	EndGroupType   Type = 4
46)
47
48const (
49	_ = -iota
50	errCodeTruncated
51	errCodeFieldNumber
52	errCodeOverflow
53	errCodeReserved
54	errCodeEndGroup
55	errCodeRecursionDepth
56)
57
58var (
59	errFieldNumber = errors.New("invalid field number")
60	errOverflow    = errors.New("variable length integer overflow")
61	errReserved    = errors.New("cannot parse reserved wire type")
62	errEndGroup    = errors.New("mismatching end group marker")
63	errParse       = errors.New("parse error")
64)
65
66// ParseError converts an error code into an error value.
67// This returns nil if n is a non-negative number.
68func ParseError(n int) error {
69	if n >= 0 {
70		return nil
71	}
72	switch n {
73	case errCodeTruncated:
74		return io.ErrUnexpectedEOF
75	case errCodeFieldNumber:
76		return errFieldNumber
77	case errCodeOverflow:
78		return errOverflow
79	case errCodeReserved:
80		return errReserved
81	case errCodeEndGroup:
82		return errEndGroup
83	default:
84		return errParse
85	}
86}
87
88// ConsumeField parses an entire field record (both tag and value) and returns
89// the field number, the wire type, and the total length.
90// This returns a negative length upon an error (see ParseError).
91//
92// The total length includes the tag header and the end group marker (if the
93// field is a group).
94func ConsumeField(b []byte) (Number, Type, int) {
95	num, typ, n := ConsumeTag(b)
96	if n < 0 {
97		return 0, 0, n // forward error code
98	}
99	m := ConsumeFieldValue(num, typ, b[n:])
100	if m < 0 {
101		return 0, 0, m // forward error code
102	}
103	return num, typ, n + m
104}
105
106// ConsumeFieldValue parses a field value and returns its length.
107// This assumes that the field Number and wire Type have already been parsed.
108// This returns a negative length upon an error (see ParseError).
109//
110// When parsing a group, the length includes the end group marker and
111// the end group is verified to match the starting field number.
112func ConsumeFieldValue(num Number, typ Type, b []byte) (n int) {
113	return consumeFieldValueD(num, typ, b, DefaultRecursionLimit)
114}
115
116func consumeFieldValueD(num Number, typ Type, b []byte, depth int) (n int) {
117	switch typ {
118	case VarintType:
119		_, n = ConsumeVarint(b)
120		return n
121	case Fixed32Type:
122		_, n = ConsumeFixed32(b)
123		return n
124	case Fixed64Type:
125		_, n = ConsumeFixed64(b)
126		return n
127	case BytesType:
128		_, n = ConsumeBytes(b)
129		return n
130	case StartGroupType:
131		if depth < 0 {
132			return errCodeRecursionDepth
133		}
134		n0 := len(b)
135		for {
136			num2, typ2, n := ConsumeTag(b)
137			if n < 0 {
138				return n // forward error code
139			}
140			b = b[n:]
141			if typ2 == EndGroupType {
142				if num != num2 {
143					return errCodeEndGroup
144				}
145				return n0 - len(b)
146			}
147
148			n = consumeFieldValueD(num2, typ2, b, depth-1)
149			if n < 0 {
150				return n // forward error code
151			}
152			b = b[n:]
153		}
154	case EndGroupType:
155		return errCodeEndGroup
156	default:
157		return errCodeReserved
158	}
159}
160
161// AppendTag encodes num and typ as a varint-encoded tag and appends it to b.
162func AppendTag(b []byte, num Number, typ Type) []byte {
163	return AppendVarint(b, EncodeTag(num, typ))
164}
165
166// ConsumeTag parses b as a varint-encoded tag, reporting its length.
167// This returns a negative length upon an error (see ParseError).
168func ConsumeTag(b []byte) (Number, Type, int) {
169	v, n := ConsumeVarint(b)
170	if n < 0 {
171		return 0, 0, n // forward error code
172	}
173	num, typ := DecodeTag(v)
174	if num < MinValidNumber {
175		return 0, 0, errCodeFieldNumber
176	}
177	return num, typ, n
178}
179
180func SizeTag(num Number) int {
181	return SizeVarint(EncodeTag(num, 0)) // wire type has no effect on size
182}
183
184// AppendVarint appends v to b as a varint-encoded uint64.
185func AppendVarint(b []byte, v uint64) []byte {
186	switch {
187	case v < 1<<7:
188		b = append(b, byte(v))
189	case v < 1<<14:
190		b = append(b,
191			byte((v>>0)&0x7f|0x80),
192			byte(v>>7))
193	case v < 1<<21:
194		b = append(b,
195			byte((v>>0)&0x7f|0x80),
196			byte((v>>7)&0x7f|0x80),
197			byte(v>>14))
198	case v < 1<<28:
199		b = append(b,
200			byte((v>>0)&0x7f|0x80),
201			byte((v>>7)&0x7f|0x80),
202			byte((v>>14)&0x7f|0x80),
203			byte(v>>21))
204	case v < 1<<35:
205		b = append(b,
206			byte((v>>0)&0x7f|0x80),
207			byte((v>>7)&0x7f|0x80),
208			byte((v>>14)&0x7f|0x80),
209			byte((v>>21)&0x7f|0x80),
210			byte(v>>28))
211	case v < 1<<42:
212		b = append(b,
213			byte((v>>0)&0x7f|0x80),
214			byte((v>>7)&0x7f|0x80),
215			byte((v>>14)&0x7f|0x80),
216			byte((v>>21)&0x7f|0x80),
217			byte((v>>28)&0x7f|0x80),
218			byte(v>>35))
219	case v < 1<<49:
220		b = append(b,
221			byte((v>>0)&0x7f|0x80),
222			byte((v>>7)&0x7f|0x80),
223			byte((v>>14)&0x7f|0x80),
224			byte((v>>21)&0x7f|0x80),
225			byte((v>>28)&0x7f|0x80),
226			byte((v>>35)&0x7f|0x80),
227			byte(v>>42))
228	case v < 1<<56:
229		b = append(b,
230			byte((v>>0)&0x7f|0x80),
231			byte((v>>7)&0x7f|0x80),
232			byte((v>>14)&0x7f|0x80),
233			byte((v>>21)&0x7f|0x80),
234			byte((v>>28)&0x7f|0x80),
235			byte((v>>35)&0x7f|0x80),
236			byte((v>>42)&0x7f|0x80),
237			byte(v>>49))
238	case v < 1<<63:
239		b = append(b,
240			byte((v>>0)&0x7f|0x80),
241			byte((v>>7)&0x7f|0x80),
242			byte((v>>14)&0x7f|0x80),
243			byte((v>>21)&0x7f|0x80),
244			byte((v>>28)&0x7f|0x80),
245			byte((v>>35)&0x7f|0x80),
246			byte((v>>42)&0x7f|0x80),
247			byte((v>>49)&0x7f|0x80),
248			byte(v>>56))
249	default:
250		b = append(b,
251			byte((v>>0)&0x7f|0x80),
252			byte((v>>7)&0x7f|0x80),
253			byte((v>>14)&0x7f|0x80),
254			byte((v>>21)&0x7f|0x80),
255			byte((v>>28)&0x7f|0x80),
256			byte((v>>35)&0x7f|0x80),
257			byte((v>>42)&0x7f|0x80),
258			byte((v>>49)&0x7f|0x80),
259			byte((v>>56)&0x7f|0x80),
260			1)
261	}
262	return b
263}
264
265// ConsumeVarint parses b as a varint-encoded uint64, reporting its length.
266// This returns a negative length upon an error (see ParseError).
267func ConsumeVarint(b []byte) (v uint64, n int) {
268	var y uint64
269	if len(b) <= 0 {
270		return 0, errCodeTruncated
271	}
272	v = uint64(b[0])
273	if v < 0x80 {
274		return v, 1
275	}
276	v -= 0x80
277
278	if len(b) <= 1 {
279		return 0, errCodeTruncated
280	}
281	y = uint64(b[1])
282	v += y << 7
283	if y < 0x80 {
284		return v, 2
285	}
286	v -= 0x80 << 7
287
288	if len(b) <= 2 {
289		return 0, errCodeTruncated
290	}
291	y = uint64(b[2])
292	v += y << 14
293	if y < 0x80 {
294		return v, 3
295	}
296	v -= 0x80 << 14
297
298	if len(b) <= 3 {
299		return 0, errCodeTruncated
300	}
301	y = uint64(b[3])
302	v += y << 21
303	if y < 0x80 {
304		return v, 4
305	}
306	v -= 0x80 << 21
307
308	if len(b) <= 4 {
309		return 0, errCodeTruncated
310	}
311	y = uint64(b[4])
312	v += y << 28
313	if y < 0x80 {
314		return v, 5
315	}
316	v -= 0x80 << 28
317
318	if len(b) <= 5 {
319		return 0, errCodeTruncated
320	}
321	y = uint64(b[5])
322	v += y << 35
323	if y < 0x80 {
324		return v, 6
325	}
326	v -= 0x80 << 35
327
328	if len(b) <= 6 {
329		return 0, errCodeTruncated
330	}
331	y = uint64(b[6])
332	v += y << 42
333	if y < 0x80 {
334		return v, 7
335	}
336	v -= 0x80 << 42
337
338	if len(b) <= 7 {
339		return 0, errCodeTruncated
340	}
341	y = uint64(b[7])
342	v += y << 49
343	if y < 0x80 {
344		return v, 8
345	}
346	v -= 0x80 << 49
347
348	if len(b) <= 8 {
349		return 0, errCodeTruncated
350	}
351	y = uint64(b[8])
352	v += y << 56
353	if y < 0x80 {
354		return v, 9
355	}
356	v -= 0x80 << 56
357
358	if len(b) <= 9 {
359		return 0, errCodeTruncated
360	}
361	y = uint64(b[9])
362	v += y << 63
363	if y < 2 {
364		return v, 10
365	}
366	return 0, errCodeOverflow
367}
368
369// SizeVarint returns the encoded size of a varint.
370// The size is guaranteed to be within 1 and 10, inclusive.
371func SizeVarint(v uint64) int {
372	// This computes 1 + (bits.Len64(v)-1)/7.
373	// 9/64 is a good enough approximation of 1/7
374	return int(9*uint32(bits.Len64(v))+64) / 64
375}
376
377// AppendFixed32 appends v to b as a little-endian uint32.
378func AppendFixed32(b []byte, v uint32) []byte {
379	return append(b,
380		byte(v>>0),
381		byte(v>>8),
382		byte(v>>16),
383		byte(v>>24))
384}
385
386// ConsumeFixed32 parses b as a little-endian uint32, reporting its length.
387// This returns a negative length upon an error (see ParseError).
388func ConsumeFixed32(b []byte) (v uint32, n int) {
389	if len(b) < 4 {
390		return 0, errCodeTruncated
391	}
392	v = uint32(b[0])<<0 | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
393	return v, 4
394}
395
396// SizeFixed32 returns the encoded size of a fixed32; which is always 4.
397func SizeFixed32() int {
398	return 4
399}
400
401// AppendFixed64 appends v to b as a little-endian uint64.
402func AppendFixed64(b []byte, v uint64) []byte {
403	return append(b,
404		byte(v>>0),
405		byte(v>>8),
406		byte(v>>16),
407		byte(v>>24),
408		byte(v>>32),
409		byte(v>>40),
410		byte(v>>48),
411		byte(v>>56))
412}
413
414// ConsumeFixed64 parses b as a little-endian uint64, reporting its length.
415// This returns a negative length upon an error (see ParseError).
416func ConsumeFixed64(b []byte) (v uint64, n int) {
417	if len(b) < 8 {
418		return 0, errCodeTruncated
419	}
420	v = uint64(b[0])<<0 | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 | uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56
421	return v, 8
422}
423
424// SizeFixed64 returns the encoded size of a fixed64; which is always 8.
425func SizeFixed64() int {
426	return 8
427}
428
429// AppendBytes appends v to b as a length-prefixed bytes value.
430func AppendBytes(b []byte, v []byte) []byte {
431	return append(AppendVarint(b, uint64(len(v))), v...)
432}
433
434// ConsumeBytes parses b as a length-prefixed bytes value, reporting its length.
435// This returns a negative length upon an error (see ParseError).
436func ConsumeBytes(b []byte) (v []byte, n int) {
437	m, n := ConsumeVarint(b)
438	if n < 0 {
439		return nil, n // forward error code
440	}
441	if m > uint64(len(b[n:])) {
442		return nil, errCodeTruncated
443	}
444	return b[n:][:m], n + int(m)
445}
446
447// SizeBytes returns the encoded size of a length-prefixed bytes value,
448// given only the length.
449func SizeBytes(n int) int {
450	return SizeVarint(uint64(n)) + n
451}
452
453// AppendString appends v to b as a length-prefixed bytes value.
454func AppendString(b []byte, v string) []byte {
455	return append(AppendVarint(b, uint64(len(v))), v...)
456}
457
458// ConsumeString parses b as a length-prefixed bytes value, reporting its length.
459// This returns a negative length upon an error (see ParseError).
460func ConsumeString(b []byte) (v string, n int) {
461	bb, n := ConsumeBytes(b)
462	return string(bb), n
463}
464
465// AppendGroup appends v to b as group value, with a trailing end group marker.
466// The value v must not contain the end marker.
467func AppendGroup(b []byte, num Number, v []byte) []byte {
468	return AppendVarint(append(b, v...), EncodeTag(num, EndGroupType))
469}
470
471// ConsumeGroup parses b as a group value until the trailing end group marker,
472// and verifies that the end marker matches the provided num. The value v
473// does not contain the end marker, while the length does contain the end marker.
474// This returns a negative length upon an error (see ParseError).
475func ConsumeGroup(num Number, b []byte) (v []byte, n int) {
476	n = ConsumeFieldValue(num, StartGroupType, b)
477	if n < 0 {
478		return nil, n // forward error code
479	}
480	b = b[:n]
481
482	// Truncate off end group marker, but need to handle denormalized varints.
483	// Assuming end marker is never 0 (which is always the case since
484	// EndGroupType is non-zero), we can truncate all trailing bytes where the
485	// lower 7 bits are all zero (implying that the varint is denormalized).
486	for len(b) > 0 && b[len(b)-1]&0x7f == 0 {
487		b = b[:len(b)-1]
488	}
489	b = b[:len(b)-SizeTag(num)]
490	return b, n
491}
492
493// SizeGroup returns the encoded size of a group, given only the length.
494func SizeGroup(num Number, n int) int {
495	return n + SizeTag(num)
496}
497
498// DecodeTag decodes the field Number and wire Type from its unified form.
499// The Number is -1 if the decoded field number overflows int32.
500// Other than overflow, this does not check for field number validity.
501func DecodeTag(x uint64) (Number, Type) {
502	// NOTE: MessageSet allows for larger field numbers than normal.
503	if x>>3 > uint64(math.MaxInt32) {
504		return -1, 0
505	}
506	return Number(x >> 3), Type(x & 7)
507}
508
509// EncodeTag encodes the field Number and wire Type into its unified form.
510func EncodeTag(num Number, typ Type) uint64 {
511	return uint64(num)<<3 | uint64(typ&7)
512}
513
514// DecodeZigZag decodes a zig-zag-encoded uint64 as an int64.
515//
516//	Input:  {…,  5,  3,  1,  0,  2,  4,  6, …}
517//	Output: {…, -3, -2, -1,  0, +1, +2, +3, …}
518func DecodeZigZag(x uint64) int64 {
519	return int64(x>>1) ^ int64(x)<<63>>63
520}
521
522// EncodeZigZag encodes an int64 as a zig-zag-encoded uint64.
523//
524//	Input:  {…, -3, -2, -1,  0, +1, +2, +3, …}
525//	Output: {…,  5,  3,  1,  0,  2,  4,  6, …}
526func EncodeZigZag(x int64) uint64 {
527	return uint64(x<<1) ^ uint64(x>>63)
528}
529
530// DecodeBool decodes a uint64 as a bool.
531//
532//	Input:  {    0,    1,    2, …}
533//	Output: {false, true, true, …}
534func DecodeBool(x uint64) bool {
535	return x != 0
536}
537
538// EncodeBool encodes a bool as a uint64.
539//
540//	Input:  {false, true}
541//	Output: {    0,    1}
542func EncodeBool(x bool) uint64 {
543	if x {
544		return 1
545	}
546	return 0
547}
548