• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Record the summary event."""
16import atexit
17import os
18import re
19import threading
20import time
21from collections import defaultdict
22
23from mindspore import log as logger
24from mindspore.nn import Cell
25
26from ..._c_expression import Tensor, security
27from ..._checkparam import Validator
28from .._utils import _check_lineage_value, _check_to_numpy, _make_directory, check_value_type
29from ._summary_adapter import get_event_file_name, package_graph_event
30from ._explain_adapter import check_explain_proto
31from ._writer_pool import WriterPool
32
33# for the moment, this lock is for caution's sake,
34# there are actually no any concurrences happening.
35_summary_lock = threading.Lock()
36# cache the summary data
37_summary_tensor_cache = {}
38_DEFAULT_EXPORT_OPTIONS = {
39    'tensor_format': {'npy', None},
40}
41
42
43def _cache_summary_tensor_data(summary):
44    """
45    Get the time of ms.
46
47    Args:
48         summary (list): [{"name": tag_name, "data": tensor}, {"name": tag_name, "data": tensor},...].
49    """
50    with _summary_lock:
51        for item in summary:
52            _summary_tensor_cache[item['name']] = item['data']
53        return True
54
55
56def _get_summary_tensor_data():
57    global _summary_tensor_cache
58    with _summary_lock:
59        data = _summary_tensor_cache
60        _summary_tensor_cache = {}
61        return data
62
63
64def process_export_options(export_options):
65    """Check specified data type and value."""
66    if export_options is None:
67        return None
68
69    check_value_type('export_options', export_options, [dict, type(None)])
70
71    for export_option, export_format in export_options.items():
72        check_value_type('export_option', export_option, [str])
73        check_value_type('export_format', export_format, [str, type(None)])
74
75    unexpected_params = set(export_options) - set(_DEFAULT_EXPORT_OPTIONS)
76    if unexpected_params:
77        raise ValueError(f'For `export_options` the keys {unexpected_params} are unsupported, '
78                         f'expect the follow keys: {list(_DEFAULT_EXPORT_OPTIONS.keys())}')
79
80    for export_option, export_format in export_options.items():
81        unexpected_format = {export_format} - _DEFAULT_EXPORT_OPTIONS.get(export_option)
82        if unexpected_format:
83            raise ValueError(
84                f'For `export_options`, the export_format {unexpected_format} are unsupported for {export_option}, '
85                f'expect the follow values: {list(_DEFAULT_EXPORT_OPTIONS.get(export_option))}')
86
87    for item in set(export_options):
88        check_value_type(item, export_options.get(item), [str, type(None)])
89
90    return export_options
91
92
93class SummaryRecord:
94    """
95    SummaryRecord is used to record the summary data and lineage data.
96
97    The API will create a summary file and lineage files lazily in a given directory and writes data to them.
98    It writes the data to files by executing the 'record' method. In addition to recording the data bubbled up from
99    the network by defining the summary operators, SummaryRecord also supports to record extra data which
100    can be added by calling add_value.
101
102    Note:
103        1. Make sure to close the SummaryRecord at the end, otherwise the process will not exit.
104           Please see the Example section below to learn how to close properly in two ways.
105        2. Only one SummaryRecord instance is allowed at a time, otherwise it will cause data writing problems.
106        3. SummaryRecord only supports Linux systems.
107
108    Args:
109        log_dir (str): The log_dir is a directory location to save the summary.
110        file_prefix (str): The prefix of file. Default: "events".
111        file_suffix (str): The suffix of file. Default: "_MS".
112        network (Cell): Obtain a pipeline through network for saving graph summary. Default: None.
113        max_file_size (int, optional): The maximum size of each file that can be written to disk (in bytes).
114            For example, to write not larger than 4GB, specify `max_file_size=4*1024**3`.
115            Default: None, which means no limit.
116        raise_exception (bool, optional): Sets whether to throw an exception when a RuntimeError or OSError exception
117            occurs in recording data. Default: False, this means that error logs are printed and no exception is thrown.
118        export_options (Union[None, dict]): Perform custom operations on the export data.
119            Note that the size of export files is not limited by the max_file_size.
120            You can customize the export data with a dictionary. For example, you can set {'tensor_format': 'npy'}
121            to export tensor as npy file. The data that supports control is shown below. Default: None, it means that
122            the data is not exported.
123
124            - tensor_format (Union[str, None]): Customize the export tensor format. Supports ["npy", None].
125              Default: None, it means that the tensor is not exported.
126
127              - npy: export tensor as npy file.
128
129    Raises:
130        TypeError: If the parameter type is incorrect.
131
132    Examples:
133        >>> from mindspore.train.summary import SummaryRecord
134        >>> if __name__ == '__main__':
135        ...     # use in with statement to auto close
136        ...     with SummaryRecord(log_dir="./summary_dir") as summary_record:
137        ...         pass
138        ...
139        ...     # use in try .. finally .. to ensure closing
140        ...     try:
141        ...         summary_record = SummaryRecord(log_dir="./summary_dir")
142        ...     finally:
143        ...         summary_record.close()
144    """
145
146    def __init__(self, log_dir, file_prefix="events", file_suffix="_MS",
147                 network=None, max_file_size=None, raise_exception=False, export_options=None):
148
149        if security.enable_security():
150            raise ValueError('The Summary is not supported, please without `-s on` and recompile source.')
151
152        self._event_writer = None
153        self._mode, self._data_pool = 'train', defaultdict(list)
154        self._status = {
155            'closed': False,
156            'has_graph': False
157        }
158        self.file_info = {
159            'file_name': None,
160            'file_path': None
161        }
162        Validator.check_str_by_regular(file_prefix)
163        Validator.check_str_by_regular(file_suffix)
164
165        log_path = _make_directory(log_dir, "log_dir")
166
167        if not isinstance(max_file_size, (int, type(None))):
168            raise TypeError("The 'max_file_size' should be int type.")
169
170        if not isinstance(file_prefix, str) or not isinstance(file_suffix, str):
171            raise TypeError("`file_prefix` and `file_suffix`  should be str.")
172
173        if max_file_size is not None and max_file_size < 0:
174            logger.warning("The 'max_file_size' should be greater than 0.")
175            max_file_size = None
176
177        Validator.check_value_type(arg_name='raise_exception', arg_value=raise_exception, valid_types=bool)
178
179        self.network = network
180
181        time_second = str(int(time.time()))
182        # create the summary writer file
183        self.file_info['file_name'] = get_event_file_name(file_prefix, file_suffix, time_second)
184        self.file_info['file_path'] = os.path.join(log_path, self.file_info.get('file_name'))
185
186        self._export_options = process_export_options(export_options)
187        export_dir = ''
188        if self._export_options is not None:
189            export_dir = "export_{}".format(time_second)
190
191        filename_dict = dict(summary=self.file_info.get('file_name'),
192                             lineage=get_event_file_name(file_prefix, '_lineage', time_second),
193                             explainer=get_event_file_name(file_prefix, '_explain', time_second),
194                             exporter=export_dir)
195        self._event_writer = WriterPool(log_dir,
196                                        max_file_size,
197                                        raise_exception,
198                                        **filename_dict)
199        _get_summary_tensor_data()
200        atexit.register(self.close)
201
202    def __enter__(self):
203        """Enter the context manager."""
204        if self._status.get('closed'):
205            raise ValueError('SummaryRecord has been closed.')
206        return self
207
208    def __exit__(self, *err):
209        """Exit the context manager."""
210        self.close()
211
212    def set_mode(self, mode):
213        """
214        Set the training phase. Different training phases affect data recording.
215
216        Args:
217            mode (str): The mode to be set, which should be 'train' or 'eval'. When the mode is 'eval',
218                summary_record will not record the data of summary operators.
219
220        Raises:
221            ValueError: When the mode is not recognized.
222
223        Examples:
224            >>> from mindspore.train.summary import SummaryRecord
225            >>> if __name__ == '__main__':
226            ...     with SummaryRecord(log_dir="./summary_dir", file_prefix="xx_", file_suffix="_yy") as summary_record:
227            ...         summary_record.set_mode('eval')
228        """
229        mode_spec = 'train', 'eval'
230        if mode not in mode_spec:
231            raise ValueError(f'{repr(mode)} is not a recognized mode.')
232        self._mode = mode
233
234    def add_value(self, plugin, name, value):
235        """
236        Add value to be recorded later.
237
238        Args:
239            plugin (str): The value of the plugin.
240            name (str): The value of the name.
241            value (Union[Tensor, GraphProto, TrainLineage, EvaluationLineage, DatasetGraph, UserDefinedInfo]): \
242                The value to store.
243
244                - The data type of value should be 'GraphProto' (see mindspore/ccsrc/anf_ir.proto) object
245                  when the plugin is 'graph'.
246                - The data type of value should be 'Tensor' object when the plugin is 'scalar', 'image', 'tensor'
247                  or 'histogram'.
248                - The data type of value should be a 'TrainLineage' object when the plugin is 'train_lineage',
249                  see mindspore/ccsrc/lineage.proto.
250                - The data type of value should be a 'EvaluationLineage' object when the plugin is 'eval_lineage',
251                  see mindspore/ccsrc/lineage.proto.
252                - The data type of value should be a 'DatasetGraph' object when the plugin is 'dataset_graph',
253                  see mindspore/ccsrc/lineage.proto.
254                - The data type of value should be a 'UserDefinedInfo' object when the plugin is 'custom_lineage_data',
255                  see mindspore/ccsrc/lineage.proto.
256                - The data type of value should be a 'Explain' object when the plugin is 'explainer',
257                  see mindspore/ccsrc/summary.proto.
258        Raises:
259            ValueError: If the parameter value is invalid.
260            TypeError: If the parameter type is error.
261
262        Examples:
263            >>> from mindspore import Tensor
264            >>> from mindspore.train.summary import SummaryRecord
265            >>> if __name__ == '__main__':
266            ...     with SummaryRecord(log_dir="./summary_dir", file_prefix="xx_", file_suffix="_yy") as summary_record:
267            ...         summary_record.add_value('scalar', 'loss', Tensor(0.1))
268        """
269        if plugin in ('tensor', 'scalar', 'image', 'histogram'):
270            if not name or not isinstance(name, str):
271                raise ValueError(f'{repr(name)} is not a valid tag name.')
272            if not isinstance(value, Tensor):
273                raise TypeError(f'Expect the value to be Tensor, but got {type(value).__name__}')
274            np_value = _check_to_numpy(plugin, value)
275            if name in {item['tag'] for item in self._data_pool[plugin]}:
276                entry = repr(f'{name}/{plugin}')
277                logger.warning(f'{entry} has duplicate values. Only the newest one will be recorded.')
278            data = dict(tag=name, value=np_value)
279            export_plugin = '{}_format'.format(plugin)
280            if self._export_options is not None and export_plugin in self._export_options:
281                data['export_option'] = self._export_options.get(export_plugin)
282            self._data_pool[plugin].append(data)
283
284        elif plugin in ('train_lineage', 'eval_lineage', 'dataset_graph', 'custom_lineage_data'):
285            _check_lineage_value(plugin, value)
286            self._data_pool[plugin].append(dict(value=value.SerializeToString()))
287        elif plugin == 'graph':
288            package_graph_event(value)
289            self._data_pool[plugin].append(dict(value=value))
290        elif plugin == 'explainer':
291            check_explain_proto(value)
292            self._data_pool[plugin].append(dict(value=value.SerializeToString()))
293        else:
294            raise ValueError(f'No such plugin of {repr(plugin)}')
295
296    def record(self, step, train_network=None, plugin_filter=None):
297        """
298        Record the summary.
299
300        Args:
301            step (int): Represents training step number.
302            train_network (Cell): The spare network for saving graph.
303                Default: None, it means just do not saving the graph summary when the original network graph is None.
304            plugin_filter (Optional[Callable[[str], bool]]): The filter function, \
305                which is used to filter out plugins from being written by returning False. Default: None.
306
307        Returns:
308            bool, whether the record process is successful or not.
309
310        Raises:
311            TypeError: If the parameter type is error.
312            RuntimeError: If the disk space is insufficient.
313
314        Examples:
315            >>> from mindspore.train.summary import SummaryRecord
316            >>> if __name__ == '__main__':
317            ...     with SummaryRecord(log_dir="./summary_dir", file_prefix="xx_", file_suffix="_yy") as summary_record:
318            ...         summary_record.record(step=2)
319            ...
320            True
321        """
322        logger.debug("SummaryRecord step is %r.", step)
323        Validator.check_value_type(arg_name='step', arg_value=step, valid_types=int)
324        Validator.check_value_type(arg_name='train_network', arg_value=train_network, valid_types=[Cell, type(None)])
325
326        if self._status.get('closed'):
327            logger.error("The record writer is closed.")
328            return False
329        # Set the current summary of train step
330        if self.network is not None and not self._status.get('has_graph'):
331            graph_proto = self.network.get_func_graph_proto()
332            if graph_proto is None and train_network is not None:
333                graph_proto = train_network.get_func_graph_proto()
334            if graph_proto is None:
335                logger.error("Failed to get proto for graph")
336            else:
337                self._event_writer.write({'graph': [{'step': step, 'value': graph_proto}]})
338                self._status['has_graph'] = True
339                if not _summary_tensor_cache:
340                    return True
341
342        if self._mode == 'train':
343            self._add_summary_tensor_data()
344
345        if not plugin_filter:
346            self._event_writer.write(self._consume_data_pool(step))
347        else:
348            filtered = {}
349            for plugin, datalist in self._consume_data_pool(step).items():
350                if plugin_filter(plugin):
351                    filtered[plugin] = datalist
352            self._event_writer.write(filtered)
353        return True
354
355    def _add_summary_tensor_data(self):
356        summary_data = _get_summary_tensor_data()
357        if not summary_data:
358            logger.debug(f'No summary data bubbled from the network.')
359        for name, tensor in summary_data.items():
360            tag, plugin = SummaryRecord._parse_from(name)
361            if (tag, plugin) == (None, None):
362                logger.warning("The name(%r) is invalid, expected 'TAG[:TYPE]'.", name)
363            else:
364                self.add_value(plugin.lower(), tag, tensor)
365
366    def _consume_data_pool(self, step):
367        try:
368            for values in self._data_pool.values():
369                for value in values:
370                    value['step'] = step
371            return self._data_pool
372        finally:
373            self._data_pool = defaultdict(list)
374
375    @property
376    def log_dir(self):
377        """
378        Get the full path of the log file.
379
380        Returns:
381            str, the full path of log file.
382
383        Examples:
384            >>> from mindspore.train.summary import SummaryRecord
385            >>> if __name__ == '__main__':
386            ...     with SummaryRecord(log_dir="./summary_dir", file_prefix="xx_", file_suffix="_yy") as summary_record:
387            ...         log_dir = summary_record.log_dir
388        """
389        return self.file_info['file_path']
390
391    def flush(self):
392        """
393        Flush the event file to disk.
394
395        Call it to make sure that all pending events have been written to disk.
396
397        Examples:
398            >>> from mindspore.train.summary import SummaryRecord
399            >>> if __name__ == '__main__':
400            ...     with SummaryRecord(log_dir="./summary_dir", file_prefix="xx_", file_suffix="_yy") as summary_record:
401            ...         summary_record.flush()
402        """
403        if self._status.get('closed'):
404            logger.error("The record writer is closed and can not flush.")
405        elif self._event_writer:
406            self._event_writer.flush()
407
408    def close(self):
409        """
410        Flush all events and close summary records. Please use the statement to autoclose.
411
412        Examples:
413            >>> from mindspore.train.summary import SummaryRecord
414            >>> if __name__ == '__main__':
415            ...     try:
416            ...         summary_record = SummaryRecord(log_dir="./summary_dir")
417            ...     finally:
418            ...         summary_record.close()
419        """
420        if not self._status.get('closed') and self._event_writer:
421            # event writer flush and close
422            logger.info('Please wait it may take quite some time to finish writing and closing.')
423            atexit.unregister(self.close)
424            self._event_writer.close()
425            self._event_writer.join()
426            self._status['closed'] = True
427
428    @staticmethod
429    def _parse_from(name: str = None):
430        """Parse the tag and type from name."""
431        if not isinstance(name, str):
432            return None, None
433        match = re.match(r'(.+)\[:(.+)\]', name)
434        if match:
435            return match.groups()
436        return None, None
437