• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Thr parser for parsing framework files."""
16import csv
17import enum
18import json
19import os
20import re
21import stat
22
23from mindspore.profiler.common.exceptions.exceptions import \
24    ProfilerPathErrorException, ProfilerDirNotFoundException, \
25    ProfilerFileNotFoundException, ProfilerDeviceIdMismatchException, \
26    ProfilerRawFileException, ProfilerParamValueErrorException
27from mindspore.profiler.common.validator.validate_path import \
28    validate_and_normalize_path
29
30
31class VmDataType(enum.IntEnum):
32    """Definition of vm data type."""
33    NUMBER_TYPE_BEGIN = 30
34    NUMBER_TYPE_BOOL = 31
35    NUMBER_TYPE_INT = 32
36    NUMBER_TYPE_INT8 = 33
37    NUMBER_TYPE_INT16 = 34
38    NUMBER_TYPE_INT32 = 35
39    NUMBER_TYPE_INT64 = 36
40    NUMBER_TYPE_UINT = 37
41    NUMBER_TYPE_UINT8 = 38
42    NUMBER_TYPE_UINT16 = 39
43    NUMBER_TYPE_UINT32 = 40
44    NUMBER_TYPE_UINT64 = 41
45    NUMBER_TYPE_FLOAT = 42
46    NUMBER_TYPE_FLOAT16 = 43
47    NUMBER_TYPE_FLOAT32 = 44
48    NUMBER_TYPE_FLOAT64 = 45
49    NUMBER_TYPE_COMPLEX = 46
50    NUMBER_TYPE_END = 47
51
52    @classmethod
53    def get_data_type_name(cls, num):
54        """
55        Get the name of data type by enum number.
56
57        Args:
58            num (int): Enum number.
59
60        Returns:
61            str, the name of data type.
62        """
63        data_type = cls._value2member_map_.get(num)
64        return 'UNKNOWN' if data_type is None else data_type.name
65
66
67class GeDataType(enum.IntEnum):
68    """Definition of ge data type."""
69    DT_FLOAT = 0
70    DT_FLOAT16 = 1
71    DT_INT8 = 2
72    DT_INT16 = 6
73    DT_UINT16 = 7
74    DT_UINT8 = 4
75    DT_INT32 = 3
76    DT_INT64 = 9
77    DT_UINT32 = 8
78    DT_UINT64 = 10
79    DT_BOOL = 12
80    DT_DOUBLE = 11
81    DT_STRING = 13
82    DT_DUAL_SUB_INT8 = 14
83    DT_DUAL_SUB_UINT8 = 15
84    DT_COMPLEX64 = 16
85    DT_COMPLEX128 = 17
86    DT_QINT8 = 18
87    DT_QINT16 = 19
88    DT_QINT32 = 20
89    DT_QUINT8 = 21
90    DT_QUINT16 = 22
91    DT_RESOURCE = 23
92    DT_STRING_REF = 24
93    DT_DUAL = 25
94    DT_UNDEFINED = 26
95
96    @classmethod
97    def get_data_type_name(cls, num):
98        """
99        Get the name of data type by enum number.
100
101        Args:
102            num (int): Enum number.
103
104        Returns:
105            str, the name of data type.
106        """
107        data_type = cls._value2member_map_.get(num)
108        return 'UNKNOWN' if data_type is None else data_type.name
109
110
111class GeFormat(enum.IntEnum):
112    """Definition of ge format type."""
113    FORMAT_NCHW = 0
114    FORMAT_NHWC = 1
115    FORMAT_ND = 2
116    FORMAT_NC1HWC0 = 3
117    FORMAT_FRACTAL_Z = 4
118    FORMAT_NC1C0HWPAD = 5
119    FORMAT_NHWC1C0 = 6
120    FORMAT_FSR_NCHW = 7
121    FORMAT_FRACTAL_DECONV = 8
122    FORMAT_C1HWNC0 = 9
123    FORMAT_FRACTAL_DECONV_TRANSPOSE = 10
124    FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS = 11
125    FORMAT_NC1HWC0_C04 = 12
126    FORMAT_FRACTAL_Z_C04 = 13
127    FORMAT_CHWN = 14
128    FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS = 15
129    FORMAT_HWCN = 16
130    FORMAT_NC1KHKWHWC0 = 17
131    FORMAT_BN_WEIGHT = 18
132    FORMAT_FILTER_HWCK = 19
133    FORMAT_HASHTABLE_LOOKUP_LOOKUPS = 20
134    FORMAT_HASHTABLE_LOOKUP_KEYS = 21
135    FORMAT_HASHTABLE_LOOKUP_VALUE = 22
136    FORMAT_HASHTABLE_LOOKUP_OUTPUT = 23
137    FORMAT_HASHTABLE_LOOKUP_HITS = 24
138    FORMAT_C1HWNCOC0 = 25
139    FORMAT_MD = 26
140    FORMAT_NDHWC = 27
141    FORMAT_FRACTAL_ZZ = 28
142    FORMAT_FRACTAL_NZ = 29
143    FORMAT_NCDHW = 30
144    FORMAT_DHWCN = 31
145    FORMAT_NDC1HWC0 = 32
146    FORMAT_FRACTAL_Z_3D = 33
147    FORMAT_CN = 34
148    FORMAT_NC = 35
149    FORMAT_DHWNC = 36
150    FORMAT_FRACTAL_Z_3D_TRANSPOSE = 37
151    FORMAT_RESERVED = 38
152    FORMAT_ALL = 39
153
154    @classmethod
155    def get_format_name(cls, num):
156        """
157        Get the name of format type by enum number.
158
159        Args:
160            num (int): Enum number.
161
162        Returns:
163            str, the name of format type.
164        """
165        format_type = cls._value2member_map_.get(num)
166        return 'UNKNOWN' if format_type is None else format_type.name
167
168
169class FrameworkParser:
170    """
171    Thr parser for parsing framework files.
172
173    Args:
174        profiling_id (str): The profiling ID.
175        device_id (str): The device ID.
176        rank_id (str): The rank ID.
177        output_path (str): The directory of the parsed file. Default: `./`.
178    """
179    _regex_framework = r'Framework\.(?P<data_type>.+)\.(?P<device_id>\d).+'
180    _regex_framework_in_data = r'Framework\.(?P<data_type>.+)\.' \
181                               r'(?P<device_id>\d)\.(?P<profiling_id>[a-zA-Z0-9]+).+'
182    _col_names = [
183        'task_id', 'stream_id', 'block_dim', 'full_op_name', 'op_name',
184        'op_type', 'subgraph', 'op_info'
185    ]
186    _graph_attr_name = [
187        'input_format', 'input_data_type', 'input_shape', 'output_format',
188        'output_data_type', 'output_shape'
189    ]
190
191    # if the task id is less than the task id threshold, The combination of
192    # task id and Stream id represents one operator, else the task id represents
193    # one operator
194    _task_id_threshold = 25000
195
196    def __init__(self, profiling_id, device_id, rank_id, output_path='./'):
197        self._raw_data_dir = output_path
198        self._profiling_path = self._get_raw_profiling_path(profiling_id)
199        self._backend_type = None
200        self._framework_path = {'graph': [], 'task': [], 'point': []}
201        self._search_file(profiling_id, device_id)
202        self._device_id = device_id
203        self._save_path = self._get_save_path(rank_id, output_path)
204        self._task_id_full_op_name_dict = {}
205        self._task_cache = {}
206        self._point_info = {}
207        self._parse_task_files()
208        self._parse_point_files()
209
210    @property
211    def save_path(self):
212        """
213        The property of save path.
214
215        Returns:
216            str, the save path.
217        """
218        return self._save_path
219
220    @property
221    def point_info(self):
222        """
223        The property of the framework point information.
224
225        Returns:
226            dict, the framework point information.
227        """
228        return self._point_info
229
230    def to_task_id_full_op_name_dict(self):
231        """
232        Get the task id and full operator name dict.
233
234        Returns:
235            dict, the task id and full operator name dict.
236        """
237        return self._task_id_full_op_name_dict
238
239    def parse(self):
240        """Parse the framework files."""
241        self._parse_graph_files_and_save(self._task_cache)
242        del self._task_cache
243
244    def check_op_name(self, op_name, is_prefix=True):
245        """
246        Check whether the operator name exists.
247
248        Args:
249            op_name (str): The operator name or operator name prefix.
250            is_prefix (bool): `True` if the op_name is prefix, else `False`.
251                Default: True.
252
253        Returns:
254            bool, `True` if the operator name does exist in framework file, else
255            `False`.
256        """
257        if not op_name:
258            raise ProfilerParamValueErrorException('The op_name should exist.')
259        for full_op_name in self._task_id_full_op_name_dict.values():
260            if full_op_name:
261                if is_prefix and full_op_name.startswith(op_name):
262                    return True
263                if not is_prefix and op_name == full_op_name:
264                    return True
265        return False
266
267    def _get_raw_profiling_path(self, profiling_id):
268        """
269        Get raw profiling path.
270
271        Args:
272            profiling_id (str): The profiling ID.
273
274        Returns:
275            str, the raw profiling path.
276
277        Raises:
278            ProfilerPathErrorException: If the profiling path is invalid.
279            ProfilerDirNotFoundException: If the profiling dir is not found.
280        """
281        profiling_path = os.path.join(self._raw_data_dir, profiling_id)
282        try:
283            profiling_path = validate_and_normalize_path(profiling_path)
284        except RuntimeError:
285            raise ProfilerPathErrorException('Profiling path is invalid.')
286        if not os.path.isdir(profiling_path):
287            raise ProfilerDirNotFoundException(profiling_path)
288        return profiling_path
289
290    def _search_file(self, profiling_id, device_id):
291        """
292        Search all framework files in raw profiling path.
293
294        Args:
295            profiling_id (str): The profiling ID.
296            device_id (str): The device ID.
297
298        Raises:
299            ProfilerFileNotFoundException: If the framework files are not found.
300        """
301        # first search in the JOB dir, and if not, search in the sub directory
302        # in the JOB
303        self._search_file_from_job_path(device_id, search_in_sub_path=False)
304        if self._backend_type is None:
305            self._search_file_from_job_path(device_id, search_in_sub_path=True)
306        self._search_file_from_data_path(profiling_id, device_id)
307
308        if self._backend_type is None:
309            raise ProfilerFileNotFoundException('Framework')
310        self._framework_path['graph'].sort()
311        self._framework_path['task'].sort()
312
313    def _search_file_from_job_path(self, device_id, search_in_sub_path=False):
314        """
315        Search framework files from job path.
316
317        Args:
318            device_id (str): The device ID.
319            search_in_sub_path (bool): `True` if search file in profiling dir,
320                else search in profiling sub dir. Default: False.
321
322        Raises:
323            ProfilerRawFileException: If the framework file type is inconsistent.
324            ProfilerDeviceIdMismatchException: If the device id is mismatch
325                with framework in the raw dir.
326        """
327        profiling_dir = os.path.join(self._profiling_path, 'data') \
328            if search_in_sub_path else self._profiling_path
329        if not os.path.isdir(profiling_dir):
330            return
331
332        files = os.listdir(profiling_dir)
333        for file in files:
334            pattern = re.search(self._regex_framework, file)
335            if not pattern or file.endswith('.done'):
336                continue
337            attrs = pattern.groupdict()
338
339            device_id_in_path = attrs.get('device_id')
340            if device_id_in_path != device_id:
341                raise ProfilerDeviceIdMismatchException()
342
343            data_type = attrs.get('data_type')
344            data_type = data_type.replace("host.", "")
345            if data_type.startswith('vm_'):
346                if self._backend_type and self._backend_type != 'vm':
347                    raise ProfilerRawFileException('Backend type is inconsistent.')
348                self._backend_type = 'vm'
349                _, data_type = data_type.split('_', 1)
350            else:
351                if self._backend_type and self._backend_type != 'ge':
352                    raise ProfilerRawFileException('Backend type is inconsistent.')
353                self._backend_type = 'ge'
354            if data_type.startswith('graph_desc_info'):
355                self._framework_path['graph'].append(
356                    os.path.join(profiling_dir, file)
357                )
358            elif data_type.startswith('task_desc_info'):
359                self._framework_path['task'].append(
360                    os.path.join(profiling_dir, file)
361                )
362            elif data_type.startswith('point'):
363                self._framework_path['point'].append(
364                    os.path.join(profiling_dir, file)
365                )
366
367    def _search_file_from_data_path(self, profiling_id, device_id):
368        """
369        Search framework files from data path.
370
371        Args:
372            profiling_id (str): The profiling ID.
373            device_id (str): The device ID.
374
375        Raises:
376            ProfilerRawFileException: If the framework file type is inconsistent.
377            ProfilerDeviceIdMismatchException: If the device id is mismatch
378                with framework in the raw dir.
379        """
380        profiling_data_path = os.path.join(
381            self._raw_data_dir, 'container', device_id, 'data'
382        )
383        if not os.path.isdir(profiling_data_path):
384            return
385
386        files = os.listdir(profiling_data_path)
387        for file in files:
388            pattern = re.search(self._regex_framework_in_data, file)
389            if not pattern or file.endswith('.done') or file.endswith('.zip'):
390                continue
391            attrs = pattern.groupdict()
392
393            profiling_id_in_path = attrs.get('profiling_id')
394            if profiling_id_in_path != profiling_id:
395                continue
396
397            device_id_in_path = attrs.get('device_id')
398            if device_id_in_path != device_id:
399                raise ProfilerDeviceIdMismatchException()
400
401            data_type = attrs.get('data_type')
402            data_type = data_type.replace("host.", "")
403            if data_type.startswith('vm_'):
404                if self._backend_type and self._backend_type != 'vm':
405                    raise ProfilerRawFileException('Backend type is inconsistent.')
406                self._backend_type = 'vm'
407                _, data_type = data_type.split('_', 1)
408            else:
409                if self._backend_type and self._backend_type != 'ge':
410                    raise ProfilerRawFileException('Backend type is inconsistent.')
411                self._backend_type = 'ge'
412            if data_type.startswith('graph_desc_info'):
413                self._framework_path['graph'].append(
414                    os.path.join(profiling_data_path, file)
415                )
416            elif data_type.startswith('task_desc_info'):
417                self._framework_path['task'].append(
418                    os.path.join(profiling_data_path, file)
419                )
420            elif data_type.startswith('point'):
421                self._framework_path['point'].append(
422                    os.path.join(profiling_data_path, file)
423                )
424
425    def _get_save_path(self, rank_id, output_path):
426        """
427        Get the save path.
428
429        Args:
430            rank_id (str): The rank ID.
431            output_path (str): The output dir.
432
433        Returns:
434            str, the save path.
435
436        Raises:
437            ProfilerPathErrorException: If the output path is invalid.
438            ProfilerDirNotFoundException: If the output dir is not found.
439        """
440        try:
441            output_dir = validate_and_normalize_path(output_path)
442        except RuntimeError:
443            raise ProfilerPathErrorException('Output path is invalid.')
444        if not os.path.isdir(output_dir):
445            raise ProfilerDirNotFoundException(output_dir)
446        return os.path.join(
447            output_dir, '_'.join(['framework', 'raw', rank_id]) + '.csv'
448        )
449
450    def _parse_task_files(self):
451        """Parse the framework task files."""
452        for path in self._framework_path['task']:
453            path = validate_and_normalize_path(path)
454            with open(path, 'r') as file:
455                for task_info in file:
456                    infos = task_info.strip('\n').split(' ')
457                    infos = infos[1:] if len(infos) == 5 else infos
458                    # key is op name, values is task id, stream id, block_dim
459                    self._task_cache[infos[0]] = [infos[2], infos[3], infos[1]]
460
461                    # if the task id is less than the task id threshold, the
462                    # stream id and task id correspond to an operator
463                    task_id = infos[2]
464                    if int(task_id) < self._task_id_threshold:
465                        task_id = '_'.join([infos[3], task_id])
466                    self._task_id_full_op_name_dict[task_id] = infos[0]
467
468    def _parse_graph_files_and_save(self, task_cache):
469        """
470        Parse the framework graph files and save the framework information.
471
472        Args:
473            task_cache (dict): The task information cache.
474        """
475        with open(self._save_path, 'w') as save_file:
476            csv_writer = csv.writer(save_file)
477            csv_writer.writerow(self._col_names)
478            pre_graph_info = None
479            for path in self._framework_path['graph']:
480                first_row = True
481                with open(path, 'r') as graph_file:
482                    for graph_info in graph_file:
483                        if first_row is True:
484                            first_row = False
485                            # The last row of the previous file and the first row of the current file may need
486                            # to be combined to one row
487                            if graph_info.startswith("op_name:") is False:
488                                pre_graph_info = pre_graph_info + graph_info
489                                continue
490                        if pre_graph_info is not None:
491                            self._parse_graph_row_and_save(task_cache, csv_writer, pre_graph_info)
492                        pre_graph_info = graph_info
493
494            if pre_graph_info is not None:
495                self._parse_graph_row_and_save(task_cache, csv_writer, pre_graph_info)
496
497            none_list = [None, None, None, None]
498            for key, value in task_cache.items():
499                value.append(key)
500                value.extend(none_list)
501                csv_writer.writerow(value)
502        os.chmod(self._save_path, stat.S_IREAD | stat.S_IWRITE)
503
504    def _parse_graph_row_and_save(self, task_cache, csv_writer, graph_info):
505        """
506        Parse the framework graph row and save the framework information.
507
508        Args:
509            task_cache (dict): The task information cache.
510            csv_writer (csv): Csv writer.
511            graph_info (str): Row info of graph.
512        """
513        result = self._parse_one_row_graph_info(graph_info)
514        task_info = task_cache.get(result[0])
515        if task_info:
516            task_info.extend(result)
517            csv_writer.writerow(task_info)
518            del task_cache[result[0]]
519        else:
520            save_info = [None, None, None]
521            save_info.extend(result)
522            csv_writer.writerow(save_info)
523
524    def _parse_one_row_graph_info(self, row_info):
525        """
526        Parse the graph information in one row.
527
528        Args:
529            row_info (str): One row graph information.
530
531        Returns:
532            list[str], the parsed graph information.
533        """
534        full_op_name = None
535        op_name = None
536        subgraph_name = None
537        op_type = None
538        op_info = dict()
539        cur_op_info_key = None
540
541        infos = row_info.strip('\n').split(' ')
542        for info in infos:
543            attr_name, attr_value = info.split(':', 1)
544            if attr_name == 'op_name':
545                full_op_name = attr_value
546                subgraph_name = self._get_subgraph_name(full_op_name)
547                op_name = self._get_op_name(full_op_name, subgraph_name)
548            elif attr_name == 'op_type':
549                op_type = attr_value
550            elif attr_name in ['input_id', 'output_id']:
551                cur_op_info_key = '{}_{}'.format(
552                    attr_name.split('_')[0], attr_value
553                )
554                op_info[cur_op_info_key] = dict()
555            elif attr_name in self._graph_attr_name:
556                op_attr = attr_name.split('_', 1)[1]
557                if op_attr == 'shape':
558                    attr_value = attr_value.strip('"')
559                if self._backend_type == 'vm':
560                    if op_attr == 'data_type':
561                        attr_value = VmDataType.get_data_type_name(
562                            int(attr_value)
563                        )
564                else:
565                    if op_attr == 'data_type':
566                        attr_value = GeDataType.get_data_type_name(
567                            int(attr_value)
568                        )
569                    elif op_attr == 'format':
570                        attr_value = GeFormat.get_format_name(int(attr_value))
571
572                op_info[cur_op_info_key][op_attr] = attr_value
573
574        # the list info are full_op_name, op_name, op_type, subgraph, op_info
575        return [full_op_name, op_name, op_type, subgraph_name,
576                json.dumps(op_info)]
577
578    def _get_subgraph_name(self, full_op_name):
579        """
580        Get subgraph name.
581
582        Args:
583            full_op_name (str): The full operator name.
584
585        Returns:
586            str, the subgraph name.
587        """
588        subgraph_name = full_op_name.split('/', 1)[0]
589        if subgraph_name in ['Default', 'Gradients']:
590            return subgraph_name
591        return None
592
593    def _get_op_name(self, full_op_name, subgraph_name):
594        """
595        Get operator name.
596
597        Args:
598            full_op_name (str): The full operator name.
599            subgraph_name (str): The subgraph name.
600
601        Returns:
602            str, the operator name.
603        """
604        if subgraph_name is None:
605            return full_op_name
606
607        if self._backend_type == 'vm':
608            return full_op_name.split('/')[-1]
609
610        strs = full_op_name.split(subgraph_name + '/')
611        op_name = None
612        for name_str in strs:
613            if not name_str:
614                continue
615            if op_name is None:
616                op_name = name_str.split('/')[-1]
617            else:
618                op_name = '+'.join([op_name, name_str.split('/')[-1]])
619        return op_name
620
621    def _parse_point_files(self):
622        """Parse the framework point files."""
623        for path in self._framework_path['point']:
624            path = validate_and_normalize_path(path)
625            with open(path, 'r') as file:
626                for point_info in file:
627                    infos = point_info.strip('\n').split(' ')
628                    self._point_info[int(infos[0])] = infos[1]
629