• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (C) 2024 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 mod manager;
15 
16 use std::collections::HashMap;
17 use std::net::Shutdown;
18 use std::os::fd::AsRawFd;
19 use std::time::Duration;
20 
21 pub(crate) use manager::{ClientManager, ClientManagerEntry};
22 use ylong_http_client::Headers;
23 use ylong_runtime::net::UnixDatagram;
24 use ylong_runtime::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
25 use ylong_runtime::sync::oneshot::{channel, Sender};
26 
27 use crate::config::Version;
28 use crate::error::ErrorCode;
29 use crate::task::notify::{NotifyData, SubscribeType};
30 use crate::utils::{runtime_spawn, Recv};
31 
32 const REQUEST_MAGIC_NUM: u32 = 0x43434646;
33 const HEADERS_MAX_SIZE: u16 = 8 * 1024;
34 const POSITION_OF_LENGTH: u32 = 10;
35 
36 #[derive(Debug)]
37 pub(crate) enum ClientEvent {
38     OpenChannel(u64, Sender<Result<i32, ErrorCode>>),
39     Subscribe(u32, u64, u64, u64, Sender<ErrorCode>),
40     Unsubscribe(u32, Sender<ErrorCode>),
41     TaskFinished(u32),
42     Terminate(u64, Sender<ErrorCode>),
43     SendResponse(u32, String, u32, String, Headers),
44     SendNotifyData(SubscribeType, NotifyData),
45     Shutdown,
46 }
47 
48 pub(crate) enum MessageType {
49     HttpResponse = 0,
50     NotifyData,
51 }
52 
53 impl ClientManagerEntry {
open_channel(&self, pid: u64) -> Result<i32, ErrorCode>54     pub(crate) fn open_channel(&self, pid: u64) -> Result<i32, ErrorCode> {
55         let (tx, rx) = channel::<Result<i32, ErrorCode>>();
56         let event = ClientEvent::OpenChannel(pid, tx);
57         if !self.send_event(event) {
58             return Err(ErrorCode::Other);
59         }
60         let rx = Recv::new(rx);
61         match rx.get() {
62             Some(ret) => ret,
63             None => {
64                 error!("open channel fail, recv none");
65                 Err(ErrorCode::Other)
66             }
67         }
68     }
69 
subscribe(&self, tid: u32, pid: u64, uid: u64, token_id: u64) -> ErrorCode70     pub(crate) fn subscribe(&self, tid: u32, pid: u64, uid: u64, token_id: u64) -> ErrorCode {
71         let (tx, rx) = channel::<ErrorCode>();
72         let event = ClientEvent::Subscribe(tid, pid, uid, token_id, tx);
73         if !self.send_event(event) {
74             return ErrorCode::Other;
75         }
76         let rx = Recv::new(rx);
77         match rx.get() {
78             Some(ret) => ret,
79             None => {
80                 error!("subscribe fail, recv none");
81                 ErrorCode::Other
82             }
83         }
84     }
85 
unsubscribe(&self, tid: u32) -> ErrorCode86     pub(crate) fn unsubscribe(&self, tid: u32) -> ErrorCode {
87         let (tx, rx) = channel::<ErrorCode>();
88         let event = ClientEvent::Unsubscribe(tid, tx);
89         if !self.send_event(event) {
90             return ErrorCode::Other;
91         }
92         let rx = Recv::new(rx);
93         match rx.get() {
94             Some(ret) => ret,
95             None => {
96                 error!("unsubscribe failed");
97                 ErrorCode::Other
98             }
99         }
100     }
101 
notify_task_finished(&self, tid: u32)102     pub(crate) fn notify_task_finished(&self, tid: u32) {
103         let event = ClientEvent::TaskFinished(tid);
104         self.send_event(event);
105     }
106 
notify_process_terminate(&self, pid: u64) -> ErrorCode107     pub(crate) fn notify_process_terminate(&self, pid: u64) -> ErrorCode {
108         let (tx, rx) = channel::<ErrorCode>();
109         let event = ClientEvent::Terminate(pid, tx);
110         if !self.send_event(event) {
111             return ErrorCode::Other;
112         }
113         let rx = Recv::new(rx);
114         match rx.get() {
115             Some(ret) => ret,
116             None => {
117                 error!("notify_process_terminate failed");
118                 ErrorCode::Other
119             }
120         }
121     }
122 
send_response( &self, tid: u32, version: String, status_code: u32, reason: String, headers: Headers, )123     pub(crate) fn send_response(
124         &self,
125         tid: u32,
126         version: String,
127         status_code: u32,
128         reason: String,
129         headers: Headers,
130     ) {
131         let event = ClientEvent::SendResponse(tid, version, status_code, reason, headers);
132         let _ = self.send_event(event);
133     }
134 
send_notify_data(&self, subscribe_type: SubscribeType, notify_data: NotifyData)135     pub(crate) fn send_notify_data(&self, subscribe_type: SubscribeType, notify_data: NotifyData) {
136         let event = ClientEvent::SendNotifyData(subscribe_type, notify_data);
137         let _ = self.send_event(event);
138     }
139 }
140 
141 // uid and token_id will be used later
142 pub(crate) struct Client {
143     pub(crate) pid: u64,
144     pub(crate) message_id: u32,
145     pub(crate) server_sock_fd: UnixDatagram,
146     pub(crate) client_sock_fd: UnixDatagram,
147     rx: UnboundedReceiver<ClientEvent>,
148 }
149 
150 impl Client {
constructor(pid: u64) -> Option<(UnboundedSender<ClientEvent>, i32)>151     pub(crate) fn constructor(pid: u64) -> Option<(UnboundedSender<ClientEvent>, i32)> {
152         let (tx, rx) = unbounded_channel();
153         let (server_sock_fd, client_sock_fd) = match UnixDatagram::pair() {
154             Ok((server_sock_fd, client_sock_fd)) => (server_sock_fd, client_sock_fd),
155             Err(err) => {
156                 error!("can't create a pair of sockets, {:?}", err);
157                 return None;
158             }
159         };
160         let client = Client {
161             pid,
162             message_id: 1,
163             server_sock_fd,
164             client_sock_fd,
165             rx,
166         };
167         let fd = client.client_sock_fd.as_raw_fd();
168         runtime_spawn(client.run());
169         Some((tx, fd))
170     }
171 
run(mut self)172     async fn run(mut self) {
173         loop {
174             // for one task, only send last progress message
175             let mut progress_index = HashMap::new();
176             let mut temp_notify_data: Vec<(SubscribeType, NotifyData)> = Vec::new();
177             let mut len = self.rx.len();
178             if len == 0 {
179                 len = 1;
180             }
181             for index in 0..len {
182                 let recv = match self.rx.recv().await {
183                     Ok(message) => message,
184                     Err(e) => {
185                         error!("ClientManager recv error {:?}", e);
186                         continue;
187                     }
188                 };
189                 match recv {
190                     ClientEvent::Shutdown => {
191                         let _ = self.client_sock_fd.shutdown(Shutdown::Both);
192                         let _ = self.server_sock_fd.shutdown(Shutdown::Both);
193                         self.rx.close();
194                         info!("client terminate, pid {}", self.pid);
195                         return;
196                     }
197                     ClientEvent::SendResponse(tid, version, status_code, reason, headers) => {
198                         self.handle_send_response(tid, version, status_code, reason, headers)
199                             .await;
200                     }
201                     ClientEvent::SendNotifyData(subscribe_type, notify_data) => {
202                         if subscribe_type == SubscribeType::Progress {
203                             progress_index.insert(notify_data.task_id, index);
204                         }
205                         temp_notify_data.push((subscribe_type, notify_data));
206                     }
207                     _ => {}
208                 }
209             }
210             for (index, (subscribe_type, notify_data)) in temp_notify_data.into_iter().enumerate() {
211                 if subscribe_type != SubscribeType::Progress
212                     || progress_index.get(&notify_data.task_id) == Some(&index)
213                 {
214                     self.handle_send_notify_data(subscribe_type, notify_data)
215                         .await;
216                 }
217             }
218             debug!("Client handle message done");
219         }
220     }
221 
handle_send_response( &mut self, tid: u32, version: String, status_code: u32, reason: String, headers: Headers, )222     async fn handle_send_response(
223         &mut self,
224         tid: u32,
225         version: String,
226         status_code: u32,
227         reason: String,
228         headers: Headers,
229     ) {
230         let mut response = Vec::<u8>::new();
231 
232         response.extend_from_slice(&REQUEST_MAGIC_NUM.to_le_bytes());
233 
234         response.extend_from_slice(&self.message_id.to_le_bytes());
235         self.message_id += 1;
236 
237         let message_type = MessageType::HttpResponse as u16;
238         response.extend_from_slice(&message_type.to_le_bytes());
239 
240         let message_body_size: u16 = 0;
241         response.extend_from_slice(&message_body_size.to_le_bytes());
242 
243         response.extend_from_slice(&tid.to_le_bytes());
244 
245         response.extend_from_slice(&version.into_bytes());
246         response.push(b'\0');
247 
248         response.extend_from_slice(&status_code.to_le_bytes());
249 
250         response.extend_from_slice(&reason.into_bytes());
251         response.push(b'\0');
252 
253         // The maximum length of the headers in uds should not exceed 8192
254         let mut buf_size = 0;
255         for (k, v) in headers {
256             buf_size += k.as_bytes().len() + v.iter().map(|f| f.len()).sum::<usize>();
257             if buf_size > HEADERS_MAX_SIZE as usize {
258                 break;
259             }
260 
261             response.extend_from_slice(k.as_bytes());
262             response.push(b':');
263             for (i, sub_value) in v.iter().enumerate() {
264                 if i != 0 {
265                     response.push(b',');
266                 }
267                 response.extend_from_slice(sub_value);
268             }
269             response.push(b'\n');
270         }
271 
272         let mut size = response.len() as u16;
273         if size > HEADERS_MAX_SIZE {
274             info!("send response too long");
275             response.truncate(HEADERS_MAX_SIZE as usize);
276             size = HEADERS_MAX_SIZE;
277         }
278         debug!("send response size, {:?}", size);
279         let size = size.to_le_bytes();
280         response[POSITION_OF_LENGTH as usize] = size[0];
281         response[(POSITION_OF_LENGTH + 1) as usize] = size[1];
282 
283         self.send_message(response).await;
284     }
285 
handle_send_notify_data( &mut self, subscribe_type: SubscribeType, notify_data: NotifyData, )286     async fn handle_send_notify_data(
287         &mut self,
288         subscribe_type: SubscribeType,
289         notify_data: NotifyData,
290     ) {
291         let mut message = Vec::<u8>::new();
292 
293         message.extend_from_slice(&REQUEST_MAGIC_NUM.to_le_bytes());
294 
295         message.extend_from_slice(&self.message_id.to_le_bytes());
296         self.message_id += 1;
297 
298         let message_type = MessageType::NotifyData as u16;
299         message.extend_from_slice(&message_type.to_le_bytes());
300 
301         let message_body_size: u16 = 0;
302         message.extend_from_slice(&message_body_size.to_le_bytes());
303 
304         message.extend_from_slice(&(subscribe_type as u32).to_le_bytes());
305 
306         message.extend_from_slice(&notify_data.task_id.to_le_bytes());
307 
308         message.extend_from_slice(&(notify_data.progress.common_data.state as u32).to_le_bytes());
309 
310         let index = notify_data.progress.common_data.index;
311         message.extend_from_slice(&(index as u32).to_le_bytes());
312         // for one task, only send last progress message
313         message.extend_from_slice(&(notify_data.progress.processed[index] as u64).to_le_bytes());
314 
315         message.extend_from_slice(
316             &(notify_data.progress.common_data.total_processed as u64).to_le_bytes(),
317         );
318 
319         message.extend_from_slice(&(notify_data.progress.sizes.len() as u32).to_le_bytes());
320         for size in notify_data.progress.sizes {
321             message.extend_from_slice(&size.to_le_bytes());
322         }
323 
324         // The maximum length of the headers in uds should not exceed 8192
325         let mut buf_size = 0;
326         let index = notify_data
327             .progress
328             .extras
329             .iter()
330             .take_while(|x| {
331                 buf_size += x.0.len() + x.1.len();
332                 buf_size < HEADERS_MAX_SIZE as usize
333             })
334             .count();
335 
336         message.extend_from_slice(&(index as u32).to_le_bytes());
337         for (key, value) in notify_data.progress.extras.iter().take(index) {
338             message.extend_from_slice(key.as_bytes());
339             message.push(b'\0');
340             message.extend_from_slice(value.as_bytes());
341             message.push(b'\0');
342         }
343 
344         message.extend_from_slice(&(notify_data.action.repr as u32).to_le_bytes());
345 
346         message.extend_from_slice(&(notify_data.version as u32).to_le_bytes());
347 
348         // Param taskstates used for UploadFile when complete or fail
349         message.extend_from_slice(&(notify_data.each_file_status.len() as u32).to_le_bytes());
350         for status in notify_data.each_file_status {
351             if notify_data.version == Version::API9 {
352                 message.extend_from_slice(&status.path.into_bytes());
353             }
354             message.push(b'\0');
355             message.extend_from_slice(&(status.reason.repr as u32).to_le_bytes());
356             message.extend_from_slice(&status.message.into_bytes());
357             message.push(b'\0');
358         }
359 
360         let size = message.len() as u16;
361         if subscribe_type == SubscribeType::Progress {
362             debug!(
363                 "send tid {} {:?} size {}",
364                 notify_data.task_id, subscribe_type, size
365             );
366         } else {
367             info!(
368                 "send tid {} {:?} size {}",
369                 notify_data.task_id, subscribe_type, size
370             );
371         }
372 
373         let size = size.to_le_bytes();
374         message[POSITION_OF_LENGTH as usize] = size[0];
375         message[(POSITION_OF_LENGTH + 1) as usize] = size[1];
376 
377         self.send_message(message).await;
378     }
379 
send_message(&mut self, message: Vec<u8>)380     async fn send_message(&mut self, message: Vec<u8>) {
381         let ret = self.server_sock_fd.send(&message).await;
382         match ret {
383             Ok(size) => {
384                 debug!("send message ok, pid: {}, size: {}", self.pid, size);
385                 let mut buf: [u8; 4] = [0; 4];
386 
387                 match ylong_runtime::time::timeout(
388                     Duration::from_millis(500),
389                     self.server_sock_fd.recv(&mut buf),
390                 )
391                 .await
392                 {
393                     Ok(ret) => match ret {
394                         Ok(len) => {
395                             debug!("message recv len {:}", len);
396                         }
397                         Err(e) => {
398                             error!("message recv error: {:?}", e);
399                         }
400                     },
401                     Err(e) => {
402                         error!("message recv {}", e);
403                         return;
404                     }
405                 };
406 
407                 let len: u32 = u32::from_le_bytes(buf);
408                 if len != message.len() as u32 {
409                     error!("message len bad, send {:?}, recv {:?}", message.len(), len);
410                 } else {
411                     debug!("notify done, pid: {}", self.pid);
412                 }
413             }
414             Err(err) => {
415                 error!("message send error: {:?}", err);
416             }
417         }
418     }
419 }
420