• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 //! Provides a backing task to implement a Dispatcher
18 
19 use crate::boot_time::{BootTime, Duration};
20 use anyhow::{bail, Result};
21 use log::{debug, trace, warn};
22 use std::collections::HashMap;
23 use tokio::sync::{mpsc, oneshot};
24 
25 use super::{Command, QueryError, Response};
26 use crate::network::{Network, ServerInfo, SocketTagger, ValidationReporter};
27 use crate::{config, network};
28 
29 pub struct Driver {
30     command_rx: mpsc::Receiver<Command>,
31     networks: HashMap<u32, Network>,
32     validation: ValidationReporter,
33     tagger: SocketTagger,
34     config_cache: config::Cache,
35 }
36 
debug_err(r: Result<()>)37 fn debug_err(r: Result<()>) {
38     if let Err(e) = r {
39         debug!("Dispatcher loop got {:?}", e);
40     }
41 }
42 
43 impl Driver {
new( command_rx: mpsc::Receiver<Command>, validation: ValidationReporter, tagger: SocketTagger, ) -> Self44     pub fn new(
45         command_rx: mpsc::Receiver<Command>,
46         validation: ValidationReporter,
47         tagger: SocketTagger,
48     ) -> Self {
49         Self {
50             command_rx,
51             networks: HashMap::new(),
52             validation,
53             tagger,
54             config_cache: config::Cache::new(),
55         }
56     }
57 
drive(mut self) -> Result<()>58     pub async fn drive(mut self) -> Result<()> {
59         loop {
60             self.drive_once().await?
61         }
62     }
63 
drive_once(&mut self) -> Result<()>64     async fn drive_once(&mut self) -> Result<()> {
65         if let Some(command) = self.command_rx.recv().await {
66             trace!("dispatch command: {:?}", command);
67             match command {
68                 Command::Probe { info, timeout } => debug_err(self.probe(info, timeout).await),
69                 Command::Query { net_id, base64_query, expired_time, resp } => {
70                     debug_err(self.query(net_id, base64_query, expired_time, resp).await)
71                 }
72                 Command::Clear { net_id } => {
73                     self.networks.remove(&net_id);
74                     self.config_cache.garbage_collect();
75                 }
76                 Command::Exit => {
77                     bail!("Death due to Exit")
78                 }
79             }
80             Ok(())
81         } else {
82             bail!("Death due to command_tx dying")
83         }
84     }
85 
query( &mut self, net_id: u32, query: String, expiry: BootTime, response: oneshot::Sender<Response>, ) -> Result<()>86     async fn query(
87         &mut self,
88         net_id: u32,
89         query: String,
90         expiry: BootTime,
91         response: oneshot::Sender<Response>,
92     ) -> Result<()> {
93         if let Some(network) = self.networks.get_mut(&net_id) {
94             network.query(network::Query { query, response, expiry }).await?;
95         } else {
96             warn!("Tried to send a query to non-existent network net_id={}", net_id);
97             response.send(Response::Error { error: QueryError::Unexpected }).unwrap_or_else(|_| {
98                 warn!("Unable to send reply for non-existent network net_id={}", net_id);
99             })
100         }
101         Ok(())
102     }
103 
probe(&mut self, info: ServerInfo, timeout: Duration) -> Result<()>104     async fn probe(&mut self, info: ServerInfo, timeout: Duration) -> Result<()> {
105         use std::collections::hash_map::Entry;
106         if !self.networks.get(&info.net_id).map_or(true, |net| net.get_info() == &info) {
107             // If we have a network registered to the provided net_id, but the server info doesn't
108             // match, our API has been used incorrectly. Attempt to recover by deleting the old
109             // network and recreating it according to the probe request.
110             warn!("Probing net_id={} with mismatched server info {:?}", info.net_id, info);
111             self.networks.remove(&info.net_id);
112         }
113         // Can't use or_insert_with because creating a network may fail
114         let net = match self.networks.entry(info.net_id) {
115             Entry::Occupied(network) => network.into_mut(),
116             Entry::Vacant(vacant) => {
117                 let key = config::Key {
118                     cert_path: info.cert_path.clone(),
119                     max_idle_timeout: info.idle_timeout_ms,
120                 };
121                 let config = self.config_cache.get(&key)?;
122                 vacant.insert(
123                     Network::new(info, config, self.validation.clone(), self.tagger.clone())
124                         .await?,
125                 )
126             }
127         };
128         net.probe(timeout).await?;
129         Ok(())
130     }
131 }
132