• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import argparse
16import collections
17from concurrent import futures
18import logging
19import signal
20import socket
21import sys
22import threading
23import time
24from typing import DefaultDict, Dict, List, Mapping, Sequence, Set, Tuple
25
26import grpc
27from grpc_channelz.v1 import channelz
28from grpc_channelz.v1 import channelz_pb2
29from grpc_csm_observability import CsmOpenTelemetryPlugin
30from grpc_health.v1 import health as grpc_health
31from grpc_health.v1 import health_pb2
32from grpc_health.v1 import health_pb2_grpc
33from grpc_reflection.v1alpha import reflection
34from opentelemetry.exporter.prometheus import PrometheusMetricReader
35from opentelemetry.sdk.metrics import MeterProvider
36from prometheus_client import start_http_server
37
38from src.proto.grpc.testing import empty_pb2
39from src.proto.grpc.testing import messages_pb2
40from src.proto.grpc.testing import test_pb2
41from src.proto.grpc.testing import test_pb2_grpc
42from src.python.grpcio_tests.tests.fork import native_debug
43
44native_debug.install_failure_signal_handler()
45
46# NOTE: This interop server is not fully compatible with all xDS interop tests.
47#  It currently only implements enough functionality to pass the xDS security
48#  tests.
49
50_LISTEN_HOST = "0.0.0.0"
51_PROMETHEUS_PORT = 9464
52
53_THREAD_POOL_SIZE = 256
54
55logger = logging.getLogger()
56console_handler = logging.StreamHandler()
57formatter = logging.Formatter(fmt="%(asctime)s: %(levelname)-8s %(message)s")
58console_handler.setFormatter(formatter)
59logger.addHandler(console_handler)
60
61
62class TestService(test_pb2_grpc.TestServiceServicer):
63    def __init__(self, server_id, hostname):
64        self._server_id = server_id
65        self._hostname = hostname
66
67    def EmptyCall(
68        self, _: empty_pb2.Empty, context: grpc.ServicerContext
69    ) -> empty_pb2.Empty:
70        context.send_initial_metadata((("hostname", self._hostname),))
71        return empty_pb2.Empty()
72
73    def UnaryCall(
74        self, request: messages_pb2.SimpleRequest, context: grpc.ServicerContext
75    ) -> messages_pb2.SimpleResponse:
76        context.send_initial_metadata((("hostname", self._hostname),))
77        if request.response_size > 0:
78            response = messages_pb2.SimpleResponse(
79                payload=messages_pb2.Payload(body=b"0" * request.response_size)
80            )
81        else:
82            response = messages_pb2.SimpleResponse()
83        response.server_id = self._server_id
84        response.hostname = self._hostname
85        return response
86
87
88def _configure_maintenance_server(
89    server: grpc.Server, maintenance_port: int
90) -> None:
91    channelz.add_channelz_servicer(server)
92    listen_address = f"{_LISTEN_HOST}:{maintenance_port}"
93    server.add_insecure_port(listen_address)
94    health_servicer = grpc_health.HealthServicer(
95        experimental_non_blocking=True,
96        experimental_thread_pool=futures.ThreadPoolExecutor(
97            max_workers=_THREAD_POOL_SIZE
98        ),
99    )
100
101    health_pb2_grpc.add_HealthServicer_to_server(health_servicer, server)
102    SERVICE_NAMES = (
103        test_pb2.DESCRIPTOR.services_by_name["TestService"].full_name,
104        health_pb2.DESCRIPTOR.services_by_name["Health"].full_name,
105        channelz_pb2.DESCRIPTOR.services_by_name["Channelz"].full_name,
106        reflection.SERVICE_NAME,
107    )
108    for service in SERVICE_NAMES:
109        health_servicer.set(service, health_pb2.HealthCheckResponse.SERVING)
110    reflection.enable_server_reflection(SERVICE_NAMES, server)
111
112
113def _configure_test_server(
114    server: grpc.Server, port: int, secure_mode: bool, server_id: str
115) -> None:
116    test_pb2_grpc.add_TestServiceServicer_to_server(
117        TestService(server_id, socket.gethostname()), server
118    )
119    listen_address = f"{_LISTEN_HOST}:{port}"
120    if not secure_mode:
121        server.add_insecure_port(listen_address)
122    else:
123        logger.info("Running with xDS Server credentials")
124        server_fallback_creds = grpc.insecure_server_credentials()
125        server_creds = grpc.xds_server_credentials(server_fallback_creds)
126        server.add_secure_port(listen_address, server_creds)
127
128
129def _run(
130    port: int,
131    maintenance_port: int,
132    secure_mode: bool,
133    server_id: str,
134    enable_csm_observability: bool,
135) -> None:
136    csm_plugin = None
137    if enable_csm_observability:
138        csm_plugin = _prepare_csm_observability_plugin()
139        csm_plugin.register_global()
140    if port == maintenance_port:
141        server = grpc.server(
142            futures.ThreadPoolExecutor(max_workers=_THREAD_POOL_SIZE)
143        )
144        _configure_test_server(server, port, secure_mode, server_id)
145        _configure_maintenance_server(server, maintenance_port)
146        server.start()
147        logger.info("Test server listening on port %d", port)
148        logger.info("Maintenance server listening on port %d", maintenance_port)
149        server.wait_for_termination()
150    else:
151        maintenance_server = grpc.server(
152            futures.ThreadPoolExecutor(max_workers=_THREAD_POOL_SIZE)
153        )
154        _configure_maintenance_server(maintenance_server, maintenance_port)
155        maintenance_server.start()
156        logger.info("Maintenance server listening on port %d", maintenance_port)
157        test_server = grpc.server(
158            futures.ThreadPoolExecutor(max_workers=_THREAD_POOL_SIZE),
159            xds=secure_mode,
160        )
161        _configure_test_server(test_server, port, secure_mode, server_id)
162        test_server.start()
163        logger.info("Test server listening on port %d", port)
164        test_server.wait_for_termination()
165        maintenance_server.wait_for_termination()
166    if csm_plugin:
167        csm_plugin.deregister_global()
168
169
170def bool_arg(arg: str) -> bool:
171    if arg.lower() in ("true", "yes", "y"):
172        return True
173    elif arg.lower() in ("false", "no", "n"):
174        return False
175    else:
176        raise argparse.ArgumentTypeError(f"Could not parse '{arg}' as a bool.")
177
178
179def _prepare_csm_observability_plugin() -> CsmOpenTelemetryPlugin:
180    # Start Prometheus client
181    start_http_server(port=_PROMETHEUS_PORT, addr="0.0.0.0")
182    reader = PrometheusMetricReader()
183    meter_provider = MeterProvider(metric_readers=[reader])
184    csm_plugin = CsmOpenTelemetryPlugin(
185        meter_provider=meter_provider,
186    )
187    return csm_plugin
188
189
190if __name__ == "__main__":
191    parser = argparse.ArgumentParser(
192        description="Run Python xDS interop server."
193    )
194    parser.add_argument(
195        "--port", type=int, default=8080, help="Port for test server."
196    )
197    parser.add_argument(
198        "--maintenance_port",
199        type=int,
200        default=8080,
201        help="Port for servers besides test server.",
202    )
203    parser.add_argument(
204        "--secure_mode",
205        type=bool_arg,
206        default="False",
207        help="If specified, uses xDS to retrieve server credentials.",
208    )
209    parser.add_argument(
210        "--server_id",
211        type=str,
212        default="python_server",
213        help="The server ID to return in responses..",
214    )
215    parser.add_argument(
216        "--verbose",
217        help="verbose log output",
218        default=False,
219        action="store_true",
220    )
221    parser.add_argument(
222        "--enable_csm_observability",
223        help="Whether to enable CSM Observability",
224        default="False",
225        type=bool_arg,
226    )
227    args = parser.parse_args()
228    if args.verbose:
229        logger.setLevel(logging.DEBUG)
230    else:
231        logger.setLevel(logging.INFO)
232    if args.secure_mode and args.port == args.maintenance_port:
233        raise ValueError(
234            "--port and --maintenance_port must not be the same when"
235            " --secure_mode is set."
236        )
237    _run(
238        args.port,
239        args.maintenance_port,
240        args.secure_mode,
241        args.server_id,
242        args.enable_csm_observability,
243    )
244