• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The gRPC Authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests behavior of the grpc.aio.Channel class."""
15
16import logging
17import os
18import unittest
19
20import grpc
21from grpc.experimental import aio
22
23from src.proto.grpc.testing import messages_pb2
24from src.proto.grpc.testing import test_pb2_grpc
25from tests.unit.framework.common import test_constants
26from tests_aio.unit._constants import UNARY_CALL_WITH_SLEEP_VALUE
27from tests_aio.unit._constants import UNREACHABLE_TARGET
28from tests_aio.unit._test_base import AioTestBase
29from tests_aio.unit._test_server import start_test_server
30
31_UNARY_CALL_METHOD = "/grpc.testing.TestService/UnaryCall"
32_UNARY_CALL_METHOD_WITH_SLEEP = "/grpc.testing.TestService/UnaryCallWithSleep"
33_STREAMING_OUTPUT_CALL_METHOD = "/grpc.testing.TestService/StreamingOutputCall"
34
35_INVOCATION_METADATA = (
36    ("x-grpc-test-echo-initial", "initial-md-value"),
37    ("x-grpc-test-echo-trailing-bin", b"\x00\x02"),
38)
39
40_NUM_STREAM_RESPONSES = 5
41_REQUEST_PAYLOAD_SIZE = 7
42_RESPONSE_PAYLOAD_SIZE = 42
43
44
45class TestChannel(AioTestBase):
46    async def setUp(self):
47        self._server_target, self._server = await start_test_server()
48
49    async def tearDown(self):
50        await self._server.stop(None)
51
52    async def test_async_context(self):
53        async with aio.insecure_channel(self._server_target) as channel:
54            hi = channel.unary_unary(
55                _UNARY_CALL_METHOD,
56                request_serializer=messages_pb2.SimpleRequest.SerializeToString,
57                response_deserializer=messages_pb2.SimpleResponse.FromString,
58            )
59            await hi(messages_pb2.SimpleRequest())
60
61    async def test_unary_unary(self):
62        async with aio.insecure_channel(self._server_target) as channel:
63            hi = channel.unary_unary(
64                _UNARY_CALL_METHOD,
65                request_serializer=messages_pb2.SimpleRequest.SerializeToString,
66                response_deserializer=messages_pb2.SimpleResponse.FromString,
67            )
68            response = await hi(messages_pb2.SimpleRequest())
69
70            self.assertIsInstance(response, messages_pb2.SimpleResponse)
71
72    async def test_unary_call_times_out(self):
73        async with aio.insecure_channel(self._server_target) as channel:
74            hi = channel.unary_unary(
75                _UNARY_CALL_METHOD_WITH_SLEEP,
76                request_serializer=messages_pb2.SimpleRequest.SerializeToString,
77                response_deserializer=messages_pb2.SimpleResponse.FromString,
78            )
79
80            with self.assertRaises(grpc.RpcError) as exception_context:
81                await hi(
82                    messages_pb2.SimpleRequest(),
83                    timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2,
84                )
85
86            (
87                _,
88                details,
89            ) = (
90                grpc.StatusCode.DEADLINE_EXCEEDED.value
91            )  # pylint: disable=unused-variable
92            self.assertEqual(
93                grpc.StatusCode.DEADLINE_EXCEEDED,
94                exception_context.exception.code(),
95            )
96            self.assertEqual(
97                details.title(), exception_context.exception.details()
98            )
99            self.assertIsNotNone(exception_context.exception.initial_metadata())
100            self.assertIsNotNone(
101                exception_context.exception.trailing_metadata()
102            )
103
104    @unittest.skipIf(
105        os.name == "nt", "TODO: https://github.com/grpc/grpc/issues/21658"
106    )
107    async def test_unary_call_does_not_times_out(self):
108        async with aio.insecure_channel(self._server_target) as channel:
109            hi = channel.unary_unary(
110                _UNARY_CALL_METHOD_WITH_SLEEP,
111                request_serializer=messages_pb2.SimpleRequest.SerializeToString,
112                response_deserializer=messages_pb2.SimpleResponse.FromString,
113            )
114
115            call = hi(
116                messages_pb2.SimpleRequest(),
117                timeout=UNARY_CALL_WITH_SLEEP_VALUE * 5,
118            )
119            self.assertEqual(await call.code(), grpc.StatusCode.OK)
120
121    async def test_unary_stream(self):
122        channel = aio.insecure_channel(self._server_target)
123        stub = test_pb2_grpc.TestServiceStub(channel)
124
125        # Prepares the request
126        request = messages_pb2.StreamingOutputCallRequest()
127        for _ in range(_NUM_STREAM_RESPONSES):
128            request.response_parameters.append(
129                messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)
130            )
131
132        # Invokes the actual RPC
133        call = stub.StreamingOutputCall(request)
134
135        # Validates the responses
136        response_cnt = 0
137        async for response in call:
138            response_cnt += 1
139            self.assertIs(
140                type(response), messages_pb2.StreamingOutputCallResponse
141            )
142            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
143
144        self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt)
145        self.assertEqual(await call.code(), grpc.StatusCode.OK)
146        await channel.close()
147
148    async def test_stream_unary_using_write(self):
149        channel = aio.insecure_channel(self._server_target)
150        stub = test_pb2_grpc.TestServiceStub(channel)
151
152        # Invokes the actual RPC
153        call = stub.StreamingInputCall()
154
155        # Prepares the request
156        payload = messages_pb2.Payload(body=b"\0" * _REQUEST_PAYLOAD_SIZE)
157        request = messages_pb2.StreamingInputCallRequest(payload=payload)
158
159        # Sends out requests
160        for _ in range(_NUM_STREAM_RESPONSES):
161            await call.write(request)
162        await call.done_writing()
163
164        # Validates the responses
165        response = await call
166        self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
167        self.assertEqual(
168            _NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
169            response.aggregated_payload_size,
170        )
171
172        self.assertEqual(await call.code(), grpc.StatusCode.OK)
173        await channel.close()
174
175    async def test_stream_unary_using_async_gen(self):
176        channel = aio.insecure_channel(self._server_target)
177        stub = test_pb2_grpc.TestServiceStub(channel)
178
179        # Prepares the request
180        payload = messages_pb2.Payload(body=b"\0" * _REQUEST_PAYLOAD_SIZE)
181        request = messages_pb2.StreamingInputCallRequest(payload=payload)
182
183        async def gen():
184            for _ in range(_NUM_STREAM_RESPONSES):
185                yield request
186
187        # Invokes the actual RPC
188        call = stub.StreamingInputCall(gen())
189
190        # Validates the responses
191        response = await call
192        self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
193        self.assertEqual(
194            _NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
195            response.aggregated_payload_size,
196        )
197
198        self.assertEqual(await call.code(), grpc.StatusCode.OK)
199        await channel.close()
200
201    async def test_stream_stream_using_read_write(self):
202        channel = aio.insecure_channel(self._server_target)
203        stub = test_pb2_grpc.TestServiceStub(channel)
204
205        # Invokes the actual RPC
206        call = stub.FullDuplexCall()
207
208        # Prepares the request
209        request = messages_pb2.StreamingOutputCallRequest()
210        request.response_parameters.append(
211            messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)
212        )
213
214        for _ in range(_NUM_STREAM_RESPONSES):
215            await call.write(request)
216            response = await call.read()
217            self.assertIsInstance(
218                response, messages_pb2.StreamingOutputCallResponse
219            )
220            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
221
222        await call.done_writing()
223
224        self.assertEqual(grpc.StatusCode.OK, await call.code())
225        await channel.close()
226
227    async def test_stream_stream_using_async_gen(self):
228        channel = aio.insecure_channel(self._server_target)
229        stub = test_pb2_grpc.TestServiceStub(channel)
230
231        # Prepares the request
232        request = messages_pb2.StreamingOutputCallRequest()
233        request.response_parameters.append(
234            messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)
235        )
236
237        async def gen():
238            for _ in range(_NUM_STREAM_RESPONSES):
239                yield request
240
241        # Invokes the actual RPC
242        call = stub.FullDuplexCall(gen())
243
244        async for response in call:
245            self.assertIsInstance(
246                response, messages_pb2.StreamingOutputCallResponse
247            )
248            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
249
250        self.assertEqual(grpc.StatusCode.OK, await call.code())
251        await channel.close()
252
253
254if __name__ == "__main__":
255    logging.basicConfig(level=logging.DEBUG)
256    unittest.main(verbosity=2)
257