1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Record the summary event.""" 16import atexit 17import os 18import re 19import threading 20import time 21from collections import defaultdict 22 23from mindspore import log as logger 24from mindspore.nn import Cell 25 26from ..._c_expression import Tensor, security 27from ..._checkparam import Validator 28from .._utils import _check_lineage_value, _check_to_numpy, _make_directory, check_value_type 29from ._summary_adapter import get_event_file_name, package_graph_event 30from ._explain_adapter import check_explain_proto 31from ._writer_pool import WriterPool 32 33# for the moment, this lock is for caution's sake, 34# there are actually no any concurrences happening. 35_summary_lock = threading.Lock() 36# cache the summary data 37_summary_tensor_cache = {} 38_DEFAULT_EXPORT_OPTIONS = { 39 'tensor_format': {'npy', None}, 40} 41 42 43def _cache_summary_tensor_data(summary): 44 """ 45 Get the time of ms. 46 47 Args: 48 summary (list): [{"name": tag_name, "data": tensor}, {"name": tag_name, "data": tensor},...]. 49 """ 50 with _summary_lock: 51 for item in summary: 52 _summary_tensor_cache[item['name']] = item['data'] 53 return True 54 55 56def _get_summary_tensor_data(): 57 global _summary_tensor_cache 58 with _summary_lock: 59 data = _summary_tensor_cache 60 _summary_tensor_cache = {} 61 return data 62 63 64def process_export_options(export_options): 65 """Check specified data type and value.""" 66 if export_options is None: 67 return None 68 69 check_value_type('export_options', export_options, [dict, type(None)]) 70 71 for export_option, export_format in export_options.items(): 72 check_value_type('export_option', export_option, [str]) 73 check_value_type('export_format', export_format, [str, type(None)]) 74 75 unexpected_params = set(export_options) - set(_DEFAULT_EXPORT_OPTIONS) 76 if unexpected_params: 77 raise ValueError(f'For `export_options` the keys {unexpected_params} are unsupported, ' 78 f'expect the follow keys: {list(_DEFAULT_EXPORT_OPTIONS.keys())}') 79 80 for export_option, export_format in export_options.items(): 81 unexpected_format = {export_format} - _DEFAULT_EXPORT_OPTIONS.get(export_option) 82 if unexpected_format: 83 raise ValueError( 84 f'For `export_options`, the export_format {unexpected_format} are unsupported for {export_option}, ' 85 f'expect the follow values: {list(_DEFAULT_EXPORT_OPTIONS.get(export_option))}') 86 87 for item in set(export_options): 88 check_value_type(item, export_options.get(item), [str, type(None)]) 89 90 return export_options 91 92 93class SummaryRecord: 94 """ 95 SummaryRecord is used to record the summary data and lineage data. 96 97 The API will create a summary file and lineage files lazily in a given directory and writes data to them. 98 It writes the data to files by executing the 'record' method. In addition to recording the data bubbled up from 99 the network by defining the summary operators, SummaryRecord also supports to record extra data which 100 can be added by calling add_value. 101 102 Note: 103 1. Make sure to close the SummaryRecord at the end, otherwise the process will not exit. 104 Please see the Example section below to learn how to close properly in two ways. 105 2. Only one SummaryRecord instance is allowed at a time, otherwise it will cause data writing problems. 106 3. SummaryRecord only supports Linux systems. 107 108 Args: 109 log_dir (str): The log_dir is a directory location to save the summary. 110 file_prefix (str): The prefix of file. Default: "events". 111 file_suffix (str): The suffix of file. Default: "_MS". 112 network (Cell): Obtain a pipeline through network for saving graph summary. Default: None. 113 max_file_size (int, optional): The maximum size of each file that can be written to disk (in bytes). 114 For example, to write not larger than 4GB, specify `max_file_size=4*1024**3`. 115 Default: None, which means no limit. 116 raise_exception (bool, optional): Sets whether to throw an exception when a RuntimeError or OSError exception 117 occurs in recording data. Default: False, this means that error logs are printed and no exception is thrown. 118 export_options (Union[None, dict]): Perform custom operations on the export data. 119 Note that the size of export files is not limited by the max_file_size. 120 You can customize the export data with a dictionary. For example, you can set {'tensor_format': 'npy'} 121 to export tensor as npy file. The data that supports control is shown below. Default: None, it means that 122 the data is not exported. 123 124 - tensor_format (Union[str, None]): Customize the export tensor format. Supports ["npy", None]. 125 Default: None, it means that the tensor is not exported. 126 127 - npy: export tensor as npy file. 128 129 Raises: 130 TypeError: If the parameter type is incorrect. 131 132 Examples: 133 >>> from mindspore.train.summary import SummaryRecord 134 >>> if __name__ == '__main__': 135 ... # use in with statement to auto close 136 ... with SummaryRecord(log_dir="./summary_dir") as summary_record: 137 ... pass 138 ... 139 ... # use in try .. finally .. to ensure closing 140 ... try: 141 ... summary_record = SummaryRecord(log_dir="./summary_dir") 142 ... finally: 143 ... summary_record.close() 144 """ 145 146 def __init__(self, log_dir, file_prefix="events", file_suffix="_MS", 147 network=None, max_file_size=None, raise_exception=False, export_options=None): 148 149 if security.enable_security(): 150 raise ValueError('The Summary is not supported, please without `-s on` and recompile source.') 151 152 self._event_writer = None 153 self._mode, self._data_pool = 'train', defaultdict(list) 154 self._status = { 155 'closed': False, 156 'has_graph': False 157 } 158 self.file_info = { 159 'file_name': None, 160 'file_path': None 161 } 162 Validator.check_str_by_regular(file_prefix) 163 Validator.check_str_by_regular(file_suffix) 164 165 log_path = _make_directory(log_dir, "log_dir") 166 167 if not isinstance(max_file_size, (int, type(None))): 168 raise TypeError("The 'max_file_size' should be int type.") 169 170 if not isinstance(file_prefix, str) or not isinstance(file_suffix, str): 171 raise TypeError("`file_prefix` and `file_suffix` should be str.") 172 173 if max_file_size is not None and max_file_size < 0: 174 logger.warning("The 'max_file_size' should be greater than 0.") 175 max_file_size = None 176 177 Validator.check_value_type(arg_name='raise_exception', arg_value=raise_exception, valid_types=bool) 178 179 self.network = network 180 181 time_second = str(int(time.time())) 182 # create the summary writer file 183 self.file_info['file_name'] = get_event_file_name(file_prefix, file_suffix, time_second) 184 self.file_info['file_path'] = os.path.join(log_path, self.file_info.get('file_name')) 185 186 self._export_options = process_export_options(export_options) 187 export_dir = '' 188 if self._export_options is not None: 189 export_dir = "export_{}".format(time_second) 190 191 filename_dict = dict(summary=self.file_info.get('file_name'), 192 lineage=get_event_file_name(file_prefix, '_lineage', time_second), 193 explainer=get_event_file_name(file_prefix, '_explain', time_second), 194 exporter=export_dir) 195 self._event_writer = WriterPool(log_dir, 196 max_file_size, 197 raise_exception, 198 **filename_dict) 199 _get_summary_tensor_data() 200 atexit.register(self.close) 201 202 def __enter__(self): 203 """Enter the context manager.""" 204 if self._status.get('closed'): 205 raise ValueError('SummaryRecord has been closed.') 206 return self 207 208 def __exit__(self, *err): 209 """Exit the context manager.""" 210 self.close() 211 212 def set_mode(self, mode): 213 """ 214 Set the training phase. Different training phases affect data recording. 215 216 Args: 217 mode (str): The mode to be set, which should be 'train' or 'eval'. When the mode is 'eval', 218 summary_record will not record the data of summary operators. 219 220 Raises: 221 ValueError: When the mode is not recognized. 222 223 Examples: 224 >>> from mindspore.train.summary import SummaryRecord 225 >>> if __name__ == '__main__': 226 ... with SummaryRecord(log_dir="./summary_dir", file_prefix="xx_", file_suffix="_yy") as summary_record: 227 ... summary_record.set_mode('eval') 228 """ 229 mode_spec = 'train', 'eval' 230 if mode not in mode_spec: 231 raise ValueError(f'{repr(mode)} is not a recognized mode.') 232 self._mode = mode 233 234 def add_value(self, plugin, name, value): 235 """ 236 Add value to be recorded later. 237 238 Args: 239 plugin (str): The value of the plugin. 240 name (str): The value of the name. 241 value (Union[Tensor, GraphProto, TrainLineage, EvaluationLineage, DatasetGraph, UserDefinedInfo]): \ 242 The value to store. 243 244 - The data type of value should be 'GraphProto' (see mindspore/ccsrc/anf_ir.proto) object 245 when the plugin is 'graph'. 246 - The data type of value should be 'Tensor' object when the plugin is 'scalar', 'image', 'tensor' 247 or 'histogram'. 248 - The data type of value should be a 'TrainLineage' object when the plugin is 'train_lineage', 249 see mindspore/ccsrc/lineage.proto. 250 - The data type of value should be a 'EvaluationLineage' object when the plugin is 'eval_lineage', 251 see mindspore/ccsrc/lineage.proto. 252 - The data type of value should be a 'DatasetGraph' object when the plugin is 'dataset_graph', 253 see mindspore/ccsrc/lineage.proto. 254 - The data type of value should be a 'UserDefinedInfo' object when the plugin is 'custom_lineage_data', 255 see mindspore/ccsrc/lineage.proto. 256 - The data type of value should be a 'Explain' object when the plugin is 'explainer', 257 see mindspore/ccsrc/summary.proto. 258 Raises: 259 ValueError: If the parameter value is invalid. 260 TypeError: If the parameter type is error. 261 262 Examples: 263 >>> from mindspore import Tensor 264 >>> from mindspore.train.summary import SummaryRecord 265 >>> if __name__ == '__main__': 266 ... with SummaryRecord(log_dir="./summary_dir", file_prefix="xx_", file_suffix="_yy") as summary_record: 267 ... summary_record.add_value('scalar', 'loss', Tensor(0.1)) 268 """ 269 if plugin in ('tensor', 'scalar', 'image', 'histogram'): 270 if not name or not isinstance(name, str): 271 raise ValueError(f'{repr(name)} is not a valid tag name.') 272 if not isinstance(value, Tensor): 273 raise TypeError(f'Expect the value to be Tensor, but got {type(value).__name__}') 274 np_value = _check_to_numpy(plugin, value) 275 if name in {item['tag'] for item in self._data_pool[plugin]}: 276 entry = repr(f'{name}/{plugin}') 277 logger.warning(f'{entry} has duplicate values. Only the newest one will be recorded.') 278 data = dict(tag=name, value=np_value) 279 export_plugin = '{}_format'.format(plugin) 280 if self._export_options is not None and export_plugin in self._export_options: 281 data['export_option'] = self._export_options.get(export_plugin) 282 self._data_pool[plugin].append(data) 283 284 elif plugin in ('train_lineage', 'eval_lineage', 'dataset_graph', 'custom_lineage_data'): 285 _check_lineage_value(plugin, value) 286 self._data_pool[plugin].append(dict(value=value.SerializeToString())) 287 elif plugin == 'graph': 288 package_graph_event(value) 289 self._data_pool[plugin].append(dict(value=value)) 290 elif plugin == 'explainer': 291 check_explain_proto(value) 292 self._data_pool[plugin].append(dict(value=value.SerializeToString())) 293 else: 294 raise ValueError(f'No such plugin of {repr(plugin)}') 295 296 def record(self, step, train_network=None, plugin_filter=None): 297 """ 298 Record the summary. 299 300 Args: 301 step (int): Represents training step number. 302 train_network (Cell): The spare network for saving graph. 303 Default: None, it means just do not saving the graph summary when the original network graph is None. 304 plugin_filter (Optional[Callable[[str], bool]]): The filter function, \ 305 which is used to filter out plugins from being written by returning False. Default: None. 306 307 Returns: 308 bool, whether the record process is successful or not. 309 310 Raises: 311 TypeError: If the parameter type is error. 312 RuntimeError: If the disk space is insufficient. 313 314 Examples: 315 >>> from mindspore.train.summary import SummaryRecord 316 >>> if __name__ == '__main__': 317 ... with SummaryRecord(log_dir="./summary_dir", file_prefix="xx_", file_suffix="_yy") as summary_record: 318 ... summary_record.record(step=2) 319 ... 320 True 321 """ 322 logger.debug("SummaryRecord step is %r.", step) 323 Validator.check_value_type(arg_name='step', arg_value=step, valid_types=int) 324 Validator.check_value_type(arg_name='train_network', arg_value=train_network, valid_types=[Cell, type(None)]) 325 326 if self._status.get('closed'): 327 logger.error("The record writer is closed.") 328 return False 329 # Set the current summary of train step 330 if self.network is not None and not self._status.get('has_graph'): 331 graph_proto = self.network.get_func_graph_proto() 332 if graph_proto is None and train_network is not None: 333 graph_proto = train_network.get_func_graph_proto() 334 if graph_proto is None: 335 logger.error("Failed to get proto for graph") 336 else: 337 self._event_writer.write({'graph': [{'step': step, 'value': graph_proto}]}) 338 self._status['has_graph'] = True 339 if not _summary_tensor_cache: 340 return True 341 342 if self._mode == 'train': 343 self._add_summary_tensor_data() 344 345 if not plugin_filter: 346 self._event_writer.write(self._consume_data_pool(step)) 347 else: 348 filtered = {} 349 for plugin, datalist in self._consume_data_pool(step).items(): 350 if plugin_filter(plugin): 351 filtered[plugin] = datalist 352 self._event_writer.write(filtered) 353 return True 354 355 def _add_summary_tensor_data(self): 356 summary_data = _get_summary_tensor_data() 357 if not summary_data: 358 logger.debug(f'No summary data bubbled from the network.') 359 for name, tensor in summary_data.items(): 360 tag, plugin = SummaryRecord._parse_from(name) 361 if (tag, plugin) == (None, None): 362 logger.warning("The name(%r) is invalid, expected 'TAG[:TYPE]'.", name) 363 else: 364 self.add_value(plugin.lower(), tag, tensor) 365 366 def _consume_data_pool(self, step): 367 try: 368 for values in self._data_pool.values(): 369 for value in values: 370 value['step'] = step 371 return self._data_pool 372 finally: 373 self._data_pool = defaultdict(list) 374 375 @property 376 def log_dir(self): 377 """ 378 Get the full path of the log file. 379 380 Returns: 381 str, the full path of log file. 382 383 Examples: 384 >>> from mindspore.train.summary import SummaryRecord 385 >>> if __name__ == '__main__': 386 ... with SummaryRecord(log_dir="./summary_dir", file_prefix="xx_", file_suffix="_yy") as summary_record: 387 ... log_dir = summary_record.log_dir 388 """ 389 return self.file_info['file_path'] 390 391 def flush(self): 392 """ 393 Flush the event file to disk. 394 395 Call it to make sure that all pending events have been written to disk. 396 397 Examples: 398 >>> from mindspore.train.summary import SummaryRecord 399 >>> if __name__ == '__main__': 400 ... with SummaryRecord(log_dir="./summary_dir", file_prefix="xx_", file_suffix="_yy") as summary_record: 401 ... summary_record.flush() 402 """ 403 if self._status.get('closed'): 404 logger.error("The record writer is closed and can not flush.") 405 elif self._event_writer: 406 self._event_writer.flush() 407 408 def close(self): 409 """ 410 Flush all events and close summary records. Please use the statement to autoclose. 411 412 Examples: 413 >>> from mindspore.train.summary import SummaryRecord 414 >>> if __name__ == '__main__': 415 ... try: 416 ... summary_record = SummaryRecord(log_dir="./summary_dir") 417 ... finally: 418 ... summary_record.close() 419 """ 420 if not self._status.get('closed') and self._event_writer: 421 # event writer flush and close 422 logger.info('Please wait it may take quite some time to finish writing and closing.') 423 atexit.unregister(self.close) 424 self._event_writer.close() 425 self._event_writer.join() 426 self._status['closed'] = True 427 428 @staticmethod 429 def _parse_from(name: str = None): 430 """Parse the tag and type from name.""" 431 if not isinstance(name, str): 432 return None, None 433 match = re.match(r'(.+)\[:(.+)\]', name) 434 if match: 435 return match.groups() 436 return None, None 437