1// Copyright 2018 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package proto 6 7import ( 8 "google.golang.org/protobuf/encoding/protowire" 9 "google.golang.org/protobuf/internal/encoding/messageset" 10 "google.golang.org/protobuf/internal/errors" 11 "google.golang.org/protobuf/internal/flags" 12 "google.golang.org/protobuf/internal/genid" 13 "google.golang.org/protobuf/internal/pragma" 14 "google.golang.org/protobuf/reflect/protoreflect" 15 "google.golang.org/protobuf/reflect/protoregistry" 16 "google.golang.org/protobuf/runtime/protoiface" 17) 18 19// UnmarshalOptions configures the unmarshaler. 20// 21// Example usage: 22// err := UnmarshalOptions{DiscardUnknown: true}.Unmarshal(b, m) 23type UnmarshalOptions struct { 24 pragma.NoUnkeyedLiterals 25 26 // Merge merges the input into the destination message. 27 // The default behavior is to always reset the message before unmarshaling, 28 // unless Merge is specified. 29 Merge bool 30 31 // AllowPartial accepts input for messages that will result in missing 32 // required fields. If AllowPartial is false (the default), Unmarshal will 33 // return an error if there are any missing required fields. 34 AllowPartial bool 35 36 // If DiscardUnknown is set, unknown fields are ignored. 37 DiscardUnknown bool 38 39 // Resolver is used for looking up types when unmarshaling extension fields. 40 // If nil, this defaults to using protoregistry.GlobalTypes. 41 Resolver interface { 42 FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) 43 FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) 44 } 45 46 // RecursionLimit limits how deeply messages may be nested. 47 // If zero, a default limit is applied. 48 RecursionLimit int 49} 50 51// Unmarshal parses the wire-format message in b and places the result in m. 52// The provided message must be mutable (e.g., a non-nil pointer to a message). 53func Unmarshal(b []byte, m Message) error { 54 _, err := UnmarshalOptions{RecursionLimit: protowire.DefaultRecursionLimit}.unmarshal(b, m.ProtoReflect()) 55 return err 56} 57 58// Unmarshal parses the wire-format message in b and places the result in m. 59// The provided message must be mutable (e.g., a non-nil pointer to a message). 60func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error { 61 if o.RecursionLimit == 0 { 62 o.RecursionLimit = protowire.DefaultRecursionLimit 63 } 64 _, err := o.unmarshal(b, m.ProtoReflect()) 65 return err 66} 67 68// UnmarshalState parses a wire-format message and places the result in m. 69// 70// This method permits fine-grained control over the unmarshaler. 71// Most users should use Unmarshal instead. 72func (o UnmarshalOptions) UnmarshalState(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) { 73 if o.RecursionLimit == 0 { 74 o.RecursionLimit = protowire.DefaultRecursionLimit 75 } 76 return o.unmarshal(in.Buf, in.Message) 77} 78 79// unmarshal is a centralized function that all unmarshal operations go through. 80// For profiling purposes, avoid changing the name of this function or 81// introducing other code paths for unmarshal that do not go through this. 82func (o UnmarshalOptions) unmarshal(b []byte, m protoreflect.Message) (out protoiface.UnmarshalOutput, err error) { 83 if o.Resolver == nil { 84 o.Resolver = protoregistry.GlobalTypes 85 } 86 if !o.Merge { 87 Reset(m.Interface()) 88 } 89 allowPartial := o.AllowPartial 90 o.Merge = true 91 o.AllowPartial = true 92 methods := protoMethods(m) 93 if methods != nil && methods.Unmarshal != nil && 94 !(o.DiscardUnknown && methods.Flags&protoiface.SupportUnmarshalDiscardUnknown == 0) { 95 in := protoiface.UnmarshalInput{ 96 Message: m, 97 Buf: b, 98 Resolver: o.Resolver, 99 Depth: o.RecursionLimit, 100 } 101 if o.DiscardUnknown { 102 in.Flags |= protoiface.UnmarshalDiscardUnknown 103 } 104 out, err = methods.Unmarshal(in) 105 } else { 106 o.RecursionLimit-- 107 if o.RecursionLimit < 0 { 108 return out, errors.New("exceeded max recursion depth") 109 } 110 err = o.unmarshalMessageSlow(b, m) 111 } 112 if err != nil { 113 return out, err 114 } 115 if allowPartial || (out.Flags&protoiface.UnmarshalInitialized != 0) { 116 return out, nil 117 } 118 return out, checkInitialized(m) 119} 120 121func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) error { 122 _, err := o.unmarshal(b, m) 123 return err 124} 125 126func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message) error { 127 md := m.Descriptor() 128 if messageset.IsMessageSet(md) { 129 return o.unmarshalMessageSet(b, m) 130 } 131 fields := md.Fields() 132 for len(b) > 0 { 133 // Parse the tag (field number and wire type). 134 num, wtyp, tagLen := protowire.ConsumeTag(b) 135 if tagLen < 0 { 136 return errDecode 137 } 138 if num > protowire.MaxValidNumber { 139 return errDecode 140 } 141 142 // Find the field descriptor for this field number. 143 fd := fields.ByNumber(num) 144 if fd == nil && md.ExtensionRanges().Has(num) { 145 extType, err := o.Resolver.FindExtensionByNumber(md.FullName(), num) 146 if err != nil && err != protoregistry.NotFound { 147 return errors.New("%v: unable to resolve extension %v: %v", md.FullName(), num, err) 148 } 149 if extType != nil { 150 fd = extType.TypeDescriptor() 151 } 152 } 153 var err error 154 if fd == nil { 155 err = errUnknown 156 } else if flags.ProtoLegacy { 157 if fd.IsWeak() && fd.Message().IsPlaceholder() { 158 err = errUnknown // weak referent is not linked in 159 } 160 } 161 162 // Parse the field value. 163 var valLen int 164 switch { 165 case err != nil: 166 case fd.IsList(): 167 valLen, err = o.unmarshalList(b[tagLen:], wtyp, m.Mutable(fd).List(), fd) 168 case fd.IsMap(): 169 valLen, err = o.unmarshalMap(b[tagLen:], wtyp, m.Mutable(fd).Map(), fd) 170 default: 171 valLen, err = o.unmarshalSingular(b[tagLen:], wtyp, m, fd) 172 } 173 if err != nil { 174 if err != errUnknown { 175 return err 176 } 177 valLen = protowire.ConsumeFieldValue(num, wtyp, b[tagLen:]) 178 if valLen < 0 { 179 return errDecode 180 } 181 if !o.DiscardUnknown { 182 m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...)) 183 } 184 } 185 b = b[tagLen+valLen:] 186 } 187 return nil 188} 189 190func (o UnmarshalOptions) unmarshalSingular(b []byte, wtyp protowire.Type, m protoreflect.Message, fd protoreflect.FieldDescriptor) (n int, err error) { 191 v, n, err := o.unmarshalScalar(b, wtyp, fd) 192 if err != nil { 193 return 0, err 194 } 195 switch fd.Kind() { 196 case protoreflect.GroupKind, protoreflect.MessageKind: 197 m2 := m.Mutable(fd).Message() 198 if err := o.unmarshalMessage(v.Bytes(), m2); err != nil { 199 return n, err 200 } 201 default: 202 // Non-message scalars replace the previous value. 203 m.Set(fd, v) 204 } 205 return n, nil 206} 207 208func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp protowire.Type, mapv protoreflect.Map, fd protoreflect.FieldDescriptor) (n int, err error) { 209 if wtyp != protowire.BytesType { 210 return 0, errUnknown 211 } 212 b, n = protowire.ConsumeBytes(b) 213 if n < 0 { 214 return 0, errDecode 215 } 216 var ( 217 keyField = fd.MapKey() 218 valField = fd.MapValue() 219 key protoreflect.Value 220 val protoreflect.Value 221 haveKey bool 222 haveVal bool 223 ) 224 switch valField.Kind() { 225 case protoreflect.GroupKind, protoreflect.MessageKind: 226 val = mapv.NewValue() 227 } 228 // Map entries are represented as a two-element message with fields 229 // containing the key and value. 230 for len(b) > 0 { 231 num, wtyp, n := protowire.ConsumeTag(b) 232 if n < 0 { 233 return 0, errDecode 234 } 235 if num > protowire.MaxValidNumber { 236 return 0, errDecode 237 } 238 b = b[n:] 239 err = errUnknown 240 switch num { 241 case genid.MapEntry_Key_field_number: 242 key, n, err = o.unmarshalScalar(b, wtyp, keyField) 243 if err != nil { 244 break 245 } 246 haveKey = true 247 case genid.MapEntry_Value_field_number: 248 var v protoreflect.Value 249 v, n, err = o.unmarshalScalar(b, wtyp, valField) 250 if err != nil { 251 break 252 } 253 switch valField.Kind() { 254 case protoreflect.GroupKind, protoreflect.MessageKind: 255 if err := o.unmarshalMessage(v.Bytes(), val.Message()); err != nil { 256 return 0, err 257 } 258 default: 259 val = v 260 } 261 haveVal = true 262 } 263 if err == errUnknown { 264 n = protowire.ConsumeFieldValue(num, wtyp, b) 265 if n < 0 { 266 return 0, errDecode 267 } 268 } else if err != nil { 269 return 0, err 270 } 271 b = b[n:] 272 } 273 // Every map entry should have entries for key and value, but this is not strictly required. 274 if !haveKey { 275 key = keyField.Default() 276 } 277 if !haveVal { 278 switch valField.Kind() { 279 case protoreflect.GroupKind, protoreflect.MessageKind: 280 default: 281 val = valField.Default() 282 } 283 } 284 mapv.Set(key.MapKey(), val) 285 return n, nil 286} 287 288// errUnknown is used internally to indicate fields which should be added 289// to the unknown field set of a message. It is never returned from an exported 290// function. 291var errUnknown = errors.New("BUG: internal error (unknown)") 292 293var errDecode = errors.New("cannot parse invalid wire-format data") 294