• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3
4# Copyright (c) 2024 Huawei Device Co., Ltd.
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18import logging
19import json
20from typing import List, Iterable, Set, Optional
21from pathlib import Path
22from shutil import rmtree
23from string import Template
24from dataclasses import asdict
25from collections import namedtuple
26from vmb.helpers import get_plugin, read_list_file, \
27    log_time, create_file, die, force_link
28from vmb.unit import BenchUnit, BENCH_PREFIX
29from vmb.cli import Args
30from vmb.lang import LangBase
31from vmb.doclet import DocletParser, TemplateVars
32from vmb.gensettings import GenSettings
33from vmb.shell import ShellUnix
34
35SrcPath = namedtuple("SrcPath", "full rel")
36log = logging.getLogger('vmb')
37
38
39class BenchGenerator:
40    def __init__(self, args: Args) -> None:
41        self.args = args  # need to keeep full cmdline for measure overrides
42        self.paths: List[Path] = args.paths
43        self.out_dir: Path = Path(args.outdir).joinpath('benches').resolve()
44        self.override_src_ext: Set[str] = args.src_langs
45        self.extra_plug_dir = args.extra_plugins
46        self.template_dirs: List[Path] = [
47            p.joinpath('templates')
48            # extra first (if exists)
49            for p in (self.extra_plug_dir, Path(__file__).parent) if p]
50        self.abort = args.abort_on_fail
51        if self.out_dir.is_dir():
52            rmtree(str(self.out_dir))
53
54    @staticmethod
55    def search_test_files_in_dir(d: Path,
56                                 root: Path,
57                                 ext: Iterable[str] = ()) -> List[SrcPath]:
58        files = []
59        for p in d.glob('**/*'):
60            if p.parent.parent.name == 'common':
61                continue
62            if p.suffix and p.suffix in ext:
63                log.trace('Src: %s', str(p))
64                full = p.resolve()
65                files.append(
66                    SrcPath(full, full.parent.relative_to(root)))
67        return files
68
69    @staticmethod
70    def process_test_list(lst: Path, ext: Iterable[str] = ()) -> List[SrcPath]:
71        cwd = Path.cwd().resolve()
72        paths = [cwd.joinpath(p) for p in read_list_file(lst)]
73        files = []
74        for p in paths:
75            x = BenchGenerator.search_test_files_in_dir(p, p, ext)
76            files += x
77        return files
78
79    @staticmethod
80    def search_test_files(paths: List[Path],
81                          ext: Iterable[str] = ()) -> List[SrcPath]:
82        """Collect all src files to gen process.
83
84        Returns flat list of (Full, Relative) paths
85        """
86        log.debug('Searching sources: **/*%r', ext)
87        files = []
88        for d in paths:
89            root = d.resolve()
90            # if file name provided add it if suffix matches
91            if root.is_file():
92                if '.lst' == root.suffix:
93                    log.debug('Processing list file: %s', root)
94                    files += BenchGenerator.process_test_list(root, ext)
95                    continue
96                if root.suffix not in ext:
97                    continue
98                files.append(SrcPath(root, Path('.')))
99            # in case of dir search by file extention
100            elif root.is_dir():
101                files += BenchGenerator.search_test_files_in_dir(d, root, ext)
102            else:
103                log.warning('Src: %s not found!', root)
104        return files
105
106    @staticmethod
107    def process_imports(lang_impl: LangBase, imports: str,
108                        bench_dir: Path, src: SrcPath) -> str:
109        """Process @Import and return `import` statement(s)."""
110        import_lines = ''
111        for import_doclet in imports:
112            m = lang_impl.parse_import(import_doclet)
113            if not m:
114                log.warning('Bad import: %s', import_doclet)
115                continue
116            libpath = src.full.parent.joinpath(m[0])
117            if not libpath.is_file():
118                log.warning('Lib does not exist: %s', libpath)
119            else:
120                force_link(bench_dir.joinpath(libpath.name), libpath)
121                import_lines += m[1]
122        return import_lines
123
124    @staticmethod
125    def check_common_files(full: Path, lang_name: str) -> str:
126        """Check if there is 'common' code at ../common/sts/*.sts.
127
128        This feature is actually meaningless now
129        and added only for the compatibility with existing tests
130        """
131        if full.parent.name != lang_name:
132            return ''
133        common = full.parent.parent.joinpath('common', lang_name)
134        common = common if common.is_dir() else \
135            full.parent.parent.parent.joinpath('common', lang_name)
136        if common.is_dir():
137            src = ''
138            log.trace('Common dir: %s', common)
139            for p in common.glob(f'*.{lang_name}'):
140                log.trace('Common file: %s', p)
141                with open(p, 'r', encoding="utf-8") as f:
142                    src += f.read()
143            return src
144        return ''
145
146    @staticmethod
147    def check_resources(full: Path, lang_name: str, dest: Path) -> bool:
148        """Check 'resources' at ../sts/*.sts and link to destdir."""
149        if full.parent.name != lang_name:
150            return False
151        resources = full.parent.parent.joinpath('resources')
152        if resources.is_dir():
153            log.trace('Resources: %s', resources)
154            force_link(dest.joinpath('resources'), resources)
155            return True
156        return False
157
158    @staticmethod
159    def check_native(full: Path, dest: Path) -> bool:
160        """Check 'native' near the source and link to destdir."""
161        native = full.parent.joinpath('native')
162        if native.is_dir():
163            log.trace('Native: %s', native)
164            force_link(dest.joinpath('native'), native)
165            return True
166        return False
167
168    @staticmethod
169    def write_config(bench_dir: Path, values: TemplateVars):
170        with create_file(bench_dir.joinpath('config.json')) as f:
171            f.write(json.dumps(values.config))
172
173    @staticmethod
174    def process_generator(src_full: Path, bench_dir: Path,
175                          values: TemplateVars, ext: str):
176        script = src_full.parent.joinpath(values.generator)
177        cmd = f'{script} {bench_dir} bench_{values.bench_name}{ext}'
178        log.trace('Test generator: %s', script)
179        ShellUnix().run(cmd)
180
181    @staticmethod
182    def emit_bench_variant(values: TemplateVars,
183                           template: Template,
184                           lang_impl: LangBase,
185                           src: SrcPath,
186                           outdir: Path,
187                           outext: str) -> BenchUnit:
188        log.trace('Bench Variant: %s @ %s',
189                  values.bench_name, values.fixture)
190        # create bench unit dir
191        bench_dir = outdir.joinpath(src.rel, f'bu_{values.bench_name}')
192        bench_dir.mkdir(parents=True, exist_ok=True)
193        values.bench_path = str(bench_dir)
194        # process template values
195        tags = set(values.tags)
196        values.tags = ';'.join([str(t) for t in tags])
197        bugs = set(values.bugs)
198        values.bugs = ';'.join([str(t) for t in bugs])
199        values.method_call = lang_impl.get_method_call(
200            values.method_name, values.method_rettype)
201        values.imports = BenchGenerator.process_imports(
202            lang_impl, values.imports, bench_dir, src)
203        values.common = BenchGenerator.check_common_files(
204            src.full, lang_impl.short_name)
205        # create links to extra dirs if any
206        custom_values = {
207            'resources': BenchGenerator.check_resources(
208                src.full, lang_impl.short_name, bench_dir),
209            'native': BenchGenerator.check_native(
210                src.full, bench_dir)
211        }
212        tpl_values = asdict(values)
213        tpl_values.update(
214            lang_impl.get_custom_fields(tpl_values, custom_values))
215        # fill template with values
216        bench = template.substitute(tpl_values)
217        bench_file = bench_dir.joinpath(f'bench_{values.bench_name}{outext}')
218        log.trace('Bench: %s', bench_file)
219        with create_file(bench_file) as f:
220            f.write(bench)
221        if values.generator or values.disable_inlining or values.aot_opts:
222            BenchGenerator.write_config(bench_dir, values)
223        if values.generator:
224            BenchGenerator.process_generator(
225                src.full, bench_dir, values, outext)
226        return BenchUnit(bench_dir, src=src.full, tags=tags, bugs=bugs)
227
228    def get_lang(self, lang: str) -> LangBase:
229        lang_plugin = get_plugin('langs', lang, extra=self.extra_plug_dir)
230        lang_impl: LangBase = lang_plugin.Lang()
231        log.info('Using lang: %s', lang_impl.name)
232        return lang_impl
233
234    def get_template(self, name: str) -> Template:
235        for d in self.template_dirs:
236            template_path = d.joinpath(name)
237            if not template_path.exists():
238                continue
239            log.debug('Using template: %s', template_path)
240            with open(template_path, 'r', encoding="utf-8") as f:
241                tpl = Template(f.read())
242            return tpl
243        raise RuntimeError(f'Template {name} not found!')
244
245    def process_source_file(self, src: Path, lang: LangBase
246                            ) -> Iterable[TemplateVars]:
247        with open(src, 'r', encoding="utf-8") as f:
248            full_src = f.read()
249        if '@Benchmark' not in full_src:
250            return []
251        try:
252            parser = DocletParser.create(full_src, lang).parse()
253            if not parser.state:
254                return []
255        except ValueError as e:
256            log.error('%s in %s', e, str(src))
257            die(self.abort, 'Aborting on first error...')
258            return []
259        return TemplateVars.params_from_parsed(
260            full_src, parser.state, args=self.args)
261
262    def add_bu(self, bus: List[BenchUnit], template: Template,
263               lang_impl: LangBase, src: SrcPath, variant: TemplateVars,
264               settings: Optional[GenSettings], out_ext: str) -> None:
265        try:
266            bu = BenchGenerator.emit_bench_variant(
267                variant, template, lang_impl, src, self.out_dir, out_ext)
268            if settings and settings.link_to_src:
269                link = bu.path.joinpath(
270                    f'{BENCH_PREFIX}{bu.name}{Path(src.full).suffix}')
271                force_link(link, src.full)
272            bus.append(bu)
273        # pylint: disable-next=broad-exception-caught
274        except Exception as e:
275            log.error(e)
276            die(self.abort, 'Aborting on first fail...')
277
278    def generate(self, lang: str,
279                 settings: Optional[GenSettings] = None) -> List[BenchUnit]:
280        """Generate benchmark sources for requested language."""
281        bus: List[BenchUnit] = []
282        lang_impl = self.get_lang(lang)
283        src_ext = lang_impl.src
284        out_ext = lang_impl.ext
285        template_name = f'Template{lang_impl.ext}'
286        if settings:  # override if set in platform
287            src_ext = settings.src
288            out_ext = settings.out
289            template_name = settings.template
290        if self.override_src_ext:  # override if set in cmdline
291            src_ext = self.override_src_ext
292        template = self.get_template(template_name)
293        for src in BenchGenerator.search_test_files(self.paths, ext=src_ext):
294            for variant in self.process_source_file(src.full, lang_impl):
295                self.add_bu(bus, template, lang_impl, src,
296                            variant, settings, out_ext)
297        return bus
298
299
300@log_time
301def generate_main(args: Args,
302                  settings: Optional[GenSettings] = None) -> List[BenchUnit]:
303    """Command: Generate benches from doclets."""
304    log.info("Starting GEN phase...")
305    generator = BenchGenerator(args)
306    bus: List[BenchUnit] = []
307    for lang in args.langs:
308        bus += generator.generate(lang, settings=settings)
309    log.passed('Generated %d bench units', len(bus))
310    return bus
311
312
313if __name__ == '__main__':
314    generate_main(Args())
315