• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import errno
2import os
3import selectors
4import signal
5import socket
6import struct
7import sys
8import threading
9import warnings
10
11from . import connection
12from . import process
13from .context import reduction
14from . import resource_tracker
15from . import spawn
16from . import util
17
18__all__ = ['ensure_running', 'get_inherited_fds', 'connect_to_new_process',
19           'set_forkserver_preload']
20
21#
22#
23#
24
25MAXFDS_TO_SEND = 256
26SIGNED_STRUCT = struct.Struct('q')     # large enough for pid_t
27
28#
29# Forkserver class
30#
31
32class ForkServer(object):
33
34    def __init__(self):
35        self._forkserver_address = None
36        self._forkserver_alive_fd = None
37        self._forkserver_pid = None
38        self._inherited_fds = None
39        self._lock = threading.Lock()
40        self._preload_modules = ['__main__']
41
42    def _stop(self):
43        # Method used by unit tests to stop the server
44        with self._lock:
45            self._stop_unlocked()
46
47    def _stop_unlocked(self):
48        if self._forkserver_pid is None:
49            return
50
51        # close the "alive" file descriptor asks the server to stop
52        os.close(self._forkserver_alive_fd)
53        self._forkserver_alive_fd = None
54
55        os.waitpid(self._forkserver_pid, 0)
56        self._forkserver_pid = None
57
58        os.unlink(self._forkserver_address)
59        self._forkserver_address = None
60
61    def set_forkserver_preload(self, modules_names):
62        '''Set list of module names to try to load in forkserver process.'''
63        if not all(type(mod) is str for mod in self._preload_modules):
64            raise TypeError('module_names must be a list of strings')
65        self._preload_modules = modules_names
66
67    def get_inherited_fds(self):
68        '''Return list of fds inherited from parent process.
69
70        This returns None if the current process was not started by fork
71        server.
72        '''
73        return self._inherited_fds
74
75    def connect_to_new_process(self, fds):
76        '''Request forkserver to create a child process.
77
78        Returns a pair of fds (status_r, data_w).  The calling process can read
79        the child process's pid and (eventually) its returncode from status_r.
80        The calling process should write to data_w the pickled preparation and
81        process data.
82        '''
83        self.ensure_running()
84        if len(fds) + 4 >= MAXFDS_TO_SEND:
85            raise ValueError('too many fds')
86        with socket.socket(socket.AF_UNIX) as client:
87            client.connect(self._forkserver_address)
88            parent_r, child_w = os.pipe()
89            child_r, parent_w = os.pipe()
90            allfds = [child_r, child_w, self._forkserver_alive_fd,
91                      resource_tracker.getfd()]
92            allfds += fds
93            try:
94                reduction.sendfds(client, allfds)
95                return parent_r, parent_w
96            except:
97                os.close(parent_r)
98                os.close(parent_w)
99                raise
100            finally:
101                os.close(child_r)
102                os.close(child_w)
103
104    def ensure_running(self):
105        '''Make sure that a fork server is running.
106
107        This can be called from any process.  Note that usually a child
108        process will just reuse the forkserver started by its parent, so
109        ensure_running() will do nothing.
110        '''
111        with self._lock:
112            resource_tracker.ensure_running()
113            if self._forkserver_pid is not None:
114                # forkserver was launched before, is it still running?
115                pid, status = os.waitpid(self._forkserver_pid, os.WNOHANG)
116                if not pid:
117                    # still alive
118                    return
119                # dead, launch it again
120                os.close(self._forkserver_alive_fd)
121                self._forkserver_address = None
122                self._forkserver_alive_fd = None
123                self._forkserver_pid = None
124
125            cmd = ('from multiprocessing.forkserver import main; ' +
126                   'main(%d, %d, %r, **%r)')
127
128            if self._preload_modules:
129                desired_keys = {'main_path', 'sys_path'}
130                data = spawn.get_preparation_data('ignore')
131                data = {x: y for x, y in data.items() if x in desired_keys}
132            else:
133                data = {}
134
135            with socket.socket(socket.AF_UNIX) as listener:
136                address = connection.arbitrary_address('AF_UNIX')
137                listener.bind(address)
138                os.chmod(address, 0o600)
139                listener.listen()
140
141                # all client processes own the write end of the "alive" pipe;
142                # when they all terminate the read end becomes ready.
143                alive_r, alive_w = os.pipe()
144                try:
145                    fds_to_pass = [listener.fileno(), alive_r]
146                    cmd %= (listener.fileno(), alive_r, self._preload_modules,
147                            data)
148                    exe = spawn.get_executable()
149                    args = [exe] + util._args_from_interpreter_flags()
150                    args += ['-c', cmd]
151                    pid = util.spawnv_passfds(exe, args, fds_to_pass)
152                except:
153                    os.close(alive_w)
154                    raise
155                finally:
156                    os.close(alive_r)
157                self._forkserver_address = address
158                self._forkserver_alive_fd = alive_w
159                self._forkserver_pid = pid
160
161#
162#
163#
164
165def main(listener_fd, alive_r, preload, main_path=None, sys_path=None):
166    '''Run forkserver.'''
167    if preload:
168        if '__main__' in preload and main_path is not None:
169            process.current_process()._inheriting = True
170            try:
171                spawn.import_main_path(main_path)
172            finally:
173                del process.current_process()._inheriting
174        for modname in preload:
175            try:
176                __import__(modname)
177            except ImportError:
178                pass
179
180    util._close_stdin()
181
182    sig_r, sig_w = os.pipe()
183    os.set_blocking(sig_r, False)
184    os.set_blocking(sig_w, False)
185
186    def sigchld_handler(*_unused):
187        # Dummy signal handler, doesn't do anything
188        pass
189
190    handlers = {
191        # unblocking SIGCHLD allows the wakeup fd to notify our event loop
192        signal.SIGCHLD: sigchld_handler,
193        # protect the process from ^C
194        signal.SIGINT: signal.SIG_IGN,
195        }
196    old_handlers = {sig: signal.signal(sig, val)
197                    for (sig, val) in handlers.items()}
198
199    # calling os.write() in the Python signal handler is racy
200    signal.set_wakeup_fd(sig_w)
201
202    # map child pids to client fds
203    pid_to_fd = {}
204
205    with socket.socket(socket.AF_UNIX, fileno=listener_fd) as listener, \
206         selectors.DefaultSelector() as selector:
207        _forkserver._forkserver_address = listener.getsockname()
208
209        selector.register(listener, selectors.EVENT_READ)
210        selector.register(alive_r, selectors.EVENT_READ)
211        selector.register(sig_r, selectors.EVENT_READ)
212
213        while True:
214            try:
215                while True:
216                    rfds = [key.fileobj for (key, events) in selector.select()]
217                    if rfds:
218                        break
219
220                if alive_r in rfds:
221                    # EOF because no more client processes left
222                    assert os.read(alive_r, 1) == b'', "Not at EOF?"
223                    raise SystemExit
224
225                if sig_r in rfds:
226                    # Got SIGCHLD
227                    os.read(sig_r, 65536)  # exhaust
228                    while True:
229                        # Scan for child processes
230                        try:
231                            pid, sts = os.waitpid(-1, os.WNOHANG)
232                        except ChildProcessError:
233                            break
234                        if pid == 0:
235                            break
236                        child_w = pid_to_fd.pop(pid, None)
237                        if child_w is not None:
238                            if os.WIFSIGNALED(sts):
239                                returncode = -os.WTERMSIG(sts)
240                            else:
241                                if not os.WIFEXITED(sts):
242                                    raise AssertionError(
243                                        "Child {0:n} status is {1:n}".format(
244                                            pid,sts))
245                                returncode = os.WEXITSTATUS(sts)
246                            # Send exit code to client process
247                            try:
248                                write_signed(child_w, returncode)
249                            except BrokenPipeError:
250                                # client vanished
251                                pass
252                            os.close(child_w)
253                        else:
254                            # This shouldn't happen really
255                            warnings.warn('forkserver: waitpid returned '
256                                          'unexpected pid %d' % pid)
257
258                if listener in rfds:
259                    # Incoming fork request
260                    with listener.accept()[0] as s:
261                        # Receive fds from client
262                        fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1)
263                        if len(fds) > MAXFDS_TO_SEND:
264                            raise RuntimeError(
265                                "Too many ({0:n}) fds to send".format(
266                                    len(fds)))
267                        child_r, child_w, *fds = fds
268                        s.close()
269                        pid = os.fork()
270                        if pid == 0:
271                            # Child
272                            code = 1
273                            try:
274                                listener.close()
275                                selector.close()
276                                unused_fds = [alive_r, child_w, sig_r, sig_w]
277                                unused_fds.extend(pid_to_fd.values())
278                                code = _serve_one(child_r, fds,
279                                                  unused_fds,
280                                                  old_handlers)
281                            except Exception:
282                                sys.excepthook(*sys.exc_info())
283                                sys.stderr.flush()
284                            finally:
285                                os._exit(code)
286                        else:
287                            # Send pid to client process
288                            try:
289                                write_signed(child_w, pid)
290                            except BrokenPipeError:
291                                # client vanished
292                                pass
293                            pid_to_fd[pid] = child_w
294                            os.close(child_r)
295                            for fd in fds:
296                                os.close(fd)
297
298            except OSError as e:
299                if e.errno != errno.ECONNABORTED:
300                    raise
301
302
303def _serve_one(child_r, fds, unused_fds, handlers):
304    # close unnecessary stuff and reset signal handlers
305    signal.set_wakeup_fd(-1)
306    for sig, val in handlers.items():
307        signal.signal(sig, val)
308    for fd in unused_fds:
309        os.close(fd)
310
311    (_forkserver._forkserver_alive_fd,
312     resource_tracker._resource_tracker._fd,
313     *_forkserver._inherited_fds) = fds
314
315    # Run process object received over pipe
316    parent_sentinel = os.dup(child_r)
317    code = spawn._main(child_r, parent_sentinel)
318
319    return code
320
321
322#
323# Read and write signed numbers
324#
325
326def read_signed(fd):
327    data = b''
328    length = SIGNED_STRUCT.size
329    while len(data) < length:
330        s = os.read(fd, length - len(data))
331        if not s:
332            raise EOFError('unexpected EOF')
333        data += s
334    return SIGNED_STRUCT.unpack(data)[0]
335
336def write_signed(fd, n):
337    msg = SIGNED_STRUCT.pack(n)
338    while msg:
339        nbytes = os.write(fd, msg)
340        if nbytes == 0:
341            raise RuntimeError('should not get here')
342        msg = msg[nbytes:]
343
344#
345#
346#
347
348_forkserver = ForkServer()
349ensure_running = _forkserver.ensure_running
350get_inherited_fds = _forkserver.get_inherited_fds
351connect_to_new_process = _forkserver.connect_to_new_process
352set_forkserver_preload = _forkserver.set_forkserver_preload
353