1# Copyright 2019 The gRPC Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Implementations of interoperability test methods.""" 15 16import argparse 17import asyncio 18import collections 19import datetime 20import enum 21import inspect 22import json 23import os 24import threading 25import time 26from typing import Any, Optional, Union 27 28from google import auth as google_auth 29from google.auth import environment_vars as google_auth_environment_vars 30from google.auth.transport import grpc as google_auth_transport_grpc 31from google.auth.transport import requests as google_auth_transport_requests 32import grpc 33from grpc.experimental import aio 34 35from src.proto.grpc.testing import empty_pb2 36from src.proto.grpc.testing import messages_pb2 37from src.proto.grpc.testing import test_pb2_grpc 38 39_INITIAL_METADATA_KEY = "x-grpc-test-echo-initial" 40_TRAILING_METADATA_KEY = "x-grpc-test-echo-trailing-bin" 41 42 43async def _expect_status_code( 44 call: aio.Call, expected_code: grpc.StatusCode 45) -> None: 46 code = await call.code() 47 if code != expected_code: 48 raise ValueError( 49 "expected code %s, got %s" % (expected_code, await call.code()) 50 ) 51 52 53async def _expect_status_details(call: aio.Call, expected_details: str) -> None: 54 details = await call.details() 55 if details != expected_details: 56 raise ValueError( 57 "expected message %s, got %s" 58 % (expected_details, await call.details()) 59 ) 60 61 62async def _validate_status_code_and_details( 63 call: aio.Call, expected_code: grpc.StatusCode, expected_details: str 64) -> None: 65 await _expect_status_code(call, expected_code) 66 await _expect_status_details(call, expected_details) 67 68 69def _validate_payload_type_and_length( 70 response: Union[ 71 messages_pb2.SimpleResponse, messages_pb2.StreamingOutputCallResponse 72 ], 73 expected_type: Any, 74 expected_length: int, 75) -> None: 76 if response.payload.type is not expected_type: 77 raise ValueError( 78 "expected payload type %s, got %s" 79 % (expected_type, type(response.payload.type)) 80 ) 81 elif len(response.payload.body) != expected_length: 82 raise ValueError( 83 "expected payload body size %d, got %d" 84 % (expected_length, len(response.payload.body)) 85 ) 86 87 88async def _large_unary_common_behavior( 89 stub: test_pb2_grpc.TestServiceStub, 90 fill_username: bool, 91 fill_oauth_scope: bool, 92 call_credentials: Optional[grpc.CallCredentials], 93) -> messages_pb2.SimpleResponse: 94 size = 314159 95 request = messages_pb2.SimpleRequest( 96 response_type=messages_pb2.COMPRESSABLE, 97 response_size=size, 98 payload=messages_pb2.Payload(body=b"\x00" * 271828), 99 fill_username=fill_username, 100 fill_oauth_scope=fill_oauth_scope, 101 ) 102 response = await stub.UnaryCall(request, credentials=call_credentials) 103 _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size) 104 return response 105 106 107async def _empty_unary(stub: test_pb2_grpc.TestServiceStub) -> None: 108 response = await stub.EmptyCall(empty_pb2.Empty()) 109 if not isinstance(response, empty_pb2.Empty): 110 raise TypeError( 111 'response is of type "%s", not empty_pb2.Empty!' % type(response) 112 ) 113 114 115async def _large_unary(stub: test_pb2_grpc.TestServiceStub) -> None: 116 await _large_unary_common_behavior(stub, False, False, None) 117 118 119async def _client_streaming(stub: test_pb2_grpc.TestServiceStub) -> None: 120 payload_body_sizes = ( 121 27182, 122 8, 123 1828, 124 45904, 125 ) 126 127 async def request_gen(): 128 for size in payload_body_sizes: 129 yield messages_pb2.StreamingInputCallRequest( 130 payload=messages_pb2.Payload(body=b"\x00" * size) 131 ) 132 133 response = await stub.StreamingInputCall(request_gen()) 134 if response.aggregated_payload_size != sum(payload_body_sizes): 135 raise ValueError( 136 "incorrect size %d!" % response.aggregated_payload_size 137 ) 138 139 140async def _server_streaming(stub: test_pb2_grpc.TestServiceStub) -> None: 141 sizes = ( 142 31415, 143 9, 144 2653, 145 58979, 146 ) 147 148 request = messages_pb2.StreamingOutputCallRequest( 149 response_type=messages_pb2.COMPRESSABLE, 150 response_parameters=( 151 messages_pb2.ResponseParameters(size=sizes[0]), 152 messages_pb2.ResponseParameters(size=sizes[1]), 153 messages_pb2.ResponseParameters(size=sizes[2]), 154 messages_pb2.ResponseParameters(size=sizes[3]), 155 ), 156 ) 157 call = stub.StreamingOutputCall(request) 158 for size in sizes: 159 response = await call.read() 160 _validate_payload_type_and_length( 161 response, messages_pb2.COMPRESSABLE, size 162 ) 163 164 165async def _ping_pong(stub: test_pb2_grpc.TestServiceStub) -> None: 166 request_response_sizes = ( 167 31415, 168 9, 169 2653, 170 58979, 171 ) 172 request_payload_sizes = ( 173 27182, 174 8, 175 1828, 176 45904, 177 ) 178 179 call = stub.FullDuplexCall() 180 for response_size, payload_size in zip( 181 request_response_sizes, request_payload_sizes 182 ): 183 request = messages_pb2.StreamingOutputCallRequest( 184 response_type=messages_pb2.COMPRESSABLE, 185 response_parameters=( 186 messages_pb2.ResponseParameters(size=response_size), 187 ), 188 payload=messages_pb2.Payload(body=b"\x00" * payload_size), 189 ) 190 191 await call.write(request) 192 response = await call.read() 193 _validate_payload_type_and_length( 194 response, messages_pb2.COMPRESSABLE, response_size 195 ) 196 await call.done_writing() 197 await _validate_status_code_and_details(call, grpc.StatusCode.OK, "") 198 199 200async def _cancel_after_begin(stub: test_pb2_grpc.TestServiceStub): 201 call = stub.StreamingInputCall() 202 call.cancel() 203 if not call.cancelled(): 204 raise ValueError("expected cancelled method to return True") 205 code = await call.code() 206 if code is not grpc.StatusCode.CANCELLED: 207 raise ValueError("expected status code CANCELLED") 208 209 210async def _cancel_after_first_response(stub: test_pb2_grpc.TestServiceStub): 211 request_response_sizes = ( 212 31415, 213 9, 214 2653, 215 58979, 216 ) 217 request_payload_sizes = ( 218 27182, 219 8, 220 1828, 221 45904, 222 ) 223 224 call = stub.FullDuplexCall() 225 226 response_size = request_response_sizes[0] 227 payload_size = request_payload_sizes[0] 228 request = messages_pb2.StreamingOutputCallRequest( 229 response_type=messages_pb2.COMPRESSABLE, 230 response_parameters=( 231 messages_pb2.ResponseParameters(size=response_size), 232 ), 233 payload=messages_pb2.Payload(body=b"\x00" * payload_size), 234 ) 235 236 await call.write(request) 237 await call.read() 238 239 call.cancel() 240 241 try: 242 await call.read() 243 except asyncio.CancelledError: 244 assert await call.code() is grpc.StatusCode.CANCELLED 245 else: 246 raise ValueError("expected call to be cancelled") 247 248 249async def _timeout_on_sleeping_server(stub: test_pb2_grpc.TestServiceStub): 250 request_payload_size = 27182 251 time_limit = datetime.timedelta(seconds=1) 252 253 call = stub.FullDuplexCall(timeout=time_limit.total_seconds()) 254 255 request = messages_pb2.StreamingOutputCallRequest( 256 response_type=messages_pb2.COMPRESSABLE, 257 payload=messages_pb2.Payload(body=b"\x00" * request_payload_size), 258 response_parameters=( 259 messages_pb2.ResponseParameters( 260 interval_us=int(time_limit.total_seconds() * 2 * 10**6) 261 ), 262 ), 263 ) 264 await call.write(request) 265 await call.done_writing() 266 try: 267 await call.read() 268 except aio.AioRpcError as rpc_error: 269 if rpc_error.code() is not grpc.StatusCode.DEADLINE_EXCEEDED: 270 raise 271 else: 272 raise ValueError("expected call to exceed deadline") 273 274 275async def _empty_stream(stub: test_pb2_grpc.TestServiceStub): 276 call = stub.FullDuplexCall() 277 await call.done_writing() 278 assert await call.read() == aio.EOF 279 280 281async def _status_code_and_message(stub: test_pb2_grpc.TestServiceStub): 282 details = "test status message" 283 status = grpc.StatusCode.UNKNOWN # code = 2 284 285 # Test with a UnaryCall 286 request = messages_pb2.SimpleRequest( 287 response_type=messages_pb2.COMPRESSABLE, 288 response_size=1, 289 payload=messages_pb2.Payload(body=b"\x00"), 290 response_status=messages_pb2.EchoStatus( 291 code=status.value[0], message=details 292 ), 293 ) 294 call = stub.UnaryCall(request) 295 await _validate_status_code_and_details(call, status, details) 296 297 # Test with a FullDuplexCall 298 call = stub.FullDuplexCall() 299 request = messages_pb2.StreamingOutputCallRequest( 300 response_type=messages_pb2.COMPRESSABLE, 301 response_parameters=(messages_pb2.ResponseParameters(size=1),), 302 payload=messages_pb2.Payload(body=b"\x00"), 303 response_status=messages_pb2.EchoStatus( 304 code=status.value[0], message=details 305 ), 306 ) 307 await call.write(request) # sends the initial request. 308 await call.done_writing() 309 try: 310 await call.read() 311 except aio.AioRpcError as rpc_error: 312 assert rpc_error.code() == status 313 await _validate_status_code_and_details(call, status, details) 314 315 316async def _unimplemented_method(stub: test_pb2_grpc.TestServiceStub): 317 call = stub.UnimplementedCall(empty_pb2.Empty()) 318 await _expect_status_code(call, grpc.StatusCode.UNIMPLEMENTED) 319 320 321async def _unimplemented_service(stub: test_pb2_grpc.UnimplementedServiceStub): 322 call = stub.UnimplementedCall(empty_pb2.Empty()) 323 await _expect_status_code(call, grpc.StatusCode.UNIMPLEMENTED) 324 325 326async def _custom_metadata(stub: test_pb2_grpc.TestServiceStub): 327 initial_metadata_value = "test_initial_metadata_value" 328 trailing_metadata_value = b"\x0a\x0b\x0a\x0b\x0a\x0b" 329 metadata = aio.Metadata( 330 (_INITIAL_METADATA_KEY, initial_metadata_value), 331 (_TRAILING_METADATA_KEY, trailing_metadata_value), 332 ) 333 334 async def _validate_metadata(call): 335 initial_metadata = await call.initial_metadata() 336 if initial_metadata[_INITIAL_METADATA_KEY] != initial_metadata_value: 337 raise ValueError( 338 "expected initial metadata %s, got %s" 339 % ( 340 initial_metadata_value, 341 initial_metadata[_INITIAL_METADATA_KEY], 342 ) 343 ) 344 345 trailing_metadata = await call.trailing_metadata() 346 if trailing_metadata[_TRAILING_METADATA_KEY] != trailing_metadata_value: 347 raise ValueError( 348 "expected trailing metadata %s, got %s" 349 % ( 350 trailing_metadata_value, 351 trailing_metadata[_TRAILING_METADATA_KEY], 352 ) 353 ) 354 355 # Testing with UnaryCall 356 request = messages_pb2.SimpleRequest( 357 response_type=messages_pb2.COMPRESSABLE, 358 response_size=1, 359 payload=messages_pb2.Payload(body=b"\x00"), 360 ) 361 call = stub.UnaryCall(request, metadata=metadata) 362 await _validate_metadata(call) 363 364 # Testing with FullDuplexCall 365 call = stub.FullDuplexCall(metadata=metadata) 366 request = messages_pb2.StreamingOutputCallRequest( 367 response_type=messages_pb2.COMPRESSABLE, 368 response_parameters=(messages_pb2.ResponseParameters(size=1),), 369 ) 370 await call.write(request) 371 await call.read() 372 await call.done_writing() 373 await _validate_metadata(call) 374 375 376async def _compute_engine_creds( 377 stub: test_pb2_grpc.TestServiceStub, args: argparse.Namespace 378): 379 response = await _large_unary_common_behavior(stub, True, True, None) 380 if args.default_service_account != response.username: 381 raise ValueError( 382 "expected username %s, got %s" 383 % (args.default_service_account, response.username) 384 ) 385 386 387async def _oauth2_auth_token( 388 stub: test_pb2_grpc.TestServiceStub, args: argparse.Namespace 389): 390 json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS] 391 wanted_email = json.load(open(json_key_filename, "r"))["client_email"] 392 response = await _large_unary_common_behavior(stub, True, True, None) 393 if wanted_email != response.username: 394 raise ValueError( 395 "expected username %s, got %s" % (wanted_email, response.username) 396 ) 397 if args.oauth_scope.find(response.oauth_scope) == -1: 398 raise ValueError( 399 'expected to find oauth scope "{}" in received "{}"'.format( 400 response.oauth_scope, args.oauth_scope 401 ) 402 ) 403 404 405async def _jwt_token_creds(stub: test_pb2_grpc.TestServiceStub): 406 json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS] 407 wanted_email = json.load(open(json_key_filename, "r"))["client_email"] 408 response = await _large_unary_common_behavior(stub, True, False, None) 409 if wanted_email != response.username: 410 raise ValueError( 411 "expected username %s, got %s" % (wanted_email, response.username) 412 ) 413 414 415async def _per_rpc_creds( 416 stub: test_pb2_grpc.TestServiceStub, args: argparse.Namespace 417): 418 json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS] 419 wanted_email = json.load(open(json_key_filename, "r"))["client_email"] 420 google_credentials, unused_project_id = google_auth.default( 421 scopes=[args.oauth_scope] 422 ) 423 call_credentials = grpc.metadata_call_credentials( 424 google_auth_transport_grpc.AuthMetadataPlugin( 425 credentials=google_credentials, 426 request=google_auth_transport_requests.Request(), 427 ) 428 ) 429 response = await _large_unary_common_behavior( 430 stub, True, False, call_credentials 431 ) 432 if wanted_email != response.username: 433 raise ValueError( 434 "expected username %s, got %s" % (wanted_email, response.username) 435 ) 436 437 438async def _special_status_message(stub: test_pb2_grpc.TestServiceStub): 439 details = ( 440 b"\t\ntest with whitespace\r\nand Unicode BMP \xe2\x98\xba and non-BMP" 441 b" \xf0\x9f\x98\x88\t\n".decode("utf-8") 442 ) 443 status = grpc.StatusCode.UNKNOWN # code = 2 444 445 # Test with a UnaryCall 446 request = messages_pb2.SimpleRequest( 447 response_type=messages_pb2.COMPRESSABLE, 448 response_size=1, 449 payload=messages_pb2.Payload(body=b"\x00"), 450 response_status=messages_pb2.EchoStatus( 451 code=status.value[0], message=details 452 ), 453 ) 454 call = stub.UnaryCall(request) 455 await _validate_status_code_and_details(call, status, details) 456 457 458@enum.unique 459class TestCase(enum.Enum): 460 EMPTY_UNARY = "empty_unary" 461 LARGE_UNARY = "large_unary" 462 SERVER_STREAMING = "server_streaming" 463 CLIENT_STREAMING = "client_streaming" 464 PING_PONG = "ping_pong" 465 CANCEL_AFTER_BEGIN = "cancel_after_begin" 466 CANCEL_AFTER_FIRST_RESPONSE = "cancel_after_first_response" 467 TIMEOUT_ON_SLEEPING_SERVER = "timeout_on_sleeping_server" 468 EMPTY_STREAM = "empty_stream" 469 STATUS_CODE_AND_MESSAGE = "status_code_and_message" 470 UNIMPLEMENTED_METHOD = "unimplemented_method" 471 UNIMPLEMENTED_SERVICE = "unimplemented_service" 472 CUSTOM_METADATA = "custom_metadata" 473 COMPUTE_ENGINE_CREDS = "compute_engine_creds" 474 OAUTH2_AUTH_TOKEN = "oauth2_auth_token" 475 JWT_TOKEN_CREDS = "jwt_token_creds" 476 PER_RPC_CREDS = "per_rpc_creds" 477 SPECIAL_STATUS_MESSAGE = "special_status_message" 478 479 480_TEST_CASE_IMPLEMENTATION_MAPPING = { 481 TestCase.EMPTY_UNARY: _empty_unary, 482 TestCase.LARGE_UNARY: _large_unary, 483 TestCase.SERVER_STREAMING: _server_streaming, 484 TestCase.CLIENT_STREAMING: _client_streaming, 485 TestCase.PING_PONG: _ping_pong, 486 TestCase.CANCEL_AFTER_BEGIN: _cancel_after_begin, 487 TestCase.CANCEL_AFTER_FIRST_RESPONSE: _cancel_after_first_response, 488 TestCase.TIMEOUT_ON_SLEEPING_SERVER: _timeout_on_sleeping_server, 489 TestCase.EMPTY_STREAM: _empty_stream, 490 TestCase.STATUS_CODE_AND_MESSAGE: _status_code_and_message, 491 TestCase.UNIMPLEMENTED_METHOD: _unimplemented_method, 492 TestCase.UNIMPLEMENTED_SERVICE: _unimplemented_service, 493 TestCase.CUSTOM_METADATA: _custom_metadata, 494 TestCase.COMPUTE_ENGINE_CREDS: _compute_engine_creds, 495 TestCase.OAUTH2_AUTH_TOKEN: _oauth2_auth_token, 496 TestCase.JWT_TOKEN_CREDS: _jwt_token_creds, 497 TestCase.PER_RPC_CREDS: _per_rpc_creds, 498 TestCase.SPECIAL_STATUS_MESSAGE: _special_status_message, 499} 500 501 502async def test_interoperability( 503 case: TestCase, 504 stub: test_pb2_grpc.TestServiceStub, 505 args: Optional[argparse.Namespace] = None, 506) -> None: 507 method = _TEST_CASE_IMPLEMENTATION_MAPPING.get(case) 508 if method is None: 509 raise NotImplementedError(f'Test case "{case}" not implemented!') 510 else: 511 num_params = len(inspect.signature(method).parameters) 512 if num_params == 1: 513 await method(stub) 514 elif num_params == 2: 515 if args is not None: 516 await method(stub, args) 517 else: 518 raise ValueError(f"Failed to run case [{case}]: args is None") 519 else: 520 raise ValueError(f"Invalid number of parameters [{num_params}]") 521