1# Copyright 2019 The gRPC Authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests behavior of the grpc.aio.Channel class.""" 15 16import logging 17import os 18import unittest 19 20import grpc 21from grpc.experimental import aio 22 23from src.proto.grpc.testing import messages_pb2, test_pb2_grpc 24from tests.unit.framework.common import test_constants 25from tests_aio.unit._constants import (UNARY_CALL_WITH_SLEEP_VALUE, 26 UNREACHABLE_TARGET) 27from tests_aio.unit._test_base import AioTestBase 28from tests_aio.unit._test_server import start_test_server 29 30_UNARY_CALL_METHOD = '/grpc.testing.TestService/UnaryCall' 31_UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep' 32_STREAMING_OUTPUT_CALL_METHOD = '/grpc.testing.TestService/StreamingOutputCall' 33 34_INVOCATION_METADATA = ( 35 ('x-grpc-test-echo-initial', 'initial-md-value'), 36 ('x-grpc-test-echo-trailing-bin', b'\x00\x02'), 37) 38 39_NUM_STREAM_RESPONSES = 5 40_REQUEST_PAYLOAD_SIZE = 7 41_RESPONSE_PAYLOAD_SIZE = 42 42 43 44class TestChannel(AioTestBase): 45 46 async def setUp(self): 47 self._server_target, self._server = await start_test_server() 48 49 async def tearDown(self): 50 await self._server.stop(None) 51 52 async def test_async_context(self): 53 async with aio.insecure_channel(self._server_target) as channel: 54 hi = channel.unary_unary( 55 _UNARY_CALL_METHOD, 56 request_serializer=messages_pb2.SimpleRequest.SerializeToString, 57 response_deserializer=messages_pb2.SimpleResponse.FromString) 58 await hi(messages_pb2.SimpleRequest()) 59 60 async def test_unary_unary(self): 61 async with aio.insecure_channel(self._server_target) as channel: 62 hi = channel.unary_unary( 63 _UNARY_CALL_METHOD, 64 request_serializer=messages_pb2.SimpleRequest.SerializeToString, 65 response_deserializer=messages_pb2.SimpleResponse.FromString) 66 response = await hi(messages_pb2.SimpleRequest()) 67 68 self.assertIsInstance(response, messages_pb2.SimpleResponse) 69 70 async def test_unary_call_times_out(self): 71 async with aio.insecure_channel(self._server_target) as channel: 72 hi = channel.unary_unary( 73 _UNARY_CALL_METHOD_WITH_SLEEP, 74 request_serializer=messages_pb2.SimpleRequest.SerializeToString, 75 response_deserializer=messages_pb2.SimpleResponse.FromString, 76 ) 77 78 with self.assertRaises(grpc.RpcError) as exception_context: 79 await hi(messages_pb2.SimpleRequest(), 80 timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2) 81 82 _, details = grpc.StatusCode.DEADLINE_EXCEEDED.value # pylint: disable=unused-variable 83 self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, 84 exception_context.exception.code()) 85 self.assertEqual(details.title(), 86 exception_context.exception.details()) 87 self.assertIsNotNone(exception_context.exception.initial_metadata()) 88 self.assertIsNotNone( 89 exception_context.exception.trailing_metadata()) 90 91 @unittest.skipIf(os.name == 'nt', 92 'TODO: https://github.com/grpc/grpc/issues/21658') 93 async def test_unary_call_does_not_times_out(self): 94 async with aio.insecure_channel(self._server_target) as channel: 95 hi = channel.unary_unary( 96 _UNARY_CALL_METHOD_WITH_SLEEP, 97 request_serializer=messages_pb2.SimpleRequest.SerializeToString, 98 response_deserializer=messages_pb2.SimpleResponse.FromString, 99 ) 100 101 call = hi(messages_pb2.SimpleRequest(), 102 timeout=UNARY_CALL_WITH_SLEEP_VALUE * 5) 103 self.assertEqual(await call.code(), grpc.StatusCode.OK) 104 105 async def test_unary_stream(self): 106 channel = aio.insecure_channel(self._server_target) 107 stub = test_pb2_grpc.TestServiceStub(channel) 108 109 # Prepares the request 110 request = messages_pb2.StreamingOutputCallRequest() 111 for _ in range(_NUM_STREAM_RESPONSES): 112 request.response_parameters.append( 113 messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) 114 115 # Invokes the actual RPC 116 call = stub.StreamingOutputCall(request) 117 118 # Validates the responses 119 response_cnt = 0 120 async for response in call: 121 response_cnt += 1 122 self.assertIs(type(response), 123 messages_pb2.StreamingOutputCallResponse) 124 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 125 126 self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt) 127 self.assertEqual(await call.code(), grpc.StatusCode.OK) 128 await channel.close() 129 130 async def test_stream_unary_using_write(self): 131 channel = aio.insecure_channel(self._server_target) 132 stub = test_pb2_grpc.TestServiceStub(channel) 133 134 # Invokes the actual RPC 135 call = stub.StreamingInputCall() 136 137 # Prepares the request 138 payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) 139 request = messages_pb2.StreamingInputCallRequest(payload=payload) 140 141 # Sends out requests 142 for _ in range(_NUM_STREAM_RESPONSES): 143 await call.write(request) 144 await call.done_writing() 145 146 # Validates the responses 147 response = await call 148 self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) 149 self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, 150 response.aggregated_payload_size) 151 152 self.assertEqual(await call.code(), grpc.StatusCode.OK) 153 await channel.close() 154 155 async def test_stream_unary_using_async_gen(self): 156 channel = aio.insecure_channel(self._server_target) 157 stub = test_pb2_grpc.TestServiceStub(channel) 158 159 # Prepares the request 160 payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) 161 request = messages_pb2.StreamingInputCallRequest(payload=payload) 162 163 async def gen(): 164 for _ in range(_NUM_STREAM_RESPONSES): 165 yield request 166 167 # Invokes the actual RPC 168 call = stub.StreamingInputCall(gen()) 169 170 # Validates the responses 171 response = await call 172 self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) 173 self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, 174 response.aggregated_payload_size) 175 176 self.assertEqual(await call.code(), grpc.StatusCode.OK) 177 await channel.close() 178 179 async def test_stream_stream_using_read_write(self): 180 channel = aio.insecure_channel(self._server_target) 181 stub = test_pb2_grpc.TestServiceStub(channel) 182 183 # Invokes the actual RPC 184 call = stub.FullDuplexCall() 185 186 # Prepares the request 187 request = messages_pb2.StreamingOutputCallRequest() 188 request.response_parameters.append( 189 messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) 190 191 for _ in range(_NUM_STREAM_RESPONSES): 192 await call.write(request) 193 response = await call.read() 194 self.assertIsInstance(response, 195 messages_pb2.StreamingOutputCallResponse) 196 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 197 198 await call.done_writing() 199 200 self.assertEqual(grpc.StatusCode.OK, await call.code()) 201 await channel.close() 202 203 async def test_stream_stream_using_async_gen(self): 204 channel = aio.insecure_channel(self._server_target) 205 stub = test_pb2_grpc.TestServiceStub(channel) 206 207 # Prepares the request 208 request = messages_pb2.StreamingOutputCallRequest() 209 request.response_parameters.append( 210 messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) 211 212 async def gen(): 213 for _ in range(_NUM_STREAM_RESPONSES): 214 yield request 215 216 # Invokes the actual RPC 217 call = stub.FullDuplexCall(gen()) 218 219 async for response in call: 220 self.assertIsInstance(response, 221 messages_pb2.StreamingOutputCallResponse) 222 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 223 224 self.assertEqual(grpc.StatusCode.OK, await call.code()) 225 await channel.close() 226 227 228if __name__ == '__main__': 229 logging.basicConfig(level=logging.DEBUG) 230 unittest.main(verbosity=2) 231