• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3
4# Copyright (c) 2025 Huawei Device Co., Ltd.
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18from __future__ import annotations
19
20import re
21import logging
22import argparse
23from typing import List, Dict, Optional, Iterable, Any
24from itertools import combinations, product
25from dataclasses import dataclass, field
26from collections import namedtuple
27from vmb.lang import LangBase
28from vmb.helpers import StringEnum, split_params
29from vmb.cli import Args, add_measurement_opts
30
31log = logging.getLogger('vmb')
32NameVal = namedtuple("NameVal", "name value")
33
34
35class LineParser:
36
37    def __init__(self, lines, pos, lang):
38        self.lines = lines
39        self.pos = pos
40        self.lang = lang
41
42    @property
43    def current(self) -> str:
44        if 0 <= self.pos < len(self.lines):
45            return self.lines[self.pos]
46        return ''
47
48    @classmethod
49    def create(cls, text: str, lang: LangBase):
50        return cls(text.split("\n"), -1, lang)
51
52    def next(self) -> None:
53        if self.pos + 1 > len(self.lines):
54            raise IndexError('No lines left')
55        self.pos += 1
56
57    def skip_empty(self) -> None:
58        for _ in range(4):
59            self.next()
60            line = self.current.strip()
61            if line and not line.startswith('//'):
62                break
63
64
65class Doclet(StringEnum):
66    # Lang-dependent
67    STATE = "State"
68    BENCHMARK = "Benchmark"
69    PARAM = "Param"
70    SETUP = "Setup"
71    RETURNS = "returns"
72    # Lang-agnostic
73    IMPORT = "Import"
74    INCLUDE = "Include"
75    TAGS = "Tags"
76    BUGS = "Bugs"
77    GENERATOR = "Generator"  # Legacy code generation
78
79    @staticmethod
80    def exclusive_doclets() -> Iterable[Doclet]:
81        return Doclet.STATE, Doclet.BENCHMARK, Doclet.SETUP, Doclet.PARAM
82
83
84@dataclass
85class BenchFunc:
86    name: str
87    return_type: Optional[str] = None
88    args: Optional[argparse.Namespace] = None
89    tags: List[str] = field(default_factory=list)
90    bugs: List[str] = field(default_factory=list)
91
92
93@dataclass
94class BenchClass:
95    name: str
96    setup: Optional[str] = None
97    params: Dict[str, List[str]] = field(default_factory=dict)
98    benches: List[BenchFunc] = field(default_factory=list)
99    bench_args: Optional[argparse.Namespace] = None
100    imports: List[str] = field(default_factory=list)
101    includes: List[str] = field(default_factory=list)
102    tags: List[str] = field(default_factory=list)
103    bugs: List[str] = field(default_factory=list)
104    generator: Optional[str] = None
105
106
107class DocletParser(LineParser):
108
109    # Note: strictly this should be `/**` not just `/*`
110    re_comment_open = re.compile(r'^\s*/\*(\*)?\s*$')
111    re_comment_close = re.compile(r'^\s*\*/\s*$')
112    re_doclet = re.compile(r'@+(\w+)\s*(.+)?$')
113    re_returns = re.compile(r'{+(\w+)}')
114
115    def __init__(self, lines, pos, lang) -> None:
116        super().__init__(lines, pos, lang)
117        self.state: Optional[BenchClass] = None
118        self.opts_parser = argparse.ArgumentParser()
119        add_measurement_opts(self.opts_parser)
120        self.__pending_tags: List[str] = []
121        self.__pending_bugs: List[str] = []
122        self.__pending_imports: List[str] = []
123        self.__pending_includes: List[str] = []
124
125    @staticmethod
126    def validate_comment(doclets: List[NameVal]) -> None:
127        doclet_names = [v.name for v in doclets]
128        for d in Doclet.exclusive_doclets():
129            if doclet_names.count(d.value) > 1:
130                raise ValueError(f'Multiple @{d} doclets in same comment')
131        for d1, d2 in combinations(Doclet.exclusive_doclets(), 2):
132            # allow @State + @Benchmark for compatibility reasons
133            if sorted((d1, d2)) == [Doclet.BENCHMARK, Doclet.STATE]:
134                continue
135            if doclet_names.count(d1.value) > 0 and \
136                    doclet_names.count(d2.value) > 0:
137                raise ValueError(
138                    f'@{d1.value} and @{d2.value} doclets in same comment')
139
140    @staticmethod
141    def ensure_value(val: Optional[str]) -> str:
142        if not val:
143            raise ValueError('Empty value!')
144        return val
145
146    @staticmethod
147    def get_rettype(value: str) -> str:
148        m = DocletParser.re_returns.search(value)
149        if m:
150            return m.group(1)
151        return ''
152
153    def doclist(self, comment: str) -> List[NameVal]:
154        doclets = []
155        for line in comment.split("\n"):
156            m = self.re_doclet.search(line)
157            if m:
158                doclets.append(NameVal(*m.groups()))
159        return doclets
160
161    def ensure_state(self) -> BenchClass:
162        if not self.state:
163            raise ValueError('No state found!')
164        return self.state
165
166    def parse_bench_overrides(self, line: str) -> Optional[argparse.Namespace]:
167        """Parse @Benchmark options."""
168        overrides = None
169        if line:
170            overrides, unknown = \
171                self.opts_parser.parse_known_args(line.split())
172            if unknown:
173                raise ValueError(f'Unknown arg to @Benchmark: {unknown}')
174        return overrides
175
176    def process_state(self, benchmarks: List[NameVal],
177                      generators: List[NameVal]) -> None:
178        self.skip_empty()
179        class_name = self.lang.parse_state(self.current)
180        if not class_name:
181            raise ValueError('Bench class declaration not found!')
182        self.state = BenchClass(name=class_name,
183                                tags=self.__pending_tags, bugs=self.__pending_bugs,
184                                imports=self.__pending_imports, includes=self.__pending_includes)
185        self.__pending_tags, self.__pending_bugs = [], []
186        self.__pending_imports, self.__pending_includes = [], []
187        # check if there are overrides for whole class
188        for _, value in benchmarks:
189            self.state.bench_args = self.parse_bench_overrides(value)
190        if generators:
191            self.state.generator = generators[0].value
192
193    def process_benchmark(self, value: str, returns: List[NameVal]) -> None:
194        self.skip_empty()
195        f = self.lang.parse_func(self.current)
196        if not f:
197            raise ValueError('Bench func declaration not found!')
198        overrides = self.parse_bench_overrides(value)
199        ret_type = f[1]
200        if returns:
201            typ = self.get_rettype(returns[0].value)
202            if typ:
203                ret_type = typ
204        self.ensure_state().benches.append(
205            BenchFunc(name=f[0], return_type=ret_type, args=overrides,
206                      tags=self.__pending_tags, bugs=self.__pending_bugs))
207        self.__pending_tags, self.__pending_bugs = [], []
208
209    def process_param(self, param_values: str) -> None:
210        self.skip_empty()
211        p = self.lang.parse_param(self.current)
212        if not p:
213            raise ValueError('Param declaration not found!')
214        self.ensure_state().params[p[0]] = \
215            split_params(self.ensure_value(param_values))
216
217    def process_setup(self) -> None:
218        self.skip_empty()
219        f = self.lang.parse_func(self.current)
220        if not f:
221            raise ValueError('Setup func declaration not found!')
222        self.ensure_state().setup = f[0]
223
224    def process_tag(self, value: str, states: List[NameVal]) -> None:
225        self.__pending_tags += split_params(value)
226        if self.state and states:
227            # only for @State + @Tags
228            self.state.tags += self.__pending_tags
229            self.__pending_tags = []
230
231    def process_bug(self, value: str, states: List[NameVal]) -> None:
232        self.__pending_bugs += split_params(value)
233        if self.state and states:
234            # only for @State + @Bugs
235            self.state.bugs += self.__pending_bugs
236            self.__pending_bugs = []
237
238    def parse_comment(self, comment: str) -> None:
239        """Process all the @Stuff in multiline comment.
240
241        Assuming only one exclusive doclet in same comment
242        Except for @State + @Benchmark (wich is allowed by legacy)
243        @Tags, @Bugs, @Import could co-exist with other @Stuff
244        """
245        doclets = self.doclist(comment)
246        self.validate_comment(doclets)
247
248        def filter_doclets(t: Doclet) -> List[NameVal]:
249            return list(filter(lambda x: x.name == t.value, doclets))
250
251        states = filter_doclets(Doclet.STATE)[:1]
252        benchmarks = filter_doclets(Doclet.BENCHMARK)[:1]
253        generators = filter_doclets(Doclet.GENERATOR)[:1]
254        for _, value in filter_doclets(Doclet.TAGS):
255            self.process_tag(value, states)
256        for _, value in filter_doclets(Doclet.BUGS):
257            self.process_bug(value, states)
258        for _, value in filter_doclets(Doclet.IMPORT):
259            if self.state:
260                self.state.imports.append(value)
261            else:
262                self.__pending_imports.append(value)
263        for _, value in filter_doclets(Doclet.INCLUDE):
264            value = str(value).strip("\'\"")
265            if self.state:
266                self.state.includes.append(value)
267            else:
268                self.__pending_includes.append(value)
269        for _ in states:
270            self.process_state(benchmarks, generators)
271            return
272        for _, param_values in filter_doclets(Doclet.PARAM)[:1]:
273            self.process_param(param_values)
274            return
275        for _ in filter_doclets(Doclet.SETUP)[:1]:
276            self.process_setup()
277            return
278        for _, value in benchmarks:
279            self.process_benchmark(value, filter_doclets(Doclet.RETURNS))
280            return
281
282    def parse(self) -> DocletParser:
283        """Search and parse doclet comments."""
284        comment = ''
285        try:
286            while True:
287                self.next()
288                if not comment and \
289                        re.search(self.re_comment_open, self.current):
290                    comment += "\n"
291                    continue
292                if comment and re.search(self.re_comment_close, self.current):
293                    self.parse_comment(comment)
294                    comment = ''
295                    continue
296                if comment and '@' in self.current:
297                    comment += "\n" + self.current
298        except IndexError:
299            pass
300        return self
301
302
303@dataclass
304class TemplateVars:  # pylint: disable=invalid-name
305    """Params for bench template.
306
307    Names of class props are same as of variables inside template,
308    so this could be provided `as dict` to Template
309    """
310
311    # Full bench source to be pasted into template
312    src: str
313    # Name of main class
314    state_name: str = ''
315    # Setup method call: bench.SomeMethod();'
316    state_setup: str = ''
317    # '\n'-joined list of 'bench.param1=5;'
318    state_params: str = ''
319    # ';'-joined param list of 'param1=5'
320    fixture: str = ''
321    fix_id: int = 0
322    # Name of test method
323    method_name: str = ''
324    method_rettype: str = ''
325    method_call: str = ''
326    bench_name: str = ''
327    bench_path: str = ''
328    common: str = ''  # common feature is obsoleted
329    # this should be the only place with defaults
330    mi: int = 3
331    wi: int = 2
332    it: int = 1
333    wt: int = 1
334    fi: int = 0
335    gc: int = -1
336    tags: Any = None
337    bugs: Any = None
338    imports: Any = None
339    includes: Any = None
340    generator: str = ''
341    config: Dict[str, Any] = field(default_factory=dict)
342    aot_opts: str = ''
343    disable_inlining: bool = False
344
345    @classmethod
346    def params_from_parsed(cls,
347                           src: str,
348                           parsed: BenchClass,
349                           args: Optional[Args] = None
350                           ) -> Iterable[TemplateVars]:
351        """Produce all combinations of Benches and Params."""
352        tags_filter = args.tags if args else []
353        tests_filter = args.tests if args else []
354        skip_tags = args.skip_tags if args else set()
355        # list of lists of tuples (param_name, param_value)
356        # sorting by param name to keep fixture indexing
357        params = [
358            [(p, v) for v in vals]
359            for p, vals
360            in sorted(parsed.params.items())
361        ]
362        fixtures = list(product(*params))
363        for b in parsed.benches:
364            # check tags filter:
365            tags = set(parsed.tags + b.tags)  # @State::@Tags + @Bench::@Tags
366            if skip_tags and set.intersection(tags, skip_tags):
367                log.trace("`%s` skipped by tags: Unwanted: %s Tagged: %s", b.name, skip_tags, tags)
368                continue
369            if tags_filter and not set.intersection(tags, tags_filter):
370                log.trace("`%s` skipped by tags: Wanted: %s Tagged: %s", b.name, tags_filter, tags)
371                continue
372            # if no params fixtures will be [()]
373            fix_id = 0
374            for f in fixtures:
375                tp = cls(src, parsed.name)
376                fix_str = ';'.join([f'{x[0]}={x[1]}' for x in f])
377                tp.generator = parsed.generator if parsed.generator else ''
378                tp.config = {x[0]: x[1] for x in f}
379                if not fix_str:
380                    fix_str = 'No Params'
381                tp.method_name = b.name
382                tp.method_rettype = b.return_type if b.return_type else ''
383                # strictly speaking, this should be lang specific
384                tp.state_params = "\n    ".join([
385                    f'bench.{p} = {v};' for p, v in f])
386                tp.fixture = fix_str
387                tp.fix_id = fix_id
388                tp.bench_name = f'{parsed.name}_{b.name}'
389                if tp.fix_id > 0:
390                    tp.bench_name = f'{tp.bench_name}_{tp.fix_id}'
391                if tests_filter and \
392                        not any((x in tp.bench_name) for x in tests_filter):
393                    fix_id += 1
394                    continue
395                tp.state_setup = f'bench.{parsed.setup}();' \
396                    if parsed.setup else ''
397                tp.tags = tags
398                tp.bugs = set(parsed.bugs + b.bugs)
399                # Override measure settings in following order:
400                # Defaults -> CmdLine -> Class -> Bench
401                tp.set_measure_overrides(args, parsed.bench_args, b.args)
402                tp.imports = parsed.imports
403                tp.includes = parsed.includes
404                yield tp
405                fix_id += 1
406
407    def set_measure_overrides(self, *overrides) -> None:
408        """Override all measurement options."""
409        for ovr in overrides:
410            if not ovr:
411                continue
412            if ovr.measure_iters is not None:
413                self.mi = ovr.measure_iters
414            if ovr.warmup_iters is not None:
415                self.wi = ovr.warmup_iters
416            if ovr.iter_time is not None:
417                self.it = ovr.iter_time
418            if ovr.warmup_time is not None:
419                self.wt = ovr.warmup_time
420            if ovr.fast_iters is not None:
421                self.fi = ovr.fast_iters
422            if ovr.sys_gc_pause is not None:
423                self.gc = ovr.sys_gc_pause
424            if ovr.compiler_inlining == 'false':
425                self.disable_inlining = True
426                self.config.update({'disable_inlining': True})
427            if ovr.aot_compiler_options:
428                opts = ' '.join(ovr.aot_compiler_options)
429                self.aot_opts = f'{self.aot_opts} {opts} '
430                self.config.update({'aot_opts': self.aot_opts})
431