1# Copyright 2015 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Implementations of interoperability test methods.""" 15 16import enum 17import json 18import os 19import threading 20 21from google import auth as google_auth 22from google.auth import environment_vars as google_auth_environment_vars 23from google.auth.transport import grpc as google_auth_transport_grpc 24from google.auth.transport import requests as google_auth_transport_requests 25import grpc 26from grpc.beta import implementations 27 28from src.proto.grpc.testing import empty_pb2 29from src.proto.grpc.testing import messages_pb2 30from src.proto.grpc.testing import test_pb2_grpc 31 32_INITIAL_METADATA_KEY = "x-grpc-test-echo-initial" 33_TRAILING_METADATA_KEY = "x-grpc-test-echo-trailing-bin" 34 35 36def _maybe_echo_metadata(servicer_context): 37 """Copies metadata from request to response if it is present.""" 38 invocation_metadata = dict(servicer_context.invocation_metadata()) 39 if _INITIAL_METADATA_KEY in invocation_metadata: 40 initial_metadatum = (_INITIAL_METADATA_KEY, 41 invocation_metadata[_INITIAL_METADATA_KEY]) 42 servicer_context.send_initial_metadata((initial_metadatum,)) 43 if _TRAILING_METADATA_KEY in invocation_metadata: 44 trailing_metadatum = (_TRAILING_METADATA_KEY, 45 invocation_metadata[_TRAILING_METADATA_KEY]) 46 servicer_context.set_trailing_metadata((trailing_metadatum,)) 47 48 49def _maybe_echo_status_and_message(request, servicer_context): 50 """Sets the response context code and details if the request asks for them""" 51 if request.HasField('response_status'): 52 servicer_context.set_code(request.response_status.code) 53 servicer_context.set_details(request.response_status.message) 54 55 56class TestService(test_pb2_grpc.TestServiceServicer): 57 58 def EmptyCall(self, request, context): 59 _maybe_echo_metadata(context) 60 return empty_pb2.Empty() 61 62 def UnaryCall(self, request, context): 63 _maybe_echo_metadata(context) 64 _maybe_echo_status_and_message(request, context) 65 return messages_pb2.SimpleResponse( 66 payload=messages_pb2.Payload( 67 type=messages_pb2.COMPRESSABLE, 68 body=b'\x00' * request.response_size)) 69 70 def StreamingOutputCall(self, request, context): 71 _maybe_echo_status_and_message(request, context) 72 for response_parameters in request.response_parameters: 73 yield messages_pb2.StreamingOutputCallResponse( 74 payload=messages_pb2.Payload( 75 type=request.response_type, 76 body=b'\x00' * response_parameters.size)) 77 78 def StreamingInputCall(self, request_iterator, context): 79 aggregate_size = 0 80 for request in request_iterator: 81 if request.payload is not None and request.payload.body: 82 aggregate_size += len(request.payload.body) 83 return messages_pb2.StreamingInputCallResponse( 84 aggregated_payload_size=aggregate_size) 85 86 def FullDuplexCall(self, request_iterator, context): 87 _maybe_echo_metadata(context) 88 for request in request_iterator: 89 _maybe_echo_status_and_message(request, context) 90 for response_parameters in request.response_parameters: 91 yield messages_pb2.StreamingOutputCallResponse( 92 payload=messages_pb2.Payload( 93 type=request.payload.type, 94 body=b'\x00' * response_parameters.size)) 95 96 # NOTE(nathaniel): Apparently this is the same as the full-duplex call? 97 # NOTE(atash): It isn't even called in the interop spec (Oct 22 2015)... 98 def HalfDuplexCall(self, request_iterator, context): 99 return self.FullDuplexCall(request_iterator, context) 100 101 102def _expect_status_code(call, expected_code): 103 if call.code() != expected_code: 104 raise ValueError('expected code %s, got %s' % (expected_code, 105 call.code())) 106 107 108def _expect_status_details(call, expected_details): 109 if call.details() != expected_details: 110 raise ValueError('expected message %s, got %s' % (expected_details, 111 call.details())) 112 113 114def _validate_status_code_and_details(call, expected_code, expected_details): 115 _expect_status_code(call, expected_code) 116 _expect_status_details(call, expected_details) 117 118 119def _validate_payload_type_and_length(response, expected_type, expected_length): 120 if response.payload.type is not expected_type: 121 raise ValueError('expected payload type %s, got %s' % 122 (expected_type, type(response.payload.type))) 123 elif len(response.payload.body) != expected_length: 124 raise ValueError('expected payload body size %d, got %d' % 125 (expected_length, len(response.payload.body))) 126 127 128def _large_unary_common_behavior(stub, fill_username, fill_oauth_scope, 129 call_credentials): 130 size = 314159 131 request = messages_pb2.SimpleRequest( 132 response_type=messages_pb2.COMPRESSABLE, 133 response_size=size, 134 payload=messages_pb2.Payload(body=b'\x00' * 271828), 135 fill_username=fill_username, 136 fill_oauth_scope=fill_oauth_scope) 137 response_future = stub.UnaryCall.future( 138 request, credentials=call_credentials) 139 response = response_future.result() 140 _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size) 141 return response 142 143 144def _empty_unary(stub): 145 response = stub.EmptyCall(empty_pb2.Empty()) 146 if not isinstance(response, empty_pb2.Empty): 147 raise TypeError( 148 'response is of type "%s", not empty_pb2.Empty!' % type(response)) 149 150 151def _large_unary(stub): 152 _large_unary_common_behavior(stub, False, False, None) 153 154 155def _client_streaming(stub): 156 payload_body_sizes = ( 157 27182, 158 8, 159 1828, 160 45904, 161 ) 162 payloads = (messages_pb2.Payload(body=b'\x00' * size) 163 for size in payload_body_sizes) 164 requests = (messages_pb2.StreamingInputCallRequest(payload=payload) 165 for payload in payloads) 166 response = stub.StreamingInputCall(requests) 167 if response.aggregated_payload_size != 74922: 168 raise ValueError( 169 'incorrect size %d!' % response.aggregated_payload_size) 170 171 172def _server_streaming(stub): 173 sizes = ( 174 31415, 175 9, 176 2653, 177 58979, 178 ) 179 180 request = messages_pb2.StreamingOutputCallRequest( 181 response_type=messages_pb2.COMPRESSABLE, 182 response_parameters=( 183 messages_pb2.ResponseParameters(size=sizes[0]), 184 messages_pb2.ResponseParameters(size=sizes[1]), 185 messages_pb2.ResponseParameters(size=sizes[2]), 186 messages_pb2.ResponseParameters(size=sizes[3]), 187 )) 188 response_iterator = stub.StreamingOutputCall(request) 189 for index, response in enumerate(response_iterator): 190 _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, 191 sizes[index]) 192 193 194class _Pipe(object): 195 196 def __init__(self): 197 self._condition = threading.Condition() 198 self._values = [] 199 self._open = True 200 201 def __iter__(self): 202 return self 203 204 def __next__(self): 205 return self.next() 206 207 def next(self): 208 with self._condition: 209 while not self._values and self._open: 210 self._condition.wait() 211 if self._values: 212 return self._values.pop(0) 213 else: 214 raise StopIteration() 215 216 def add(self, value): 217 with self._condition: 218 self._values.append(value) 219 self._condition.notify() 220 221 def close(self): 222 with self._condition: 223 self._open = False 224 self._condition.notify() 225 226 def __enter__(self): 227 return self 228 229 def __exit__(self, type, value, traceback): 230 self.close() 231 232 233def _ping_pong(stub): 234 request_response_sizes = ( 235 31415, 236 9, 237 2653, 238 58979, 239 ) 240 request_payload_sizes = ( 241 27182, 242 8, 243 1828, 244 45904, 245 ) 246 247 with _Pipe() as pipe: 248 response_iterator = stub.FullDuplexCall(pipe) 249 for response_size, payload_size in zip(request_response_sizes, 250 request_payload_sizes): 251 request = messages_pb2.StreamingOutputCallRequest( 252 response_type=messages_pb2.COMPRESSABLE, 253 response_parameters=( 254 messages_pb2.ResponseParameters(size=response_size),), 255 payload=messages_pb2.Payload(body=b'\x00' * payload_size)) 256 pipe.add(request) 257 response = next(response_iterator) 258 _validate_payload_type_and_length( 259 response, messages_pb2.COMPRESSABLE, response_size) 260 261 262def _cancel_after_begin(stub): 263 with _Pipe() as pipe: 264 response_future = stub.StreamingInputCall.future(pipe) 265 response_future.cancel() 266 if not response_future.cancelled(): 267 raise ValueError('expected cancelled method to return True') 268 if response_future.code() is not grpc.StatusCode.CANCELLED: 269 raise ValueError('expected status code CANCELLED') 270 271 272def _cancel_after_first_response(stub): 273 request_response_sizes = ( 274 31415, 275 9, 276 2653, 277 58979, 278 ) 279 request_payload_sizes = ( 280 27182, 281 8, 282 1828, 283 45904, 284 ) 285 with _Pipe() as pipe: 286 response_iterator = stub.FullDuplexCall(pipe) 287 288 response_size = request_response_sizes[0] 289 payload_size = request_payload_sizes[0] 290 request = messages_pb2.StreamingOutputCallRequest( 291 response_type=messages_pb2.COMPRESSABLE, 292 response_parameters=( 293 messages_pb2.ResponseParameters(size=response_size),), 294 payload=messages_pb2.Payload(body=b'\x00' * payload_size)) 295 pipe.add(request) 296 response = next(response_iterator) 297 # We test the contents of `response` in the Ping Pong test - don't check 298 # them here. 299 response_iterator.cancel() 300 301 try: 302 next(response_iterator) 303 except grpc.RpcError as rpc_error: 304 if rpc_error.code() is not grpc.StatusCode.CANCELLED: 305 raise 306 else: 307 raise ValueError('expected call to be cancelled') 308 309 310def _timeout_on_sleeping_server(stub): 311 request_payload_size = 27182 312 with _Pipe() as pipe: 313 response_iterator = stub.FullDuplexCall(pipe, timeout=0.001) 314 315 request = messages_pb2.StreamingOutputCallRequest( 316 response_type=messages_pb2.COMPRESSABLE, 317 payload=messages_pb2.Payload(body=b'\x00' * request_payload_size)) 318 pipe.add(request) 319 try: 320 next(response_iterator) 321 except grpc.RpcError as rpc_error: 322 if rpc_error.code() is not grpc.StatusCode.DEADLINE_EXCEEDED: 323 raise 324 else: 325 raise ValueError('expected call to exceed deadline') 326 327 328def _empty_stream(stub): 329 with _Pipe() as pipe: 330 response_iterator = stub.FullDuplexCall(pipe) 331 pipe.close() 332 try: 333 next(response_iterator) 334 raise ValueError('expected exactly 0 responses') 335 except StopIteration: 336 pass 337 338 339def _status_code_and_message(stub): 340 details = 'test status message' 341 code = 2 342 status = grpc.StatusCode.UNKNOWN # code = 2 343 344 # Test with a UnaryCall 345 request = messages_pb2.SimpleRequest( 346 response_type=messages_pb2.COMPRESSABLE, 347 response_size=1, 348 payload=messages_pb2.Payload(body=b'\x00'), 349 response_status=messages_pb2.EchoStatus(code=code, message=details)) 350 response_future = stub.UnaryCall.future(request) 351 _validate_status_code_and_details(response_future, status, details) 352 353 # Test with a FullDuplexCall 354 with _Pipe() as pipe: 355 response_iterator = stub.FullDuplexCall(pipe) 356 request = messages_pb2.StreamingOutputCallRequest( 357 response_type=messages_pb2.COMPRESSABLE, 358 response_parameters=(messages_pb2.ResponseParameters(size=1),), 359 payload=messages_pb2.Payload(body=b'\x00'), 360 response_status=messages_pb2.EchoStatus(code=code, message=details)) 361 pipe.add(request) # sends the initial request. 362 # Dropping out of with block closes the pipe 363 _validate_status_code_and_details(response_iterator, status, details) 364 365 366def _unimplemented_method(test_service_stub): 367 response_future = (test_service_stub.UnimplementedCall.future( 368 empty_pb2.Empty())) 369 _expect_status_code(response_future, grpc.StatusCode.UNIMPLEMENTED) 370 371 372def _unimplemented_service(unimplemented_service_stub): 373 response_future = (unimplemented_service_stub.UnimplementedCall.future( 374 empty_pb2.Empty())) 375 _expect_status_code(response_future, grpc.StatusCode.UNIMPLEMENTED) 376 377 378def _custom_metadata(stub): 379 initial_metadata_value = "test_initial_metadata_value" 380 trailing_metadata_value = "\x0a\x0b\x0a\x0b\x0a\x0b" 381 metadata = ((_INITIAL_METADATA_KEY, initial_metadata_value), 382 (_TRAILING_METADATA_KEY, trailing_metadata_value)) 383 384 def _validate_metadata(response): 385 initial_metadata = dict(response.initial_metadata()) 386 if initial_metadata[_INITIAL_METADATA_KEY] != initial_metadata_value: 387 raise ValueError('expected initial metadata %s, got %s' % 388 (initial_metadata_value, 389 initial_metadata[_INITIAL_METADATA_KEY])) 390 trailing_metadata = dict(response.trailing_metadata()) 391 if trailing_metadata[_TRAILING_METADATA_KEY] != trailing_metadata_value: 392 raise ValueError('expected trailing metadata %s, got %s' % 393 (trailing_metadata_value, 394 initial_metadata[_TRAILING_METADATA_KEY])) 395 396 # Testing with UnaryCall 397 request = messages_pb2.SimpleRequest( 398 response_type=messages_pb2.COMPRESSABLE, 399 response_size=1, 400 payload=messages_pb2.Payload(body=b'\x00')) 401 response_future = stub.UnaryCall.future(request, metadata=metadata) 402 _validate_metadata(response_future) 403 404 # Testing with FullDuplexCall 405 with _Pipe() as pipe: 406 response_iterator = stub.FullDuplexCall(pipe, metadata=metadata) 407 request = messages_pb2.StreamingOutputCallRequest( 408 response_type=messages_pb2.COMPRESSABLE, 409 response_parameters=(messages_pb2.ResponseParameters(size=1),)) 410 pipe.add(request) # Sends the request 411 next(response_iterator) # Causes server to send trailing metadata 412 # Dropping out of the with block closes the pipe 413 _validate_metadata(response_iterator) 414 415 416def _compute_engine_creds(stub, args): 417 response = _large_unary_common_behavior(stub, True, True, None) 418 if args.default_service_account != response.username: 419 raise ValueError('expected username %s, got %s' % 420 (args.default_service_account, response.username)) 421 422 423def _oauth2_auth_token(stub, args): 424 json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS] 425 wanted_email = json.load(open(json_key_filename, 'rb'))['client_email'] 426 response = _large_unary_common_behavior(stub, True, True, None) 427 if wanted_email != response.username: 428 raise ValueError('expected username %s, got %s' % (wanted_email, 429 response.username)) 430 if args.oauth_scope.find(response.oauth_scope) == -1: 431 raise ValueError( 432 'expected to find oauth scope "{}" in received "{}"'.format( 433 response.oauth_scope, args.oauth_scope)) 434 435 436def _jwt_token_creds(stub, args): 437 json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS] 438 wanted_email = json.load(open(json_key_filename, 'rb'))['client_email'] 439 response = _large_unary_common_behavior(stub, True, False, None) 440 if wanted_email != response.username: 441 raise ValueError('expected username %s, got %s' % (wanted_email, 442 response.username)) 443 444 445def _per_rpc_creds(stub, args): 446 json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS] 447 wanted_email = json.load(open(json_key_filename, 'rb'))['client_email'] 448 google_credentials, unused_project_id = google_auth.default( 449 scopes=[args.oauth_scope]) 450 call_credentials = grpc.metadata_call_credentials( 451 google_auth_transport_grpc.AuthMetadataPlugin( 452 credentials=google_credentials, 453 request=google_auth_transport_requests.Request())) 454 response = _large_unary_common_behavior(stub, True, False, call_credentials) 455 if wanted_email != response.username: 456 raise ValueError('expected username %s, got %s' % (wanted_email, 457 response.username)) 458 459 460@enum.unique 461class TestCase(enum.Enum): 462 EMPTY_UNARY = 'empty_unary' 463 LARGE_UNARY = 'large_unary' 464 SERVER_STREAMING = 'server_streaming' 465 CLIENT_STREAMING = 'client_streaming' 466 PING_PONG = 'ping_pong' 467 CANCEL_AFTER_BEGIN = 'cancel_after_begin' 468 CANCEL_AFTER_FIRST_RESPONSE = 'cancel_after_first_response' 469 EMPTY_STREAM = 'empty_stream' 470 STATUS_CODE_AND_MESSAGE = 'status_code_and_message' 471 UNIMPLEMENTED_METHOD = 'unimplemented_method' 472 UNIMPLEMENTED_SERVICE = 'unimplemented_service' 473 CUSTOM_METADATA = "custom_metadata" 474 COMPUTE_ENGINE_CREDS = 'compute_engine_creds' 475 OAUTH2_AUTH_TOKEN = 'oauth2_auth_token' 476 JWT_TOKEN_CREDS = 'jwt_token_creds' 477 PER_RPC_CREDS = 'per_rpc_creds' 478 TIMEOUT_ON_SLEEPING_SERVER = 'timeout_on_sleeping_server' 479 480 def test_interoperability(self, stub, args): 481 if self is TestCase.EMPTY_UNARY: 482 _empty_unary(stub) 483 elif self is TestCase.LARGE_UNARY: 484 _large_unary(stub) 485 elif self is TestCase.SERVER_STREAMING: 486 _server_streaming(stub) 487 elif self is TestCase.CLIENT_STREAMING: 488 _client_streaming(stub) 489 elif self is TestCase.PING_PONG: 490 _ping_pong(stub) 491 elif self is TestCase.CANCEL_AFTER_BEGIN: 492 _cancel_after_begin(stub) 493 elif self is TestCase.CANCEL_AFTER_FIRST_RESPONSE: 494 _cancel_after_first_response(stub) 495 elif self is TestCase.TIMEOUT_ON_SLEEPING_SERVER: 496 _timeout_on_sleeping_server(stub) 497 elif self is TestCase.EMPTY_STREAM: 498 _empty_stream(stub) 499 elif self is TestCase.STATUS_CODE_AND_MESSAGE: 500 _status_code_and_message(stub) 501 elif self is TestCase.UNIMPLEMENTED_METHOD: 502 _unimplemented_method(stub) 503 elif self is TestCase.UNIMPLEMENTED_SERVICE: 504 _unimplemented_service(stub) 505 elif self is TestCase.CUSTOM_METADATA: 506 _custom_metadata(stub) 507 elif self is TestCase.COMPUTE_ENGINE_CREDS: 508 _compute_engine_creds(stub, args) 509 elif self is TestCase.OAUTH2_AUTH_TOKEN: 510 _oauth2_auth_token(stub, args) 511 elif self is TestCase.JWT_TOKEN_CREDS: 512 _jwt_token_creds(stub, args) 513 elif self is TestCase.PER_RPC_CREDS: 514 _per_rpc_creds(stub, args) 515 else: 516 raise NotImplementedError( 517 'Test case "%s" not implemented!' % self.name) 518