• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The gRPC Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import asyncio
16import datetime
17
18import grpc
19from grpc.experimental import aio
20from tests.unit import resources
21
22from src.proto.grpc.testing import empty_pb2, messages_pb2, test_pb2_grpc
23from tests_aio.unit import _constants
24
25_INITIAL_METADATA_KEY = "x-grpc-test-echo-initial"
26_TRAILING_METADATA_KEY = "x-grpc-test-echo-trailing-bin"
27
28
29async def _maybe_echo_metadata(servicer_context):
30    """Copies metadata from request to response if it is present."""
31    invocation_metadata = dict(servicer_context.invocation_metadata())
32    if _INITIAL_METADATA_KEY in invocation_metadata:
33        initial_metadatum = (_INITIAL_METADATA_KEY,
34                             invocation_metadata[_INITIAL_METADATA_KEY])
35        await servicer_context.send_initial_metadata((initial_metadatum,))
36    if _TRAILING_METADATA_KEY in invocation_metadata:
37        trailing_metadatum = (_TRAILING_METADATA_KEY,
38                              invocation_metadata[_TRAILING_METADATA_KEY])
39        servicer_context.set_trailing_metadata((trailing_metadatum,))
40
41
42async def _maybe_echo_status(request: messages_pb2.SimpleRequest,
43                             servicer_context):
44    """Echos the RPC status if demanded by the request."""
45    if request.HasField('response_status'):
46        await servicer_context.abort(request.response_status.code,
47                                     request.response_status.message)
48
49
50class TestServiceServicer(test_pb2_grpc.TestServiceServicer):
51
52    async def UnaryCall(self, request, context):
53        await _maybe_echo_metadata(context)
54        await _maybe_echo_status(request, context)
55        return messages_pb2.SimpleResponse(
56            payload=messages_pb2.Payload(type=messages_pb2.COMPRESSABLE,
57                                         body=b'\x00' * request.response_size))
58
59    async def EmptyCall(self, request, context):
60        return empty_pb2.Empty()
61
62    async def StreamingOutputCall(
63            self, request: messages_pb2.StreamingOutputCallRequest,
64            unused_context):
65        for response_parameters in request.response_parameters:
66            if response_parameters.interval_us != 0:
67                await asyncio.sleep(
68                    datetime.timedelta(microseconds=response_parameters.
69                                       interval_us).total_seconds())
70            yield messages_pb2.StreamingOutputCallResponse(
71                payload=messages_pb2.Payload(type=request.response_type,
72                                             body=b'\x00' *
73                                             response_parameters.size))
74
75    # Next methods are extra ones that are registred programatically
76    # when the sever is instantiated. They are not being provided by
77    # the proto file.
78    async def UnaryCallWithSleep(self, unused_request, unused_context):
79        await asyncio.sleep(_constants.UNARY_CALL_WITH_SLEEP_VALUE)
80        return messages_pb2.SimpleResponse()
81
82    async def StreamingInputCall(self, request_async_iterator, unused_context):
83        aggregate_size = 0
84        async for request in request_async_iterator:
85            if request.payload is not None and request.payload.body:
86                aggregate_size += len(request.payload.body)
87        return messages_pb2.StreamingInputCallResponse(
88            aggregated_payload_size=aggregate_size)
89
90    async def FullDuplexCall(self, request_async_iterator, context):
91        await _maybe_echo_metadata(context)
92        async for request in request_async_iterator:
93            await _maybe_echo_status(request, context)
94            for response_parameters in request.response_parameters:
95                if response_parameters.interval_us != 0:
96                    await asyncio.sleep(
97                        datetime.timedelta(microseconds=response_parameters.
98                                           interval_us).total_seconds())
99                yield messages_pb2.StreamingOutputCallResponse(
100                    payload=messages_pb2.Payload(type=request.payload.type,
101                                                 body=b'\x00' *
102                                                 response_parameters.size))
103
104
105def _create_extra_generic_handler(servicer: TestServiceServicer):
106    # Add programatically extra methods not provided by the proto file
107    # that are used during the tests
108    rpc_method_handlers = {
109        'UnaryCallWithSleep':
110            grpc.unary_unary_rpc_method_handler(
111                servicer.UnaryCallWithSleep,
112                request_deserializer=messages_pb2.SimpleRequest.FromString,
113                response_serializer=messages_pb2.SimpleResponse.
114                SerializeToString)
115    }
116    return grpc.method_handlers_generic_handler('grpc.testing.TestService',
117                                                rpc_method_handlers)
118
119
120async def start_test_server(port=0,
121                            secure=False,
122                            server_credentials=None,
123                            interceptors=None):
124    server = aio.server(options=(('grpc.so_reuseport', 0),),
125                        interceptors=interceptors)
126    servicer = TestServiceServicer()
127    test_pb2_grpc.add_TestServiceServicer_to_server(servicer, server)
128
129    server.add_generic_rpc_handlers((_create_extra_generic_handler(servicer),))
130
131    if secure:
132        if server_credentials is None:
133            server_credentials = grpc.ssl_server_credentials([
134                (resources.private_key(), resources.certificate_chain())
135            ])
136        port = server.add_secure_port('[::]:%d' % port, server_credentials)
137    else:
138        port = server.add_insecure_port('[::]:%d' % port)
139
140    await server.start()
141
142    # NOTE(lidizheng) returning the server to prevent it from deallocation
143    return 'localhost:%d' % port, server
144