• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3
4# Copyright (c) 2024-2025 Huawei Device Co., Ltd.
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18import logging
19import json
20from typing import List, Iterable, Set, Optional, Dict, Any
21from pathlib import Path
22from shutil import rmtree
23from string import Template
24from dataclasses import asdict
25from collections import namedtuple
26from vmb.helpers import get_plugin, read_list_file, log_time, create_file, die, force_link
27from vmb.unit import BenchUnit, BENCH_PREFIX
28from vmb.cli import Args
29from vmb.lang import LangBase
30from vmb.doclet import DocletParser, TemplateVars
31from vmb.gensettings import GenSettings
32from vmb.shell import ShellUnix
33
34SrcPath = namedtuple("SrcPath", "full rel")
35log = logging.getLogger('vmb')
36
37
38class BenchGenerator:
39    def __init__(self, args: Args) -> None:
40        self.args = args  # need to keep full cmdline for measure overrides
41        self.paths: List[Path] = args.paths
42        self.out_dir: Path = Path(args.outdir).joinpath('benches').resolve()
43        self.override_src_ext: Set[str] = args.src_langs
44        self.extra_plug_dir = args.extra_plugins
45        self.template_dirs: List[Path] = [
46            p.joinpath('templates')
47            # extra first (if exists)
48            for p in (self.extra_plug_dir, Path(__file__).parent) if p]
49        self.abort = args.abort_on_fail
50        if self.out_dir.is_dir():
51            rmtree(str(self.out_dir))
52
53    @staticmethod
54    def search_test_files_in_dir(d: Path,
55                                 root: Path,
56                                 ext: Iterable[str] = (),
57                                 allowed_dir_name: Optional[str] = None) -> List[SrcPath]:
58        if allowed_dir_name:
59            log.trace('Search test files, allowed dir name: %s', allowed_dir_name)
60        files = []
61        for p in d.glob('**/*'):
62            if allowed_dir_name and not p.parent.name == allowed_dir_name:
63                continue
64            if p.parent.parent.name == 'common':
65                continue
66            if p.suffix and p.suffix in ext:
67                log.trace('Src: %s', str(p))
68                full = p.resolve()
69                files.append(
70                    SrcPath(full, full.parent.relative_to(root.resolve())))
71        return files
72
73    @staticmethod
74    def process_test_list(lst: Path, ext: Iterable[str] = (),
75                          allowed_dir_name: Optional[str] = None) -> List[SrcPath]:
76        cwd = Path.cwd().resolve()
77        paths = [cwd.joinpath(p) for p in read_list_file(lst)]
78        files = []
79        for p in paths:
80            if not p.exists():
81                log.error('Path `%s` not found!', str(p))
82            elif p.is_file():  # add file from test list unconditionally
83                files.append(SrcPath(p, Path('.')))
84            else:
85                x = BenchGenerator.search_test_files_in_dir(p, p, ext, allowed_dir_name)
86                files += x
87        return files
88
89    @staticmethod
90    def search_test_files(paths: List[Path],
91                          ext: Iterable[str] = (),
92                          allowed_dir_name: Optional[str] = None) -> List[SrcPath]:
93        """Collect all src files to gen process.
94
95        Returns flat list of (Full, Relative) paths
96        """
97        log.debug('Searching sources: **/*%r', ext)
98        files = []
99        for d in paths:
100            root = d.resolve()
101            # if file name provided add it if suffix matches
102            if root.is_file():
103                if '.lst' == root.suffix:
104                    log.debug('Processing list file: %s', root)
105                    files += BenchGenerator.process_test_list(root, ext, allowed_dir_name)
106                    continue
107                if root.suffix not in ext:
108                    continue
109                files.append(SrcPath(root, Path('.')))
110            # in case of dir search by file extension
111            elif root.is_dir():
112                files += BenchGenerator.search_test_files_in_dir(d, root, ext, allowed_dir_name)
113            else:
114                log.warning('Src: %s not found!', root)
115        return files
116
117    @staticmethod
118    def process_imports(lang_impl: LangBase, imports: str,
119                        bench_dir: Path, src: SrcPath) -> str:
120        """Process @Import and return `import` statement(s)."""
121        import_lines = ''
122        for import_doclet in imports:
123            m = lang_impl.parse_import(import_doclet)
124            if not m:
125                log.warning('Bad import: %s', import_doclet)
126                continue
127            libpath = src.full.parent.joinpath(m[0])
128            if not libpath.is_file():
129                log.warning('Lib does not exist: %s', libpath)
130            else:
131                force_link(bench_dir.joinpath(libpath.name), libpath)
132                import_lines += m[1]
133        return import_lines
134
135    @staticmethod
136    def check_common_files(full: Path, lang_name: str, includes: List[str]) -> str:
137        """Check if there is 'common' code at ../common/ets/*.ets.
138
139        This feature is actually meaningless now
140        and added only for the compatibility with existing tests
141        """
142        src = ''
143        include_paths = [f for sublist in [x.split() for x in includes] for f in sublist]
144        if include_paths:
145            log.trace("Includes: %s", ';'.join(include_paths))
146            for inc in include_paths:
147                include = full.parent.joinpath(inc)
148                if not include.exists():
149                    log.error('Include %s does not exist!', str(include))
150                    continue
151                with open(include, 'r', encoding="utf-8") as f:
152                    src += f.read()
153        if full.parent.name != lang_name:
154            return src
155        common = full.parent.parent.joinpath('common', lang_name)
156        common = common if common.is_dir() else \
157            full.parent.parent.parent.joinpath('common', lang_name)
158        if common.is_dir():
159            log.trace('Common dir: %s', common)
160            for p in common.glob(f'*.{lang_name}'):
161                log.trace('Common file: %s', p)
162                with open(p, 'r', encoding="utf-8") as f:
163                    src += f.read()
164        return src
165
166    @staticmethod
167    def check_resources(full: Path, lang_name: str, dest: Path) -> bool:
168        """Check 'resources' at ../ets/*.ets and link to destdir."""
169        if full.parent.name != lang_name:
170            return False
171        resources = full.parent.parent.joinpath('resources')
172        if resources.is_dir():
173            log.trace('Resources: %s', resources)
174            force_link(dest.joinpath('resources'), resources)
175            return True
176        return False
177
178    @staticmethod
179    def check_native(full: Path, dest: Path, values: Dict[str, Any]) -> bool:
180        """Check 'native' near the source and link to destdir."""
181        native = full.parent.joinpath('native')
182        if not native.is_dir():
183            return False
184        log.debug('Native: %s', native)
185        dest_dir = dest.joinpath('native')
186        dest_dir.mkdir(parents=True, exist_ok=True)
187        for f in native.glob('*'):
188            if f.is_file():
189                dest_file = dest_dir.joinpath(f.name)
190                with open(f, 'r', encoding='utf-8') as t:
191                    native_tpl = t.read()
192                tpl = Template(native_tpl)
193                with create_file(dest_file) as d:
194                    d.write(tpl.substitute(values))
195        return True
196
197    @staticmethod
198    def write_config(bench_dir: Path, values: TemplateVars):
199        with create_file(bench_dir.joinpath('config.json')) as f:
200            f.write(json.dumps(values.config))
201
202    @staticmethod
203    def process_generator(src_full: Path, bench_dir: Path,
204                          values: TemplateVars, ext: str):
205        script = src_full.parent.joinpath(values.generator)
206        cmd = f'{script} {bench_dir} bench_{values.bench_name}{ext}'
207        log.trace('Test generator: %s', script)
208        ShellUnix().run(cmd)
209
210    @staticmethod
211    def emit_bench_variant(values: TemplateVars,
212                           template: Template,
213                           lang_impl: LangBase,
214                           src: SrcPath,
215                           outdir: Path,
216                           outext: str) -> BenchUnit:
217        log.trace('Bench Variant: %s @ %s',
218                  values.bench_name, values.fixture)
219        # create bench unit dir
220        bench_dir = outdir.joinpath(src.rel, f'bu_{values.bench_name}')
221        bench_dir.mkdir(parents=True, exist_ok=True)
222        values.bench_path = str(bench_dir)
223        # process template values
224        tags = set(values.tags)
225        values.tags = ';'.join([str(t) for t in tags])
226        bugs = set(values.bugs)
227        values.bugs = ';'.join([str(t) for t in bugs])
228        values.method_call = lang_impl.get_method_call(
229            values.method_name, values.method_rettype)
230        values.imports = BenchGenerator.process_imports(
231            lang_impl, values.imports, bench_dir, src)
232        values.common = BenchGenerator.check_common_files(
233            src.full, lang_impl.short_name, values.includes)
234        tpl_values = asdict(values)
235        # create links to extra dirs if any
236        custom_values = {
237            'resources': BenchGenerator.check_resources(
238                src.full, lang_impl.short_name, bench_dir),
239            'native': BenchGenerator.check_native(
240                src.full, bench_dir, tpl_values)
241        }
242        tpl_values.update(
243            lang_impl.get_custom_fields(tpl_values, custom_values))
244        # fill template with values
245        bench = template.substitute(tpl_values)
246        bench_file = bench_dir.joinpath(f'bench_{values.bench_name}{outext}')
247        log.trace('Bench: %s', bench_file)
248        with create_file(bench_file) as f:
249            f.write(bench)
250        if values.generator or values.disable_inlining or values.aot_opts:
251            BenchGenerator.write_config(bench_dir, values)
252        if values.generator:
253            BenchGenerator.process_generator(
254                src.full, bench_dir, values, outext)
255        return BenchUnit(bench_dir, src=src.full, tags=tags, bugs=bugs)
256
257    @staticmethod
258    def create_links(bu: BenchUnit, settings: Optional[GenSettings], src: SrcPath) -> None:
259        if not settings:
260            return
261        if settings.link_to_src:
262            link = bu.path.joinpath(
263                f'{BENCH_PREFIX}{bu.name}{Path(src.full).suffix}')
264            force_link(link, src.full)
265        for s in settings.link_to_other_src:
266            for other in Path(src.full).parent.glob(f'*{s}'):
267                force_link(bu.path.joinpath(other.name), other)
268
269    def get_lang(self, lang: str) -> LangBase:
270        lang_plugin = get_plugin('langs', lang, extra=self.extra_plug_dir)
271        lang_impl: LangBase = lang_plugin.Lang()
272        log.info('Using lang: %s', lang_impl.name)
273        return lang_impl
274
275    def get_template(self, name: str) -> Template:
276        for d in self.template_dirs:
277            template_path = d.joinpath(name)
278            if not template_path.exists():
279                continue
280            log.debug('Using template: %s', template_path)
281            with open(template_path, 'r', encoding="utf-8") as f:
282                tpl = Template(f.read())
283            return tpl
284        die(True, f'Template {name} not found!')
285        return Template('')  # make mypy happy
286
287    def process_source_file(self, src: Path, lang: LangBase) -> Iterable[TemplateVars]:
288        with open(src, 'r', encoding="utf-8") as f:
289            full_src = f.read()
290        if '@Benchmark' not in full_src:
291            return []
292        try:
293            parser = DocletParser.create(full_src, lang).parse()
294            if not parser.state:
295                return []
296        except ValueError as e:
297            log.error('%s in %s', e, str(src))
298            die(self.abort, 'Aborting on first error...')
299            return []
300        return TemplateVars.params_from_parsed(
301            full_src, parser.state, args=self.args)
302
303    def add_bu(self, bus: List[BenchUnit], template: Template,
304               lang_impl: LangBase, src: SrcPath, variant: TemplateVars,
305               settings: Optional[GenSettings], out_ext: str) -> None:
306        try:
307            bu = BenchGenerator.emit_bench_variant(
308                variant, template, lang_impl, src, self.out_dir, out_ext)
309            self.create_links(bu, settings, src)
310            bus.append(bu)
311        # pylint: disable-next=broad-exception-caught
312        except Exception as e:
313            log.error(e)
314            die(self.abort, 'Aborting on first fail...')
315
316    def generate(self, lang: str,
317                 settings: Optional[GenSettings] = None) -> List[BenchUnit]:
318        """Generate benchmark sources for requested language."""
319        bus: List[BenchUnit] = []
320        lang_impl = self.get_lang(lang)
321        src_ext = lang_impl.src
322        out_ext = lang_impl.ext
323        template_name = f'Template{lang_impl.ext}'
324        if settings:  # override if set in platform
325            src_ext = settings.src
326            out_ext = settings.out
327            template_name = settings.template
328        if self.override_src_ext:  # override if set in cmdline
329            src_ext = self.override_src_ext
330        template = self.get_template(template_name)
331        for src in BenchGenerator.search_test_files(self.paths, ext=src_ext):
332            for variant in self.process_source_file(src.full, lang_impl):
333                self.add_bu(bus, template, lang_impl, src,
334                            variant, settings, out_ext)
335        return bus
336
337
338@log_time
339def generate_main(args: Args,
340                  settings: Optional[GenSettings] = None) -> List[BenchUnit]:
341    """Command: Generate benches from doclets."""
342    log.info("Starting GEN phase...")
343    log.trace("GEN phase args:  %s", args)
344    generator = BenchGenerator(args)
345    bus: List[BenchUnit] = []
346    for lang in args.langs:
347        bus += generator.generate(lang, settings=settings)
348    log.passed('Generated %d bench units', len(bus))
349    return bus
350
351
352if __name__ == '__main__':
353    generate_main(Args())
354