• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3
4# Copyright (c) 2024 Huawei Device Co., Ltd.
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18from __future__ import annotations
19
20import re
21import logging
22import argparse
23from typing import List, Dict, Optional, Iterable, Any
24from itertools import combinations, product
25from dataclasses import dataclass, field
26from collections import namedtuple
27from vmb.lang import LangBase
28from vmb.helpers import StringEnum, split_params
29from vmb.cli import Args, add_measurement_opts
30
31log = logging.getLogger('vmb')
32NameVal = namedtuple("NameVal", "name value")
33
34
35class LineParser:
36
37    def __init__(self, lines, pos, lang):
38        self.lines = lines
39        self.pos = pos
40        self.lang = lang
41
42    @property
43    def current(self) -> str:
44        if 0 <= self.pos < len(self.lines):
45            return self.lines[self.pos]
46        return ''
47
48    @classmethod
49    def create(cls, text: str, lang: LangBase):
50        return cls(text.split("\n"), -1, lang)
51
52    def next(self) -> None:
53        if self.pos + 1 > len(self.lines):
54            raise IndexError('No lines left')
55        self.pos += 1
56
57    def skip_empty(self) -> None:
58        for _ in range(4):
59            self.next()
60            line = self.current.strip()
61            if line and not line.startswith('//'):
62                break
63
64
65class Doclet(StringEnum):
66    # Lang-dependent
67    STATE = "State"
68    BENCHMARK = "Benchmark"
69    PARAM = "Param"
70    SETUP = "Setup"
71    RETURNS = "returns"
72    # Lang-agnostic
73    IMPORT = "Import"
74    TAGS = "Tags"
75    BUGS = "Bugs"
76    GENERATOR = "Generator"  # Legacy code generation
77
78    @staticmethod
79    def exclusive_doclets() -> Iterable[Doclet]:
80        return Doclet.STATE, Doclet.BENCHMARK, Doclet.SETUP, Doclet.PARAM
81
82
83@dataclass
84class BenchFunc:
85    name: str
86    return_type: Optional[str] = None
87    args: Optional[argparse.Namespace] = None
88    tags: List[str] = field(default_factory=list)
89    bugs: List[str] = field(default_factory=list)
90
91
92@dataclass
93class BenchClass:
94    name: str
95    setup: Optional[str] = None
96    params: Dict[str, List[str]] = field(default_factory=dict)
97    benches: List[BenchFunc] = field(default_factory=list)
98    bench_args: Optional[argparse.Namespace] = None
99    imports: List[str] = field(default_factory=list)
100    tags: List[str] = field(default_factory=list)
101    bugs: List[str] = field(default_factory=list)
102    generator: Optional[str] = None
103
104
105class DocletParser(LineParser):
106
107    # Note: strictly this should be `/**` not just `/*`
108    re_comment_open = re.compile(r'^\s*/\*(\*)?\s*$')
109    re_comment_close = re.compile(r'^\s*\*/\s*$')
110    re_doclet = re.compile(r'@+(\w+)\s*(.+)?$')
111    re_returns = re.compile(r'{+(\w+)}')
112
113    def __init__(self, lines, pos, lang) -> None:
114        super().__init__(lines, pos, lang)
115        self.state: Optional[BenchClass] = None
116        self.opts_parser = argparse.ArgumentParser()
117        add_measurement_opts(self.opts_parser)
118        self.__pending_tags: List[str] = []
119        self.__pending_bugs: List[str] = []
120        self.__pending_imports: List[str] = []
121
122    @staticmethod
123    def validate_comment(doclets: List[NameVal]) -> None:
124        doclet_names = [v.name for v in doclets]
125        for d in Doclet.exclusive_doclets():
126            if doclet_names.count(d.value) > 1:
127                raise ValueError(f'Multiple @{d} doclets in same comment')
128        for d1, d2 in combinations(Doclet.exclusive_doclets(), 2):
129            # allow @State + @Benchmark for compatibility reasons
130            if sorted((d1, d2)) == [Doclet.BENCHMARK, Doclet.STATE]:
131                continue
132            if doclet_names.count(d1.value) > 0 and \
133                    doclet_names.count(d2.value) > 0:
134                raise ValueError(
135                    f'@{d1.value} and @{d2.value} doclets in same comment')
136
137    @staticmethod
138    def ensure_value(val: Optional[str]) -> str:
139        if not val:
140            raise ValueError('Empty value!')
141        return val
142
143    @staticmethod
144    def get_rettype(value: str) -> str:
145        m = DocletParser.re_returns.search(value)
146        if m:
147            return m.group(1)
148        return ''
149
150    def doclist(self, comment: str) -> List[NameVal]:
151        doclets = []
152        for line in comment.split("\n"):
153            m = self.re_doclet.search(line)
154            if m:
155                doclets.append(NameVal(*m.groups()))
156        return doclets
157
158    def ensure_state(self) -> BenchClass:
159        if not self.state:
160            raise ValueError('No state found!')
161        return self.state
162
163    def parse_bench_overrides(self, line: str) -> Optional[argparse.Namespace]:
164        """Parse @Benchmark options."""
165        overrides = None
166        if line:
167            overrides, unknown = \
168                self.opts_parser.parse_known_args(line.split())
169            if unknown:
170                raise ValueError(f'Unknown arg to @Benchmark: {unknown}')
171        return overrides
172
173    def process_state(self, benchmarks: List[NameVal],
174                      generators: List[NameVal]) -> None:
175        self.skip_empty()
176        class_name = self.lang.parse_state(self.current)
177        if not class_name:
178            raise ValueError('Bench class declaration not found!')
179        self.state = BenchClass(name=class_name, tags=self.__pending_tags,
180                                bugs=self.__pending_bugs,
181                                imports=self.__pending_imports)
182        self.__pending_tags, self.__pending_bugs = [], []
183        self.__pending_imports = []
184        # check if there are overrides for whole class
185        for _, value in benchmarks:
186            self.state.bench_args = self.parse_bench_overrides(value)
187        if generators:
188            self.state.generator = generators[0].value
189
190    def process_benchmark(self, value: str, returns: List[NameVal]) -> None:
191        self.skip_empty()
192        f = self.lang.parse_func(self.current)
193        if not f:
194            raise ValueError('Bench func declaration not found!')
195        overrides = self.parse_bench_overrides(value)
196        ret_type = f[1]
197        if returns:
198            typ = self.get_rettype(returns[0].value)
199            if typ:
200                ret_type = typ
201        self.ensure_state().benches.append(
202            BenchFunc(name=f[0], return_type=ret_type, args=overrides,
203                      tags=self.__pending_tags, bugs=self.__pending_bugs))
204        self.__pending_tags, self.__pending_bugs = [], []
205
206    def process_param(self, param_values: str) -> None:
207        self.skip_empty()
208        p = self.lang.parse_param(self.current)
209        if not p:
210            raise ValueError('Param declaration not found!')
211        self.ensure_state().params[p[0]] = \
212            split_params(self.ensure_value(param_values))
213
214    def process_setup(self) -> None:
215        self.skip_empty()
216        f = self.lang.parse_func(self.current)
217        if not f:
218            raise ValueError('Setup func declaration not found!')
219        self.ensure_state().setup = f[0]
220
221    def process_tag(self, value: str, states: List[NameVal]) -> None:
222        self.__pending_tags += split_params(value)
223        if self.state and states:
224            # only for @State + @Tags
225            self.state.tags += self.__pending_tags
226            self.__pending_tags = []
227
228    def process_bug(self, value: str, states: List[NameVal]) -> None:
229        self.__pending_bugs += split_params(value)
230        if self.state and states:
231            # only for @State + @Bugs
232            self.state.bugs += self.__pending_bugs
233            self.__pending_bugs = []
234
235    def parse_comment(self, comment: str) -> None:
236        """Process all the @Stuff in multiline comment.
237
238        Assuming only one exclusive doclet in same comment
239        Except for @State + @Benchmark (wich is allowed by legacy)
240        @Tags, @Bugs, @Import could co-exist with other @Stuff
241        """
242        doclets = self.doclist(comment)
243        self.validate_comment(doclets)
244
245        def filter_doclets(t: Doclet) -> List[NameVal]:
246            return list(filter(lambda x: x.name == t.value, doclets))
247
248        states = filter_doclets(Doclet.STATE)[:1]
249        benchmarks = filter_doclets(Doclet.BENCHMARK)[:1]
250        generators = filter_doclets(Doclet.GENERATOR)[:1]
251        for _, value in filter_doclets(Doclet.TAGS):
252            self.process_tag(value, states)
253        for _, value in filter_doclets(Doclet.BUGS):
254            self.process_bug(value, states)
255        for _, value in filter_doclets(Doclet.IMPORT):
256            if self.state:
257                self.state.imports.append(value)
258            else:
259                self.__pending_imports.append(value)
260        for _ in states:
261            self.process_state(benchmarks, generators)
262            return
263        for _, param_values in filter_doclets(Doclet.PARAM)[:1]:
264            self.process_param(param_values)
265            return
266        for _ in filter_doclets(Doclet.SETUP)[:1]:
267            self.process_setup()
268            return
269        for _, value in benchmarks:
270            self.process_benchmark(value, filter_doclets(Doclet.RETURNS))
271            return
272
273    def parse(self) -> DocletParser:
274        """Search and parse doclet comments."""
275        comment = ''
276        try:
277            while True:
278                self.next()
279                if not comment and \
280                        re.search(self.re_comment_open, self.current):
281                    comment += "\n"
282                    continue
283                if comment and re.search(self.re_comment_close, self.current):
284                    self.parse_comment(comment)
285                    comment = ''
286                    continue
287                if comment and '@' in self.current:
288                    comment += "\n" + self.current
289        except IndexError:
290            pass
291        return self
292
293
294@dataclass
295class TemplateVars:  # pylint: disable=invalid-name
296    """Params for bench template.
297
298    Names of class props are same as of variables inside template,
299    so this could be provided `as dict` to Template
300    """
301
302    # Full bench source to be pasted into template
303    src: str
304    # Name of main class
305    state_name: str = ''
306    # Setup method call: bench.SomeMethod();'
307    state_setup: str = ''
308    # '\n'-joined list of 'bench.param1=5;'
309    state_params: str = ''
310    # ';'-joined param list of 'param1=5'
311    fixture: str = ''
312    fix_id: int = 0
313    # Name of test method
314    method_name: str = ''
315    method_rettype: str = ''
316    method_call: str = ''
317    bench_name: str = ''
318    bench_path: str = ''
319    common: str = ''  # common feature is obsoleted
320    # this should be the only place with defaults
321    mi: int = 3
322    wi: int = 2
323    it: int = 1
324    wt: int = 1
325    fi: int = 0
326    gc: int = -1
327    tags: Any = None
328    bugs: Any = None
329    imports: Any = None
330    generator: str = ''
331    config: Dict[str, Any] = field(default_factory=dict)
332    aot_opts: str = ''
333    disable_inlining: bool = False
334
335    @classmethod
336    def params_from_parsed(cls,
337                           src: str,
338                           parsed: BenchClass,
339                           args: Optional[Args] = None
340                           ) -> Iterable[TemplateVars]:
341        """Produce all combinations of Benches and Params."""
342        tags_filter = args.tags if args else []
343        tests_filter = args.tests if args else []
344        skip_tags = args.skip_tags if args else set()
345        # list of lists of tuples (param_name, param_value)
346        # sorting by param name to keep fixture indexing
347        params = [
348            [(p, v) for v in vals]
349            for p, vals
350            in sorted(parsed.params.items())
351        ]
352        fixtures = list(product(*params))
353        for b in parsed.benches:
354            # check tags filter:
355            tags = set(parsed.tags + b.tags)  # @State::@Tags + @Bench::@Tags
356            if skip_tags and set.intersection(tags, skip_tags):
357                continue
358            if tags_filter and not set.intersection(tags, tags_filter):
359                continue
360            # if no params fixtures will be [()]
361            fix_id = 0
362            for f in fixtures:
363                tp = cls(src, parsed.name)
364                fix_str = ';'.join([f'{x[0]}={x[1]}' for x in f])
365                tp.generator = parsed.generator if parsed.generator else ''
366                tp.config = {x[0]: x[1] for x in f}
367                if not fix_str:
368                    fix_str = 'No Params'
369                tp.method_name = b.name
370                tp.method_rettype = b.return_type if b.return_type else ''
371                # strictly speaking, this should be lang specific
372                tp.state_params = "\n    ".join([
373                    f'bench.{p} = {v};' for p, v in f])
374                tp.fixture = fix_str
375                tp.fix_id = fix_id
376                tp.bench_name = f'{parsed.name}_{b.name}'
377                if tp.fix_id > 0:
378                    tp.bench_name = f'{tp.bench_name}_{tp.fix_id}'
379                if tests_filter and \
380                        not any((x in tp.bench_name) for x in tests_filter):
381                    fix_id += 1
382                    continue
383                tp.state_setup = f'bench.{parsed.setup}();' \
384                    if parsed.setup else ''
385                tp.tags = tags
386                tp.bugs = set(parsed.bugs + b.bugs)
387                # Override measure settings in following order:
388                # Defaults -> CmdLine -> Class -> Bench
389                tp.set_measure_overrides(args, parsed.bench_args, b.args)
390                tp.imports = parsed.imports
391                yield tp
392                fix_id += 1
393
394    def set_measure_overrides(self, *overrides) -> None:
395        """Override all measurement options."""
396        for ovr in overrides:
397            if not ovr:
398                continue
399            if ovr.measure_iters is not None:
400                self.mi = ovr.measure_iters
401            if ovr.warmup_iters is not None:
402                self.wi = ovr.warmup_iters
403            if ovr.iter_time is not None:
404                self.it = ovr.iter_time
405            if ovr.warmup_time is not None:
406                self.wt = ovr.warmup_time
407            if ovr.fast_iters is not None:
408                self.fi = ovr.fast_iters
409            if ovr.sys_gc_pause is not None:
410                self.gc = ovr.sys_gc_pause
411            if ovr.compiler_inlining == 'false':
412                self.disable_inlining = True
413                self.config.update({'disable_inlining': True})
414            if ovr.aot_compiler_options:
415                opts = ' '.join(ovr.aot_compiler_options)
416                self.aot_opts = f'{self.aot_opts} {opts} '
417                self.config.update({'aot_opts': self.aot_opts})
418