1 // Copyright 2024 The Android Open Source Project
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 //! Launcher of forwarder_guest
16
17 use anyhow::{anyhow, Context};
18 use clap::Parser;
19 use csv_async::AsyncReader;
20 use debian_service::debian_service_client::DebianServiceClient;
21 use debian_service::{ActivePort, QueueOpeningRequest, ReportVmActivePortsRequest};
22 use futures::stream::StreamExt;
23 use log::{debug, error};
24 use serde::Deserialize;
25 use std::collections::HashMap;
26 use std::process::Stdio;
27 use tokio::io::BufReader;
28 use tokio::process::Command;
29 use tokio::try_join;
30 use tonic::transport::{Channel, Endpoint};
31 use tonic::Request;
32
33 mod debian_service {
34 tonic::include_proto!("com.android.virtualization.terminal.proto");
35 }
36
37 const NON_PREVILEGED_PORT_RANGE_START: i32 = 1024;
38 const TTYD_PORT: i32 = 7681;
39 const TCPSTATES_STATE_CLOSE: &str = "CLOSE";
40 const TCPSTATES_STATE_LISTEN: &str = "LISTEN";
41
42 #[derive(Debug, Deserialize)]
43 #[serde(rename_all = "UPPERCASE")]
44 struct TcpStateRow {
45 lport: i32,
46 rport: i32,
47 #[serde(alias = "C-COMM")]
48 c_comm: String,
49 newstate: String,
50 }
51
52 #[derive(Parser)]
53 /// Flags for running command
54 pub struct Args {
55 /// path to a file where grpc port number is written
56 #[arg(long)]
57 grpc_port_file: String,
58 }
59
process_forwarding_request_queue( mut client: DebianServiceClient<Channel>, ) -> Result<(), Box<dyn std::error::Error>>60 async fn process_forwarding_request_queue(
61 mut client: DebianServiceClient<Channel>,
62 ) -> Result<(), Box<dyn std::error::Error>> {
63 let cid = vsock::get_local_cid().context("Failed to get CID of VM")?;
64 let mut res_stream = client
65 .open_forwarding_request_queue(Request::new(QueueOpeningRequest { cid: cid as i32 }))
66 .await?
67 .into_inner();
68
69 while let Some(response) = res_stream.message().await? {
70 let tcp_port = i16::try_from(response.guest_tcp_port)
71 .context("Failed to convert guest_tcp_port as i16")?;
72 let vsock_port = response.vsock_port as u32;
73
74 debug!(
75 "executing forwarder_guest with guest_tcp_port: {:?}, vsock_port: {:?}",
76 &tcp_port, &vsock_port
77 );
78
79 let _ = Command::new("forwarder_guest")
80 .arg("--local")
81 .arg(format!("127.0.0.1:{}", tcp_port))
82 .arg("--remote")
83 .arg(format!("vsock:2:{}", vsock_port))
84 .spawn();
85 }
86 Err(anyhow!("process_forwarding_request_queue is terminated").into())
87 }
88
send_active_ports_report( listening_ports: HashMap<i32, ActivePort>, client: &mut DebianServiceClient<Channel>, ) -> Result<(), Box<dyn std::error::Error>>89 async fn send_active_ports_report(
90 listening_ports: HashMap<i32, ActivePort>,
91 client: &mut DebianServiceClient<Channel>,
92 ) -> Result<(), Box<dyn std::error::Error>> {
93 let res = client
94 .report_vm_active_ports(Request::new(ReportVmActivePortsRequest {
95 ports: listening_ports.values().cloned().collect(),
96 }))
97 .await?
98 .into_inner();
99 if res.success {
100 debug!("Successfully reported active ports to the host");
101 } else {
102 error!("Failure response received from the host for reporting active ports");
103 }
104 Ok(())
105 }
106
is_forwardable_port(port: i32) -> bool107 fn is_forwardable_port(port: i32) -> bool {
108 port >= NON_PREVILEGED_PORT_RANGE_START && port != TTYD_PORT
109 }
110
report_active_ports( mut client: DebianServiceClient<Channel>, ) -> Result<(), Box<dyn std::error::Error>>111 async fn report_active_ports(
112 mut client: DebianServiceClient<Channel>,
113 ) -> Result<(), Box<dyn std::error::Error>> {
114 // TODO: we can remove python3 -u when https://github.com/iovisor/bcc/pull/5142 is deployed
115 let mut cmd = Command::new("python3")
116 .arg("-u")
117 .arg("/usr/sbin/tcpstates-bpfcc")
118 .arg("-s")
119 .stdout(Stdio::piped())
120 .spawn()?;
121 let stdout = cmd.stdout.take().context("Failed to get stdout of tcpstates")?;
122 let mut csv_reader = AsyncReader::from_reader(BufReader::new(stdout));
123 let header = csv_reader.headers().await?.clone();
124
125 // TODO(b/340126051): Consider using NETLINK_SOCK_DIAG for the optimization.
126 let listeners = listeners::get_all()?;
127 let mut listening_ports: HashMap<_, _> = listeners
128 .iter()
129 .map(|x| {
130 (
131 x.socket.port().into(),
132 ActivePort { port: x.socket.port().into(), comm: x.process.name.to_string() },
133 )
134 })
135 .filter(|(x, _)| is_forwardable_port(*x))
136 .collect();
137 send_active_ports_report(listening_ports.clone(), &mut client).await?;
138
139 let mut records = csv_reader.records();
140 while let Some(record) = records.next().await {
141 let row: TcpStateRow = record?.deserialize(Some(&header))?;
142 if !is_forwardable_port(row.lport) {
143 continue;
144 }
145 if row.rport > 0 {
146 continue;
147 }
148 match row.newstate.as_str() {
149 TCPSTATES_STATE_LISTEN => {
150 listening_ports.insert(row.lport, ActivePort { port: row.lport, comm: row.c_comm });
151 }
152 TCPSTATES_STATE_CLOSE => {
153 listening_ports.remove(&row.lport);
154 }
155 _ => continue,
156 }
157 send_active_ports_report(listening_ports.clone(), &mut client).await?;
158 }
159
160 Err(anyhow!("report_active_ports is terminated").into())
161 }
162
163 #[tokio::main]
main() -> Result<(), Box<dyn std::error::Error>>164 async fn main() -> Result<(), Box<dyn std::error::Error>> {
165 env_logger::builder().filter_level(log::LevelFilter::Debug).init();
166 debug!("Starting forwarder_guest_launcher");
167 let args = Args::parse();
168 let gateway_ip_addr = netdev::get_default_gateway()?.ipv4[0];
169
170 // Wait for `grpc_port_file` becomes available.
171 const GRPC_PORT_MAX_RETRY_COUNT: u32 = 10;
172 for _ in 0..GRPC_PORT_MAX_RETRY_COUNT {
173 if std::path::Path::new(&args.grpc_port_file).exists() {
174 break;
175 }
176 debug!("{} does not exist. Wait 1 second", args.grpc_port_file);
177 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
178 }
179 let grpc_port = std::fs::read_to_string(&args.grpc_port_file)?.trim().to_string();
180
181 let addr = format!("https://{}:{}", gateway_ip_addr.to_string(), grpc_port);
182 let channel = Endpoint::from_shared(addr)?.connect().await?;
183 let client = DebianServiceClient::new(channel);
184
185 try_join!(process_forwarding_request_queue(client.clone()), report_active_ports(client))?;
186 Ok(())
187 }
188