1 /*
2 * Copyright (c) 2009-2021, Google LLC
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 * * Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * * Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 * * Neither the name of Google LLC nor the
13 * names of its contributors may be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED. IN NO EVENT SHALL Google LLC BE LIABLE FOR ANY DIRECT,
20 * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27
28 #ifndef UPB_PROTOS_PROTOS_H_
29 #define UPB_PROTOS_PROTOS_H_
30
31 #include <type_traits>
32 #include <vector>
33
34 #include "absl/status/status.h"
35 #include "absl/status/statusor.h"
36 #include "upb/message/extension_internal.h"
37 #include "upb/mini_table/types.h"
38 #include "upb/upb.hpp"
39 #include "upb/wire/decode.h"
40 #include "upb/wire/encode.h"
41
42 namespace protos {
43
44 using Arena = ::upb::Arena;
45 class ExtensionRegistry;
46
47 template <typename T>
48 using Proxy = std::conditional_t<std::is_const<T>::value,
49 typename std::remove_const_t<T>::CProxy,
50 typename T::Proxy>;
51
52 // Provides convenient access to Proxy and CProxy message types.
53 //
54 // Using rebinding and handling of const, Ptr<Message> and Ptr<const Message>
55 // allows copying const with T* const and avoids using non-copyable Proxy types
56 // directly.
57 template <typename T>
58 class Ptr final {
59 public:
60 Ptr() = delete;
61
62 // Implicit conversions
Ptr(T * m)63 Ptr(T* m) : p_(m) {} // NOLINT
Ptr(const Proxy<T> * p)64 Ptr(const Proxy<T>* p) : p_(*p) {} // NOLINT
Ptr(Proxy<T> p)65 Ptr(Proxy<T> p) : p_(p) {} // NOLINT
66 Ptr(const Ptr& m) = default;
67
68 Ptr& operator=(Ptr v) & {
69 Proxy<T>::Rebind(p_, v.p_);
70 return *this;
71 }
72
73 Proxy<T> operator*() const { return p_; }
74 Proxy<T>* operator->() const {
75 return const_cast<Proxy<T>*>(std::addressof(p_));
76 }
77
78 #ifdef __clang__
79 #pragma clang diagnostic push
80 #pragma clang diagnostic ignored "-Wclass-conversion"
81 #endif
82 template <typename U = T, std::enable_if_t<!std::is_const<U>::value, int> = 0>
83 operator Ptr<const T>() const {
84 Proxy<const T> p(p_);
85 return Ptr<const T>(&p);
86 }
87 #ifdef __clang__
88 #pragma clang diagnostic pop
89 #endif
90
91 private:
Ptr(void * msg,upb_Arena * arena)92 Ptr(void* msg, upb_Arena* arena) : p_(msg, arena) {} // NOLINT
93
94 friend class Ptr<const T>;
95 friend typename T::Access;
96
97 Proxy<T> p_;
98 };
99
UpbStrToStringView(upb_StringView str)100 inline absl::string_view UpbStrToStringView(upb_StringView str) {
101 return absl::string_view(str.data, str.size);
102 }
103
104 // TODO: update bzl and move to upb runtime / protos.cc.
UpbStrFromStringView(absl::string_view str,upb_Arena * arena)105 inline upb_StringView UpbStrFromStringView(absl::string_view str,
106 upb_Arena* arena) {
107 const size_t str_size = str.size();
108 char* buffer = static_cast<char*>(upb_Arena_Malloc(arena, str_size));
109 memcpy(buffer, str.data(), str_size);
110 return upb_StringView_FromDataAndSize(buffer, str_size);
111 }
112
113 template <typename T>
CreateMessage(::protos::Arena & arena)114 typename T::Proxy CreateMessage(::protos::Arena& arena) {
115 return typename T::Proxy(upb_Message_New(T::minitable(), arena.ptr()),
116 arena.ptr());
117 }
118
119 template <typename T>
CloneMessage(Ptr<T> message,upb::Arena & arena)120 typename T::Proxy CloneMessage(Ptr<T> message, upb::Arena& arena) {
121 return typename T::Proxy(
122 upb_Message_DeepClone(message, T::minitable(), arena.ptr()), arena.ptr());
123 }
124
125 // begin:github_only
126 // This type exists to work around an absl type that has not yet been
127 // released.
128 struct SourceLocation {
currentSourceLocation129 static SourceLocation current() { return {}; }
file_nameSourceLocation130 absl::string_view file_name() { return "<unknown>"; }
lineSourceLocation131 int line() { return 0; }
132 };
133 // end:github_only
134
135 // begin:google_only
136 // using SourceLocation = absl::SourceLocation;
137 // end:google_only
138
139 absl::Status MessageAllocationError(
140 SourceLocation loc = SourceLocation::current());
141
142 absl::Status ExtensionNotFoundError(
143 int extension_number, SourceLocation loc = SourceLocation::current());
144
145 absl::Status MessageDecodeError(upb_DecodeStatus status,
146 SourceLocation loc = SourceLocation::current());
147
148 absl::Status MessageEncodeError(upb_EncodeStatus status,
149 SourceLocation loc = SourceLocation::current());
150
151 namespace internal {
152 template <typename T>
CreateMessage()153 T CreateMessage() {
154 return T();
155 }
156
157 template <typename T>
CreateMessageProxy(void * msg,upb_Arena * arena)158 typename T::Proxy CreateMessageProxy(void* msg, upb_Arena* arena) {
159 return typename T::Proxy(msg, arena);
160 }
161
162 template <typename T>
CreateMessage(upb_Message * msg)163 typename T::CProxy CreateMessage(upb_Message* msg) {
164 return typename T::CProxy(msg);
165 }
166
167 class ExtensionMiniTableProvider {
168 public:
ExtensionMiniTableProvider(const upb_MiniTableExtension * mini_table_ext)169 constexpr explicit ExtensionMiniTableProvider(
170 const upb_MiniTableExtension* mini_table_ext)
171 : mini_table_ext_(mini_table_ext) {}
mini_table_ext()172 const upb_MiniTableExtension* mini_table_ext() const {
173 return mini_table_ext_;
174 }
175
176 private:
177 const upb_MiniTableExtension* mini_table_ext_;
178 };
179
180 // -------------------------------------------------------------------
181 // ExtensionIdentifier
182 // This is the type of actual extension objects. E.g. if you have:
183 // extend Foo {
184 // optional MyExtension bar = 1234;
185 // }
186 // then "bar" will be defined in C++ as:
187 // ExtensionIdentifier<Foo, MyExtension> bar(&namespace_bar_ext);
188 template <typename ExtendeeType, typename ExtensionType>
189 class ExtensionIdentifier : public ExtensionMiniTableProvider {
190 public:
191 using Extension = ExtensionType;
192 using Extendee = ExtendeeType;
193
ExtensionIdentifier(const upb_MiniTableExtension * mini_table_ext)194 constexpr explicit ExtensionIdentifier(
195 const upb_MiniTableExtension* mini_table_ext)
196 : ExtensionMiniTableProvider(mini_table_ext) {}
197 };
198
199 template <typename T>
GetInternalMsg(const T & message)200 void* GetInternalMsg(const T& message) {
201 return message.msg();
202 }
203
204 template <typename T>
GetInternalMsg(const Ptr<T> & message)205 void* GetInternalMsg(const Ptr<T>& message) {
206 return message->msg();
207 }
208
209 template <typename T>
GetArena(const T & message)210 upb_Arena* GetArena(const T& message) {
211 return static_cast<upb_Arena*>(message.GetInternalArena());
212 }
213
214 template <typename T>
GetArena(const Ptr<T> & message)215 upb_Arena* GetArena(const Ptr<T>& message) {
216 return static_cast<upb_Arena*>(message->GetInternalArena());
217 }
218
219 upb_ExtensionRegistry* GetUpbExtensions(
220 const ExtensionRegistry& extension_registry);
221
222 absl::StatusOr<absl::string_view> Serialize(const upb_Message* message,
223 const upb_MiniTable* mini_table,
224 upb_Arena* arena, int options);
225
226 } // namespace internal
227
228 class ExtensionRegistry {
229 public:
ExtensionRegistry(const std::vector<const::protos::internal::ExtensionMiniTableProvider * > & extensions,const upb::Arena & arena)230 ExtensionRegistry(
231 const std::vector<const ::protos::internal::ExtensionMiniTableProvider*>&
232 extensions,
233 const upb::Arena& arena)
234 : registry_(upb_ExtensionRegistry_New(arena.ptr())) {
235 if (registry_) {
236 for (const auto& ext_provider : extensions) {
237 const auto* ext = ext_provider->mini_table_ext();
238 bool success = upb_ExtensionRegistry_AddArray(registry_, &ext, 1);
239 if (!success) {
240 registry_ = nullptr;
241 break;
242 }
243 }
244 }
245 }
246
247 private:
248 friend upb_ExtensionRegistry* ::protos::internal::GetUpbExtensions(
249 const ExtensionRegistry& extension_registry);
250 upb_ExtensionRegistry* registry_;
251 };
252
253 template <typename T>
254 using EnableIfProtosClass = std::enable_if_t<
255 std::is_base_of<typename T::Access, T>::value &&
256 std::is_base_of<typename T::Access, typename T::ExtendableType>::value>;
257
258 template <typename T>
259 using EnableIfMutableProto = std::enable_if_t<!std::is_const<T>::value>;
260
261 template <typename T, typename Extendee, typename Extension,
262 typename = EnableIfProtosClass<T>>
HasExtension(const T & message,const::protos::internal::ExtensionIdentifier<Extendee,Extension> & id)263 bool HasExtension(
264 const T& message,
265 const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id) {
266 return _upb_Message_Getext(message.msg(), id.mini_table_ext()) != nullptr;
267 }
268
269 template <typename T, typename Extendee, typename Extension,
270 typename = EnableIfProtosClass<T>>
HasExtension(const Ptr<T> & message,const::protos::internal::ExtensionIdentifier<Extendee,Extension> & id)271 bool HasExtension(
272 const Ptr<T>& message,
273 const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id) {
274 return _upb_Message_Getext(message->msg(), id.mini_table_ext()) != nullptr;
275 }
276
277 template <typename T, typename Extendee, typename Extension,
278 typename = EnableIfProtosClass<T>, typename = EnableIfMutableProto<T>>
ClearExtension(const Ptr<T> & message,const::protos::internal::ExtensionIdentifier<Extendee,Extension> & id)279 void ClearExtension(
280 const Ptr<T>& message,
281 const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id) {
282 _upb_Message_ClearExtensionField(message->msg(), id.mini_table_ext());
283 }
284
285 template <typename T, typename Extendee, typename Extension,
286 typename = EnableIfProtosClass<T>>
ClearExtension(const T & message,const::protos::internal::ExtensionIdentifier<Extendee,Extension> & id)287 void ClearExtension(
288 const T& message,
289 const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id) {
290 _upb_Message_ClearExtensionField(message.msg(), id.mini_table_ext());
291 }
292
293 template <typename T, typename Extendee, typename Extension,
294 typename = EnableIfProtosClass<T>>
SetExtension(const T & message,const::protos::internal::ExtensionIdentifier<Extendee,Extension> & id,Extension & value)295 absl::Status SetExtension(
296 const T& message,
297 const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id,
298 Extension& value) {
299 auto* message_arena = static_cast<upb_Arena*>(message.GetInternalArena());
300 upb_Message_Extension* msg_ext = _upb_Message_GetOrCreateExtension(
301 message.msg(), id.mini_table_ext(), message_arena);
302 if (!msg_ext) {
303 return MessageAllocationError();
304 }
305 auto* extension_arena = static_cast<upb_Arena*>(value.GetInternalArena());
306 if (message_arena != extension_arena) {
307 upb_Arena_Fuse(message_arena, extension_arena);
308 }
309 msg_ext->data.ptr = value.msg();
310 return absl::OkStatus();
311 }
312
313 template <typename T, typename Extendee, typename Extension,
314 typename = EnableIfProtosClass<T>, typename = EnableIfMutableProto<T>>
SetExtension(const Ptr<T> & message,const::protos::internal::ExtensionIdentifier<Extendee,Extension> & id,Extension & value)315 absl::Status SetExtension(
316 const Ptr<T>& message,
317 const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id,
318 Extension& value) {
319 auto* message_arena = static_cast<upb_Arena*>(message->GetInternalArena());
320 upb_Message_Extension* msg_ext = _upb_Message_GetOrCreateExtension(
321 message->msg(), id.mini_table_ext(), message_arena);
322 if (!msg_ext) {
323 return MessageAllocationError();
324 }
325 auto* extension_arena = static_cast<upb_Arena*>(message->GetInternalArena());
326 if (message_arena != extension_arena) {
327 upb_Arena_Fuse(message_arena, extension_arena);
328 }
329 msg_ext->data.ptr = ::protos::internal::GetInternalMsg(value);
330 return absl::OkStatus();
331 }
332
333 template <typename T, typename Extendee, typename Extension,
334 typename = EnableIfProtosClass<T>>
GetExtension(const T & message,const::protos::internal::ExtensionIdentifier<Extendee,Extension> & id)335 absl::StatusOr<Ptr<const Extension>> GetExtension(
336 const T& message,
337 const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id) {
338 const upb_Message_Extension* ext =
339 _upb_Message_Getext(message.msg(), id.mini_table_ext());
340 if (!ext) {
341 return ExtensionNotFoundError(id.mini_table_ext()->field.number);
342 }
343 return Ptr<const Extension>(
344 ::protos::internal::CreateMessage<Extension>(ext->data.ptr));
345 }
346
347 template <typename T, typename Extendee, typename Extension,
348 typename = EnableIfProtosClass<T>>
GetExtension(const Ptr<T> & message,const::protos::internal::ExtensionIdentifier<Extendee,Extension> & id)349 absl::StatusOr<Ptr<const Extension>> GetExtension(
350 const Ptr<T>& message,
351 const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id) {
352 const upb_Message_Extension* ext =
353 _upb_Message_Getext(message->msg(), id.mini_table_ext());
354 if (!ext) {
355 return ExtensionNotFoundError(id.mini_table_ext()->field.number);
356 }
357 return Ptr<const Extension>(
358 ::protos::internal::CreateMessage<Extension>(ext->data.ptr));
359 }
360
361 template <typename T>
Parse(T & message,absl::string_view bytes)362 bool Parse(T& message, absl::string_view bytes) {
363 _upb_Message_Clear(message.msg(), T::minitable());
364 auto* arena = static_cast<upb_Arena*>(message.GetInternalArena());
365 return upb_Decode(bytes.data(), bytes.size(), message.msg(), T::minitable(),
366 /* extreg= */ nullptr, /* options= */ 0,
367 arena) == kUpb_DecodeStatus_Ok;
368 }
369
370 template <typename T>
Parse(T & message,absl::string_view bytes,const::protos::ExtensionRegistry & extension_registry)371 bool Parse(T& message, absl::string_view bytes,
372 const ::protos::ExtensionRegistry& extension_registry) {
373 _upb_Message_Clear(message.msg(), T::minitable());
374 auto* arena = static_cast<upb_Arena*>(message.GetInternalArena());
375 return upb_Decode(bytes.data(), bytes.size(), message.msg(), T::minitable(),
376 /* extreg= */
377 ::protos::internal::GetUpbExtensions(extension_registry),
378 /* options= */ 0, arena) == kUpb_DecodeStatus_Ok;
379 }
380
381 template <typename T>
Parse(Ptr<T> & message,absl::string_view bytes)382 bool Parse(Ptr<T>& message, absl::string_view bytes) {
383 _upb_Message_Clear(message->msg(), T::minitable());
384 auto* arena = static_cast<upb_Arena*>(message->GetInternalArena());
385 return upb_Decode(bytes.data(), bytes.size(), message->msg(), T::minitable(),
386 /* extreg= */ nullptr, /* options= */ 0,
387 arena) == kUpb_DecodeStatus_Ok;
388 }
389
390 template <typename T>
Parse(Ptr<T> & message,absl::string_view bytes,const::protos::ExtensionRegistry & extension_registry)391 bool Parse(Ptr<T>& message, absl::string_view bytes,
392 const ::protos::ExtensionRegistry& extension_registry) {
393 _upb_Message_Clear(message->msg(), T::minitable());
394 auto* arena = static_cast<upb_Arena*>(message->GetInternalArena());
395 return upb_Decode(bytes.data(), bytes.size(), message->msg(), T::minitable(),
396 /* extreg= */
397 ::protos::internal::GetUpbExtensions(extension_registry),
398 /* options= */ 0, arena) == kUpb_DecodeStatus_Ok;
399 }
400
401 template <typename T>
Parse(std::unique_ptr<T> & message,absl::string_view bytes)402 bool Parse(std::unique_ptr<T>& message, absl::string_view bytes) {
403 _upb_Message_Clear(message->msg(), T::minitable());
404 auto* arena = static_cast<upb_Arena*>(message->GetInternalArena());
405 return upb_Decode(bytes.data(), bytes.size(), message->msg(), T::minitable(),
406 /* extreg= */ nullptr, /* options= */ 0,
407 arena) == kUpb_DecodeStatus_Ok;
408 }
409
410 template <typename T>
Parse(std::unique_ptr<T> & message,absl::string_view bytes,const::protos::ExtensionRegistry & extension_registry)411 bool Parse(std::unique_ptr<T>& message, absl::string_view bytes,
412 const ::protos::ExtensionRegistry& extension_registry) {
413 _upb_Message_Clear(message->msg(), T::minitable());
414 auto* arena = static_cast<upb_Arena*>(message->GetInternalArena());
415 return upb_Decode(bytes.data(), bytes.size(), message->msg(), T::minitable(),
416 /* extreg= */
417 ::protos::internal::GetUpbExtensions(extension_registry),
418 /* options= */ 0, arena) == kUpb_DecodeStatus_Ok;
419 }
420
421 template <typename T>
Parse(std::shared_ptr<T> & message,absl::string_view bytes)422 bool Parse(std::shared_ptr<T>& message, absl::string_view bytes) {
423 _upb_Message_Clear(message->msg(), T::minitable());
424 auto* arena = static_cast<upb_Arena*>(message->GetInternalArena());
425 return upb_Decode(bytes.data(), bytes.size(), message->msg(), T::minitable(),
426 /* extreg= */ nullptr, /* options= */ 0,
427 arena) == kUpb_DecodeStatus_Ok;
428 }
429
430 template <typename T>
Parse(std::shared_ptr<T> & message,absl::string_view bytes,const::protos::ExtensionRegistry & extension_registry)431 bool Parse(std::shared_ptr<T>& message, absl::string_view bytes,
432 const ::protos::ExtensionRegistry& extension_registry) {
433 _upb_Message_Clear(message->msg(), T::minitable());
434 auto* arena = static_cast<upb_Arena*>(message->GetInternalArena());
435 return upb_Decode(bytes.data(), bytes.size(), message->msg(), T::minitable(),
436 /* extreg= */
437 ::protos::internal::GetUpbExtensions(extension_registry),
438 /* options= */ 0, arena) == kUpb_DecodeStatus_Ok;
439 }
440
441 template <typename T>
442 absl::StatusOr<T> Parse(absl::string_view bytes, int options = 0) {
443 T message;
444 auto* arena = static_cast<upb_Arena*>(message.GetInternalArena());
445 upb_DecodeStatus status =
446 upb_Decode(bytes.data(), bytes.size(), message.msg(), T::minitable(),
447 /* extreg= */ nullptr, /* options= */ 0, arena);
448 if (status == kUpb_DecodeStatus_Ok) {
449 return message;
450 }
451 return MessageDecodeError(status);
452 }
453
454 template <typename T>
455 absl::StatusOr<T> Parse(absl::string_view bytes,
456 const ::protos::ExtensionRegistry& extension_registry,
457 int options = 0) {
458 T message;
459 auto* arena = static_cast<upb_Arena*>(message.GetInternalArena());
460 upb_DecodeStatus status =
461 upb_Decode(bytes.data(), bytes.size(), message.msg(), T::minitable(),
462 ::protos::internal::GetUpbExtensions(extension_registry),
463 /* options= */ 0, arena);
464 if (status == kUpb_DecodeStatus_Ok) {
465 return message;
466 }
467 return MessageDecodeError(status);
468 }
469
470 template <typename T>
471 absl::StatusOr<absl::string_view> Serialize(const T& message, upb::Arena& arena,
472 int options = 0) {
473 return ::protos::internal::Serialize(
474 ::protos::internal::GetInternalMsg(message), T::minitable(), arena.ptr(),
475 options);
476 }
477
478 template <typename T>
479 absl::StatusOr<absl::string_view> Serialize(std::unique_ptr<T>& message,
480 upb::Arena& arena,
481 int options = 0) {
482 return ::protos::internal::Serialize(message->msg(), T::minitable(),
483 arena.ptr(), options);
484 }
485
486 template <typename T>
487 absl::StatusOr<absl::string_view> Serialize(std::shared_ptr<T>& message,
488 upb::Arena& arena,
489 int options = 0) {
490 return ::protos::internal::Serialize(message->msg(), T::minitable(),
491 arena.ptr(), options);
492 }
493
494 template <typename T>
495 absl::StatusOr<absl::string_view> Serialize(Ptr<T> message, upb::Arena& arena,
496 int options = 0) {
497 return ::protos::internal::Serialize(message->msg(), T::minitable(),
498 arena.ptr(), options);
499 }
500
501 } // namespace protos
502
503 #endif // UPB_PROTOS_PROTOS_H_
504