1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Thr parser for parsing framework files.""" 16import csv 17import enum 18import json 19import os 20import re 21import stat 22 23from mindspore.profiler.common.exceptions.exceptions import \ 24 ProfilerPathErrorException, ProfilerDirNotFoundException, \ 25 ProfilerFileNotFoundException, ProfilerDeviceIdMismatchException, \ 26 ProfilerRawFileException, ProfilerParamValueErrorException 27from mindspore.profiler.common.validator.validate_path import \ 28 validate_and_normalize_path 29 30 31class VmDataType(enum.IntEnum): 32 """Definition of vm data type.""" 33 NUMBER_TYPE_BEGIN = 30 34 NUMBER_TYPE_BOOL = 31 35 NUMBER_TYPE_INT = 32 36 NUMBER_TYPE_INT8 = 33 37 NUMBER_TYPE_INT16 = 34 38 NUMBER_TYPE_INT32 = 35 39 NUMBER_TYPE_INT64 = 36 40 NUMBER_TYPE_UINT = 37 41 NUMBER_TYPE_UINT8 = 38 42 NUMBER_TYPE_UINT16 = 39 43 NUMBER_TYPE_UINT32 = 40 44 NUMBER_TYPE_UINT64 = 41 45 NUMBER_TYPE_FLOAT = 42 46 NUMBER_TYPE_FLOAT16 = 43 47 NUMBER_TYPE_FLOAT32 = 44 48 NUMBER_TYPE_FLOAT64 = 45 49 NUMBER_TYPE_COMPLEX = 46 50 NUMBER_TYPE_END = 47 51 52 @classmethod 53 def get_data_type_name(cls, num): 54 """ 55 Get the name of data type by enum number. 56 57 Args: 58 num (int): Enum number. 59 60 Returns: 61 str, the name of data type. 62 """ 63 data_type = cls._value2member_map_.get(num) 64 return 'UNKNOWN' if data_type is None else data_type.name 65 66 67class GeDataType(enum.IntEnum): 68 """Definition of ge data type.""" 69 DT_FLOAT = 0 70 DT_FLOAT16 = 1 71 DT_INT8 = 2 72 DT_INT16 = 6 73 DT_UINT16 = 7 74 DT_UINT8 = 4 75 DT_INT32 = 3 76 DT_INT64 = 9 77 DT_UINT32 = 8 78 DT_UINT64 = 10 79 DT_BOOL = 12 80 DT_DOUBLE = 11 81 DT_STRING = 13 82 DT_DUAL_SUB_INT8 = 14 83 DT_DUAL_SUB_UINT8 = 15 84 DT_COMPLEX64 = 16 85 DT_COMPLEX128 = 17 86 DT_QINT8 = 18 87 DT_QINT16 = 19 88 DT_QINT32 = 20 89 DT_QUINT8 = 21 90 DT_QUINT16 = 22 91 DT_RESOURCE = 23 92 DT_STRING_REF = 24 93 DT_DUAL = 25 94 DT_UNDEFINED = 26 95 96 @classmethod 97 def get_data_type_name(cls, num): 98 """ 99 Get the name of data type by enum number. 100 101 Args: 102 num (int): Enum number. 103 104 Returns: 105 str, the name of data type. 106 """ 107 data_type = cls._value2member_map_.get(num) 108 return 'UNKNOWN' if data_type is None else data_type.name 109 110 111class GeFormat(enum.IntEnum): 112 """Definition of ge format type.""" 113 FORMAT_NCHW = 0 114 FORMAT_NHWC = 1 115 FORMAT_ND = 2 116 FORMAT_NC1HWC0 = 3 117 FORMAT_FRACTAL_Z = 4 118 FORMAT_NC1C0HWPAD = 5 119 FORMAT_NHWC1C0 = 6 120 FORMAT_FSR_NCHW = 7 121 FORMAT_FRACTAL_DECONV = 8 122 FORMAT_C1HWNC0 = 9 123 FORMAT_FRACTAL_DECONV_TRANSPOSE = 10 124 FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS = 11 125 FORMAT_NC1HWC0_C04 = 12 126 FORMAT_FRACTAL_Z_C04 = 13 127 FORMAT_CHWN = 14 128 FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS = 15 129 FORMAT_HWCN = 16 130 FORMAT_NC1KHKWHWC0 = 17 131 FORMAT_BN_WEIGHT = 18 132 FORMAT_FILTER_HWCK = 19 133 FORMAT_HASHTABLE_LOOKUP_LOOKUPS = 20 134 FORMAT_HASHTABLE_LOOKUP_KEYS = 21 135 FORMAT_HASHTABLE_LOOKUP_VALUE = 22 136 FORMAT_HASHTABLE_LOOKUP_OUTPUT = 23 137 FORMAT_HASHTABLE_LOOKUP_HITS = 24 138 FORMAT_C1HWNCOC0 = 25 139 FORMAT_MD = 26 140 FORMAT_NDHWC = 27 141 FORMAT_FRACTAL_ZZ = 28 142 FORMAT_FRACTAL_NZ = 29 143 FORMAT_NCDHW = 30 144 FORMAT_DHWCN = 31 145 FORMAT_NDC1HWC0 = 32 146 FORMAT_FRACTAL_Z_3D = 33 147 FORMAT_CN = 34 148 FORMAT_NC = 35 149 FORMAT_DHWNC = 36 150 FORMAT_FRACTAL_Z_3D_TRANSPOSE = 37 151 FORMAT_RESERVED = 38 152 FORMAT_ALL = 39 153 154 @classmethod 155 def get_format_name(cls, num): 156 """ 157 Get the name of format type by enum number. 158 159 Args: 160 num (int): Enum number. 161 162 Returns: 163 str, the name of format type. 164 """ 165 format_type = cls._value2member_map_.get(num) 166 return 'UNKNOWN' if format_type is None else format_type.name 167 168 169class FrameworkParser: 170 """ 171 Thr parser for parsing framework files. 172 173 Args: 174 profiling_id (str): The profiling ID. 175 device_id (str): The device ID. 176 rank_id (str): The rank ID. 177 output_path (str): The directory of the parsed file. Default: `./`. 178 """ 179 _regex_framework = r'Framework\.(?P<data_type>.+)\.(?P<device_id>\d).+' 180 _regex_framework_in_data = r'Framework\.(?P<data_type>.+)\.' \ 181 r'(?P<device_id>\d)\.(?P<profiling_id>[a-zA-Z0-9]+).+' 182 _col_names = [ 183 'task_id', 'stream_id', 'block_dim', 'full_op_name', 'op_name', 184 'op_type', 'subgraph', 'op_info' 185 ] 186 _graph_attr_name = [ 187 'input_format', 'input_data_type', 'input_shape', 'output_format', 188 'output_data_type', 'output_shape' 189 ] 190 191 # if the task id is less than the task id threshold, The combination of 192 # task id and Stream id represents one operator, else the task id represents 193 # one operator 194 _task_id_threshold = 25000 195 196 def __init__(self, profiling_id, device_id, rank_id, output_path='./'): 197 self._raw_data_dir = output_path 198 self._profiling_path = self._get_raw_profiling_path(profiling_id) 199 self._backend_type = None 200 self._framework_path = {'graph': [], 'task': [], 'point': []} 201 self._search_file(profiling_id, device_id) 202 self._device_id = device_id 203 self._save_path = self._get_save_path(rank_id, output_path) 204 self._task_id_full_op_name_dict = {} 205 self._task_cache = {} 206 self._point_info = {} 207 self._parse_task_files() 208 self._parse_point_files() 209 210 @property 211 def save_path(self): 212 """ 213 The property of save path. 214 215 Returns: 216 str, the save path. 217 """ 218 return self._save_path 219 220 @property 221 def point_info(self): 222 """ 223 The property of the framework point information. 224 225 Returns: 226 dict, the framework point information. 227 """ 228 return self._point_info 229 230 def to_task_id_full_op_name_dict(self): 231 """ 232 Get the task id and full operator name dict. 233 234 Returns: 235 dict, the task id and full operator name dict. 236 """ 237 return self._task_id_full_op_name_dict 238 239 def parse(self): 240 """Parse the framework files.""" 241 self._parse_graph_files_and_save(self._task_cache) 242 del self._task_cache 243 244 def check_op_name(self, op_name, is_prefix=True): 245 """ 246 Check whether the operator name exists. 247 248 Args: 249 op_name (str): The operator name or operator name prefix. 250 is_prefix (bool): `True` if the op_name is prefix, else `False`. 251 Default: True. 252 253 Returns: 254 bool, `True` if the operator name does exist in framework file, else 255 `False`. 256 """ 257 if not op_name: 258 raise ProfilerParamValueErrorException('The op_name should exist.') 259 for full_op_name in self._task_id_full_op_name_dict.values(): 260 if full_op_name: 261 if is_prefix and full_op_name.startswith(op_name): 262 return True 263 if not is_prefix and op_name == full_op_name: 264 return True 265 return False 266 267 def _get_raw_profiling_path(self, profiling_id): 268 """ 269 Get raw profiling path. 270 271 Args: 272 profiling_id (str): The profiling ID. 273 274 Returns: 275 str, the raw profiling path. 276 277 Raises: 278 ProfilerPathErrorException: If the profiling path is invalid. 279 ProfilerDirNotFoundException: If the profiling dir is not found. 280 """ 281 profiling_path = os.path.join(self._raw_data_dir, profiling_id) 282 try: 283 profiling_path = validate_and_normalize_path(profiling_path) 284 except RuntimeError: 285 raise ProfilerPathErrorException('Profiling path is invalid.') 286 if not os.path.isdir(profiling_path): 287 raise ProfilerDirNotFoundException(profiling_path) 288 return profiling_path 289 290 def _search_file(self, profiling_id, device_id): 291 """ 292 Search all framework files in raw profiling path. 293 294 Args: 295 profiling_id (str): The profiling ID. 296 device_id (str): The device ID. 297 298 Raises: 299 ProfilerFileNotFoundException: If the framework files are not found. 300 """ 301 # first search in the JOB dir, and if not, search in the sub directory 302 # in the JOB 303 self._search_file_from_job_path(device_id, search_in_sub_path=False) 304 if self._backend_type is None: 305 self._search_file_from_job_path(device_id, search_in_sub_path=True) 306 self._search_file_from_data_path(profiling_id, device_id) 307 308 if self._backend_type is None: 309 raise ProfilerFileNotFoundException('Framework') 310 self._framework_path['graph'].sort() 311 self._framework_path['task'].sort() 312 313 def _search_file_from_job_path(self, device_id, search_in_sub_path=False): 314 """ 315 Search framework files from job path. 316 317 Args: 318 device_id (str): The device ID. 319 search_in_sub_path (bool): `True` if search file in profiling dir, 320 else search in profiling sub dir. Default: False. 321 322 Raises: 323 ProfilerRawFileException: If the framework file type is inconsistent. 324 ProfilerDeviceIdMismatchException: If the device id is mismatch 325 with framework in the raw dir. 326 """ 327 profiling_dir = os.path.join(self._profiling_path, 'data') \ 328 if search_in_sub_path else self._profiling_path 329 if not os.path.isdir(profiling_dir): 330 return 331 332 files = os.listdir(profiling_dir) 333 for file in files: 334 pattern = re.search(self._regex_framework, file) 335 if not pattern or file.endswith('.done'): 336 continue 337 attrs = pattern.groupdict() 338 339 device_id_in_path = attrs.get('device_id') 340 if device_id_in_path != device_id: 341 raise ProfilerDeviceIdMismatchException() 342 343 data_type = attrs.get('data_type') 344 data_type = data_type.replace("host.", "") 345 if data_type.startswith('vm_'): 346 if self._backend_type and self._backend_type != 'vm': 347 raise ProfilerRawFileException('Backend type is inconsistent.') 348 self._backend_type = 'vm' 349 _, data_type = data_type.split('_', 1) 350 else: 351 if self._backend_type and self._backend_type != 'ge': 352 raise ProfilerRawFileException('Backend type is inconsistent.') 353 self._backend_type = 'ge' 354 if data_type.startswith('graph_desc_info'): 355 self._framework_path['graph'].append( 356 os.path.join(profiling_dir, file) 357 ) 358 elif data_type.startswith('task_desc_info'): 359 self._framework_path['task'].append( 360 os.path.join(profiling_dir, file) 361 ) 362 elif data_type.startswith('point'): 363 self._framework_path['point'].append( 364 os.path.join(profiling_dir, file) 365 ) 366 367 def _search_file_from_data_path(self, profiling_id, device_id): 368 """ 369 Search framework files from data path. 370 371 Args: 372 profiling_id (str): The profiling ID. 373 device_id (str): The device ID. 374 375 Raises: 376 ProfilerRawFileException: If the framework file type is inconsistent. 377 ProfilerDeviceIdMismatchException: If the device id is mismatch 378 with framework in the raw dir. 379 """ 380 profiling_data_path = os.path.join( 381 self._raw_data_dir, 'container', device_id, 'data' 382 ) 383 if not os.path.isdir(profiling_data_path): 384 return 385 386 files = os.listdir(profiling_data_path) 387 for file in files: 388 pattern = re.search(self._regex_framework_in_data, file) 389 if not pattern or file.endswith('.done') or file.endswith('.zip'): 390 continue 391 attrs = pattern.groupdict() 392 393 profiling_id_in_path = attrs.get('profiling_id') 394 if profiling_id_in_path != profiling_id: 395 continue 396 397 device_id_in_path = attrs.get('device_id') 398 if device_id_in_path != device_id: 399 raise ProfilerDeviceIdMismatchException() 400 401 data_type = attrs.get('data_type') 402 data_type = data_type.replace("host.", "") 403 if data_type.startswith('vm_'): 404 if self._backend_type and self._backend_type != 'vm': 405 raise ProfilerRawFileException('Backend type is inconsistent.') 406 self._backend_type = 'vm' 407 _, data_type = data_type.split('_', 1) 408 else: 409 if self._backend_type and self._backend_type != 'ge': 410 raise ProfilerRawFileException('Backend type is inconsistent.') 411 self._backend_type = 'ge' 412 if data_type.startswith('graph_desc_info'): 413 self._framework_path['graph'].append( 414 os.path.join(profiling_data_path, file) 415 ) 416 elif data_type.startswith('task_desc_info'): 417 self._framework_path['task'].append( 418 os.path.join(profiling_data_path, file) 419 ) 420 elif data_type.startswith('point'): 421 self._framework_path['point'].append( 422 os.path.join(profiling_data_path, file) 423 ) 424 425 def _get_save_path(self, rank_id, output_path): 426 """ 427 Get the save path. 428 429 Args: 430 rank_id (str): The rank ID. 431 output_path (str): The output dir. 432 433 Returns: 434 str, the save path. 435 436 Raises: 437 ProfilerPathErrorException: If the output path is invalid. 438 ProfilerDirNotFoundException: If the output dir is not found. 439 """ 440 try: 441 output_dir = validate_and_normalize_path(output_path) 442 except RuntimeError: 443 raise ProfilerPathErrorException('Output path is invalid.') 444 if not os.path.isdir(output_dir): 445 raise ProfilerDirNotFoundException(output_dir) 446 return os.path.join( 447 output_dir, '_'.join(['framework', 'raw', rank_id]) + '.csv' 448 ) 449 450 def _parse_task_files(self): 451 """Parse the framework task files.""" 452 for path in self._framework_path['task']: 453 path = validate_and_normalize_path(path) 454 with open(path, 'r') as file: 455 for task_info in file: 456 infos = task_info.strip('\n').split(' ') 457 infos = infos[1:] if len(infos) == 5 else infos 458 # key is op name, values is task id, stream id, block_dim 459 self._task_cache[infos[0]] = [infos[2], infos[3], infos[1]] 460 461 # if the task id is less than the task id threshold, the 462 # stream id and task id correspond to an operator 463 task_id = infos[2] 464 if int(task_id) < self._task_id_threshold: 465 task_id = '_'.join([infos[3], task_id]) 466 self._task_id_full_op_name_dict[task_id] = infos[0] 467 468 def _parse_graph_files_and_save(self, task_cache): 469 """ 470 Parse the framework graph files and save the framework information. 471 472 Args: 473 task_cache (dict): The task information cache. 474 """ 475 with open(self._save_path, 'w') as save_file: 476 csv_writer = csv.writer(save_file) 477 csv_writer.writerow(self._col_names) 478 pre_graph_info = None 479 for path in self._framework_path['graph']: 480 first_row = True 481 with open(path, 'r') as graph_file: 482 for graph_info in graph_file: 483 if first_row is True: 484 first_row = False 485 # The last row of the previous file and the first row of the current file may need 486 # to be combined to one row 487 if graph_info.startswith("op_name:") is False: 488 pre_graph_info = pre_graph_info + graph_info 489 continue 490 if pre_graph_info is not None: 491 self._parse_graph_row_and_save(task_cache, csv_writer, pre_graph_info) 492 pre_graph_info = graph_info 493 494 if pre_graph_info is not None: 495 self._parse_graph_row_and_save(task_cache, csv_writer, pre_graph_info) 496 497 none_list = [None, None, None, None] 498 for key, value in task_cache.items(): 499 value.append(key) 500 value.extend(none_list) 501 csv_writer.writerow(value) 502 os.chmod(self._save_path, stat.S_IREAD | stat.S_IWRITE) 503 504 def _parse_graph_row_and_save(self, task_cache, csv_writer, graph_info): 505 """ 506 Parse the framework graph row and save the framework information. 507 508 Args: 509 task_cache (dict): The task information cache. 510 csv_writer (csv): Csv writer. 511 graph_info (str): Row info of graph. 512 """ 513 result = self._parse_one_row_graph_info(graph_info) 514 task_info = task_cache.get(result[0]) 515 if task_info: 516 task_info.extend(result) 517 csv_writer.writerow(task_info) 518 del task_cache[result[0]] 519 else: 520 save_info = [None, None, None] 521 save_info.extend(result) 522 csv_writer.writerow(save_info) 523 524 def _parse_one_row_graph_info(self, row_info): 525 """ 526 Parse the graph information in one row. 527 528 Args: 529 row_info (str): One row graph information. 530 531 Returns: 532 list[str], the parsed graph information. 533 """ 534 full_op_name = None 535 op_name = None 536 subgraph_name = None 537 op_type = None 538 op_info = dict() 539 cur_op_info_key = None 540 541 infos = row_info.strip('\n').split(' ') 542 for info in infos: 543 attr_name, attr_value = info.split(':', 1) 544 if attr_name == 'op_name': 545 full_op_name = attr_value 546 subgraph_name = self._get_subgraph_name(full_op_name) 547 op_name = self._get_op_name(full_op_name, subgraph_name) 548 elif attr_name == 'op_type': 549 op_type = attr_value 550 elif attr_name in ['input_id', 'output_id']: 551 cur_op_info_key = '{}_{}'.format( 552 attr_name.split('_')[0], attr_value 553 ) 554 op_info[cur_op_info_key] = dict() 555 elif attr_name in self._graph_attr_name: 556 op_attr = attr_name.split('_', 1)[1] 557 if op_attr == 'shape': 558 attr_value = attr_value.strip('"') 559 if self._backend_type == 'vm': 560 if op_attr == 'data_type': 561 attr_value = VmDataType.get_data_type_name( 562 int(attr_value) 563 ) 564 else: 565 if op_attr == 'data_type': 566 attr_value = GeDataType.get_data_type_name( 567 int(attr_value) 568 ) 569 elif op_attr == 'format': 570 attr_value = GeFormat.get_format_name(int(attr_value)) 571 572 op_info[cur_op_info_key][op_attr] = attr_value 573 574 # the list info are full_op_name, op_name, op_type, subgraph, op_info 575 return [full_op_name, op_name, op_type, subgraph_name, 576 json.dumps(op_info)] 577 578 def _get_subgraph_name(self, full_op_name): 579 """ 580 Get subgraph name. 581 582 Args: 583 full_op_name (str): The full operator name. 584 585 Returns: 586 str, the subgraph name. 587 """ 588 subgraph_name = full_op_name.split('/', 1)[0] 589 if subgraph_name in ['Default', 'Gradients']: 590 return subgraph_name 591 return None 592 593 def _get_op_name(self, full_op_name, subgraph_name): 594 """ 595 Get operator name. 596 597 Args: 598 full_op_name (str): The full operator name. 599 subgraph_name (str): The subgraph name. 600 601 Returns: 602 str, the operator name. 603 """ 604 if subgraph_name is None: 605 return full_op_name 606 607 if self._backend_type == 'vm': 608 return full_op_name.split('/')[-1] 609 610 strs = full_op_name.split(subgraph_name + '/') 611 op_name = None 612 for name_str in strs: 613 if not name_str: 614 continue 615 if op_name is None: 616 op_name = name_str.split('/')[-1] 617 else: 618 op_name = '+'.join([op_name, name_str.split('/')[-1]]) 619 return op_name 620 621 def _parse_point_files(self): 622 """Parse the framework point files.""" 623 for path in self._framework_path['point']: 624 path = validate_and_normalize_path(path) 625 with open(path, 'r') as file: 626 for point_info in file: 627 infos = point_info.strip('\n').split(' ') 628 self._point_info[int(infos[0])] = infos[1] 629