1# Copyright 2019 The gRPC Authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests behavior of the grpc.aio.Channel class.""" 15 16import logging 17import os 18import unittest 19 20import grpc 21from grpc.experimental import aio 22 23from src.proto.grpc.testing import messages_pb2 24from src.proto.grpc.testing import test_pb2_grpc 25from tests.unit.framework.common import test_constants 26from tests_aio.unit._constants import UNARY_CALL_WITH_SLEEP_VALUE 27from tests_aio.unit._constants import UNREACHABLE_TARGET 28from tests_aio.unit._test_base import AioTestBase 29from tests_aio.unit._test_server import start_test_server 30 31_UNARY_CALL_METHOD = "/grpc.testing.TestService/UnaryCall" 32_UNARY_CALL_METHOD_WITH_SLEEP = "/grpc.testing.TestService/UnaryCallWithSleep" 33_STREAMING_OUTPUT_CALL_METHOD = "/grpc.testing.TestService/StreamingOutputCall" 34 35_INVOCATION_METADATA = ( 36 ("x-grpc-test-echo-initial", "initial-md-value"), 37 ("x-grpc-test-echo-trailing-bin", b"\x00\x02"), 38) 39 40_NUM_STREAM_RESPONSES = 5 41_REQUEST_PAYLOAD_SIZE = 7 42_RESPONSE_PAYLOAD_SIZE = 42 43 44 45class TestChannel(AioTestBase): 46 async def setUp(self): 47 self._server_target, self._server = await start_test_server() 48 49 async def tearDown(self): 50 await self._server.stop(None) 51 52 async def test_async_context(self): 53 async with aio.insecure_channel(self._server_target) as channel: 54 hi = channel.unary_unary( 55 _UNARY_CALL_METHOD, 56 request_serializer=messages_pb2.SimpleRequest.SerializeToString, 57 response_deserializer=messages_pb2.SimpleResponse.FromString, 58 ) 59 await hi(messages_pb2.SimpleRequest()) 60 61 async def test_unary_unary(self): 62 async with aio.insecure_channel(self._server_target) as channel: 63 hi = channel.unary_unary( 64 _UNARY_CALL_METHOD, 65 request_serializer=messages_pb2.SimpleRequest.SerializeToString, 66 response_deserializer=messages_pb2.SimpleResponse.FromString, 67 ) 68 response = await hi(messages_pb2.SimpleRequest()) 69 70 self.assertIsInstance(response, messages_pb2.SimpleResponse) 71 72 async def test_unary_call_times_out(self): 73 async with aio.insecure_channel(self._server_target) as channel: 74 hi = channel.unary_unary( 75 _UNARY_CALL_METHOD_WITH_SLEEP, 76 request_serializer=messages_pb2.SimpleRequest.SerializeToString, 77 response_deserializer=messages_pb2.SimpleResponse.FromString, 78 ) 79 80 with self.assertRaises(grpc.RpcError) as exception_context: 81 await hi( 82 messages_pb2.SimpleRequest(), 83 timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2, 84 ) 85 86 ( 87 _, 88 details, 89 ) = ( 90 grpc.StatusCode.DEADLINE_EXCEEDED.value 91 ) # pylint: disable=unused-variable 92 self.assertEqual( 93 grpc.StatusCode.DEADLINE_EXCEEDED, 94 exception_context.exception.code(), 95 ) 96 self.assertEqual( 97 details.title(), exception_context.exception.details() 98 ) 99 self.assertIsNotNone(exception_context.exception.initial_metadata()) 100 self.assertIsNotNone( 101 exception_context.exception.trailing_metadata() 102 ) 103 104 @unittest.skipIf( 105 os.name == "nt", "TODO: https://github.com/grpc/grpc/issues/21658" 106 ) 107 async def test_unary_call_does_not_times_out(self): 108 async with aio.insecure_channel(self._server_target) as channel: 109 hi = channel.unary_unary( 110 _UNARY_CALL_METHOD_WITH_SLEEP, 111 request_serializer=messages_pb2.SimpleRequest.SerializeToString, 112 response_deserializer=messages_pb2.SimpleResponse.FromString, 113 ) 114 115 call = hi( 116 messages_pb2.SimpleRequest(), 117 timeout=UNARY_CALL_WITH_SLEEP_VALUE * 5, 118 ) 119 self.assertEqual(await call.code(), grpc.StatusCode.OK) 120 121 async def test_unary_stream(self): 122 channel = aio.insecure_channel(self._server_target) 123 stub = test_pb2_grpc.TestServiceStub(channel) 124 125 # Prepares the request 126 request = messages_pb2.StreamingOutputCallRequest() 127 for _ in range(_NUM_STREAM_RESPONSES): 128 request.response_parameters.append( 129 messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE) 130 ) 131 132 # Invokes the actual RPC 133 call = stub.StreamingOutputCall(request) 134 135 # Validates the responses 136 response_cnt = 0 137 async for response in call: 138 response_cnt += 1 139 self.assertIs( 140 type(response), messages_pb2.StreamingOutputCallResponse 141 ) 142 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 143 144 self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt) 145 self.assertEqual(await call.code(), grpc.StatusCode.OK) 146 await channel.close() 147 148 async def test_stream_unary_using_write(self): 149 channel = aio.insecure_channel(self._server_target) 150 stub = test_pb2_grpc.TestServiceStub(channel) 151 152 # Invokes the actual RPC 153 call = stub.StreamingInputCall() 154 155 # Prepares the request 156 payload = messages_pb2.Payload(body=b"\0" * _REQUEST_PAYLOAD_SIZE) 157 request = messages_pb2.StreamingInputCallRequest(payload=payload) 158 159 # Sends out requests 160 for _ in range(_NUM_STREAM_RESPONSES): 161 await call.write(request) 162 await call.done_writing() 163 164 # Validates the responses 165 response = await call 166 self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) 167 self.assertEqual( 168 _NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, 169 response.aggregated_payload_size, 170 ) 171 172 self.assertEqual(await call.code(), grpc.StatusCode.OK) 173 await channel.close() 174 175 async def test_stream_unary_using_async_gen(self): 176 channel = aio.insecure_channel(self._server_target) 177 stub = test_pb2_grpc.TestServiceStub(channel) 178 179 # Prepares the request 180 payload = messages_pb2.Payload(body=b"\0" * _REQUEST_PAYLOAD_SIZE) 181 request = messages_pb2.StreamingInputCallRequest(payload=payload) 182 183 async def gen(): 184 for _ in range(_NUM_STREAM_RESPONSES): 185 yield request 186 187 # Invokes the actual RPC 188 call = stub.StreamingInputCall(gen()) 189 190 # Validates the responses 191 response = await call 192 self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) 193 self.assertEqual( 194 _NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, 195 response.aggregated_payload_size, 196 ) 197 198 self.assertEqual(await call.code(), grpc.StatusCode.OK) 199 await channel.close() 200 201 async def test_stream_stream_using_read_write(self): 202 channel = aio.insecure_channel(self._server_target) 203 stub = test_pb2_grpc.TestServiceStub(channel) 204 205 # Invokes the actual RPC 206 call = stub.FullDuplexCall() 207 208 # Prepares the request 209 request = messages_pb2.StreamingOutputCallRequest() 210 request.response_parameters.append( 211 messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE) 212 ) 213 214 for _ in range(_NUM_STREAM_RESPONSES): 215 await call.write(request) 216 response = await call.read() 217 self.assertIsInstance( 218 response, messages_pb2.StreamingOutputCallResponse 219 ) 220 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 221 222 await call.done_writing() 223 224 self.assertEqual(grpc.StatusCode.OK, await call.code()) 225 await channel.close() 226 227 async def test_stream_stream_using_async_gen(self): 228 channel = aio.insecure_channel(self._server_target) 229 stub = test_pb2_grpc.TestServiceStub(channel) 230 231 # Prepares the request 232 request = messages_pb2.StreamingOutputCallRequest() 233 request.response_parameters.append( 234 messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE) 235 ) 236 237 async def gen(): 238 for _ in range(_NUM_STREAM_RESPONSES): 239 yield request 240 241 # Invokes the actual RPC 242 call = stub.FullDuplexCall(gen()) 243 244 async for response in call: 245 self.assertIsInstance( 246 response, messages_pb2.StreamingOutputCallResponse 247 ) 248 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 249 250 self.assertEqual(grpc.StatusCode.OK, await call.code()) 251 await channel.close() 252 253 254if __name__ == "__main__": 255 logging.basicConfig(level=logging.DEBUG) 256 unittest.main(verbosity=2) 257