• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package proto
6
7import (
8	"google.golang.org/protobuf/encoding/protowire"
9	"google.golang.org/protobuf/internal/encoding/messageset"
10	"google.golang.org/protobuf/internal/errors"
11	"google.golang.org/protobuf/internal/flags"
12	"google.golang.org/protobuf/internal/genid"
13	"google.golang.org/protobuf/internal/pragma"
14	"google.golang.org/protobuf/reflect/protoreflect"
15	"google.golang.org/protobuf/reflect/protoregistry"
16	"google.golang.org/protobuf/runtime/protoiface"
17)
18
19// UnmarshalOptions configures the unmarshaler.
20//
21// Example usage:
22//   err := UnmarshalOptions{DiscardUnknown: true}.Unmarshal(b, m)
23type UnmarshalOptions struct {
24	pragma.NoUnkeyedLiterals
25
26	// Merge merges the input into the destination message.
27	// The default behavior is to always reset the message before unmarshaling,
28	// unless Merge is specified.
29	Merge bool
30
31	// AllowPartial accepts input for messages that will result in missing
32	// required fields. If AllowPartial is false (the default), Unmarshal will
33	// return an error if there are any missing required fields.
34	AllowPartial bool
35
36	// If DiscardUnknown is set, unknown fields are ignored.
37	DiscardUnknown bool
38
39	// Resolver is used for looking up types when unmarshaling extension fields.
40	// If nil, this defaults to using protoregistry.GlobalTypes.
41	Resolver interface {
42		FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
43		FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
44	}
45
46	// RecursionLimit limits how deeply messages may be nested.
47	// If zero, a default limit is applied.
48	RecursionLimit int
49}
50
51// Unmarshal parses the wire-format message in b and places the result in m.
52// The provided message must be mutable (e.g., a non-nil pointer to a message).
53func Unmarshal(b []byte, m Message) error {
54	_, err := UnmarshalOptions{RecursionLimit: protowire.DefaultRecursionLimit}.unmarshal(b, m.ProtoReflect())
55	return err
56}
57
58// Unmarshal parses the wire-format message in b and places the result in m.
59// The provided message must be mutable (e.g., a non-nil pointer to a message).
60func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
61	if o.RecursionLimit == 0 {
62		o.RecursionLimit = protowire.DefaultRecursionLimit
63	}
64	_, err := o.unmarshal(b, m.ProtoReflect())
65	return err
66}
67
68// UnmarshalState parses a wire-format message and places the result in m.
69//
70// This method permits fine-grained control over the unmarshaler.
71// Most users should use Unmarshal instead.
72func (o UnmarshalOptions) UnmarshalState(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
73	if o.RecursionLimit == 0 {
74		o.RecursionLimit = protowire.DefaultRecursionLimit
75	}
76	return o.unmarshal(in.Buf, in.Message)
77}
78
79// unmarshal is a centralized function that all unmarshal operations go through.
80// For profiling purposes, avoid changing the name of this function or
81// introducing other code paths for unmarshal that do not go through this.
82func (o UnmarshalOptions) unmarshal(b []byte, m protoreflect.Message) (out protoiface.UnmarshalOutput, err error) {
83	if o.Resolver == nil {
84		o.Resolver = protoregistry.GlobalTypes
85	}
86	if !o.Merge {
87		Reset(m.Interface())
88	}
89	allowPartial := o.AllowPartial
90	o.Merge = true
91	o.AllowPartial = true
92	methods := protoMethods(m)
93	if methods != nil && methods.Unmarshal != nil &&
94		!(o.DiscardUnknown && methods.Flags&protoiface.SupportUnmarshalDiscardUnknown == 0) {
95		in := protoiface.UnmarshalInput{
96			Message:  m,
97			Buf:      b,
98			Resolver: o.Resolver,
99			Depth:    o.RecursionLimit,
100		}
101		if o.DiscardUnknown {
102			in.Flags |= protoiface.UnmarshalDiscardUnknown
103		}
104		out, err = methods.Unmarshal(in)
105	} else {
106		o.RecursionLimit--
107		if o.RecursionLimit < 0 {
108			return out, errors.New("exceeded max recursion depth")
109		}
110		err = o.unmarshalMessageSlow(b, m)
111	}
112	if err != nil {
113		return out, err
114	}
115	if allowPartial || (out.Flags&protoiface.UnmarshalInitialized != 0) {
116		return out, nil
117	}
118	return out, checkInitialized(m)
119}
120
121func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) error {
122	_, err := o.unmarshal(b, m)
123	return err
124}
125
126func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message) error {
127	md := m.Descriptor()
128	if messageset.IsMessageSet(md) {
129		return o.unmarshalMessageSet(b, m)
130	}
131	fields := md.Fields()
132	for len(b) > 0 {
133		// Parse the tag (field number and wire type).
134		num, wtyp, tagLen := protowire.ConsumeTag(b)
135		if tagLen < 0 {
136			return errDecode
137		}
138		if num > protowire.MaxValidNumber {
139			return errDecode
140		}
141
142		// Find the field descriptor for this field number.
143		fd := fields.ByNumber(num)
144		if fd == nil && md.ExtensionRanges().Has(num) {
145			extType, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
146			if err != nil && err != protoregistry.NotFound {
147				return errors.New("%v: unable to resolve extension %v: %v", md.FullName(), num, err)
148			}
149			if extType != nil {
150				fd = extType.TypeDescriptor()
151			}
152		}
153		var err error
154		if fd == nil {
155			err = errUnknown
156		} else if flags.ProtoLegacy {
157			if fd.IsWeak() && fd.Message().IsPlaceholder() {
158				err = errUnknown // weak referent is not linked in
159			}
160		}
161
162		// Parse the field value.
163		var valLen int
164		switch {
165		case err != nil:
166		case fd.IsList():
167			valLen, err = o.unmarshalList(b[tagLen:], wtyp, m.Mutable(fd).List(), fd)
168		case fd.IsMap():
169			valLen, err = o.unmarshalMap(b[tagLen:], wtyp, m.Mutable(fd).Map(), fd)
170		default:
171			valLen, err = o.unmarshalSingular(b[tagLen:], wtyp, m, fd)
172		}
173		if err != nil {
174			if err != errUnknown {
175				return err
176			}
177			valLen = protowire.ConsumeFieldValue(num, wtyp, b[tagLen:])
178			if valLen < 0 {
179				return errDecode
180			}
181			if !o.DiscardUnknown {
182				m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...))
183			}
184		}
185		b = b[tagLen+valLen:]
186	}
187	return nil
188}
189
190func (o UnmarshalOptions) unmarshalSingular(b []byte, wtyp protowire.Type, m protoreflect.Message, fd protoreflect.FieldDescriptor) (n int, err error) {
191	v, n, err := o.unmarshalScalar(b, wtyp, fd)
192	if err != nil {
193		return 0, err
194	}
195	switch fd.Kind() {
196	case protoreflect.GroupKind, protoreflect.MessageKind:
197		m2 := m.Mutable(fd).Message()
198		if err := o.unmarshalMessage(v.Bytes(), m2); err != nil {
199			return n, err
200		}
201	default:
202		// Non-message scalars replace the previous value.
203		m.Set(fd, v)
204	}
205	return n, nil
206}
207
208func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp protowire.Type, mapv protoreflect.Map, fd protoreflect.FieldDescriptor) (n int, err error) {
209	if wtyp != protowire.BytesType {
210		return 0, errUnknown
211	}
212	b, n = protowire.ConsumeBytes(b)
213	if n < 0 {
214		return 0, errDecode
215	}
216	var (
217		keyField = fd.MapKey()
218		valField = fd.MapValue()
219		key      protoreflect.Value
220		val      protoreflect.Value
221		haveKey  bool
222		haveVal  bool
223	)
224	switch valField.Kind() {
225	case protoreflect.GroupKind, protoreflect.MessageKind:
226		val = mapv.NewValue()
227	}
228	// Map entries are represented as a two-element message with fields
229	// containing the key and value.
230	for len(b) > 0 {
231		num, wtyp, n := protowire.ConsumeTag(b)
232		if n < 0 {
233			return 0, errDecode
234		}
235		if num > protowire.MaxValidNumber {
236			return 0, errDecode
237		}
238		b = b[n:]
239		err = errUnknown
240		switch num {
241		case genid.MapEntry_Key_field_number:
242			key, n, err = o.unmarshalScalar(b, wtyp, keyField)
243			if err != nil {
244				break
245			}
246			haveKey = true
247		case genid.MapEntry_Value_field_number:
248			var v protoreflect.Value
249			v, n, err = o.unmarshalScalar(b, wtyp, valField)
250			if err != nil {
251				break
252			}
253			switch valField.Kind() {
254			case protoreflect.GroupKind, protoreflect.MessageKind:
255				if err := o.unmarshalMessage(v.Bytes(), val.Message()); err != nil {
256					return 0, err
257				}
258			default:
259				val = v
260			}
261			haveVal = true
262		}
263		if err == errUnknown {
264			n = protowire.ConsumeFieldValue(num, wtyp, b)
265			if n < 0 {
266				return 0, errDecode
267			}
268		} else if err != nil {
269			return 0, err
270		}
271		b = b[n:]
272	}
273	// Every map entry should have entries for key and value, but this is not strictly required.
274	if !haveKey {
275		key = keyField.Default()
276	}
277	if !haveVal {
278		switch valField.Kind() {
279		case protoreflect.GroupKind, protoreflect.MessageKind:
280		default:
281			val = valField.Default()
282		}
283	}
284	mapv.Set(key.MapKey(), val)
285	return n, nil
286}
287
288// errUnknown is used internally to indicate fields which should be added
289// to the unknown field set of a message. It is never returned from an exported
290// function.
291var errUnknown = errors.New("BUG: internal error (unknown)")
292
293var errDecode = errors.New("cannot parse invalid wire-format data")
294