1# Copyright 2015 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Implementations of interoperability test methods.""" 15 16# NOTE(lidiz) This module only exists in Bazel BUILD file, for more details 17# please refer to comments in the "bazel_namespace_package_hack" module. 18try: 19 from tests import bazel_namespace_package_hack 20 bazel_namespace_package_hack.sys_path_to_site_dir_hack() 21except ImportError: 22 pass 23 24import enum 25import json 26import os 27import threading 28import time 29 30from google import auth as google_auth 31from google.auth import environment_vars as google_auth_environment_vars 32from google.auth.transport import grpc as google_auth_transport_grpc 33from google.auth.transport import requests as google_auth_transport_requests 34import grpc 35 36from src.proto.grpc.testing import empty_pb2 37from src.proto.grpc.testing import messages_pb2 38 39_INITIAL_METADATA_KEY = "x-grpc-test-echo-initial" 40_TRAILING_METADATA_KEY = "x-grpc-test-echo-trailing-bin" 41 42 43def _expect_status_code(call, expected_code): 44 if call.code() != expected_code: 45 raise ValueError('expected code %s, got %s' % 46 (expected_code, call.code())) 47 48 49def _expect_status_details(call, expected_details): 50 if call.details() != expected_details: 51 raise ValueError('expected message %s, got %s' % 52 (expected_details, call.details())) 53 54 55def _validate_status_code_and_details(call, expected_code, expected_details): 56 _expect_status_code(call, expected_code) 57 _expect_status_details(call, expected_details) 58 59 60def _validate_payload_type_and_length(response, expected_type, expected_length): 61 if response.payload.type is not expected_type: 62 raise ValueError('expected payload type %s, got %s' % 63 (expected_type, type(response.payload.type))) 64 elif len(response.payload.body) != expected_length: 65 raise ValueError('expected payload body size %d, got %d' % 66 (expected_length, len(response.payload.body))) 67 68 69def _large_unary_common_behavior(stub, fill_username, fill_oauth_scope, 70 call_credentials): 71 size = 314159 72 request = messages_pb2.SimpleRequest( 73 response_type=messages_pb2.COMPRESSABLE, 74 response_size=size, 75 payload=messages_pb2.Payload(body=b'\x00' * 271828), 76 fill_username=fill_username, 77 fill_oauth_scope=fill_oauth_scope) 78 response_future = stub.UnaryCall.future(request, 79 credentials=call_credentials) 80 response = response_future.result() 81 _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size) 82 return response 83 84 85def _empty_unary(stub): 86 response = stub.EmptyCall(empty_pb2.Empty()) 87 if not isinstance(response, empty_pb2.Empty): 88 raise TypeError('response is of type "%s", not empty_pb2.Empty!' % 89 type(response)) 90 91 92def _large_unary(stub): 93 _large_unary_common_behavior(stub, False, False, None) 94 95 96def _client_streaming(stub): 97 payload_body_sizes = ( 98 27182, 99 8, 100 1828, 101 45904, 102 ) 103 payloads = (messages_pb2.Payload(body=b'\x00' * size) 104 for size in payload_body_sizes) 105 requests = (messages_pb2.StreamingInputCallRequest(payload=payload) 106 for payload in payloads) 107 response = stub.StreamingInputCall(requests) 108 if response.aggregated_payload_size != 74922: 109 raise ValueError('incorrect size %d!' % 110 response.aggregated_payload_size) 111 112 113def _server_streaming(stub): 114 sizes = ( 115 31415, 116 9, 117 2653, 118 58979, 119 ) 120 121 request = messages_pb2.StreamingOutputCallRequest( 122 response_type=messages_pb2.COMPRESSABLE, 123 response_parameters=( 124 messages_pb2.ResponseParameters(size=sizes[0]), 125 messages_pb2.ResponseParameters(size=sizes[1]), 126 messages_pb2.ResponseParameters(size=sizes[2]), 127 messages_pb2.ResponseParameters(size=sizes[3]), 128 )) 129 response_iterator = stub.StreamingOutputCall(request) 130 for index, response in enumerate(response_iterator): 131 _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, 132 sizes[index]) 133 134 135class _Pipe(object): 136 137 def __init__(self): 138 self._condition = threading.Condition() 139 self._values = [] 140 self._open = True 141 142 def __iter__(self): 143 return self 144 145 def __next__(self): 146 return self.next() 147 148 def next(self): 149 with self._condition: 150 while not self._values and self._open: 151 self._condition.wait() 152 if self._values: 153 return self._values.pop(0) 154 else: 155 raise StopIteration() 156 157 def add(self, value): 158 with self._condition: 159 self._values.append(value) 160 self._condition.notify() 161 162 def close(self): 163 with self._condition: 164 self._open = False 165 self._condition.notify() 166 167 def __enter__(self): 168 return self 169 170 def __exit__(self, type, value, traceback): 171 self.close() 172 173 174def _ping_pong(stub): 175 request_response_sizes = ( 176 31415, 177 9, 178 2653, 179 58979, 180 ) 181 request_payload_sizes = ( 182 27182, 183 8, 184 1828, 185 45904, 186 ) 187 188 with _Pipe() as pipe: 189 response_iterator = stub.FullDuplexCall(pipe) 190 for response_size, payload_size in zip(request_response_sizes, 191 request_payload_sizes): 192 request = messages_pb2.StreamingOutputCallRequest( 193 response_type=messages_pb2.COMPRESSABLE, 194 response_parameters=(messages_pb2.ResponseParameters( 195 size=response_size),), 196 payload=messages_pb2.Payload(body=b'\x00' * payload_size)) 197 pipe.add(request) 198 response = next(response_iterator) 199 _validate_payload_type_and_length(response, 200 messages_pb2.COMPRESSABLE, 201 response_size) 202 203 204def _cancel_after_begin(stub): 205 with _Pipe() as pipe: 206 response_future = stub.StreamingInputCall.future(pipe) 207 response_future.cancel() 208 if not response_future.cancelled(): 209 raise ValueError('expected cancelled method to return True') 210 if response_future.code() is not grpc.StatusCode.CANCELLED: 211 raise ValueError('expected status code CANCELLED') 212 213 214def _cancel_after_first_response(stub): 215 request_response_sizes = ( 216 31415, 217 9, 218 2653, 219 58979, 220 ) 221 request_payload_sizes = ( 222 27182, 223 8, 224 1828, 225 45904, 226 ) 227 with _Pipe() as pipe: 228 response_iterator = stub.FullDuplexCall(pipe) 229 230 response_size = request_response_sizes[0] 231 payload_size = request_payload_sizes[0] 232 request = messages_pb2.StreamingOutputCallRequest( 233 response_type=messages_pb2.COMPRESSABLE, 234 response_parameters=(messages_pb2.ResponseParameters( 235 size=response_size),), 236 payload=messages_pb2.Payload(body=b'\x00' * payload_size)) 237 pipe.add(request) 238 response = next(response_iterator) 239 # We test the contents of `response` in the Ping Pong test - don't check 240 # them here. 241 response_iterator.cancel() 242 243 try: 244 next(response_iterator) 245 except grpc.RpcError as rpc_error: 246 if rpc_error.code() is not grpc.StatusCode.CANCELLED: 247 raise 248 else: 249 raise ValueError('expected call to be cancelled') 250 251 252def _timeout_on_sleeping_server(stub): 253 request_payload_size = 27182 254 with _Pipe() as pipe: 255 response_iterator = stub.FullDuplexCall(pipe, timeout=0.001) 256 257 request = messages_pb2.StreamingOutputCallRequest( 258 response_type=messages_pb2.COMPRESSABLE, 259 payload=messages_pb2.Payload(body=b'\x00' * request_payload_size)) 260 pipe.add(request) 261 try: 262 next(response_iterator) 263 except grpc.RpcError as rpc_error: 264 if rpc_error.code() is not grpc.StatusCode.DEADLINE_EXCEEDED: 265 raise 266 else: 267 raise ValueError('expected call to exceed deadline') 268 269 270def _empty_stream(stub): 271 with _Pipe() as pipe: 272 response_iterator = stub.FullDuplexCall(pipe) 273 pipe.close() 274 try: 275 next(response_iterator) 276 raise ValueError('expected exactly 0 responses') 277 except StopIteration: 278 pass 279 280 281def _status_code_and_message(stub): 282 details = 'test status message' 283 code = 2 284 status = grpc.StatusCode.UNKNOWN # code = 2 285 286 # Test with a UnaryCall 287 request = messages_pb2.SimpleRequest( 288 response_type=messages_pb2.COMPRESSABLE, 289 response_size=1, 290 payload=messages_pb2.Payload(body=b'\x00'), 291 response_status=messages_pb2.EchoStatus(code=code, message=details)) 292 response_future = stub.UnaryCall.future(request) 293 _validate_status_code_and_details(response_future, status, details) 294 295 # Test with a FullDuplexCall 296 with _Pipe() as pipe: 297 response_iterator = stub.FullDuplexCall(pipe) 298 request = messages_pb2.StreamingOutputCallRequest( 299 response_type=messages_pb2.COMPRESSABLE, 300 response_parameters=(messages_pb2.ResponseParameters(size=1),), 301 payload=messages_pb2.Payload(body=b'\x00'), 302 response_status=messages_pb2.EchoStatus(code=code, message=details)) 303 pipe.add(request) # sends the initial request. 304 try: 305 next(response_iterator) 306 except grpc.RpcError as rpc_error: 307 assert rpc_error.code() == status 308 # Dropping out of with block closes the pipe 309 _validate_status_code_and_details(response_iterator, status, details) 310 311 312def _unimplemented_method(test_service_stub): 313 response_future = (test_service_stub.UnimplementedCall.future( 314 empty_pb2.Empty())) 315 _expect_status_code(response_future, grpc.StatusCode.UNIMPLEMENTED) 316 317 318def _unimplemented_service(unimplemented_service_stub): 319 response_future = (unimplemented_service_stub.UnimplementedCall.future( 320 empty_pb2.Empty())) 321 _expect_status_code(response_future, grpc.StatusCode.UNIMPLEMENTED) 322 323 324def _custom_metadata(stub): 325 initial_metadata_value = "test_initial_metadata_value" 326 trailing_metadata_value = b"\x0a\x0b\x0a\x0b\x0a\x0b" 327 metadata = ((_INITIAL_METADATA_KEY, initial_metadata_value), 328 (_TRAILING_METADATA_KEY, trailing_metadata_value)) 329 330 def _validate_metadata(response): 331 initial_metadata = dict(response.initial_metadata()) 332 if initial_metadata[_INITIAL_METADATA_KEY] != initial_metadata_value: 333 raise ValueError('expected initial metadata %s, got %s' % 334 (initial_metadata_value, 335 initial_metadata[_INITIAL_METADATA_KEY])) 336 trailing_metadata = dict(response.trailing_metadata()) 337 if trailing_metadata[_TRAILING_METADATA_KEY] != trailing_metadata_value: 338 raise ValueError('expected trailing metadata %s, got %s' % 339 (trailing_metadata_value, 340 trailing_metadata[_TRAILING_METADATA_KEY])) 341 342 # Testing with UnaryCall 343 request = messages_pb2.SimpleRequest( 344 response_type=messages_pb2.COMPRESSABLE, 345 response_size=1, 346 payload=messages_pb2.Payload(body=b'\x00')) 347 response_future = stub.UnaryCall.future(request, metadata=metadata) 348 _validate_metadata(response_future) 349 350 # Testing with FullDuplexCall 351 with _Pipe() as pipe: 352 response_iterator = stub.FullDuplexCall(pipe, metadata=metadata) 353 request = messages_pb2.StreamingOutputCallRequest( 354 response_type=messages_pb2.COMPRESSABLE, 355 response_parameters=(messages_pb2.ResponseParameters(size=1),)) 356 pipe.add(request) # Sends the request 357 next(response_iterator) # Causes server to send trailing metadata 358 # Dropping out of the with block closes the pipe 359 _validate_metadata(response_iterator) 360 361 362def _compute_engine_creds(stub, args): 363 response = _large_unary_common_behavior(stub, True, True, None) 364 if args.default_service_account != response.username: 365 raise ValueError('expected username %s, got %s' % 366 (args.default_service_account, response.username)) 367 368 369def _oauth2_auth_token(stub, args): 370 json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS] 371 wanted_email = json.load(open(json_key_filename, 'r'))['client_email'] 372 response = _large_unary_common_behavior(stub, True, True, None) 373 if wanted_email != response.username: 374 raise ValueError('expected username %s, got %s' % 375 (wanted_email, response.username)) 376 if args.oauth_scope.find(response.oauth_scope) == -1: 377 raise ValueError( 378 'expected to find oauth scope "{}" in received "{}"'.format( 379 response.oauth_scope, args.oauth_scope)) 380 381 382def _jwt_token_creds(stub, args): 383 json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS] 384 wanted_email = json.load(open(json_key_filename, 'r'))['client_email'] 385 response = _large_unary_common_behavior(stub, True, False, None) 386 if wanted_email != response.username: 387 raise ValueError('expected username %s, got %s' % 388 (wanted_email, response.username)) 389 390 391def _per_rpc_creds(stub, args): 392 json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS] 393 wanted_email = json.load(open(json_key_filename, 'r'))['client_email'] 394 google_credentials, unused_project_id = google_auth.default( 395 scopes=[args.oauth_scope]) 396 call_credentials = grpc.metadata_call_credentials( 397 google_auth_transport_grpc.AuthMetadataPlugin( 398 credentials=google_credentials, 399 request=google_auth_transport_requests.Request())) 400 response = _large_unary_common_behavior(stub, True, False, call_credentials) 401 if wanted_email != response.username: 402 raise ValueError('expected username %s, got %s' % 403 (wanted_email, response.username)) 404 405 406def _special_status_message(stub, args): 407 details = b'\t\ntest with whitespace\r\nand Unicode BMP \xe2\x98\xba and non-BMP \xf0\x9f\x98\x88\t\n'.decode( 408 'utf-8') 409 code = 2 410 status = grpc.StatusCode.UNKNOWN # code = 2 411 412 # Test with a UnaryCall 413 request = messages_pb2.SimpleRequest( 414 response_type=messages_pb2.COMPRESSABLE, 415 response_size=1, 416 payload=messages_pb2.Payload(body=b'\x00'), 417 response_status=messages_pb2.EchoStatus(code=code, message=details)) 418 response_future = stub.UnaryCall.future(request) 419 _validate_status_code_and_details(response_future, status, details) 420 421 422@enum.unique 423class TestCase(enum.Enum): 424 EMPTY_UNARY = 'empty_unary' 425 LARGE_UNARY = 'large_unary' 426 SERVER_STREAMING = 'server_streaming' 427 CLIENT_STREAMING = 'client_streaming' 428 PING_PONG = 'ping_pong' 429 CANCEL_AFTER_BEGIN = 'cancel_after_begin' 430 CANCEL_AFTER_FIRST_RESPONSE = 'cancel_after_first_response' 431 EMPTY_STREAM = 'empty_stream' 432 STATUS_CODE_AND_MESSAGE = 'status_code_and_message' 433 UNIMPLEMENTED_METHOD = 'unimplemented_method' 434 UNIMPLEMENTED_SERVICE = 'unimplemented_service' 435 CUSTOM_METADATA = "custom_metadata" 436 COMPUTE_ENGINE_CREDS = 'compute_engine_creds' 437 OAUTH2_AUTH_TOKEN = 'oauth2_auth_token' 438 JWT_TOKEN_CREDS = 'jwt_token_creds' 439 PER_RPC_CREDS = 'per_rpc_creds' 440 TIMEOUT_ON_SLEEPING_SERVER = 'timeout_on_sleeping_server' 441 SPECIAL_STATUS_MESSAGE = 'special_status_message' 442 443 def test_interoperability(self, stub, args): 444 if self is TestCase.EMPTY_UNARY: 445 _empty_unary(stub) 446 elif self is TestCase.LARGE_UNARY: 447 _large_unary(stub) 448 elif self is TestCase.SERVER_STREAMING: 449 _server_streaming(stub) 450 elif self is TestCase.CLIENT_STREAMING: 451 _client_streaming(stub) 452 elif self is TestCase.PING_PONG: 453 _ping_pong(stub) 454 elif self is TestCase.CANCEL_AFTER_BEGIN: 455 _cancel_after_begin(stub) 456 elif self is TestCase.CANCEL_AFTER_FIRST_RESPONSE: 457 _cancel_after_first_response(stub) 458 elif self is TestCase.TIMEOUT_ON_SLEEPING_SERVER: 459 _timeout_on_sleeping_server(stub) 460 elif self is TestCase.EMPTY_STREAM: 461 _empty_stream(stub) 462 elif self is TestCase.STATUS_CODE_AND_MESSAGE: 463 _status_code_and_message(stub) 464 elif self is TestCase.UNIMPLEMENTED_METHOD: 465 _unimplemented_method(stub) 466 elif self is TestCase.UNIMPLEMENTED_SERVICE: 467 _unimplemented_service(stub) 468 elif self is TestCase.CUSTOM_METADATA: 469 _custom_metadata(stub) 470 elif self is TestCase.COMPUTE_ENGINE_CREDS: 471 _compute_engine_creds(stub, args) 472 elif self is TestCase.OAUTH2_AUTH_TOKEN: 473 _oauth2_auth_token(stub, args) 474 elif self is TestCase.JWT_TOKEN_CREDS: 475 _jwt_token_creds(stub, args) 476 elif self is TestCase.PER_RPC_CREDS: 477 _per_rpc_creds(stub, args) 478 elif self is TestCase.SPECIAL_STATUS_MESSAGE: 479 _special_status_message(stub, args) 480 else: 481 raise NotImplementedError('Test case "%s" not implemented!' % 482 self.name) 483