• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3
4# Copyright (c) 2024-2025 Huawei Device Co., Ltd.
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18import os
19import json
20import logging
21import sys
22import random
23import string
24import shutil
25from pathlib import Path
26from string import Template
27from typing import Union, Dict, Any, Callable, List, Iterable, Optional
28from time import time
29from enum import Enum
30from datetime import datetime, timezone, timedelta
31from importlib.util import spec_from_file_location, module_from_spec
32
33PASS_LOG_LEVEL = logging.ERROR + 1
34TRACE_LOG_LEVEL = logging.DEBUG - 1
35log = logging.getLogger('vmb')
36
37
38def ensure_env_var(var: str) -> str:
39    """Ensure that env variable is set."""
40    val = os.environ.get(var, '')
41    die(not val, 'Please set "%s" env var!', var)
42    return val
43
44
45def log_pass(self, message, *args, **kws):
46    """Inject new log level above info."""
47    if self.isEnabledFor(PASS_LOG_LEVEL):
48        # pylint: disable-next=protected-access
49        self._log(PASS_LOG_LEVEL, message, args, **kws)
50
51
52def log_trace(self, message, *args, **kws):
53    """Inject new log level above info."""
54    if self.isEnabledFor(TRACE_LOG_LEVEL):
55        # pylint: disable-next=protected-access
56        self._log(TRACE_LOG_LEVEL, message, args, **kws)
57
58
59def rnd_string(size=8):
60    """Random string of fixed size."""
61    return ''.join(
62        random.choice(
63            string.ascii_uppercase + string.digits) for _ in range(size))
64
65
66def pad_left(s: str, ln: int, limit: int = 80) -> str:
67    """Reurn string left-padded with spaces."""
68    ln = min(ln, limit)
69    return f'{s[:ln]:<{ln}}'
70
71
72def remove_prefix(s: str, prefix: str) -> str:
73    """Strip prefix from string."""
74    return s[len(prefix):] if s.startswith(prefix) else s
75
76
77def split_params(line: str) -> List[str]:
78    """Split comma-separated string into list."""
79    return [t.strip() for t in line.split(',') if t.strip()]
80
81
82def norm_list(it: Optional[Iterable[str]]) -> List[str]:
83    """Remove duplicates and cast to lower case."""
84    return sorted([t.lower() for t in set(it)]) if it else []
85
86
87def log_time(f: Callable[..., Any]) -> Callable[..., Any]:
88    """Annotation for debug performance."""
89    def f1(*args: Any, **kwargs: Any) -> Any:
90        log.trace('%s started', f.__name__)
91        start = time()
92        ret = f(*args, **kwargs)
93        log.trace('%s finished in %s',
94                  f.__name__,
95                  str(timedelta(seconds=time() - start)))
96        return ret
97    return f1
98
99
100def read_list_file(list_file: Union[str, Path]) -> List[str]:
101    """List file to array."""
102    path = Path(list_file)
103    if not path.is_file():
104        log.error('List file not found: %s', path)
105        return []
106    with open(path, 'r', encoding="utf-8") as f:
107        lst = list(filter(
108            lambda x: x and not x.startswith('#'),
109            f.read().splitlines()))
110    return lst
111
112
113def import_module(module_name, path):
114    """Import py file as a module."""
115    spec = spec_from_file_location(module_name, path)
116    if spec is None:
117        raise RuntimeError(f'Import module from "{path}" failed')
118    module = module_from_spec(spec)
119    loader = spec.loader
120    if loader:
121        loader.exec_module(module)
122    return module
123
124
125def get_plugin(plug_type: str,
126               plug_name: str,
127               extra: Optional[Path] = None) -> Any:
128    """Return plugin."""
129    # try extra dir first
130    if extra:
131        die(not extra.is_dir(),
132            'Extra plugins dir "%s" does not exist!', extra)
133        py = extra.joinpath(plug_type, f'{plug_name}.py')
134        if py.is_file():
135            return import_module(plug_name, str(py))
136    # load default in case there is no extra one
137    py = Path(__file__).parent.resolve().joinpath(
138        'plugins', plug_type, f'{plug_name}.py')
139    if py.is_file():
140        return import_module(plug_name, str(py))
141    log.fatal('No such plugin: "%s"\n'
142              'Searching here: "%s"\n'
143              'To see available plugins: `vmb list`', plug_name, py)
144    sys.exit(1)
145
146
147def get_plugins(plug_type: str,
148                plugins: List[str],
149                extra: Optional[Path]) -> Dict[str, Any]:
150    """Return dict of plugins."""
151    plugs = {}
152    for plug_name in plugins:
153        plug = get_plugin(plug_type, plug_name, extra)
154        plugs[plug_name] = plug
155    return plugs
156
157
158class Timer:
159    """Simple struct for begin-end."""
160
161    tz = datetime.now(timezone.utc).astimezone().tzinfo
162    tm_format = "%Y-%m-%dT%H:%M:%S.00000%z"
163
164    def __init__(self) -> None:
165        self.begin = datetime.now(timezone.utc)
166        self.end = self.begin
167
168    @staticmethod
169    def format(t) -> str:
170        if not isinstance(t, datetime):
171            return 'unknown'
172        return t.astimezone(Timer.tz).strftime(Timer.tm_format)
173
174    def start(self) -> None:
175        self.begin = datetime.now(timezone.utc)
176        self.end = self.begin
177
178    def finish(self) -> None:
179        self.end = datetime.now(timezone.utc)
180
181    def elapsed(self) -> timedelta:
182        return self.end - self.begin
183
184
185class Singleton(type):
186    """Singleton."""
187
188    __instances: Dict[Any, Any] = {}
189
190    def __call__(cls, *args: Any, **kwargs: Any) -> Any:
191        """Instantiante singleton."""
192        if cls not in cls.__instances:
193            cls.__instances[cls] = \
194                super(Singleton, cls).__call__(*args, **kwargs)
195        return cls.__instances[cls]
196
197
198class StringEnum(Enum):
199    """String Enum."""
200
201    def __lt__(self, other):
202        return self.value < other.value
203
204    @classmethod
205    def getall(cls) -> List[str]:
206        return [str(x.value) for x in cls]
207
208
209class Jsonable:
210    """Base class (abstract) for json-serialiazation."""
211
212    @staticmethod
213    def get_props(obj):
214        """Search for properties."""
215        # add all 'public' fields
216        props = {k: v for k, v in obj.__dict__.items()
217                 if not str(k).startswith('_')}
218        # add all @property-decorated fields
219        props.update(
220            {name: value.fget(obj) for name, value
221             in vars(obj.__class__).items()
222             if isinstance(value, property)})
223        return props
224
225    def js(self, sort_keys=False, indent=4) -> str:
226        """Serialize object to json."""
227        return json.dumps(
228            self,
229            default=Jsonable.get_props,
230            sort_keys=sort_keys,
231            indent=indent)
232
233    def save(self, json_file: Union[Path, str]) -> None:
234        with create_file(json_file) as f:
235            f.write(self.js())
236
237
238class ColorFormatter(logging.Formatter):
239    """Colorful log."""
240
241    red = "\x1b[31;20m"
242    green = "\x1b[32;20m"
243    grey = "\x1b[38;20m"
244    magenta = "\x1b[35;20m"
245    cyan = "\x1b[36;1m"
246    yellow = "\x1b[33;20m"
247    bold_red = "\x1b[31;1m"
248    reset = "\x1b[0m"
249    orange = "\x1b[33;20m"
250    bold_blue = "\x1b[34;1m"
251    fmt = '%(message)s'
252
253    FORMATS = {
254        TRACE_LOG_LEVEL: magenta + fmt + reset,
255        logging.DEBUG: bold_blue + fmt + reset,
256        logging.INFO: cyan + fmt + reset,
257        PASS_LOG_LEVEL: green + fmt + reset,
258        logging.WARNING: yellow + fmt + reset,
259        logging.ERROR: red + fmt + reset,
260        logging.CRITICAL: bold_red + fmt + reset
261    }
262
263    def format(self, record):
264        """Format."""
265        log_fmt = self.FORMATS.get(record.levelno)
266        formatter = logging.Formatter(log_fmt)
267        return formatter.format(record)
268
269
270def create_file(path: Union[str, Path]):
271    """Create file in `safe` manner."""
272    Path(path).parent.mkdir(parents=True, exist_ok=True)
273    return os.fdopen(
274        os.open(path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o664),
275        mode='w', encoding='utf-8')
276
277
278def copy_file(src: Union[str, Path], dst: Union[str, Path]) -> None:
279    log.trace('Copy: %s -> %s', str(src), str(dst))
280    s = Path(src)
281    d = Path(dst)
282    if not s.exists():
283        raise RuntimeError(f'File not found: {src}')
284    d.parent.mkdir(parents=True, exist_ok=True)
285    shutil.copy(s, d)
286
287
288def copy_files(src: Union[str, Path], dst: Union[str, Path], pattern: str = '*') -> None:
289    log.trace('Copy: %s%s -> %s', str(src), pattern, str(dst))
290    s = Path(src)
291    d = Path(dst)
292    if not s.exists():
293        raise RuntimeError(f'File not found: {src}')
294    d.mkdir(parents=True, exist_ok=True)
295    for f in s.glob(pattern):
296        shutil.copy(f, d.joinpath(f.name))
297
298
299def create_file_from_template(tpl: Union[str, Path], dst: Union[str, Path], **kwargs) -> None:
300    with open(tpl, 'r', encoding="utf-8") as src:
301        template = Template(src.read())
302        with create_file(dst) as f:
303            f.write(template.substitute(**kwargs))
304
305
306def load_file(path: Union[str, Path]) -> str:
307    """Read file to string."""
308    fd = os.fdopen(
309        os.open(path, os.O_RDONLY), mode="r", encoding='utf-8', buffering=1)
310    fd.seek(0)
311    return fd.read()
312
313
314def load_json(path: Union[str, Path]) -> Any:
315    json_path = Path(path)
316    if not json_path.exists():
317        raise RuntimeError(f'File not found: {path}')
318    try:
319        with open(json_path, 'r', encoding='utf-8') as f:
320            j = json.load(f)
321    except json.JSONDecodeError as e:
322        log.error('Bad json: %s\n%s', str(path), str(e))
323        raise RuntimeError from e
324    return j
325
326
327def force_link(link: Path, dest: Path) -> None:
328    log.trace('Force link: %s -> %s', str(link), str(dest))
329    if link.exists():
330        link.unlink()
331    link.symlink_to(dest)
332
333
334def die(condition: bool, *msg) -> None:
335    if condition:
336        log.fatal(*msg)
337        sys.exit(1)
338