1 // Copyright 2018 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #ifndef OSP_PUBLIC_MESSAGE_DEMUXER_H_ 6 #define OSP_PUBLIC_MESSAGE_DEMUXER_H_ 7 8 #include <map> 9 #include <memory> 10 #include <vector> 11 12 #include "osp/msgs/osp_messages.h" 13 #include "platform/api/time.h" 14 #include "platform/base/error.h" 15 16 namespace openscreen { 17 namespace osp { 18 19 class QuicStream; 20 21 // This class separates QUIC stream data into CBOR messages by reading a type 22 // prefix from the stream and passes those messages to any callback matching the 23 // source endpoint and message type. If there is no callback for a given 24 // message type, it will also try a default message listener. 25 class MessageDemuxer { 26 public: 27 class MessageCallback { 28 public: 29 virtual ~MessageCallback() = default; 30 31 // |buffer| contains data for a message of type |message_type|. However, 32 // the data may be incomplete, in which case the callback should return an 33 // error code of Error::Code::kCborIncompleteMessage. This way, 34 // the MessageDemuxer knows to neither consume the data nor discard it as 35 // bad. 36 virtual ErrorOr<size_t> OnStreamMessage(uint64_t endpoint_id, 37 uint64_t connection_id, 38 msgs::Type message_type, 39 const uint8_t* buffer, 40 size_t buffer_size, 41 Clock::time_point now) = 0; 42 }; 43 44 class MessageWatch { 45 public: 46 MessageWatch(); 47 MessageWatch(MessageDemuxer* parent, 48 bool is_default, 49 uint64_t endpoint_id, 50 msgs::Type message_type); 51 MessageWatch(MessageWatch&&) noexcept; 52 ~MessageWatch(); 53 MessageWatch& operator=(MessageWatch&&) noexcept; 54 55 explicit operator bool() const { return parent_; } 56 57 private: 58 MessageDemuxer* parent_ = nullptr; 59 bool is_default_; 60 uint64_t endpoint_id_; 61 msgs::Type message_type_; 62 }; 63 64 static constexpr size_t kDefaultBufferLimit = 1 << 16; 65 66 MessageDemuxer(ClockNowFunctionPtr now_function, size_t buffer_limit); 67 ~MessageDemuxer(); 68 69 // Starts watching for messages of type |message_type| from the endpoint 70 // identified by |endpoint_id|. When such a message arrives, or if some are 71 // already buffered, |callback| will be called with the message data. 72 MessageWatch WatchMessageType(uint64_t endpoint_id, 73 msgs::Type message_type, 74 MessageCallback* callback); 75 76 // Starts watching for messages of type |message_type| from any endpoint when 77 // there is not callback set for its specific endpoint ID. 78 MessageWatch SetDefaultMessageTypeWatch(msgs::Type message_type, 79 MessageCallback* callback); 80 81 // Gives data from |endpoint_id| to the demuxer for processing. 82 // TODO(btolsch): It'd be nice if errors could propagate out of here to close 83 // the stream. 84 void OnStreamData(uint64_t endpoint_id, 85 uint64_t connection_id, 86 const uint8_t* data, 87 size_t data_size); 88 89 private: 90 struct HandleStreamBufferResult { 91 bool handled; 92 size_t consumed; 93 }; 94 95 void StopWatchingMessageType(uint64_t endpoint_id, msgs::Type message_type); 96 void StopDefaultMessageTypeWatch(msgs::Type message_type); 97 98 HandleStreamBufferResult HandleStreamBufferLoop( 99 uint64_t endpoint_id, 100 uint64_t connection_id, 101 std::map<uint64_t, std::map<msgs::Type, MessageCallback*>>::iterator 102 endpoint_entry, 103 std::vector<uint8_t>* buffer); 104 105 HandleStreamBufferResult HandleStreamBuffer( 106 uint64_t endpoint_id, 107 uint64_t connection_id, 108 std::map<msgs::Type, MessageCallback*>* message_callbacks, 109 std::vector<uint8_t>* buffer); 110 111 const ClockNowFunctionPtr now_function_; 112 const size_t buffer_limit_; 113 std::map<uint64_t, std::map<msgs::Type, MessageCallback*>> message_callbacks_; 114 std::map<msgs::Type, MessageCallback*> default_callbacks_; 115 116 // Map<endpoint_id, Map<connection_id, data_buffer>> 117 std::map<uint64_t, std::map<uint64_t, std::vector<uint8_t>>> buffers_; 118 }; 119 120 // TODO(btolsch): Make sure all uses of MessageWatch are converted to this 121 // resest function for readability. 122 void StopWatching(MessageDemuxer::MessageWatch* watch); 123 124 class MessageTypeDecoder { 125 public: 126 static ErrorOr<msgs::Type> DecodeType(const std::vector<uint8_t>& buffer, 127 size_t* num_bytes_decoded); 128 129 private: 130 static ErrorOr<uint64_t> DecodeVarUint(const std::vector<uint8_t>& buffer, 131 size_t* num_bytes_decoded); 132 }; 133 134 } // namespace osp 135 } // namespace openscreen 136 137 #endif // OSP_PUBLIC_MESSAGE_DEMUXER_H_ 138