1 // Copyright 2019 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #ifndef OSP_PUBLIC_REQUEST_RESPONSE_HANDLER_H_ 6 #define OSP_PUBLIC_REQUEST_RESPONSE_HANDLER_H_ 7 8 #include <cstddef> 9 #include <cstdint> 10 #include <type_traits> 11 #include <utility> 12 #include <vector> 13 14 #include "absl/types/optional.h" 15 #include "osp/public/message_demuxer.h" 16 #include "osp/public/network_service_manager.h" 17 #include "osp/public/protocol_connection.h" 18 #include "platform/base/error.h" 19 #include "platform/base/macros.h" 20 #include "util/osp_logging.h" 21 22 namespace openscreen { 23 namespace osp { 24 25 template <typename T> 26 using MessageDecodingFunction = ssize_t (*)(const uint8_t*, size_t, T*); 27 28 // Provides a uniform way of accessing import properties of a request/response 29 // message pair from a template: request encode function, response decode 30 // function, request serializable data member. 31 template <typename T> 32 struct DefaultRequestCoderTraits { 33 public: 34 using RequestMsgType = typename T::RequestMsgType; 35 static constexpr MessageEncodingFunction<RequestMsgType> kEncoder = 36 T::kEncoder; 37 static constexpr MessageDecodingFunction<typename T::ResponseMsgType> 38 kDecoder = T::kDecoder; 39 serial_requestDefaultRequestCoderTraits40 static const RequestMsgType* serial_request(const T& data) { 41 return &data.request; 42 } serial_requestDefaultRequestCoderTraits43 static RequestMsgType* serial_request(T& data) { return &data.request; } 44 }; 45 46 // Provides a wrapper for the common pattern of sending a request message and 47 // waiting for a response message with a matching |request_id| field. It also 48 // handles the business of queueing messages to be sent until a protocol 49 // connection is available. 50 // 51 // Messages are written using WriteMessage. This will queue messages if there 52 // is no protocol connection or write them immediately if there is. When a 53 // matching response is received via the MessageDemuxer (taken from the global 54 // ProtocolConnectionClient), OnMatchedResponse is called on the provided 55 // Delegate object along with the original request that it matches. 56 template <typename RequestT, 57 typename RequestCoderTraits = DefaultRequestCoderTraits<RequestT>> 58 class RequestResponseHandler : public MessageDemuxer::MessageCallback { 59 public: 60 class Delegate { 61 public: 62 virtual ~Delegate() = default; 63 64 virtual void OnMatchedResponse(RequestT* request, 65 typename RequestT::ResponseMsgType* response, 66 uint64_t endpoint_id) = 0; 67 virtual void OnError(RequestT* request, Error error) = 0; 68 }; 69 RequestResponseHandler(Delegate * delegate)70 explicit RequestResponseHandler(Delegate* delegate) : delegate_(delegate) {} ~RequestResponseHandler()71 ~RequestResponseHandler() { Reset(); } 72 Reset()73 void Reset() { 74 connection_ = nullptr; 75 for (auto& message : to_send_) { 76 delegate_->OnError(&message.request, Error::Code::kRequestCancelled); 77 } 78 to_send_.clear(); 79 for (auto& message : sent_) { 80 delegate_->OnError(&message.request, Error::Code::kRequestCancelled); 81 } 82 sent_.clear(); 83 response_watch_ = MessageDemuxer::MessageWatch(); 84 } 85 86 // Write a message to the underlying protocol connection, or queue it until 87 // one is provided via SetConnection. If |id| is provided, it can be used to 88 // cancel the message via CancelMessage. 89 template <typename RequestTRval> 90 typename std::enable_if< 91 !std::is_lvalue_reference<RequestTRval>::value && 92 std::is_same<typename std::decay<RequestTRval>::type, 93 RequestT>::value, 94 Error>::type WriteMessage(absl::optional<uint64_t> id,RequestTRval && message)95 WriteMessage(absl::optional<uint64_t> id, RequestTRval&& message) { 96 auto* request_msg = RequestCoderTraits::serial_request(message); 97 if (connection_) { 98 request_msg->request_id = GetNextRequestId(connection_->endpoint_id()); 99 Error result = 100 connection_->WriteMessage(*request_msg, RequestCoderTraits::kEncoder); 101 if (!result.ok()) { 102 return result; 103 } 104 sent_.emplace_back(RequestWithId{id, std::move(message)}); 105 EnsureResponseWatch(); 106 } else { 107 to_send_.emplace_back(RequestWithId{id, std::move(message)}); 108 } 109 return Error::None(); 110 } 111 112 template <typename RequestTRval> 113 typename std::enable_if< 114 !std::is_lvalue_reference<RequestTRval>::value && 115 std::is_same<typename std::decay<RequestTRval>::type, 116 RequestT>::value, 117 Error>::type WriteMessage(RequestTRval && message)118 WriteMessage(RequestTRval&& message) { 119 return WriteMessage(absl::nullopt, std::move(message)); 120 } 121 122 // Remove the message that was originally written with |id| from the send and 123 // sent queues so that we are no longer looking for a response. CancelMessage(uint64_t id)124 void CancelMessage(uint64_t id) { 125 to_send_.erase(std::remove_if(to_send_.begin(), to_send_.end(), 126 [&id](const RequestWithId& msg) { 127 return id == msg.id; 128 }), 129 to_send_.end()); 130 sent_.erase(std::remove_if( 131 sent_.begin(), sent_.end(), 132 [&id](const RequestWithId& msg) { return id == msg.id; }), 133 sent_.end()); 134 if (sent_.empty()) { 135 response_watch_ = MessageDemuxer::MessageWatch(); 136 } 137 } 138 139 // Assign a ProtocolConnection to this handler for writing messages. SetConnection(ProtocolConnection * connection)140 void SetConnection(ProtocolConnection* connection) { 141 connection_ = connection; 142 for (auto& message : to_send_) { 143 auto* request_msg = RequestCoderTraits::serial_request(message.request); 144 request_msg->request_id = GetNextRequestId(connection_->endpoint_id()); 145 Error result = 146 connection_->WriteMessage(*request_msg, RequestCoderTraits::kEncoder); 147 if (result.ok()) { 148 sent_.emplace_back(std::move(message)); 149 } else { 150 delegate_->OnError(&message.request, result); 151 } 152 } 153 if (!to_send_.empty()) { 154 EnsureResponseWatch(); 155 } 156 to_send_.clear(); 157 } 158 159 // MessageDemuxer::MessageCallback overrides. OnStreamMessage(uint64_t endpoint_id,uint64_t connection_id,msgs::Type message_type,const uint8_t * buffer,size_t buffer_size,Clock::time_point now)160 ErrorOr<size_t> OnStreamMessage(uint64_t endpoint_id, 161 uint64_t connection_id, 162 msgs::Type message_type, 163 const uint8_t* buffer, 164 size_t buffer_size, 165 Clock::time_point now) override { 166 if (message_type != RequestT::kResponseType) { 167 return 0; 168 } 169 typename RequestT::ResponseMsgType response; 170 ssize_t result = 171 RequestCoderTraits::kDecoder(buffer, buffer_size, &response); 172 if (result < 0) { 173 return 0; 174 } 175 auto it = std::find_if( 176 sent_.begin(), sent_.end(), [&response](const RequestWithId& msg) { 177 return RequestCoderTraits::serial_request(msg.request)->request_id == 178 response.request_id; 179 }); 180 if (it != sent_.end()) { 181 delegate_->OnMatchedResponse(&it->request, &response, 182 connection_->endpoint_id()); 183 sent_.erase(it); 184 if (sent_.empty()) { 185 response_watch_ = MessageDemuxer::MessageWatch(); 186 } 187 } else { 188 OSP_LOG_WARN << "got response for unknown request id: " 189 << response.request_id; 190 } 191 return result; 192 } 193 194 private: 195 struct RequestWithId { 196 absl::optional<uint64_t> id; 197 RequestT request; 198 }; 199 EnsureResponseWatch()200 void EnsureResponseWatch() { 201 if (!response_watch_) { 202 response_watch_ = NetworkServiceManager::Get() 203 ->GetProtocolConnectionClient() 204 ->message_demuxer() 205 ->WatchMessageType(connection_->endpoint_id(), 206 RequestT::kResponseType, this); 207 } 208 } 209 GetNextRequestId(uint64_t endpoint_id)210 uint64_t GetNextRequestId(uint64_t endpoint_id) { 211 return NetworkServiceManager::Get() 212 ->GetProtocolConnectionClient() 213 ->endpoint_request_ids() 214 ->GetNextRequestId(endpoint_id); 215 } 216 217 ProtocolConnection* connection_ = nullptr; 218 Delegate* const delegate_; 219 std::vector<RequestWithId> to_send_; 220 std::vector<RequestWithId> sent_; 221 MessageDemuxer::MessageWatch response_watch_; 222 223 OSP_DISALLOW_COPY_AND_ASSIGN(RequestResponseHandler); 224 }; 225 226 } // namespace osp 227 } // namespace openscreen 228 229 #endif // OSP_PUBLIC_REQUEST_RESPONSE_HANDLER_H_ 230