1// Go support for Protocol Buffers - Google's data interchange format 2// 3// Copyright 2010 The Go Authors. All rights reserved. 4// https://github.com/golang/protobuf 5// 6// Redistribution and use in source and binary forms, with or without 7// modification, are permitted provided that the following conditions are 8// met: 9// 10// * Redistributions of source code must retain the above copyright 11// notice, this list of conditions and the following disclaimer. 12// * Redistributions in binary form must reproduce the above 13// copyright notice, this list of conditions and the following disclaimer 14// in the documentation and/or other materials provided with the 15// distribution. 16// * Neither the name of Google Inc. nor the names of its 17// contributors may be used to endorse or promote products derived from 18// this software without specific prior written permission. 19// 20// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 32package proto 33 34/* 35 * Types and routines for supporting protocol buffer extensions. 36 */ 37 38import ( 39 "errors" 40 "fmt" 41 "io" 42 "reflect" 43 "strconv" 44 "sync" 45) 46 47// ErrMissingExtension is the error returned by GetExtension if the named extension is not in the message. 48var ErrMissingExtension = errors.New("proto: missing extension") 49 50// ExtensionRange represents a range of message extensions for a protocol buffer. 51// Used in code generated by the protocol compiler. 52type ExtensionRange struct { 53 Start, End int32 // both inclusive 54} 55 56// extendableProto is an interface implemented by any protocol buffer generated by the current 57// proto compiler that may be extended. 58type extendableProto interface { 59 Message 60 ExtensionRangeArray() []ExtensionRange 61 extensionsWrite() map[int32]Extension 62 extensionsRead() (map[int32]Extension, sync.Locker) 63} 64 65// extendableProtoV1 is an interface implemented by a protocol buffer generated by the previous 66// version of the proto compiler that may be extended. 67type extendableProtoV1 interface { 68 Message 69 ExtensionRangeArray() []ExtensionRange 70 ExtensionMap() map[int32]Extension 71} 72 73// extensionAdapter is a wrapper around extendableProtoV1 that implements extendableProto. 74type extensionAdapter struct { 75 extendableProtoV1 76} 77 78func (e extensionAdapter) extensionsWrite() map[int32]Extension { 79 return e.ExtensionMap() 80} 81 82func (e extensionAdapter) extensionsRead() (map[int32]Extension, sync.Locker) { 83 return e.ExtensionMap(), notLocker{} 84} 85 86// notLocker is a sync.Locker whose Lock and Unlock methods are nops. 87type notLocker struct{} 88 89func (n notLocker) Lock() {} 90func (n notLocker) Unlock() {} 91 92// extendable returns the extendableProto interface for the given generated proto message. 93// If the proto message has the old extension format, it returns a wrapper that implements 94// the extendableProto interface. 95func extendable(p interface{}) (extendableProto, error) { 96 switch p := p.(type) { 97 case extendableProto: 98 if isNilPtr(p) { 99 return nil, fmt.Errorf("proto: nil %T is not extendable", p) 100 } 101 return p, nil 102 case extendableProtoV1: 103 if isNilPtr(p) { 104 return nil, fmt.Errorf("proto: nil %T is not extendable", p) 105 } 106 return extensionAdapter{p}, nil 107 } 108 // Don't allocate a specific error containing %T: 109 // this is the hot path for Clone and MarshalText. 110 return nil, errNotExtendable 111} 112 113var errNotExtendable = errors.New("proto: not an extendable proto.Message") 114 115func isNilPtr(x interface{}) bool { 116 v := reflect.ValueOf(x) 117 return v.Kind() == reflect.Ptr && v.IsNil() 118} 119 120// XXX_InternalExtensions is an internal representation of proto extensions. 121// 122// Each generated message struct type embeds an anonymous XXX_InternalExtensions field, 123// thus gaining the unexported 'extensions' method, which can be called only from the proto package. 124// 125// The methods of XXX_InternalExtensions are not concurrency safe in general, 126// but calls to logically read-only methods such as has and get may be executed concurrently. 127type XXX_InternalExtensions struct { 128 // The struct must be indirect so that if a user inadvertently copies a 129 // generated message and its embedded XXX_InternalExtensions, they 130 // avoid the mayhem of a copied mutex. 131 // 132 // The mutex serializes all logically read-only operations to p.extensionMap. 133 // It is up to the client to ensure that write operations to p.extensionMap are 134 // mutually exclusive with other accesses. 135 p *struct { 136 mu sync.Mutex 137 extensionMap map[int32]Extension 138 } 139} 140 141// extensionsWrite returns the extension map, creating it on first use. 142func (e *XXX_InternalExtensions) extensionsWrite() map[int32]Extension { 143 if e.p == nil { 144 e.p = new(struct { 145 mu sync.Mutex 146 extensionMap map[int32]Extension 147 }) 148 e.p.extensionMap = make(map[int32]Extension) 149 } 150 return e.p.extensionMap 151} 152 153// extensionsRead returns the extensions map for read-only use. It may be nil. 154// The caller must hold the returned mutex's lock when accessing Elements within the map. 155func (e *XXX_InternalExtensions) extensionsRead() (map[int32]Extension, sync.Locker) { 156 if e.p == nil { 157 return nil, nil 158 } 159 return e.p.extensionMap, &e.p.mu 160} 161 162// ExtensionDesc represents an extension specification. 163// Used in generated code from the protocol compiler. 164type ExtensionDesc struct { 165 ExtendedType Message // nil pointer to the type that is being extended 166 ExtensionType interface{} // nil pointer to the extension type 167 Field int32 // field number 168 Name string // fully-qualified name of extension, for text formatting 169 Tag string // protobuf tag style 170 Filename string // name of the file in which the extension is defined 171} 172 173func (ed *ExtensionDesc) repeated() bool { 174 t := reflect.TypeOf(ed.ExtensionType) 175 return t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8 176} 177 178// Extension represents an extension in a message. 179type Extension struct { 180 // When an extension is stored in a message using SetExtension 181 // only desc and value are set. When the message is marshaled 182 // enc will be set to the encoded form of the message. 183 // 184 // When a message is unmarshaled and contains extensions, each 185 // extension will have only enc set. When such an extension is 186 // accessed using GetExtension (or GetExtensions) desc and value 187 // will be set. 188 desc *ExtensionDesc 189 value interface{} 190 enc []byte 191} 192 193// SetRawExtension is for testing only. 194func SetRawExtension(base Message, id int32, b []byte) { 195 epb, err := extendable(base) 196 if err != nil { 197 return 198 } 199 extmap := epb.extensionsWrite() 200 extmap[id] = Extension{enc: b} 201} 202 203// isExtensionField returns true iff the given field number is in an extension range. 204func isExtensionField(pb extendableProto, field int32) bool { 205 for _, er := range pb.ExtensionRangeArray() { 206 if er.Start <= field && field <= er.End { 207 return true 208 } 209 } 210 return false 211} 212 213// checkExtensionTypes checks that the given extension is valid for pb. 214func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) error { 215 var pbi interface{} = pb 216 // Check the extended type. 217 if ea, ok := pbi.(extensionAdapter); ok { 218 pbi = ea.extendableProtoV1 219 } 220 if a, b := reflect.TypeOf(pbi), reflect.TypeOf(extension.ExtendedType); a != b { 221 return fmt.Errorf("proto: bad extended type; %v does not extend %v", b, a) 222 } 223 // Check the range. 224 if !isExtensionField(pb, extension.Field) { 225 return errors.New("proto: bad extension number; not in declared ranges") 226 } 227 return nil 228} 229 230// extPropKey is sufficient to uniquely identify an extension. 231type extPropKey struct { 232 base reflect.Type 233 field int32 234} 235 236var extProp = struct { 237 sync.RWMutex 238 m map[extPropKey]*Properties 239}{ 240 m: make(map[extPropKey]*Properties), 241} 242 243func extensionProperties(ed *ExtensionDesc) *Properties { 244 key := extPropKey{base: reflect.TypeOf(ed.ExtendedType), field: ed.Field} 245 246 extProp.RLock() 247 if prop, ok := extProp.m[key]; ok { 248 extProp.RUnlock() 249 return prop 250 } 251 extProp.RUnlock() 252 253 extProp.Lock() 254 defer extProp.Unlock() 255 // Check again. 256 if prop, ok := extProp.m[key]; ok { 257 return prop 258 } 259 260 prop := new(Properties) 261 prop.Init(reflect.TypeOf(ed.ExtensionType), "unknown_name", ed.Tag, nil) 262 extProp.m[key] = prop 263 return prop 264} 265 266// HasExtension returns whether the given extension is present in pb. 267func HasExtension(pb Message, extension *ExtensionDesc) bool { 268 // TODO: Check types, field numbers, etc.? 269 epb, err := extendable(pb) 270 if err != nil { 271 return false 272 } 273 extmap, mu := epb.extensionsRead() 274 if extmap == nil { 275 return false 276 } 277 mu.Lock() 278 _, ok := extmap[extension.Field] 279 mu.Unlock() 280 return ok 281} 282 283// ClearExtension removes the given extension from pb. 284func ClearExtension(pb Message, extension *ExtensionDesc) { 285 epb, err := extendable(pb) 286 if err != nil { 287 return 288 } 289 // TODO: Check types, field numbers, etc.? 290 extmap := epb.extensionsWrite() 291 delete(extmap, extension.Field) 292} 293 294// GetExtension retrieves a proto2 extended field from pb. 295// 296// If the descriptor is type complete (i.e., ExtensionDesc.ExtensionType is non-nil), 297// then GetExtension parses the encoded field and returns a Go value of the specified type. 298// If the field is not present, then the default value is returned (if one is specified), 299// otherwise ErrMissingExtension is reported. 300// 301// If the descriptor is not type complete (i.e., ExtensionDesc.ExtensionType is nil), 302// then GetExtension returns the raw encoded bytes of the field extension. 303func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) { 304 epb, err := extendable(pb) 305 if err != nil { 306 return nil, err 307 } 308 309 if extension.ExtendedType != nil { 310 // can only check type if this is a complete descriptor 311 if err := checkExtensionTypes(epb, extension); err != nil { 312 return nil, err 313 } 314 } 315 316 emap, mu := epb.extensionsRead() 317 if emap == nil { 318 return defaultExtensionValue(extension) 319 } 320 mu.Lock() 321 defer mu.Unlock() 322 e, ok := emap[extension.Field] 323 if !ok { 324 // defaultExtensionValue returns the default value or 325 // ErrMissingExtension if there is no default. 326 return defaultExtensionValue(extension) 327 } 328 329 if e.value != nil { 330 // Already decoded. Check the descriptor, though. 331 if e.desc != extension { 332 // This shouldn't happen. If it does, it means that 333 // GetExtension was called twice with two different 334 // descriptors with the same field number. 335 return nil, errors.New("proto: descriptor conflict") 336 } 337 return e.value, nil 338 } 339 340 if extension.ExtensionType == nil { 341 // incomplete descriptor 342 return e.enc, nil 343 } 344 345 v, err := decodeExtension(e.enc, extension) 346 if err != nil { 347 return nil, err 348 } 349 350 // Remember the decoded version and drop the encoded version. 351 // That way it is safe to mutate what we return. 352 e.value = v 353 e.desc = extension 354 e.enc = nil 355 emap[extension.Field] = e 356 return e.value, nil 357} 358 359// defaultExtensionValue returns the default value for extension. 360// If no default for an extension is defined ErrMissingExtension is returned. 361func defaultExtensionValue(extension *ExtensionDesc) (interface{}, error) { 362 if extension.ExtensionType == nil { 363 // incomplete descriptor, so no default 364 return nil, ErrMissingExtension 365 } 366 367 t := reflect.TypeOf(extension.ExtensionType) 368 props := extensionProperties(extension) 369 370 sf, _, err := fieldDefault(t, props) 371 if err != nil { 372 return nil, err 373 } 374 375 if sf == nil || sf.value == nil { 376 // There is no default value. 377 return nil, ErrMissingExtension 378 } 379 380 if t.Kind() != reflect.Ptr { 381 // We do not need to return a Ptr, we can directly return sf.value. 382 return sf.value, nil 383 } 384 385 // We need to return an interface{} that is a pointer to sf.value. 386 value := reflect.New(t).Elem() 387 value.Set(reflect.New(value.Type().Elem())) 388 if sf.kind == reflect.Int32 { 389 // We may have an int32 or an enum, but the underlying data is int32. 390 // Since we can't set an int32 into a non int32 reflect.value directly 391 // set it as a int32. 392 value.Elem().SetInt(int64(sf.value.(int32))) 393 } else { 394 value.Elem().Set(reflect.ValueOf(sf.value)) 395 } 396 return value.Interface(), nil 397} 398 399// decodeExtension decodes an extension encoded in b. 400func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) { 401 t := reflect.TypeOf(extension.ExtensionType) 402 unmarshal := typeUnmarshaler(t, extension.Tag) 403 404 // t is a pointer to a struct, pointer to basic type or a slice. 405 // Allocate space to store the pointer/slice. 406 value := reflect.New(t).Elem() 407 408 var err error 409 for { 410 x, n := decodeVarint(b) 411 if n == 0 { 412 return nil, io.ErrUnexpectedEOF 413 } 414 b = b[n:] 415 wire := int(x) & 7 416 417 b, err = unmarshal(b, valToPointer(value.Addr()), wire) 418 if err != nil { 419 return nil, err 420 } 421 422 if len(b) == 0 { 423 break 424 } 425 } 426 return value.Interface(), nil 427} 428 429// GetExtensions returns a slice of the extensions present in pb that are also listed in es. 430// The returned slice has the same length as es; missing extensions will appear as nil elements. 431func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) { 432 epb, err := extendable(pb) 433 if err != nil { 434 return nil, err 435 } 436 extensions = make([]interface{}, len(es)) 437 for i, e := range es { 438 extensions[i], err = GetExtension(epb, e) 439 if err == ErrMissingExtension { 440 err = nil 441 } 442 if err != nil { 443 return 444 } 445 } 446 return 447} 448 449// ExtensionDescs returns a new slice containing pb's extension descriptors, in undefined order. 450// For non-registered extensions, ExtensionDescs returns an incomplete descriptor containing 451// just the Field field, which defines the extension's field number. 452func ExtensionDescs(pb Message) ([]*ExtensionDesc, error) { 453 epb, err := extendable(pb) 454 if err != nil { 455 return nil, err 456 } 457 registeredExtensions := RegisteredExtensions(pb) 458 459 emap, mu := epb.extensionsRead() 460 if emap == nil { 461 return nil, nil 462 } 463 mu.Lock() 464 defer mu.Unlock() 465 extensions := make([]*ExtensionDesc, 0, len(emap)) 466 for extid, e := range emap { 467 desc := e.desc 468 if desc == nil { 469 desc = registeredExtensions[extid] 470 if desc == nil { 471 desc = &ExtensionDesc{Field: extid} 472 } 473 } 474 475 extensions = append(extensions, desc) 476 } 477 return extensions, nil 478} 479 480// SetExtension sets the specified extension of pb to the specified value. 481func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error { 482 epb, err := extendable(pb) 483 if err != nil { 484 return err 485 } 486 if err := checkExtensionTypes(epb, extension); err != nil { 487 return err 488 } 489 typ := reflect.TypeOf(extension.ExtensionType) 490 if typ != reflect.TypeOf(value) { 491 return errors.New("proto: bad extension value type") 492 } 493 // nil extension values need to be caught early, because the 494 // encoder can't distinguish an ErrNil due to a nil extension 495 // from an ErrNil due to a missing field. Extensions are 496 // always optional, so the encoder would just swallow the error 497 // and drop all the extensions from the encoded message. 498 if reflect.ValueOf(value).IsNil() { 499 return fmt.Errorf("proto: SetExtension called with nil value of type %T", value) 500 } 501 502 extmap := epb.extensionsWrite() 503 extmap[extension.Field] = Extension{desc: extension, value: value} 504 return nil 505} 506 507// ClearAllExtensions clears all extensions from pb. 508func ClearAllExtensions(pb Message) { 509 epb, err := extendable(pb) 510 if err != nil { 511 return 512 } 513 m := epb.extensionsWrite() 514 for k := range m { 515 delete(m, k) 516 } 517} 518 519// A global registry of extensions. 520// The generated code will register the generated descriptors by calling RegisterExtension. 521 522var extensionMaps = make(map[reflect.Type]map[int32]*ExtensionDesc) 523 524// RegisterExtension is called from the generated code. 525func RegisterExtension(desc *ExtensionDesc) { 526 st := reflect.TypeOf(desc.ExtendedType).Elem() 527 m := extensionMaps[st] 528 if m == nil { 529 m = make(map[int32]*ExtensionDesc) 530 extensionMaps[st] = m 531 } 532 if _, ok := m[desc.Field]; ok { 533 panic("proto: duplicate extension registered: " + st.String() + " " + strconv.Itoa(int(desc.Field))) 534 } 535 m[desc.Field] = desc 536} 537 538// RegisteredExtensions returns a map of the registered extensions of a 539// protocol buffer struct, indexed by the extension number. 540// The argument pb should be a nil pointer to the struct type. 541func RegisteredExtensions(pb Message) map[int32]*ExtensionDesc { 542 return extensionMaps[reflect.TypeOf(pb).Elem()] 543} 544