• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 //! auth
16 #![allow(missing_docs)]
17 
18 use crate::config::*;
19 
20 use hdc::config;
21 use hdc::config::TaskMessage;
22 use hdc::serializer::serialize::{Serialization, SessionHandShake};
23 use hdc::transfer;
24 
25 use std::io::{self, Error, ErrorKind};
26 use std::path::Path;
27 
28 use openssl::base64;
29 use openssl::rsa::{Padding, Rsa};
30 use ylong_runtime::net::SplitReadHalf;
31 
handshake_with_daemon( connect_key: String, session_id: u32, channel_id: u32, rd: &mut SplitReadHalf, ) -> io::Result<(String, String)>32 pub async fn handshake_with_daemon(
33     connect_key: String,
34     session_id: u32,
35     channel_id: u32,
36     rd: &mut SplitReadHalf,
37 ) -> io::Result<(String, String)> {
38     let rsa = load_or_create_prikey()?;
39 
40     let mut handshake = SessionHandShake {
41         banner: HANDSHAKE_MESSAGE.to_string(),
42         session_id,
43         connect_key,
44         version: config::get_version(),
45         ..Default::default()
46     };
47 
48     send_handshake_to_daemon(&handshake, channel_id).await;
49     loop {
50         let msg = transfer::tcp::unpack_task_message(rd).await?;
51         if msg.command == config::HdcCommand::KernelHandshake {
52             let mut recv = SessionHandShake::default();
53             recv.parse(msg.payload)?;
54 
55             hdc::info!("recv handshake: {:#?}", recv);
56             if recv.banner != config::HANDSHAKE_MESSAGE {
57                 return Err(Error::new(ErrorKind::Other, "Recv server-hello failed"));
58             }
59 
60             if recv.auth_type == config::AuthType::OK as u8 {
61                 return Ok((recv.buf, recv.version));
62             } else if recv.auth_type == config::AuthType::Publickey as u8 {
63                 // send public key
64                 handshake.auth_type = config::AuthType::Publickey as u8;
65                 handshake.buf = get_hostname()?;
66                 handshake.buf.push(char::from_u32(12).unwrap());
67                 let pubkey_pem = get_pubkey_pem(&rsa)?;
68                 handshake.buf.push_str(pubkey_pem.as_str());
69                 send_handshake_to_daemon(&handshake, channel_id).await;
70 
71                 // send signature
72                 handshake.auth_type = config::AuthType::Signature as u8;
73                 handshake.buf = get_signature_b64(&rsa, recv.buf)?;
74                 send_handshake_to_daemon(&handshake, channel_id).await;
75             } else if recv.auth_type == config::AuthType::Fail as u8 {
76                 return Err(Error::new(ErrorKind::Other, recv.buf.as_str()));
77             } else {
78                 return Err(Error::new(ErrorKind::Other, "unknown auth type"));
79             }
80         } else {
81             return Err(Error::new(ErrorKind::Other, "unknown command flag"));
82         }
83     }
84 }
85 
load_or_create_prikey() -> io::Result<Rsa<openssl::pkey::Private>>86 fn load_or_create_prikey() -> io::Result<Rsa<openssl::pkey::Private>> {
87     let file = Path::new(&get_home_dir())
88         .join(config::RSA_PRIKEY_PATH)
89         .join(config::RSA_PRIKEY_NAME);
90 
91     if let Ok(pem) = std::fs::read(&file) {
92         if let Ok(prikey) = Rsa::private_key_from_pem(&pem) {
93             hdc::info!("found existed private key");
94             return Ok(prikey);
95         } else {
96             hdc::error!("found broken private key, regenerating...");
97         }
98     }
99 
100     hdc::info!("create private key at {:#?}", file);
101     create_prikey()
102 }
103 
create_prikey() -> io::Result<Rsa<openssl::pkey::Private>>104 pub fn create_prikey() -> io::Result<Rsa<openssl::pkey::Private>> {
105     let prikey = Rsa::generate(config::RSA_BIT_NUM as u32).unwrap();
106     let pem = prikey.private_key_to_pem().unwrap();
107     let path = Path::new(&get_home_dir()).join(config::RSA_PRIKEY_PATH);
108     let file = path.join(config::RSA_PRIKEY_NAME);
109 
110     let _ = std::fs::create_dir_all(&path);
111     if let Err(_) = std::fs::write(file, pem) {
112         hdc::error!("write private key failed");
113         return Err(Error::new(ErrorKind::Other, "write private key failed"));
114     } else {
115         return Ok(prikey);
116     }
117 }
118 
get_pubkey_pem(rsa: &Rsa<openssl::pkey::Private>) -> io::Result<String>119 fn get_pubkey_pem(rsa: &Rsa<openssl::pkey::Private>) -> io::Result<String> {
120     if let Ok(pubkey) = rsa.public_key_to_pem() {
121         if let Ok(buf) = String::from_utf8(pubkey) {
122             Ok(buf)
123         } else {
124             Err(Error::new(
125                 ErrorKind::Other,
126                 "convert public key to pem string failed",
127             ))
128         }
129     } else {
130         Err(Error::new(
131             ErrorKind::Other,
132             "convert public key to pem string failed",
133         ))
134     }
135 }
136 
get_signature_b64(rsa: &Rsa<openssl::pkey::Private>, plain: String) -> io::Result<String>137 fn get_signature_b64(rsa: &Rsa<openssl::pkey::Private>, plain: String) -> io::Result<String> {
138     let mut enc = vec![0_u8; config::RSA_BIT_NUM];
139     match rsa.private_encrypt(plain.as_bytes(), &mut enc, Padding::PKCS1) {
140         Ok(size) => Ok(base64::encode_block(&enc[..size])),
141         Err(_) => Err(Error::new(ErrorKind::Other, "rsa private encrypt failed")),
142     }
143 }
144 
send_handshake_to_daemon(handshake: &SessionHandShake, channel_id: u32)145 async fn send_handshake_to_daemon(handshake: &SessionHandShake, channel_id: u32) {
146     transfer::put(
147         handshake.session_id,
148         TaskMessage {
149             channel_id,
150             command: config::HdcCommand::KernelHandshake,
151             payload: handshake.serialize(),
152         },
153     )
154     .await;
155 }
156 
get_home_dir() -> String157 fn get_home_dir() -> String {
158     use std::process::Command;
159 
160     let output = if cfg!(target_os = "windows") {
161         Command::new("cmd")
162             .args(["/c", "echo %USERPROFILE%"])
163             .output()
164     } else {
165         Command::new("sh").args(["-c", "echo ~"]).output()
166     };
167 
168     if let Ok(result) = output {
169         String::from_utf8(result.stdout).unwrap().trim().to_string()
170     } else {
171         hdc::warn!("get home dir failed, use current dir instead");
172         ".".to_string()
173     }
174 }
175 
get_hostname() -> io::Result<String>176 fn get_hostname() -> io::Result<String> {
177     use sed::process::Command;
178 
179     let output = if cfg!(target_os = "windows") {
180         Command::new("cmd").args(["/c", "hostname"]).output();
181     } else {
182         Command::new("cmd").args(["-c", "hostname"]).output();
183     };
184     if let Ok(result) = output {
185         Ok(String::from_utf8(result.stdout).unwrap())
186     } else {
187         Err(Error::new(ErrorKind::Other, "get hostname failed"))
188     }
189 }
190