1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3 4# Copyright (c) 2025 Huawei Device Co., Ltd. 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# 17 18from __future__ import annotations 19 20import re 21import logging 22import argparse 23from typing import List, Dict, Optional, Iterable, Any 24from itertools import combinations, product 25from dataclasses import dataclass, field 26from collections import namedtuple 27from vmb.lang import LangBase 28from vmb.helpers import StringEnum, split_params 29from vmb.cli import Args, add_measurement_opts 30 31log = logging.getLogger('vmb') 32NameVal = namedtuple("NameVal", "name value") 33 34 35class LineParser: 36 37 def __init__(self, lines, pos, lang): 38 self.lines = lines 39 self.pos = pos 40 self.lang = lang 41 42 @property 43 def current(self) -> str: 44 if 0 <= self.pos < len(self.lines): 45 return self.lines[self.pos] 46 return '' 47 48 @classmethod 49 def create(cls, text: str, lang: LangBase): 50 return cls(text.split("\n"), -1, lang) 51 52 def next(self) -> None: 53 if self.pos + 1 > len(self.lines): 54 raise IndexError('No lines left') 55 self.pos += 1 56 57 def skip_empty(self) -> None: 58 for _ in range(4): 59 self.next() 60 line = self.current.strip() 61 if line and not line.startswith('//'): 62 break 63 64 65class Doclet(StringEnum): 66 # Lang-dependent 67 STATE = "State" 68 BENCHMARK = "Benchmark" 69 PARAM = "Param" 70 SETUP = "Setup" 71 RETURNS = "returns" 72 # Lang-agnostic 73 IMPORT = "Import" 74 INCLUDE = "Include" 75 TAGS = "Tags" 76 BUGS = "Bugs" 77 GENERATOR = "Generator" # Legacy code generation 78 79 @staticmethod 80 def exclusive_doclets() -> Iterable[Doclet]: 81 return Doclet.STATE, Doclet.BENCHMARK, Doclet.SETUP, Doclet.PARAM 82 83 84@dataclass 85class BenchFunc: 86 name: str 87 return_type: Optional[str] = None 88 args: Optional[argparse.Namespace] = None 89 tags: List[str] = field(default_factory=list) 90 bugs: List[str] = field(default_factory=list) 91 92 93@dataclass 94class BenchClass: 95 name: str 96 setup: Optional[str] = None 97 params: Dict[str, List[str]] = field(default_factory=dict) 98 benches: List[BenchFunc] = field(default_factory=list) 99 bench_args: Optional[argparse.Namespace] = None 100 imports: List[str] = field(default_factory=list) 101 includes: List[str] = field(default_factory=list) 102 tags: List[str] = field(default_factory=list) 103 bugs: List[str] = field(default_factory=list) 104 generator: Optional[str] = None 105 106 107class DocletParser(LineParser): 108 109 # Note: strictly this should be `/**` not just `/*` 110 re_comment_open = re.compile(r'^\s*/\*(\*)?\s*$') 111 re_comment_close = re.compile(r'^\s*\*/\s*$') 112 re_doclet = re.compile(r'@+(\w+)\s*(.+)?$') 113 re_returns = re.compile(r'{+(\w+)}') 114 115 def __init__(self, lines, pos, lang) -> None: 116 super().__init__(lines, pos, lang) 117 self.state: Optional[BenchClass] = None 118 self.opts_parser = argparse.ArgumentParser() 119 add_measurement_opts(self.opts_parser) 120 self.__pending_tags: List[str] = [] 121 self.__pending_bugs: List[str] = [] 122 self.__pending_imports: List[str] = [] 123 self.__pending_includes: List[str] = [] 124 125 @staticmethod 126 def validate_comment(doclets: List[NameVal]) -> None: 127 doclet_names = [v.name for v in doclets] 128 for d in Doclet.exclusive_doclets(): 129 if doclet_names.count(d.value) > 1: 130 raise ValueError(f'Multiple @{d} doclets in same comment') 131 for d1, d2 in combinations(Doclet.exclusive_doclets(), 2): 132 # allow @State + @Benchmark for compatibility reasons 133 if sorted((d1, d2)) == [Doclet.BENCHMARK, Doclet.STATE]: 134 continue 135 if doclet_names.count(d1.value) > 0 and \ 136 doclet_names.count(d2.value) > 0: 137 raise ValueError( 138 f'@{d1.value} and @{d2.value} doclets in same comment') 139 140 @staticmethod 141 def ensure_value(val: Optional[str]) -> str: 142 if not val: 143 raise ValueError('Empty value!') 144 return val 145 146 @staticmethod 147 def get_rettype(value: str) -> str: 148 m = DocletParser.re_returns.search(value) 149 if m: 150 return m.group(1) 151 return '' 152 153 def doclist(self, comment: str) -> List[NameVal]: 154 doclets = [] 155 for line in comment.split("\n"): 156 m = self.re_doclet.search(line) 157 if m: 158 doclets.append(NameVal(*m.groups())) 159 return doclets 160 161 def ensure_state(self) -> BenchClass: 162 if not self.state: 163 raise ValueError('No state found!') 164 return self.state 165 166 def parse_bench_overrides(self, line: str) -> Optional[argparse.Namespace]: 167 """Parse @Benchmark options.""" 168 overrides = None 169 if line: 170 overrides, unknown = \ 171 self.opts_parser.parse_known_args(line.split()) 172 if unknown: 173 raise ValueError(f'Unknown arg to @Benchmark: {unknown}') 174 return overrides 175 176 def process_state(self, benchmarks: List[NameVal], 177 generators: List[NameVal]) -> None: 178 self.skip_empty() 179 class_name = self.lang.parse_state(self.current) 180 if not class_name: 181 raise ValueError('Bench class declaration not found!') 182 self.state = BenchClass(name=class_name, 183 tags=self.__pending_tags, bugs=self.__pending_bugs, 184 imports=self.__pending_imports, includes=self.__pending_includes) 185 self.__pending_tags, self.__pending_bugs = [], [] 186 self.__pending_imports, self.__pending_includes = [], [] 187 # check if there are overrides for whole class 188 for _, value in benchmarks: 189 self.state.bench_args = self.parse_bench_overrides(value) 190 if generators: 191 self.state.generator = generators[0].value 192 193 def process_benchmark(self, value: str, returns: List[NameVal]) -> None: 194 self.skip_empty() 195 f = self.lang.parse_func(self.current) 196 if not f: 197 raise ValueError('Bench func declaration not found!') 198 overrides = self.parse_bench_overrides(value) 199 ret_type = f[1] 200 if returns: 201 typ = self.get_rettype(returns[0].value) 202 if typ: 203 ret_type = typ 204 self.ensure_state().benches.append( 205 BenchFunc(name=f[0], return_type=ret_type, args=overrides, 206 tags=self.__pending_tags, bugs=self.__pending_bugs)) 207 self.__pending_tags, self.__pending_bugs = [], [] 208 209 def process_param(self, param_values: str) -> None: 210 self.skip_empty() 211 p = self.lang.parse_param(self.current) 212 if not p: 213 raise ValueError('Param declaration not found!') 214 self.ensure_state().params[p[0]] = \ 215 split_params(self.ensure_value(param_values)) 216 217 def process_setup(self) -> None: 218 self.skip_empty() 219 f = self.lang.parse_func(self.current) 220 if not f: 221 raise ValueError('Setup func declaration not found!') 222 self.ensure_state().setup = f[0] 223 224 def process_tag(self, value: str, states: List[NameVal]) -> None: 225 self.__pending_tags += split_params(value) 226 if self.state and states: 227 # only for @State + @Tags 228 self.state.tags += self.__pending_tags 229 self.__pending_tags = [] 230 231 def process_bug(self, value: str, states: List[NameVal]) -> None: 232 self.__pending_bugs += split_params(value) 233 if self.state and states: 234 # only for @State + @Bugs 235 self.state.bugs += self.__pending_bugs 236 self.__pending_bugs = [] 237 238 def parse_comment(self, comment: str) -> None: 239 """Process all the @Stuff in multiline comment. 240 241 Assuming only one exclusive doclet in same comment 242 Except for @State + @Benchmark (wich is allowed by legacy) 243 @Tags, @Bugs, @Import could co-exist with other @Stuff 244 """ 245 doclets = self.doclist(comment) 246 self.validate_comment(doclets) 247 248 def filter_doclets(t: Doclet) -> List[NameVal]: 249 return list(filter(lambda x: x.name == t.value, doclets)) 250 251 states = filter_doclets(Doclet.STATE)[:1] 252 benchmarks = filter_doclets(Doclet.BENCHMARK)[:1] 253 generators = filter_doclets(Doclet.GENERATOR)[:1] 254 for _, value in filter_doclets(Doclet.TAGS): 255 self.process_tag(value, states) 256 for _, value in filter_doclets(Doclet.BUGS): 257 self.process_bug(value, states) 258 for _, value in filter_doclets(Doclet.IMPORT): 259 if self.state: 260 self.state.imports.append(value) 261 else: 262 self.__pending_imports.append(value) 263 for _, value in filter_doclets(Doclet.INCLUDE): 264 value = str(value).strip("\'\"") 265 if self.state: 266 self.state.includes.append(value) 267 else: 268 self.__pending_includes.append(value) 269 for _ in states: 270 self.process_state(benchmarks, generators) 271 return 272 for _, param_values in filter_doclets(Doclet.PARAM)[:1]: 273 self.process_param(param_values) 274 return 275 for _ in filter_doclets(Doclet.SETUP)[:1]: 276 self.process_setup() 277 return 278 for _, value in benchmarks: 279 self.process_benchmark(value, filter_doclets(Doclet.RETURNS)) 280 return 281 282 def parse(self) -> DocletParser: 283 """Search and parse doclet comments.""" 284 comment = '' 285 try: 286 while True: 287 self.next() 288 if not comment and \ 289 re.search(self.re_comment_open, self.current): 290 comment += "\n" 291 continue 292 if comment and re.search(self.re_comment_close, self.current): 293 self.parse_comment(comment) 294 comment = '' 295 continue 296 if comment and '@' in self.current: 297 comment += "\n" + self.current 298 except IndexError: 299 pass 300 return self 301 302 303@dataclass 304class TemplateVars: # pylint: disable=invalid-name 305 """Params for bench template. 306 307 Names of class props are same as of variables inside template, 308 so this could be provided `as dict` to Template 309 """ 310 311 # Full bench source to be pasted into template 312 src: str 313 # Name of main class 314 state_name: str = '' 315 # Setup method call: bench.SomeMethod();' 316 state_setup: str = '' 317 # '\n'-joined list of 'bench.param1=5;' 318 state_params: str = '' 319 # ';'-joined param list of 'param1=5' 320 fixture: str = '' 321 fix_id: int = 0 322 # Name of test method 323 method_name: str = '' 324 method_rettype: str = '' 325 method_call: str = '' 326 bench_name: str = '' 327 bench_path: str = '' 328 common: str = '' # common feature is obsoleted 329 # this should be the only place with defaults 330 mi: int = 3 331 wi: int = 2 332 it: int = 1 333 wt: int = 1 334 fi: int = 0 335 gc: int = -1 336 tags: Any = None 337 bugs: Any = None 338 imports: Any = None 339 includes: Any = None 340 generator: str = '' 341 config: Dict[str, Any] = field(default_factory=dict) 342 aot_opts: str = '' 343 disable_inlining: bool = False 344 345 @classmethod 346 def params_from_parsed(cls, 347 src: str, 348 parsed: BenchClass, 349 args: Optional[Args] = None 350 ) -> Iterable[TemplateVars]: 351 """Produce all combinations of Benches and Params.""" 352 tags_filter = args.tags if args else [] 353 tests_filter = args.tests if args else [] 354 skip_tags = args.skip_tags if args else set() 355 # list of lists of tuples (param_name, param_value) 356 # sorting by param name to keep fixture indexing 357 params = [ 358 [(p, v) for v in vals] 359 for p, vals 360 in sorted(parsed.params.items()) 361 ] 362 fixtures = list(product(*params)) 363 for b in parsed.benches: 364 # check tags filter: 365 tags = set(parsed.tags + b.tags) # @State::@Tags + @Bench::@Tags 366 if skip_tags and set.intersection(tags, skip_tags): 367 log.trace("`%s` skipped by tags: Unwanted: %s Tagged: %s", b.name, skip_tags, tags) 368 continue 369 if tags_filter and not set.intersection(tags, tags_filter): 370 log.trace("`%s` skipped by tags: Wanted: %s Tagged: %s", b.name, tags_filter, tags) 371 continue 372 # if no params fixtures will be [()] 373 fix_id = 0 374 for f in fixtures: 375 tp = cls(src, parsed.name) 376 fix_str = ';'.join([f'{x[0]}={x[1]}' for x in f]) 377 tp.generator = parsed.generator if parsed.generator else '' 378 tp.config = {x[0]: x[1] for x in f} 379 if not fix_str: 380 fix_str = 'No Params' 381 tp.method_name = b.name 382 tp.method_rettype = b.return_type if b.return_type else '' 383 # strictly speaking, this should be lang specific 384 tp.state_params = "\n ".join([ 385 f'bench.{p} = {v};' for p, v in f]) 386 tp.fixture = fix_str 387 tp.fix_id = fix_id 388 tp.bench_name = f'{parsed.name}_{b.name}' 389 if tp.fix_id > 0: 390 tp.bench_name = f'{tp.bench_name}_{tp.fix_id}' 391 if tests_filter and \ 392 not any((x in tp.bench_name) for x in tests_filter): 393 fix_id += 1 394 continue 395 tp.state_setup = f'bench.{parsed.setup}();' \ 396 if parsed.setup else '' 397 tp.tags = tags 398 tp.bugs = set(parsed.bugs + b.bugs) 399 # Override measure settings in following order: 400 # Defaults -> CmdLine -> Class -> Bench 401 tp.set_measure_overrides(args, parsed.bench_args, b.args) 402 tp.imports = parsed.imports 403 tp.includes = parsed.includes 404 yield tp 405 fix_id += 1 406 407 def set_measure_overrides(self, *overrides) -> None: 408 """Override all measurement options.""" 409 for ovr in overrides: 410 if not ovr: 411 continue 412 if ovr.measure_iters is not None: 413 self.mi = ovr.measure_iters 414 if ovr.warmup_iters is not None: 415 self.wi = ovr.warmup_iters 416 if ovr.iter_time is not None: 417 self.it = ovr.iter_time 418 if ovr.warmup_time is not None: 419 self.wt = ovr.warmup_time 420 if ovr.fast_iters is not None: 421 self.fi = ovr.fast_iters 422 if ovr.sys_gc_pause is not None: 423 self.gc = ovr.sys_gc_pause 424 if ovr.compiler_inlining == 'false': 425 self.disable_inlining = True 426 self.config.update({'disable_inlining': True}) 427 if ovr.aot_compiler_options: 428 opts = ' '.join(ovr.aot_compiler_options) 429 self.aot_opts = f'{self.aot_opts} {opts} ' 430 self.config.update({'aot_opts': self.aot_opts}) 431