1// Copyright 2018 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5// Package protowire parses and formats the raw wire encoding. 6// See https://protobuf.dev/programming-guides/encoding. 7// 8// For marshaling and unmarshaling entire protobuf messages, 9// use the "google.golang.org/protobuf/proto" package instead. 10package protowire 11 12import ( 13 "io" 14 "math" 15 "math/bits" 16 17 "google.golang.org/protobuf/internal/errors" 18) 19 20// Number represents the field number. 21type Number int32 22 23const ( 24 MinValidNumber Number = 1 25 FirstReservedNumber Number = 19000 26 LastReservedNumber Number = 19999 27 MaxValidNumber Number = 1<<29 - 1 28 DefaultRecursionLimit = 10000 29) 30 31// IsValid reports whether the field number is semantically valid. 32func (n Number) IsValid() bool { 33 return MinValidNumber <= n && n <= MaxValidNumber 34} 35 36// Type represents the wire type. 37type Type int8 38 39const ( 40 VarintType Type = 0 41 Fixed32Type Type = 5 42 Fixed64Type Type = 1 43 BytesType Type = 2 44 StartGroupType Type = 3 45 EndGroupType Type = 4 46) 47 48const ( 49 _ = -iota 50 errCodeTruncated 51 errCodeFieldNumber 52 errCodeOverflow 53 errCodeReserved 54 errCodeEndGroup 55 errCodeRecursionDepth 56) 57 58var ( 59 errFieldNumber = errors.New("invalid field number") 60 errOverflow = errors.New("variable length integer overflow") 61 errReserved = errors.New("cannot parse reserved wire type") 62 errEndGroup = errors.New("mismatching end group marker") 63 errParse = errors.New("parse error") 64) 65 66// ParseError converts an error code into an error value. 67// This returns nil if n is a non-negative number. 68func ParseError(n int) error { 69 if n >= 0 { 70 return nil 71 } 72 switch n { 73 case errCodeTruncated: 74 return io.ErrUnexpectedEOF 75 case errCodeFieldNumber: 76 return errFieldNumber 77 case errCodeOverflow: 78 return errOverflow 79 case errCodeReserved: 80 return errReserved 81 case errCodeEndGroup: 82 return errEndGroup 83 default: 84 return errParse 85 } 86} 87 88// ConsumeField parses an entire field record (both tag and value) and returns 89// the field number, the wire type, and the total length. 90// This returns a negative length upon an error (see ParseError). 91// 92// The total length includes the tag header and the end group marker (if the 93// field is a group). 94func ConsumeField(b []byte) (Number, Type, int) { 95 num, typ, n := ConsumeTag(b) 96 if n < 0 { 97 return 0, 0, n // forward error code 98 } 99 m := ConsumeFieldValue(num, typ, b[n:]) 100 if m < 0 { 101 return 0, 0, m // forward error code 102 } 103 return num, typ, n + m 104} 105 106// ConsumeFieldValue parses a field value and returns its length. 107// This assumes that the field Number and wire Type have already been parsed. 108// This returns a negative length upon an error (see ParseError). 109// 110// When parsing a group, the length includes the end group marker and 111// the end group is verified to match the starting field number. 112func ConsumeFieldValue(num Number, typ Type, b []byte) (n int) { 113 return consumeFieldValueD(num, typ, b, DefaultRecursionLimit) 114} 115 116func consumeFieldValueD(num Number, typ Type, b []byte, depth int) (n int) { 117 switch typ { 118 case VarintType: 119 _, n = ConsumeVarint(b) 120 return n 121 case Fixed32Type: 122 _, n = ConsumeFixed32(b) 123 return n 124 case Fixed64Type: 125 _, n = ConsumeFixed64(b) 126 return n 127 case BytesType: 128 _, n = ConsumeBytes(b) 129 return n 130 case StartGroupType: 131 if depth < 0 { 132 return errCodeRecursionDepth 133 } 134 n0 := len(b) 135 for { 136 num2, typ2, n := ConsumeTag(b) 137 if n < 0 { 138 return n // forward error code 139 } 140 b = b[n:] 141 if typ2 == EndGroupType { 142 if num != num2 { 143 return errCodeEndGroup 144 } 145 return n0 - len(b) 146 } 147 148 n = consumeFieldValueD(num2, typ2, b, depth-1) 149 if n < 0 { 150 return n // forward error code 151 } 152 b = b[n:] 153 } 154 case EndGroupType: 155 return errCodeEndGroup 156 default: 157 return errCodeReserved 158 } 159} 160 161// AppendTag encodes num and typ as a varint-encoded tag and appends it to b. 162func AppendTag(b []byte, num Number, typ Type) []byte { 163 return AppendVarint(b, EncodeTag(num, typ)) 164} 165 166// ConsumeTag parses b as a varint-encoded tag, reporting its length. 167// This returns a negative length upon an error (see ParseError). 168func ConsumeTag(b []byte) (Number, Type, int) { 169 v, n := ConsumeVarint(b) 170 if n < 0 { 171 return 0, 0, n // forward error code 172 } 173 num, typ := DecodeTag(v) 174 if num < MinValidNumber { 175 return 0, 0, errCodeFieldNumber 176 } 177 return num, typ, n 178} 179 180func SizeTag(num Number) int { 181 return SizeVarint(EncodeTag(num, 0)) // wire type has no effect on size 182} 183 184// AppendVarint appends v to b as a varint-encoded uint64. 185func AppendVarint(b []byte, v uint64) []byte { 186 switch { 187 case v < 1<<7: 188 b = append(b, byte(v)) 189 case v < 1<<14: 190 b = append(b, 191 byte((v>>0)&0x7f|0x80), 192 byte(v>>7)) 193 case v < 1<<21: 194 b = append(b, 195 byte((v>>0)&0x7f|0x80), 196 byte((v>>7)&0x7f|0x80), 197 byte(v>>14)) 198 case v < 1<<28: 199 b = append(b, 200 byte((v>>0)&0x7f|0x80), 201 byte((v>>7)&0x7f|0x80), 202 byte((v>>14)&0x7f|0x80), 203 byte(v>>21)) 204 case v < 1<<35: 205 b = append(b, 206 byte((v>>0)&0x7f|0x80), 207 byte((v>>7)&0x7f|0x80), 208 byte((v>>14)&0x7f|0x80), 209 byte((v>>21)&0x7f|0x80), 210 byte(v>>28)) 211 case v < 1<<42: 212 b = append(b, 213 byte((v>>0)&0x7f|0x80), 214 byte((v>>7)&0x7f|0x80), 215 byte((v>>14)&0x7f|0x80), 216 byte((v>>21)&0x7f|0x80), 217 byte((v>>28)&0x7f|0x80), 218 byte(v>>35)) 219 case v < 1<<49: 220 b = append(b, 221 byte((v>>0)&0x7f|0x80), 222 byte((v>>7)&0x7f|0x80), 223 byte((v>>14)&0x7f|0x80), 224 byte((v>>21)&0x7f|0x80), 225 byte((v>>28)&0x7f|0x80), 226 byte((v>>35)&0x7f|0x80), 227 byte(v>>42)) 228 case v < 1<<56: 229 b = append(b, 230 byte((v>>0)&0x7f|0x80), 231 byte((v>>7)&0x7f|0x80), 232 byte((v>>14)&0x7f|0x80), 233 byte((v>>21)&0x7f|0x80), 234 byte((v>>28)&0x7f|0x80), 235 byte((v>>35)&0x7f|0x80), 236 byte((v>>42)&0x7f|0x80), 237 byte(v>>49)) 238 case v < 1<<63: 239 b = append(b, 240 byte((v>>0)&0x7f|0x80), 241 byte((v>>7)&0x7f|0x80), 242 byte((v>>14)&0x7f|0x80), 243 byte((v>>21)&0x7f|0x80), 244 byte((v>>28)&0x7f|0x80), 245 byte((v>>35)&0x7f|0x80), 246 byte((v>>42)&0x7f|0x80), 247 byte((v>>49)&0x7f|0x80), 248 byte(v>>56)) 249 default: 250 b = append(b, 251 byte((v>>0)&0x7f|0x80), 252 byte((v>>7)&0x7f|0x80), 253 byte((v>>14)&0x7f|0x80), 254 byte((v>>21)&0x7f|0x80), 255 byte((v>>28)&0x7f|0x80), 256 byte((v>>35)&0x7f|0x80), 257 byte((v>>42)&0x7f|0x80), 258 byte((v>>49)&0x7f|0x80), 259 byte((v>>56)&0x7f|0x80), 260 1) 261 } 262 return b 263} 264 265// ConsumeVarint parses b as a varint-encoded uint64, reporting its length. 266// This returns a negative length upon an error (see ParseError). 267func ConsumeVarint(b []byte) (v uint64, n int) { 268 var y uint64 269 if len(b) <= 0 { 270 return 0, errCodeTruncated 271 } 272 v = uint64(b[0]) 273 if v < 0x80 { 274 return v, 1 275 } 276 v -= 0x80 277 278 if len(b) <= 1 { 279 return 0, errCodeTruncated 280 } 281 y = uint64(b[1]) 282 v += y << 7 283 if y < 0x80 { 284 return v, 2 285 } 286 v -= 0x80 << 7 287 288 if len(b) <= 2 { 289 return 0, errCodeTruncated 290 } 291 y = uint64(b[2]) 292 v += y << 14 293 if y < 0x80 { 294 return v, 3 295 } 296 v -= 0x80 << 14 297 298 if len(b) <= 3 { 299 return 0, errCodeTruncated 300 } 301 y = uint64(b[3]) 302 v += y << 21 303 if y < 0x80 { 304 return v, 4 305 } 306 v -= 0x80 << 21 307 308 if len(b) <= 4 { 309 return 0, errCodeTruncated 310 } 311 y = uint64(b[4]) 312 v += y << 28 313 if y < 0x80 { 314 return v, 5 315 } 316 v -= 0x80 << 28 317 318 if len(b) <= 5 { 319 return 0, errCodeTruncated 320 } 321 y = uint64(b[5]) 322 v += y << 35 323 if y < 0x80 { 324 return v, 6 325 } 326 v -= 0x80 << 35 327 328 if len(b) <= 6 { 329 return 0, errCodeTruncated 330 } 331 y = uint64(b[6]) 332 v += y << 42 333 if y < 0x80 { 334 return v, 7 335 } 336 v -= 0x80 << 42 337 338 if len(b) <= 7 { 339 return 0, errCodeTruncated 340 } 341 y = uint64(b[7]) 342 v += y << 49 343 if y < 0x80 { 344 return v, 8 345 } 346 v -= 0x80 << 49 347 348 if len(b) <= 8 { 349 return 0, errCodeTruncated 350 } 351 y = uint64(b[8]) 352 v += y << 56 353 if y < 0x80 { 354 return v, 9 355 } 356 v -= 0x80 << 56 357 358 if len(b) <= 9 { 359 return 0, errCodeTruncated 360 } 361 y = uint64(b[9]) 362 v += y << 63 363 if y < 2 { 364 return v, 10 365 } 366 return 0, errCodeOverflow 367} 368 369// SizeVarint returns the encoded size of a varint. 370// The size is guaranteed to be within 1 and 10, inclusive. 371func SizeVarint(v uint64) int { 372 // This computes 1 + (bits.Len64(v)-1)/7. 373 // 9/64 is a good enough approximation of 1/7 374 return int(9*uint32(bits.Len64(v))+64) / 64 375} 376 377// AppendFixed32 appends v to b as a little-endian uint32. 378func AppendFixed32(b []byte, v uint32) []byte { 379 return append(b, 380 byte(v>>0), 381 byte(v>>8), 382 byte(v>>16), 383 byte(v>>24)) 384} 385 386// ConsumeFixed32 parses b as a little-endian uint32, reporting its length. 387// This returns a negative length upon an error (see ParseError). 388func ConsumeFixed32(b []byte) (v uint32, n int) { 389 if len(b) < 4 { 390 return 0, errCodeTruncated 391 } 392 v = uint32(b[0])<<0 | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24 393 return v, 4 394} 395 396// SizeFixed32 returns the encoded size of a fixed32; which is always 4. 397func SizeFixed32() int { 398 return 4 399} 400 401// AppendFixed64 appends v to b as a little-endian uint64. 402func AppendFixed64(b []byte, v uint64) []byte { 403 return append(b, 404 byte(v>>0), 405 byte(v>>8), 406 byte(v>>16), 407 byte(v>>24), 408 byte(v>>32), 409 byte(v>>40), 410 byte(v>>48), 411 byte(v>>56)) 412} 413 414// ConsumeFixed64 parses b as a little-endian uint64, reporting its length. 415// This returns a negative length upon an error (see ParseError). 416func ConsumeFixed64(b []byte) (v uint64, n int) { 417 if len(b) < 8 { 418 return 0, errCodeTruncated 419 } 420 v = uint64(b[0])<<0 | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 | uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56 421 return v, 8 422} 423 424// SizeFixed64 returns the encoded size of a fixed64; which is always 8. 425func SizeFixed64() int { 426 return 8 427} 428 429// AppendBytes appends v to b as a length-prefixed bytes value. 430func AppendBytes(b []byte, v []byte) []byte { 431 return append(AppendVarint(b, uint64(len(v))), v...) 432} 433 434// ConsumeBytes parses b as a length-prefixed bytes value, reporting its length. 435// This returns a negative length upon an error (see ParseError). 436func ConsumeBytes(b []byte) (v []byte, n int) { 437 m, n := ConsumeVarint(b) 438 if n < 0 { 439 return nil, n // forward error code 440 } 441 if m > uint64(len(b[n:])) { 442 return nil, errCodeTruncated 443 } 444 return b[n:][:m], n + int(m) 445} 446 447// SizeBytes returns the encoded size of a length-prefixed bytes value, 448// given only the length. 449func SizeBytes(n int) int { 450 return SizeVarint(uint64(n)) + n 451} 452 453// AppendString appends v to b as a length-prefixed bytes value. 454func AppendString(b []byte, v string) []byte { 455 return append(AppendVarint(b, uint64(len(v))), v...) 456} 457 458// ConsumeString parses b as a length-prefixed bytes value, reporting its length. 459// This returns a negative length upon an error (see ParseError). 460func ConsumeString(b []byte) (v string, n int) { 461 bb, n := ConsumeBytes(b) 462 return string(bb), n 463} 464 465// AppendGroup appends v to b as group value, with a trailing end group marker. 466// The value v must not contain the end marker. 467func AppendGroup(b []byte, num Number, v []byte) []byte { 468 return AppendVarint(append(b, v...), EncodeTag(num, EndGroupType)) 469} 470 471// ConsumeGroup parses b as a group value until the trailing end group marker, 472// and verifies that the end marker matches the provided num. The value v 473// does not contain the end marker, while the length does contain the end marker. 474// This returns a negative length upon an error (see ParseError). 475func ConsumeGroup(num Number, b []byte) (v []byte, n int) { 476 n = ConsumeFieldValue(num, StartGroupType, b) 477 if n < 0 { 478 return nil, n // forward error code 479 } 480 b = b[:n] 481 482 // Truncate off end group marker, but need to handle denormalized varints. 483 // Assuming end marker is never 0 (which is always the case since 484 // EndGroupType is non-zero), we can truncate all trailing bytes where the 485 // lower 7 bits are all zero (implying that the varint is denormalized). 486 for len(b) > 0 && b[len(b)-1]&0x7f == 0 { 487 b = b[:len(b)-1] 488 } 489 b = b[:len(b)-SizeTag(num)] 490 return b, n 491} 492 493// SizeGroup returns the encoded size of a group, given only the length. 494func SizeGroup(num Number, n int) int { 495 return n + SizeTag(num) 496} 497 498// DecodeTag decodes the field Number and wire Type from its unified form. 499// The Number is -1 if the decoded field number overflows int32. 500// Other than overflow, this does not check for field number validity. 501func DecodeTag(x uint64) (Number, Type) { 502 // NOTE: MessageSet allows for larger field numbers than normal. 503 if x>>3 > uint64(math.MaxInt32) { 504 return -1, 0 505 } 506 return Number(x >> 3), Type(x & 7) 507} 508 509// EncodeTag encodes the field Number and wire Type into its unified form. 510func EncodeTag(num Number, typ Type) uint64 { 511 return uint64(num)<<3 | uint64(typ&7) 512} 513 514// DecodeZigZag decodes a zig-zag-encoded uint64 as an int64. 515// 516// Input: {…, 5, 3, 1, 0, 2, 4, 6, …} 517// Output: {…, -3, -2, -1, 0, +1, +2, +3, …} 518func DecodeZigZag(x uint64) int64 { 519 return int64(x>>1) ^ int64(x)<<63>>63 520} 521 522// EncodeZigZag encodes an int64 as a zig-zag-encoded uint64. 523// 524// Input: {…, -3, -2, -1, 0, +1, +2, +3, …} 525// Output: {…, 5, 3, 1, 0, 2, 4, 6, …} 526func EncodeZigZag(x int64) uint64 { 527 return uint64(x<<1) ^ uint64(x>>63) 528} 529 530// DecodeBool decodes a uint64 as a bool. 531// 532// Input: { 0, 1, 2, …} 533// Output: {false, true, true, …} 534func DecodeBool(x uint64) bool { 535 return x != 0 536} 537 538// EncodeBool encodes a bool as a uint64. 539// 540// Input: {false, true} 541// Output: { 0, 1} 542func EncodeBool(x bool) uint64 { 543 if x { 544 return 1 545 } 546 return 0 547} 548