1# Copyright 2019 The gRPC Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Implementations of interoperability test methods.""" 15 16import argparse 17import asyncio 18import collections 19import datetime 20import enum 21import inspect 22import json 23import os 24import threading 25import time 26from typing import Any, Optional, Union 27 28import grpc 29from google import auth as google_auth 30from google.auth import environment_vars as google_auth_environment_vars 31from google.auth.transport import grpc as google_auth_transport_grpc 32from google.auth.transport import requests as google_auth_transport_requests 33from grpc.experimental import aio 34 35from src.proto.grpc.testing import empty_pb2, messages_pb2, test_pb2_grpc 36 37_INITIAL_METADATA_KEY = "x-grpc-test-echo-initial" 38_TRAILING_METADATA_KEY = "x-grpc-test-echo-trailing-bin" 39 40 41async def _expect_status_code(call: aio.Call, 42 expected_code: grpc.StatusCode) -> None: 43 code = await call.code() 44 if code != expected_code: 45 raise ValueError('expected code %s, got %s' % 46 (expected_code, await call.code())) 47 48 49async def _expect_status_details(call: aio.Call, expected_details: str) -> None: 50 details = await call.details() 51 if details != expected_details: 52 raise ValueError('expected message %s, got %s' % 53 (expected_details, await call.details())) 54 55 56async def _validate_status_code_and_details(call: aio.Call, 57 expected_code: grpc.StatusCode, 58 expected_details: str) -> None: 59 await _expect_status_code(call, expected_code) 60 await _expect_status_details(call, expected_details) 61 62 63def _validate_payload_type_and_length(response: Union[ 64 messages_pb2.SimpleResponse, messages_pb2.StreamingOutputCallResponse], 65 expected_type: Any, 66 expected_length: int) -> None: 67 if response.payload.type is not expected_type: 68 raise ValueError('expected payload type %s, got %s' % 69 (expected_type, type(response.payload.type))) 70 elif len(response.payload.body) != expected_length: 71 raise ValueError('expected payload body size %d, got %d' % 72 (expected_length, len(response.payload.body))) 73 74 75async def _large_unary_common_behavior( 76 stub: test_pb2_grpc.TestServiceStub, fill_username: bool, 77 fill_oauth_scope: bool, call_credentials: Optional[grpc.CallCredentials] 78) -> messages_pb2.SimpleResponse: 79 size = 314159 80 request = messages_pb2.SimpleRequest( 81 response_type=messages_pb2.COMPRESSABLE, 82 response_size=size, 83 payload=messages_pb2.Payload(body=b'\x00' * 271828), 84 fill_username=fill_username, 85 fill_oauth_scope=fill_oauth_scope) 86 response = await stub.UnaryCall(request, credentials=call_credentials) 87 _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size) 88 return response 89 90 91async def _empty_unary(stub: test_pb2_grpc.TestServiceStub) -> None: 92 response = await stub.EmptyCall(empty_pb2.Empty()) 93 if not isinstance(response, empty_pb2.Empty): 94 raise TypeError('response is of type "%s", not empty_pb2.Empty!' % 95 type(response)) 96 97 98async def _large_unary(stub: test_pb2_grpc.TestServiceStub) -> None: 99 await _large_unary_common_behavior(stub, False, False, None) 100 101 102async def _client_streaming(stub: test_pb2_grpc.TestServiceStub) -> None: 103 payload_body_sizes = ( 104 27182, 105 8, 106 1828, 107 45904, 108 ) 109 110 async def request_gen(): 111 for size in payload_body_sizes: 112 yield messages_pb2.StreamingInputCallRequest( 113 payload=messages_pb2.Payload(body=b'\x00' * size)) 114 115 response = await stub.StreamingInputCall(request_gen()) 116 if response.aggregated_payload_size != sum(payload_body_sizes): 117 raise ValueError('incorrect size %d!' % 118 response.aggregated_payload_size) 119 120 121async def _server_streaming(stub: test_pb2_grpc.TestServiceStub) -> None: 122 sizes = ( 123 31415, 124 9, 125 2653, 126 58979, 127 ) 128 129 request = messages_pb2.StreamingOutputCallRequest( 130 response_type=messages_pb2.COMPRESSABLE, 131 response_parameters=( 132 messages_pb2.ResponseParameters(size=sizes[0]), 133 messages_pb2.ResponseParameters(size=sizes[1]), 134 messages_pb2.ResponseParameters(size=sizes[2]), 135 messages_pb2.ResponseParameters(size=sizes[3]), 136 )) 137 call = stub.StreamingOutputCall(request) 138 for size in sizes: 139 response = await call.read() 140 _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, 141 size) 142 143 144async def _ping_pong(stub: test_pb2_grpc.TestServiceStub) -> None: 145 request_response_sizes = ( 146 31415, 147 9, 148 2653, 149 58979, 150 ) 151 request_payload_sizes = ( 152 27182, 153 8, 154 1828, 155 45904, 156 ) 157 158 call = stub.FullDuplexCall() 159 for response_size, payload_size in zip(request_response_sizes, 160 request_payload_sizes): 161 request = messages_pb2.StreamingOutputCallRequest( 162 response_type=messages_pb2.COMPRESSABLE, 163 response_parameters=(messages_pb2.ResponseParameters( 164 size=response_size),), 165 payload=messages_pb2.Payload(body=b'\x00' * payload_size)) 166 167 await call.write(request) 168 response = await call.read() 169 _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, 170 response_size) 171 await call.done_writing() 172 await _validate_status_code_and_details(call, grpc.StatusCode.OK, '') 173 174 175async def _cancel_after_begin(stub: test_pb2_grpc.TestServiceStub): 176 call = stub.StreamingInputCall() 177 call.cancel() 178 if not call.cancelled(): 179 raise ValueError('expected cancelled method to return True') 180 code = await call.code() 181 if code is not grpc.StatusCode.CANCELLED: 182 raise ValueError('expected status code CANCELLED') 183 184 185async def _cancel_after_first_response(stub: test_pb2_grpc.TestServiceStub): 186 request_response_sizes = ( 187 31415, 188 9, 189 2653, 190 58979, 191 ) 192 request_payload_sizes = ( 193 27182, 194 8, 195 1828, 196 45904, 197 ) 198 199 call = stub.FullDuplexCall() 200 201 response_size = request_response_sizes[0] 202 payload_size = request_payload_sizes[0] 203 request = messages_pb2.StreamingOutputCallRequest( 204 response_type=messages_pb2.COMPRESSABLE, 205 response_parameters=(messages_pb2.ResponseParameters( 206 size=response_size),), 207 payload=messages_pb2.Payload(body=b'\x00' * payload_size)) 208 209 await call.write(request) 210 await call.read() 211 212 call.cancel() 213 214 try: 215 await call.read() 216 except asyncio.CancelledError: 217 assert await call.code() is grpc.StatusCode.CANCELLED 218 else: 219 raise ValueError('expected call to be cancelled') 220 221 222async def _timeout_on_sleeping_server(stub: test_pb2_grpc.TestServiceStub): 223 request_payload_size = 27182 224 time_limit = datetime.timedelta(seconds=1) 225 226 call = stub.FullDuplexCall(timeout=time_limit.total_seconds()) 227 228 request = messages_pb2.StreamingOutputCallRequest( 229 response_type=messages_pb2.COMPRESSABLE, 230 payload=messages_pb2.Payload(body=b'\x00' * request_payload_size), 231 response_parameters=(messages_pb2.ResponseParameters( 232 interval_us=int(time_limit.total_seconds() * 2 * 10**6)),)) 233 await call.write(request) 234 await call.done_writing() 235 try: 236 await call.read() 237 except aio.AioRpcError as rpc_error: 238 if rpc_error.code() is not grpc.StatusCode.DEADLINE_EXCEEDED: 239 raise 240 else: 241 raise ValueError('expected call to exceed deadline') 242 243 244async def _empty_stream(stub: test_pb2_grpc.TestServiceStub): 245 call = stub.FullDuplexCall() 246 await call.done_writing() 247 assert await call.read() == aio.EOF 248 249 250async def _status_code_and_message(stub: test_pb2_grpc.TestServiceStub): 251 details = 'test status message' 252 status = grpc.StatusCode.UNKNOWN # code = 2 253 254 # Test with a UnaryCall 255 request = messages_pb2.SimpleRequest( 256 response_type=messages_pb2.COMPRESSABLE, 257 response_size=1, 258 payload=messages_pb2.Payload(body=b'\x00'), 259 response_status=messages_pb2.EchoStatus(code=status.value[0], 260 message=details)) 261 call = stub.UnaryCall(request) 262 await _validate_status_code_and_details(call, status, details) 263 264 # Test with a FullDuplexCall 265 call = stub.FullDuplexCall() 266 request = messages_pb2.StreamingOutputCallRequest( 267 response_type=messages_pb2.COMPRESSABLE, 268 response_parameters=(messages_pb2.ResponseParameters(size=1),), 269 payload=messages_pb2.Payload(body=b'\x00'), 270 response_status=messages_pb2.EchoStatus(code=status.value[0], 271 message=details)) 272 await call.write(request) # sends the initial request. 273 await call.done_writing() 274 try: 275 await call.read() 276 except aio.AioRpcError as rpc_error: 277 assert rpc_error.code() == status 278 await _validate_status_code_and_details(call, status, details) 279 280 281async def _unimplemented_method(stub: test_pb2_grpc.TestServiceStub): 282 call = stub.UnimplementedCall(empty_pb2.Empty()) 283 await _expect_status_code(call, grpc.StatusCode.UNIMPLEMENTED) 284 285 286async def _unimplemented_service(stub: test_pb2_grpc.UnimplementedServiceStub): 287 call = stub.UnimplementedCall(empty_pb2.Empty()) 288 await _expect_status_code(call, grpc.StatusCode.UNIMPLEMENTED) 289 290 291async def _custom_metadata(stub: test_pb2_grpc.TestServiceStub): 292 initial_metadata_value = "test_initial_metadata_value" 293 trailing_metadata_value = b"\x0a\x0b\x0a\x0b\x0a\x0b" 294 metadata = aio.Metadata( 295 (_INITIAL_METADATA_KEY, initial_metadata_value), 296 (_TRAILING_METADATA_KEY, trailing_metadata_value), 297 ) 298 299 async def _validate_metadata(call): 300 initial_metadata = await call.initial_metadata() 301 if initial_metadata[_INITIAL_METADATA_KEY] != initial_metadata_value: 302 raise ValueError('expected initial metadata %s, got %s' % 303 (initial_metadata_value, 304 initial_metadata[_INITIAL_METADATA_KEY])) 305 306 trailing_metadata = await call.trailing_metadata() 307 if trailing_metadata[_TRAILING_METADATA_KEY] != trailing_metadata_value: 308 raise ValueError('expected trailing metadata %s, got %s' % 309 (trailing_metadata_value, 310 trailing_metadata[_TRAILING_METADATA_KEY])) 311 312 # Testing with UnaryCall 313 request = messages_pb2.SimpleRequest( 314 response_type=messages_pb2.COMPRESSABLE, 315 response_size=1, 316 payload=messages_pb2.Payload(body=b'\x00')) 317 call = stub.UnaryCall(request, metadata=metadata) 318 await _validate_metadata(call) 319 320 # Testing with FullDuplexCall 321 call = stub.FullDuplexCall(metadata=metadata) 322 request = messages_pb2.StreamingOutputCallRequest( 323 response_type=messages_pb2.COMPRESSABLE, 324 response_parameters=(messages_pb2.ResponseParameters(size=1),)) 325 await call.write(request) 326 await call.read() 327 await call.done_writing() 328 await _validate_metadata(call) 329 330 331async def _compute_engine_creds(stub: test_pb2_grpc.TestServiceStub, 332 args: argparse.Namespace): 333 response = await _large_unary_common_behavior(stub, True, True, None) 334 if args.default_service_account != response.username: 335 raise ValueError('expected username %s, got %s' % 336 (args.default_service_account, response.username)) 337 338 339async def _oauth2_auth_token(stub: test_pb2_grpc.TestServiceStub, 340 args: argparse.Namespace): 341 json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS] 342 wanted_email = json.load(open(json_key_filename, 'r'))['client_email'] 343 response = await _large_unary_common_behavior(stub, True, True, None) 344 if wanted_email != response.username: 345 raise ValueError('expected username %s, got %s' % 346 (wanted_email, response.username)) 347 if args.oauth_scope.find(response.oauth_scope) == -1: 348 raise ValueError( 349 'expected to find oauth scope "{}" in received "{}"'.format( 350 response.oauth_scope, args.oauth_scope)) 351 352 353async def _jwt_token_creds(stub: test_pb2_grpc.TestServiceStub): 354 json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS] 355 wanted_email = json.load(open(json_key_filename, 'r'))['client_email'] 356 response = await _large_unary_common_behavior(stub, True, False, None) 357 if wanted_email != response.username: 358 raise ValueError('expected username %s, got %s' % 359 (wanted_email, response.username)) 360 361 362async def _per_rpc_creds(stub: test_pb2_grpc.TestServiceStub, 363 args: argparse.Namespace): 364 json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS] 365 wanted_email = json.load(open(json_key_filename, 'r'))['client_email'] 366 google_credentials, unused_project_id = google_auth.default( 367 scopes=[args.oauth_scope]) 368 call_credentials = grpc.metadata_call_credentials( 369 google_auth_transport_grpc.AuthMetadataPlugin( 370 credentials=google_credentials, 371 request=google_auth_transport_requests.Request())) 372 response = await _large_unary_common_behavior(stub, True, False, 373 call_credentials) 374 if wanted_email != response.username: 375 raise ValueError('expected username %s, got %s' % 376 (wanted_email, response.username)) 377 378 379async def _special_status_message(stub: test_pb2_grpc.TestServiceStub): 380 details = b'\t\ntest with whitespace\r\nand Unicode BMP \xe2\x98\xba and non-BMP \xf0\x9f\x98\x88\t\n'.decode( 381 'utf-8') 382 status = grpc.StatusCode.UNKNOWN # code = 2 383 384 # Test with a UnaryCall 385 request = messages_pb2.SimpleRequest( 386 response_type=messages_pb2.COMPRESSABLE, 387 response_size=1, 388 payload=messages_pb2.Payload(body=b'\x00'), 389 response_status=messages_pb2.EchoStatus(code=status.value[0], 390 message=details)) 391 call = stub.UnaryCall(request) 392 await _validate_status_code_and_details(call, status, details) 393 394 395@enum.unique 396class TestCase(enum.Enum): 397 EMPTY_UNARY = 'empty_unary' 398 LARGE_UNARY = 'large_unary' 399 SERVER_STREAMING = 'server_streaming' 400 CLIENT_STREAMING = 'client_streaming' 401 PING_PONG = 'ping_pong' 402 CANCEL_AFTER_BEGIN = 'cancel_after_begin' 403 CANCEL_AFTER_FIRST_RESPONSE = 'cancel_after_first_response' 404 TIMEOUT_ON_SLEEPING_SERVER = 'timeout_on_sleeping_server' 405 EMPTY_STREAM = 'empty_stream' 406 STATUS_CODE_AND_MESSAGE = 'status_code_and_message' 407 UNIMPLEMENTED_METHOD = 'unimplemented_method' 408 UNIMPLEMENTED_SERVICE = 'unimplemented_service' 409 CUSTOM_METADATA = "custom_metadata" 410 COMPUTE_ENGINE_CREDS = 'compute_engine_creds' 411 OAUTH2_AUTH_TOKEN = 'oauth2_auth_token' 412 JWT_TOKEN_CREDS = 'jwt_token_creds' 413 PER_RPC_CREDS = 'per_rpc_creds' 414 SPECIAL_STATUS_MESSAGE = 'special_status_message' 415 416 417_TEST_CASE_IMPLEMENTATION_MAPPING = { 418 TestCase.EMPTY_UNARY: _empty_unary, 419 TestCase.LARGE_UNARY: _large_unary, 420 TestCase.SERVER_STREAMING: _server_streaming, 421 TestCase.CLIENT_STREAMING: _client_streaming, 422 TestCase.PING_PONG: _ping_pong, 423 TestCase.CANCEL_AFTER_BEGIN: _cancel_after_begin, 424 TestCase.CANCEL_AFTER_FIRST_RESPONSE: _cancel_after_first_response, 425 TestCase.TIMEOUT_ON_SLEEPING_SERVER: _timeout_on_sleeping_server, 426 TestCase.EMPTY_STREAM: _empty_stream, 427 TestCase.STATUS_CODE_AND_MESSAGE: _status_code_and_message, 428 TestCase.UNIMPLEMENTED_METHOD: _unimplemented_method, 429 TestCase.UNIMPLEMENTED_SERVICE: _unimplemented_service, 430 TestCase.CUSTOM_METADATA: _custom_metadata, 431 TestCase.COMPUTE_ENGINE_CREDS: _compute_engine_creds, 432 TestCase.OAUTH2_AUTH_TOKEN: _oauth2_auth_token, 433 TestCase.JWT_TOKEN_CREDS: _jwt_token_creds, 434 TestCase.PER_RPC_CREDS: _per_rpc_creds, 435 TestCase.SPECIAL_STATUS_MESSAGE: _special_status_message, 436} 437 438 439async def test_interoperability( 440 case: TestCase, 441 stub: test_pb2_grpc.TestServiceStub, 442 args: Optional[argparse.Namespace] = None) -> None: 443 method = _TEST_CASE_IMPLEMENTATION_MAPPING.get(case) 444 if method is None: 445 raise NotImplementedError(f'Test case "{case}" not implemented!') 446 else: 447 num_params = len(inspect.signature(method).parameters) 448 if num_params == 1: 449 await method(stub) 450 elif num_params == 2: 451 if args is not None: 452 await method(stub, args) 453 else: 454 raise ValueError(f'Failed to run case [{case}]: args is None') 455 else: 456 raise ValueError(f'Invalid number of parameters [{num_params}]') 457