1 /*
2 * Copyright (C) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15 use crate::parser;
16 use crate::task;
17
18 use hdc::config;
19 use hdc::config::HdcCommand;
20 use hdc::config::TaskMessage;
21 use hdc::transfer;
22 use hdc::utils;
23 use std::process;
24 use std::str::FromStr;
25 use std::time::Duration;
26
27 use std::io::{self, Error, ErrorKind};
28
29 use ylong_runtime::net::{SplitReadHalf, SplitWriteHalf, TcpListener, TcpStream};
30
run_server_mode(addr_str: String) -> io::Result<()>31 pub async fn run_server_mode(addr_str: String) -> io::Result<()> {
32 let saddr = addr_str;
33 let listener = TcpListener::bind(saddr.clone()).await?;
34 hdc::info!("server binds on {saddr}");
35
36 loop {
37 let (stream, addr) = listener.accept().await?;
38 hdc::info!("accepted client {addr}");
39 ylong_runtime::spawn(handle_client(stream));
40 }
41 }
42
get_process_pids() -> Vec<u32>43 pub async fn get_process_pids() -> Vec<u32> {
44 let mut pids: Vec<u32> = Vec::new();
45 if cfg!(target_os = "windows") {
46 let output = utils::execute_cmd(format!("tasklist | findstr hdc"));
47 let output_str = String::from_utf8_lossy(&output);
48 let mut get_pid = false;
49 for token in output_str.split_whitespace() {
50 if get_pid {
51 pids.push(u32::from_str(token).unwrap());
52 get_pid = false;
53 }
54 if token.contains("exe") {
55 get_pid = true;
56 }
57 }
58 } else {
59 let output = utils::execute_cmd(format!("ps -ef | grep hdc | awk '{{print $2}}'"));
60 let output_str = String::from_utf8_lossy(&output);
61 for pid in output_str.split_whitespace() {
62 pids.push(u32::from_str(pid).unwrap());
63 }
64 }
65 pids
66 }
67
68 // 跨平台命令
check_allow_fork() -> bool69 pub async fn check_allow_fork() -> bool {
70 let pids = get_process_pids().await;
71 for pid in pids {
72 if pid != process::id() {
73 false;
74 }
75 }
76 true
77 }
78
79 // 跨平台命令
server_fork(addr_str: String)80 pub async fn server_fork(addr_str: String) {
81 let current_exe = std::env::current_exe().unwrap().display().to_string();
82 let result = process::Command::new(¤t_exe)
83 .args(["-b", "-m", "-s", addr_str.as_str()])
84 .spawn();
85 match result {
86 Ok(_) => ylong_runtime::time::sleep(Duration::from_millis(1000)).await,
87 Err(_) => hdc::info!("server fork failed"),
88 }
89 }
90
server_kill()91 pub async fn server_kill() {
92 // TODO: check mac & win
93 let pids = get_process_pids().await;
94
95 for pid in pids {
96 if pid != process::id() {
97 if cfg!(target_os = "windows") {
98 utils::execute_cmd(format!("taskkill /pid {} /f", pid));
99 } else {
100 utils::execute_cmd(format!("kill -9 {}", pid));
101 }
102 }
103 }
104 }
105
106 #[allow(unused)]
107 #[derive(PartialEq)]
108 enum ChannelState {
109 InteractiveShell,
110 File,
111 App,
112 None,
113 }
114
handle_client(stream: TcpStream) -> io::Result<()>115 async fn handle_client(stream: TcpStream) -> io::Result<()> {
116 let (mut rd, wr) = stream.into_split();
117 let (connect_key, channel_id) = handshake_with_client(&mut rd, wr).await?;
118 let mut channel_state = ChannelState::None;
119
120 loop {
121 let recv_opt = transfer::tcp::recv_channel_message(&mut rd).await;
122 if recv_opt.is_err() {
123 let session_id = match task::ConnectMap::get_session_id(connect_key.clone()).await {
124 Some(seid) => seid,
125 None => return Ok(()),
126 };
127 let message = TaskMessage {
128 channel_id,
129 command: HdcCommand::KernelChannelClose,
130 payload: vec![0],
131 };
132 transfer::put(session_id, message).await;
133 return Ok(());
134 }
135 let recv = recv_opt.unwrap();
136 hdc::debug!(
137 "recv hex: {}",
138 recv.iter()
139 .map(|c| format!("{c:02x}"))
140 .collect::<Vec<_>>()
141 .join(" ")
142 );
143
144 let recv_str = String::from_utf8(recv.clone()).unwrap();
145 hdc::debug!("recv str: {}", recv_str.clone());
146 let mut parsed = parser::split_opt_and_cmd(
147 String::from_utf8(recv)
148 .unwrap()
149 .split(" ")
150 .map(|s| s.trim_end_matches('\0').to_string())
151 .collect::<Vec<_>>(),
152 );
153
154 if channel_state == ChannelState::InteractiveShell {
155 parsed.command = Some(HdcCommand::ShellData);
156 parsed.parameters = vec![recv_str];
157 }
158
159 if parsed.command == Some(HdcCommand::UnityExecute) {
160 channel_state = ChannelState::InteractiveShell;
161 if parsed.parameters.len() == 1 {
162 parsed.command = Some(HdcCommand::ShellInit);
163 }
164 }
165
166 hdc::debug!("parsed cmd: {:#?}", parsed);
167
168 if let Some(cmd) = parsed.command {
169 if let Err(e) = task::channel_task_dispatch(task::TaskInfo {
170 command: cmd,
171 connect_key: connect_key.clone(),
172 channel_id,
173 params: parsed.parameters,
174 })
175 .await
176 {
177 hdc::error!("{e}");
178 }
179 } else {
180 return Err(Error::new(ErrorKind::Other, "command not found"));
181 }
182 }
183 }
184
handshake_with_client( rd: &mut SplitReadHalf, wr: SplitWriteHalf, ) -> io::Result<(String, u32)>185 async fn handshake_with_client(
186 rd: &mut SplitReadHalf,
187 wr: SplitWriteHalf,
188 ) -> io::Result<(String, u32)> {
189 let channel_id = utils::get_pseudo_random_u32();
190 transfer::TcpMap::start(channel_id, wr).await;
191
192 let buf = [
193 config::HANDSHAKE_MESSAGE.as_bytes(),
194 vec![0_u8; config::BANNER_SIZE - config::HANDSHAKE_MESSAGE.len()].as_slice(),
195 u32::to_le_bytes(channel_id).as_slice(),
196 vec![0_u8; config::KEY_MAX_SIZE - std::mem::size_of::<u32>()].as_slice(),
197 ]
198 .concat();
199
200 transfer::send_channel_data(channel_id, buf).await;
201 let recv = transfer::tcp::recv_channel_message(rd).await.unwrap();
202 let connect_key = unpack_channel_handshake(recv)?;
203 Ok((connect_key, channel_id))
204 }
205
unpack_channel_handshake(recv: Vec<u8>) -> io::Result<String>206 fn unpack_channel_handshake(recv: Vec<u8>) -> io::Result<String> {
207 let msg = std::str::from_utf8(&recv[..config::HANDSHAKE_MESSAGE.len()]).unwrap();
208 if msg != config::HANDSHAKE_MESSAGE {
209 return Err(Error::new(ErrorKind::Other, "Recv server-hello failed"));
210 }
211 let key_buf = &recv[config::BANNER_SIZE..];
212 let pos = match key_buf.iter().position(|c| *c == 0) {
213 Some(p) => p,
214 None => key_buf.len(),
215 };
216 if let Ok(connect_key) = String::from_utf8(key_buf[..pos].to_vec()) {
217 return Ok(connect_key);
218 } else {
219 return Err(Error::new(ErrorKind::Other, "unpack connect key failed"));
220 }
221 }
222