1# Copyright 2022 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import collections 16import contextlib 17import logging 18import os 19import subprocess 20import sys 21import tempfile 22import time 23from typing import Iterable, List, Mapping, Set, Tuple 24import unittest 25 26import grpc.experimental 27import xds_interop_client 28import xds_interop_server 29 30from src.proto.grpc.testing import empty_pb2 31from src.proto.grpc.testing import messages_pb2 32from src.proto.grpc.testing import test_pb2 33from src.proto.grpc.testing import test_pb2_grpc 34import src.python.grpcio_tests.tests.unit.framework.common as framework_common 35 36_CLIENT_PATH = os.path.abspath(os.path.realpath(xds_interop_client.__file__)) 37_SERVER_PATH = os.path.abspath(os.path.realpath(xds_interop_server.__file__)) 38 39_METHODS = ( 40 (messages_pb2.ClientConfigureRequest.UNARY_CALL, "UNARY_CALL"), 41 (messages_pb2.ClientConfigureRequest.EMPTY_CALL, "EMPTY_CALL"), 42) 43 44_QPS = 100 45_NUM_CHANNELS = 20 46 47_TEST_ITERATIONS = 10 48_ITERATION_DURATION_SECONDS = 1 49_SUBPROCESS_TIMEOUT_SECONDS = 2 50 51 52def _set_union(a: Iterable, b: Iterable) -> Set: 53 c = set(a) 54 c.update(b) 55 return c 56 57 58@contextlib.contextmanager 59def _start_python_with_args( 60 file: str, args: List[str] 61) -> Tuple[subprocess.Popen, tempfile.TemporaryFile, tempfile.TemporaryFile]: 62 with tempfile.TemporaryFile(mode="r") as stdout: 63 with tempfile.TemporaryFile(mode="r") as stderr: 64 proc = subprocess.Popen( 65 (sys.executable, file) + tuple(args), 66 stdout=stdout, 67 stderr=stderr, 68 ) 69 yield proc, stdout, stderr 70 71 72def _dump_stream( 73 process_name: str, stream_name: str, stream: tempfile.TemporaryFile 74): 75 sys.stderr.write(f"{process_name} {stream_name}:\n") 76 stream.seek(0) 77 sys.stderr.write(stream.read()) 78 79 80def _dump_streams( 81 process_name: str, 82 stdout: tempfile.TemporaryFile, 83 stderr: tempfile.TemporaryFile, 84): 85 _dump_stream(process_name, "stdout", stdout) 86 _dump_stream(process_name, "stderr", stderr) 87 sys.stderr.write(f"End {process_name} output.\n") 88 89 90def _index_accumulated_stats( 91 response: messages_pb2.LoadBalancerAccumulatedStatsResponse, 92) -> Mapping[str, Mapping[int, int]]: 93 indexed = collections.defaultdict(lambda: collections.defaultdict(int)) 94 for _, method_str in _METHODS: 95 for status in response.stats_per_method[method_str].result.keys(): 96 indexed[method_str][status] = response.stats_per_method[ 97 method_str 98 ].result[status] 99 return indexed 100 101 102def _subtract_indexed_stats( 103 a: Mapping[str, Mapping[int, int]], b: Mapping[str, Mapping[int, int]] 104): 105 c = collections.defaultdict(lambda: collections.defaultdict(int)) 106 all_methods = _set_union(a.keys(), b.keys()) 107 for method in all_methods: 108 all_statuses = _set_union(a[method].keys(), b[method].keys()) 109 for status in all_statuses: 110 c[method][status] = a[method][status] - b[method][status] 111 return c 112 113 114def _collect_stats( 115 stats_port: int, duration: int 116) -> Mapping[str, Mapping[int, int]]: 117 settings = { 118 "target": f"localhost:{stats_port}", 119 "insecure": True, 120 } 121 response = test_pb2_grpc.LoadBalancerStatsService.GetClientAccumulatedStats( 122 messages_pb2.LoadBalancerAccumulatedStatsRequest(), **settings 123 ) 124 before = _index_accumulated_stats(response) 125 time.sleep(duration) 126 response = test_pb2_grpc.LoadBalancerStatsService.GetClientAccumulatedStats( 127 messages_pb2.LoadBalancerAccumulatedStatsRequest(), **settings 128 ) 129 after = _index_accumulated_stats(response) 130 return _subtract_indexed_stats(after, before) 131 132 133class XdsInteropClientTest(unittest.TestCase): 134 def _assert_client_consistent( 135 self, server_port: int, stats_port: int, qps: int, num_channels: int 136 ): 137 settings = { 138 "target": f"localhost:{stats_port}", 139 "insecure": True, 140 } 141 for i in range(_TEST_ITERATIONS): 142 target_method, target_method_str = _METHODS[i % len(_METHODS)] 143 test_pb2_grpc.XdsUpdateClientConfigureService.Configure( 144 messages_pb2.ClientConfigureRequest(types=[target_method]), 145 **settings, 146 ) 147 delta = _collect_stats(stats_port, _ITERATION_DURATION_SECONDS) 148 logging.info("Delta: %s", delta) 149 for _, method_str in _METHODS: 150 for status in delta[method_str]: 151 if status == 0 and method_str == target_method_str: 152 self.assertGreater(delta[method_str][status], 0, delta) 153 else: 154 self.assertEqual(delta[method_str][status], 0, delta) 155 156 def test_configure_consistency(self): 157 _, server_port, socket = framework_common.get_socket() 158 159 with _start_python_with_args( 160 _SERVER_PATH, 161 [f"--port={server_port}", f"--maintenance_port={server_port}"], 162 ) as (server, server_stdout, server_stderr): 163 # Send RPC to server to make sure it's running. 164 logging.info("Sending RPC to server.") 165 test_pb2_grpc.TestService.EmptyCall( 166 empty_pb2.Empty(), 167 f"localhost:{server_port}", 168 insecure=True, 169 wait_for_ready=True, 170 ) 171 logging.info("Server successfully started.") 172 socket.close() 173 _, stats_port, stats_socket = framework_common.get_socket() 174 with _start_python_with_args( 175 _CLIENT_PATH, 176 [ 177 f"--server=localhost:{server_port}", 178 f"--stats_port={stats_port}", 179 f"--qps={_QPS}", 180 f"--num_channels={_NUM_CHANNELS}", 181 ], 182 ) as (client, client_stdout, client_stderr): 183 stats_socket.close() 184 try: 185 self._assert_client_consistent( 186 server_port, stats_port, _QPS, _NUM_CHANNELS 187 ) 188 except: 189 _dump_streams("server", server_stdout, server_stderr) 190 _dump_streams("client", client_stdout, client_stderr) 191 raise 192 finally: 193 server.kill() 194 client.kill() 195 server.wait(timeout=_SUBPROCESS_TIMEOUT_SECONDS) 196 client.wait(timeout=_SUBPROCESS_TIMEOUT_SECONDS) 197 198 199if __name__ == "__main__": 200 logging.basicConfig() 201 unittest.main(verbosity=2) 202