• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2018 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "osp/impl/presentation/url_availability_requester.h"
6 
7 #include <algorithm>
8 #include <chrono>
9 #include <memory>
10 #include <utility>
11 #include <vector>
12 
13 #include "osp/impl/presentation/presentation_common.h"
14 #include "osp/public/network_service_manager.h"
15 #include "util/chrono_helpers.h"
16 #include "util/osp_logging.h"
17 
18 using std::chrono::seconds;
19 
20 namespace openscreen {
21 namespace osp {
22 namespace {
23 
24 static constexpr Clock::duration kWatchDuration = seconds(20);
25 static constexpr Clock::duration kWatchRefreshPadding = seconds(2);
26 
PartitionUrlsBySetMembership(std::vector<std::string> * urls,const std::set<std::string> & membership_test)27 std::vector<std::string>::iterator PartitionUrlsBySetMembership(
28     std::vector<std::string>* urls,
29     const std::set<std::string>& membership_test) {
30   return std::partition(
31       urls->begin(), urls->end(), [&membership_test](const std::string& url) {
32         return membership_test.find(url) == membership_test.end();
33       });
34 }
35 
MoveVectorSegment(std::vector<std::string>::iterator first,std::vector<std::string>::iterator last,std::set<std::string> * target)36 void MoveVectorSegment(std::vector<std::string>::iterator first,
37                        std::vector<std::string>::iterator last,
38                        std::set<std::string>* target) {
39   for (auto it = first; it != last; ++it)
40     target->emplace(std::move(*it));
41 }
42 
GetNextRequestId(const uint64_t endpoint_id)43 uint64_t GetNextRequestId(const uint64_t endpoint_id) {
44   return NetworkServiceManager::Get()
45       ->GetProtocolConnectionClient()
46       ->endpoint_request_ids()
47       ->GetNextRequestId(endpoint_id);
48 }
49 
50 }  // namespace
51 
UrlAvailabilityRequester(ClockNowFunctionPtr now_function)52 UrlAvailabilityRequester::UrlAvailabilityRequester(
53     ClockNowFunctionPtr now_function)
54     : now_function_(now_function) {
55   OSP_DCHECK(now_function_);
56 }
57 
58 UrlAvailabilityRequester::~UrlAvailabilityRequester() = default;
59 
AddObserver(const std::vector<std::string> & urls,ReceiverObserver * observer)60 void UrlAvailabilityRequester::AddObserver(const std::vector<std::string>& urls,
61                                            ReceiverObserver* observer) {
62   for (const auto& url : urls) {
63     observers_by_url_[url].push_back(observer);
64   }
65   for (auto& entry : receiver_by_service_id_) {
66     auto& receiver = entry.second;
67     receiver->GetOrRequestAvailabilities(urls, observer);
68   }
69 }
70 
RemoveObserverUrls(const std::vector<std::string> & urls,ReceiverObserver * observer)71 void UrlAvailabilityRequester::RemoveObserverUrls(
72     const std::vector<std::string>& urls,
73     ReceiverObserver* observer) {
74   std::set<std::string> unobserved_urls;
75   for (const auto& url : urls) {
76     auto observer_entry = observers_by_url_.find(url);
77     if (observer_entry == observers_by_url_.end())
78       continue;
79     auto& observers = observer_entry->second;
80     observers.erase(std::remove(observers.begin(), observers.end(), observer),
81                     observers.end());
82     if (observers.empty()) {
83       unobserved_urls.emplace(std::move(observer_entry->first));
84       observers_by_url_.erase(observer_entry);
85       for (auto& entry : receiver_by_service_id_) {
86         auto& receiver = entry.second;
87         receiver->known_availability_by_url.erase(url);
88       }
89     }
90   }
91 
92   for (auto& entry : receiver_by_service_id_) {
93     auto& receiver = entry.second;
94     receiver->RemoveUnobservedRequests(unobserved_urls);
95     receiver->RemoveUnobservedWatches(unobserved_urls);
96   }
97 }
98 
RemoveObserver(ReceiverObserver * observer)99 void UrlAvailabilityRequester::RemoveObserver(ReceiverObserver* observer) {
100   std::set<std::string> unobserved_urls;
101   for (auto& entry : observers_by_url_) {
102     auto& observer_list = entry.second;
103     auto it = std::remove(observer_list.begin(), observer_list.end(), observer);
104     if (it != observer_list.end()) {
105       observer_list.erase(it);
106       if (observer_list.empty())
107         unobserved_urls.insert(entry.first);
108     }
109   }
110 
111   for (auto& entry : receiver_by_service_id_) {
112     auto& receiver = entry.second;
113     receiver->RemoveUnobservedRequests(unobserved_urls);
114     receiver->RemoveUnobservedWatches(unobserved_urls);
115   }
116 }
117 
AddReceiver(const ServiceInfo & info)118 void UrlAvailabilityRequester::AddReceiver(const ServiceInfo& info) {
119   auto result = receiver_by_service_id_.emplace(
120       info.service_id,
121       std::make_unique<ReceiverRequester>(
122           this, info.service_id,
123           info.v4_endpoint.address ? info.v4_endpoint : info.v6_endpoint));
124   std::unique_ptr<ReceiverRequester>& receiver = result.first->second;
125   std::vector<std::string> urls;
126   urls.reserve(observers_by_url_.size());
127   for (const auto& url : observers_by_url_)
128     urls.push_back(url.first);
129   receiver->RequestUrlAvailabilities(std::move(urls));
130 }
131 
ChangeReceiver(const ServiceInfo & info)132 void UrlAvailabilityRequester::ChangeReceiver(const ServiceInfo& info) {}
133 
RemoveReceiver(const ServiceInfo & info)134 void UrlAvailabilityRequester::RemoveReceiver(const ServiceInfo& info) {
135   auto receiver_entry = receiver_by_service_id_.find(info.service_id);
136   if (receiver_entry != receiver_by_service_id_.end()) {
137     auto& receiver = receiver_entry->second;
138     receiver->RemoveReceiver();
139     receiver_by_service_id_.erase(receiver_entry);
140   }
141 }
142 
RemoveAllReceivers()143 void UrlAvailabilityRequester::RemoveAllReceivers() {
144   for (auto& entry : receiver_by_service_id_) {
145     auto& receiver = entry.second;
146     receiver->RemoveReceiver();
147   }
148   receiver_by_service_id_.clear();
149 }
150 
RefreshWatches()151 Clock::time_point UrlAvailabilityRequester::RefreshWatches() {
152   const Clock::time_point now = now_function_();
153   Clock::time_point minimum_schedule_time = now + kWatchDuration;
154   for (auto& entry : receiver_by_service_id_) {
155     auto& receiver = entry.second;
156     const Clock::time_point requested_schedule_time =
157         receiver->RefreshWatches(now);
158     if (requested_schedule_time < minimum_schedule_time)
159       minimum_schedule_time = requested_schedule_time;
160   }
161   return minimum_schedule_time;
162 }
163 
ReceiverRequester(UrlAvailabilityRequester * listener,const std::string & service_id,const IPEndpoint & endpoint)164 UrlAvailabilityRequester::ReceiverRequester::ReceiverRequester(
165     UrlAvailabilityRequester* listener,
166     const std::string& service_id,
167     const IPEndpoint& endpoint)
168     : listener(listener),
169       service_id(service_id),
170       connect_request(
171           NetworkServiceManager::Get()->GetProtocolConnectionClient()->Connect(
172               endpoint,
173               this)) {}
174 
175 UrlAvailabilityRequester::ReceiverRequester::~ReceiverRequester() = default;
176 
GetOrRequestAvailabilities(const std::vector<std::string> & requested_urls,ReceiverObserver * observer)177 void UrlAvailabilityRequester::ReceiverRequester::GetOrRequestAvailabilities(
178     const std::vector<std::string>& requested_urls,
179     ReceiverObserver* observer) {
180   std::vector<std::string> unknown_urls;
181   for (const auto& url : requested_urls) {
182     auto availability_entry = known_availability_by_url.find(url);
183     if (availability_entry == known_availability_by_url.end()) {
184       unknown_urls.emplace_back(url);
185       continue;
186     }
187 
188     msgs::UrlAvailability availability = availability_entry->second;
189     if (observer) {
190       switch (availability) {
191         case msgs::UrlAvailability::kAvailable:
192           observer->OnReceiverAvailable(url, service_id);
193           break;
194         case msgs::UrlAvailability::kUnavailable:
195         case msgs::UrlAvailability::kInvalid:
196           observer->OnReceiverUnavailable(url, service_id);
197           break;
198       }
199     }
200   }
201   if (!unknown_urls.empty()) {
202     RequestUrlAvailabilities(std::move(unknown_urls));
203   }
204 }
205 
RequestUrlAvailabilities(std::vector<std::string> urls)206 void UrlAvailabilityRequester::ReceiverRequester::RequestUrlAvailabilities(
207     std::vector<std::string> urls) {
208   if (urls.empty())
209     return;
210   const uint64_t request_id = GetNextRequestId(endpoint_id);
211   ErrorOr<uint64_t> watch_id_or_error(0);
212   if (!connection || (watch_id_or_error = SendRequest(request_id, urls))) {
213     request_by_id.emplace(request_id,
214                           Request{watch_id_or_error.value(), std::move(urls)});
215   } else {
216     for (const auto& url : urls)
217       for (auto& observer : listener->observers_by_url_[url])
218         observer->OnRequestFailed(url, service_id);
219   }
220 }
221 
SendRequest(uint64_t request_id,const std::vector<std::string> & urls)222 ErrorOr<uint64_t> UrlAvailabilityRequester::ReceiverRequester::SendRequest(
223     uint64_t request_id,
224     const std::vector<std::string>& urls) {
225   uint64_t watch_id = next_watch_id++;
226   msgs::PresentationUrlAvailabilityRequest cbor_request;
227   cbor_request.request_id = request_id;
228   cbor_request.urls = urls;
229   cbor_request.watch_id = watch_id;
230   cbor_request.watch_duration = to_microseconds(kWatchDuration).count();
231 
232   msgs::CborEncodeBuffer buffer;
233   if (msgs::EncodePresentationUrlAvailabilityRequest(cbor_request, &buffer)) {
234     OSP_VLOG << "writing presentation-url-availability-request";
235     connection->Write(buffer.data(), buffer.size());
236     watch_by_id.emplace(
237         watch_id, Watch{listener->now_function_() + kWatchDuration, urls});
238     if (!event_watch) {
239       event_watch = GetClientDemuxer()->WatchMessageType(
240           endpoint_id, msgs::Type::kPresentationUrlAvailabilityEvent, this);
241     }
242     if (!response_watch) {
243       response_watch = GetClientDemuxer()->WatchMessageType(
244           endpoint_id, msgs::Type::kPresentationUrlAvailabilityResponse, this);
245     }
246     return watch_id;
247   }
248   return Error::Code::kCborEncoding;
249 }
250 
RefreshWatches(Clock::time_point now)251 Clock::time_point UrlAvailabilityRequester::ReceiverRequester::RefreshWatches(
252     Clock::time_point now) {
253   Clock::time_point minimum_schedule_time = now + kWatchDuration;
254   std::vector<std::vector<std::string>> new_requests;
255   for (auto entry = watch_by_id.begin(); entry != watch_by_id.end();) {
256     Watch& watch = entry->second;
257     const Clock::time_point buffered_deadline =
258         watch.deadline - kWatchRefreshPadding;
259     if (now > buffered_deadline) {
260       new_requests.emplace_back(std::move(watch.urls));
261       entry = watch_by_id.erase(entry);
262     } else {
263       ++entry;
264       if (buffered_deadline < minimum_schedule_time)
265         minimum_schedule_time = buffered_deadline;
266     }
267   }
268   if (watch_by_id.empty())
269     StopWatching(&event_watch);
270 
271   for (auto& request : new_requests)
272     RequestUrlAvailabilities(std::move(request));
273 
274   return minimum_schedule_time;
275 }
276 
UpdateAvailabilities(const std::vector<std::string> & urls,const std::vector<msgs::UrlAvailability> & availabilities)277 Error::Code UrlAvailabilityRequester::ReceiverRequester::UpdateAvailabilities(
278     const std::vector<std::string>& urls,
279     const std::vector<msgs::UrlAvailability>& availabilities) {
280   auto availability_it = availabilities.begin();
281   if (urls.size() != availabilities.size()) {
282     return Error::Code::kCborInvalidMessage;
283   }
284   for (const auto& url : urls) {
285     auto observer_entry = listener->observers_by_url_.find(url);
286     if (observer_entry == listener->observers_by_url_.end())
287       continue;
288     std::vector<ReceiverObserver*>& observers = observer_entry->second;
289     auto result = known_availability_by_url.emplace(url, *availability_it);
290     auto entry = result.first;
291     bool inserted = result.second;
292     bool updated = (entry->second != *availability_it);
293     if (inserted || updated) {
294       switch (*availability_it) {
295         case msgs::UrlAvailability::kAvailable:
296           for (auto* observer : observers)
297             observer->OnReceiverAvailable(url, service_id);
298           break;
299         case msgs::UrlAvailability::kUnavailable:
300         case msgs::UrlAvailability::kInvalid:
301           for (auto* observer : observers)
302             observer->OnReceiverUnavailable(url, service_id);
303           break;
304         default:
305           break;
306       }
307     }
308     ++availability_it;
309   }
310   return Error::Code::kNone;
311 }
312 
RemoveUnobservedRequests(const std::set<std::string> & unobserved_urls)313 void UrlAvailabilityRequester::ReceiverRequester::RemoveUnobservedRequests(
314     const std::set<std::string>& unobserved_urls) {
315   std::map<uint64_t, Request> new_requests;
316   std::set<std::string> still_observed_urls;
317   for (auto entry = request_by_id.begin(); entry != request_by_id.end();
318        ++entry) {
319     Request& request = entry->second;
320     auto split = PartitionUrlsBySetMembership(&request.urls, unobserved_urls);
321     if (split == request.urls.end())
322       continue;
323     MoveVectorSegment(request.urls.begin(), split, &still_observed_urls);
324     if (connection)
325       watch_by_id.erase(request.watch_id);
326   }
327   if (!still_observed_urls.empty()) {
328     const uint64_t new_request_id = GetNextRequestId(endpoint_id);
329     ErrorOr<uint64_t> watch_id_or_error(0);
330     std::vector<std::string> urls;
331     urls.reserve(still_observed_urls.size());
332     for (auto& url : still_observed_urls)
333       urls.emplace_back(std::move(url));
334     if (!connection ||
335         (watch_id_or_error = SendRequest(new_request_id, urls))) {
336       new_requests.emplace(new_request_id,
337                            Request{watch_id_or_error.value(), std::move(urls)});
338     } else {
339       for (const auto& url : urls)
340         for (auto& observer : listener->observers_by_url_[url])
341           observer->OnRequestFailed(url, service_id);
342     }
343   }
344 
345   for (auto& entry : new_requests)
346     request_by_id.emplace(entry.first, std::move(entry.second));
347 
348   if (request_by_id.empty())
349     StopWatching(&response_watch);
350 }
351 
RemoveUnobservedWatches(const std::set<std::string> & unobserved_urls)352 void UrlAvailabilityRequester::ReceiverRequester::RemoveUnobservedWatches(
353     const std::set<std::string>& unobserved_urls) {
354   std::set<std::string> still_observed_urls;
355   for (auto entry = watch_by_id.begin(); entry != watch_by_id.end();) {
356     Watch& watch = entry->second;
357     auto split = PartitionUrlsBySetMembership(&watch.urls, unobserved_urls);
358     if (split == watch.urls.end()) {
359       ++entry;
360       continue;
361     }
362     MoveVectorSegment(watch.urls.begin(), split, &still_observed_urls);
363     entry = watch_by_id.erase(entry);
364   }
365 
366   std::vector<std::string> urls;
367   urls.reserve(still_observed_urls.size());
368   for (auto& url : still_observed_urls)
369     urls.emplace_back(std::move(url));
370   RequestUrlAvailabilities(std::move(urls));
371   // TODO(btolsch): These message watch cancels could be tested by expecting
372   // messages to fall through to the default watch.
373   if (watch_by_id.empty())
374     StopWatching(&event_watch);
375 }
376 
RemoveReceiver()377 void UrlAvailabilityRequester::ReceiverRequester::RemoveReceiver() {
378   for (const auto& availability : known_availability_by_url) {
379     if (availability.second == msgs::UrlAvailability::kAvailable) {
380       const std::string& url = availability.first;
381       for (auto& observer : listener->observers_by_url_[url])
382         observer->OnReceiverUnavailable(url, service_id);
383     }
384   }
385 }
386 
OnConnectionOpened(uint64_t request_id,std::unique_ptr<ProtocolConnection> connection)387 void UrlAvailabilityRequester::ReceiverRequester::OnConnectionOpened(
388     uint64_t request_id,
389     std::unique_ptr<ProtocolConnection> connection) {
390   connect_request.MarkComplete();
391   // TODO(btolsch): This is one place where we need to make sure the QUIC
392   // connection stays alive, even without constant traffic.
393   endpoint_id = connection->endpoint_id();
394   this->connection = std::move(connection);
395   ErrorOr<uint64_t> watch_id_or_error(0);
396   for (auto entry = request_by_id.begin(); entry != request_by_id.end();) {
397     if ((watch_id_or_error = SendRequest(entry->first, entry->second.urls))) {
398       entry->second.watch_id = watch_id_or_error.value();
399       ++entry;
400     } else {
401       entry = request_by_id.erase(entry);
402     }
403   }
404 }
405 
OnConnectionFailed(uint64_t request_id)406 void UrlAvailabilityRequester::ReceiverRequester::OnConnectionFailed(
407     uint64_t request_id) {
408   connect_request.MarkComplete();
409 
410   std::set<std::string> waiting_urls;
411   for (auto& entry : request_by_id) {
412     Request& request = entry.second;
413     for (auto& url : request.urls) {
414       waiting_urls.emplace(std::move(url));
415     }
416   }
417   for (const auto& url : waiting_urls)
418     for (auto& observer : listener->observers_by_url_[url])
419       observer->OnRequestFailed(url, service_id);
420 
421   std::string id = std::move(service_id);
422   listener->receiver_by_service_id_.erase(id);
423 }
424 
OnStreamMessage(uint64_t endpoint_id,uint64_t connection_id,msgs::Type message_type,const uint8_t * buffer,size_t buffer_size,Clock::time_point now)425 ErrorOr<size_t> UrlAvailabilityRequester::ReceiverRequester::OnStreamMessage(
426     uint64_t endpoint_id,
427     uint64_t connection_id,
428     msgs::Type message_type,
429     const uint8_t* buffer,
430     size_t buffer_size,
431     Clock::time_point now) {
432   switch (message_type) {
433     case msgs::Type::kPresentationUrlAvailabilityResponse: {
434       msgs::PresentationUrlAvailabilityResponse response;
435       ssize_t result = msgs::DecodePresentationUrlAvailabilityResponse(
436           buffer, buffer_size, &response);
437       if (result < 0) {
438         if (result == msgs::kParserEOF)
439           return Error::Code::kCborIncompleteMessage;
440         OSP_LOG_WARN << "parse error: " << result;
441         return Error::Code::kCborParsing;
442       } else {
443         auto request_entry = request_by_id.find(response.request_id);
444         if (request_entry == request_by_id.end()) {
445           OSP_LOG_ERROR << "bad response id: " << response.request_id;
446           return Error::Code::kCborInvalidResponseId;
447         }
448         std::vector<std::string>& urls = request_entry->second.urls;
449         if (urls.size() != response.url_availabilities.size()) {
450           OSP_LOG_WARN << "bad response size: expected " << urls.size()
451                        << " but got " << response.url_availabilities.size();
452           return Error::Code::kCborInvalidMessage;
453         }
454         Error::Code update_result =
455             UpdateAvailabilities(urls, response.url_availabilities);
456         if (update_result != Error::Code::kNone) {
457           return update_result;
458         }
459         request_by_id.erase(response.request_id);
460         if (request_by_id.empty())
461           StopWatching(&response_watch);
462         return result;
463       }
464     } break;
465     case msgs::Type::kPresentationUrlAvailabilityEvent: {
466       msgs::PresentationUrlAvailabilityEvent event;
467       ssize_t result = msgs::DecodePresentationUrlAvailabilityEvent(
468           buffer, buffer_size, &event);
469       if (result < 0) {
470         if (result == msgs::kParserEOF)
471           return Error::Code::kCborIncompleteMessage;
472         OSP_LOG_WARN << "parse error: " << result;
473         return Error::Code::kCborParsing;
474       } else {
475         auto watch_entry = watch_by_id.find(event.watch_id);
476         if (watch_entry != watch_by_id.end()) {
477           std::vector<std::string> urls = watch_entry->second.urls;
478           Error::Code update_result =
479               UpdateAvailabilities(urls, event.url_availabilities);
480           if (update_result != Error::Code::kNone) {
481             return update_result;
482           }
483         }
484         return result;
485       }
486     } break;
487     default:
488       break;
489   }
490   return Error::Code::kCborParsing;
491 }
492 
493 }  // namespace osp
494 }  // namespace openscreen
495