• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Write events to disk in a base directory."""
16import os
17import time
18import signal
19import queue
20from collections import deque
21
22import psutil
23
24import mindspore.log as logger
25from mindspore.train.summary.enums import PluginEnum, WriterPluginEnum
26
27from ._lineage_adapter import serialize_to_lineage_event
28from ._summary_adapter import package_graph_event, package_summary_event
29from ._explain_adapter import package_explain_event
30from .writer import LineageWriter, SummaryWriter, ExplainWriter, ExportWriter
31
32try:
33    from multiprocessing import get_context
34    ctx = get_context('forkserver')
35except ValueError:
36    import multiprocessing as ctx
37
38
39def _pack_data(datadict, wall_time):
40    """Pack data according to which plugin."""
41    result, summaries, step = [], [], None
42    for plugin, datalist in datadict.items():
43        for data in datalist:
44            if plugin == PluginEnum.GRAPH.value:
45                result.append([plugin, package_graph_event(data.get('value')).SerializeToString()])
46            elif plugin in (PluginEnum.TRAIN_LINEAGE.value, PluginEnum.EVAL_LINEAGE.value,
47                            PluginEnum.CUSTOM_LINEAGE_DATA.value, PluginEnum.DATASET_GRAPH.value):
48                result.append([plugin, serialize_to_lineage_event(plugin, data.get('value'))])
49            elif plugin in (PluginEnum.SCALAR.value, PluginEnum.TENSOR.value, PluginEnum.HISTOGRAM.value,
50                            PluginEnum.IMAGE.value):
51                summaries.append({'_type': plugin.title(), 'name': data.get('tag'), 'data': data.get('value')})
52                step = data.get('step')
53            elif plugin == PluginEnum.EXPLAINER.value:
54                result.append([plugin, package_explain_event(data.get('value'))])
55
56            if 'export_option' in data:
57                result.append([WriterPluginEnum.EXPORTER.value, data])
58
59    if summaries:
60        result.append(
61            [WriterPluginEnum.SUMMARY.value, package_summary_event(summaries, step, wall_time).SerializeToString()])
62    return result
63
64
65class WriterPool(ctx.Process):
66    """
67    Use a set of pooled resident processes for writing a list of file.
68
69    Args:
70        base_dir (str): The base directory to hold all the files.
71        max_file_size (Optional[int]): The maximum size of each file that can be written to disk in bytes.
72        raise_exception (bool, optional): Sets whether to throw an exception when an RuntimeError exception occurs
73            in recording data. Default: False, this means that error logs are printed and no exception is thrown.
74        export_options (Union[None, dict]): Perform custom operations on the export data. Default: None.
75        filedict (dict): The mapping from plugin to filename.
76    """
77
78    def __init__(self, base_dir, max_file_size, raise_exception=False, **filedict) -> None:
79        super().__init__()
80        self._base_dir, self._filedict = base_dir, filedict
81        self._queue, self._writers_ = ctx.Queue(ctx.cpu_count() * 2), None
82        self._max_file_size = max_file_size
83        self._raise_exception = raise_exception
84        self._training_pid = os.getpid()
85        self.start()
86
87    def run(self):
88        # Environment variables are used to specify a maximum number of OpenBLAS threads:
89        # In ubuntu(GPU) environment, numpy will use too many threads for computing,
90        # it may affect the start of the summary process.
91        # Notice: At present, the performance of setting the thread to 2 has been tested to be more suitable.
92        # If it is to be adjusted, it is recommended to test according to the scenario first
93        os.environ['OPENBLAS_NUM_THREADS'] = '2'
94        os.environ['GOTO_NUM_THREADS'] = '2'
95        os.environ['OMP_NUM_THREADS'] = '2'
96
97        # Prevent the multiprocess from capturing KeyboardInterrupt,
98        # which causes the main process to fail to exit.
99        signal.signal(signal.SIGINT, signal.SIG_IGN)
100
101        with ctx.Pool(min(ctx.cpu_count(), 32)) as pool:
102            deq = deque()
103            while True:
104                if self._check_heartbeat():
105                    self._close()
106                    return
107
108                while deq and deq[0].ready():
109                    for plugin, data in deq.popleft().get():
110                        self._write(plugin, data)
111
112                try:
113                    action, data = self._queue.get(block=False)
114                    if action == 'WRITE':
115                        deq.append(pool.apply_async(_pack_data, (data, time.time())))
116                    elif action == 'FLUSH':
117                        self._flush()
118                    elif action == 'END':
119                        break
120                except queue.Empty:
121                    continue
122
123            for result in deq:
124                for plugin, data in result.get():
125                    self._write(plugin, data)
126
127            self._close()
128
129    @property
130    def _writers(self):
131        """Get the writers in the subprocess."""
132        if self._writers_ is not None:
133            return self._writers_
134        self._writers_ = []
135        for plugin, filename in self._filedict.items():
136            filepath = os.path.join(self._base_dir, filename)
137            if plugin == WriterPluginEnum.SUMMARY.value:
138                self._writers_.append(SummaryWriter(filepath, self._max_file_size))
139            elif plugin == WriterPluginEnum.LINEAGE.value:
140                self._writers_.append(LineageWriter(filepath, self._max_file_size))
141            elif plugin == WriterPluginEnum.EXPLAINER.value:
142                self._writers_.append(ExplainWriter(filepath, self._max_file_size))
143            elif plugin == WriterPluginEnum.EXPORTER.value:
144                self._writers_.append(ExportWriter(filepath, self._max_file_size))
145        return self._writers_
146
147    def _write(self, plugin, data):
148        """Write the data in the subprocess."""
149        for writer in self._writers[:]:
150            try:
151                writer.write(plugin, data)
152            except (RuntimeError, OSError) as exc:
153                logger.error(str(exc))
154                self._writers.remove(writer)
155                writer.close()
156                if self._raise_exception:
157                    raise
158            except RuntimeWarning as exc:
159                logger.warning(str(exc))
160                self._writers.remove(writer)
161                writer.close()
162
163    def _flush(self):
164        """Flush the writers in the subprocess."""
165        for writer in self._writers:
166            writer.flush()
167
168    def _close(self):
169        """Close the writers in the subprocess."""
170        for writer in self._writers:
171            writer.close()
172        super().close()
173
174    def write(self, data) -> None:
175        """
176        Write the event to file.
177
178        Args:
179            data (Optional[str, Tuple[list, int]]): The data to write.
180        """
181        self._queue.put(('WRITE', data))
182
183    def flush(self):
184        """Flush the writer and sync data to disk."""
185        self._queue.put(('FLUSH', None))
186
187    def close(self) -> None:
188        """Close the writer."""
189        self._queue.put(('END', None))
190
191    def _check_heartbeat(self):
192        """Check if the summary process should survive."""
193        is_exit = False
194        if not psutil.pid_exists(self._training_pid):
195            logger.warning("The training process %d has exited, summary process will exit.", self._training_pid)
196            is_exit = True
197
198        if not self._writers:
199            logger.warning("Can not find any writer to write summary data, "
200                           "so SummaryRecord will not record data.")
201            is_exit = True
202
203        return is_exit
204