• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 use crate::parser;
16 use crate::task;
17 
18 use hdc::config;
19 use hdc::config::HdcCommand;
20 use hdc::config::TaskMessage;
21 use hdc::transfer;
22 use hdc::utils;
23 use std::process;
24 use std::str::FromStr;
25 use std::time::Duration;
26 
27 use std::io::{self, Error, ErrorKind};
28 
29 use ylong_runtime::net::{SplitReadHalf, SplitWriteHalf, TcpListener, TcpStream};
30 
run_server_mode(addr_str: String) -> io::Result<()>31 pub async fn run_server_mode(addr_str: String) -> io::Result<()> {
32     let saddr = addr_str;
33     let listener = TcpListener::bind(saddr.clone()).await?;
34     hdc::info!("server binds on {saddr}");
35 
36     loop {
37         let (stream, addr) = listener.accept().await?;
38         hdc::info!("accepted client {addr}");
39         ylong_runtime::spawn(handle_client(stream));
40     }
41 }
42 
get_process_pids() -> Vec<u32>43 pub async fn get_process_pids() -> Vec<u32> {
44     let mut pids: Vec<u32> = Vec::new();
45     if cfg!(target_os = "windows") {
46         let output = utils::execute_cmd(format!("tasklist | findstr hdc"));
47         let output_str = String::from_utf8_lossy(&output);
48         let mut get_pid = false;
49         for token in output_str.split_whitespace() {
50             if get_pid {
51                 pids.push(u32::from_str(token).unwrap());
52                 get_pid = false;
53             }
54             if token.contains("exe") {
55                 get_pid = true;
56             }
57         }
58     } else {
59         let output = utils::execute_cmd(format!("ps -ef | grep hdc | awk '{{print $2}}'"));
60         let output_str = String::from_utf8_lossy(&output);
61         for pid in output_str.split_whitespace() {
62             pids.push(u32::from_str(pid).unwrap());
63         }
64     }
65     pids
66 }
67 
68 // 跨平台命令
check_allow_fork() -> bool69 pub async fn check_allow_fork() -> bool {
70     let pids = get_process_pids().await;
71     for pid in pids {
72         if pid != process::id() {
73             false;
74         }
75     }
76     true
77 }
78 
79 // 跨平台命令
server_fork(addr_str: String)80 pub async fn server_fork(addr_str: String) {
81     let current_exe = std::env::current_exe().unwrap().display().to_string();
82     let result = process::Command::new(&current_exe)
83         .args(["-b", "-m", "-s", addr_str.as_str()])
84         .spawn();
85     match result {
86         Ok(_) => ylong_runtime::time::sleep(Duration::from_millis(1000)).await,
87         Err(_) => hdc::info!("server fork failed"),
88     }
89 }
90 
server_kill()91 pub async fn server_kill() {
92     // TODO: check mac & win
93     let pids = get_process_pids().await;
94 
95     for pid in pids {
96         if pid != process::id() {
97             if cfg!(target_os = "windows") {
98                 utils::execute_cmd(format!("taskkill /pid {} /f", pid));
99             } else {
100                 utils::execute_cmd(format!("kill -9 {}", pid));
101             }
102         }
103     }
104 }
105 
106 #[allow(unused)]
107 #[derive(PartialEq)]
108 enum ChannelState {
109     InteractiveShell,
110     File,
111     App,
112     None,
113 }
114 
handle_client(stream: TcpStream) -> io::Result<()>115 async fn handle_client(stream: TcpStream) -> io::Result<()> {
116     let (mut rd, wr) = stream.into_split();
117     let (connect_key, channel_id) = handshake_with_client(&mut rd, wr).await?;
118     let mut channel_state = ChannelState::None;
119 
120     loop {
121         let recv_opt = transfer::tcp::recv_channel_message(&mut rd).await;
122         if recv_opt.is_err() {
123             let session_id = match task::ConnectMap::get_session_id(connect_key.clone()).await {
124                 Some(seid) => seid,
125                 None => return Ok(()),
126             };
127             let message = TaskMessage {
128                 channel_id,
129                 command: HdcCommand::KernelChannelClose,
130                 payload: vec![0],
131             };
132             transfer::put(session_id, message).await;
133             return Ok(());
134         }
135         let recv = recv_opt.unwrap();
136         hdc::debug!(
137             "recv hex: {}",
138             recv.iter()
139                 .map(|c| format!("{c:02x}"))
140                 .collect::<Vec<_>>()
141                 .join(" ")
142         );
143 
144         let recv_str = String::from_utf8(recv.clone()).unwrap();
145         hdc::debug!("recv str: {}", recv_str.clone());
146         let mut parsed = parser::split_opt_and_cmd(
147             String::from_utf8(recv)
148                 .unwrap()
149                 .split(" ")
150                 .map(|s| s.trim_end_matches('\0').to_string())
151                 .collect::<Vec<_>>(),
152         );
153 
154         if channel_state == ChannelState::InteractiveShell {
155             parsed.command = Some(HdcCommand::ShellData);
156             parsed.parameters = vec![recv_str];
157         }
158 
159         if parsed.command == Some(HdcCommand::UnityExecute) {
160             channel_state = ChannelState::InteractiveShell;
161             if parsed.parameters.len() == 1 {
162                 parsed.command = Some(HdcCommand::ShellInit);
163             }
164         }
165 
166         hdc::debug!("parsed cmd: {:#?}", parsed);
167 
168         if let Some(cmd) = parsed.command {
169             if let Err(e) = task::channel_task_dispatch(task::TaskInfo {
170                 command: cmd,
171                 connect_key: connect_key.clone(),
172                 channel_id,
173                 params: parsed.parameters,
174             })
175             .await
176             {
177                 hdc::error!("{e}");
178             }
179         } else {
180             return Err(Error::new(ErrorKind::Other, "command not found"));
181         }
182     }
183 }
184 
handshake_with_client( rd: &mut SplitReadHalf, wr: SplitWriteHalf, ) -> io::Result<(String, u32)>185 async fn handshake_with_client(
186     rd: &mut SplitReadHalf,
187     wr: SplitWriteHalf,
188 ) -> io::Result<(String, u32)> {
189     let channel_id = utils::get_pseudo_random_u32();
190     transfer::TcpMap::start(channel_id, wr).await;
191 
192     let buf = [
193         config::HANDSHAKE_MESSAGE.as_bytes(),
194         vec![0_u8; config::BANNER_SIZE - config::HANDSHAKE_MESSAGE.len()].as_slice(),
195         u32::to_le_bytes(channel_id).as_slice(),
196         vec![0_u8; config::KEY_MAX_SIZE - std::mem::size_of::<u32>()].as_slice(),
197     ]
198     .concat();
199 
200     transfer::send_channel_data(channel_id, buf).await;
201     let recv = transfer::tcp::recv_channel_message(rd).await.unwrap();
202     let connect_key = unpack_channel_handshake(recv)?;
203     Ok((connect_key, channel_id))
204 }
205 
unpack_channel_handshake(recv: Vec<u8>) -> io::Result<String>206 fn unpack_channel_handshake(recv: Vec<u8>) -> io::Result<String> {
207     let msg = std::str::from_utf8(&recv[..config::HANDSHAKE_MESSAGE.len()]).unwrap();
208     if msg != config::HANDSHAKE_MESSAGE {
209         return Err(Error::new(ErrorKind::Other, "Recv server-hello failed"));
210     }
211     let key_buf = &recv[config::BANNER_SIZE..];
212     let pos = match key_buf.iter().position(|c| *c == 0) {
213         Some(p) => p,
214         None => key_buf.len(),
215     };
216     if let Ok(connect_key) = String::from_utf8(key_buf[..pos].to_vec()) {
217         return Ok(connect_key);
218     } else {
219         return Err(Error::new(ErrorKind::Other, "unpack connect key failed"));
220     }
221 }
222