1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3 4# Copyright (c) 2024 Huawei Device Co., Ltd. 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# 17 18import re 19import os 20import logging 21import signal 22from typing import Union, Optional 23from pathlib import Path 24from subprocess import Popen, PIPE 25from threading import Thread, Timer 26from dataclasses import dataclass 27from tempfile import mktemp 28from vmb.helpers import Singleton 29 30log = logging.getLogger('vmb') 31tm_re = re.compile( 32 r"(?:Elapsed.*\(h:mm:ss or m:ss\)|Real time)" 33 r"[^:]*:\s*(?:(\d*):)?(\d*)(?:.(\d*))?") 34rss_re = re.compile(r"(?:Maximum resident set size|Max RSS)[^:]*:\s*(\d*)") 35 36 37@dataclass 38class ShellResult: 39 40 # Default initial result is 'failure' 41 ret: int = -13 42 out: str = '' 43 err: str = '' 44 tm: float = 0.0 45 rss: int = 0 46 47 def grep(self, regex: str) -> str: 48 out = self.out.split("\n") 49 err = self.err.split("\n") 50 for line in out + err: 51 m = re.search(regex, line) 52 if m: 53 if len(m.groups()) < 1: 54 return m.group() 55 return m.group(1) 56 return '' 57 58 def set_ret_val(self) -> None: 59 if not self.out: 60 log.error("No shell output") 61 self.ret = -13 62 matches = re.search(r"__RET_VAL__=(\d*)", self.out) 63 if not matches: 64 log.error("No shell ret val; out:") 65 self.ret = -13 66 else: 67 self.ret = int(matches.groups()[0]) 68 69 def set_time(self) -> None: 70 # expecting output of '\time -v' to stderr 71 if not self.err: 72 return 73 tm_val = re.search(tm_re, self.err) 74 if tm_val: 75 tmp = tm_val.groups() 76 if tmp[0] is None: 77 self.tm = float(str(tmp[1]) + "." + tmp[2]) 78 else: 79 self.tm = int(tmp[0]) * 60 + float(str(tmp[1]) + "." + tmp[2]) 80 self.tm = round(self.tm, 5) 81 else: 82 self.tm = 0.0 83 self.tm *= 1e9 84 rss_val = re.search(rss_re, self.err) 85 if rss_val: 86 self.rss = int(rss_val.group(1)) 87 else: 88 self.rss = 0 89 90 def log_output(self) -> None: 91 if self.ret != 0: 92 if self.out: 93 log.error(self.out) 94 err = self.err.split("\n")[:3] if self.err else [] 95 for line in err: 96 log.error(line.strip()) 97 else: 98 if self.out: 99 log.debug(self.out) 100 101 102class ShellBase(metaclass=Singleton): 103 104 def __init__(self, timeout: Optional[float] = None) -> None: 105 self._timeout = timeout 106 self.taskset = '' 107 108 @staticmethod 109 def timed_cmd(cmd: str) -> str: 110 return f"\\time -v env {cmd}" 111 112 def run(self, 113 cmd: str, 114 measure_time: bool = False, 115 timeout: Optional[float] = None, 116 cwd: str = '') -> ShellResult: 117 raise NotImplementedError 118 119 def run_async(self, cmd: str) -> None: 120 raise NotImplementedError 121 122 def push(self, 123 src: Union[str, Path], 124 dst: Union[str, Path]) -> ShellResult: 125 raise NotImplementedError 126 127 def pull(self, 128 src: Union[str, Path], 129 dst: Union[str, Path]) -> ShellResult: 130 raise NotImplementedError 131 132 def get_filesize(self, filepath: Union[str, Path]) -> int: 133 if os.path.exists(str(filepath)): 134 return os.stat(str(filepath)).st_size 135 return 0 136 137 def grep_output(self, cmd: str, regex: str) -> str: 138 r = self.run(cmd=cmd) 139 return r.grep(regex) 140 141 def set_affinity(self, arg: str) -> None: 142 """Set affinity mask for processes. 143 144 Effective only on devices, so hardcoding path 145 """ 146 self.taskset = f'/system/bin/taskset -a {arg}' 147 148 149class ShellUnix(ShellBase): 150 151 def __init__(self, timeout: Optional[float] = None) -> None: 152 super().__init__(timeout=timeout) 153 154 def run(self, 155 cmd: str, 156 measure_time: bool = False, 157 timeout: Optional[float] = None, 158 cwd: str = '') -> ShellResult: 159 return self.__run( 160 cmd, measure_time=measure_time, timeout=timeout, cwd=cwd) 161 162 def push(self, 163 src: Union[str, Path], 164 dst: Union[str, Path]) -> ShellResult: 165 raise NotImplementedError 166 167 def pull(self, 168 src: Union[str, Path], 169 dst: Union[str, Path]) -> ShellResult: 170 raise NotImplementedError 171 172 def run_async(self, cmd: str) -> None: 173 def run_shell(): 174 # pylint: disable-next=all 175 return Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE) # NOQA 176 177 log.debug('Async cmd: %s', cmd) 178 async_trhead = Thread(target=run_shell) 179 async_trhead.daemon = True 180 async_trhead.start() 181 182 def __run(self, 183 cmd: str, 184 measure_time: bool = False, 185 timeout: Optional[float] = None, 186 cwd: str = '') -> ShellResult: 187 if measure_time: 188 cmd = self.timed_cmd(cmd) 189 result = self.__exec_process(cmd, cwd=cwd, timeout=timeout) 190 if measure_time: 191 result.set_time() 192 result.log_output() 193 return result 194 195 def __exec_process(self, cmd: str, cwd: str = '', 196 timeout: Optional[float] = None) -> ShellResult: 197 result = ShellResult() 198 # Note: self._timeout=None so default behaivior is to wait forever 199 to = timeout if timeout else self._timeout 200 if timeout is not None and self._timeout is not None: 201 to = max(timeout, self._timeout) 202 log.debug(cmd) 203 log.trace('CWD="%s" Timeout=[%s]', cwd, to) 204 # pylint: disable-next=all 205 with Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE, # NOQA 206 cwd=(cwd if cwd else None), 207 preexec_fn=os.setsid) as proc: 208 if to is not None: 209 timer = Timer(to, 210 lambda x: os.killpg( 211 os.getpgid(x.pid), signal.SIGKILL), [proc]) 212 timer.start() 213 out, err = proc.communicate(timeout=to) 214 if to is not None: 215 timer.cancel() 216 ret_code = proc.poll() 217 if ret_code is not None: 218 result.ret = ret_code 219 result.out = out.decode('utf-8') 220 result.err = err.decode('utf-8') 221 return result 222 223 224class ShellDevice(ShellBase): 225 def __init__(self, 226 dev_sh: str, 227 timeout: Optional[float] = None, 228 tmp_dir: str = '/data/local/tmp/vmb',) -> None: 229 super().__init__(timeout=timeout) 230 self._sh = ShellUnix() 231 self._devsh = dev_sh 232 self.tmp_dir = tmp_dir 233 self.stderr_out = os.path.join(tmp_dir, 'vmb-stderr.out') 234 235 def run(self, cmd: str, 236 measure_time: bool = False, 237 timeout: Optional[float] = None, 238 cwd: str = '') -> ShellResult: 239 redir = '' 240 if measure_time: 241 cmd = f"\\time -v {self.taskset} env {cmd}" 242 redir = f' 2>{self.stderr_out}' 243 cwd = f'cd {cwd}; ' if cwd else '' 244 res = self._sh.run( 245 f"{self._devsh} shell '{cwd}({cmd}){redir}; echo __RET_VAL__=$?'", 246 timeout=timeout, 247 measure_time=False) 248 res.set_ret_val() 249 if measure_time: 250 stderr_host = mktemp(prefix='vmb-') 251 self.pull(self.stderr_out, stderr_host) 252 self._sh.run(f"{self._devsh} shell 'rm -f {self.stderr_out}'") 253 if not Path(stderr_host).exists(): 254 res.err = 'Pull from device failed' 255 return res 256 with open(stderr_host, 'r', encoding="utf-8") as f: 257 res.err = f.read() 258 self._sh.run(f'rm -f {stderr_host}') 259 res.set_time() 260 else: 261 res.err = '' 262 return res 263 264 def run_async(self, cmd: str) -> None: 265 self._sh.run_async(f"{self._devsh} shell '{cmd}'") 266 267 def get_filesize(self, filepath: Union[str, Path]) -> int: 268 res = self.run(f"stat -c '%s' {filepath}") 269 if res.ret == 0 and res.out: 270 return int(res.out.split("\n")[0]) 271 return 0 272 273 def push(self, 274 src: Union[str, Path], 275 dst: Union[str, Path]) -> ShellResult: 276 raise NotImplementedError 277 278 def pull(self, 279 src: Union[str, Path], 280 dst: Union[str, Path]) -> ShellResult: 281 raise NotImplementedError 282 283 def mk_tmp_dir(self): 284 res = self.run(f'mkdir -p {self.tmp_dir}') 285 if res.ret != 0: 286 raise RuntimeError('Device connection failed!\n' 287 f'{res.out}\n{res.err}') 288 289 290class ShellAdb(ShellDevice): 291 binname = f"a{'d'}b" 292 293 def __init__(self, 294 dev_serial: str = '', 295 timeout: Optional[float] = None, 296 tmp_dir: str = '/data/local/tmp/vmb') -> None: 297 super().__init__( 298 f"{os.environ.get(self.binname.upper(), self.binname)}", 299 timeout=timeout, 300 tmp_dir=tmp_dir) 301 if dev_serial: 302 self._devsh = f'{self._devsh} -s {dev_serial}' 303 self.mk_tmp_dir() 304 305 def push(self, 306 src: Union[str, Path], 307 dst: Union[str, Path]) -> ShellResult: 308 return self._sh.run(f'{self._devsh} push {src} {dst}', 309 measure_time=False) 310 311 def pull(self, 312 src: Union[str, Path], 313 dst: Union[str, Path]) -> ShellResult: 314 return self._sh.run(f'{self._devsh} pull {src} {dst}', 315 measure_time=False) 316 317 318class ShellHdc(ShellDevice): 319 def __init__(self, 320 dev_serial: str = '', 321 timeout: Optional[float] = None, 322 tmp_dir: str = '/data/local/tmp/vmb') -> None: 323 # -l0 because of HDC mutex file permission messages 324 # -p (undocumented) due to poor hdc performance 325 super().__init__(f"{os.environ.get('HDC', 'hdc')} -p -l0", 326 timeout=timeout, 327 tmp_dir=tmp_dir) 328 if dev_serial: 329 self._devsh = f'{self._devsh} -t {dev_serial}' 330 self.mk_tmp_dir() 331 332 def push(self, 333 src: Union[str, Path], 334 dst: Union[str, Path]) -> ShellResult: 335 return self._sh.run(f'{self._devsh} file send {src} {dst}', 336 measure_time=False) 337 338 def pull(self, 339 src: Union[str, Path], 340 dst: Union[str, Path]) -> ShellResult: 341 return self._sh.run(f'{self._devsh} file recv {src} {dst}', 342 measure_time=False) 343