1# Copyright 2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16mindspore.multiprocessing is a wrapper around the native `multiprocessing` module. 17Some methods are overrode to support fork-based multiprocess. 18""" 19import types 20import signal 21import multiprocessing as mp 22from multiprocessing import * 23from mindspore._c_expression import fork_utils 24 25__all__ = [] 26__all__ += mp.__all__ 27 28 29class Process(mp.Process): # pylint: disable=function-redefined 30 """ 31 Trigger fork callbacks by overriding native multiprocessing methods. 32 """ 33 _child_at_fork_func = None 34 35 def run(self): 36 """ 37 Trigger child_at_fork callback function after fork by overriding 38 multiprocessing.run method. 39 """ 40 if mp.get_start_method() == "fork": 41 fork_utils.prctl_set_pdeathsig(signal.SIGINT) 42 fork_utils.child_at_fork() 43 if Process._child_at_fork_func and callable(Process._child_at_fork_func): 44 Process._child_at_fork_func() # pylint: disable=not-callable 45 super().run() 46 47 def start(self): 48 """ 49 Trigger prepare_before_fork and parent_at_fork callback functions 50 by overriding multiprocessing.start method. 51 """ 52 if mp.get_start_method() == "fork": 53 fork_utils.prepare_before_fork() 54 super().start() 55 fork_utils.parent_at_fork() 56 else: 57 super().start() 58 59_MsProcess = Process 60 61 62class Pool(mp.pool.Pool): # pylint: disable=function-redefined, abstract-method 63 """ 64 Trigger fork callbacks by overriding native multiprocessing methods. 65 """ 66 def Process(self, *args, **kwds): 67 if self._ctx.get_start_method() == "fork": 68 # Process() becomes a staticmethod function of Pool with first argument 'ctx' in python 3.8.0 and later 69 if isinstance(super().Process, types.FunctionType): 70 args = args[1:] 71 return _MsProcess(*args, **kwds) 72 return super().Process(*args, **kwds) 73