1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Write events to disk in a base directory.""" 16import os 17import time 18import signal 19import queue 20from collections import deque 21 22import psutil 23 24import mindspore.log as logger 25from mindspore.train.summary.enums import PluginEnum, WriterPluginEnum 26 27from ._lineage_adapter import serialize_to_lineage_event 28from ._summary_adapter import package_graph_event, package_summary_event 29from ._explain_adapter import package_explain_event 30from .writer import LineageWriter, SummaryWriter, ExplainWriter, ExportWriter 31 32try: 33 from multiprocessing import get_context 34 ctx = get_context('forkserver') 35except ValueError: 36 import multiprocessing as ctx 37 38 39def _pack_data(datadict, wall_time): 40 """Pack data according to which plugin.""" 41 result, summaries, step = [], [], None 42 for plugin, datalist in datadict.items(): 43 for data in datalist: 44 if plugin == PluginEnum.GRAPH.value: 45 result.append([plugin, package_graph_event(data.get('value')).SerializeToString()]) 46 elif plugin in (PluginEnum.TRAIN_LINEAGE.value, PluginEnum.EVAL_LINEAGE.value, 47 PluginEnum.CUSTOM_LINEAGE_DATA.value, PluginEnum.DATASET_GRAPH.value): 48 result.append([plugin, serialize_to_lineage_event(plugin, data.get('value'))]) 49 elif plugin in (PluginEnum.SCALAR.value, PluginEnum.TENSOR.value, PluginEnum.HISTOGRAM.value, 50 PluginEnum.IMAGE.value): 51 summaries.append({'_type': plugin.title(), 'name': data.get('tag'), 'data': data.get('value')}) 52 step = data.get('step') 53 elif plugin == PluginEnum.EXPLAINER.value: 54 result.append([plugin, package_explain_event(data.get('value'))]) 55 56 if 'export_option' in data: 57 result.append([WriterPluginEnum.EXPORTER.value, data]) 58 59 if summaries: 60 result.append( 61 [WriterPluginEnum.SUMMARY.value, package_summary_event(summaries, step, wall_time).SerializeToString()]) 62 return result 63 64 65class WriterPool(ctx.Process): 66 """ 67 Use a set of pooled resident processes for writing a list of file. 68 69 Args: 70 base_dir (str): The base directory to hold all the files. 71 max_file_size (Optional[int]): The maximum size of each file that can be written to disk in bytes. 72 raise_exception (bool, optional): Sets whether to throw an exception when an RuntimeError exception occurs 73 in recording data. Default: False, this means that error logs are printed and no exception is thrown. 74 export_options (Union[None, dict]): Perform custom operations on the export data. Default: None. 75 filedict (dict): The mapping from plugin to filename. 76 """ 77 78 def __init__(self, base_dir, max_file_size, raise_exception=False, **filedict) -> None: 79 super().__init__() 80 self._base_dir, self._filedict = base_dir, filedict 81 self._queue, self._writers_ = ctx.Queue(ctx.cpu_count() * 2), None 82 self._max_file_size = max_file_size 83 self._raise_exception = raise_exception 84 self._training_pid = os.getpid() 85 self.start() 86 87 def run(self): 88 # Environment variables are used to specify a maximum number of OpenBLAS threads: 89 # In ubuntu(GPU) environment, numpy will use too many threads for computing, 90 # it may affect the start of the summary process. 91 # Notice: At present, the performance of setting the thread to 2 has been tested to be more suitable. 92 # If it is to be adjusted, it is recommended to test according to the scenario first 93 os.environ['OPENBLAS_NUM_THREADS'] = '2' 94 os.environ['GOTO_NUM_THREADS'] = '2' 95 os.environ['OMP_NUM_THREADS'] = '2' 96 97 # Prevent the multiprocess from capturing KeyboardInterrupt, 98 # which causes the main process to fail to exit. 99 signal.signal(signal.SIGINT, signal.SIG_IGN) 100 101 with ctx.Pool(min(ctx.cpu_count(), 32)) as pool: 102 deq = deque() 103 while True: 104 if self._check_heartbeat(): 105 self._close() 106 return 107 108 while deq and deq[0].ready(): 109 for plugin, data in deq.popleft().get(): 110 self._write(plugin, data) 111 112 try: 113 action, data = self._queue.get(block=False) 114 if action == 'WRITE': 115 deq.append(pool.apply_async(_pack_data, (data, time.time()))) 116 elif action == 'FLUSH': 117 self._flush() 118 elif action == 'END': 119 break 120 except queue.Empty: 121 continue 122 123 for result in deq: 124 for plugin, data in result.get(): 125 self._write(plugin, data) 126 127 self._close() 128 129 @property 130 def _writers(self): 131 """Get the writers in the subprocess.""" 132 if self._writers_ is not None: 133 return self._writers_ 134 self._writers_ = [] 135 for plugin, filename in self._filedict.items(): 136 filepath = os.path.join(self._base_dir, filename) 137 if plugin == WriterPluginEnum.SUMMARY.value: 138 self._writers_.append(SummaryWriter(filepath, self._max_file_size)) 139 elif plugin == WriterPluginEnum.LINEAGE.value: 140 self._writers_.append(LineageWriter(filepath, self._max_file_size)) 141 elif plugin == WriterPluginEnum.EXPLAINER.value: 142 self._writers_.append(ExplainWriter(filepath, self._max_file_size)) 143 elif plugin == WriterPluginEnum.EXPORTER.value: 144 self._writers_.append(ExportWriter(filepath, self._max_file_size)) 145 return self._writers_ 146 147 def _write(self, plugin, data): 148 """Write the data in the subprocess.""" 149 for writer in self._writers[:]: 150 try: 151 writer.write(plugin, data) 152 except (RuntimeError, OSError) as exc: 153 logger.error(str(exc)) 154 self._writers.remove(writer) 155 writer.close() 156 if self._raise_exception: 157 raise 158 except RuntimeWarning as exc: 159 logger.warning(str(exc)) 160 self._writers.remove(writer) 161 writer.close() 162 163 def _flush(self): 164 """Flush the writers in the subprocess.""" 165 for writer in self._writers: 166 writer.flush() 167 168 def _close(self): 169 """Close the writers in the subprocess.""" 170 for writer in self._writers: 171 writer.close() 172 super().close() 173 174 def write(self, data) -> None: 175 """ 176 Write the event to file. 177 178 Args: 179 data (Optional[str, Tuple[list, int]]): The data to write. 180 """ 181 self._queue.put(('WRITE', data)) 182 183 def flush(self): 184 """Flush the writer and sync data to disk.""" 185 self._queue.put(('FLUSH', None)) 186 187 def close(self) -> None: 188 """Close the writer.""" 189 self._queue.put(('END', None)) 190 191 def _check_heartbeat(self): 192 """Check if the summary process should survive.""" 193 is_exit = False 194 if not psutil.pid_exists(self._training_pid): 195 logger.warning("The training process %d has exited, summary process will exit.", self._training_pid) 196 is_exit = True 197 198 if not self._writers: 199 logger.warning("Can not find any writer to write summary data, " 200 "so SummaryRecord will not record data.") 201 is_exit = True 202 203 return is_exit 204