• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3
4# Copyright (c) 2024 Huawei Device Co., Ltd.
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18import os
19import json
20import logging
21import sys
22import random
23import string
24from pathlib import Path
25from typing import Union, Dict, Any, Callable, List, Iterable, Optional
26from time import time
27from enum import Enum
28from datetime import datetime, timezone, timedelta
29from importlib.util import spec_from_file_location, module_from_spec
30
31PASS_LOG_LEVEL = logging.ERROR + 1
32TRACE_LOG_LEVEL = logging.DEBUG - 1
33log = logging.getLogger('vmb')
34
35
36def ensure_env_var(var: str) -> str:
37    """Ensure that env variable is set."""
38    val = os.environ.get(var, '')
39    die(not val, 'Please set "%s" env var!', var)
40    return val
41
42
43def log_pass(self, message, *args, **kws):
44    """Inject new log level above info."""
45    if self.isEnabledFor(PASS_LOG_LEVEL):
46        # pylint: disable-next=protected-access
47        self._log(PASS_LOG_LEVEL, message, args, **kws)
48
49
50def log_trace(self, message, *args, **kws):
51    """Inject new log level above info."""
52    if self.isEnabledFor(TRACE_LOG_LEVEL):
53        # pylint: disable-next=protected-access
54        self._log(TRACE_LOG_LEVEL, message, args, **kws)
55
56
57def rnd_string(size=8):
58    """Random string of fixed size."""
59    return ''.join(
60        random.choice(
61            string.ascii_uppercase + string.digits) for _ in range(size))
62
63
64def pad_left(s: str, ln: int, limit: int = 80) -> str:
65    """Reurn string left-padded with spaces."""
66    ln = min(ln, limit)
67    return f'{s[:ln]:<{ln}}'
68
69
70def remove_prefix(s: str, prefix: str) -> str:
71    """Strip prefix from string."""
72    return s[len(prefix):] if s.startswith(prefix) else s
73
74
75def split_params(line: str) -> List[str]:
76    """Split comma-separated string into list."""
77    return [t.strip() for t in line.split(',') if t.strip()]
78
79
80def norm_list(it: Optional[Iterable[str]]) -> List[str]:
81    """Remove duplicates and cast to lower case."""
82    return sorted([t.lower() for t in set(it)]) if it else []
83
84
85def log_time(f: Callable[..., Any]) -> Callable[..., Any]:
86    """Annotation for debug performance."""
87    def f1(*args: Any, **kwargs: Any) -> Any:
88        log.trace('%s started', f.__name__)
89        start = time()
90        ret = f(*args, **kwargs)
91        log.trace('%s finished in %s',
92                  f.__name__,
93                  str(timedelta(seconds=time() - start)))
94        return ret
95    return f1
96
97
98def read_list_file(list_file: Union[str, Path]) -> List[str]:
99    """List file to array."""
100    path = Path(list_file)
101    if not path.is_file():
102        log.error('List file not found: %s', path)
103        return []
104    with open(path, 'r', encoding="utf-8") as f:
105        lst = list(filter(
106            lambda x: x and not x.startswith('#'),
107            f.read().splitlines()))
108    return lst
109
110
111def import_module(module_name, path):
112    """Import py file as a module."""
113    spec = spec_from_file_location(module_name, path)
114    if spec is None:
115        raise RuntimeError(f'Import module from "{path}" failed')
116    module = module_from_spec(spec)
117    loader = spec.loader
118    if loader:
119        loader.exec_module(module)
120    return module
121
122
123def get_plugin(plug_type: str,
124               plug_name: str,
125               extra: Optional[Path] = None) -> Any:
126    """Return plugin."""
127    # try extra dir first
128    if extra:
129        die(not extra.is_dir(),
130            'Extra plugins dir "%s" does not exist!', extra)
131        py = extra.joinpath(plug_type, f'{plug_name}.py')
132        if py.is_file():
133            return import_module(plug_name, str(py))
134    # load default in case there is no extra one
135    py = Path(__file__).parent.resolve().joinpath(
136        'plugins', plug_type, f'{plug_name}.py')
137    if py.is_file():
138        return import_module(plug_name, str(py))
139    log.fatal('No such plugin: "%s"\n'
140              'Searching here: "%s"\n'
141              'To see available plugins: `vmb list`', plug_name, py)
142    sys.exit(1)
143
144
145def get_plugins(plug_type: str,
146                plugins: List[str],
147                extra: Optional[Path]) -> Dict[str, Any]:
148    """Return dict of plugins."""
149    plugs = {}
150    for plug_name in plugins:
151        plug = get_plugin(plug_type, plug_name, extra)
152        plugs[plug_name] = plug
153    return plugs
154
155
156class Timer:
157    """Simple struct for begin-end."""
158
159    tz = datetime.now(timezone.utc).astimezone().tzinfo
160    tm_format = "%Y-%m-%dT%H:%M:%S.00000%z"
161
162    def __init__(self) -> None:
163        self.begin = datetime.now(timezone.utc)
164        self.end = self.begin
165
166    @staticmethod
167    def format(t) -> str:
168        if not isinstance(t, datetime):
169            return 'unknown'
170        return t.astimezone(Timer.tz).strftime(Timer.tm_format)
171
172    def start(self) -> None:
173        self.begin = datetime.now(timezone.utc)
174        self.end = self.begin
175
176    def finish(self) -> None:
177        self.end = datetime.now(timezone.utc)
178
179    def elapsed(self) -> timedelta:
180        return self.end - self.begin
181
182
183class Singleton(type):
184    """Singleton."""
185
186    __instances: Dict[Any, Any] = {}
187
188    def __call__(cls, *args: Any, **kwargs: Any) -> Any:
189        """Instantiante singleton."""
190        if cls not in cls.__instances:
191            cls.__instances[cls] = \
192                super(Singleton, cls).__call__(*args, **kwargs)
193        return cls.__instances[cls]
194
195
196class StringEnum(Enum):
197    """String Enum."""
198
199    def __lt__(self, other):
200        return self.value < other.value
201
202    @classmethod
203    def getall(cls) -> List[str]:
204        return [str(x.value) for x in cls]
205
206
207class Jsonable:
208    """Base class (abstract) for json-serialiazation."""
209
210    @staticmethod
211    def get_props(obj):
212        """Search for properties."""
213        # add all 'public' fields
214        props = {k: v for k, v in obj.__dict__.items()
215                 if not str(k).startswith('_')}
216        # add all @property-decorated fields
217        props.update(
218            {name: value.fget(obj) for name, value
219             in vars(obj.__class__).items()
220             if isinstance(value, property)})
221        return props
222
223    def js(self, sort_keys=False, indent=4) -> str:
224        """Serialize object to json."""
225        return json.dumps(
226            self,
227            default=Jsonable.get_props,
228            sort_keys=sort_keys,
229            indent=indent)
230
231    def save(self, json_file: Union[Path, str]) -> None:
232        with create_file(json_file) as f:
233            f.write(self.js())
234
235
236class ColorFormatter(logging.Formatter):
237    """Colorful log."""
238
239    red = "\x1b[31;20m"
240    green = "\x1b[32;20m"
241    grey = "\x1b[38;20m"
242    magenta = "\x1b[35;20m"
243    cyan = "\x1b[36;1m"
244    yellow = "\x1b[33;20m"
245    bold_red = "\x1b[31;1m"
246    reset = "\x1b[0m"
247    orange = "\x1b[33;20m"
248    bold_blue = "\x1b[34;1m"
249    fmt = '%(message)s'
250
251    FORMATS = {
252        TRACE_LOG_LEVEL: magenta + fmt + reset,
253        logging.DEBUG: bold_blue + fmt + reset,
254        logging.INFO: cyan + fmt + reset,
255        PASS_LOG_LEVEL: green + fmt + reset,
256        logging.WARNING: yellow + fmt + reset,
257        logging.ERROR: red + fmt + reset,
258        logging.CRITICAL: bold_red + fmt + reset
259    }
260
261    def format(self, record):
262        """Format."""
263        log_fmt = self.FORMATS.get(record.levelno)
264        formatter = logging.Formatter(log_fmt)
265        return formatter.format(record)
266
267
268def create_file(path: Union[str, Path]):
269    """Create file in `safe` manner."""
270    return os.fdopen(
271        os.open(path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o664),
272        mode='w', encoding='utf-8')
273
274
275def load_file(path: Union[str, Path]) -> str:
276    """Read file to string."""
277    fd = os.fdopen(
278        os.open(path, os.O_RDONLY), mode="r", encoding='utf-8', buffering=1)
279    fd.seek(0)
280    return fd.read()
281
282
283def force_link(link: Path, dest: Path) -> None:
284    if link.exists():
285        link.unlink()
286    link.symlink_to(dest)
287
288
289def die(condition: bool, *msg) -> None:
290    if condition:
291        log.fatal(*msg)
292        sys.exit(1)
293