• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifndef OSP_PUBLIC_REQUEST_RESPONSE_HANDLER_H_
6 #define OSP_PUBLIC_REQUEST_RESPONSE_HANDLER_H_
7 
8 #include <cstddef>
9 #include <cstdint>
10 #include <type_traits>
11 #include <utility>
12 #include <vector>
13 
14 #include "absl/types/optional.h"
15 #include "osp/public/message_demuxer.h"
16 #include "osp/public/network_service_manager.h"
17 #include "osp/public/protocol_connection.h"
18 #include "platform/base/error.h"
19 #include "platform/base/macros.h"
20 #include "util/osp_logging.h"
21 
22 namespace openscreen {
23 namespace osp {
24 
25 template <typename T>
26 using MessageDecodingFunction = ssize_t (*)(const uint8_t*, size_t, T*);
27 
28 // Provides a uniform way of accessing import properties of a request/response
29 // message pair from a template: request encode function, response decode
30 // function, request serializable data member.
31 template <typename T>
32 struct DefaultRequestCoderTraits {
33  public:
34   using RequestMsgType = typename T::RequestMsgType;
35   static constexpr MessageEncodingFunction<RequestMsgType> kEncoder =
36       T::kEncoder;
37   static constexpr MessageDecodingFunction<typename T::ResponseMsgType>
38       kDecoder = T::kDecoder;
39 
serial_requestDefaultRequestCoderTraits40   static const RequestMsgType* serial_request(const T& data) {
41     return &data.request;
42   }
serial_requestDefaultRequestCoderTraits43   static RequestMsgType* serial_request(T& data) { return &data.request; }
44 };
45 
46 // Provides a wrapper for the common pattern of sending a request message and
47 // waiting for a response message with a matching |request_id| field.  It also
48 // handles the business of queueing messages to be sent until a protocol
49 // connection is available.
50 //
51 // Messages are written using WriteMessage.  This will queue messages if there
52 // is no protocol connection or write them immediately if there is.  When a
53 // matching response is received via the MessageDemuxer (taken from the global
54 // ProtocolConnectionClient), OnMatchedResponse is called on the provided
55 // Delegate object along with the original request that it matches.
56 template <typename RequestT,
57           typename RequestCoderTraits = DefaultRequestCoderTraits<RequestT>>
58 class RequestResponseHandler : public MessageDemuxer::MessageCallback {
59  public:
60   class Delegate {
61    public:
62     virtual ~Delegate() = default;
63 
64     virtual void OnMatchedResponse(RequestT* request,
65                                    typename RequestT::ResponseMsgType* response,
66                                    uint64_t endpoint_id) = 0;
67     virtual void OnError(RequestT* request, Error error) = 0;
68   };
69 
RequestResponseHandler(Delegate * delegate)70   explicit RequestResponseHandler(Delegate* delegate) : delegate_(delegate) {}
~RequestResponseHandler()71   ~RequestResponseHandler() { Reset(); }
72 
Reset()73   void Reset() {
74     connection_ = nullptr;
75     for (auto& message : to_send_) {
76       delegate_->OnError(&message.request, Error::Code::kRequestCancelled);
77     }
78     to_send_.clear();
79     for (auto& message : sent_) {
80       delegate_->OnError(&message.request, Error::Code::kRequestCancelled);
81     }
82     sent_.clear();
83     response_watch_ = MessageDemuxer::MessageWatch();
84   }
85 
86   // Write a message to the underlying protocol connection, or queue it until
87   // one is provided via SetConnection.  If |id| is provided, it can be used to
88   // cancel the message via CancelMessage.
89   template <typename RequestTRval>
90   typename std::enable_if<
91       !std::is_lvalue_reference<RequestTRval>::value &&
92           std::is_same<typename std::decay<RequestTRval>::type,
93                        RequestT>::value,
94       Error>::type
WriteMessage(absl::optional<uint64_t> id,RequestTRval && message)95   WriteMessage(absl::optional<uint64_t> id, RequestTRval&& message) {
96     auto* request_msg = RequestCoderTraits::serial_request(message);
97     if (connection_) {
98       request_msg->request_id = GetNextRequestId(connection_->endpoint_id());
99       Error result =
100           connection_->WriteMessage(*request_msg, RequestCoderTraits::kEncoder);
101       if (!result.ok()) {
102         return result;
103       }
104       sent_.emplace_back(RequestWithId{id, std::move(message)});
105       EnsureResponseWatch();
106     } else {
107       to_send_.emplace_back(RequestWithId{id, std::move(message)});
108     }
109     return Error::None();
110   }
111 
112   template <typename RequestTRval>
113   typename std::enable_if<
114       !std::is_lvalue_reference<RequestTRval>::value &&
115           std::is_same<typename std::decay<RequestTRval>::type,
116                        RequestT>::value,
117       Error>::type
WriteMessage(RequestTRval && message)118   WriteMessage(RequestTRval&& message) {
119     return WriteMessage(absl::nullopt, std::move(message));
120   }
121 
122   // Remove the message that was originally written with |id| from the send and
123   // sent queues so that we are no longer looking for a response.
CancelMessage(uint64_t id)124   void CancelMessage(uint64_t id) {
125     to_send_.erase(std::remove_if(to_send_.begin(), to_send_.end(),
126                                   [&id](const RequestWithId& msg) {
127                                     return id == msg.id;
128                                   }),
129                    to_send_.end());
130     sent_.erase(std::remove_if(
131                     sent_.begin(), sent_.end(),
132                     [&id](const RequestWithId& msg) { return id == msg.id; }),
133                 sent_.end());
134     if (sent_.empty()) {
135       response_watch_ = MessageDemuxer::MessageWatch();
136     }
137   }
138 
139   // Assign a ProtocolConnection to this handler for writing messages.
SetConnection(ProtocolConnection * connection)140   void SetConnection(ProtocolConnection* connection) {
141     connection_ = connection;
142     for (auto& message : to_send_) {
143       auto* request_msg = RequestCoderTraits::serial_request(message.request);
144       request_msg->request_id = GetNextRequestId(connection_->endpoint_id());
145       Error result =
146           connection_->WriteMessage(*request_msg, RequestCoderTraits::kEncoder);
147       if (result.ok()) {
148         sent_.emplace_back(std::move(message));
149       } else {
150         delegate_->OnError(&message.request, result);
151       }
152     }
153     if (!to_send_.empty()) {
154       EnsureResponseWatch();
155     }
156     to_send_.clear();
157   }
158 
159   // MessageDemuxer::MessageCallback overrides.
OnStreamMessage(uint64_t endpoint_id,uint64_t connection_id,msgs::Type message_type,const uint8_t * buffer,size_t buffer_size,Clock::time_point now)160   ErrorOr<size_t> OnStreamMessage(uint64_t endpoint_id,
161                                   uint64_t connection_id,
162                                   msgs::Type message_type,
163                                   const uint8_t* buffer,
164                                   size_t buffer_size,
165                                   Clock::time_point now) override {
166     if (message_type != RequestT::kResponseType) {
167       return 0;
168     }
169     typename RequestT::ResponseMsgType response;
170     ssize_t result =
171         RequestCoderTraits::kDecoder(buffer, buffer_size, &response);
172     if (result < 0) {
173       return 0;
174     }
175     auto it = std::find_if(
176         sent_.begin(), sent_.end(), [&response](const RequestWithId& msg) {
177           return RequestCoderTraits::serial_request(msg.request)->request_id ==
178                  response.request_id;
179         });
180     if (it != sent_.end()) {
181       delegate_->OnMatchedResponse(&it->request, &response,
182                                    connection_->endpoint_id());
183       sent_.erase(it);
184       if (sent_.empty()) {
185         response_watch_ = MessageDemuxer::MessageWatch();
186       }
187     } else {
188       OSP_LOG_WARN << "got response for unknown request id: "
189                    << response.request_id;
190     }
191     return result;
192   }
193 
194  private:
195   struct RequestWithId {
196     absl::optional<uint64_t> id;
197     RequestT request;
198   };
199 
EnsureResponseWatch()200   void EnsureResponseWatch() {
201     if (!response_watch_) {
202       response_watch_ = NetworkServiceManager::Get()
203                             ->GetProtocolConnectionClient()
204                             ->message_demuxer()
205                             ->WatchMessageType(connection_->endpoint_id(),
206                                                RequestT::kResponseType, this);
207     }
208   }
209 
GetNextRequestId(uint64_t endpoint_id)210   uint64_t GetNextRequestId(uint64_t endpoint_id) {
211     return NetworkServiceManager::Get()
212         ->GetProtocolConnectionClient()
213         ->endpoint_request_ids()
214         ->GetNextRequestId(endpoint_id);
215   }
216 
217   ProtocolConnection* connection_ = nullptr;
218   Delegate* const delegate_;
219   std::vector<RequestWithId> to_send_;
220   std::vector<RequestWithId> sent_;
221   MessageDemuxer::MessageWatch response_watch_;
222 
223   OSP_DISALLOW_COPY_AND_ASSIGN(RequestResponseHandler);
224 };
225 
226 }  // namespace osp
227 }  // namespace openscreen
228 
229 #endif  // OSP_PUBLIC_REQUEST_RESPONSE_HANDLER_H_
230