1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2023 Google LLC. All rights reserved.
3 //
4 // Use of this source code is governed by a BSD-style
5 // license that can be found in the LICENSE file or at
6 // https://developers.google.com/open-source/licenses/bsd
7
8 #include "google/protobuf/hpb/hpb.h"
9
10 #include <atomic>
11 #include <cstddef>
12
13 #include "absl/status/status.h"
14 #include "absl/status/statusor.h"
15 #include "absl/strings/str_format.h"
16 #include "absl/strings/string_view.h"
17 #include "google/protobuf/hpb/internal/message_lock.h"
18 #include "upb/mem/arena.h"
19 #include "upb/message/accessors.h"
20 #include "upb/message/copy.h"
21 #include "upb/message/message.h"
22 #include "upb/message/promote.h"
23 #include "upb/message/value.h"
24 #include "upb/mini_table/extension.h"
25 #include "upb/mini_table/extension_registry.h"
26 #include "upb/mini_table/message.h"
27 #include "upb/wire/decode.h"
28 #include "upb/wire/encode.h"
29
30 namespace hpb {
31
MessageAllocationError(SourceLocation loc)32 absl::Status MessageAllocationError(SourceLocation loc) {
33 return absl::Status(absl::StatusCode::kUnknown,
34 "Upb message allocation error");
35 }
36
ExtensionNotFoundError(int ext_number,SourceLocation loc)37 absl::Status ExtensionNotFoundError(int ext_number, SourceLocation loc) {
38 return absl::Status(absl::StatusCode::kUnknown,
39 absl::StrFormat("Extension %d not found", ext_number));
40 }
41
MessageEncodeError(upb_EncodeStatus s,SourceLocation loc)42 absl::Status MessageEncodeError(upb_EncodeStatus s, SourceLocation loc) {
43 return absl::Status(absl::StatusCode::kUnknown, "Encoding error");
44 }
45
MessageDecodeError(upb_DecodeStatus status,SourceLocation loc)46 absl::Status MessageDecodeError(upb_DecodeStatus status, SourceLocation loc
47
48 ) {
49 return absl::Status(absl::StatusCode::kUnknown, "Upb message parse error");
50 }
51
52 namespace internal {
53
54 /**
55 * MessageLock(msg) acquires lock on msg when constructed and releases it when
56 * destroyed.
57 */
58 class MessageLock {
59 public:
MessageLock(const upb_Message * msg)60 explicit MessageLock(const upb_Message* msg) : msg_(msg) {
61 UpbExtensionLocker locker =
62 upb_extension_locker_global.load(std::memory_order_acquire);
63 unlocker_ = (locker != nullptr) ? locker(msg) : nullptr;
64 }
65 MessageLock(const MessageLock&) = delete;
66 void operator=(const MessageLock&) = delete;
~MessageLock()67 ~MessageLock() {
68 if (unlocker_ != nullptr) {
69 unlocker_(msg_);
70 }
71 }
72
73 private:
74 const upb_Message* msg_;
75 UpbExtensionUnlocker unlocker_;
76 };
77
HasExtensionOrUnknown(const upb_Message * msg,const upb_MiniTableExtension * eid)78 bool HasExtensionOrUnknown(const upb_Message* msg,
79 const upb_MiniTableExtension* eid) {
80 MessageLock msg_lock(msg);
81 if (upb_Message_HasExtension(msg, eid)) return true;
82
83 const int number = upb_MiniTableExtension_Number(eid);
84 return upb_Message_FindUnknown(msg, number, 0).status == kUpb_FindUnknown_Ok;
85 }
86
GetOrPromoteExtension(upb_Message * msg,const upb_MiniTableExtension * eid,upb_Arena * arena,upb_MessageValue * value)87 bool GetOrPromoteExtension(upb_Message* msg, const upb_MiniTableExtension* eid,
88 upb_Arena* arena, upb_MessageValue* value) {
89 MessageLock msg_lock(msg);
90 upb_GetExtension_Status ext_status = upb_Message_GetOrPromoteExtension(
91 (upb_Message*)msg, eid, 0, arena, value);
92 return ext_status == kUpb_GetExtension_Ok;
93 }
94
Serialize(const upb_Message * message,const upb_MiniTable * mini_table,upb_Arena * arena,int options)95 absl::StatusOr<absl::string_view> Serialize(const upb_Message* message,
96 const upb_MiniTable* mini_table,
97 upb_Arena* arena, int options) {
98 MessageLock msg_lock(message);
99 size_t len;
100 char* ptr;
101 upb_EncodeStatus status =
102 upb_Encode(message, mini_table, options, arena, &ptr, &len);
103 if (status == kUpb_EncodeStatus_Ok) {
104 return absl::string_view(ptr, len);
105 }
106 return MessageEncodeError(status);
107 }
108
DeepCopy(upb_Message * target,const upb_Message * source,const upb_MiniTable * mini_table,upb_Arena * arena)109 void DeepCopy(upb_Message* target, const upb_Message* source,
110 const upb_MiniTable* mini_table, upb_Arena* arena) {
111 MessageLock msg_lock(source);
112 upb_Message_DeepCopy(target, source, mini_table, arena);
113 }
114
DeepClone(const upb_Message * source,const upb_MiniTable * mini_table,upb_Arena * arena)115 upb_Message* DeepClone(const upb_Message* source,
116 const upb_MiniTable* mini_table, upb_Arena* arena) {
117 MessageLock msg_lock(source);
118 return upb_Message_DeepClone(source, mini_table, arena);
119 }
120
MoveExtension(upb_Message * message,upb_Arena * message_arena,const upb_MiniTableExtension * ext,upb_Message * extension,upb_Arena * extension_arena)121 absl::Status MoveExtension(upb_Message* message, upb_Arena* message_arena,
122 const upb_MiniTableExtension* ext,
123 upb_Message* extension, upb_Arena* extension_arena) {
124 if (message_arena != extension_arena &&
125 // Try fuse, if fusing is not allowed or fails, create copy of extension.
126 !upb_Arena_Fuse(message_arena, extension_arena)) {
127 extension = DeepClone(extension, upb_MiniTableExtension_GetSubMessage(ext),
128 message_arena);
129 }
130 return upb_Message_SetExtension(message, ext, &extension, message_arena)
131 ? absl::OkStatus()
132 : MessageAllocationError();
133 }
134
SetExtension(upb_Message * message,upb_Arena * message_arena,const upb_MiniTableExtension * ext,const upb_Message * extension)135 absl::Status SetExtension(upb_Message* message, upb_Arena* message_arena,
136 const upb_MiniTableExtension* ext,
137 const upb_Message* extension) {
138 // Clone extension into target message arena.
139 extension = DeepClone(extension, upb_MiniTableExtension_GetSubMessage(ext),
140 message_arena);
141 return upb_Message_SetExtension(message, ext, &extension, message_arena)
142 ? absl::OkStatus()
143 : MessageAllocationError();
144 }
145
146 } // namespace internal
147
148 } // namespace hpb
149