1 // Copyright 2018 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "osp/public/message_demuxer.h"
6
7 #include <memory>
8 #include <utility>
9
10 #include "osp/impl/quic/quic_connection.h"
11 #include "platform/base/error.h"
12 #include "util/big_endian.h"
13 #include "util/osp_logging.h"
14
15 namespace openscreen {
16 namespace osp {
17
18 // static
19 // Decodes a varUint, expecting it to follow the encoding format described here:
20 // https://tools.ietf.org/html/draft-ietf-quic-transport-16#section-16
DecodeVarUint(const std::vector<uint8_t> & buffer,size_t * num_bytes_decoded)21 ErrorOr<uint64_t> MessageTypeDecoder::DecodeVarUint(
22 const std::vector<uint8_t>& buffer,
23 size_t* num_bytes_decoded) {
24 if (buffer.size() == 0) {
25 return Error::Code::kCborIncompleteMessage;
26 }
27
28 uint8_t num_type_bytes = static_cast<uint8_t>(buffer[0] >> 6 & 0x03);
29 *num_bytes_decoded = 0x1 << num_type_bytes;
30
31 // Ensure that ReadBigEndian won't read beyond the end of the buffer. Also,
32 // since we expect the id to be followed by the message, equality is not valid
33 if (buffer.size() <= *num_bytes_decoded) {
34 return Error::Code::kCborIncompleteMessage;
35 }
36
37 switch (num_type_bytes) {
38 case 0:
39 return ReadBigEndian<uint8_t>(&buffer[0]) & ~0xC0;
40 case 1:
41 return ReadBigEndian<uint16_t>(&buffer[0]) & ~(0xC0 << 8);
42 case 2:
43 return ReadBigEndian<uint32_t>(&buffer[0]) & ~(0xC0 << 24);
44 case 3:
45 return ReadBigEndian<uint64_t>(&buffer[0]) & ~(uint64_t{0xC0} << 56);
46 default:
47 OSP_NOTREACHED();
48 }
49 }
50
51 // static
52 // Decodes the Type of message, expecting it to follow the encoding format
53 // described here:
54 // https://tools.ietf.org/html/draft-ietf-quic-transport-16#section-16
DecodeType(const std::vector<uint8_t> & buffer,size_t * num_bytes_decoded)55 ErrorOr<msgs::Type> MessageTypeDecoder::DecodeType(
56 const std::vector<uint8_t>& buffer,
57 size_t* num_bytes_decoded) {
58 ErrorOr<uint64_t> message_type =
59 MessageTypeDecoder::DecodeVarUint(buffer, num_bytes_decoded);
60 if (message_type.is_error()) {
61 return message_type.error();
62 }
63
64 msgs::Type parsed_type =
65 msgs::TypeEnumValidator::SafeCast(message_type.value());
66 if (parsed_type == msgs::Type::kUnknown) {
67 return Error::Code::kCborInvalidMessage;
68 }
69
70 return parsed_type;
71 }
72
73 // static
74 constexpr size_t MessageDemuxer::kDefaultBufferLimit;
75
76 MessageDemuxer::MessageWatch::MessageWatch() = default;
77
MessageWatch(MessageDemuxer * parent,bool is_default,uint64_t endpoint_id,msgs::Type message_type)78 MessageDemuxer::MessageWatch::MessageWatch(MessageDemuxer* parent,
79 bool is_default,
80 uint64_t endpoint_id,
81 msgs::Type message_type)
82 : parent_(parent),
83 is_default_(is_default),
84 endpoint_id_(endpoint_id),
85 message_type_(message_type) {}
86
MessageWatch(MessageDemuxer::MessageWatch && other)87 MessageDemuxer::MessageWatch::MessageWatch(
88 MessageDemuxer::MessageWatch&& other) noexcept
89 : parent_(other.parent_),
90 is_default_(other.is_default_),
91 endpoint_id_(other.endpoint_id_),
92 message_type_(other.message_type_) {
93 other.parent_ = nullptr;
94 }
95
~MessageWatch()96 MessageDemuxer::MessageWatch::~MessageWatch() {
97 if (parent_) {
98 if (is_default_) {
99 OSP_VLOG << "dropping default handler for type: "
100 << static_cast<int>(message_type_);
101 parent_->StopDefaultMessageTypeWatch(message_type_);
102 } else {
103 OSP_VLOG << "dropping handler for type: "
104 << static_cast<int>(message_type_);
105 parent_->StopWatchingMessageType(endpoint_id_, message_type_);
106 }
107 }
108 }
109
operator =(MessageWatch && other)110 MessageDemuxer::MessageWatch& MessageDemuxer::MessageWatch::operator=(
111 MessageWatch&& other) noexcept {
112 using std::swap;
113 swap(parent_, other.parent_);
114 swap(is_default_, other.is_default_);
115 swap(endpoint_id_, other.endpoint_id_);
116 swap(message_type_, other.message_type_);
117 return *this;
118 }
119
MessageDemuxer(ClockNowFunctionPtr now_function,size_t buffer_limit=kDefaultBufferLimit)120 MessageDemuxer::MessageDemuxer(ClockNowFunctionPtr now_function,
121 size_t buffer_limit = kDefaultBufferLimit)
122 : now_function_(now_function), buffer_limit_(buffer_limit) {
123 OSP_DCHECK(now_function_);
124 }
125
126 MessageDemuxer::~MessageDemuxer() = default;
127
WatchMessageType(uint64_t endpoint_id,msgs::Type message_type,MessageCallback * callback)128 MessageDemuxer::MessageWatch MessageDemuxer::WatchMessageType(
129 uint64_t endpoint_id,
130 msgs::Type message_type,
131 MessageCallback* callback) {
132 auto callbacks_entry = message_callbacks_.find(endpoint_id);
133 if (callbacks_entry == message_callbacks_.end()) {
134 callbacks_entry =
135 message_callbacks_
136 .emplace(endpoint_id, std::map<msgs::Type, MessageCallback*>{})
137 .first;
138 }
139 auto emplace_result = callbacks_entry->second.emplace(message_type, callback);
140 if (!emplace_result.second)
141 return MessageWatch();
142 auto endpoint_entry = buffers_.find(endpoint_id);
143 if (endpoint_entry != buffers_.end()) {
144 for (auto& buffer : endpoint_entry->second) {
145 if (buffer.second.empty())
146 continue;
147 auto buffered_type = static_cast<msgs::Type>(buffer.second[0]);
148 if (message_type == buffered_type) {
149 HandleStreamBufferLoop(endpoint_id, buffer.first, callbacks_entry,
150 &buffer.second);
151 }
152 }
153 }
154 return MessageWatch(this, false, endpoint_id, message_type);
155 }
156
SetDefaultMessageTypeWatch(msgs::Type message_type,MessageCallback * callback)157 MessageDemuxer::MessageWatch MessageDemuxer::SetDefaultMessageTypeWatch(
158 msgs::Type message_type,
159 MessageCallback* callback) {
160 auto emplace_result = default_callbacks_.emplace(message_type, callback);
161 if (!emplace_result.second)
162 return MessageWatch();
163 for (auto& endpoint_buffers : buffers_) {
164 auto endpoint_id = endpoint_buffers.first;
165 for (auto& stream_map : endpoint_buffers.second) {
166 if (stream_map.second.empty())
167 continue;
168 auto buffered_type = static_cast<msgs::Type>(stream_map.second[0]);
169 if (message_type == buffered_type) {
170 auto connection_id = stream_map.first;
171 auto callbacks_entry = message_callbacks_.find(endpoint_id);
172 HandleStreamBufferLoop(endpoint_id, connection_id, callbacks_entry,
173 &stream_map.second);
174 }
175 }
176 }
177 return MessageWatch(this, true, 0, message_type);
178 }
179
OnStreamData(uint64_t endpoint_id,uint64_t connection_id,const uint8_t * data,size_t data_size)180 void MessageDemuxer::OnStreamData(uint64_t endpoint_id,
181 uint64_t connection_id,
182 const uint8_t* data,
183 size_t data_size) {
184 OSP_VLOG << __func__ << ": [" << endpoint_id << ", " << connection_id
185 << "] - (" << data_size << ")";
186 auto& stream_map = buffers_[endpoint_id];
187 if (!data_size) {
188 stream_map.erase(connection_id);
189 if (stream_map.empty())
190 buffers_.erase(endpoint_id);
191 return;
192 }
193 std::vector<uint8_t>& buffer = stream_map[connection_id];
194 buffer.insert(buffer.end(), data, data + data_size);
195
196 auto callbacks_entry = message_callbacks_.find(endpoint_id);
197 HandleStreamBufferLoop(endpoint_id, connection_id, callbacks_entry, &buffer);
198
199 if (buffer.size() > buffer_limit_)
200 stream_map.erase(connection_id);
201 }
202
StopWatchingMessageType(uint64_t endpoint_id,msgs::Type message_type)203 void MessageDemuxer::StopWatchingMessageType(uint64_t endpoint_id,
204 msgs::Type message_type) {
205 auto& message_map = message_callbacks_[endpoint_id];
206 auto it = message_map.find(message_type);
207 message_map.erase(it);
208 }
209
StopDefaultMessageTypeWatch(msgs::Type message_type)210 void MessageDemuxer::StopDefaultMessageTypeWatch(msgs::Type message_type) {
211 default_callbacks_.erase(message_type);
212 }
213
HandleStreamBufferLoop(uint64_t endpoint_id,uint64_t connection_id,std::map<uint64_t,std::map<msgs::Type,MessageCallback * >>::iterator callbacks_entry,std::vector<uint8_t> * buffer)214 MessageDemuxer::HandleStreamBufferResult MessageDemuxer::HandleStreamBufferLoop(
215 uint64_t endpoint_id,
216 uint64_t connection_id,
217 std::map<uint64_t, std::map<msgs::Type, MessageCallback*>>::iterator
218 callbacks_entry,
219 std::vector<uint8_t>* buffer) {
220 HandleStreamBufferResult result;
221 do {
222 result = {false, 0};
223 if (callbacks_entry != message_callbacks_.end()) {
224 OSP_VLOG << "attempting endpoint-specific handling";
225 result = HandleStreamBuffer(endpoint_id, connection_id,
226 &callbacks_entry->second, buffer);
227 }
228 if (!result.handled) {
229 if (!default_callbacks_.empty()) {
230 OSP_VLOG << "attempting generic message handling";
231 result = HandleStreamBuffer(endpoint_id, connection_id,
232 &default_callbacks_, buffer);
233 }
234 }
235 OSP_VLOG_IF(!result.handled) << "no message handler matched";
236 } while (result.consumed && !buffer->empty());
237 return result;
238 }
239
240 // TODO(rwkeane) Use absl::Span for the buffer
HandleStreamBuffer(uint64_t endpoint_id,uint64_t connection_id,std::map<msgs::Type,MessageCallback * > * message_callbacks,std::vector<uint8_t> * buffer)241 MessageDemuxer::HandleStreamBufferResult MessageDemuxer::HandleStreamBuffer(
242 uint64_t endpoint_id,
243 uint64_t connection_id,
244 std::map<msgs::Type, MessageCallback*>* message_callbacks,
245 std::vector<uint8_t>* buffer) {
246 size_t consumed = 0;
247 size_t total_consumed = 0;
248 bool handled = false;
249 do {
250 consumed = 0;
251 size_t msg_type_byte_length;
252 ErrorOr<msgs::Type> message_type =
253 MessageTypeDecoder::DecodeType(*buffer, &msg_type_byte_length);
254 if (message_type.is_error()) {
255 buffer->clear();
256 break;
257 }
258 auto callback_entry = message_callbacks->find(message_type.value());
259 if (callback_entry == message_callbacks->end())
260 break;
261 handled = true;
262 OSP_VLOG << "handling message type "
263 << static_cast<int>(message_type.value());
264 auto consumed_or_error = callback_entry->second->OnStreamMessage(
265 endpoint_id, connection_id, message_type.value(),
266 buffer->data() + msg_type_byte_length,
267 buffer->size() - msg_type_byte_length, now_function_());
268 if (!consumed_or_error) {
269 if (consumed_or_error.error().code() !=
270 Error::Code::kCborIncompleteMessage) {
271 buffer->clear();
272 break;
273 }
274 } else {
275 consumed = consumed_or_error.value();
276 buffer->erase(buffer->begin(),
277 buffer->begin() + consumed + msg_type_byte_length);
278 }
279 total_consumed += consumed;
280 } while (consumed && !buffer->empty());
281 return HandleStreamBufferResult{handled, total_consumed};
282 }
283
StopWatching(MessageDemuxer::MessageWatch * watch)284 void StopWatching(MessageDemuxer::MessageWatch* watch) {
285 *watch = MessageDemuxer::MessageWatch();
286 }
287
288 } // namespace osp
289 } // namespace openscreen
290