1// Copyright 2019 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package impl 6 7import ( 8 "math/bits" 9 10 "google.golang.org/protobuf/encoding/protowire" 11 "google.golang.org/protobuf/internal/errors" 12 "google.golang.org/protobuf/internal/flags" 13 "google.golang.org/protobuf/proto" 14 "google.golang.org/protobuf/reflect/protoreflect" 15 preg "google.golang.org/protobuf/reflect/protoregistry" 16 "google.golang.org/protobuf/runtime/protoiface" 17 piface "google.golang.org/protobuf/runtime/protoiface" 18) 19 20var errDecode = errors.New("cannot parse invalid wire-format data") 21var errRecursionDepth = errors.New("exceeded maximum recursion depth") 22 23type unmarshalOptions struct { 24 flags protoiface.UnmarshalInputFlags 25 resolver interface { 26 FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) 27 FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) 28 } 29 depth int 30} 31 32func (o unmarshalOptions) Options() proto.UnmarshalOptions { 33 return proto.UnmarshalOptions{ 34 Merge: true, 35 AllowPartial: true, 36 DiscardUnknown: o.DiscardUnknown(), 37 Resolver: o.resolver, 38 } 39} 40 41func (o unmarshalOptions) DiscardUnknown() bool { return o.flags&piface.UnmarshalDiscardUnknown != 0 } 42 43func (o unmarshalOptions) IsDefault() bool { 44 return o.flags == 0 && o.resolver == preg.GlobalTypes 45} 46 47var lazyUnmarshalOptions = unmarshalOptions{ 48 resolver: preg.GlobalTypes, 49 depth: protowire.DefaultRecursionLimit, 50} 51 52type unmarshalOutput struct { 53 n int // number of bytes consumed 54 initialized bool 55} 56 57// unmarshal is protoreflect.Methods.Unmarshal. 58func (mi *MessageInfo) unmarshal(in piface.UnmarshalInput) (piface.UnmarshalOutput, error) { 59 var p pointer 60 if ms, ok := in.Message.(*messageState); ok { 61 p = ms.pointer() 62 } else { 63 p = in.Message.(*messageReflectWrapper).pointer() 64 } 65 out, err := mi.unmarshalPointer(in.Buf, p, 0, unmarshalOptions{ 66 flags: in.Flags, 67 resolver: in.Resolver, 68 depth: in.Depth, 69 }) 70 var flags piface.UnmarshalOutputFlags 71 if out.initialized { 72 flags |= piface.UnmarshalInitialized 73 } 74 return piface.UnmarshalOutput{ 75 Flags: flags, 76 }, err 77} 78 79// errUnknown is returned during unmarshaling to indicate a parse error that 80// should result in a field being placed in the unknown fields section (for example, 81// when the wire type doesn't match) as opposed to the entire unmarshal operation 82// failing (for example, when a field extends past the available input). 83// 84// This is a sentinel error which should never be visible to the user. 85var errUnknown = errors.New("unknown") 86 87func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) { 88 mi.init() 89 opts.depth-- 90 if opts.depth < 0 { 91 return out, errRecursionDepth 92 } 93 if flags.ProtoLegacy && mi.isMessageSet { 94 return unmarshalMessageSet(mi, b, p, opts) 95 } 96 initialized := true 97 var requiredMask uint64 98 var exts *map[int32]ExtensionField 99 start := len(b) 100 for len(b) > 0 { 101 // Parse the tag (field number and wire type). 102 var tag uint64 103 if b[0] < 0x80 { 104 tag = uint64(b[0]) 105 b = b[1:] 106 } else if len(b) >= 2 && b[1] < 128 { 107 tag = uint64(b[0]&0x7f) + uint64(b[1])<<7 108 b = b[2:] 109 } else { 110 var n int 111 tag, n = protowire.ConsumeVarint(b) 112 if n < 0 { 113 return out, errDecode 114 } 115 b = b[n:] 116 } 117 var num protowire.Number 118 if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) { 119 return out, errDecode 120 } else { 121 num = protowire.Number(n) 122 } 123 wtyp := protowire.Type(tag & 7) 124 125 if wtyp == protowire.EndGroupType { 126 if num != groupTag { 127 return out, errDecode 128 } 129 groupTag = 0 130 break 131 } 132 133 var f *coderFieldInfo 134 if int(num) < len(mi.denseCoderFields) { 135 f = mi.denseCoderFields[num] 136 } else { 137 f = mi.coderFields[num] 138 } 139 var n int 140 err := errUnknown 141 switch { 142 case f != nil: 143 if f.funcs.unmarshal == nil { 144 break 145 } 146 var o unmarshalOutput 147 o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts) 148 n = o.n 149 if err != nil { 150 break 151 } 152 requiredMask |= f.validation.requiredBit 153 if f.funcs.isInit != nil && !o.initialized { 154 initialized = false 155 } 156 default: 157 // Possible extension. 158 if exts == nil && mi.extensionOffset.IsValid() { 159 exts = p.Apply(mi.extensionOffset).Extensions() 160 if *exts == nil { 161 *exts = make(map[int32]ExtensionField) 162 } 163 } 164 if exts == nil { 165 break 166 } 167 var o unmarshalOutput 168 o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts) 169 if err != nil { 170 break 171 } 172 n = o.n 173 if !o.initialized { 174 initialized = false 175 } 176 } 177 if err != nil { 178 if err != errUnknown { 179 return out, err 180 } 181 n = protowire.ConsumeFieldValue(num, wtyp, b) 182 if n < 0 { 183 return out, errDecode 184 } 185 if !opts.DiscardUnknown() && mi.unknownOffset.IsValid() { 186 u := mi.mutableUnknownBytes(p) 187 *u = protowire.AppendTag(*u, num, wtyp) 188 *u = append(*u, b[:n]...) 189 } 190 } 191 b = b[n:] 192 } 193 if groupTag != 0 { 194 return out, errDecode 195 } 196 if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) { 197 initialized = false 198 } 199 if initialized { 200 out.initialized = true 201 } 202 out.n = start - len(b) 203 return out, nil 204} 205 206func (mi *MessageInfo) unmarshalExtension(b []byte, num protowire.Number, wtyp protowire.Type, exts map[int32]ExtensionField, opts unmarshalOptions) (out unmarshalOutput, err error) { 207 x := exts[int32(num)] 208 xt := x.Type() 209 if xt == nil { 210 var err error 211 xt, err = opts.resolver.FindExtensionByNumber(mi.Desc.FullName(), num) 212 if err != nil { 213 if err == preg.NotFound { 214 return out, errUnknown 215 } 216 return out, errors.New("%v: unable to resolve extension %v: %v", mi.Desc.FullName(), num, err) 217 } 218 } 219 xi := getExtensionFieldInfo(xt) 220 if xi.funcs.unmarshal == nil { 221 return out, errUnknown 222 } 223 if flags.LazyUnmarshalExtensions { 224 if opts.IsDefault() && x.canLazy(xt) { 225 out, valid := skipExtension(b, xi, num, wtyp, opts) 226 switch valid { 227 case ValidationValid: 228 if out.initialized { 229 x.appendLazyBytes(xt, xi, num, wtyp, b[:out.n]) 230 exts[int32(num)] = x 231 return out, nil 232 } 233 case ValidationInvalid: 234 return out, errDecode 235 case ValidationUnknown: 236 } 237 } 238 } 239 ival := x.Value() 240 if !ival.IsValid() && xi.unmarshalNeedsValue { 241 // Create a new message, list, or map value to fill in. 242 // For enums, create a prototype value to let the unmarshal func know the 243 // concrete type. 244 ival = xt.New() 245 } 246 v, out, err := xi.funcs.unmarshal(b, ival, num, wtyp, opts) 247 if err != nil { 248 return out, err 249 } 250 if xi.funcs.isInit == nil { 251 out.initialized = true 252 } 253 x.Set(xt, v) 254 exts[int32(num)] = x 255 return out, nil 256} 257 258func skipExtension(b []byte, xi *extensionFieldInfo, num protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) { 259 if xi.validation.mi == nil { 260 return out, ValidationUnknown 261 } 262 xi.validation.mi.init() 263 switch xi.validation.typ { 264 case validationTypeMessage: 265 if wtyp != protowire.BytesType { 266 return out, ValidationUnknown 267 } 268 v, n := protowire.ConsumeBytes(b) 269 if n < 0 { 270 return out, ValidationUnknown 271 } 272 out, st := xi.validation.mi.validate(v, 0, opts) 273 out.n = n 274 return out, st 275 case validationTypeGroup: 276 if wtyp != protowire.StartGroupType { 277 return out, ValidationUnknown 278 } 279 out, st := xi.validation.mi.validate(b, num, opts) 280 return out, st 281 default: 282 return out, ValidationUnknown 283 } 284} 285