• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Implementations of fork support test methods."""
15
16import enum
17import json
18import logging
19import multiprocessing
20import os
21import queue
22import subprocess
23import sys
24import tempfile
25import threading
26import time
27import traceback
28
29import grpc
30
31from src.proto.grpc.testing import empty_pb2
32from src.proto.grpc.testing import messages_pb2
33from src.proto.grpc.testing import test_pb2_grpc
34
35_LOGGER = logging.getLogger(__name__)
36_RPC_TIMEOUT_S = 10
37_CHILD_FINISH_TIMEOUT_S = 20
38_GDB_TIMEOUT_S = 60
39
40
41def _channel(args):
42    target = "{}:{}".format(args["server_host"], args["server_port"])
43    if args["use_tls"]:
44        channel_credentials = grpc.ssl_channel_credentials()
45        channel = grpc.secure_channel(target, channel_credentials)
46    else:
47        channel = grpc.insecure_channel(target)
48    return channel
49
50
51def _validate_payload_type_and_length(response, expected_type, expected_length):
52    if response.payload.type is not expected_type:
53        raise ValueError(
54            "expected payload type %s, got %s"
55            % (expected_type, type(response.payload.type))
56        )
57    elif len(response.payload.body) != expected_length:
58        raise ValueError(
59            "expected payload body size %d, got %d"
60            % (expected_length, len(response.payload.body))
61        )
62
63
64def _async_unary(stub):
65    size = 314159
66    request = messages_pb2.SimpleRequest(
67        response_type=messages_pb2.COMPRESSABLE,
68        response_size=size,
69        payload=messages_pb2.Payload(body=b"\x00" * 271828),
70    )
71
72    response_future = stub.UnaryCall.future(request, timeout=_RPC_TIMEOUT_S)
73    response = response_future.result()
74    _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size)
75
76
77def _blocking_unary(stub):
78    size = 314159
79    request = messages_pb2.SimpleRequest(
80        response_type=messages_pb2.COMPRESSABLE,
81        response_size=size,
82        payload=messages_pb2.Payload(body=b"\x00" * 271828),
83    )
84    response = stub.UnaryCall(request, timeout=_RPC_TIMEOUT_S)
85    _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size)
86
87
88class _Pipe(object):
89    def __init__(self):
90        self._condition = threading.Condition()
91        self._values = []
92        self._open = True
93
94    def __iter__(self):
95        return self
96
97    def __next__(self):
98        return self.next()
99
100    def next(self):
101        with self._condition:
102            while not self._values and self._open:
103                self._condition.wait()
104            if self._values:
105                return self._values.pop(0)
106            else:
107                raise StopIteration()
108
109    def add(self, value):
110        with self._condition:
111            self._values.append(value)
112            self._condition.notify()
113
114    def close(self):
115        with self._condition:
116            self._open = False
117            self._condition.notify()
118
119    def __enter__(self):
120        return self
121
122    def __exit__(self, type, value, traceback):
123        self.close()
124
125
126class _ChildProcess(object):
127    def __init__(self, task, args=None):
128        if args is None:
129            args = ()
130        self._exceptions = multiprocessing.Queue()
131        self._stdout_path = tempfile.mkstemp()[1]
132        self._stderr_path = tempfile.mkstemp()[1]
133        self._child_pid = None
134        self._rc = None
135        self._args = args
136
137        self._task = task
138
139    def _child_main(self):
140        import faulthandler
141
142        faulthandler.enable(all_threads=True)
143
144        try:
145            self._task(*self._args)
146        except grpc.RpcError as rpc_error:
147            traceback.print_exc()
148            self._exceptions.put("RpcError: %s" % rpc_error)
149        except Exception as e:  # pylint: disable=broad-except
150            traceback.print_exc()
151            self._exceptions.put(e)
152        sys.exit(0)
153
154    def _orchestrate_child_gdb(self):
155        cmd = [
156            "gdb",
157            "-ex",
158            "set confirm off",
159            "-ex",
160            "attach {}".format(os.getpid()),
161            "-ex",
162            "set follow-fork-mode child",
163            "-ex",
164            "continue",
165            "-ex",
166            "bt",
167        ]
168        streams = tuple(tempfile.TemporaryFile() for _ in range(2))
169        sys.stderr.write("Invoking gdb\n")
170        sys.stderr.flush()
171        process = subprocess.Popen(cmd, stdout=sys.stderr, stderr=sys.stderr)
172        time.sleep(5)
173
174    def start(self):
175        # NOTE: Try uncommenting the following line if the child is segfaulting.
176        # self._orchestrate_child_gdb()
177        ret = os.fork()
178        if ret == 0:
179            self._child_main()
180        else:
181            self._child_pid = ret
182
183    def wait(self, timeout):
184        total = 0.0
185        wait_interval = 1.0
186        while total < timeout:
187            ret, termination = os.waitpid(self._child_pid, os.WNOHANG)
188            if ret == self._child_pid:
189                self._rc = termination
190                return True
191            time.sleep(wait_interval)
192            total += wait_interval
193        else:
194            return False
195
196    def _print_backtraces(self):
197        cmd = [
198            "gdb",
199            "-ex",
200            "set confirm off",
201            "-ex",
202            "echo attaching",
203            "-ex",
204            "attach {}".format(self._child_pid),
205            "-ex",
206            "echo print_backtrace",
207            "-ex",
208            "thread apply all bt",
209            "-ex",
210            "echo printed_backtrace",
211            "-ex",
212            "quit",
213        ]
214        streams = tuple(tempfile.TemporaryFile() for _ in range(2))
215        sys.stderr.write("Invoking gdb\n")
216        sys.stderr.flush()
217        process = subprocess.Popen(cmd, stdout=streams[0], stderr=streams[1])
218        try:
219            process.wait(timeout=_GDB_TIMEOUT_S)
220        except subprocess.TimeoutExpired:
221            sys.stderr.write("gdb stacktrace generation timed out.\n")
222        finally:
223            for stream_name, stream in zip(("STDOUT", "STDERR"), streams):
224                stream.seek(0)
225                sys.stderr.write(
226                    "gdb {}:\n{}\n".format(
227                        stream_name, stream.read().decode("ascii")
228                    )
229                )
230                stream.close()
231            sys.stderr.flush()
232
233    def finish(self):
234        terminated = self.wait(_CHILD_FINISH_TIMEOUT_S)
235        sys.stderr.write("Exit code: {}\n".format(self._rc))
236        if not terminated:
237            self._print_backtraces()
238            raise RuntimeError("Child process did not terminate")
239        if self._rc != 0:
240            raise ValueError("Child process failed with exitcode %d" % self._rc)
241        try:
242            exception = self._exceptions.get(block=False)
243            raise ValueError(
244                'Child process failed: "%s": "%s"'
245                % (repr(exception), exception)
246            )
247        except queue.Empty:
248            pass
249
250
251def _async_unary_same_channel(channel):
252    def child_target():
253        try:
254            _async_unary(stub)
255            raise Exception(
256                "Child should not be able to re-use channel after fork"
257            )
258        except ValueError as expected_value_error:
259            pass
260
261    stub = test_pb2_grpc.TestServiceStub(channel)
262    _async_unary(stub)
263    child_process = _ChildProcess(child_target)
264    child_process.start()
265    _async_unary(stub)
266    child_process.finish()
267
268
269def _async_unary_new_channel(channel, args):
270    def child_target():
271        with _channel(args) as child_channel:
272            child_stub = test_pb2_grpc.TestServiceStub(child_channel)
273            _async_unary(child_stub)
274            child_channel.close()
275
276    stub = test_pb2_grpc.TestServiceStub(channel)
277    _async_unary(stub)
278    child_process = _ChildProcess(child_target)
279    child_process.start()
280    _async_unary(stub)
281    child_process.finish()
282
283
284def _blocking_unary_same_channel(channel):
285    def child_target():
286        try:
287            _blocking_unary(stub)
288            raise Exception(
289                "Child should not be able to re-use channel after fork"
290            )
291        except ValueError as expected_value_error:
292            pass
293
294    stub = test_pb2_grpc.TestServiceStub(channel)
295    _blocking_unary(stub)
296    child_process = _ChildProcess(child_target)
297    child_process.start()
298    child_process.finish()
299
300
301def _blocking_unary_new_channel(channel, args):
302    def child_target():
303        with _channel(args) as child_channel:
304            child_stub = test_pb2_grpc.TestServiceStub(child_channel)
305            _blocking_unary(child_stub)
306
307    stub = test_pb2_grpc.TestServiceStub(channel)
308    _blocking_unary(stub)
309    child_process = _ChildProcess(child_target)
310    child_process.start()
311    _blocking_unary(stub)
312    child_process.finish()
313
314
315# Verify that the fork channel registry can handle already closed channels
316def _close_channel_before_fork(channel, args):
317    def child_target():
318        new_channel.close()
319        with _channel(args) as child_channel:
320            child_stub = test_pb2_grpc.TestServiceStub(child_channel)
321            _blocking_unary(child_stub)
322
323    stub = test_pb2_grpc.TestServiceStub(channel)
324    _blocking_unary(stub)
325    channel.close()
326
327    with _channel(args) as new_channel:
328        new_stub = test_pb2_grpc.TestServiceStub(new_channel)
329        child_process = _ChildProcess(child_target)
330        child_process.start()
331        _blocking_unary(new_stub)
332        child_process.finish()
333
334
335def _connectivity_watch(channel, args):
336    parent_states = []
337    parent_channel_ready_event = threading.Event()
338
339    def child_target():
340        child_channel_ready_event = threading.Event()
341
342        def child_connectivity_callback(state):
343            if state is grpc.ChannelConnectivity.READY:
344                child_channel_ready_event.set()
345
346        with _channel(args) as child_channel:
347            child_stub = test_pb2_grpc.TestServiceStub(child_channel)
348            child_channel.subscribe(child_connectivity_callback)
349            _async_unary(child_stub)
350            if not child_channel_ready_event.wait(timeout=_RPC_TIMEOUT_S):
351                raise ValueError("Channel did not move to READY")
352            if len(parent_states) > 1:
353                raise ValueError(
354                    "Received connectivity updates on parent callback",
355                    parent_states,
356                )
357            child_channel.unsubscribe(child_connectivity_callback)
358
359    def parent_connectivity_callback(state):
360        parent_states.append(state)
361        if state is grpc.ChannelConnectivity.READY:
362            parent_channel_ready_event.set()
363
364    channel.subscribe(parent_connectivity_callback)
365    stub = test_pb2_grpc.TestServiceStub(channel)
366    child_process = _ChildProcess(child_target)
367    child_process.start()
368    _async_unary(stub)
369    if not parent_channel_ready_event.wait(timeout=_RPC_TIMEOUT_S):
370        raise ValueError("Channel did not move to READY")
371    channel.unsubscribe(parent_connectivity_callback)
372    child_process.finish()
373
374
375def _ping_pong_with_child_processes_after_first_response(
376    channel, args, child_target, run_after_close=True
377):
378    request_response_sizes = (
379        31415,
380        9,
381        2653,
382        58979,
383    )
384    request_payload_sizes = (
385        27182,
386        8,
387        1828,
388        45904,
389    )
390    stub = test_pb2_grpc.TestServiceStub(channel)
391    pipe = _Pipe()
392    parent_bidi_call = stub.FullDuplexCall(pipe)
393    child_processes = []
394    first_message_received = False
395    for response_size, payload_size in zip(
396        request_response_sizes, request_payload_sizes
397    ):
398        request = messages_pb2.StreamingOutputCallRequest(
399            response_type=messages_pb2.COMPRESSABLE,
400            response_parameters=(
401                messages_pb2.ResponseParameters(size=response_size),
402            ),
403            payload=messages_pb2.Payload(body=b"\x00" * payload_size),
404        )
405        pipe.add(request)
406        if first_message_received:
407            child_process = _ChildProcess(
408                child_target, (parent_bidi_call, channel, args)
409            )
410            child_process.start()
411            child_processes.append(child_process)
412        response = next(parent_bidi_call)
413        first_message_received = True
414        child_process = _ChildProcess(
415            child_target, (parent_bidi_call, channel, args)
416        )
417        child_process.start()
418        child_processes.append(child_process)
419        _validate_payload_type_and_length(
420            response, messages_pb2.COMPRESSABLE, response_size
421        )
422    pipe.close()
423    if run_after_close:
424        child_process = _ChildProcess(
425            child_target, (parent_bidi_call, channel, args)
426        )
427        child_process.start()
428        child_processes.append(child_process)
429    for child_process in child_processes:
430        child_process.finish()
431
432
433def _in_progress_bidi_continue_call(channel):
434    def child_target(parent_bidi_call, parent_channel, args):
435        stub = test_pb2_grpc.TestServiceStub(parent_channel)
436        try:
437            _async_unary(stub)
438            raise Exception(
439                "Child should not be able to re-use channel after fork"
440            )
441        except ValueError as expected_value_error:
442            pass
443        inherited_code = parent_bidi_call.code()
444        inherited_details = parent_bidi_call.details()
445        if inherited_code != grpc.StatusCode.CANCELLED:
446            raise ValueError(
447                "Expected inherited code CANCELLED, got %s" % inherited_code
448            )
449        if inherited_details != "Channel closed due to fork":
450            raise ValueError(
451                "Expected inherited details Channel closed due to fork, got %s"
452                % inherited_details
453            )
454
455    # Don't run child_target after closing the parent call, as the call may have
456    # received a status from the  server before fork occurs.
457    _ping_pong_with_child_processes_after_first_response(
458        channel, None, child_target, run_after_close=False
459    )
460
461
462def _in_progress_bidi_same_channel_async_call(channel):
463    def child_target(parent_bidi_call, parent_channel, args):
464        stub = test_pb2_grpc.TestServiceStub(parent_channel)
465        try:
466            _async_unary(stub)
467            raise Exception(
468                "Child should not be able to re-use channel after fork"
469            )
470        except ValueError as expected_value_error:
471            pass
472
473    _ping_pong_with_child_processes_after_first_response(
474        channel, None, child_target
475    )
476
477
478def _in_progress_bidi_same_channel_blocking_call(channel):
479    def child_target(parent_bidi_call, parent_channel, args):
480        stub = test_pb2_grpc.TestServiceStub(parent_channel)
481        try:
482            _blocking_unary(stub)
483            raise Exception(
484                "Child should not be able to re-use channel after fork"
485            )
486        except ValueError as expected_value_error:
487            pass
488
489    _ping_pong_with_child_processes_after_first_response(
490        channel, None, child_target
491    )
492
493
494def _in_progress_bidi_new_channel_async_call(channel, args):
495    def child_target(parent_bidi_call, parent_channel, args):
496        with _channel(args) as channel:
497            stub = test_pb2_grpc.TestServiceStub(channel)
498            _async_unary(stub)
499
500    _ping_pong_with_child_processes_after_first_response(
501        channel, args, child_target
502    )
503
504
505def _in_progress_bidi_new_channel_blocking_call(channel, args):
506    def child_target(parent_bidi_call, parent_channel, args):
507        with _channel(args) as channel:
508            stub = test_pb2_grpc.TestServiceStub(channel)
509            _blocking_unary(stub)
510
511    _ping_pong_with_child_processes_after_first_response(
512        channel, args, child_target
513    )
514
515
516@enum.unique
517class TestCase(enum.Enum):
518    CONNECTIVITY_WATCH = "connectivity_watch"
519    CLOSE_CHANNEL_BEFORE_FORK = "close_channel_before_fork"
520    ASYNC_UNARY_SAME_CHANNEL = "async_unary_same_channel"
521    ASYNC_UNARY_NEW_CHANNEL = "async_unary_new_channel"
522    BLOCKING_UNARY_SAME_CHANNEL = "blocking_unary_same_channel"
523    BLOCKING_UNARY_NEW_CHANNEL = "blocking_unary_new_channel"
524    IN_PROGRESS_BIDI_CONTINUE_CALL = "in_progress_bidi_continue_call"
525    IN_PROGRESS_BIDI_SAME_CHANNEL_ASYNC_CALL = (
526        "in_progress_bidi_same_channel_async_call"
527    )
528    IN_PROGRESS_BIDI_SAME_CHANNEL_BLOCKING_CALL = (
529        "in_progress_bidi_same_channel_blocking_call"
530    )
531    IN_PROGRESS_BIDI_NEW_CHANNEL_ASYNC_CALL = (
532        "in_progress_bidi_new_channel_async_call"
533    )
534    IN_PROGRESS_BIDI_NEW_CHANNEL_BLOCKING_CALL = (
535        "in_progress_bidi_new_channel_blocking_call"
536    )
537
538    def run_test(self, args):
539        _LOGGER.info("Running %s", self)
540        channel = _channel(args)
541        if self is TestCase.ASYNC_UNARY_SAME_CHANNEL:
542            _async_unary_same_channel(channel)
543        elif self is TestCase.ASYNC_UNARY_NEW_CHANNEL:
544            _async_unary_new_channel(channel, args)
545        elif self is TestCase.BLOCKING_UNARY_SAME_CHANNEL:
546            _blocking_unary_same_channel(channel)
547        elif self is TestCase.BLOCKING_UNARY_NEW_CHANNEL:
548            _blocking_unary_new_channel(channel, args)
549        elif self is TestCase.CLOSE_CHANNEL_BEFORE_FORK:
550            _close_channel_before_fork(channel, args)
551        elif self is TestCase.CONNECTIVITY_WATCH:
552            _connectivity_watch(channel, args)
553        elif self is TestCase.IN_PROGRESS_BIDI_CONTINUE_CALL:
554            _in_progress_bidi_continue_call(channel)
555        elif self is TestCase.IN_PROGRESS_BIDI_SAME_CHANNEL_ASYNC_CALL:
556            _in_progress_bidi_same_channel_async_call(channel)
557        elif self is TestCase.IN_PROGRESS_BIDI_SAME_CHANNEL_BLOCKING_CALL:
558            _in_progress_bidi_same_channel_blocking_call(channel)
559        elif self is TestCase.IN_PROGRESS_BIDI_NEW_CHANNEL_ASYNC_CALL:
560            _in_progress_bidi_new_channel_async_call(channel, args)
561        elif self is TestCase.IN_PROGRESS_BIDI_NEW_CHANNEL_BLOCKING_CALL:
562            _in_progress_bidi_new_channel_blocking_call(channel, args)
563        else:
564            raise NotImplementedError(
565                'Test case "%s" not implemented!' % self.name
566            )
567        channel.close()
568
569
570# Useful if needing to find a block of code from an address in an SO.
571def dump_object_map():
572    with open("/proc/self/maps", "r") as f:
573        sys.stderr.write("=============== /proc/self/maps ===============\n")
574        sys.stderr.write(f.read())
575        sys.stderr.write("\n")
576        sys.stderr.flush()
577