1 // Copyright 2018 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "osp/impl/presentation/url_availability_requester.h"
6
7 #include <algorithm>
8 #include <chrono>
9 #include <memory>
10 #include <utility>
11 #include <vector>
12
13 #include "osp/impl/presentation/presentation_common.h"
14 #include "osp/public/network_service_manager.h"
15 #include "util/chrono_helpers.h"
16 #include "util/osp_logging.h"
17
18 using std::chrono::seconds;
19
20 namespace openscreen {
21 namespace osp {
22 namespace {
23
24 static constexpr Clock::duration kWatchDuration = seconds(20);
25 static constexpr Clock::duration kWatchRefreshPadding = seconds(2);
26
PartitionUrlsBySetMembership(std::vector<std::string> * urls,const std::set<std::string> & membership_test)27 std::vector<std::string>::iterator PartitionUrlsBySetMembership(
28 std::vector<std::string>* urls,
29 const std::set<std::string>& membership_test) {
30 return std::partition(
31 urls->begin(), urls->end(), [&membership_test](const std::string& url) {
32 return membership_test.find(url) == membership_test.end();
33 });
34 }
35
MoveVectorSegment(std::vector<std::string>::iterator first,std::vector<std::string>::iterator last,std::set<std::string> * target)36 void MoveVectorSegment(std::vector<std::string>::iterator first,
37 std::vector<std::string>::iterator last,
38 std::set<std::string>* target) {
39 for (auto it = first; it != last; ++it)
40 target->emplace(std::move(*it));
41 }
42
GetNextRequestId(const uint64_t endpoint_id)43 uint64_t GetNextRequestId(const uint64_t endpoint_id) {
44 return NetworkServiceManager::Get()
45 ->GetProtocolConnectionClient()
46 ->endpoint_request_ids()
47 ->GetNextRequestId(endpoint_id);
48 }
49
50 } // namespace
51
UrlAvailabilityRequester(ClockNowFunctionPtr now_function)52 UrlAvailabilityRequester::UrlAvailabilityRequester(
53 ClockNowFunctionPtr now_function)
54 : now_function_(now_function) {
55 OSP_DCHECK(now_function_);
56 }
57
58 UrlAvailabilityRequester::~UrlAvailabilityRequester() = default;
59
AddObserver(const std::vector<std::string> & urls,ReceiverObserver * observer)60 void UrlAvailabilityRequester::AddObserver(const std::vector<std::string>& urls,
61 ReceiverObserver* observer) {
62 for (const auto& url : urls) {
63 observers_by_url_[url].push_back(observer);
64 }
65 for (auto& entry : receiver_by_service_id_) {
66 auto& receiver = entry.second;
67 receiver->GetOrRequestAvailabilities(urls, observer);
68 }
69 }
70
RemoveObserverUrls(const std::vector<std::string> & urls,ReceiverObserver * observer)71 void UrlAvailabilityRequester::RemoveObserverUrls(
72 const std::vector<std::string>& urls,
73 ReceiverObserver* observer) {
74 std::set<std::string> unobserved_urls;
75 for (const auto& url : urls) {
76 auto observer_entry = observers_by_url_.find(url);
77 if (observer_entry == observers_by_url_.end())
78 continue;
79 auto& observers = observer_entry->second;
80 observers.erase(std::remove(observers.begin(), observers.end(), observer),
81 observers.end());
82 if (observers.empty()) {
83 unobserved_urls.emplace(std::move(observer_entry->first));
84 observers_by_url_.erase(observer_entry);
85 for (auto& entry : receiver_by_service_id_) {
86 auto& receiver = entry.second;
87 receiver->known_availability_by_url.erase(url);
88 }
89 }
90 }
91
92 for (auto& entry : receiver_by_service_id_) {
93 auto& receiver = entry.second;
94 receiver->RemoveUnobservedRequests(unobserved_urls);
95 receiver->RemoveUnobservedWatches(unobserved_urls);
96 }
97 }
98
RemoveObserver(ReceiverObserver * observer)99 void UrlAvailabilityRequester::RemoveObserver(ReceiverObserver* observer) {
100 std::set<std::string> unobserved_urls;
101 for (auto& entry : observers_by_url_) {
102 auto& observer_list = entry.second;
103 auto it = std::remove(observer_list.begin(), observer_list.end(), observer);
104 if (it != observer_list.end()) {
105 observer_list.erase(it);
106 if (observer_list.empty())
107 unobserved_urls.insert(entry.first);
108 }
109 }
110
111 for (auto& entry : receiver_by_service_id_) {
112 auto& receiver = entry.second;
113 receiver->RemoveUnobservedRequests(unobserved_urls);
114 receiver->RemoveUnobservedWatches(unobserved_urls);
115 }
116 }
117
AddReceiver(const ServiceInfo & info)118 void UrlAvailabilityRequester::AddReceiver(const ServiceInfo& info) {
119 auto result = receiver_by_service_id_.emplace(
120 info.service_id,
121 std::make_unique<ReceiverRequester>(
122 this, info.service_id,
123 info.v4_endpoint.address ? info.v4_endpoint : info.v6_endpoint));
124 std::unique_ptr<ReceiverRequester>& receiver = result.first->second;
125 std::vector<std::string> urls;
126 urls.reserve(observers_by_url_.size());
127 for (const auto& url : observers_by_url_)
128 urls.push_back(url.first);
129 receiver->RequestUrlAvailabilities(std::move(urls));
130 }
131
ChangeReceiver(const ServiceInfo & info)132 void UrlAvailabilityRequester::ChangeReceiver(const ServiceInfo& info) {}
133
RemoveReceiver(const ServiceInfo & info)134 void UrlAvailabilityRequester::RemoveReceiver(const ServiceInfo& info) {
135 auto receiver_entry = receiver_by_service_id_.find(info.service_id);
136 if (receiver_entry != receiver_by_service_id_.end()) {
137 auto& receiver = receiver_entry->second;
138 receiver->RemoveReceiver();
139 receiver_by_service_id_.erase(receiver_entry);
140 }
141 }
142
RemoveAllReceivers()143 void UrlAvailabilityRequester::RemoveAllReceivers() {
144 for (auto& entry : receiver_by_service_id_) {
145 auto& receiver = entry.second;
146 receiver->RemoveReceiver();
147 }
148 receiver_by_service_id_.clear();
149 }
150
RefreshWatches()151 Clock::time_point UrlAvailabilityRequester::RefreshWatches() {
152 const Clock::time_point now = now_function_();
153 Clock::time_point minimum_schedule_time = now + kWatchDuration;
154 for (auto& entry : receiver_by_service_id_) {
155 auto& receiver = entry.second;
156 const Clock::time_point requested_schedule_time =
157 receiver->RefreshWatches(now);
158 if (requested_schedule_time < minimum_schedule_time)
159 minimum_schedule_time = requested_schedule_time;
160 }
161 return minimum_schedule_time;
162 }
163
ReceiverRequester(UrlAvailabilityRequester * listener,const std::string & service_id,const IPEndpoint & endpoint)164 UrlAvailabilityRequester::ReceiverRequester::ReceiverRequester(
165 UrlAvailabilityRequester* listener,
166 const std::string& service_id,
167 const IPEndpoint& endpoint)
168 : listener(listener),
169 service_id(service_id),
170 connect_request(
171 NetworkServiceManager::Get()->GetProtocolConnectionClient()->Connect(
172 endpoint,
173 this)) {}
174
175 UrlAvailabilityRequester::ReceiverRequester::~ReceiverRequester() = default;
176
GetOrRequestAvailabilities(const std::vector<std::string> & requested_urls,ReceiverObserver * observer)177 void UrlAvailabilityRequester::ReceiverRequester::GetOrRequestAvailabilities(
178 const std::vector<std::string>& requested_urls,
179 ReceiverObserver* observer) {
180 std::vector<std::string> unknown_urls;
181 for (const auto& url : requested_urls) {
182 auto availability_entry = known_availability_by_url.find(url);
183 if (availability_entry == known_availability_by_url.end()) {
184 unknown_urls.emplace_back(url);
185 continue;
186 }
187
188 msgs::UrlAvailability availability = availability_entry->second;
189 if (observer) {
190 switch (availability) {
191 case msgs::UrlAvailability::kAvailable:
192 observer->OnReceiverAvailable(url, service_id);
193 break;
194 case msgs::UrlAvailability::kUnavailable:
195 case msgs::UrlAvailability::kInvalid:
196 observer->OnReceiverUnavailable(url, service_id);
197 break;
198 }
199 }
200 }
201 if (!unknown_urls.empty()) {
202 RequestUrlAvailabilities(std::move(unknown_urls));
203 }
204 }
205
RequestUrlAvailabilities(std::vector<std::string> urls)206 void UrlAvailabilityRequester::ReceiverRequester::RequestUrlAvailabilities(
207 std::vector<std::string> urls) {
208 if (urls.empty())
209 return;
210 const uint64_t request_id = GetNextRequestId(endpoint_id);
211 ErrorOr<uint64_t> watch_id_or_error(0);
212 if (!connection || (watch_id_or_error = SendRequest(request_id, urls))) {
213 request_by_id.emplace(request_id,
214 Request{watch_id_or_error.value(), std::move(urls)});
215 } else {
216 for (const auto& url : urls)
217 for (auto& observer : listener->observers_by_url_[url])
218 observer->OnRequestFailed(url, service_id);
219 }
220 }
221
SendRequest(uint64_t request_id,const std::vector<std::string> & urls)222 ErrorOr<uint64_t> UrlAvailabilityRequester::ReceiverRequester::SendRequest(
223 uint64_t request_id,
224 const std::vector<std::string>& urls) {
225 uint64_t watch_id = next_watch_id++;
226 msgs::PresentationUrlAvailabilityRequest cbor_request;
227 cbor_request.request_id = request_id;
228 cbor_request.urls = urls;
229 cbor_request.watch_id = watch_id;
230 cbor_request.watch_duration = to_microseconds(kWatchDuration).count();
231
232 msgs::CborEncodeBuffer buffer;
233 if (msgs::EncodePresentationUrlAvailabilityRequest(cbor_request, &buffer)) {
234 OSP_VLOG << "writing presentation-url-availability-request";
235 connection->Write(buffer.data(), buffer.size());
236 watch_by_id.emplace(
237 watch_id, Watch{listener->now_function_() + kWatchDuration, urls});
238 if (!event_watch) {
239 event_watch = GetClientDemuxer()->WatchMessageType(
240 endpoint_id, msgs::Type::kPresentationUrlAvailabilityEvent, this);
241 }
242 if (!response_watch) {
243 response_watch = GetClientDemuxer()->WatchMessageType(
244 endpoint_id, msgs::Type::kPresentationUrlAvailabilityResponse, this);
245 }
246 return watch_id;
247 }
248 return Error::Code::kCborEncoding;
249 }
250
RefreshWatches(Clock::time_point now)251 Clock::time_point UrlAvailabilityRequester::ReceiverRequester::RefreshWatches(
252 Clock::time_point now) {
253 Clock::time_point minimum_schedule_time = now + kWatchDuration;
254 std::vector<std::vector<std::string>> new_requests;
255 for (auto entry = watch_by_id.begin(); entry != watch_by_id.end();) {
256 Watch& watch = entry->second;
257 const Clock::time_point buffered_deadline =
258 watch.deadline - kWatchRefreshPadding;
259 if (now > buffered_deadline) {
260 new_requests.emplace_back(std::move(watch.urls));
261 entry = watch_by_id.erase(entry);
262 } else {
263 ++entry;
264 if (buffered_deadline < minimum_schedule_time)
265 minimum_schedule_time = buffered_deadline;
266 }
267 }
268 if (watch_by_id.empty())
269 StopWatching(&event_watch);
270
271 for (auto& request : new_requests)
272 RequestUrlAvailabilities(std::move(request));
273
274 return minimum_schedule_time;
275 }
276
UpdateAvailabilities(const std::vector<std::string> & urls,const std::vector<msgs::UrlAvailability> & availabilities)277 Error::Code UrlAvailabilityRequester::ReceiverRequester::UpdateAvailabilities(
278 const std::vector<std::string>& urls,
279 const std::vector<msgs::UrlAvailability>& availabilities) {
280 auto availability_it = availabilities.begin();
281 if (urls.size() != availabilities.size()) {
282 return Error::Code::kCborInvalidMessage;
283 }
284 for (const auto& url : urls) {
285 auto observer_entry = listener->observers_by_url_.find(url);
286 if (observer_entry == listener->observers_by_url_.end())
287 continue;
288 std::vector<ReceiverObserver*>& observers = observer_entry->second;
289 auto result = known_availability_by_url.emplace(url, *availability_it);
290 auto entry = result.first;
291 bool inserted = result.second;
292 bool updated = (entry->second != *availability_it);
293 if (inserted || updated) {
294 switch (*availability_it) {
295 case msgs::UrlAvailability::kAvailable:
296 for (auto* observer : observers)
297 observer->OnReceiverAvailable(url, service_id);
298 break;
299 case msgs::UrlAvailability::kUnavailable:
300 case msgs::UrlAvailability::kInvalid:
301 for (auto* observer : observers)
302 observer->OnReceiverUnavailable(url, service_id);
303 break;
304 default:
305 break;
306 }
307 }
308 ++availability_it;
309 }
310 return Error::Code::kNone;
311 }
312
RemoveUnobservedRequests(const std::set<std::string> & unobserved_urls)313 void UrlAvailabilityRequester::ReceiverRequester::RemoveUnobservedRequests(
314 const std::set<std::string>& unobserved_urls) {
315 std::map<uint64_t, Request> new_requests;
316 std::set<std::string> still_observed_urls;
317 for (auto entry = request_by_id.begin(); entry != request_by_id.end();
318 ++entry) {
319 Request& request = entry->second;
320 auto split = PartitionUrlsBySetMembership(&request.urls, unobserved_urls);
321 if (split == request.urls.end())
322 continue;
323 MoveVectorSegment(request.urls.begin(), split, &still_observed_urls);
324 if (connection)
325 watch_by_id.erase(request.watch_id);
326 }
327 if (!still_observed_urls.empty()) {
328 const uint64_t new_request_id = GetNextRequestId(endpoint_id);
329 ErrorOr<uint64_t> watch_id_or_error(0);
330 std::vector<std::string> urls;
331 urls.reserve(still_observed_urls.size());
332 for (auto& url : still_observed_urls)
333 urls.emplace_back(std::move(url));
334 if (!connection ||
335 (watch_id_or_error = SendRequest(new_request_id, urls))) {
336 new_requests.emplace(new_request_id,
337 Request{watch_id_or_error.value(), std::move(urls)});
338 } else {
339 for (const auto& url : urls)
340 for (auto& observer : listener->observers_by_url_[url])
341 observer->OnRequestFailed(url, service_id);
342 }
343 }
344
345 for (auto& entry : new_requests)
346 request_by_id.emplace(entry.first, std::move(entry.second));
347
348 if (request_by_id.empty())
349 StopWatching(&response_watch);
350 }
351
RemoveUnobservedWatches(const std::set<std::string> & unobserved_urls)352 void UrlAvailabilityRequester::ReceiverRequester::RemoveUnobservedWatches(
353 const std::set<std::string>& unobserved_urls) {
354 std::set<std::string> still_observed_urls;
355 for (auto entry = watch_by_id.begin(); entry != watch_by_id.end();) {
356 Watch& watch = entry->second;
357 auto split = PartitionUrlsBySetMembership(&watch.urls, unobserved_urls);
358 if (split == watch.urls.end()) {
359 ++entry;
360 continue;
361 }
362 MoveVectorSegment(watch.urls.begin(), split, &still_observed_urls);
363 entry = watch_by_id.erase(entry);
364 }
365
366 std::vector<std::string> urls;
367 urls.reserve(still_observed_urls.size());
368 for (auto& url : still_observed_urls)
369 urls.emplace_back(std::move(url));
370 RequestUrlAvailabilities(std::move(urls));
371 // TODO(btolsch): These message watch cancels could be tested by expecting
372 // messages to fall through to the default watch.
373 if (watch_by_id.empty())
374 StopWatching(&event_watch);
375 }
376
RemoveReceiver()377 void UrlAvailabilityRequester::ReceiverRequester::RemoveReceiver() {
378 for (const auto& availability : known_availability_by_url) {
379 if (availability.second == msgs::UrlAvailability::kAvailable) {
380 const std::string& url = availability.first;
381 for (auto& observer : listener->observers_by_url_[url])
382 observer->OnReceiverUnavailable(url, service_id);
383 }
384 }
385 }
386
OnConnectionOpened(uint64_t request_id,std::unique_ptr<ProtocolConnection> connection)387 void UrlAvailabilityRequester::ReceiverRequester::OnConnectionOpened(
388 uint64_t request_id,
389 std::unique_ptr<ProtocolConnection> connection) {
390 connect_request.MarkComplete();
391 // TODO(btolsch): This is one place where we need to make sure the QUIC
392 // connection stays alive, even without constant traffic.
393 endpoint_id = connection->endpoint_id();
394 this->connection = std::move(connection);
395 ErrorOr<uint64_t> watch_id_or_error(0);
396 for (auto entry = request_by_id.begin(); entry != request_by_id.end();) {
397 if ((watch_id_or_error = SendRequest(entry->first, entry->second.urls))) {
398 entry->second.watch_id = watch_id_or_error.value();
399 ++entry;
400 } else {
401 entry = request_by_id.erase(entry);
402 }
403 }
404 }
405
OnConnectionFailed(uint64_t request_id)406 void UrlAvailabilityRequester::ReceiverRequester::OnConnectionFailed(
407 uint64_t request_id) {
408 connect_request.MarkComplete();
409
410 std::set<std::string> waiting_urls;
411 for (auto& entry : request_by_id) {
412 Request& request = entry.second;
413 for (auto& url : request.urls) {
414 waiting_urls.emplace(std::move(url));
415 }
416 }
417 for (const auto& url : waiting_urls)
418 for (auto& observer : listener->observers_by_url_[url])
419 observer->OnRequestFailed(url, service_id);
420
421 std::string id = std::move(service_id);
422 listener->receiver_by_service_id_.erase(id);
423 }
424
OnStreamMessage(uint64_t endpoint_id,uint64_t connection_id,msgs::Type message_type,const uint8_t * buffer,size_t buffer_size,Clock::time_point now)425 ErrorOr<size_t> UrlAvailabilityRequester::ReceiverRequester::OnStreamMessage(
426 uint64_t endpoint_id,
427 uint64_t connection_id,
428 msgs::Type message_type,
429 const uint8_t* buffer,
430 size_t buffer_size,
431 Clock::time_point now) {
432 switch (message_type) {
433 case msgs::Type::kPresentationUrlAvailabilityResponse: {
434 msgs::PresentationUrlAvailabilityResponse response;
435 ssize_t result = msgs::DecodePresentationUrlAvailabilityResponse(
436 buffer, buffer_size, &response);
437 if (result < 0) {
438 if (result == msgs::kParserEOF)
439 return Error::Code::kCborIncompleteMessage;
440 OSP_LOG_WARN << "parse error: " << result;
441 return Error::Code::kCborParsing;
442 } else {
443 auto request_entry = request_by_id.find(response.request_id);
444 if (request_entry == request_by_id.end()) {
445 OSP_LOG_ERROR << "bad response id: " << response.request_id;
446 return Error::Code::kCborInvalidResponseId;
447 }
448 std::vector<std::string>& urls = request_entry->second.urls;
449 if (urls.size() != response.url_availabilities.size()) {
450 OSP_LOG_WARN << "bad response size: expected " << urls.size()
451 << " but got " << response.url_availabilities.size();
452 return Error::Code::kCborInvalidMessage;
453 }
454 Error::Code update_result =
455 UpdateAvailabilities(urls, response.url_availabilities);
456 if (update_result != Error::Code::kNone) {
457 return update_result;
458 }
459 request_by_id.erase(response.request_id);
460 if (request_by_id.empty())
461 StopWatching(&response_watch);
462 return result;
463 }
464 } break;
465 case msgs::Type::kPresentationUrlAvailabilityEvent: {
466 msgs::PresentationUrlAvailabilityEvent event;
467 ssize_t result = msgs::DecodePresentationUrlAvailabilityEvent(
468 buffer, buffer_size, &event);
469 if (result < 0) {
470 if (result == msgs::kParserEOF)
471 return Error::Code::kCborIncompleteMessage;
472 OSP_LOG_WARN << "parse error: " << result;
473 return Error::Code::kCborParsing;
474 } else {
475 auto watch_entry = watch_by_id.find(event.watch_id);
476 if (watch_entry != watch_by_id.end()) {
477 std::vector<std::string> urls = watch_entry->second.urls;
478 Error::Code update_result =
479 UpdateAvailabilities(urls, event.url_availabilities);
480 if (update_result != Error::Code::kNone) {
481 return update_result;
482 }
483 }
484 return result;
485 }
486 } break;
487 default:
488 break;
489 }
490 return Error::Code::kCborParsing;
491 }
492
493 } // namespace osp
494 } // namespace openscreen
495