1# Copyright (c) 2019-2020 Arm Limited. 2# 3# SPDX-License-Identifier: MIT 4# 5# Permission is hereby granted, free of charge, to any person obtaining a copy 6# of this software and associated documentation files (the "Software"), to 7# deal in the Software without restriction, including without limitation the 8# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 9# sell copies of the Software, and to permit persons to whom the Software is 10# furnished to do so, subject to the following conditions: 11# 12# The above copyright notice and this permission notice shall be included in all 13# copies or substantial portions of the Software. 14# 15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21# SOFTWARE. 22 23#!/usr/bin/python3 24 25import argparse 26import csv 27import json 28import logging 29import math 30import os 31from collections import Counter, defaultdict, deque, namedtuple 32from enum import Enum 33from pathlib import Path 34from typing import Deque, Dict, Generator, List, NamedTuple, Set, Tuple, Union 35 36################################################################################ 37# Types 38################################################################################ 39 40# Gemm strategy 41Strategy = Enum("Strategy", ["Native", "ReshapedOnlyRHS", "Reshaped"]) 42 43# Gemm parameter 44 45 46class GEMMParam(NamedTuple): 47 M: int # Number of lhs matrix rows 48 N: int # Number of rhs matrix columns 49 K: int # Number of lhs matrix columns/rhs matrix rows 50 B: int # Batch size 51 data_type: str # Data type 52 53 @classmethod 54 def parse_from_strs(cls, *M_N_K_B, data_type): 55 return cls(*map(int, M_N_K_B), str(data_type)) 56 57 def __str__(self): 58 return ",".join(map(str, self)) 59 60 61# Gemm configuration for strategy Native 62class NativeGEMMConfig(NamedTuple): 63 m0: int # Number of rows processed by the matrix multiplication 64 n0: int # Number of columns processed by the matrix multiplication 65 k0: int # Number of partial accumulations performed by the matrix multiplication 66 67 @classmethod 68 def parse_from_strs(cls, *args): 69 (*mnk,) = map(int, args) 70 return cls(*mnk) 71 72 def __str__(self): 73 return ",".join(map(str, self)) 74 75 76# Gemm configuration for strategy Reshaped Only RHS 77class ReshapedOnlyRHSGEMMConfig(NamedTuple): 78 m0: int # Number of rows processed by the matrix multiplication 79 n0: int # Number of columns processed by the matrix multiplication 80 k0: int # Number of partial accumulations performed by the matrix multiplication 81 # Number of horizontal blocks of size (k0xn0) stored on the same output row 82 h0: int 83 # Interleave rhs matrix (1) / Do not interleave rhs matrix (0) 84 interleave_rhs: bool 85 # Transpose rhs matrix but not lhs matrix (1) / Do not transpose rhs matrix but do transpose lhs matrix (0) 86 transpose_rhs: bool 87 # Export rhs matrix to cl_image (1) / Do not export rhs matrix to cl_image (0) 88 export_to_cl_image_rhs: bool 89 90 @classmethod 91 def parse_from_strs(cls, *args): 92 (*mnkh, interleave_rhs, transpose_rhs, export_to_cl_image_rhs,) = map(int, args) 93 interleave_rhs = interleave_rhs == 1 94 transpose_rhs = transpose_rhs == 1 95 export_to_cl_image_rhs = export_to_cl_image_rhs == 1 96 return cls(*mnkh, interleave_rhs, transpose_rhs, export_to_cl_image_rhs) 97 98 def __str__(self): 99 return ",".join(map(str, self)) 100 101 102# Gemm configuration for strategy Reshaped 103class ReshapedGEMMConfig(NamedTuple): 104 m0: int # Number of rows processed by the matrix multiplication 105 n0: int # Number of columns processed by the matrix multiplication 106 k0: int # Number of partial accumulations performed by the matrix multiplication 107 # Number of vertical blocks of size (m0xk0) stored on the same output row 108 v0: int 109 # Number of horizontal blocks of size (k0xn0) stored on the same output row 110 h0: int 111 # Interleave lhs matrix (1) / Do not interleave lhs matrix (0) 112 interleave_lhs: bool 113 # Interleave rhs matrix (1) / Do not interleave rhs matrix (0) 114 interleave_rhs: bool 115 # Transpose rhs matrix but not lhs matrix (1) / Do not transpose rhs matrix but do transpose lhs matrix (0) 116 transpose_rhs: bool 117 # Export rhs matrix to cl_image (1) / Do not export rhs matrix to cl_image (0) 118 export_to_cl_image_rhs: bool 119 120 @classmethod 121 def parse_from_strs(cls, *args): 122 (*mnkvh, interleave_lhs, interleave_rhs, transpose_rhs, export_to_cl_image_rhs,) = map(int, args) 123 interleave_lhs = interleave_lhs == 1 124 interleave_rhs = interleave_rhs == 1 125 transpose_rhs = transpose_rhs == 1 126 export_to_cl_image_rhs = export_to_cl_image_rhs == 1 127 return cls(*mnkvh, interleave_lhs, interleave_rhs, transpose_rhs, export_to_cl_image_rhs) 128 129 def __str__(self): 130 return ",".join(map(str, self)) 131 132 133# Measurement we take from the benchmark result. 134class Measurement(NamedTuple): 135 opencl_timer_ms_reshape: float 136 opencl_timer_ms_kernel: float 137 138 def get_total_ms(self): 139 return self.opencl_timer_ms_reshape + self.opencl_timer_ms_kernel 140 141 def is_close_to(self, other, tol): 142 return math.fabs(self.get_total_ms() - other.get_total_ms()) < tol 143 144 def is_better_than(self, other, tol): 145 return self.get_total_ms() < other.get_total_ms() and not self.is_close_to( 146 other 147 ) 148 149 def __add__(self, other): 150 return Measurement( 151 self.opencl_timer_ms_reshape + other.opencl_timer_ms_reshape, 152 self.opencl_timer_ms_kernel + other.opencl_timer_ms_kernel, 153 ) 154 155 def __sub__(self, other): 156 return Measurement( 157 self.opencl_timer_ms_reshape - other.opencl_timer_ms_reshape, 158 self.opencl_timer_ms_kernel - other.opencl_timer_ms_kernel, 159 ) 160 161 def __mul__(self, other): 162 return Measurement( 163 self.opencl_timer_ms_reshape * other.opencl_timer_ms_reshape, 164 self.opencl_timer_ms_kernel * other.opencl_timer_ms_kernel, 165 ) 166 167 def __floordiv__(self, other): 168 return Measurement( 169 self.opencl_timer_ms_reshape // other.opencl_timer_ms_reshape, 170 self.opencl_timer_ms_kernel // other.opencl_timer_ms_kernel, 171 ) 172 173 def __truediv__(self, other): 174 return Measurement( 175 self.opencl_timer_ms_reshape / other.opencl_timer_ms_reshape, 176 self.opencl_timer_ms_kernel / other.opencl_timer_ms_kernel, 177 ) 178 179 def __pow__(self, power): 180 return Measurement( 181 self.opencl_timer_ms_reshape ** power, self.opencl_timer_ms_kernel ** power 182 ) 183 184 def __str__(self): 185 return ",".join(map(str, self)) 186 187 188# GEMMConfig Type 189GEMMConfigT = Union[NativeGEMMConfig, 190 ReshapedOnlyRHSGEMMConfig, ReshapedGEMMConfig] 191 192 193# Representation of the benchmark result from a single experiment 194class BenchmarkResult(NamedTuple): 195 gemm_param: GEMMParam 196 strategy: Strategy 197 gemm_config: GEMMConfigT 198 measurement: Measurement 199 200 201class GEMMBenchmarkResultRecorder: 202 """ A recorder that records and organises GEMM Benchmark results, and produces various reports on the record. 203 """ 204 205 SummaryLevel = Enum("SummaryLevel", ["Short", "Detailed"]) 206 207 def __init__(self, tol=0.01): 208 """ Initializer 209 """ 210 self._benchmark_result_record: List[BenchmarkResult] = [] 211 # Strategies recorded 212 self._strategies = set() 213 self._tol = tol 214 215 def add(self, benchmark_result: BenchmarkResult): 216 """ Add a benchmark result to the record. 217 """ 218 gemm_param, strategy, gemm_config, measurement = benchmark_result 219 # Update strategies encoutnered 220 self._strategies.add(strategy) 221 222 self._benchmark_result_record.append(benchmark_result) 223 224 def get_record(self) -> Generator[BenchmarkResult, None, None]: 225 """ Return an iterator that iterates over the record. 226 """ 227 yield from self._benchmark_result_record 228 229 def get_best_gemm_configs(self): 230 """ Get the best GEMMConfig set per GEMMParam per Strategy 231 """ 232 best_gc_sets: Dict[ 233 Tuple[GEMMParam, Strategy], List[Tuple[GEMMConfig, Measurement]] 234 ] = defaultdict(list) 235 for gemm_param, strategy, gemm_config, measurement in self.get_record(): 236 best_gc_set = best_gc_sets.setdefault((gemm_param, strategy), []) 237 best_gc_set.append((gemm_config, measurement)) 238 # Sort the best config set (list) 239 best_gc_set = sorted( 240 best_gc_set, key=lambda gc_and_m: gc_and_m[1].get_total_ms() 241 ) 242 # Filter out configs that are beyond tolerance to the best GEMMConfig's measurement 243 best_gc, best_m = best_gc_set[0] 244 best_gc_set_new = [ 245 (gemm_config, measurement) 246 for gemm_config, measurement in best_gc_set[1:] 247 if measurement.is_close_to(best_m, self._tol) 248 ] 249 # Add back the best config 250 best_gc_set_new.insert(0, (best_gc, best_m)) 251 best_gc_sets[(gemm_param, strategy)] = best_gc_set_new 252 253 return best_gc_sets 254 255 def get_best_gemm_configs_as_sequence(self): 256 """ Get the best GEMMConfig set per GEMMParam per Strategy, and flatten the result into a sequence 257 of BenchmarkResults 258 """ 259 for ( 260 (gemm_param, strategy), 261 best_gc_sets, 262 ) in self.get_best_gemm_configs().items(): 263 for best_gemm_config, best_measurement in best_gc_sets: 264 yield BenchmarkResult( 265 gemm_param, strategy, best_gemm_config, best_measurement 266 ) 267 268 def get_config_distributions(self): 269 """ Return GEMMConfigDistribution for each strategy 270 """ 271 gemm_config_distributions: Dict[Strategy, GEMMConfigDistribution] = defaultdict( 272 GEMMConfigDistribution 273 ) 274 for benchmark_result in self.get_best_gemm_configs_as_sequence(): 275 _, strategy, _, _ = benchmark_result 276 gemm_config_distributions[strategy].add(benchmark_result) 277 278 return gemm_config_distributions 279 280 def get_best_gemm_strategies(self): 281 """ Get the best Stratey per GEMMParam 282 """ 283 all_results: Dict[GEMMParam, List[Tuple[Strategy, Measurement]]] = defaultdict( 284 list 285 ) 286 287 best_strategies: Dict[GEMMParam, Strategy] = {} 288 289 for gemm_param, strategy, gemm_config, measurement in self.get_record(): 290 all_results[gemm_param].append((strategy, measurement)) 291 292 for gemm_param, results_set in all_results.items(): 293 # Sort the best results set (list) 294 results_set = sorted( 295 results_set, key=lambda s_and_m: s_and_m[1].get_total_ms() 296 ) 297 # Select best Strategy 298 best_s, best_m = results_set[0] 299 best_strategies[gemm_param] = best_s 300 301 return best_strategies 302 303 def save_to_jsons(self, out_dir, only_best_config=True): 304 """ Save records to an output directory of JSON files. 305 The directory is organized such that each strategy gets its own JSON file. 306 The directory also includes a JSON file to define the best strategy per GEMM Param. 307 """ 308 if not os.path.exists(out_dir): 309 logging.info( 310 "Output directory {} does not exist. Creating...".format( 311 out_dir) 312 ) 313 os.mkdir(out_dir) 314 315 out_json_path = os.path.join(out_dir, "gemm_type_selection.json") 316 if check_out_path(out_json_path): 317 results = self.get_best_gemm_strategies() 318 results = {str(key): value.name for key, value in results.items()} 319 dump_json(out_json_path, results) 320 321 for strategy in self._strategies: 322 out_json_path = os.path.join( 323 out_dir, ("gemm_config_" + strategy.name.lower() + ".json") 324 ) 325 if check_out_path(out_json_path): 326 record = ( 327 self.get_best_gemm_configs_as_sequence() 328 if only_best_config 329 else self.get_record() 330 ) 331 results = defaultdict(list) 332 for res in record: 333 if res.strategy == strategy: 334 results[str(res.gemm_param)].append( 335 { 336 "GEMMConfig": str(res.gemm_config), 337 "OpenCL_Timer_ms_reshape": str( 338 res.measurement.opencl_timer_ms_reshape 339 ), 340 "OpenCL_Timer_ms_kernel": str( 341 res.measurement.opencl_timer_ms_kernel 342 ), 343 } 344 ) 345 dump_json(out_json_path, results) 346 347 def summary(self, sum_level=SummaryLevel.Short): 348 """ Return the summary string of the record 349 """ 350 num_raw_records = sum(1 for _ in self.get_record()) 351 gemm_params_per_strategy = defaultdict(list) 352 for gemm_param, strategy in self.get_best_gemm_configs().keys(): 353 gemm_params_per_strategy[strategy].append(gemm_param) 354 global_summary = f""" 355=== {self.__class__.__name__} Summary === 356[Global] 357Strategies recorded: {", ".join(map(lambda s: s.name, self._strategies))} 358Total number of results recorded: {num_raw_records} 359 360[Per strategy] 361 """ 362 strategy_summaries = [] 363 for strategy in gemm_params_per_strategy: 364 summary = f""" 365Strategy {strategy.name}: 366GEMM parameters: 367 Number of: {len(gemm_params_per_strategy[strategy])} 368 """ 369 if sum_level == self.__class__.SummaryLevel.Detailed: 370 summary += f""" 371 Content: {gemm_params_per_strategy[strategy]} 372 """ 373 strategy_summaries.append(summary) 374 return global_summary + "".join(strategy_summaries) 375 376 377class GEMMConfigDistribution: 378 """ A representation of the GEMM Configuration distribution produced by the GEMMBenchmarkResultRecorder. 379 """ 380 381 def __init__(self): 382 """ Initializer 383 """ 384 self._gemm_config_dist: Dict[ 385 GEMMConfig, List[Tuple[GEMMParam, Measurement]] 386 ] = defaultdict(list) 387 self._gemm_config_freq = Counter() 388 389 def add(self, benchmark_result: BenchmarkResult): 390 """ Add a benchmark result to the distribution 391 """ 392 gemm_param, _, gemm_config, measurement = benchmark_result 393 self._gemm_config_dist[gemm_config].append((gemm_param, measurement)) 394 self._gemm_config_freq[gemm_config] += 1 395 396 def distribution(self): 397 return self._gemm_config_dist 398 399 def frequency(self): 400 """ Get the frequency of each (best) gemm config recorded 401 """ 402 return self._gemm_config_freq.most_common() 403 404 def best_config(self): 405 """ Get the overall best config, as voted by all benchmark results. 406 """ 407 return self._gemm_config_freq.most_common(1) 408 409 def std(self): 410 """ Get the standard deviation as a measure of dispersion of the distribution. We should aim for higher values 411 as they indicate there is high variation in the distribution. Thus the evidence of the best config is stronger. 412 """ 413 freqs = self._gemm_config_freq.values() 414 if len(freqs) == 0: 415 return 0 416 mean_freq = sum(freqs) / len(freqs) 417 return math.sqrt(sum((freq - mean_freq) ** 2 for freq in freqs) / len(freqs)) 418 419 420################################################################################ 421# Globals 422################################################################################ 423 424# Gemm config type factory 425# Produces a GEMMConfig type specific to a Strategy 426GEMM_CONFIG_FACTORY = { 427 Strategy.Native: NativeGEMMConfig, 428 Strategy.ReshapedOnlyRHS: ReshapedOnlyRHSGEMMConfig, 429 Strategy.Reshaped: ReshapedGEMMConfig, 430} 431 432# Mapping from example binary name to Strategy 433# Assume 1-to-1 mapping 434EXAMPLE_FILE_2_STRATEGY = { 435 "benchmark_cl_gemm_native": Strategy.Native, 436 "benchmark_cl_gemm_reshaped_rhs_only": Strategy.ReshapedOnlyRHS, 437 "benchmark_cl_gemm_reshaped": Strategy.Reshaped, 438} 439 440# Gemm example arguments type factory 441# Produces a Gemm_Example_Args type specific to a Strategy 442# Gemm example arguments consist of: 443# GEMMParam + GEMMConfig 444# in that order. 445# For example, the example args of running a reshaped rhs only example could be: 446# 100,100,100,1, 4, 4, 4, 1, 1, 1, 0 447# M ,N ,K, B,m0,n0,k0,h0,interleave_rhs,transpose_rhs,export_to_cl_image_rhs 448# <-GEMMParam-><-------------GEMMConfig---------------------------------------> 449# Note that the test strategy_name == strategy.name is in place to avoid unwanted enum aliases 450GEMM_EXAMPLE_ARGS_FACTORY = { 451 # We ignore the data type field from GEMMParam as that is extracted separately 452 strategy: namedtuple( 453 "{}_Gemm_Example_Args".format(strategy_name), 454 GEMMParam._fields[:-1] + GEMM_CONFIG_FACTORY[strategy]._fields, 455 ) 456 for strategy_name, strategy in Strategy.__members__.items() 457 if strategy_name == strategy.name 458} 459 460# File extension used for benchmark result json files 461BENCHMARK_RESULT_JSON_EXTENSION = "gemmtuner_benchmark" 462 463################################################################################ 464# Functions 465################################################################################ 466 467 468def parse_benchmark_commandline(commandline: str) -> Dict[str, str]: 469 """ Parse the benchmark example command-line string into a dictionary of command-line arguments 470 """ 471 # Separate the data type option from the example_args portion of the string 472 commandline = commandline.replace(",--type=", " --type=") 473 474 args = commandline.split() 475 # Discard program name 476 args = args[1:] 477 # Split into a list of (argument name, argument value) 478 args = map(lambda arg: arg.split("="), args) 479 480 def transform(_name): 481 # Strip '-'/"--" if it exists 482 _name = _name.lstrip("-") 483 return _name 484 485 return {transform(name): val for name, val in args} 486 487 488def extract_benchmark_results( 489 json_results: Dict, measurement_method="avg" 490) -> Generator[BenchmarkResult, None, None]: 491 """ Parse the benchmark result and extract relevant information, namely: 492 GEMM param, 493 Strategy, 494 GEMM config, 495 Measurements 496 """ 497 for json_res in json_results: 498 # Get example test and test data. 499 # There should only be 1 test per run 500 example_tests = list(json_res["tests"].items()) 501 assert len(example_tests) == 1 502 example_fn, example_test_data = example_tests[0] 503 504 # Process example file name 505 example_fn = example_fn.split(os.path.sep)[-1] 506 507 # Get strategy 508 strategy = EXAMPLE_FILE_2_STRATEGY[example_fn] 509 510 # Get gemm params + gemm configs from example args 511 benchmark_args = parse_benchmark_commandline(json_res["CommandLine"]) 512 Gemm_Example_Args_T = GEMM_EXAMPLE_ARGS_FACTORY[strategy] 513 example_args = Gemm_Example_Args_T( 514 *(benchmark_args["example_args"].split(","))) 515 # Gemm_Example_Arg consists of GEMMParam first and then GEMMConfig (in that order) 516 # However data type option is parsed separately from end of options, hence -1 is applied to fields length 517 gemm_param_fields_len = len(GEMMParam._fields) - 1 518 gemm_param = GEMMParam.parse_from_strs( 519 *example_args[:gemm_param_fields_len], 520 data_type = benchmark_args["type"]) 521 GEMMConfig = GEMM_CONFIG_FACTORY[strategy] 522 gemm_config = GEMMConfig.parse_from_strs( 523 *example_args[gemm_param_fields_len:]) 524 525 # Get OpenCL_Time_Ms stats 526 measurements = list(example_test_data["measurements"].items()) 527 # For reshaped RHS only we have two measurements (one also for the reshape kernel) 528 # Hence we must parse and sum them 529 measurement_ms_reshape = 0 530 measurement_ms_kernel = 0 531 for single_measurement in measurements: 532 measurement_instrument, data = single_measurement 533 # Get instrument name and assert that it is the one we expect 534 measurement_instrument_name = measurement_instrument.split("/")[0] 535 assert measurement_instrument_name == "OpenCLTimer" 536 # Take either the minimum or the average of the raw data as the measurement value 537 if measurement_method == "min": 538 measurement_val = min(data["raw"]) 539 elif measurement_method == "avg": 540 measurement_val = sum(data["raw"]) / len(data["raw"]) 541 else: 542 raise ValueError( 543 "Invalid measurement method: {}".format(measurement_method) 544 ) 545 546 measurement_type = measurement_instrument.split("/")[1] 547 if "reshape" in measurement_type.split("_"): 548 measurement_ms_reshape = measurement_val 549 else: 550 measurement_ms_kernel = measurement_val 551 552 measurement = Measurement( 553 measurement_ms_reshape, measurement_ms_kernel) 554 555 yield BenchmarkResult(gemm_param, strategy, gemm_config, measurement) 556 557 558def parse_json(dir_name): 559 """ Glob all benchmark result json files and parse them into json objects (dicts). 560 """ 561 for res_fn in Path(dir_name).rglob("*.{}".format(BENCHMARK_RESULT_JSON_EXTENSION)): 562 with open(res_fn) as res_fp: 563 yield json.load(res_fp) 564 565 566def check_out_path(out_path): 567 if os.path.exists(out_path): 568 overwrite = ( 569 input( 570 "Output JSON {} already exists. Overwrite? [Y/N]: ".format( 571 out_path) 572 ).lower() 573 == "y" 574 ) 575 if not overwrite: 576 logging.info("Skipping {}".format(out_path)) 577 return False 578 logging.info("Saving JSON file to {}".format(out_path)) 579 return True 580 581 582def dump_json(out_path, dict): 583 with open(out_path, "w") as f: 584 json.dump(dict, f) 585 logging.info("Saved") 586 587 588################################################################################ 589# Main 590################################################################################ 591 592 593def main(args): 594 logging.info( 595 "Searching best gemm configurations from {}".format( 596 args.benchmark_results_dir) 597 ) 598 599 benchmark_results = extract_benchmark_results( 600 parse_json(args.benchmark_results_dir) 601 ) 602 603 # Add all benchmark results to the recorder 604 benchmark_result_recorder = GEMMBenchmarkResultRecorder(tol=args.tolerance) 605 for benchmark_result in benchmark_results: 606 benchmark_result_recorder.add(benchmark_result) 607 608 if args.debug: 609 recorder_sum_level = GEMMBenchmarkResultRecorder.SummaryLevel.Detailed 610 else: 611 recorder_sum_level = GEMMBenchmarkResultRecorder.SummaryLevel.Short 612 613 # Print overall summary of the recorded results 614 logging.info(benchmark_result_recorder.summary( 615 sum_level=recorder_sum_level)) 616 617 # Get GEMM configuration distributions for each strategy 618 all_config_dists = benchmark_result_recorder.get_config_distributions() 619 620 logging.info("=== Result ===") 621 for strategy, config_dist in all_config_dists.items(): 622 logging.info("Strategy: {}".format(strategy.name)) 623 logging.debug("GEMM Config, Votes") 624 for config, freq in config_dist.frequency(): 625 logging.debug("{}, {}".format(config, freq)) 626 logging.info( 627 "Best GEMM Config: {} with std: {}".format( 628 config_dist.best_config(), config_dist.std() 629 ) 630 ) 631 632 # Save the recorded results to JSON files in output directory 633 if args.output_dir is not None: 634 benchmark_result_recorder.save_to_jsons( 635 args.output_dir, only_best_config=(not args.debug) 636 ) 637 638 639if __name__ == "__main__": 640 parser = argparse.ArgumentParser(description="CL GEMM Tuner") 641 parser.add_argument( 642 "-b", 643 "--benchmark_results", 644 dest="benchmark_results_dir", 645 metavar="PATH", 646 action="store", 647 type=str, 648 help="Path to benchmark result directory, where benchmark result json files have a file \ 649 extension of '{}'".format( 650 BENCHMARK_RESULT_JSON_EXTENSION 651 ), 652 required=True, 653 ) 654 parser.add_argument( 655 "-o", 656 "--output_dir", 657 dest="output_dir", 658 metavar="PATH", 659 action="store", 660 type=str, 661 help="Path to directory that holds output JSON files. One for strategy selection and one per strategy for GEMM config selection", 662 ) 663 parser.add_argument( 664 "-t", 665 "--tolerance", 666 action="store", 667 type=float, 668 default=0.01, 669 help="For testing if two GEMMConfigs are equivalent in terms of performance. The tolerance is OpenCL timer in\ 670 milliseconds. Recommended value: <= 0.1 ms", 671 ) 672 parser.add_argument( 673 "-D", 674 "--debug", 675 dest="debug", 676 action="store_true", 677 help="Enable script debugging output", 678 ) 679 args = parser.parse_args() 680 logging_level = logging.DEBUG if args.debug else logging.INFO 681 logging.basicConfig(level=logging_level) 682 logging.debug("Arguments: {}".format(args)) 683 main(args) 684