• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #region Copyright notice and license
2 // Protocol Buffers - Google's data interchange format
3 // Copyright 2008 Google Inc.  All rights reserved.
4 //
5 // Use of this source code is governed by a BSD-style
6 // license that can be found in the LICENSE file or at
7 // https://developers.google.com/open-source/licenses/bsd
8 #endregion
9 
10 using Google.Protobuf.Collections;
11 using System;
12 using System.Collections.Generic;
13 using System.Linq;
14 using System.Reflection;
15 using System.Security;
16 
17 namespace Google.Protobuf
18 {
19     /// <summary>
20     /// Methods for managing <see cref="ExtensionSet{TTarget}"/>s with null checking.
21     ///
22     /// Most users will not use this class directly and its API is experimental and subject to change.
23     /// </summary>
24     public static class ExtensionSet
25     {
26         private static bool TryGetValue<TTarget>(ref ExtensionSet<TTarget> set, Extension extension, out IExtensionValue value) where TTarget : IExtendableMessage<TTarget>
27         {
28             if (set == null)
29             {
30                 value = null;
31                 return false;
32             }
33             return set.ValuesByNumber.TryGetValue(extension.FieldNumber, out value);
34         }
35 
36         /// <summary>
37         /// Gets the value of the specified extension
38         /// </summary>
39         public static TValue Get<TTarget, TValue>(ref ExtensionSet<TTarget> set, Extension<TTarget, TValue> extension) where TTarget : IExtendableMessage<TTarget>
40         {
41             if (TryGetValue(ref set, extension, out IExtensionValue value))
42             {
43                 // The stored ExtensionValue can be a different type to what is being requested.
44                 // This happens when the same extension proto is compiled in different assemblies.
45                 // To allow consuming assemblies to still get the value when the TValue type is
46                 // different, this get method:
47                 // 1. Attempts to cast the value to the expected ExtensionValue<TValue>.
48                 //    This is the usual case. It is used first because it avoids possibly boxing the value.
49                 // 2. Fallback to get the value as object from IExtensionValue then casting.
50                 //    This allows for someone to specify a TValue of object. They can then convert
51                 //    the values to bytes and reparse using expected value.
52                 // 3. If neither of these work, throw a user friendly error that the types aren't compatible.
53                 if (value is ExtensionValue<TValue> extensionValue)
54                 {
55                     return extensionValue.GetValue();
56                 }
57                 else if (value.GetValue() is TValue underlyingValue)
58                 {
59                     return underlyingValue;
60                 }
61                 else
62                 {
63                     var valueType = value.GetType().GetTypeInfo();
64                     if (valueType.IsGenericType && valueType.GetGenericTypeDefinition() == typeof(ExtensionValue<>))
65                     {
66                         var storedType = valueType.GenericTypeArguments[0];
67                         throw new InvalidOperationException(
68                             "The stored extension value has a type of '" + storedType.AssemblyQualifiedName + "'. " +
69                             "This a different from the requested type of '" + typeof(TValue).AssemblyQualifiedName + "'.");
70                     }
71                     else
72                     {
73                         throw new InvalidOperationException("Unexpected extension value type: " + valueType.AssemblyQualifiedName);
74                     }
75                 }
76             }
77             else
78             {
79                 return extension.DefaultValue;
80             }
81         }
82 
83         /// <summary>
84         /// Gets the value of the specified repeated extension or null if it doesn't exist in this set
85         /// </summary>
86         public static RepeatedField<TValue> Get<TTarget, TValue>(ref ExtensionSet<TTarget> set, RepeatedExtension<TTarget, TValue> extension) where TTarget : IExtendableMessage<TTarget>
87         {
88             if (TryGetValue(ref set, extension, out IExtensionValue value))
89             {
90                 if (value is RepeatedExtensionValue<TValue> extensionValue)
91                 {
92                     return extensionValue.GetValue();
93                 }
94                 else
95                 {
96                     var valueType = value.GetType().GetTypeInfo();
97                     if (valueType.IsGenericType && valueType.GetGenericTypeDefinition() == typeof(RepeatedExtensionValue<>))
98                     {
99                         var storedType = valueType.GenericTypeArguments[0];
100                         throw new InvalidOperationException(
101                             "The stored extension value has a type of '" + storedType.AssemblyQualifiedName + "'. " +
102                             "This a different from the requested type of '" + typeof(TValue).AssemblyQualifiedName + "'.");
103                     }
104                     else
105                     {
106                         throw new InvalidOperationException("Unexpected extension value type: " + valueType.AssemblyQualifiedName);
107                     }
108                 }
109             }
110             else
111             {
112                 return null;
113             }
114         }
115 
116         /// <summary>
117         /// Gets the value of the specified repeated extension, registering it if it doesn't exist
118         /// </summary>
119         public static RepeatedField<TValue> GetOrInitialize<TTarget, TValue>(ref ExtensionSet<TTarget> set, RepeatedExtension<TTarget, TValue> extension) where TTarget : IExtendableMessage<TTarget>
120         {
121             IExtensionValue value;
122             if (set == null)
123             {
124                 value = extension.CreateValue();
125                 set = new ExtensionSet<TTarget>();
126                 set.ValuesByNumber.Add(extension.FieldNumber, value);
127             }
128             else
129             {
130                 if (!set.ValuesByNumber.TryGetValue(extension.FieldNumber, out value))
131                 {
132                     value = extension.CreateValue();
133                     set.ValuesByNumber.Add(extension.FieldNumber, value);
134                 }
135             }
136 
137             return ((RepeatedExtensionValue<TValue>)value).GetValue();
138         }
139 
140         /// <summary>
141         /// Sets the value of the specified extension. This will make a new instance of ExtensionSet if the set is null.
142         /// </summary>
143         public static void Set<TTarget, TValue>(ref ExtensionSet<TTarget> set, Extension<TTarget, TValue> extension, TValue value) where TTarget : IExtendableMessage<TTarget>
144         {
145             ProtoPreconditions.CheckNotNullUnconstrained(value, nameof(value));
146 
147             IExtensionValue extensionValue;
148             if (set == null)
149             {
150                 extensionValue = extension.CreateValue();
151                 set = new ExtensionSet<TTarget>();
152                 set.ValuesByNumber.Add(extension.FieldNumber, extensionValue);
153             }
154             else
155             {
156                 if (!set.ValuesByNumber.TryGetValue(extension.FieldNumber, out extensionValue))
157                 {
158                     extensionValue = extension.CreateValue();
159                     set.ValuesByNumber.Add(extension.FieldNumber, extensionValue);
160                 }
161             }
162 
163             ((ExtensionValue<TValue>)extensionValue).SetValue(value);
164         }
165 
166         /// <summary>
167         /// Gets whether the value of the specified extension is set
168         /// </summary>
169         public static bool Has<TTarget, TValue>(ref ExtensionSet<TTarget> set, Extension<TTarget, TValue> extension) where TTarget : IExtendableMessage<TTarget>
170         {
171             return TryGetValue(ref set, extension, out IExtensionValue _);
172         }
173 
174         /// <summary>
175         /// Clears the value of the specified extension
176         /// </summary>
177         public static void Clear<TTarget, TValue>(ref ExtensionSet<TTarget> set, Extension<TTarget, TValue> extension) where TTarget : IExtendableMessage<TTarget>
178         {
179             if (set == null)
180             {
181                 return;
182             }
183             set.ValuesByNumber.Remove(extension.FieldNumber);
184             if (set.ValuesByNumber.Count == 0)
185             {
186                 set = null;
187             }
188         }
189 
190         /// <summary>
191         /// Clears the value of the specified extension
192         /// </summary>
193         public static void Clear<TTarget, TValue>(ref ExtensionSet<TTarget> set, RepeatedExtension<TTarget, TValue> extension) where TTarget : IExtendableMessage<TTarget>
194         {
195             if (set == null)
196             {
197                 return;
198             }
199             set.ValuesByNumber.Remove(extension.FieldNumber);
200             if (set.ValuesByNumber.Count == 0)
201             {
202                 set = null;
203             }
204         }
205 
206         /// <summary>
207         /// Tries to merge a field from the coded input, returning true if the field was merged.
208         /// If the set is null or the field was not otherwise merged, this returns false.
209         /// </summary>
210         public static bool TryMergeFieldFrom<TTarget>(ref ExtensionSet<TTarget> set, CodedInputStream stream) where TTarget : IExtendableMessage<TTarget>
211         {
212             ParseContext.Initialize(stream, out ParseContext ctx);
213             try
214             {
215                 return TryMergeFieldFrom<TTarget>(ref set, ref ctx);
216             }
217             finally
218             {
219                 ctx.CopyStateTo(stream);
220             }
221         }
222 
223         /// <summary>
224         /// Tries to merge a field from the coded input, returning true if the field was merged.
225         /// If the set is null or the field was not otherwise merged, this returns false.
226         /// </summary>
227         public static bool TryMergeFieldFrom<TTarget>(ref ExtensionSet<TTarget> set, ref ParseContext ctx) where TTarget : IExtendableMessage<TTarget>
228         {
229             int lastFieldNumber = WireFormat.GetTagFieldNumber(ctx.LastTag);
230 
231             if (set != null && set.ValuesByNumber.TryGetValue(lastFieldNumber, out IExtensionValue extensionValue))
232             {
233                 extensionValue.MergeFrom(ref ctx);
234                 return true;
235             }
236             else if (ctx.ExtensionRegistry != null && ctx.ExtensionRegistry.ContainsInputField(ctx.LastTag, typeof(TTarget), out Extension extension))
237             {
238                 IExtensionValue value = extension.CreateValue();
239                 value.MergeFrom(ref ctx);
240                 set ??= new ExtensionSet<TTarget>();
241                 set.ValuesByNumber.Add(extension.FieldNumber, value);
242                 return true;
243             }
244             else
245             {
246                 return false;
247             }
248         }
249 
250         /// <summary>
251         /// Merges the second set into the first set, creating a new instance if first is null
252         /// </summary>
253         public static void MergeFrom<TTarget>(ref ExtensionSet<TTarget> first, ExtensionSet<TTarget> second) where TTarget : IExtendableMessage<TTarget>
254         {
255             if (second == null)
256             {
257                 return;
258             }
259             if (first == null)
260             {
261                 first = new ExtensionSet<TTarget>();
262             }
263             foreach (var pair in second.ValuesByNumber)
264             {
265                 if (first.ValuesByNumber.TryGetValue(pair.Key, out IExtensionValue value))
266                 {
267                     value.MergeFrom(pair.Value);
268                 }
269                 else
270                 {
271                     var cloned = pair.Value.Clone();
272                     first.ValuesByNumber[pair.Key] = cloned;
273                 }
274             }
275         }
276 
277         /// <summary>
278         /// Clones the set into a new set. If the set is null, this returns null
279         /// </summary>
280         public static ExtensionSet<TTarget> Clone<TTarget>(ExtensionSet<TTarget> set) where TTarget : IExtendableMessage<TTarget>
281         {
282             if (set == null)
283             {
284                 return null;
285             }
286 
287             var newSet = new ExtensionSet<TTarget>();
288             foreach (var pair in set.ValuesByNumber)
289             {
290                 var cloned = pair.Value.Clone();
291                 newSet.ValuesByNumber[pair.Key] = cloned;
292             }
293             return newSet;
294         }
295     }
296 
297     /// <summary>
298     /// Used for keeping track of extensions in messages.
299     /// <see cref="IExtendableMessage{T}"/> methods route to this set.
300     ///
301     /// Most users will not need to use this class directly
302     /// </summary>
303     /// <typeparam name="TTarget">The message type that extensions in this set target</typeparam>
304     public sealed class ExtensionSet<TTarget> where TTarget : IExtendableMessage<TTarget>
305     {
306         internal Dictionary<int, IExtensionValue> ValuesByNumber { get; } = new Dictionary<int, IExtensionValue>();
307 
308         /// <summary>
309         /// Gets a hash code of the set
310         /// </summary>
GetHashCode()311         public override int GetHashCode()
312         {
313             int ret = typeof(TTarget).GetHashCode();
314             foreach (KeyValuePair<int, IExtensionValue> field in ValuesByNumber)
315             {
316                 // Use ^ here to make the field order irrelevant.
317                 int hash = field.Key.GetHashCode() ^ field.Value.GetHashCode();
318                 ret ^= hash;
319             }
320             return ret;
321         }
322 
323         /// <summary>
324         /// Returns whether this set is equal to the other object
325         /// </summary>
Equals(object other)326         public override bool Equals(object other)
327         {
328             if (ReferenceEquals(this, other))
329             {
330                 return true;
331             }
332             ExtensionSet<TTarget> otherSet = other as ExtensionSet<TTarget>;
333             if (ValuesByNumber.Count != otherSet.ValuesByNumber.Count)
334             {
335                 return false;
336             }
337             foreach (var pair in ValuesByNumber)
338             {
339                 if (!otherSet.ValuesByNumber.TryGetValue(pair.Key, out IExtensionValue secondValue))
340                 {
341                     return false;
342                 }
343                 if (!pair.Value.Equals(secondValue))
344                 {
345                     return false;
346                 }
347             }
348             return true;
349         }
350 
351         /// <summary>
352         /// Calculates the size of this extension set
353         /// </summary>
CalculateSize()354         public int CalculateSize()
355         {
356             int size = 0;
357             foreach (var value in ValuesByNumber.Values)
358             {
359                 size += value.CalculateSize();
360             }
361             return size;
362         }
363 
364         /// <summary>
365         /// Writes the extension values in this set to the output stream
366         /// </summary>
WriteTo(CodedOutputStream stream)367         public void WriteTo(CodedOutputStream stream)
368         {
369 
370             WriteContext.Initialize(stream, out WriteContext ctx);
371             try
372             {
373                 WriteTo(ref ctx);
374             }
375             finally
376             {
377                 ctx.CopyStateTo(stream);
378             }
379         }
380 
381         /// <summary>
382         /// Writes the extension values in this set to the write context
383         /// </summary>
384         [SecuritySafeCritical]
WriteTo(ref WriteContext ctx)385         public void WriteTo(ref WriteContext ctx)
386         {
387             foreach (var value in ValuesByNumber.Values)
388             {
389                 value.WriteTo(ref ctx);
390             }
391         }
392 
IsInitialized()393         internal bool IsInitialized()
394         {
395             return ValuesByNumber.Values.All(v => v.IsInitialized());
396         }
397     }
398 }
399