1# Copyright 2019 The gRPC Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Implementations of interoperability test methods.""" 15 16import argparse 17import asyncio 18import collections 19import datetime 20import enum 21import inspect 22import json 23import os 24import threading 25import time 26from typing import Any, Optional, Union 27 28import grpc 29from google import auth as google_auth 30from google.auth import environment_vars as google_auth_environment_vars 31from google.auth.transport import grpc as google_auth_transport_grpc 32from google.auth.transport import requests as google_auth_transport_requests 33from grpc.experimental import aio 34 35from src.proto.grpc.testing import empty_pb2, messages_pb2, test_pb2_grpc 36 37_INITIAL_METADATA_KEY = "x-grpc-test-echo-initial" 38_TRAILING_METADATA_KEY = "x-grpc-test-echo-trailing-bin" 39 40 41async def _expect_status_code(call: aio.Call, 42 expected_code: grpc.StatusCode) -> None: 43 code = await call.code() 44 if code != expected_code: 45 raise ValueError('expected code %s, got %s' % 46 (expected_code, await call.code())) 47 48 49async def _expect_status_details(call: aio.Call, expected_details: str) -> None: 50 details = await call.details() 51 if details != expected_details: 52 raise ValueError('expected message %s, got %s' % 53 (expected_details, await call.details())) 54 55 56async def _validate_status_code_and_details(call: aio.Call, 57 expected_code: grpc.StatusCode, 58 expected_details: str) -> None: 59 await _expect_status_code(call, expected_code) 60 await _expect_status_details(call, expected_details) 61 62 63def _validate_payload_type_and_length( 64 response: Union[messages_pb2.SimpleResponse, messages_pb2. 65 StreamingOutputCallResponse], expected_type: Any, 66 expected_length: int) -> None: 67 if response.payload.type is not expected_type: 68 raise ValueError('expected payload type %s, got %s' % 69 (expected_type, type(response.payload.type))) 70 elif len(response.payload.body) != expected_length: 71 raise ValueError('expected payload body size %d, got %d' % 72 (expected_length, len(response.payload.body))) 73 74 75async def _large_unary_common_behavior( 76 stub: test_pb2_grpc.TestServiceStub, fill_username: bool, 77 fill_oauth_scope: bool, call_credentials: Optional[grpc.CallCredentials] 78) -> messages_pb2.SimpleResponse: 79 size = 314159 80 request = messages_pb2.SimpleRequest( 81 response_type=messages_pb2.COMPRESSABLE, 82 response_size=size, 83 payload=messages_pb2.Payload(body=b'\x00' * 271828), 84 fill_username=fill_username, 85 fill_oauth_scope=fill_oauth_scope) 86 response = await stub.UnaryCall(request, credentials=call_credentials) 87 _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size) 88 return response 89 90 91async def _empty_unary(stub: test_pb2_grpc.TestServiceStub) -> None: 92 response = await stub.EmptyCall(empty_pb2.Empty()) 93 if not isinstance(response, empty_pb2.Empty): 94 raise TypeError('response is of type "%s", not empty_pb2.Empty!' % 95 type(response)) 96 97 98async def _large_unary(stub: test_pb2_grpc.TestServiceStub) -> None: 99 await _large_unary_common_behavior(stub, False, False, None) 100 101 102async def _client_streaming(stub: test_pb2_grpc.TestServiceStub) -> None: 103 payload_body_sizes = ( 104 27182, 105 8, 106 1828, 107 45904, 108 ) 109 110 async def request_gen(): 111 for size in payload_body_sizes: 112 yield messages_pb2.StreamingInputCallRequest( 113 payload=messages_pb2.Payload(body=b'\x00' * size)) 114 115 response = await stub.StreamingInputCall(request_gen()) 116 if response.aggregated_payload_size != sum(payload_body_sizes): 117 raise ValueError('incorrect size %d!' % 118 response.aggregated_payload_size) 119 120 121async def _server_streaming(stub: test_pb2_grpc.TestServiceStub) -> None: 122 sizes = ( 123 31415, 124 9, 125 2653, 126 58979, 127 ) 128 129 request = messages_pb2.StreamingOutputCallRequest( 130 response_type=messages_pb2.COMPRESSABLE, 131 response_parameters=( 132 messages_pb2.ResponseParameters(size=sizes[0]), 133 messages_pb2.ResponseParameters(size=sizes[1]), 134 messages_pb2.ResponseParameters(size=sizes[2]), 135 messages_pb2.ResponseParameters(size=sizes[3]), 136 )) 137 call = stub.StreamingOutputCall(request) 138 for size in sizes: 139 response = await call.read() 140 _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, 141 size) 142 143 144async def _ping_pong(stub: test_pb2_grpc.TestServiceStub) -> None: 145 request_response_sizes = ( 146 31415, 147 9, 148 2653, 149 58979, 150 ) 151 request_payload_sizes = ( 152 27182, 153 8, 154 1828, 155 45904, 156 ) 157 158 call = stub.FullDuplexCall() 159 for response_size, payload_size in zip(request_response_sizes, 160 request_payload_sizes): 161 request = messages_pb2.StreamingOutputCallRequest( 162 response_type=messages_pb2.COMPRESSABLE, 163 response_parameters=(messages_pb2.ResponseParameters( 164 size=response_size),), 165 payload=messages_pb2.Payload(body=b'\x00' * payload_size)) 166 167 await call.write(request) 168 response = await call.read() 169 _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, 170 response_size) 171 await call.done_writing() 172 await _validate_status_code_and_details(call, grpc.StatusCode.OK, '') 173 174 175async def _cancel_after_begin(stub: test_pb2_grpc.TestServiceStub): 176 call = stub.StreamingInputCall() 177 call.cancel() 178 if not call.cancelled(): 179 raise ValueError('expected cancelled method to return True') 180 code = await call.code() 181 if code is not grpc.StatusCode.CANCELLED: 182 raise ValueError('expected status code CANCELLED') 183 184 185async def _cancel_after_first_response(stub: test_pb2_grpc.TestServiceStub): 186 request_response_sizes = ( 187 31415, 188 9, 189 2653, 190 58979, 191 ) 192 request_payload_sizes = ( 193 27182, 194 8, 195 1828, 196 45904, 197 ) 198 199 call = stub.FullDuplexCall() 200 201 response_size = request_response_sizes[0] 202 payload_size = request_payload_sizes[0] 203 request = messages_pb2.StreamingOutputCallRequest( 204 response_type=messages_pb2.COMPRESSABLE, 205 response_parameters=(messages_pb2.ResponseParameters( 206 size=response_size),), 207 payload=messages_pb2.Payload(body=b'\x00' * payload_size)) 208 209 await call.write(request) 210 await call.read() 211 212 call.cancel() 213 214 try: 215 await call.read() 216 except asyncio.CancelledError: 217 assert await call.code() is grpc.StatusCode.CANCELLED 218 else: 219 raise ValueError('expected call to be cancelled') 220 221 222async def _timeout_on_sleeping_server(stub: test_pb2_grpc.TestServiceStub): 223 request_payload_size = 27182 224 time_limit = datetime.timedelta(seconds=1) 225 226 call = stub.FullDuplexCall(timeout=time_limit.total_seconds()) 227 228 request = messages_pb2.StreamingOutputCallRequest( 229 response_type=messages_pb2.COMPRESSABLE, 230 payload=messages_pb2.Payload(body=b'\x00' * request_payload_size), 231 response_parameters=(messages_pb2.ResponseParameters( 232 interval_us=int(time_limit.total_seconds() * 2 * 10**6)),)) 233 await call.write(request) 234 await call.done_writing() 235 try: 236 await call.read() 237 except aio.AioRpcError as rpc_error: 238 if rpc_error.code() is not grpc.StatusCode.DEADLINE_EXCEEDED: 239 raise 240 else: 241 raise ValueError('expected call to exceed deadline') 242 243 244async def _empty_stream(stub: test_pb2_grpc.TestServiceStub): 245 call = stub.FullDuplexCall() 246 await call.done_writing() 247 assert await call.read() == aio.EOF 248 249 250async def _status_code_and_message(stub: test_pb2_grpc.TestServiceStub): 251 details = 'test status message' 252 status = grpc.StatusCode.UNKNOWN # code = 2 253 254 # Test with a UnaryCall 255 request = messages_pb2.SimpleRequest( 256 response_type=messages_pb2.COMPRESSABLE, 257 response_size=1, 258 payload=messages_pb2.Payload(body=b'\x00'), 259 response_status=messages_pb2.EchoStatus(code=status.value[0], 260 message=details)) 261 call = stub.UnaryCall(request) 262 await _validate_status_code_and_details(call, status, details) 263 264 # Test with a FullDuplexCall 265 call = stub.FullDuplexCall() 266 request = messages_pb2.StreamingOutputCallRequest( 267 response_type=messages_pb2.COMPRESSABLE, 268 response_parameters=(messages_pb2.ResponseParameters(size=1),), 269 payload=messages_pb2.Payload(body=b'\x00'), 270 response_status=messages_pb2.EchoStatus(code=status.value[0], 271 message=details)) 272 await call.write(request) # sends the initial request. 273 await call.done_writing() 274 await _validate_status_code_and_details(call, status, details) 275 276 277async def _unimplemented_method(stub: test_pb2_grpc.TestServiceStub): 278 call = stub.UnimplementedCall(empty_pb2.Empty()) 279 await _expect_status_code(call, grpc.StatusCode.UNIMPLEMENTED) 280 281 282async def _unimplemented_service(stub: test_pb2_grpc.UnimplementedServiceStub): 283 call = stub.UnimplementedCall(empty_pb2.Empty()) 284 await _expect_status_code(call, grpc.StatusCode.UNIMPLEMENTED) 285 286 287async def _custom_metadata(stub: test_pb2_grpc.TestServiceStub): 288 initial_metadata_value = "test_initial_metadata_value" 289 trailing_metadata_value = b"\x0a\x0b\x0a\x0b\x0a\x0b" 290 metadata = aio.Metadata( 291 (_INITIAL_METADATA_KEY, initial_metadata_value), 292 (_TRAILING_METADATA_KEY, trailing_metadata_value), 293 ) 294 295 async def _validate_metadata(call): 296 initial_metadata = await call.initial_metadata() 297 if initial_metadata[_INITIAL_METADATA_KEY] != initial_metadata_value: 298 raise ValueError('expected initial metadata %s, got %s' % 299 (initial_metadata_value, 300 initial_metadata[_INITIAL_METADATA_KEY])) 301 302 trailing_metadata = await call.trailing_metadata() 303 if trailing_metadata[_TRAILING_METADATA_KEY] != trailing_metadata_value: 304 raise ValueError('expected trailing metadata %s, got %s' % 305 (trailing_metadata_value, 306 trailing_metadata[_TRAILING_METADATA_KEY])) 307 308 # Testing with UnaryCall 309 request = messages_pb2.SimpleRequest( 310 response_type=messages_pb2.COMPRESSABLE, 311 response_size=1, 312 payload=messages_pb2.Payload(body=b'\x00')) 313 call = stub.UnaryCall(request, metadata=metadata) 314 await _validate_metadata(call) 315 316 # Testing with FullDuplexCall 317 call = stub.FullDuplexCall(metadata=metadata) 318 request = messages_pb2.StreamingOutputCallRequest( 319 response_type=messages_pb2.COMPRESSABLE, 320 response_parameters=(messages_pb2.ResponseParameters(size=1),)) 321 await call.write(request) 322 await call.read() 323 await call.done_writing() 324 await _validate_metadata(call) 325 326 327async def _compute_engine_creds(stub: test_pb2_grpc.TestServiceStub, 328 args: argparse.Namespace): 329 response = await _large_unary_common_behavior(stub, True, True, None) 330 if args.default_service_account != response.username: 331 raise ValueError('expected username %s, got %s' % 332 (args.default_service_account, response.username)) 333 334 335async def _oauth2_auth_token(stub: test_pb2_grpc.TestServiceStub, 336 args: argparse.Namespace): 337 json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS] 338 wanted_email = json.load(open(json_key_filename, 'r'))['client_email'] 339 response = await _large_unary_common_behavior(stub, True, True, None) 340 if wanted_email != response.username: 341 raise ValueError('expected username %s, got %s' % 342 (wanted_email, response.username)) 343 if args.oauth_scope.find(response.oauth_scope) == -1: 344 raise ValueError( 345 'expected to find oauth scope "{}" in received "{}"'.format( 346 response.oauth_scope, args.oauth_scope)) 347 348 349async def _jwt_token_creds(stub: test_pb2_grpc.TestServiceStub): 350 json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS] 351 wanted_email = json.load(open(json_key_filename, 'r'))['client_email'] 352 response = await _large_unary_common_behavior(stub, True, False, None) 353 if wanted_email != response.username: 354 raise ValueError('expected username %s, got %s' % 355 (wanted_email, response.username)) 356 357 358async def _per_rpc_creds(stub: test_pb2_grpc.TestServiceStub, 359 args: argparse.Namespace): 360 json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS] 361 wanted_email = json.load(open(json_key_filename, 'r'))['client_email'] 362 google_credentials, unused_project_id = google_auth.default( 363 scopes=[args.oauth_scope]) 364 call_credentials = grpc.metadata_call_credentials( 365 google_auth_transport_grpc.AuthMetadataPlugin( 366 credentials=google_credentials, 367 request=google_auth_transport_requests.Request())) 368 response = await _large_unary_common_behavior(stub, True, False, 369 call_credentials) 370 if wanted_email != response.username: 371 raise ValueError('expected username %s, got %s' % 372 (wanted_email, response.username)) 373 374 375async def _special_status_message(stub: test_pb2_grpc.TestServiceStub): 376 details = b'\t\ntest with whitespace\r\nand Unicode BMP \xe2\x98\xba and non-BMP \xf0\x9f\x98\x88\t\n'.decode( 377 'utf-8') 378 status = grpc.StatusCode.UNKNOWN # code = 2 379 380 # Test with a UnaryCall 381 request = messages_pb2.SimpleRequest( 382 response_type=messages_pb2.COMPRESSABLE, 383 response_size=1, 384 payload=messages_pb2.Payload(body=b'\x00'), 385 response_status=messages_pb2.EchoStatus(code=status.value[0], 386 message=details)) 387 call = stub.UnaryCall(request) 388 await _validate_status_code_and_details(call, status, details) 389 390 391@enum.unique 392class TestCase(enum.Enum): 393 EMPTY_UNARY = 'empty_unary' 394 LARGE_UNARY = 'large_unary' 395 SERVER_STREAMING = 'server_streaming' 396 CLIENT_STREAMING = 'client_streaming' 397 PING_PONG = 'ping_pong' 398 CANCEL_AFTER_BEGIN = 'cancel_after_begin' 399 CANCEL_AFTER_FIRST_RESPONSE = 'cancel_after_first_response' 400 TIMEOUT_ON_SLEEPING_SERVER = 'timeout_on_sleeping_server' 401 EMPTY_STREAM = 'empty_stream' 402 STATUS_CODE_AND_MESSAGE = 'status_code_and_message' 403 UNIMPLEMENTED_METHOD = 'unimplemented_method' 404 UNIMPLEMENTED_SERVICE = 'unimplemented_service' 405 CUSTOM_METADATA = "custom_metadata" 406 COMPUTE_ENGINE_CREDS = 'compute_engine_creds' 407 OAUTH2_AUTH_TOKEN = 'oauth2_auth_token' 408 JWT_TOKEN_CREDS = 'jwt_token_creds' 409 PER_RPC_CREDS = 'per_rpc_creds' 410 SPECIAL_STATUS_MESSAGE = 'special_status_message' 411 412 413_TEST_CASE_IMPLEMENTATION_MAPPING = { 414 TestCase.EMPTY_UNARY: _empty_unary, 415 TestCase.LARGE_UNARY: _large_unary, 416 TestCase.SERVER_STREAMING: _server_streaming, 417 TestCase.CLIENT_STREAMING: _client_streaming, 418 TestCase.PING_PONG: _ping_pong, 419 TestCase.CANCEL_AFTER_BEGIN: _cancel_after_begin, 420 TestCase.CANCEL_AFTER_FIRST_RESPONSE: _cancel_after_first_response, 421 TestCase.TIMEOUT_ON_SLEEPING_SERVER: _timeout_on_sleeping_server, 422 TestCase.EMPTY_STREAM: _empty_stream, 423 TestCase.STATUS_CODE_AND_MESSAGE: _status_code_and_message, 424 TestCase.UNIMPLEMENTED_METHOD: _unimplemented_method, 425 TestCase.UNIMPLEMENTED_SERVICE: _unimplemented_service, 426 TestCase.CUSTOM_METADATA: _custom_metadata, 427 TestCase.COMPUTE_ENGINE_CREDS: _compute_engine_creds, 428 TestCase.OAUTH2_AUTH_TOKEN: _oauth2_auth_token, 429 TestCase.JWT_TOKEN_CREDS: _jwt_token_creds, 430 TestCase.PER_RPC_CREDS: _per_rpc_creds, 431 TestCase.SPECIAL_STATUS_MESSAGE: _special_status_message, 432} 433 434 435async def test_interoperability(case: TestCase, 436 stub: test_pb2_grpc.TestServiceStub, 437 args: Optional[argparse.Namespace] = None 438 ) -> None: 439 method = _TEST_CASE_IMPLEMENTATION_MAPPING.get(case) 440 if method is None: 441 raise NotImplementedError(f'Test case "{case}" not implemented!') 442 else: 443 num_params = len(inspect.signature(method).parameters) 444 if num_params == 1: 445 await method(stub) 446 elif num_params == 2: 447 if args is not None: 448 await method(stub, args) 449 else: 450 raise ValueError(f'Failed to run case [{case}]: args is None') 451 else: 452 raise ValueError(f'Invalid number of parameters [{num_params}]') 453