1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3 4# Copyright (c) 2024 Huawei Device Co., Ltd. 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# 17 18import os 19import json 20import logging 21import sys 22import random 23import string 24from pathlib import Path 25from typing import Union, Dict, Any, Callable, List, Iterable, Optional 26from time import time 27from enum import Enum 28from datetime import datetime, timezone, timedelta 29from importlib.util import spec_from_file_location, module_from_spec 30 31PASS_LOG_LEVEL = logging.ERROR + 1 32TRACE_LOG_LEVEL = logging.DEBUG - 1 33log = logging.getLogger('vmb') 34 35 36def ensure_env_var(var: str) -> str: 37 """Ensure that env variable is set.""" 38 val = os.environ.get(var, '') 39 die(not val, 'Please set "%s" env var!', var) 40 return val 41 42 43def log_pass(self, message, *args, **kws): 44 """Inject new log level above info.""" 45 if self.isEnabledFor(PASS_LOG_LEVEL): 46 # pylint: disable-next=protected-access 47 self._log(PASS_LOG_LEVEL, message, args, **kws) 48 49 50def log_trace(self, message, *args, **kws): 51 """Inject new log level above info.""" 52 if self.isEnabledFor(TRACE_LOG_LEVEL): 53 # pylint: disable-next=protected-access 54 self._log(TRACE_LOG_LEVEL, message, args, **kws) 55 56 57def rnd_string(size=8): 58 """Random string of fixed size.""" 59 return ''.join( 60 random.choice( 61 string.ascii_uppercase + string.digits) for _ in range(size)) 62 63 64def pad_left(s: str, ln: int, limit: int = 80) -> str: 65 """Reurn string left-padded with spaces.""" 66 ln = min(ln, limit) 67 return f'{s[:ln]:<{ln}}' 68 69 70def remove_prefix(s: str, prefix: str) -> str: 71 """Strip prefix from string.""" 72 return s[len(prefix):] if s.startswith(prefix) else s 73 74 75def split_params(line: str) -> List[str]: 76 """Split comma-separated string into list.""" 77 return [t.strip() for t in line.split(',') if t.strip()] 78 79 80def norm_list(it: Optional[Iterable[str]]) -> List[str]: 81 """Remove duplicates and cast to lower case.""" 82 return sorted([t.lower() for t in set(it)]) if it else [] 83 84 85def log_time(f: Callable[..., Any]) -> Callable[..., Any]: 86 """Annotation for debug performance.""" 87 def f1(*args: Any, **kwargs: Any) -> Any: 88 log.trace('%s started', f.__name__) 89 start = time() 90 ret = f(*args, **kwargs) 91 log.trace('%s finished in %s', 92 f.__name__, 93 str(timedelta(seconds=time() - start))) 94 return ret 95 return f1 96 97 98def read_list_file(list_file: Union[str, Path]) -> List[str]: 99 """List file to array.""" 100 path = Path(list_file) 101 if not path.is_file(): 102 log.error('List file not found: %s', path) 103 return [] 104 with open(path, 'r', encoding="utf-8") as f: 105 lst = list(filter( 106 lambda x: x and not x.startswith('#'), 107 f.read().splitlines())) 108 return lst 109 110 111def import_module(module_name, path): 112 """Import py file as a module.""" 113 spec = spec_from_file_location(module_name, path) 114 if spec is None: 115 raise RuntimeError(f'Import module from "{path}" failed') 116 module = module_from_spec(spec) 117 loader = spec.loader 118 if loader: 119 loader.exec_module(module) 120 return module 121 122 123def get_plugin(plug_type: str, 124 plug_name: str, 125 extra: Optional[Path] = None) -> Any: 126 """Return plugin.""" 127 # try extra dir first 128 if extra: 129 die(not extra.is_dir(), 130 'Extra plugins dir "%s" does not exist!', extra) 131 py = extra.joinpath(plug_type, f'{plug_name}.py') 132 if py.is_file(): 133 return import_module(plug_name, str(py)) 134 # load default in case there is no extra one 135 py = Path(__file__).parent.resolve().joinpath( 136 'plugins', plug_type, f'{plug_name}.py') 137 if py.is_file(): 138 return import_module(plug_name, str(py)) 139 log.fatal('No such plugin: "%s"\n' 140 'Searching here: "%s"\n' 141 'To see available plugins: `vmb list`', plug_name, py) 142 sys.exit(1) 143 144 145def get_plugins(plug_type: str, 146 plugins: List[str], 147 extra: Optional[Path]) -> Dict[str, Any]: 148 """Return dict of plugins.""" 149 plugs = {} 150 for plug_name in plugins: 151 plug = get_plugin(plug_type, plug_name, extra) 152 plugs[plug_name] = plug 153 return plugs 154 155 156class Timer: 157 """Simple struct for begin-end.""" 158 159 tz = datetime.now(timezone.utc).astimezone().tzinfo 160 tm_format = "%Y-%m-%dT%H:%M:%S.00000%z" 161 162 def __init__(self) -> None: 163 self.begin = datetime.now(timezone.utc) 164 self.end = self.begin 165 166 @staticmethod 167 def format(t) -> str: 168 if not isinstance(t, datetime): 169 return 'unknown' 170 return t.astimezone(Timer.tz).strftime(Timer.tm_format) 171 172 def start(self) -> None: 173 self.begin = datetime.now(timezone.utc) 174 self.end = self.begin 175 176 def finish(self) -> None: 177 self.end = datetime.now(timezone.utc) 178 179 def elapsed(self) -> timedelta: 180 return self.end - self.begin 181 182 183class Singleton(type): 184 """Singleton.""" 185 186 __instances: Dict[Any, Any] = {} 187 188 def __call__(cls, *args: Any, **kwargs: Any) -> Any: 189 """Instantiante singleton.""" 190 if cls not in cls.__instances: 191 cls.__instances[cls] = \ 192 super(Singleton, cls).__call__(*args, **kwargs) 193 return cls.__instances[cls] 194 195 196class StringEnum(Enum): 197 """String Enum.""" 198 199 def __lt__(self, other): 200 return self.value < other.value 201 202 @classmethod 203 def getall(cls) -> List[str]: 204 return [str(x.value) for x in cls] 205 206 207class Jsonable: 208 """Base class (abstract) for json-serialiazation.""" 209 210 @staticmethod 211 def get_props(obj): 212 """Search for properties.""" 213 # add all 'public' fields 214 props = {k: v for k, v in obj.__dict__.items() 215 if not str(k).startswith('_')} 216 # add all @property-decorated fields 217 props.update( 218 {name: value.fget(obj) for name, value 219 in vars(obj.__class__).items() 220 if isinstance(value, property)}) 221 return props 222 223 def js(self, sort_keys=False, indent=4) -> str: 224 """Serialize object to json.""" 225 return json.dumps( 226 self, 227 default=Jsonable.get_props, 228 sort_keys=sort_keys, 229 indent=indent) 230 231 def save(self, json_file: Union[Path, str]) -> None: 232 with create_file(json_file) as f: 233 f.write(self.js()) 234 235 236class ColorFormatter(logging.Formatter): 237 """Colorful log.""" 238 239 red = "\x1b[31;20m" 240 green = "\x1b[32;20m" 241 grey = "\x1b[38;20m" 242 magenta = "\x1b[35;20m" 243 cyan = "\x1b[36;1m" 244 yellow = "\x1b[33;20m" 245 bold_red = "\x1b[31;1m" 246 reset = "\x1b[0m" 247 orange = "\x1b[33;20m" 248 bold_blue = "\x1b[34;1m" 249 fmt = '%(message)s' 250 251 FORMATS = { 252 TRACE_LOG_LEVEL: magenta + fmt + reset, 253 logging.DEBUG: bold_blue + fmt + reset, 254 logging.INFO: cyan + fmt + reset, 255 PASS_LOG_LEVEL: green + fmt + reset, 256 logging.WARNING: yellow + fmt + reset, 257 logging.ERROR: red + fmt + reset, 258 logging.CRITICAL: bold_red + fmt + reset 259 } 260 261 def format(self, record): 262 """Format.""" 263 log_fmt = self.FORMATS.get(record.levelno) 264 formatter = logging.Formatter(log_fmt) 265 return formatter.format(record) 266 267 268def create_file(path: Union[str, Path]): 269 """Create file in `safe` manner.""" 270 return os.fdopen( 271 os.open(path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o664), 272 mode='w', encoding='utf-8') 273 274 275def load_file(path: Union[str, Path]) -> str: 276 """Read file to string.""" 277 fd = os.fdopen( 278 os.open(path, os.O_RDONLY), mode="r", encoding='utf-8', buffering=1) 279 fd.seek(0) 280 return fd.read() 281 282 283def force_link(link: Path, dest: Path) -> None: 284 if link.exists(): 285 link.unlink() 286 link.symlink_to(dest) 287 288 289def die(condition: bool, *msg) -> None: 290 if condition: 291 log.fatal(*msg) 292 sys.exit(1) 293