1# Copyright 2019 The gRPC Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import asyncio 16import datetime 17 18import grpc 19from grpc.experimental import aio 20from tests.unit import resources 21 22from src.proto.grpc.testing import empty_pb2, messages_pb2, test_pb2_grpc 23from tests_aio.unit import _constants 24 25_INITIAL_METADATA_KEY = "x-grpc-test-echo-initial" 26_TRAILING_METADATA_KEY = "x-grpc-test-echo-trailing-bin" 27 28 29async def _maybe_echo_metadata(servicer_context): 30 """Copies metadata from request to response if it is present.""" 31 invocation_metadata = dict(servicer_context.invocation_metadata()) 32 if _INITIAL_METADATA_KEY in invocation_metadata: 33 initial_metadatum = (_INITIAL_METADATA_KEY, 34 invocation_metadata[_INITIAL_METADATA_KEY]) 35 await servicer_context.send_initial_metadata((initial_metadatum,)) 36 if _TRAILING_METADATA_KEY in invocation_metadata: 37 trailing_metadatum = (_TRAILING_METADATA_KEY, 38 invocation_metadata[_TRAILING_METADATA_KEY]) 39 servicer_context.set_trailing_metadata((trailing_metadatum,)) 40 41 42async def _maybe_echo_status(request: messages_pb2.SimpleRequest, 43 servicer_context): 44 """Echos the RPC status if demanded by the request.""" 45 if request.HasField('response_status'): 46 await servicer_context.abort(request.response_status.code, 47 request.response_status.message) 48 49 50class TestServiceServicer(test_pb2_grpc.TestServiceServicer): 51 52 async def UnaryCall(self, request, context): 53 await _maybe_echo_metadata(context) 54 await _maybe_echo_status(request, context) 55 return messages_pb2.SimpleResponse( 56 payload=messages_pb2.Payload(type=messages_pb2.COMPRESSABLE, 57 body=b'\x00' * request.response_size)) 58 59 async def EmptyCall(self, request, context): 60 return empty_pb2.Empty() 61 62 async def StreamingOutputCall( 63 self, request: messages_pb2.StreamingOutputCallRequest, 64 unused_context): 65 for response_parameters in request.response_parameters: 66 if response_parameters.interval_us != 0: 67 await asyncio.sleep( 68 datetime.timedelta(microseconds=response_parameters. 69 interval_us).total_seconds()) 70 if response_parameters.size != 0: 71 yield messages_pb2.StreamingOutputCallResponse( 72 payload=messages_pb2.Payload(type=request.response_type, 73 body=b'\x00' * 74 response_parameters.size)) 75 else: 76 yield messages_pb2.StreamingOutputCallResponse() 77 78 # Next methods are extra ones that are registred programatically 79 # when the sever is instantiated. They are not being provided by 80 # the proto file. 81 async def UnaryCallWithSleep(self, unused_request, unused_context): 82 await asyncio.sleep(_constants.UNARY_CALL_WITH_SLEEP_VALUE) 83 return messages_pb2.SimpleResponse() 84 85 async def StreamingInputCall(self, request_async_iterator, unused_context): 86 aggregate_size = 0 87 async for request in request_async_iterator: 88 if request.payload is not None and request.payload.body: 89 aggregate_size += len(request.payload.body) 90 return messages_pb2.StreamingInputCallResponse( 91 aggregated_payload_size=aggregate_size) 92 93 async def FullDuplexCall(self, request_async_iterator, context): 94 await _maybe_echo_metadata(context) 95 async for request in request_async_iterator: 96 await _maybe_echo_status(request, context) 97 for response_parameters in request.response_parameters: 98 if response_parameters.interval_us != 0: 99 await asyncio.sleep( 100 datetime.timedelta(microseconds=response_parameters. 101 interval_us).total_seconds()) 102 if response_parameters.size != 0: 103 yield messages_pb2.StreamingOutputCallResponse( 104 payload=messages_pb2.Payload(type=request.payload.type, 105 body=b'\x00' * 106 response_parameters.size)) 107 else: 108 yield messages_pb2.StreamingOutputCallResponse() 109 110 111def _create_extra_generic_handler(servicer: TestServiceServicer): 112 # Add programatically extra methods not provided by the proto file 113 # that are used during the tests 114 rpc_method_handlers = { 115 'UnaryCallWithSleep': 116 grpc.unary_unary_rpc_method_handler( 117 servicer.UnaryCallWithSleep, 118 request_deserializer=messages_pb2.SimpleRequest.FromString, 119 response_serializer=messages_pb2.SimpleResponse. 120 SerializeToString) 121 } 122 return grpc.method_handlers_generic_handler('grpc.testing.TestService', 123 rpc_method_handlers) 124 125 126async def start_test_server(port=0, 127 secure=False, 128 server_credentials=None, 129 interceptors=None): 130 server = aio.server(options=(('grpc.so_reuseport', 0),), 131 interceptors=interceptors) 132 servicer = TestServiceServicer() 133 test_pb2_grpc.add_TestServiceServicer_to_server(servicer, server) 134 135 server.add_generic_rpc_handlers((_create_extra_generic_handler(servicer),)) 136 137 if secure: 138 if server_credentials is None: 139 server_credentials = grpc.ssl_server_credentials([ 140 (resources.private_key(), resources.certificate_chain()) 141 ]) 142 port = server.add_secure_port('[::]:%d' % port, server_credentials) 143 else: 144 port = server.add_insecure_port('[::]:%d' % port) 145 146 await server.start() 147 148 # NOTE(lidizheng) returning the server to prevent it from deallocation 149 return 'localhost:%d' % port, server 150