1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3 4# Copyright (c) 2024-2025 Huawei Device Co., Ltd. 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# 17 18import os 19import json 20import logging 21import sys 22import random 23import string 24import shutil 25from pathlib import Path 26from string import Template 27from typing import Union, Dict, Any, Callable, List, Iterable, Optional 28from time import time 29from enum import Enum 30from datetime import datetime, timezone, timedelta 31from importlib.util import spec_from_file_location, module_from_spec 32 33PASS_LOG_LEVEL = logging.ERROR + 1 34TRACE_LOG_LEVEL = logging.DEBUG - 1 35log = logging.getLogger('vmb') 36 37 38def ensure_env_var(var: str) -> str: 39 """Ensure that env variable is set.""" 40 val = os.environ.get(var, '') 41 die(not val, 'Please set "%s" env var!', var) 42 return val 43 44 45def log_pass(self, message, *args, **kws): 46 """Inject new log level above info.""" 47 if self.isEnabledFor(PASS_LOG_LEVEL): 48 # pylint: disable-next=protected-access 49 self._log(PASS_LOG_LEVEL, message, args, **kws) 50 51 52def log_trace(self, message, *args, **kws): 53 """Inject new log level above info.""" 54 if self.isEnabledFor(TRACE_LOG_LEVEL): 55 # pylint: disable-next=protected-access 56 self._log(TRACE_LOG_LEVEL, message, args, **kws) 57 58 59def rnd_string(size=8): 60 """Random string of fixed size.""" 61 return ''.join( 62 random.choice( 63 string.ascii_uppercase + string.digits) for _ in range(size)) 64 65 66def pad_left(s: str, ln: int, limit: int = 80) -> str: 67 """Reurn string left-padded with spaces.""" 68 ln = min(ln, limit) 69 return f'{s[:ln]:<{ln}}' 70 71 72def remove_prefix(s: str, prefix: str) -> str: 73 """Strip prefix from string.""" 74 return s[len(prefix):] if s.startswith(prefix) else s 75 76 77def split_params(line: str) -> List[str]: 78 """Split comma-separated string into list.""" 79 return [t.strip() for t in line.split(',') if t.strip()] 80 81 82def norm_list(it: Optional[Iterable[str]]) -> List[str]: 83 """Remove duplicates and cast to lower case.""" 84 return sorted([t.lower() for t in set(it)]) if it else [] 85 86 87def log_time(f: Callable[..., Any]) -> Callable[..., Any]: 88 """Annotation for debug performance.""" 89 def f1(*args: Any, **kwargs: Any) -> Any: 90 log.trace('%s started', f.__name__) 91 start = time() 92 ret = f(*args, **kwargs) 93 log.trace('%s finished in %s', 94 f.__name__, 95 str(timedelta(seconds=time() - start))) 96 return ret 97 return f1 98 99 100def read_list_file(list_file: Union[str, Path]) -> List[str]: 101 """List file to array.""" 102 path = Path(list_file) 103 if not path.is_file(): 104 log.error('List file not found: %s', path) 105 return [] 106 with open(path, 'r', encoding="utf-8") as f: 107 lst = list(filter( 108 lambda x: x and not x.startswith('#'), 109 f.read().splitlines())) 110 return lst 111 112 113def import_module(module_name, path): 114 """Import py file as a module.""" 115 spec = spec_from_file_location(module_name, path) 116 if spec is None: 117 raise RuntimeError(f'Import module from "{path}" failed') 118 module = module_from_spec(spec) 119 loader = spec.loader 120 if loader: 121 loader.exec_module(module) 122 return module 123 124 125def get_plugin(plug_type: str, 126 plug_name: str, 127 extra: Optional[Path] = None) -> Any: 128 """Return plugin.""" 129 # try extra dir first 130 if extra: 131 die(not extra.is_dir(), 132 'Extra plugins dir "%s" does not exist!', extra) 133 py = extra.joinpath(plug_type, f'{plug_name}.py') 134 if py.is_file(): 135 return import_module(plug_name, str(py)) 136 # load default in case there is no extra one 137 py = Path(__file__).parent.resolve().joinpath( 138 'plugins', plug_type, f'{plug_name}.py') 139 if py.is_file(): 140 return import_module(plug_name, str(py)) 141 log.fatal('No such plugin: "%s"\n' 142 'Searching here: "%s"\n' 143 'To see available plugins: `vmb list`', plug_name, py) 144 sys.exit(1) 145 146 147def get_plugins(plug_type: str, 148 plugins: List[str], 149 extra: Optional[Path]) -> Dict[str, Any]: 150 """Return dict of plugins.""" 151 plugs = {} 152 for plug_name in plugins: 153 plug = get_plugin(plug_type, plug_name, extra) 154 plugs[plug_name] = plug 155 return plugs 156 157 158class Timer: 159 """Simple struct for begin-end.""" 160 161 tz = datetime.now(timezone.utc).astimezone().tzinfo 162 tm_format = "%Y-%m-%dT%H:%M:%S.00000%z" 163 164 def __init__(self) -> None: 165 self.begin = datetime.now(timezone.utc) 166 self.end = self.begin 167 168 @staticmethod 169 def format(t) -> str: 170 if not isinstance(t, datetime): 171 return 'unknown' 172 return t.astimezone(Timer.tz).strftime(Timer.tm_format) 173 174 def start(self) -> None: 175 self.begin = datetime.now(timezone.utc) 176 self.end = self.begin 177 178 def finish(self) -> None: 179 self.end = datetime.now(timezone.utc) 180 181 def elapsed(self) -> timedelta: 182 return self.end - self.begin 183 184 185class Singleton(type): 186 """Singleton.""" 187 188 __instances: Dict[Any, Any] = {} 189 190 def __call__(cls, *args: Any, **kwargs: Any) -> Any: 191 """Instantiante singleton.""" 192 if cls not in cls.__instances: 193 cls.__instances[cls] = \ 194 super(Singleton, cls).__call__(*args, **kwargs) 195 return cls.__instances[cls] 196 197 198class StringEnum(Enum): 199 """String Enum.""" 200 201 def __lt__(self, other): 202 return self.value < other.value 203 204 @classmethod 205 def getall(cls) -> List[str]: 206 return [str(x.value) for x in cls] 207 208 209class Jsonable: 210 """Base class (abstract) for json-serialiazation.""" 211 212 @staticmethod 213 def get_props(obj): 214 """Search for properties.""" 215 # add all 'public' fields 216 props = {k: v for k, v in obj.__dict__.items() 217 if not str(k).startswith('_')} 218 # add all @property-decorated fields 219 props.update( 220 {name: value.fget(obj) for name, value 221 in vars(obj.__class__).items() 222 if isinstance(value, property)}) 223 return props 224 225 def js(self, sort_keys=False, indent=4) -> str: 226 """Serialize object to json.""" 227 return json.dumps( 228 self, 229 default=Jsonable.get_props, 230 sort_keys=sort_keys, 231 indent=indent) 232 233 def save(self, json_file: Union[Path, str]) -> None: 234 with create_file(json_file) as f: 235 f.write(self.js()) 236 237 238class ColorFormatter(logging.Formatter): 239 """Colorful log.""" 240 241 red = "\x1b[31;20m" 242 green = "\x1b[32;20m" 243 grey = "\x1b[38;20m" 244 magenta = "\x1b[35;20m" 245 cyan = "\x1b[36;1m" 246 yellow = "\x1b[33;20m" 247 bold_red = "\x1b[31;1m" 248 reset = "\x1b[0m" 249 orange = "\x1b[33;20m" 250 bold_blue = "\x1b[34;1m" 251 fmt = '%(message)s' 252 253 FORMATS = { 254 TRACE_LOG_LEVEL: magenta + fmt + reset, 255 logging.DEBUG: bold_blue + fmt + reset, 256 logging.INFO: cyan + fmt + reset, 257 PASS_LOG_LEVEL: green + fmt + reset, 258 logging.WARNING: yellow + fmt + reset, 259 logging.ERROR: red + fmt + reset, 260 logging.CRITICAL: bold_red + fmt + reset 261 } 262 263 def format(self, record): 264 """Format.""" 265 log_fmt = self.FORMATS.get(record.levelno) 266 formatter = logging.Formatter(log_fmt) 267 return formatter.format(record) 268 269 270def create_file(path: Union[str, Path]): 271 """Create file in `safe` manner.""" 272 Path(path).parent.mkdir(parents=True, exist_ok=True) 273 return os.fdopen( 274 os.open(path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o664), 275 mode='w', encoding='utf-8') 276 277 278def copy_file(src: Union[str, Path], dst: Union[str, Path]) -> None: 279 log.trace('Copy: %s -> %s', str(src), str(dst)) 280 s = Path(src) 281 d = Path(dst) 282 if not s.exists(): 283 raise RuntimeError(f'File not found: {src}') 284 d.parent.mkdir(parents=True, exist_ok=True) 285 shutil.copy(s, d) 286 287 288def copy_files(src: Union[str, Path], dst: Union[str, Path], pattern: str = '*') -> None: 289 log.trace('Copy: %s%s -> %s', str(src), pattern, str(dst)) 290 s = Path(src) 291 d = Path(dst) 292 if not s.exists(): 293 raise RuntimeError(f'File not found: {src}') 294 d.mkdir(parents=True, exist_ok=True) 295 for f in s.glob(pattern): 296 shutil.copy(f, d.joinpath(f.name)) 297 298 299def create_file_from_template(tpl: Union[str, Path], dst: Union[str, Path], **kwargs) -> None: 300 with open(tpl, 'r', encoding="utf-8") as src: 301 template = Template(src.read()) 302 with create_file(dst) as f: 303 f.write(template.substitute(**kwargs)) 304 305 306def load_file(path: Union[str, Path]) -> str: 307 """Read file to string.""" 308 fd = os.fdopen( 309 os.open(path, os.O_RDONLY), mode="r", encoding='utf-8', buffering=1) 310 fd.seek(0) 311 return fd.read() 312 313 314def load_json(path: Union[str, Path]) -> Any: 315 json_path = Path(path) 316 if not json_path.exists(): 317 raise RuntimeError(f'File not found: {path}') 318 try: 319 with open(json_path, 'r', encoding='utf-8') as f: 320 j = json.load(f) 321 except json.JSONDecodeError as e: 322 log.error('Bad json: %s\n%s', str(path), str(e)) 323 raise RuntimeError from e 324 return j 325 326 327def force_link(link: Path, dest: Path) -> None: 328 log.trace('Force link: %s -> %s', str(link), str(dest)) 329 if link.exists(): 330 link.unlink() 331 link.symlink_to(dest) 332 333 334def die(condition: bool, *msg) -> None: 335 if condition: 336 log.fatal(*msg) 337 sys.exit(1) 338