• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The gRPC Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import asyncio
16import datetime
17
18import grpc
19from grpc.experimental import aio
20from tests.unit import resources
21
22from src.proto.grpc.testing import empty_pb2, messages_pb2, test_pb2_grpc
23from tests_aio.unit import _constants
24
25_INITIAL_METADATA_KEY = "x-grpc-test-echo-initial"
26_TRAILING_METADATA_KEY = "x-grpc-test-echo-trailing-bin"
27
28
29async def _maybe_echo_metadata(servicer_context):
30    """Copies metadata from request to response if it is present."""
31    invocation_metadata = dict(servicer_context.invocation_metadata())
32    if _INITIAL_METADATA_KEY in invocation_metadata:
33        initial_metadatum = (_INITIAL_METADATA_KEY,
34                             invocation_metadata[_INITIAL_METADATA_KEY])
35        await servicer_context.send_initial_metadata((initial_metadatum,))
36    if _TRAILING_METADATA_KEY in invocation_metadata:
37        trailing_metadatum = (_TRAILING_METADATA_KEY,
38                              invocation_metadata[_TRAILING_METADATA_KEY])
39        servicer_context.set_trailing_metadata((trailing_metadatum,))
40
41
42async def _maybe_echo_status(request: messages_pb2.SimpleRequest,
43                             servicer_context):
44    """Echos the RPC status if demanded by the request."""
45    if request.HasField('response_status'):
46        await servicer_context.abort(request.response_status.code,
47                                     request.response_status.message)
48
49
50class TestServiceServicer(test_pb2_grpc.TestServiceServicer):
51
52    async def UnaryCall(self, request, context):
53        await _maybe_echo_metadata(context)
54        await _maybe_echo_status(request, context)
55        return messages_pb2.SimpleResponse(
56            payload=messages_pb2.Payload(type=messages_pb2.COMPRESSABLE,
57                                         body=b'\x00' * request.response_size))
58
59    async def EmptyCall(self, request, context):
60        return empty_pb2.Empty()
61
62    async def StreamingOutputCall(
63            self, request: messages_pb2.StreamingOutputCallRequest,
64            unused_context):
65        for response_parameters in request.response_parameters:
66            if response_parameters.interval_us != 0:
67                await asyncio.sleep(
68                    datetime.timedelta(microseconds=response_parameters.
69                                       interval_us).total_seconds())
70            if response_parameters.size != 0:
71                yield messages_pb2.StreamingOutputCallResponse(
72                    payload=messages_pb2.Payload(type=request.response_type,
73                                                 body=b'\x00' *
74                                                 response_parameters.size))
75            else:
76                yield messages_pb2.StreamingOutputCallResponse()
77
78    # Next methods are extra ones that are registred programatically
79    # when the sever is instantiated. They are not being provided by
80    # the proto file.
81    async def UnaryCallWithSleep(self, unused_request, unused_context):
82        await asyncio.sleep(_constants.UNARY_CALL_WITH_SLEEP_VALUE)
83        return messages_pb2.SimpleResponse()
84
85    async def StreamingInputCall(self, request_async_iterator, unused_context):
86        aggregate_size = 0
87        async for request in request_async_iterator:
88            if request.payload is not None and request.payload.body:
89                aggregate_size += len(request.payload.body)
90        return messages_pb2.StreamingInputCallResponse(
91            aggregated_payload_size=aggregate_size)
92
93    async def FullDuplexCall(self, request_async_iterator, context):
94        await _maybe_echo_metadata(context)
95        async for request in request_async_iterator:
96            await _maybe_echo_status(request, context)
97            for response_parameters in request.response_parameters:
98                if response_parameters.interval_us != 0:
99                    await asyncio.sleep(
100                        datetime.timedelta(microseconds=response_parameters.
101                                           interval_us).total_seconds())
102                if response_parameters.size != 0:
103                    yield messages_pb2.StreamingOutputCallResponse(
104                        payload=messages_pb2.Payload(type=request.payload.type,
105                                                     body=b'\x00' *
106                                                     response_parameters.size))
107                else:
108                    yield messages_pb2.StreamingOutputCallResponse()
109
110
111def _create_extra_generic_handler(servicer: TestServiceServicer):
112    # Add programatically extra methods not provided by the proto file
113    # that are used during the tests
114    rpc_method_handlers = {
115        'UnaryCallWithSleep':
116            grpc.unary_unary_rpc_method_handler(
117                servicer.UnaryCallWithSleep,
118                request_deserializer=messages_pb2.SimpleRequest.FromString,
119                response_serializer=messages_pb2.SimpleResponse.
120                SerializeToString)
121    }
122    return grpc.method_handlers_generic_handler('grpc.testing.TestService',
123                                                rpc_method_handlers)
124
125
126async def start_test_server(port=0,
127                            secure=False,
128                            server_credentials=None,
129                            interceptors=None):
130    server = aio.server(options=(('grpc.so_reuseport', 0),),
131                        interceptors=interceptors)
132    servicer = TestServiceServicer()
133    test_pb2_grpc.add_TestServiceServicer_to_server(servicer, server)
134
135    server.add_generic_rpc_handlers((_create_extra_generic_handler(servicer),))
136
137    if secure:
138        if server_credentials is None:
139            server_credentials = grpc.ssl_server_credentials([
140                (resources.private_key(), resources.certificate_chain())
141            ])
142        port = server.add_secure_port('[::]:%d' % port, server_credentials)
143    else:
144        port = server.add_insecure_port('[::]:%d' % port)
145
146    await server.start()
147
148    # NOTE(lidizheng) returning the server to prevent it from deallocation
149    return 'localhost:%d' % port, server
150