• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The gRPC Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""The Python example of utilizing wait-for-ready flag."""
15
16from concurrent import futures
17from contextlib import contextmanager
18import logging
19import socket
20import threading
21
22import grpc
23
24helloworld_pb2, helloworld_pb2_grpc = grpc.protos_and_services(
25    "helloworld.proto"
26)
27
28_LOGGER = logging.getLogger(__name__)
29_LOGGER.setLevel(logging.INFO)
30
31
32@contextmanager
33def get_free_loopback_tcp_port():
34    if socket.has_ipv6:
35        tcp_socket = socket.socket(socket.AF_INET6)
36    else:
37        tcp_socket = socket.socket(socket.AF_INET)
38    tcp_socket.bind(("", 0))
39    address_tuple = tcp_socket.getsockname()
40    yield "localhost:%s" % (address_tuple[1])
41    tcp_socket.close()
42
43
44class Greeter(helloworld_pb2_grpc.GreeterServicer):
45    def SayHello(self, request, unused_context):
46        return helloworld_pb2.HelloReply(message="Hello, %s!" % request.name)
47
48
49def create_server(server_address):
50    server = grpc.server(futures.ThreadPoolExecutor())
51    helloworld_pb2_grpc.add_GreeterServicer_to_server(Greeter(), server)
52    bound_port = server.add_insecure_port(server_address)
53    assert bound_port == int(server_address.split(":")[-1])
54    return server
55
56
57def process(stub, wait_for_ready=None):
58    try:
59        response = stub.SayHello(
60            helloworld_pb2.HelloRequest(name="you"),
61            wait_for_ready=wait_for_ready,
62        )
63        message = response.message
64    except grpc.RpcError as rpc_error:
65        assert rpc_error.code() == grpc.StatusCode.UNAVAILABLE
66        assert not wait_for_ready
67        message = rpc_error
68    else:
69        assert wait_for_ready
70    _LOGGER.info(
71        "Wait-for-ready %s, client received: %s",
72        "enabled" if wait_for_ready else "disabled",
73        message,
74    )
75
76
77def main():
78    # Pick a random free port
79    with get_free_loopback_tcp_port() as server_address:
80        # Register connectivity event to notify main thread
81        transient_failure_event = threading.Event()
82
83        def wait_for_transient_failure(channel_connectivity):
84            if (
85                channel_connectivity
86                == grpc.ChannelConnectivity.TRANSIENT_FAILURE
87            ):
88                transient_failure_event.set()
89
90        # Create gRPC channel
91        channel = grpc.insecure_channel(server_address)
92        channel.subscribe(wait_for_transient_failure)
93        stub = helloworld_pb2_grpc.GreeterStub(channel)
94
95        # Fire an RPC without wait_for_ready
96        thread_disabled_wait_for_ready = threading.Thread(
97            target=process, args=(stub, False)
98        )
99        thread_disabled_wait_for_ready.start()
100        # Fire an RPC with wait_for_ready
101        thread_enabled_wait_for_ready = threading.Thread(
102            target=process, args=(stub, True)
103        )
104        thread_enabled_wait_for_ready.start()
105
106    # Wait for the channel entering TRANSIENT FAILURE state.
107    transient_failure_event.wait()
108    server = create_server(server_address)
109    server.start()
110
111    # Expected to fail with StatusCode.UNAVAILABLE.
112    thread_disabled_wait_for_ready.join()
113    # Expected to success.
114    thread_enabled_wait_for_ready.join()
115
116    server.stop(None)
117    channel.close()
118
119
120if __name__ == "__main__":
121    logging.basicConfig(level=logging.INFO)
122    main()
123