1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3 4# Copyright (c) 2024 Huawei Device Co., Ltd. 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# 17 18from __future__ import annotations 19 20import re 21import logging 22import argparse 23from typing import List, Dict, Optional, Iterable, Any 24from itertools import combinations, product 25from dataclasses import dataclass, field 26from collections import namedtuple 27from vmb.lang import LangBase 28from vmb.helpers import StringEnum, split_params 29from vmb.cli import Args, add_measurement_opts 30 31log = logging.getLogger('vmb') 32NameVal = namedtuple("NameVal", "name value") 33 34 35class LineParser: 36 37 def __init__(self, lines, pos, lang): 38 self.lines = lines 39 self.pos = pos 40 self.lang = lang 41 42 @property 43 def current(self) -> str: 44 if 0 <= self.pos < len(self.lines): 45 return self.lines[self.pos] 46 return '' 47 48 @classmethod 49 def create(cls, text: str, lang: LangBase): 50 return cls(text.split("\n"), -1, lang) 51 52 def next(self) -> None: 53 if self.pos + 1 > len(self.lines): 54 raise IndexError('No lines left') 55 self.pos += 1 56 57 def skip_empty(self) -> None: 58 for _ in range(4): 59 self.next() 60 line = self.current.strip() 61 if line and not line.startswith('//'): 62 break 63 64 65class Doclet(StringEnum): 66 # Lang-dependent 67 STATE = "State" 68 BENCHMARK = "Benchmark" 69 PARAM = "Param" 70 SETUP = "Setup" 71 RETURNS = "returns" 72 # Lang-agnostic 73 IMPORT = "Import" 74 TAGS = "Tags" 75 BUGS = "Bugs" 76 GENERATOR = "Generator" # Legacy code generation 77 78 @staticmethod 79 def exclusive_doclets() -> Iterable[Doclet]: 80 return Doclet.STATE, Doclet.BENCHMARK, Doclet.SETUP, Doclet.PARAM 81 82 83@dataclass 84class BenchFunc: 85 name: str 86 return_type: Optional[str] = None 87 args: Optional[argparse.Namespace] = None 88 tags: List[str] = field(default_factory=list) 89 bugs: List[str] = field(default_factory=list) 90 91 92@dataclass 93class BenchClass: 94 name: str 95 setup: Optional[str] = None 96 params: Dict[str, List[str]] = field(default_factory=dict) 97 benches: List[BenchFunc] = field(default_factory=list) 98 bench_args: Optional[argparse.Namespace] = None 99 imports: List[str] = field(default_factory=list) 100 tags: List[str] = field(default_factory=list) 101 bugs: List[str] = field(default_factory=list) 102 generator: Optional[str] = None 103 104 105class DocletParser(LineParser): 106 107 # Note: strictly this should be `/**` not just `/*` 108 re_comment_open = re.compile(r'^\s*/\*(\*)?\s*$') 109 re_comment_close = re.compile(r'^\s*\*/\s*$') 110 re_doclet = re.compile(r'@+(\w+)\s*(.+)?$') 111 re_returns = re.compile(r'{+(\w+)}') 112 113 def __init__(self, lines, pos, lang) -> None: 114 super().__init__(lines, pos, lang) 115 self.state: Optional[BenchClass] = None 116 self.opts_parser = argparse.ArgumentParser() 117 add_measurement_opts(self.opts_parser) 118 self.__pending_tags: List[str] = [] 119 self.__pending_bugs: List[str] = [] 120 self.__pending_imports: List[str] = [] 121 122 @staticmethod 123 def validate_comment(doclets: List[NameVal]) -> None: 124 doclet_names = [v.name for v in doclets] 125 for d in Doclet.exclusive_doclets(): 126 if doclet_names.count(d.value) > 1: 127 raise ValueError(f'Multiple @{d} doclets in same comment') 128 for d1, d2 in combinations(Doclet.exclusive_doclets(), 2): 129 # allow @State + @Benchmark for compatibility reasons 130 if sorted((d1, d2)) == [Doclet.BENCHMARK, Doclet.STATE]: 131 continue 132 if doclet_names.count(d1.value) > 0 and \ 133 doclet_names.count(d2.value) > 0: 134 raise ValueError( 135 f'@{d1.value} and @{d2.value} doclets in same comment') 136 137 @staticmethod 138 def ensure_value(val: Optional[str]) -> str: 139 if not val: 140 raise ValueError('Empty value!') 141 return val 142 143 @staticmethod 144 def get_rettype(value: str) -> str: 145 m = DocletParser.re_returns.search(value) 146 if m: 147 return m.group(1) 148 return '' 149 150 def doclist(self, comment: str) -> List[NameVal]: 151 doclets = [] 152 for line in comment.split("\n"): 153 m = self.re_doclet.search(line) 154 if m: 155 doclets.append(NameVal(*m.groups())) 156 return doclets 157 158 def ensure_state(self) -> BenchClass: 159 if not self.state: 160 raise ValueError('No state found!') 161 return self.state 162 163 def parse_bench_overrides(self, line: str) -> Optional[argparse.Namespace]: 164 """Parse @Benchmark options.""" 165 overrides = None 166 if line: 167 overrides, unknown = \ 168 self.opts_parser.parse_known_args(line.split()) 169 if unknown: 170 raise ValueError(f'Unknown arg to @Benchmark: {unknown}') 171 return overrides 172 173 def process_state(self, benchmarks: List[NameVal], 174 generators: List[NameVal]) -> None: 175 self.skip_empty() 176 class_name = self.lang.parse_state(self.current) 177 if not class_name: 178 raise ValueError('Bench class declaration not found!') 179 self.state = BenchClass(name=class_name, tags=self.__pending_tags, 180 bugs=self.__pending_bugs, 181 imports=self.__pending_imports) 182 self.__pending_tags, self.__pending_bugs = [], [] 183 self.__pending_imports = [] 184 # check if there are overrides for whole class 185 for _, value in benchmarks: 186 self.state.bench_args = self.parse_bench_overrides(value) 187 if generators: 188 self.state.generator = generators[0].value 189 190 def process_benchmark(self, value: str, returns: List[NameVal]) -> None: 191 self.skip_empty() 192 f = self.lang.parse_func(self.current) 193 if not f: 194 raise ValueError('Bench func declaration not found!') 195 overrides = self.parse_bench_overrides(value) 196 ret_type = f[1] 197 if returns: 198 typ = self.get_rettype(returns[0].value) 199 if typ: 200 ret_type = typ 201 self.ensure_state().benches.append( 202 BenchFunc(name=f[0], return_type=ret_type, args=overrides, 203 tags=self.__pending_tags, bugs=self.__pending_bugs)) 204 self.__pending_tags, self.__pending_bugs = [], [] 205 206 def process_param(self, param_values: str) -> None: 207 self.skip_empty() 208 p = self.lang.parse_param(self.current) 209 if not p: 210 raise ValueError('Param declaration not found!') 211 self.ensure_state().params[p[0]] = \ 212 split_params(self.ensure_value(param_values)) 213 214 def process_setup(self) -> None: 215 self.skip_empty() 216 f = self.lang.parse_func(self.current) 217 if not f: 218 raise ValueError('Setup func declaration not found!') 219 self.ensure_state().setup = f[0] 220 221 def process_tag(self, value: str, states: List[NameVal]) -> None: 222 self.__pending_tags += split_params(value) 223 if self.state and states: 224 # only for @State + @Tags 225 self.state.tags += self.__pending_tags 226 self.__pending_tags = [] 227 228 def process_bug(self, value: str, states: List[NameVal]) -> None: 229 self.__pending_bugs += split_params(value) 230 if self.state and states: 231 # only for @State + @Bugs 232 self.state.bugs += self.__pending_bugs 233 self.__pending_bugs = [] 234 235 def parse_comment(self, comment: str) -> None: 236 """Process all the @Stuff in multiline comment. 237 238 Assuming only one exclusive doclet in same comment 239 Except for @State + @Benchmark (wich is allowed by legacy) 240 @Tags, @Bugs, @Import could co-exist with other @Stuff 241 """ 242 doclets = self.doclist(comment) 243 self.validate_comment(doclets) 244 245 def filter_doclets(t: Doclet) -> List[NameVal]: 246 return list(filter(lambda x: x.name == t.value, doclets)) 247 248 states = filter_doclets(Doclet.STATE)[:1] 249 benchmarks = filter_doclets(Doclet.BENCHMARK)[:1] 250 generators = filter_doclets(Doclet.GENERATOR)[:1] 251 for _, value in filter_doclets(Doclet.TAGS): 252 self.process_tag(value, states) 253 for _, value in filter_doclets(Doclet.BUGS): 254 self.process_bug(value, states) 255 for _, value in filter_doclets(Doclet.IMPORT): 256 if self.state: 257 self.state.imports.append(value) 258 else: 259 self.__pending_imports.append(value) 260 for _ in states: 261 self.process_state(benchmarks, generators) 262 return 263 for _, param_values in filter_doclets(Doclet.PARAM)[:1]: 264 self.process_param(param_values) 265 return 266 for _ in filter_doclets(Doclet.SETUP)[:1]: 267 self.process_setup() 268 return 269 for _, value in benchmarks: 270 self.process_benchmark(value, filter_doclets(Doclet.RETURNS)) 271 return 272 273 def parse(self) -> DocletParser: 274 """Search and parse doclet comments.""" 275 comment = '' 276 try: 277 while True: 278 self.next() 279 if not comment and \ 280 re.search(self.re_comment_open, self.current): 281 comment += "\n" 282 continue 283 if comment and re.search(self.re_comment_close, self.current): 284 self.parse_comment(comment) 285 comment = '' 286 continue 287 if comment and '@' in self.current: 288 comment += "\n" + self.current 289 except IndexError: 290 pass 291 return self 292 293 294@dataclass 295class TemplateVars: # pylint: disable=invalid-name 296 """Params for bench template. 297 298 Names of class props are same as of variables inside template, 299 so this could be provided `as dict` to Template 300 """ 301 302 # Full bench source to be pasted into template 303 src: str 304 # Name of main class 305 state_name: str = '' 306 # Setup method call: bench.SomeMethod();' 307 state_setup: str = '' 308 # '\n'-joined list of 'bench.param1=5;' 309 state_params: str = '' 310 # ';'-joined param list of 'param1=5' 311 fixture: str = '' 312 fix_id: int = 0 313 # Name of test method 314 method_name: str = '' 315 method_rettype: str = '' 316 method_call: str = '' 317 bench_name: str = '' 318 bench_path: str = '' 319 common: str = '' # common feature is obsoleted 320 # this should be the only place with defaults 321 mi: int = 3 322 wi: int = 2 323 it: int = 1 324 wt: int = 1 325 fi: int = 0 326 gc: int = -1 327 tags: Any = None 328 bugs: Any = None 329 imports: Any = None 330 generator: str = '' 331 config: Dict[str, Any] = field(default_factory=dict) 332 aot_opts: str = '' 333 disable_inlining: bool = False 334 335 @classmethod 336 def params_from_parsed(cls, 337 src: str, 338 parsed: BenchClass, 339 args: Optional[Args] = None 340 ) -> Iterable[TemplateVars]: 341 """Produce all combinations of Benches and Params.""" 342 tags_filter = args.tags if args else [] 343 tests_filter = args.tests if args else [] 344 skip_tags = args.skip_tags if args else set() 345 # list of lists of tuples (param_name, param_value) 346 # sorting by param name to keep fixture indexing 347 params = [ 348 [(p, v) for v in vals] 349 for p, vals 350 in sorted(parsed.params.items()) 351 ] 352 fixtures = list(product(*params)) 353 for b in parsed.benches: 354 # check tags filter: 355 tags = set(parsed.tags + b.tags) # @State::@Tags + @Bench::@Tags 356 if skip_tags and set.intersection(tags, skip_tags): 357 continue 358 if tags_filter and not set.intersection(tags, tags_filter): 359 continue 360 # if no params fixtures will be [()] 361 fix_id = 0 362 for f in fixtures: 363 tp = cls(src, parsed.name) 364 fix_str = ';'.join([f'{x[0]}={x[1]}' for x in f]) 365 tp.generator = parsed.generator if parsed.generator else '' 366 tp.config = {x[0]: x[1] for x in f} 367 if not fix_str: 368 fix_str = 'No Params' 369 tp.method_name = b.name 370 tp.method_rettype = b.return_type if b.return_type else '' 371 # strictly speaking, this should be lang specific 372 tp.state_params = "\n ".join([ 373 f'bench.{p} = {v};' for p, v in f]) 374 tp.fixture = fix_str 375 tp.fix_id = fix_id 376 tp.bench_name = f'{parsed.name}_{b.name}' 377 if tp.fix_id > 0: 378 tp.bench_name = f'{tp.bench_name}_{tp.fix_id}' 379 if tests_filter and \ 380 not any((x in tp.bench_name) for x in tests_filter): 381 fix_id += 1 382 continue 383 tp.state_setup = f'bench.{parsed.setup}();' \ 384 if parsed.setup else '' 385 tp.tags = tags 386 tp.bugs = set(parsed.bugs + b.bugs) 387 # Override measure settings in following order: 388 # Defaults -> CmdLine -> Class -> Bench 389 tp.set_measure_overrides(args, parsed.bench_args, b.args) 390 tp.imports = parsed.imports 391 yield tp 392 fix_id += 1 393 394 def set_measure_overrides(self, *overrides) -> None: 395 """Override all measurement options.""" 396 for ovr in overrides: 397 if not ovr: 398 continue 399 if ovr.measure_iters is not None: 400 self.mi = ovr.measure_iters 401 if ovr.warmup_iters is not None: 402 self.wi = ovr.warmup_iters 403 if ovr.iter_time is not None: 404 self.it = ovr.iter_time 405 if ovr.warmup_time is not None: 406 self.wt = ovr.warmup_time 407 if ovr.fast_iters is not None: 408 self.fi = ovr.fast_iters 409 if ovr.sys_gc_pause is not None: 410 self.gc = ovr.sys_gc_pause 411 if ovr.compiler_inlining == 'false': 412 self.disable_inlining = True 413 self.config.update({'disable_inlining': True}) 414 if ovr.aot_compiler_options: 415 opts = ' '.join(ovr.aot_compiler_options) 416 self.aot_opts = f'{self.aot_opts} {opts} ' 417 self.config.update({'aot_opts': self.aot_opts}) 418