• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8	"math/bits"
9
10	"google.golang.org/protobuf/encoding/protowire"
11	"google.golang.org/protobuf/internal/errors"
12	"google.golang.org/protobuf/internal/flags"
13	"google.golang.org/protobuf/proto"
14	"google.golang.org/protobuf/reflect/protoreflect"
15	preg "google.golang.org/protobuf/reflect/protoregistry"
16	"google.golang.org/protobuf/runtime/protoiface"
17	piface "google.golang.org/protobuf/runtime/protoiface"
18)
19
20var errDecode = errors.New("cannot parse invalid wire-format data")
21var errRecursionDepth = errors.New("exceeded maximum recursion depth")
22
23type unmarshalOptions struct {
24	flags    protoiface.UnmarshalInputFlags
25	resolver interface {
26		FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
27		FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
28	}
29	depth int
30}
31
32func (o unmarshalOptions) Options() proto.UnmarshalOptions {
33	return proto.UnmarshalOptions{
34		Merge:          true,
35		AllowPartial:   true,
36		DiscardUnknown: o.DiscardUnknown(),
37		Resolver:       o.resolver,
38	}
39}
40
41func (o unmarshalOptions) DiscardUnknown() bool { return o.flags&piface.UnmarshalDiscardUnknown != 0 }
42
43func (o unmarshalOptions) IsDefault() bool {
44	return o.flags == 0 && o.resolver == preg.GlobalTypes
45}
46
47var lazyUnmarshalOptions = unmarshalOptions{
48	resolver: preg.GlobalTypes,
49	depth:    protowire.DefaultRecursionLimit,
50}
51
52type unmarshalOutput struct {
53	n           int // number of bytes consumed
54	initialized bool
55}
56
57// unmarshal is protoreflect.Methods.Unmarshal.
58func (mi *MessageInfo) unmarshal(in piface.UnmarshalInput) (piface.UnmarshalOutput, error) {
59	var p pointer
60	if ms, ok := in.Message.(*messageState); ok {
61		p = ms.pointer()
62	} else {
63		p = in.Message.(*messageReflectWrapper).pointer()
64	}
65	out, err := mi.unmarshalPointer(in.Buf, p, 0, unmarshalOptions{
66		flags:    in.Flags,
67		resolver: in.Resolver,
68		depth:    in.Depth,
69	})
70	var flags piface.UnmarshalOutputFlags
71	if out.initialized {
72		flags |= piface.UnmarshalInitialized
73	}
74	return piface.UnmarshalOutput{
75		Flags: flags,
76	}, err
77}
78
79// errUnknown is returned during unmarshaling to indicate a parse error that
80// should result in a field being placed in the unknown fields section (for example,
81// when the wire type doesn't match) as opposed to the entire unmarshal operation
82// failing (for example, when a field extends past the available input).
83//
84// This is a sentinel error which should never be visible to the user.
85var errUnknown = errors.New("unknown")
86
87func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) {
88	mi.init()
89	opts.depth--
90	if opts.depth < 0 {
91		return out, errRecursionDepth
92	}
93	if flags.ProtoLegacy && mi.isMessageSet {
94		return unmarshalMessageSet(mi, b, p, opts)
95	}
96	initialized := true
97	var requiredMask uint64
98	var exts *map[int32]ExtensionField
99	start := len(b)
100	for len(b) > 0 {
101		// Parse the tag (field number and wire type).
102		var tag uint64
103		if b[0] < 0x80 {
104			tag = uint64(b[0])
105			b = b[1:]
106		} else if len(b) >= 2 && b[1] < 128 {
107			tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
108			b = b[2:]
109		} else {
110			var n int
111			tag, n = protowire.ConsumeVarint(b)
112			if n < 0 {
113				return out, errDecode
114			}
115			b = b[n:]
116		}
117		var num protowire.Number
118		if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
119			return out, errDecode
120		} else {
121			num = protowire.Number(n)
122		}
123		wtyp := protowire.Type(tag & 7)
124
125		if wtyp == protowire.EndGroupType {
126			if num != groupTag {
127				return out, errDecode
128			}
129			groupTag = 0
130			break
131		}
132
133		var f *coderFieldInfo
134		if int(num) < len(mi.denseCoderFields) {
135			f = mi.denseCoderFields[num]
136		} else {
137			f = mi.coderFields[num]
138		}
139		var n int
140		err := errUnknown
141		switch {
142		case f != nil:
143			if f.funcs.unmarshal == nil {
144				break
145			}
146			var o unmarshalOutput
147			o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts)
148			n = o.n
149			if err != nil {
150				break
151			}
152			requiredMask |= f.validation.requiredBit
153			if f.funcs.isInit != nil && !o.initialized {
154				initialized = false
155			}
156		default:
157			// Possible extension.
158			if exts == nil && mi.extensionOffset.IsValid() {
159				exts = p.Apply(mi.extensionOffset).Extensions()
160				if *exts == nil {
161					*exts = make(map[int32]ExtensionField)
162				}
163			}
164			if exts == nil {
165				break
166			}
167			var o unmarshalOutput
168			o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts)
169			if err != nil {
170				break
171			}
172			n = o.n
173			if !o.initialized {
174				initialized = false
175			}
176		}
177		if err != nil {
178			if err != errUnknown {
179				return out, err
180			}
181			n = protowire.ConsumeFieldValue(num, wtyp, b)
182			if n < 0 {
183				return out, errDecode
184			}
185			if !opts.DiscardUnknown() && mi.unknownOffset.IsValid() {
186				u := mi.mutableUnknownBytes(p)
187				*u = protowire.AppendTag(*u, num, wtyp)
188				*u = append(*u, b[:n]...)
189			}
190		}
191		b = b[n:]
192	}
193	if groupTag != 0 {
194		return out, errDecode
195	}
196	if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) {
197		initialized = false
198	}
199	if initialized {
200		out.initialized = true
201	}
202	out.n = start - len(b)
203	return out, nil
204}
205
206func (mi *MessageInfo) unmarshalExtension(b []byte, num protowire.Number, wtyp protowire.Type, exts map[int32]ExtensionField, opts unmarshalOptions) (out unmarshalOutput, err error) {
207	x := exts[int32(num)]
208	xt := x.Type()
209	if xt == nil {
210		var err error
211		xt, err = opts.resolver.FindExtensionByNumber(mi.Desc.FullName(), num)
212		if err != nil {
213			if err == preg.NotFound {
214				return out, errUnknown
215			}
216			return out, errors.New("%v: unable to resolve extension %v: %v", mi.Desc.FullName(), num, err)
217		}
218	}
219	xi := getExtensionFieldInfo(xt)
220	if xi.funcs.unmarshal == nil {
221		return out, errUnknown
222	}
223	if flags.LazyUnmarshalExtensions {
224		if opts.IsDefault() && x.canLazy(xt) {
225			out, valid := skipExtension(b, xi, num, wtyp, opts)
226			switch valid {
227			case ValidationValid:
228				if out.initialized {
229					x.appendLazyBytes(xt, xi, num, wtyp, b[:out.n])
230					exts[int32(num)] = x
231					return out, nil
232				}
233			case ValidationInvalid:
234				return out, errDecode
235			case ValidationUnknown:
236			}
237		}
238	}
239	ival := x.Value()
240	if !ival.IsValid() && xi.unmarshalNeedsValue {
241		// Create a new message, list, or map value to fill in.
242		// For enums, create a prototype value to let the unmarshal func know the
243		// concrete type.
244		ival = xt.New()
245	}
246	v, out, err := xi.funcs.unmarshal(b, ival, num, wtyp, opts)
247	if err != nil {
248		return out, err
249	}
250	if xi.funcs.isInit == nil {
251		out.initialized = true
252	}
253	x.Set(xt, v)
254	exts[int32(num)] = x
255	return out, nil
256}
257
258func skipExtension(b []byte, xi *extensionFieldInfo, num protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) {
259	if xi.validation.mi == nil {
260		return out, ValidationUnknown
261	}
262	xi.validation.mi.init()
263	switch xi.validation.typ {
264	case validationTypeMessage:
265		if wtyp != protowire.BytesType {
266			return out, ValidationUnknown
267		}
268		v, n := protowire.ConsumeBytes(b)
269		if n < 0 {
270			return out, ValidationUnknown
271		}
272		out, st := xi.validation.mi.validate(v, 0, opts)
273		out.n = n
274		return out, st
275	case validationTypeGroup:
276		if wtyp != protowire.StartGroupType {
277			return out, ValidationUnknown
278		}
279		out, st := xi.validation.mi.validate(b, num, opts)
280		return out, st
281	default:
282		return out, ValidationUnknown
283	}
284}
285