1 /*
2 * Copyright (C) 2021 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 //! Provides a backing task to implement a Dispatcher
18
19 use crate::boot_time::{BootTime, Duration};
20 use anyhow::{bail, Result};
21 use log::{debug, trace, warn};
22 use std::collections::HashMap;
23 use tokio::sync::{mpsc, oneshot};
24
25 use super::{Command, QueryError, Response};
26 use crate::network::{Network, ServerInfo, SocketTagger, ValidationReporter};
27 use crate::{config, network};
28
29 pub struct Driver {
30 command_rx: mpsc::Receiver<Command>,
31 networks: HashMap<u32, Network>,
32 validation: ValidationReporter,
33 tagger: SocketTagger,
34 config_cache: config::Cache,
35 }
36
debug_err(r: Result<()>)37 fn debug_err(r: Result<()>) {
38 if let Err(e) = r {
39 debug!("Dispatcher loop got {:?}", e);
40 }
41 }
42
43 impl Driver {
new( command_rx: mpsc::Receiver<Command>, validation: ValidationReporter, tagger: SocketTagger, ) -> Self44 pub fn new(
45 command_rx: mpsc::Receiver<Command>,
46 validation: ValidationReporter,
47 tagger: SocketTagger,
48 ) -> Self {
49 Self {
50 command_rx,
51 networks: HashMap::new(),
52 validation,
53 tagger,
54 config_cache: config::Cache::new(),
55 }
56 }
57
drive(mut self) -> Result<()>58 pub async fn drive(mut self) -> Result<()> {
59 loop {
60 self.drive_once().await?
61 }
62 }
63
drive_once(&mut self) -> Result<()>64 async fn drive_once(&mut self) -> Result<()> {
65 if let Some(command) = self.command_rx.recv().await {
66 trace!("dispatch command: {:?}", command);
67 match command {
68 Command::Probe { info, timeout } => debug_err(self.probe(info, timeout).await),
69 Command::Query { net_id, base64_query, expired_time, resp } => {
70 debug_err(self.query(net_id, base64_query, expired_time, resp).await)
71 }
72 Command::Clear { net_id } => {
73 self.networks.remove(&net_id);
74 self.config_cache.garbage_collect();
75 }
76 Command::Exit => {
77 bail!("Death due to Exit")
78 }
79 }
80 Ok(())
81 } else {
82 bail!("Death due to command_tx dying")
83 }
84 }
85
query( &mut self, net_id: u32, query: String, expiry: BootTime, response: oneshot::Sender<Response>, ) -> Result<()>86 async fn query(
87 &mut self,
88 net_id: u32,
89 query: String,
90 expiry: BootTime,
91 response: oneshot::Sender<Response>,
92 ) -> Result<()> {
93 if let Some(network) = self.networks.get_mut(&net_id) {
94 network.query(network::Query { query, response, expiry }).await?;
95 } else {
96 warn!("Tried to send a query to non-existent network net_id={}", net_id);
97 response.send(Response::Error { error: QueryError::Unexpected }).unwrap_or_else(|_| {
98 warn!("Unable to send reply for non-existent network net_id={}", net_id);
99 })
100 }
101 Ok(())
102 }
103
probe(&mut self, info: ServerInfo, timeout: Duration) -> Result<()>104 async fn probe(&mut self, info: ServerInfo, timeout: Duration) -> Result<()> {
105 use std::collections::hash_map::Entry;
106 if !self.networks.get(&info.net_id).map_or(true, |net| net.get_info() == &info) {
107 // If we have a network registered to the provided net_id, but the server info doesn't
108 // match, our API has been used incorrectly. Attempt to recover by deleting the old
109 // network and recreating it according to the probe request.
110 warn!("Probing net_id={} with mismatched server info {:?}", info.net_id, info);
111 self.networks.remove(&info.net_id);
112 }
113 // Can't use or_insert_with because creating a network may fail
114 let net = match self.networks.entry(info.net_id) {
115 Entry::Occupied(network) => network.into_mut(),
116 Entry::Vacant(vacant) => {
117 let key = config::Key {
118 cert_path: info.cert_path.clone(),
119 max_idle_timeout: info.idle_timeout_ms,
120 };
121 let config = self.config_cache.get(&key)?;
122 vacant.insert(
123 Network::new(info, config, self.validation.clone(), self.tagger.clone())
124 .await?,
125 )
126 }
127 };
128 net.probe(timeout).await?;
129 Ok(())
130 }
131 }
132