• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16mindspore.multiprocessing is a wrapper around the native `multiprocessing` module.
17Some methods are overrode to support fork-based multiprocess.
18"""
19import types
20import signal
21import multiprocessing as mp
22from multiprocessing import *
23from mindspore._c_expression import fork_utils
24
25__all__ = []
26__all__ += mp.__all__
27
28
29class Process(mp.Process): # pylint: disable=function-redefined
30    """
31    Trigger fork callbacks by overriding native multiprocessing methods.
32    """
33    _child_at_fork_func = None
34
35    def run(self):
36        """
37        Trigger child_at_fork callback function after fork by overriding
38        multiprocessing.run method.
39        """
40        if mp.get_start_method() == "fork":
41            fork_utils.prctl_set_pdeathsig(signal.SIGINT)
42            fork_utils.child_at_fork()
43            if Process._child_at_fork_func and callable(Process._child_at_fork_func):
44                Process._child_at_fork_func() # pylint: disable=not-callable
45        super().run()
46
47    def start(self):
48        """
49        Trigger prepare_before_fork and parent_at_fork callback functions
50        by overriding multiprocessing.start method.
51        """
52        if mp.get_start_method() == "fork":
53            fork_utils.prepare_before_fork()
54            super().start()
55            fork_utils.parent_at_fork()
56        else:
57            super().start()
58
59_MsProcess = Process
60
61
62class Pool(mp.pool.Pool): # pylint: disable=function-redefined, abstract-method
63    """
64    Trigger fork callbacks by overriding native multiprocessing methods.
65    """
66    def Process(self, *args, **kwds):
67        if self._ctx.get_start_method() == "fork":
68            # Process() becomes a staticmethod function of Pool with first argument 'ctx' in python 3.8.0 and later
69            if isinstance(super().Process, types.FunctionType):
70                args = args[1:]
71            return _MsProcess(*args, **kwds)
72        return super().Process(*args, **kwds)
73