• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Go support for Protocol Buffers - Google's data interchange format
2//
3// Copyright 2014 The Go Authors.  All rights reserved.
4// https://github.com/golang/protobuf
5//
6// Redistribution and use in source and binary forms, with or without
7// modification, are permitted provided that the following conditions are
8// met:
9//
10//     * Redistributions of source code must retain the above copyright
11// notice, this list of conditions and the following disclaimer.
12//     * Redistributions in binary form must reproduce the above
13// copyright notice, this list of conditions and the following disclaimer
14// in the documentation and/or other materials provided with the
15// distribution.
16//     * Neither the name of Google Inc. nor the names of its
17// contributors may be used to endorse or promote products derived from
18// this software without specific prior written permission.
19//
20// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
32// AOSP change: ignore this file, since AOSP does not include
33// golang.org/x/sync/errgroup
34// +build ignore
35
36package proto_test
37
38import (
39	"bytes"
40	"fmt"
41	"io"
42	"reflect"
43	"sort"
44	"strings"
45	"testing"
46
47	"github.com/golang/protobuf/proto"
48	pb "github.com/golang/protobuf/proto/test_proto"
49	"golang.org/x/sync/errgroup"
50)
51
52func TestGetExtensionsWithMissingExtensions(t *testing.T) {
53	msg := &pb.MyMessage{}
54	ext1 := &pb.Ext{}
55	if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
56		t.Fatalf("Could not set ext1: %s", err)
57	}
58	exts, err := proto.GetExtensions(msg, []*proto.ExtensionDesc{
59		pb.E_Ext_More,
60		pb.E_Ext_Text,
61	})
62	if err != nil {
63		t.Fatalf("GetExtensions() failed: %s", err)
64	}
65	if exts[0] != ext1 {
66		t.Errorf("ext1 not in returned extensions: %T %v", exts[0], exts[0])
67	}
68	if exts[1] != nil {
69		t.Errorf("ext2 in returned extensions: %T %v", exts[1], exts[1])
70	}
71}
72
73func TestGetExtensionWithEmptyBuffer(t *testing.T) {
74	// Make sure that GetExtension returns an error if its
75	// undecoded buffer is empty.
76	msg := &pb.MyMessage{}
77	proto.SetRawExtension(msg, pb.E_Ext_More.Field, []byte{})
78	_, err := proto.GetExtension(msg, pb.E_Ext_More)
79	if want := io.ErrUnexpectedEOF; err != want {
80		t.Errorf("unexpected error in GetExtension from empty buffer: got %v, want %v", err, want)
81	}
82}
83
84func TestGetExtensionForIncompleteDesc(t *testing.T) {
85	msg := &pb.MyMessage{Count: proto.Int32(0)}
86	extdesc1 := &proto.ExtensionDesc{
87		ExtendedType:  (*pb.MyMessage)(nil),
88		ExtensionType: (*bool)(nil),
89		Field:         123456789,
90		Name:          "a.b",
91		Tag:           "varint,123456789,opt",
92	}
93	ext1 := proto.Bool(true)
94	if err := proto.SetExtension(msg, extdesc1, ext1); err != nil {
95		t.Fatalf("Could not set ext1: %s", err)
96	}
97	extdesc2 := &proto.ExtensionDesc{
98		ExtendedType:  (*pb.MyMessage)(nil),
99		ExtensionType: ([]byte)(nil),
100		Field:         123456790,
101		Name:          "a.c",
102		Tag:           "bytes,123456790,opt",
103	}
104	ext2 := []byte{0, 1, 2, 3, 4, 5, 6, 7}
105	if err := proto.SetExtension(msg, extdesc2, ext2); err != nil {
106		t.Fatalf("Could not set ext2: %s", err)
107	}
108	extdesc3 := &proto.ExtensionDesc{
109		ExtendedType:  (*pb.MyMessage)(nil),
110		ExtensionType: (*pb.Ext)(nil),
111		Field:         123456791,
112		Name:          "a.d",
113		Tag:           "bytes,123456791,opt",
114	}
115	ext3 := &pb.Ext{Data: proto.String("foo")}
116	if err := proto.SetExtension(msg, extdesc3, ext3); err != nil {
117		t.Fatalf("Could not set ext3: %s", err)
118	}
119
120	b, err := proto.Marshal(msg)
121	if err != nil {
122		t.Fatalf("Could not marshal msg: %v", err)
123	}
124	if err := proto.Unmarshal(b, msg); err != nil {
125		t.Fatalf("Could not unmarshal into msg: %v", err)
126	}
127
128	var expected proto.Buffer
129	if err := expected.EncodeVarint(uint64((extdesc1.Field << 3) | proto.WireVarint)); err != nil {
130		t.Fatalf("failed to compute expected prefix for ext1: %s", err)
131	}
132	if err := expected.EncodeVarint(1 /* bool true */); err != nil {
133		t.Fatalf("failed to compute expected value for ext1: %s", err)
134	}
135
136	if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc1.Field}); err != nil {
137		t.Fatalf("Failed to get raw value for ext1: %s", err)
138	} else if !reflect.DeepEqual(b, expected.Bytes()) {
139		t.Fatalf("Raw value for ext1: got %v, want %v", b, expected.Bytes())
140	}
141
142	expected = proto.Buffer{} // reset
143	if err := expected.EncodeVarint(uint64((extdesc2.Field << 3) | proto.WireBytes)); err != nil {
144		t.Fatalf("failed to compute expected prefix for ext2: %s", err)
145	}
146	if err := expected.EncodeRawBytes(ext2); err != nil {
147		t.Fatalf("failed to compute expected value for ext2: %s", err)
148	}
149
150	if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc2.Field}); err != nil {
151		t.Fatalf("Failed to get raw value for ext2: %s", err)
152	} else if !reflect.DeepEqual(b, expected.Bytes()) {
153		t.Fatalf("Raw value for ext2: got %v, want %v", b, expected.Bytes())
154	}
155
156	expected = proto.Buffer{} // reset
157	if err := expected.EncodeVarint(uint64((extdesc3.Field << 3) | proto.WireBytes)); err != nil {
158		t.Fatalf("failed to compute expected prefix for ext3: %s", err)
159	}
160	if b, err := proto.Marshal(ext3); err != nil {
161		t.Fatalf("failed to compute expected value for ext3: %s", err)
162	} else if err := expected.EncodeRawBytes(b); err != nil {
163		t.Fatalf("failed to compute expected value for ext3: %s", err)
164	}
165
166	if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc3.Field}); err != nil {
167		t.Fatalf("Failed to get raw value for ext3: %s", err)
168	} else if !reflect.DeepEqual(b, expected.Bytes()) {
169		t.Fatalf("Raw value for ext3: got %v, want %v", b, expected.Bytes())
170	}
171}
172
173func TestExtensionDescsWithUnregisteredExtensions(t *testing.T) {
174	msg := &pb.MyMessage{Count: proto.Int32(0)}
175	extdesc1 := pb.E_Ext_More
176	if descs, err := proto.ExtensionDescs(msg); len(descs) != 0 || err != nil {
177		t.Errorf("proto.ExtensionDescs: got %d descs, error %v; want 0, nil", len(descs), err)
178	}
179
180	ext1 := &pb.Ext{}
181	if err := proto.SetExtension(msg, extdesc1, ext1); err != nil {
182		t.Fatalf("Could not set ext1: %s", err)
183	}
184	extdesc2 := &proto.ExtensionDesc{
185		ExtendedType:  (*pb.MyMessage)(nil),
186		ExtensionType: (*bool)(nil),
187		Field:         123456789,
188		Name:          "a.b",
189		Tag:           "varint,123456789,opt",
190	}
191	ext2 := proto.Bool(false)
192	if err := proto.SetExtension(msg, extdesc2, ext2); err != nil {
193		t.Fatalf("Could not set ext2: %s", err)
194	}
195
196	b, err := proto.Marshal(msg)
197	if err != nil {
198		t.Fatalf("Could not marshal msg: %v", err)
199	}
200	if err := proto.Unmarshal(b, msg); err != nil {
201		t.Fatalf("Could not unmarshal into msg: %v", err)
202	}
203
204	descs, err := proto.ExtensionDescs(msg)
205	if err != nil {
206		t.Fatalf("proto.ExtensionDescs: got error %v", err)
207	}
208	sortExtDescs(descs)
209	wantDescs := []*proto.ExtensionDesc{extdesc1, {Field: extdesc2.Field}}
210	if !reflect.DeepEqual(descs, wantDescs) {
211		t.Errorf("proto.ExtensionDescs(msg) sorted extension ids: got %+v, want %+v", descs, wantDescs)
212	}
213}
214
215type ExtensionDescSlice []*proto.ExtensionDesc
216
217func (s ExtensionDescSlice) Len() int           { return len(s) }
218func (s ExtensionDescSlice) Less(i, j int) bool { return s[i].Field < s[j].Field }
219func (s ExtensionDescSlice) Swap(i, j int)      { s[i], s[j] = s[j], s[i] }
220
221func sortExtDescs(s []*proto.ExtensionDesc) {
222	sort.Sort(ExtensionDescSlice(s))
223}
224
225func TestGetExtensionStability(t *testing.T) {
226	check := func(m *pb.MyMessage) bool {
227		ext1, err := proto.GetExtension(m, pb.E_Ext_More)
228		if err != nil {
229			t.Fatalf("GetExtension() failed: %s", err)
230		}
231		ext2, err := proto.GetExtension(m, pb.E_Ext_More)
232		if err != nil {
233			t.Fatalf("GetExtension() failed: %s", err)
234		}
235		return ext1 == ext2
236	}
237	msg := &pb.MyMessage{Count: proto.Int32(4)}
238	ext0 := &pb.Ext{}
239	if err := proto.SetExtension(msg, pb.E_Ext_More, ext0); err != nil {
240		t.Fatalf("Could not set ext1: %s", ext0)
241	}
242	if !check(msg) {
243		t.Errorf("GetExtension() not stable before marshaling")
244	}
245	bb, err := proto.Marshal(msg)
246	if err != nil {
247		t.Fatalf("Marshal() failed: %s", err)
248	}
249	msg1 := &pb.MyMessage{}
250	err = proto.Unmarshal(bb, msg1)
251	if err != nil {
252		t.Fatalf("Unmarshal() failed: %s", err)
253	}
254	if !check(msg1) {
255		t.Errorf("GetExtension() not stable after unmarshaling")
256	}
257}
258
259func TestGetExtensionDefaults(t *testing.T) {
260	var setFloat64 float64 = 1
261	var setFloat32 float32 = 2
262	var setInt32 int32 = 3
263	var setInt64 int64 = 4
264	var setUint32 uint32 = 5
265	var setUint64 uint64 = 6
266	var setBool = true
267	var setBool2 = false
268	var setString = "Goodnight string"
269	var setBytes = []byte("Goodnight bytes")
270	var setEnum = pb.DefaultsMessage_TWO
271
272	type testcase struct {
273		ext  *proto.ExtensionDesc // Extension we are testing.
274		want interface{}          // Expected value of extension, or nil (meaning that GetExtension will fail).
275		def  interface{}          // Expected value of extension after ClearExtension().
276	}
277	tests := []testcase{
278		{pb.E_NoDefaultDouble, setFloat64, nil},
279		{pb.E_NoDefaultFloat, setFloat32, nil},
280		{pb.E_NoDefaultInt32, setInt32, nil},
281		{pb.E_NoDefaultInt64, setInt64, nil},
282		{pb.E_NoDefaultUint32, setUint32, nil},
283		{pb.E_NoDefaultUint64, setUint64, nil},
284		{pb.E_NoDefaultSint32, setInt32, nil},
285		{pb.E_NoDefaultSint64, setInt64, nil},
286		{pb.E_NoDefaultFixed32, setUint32, nil},
287		{pb.E_NoDefaultFixed64, setUint64, nil},
288		{pb.E_NoDefaultSfixed32, setInt32, nil},
289		{pb.E_NoDefaultSfixed64, setInt64, nil},
290		{pb.E_NoDefaultBool, setBool, nil},
291		{pb.E_NoDefaultBool, setBool2, nil},
292		{pb.E_NoDefaultString, setString, nil},
293		{pb.E_NoDefaultBytes, setBytes, nil},
294		{pb.E_NoDefaultEnum, setEnum, nil},
295		{pb.E_DefaultDouble, setFloat64, float64(3.1415)},
296		{pb.E_DefaultFloat, setFloat32, float32(3.14)},
297		{pb.E_DefaultInt32, setInt32, int32(42)},
298		{pb.E_DefaultInt64, setInt64, int64(43)},
299		{pb.E_DefaultUint32, setUint32, uint32(44)},
300		{pb.E_DefaultUint64, setUint64, uint64(45)},
301		{pb.E_DefaultSint32, setInt32, int32(46)},
302		{pb.E_DefaultSint64, setInt64, int64(47)},
303		{pb.E_DefaultFixed32, setUint32, uint32(48)},
304		{pb.E_DefaultFixed64, setUint64, uint64(49)},
305		{pb.E_DefaultSfixed32, setInt32, int32(50)},
306		{pb.E_DefaultSfixed64, setInt64, int64(51)},
307		{pb.E_DefaultBool, setBool, true},
308		{pb.E_DefaultBool, setBool2, true},
309		{pb.E_DefaultString, setString, "Hello, string,def=foo"},
310		{pb.E_DefaultBytes, setBytes, []byte("Hello, bytes")},
311		{pb.E_DefaultEnum, setEnum, pb.DefaultsMessage_ONE},
312	}
313
314	checkVal := func(test testcase, msg *pb.DefaultsMessage, valWant interface{}) error {
315		val, err := proto.GetExtension(msg, test.ext)
316		if err != nil {
317			if valWant != nil {
318				return fmt.Errorf("GetExtension(): %s", err)
319			}
320			if want := proto.ErrMissingExtension; err != want {
321				return fmt.Errorf("Unexpected error: got %v, want %v", err, want)
322			}
323			return nil
324		}
325
326		// All proto2 extension values are either a pointer to a value or a slice of values.
327		ty := reflect.TypeOf(val)
328		tyWant := reflect.TypeOf(test.ext.ExtensionType)
329		if got, want := ty, tyWant; got != want {
330			return fmt.Errorf("unexpected reflect.TypeOf(): got %v want %v", got, want)
331		}
332		tye := ty.Elem()
333		tyeWant := tyWant.Elem()
334		if got, want := tye, tyeWant; got != want {
335			return fmt.Errorf("unexpected reflect.TypeOf().Elem(): got %v want %v", got, want)
336		}
337
338		// Check the name of the type of the value.
339		// If it is an enum it will be type int32 with the name of the enum.
340		if got, want := tye.Name(), tye.Name(); got != want {
341			return fmt.Errorf("unexpected reflect.TypeOf().Elem().Name(): got %v want %v", got, want)
342		}
343
344		// Check that value is what we expect.
345		// If we have a pointer in val, get the value it points to.
346		valExp := val
347		if ty.Kind() == reflect.Ptr {
348			valExp = reflect.ValueOf(val).Elem().Interface()
349		}
350		if got, want := valExp, valWant; !reflect.DeepEqual(got, want) {
351			return fmt.Errorf("unexpected reflect.DeepEqual(): got %v want %v", got, want)
352		}
353
354		return nil
355	}
356
357	setTo := func(test testcase) interface{} {
358		setTo := reflect.ValueOf(test.want)
359		if typ := reflect.TypeOf(test.ext.ExtensionType); typ.Kind() == reflect.Ptr {
360			setTo = reflect.New(typ).Elem()
361			setTo.Set(reflect.New(setTo.Type().Elem()))
362			setTo.Elem().Set(reflect.ValueOf(test.want))
363		}
364		return setTo.Interface()
365	}
366
367	for _, test := range tests {
368		msg := &pb.DefaultsMessage{}
369		name := test.ext.Name
370
371		// Check the initial value.
372		if err := checkVal(test, msg, test.def); err != nil {
373			t.Errorf("%s: %v", name, err)
374		}
375
376		// Set the per-type value and check value.
377		name = fmt.Sprintf("%s (set to %T %v)", name, test.want, test.want)
378		if err := proto.SetExtension(msg, test.ext, setTo(test)); err != nil {
379			t.Errorf("%s: SetExtension(): %v", name, err)
380			continue
381		}
382		if err := checkVal(test, msg, test.want); err != nil {
383			t.Errorf("%s: %v", name, err)
384			continue
385		}
386
387		// Set and check the value.
388		name += " (cleared)"
389		proto.ClearExtension(msg, test.ext)
390		if err := checkVal(test, msg, test.def); err != nil {
391			t.Errorf("%s: %v", name, err)
392		}
393	}
394}
395
396func TestNilMessage(t *testing.T) {
397	name := "nil interface"
398	if got, err := proto.GetExtension(nil, pb.E_Ext_More); err == nil {
399		t.Errorf("%s: got %T %v, expected to fail", name, got, got)
400	} else if !strings.Contains(err.Error(), "extendable") {
401		t.Errorf("%s: got error %v, expected not-extendable error", name, err)
402	}
403
404	// Regression tests: all functions of the Extension API
405	// used to panic when passed (*M)(nil), where M is a concrete message
406	// type.  Now they handle this gracefully as a no-op or reported error.
407	var nilMsg *pb.MyMessage
408	desc := pb.E_Ext_More
409
410	isNotExtendable := func(err error) bool {
411		return strings.Contains(fmt.Sprint(err), "not extendable")
412	}
413
414	if proto.HasExtension(nilMsg, desc) {
415		t.Error("HasExtension(nil) = true")
416	}
417
418	if _, err := proto.GetExtensions(nilMsg, []*proto.ExtensionDesc{desc}); !isNotExtendable(err) {
419		t.Errorf("GetExtensions(nil) = %q (wrong error)", err)
420	}
421
422	if _, err := proto.ExtensionDescs(nilMsg); !isNotExtendable(err) {
423		t.Errorf("ExtensionDescs(nil) = %q (wrong error)", err)
424	}
425
426	if err := proto.SetExtension(nilMsg, desc, nil); !isNotExtendable(err) {
427		t.Errorf("SetExtension(nil) = %q (wrong error)", err)
428	}
429
430	proto.ClearExtension(nilMsg, desc) // no-op
431	proto.ClearAllExtensions(nilMsg)   // no-op
432}
433
434func TestExtensionsRoundTrip(t *testing.T) {
435	msg := &pb.MyMessage{}
436	ext1 := &pb.Ext{
437		Data: proto.String("hi"),
438	}
439	ext2 := &pb.Ext{
440		Data: proto.String("there"),
441	}
442	exists := proto.HasExtension(msg, pb.E_Ext_More)
443	if exists {
444		t.Error("Extension More present unexpectedly")
445	}
446	if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
447		t.Error(err)
448	}
449	if err := proto.SetExtension(msg, pb.E_Ext_More, ext2); err != nil {
450		t.Error(err)
451	}
452	e, err := proto.GetExtension(msg, pb.E_Ext_More)
453	if err != nil {
454		t.Error(err)
455	}
456	x, ok := e.(*pb.Ext)
457	if !ok {
458		t.Errorf("e has type %T, expected test_proto.Ext", e)
459	} else if *x.Data != "there" {
460		t.Errorf("SetExtension failed to overwrite, got %+v, not 'there'", x)
461	}
462	proto.ClearExtension(msg, pb.E_Ext_More)
463	if _, err = proto.GetExtension(msg, pb.E_Ext_More); err != proto.ErrMissingExtension {
464		t.Errorf("got %v, expected ErrMissingExtension", e)
465	}
466	if _, err := proto.GetExtension(msg, pb.E_X215); err == nil {
467		t.Error("expected bad extension error, got nil")
468	}
469	if err := proto.SetExtension(msg, pb.E_X215, 12); err == nil {
470		t.Error("expected extension err")
471	}
472	if err := proto.SetExtension(msg, pb.E_Ext_More, 12); err == nil {
473		t.Error("expected some sort of type mismatch error, got nil")
474	}
475}
476
477func TestNilExtension(t *testing.T) {
478	msg := &pb.MyMessage{
479		Count: proto.Int32(1),
480	}
481	if err := proto.SetExtension(msg, pb.E_Ext_Text, proto.String("hello")); err != nil {
482		t.Fatal(err)
483	}
484	if err := proto.SetExtension(msg, pb.E_Ext_More, (*pb.Ext)(nil)); err == nil {
485		t.Error("expected SetExtension to fail due to a nil extension")
486	} else if want := fmt.Sprintf("proto: SetExtension called with nil value of type %T", new(pb.Ext)); err.Error() != want {
487		t.Errorf("expected error %v, got %v", want, err)
488	}
489	// Note: if the behavior of Marshal is ever changed to ignore nil extensions, update
490	// this test to verify that E_Ext_Text is properly propagated through marshal->unmarshal.
491}
492
493func TestMarshalUnmarshalRepeatedExtension(t *testing.T) {
494	// Add a repeated extension to the result.
495	tests := []struct {
496		name string
497		ext  []*pb.ComplexExtension
498	}{
499		{
500			"two fields",
501			[]*pb.ComplexExtension{
502				{First: proto.Int32(7)},
503				{Second: proto.Int32(11)},
504			},
505		},
506		{
507			"repeated field",
508			[]*pb.ComplexExtension{
509				{Third: []int32{1000}},
510				{Third: []int32{2000}},
511			},
512		},
513		{
514			"two fields and repeated field",
515			[]*pb.ComplexExtension{
516				{Third: []int32{1000}},
517				{First: proto.Int32(9)},
518				{Second: proto.Int32(21)},
519				{Third: []int32{2000}},
520			},
521		},
522	}
523	for _, test := range tests {
524		// Marshal message with a repeated extension.
525		msg1 := new(pb.OtherMessage)
526		err := proto.SetExtension(msg1, pb.E_RComplex, test.ext)
527		if err != nil {
528			t.Fatalf("[%s] Error setting extension: %v", test.name, err)
529		}
530		b, err := proto.Marshal(msg1)
531		if err != nil {
532			t.Fatalf("[%s] Error marshaling message: %v", test.name, err)
533		}
534
535		// Unmarshal and read the merged proto.
536		msg2 := new(pb.OtherMessage)
537		err = proto.Unmarshal(b, msg2)
538		if err != nil {
539			t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
540		}
541		e, err := proto.GetExtension(msg2, pb.E_RComplex)
542		if err != nil {
543			t.Fatalf("[%s] Error getting extension: %v", test.name, err)
544		}
545		ext := e.([]*pb.ComplexExtension)
546		if ext == nil {
547			t.Fatalf("[%s] Invalid extension", test.name)
548		}
549		if len(ext) != len(test.ext) {
550			t.Errorf("[%s] Wrong length of ComplexExtension: got: %v want: %v\n", test.name, len(ext), len(test.ext))
551		}
552		for i := range test.ext {
553			if !proto.Equal(ext[i], test.ext[i]) {
554				t.Errorf("[%s] Wrong value for ComplexExtension[%d]: got: %v want: %v\n", test.name, i, ext[i], test.ext[i])
555			}
556		}
557	}
558}
559
560func TestUnmarshalRepeatingNonRepeatedExtension(t *testing.T) {
561	// We may see multiple instances of the same extension in the wire
562	// format. For example, the proto compiler may encode custom options in
563	// this way. Here, we verify that we merge the extensions together.
564	tests := []struct {
565		name string
566		ext  []*pb.ComplexExtension
567	}{
568		{
569			"two fields",
570			[]*pb.ComplexExtension{
571				{First: proto.Int32(7)},
572				{Second: proto.Int32(11)},
573			},
574		},
575		{
576			"repeated field",
577			[]*pb.ComplexExtension{
578				{Third: []int32{1000}},
579				{Third: []int32{2000}},
580			},
581		},
582		{
583			"two fields and repeated field",
584			[]*pb.ComplexExtension{
585				{Third: []int32{1000}},
586				{First: proto.Int32(9)},
587				{Second: proto.Int32(21)},
588				{Third: []int32{2000}},
589			},
590		},
591	}
592	for _, test := range tests {
593		var buf bytes.Buffer
594		var want pb.ComplexExtension
595
596		// Generate a serialized representation of a repeated extension
597		// by catenating bytes together.
598		for i, e := range test.ext {
599			// Merge to create the wanted proto.
600			proto.Merge(&want, e)
601
602			// serialize the message
603			msg := new(pb.OtherMessage)
604			err := proto.SetExtension(msg, pb.E_Complex, e)
605			if err != nil {
606				t.Fatalf("[%s] Error setting extension %d: %v", test.name, i, err)
607			}
608			b, err := proto.Marshal(msg)
609			if err != nil {
610				t.Fatalf("[%s] Error marshaling message %d: %v", test.name, i, err)
611			}
612			buf.Write(b)
613		}
614
615		// Unmarshal and read the merged proto.
616		msg2 := new(pb.OtherMessage)
617		err := proto.Unmarshal(buf.Bytes(), msg2)
618		if err != nil {
619			t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
620		}
621		e, err := proto.GetExtension(msg2, pb.E_Complex)
622		if err != nil {
623			t.Fatalf("[%s] Error getting extension: %v", test.name, err)
624		}
625		ext := e.(*pb.ComplexExtension)
626		if ext == nil {
627			t.Fatalf("[%s] Invalid extension", test.name)
628		}
629		if !proto.Equal(ext, &want) {
630			t.Errorf("[%s] Wrong value for ComplexExtension: got: %s want: %s\n", test.name, ext, &want)
631		}
632	}
633}
634
635func TestClearAllExtensions(t *testing.T) {
636	// unregistered extension
637	desc := &proto.ExtensionDesc{
638		ExtendedType:  (*pb.MyMessage)(nil),
639		ExtensionType: (*bool)(nil),
640		Field:         101010100,
641		Name:          "emptyextension",
642		Tag:           "varint,0,opt",
643	}
644	m := &pb.MyMessage{}
645	if proto.HasExtension(m, desc) {
646		t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m))
647	}
648	if err := proto.SetExtension(m, desc, proto.Bool(true)); err != nil {
649		t.Errorf("proto.SetExtension(m, desc, true): got error %q, want nil", err)
650	}
651	if !proto.HasExtension(m, desc) {
652		t.Errorf("proto.HasExtension(%s): got false, want true", proto.MarshalTextString(m))
653	}
654	proto.ClearAllExtensions(m)
655	if proto.HasExtension(m, desc) {
656		t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m))
657	}
658}
659
660func TestMarshalRace(t *testing.T) {
661	ext := &pb.Ext{}
662	m := &pb.MyMessage{Count: proto.Int32(4)}
663	if err := proto.SetExtension(m, pb.E_Ext_More, ext); err != nil {
664		t.Fatalf("proto.SetExtension(m, desc, true): got error %q, want nil", err)
665	}
666
667	b, err := proto.Marshal(m)
668	if err != nil {
669		t.Fatalf("Could not marshal message: %v", err)
670	}
671	if err := proto.Unmarshal(b, m); err != nil {
672		t.Fatalf("Could not unmarshal message: %v", err)
673	}
674	// after Unmarshal, the extension is in undecoded form.
675	// GetExtension will decode it lazily. Make sure this does
676	// not race against Marshal.
677
678	var g errgroup.Group
679	for n := 3; n > 0; n-- {
680		g.Go(func() error {
681			_, err := proto.Marshal(m)
682			return err
683		})
684		g.Go(func() error {
685			_, err := proto.GetExtension(m, pb.E_Ext_More)
686			return err
687		})
688	}
689	if err := g.Wait(); err != nil {
690		t.Fatal(err)
691	}
692}
693