1// Go support for Protocol Buffers - Google's data interchange format 2// 3// Copyright 2014 The Go Authors. All rights reserved. 4// https://github.com/golang/protobuf 5// 6// Redistribution and use in source and binary forms, with or without 7// modification, are permitted provided that the following conditions are 8// met: 9// 10// * Redistributions of source code must retain the above copyright 11// notice, this list of conditions and the following disclaimer. 12// * Redistributions in binary form must reproduce the above 13// copyright notice, this list of conditions and the following disclaimer 14// in the documentation and/or other materials provided with the 15// distribution. 16// * Neither the name of Google Inc. nor the names of its 17// contributors may be used to endorse or promote products derived from 18// this software without specific prior written permission. 19// 20// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 32// AOSP change: ignore this file, since AOSP does not include 33// golang.org/x/sync/errgroup 34// +build ignore 35 36package proto_test 37 38import ( 39 "bytes" 40 "fmt" 41 "io" 42 "reflect" 43 "sort" 44 "strings" 45 "testing" 46 47 "github.com/golang/protobuf/proto" 48 pb "github.com/golang/protobuf/proto/test_proto" 49 "golang.org/x/sync/errgroup" 50) 51 52func TestGetExtensionsWithMissingExtensions(t *testing.T) { 53 msg := &pb.MyMessage{} 54 ext1 := &pb.Ext{} 55 if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil { 56 t.Fatalf("Could not set ext1: %s", err) 57 } 58 exts, err := proto.GetExtensions(msg, []*proto.ExtensionDesc{ 59 pb.E_Ext_More, 60 pb.E_Ext_Text, 61 }) 62 if err != nil { 63 t.Fatalf("GetExtensions() failed: %s", err) 64 } 65 if exts[0] != ext1 { 66 t.Errorf("ext1 not in returned extensions: %T %v", exts[0], exts[0]) 67 } 68 if exts[1] != nil { 69 t.Errorf("ext2 in returned extensions: %T %v", exts[1], exts[1]) 70 } 71} 72 73func TestGetExtensionWithEmptyBuffer(t *testing.T) { 74 // Make sure that GetExtension returns an error if its 75 // undecoded buffer is empty. 76 msg := &pb.MyMessage{} 77 proto.SetRawExtension(msg, pb.E_Ext_More.Field, []byte{}) 78 _, err := proto.GetExtension(msg, pb.E_Ext_More) 79 if want := io.ErrUnexpectedEOF; err != want { 80 t.Errorf("unexpected error in GetExtension from empty buffer: got %v, want %v", err, want) 81 } 82} 83 84func TestGetExtensionForIncompleteDesc(t *testing.T) { 85 msg := &pb.MyMessage{Count: proto.Int32(0)} 86 extdesc1 := &proto.ExtensionDesc{ 87 ExtendedType: (*pb.MyMessage)(nil), 88 ExtensionType: (*bool)(nil), 89 Field: 123456789, 90 Name: "a.b", 91 Tag: "varint,123456789,opt", 92 } 93 ext1 := proto.Bool(true) 94 if err := proto.SetExtension(msg, extdesc1, ext1); err != nil { 95 t.Fatalf("Could not set ext1: %s", err) 96 } 97 extdesc2 := &proto.ExtensionDesc{ 98 ExtendedType: (*pb.MyMessage)(nil), 99 ExtensionType: ([]byte)(nil), 100 Field: 123456790, 101 Name: "a.c", 102 Tag: "bytes,123456790,opt", 103 } 104 ext2 := []byte{0, 1, 2, 3, 4, 5, 6, 7} 105 if err := proto.SetExtension(msg, extdesc2, ext2); err != nil { 106 t.Fatalf("Could not set ext2: %s", err) 107 } 108 extdesc3 := &proto.ExtensionDesc{ 109 ExtendedType: (*pb.MyMessage)(nil), 110 ExtensionType: (*pb.Ext)(nil), 111 Field: 123456791, 112 Name: "a.d", 113 Tag: "bytes,123456791,opt", 114 } 115 ext3 := &pb.Ext{Data: proto.String("foo")} 116 if err := proto.SetExtension(msg, extdesc3, ext3); err != nil { 117 t.Fatalf("Could not set ext3: %s", err) 118 } 119 120 b, err := proto.Marshal(msg) 121 if err != nil { 122 t.Fatalf("Could not marshal msg: %v", err) 123 } 124 if err := proto.Unmarshal(b, msg); err != nil { 125 t.Fatalf("Could not unmarshal into msg: %v", err) 126 } 127 128 var expected proto.Buffer 129 if err := expected.EncodeVarint(uint64((extdesc1.Field << 3) | proto.WireVarint)); err != nil { 130 t.Fatalf("failed to compute expected prefix for ext1: %s", err) 131 } 132 if err := expected.EncodeVarint(1 /* bool true */); err != nil { 133 t.Fatalf("failed to compute expected value for ext1: %s", err) 134 } 135 136 if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc1.Field}); err != nil { 137 t.Fatalf("Failed to get raw value for ext1: %s", err) 138 } else if !reflect.DeepEqual(b, expected.Bytes()) { 139 t.Fatalf("Raw value for ext1: got %v, want %v", b, expected.Bytes()) 140 } 141 142 expected = proto.Buffer{} // reset 143 if err := expected.EncodeVarint(uint64((extdesc2.Field << 3) | proto.WireBytes)); err != nil { 144 t.Fatalf("failed to compute expected prefix for ext2: %s", err) 145 } 146 if err := expected.EncodeRawBytes(ext2); err != nil { 147 t.Fatalf("failed to compute expected value for ext2: %s", err) 148 } 149 150 if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc2.Field}); err != nil { 151 t.Fatalf("Failed to get raw value for ext2: %s", err) 152 } else if !reflect.DeepEqual(b, expected.Bytes()) { 153 t.Fatalf("Raw value for ext2: got %v, want %v", b, expected.Bytes()) 154 } 155 156 expected = proto.Buffer{} // reset 157 if err := expected.EncodeVarint(uint64((extdesc3.Field << 3) | proto.WireBytes)); err != nil { 158 t.Fatalf("failed to compute expected prefix for ext3: %s", err) 159 } 160 if b, err := proto.Marshal(ext3); err != nil { 161 t.Fatalf("failed to compute expected value for ext3: %s", err) 162 } else if err := expected.EncodeRawBytes(b); err != nil { 163 t.Fatalf("failed to compute expected value for ext3: %s", err) 164 } 165 166 if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc3.Field}); err != nil { 167 t.Fatalf("Failed to get raw value for ext3: %s", err) 168 } else if !reflect.DeepEqual(b, expected.Bytes()) { 169 t.Fatalf("Raw value for ext3: got %v, want %v", b, expected.Bytes()) 170 } 171} 172 173func TestExtensionDescsWithUnregisteredExtensions(t *testing.T) { 174 msg := &pb.MyMessage{Count: proto.Int32(0)} 175 extdesc1 := pb.E_Ext_More 176 if descs, err := proto.ExtensionDescs(msg); len(descs) != 0 || err != nil { 177 t.Errorf("proto.ExtensionDescs: got %d descs, error %v; want 0, nil", len(descs), err) 178 } 179 180 ext1 := &pb.Ext{} 181 if err := proto.SetExtension(msg, extdesc1, ext1); err != nil { 182 t.Fatalf("Could not set ext1: %s", err) 183 } 184 extdesc2 := &proto.ExtensionDesc{ 185 ExtendedType: (*pb.MyMessage)(nil), 186 ExtensionType: (*bool)(nil), 187 Field: 123456789, 188 Name: "a.b", 189 Tag: "varint,123456789,opt", 190 } 191 ext2 := proto.Bool(false) 192 if err := proto.SetExtension(msg, extdesc2, ext2); err != nil { 193 t.Fatalf("Could not set ext2: %s", err) 194 } 195 196 b, err := proto.Marshal(msg) 197 if err != nil { 198 t.Fatalf("Could not marshal msg: %v", err) 199 } 200 if err := proto.Unmarshal(b, msg); err != nil { 201 t.Fatalf("Could not unmarshal into msg: %v", err) 202 } 203 204 descs, err := proto.ExtensionDescs(msg) 205 if err != nil { 206 t.Fatalf("proto.ExtensionDescs: got error %v", err) 207 } 208 sortExtDescs(descs) 209 wantDescs := []*proto.ExtensionDesc{extdesc1, {Field: extdesc2.Field}} 210 if !reflect.DeepEqual(descs, wantDescs) { 211 t.Errorf("proto.ExtensionDescs(msg) sorted extension ids: got %+v, want %+v", descs, wantDescs) 212 } 213} 214 215type ExtensionDescSlice []*proto.ExtensionDesc 216 217func (s ExtensionDescSlice) Len() int { return len(s) } 218func (s ExtensionDescSlice) Less(i, j int) bool { return s[i].Field < s[j].Field } 219func (s ExtensionDescSlice) Swap(i, j int) { s[i], s[j] = s[j], s[i] } 220 221func sortExtDescs(s []*proto.ExtensionDesc) { 222 sort.Sort(ExtensionDescSlice(s)) 223} 224 225func TestGetExtensionStability(t *testing.T) { 226 check := func(m *pb.MyMessage) bool { 227 ext1, err := proto.GetExtension(m, pb.E_Ext_More) 228 if err != nil { 229 t.Fatalf("GetExtension() failed: %s", err) 230 } 231 ext2, err := proto.GetExtension(m, pb.E_Ext_More) 232 if err != nil { 233 t.Fatalf("GetExtension() failed: %s", err) 234 } 235 return ext1 == ext2 236 } 237 msg := &pb.MyMessage{Count: proto.Int32(4)} 238 ext0 := &pb.Ext{} 239 if err := proto.SetExtension(msg, pb.E_Ext_More, ext0); err != nil { 240 t.Fatalf("Could not set ext1: %s", ext0) 241 } 242 if !check(msg) { 243 t.Errorf("GetExtension() not stable before marshaling") 244 } 245 bb, err := proto.Marshal(msg) 246 if err != nil { 247 t.Fatalf("Marshal() failed: %s", err) 248 } 249 msg1 := &pb.MyMessage{} 250 err = proto.Unmarshal(bb, msg1) 251 if err != nil { 252 t.Fatalf("Unmarshal() failed: %s", err) 253 } 254 if !check(msg1) { 255 t.Errorf("GetExtension() not stable after unmarshaling") 256 } 257} 258 259func TestGetExtensionDefaults(t *testing.T) { 260 var setFloat64 float64 = 1 261 var setFloat32 float32 = 2 262 var setInt32 int32 = 3 263 var setInt64 int64 = 4 264 var setUint32 uint32 = 5 265 var setUint64 uint64 = 6 266 var setBool = true 267 var setBool2 = false 268 var setString = "Goodnight string" 269 var setBytes = []byte("Goodnight bytes") 270 var setEnum = pb.DefaultsMessage_TWO 271 272 type testcase struct { 273 ext *proto.ExtensionDesc // Extension we are testing. 274 want interface{} // Expected value of extension, or nil (meaning that GetExtension will fail). 275 def interface{} // Expected value of extension after ClearExtension(). 276 } 277 tests := []testcase{ 278 {pb.E_NoDefaultDouble, setFloat64, nil}, 279 {pb.E_NoDefaultFloat, setFloat32, nil}, 280 {pb.E_NoDefaultInt32, setInt32, nil}, 281 {pb.E_NoDefaultInt64, setInt64, nil}, 282 {pb.E_NoDefaultUint32, setUint32, nil}, 283 {pb.E_NoDefaultUint64, setUint64, nil}, 284 {pb.E_NoDefaultSint32, setInt32, nil}, 285 {pb.E_NoDefaultSint64, setInt64, nil}, 286 {pb.E_NoDefaultFixed32, setUint32, nil}, 287 {pb.E_NoDefaultFixed64, setUint64, nil}, 288 {pb.E_NoDefaultSfixed32, setInt32, nil}, 289 {pb.E_NoDefaultSfixed64, setInt64, nil}, 290 {pb.E_NoDefaultBool, setBool, nil}, 291 {pb.E_NoDefaultBool, setBool2, nil}, 292 {pb.E_NoDefaultString, setString, nil}, 293 {pb.E_NoDefaultBytes, setBytes, nil}, 294 {pb.E_NoDefaultEnum, setEnum, nil}, 295 {pb.E_DefaultDouble, setFloat64, float64(3.1415)}, 296 {pb.E_DefaultFloat, setFloat32, float32(3.14)}, 297 {pb.E_DefaultInt32, setInt32, int32(42)}, 298 {pb.E_DefaultInt64, setInt64, int64(43)}, 299 {pb.E_DefaultUint32, setUint32, uint32(44)}, 300 {pb.E_DefaultUint64, setUint64, uint64(45)}, 301 {pb.E_DefaultSint32, setInt32, int32(46)}, 302 {pb.E_DefaultSint64, setInt64, int64(47)}, 303 {pb.E_DefaultFixed32, setUint32, uint32(48)}, 304 {pb.E_DefaultFixed64, setUint64, uint64(49)}, 305 {pb.E_DefaultSfixed32, setInt32, int32(50)}, 306 {pb.E_DefaultSfixed64, setInt64, int64(51)}, 307 {pb.E_DefaultBool, setBool, true}, 308 {pb.E_DefaultBool, setBool2, true}, 309 {pb.E_DefaultString, setString, "Hello, string,def=foo"}, 310 {pb.E_DefaultBytes, setBytes, []byte("Hello, bytes")}, 311 {pb.E_DefaultEnum, setEnum, pb.DefaultsMessage_ONE}, 312 } 313 314 checkVal := func(test testcase, msg *pb.DefaultsMessage, valWant interface{}) error { 315 val, err := proto.GetExtension(msg, test.ext) 316 if err != nil { 317 if valWant != nil { 318 return fmt.Errorf("GetExtension(): %s", err) 319 } 320 if want := proto.ErrMissingExtension; err != want { 321 return fmt.Errorf("Unexpected error: got %v, want %v", err, want) 322 } 323 return nil 324 } 325 326 // All proto2 extension values are either a pointer to a value or a slice of values. 327 ty := reflect.TypeOf(val) 328 tyWant := reflect.TypeOf(test.ext.ExtensionType) 329 if got, want := ty, tyWant; got != want { 330 return fmt.Errorf("unexpected reflect.TypeOf(): got %v want %v", got, want) 331 } 332 tye := ty.Elem() 333 tyeWant := tyWant.Elem() 334 if got, want := tye, tyeWant; got != want { 335 return fmt.Errorf("unexpected reflect.TypeOf().Elem(): got %v want %v", got, want) 336 } 337 338 // Check the name of the type of the value. 339 // If it is an enum it will be type int32 with the name of the enum. 340 if got, want := tye.Name(), tye.Name(); got != want { 341 return fmt.Errorf("unexpected reflect.TypeOf().Elem().Name(): got %v want %v", got, want) 342 } 343 344 // Check that value is what we expect. 345 // If we have a pointer in val, get the value it points to. 346 valExp := val 347 if ty.Kind() == reflect.Ptr { 348 valExp = reflect.ValueOf(val).Elem().Interface() 349 } 350 if got, want := valExp, valWant; !reflect.DeepEqual(got, want) { 351 return fmt.Errorf("unexpected reflect.DeepEqual(): got %v want %v", got, want) 352 } 353 354 return nil 355 } 356 357 setTo := func(test testcase) interface{} { 358 setTo := reflect.ValueOf(test.want) 359 if typ := reflect.TypeOf(test.ext.ExtensionType); typ.Kind() == reflect.Ptr { 360 setTo = reflect.New(typ).Elem() 361 setTo.Set(reflect.New(setTo.Type().Elem())) 362 setTo.Elem().Set(reflect.ValueOf(test.want)) 363 } 364 return setTo.Interface() 365 } 366 367 for _, test := range tests { 368 msg := &pb.DefaultsMessage{} 369 name := test.ext.Name 370 371 // Check the initial value. 372 if err := checkVal(test, msg, test.def); err != nil { 373 t.Errorf("%s: %v", name, err) 374 } 375 376 // Set the per-type value and check value. 377 name = fmt.Sprintf("%s (set to %T %v)", name, test.want, test.want) 378 if err := proto.SetExtension(msg, test.ext, setTo(test)); err != nil { 379 t.Errorf("%s: SetExtension(): %v", name, err) 380 continue 381 } 382 if err := checkVal(test, msg, test.want); err != nil { 383 t.Errorf("%s: %v", name, err) 384 continue 385 } 386 387 // Set and check the value. 388 name += " (cleared)" 389 proto.ClearExtension(msg, test.ext) 390 if err := checkVal(test, msg, test.def); err != nil { 391 t.Errorf("%s: %v", name, err) 392 } 393 } 394} 395 396func TestNilMessage(t *testing.T) { 397 name := "nil interface" 398 if got, err := proto.GetExtension(nil, pb.E_Ext_More); err == nil { 399 t.Errorf("%s: got %T %v, expected to fail", name, got, got) 400 } else if !strings.Contains(err.Error(), "extendable") { 401 t.Errorf("%s: got error %v, expected not-extendable error", name, err) 402 } 403 404 // Regression tests: all functions of the Extension API 405 // used to panic when passed (*M)(nil), where M is a concrete message 406 // type. Now they handle this gracefully as a no-op or reported error. 407 var nilMsg *pb.MyMessage 408 desc := pb.E_Ext_More 409 410 isNotExtendable := func(err error) bool { 411 return strings.Contains(fmt.Sprint(err), "not extendable") 412 } 413 414 if proto.HasExtension(nilMsg, desc) { 415 t.Error("HasExtension(nil) = true") 416 } 417 418 if _, err := proto.GetExtensions(nilMsg, []*proto.ExtensionDesc{desc}); !isNotExtendable(err) { 419 t.Errorf("GetExtensions(nil) = %q (wrong error)", err) 420 } 421 422 if _, err := proto.ExtensionDescs(nilMsg); !isNotExtendable(err) { 423 t.Errorf("ExtensionDescs(nil) = %q (wrong error)", err) 424 } 425 426 if err := proto.SetExtension(nilMsg, desc, nil); !isNotExtendable(err) { 427 t.Errorf("SetExtension(nil) = %q (wrong error)", err) 428 } 429 430 proto.ClearExtension(nilMsg, desc) // no-op 431 proto.ClearAllExtensions(nilMsg) // no-op 432} 433 434func TestExtensionsRoundTrip(t *testing.T) { 435 msg := &pb.MyMessage{} 436 ext1 := &pb.Ext{ 437 Data: proto.String("hi"), 438 } 439 ext2 := &pb.Ext{ 440 Data: proto.String("there"), 441 } 442 exists := proto.HasExtension(msg, pb.E_Ext_More) 443 if exists { 444 t.Error("Extension More present unexpectedly") 445 } 446 if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil { 447 t.Error(err) 448 } 449 if err := proto.SetExtension(msg, pb.E_Ext_More, ext2); err != nil { 450 t.Error(err) 451 } 452 e, err := proto.GetExtension(msg, pb.E_Ext_More) 453 if err != nil { 454 t.Error(err) 455 } 456 x, ok := e.(*pb.Ext) 457 if !ok { 458 t.Errorf("e has type %T, expected test_proto.Ext", e) 459 } else if *x.Data != "there" { 460 t.Errorf("SetExtension failed to overwrite, got %+v, not 'there'", x) 461 } 462 proto.ClearExtension(msg, pb.E_Ext_More) 463 if _, err = proto.GetExtension(msg, pb.E_Ext_More); err != proto.ErrMissingExtension { 464 t.Errorf("got %v, expected ErrMissingExtension", e) 465 } 466 if _, err := proto.GetExtension(msg, pb.E_X215); err == nil { 467 t.Error("expected bad extension error, got nil") 468 } 469 if err := proto.SetExtension(msg, pb.E_X215, 12); err == nil { 470 t.Error("expected extension err") 471 } 472 if err := proto.SetExtension(msg, pb.E_Ext_More, 12); err == nil { 473 t.Error("expected some sort of type mismatch error, got nil") 474 } 475} 476 477func TestNilExtension(t *testing.T) { 478 msg := &pb.MyMessage{ 479 Count: proto.Int32(1), 480 } 481 if err := proto.SetExtension(msg, pb.E_Ext_Text, proto.String("hello")); err != nil { 482 t.Fatal(err) 483 } 484 if err := proto.SetExtension(msg, pb.E_Ext_More, (*pb.Ext)(nil)); err == nil { 485 t.Error("expected SetExtension to fail due to a nil extension") 486 } else if want := fmt.Sprintf("proto: SetExtension called with nil value of type %T", new(pb.Ext)); err.Error() != want { 487 t.Errorf("expected error %v, got %v", want, err) 488 } 489 // Note: if the behavior of Marshal is ever changed to ignore nil extensions, update 490 // this test to verify that E_Ext_Text is properly propagated through marshal->unmarshal. 491} 492 493func TestMarshalUnmarshalRepeatedExtension(t *testing.T) { 494 // Add a repeated extension to the result. 495 tests := []struct { 496 name string 497 ext []*pb.ComplexExtension 498 }{ 499 { 500 "two fields", 501 []*pb.ComplexExtension{ 502 {First: proto.Int32(7)}, 503 {Second: proto.Int32(11)}, 504 }, 505 }, 506 { 507 "repeated field", 508 []*pb.ComplexExtension{ 509 {Third: []int32{1000}}, 510 {Third: []int32{2000}}, 511 }, 512 }, 513 { 514 "two fields and repeated field", 515 []*pb.ComplexExtension{ 516 {Third: []int32{1000}}, 517 {First: proto.Int32(9)}, 518 {Second: proto.Int32(21)}, 519 {Third: []int32{2000}}, 520 }, 521 }, 522 } 523 for _, test := range tests { 524 // Marshal message with a repeated extension. 525 msg1 := new(pb.OtherMessage) 526 err := proto.SetExtension(msg1, pb.E_RComplex, test.ext) 527 if err != nil { 528 t.Fatalf("[%s] Error setting extension: %v", test.name, err) 529 } 530 b, err := proto.Marshal(msg1) 531 if err != nil { 532 t.Fatalf("[%s] Error marshaling message: %v", test.name, err) 533 } 534 535 // Unmarshal and read the merged proto. 536 msg2 := new(pb.OtherMessage) 537 err = proto.Unmarshal(b, msg2) 538 if err != nil { 539 t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err) 540 } 541 e, err := proto.GetExtension(msg2, pb.E_RComplex) 542 if err != nil { 543 t.Fatalf("[%s] Error getting extension: %v", test.name, err) 544 } 545 ext := e.([]*pb.ComplexExtension) 546 if ext == nil { 547 t.Fatalf("[%s] Invalid extension", test.name) 548 } 549 if len(ext) != len(test.ext) { 550 t.Errorf("[%s] Wrong length of ComplexExtension: got: %v want: %v\n", test.name, len(ext), len(test.ext)) 551 } 552 for i := range test.ext { 553 if !proto.Equal(ext[i], test.ext[i]) { 554 t.Errorf("[%s] Wrong value for ComplexExtension[%d]: got: %v want: %v\n", test.name, i, ext[i], test.ext[i]) 555 } 556 } 557 } 558} 559 560func TestUnmarshalRepeatingNonRepeatedExtension(t *testing.T) { 561 // We may see multiple instances of the same extension in the wire 562 // format. For example, the proto compiler may encode custom options in 563 // this way. Here, we verify that we merge the extensions together. 564 tests := []struct { 565 name string 566 ext []*pb.ComplexExtension 567 }{ 568 { 569 "two fields", 570 []*pb.ComplexExtension{ 571 {First: proto.Int32(7)}, 572 {Second: proto.Int32(11)}, 573 }, 574 }, 575 { 576 "repeated field", 577 []*pb.ComplexExtension{ 578 {Third: []int32{1000}}, 579 {Third: []int32{2000}}, 580 }, 581 }, 582 { 583 "two fields and repeated field", 584 []*pb.ComplexExtension{ 585 {Third: []int32{1000}}, 586 {First: proto.Int32(9)}, 587 {Second: proto.Int32(21)}, 588 {Third: []int32{2000}}, 589 }, 590 }, 591 } 592 for _, test := range tests { 593 var buf bytes.Buffer 594 var want pb.ComplexExtension 595 596 // Generate a serialized representation of a repeated extension 597 // by catenating bytes together. 598 for i, e := range test.ext { 599 // Merge to create the wanted proto. 600 proto.Merge(&want, e) 601 602 // serialize the message 603 msg := new(pb.OtherMessage) 604 err := proto.SetExtension(msg, pb.E_Complex, e) 605 if err != nil { 606 t.Fatalf("[%s] Error setting extension %d: %v", test.name, i, err) 607 } 608 b, err := proto.Marshal(msg) 609 if err != nil { 610 t.Fatalf("[%s] Error marshaling message %d: %v", test.name, i, err) 611 } 612 buf.Write(b) 613 } 614 615 // Unmarshal and read the merged proto. 616 msg2 := new(pb.OtherMessage) 617 err := proto.Unmarshal(buf.Bytes(), msg2) 618 if err != nil { 619 t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err) 620 } 621 e, err := proto.GetExtension(msg2, pb.E_Complex) 622 if err != nil { 623 t.Fatalf("[%s] Error getting extension: %v", test.name, err) 624 } 625 ext := e.(*pb.ComplexExtension) 626 if ext == nil { 627 t.Fatalf("[%s] Invalid extension", test.name) 628 } 629 if !proto.Equal(ext, &want) { 630 t.Errorf("[%s] Wrong value for ComplexExtension: got: %s want: %s\n", test.name, ext, &want) 631 } 632 } 633} 634 635func TestClearAllExtensions(t *testing.T) { 636 // unregistered extension 637 desc := &proto.ExtensionDesc{ 638 ExtendedType: (*pb.MyMessage)(nil), 639 ExtensionType: (*bool)(nil), 640 Field: 101010100, 641 Name: "emptyextension", 642 Tag: "varint,0,opt", 643 } 644 m := &pb.MyMessage{} 645 if proto.HasExtension(m, desc) { 646 t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m)) 647 } 648 if err := proto.SetExtension(m, desc, proto.Bool(true)); err != nil { 649 t.Errorf("proto.SetExtension(m, desc, true): got error %q, want nil", err) 650 } 651 if !proto.HasExtension(m, desc) { 652 t.Errorf("proto.HasExtension(%s): got false, want true", proto.MarshalTextString(m)) 653 } 654 proto.ClearAllExtensions(m) 655 if proto.HasExtension(m, desc) { 656 t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m)) 657 } 658} 659 660func TestMarshalRace(t *testing.T) { 661 ext := &pb.Ext{} 662 m := &pb.MyMessage{Count: proto.Int32(4)} 663 if err := proto.SetExtension(m, pb.E_Ext_More, ext); err != nil { 664 t.Fatalf("proto.SetExtension(m, desc, true): got error %q, want nil", err) 665 } 666 667 b, err := proto.Marshal(m) 668 if err != nil { 669 t.Fatalf("Could not marshal message: %v", err) 670 } 671 if err := proto.Unmarshal(b, m); err != nil { 672 t.Fatalf("Could not unmarshal message: %v", err) 673 } 674 // after Unmarshal, the extension is in undecoded form. 675 // GetExtension will decode it lazily. Make sure this does 676 // not race against Marshal. 677 678 var g errgroup.Group 679 for n := 3; n > 0; n-- { 680 g.Go(func() error { 681 _, err := proto.Marshal(m) 682 return err 683 }) 684 g.Go(func() error { 685 _, err := proto.GetExtension(m, pb.E_Ext_More) 686 return err 687 }) 688 } 689 if err := g.Wait(); err != nil { 690 t.Fatal(err) 691 } 692} 693