• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2009-2021, Google LLC
3  * All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *     * Redistributions of source code must retain the above copyright
8  *       notice, this list of conditions and the following disclaimer.
9  *     * Redistributions in binary form must reproduce the above copyright
10  *       notice, this list of conditions and the following disclaimer in the
11  *       documentation and/or other materials provided with the distribution.
12  *     * Neither the name of Google LLC nor the
13  *       names of its contributors may be used to endorse or promote products
14  *       derived from this software without specific prior written permission.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  * ARE DISCLAIMED. IN NO EVENT SHALL Google LLC BE LIABLE FOR ANY DIRECT,
20  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  */
27 
28 #ifndef UPB_PROTOS_PROTOS_H_
29 #define UPB_PROTOS_PROTOS_H_
30 
31 #include <type_traits>
32 #include <vector>
33 
34 #include "absl/status/status.h"
35 #include "absl/status/statusor.h"
36 #include "upb/message/extension_internal.h"
37 #include "upb/mini_table/types.h"
38 #include "upb/upb.hpp"
39 #include "upb/wire/decode.h"
40 #include "upb/wire/encode.h"
41 
42 namespace protos {
43 
44 using Arena = ::upb::Arena;
45 class ExtensionRegistry;
46 
47 template <typename T>
48 using Proxy = std::conditional_t<std::is_const<T>::value,
49                                  typename std::remove_const_t<T>::CProxy,
50                                  typename T::Proxy>;
51 
52 // Provides convenient access to Proxy and CProxy message types.
53 //
54 // Using rebinding and handling of const, Ptr<Message> and Ptr<const Message>
55 // allows copying const with T* const and avoids using non-copyable Proxy types
56 // directly.
57 template <typename T>
58 class Ptr final {
59  public:
60   Ptr() = delete;
61 
62   // Implicit conversions
Ptr(T * m)63   Ptr(T* m) : p_(m) {}                // NOLINT
Ptr(const Proxy<T> * p)64   Ptr(const Proxy<T>* p) : p_(*p) {}  // NOLINT
Ptr(Proxy<T> p)65   Ptr(Proxy<T> p) : p_(p) {}          // NOLINT
66   Ptr(const Ptr& m) = default;
67 
68   Ptr& operator=(Ptr v) & {
69     Proxy<T>::Rebind(p_, v.p_);
70     return *this;
71   }
72 
73   Proxy<T> operator*() const { return p_; }
74   Proxy<T>* operator->() const {
75     return const_cast<Proxy<T>*>(std::addressof(p_));
76   }
77 
78 #ifdef __clang__
79 #pragma clang diagnostic push
80 #pragma clang diagnostic ignored "-Wclass-conversion"
81 #endif
82   template <typename U = T, std::enable_if_t<!std::is_const<U>::value, int> = 0>
83   operator Ptr<const T>() const {
84     Proxy<const T> p(p_);
85     return Ptr<const T>(&p);
86   }
87 #ifdef __clang__
88 #pragma clang diagnostic pop
89 #endif
90 
91  private:
Ptr(void * msg,upb_Arena * arena)92   Ptr(void* msg, upb_Arena* arena) : p_(msg, arena) {}  // NOLINT
93 
94   friend class Ptr<const T>;
95   friend typename T::Access;
96 
97   Proxy<T> p_;
98 };
99 
UpbStrToStringView(upb_StringView str)100 inline absl::string_view UpbStrToStringView(upb_StringView str) {
101   return absl::string_view(str.data, str.size);
102 }
103 
104 // TODO: update bzl and move to upb runtime / protos.cc.
UpbStrFromStringView(absl::string_view str,upb_Arena * arena)105 inline upb_StringView UpbStrFromStringView(absl::string_view str,
106                                            upb_Arena* arena) {
107   const size_t str_size = str.size();
108   char* buffer = static_cast<char*>(upb_Arena_Malloc(arena, str_size));
109   memcpy(buffer, str.data(), str_size);
110   return upb_StringView_FromDataAndSize(buffer, str_size);
111 }
112 
113 template <typename T>
CreateMessage(::protos::Arena & arena)114 typename T::Proxy CreateMessage(::protos::Arena& arena) {
115   return typename T::Proxy(upb_Message_New(T::minitable(), arena.ptr()),
116                            arena.ptr());
117 }
118 
119 template <typename T>
CloneMessage(Ptr<T> message,upb::Arena & arena)120 typename T::Proxy CloneMessage(Ptr<T> message, upb::Arena& arena) {
121   return typename T::Proxy(
122       upb_Message_DeepClone(message, T::minitable(), arena.ptr()), arena.ptr());
123 }
124 
125 // begin:github_only
126 // This type exists to work around an absl type that has not yet been
127 // released.
128 struct SourceLocation {
currentSourceLocation129   static SourceLocation current() { return {}; }
file_nameSourceLocation130   absl::string_view file_name() { return "<unknown>"; }
lineSourceLocation131   int line() { return 0; }
132 };
133 // end:github_only
134 
135 // begin:google_only
136 // using SourceLocation = absl::SourceLocation;
137 // end:google_only
138 
139 absl::Status MessageAllocationError(
140     SourceLocation loc = SourceLocation::current());
141 
142 absl::Status ExtensionNotFoundError(
143     int extension_number, SourceLocation loc = SourceLocation::current());
144 
145 absl::Status MessageDecodeError(upb_DecodeStatus status,
146                                 SourceLocation loc = SourceLocation::current());
147 
148 absl::Status MessageEncodeError(upb_EncodeStatus status,
149                                 SourceLocation loc = SourceLocation::current());
150 
151 namespace internal {
152 template <typename T>
CreateMessage()153 T CreateMessage() {
154   return T();
155 }
156 
157 template <typename T>
CreateMessageProxy(void * msg,upb_Arena * arena)158 typename T::Proxy CreateMessageProxy(void* msg, upb_Arena* arena) {
159   return typename T::Proxy(msg, arena);
160 }
161 
162 template <typename T>
CreateMessage(upb_Message * msg)163 typename T::CProxy CreateMessage(upb_Message* msg) {
164   return typename T::CProxy(msg);
165 }
166 
167 class ExtensionMiniTableProvider {
168  public:
ExtensionMiniTableProvider(const upb_MiniTableExtension * mini_table_ext)169   constexpr explicit ExtensionMiniTableProvider(
170       const upb_MiniTableExtension* mini_table_ext)
171       : mini_table_ext_(mini_table_ext) {}
mini_table_ext()172   const upb_MiniTableExtension* mini_table_ext() const {
173     return mini_table_ext_;
174   }
175 
176  private:
177   const upb_MiniTableExtension* mini_table_ext_;
178 };
179 
180 // -------------------------------------------------------------------
181 // ExtensionIdentifier
182 // This is the type of actual extension objects.  E.g. if you have:
183 //   extend Foo {
184 //     optional MyExtension bar = 1234;
185 //   }
186 // then "bar" will be defined in C++ as:
187 //   ExtensionIdentifier<Foo, MyExtension> bar(&namespace_bar_ext);
188 template <typename ExtendeeType, typename ExtensionType>
189 class ExtensionIdentifier : public ExtensionMiniTableProvider {
190  public:
191   using Extension = ExtensionType;
192   using Extendee = ExtendeeType;
193 
ExtensionIdentifier(const upb_MiniTableExtension * mini_table_ext)194   constexpr explicit ExtensionIdentifier(
195       const upb_MiniTableExtension* mini_table_ext)
196       : ExtensionMiniTableProvider(mini_table_ext) {}
197 };
198 
199 template <typename T>
GetInternalMsg(const T & message)200 void* GetInternalMsg(const T& message) {
201   return message.msg();
202 }
203 
204 template <typename T>
GetInternalMsg(const Ptr<T> & message)205 void* GetInternalMsg(const Ptr<T>& message) {
206   return message->msg();
207 }
208 
209 template <typename T>
GetArena(const T & message)210 upb_Arena* GetArena(const T& message) {
211   return static_cast<upb_Arena*>(message.GetInternalArena());
212 }
213 
214 template <typename T>
GetArena(const Ptr<T> & message)215 upb_Arena* GetArena(const Ptr<T>& message) {
216   return static_cast<upb_Arena*>(message->GetInternalArena());
217 }
218 
219 upb_ExtensionRegistry* GetUpbExtensions(
220     const ExtensionRegistry& extension_registry);
221 
222 absl::StatusOr<absl::string_view> Serialize(const upb_Message* message,
223                                             const upb_MiniTable* mini_table,
224                                             upb_Arena* arena, int options);
225 
226 }  // namespace internal
227 
228 class ExtensionRegistry {
229  public:
ExtensionRegistry(const std::vector<const::protos::internal::ExtensionMiniTableProvider * > & extensions,const upb::Arena & arena)230   ExtensionRegistry(
231       const std::vector<const ::protos::internal::ExtensionMiniTableProvider*>&
232           extensions,
233       const upb::Arena& arena)
234       : registry_(upb_ExtensionRegistry_New(arena.ptr())) {
235     if (registry_) {
236       for (const auto& ext_provider : extensions) {
237         const auto* ext = ext_provider->mini_table_ext();
238         bool success = upb_ExtensionRegistry_AddArray(registry_, &ext, 1);
239         if (!success) {
240           registry_ = nullptr;
241           break;
242         }
243       }
244     }
245   }
246 
247  private:
248   friend upb_ExtensionRegistry* ::protos::internal::GetUpbExtensions(
249       const ExtensionRegistry& extension_registry);
250   upb_ExtensionRegistry* registry_;
251 };
252 
253 template <typename T>
254 using EnableIfProtosClass = std::enable_if_t<
255     std::is_base_of<typename T::Access, T>::value &&
256     std::is_base_of<typename T::Access, typename T::ExtendableType>::value>;
257 
258 template <typename T>
259 using EnableIfMutableProto = std::enable_if_t<!std::is_const<T>::value>;
260 
261 template <typename T, typename Extendee, typename Extension,
262           typename = EnableIfProtosClass<T>>
HasExtension(const T & message,const::protos::internal::ExtensionIdentifier<Extendee,Extension> & id)263 bool HasExtension(
264     const T& message,
265     const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id) {
266   return _upb_Message_Getext(message.msg(), id.mini_table_ext()) != nullptr;
267 }
268 
269 template <typename T, typename Extendee, typename Extension,
270           typename = EnableIfProtosClass<T>>
HasExtension(const Ptr<T> & message,const::protos::internal::ExtensionIdentifier<Extendee,Extension> & id)271 bool HasExtension(
272     const Ptr<T>& message,
273     const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id) {
274   return _upb_Message_Getext(message->msg(), id.mini_table_ext()) != nullptr;
275 }
276 
277 template <typename T, typename Extendee, typename Extension,
278           typename = EnableIfProtosClass<T>, typename = EnableIfMutableProto<T>>
ClearExtension(const Ptr<T> & message,const::protos::internal::ExtensionIdentifier<Extendee,Extension> & id)279 void ClearExtension(
280     const Ptr<T>& message,
281     const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id) {
282   _upb_Message_ClearExtensionField(message->msg(), id.mini_table_ext());
283 }
284 
285 template <typename T, typename Extendee, typename Extension,
286           typename = EnableIfProtosClass<T>>
ClearExtension(const T & message,const::protos::internal::ExtensionIdentifier<Extendee,Extension> & id)287 void ClearExtension(
288     const T& message,
289     const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id) {
290   _upb_Message_ClearExtensionField(message.msg(), id.mini_table_ext());
291 }
292 
293 template <typename T, typename Extendee, typename Extension,
294           typename = EnableIfProtosClass<T>>
SetExtension(const T & message,const::protos::internal::ExtensionIdentifier<Extendee,Extension> & id,Extension & value)295 absl::Status SetExtension(
296     const T& message,
297     const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id,
298     Extension& value) {
299   auto* message_arena = static_cast<upb_Arena*>(message.GetInternalArena());
300   upb_Message_Extension* msg_ext = _upb_Message_GetOrCreateExtension(
301       message.msg(), id.mini_table_ext(), message_arena);
302   if (!msg_ext) {
303     return MessageAllocationError();
304   }
305   auto* extension_arena = static_cast<upb_Arena*>(value.GetInternalArena());
306   if (message_arena != extension_arena) {
307     upb_Arena_Fuse(message_arena, extension_arena);
308   }
309   msg_ext->data.ptr = value.msg();
310   return absl::OkStatus();
311 }
312 
313 template <typename T, typename Extendee, typename Extension,
314           typename = EnableIfProtosClass<T>, typename = EnableIfMutableProto<T>>
SetExtension(const Ptr<T> & message,const::protos::internal::ExtensionIdentifier<Extendee,Extension> & id,Extension & value)315 absl::Status SetExtension(
316     const Ptr<T>& message,
317     const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id,
318     Extension& value) {
319   auto* message_arena = static_cast<upb_Arena*>(message->GetInternalArena());
320   upb_Message_Extension* msg_ext = _upb_Message_GetOrCreateExtension(
321       message->msg(), id.mini_table_ext(), message_arena);
322   if (!msg_ext) {
323     return MessageAllocationError();
324   }
325   auto* extension_arena = static_cast<upb_Arena*>(message->GetInternalArena());
326   if (message_arena != extension_arena) {
327     upb_Arena_Fuse(message_arena, extension_arena);
328   }
329   msg_ext->data.ptr = ::protos::internal::GetInternalMsg(value);
330   return absl::OkStatus();
331 }
332 
333 template <typename T, typename Extendee, typename Extension,
334           typename = EnableIfProtosClass<T>>
GetExtension(const T & message,const::protos::internal::ExtensionIdentifier<Extendee,Extension> & id)335 absl::StatusOr<Ptr<const Extension>> GetExtension(
336     const T& message,
337     const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id) {
338   const upb_Message_Extension* ext =
339       _upb_Message_Getext(message.msg(), id.mini_table_ext());
340   if (!ext) {
341     return ExtensionNotFoundError(id.mini_table_ext()->field.number);
342   }
343   return Ptr<const Extension>(
344       ::protos::internal::CreateMessage<Extension>(ext->data.ptr));
345 }
346 
347 template <typename T, typename Extendee, typename Extension,
348           typename = EnableIfProtosClass<T>>
GetExtension(const Ptr<T> & message,const::protos::internal::ExtensionIdentifier<Extendee,Extension> & id)349 absl::StatusOr<Ptr<const Extension>> GetExtension(
350     const Ptr<T>& message,
351     const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id) {
352   const upb_Message_Extension* ext =
353       _upb_Message_Getext(message->msg(), id.mini_table_ext());
354   if (!ext) {
355     return ExtensionNotFoundError(id.mini_table_ext()->field.number);
356   }
357   return Ptr<const Extension>(
358       ::protos::internal::CreateMessage<Extension>(ext->data.ptr));
359 }
360 
361 template <typename T>
Parse(T & message,absl::string_view bytes)362 bool Parse(T& message, absl::string_view bytes) {
363   _upb_Message_Clear(message.msg(), T::minitable());
364   auto* arena = static_cast<upb_Arena*>(message.GetInternalArena());
365   return upb_Decode(bytes.data(), bytes.size(), message.msg(), T::minitable(),
366                     /* extreg= */ nullptr, /* options= */ 0,
367                     arena) == kUpb_DecodeStatus_Ok;
368 }
369 
370 template <typename T>
Parse(T & message,absl::string_view bytes,const::protos::ExtensionRegistry & extension_registry)371 bool Parse(T& message, absl::string_view bytes,
372            const ::protos::ExtensionRegistry& extension_registry) {
373   _upb_Message_Clear(message.msg(), T::minitable());
374   auto* arena = static_cast<upb_Arena*>(message.GetInternalArena());
375   return upb_Decode(bytes.data(), bytes.size(), message.msg(), T::minitable(),
376                     /* extreg= */
377                     ::protos::internal::GetUpbExtensions(extension_registry),
378                     /* options= */ 0, arena) == kUpb_DecodeStatus_Ok;
379 }
380 
381 template <typename T>
Parse(Ptr<T> & message,absl::string_view bytes)382 bool Parse(Ptr<T>& message, absl::string_view bytes) {
383   _upb_Message_Clear(message->msg(), T::minitable());
384   auto* arena = static_cast<upb_Arena*>(message->GetInternalArena());
385   return upb_Decode(bytes.data(), bytes.size(), message->msg(), T::minitable(),
386                     /* extreg= */ nullptr, /* options= */ 0,
387                     arena) == kUpb_DecodeStatus_Ok;
388 }
389 
390 template <typename T>
Parse(Ptr<T> & message,absl::string_view bytes,const::protos::ExtensionRegistry & extension_registry)391 bool Parse(Ptr<T>& message, absl::string_view bytes,
392            const ::protos::ExtensionRegistry& extension_registry) {
393   _upb_Message_Clear(message->msg(), T::minitable());
394   auto* arena = static_cast<upb_Arena*>(message->GetInternalArena());
395   return upb_Decode(bytes.data(), bytes.size(), message->msg(), T::minitable(),
396                     /* extreg= */
397                     ::protos::internal::GetUpbExtensions(extension_registry),
398                     /* options= */ 0, arena) == kUpb_DecodeStatus_Ok;
399 }
400 
401 template <typename T>
Parse(std::unique_ptr<T> & message,absl::string_view bytes)402 bool Parse(std::unique_ptr<T>& message, absl::string_view bytes) {
403   _upb_Message_Clear(message->msg(), T::minitable());
404   auto* arena = static_cast<upb_Arena*>(message->GetInternalArena());
405   return upb_Decode(bytes.data(), bytes.size(), message->msg(), T::minitable(),
406                     /* extreg= */ nullptr, /* options= */ 0,
407                     arena) == kUpb_DecodeStatus_Ok;
408 }
409 
410 template <typename T>
Parse(std::unique_ptr<T> & message,absl::string_view bytes,const::protos::ExtensionRegistry & extension_registry)411 bool Parse(std::unique_ptr<T>& message, absl::string_view bytes,
412            const ::protos::ExtensionRegistry& extension_registry) {
413   _upb_Message_Clear(message->msg(), T::minitable());
414   auto* arena = static_cast<upb_Arena*>(message->GetInternalArena());
415   return upb_Decode(bytes.data(), bytes.size(), message->msg(), T::minitable(),
416                     /* extreg= */
417                     ::protos::internal::GetUpbExtensions(extension_registry),
418                     /* options= */ 0, arena) == kUpb_DecodeStatus_Ok;
419 }
420 
421 template <typename T>
Parse(std::shared_ptr<T> & message,absl::string_view bytes)422 bool Parse(std::shared_ptr<T>& message, absl::string_view bytes) {
423   _upb_Message_Clear(message->msg(), T::minitable());
424   auto* arena = static_cast<upb_Arena*>(message->GetInternalArena());
425   return upb_Decode(bytes.data(), bytes.size(), message->msg(), T::minitable(),
426                     /* extreg= */ nullptr, /* options= */ 0,
427                     arena) == kUpb_DecodeStatus_Ok;
428 }
429 
430 template <typename T>
Parse(std::shared_ptr<T> & message,absl::string_view bytes,const::protos::ExtensionRegistry & extension_registry)431 bool Parse(std::shared_ptr<T>& message, absl::string_view bytes,
432            const ::protos::ExtensionRegistry& extension_registry) {
433   _upb_Message_Clear(message->msg(), T::minitable());
434   auto* arena = static_cast<upb_Arena*>(message->GetInternalArena());
435   return upb_Decode(bytes.data(), bytes.size(), message->msg(), T::minitable(),
436                     /* extreg= */
437                     ::protos::internal::GetUpbExtensions(extension_registry),
438                     /* options= */ 0, arena) == kUpb_DecodeStatus_Ok;
439 }
440 
441 template <typename T>
442 absl::StatusOr<T> Parse(absl::string_view bytes, int options = 0) {
443   T message;
444   auto* arena = static_cast<upb_Arena*>(message.GetInternalArena());
445   upb_DecodeStatus status =
446       upb_Decode(bytes.data(), bytes.size(), message.msg(), T::minitable(),
447                  /* extreg= */ nullptr, /* options= */ 0, arena);
448   if (status == kUpb_DecodeStatus_Ok) {
449     return message;
450   }
451   return MessageDecodeError(status);
452 }
453 
454 template <typename T>
455 absl::StatusOr<T> Parse(absl::string_view bytes,
456                         const ::protos::ExtensionRegistry& extension_registry,
457                         int options = 0) {
458   T message;
459   auto* arena = static_cast<upb_Arena*>(message.GetInternalArena());
460   upb_DecodeStatus status =
461       upb_Decode(bytes.data(), bytes.size(), message.msg(), T::minitable(),
462                  ::protos::internal::GetUpbExtensions(extension_registry),
463                  /* options= */ 0, arena);
464   if (status == kUpb_DecodeStatus_Ok) {
465     return message;
466   }
467   return MessageDecodeError(status);
468 }
469 
470 template <typename T>
471 absl::StatusOr<absl::string_view> Serialize(const T& message, upb::Arena& arena,
472                                             int options = 0) {
473   return ::protos::internal::Serialize(
474       ::protos::internal::GetInternalMsg(message), T::minitable(), arena.ptr(),
475       options);
476 }
477 
478 template <typename T>
479 absl::StatusOr<absl::string_view> Serialize(std::unique_ptr<T>& message,
480                                             upb::Arena& arena,
481                                             int options = 0) {
482   return ::protos::internal::Serialize(message->msg(), T::minitable(),
483                                        arena.ptr(), options);
484 }
485 
486 template <typename T>
487 absl::StatusOr<absl::string_view> Serialize(std::shared_ptr<T>& message,
488                                             upb::Arena& arena,
489                                             int options = 0) {
490   return ::protos::internal::Serialize(message->msg(), T::minitable(),
491                                        arena.ptr(), options);
492 }
493 
494 template <typename T>
495 absl::StatusOr<absl::string_view> Serialize(Ptr<T> message, upb::Arena& arena,
496                                             int options = 0) {
497   return ::protos::internal::Serialize(message->msg(), T::minitable(),
498                                        arena.ptr(), options);
499 }
500 
501 }  // namespace protos
502 
503 #endif  // UPB_PROTOS_PROTOS_H_
504