1// Go support for Protocol Buffers - Google's data interchange format 2// 3// Copyright 2010 The Go Authors. All rights reserved. 4// https://github.com/golang/protobuf 5// 6// Redistribution and use in source and binary forms, with or without 7// modification, are permitted provided that the following conditions are 8// met: 9// 10// * Redistributions of source code must retain the above copyright 11// notice, this list of conditions and the following disclaimer. 12// * Redistributions in binary form must reproduce the above 13// copyright notice, this list of conditions and the following disclaimer 14// in the documentation and/or other materials provided with the 15// distribution. 16// * Neither the name of Google Inc. nor the names of its 17// contributors may be used to endorse or promote products derived from 18// this software without specific prior written permission. 19// 20// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 32package proto 33 34/* 35 * Types and routines for supporting protocol buffer extensions. 36 */ 37 38import ( 39 "errors" 40 "fmt" 41 "reflect" 42 "strconv" 43 "sync" 44) 45 46// ErrMissingExtension is the error returned by GetExtension if the named extension is not in the message. 47var ErrMissingExtension = errors.New("proto: missing extension") 48 49// ExtensionRange represents a range of message extensions for a protocol buffer. 50// Used in code generated by the protocol compiler. 51type ExtensionRange struct { 52 Start, End int32 // both inclusive 53} 54 55// extendableProto is an interface implemented by any protocol buffer generated by the current 56// proto compiler that may be extended. 57type extendableProto interface { 58 Message 59 ExtensionRangeArray() []ExtensionRange 60 extensionsWrite() map[int32]Extension 61 extensionsRead() (map[int32]Extension, sync.Locker) 62} 63 64// extendableProtoV1 is an interface implemented by a protocol buffer generated by the previous 65// version of the proto compiler that may be extended. 66type extendableProtoV1 interface { 67 Message 68 ExtensionRangeArray() []ExtensionRange 69 ExtensionMap() map[int32]Extension 70} 71 72// extensionAdapter is a wrapper around extendableProtoV1 that implements extendableProto. 73type extensionAdapter struct { 74 extendableProtoV1 75} 76 77func (e extensionAdapter) extensionsWrite() map[int32]Extension { 78 return e.ExtensionMap() 79} 80 81func (e extensionAdapter) extensionsRead() (map[int32]Extension, sync.Locker) { 82 return e.ExtensionMap(), notLocker{} 83} 84 85// notLocker is a sync.Locker whose Lock and Unlock methods are nops. 86type notLocker struct{} 87 88func (n notLocker) Lock() {} 89func (n notLocker) Unlock() {} 90 91// extendable returns the extendableProto interface for the given generated proto message. 92// If the proto message has the old extension format, it returns a wrapper that implements 93// the extendableProto interface. 94func extendable(p interface{}) (extendableProto, bool) { 95 if ep, ok := p.(extendableProto); ok { 96 return ep, ok 97 } 98 if ep, ok := p.(extendableProtoV1); ok { 99 return extensionAdapter{ep}, ok 100 } 101 return nil, false 102} 103 104// XXX_InternalExtensions is an internal representation of proto extensions. 105// 106// Each generated message struct type embeds an anonymous XXX_InternalExtensions field, 107// thus gaining the unexported 'extensions' method, which can be called only from the proto package. 108// 109// The methods of XXX_InternalExtensions are not concurrency safe in general, 110// but calls to logically read-only methods such as has and get may be executed concurrently. 111type XXX_InternalExtensions struct { 112 // The struct must be indirect so that if a user inadvertently copies a 113 // generated message and its embedded XXX_InternalExtensions, they 114 // avoid the mayhem of a copied mutex. 115 // 116 // The mutex serializes all logically read-only operations to p.extensionMap. 117 // It is up to the client to ensure that write operations to p.extensionMap are 118 // mutually exclusive with other accesses. 119 p *struct { 120 mu sync.Mutex 121 extensionMap map[int32]Extension 122 } 123} 124 125// extensionsWrite returns the extension map, creating it on first use. 126func (e *XXX_InternalExtensions) extensionsWrite() map[int32]Extension { 127 if e.p == nil { 128 e.p = new(struct { 129 mu sync.Mutex 130 extensionMap map[int32]Extension 131 }) 132 e.p.extensionMap = make(map[int32]Extension) 133 } 134 return e.p.extensionMap 135} 136 137// extensionsRead returns the extensions map for read-only use. It may be nil. 138// The caller must hold the returned mutex's lock when accessing Elements within the map. 139func (e *XXX_InternalExtensions) extensionsRead() (map[int32]Extension, sync.Locker) { 140 if e.p == nil { 141 return nil, nil 142 } 143 return e.p.extensionMap, &e.p.mu 144} 145 146var extendableProtoType = reflect.TypeOf((*extendableProto)(nil)).Elem() 147var extendableProtoV1Type = reflect.TypeOf((*extendableProtoV1)(nil)).Elem() 148 149// ExtensionDesc represents an extension specification. 150// Used in generated code from the protocol compiler. 151type ExtensionDesc struct { 152 ExtendedType Message // nil pointer to the type that is being extended 153 ExtensionType interface{} // nil pointer to the extension type 154 Field int32 // field number 155 Name string // fully-qualified name of extension, for text formatting 156 Tag string // protobuf tag style 157 Filename string // name of the file in which the extension is defined 158} 159 160func (ed *ExtensionDesc) repeated() bool { 161 t := reflect.TypeOf(ed.ExtensionType) 162 return t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8 163} 164 165// Extension represents an extension in a message. 166type Extension struct { 167 // When an extension is stored in a message using SetExtension 168 // only desc and value are set. When the message is marshaled 169 // enc will be set to the encoded form of the message. 170 // 171 // When a message is unmarshaled and contains extensions, each 172 // extension will have only enc set. When such an extension is 173 // accessed using GetExtension (or GetExtensions) desc and value 174 // will be set. 175 desc *ExtensionDesc 176 value interface{} 177 enc []byte 178} 179 180// SetRawExtension is for testing only. 181func SetRawExtension(base Message, id int32, b []byte) { 182 epb, ok := extendable(base) 183 if !ok { 184 return 185 } 186 extmap := epb.extensionsWrite() 187 extmap[id] = Extension{enc: b} 188} 189 190// isExtensionField returns true iff the given field number is in an extension range. 191func isExtensionField(pb extendableProto, field int32) bool { 192 for _, er := range pb.ExtensionRangeArray() { 193 if er.Start <= field && field <= er.End { 194 return true 195 } 196 } 197 return false 198} 199 200// checkExtensionTypes checks that the given extension is valid for pb. 201func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) error { 202 var pbi interface{} = pb 203 // Check the extended type. 204 if ea, ok := pbi.(extensionAdapter); ok { 205 pbi = ea.extendableProtoV1 206 } 207 if a, b := reflect.TypeOf(pbi), reflect.TypeOf(extension.ExtendedType); a != b { 208 return errors.New("proto: bad extended type; " + b.String() + " does not extend " + a.String()) 209 } 210 // Check the range. 211 if !isExtensionField(pb, extension.Field) { 212 return errors.New("proto: bad extension number; not in declared ranges") 213 } 214 return nil 215} 216 217// extPropKey is sufficient to uniquely identify an extension. 218type extPropKey struct { 219 base reflect.Type 220 field int32 221} 222 223var extProp = struct { 224 sync.RWMutex 225 m map[extPropKey]*Properties 226}{ 227 m: make(map[extPropKey]*Properties), 228} 229 230func extensionProperties(ed *ExtensionDesc) *Properties { 231 key := extPropKey{base: reflect.TypeOf(ed.ExtendedType), field: ed.Field} 232 233 extProp.RLock() 234 if prop, ok := extProp.m[key]; ok { 235 extProp.RUnlock() 236 return prop 237 } 238 extProp.RUnlock() 239 240 extProp.Lock() 241 defer extProp.Unlock() 242 // Check again. 243 if prop, ok := extProp.m[key]; ok { 244 return prop 245 } 246 247 prop := new(Properties) 248 prop.Init(reflect.TypeOf(ed.ExtensionType), "unknown_name", ed.Tag, nil) 249 extProp.m[key] = prop 250 return prop 251} 252 253// encode encodes any unmarshaled (unencoded) extensions in e. 254func encodeExtensions(e *XXX_InternalExtensions) error { 255 m, mu := e.extensionsRead() 256 if m == nil { 257 return nil // fast path 258 } 259 mu.Lock() 260 defer mu.Unlock() 261 return encodeExtensionsMap(m) 262} 263 264// encode encodes any unmarshaled (unencoded) extensions in e. 265func encodeExtensionsMap(m map[int32]Extension) error { 266 for k, e := range m { 267 if e.value == nil || e.desc == nil { 268 // Extension is only in its encoded form. 269 continue 270 } 271 272 // We don't skip extensions that have an encoded form set, 273 // because the extension value may have been mutated after 274 // the last time this function was called. 275 276 et := reflect.TypeOf(e.desc.ExtensionType) 277 props := extensionProperties(e.desc) 278 279 p := NewBuffer(nil) 280 // If e.value has type T, the encoder expects a *struct{ X T }. 281 // Pass a *T with a zero field and hope it all works out. 282 x := reflect.New(et) 283 x.Elem().Set(reflect.ValueOf(e.value)) 284 if err := props.enc(p, props, toStructPointer(x)); err != nil { 285 return err 286 } 287 e.enc = p.buf 288 m[k] = e 289 } 290 return nil 291} 292 293func extensionsSize(e *XXX_InternalExtensions) (n int) { 294 m, mu := e.extensionsRead() 295 if m == nil { 296 return 0 297 } 298 mu.Lock() 299 defer mu.Unlock() 300 return extensionsMapSize(m) 301} 302 303func extensionsMapSize(m map[int32]Extension) (n int) { 304 for _, e := range m { 305 if e.value == nil || e.desc == nil { 306 // Extension is only in its encoded form. 307 n += len(e.enc) 308 continue 309 } 310 311 // We don't skip extensions that have an encoded form set, 312 // because the extension value may have been mutated after 313 // the last time this function was called. 314 315 et := reflect.TypeOf(e.desc.ExtensionType) 316 props := extensionProperties(e.desc) 317 318 // If e.value has type T, the encoder expects a *struct{ X T }. 319 // Pass a *T with a zero field and hope it all works out. 320 x := reflect.New(et) 321 x.Elem().Set(reflect.ValueOf(e.value)) 322 n += props.size(props, toStructPointer(x)) 323 } 324 return 325} 326 327// HasExtension returns whether the given extension is present in pb. 328func HasExtension(pb Message, extension *ExtensionDesc) bool { 329 // TODO: Check types, field numbers, etc.? 330 epb, ok := extendable(pb) 331 if !ok { 332 return false 333 } 334 extmap, mu := epb.extensionsRead() 335 if extmap == nil { 336 return false 337 } 338 mu.Lock() 339 _, ok = extmap[extension.Field] 340 mu.Unlock() 341 return ok 342} 343 344// ClearExtension removes the given extension from pb. 345func ClearExtension(pb Message, extension *ExtensionDesc) { 346 epb, ok := extendable(pb) 347 if !ok { 348 return 349 } 350 // TODO: Check types, field numbers, etc.? 351 extmap := epb.extensionsWrite() 352 delete(extmap, extension.Field) 353} 354 355// GetExtension parses and returns the given extension of pb. 356// If the extension is not present and has no default value it returns ErrMissingExtension. 357func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) { 358 epb, ok := extendable(pb) 359 if !ok { 360 return nil, errors.New("proto: not an extendable proto") 361 } 362 363 if err := checkExtensionTypes(epb, extension); err != nil { 364 return nil, err 365 } 366 367 emap, mu := epb.extensionsRead() 368 if emap == nil { 369 return defaultExtensionValue(extension) 370 } 371 mu.Lock() 372 defer mu.Unlock() 373 e, ok := emap[extension.Field] 374 if !ok { 375 // defaultExtensionValue returns the default value or 376 // ErrMissingExtension if there is no default. 377 return defaultExtensionValue(extension) 378 } 379 380 if e.value != nil { 381 // Already decoded. Check the descriptor, though. 382 if e.desc != extension { 383 // This shouldn't happen. If it does, it means that 384 // GetExtension was called twice with two different 385 // descriptors with the same field number. 386 return nil, errors.New("proto: descriptor conflict") 387 } 388 return e.value, nil 389 } 390 391 v, err := decodeExtension(e.enc, extension) 392 if err != nil { 393 return nil, err 394 } 395 396 // Remember the decoded version and drop the encoded version. 397 // That way it is safe to mutate what we return. 398 e.value = v 399 e.desc = extension 400 e.enc = nil 401 emap[extension.Field] = e 402 return e.value, nil 403} 404 405// defaultExtensionValue returns the default value for extension. 406// If no default for an extension is defined ErrMissingExtension is returned. 407func defaultExtensionValue(extension *ExtensionDesc) (interface{}, error) { 408 t := reflect.TypeOf(extension.ExtensionType) 409 props := extensionProperties(extension) 410 411 sf, _, err := fieldDefault(t, props) 412 if err != nil { 413 return nil, err 414 } 415 416 if sf == nil || sf.value == nil { 417 // There is no default value. 418 return nil, ErrMissingExtension 419 } 420 421 if t.Kind() != reflect.Ptr { 422 // We do not need to return a Ptr, we can directly return sf.value. 423 return sf.value, nil 424 } 425 426 // We need to return an interface{} that is a pointer to sf.value. 427 value := reflect.New(t).Elem() 428 value.Set(reflect.New(value.Type().Elem())) 429 if sf.kind == reflect.Int32 { 430 // We may have an int32 or an enum, but the underlying data is int32. 431 // Since we can't set an int32 into a non int32 reflect.value directly 432 // set it as a int32. 433 value.Elem().SetInt(int64(sf.value.(int32))) 434 } else { 435 value.Elem().Set(reflect.ValueOf(sf.value)) 436 } 437 return value.Interface(), nil 438} 439 440// decodeExtension decodes an extension encoded in b. 441func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) { 442 o := NewBuffer(b) 443 444 t := reflect.TypeOf(extension.ExtensionType) 445 446 props := extensionProperties(extension) 447 448 // t is a pointer to a struct, pointer to basic type or a slice. 449 // Allocate a "field" to store the pointer/slice itself; the 450 // pointer/slice will be stored here. We pass 451 // the address of this field to props.dec. 452 // This passes a zero field and a *t and lets props.dec 453 // interpret it as a *struct{ x t }. 454 value := reflect.New(t).Elem() 455 456 for { 457 // Discard wire type and field number varint. It isn't needed. 458 if _, err := o.DecodeVarint(); err != nil { 459 return nil, err 460 } 461 462 if err := props.dec(o, props, toStructPointer(value.Addr())); err != nil { 463 return nil, err 464 } 465 466 if o.index >= len(o.buf) { 467 break 468 } 469 } 470 return value.Interface(), nil 471} 472 473// GetExtensions returns a slice of the extensions present in pb that are also listed in es. 474// The returned slice has the same length as es; missing extensions will appear as nil elements. 475func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) { 476 epb, ok := extendable(pb) 477 if !ok { 478 return nil, errors.New("proto: not an extendable proto") 479 } 480 extensions = make([]interface{}, len(es)) 481 for i, e := range es { 482 extensions[i], err = GetExtension(epb, e) 483 if err == ErrMissingExtension { 484 err = nil 485 } 486 if err != nil { 487 return 488 } 489 } 490 return 491} 492 493// ExtensionDescs returns a new slice containing pb's extension descriptors, in undefined order. 494// For non-registered extensions, ExtensionDescs returns an incomplete descriptor containing 495// just the Field field, which defines the extension's field number. 496func ExtensionDescs(pb Message) ([]*ExtensionDesc, error) { 497 epb, ok := extendable(pb) 498 if !ok { 499 return nil, fmt.Errorf("proto: %T is not an extendable proto.Message", pb) 500 } 501 registeredExtensions := RegisteredExtensions(pb) 502 503 emap, mu := epb.extensionsRead() 504 if emap == nil { 505 return nil, nil 506 } 507 mu.Lock() 508 defer mu.Unlock() 509 extensions := make([]*ExtensionDesc, 0, len(emap)) 510 for extid, e := range emap { 511 desc := e.desc 512 if desc == nil { 513 desc = registeredExtensions[extid] 514 if desc == nil { 515 desc = &ExtensionDesc{Field: extid} 516 } 517 } 518 519 extensions = append(extensions, desc) 520 } 521 return extensions, nil 522} 523 524// SetExtension sets the specified extension of pb to the specified value. 525func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error { 526 epb, ok := extendable(pb) 527 if !ok { 528 return errors.New("proto: not an extendable proto") 529 } 530 if err := checkExtensionTypes(epb, extension); err != nil { 531 return err 532 } 533 typ := reflect.TypeOf(extension.ExtensionType) 534 if typ != reflect.TypeOf(value) { 535 return errors.New("proto: bad extension value type") 536 } 537 // nil extension values need to be caught early, because the 538 // encoder can't distinguish an ErrNil due to a nil extension 539 // from an ErrNil due to a missing field. Extensions are 540 // always optional, so the encoder would just swallow the error 541 // and drop all the extensions from the encoded message. 542 if reflect.ValueOf(value).IsNil() { 543 return fmt.Errorf("proto: SetExtension called with nil value of type %T", value) 544 } 545 546 extmap := epb.extensionsWrite() 547 extmap[extension.Field] = Extension{desc: extension, value: value} 548 return nil 549} 550 551// ClearAllExtensions clears all extensions from pb. 552func ClearAllExtensions(pb Message) { 553 epb, ok := extendable(pb) 554 if !ok { 555 return 556 } 557 m := epb.extensionsWrite() 558 for k := range m { 559 delete(m, k) 560 } 561} 562 563// A global registry of extensions. 564// The generated code will register the generated descriptors by calling RegisterExtension. 565 566var extensionMaps = make(map[reflect.Type]map[int32]*ExtensionDesc) 567 568// RegisterExtension is called from the generated code. 569func RegisterExtension(desc *ExtensionDesc) { 570 st := reflect.TypeOf(desc.ExtendedType).Elem() 571 m := extensionMaps[st] 572 if m == nil { 573 m = make(map[int32]*ExtensionDesc) 574 extensionMaps[st] = m 575 } 576 if _, ok := m[desc.Field]; ok { 577 panic("proto: duplicate extension registered: " + st.String() + " " + strconv.Itoa(int(desc.Field))) 578 } 579 m[desc.Field] = desc 580} 581 582// RegisteredExtensions returns a map of the registered extensions of a 583// protocol buffer struct, indexed by the extension number. 584// The argument pb should be a nil pointer to the struct type. 585func RegisteredExtensions(pb Message) map[int32]*ExtensionDesc { 586 return extensionMaps[reflect.TypeOf(pb).Elem()] 587} 588