• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The gRPC Authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests behavior of the Call classes."""
15
16import asyncio
17import datetime
18import logging
19import random
20import unittest
21
22import grpc
23from grpc.experimental import aio
24
25from src.proto.grpc.testing import messages_pb2
26from src.proto.grpc.testing import test_pb2_grpc
27from tests_aio.unit._constants import UNREACHABLE_TARGET
28from tests_aio.unit._test_base import AioTestBase
29from tests_aio.unit._test_server import start_test_server
30
31_SHORT_TIMEOUT_S = datetime.timedelta(seconds=1).total_seconds()
32
33_NUM_STREAM_RESPONSES = 5
34_RESPONSE_PAYLOAD_SIZE = 42
35_REQUEST_PAYLOAD_SIZE = 7
36_LOCAL_CANCEL_DETAILS_EXPECTATION = "Locally cancelled by application!"
37_RESPONSE_INTERVAL_US = int(_SHORT_TIMEOUT_S * 1000 * 1000)
38_INFINITE_INTERVAL_US = 2**31 - 1
39
40_NONDETERMINISTIC_ITERATIONS = 50
41_NONDETERMINISTIC_SERVER_SLEEP_MAX_US = 1000
42
43
44class _MulticallableTestMixin:
45    async def setUp(self):
46        address, self._server = await start_test_server()
47        self._channel = aio.insecure_channel(address)
48        self._stub = test_pb2_grpc.TestServiceStub(self._channel)
49
50    async def tearDown(self):
51        await self._channel.close()
52        await self._server.stop(None)
53
54
55class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase):
56    async def test_call_to_string(self):
57        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
58
59        self.assertTrue(str(call) is not None)
60        self.assertTrue(repr(call) is not None)
61
62        await call
63
64        self.assertTrue(str(call) is not None)
65        self.assertTrue(repr(call) is not None)
66
67    async def test_call_ok(self):
68        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
69
70        self.assertFalse(call.done())
71
72        response = await call
73
74        self.assertTrue(call.done())
75        self.assertIsInstance(response, messages_pb2.SimpleResponse)
76        self.assertEqual(await call.code(), grpc.StatusCode.OK)
77
78        # Response is cached at call object level, reentrance
79        # returns again the same response
80        response_retry = await call
81        self.assertIs(response, response_retry)
82
83    async def test_call_rpc_error(self):
84        async with aio.insecure_channel(UNREACHABLE_TARGET) as channel:
85            stub = test_pb2_grpc.TestServiceStub(channel)
86
87            call = stub.UnaryCall(messages_pb2.SimpleRequest())
88
89            with self.assertRaises(aio.AioRpcError) as exception_context:
90                await call
91
92            self.assertEqual(
93                grpc.StatusCode.UNAVAILABLE, exception_context.exception.code()
94            )
95
96            self.assertTrue(call.done())
97            self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code())
98
99    async def test_call_code_awaitable(self):
100        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
101        self.assertEqual(await call.code(), grpc.StatusCode.OK)
102
103    async def test_call_details_awaitable(self):
104        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
105        self.assertEqual("", await call.details())
106
107    async def test_call_initial_metadata_awaitable(self):
108        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
109        self.assertEqual(aio.Metadata(), await call.initial_metadata())
110
111    async def test_call_trailing_metadata_awaitable(self):
112        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
113        self.assertEqual(aio.Metadata(), await call.trailing_metadata())
114
115    async def test_call_initial_metadata_cancelable(self):
116        coro_started = asyncio.Event()
117        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
118
119        async def coro():
120            coro_started.set()
121            await call.initial_metadata()
122
123        task = self.loop.create_task(coro())
124        await coro_started.wait()
125        task.cancel()
126
127        # Test that initial metadata can still be asked thought
128        # a cancellation happened with the previous task
129        self.assertEqual(aio.Metadata(), await call.initial_metadata())
130
131    async def test_call_initial_metadata_multiple_waiters(self):
132        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
133
134        async def coro():
135            return await call.initial_metadata()
136
137        task1 = self.loop.create_task(coro())
138        task2 = self.loop.create_task(coro())
139
140        await call
141        expected = [aio.Metadata() for _ in range(2)]
142        self.assertEqual(expected, await asyncio.gather(*[task1, task2]))
143
144    async def test_call_code_cancelable(self):
145        coro_started = asyncio.Event()
146        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
147
148        async def coro():
149            coro_started.set()
150            await call.code()
151
152        task = self.loop.create_task(coro())
153        await coro_started.wait()
154        task.cancel()
155
156        # Test that code can still be asked thought
157        # a cancellation happened with the previous task
158        self.assertEqual(grpc.StatusCode.OK, await call.code())
159
160    async def test_call_code_multiple_waiters(self):
161        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
162
163        async def coro():
164            return await call.code()
165
166        task1 = self.loop.create_task(coro())
167        task2 = self.loop.create_task(coro())
168
169        await call
170
171        self.assertEqual(
172            [grpc.StatusCode.OK, grpc.StatusCode.OK],
173            await asyncio.gather(task1, task2),
174        )
175
176    async def test_cancel_unary_unary(self):
177        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
178
179        self.assertFalse(call.cancelled())
180
181        self.assertTrue(call.cancel())
182        self.assertFalse(call.cancel())
183
184        with self.assertRaises(asyncio.CancelledError):
185            await call
186
187        # The info in the RpcError should match the info in Call object.
188        self.assertTrue(call.cancelled())
189        self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
190        self.assertEqual(
191            await call.details(), "Locally cancelled by application!"
192        )
193
194    async def test_cancel_unary_unary_in_task(self):
195        coro_started = asyncio.Event()
196        call = self._stub.EmptyCall(messages_pb2.SimpleRequest())
197
198        async def another_coro():
199            coro_started.set()
200            await call
201
202        task = self.loop.create_task(another_coro())
203        await coro_started.wait()
204
205        self.assertFalse(task.done())
206        task.cancel()
207
208        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
209
210        with self.assertRaises(asyncio.CancelledError):
211            await task
212
213    async def test_passing_credentials_fails_over_insecure_channel(self):
214        call_credentials = grpc.composite_call_credentials(
215            grpc.access_token_call_credentials("abc"),
216            grpc.access_token_call_credentials("def"),
217        )
218        with self.assertRaisesRegex(
219            aio.UsageError, "Call credentials are only valid on secure channels"
220        ):
221            self._stub.UnaryCall(
222                messages_pb2.SimpleRequest(), credentials=call_credentials
223            )
224
225
226class TestUnaryStreamCall(_MulticallableTestMixin, AioTestBase):
227    async def test_call_rpc_error(self):
228        channel = aio.insecure_channel(UNREACHABLE_TARGET)
229        request = messages_pb2.StreamingOutputCallRequest()
230        stub = test_pb2_grpc.TestServiceStub(channel)
231        call = stub.StreamingOutputCall(request)
232
233        with self.assertRaises(aio.AioRpcError) as exception_context:
234            async for response in call:
235                pass
236
237        self.assertEqual(
238            grpc.StatusCode.UNAVAILABLE, exception_context.exception.code()
239        )
240
241        self.assertTrue(call.done())
242        self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code())
243        await channel.close()
244
245    async def test_cancel_unary_stream(self):
246        # Prepares the request
247        request = messages_pb2.StreamingOutputCallRequest()
248        for _ in range(_NUM_STREAM_RESPONSES):
249            request.response_parameters.append(
250                messages_pb2.ResponseParameters(
251                    size=_RESPONSE_PAYLOAD_SIZE,
252                    interval_us=_RESPONSE_INTERVAL_US,
253                )
254            )
255
256        # Invokes the actual RPC
257        call = self._stub.StreamingOutputCall(request)
258        self.assertFalse(call.cancelled())
259
260        response = await call.read()
261        self.assertIs(type(response), messages_pb2.StreamingOutputCallResponse)
262        self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
263
264        self.assertTrue(call.cancel())
265        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
266        self.assertEqual(
267            _LOCAL_CANCEL_DETAILS_EXPECTATION, await call.details()
268        )
269        self.assertFalse(call.cancel())
270
271        with self.assertRaises(asyncio.CancelledError):
272            await call.read()
273        self.assertTrue(call.cancelled())
274
275    async def test_multiple_cancel_unary_stream(self):
276        # Prepares the request
277        request = messages_pb2.StreamingOutputCallRequest()
278        for _ in range(_NUM_STREAM_RESPONSES):
279            request.response_parameters.append(
280                messages_pb2.ResponseParameters(
281                    size=_RESPONSE_PAYLOAD_SIZE,
282                    interval_us=_RESPONSE_INTERVAL_US,
283                )
284            )
285
286        # Invokes the actual RPC
287        call = self._stub.StreamingOutputCall(request)
288        self.assertFalse(call.cancelled())
289
290        response = await call.read()
291        self.assertIs(type(response), messages_pb2.StreamingOutputCallResponse)
292        self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
293
294        self.assertTrue(call.cancel())
295        self.assertFalse(call.cancel())
296        self.assertFalse(call.cancel())
297        self.assertFalse(call.cancel())
298
299        with self.assertRaises(asyncio.CancelledError):
300            await call.read()
301
302    async def test_early_cancel_unary_stream(self):
303        """Test cancellation before receiving messages."""
304        # Prepares the request
305        request = messages_pb2.StreamingOutputCallRequest()
306        for _ in range(_NUM_STREAM_RESPONSES):
307            request.response_parameters.append(
308                messages_pb2.ResponseParameters(
309                    size=_RESPONSE_PAYLOAD_SIZE,
310                    interval_us=_RESPONSE_INTERVAL_US,
311                )
312            )
313
314        # Invokes the actual RPC
315        call = self._stub.StreamingOutputCall(request)
316
317        self.assertFalse(call.cancelled())
318        self.assertTrue(call.cancel())
319        self.assertFalse(call.cancel())
320
321        with self.assertRaises(asyncio.CancelledError):
322            await call.read()
323
324        self.assertTrue(call.cancelled())
325
326        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
327        self.assertEqual(
328            _LOCAL_CANCEL_DETAILS_EXPECTATION, await call.details()
329        )
330
331    async def test_late_cancel_unary_stream(self):
332        """Test cancellation after received all messages."""
333        # Prepares the request
334        request = messages_pb2.StreamingOutputCallRequest()
335        for _ in range(_NUM_STREAM_RESPONSES):
336            request.response_parameters.append(
337                messages_pb2.ResponseParameters(
338                    size=_RESPONSE_PAYLOAD_SIZE,
339                )
340            )
341
342        # Invokes the actual RPC
343        call = self._stub.StreamingOutputCall(request)
344
345        for _ in range(_NUM_STREAM_RESPONSES):
346            response = await call.read()
347            self.assertIs(
348                type(response), messages_pb2.StreamingOutputCallResponse
349            )
350            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
351
352        # After all messages received, it is possible that the final state
353        # is received or on its way. It's basically a data race, so our
354        # expectation here is do not crash :)
355        call.cancel()
356        self.assertIn(
357            await call.code(), [grpc.StatusCode.OK, grpc.StatusCode.CANCELLED]
358        )
359
360    async def test_too_many_reads_unary_stream(self):
361        """Test calling read after received all messages fails."""
362        # Prepares the request
363        request = messages_pb2.StreamingOutputCallRequest()
364        for _ in range(_NUM_STREAM_RESPONSES):
365            request.response_parameters.append(
366                messages_pb2.ResponseParameters(
367                    size=_RESPONSE_PAYLOAD_SIZE,
368                )
369            )
370
371        # Invokes the actual RPC
372        call = self._stub.StreamingOutputCall(request)
373
374        for _ in range(_NUM_STREAM_RESPONSES):
375            response = await call.read()
376            self.assertIs(
377                type(response), messages_pb2.StreamingOutputCallResponse
378            )
379            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
380        self.assertIs(await call.read(), aio.EOF)
381
382        # After the RPC is finished, further reads will lead to exception.
383        self.assertEqual(await call.code(), grpc.StatusCode.OK)
384        self.assertIs(await call.read(), aio.EOF)
385
386    async def test_unary_stream_async_generator(self):
387        """Sunny day test case for unary_stream."""
388        # Prepares the request
389        request = messages_pb2.StreamingOutputCallRequest()
390        for _ in range(_NUM_STREAM_RESPONSES):
391            request.response_parameters.append(
392                messages_pb2.ResponseParameters(
393                    size=_RESPONSE_PAYLOAD_SIZE,
394                )
395            )
396
397        # Invokes the actual RPC
398        call = self._stub.StreamingOutputCall(request)
399        self.assertFalse(call.cancelled())
400
401        async for response in call:
402            self.assertIs(
403                type(response), messages_pb2.StreamingOutputCallResponse
404            )
405            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
406
407        self.assertEqual(await call.code(), grpc.StatusCode.OK)
408
409    async def test_cancel_unary_stream_with_many_interleavings(self):
410        """A cheap alternative to a structured fuzzer.
411
412        Certain classes of error only appear for very specific interleavings of
413        coroutines. Rather than inserting semi-private asyncio.Events throughout
414        the implementation on which to coordinate and explicilty waiting on those
415        in tests, we instead search for bugs over the space of interleavings by
416        stochastically varying the durations of certain events within the test.
417        """
418
419        # We range over several orders of magnitude to ensure that switching platforms
420        # (i.e. to slow CI machines) does not result in this test becoming a no-op.
421        sleep_ranges = (10.0**-i for i in range(1, 4))
422        for sleep_range in sleep_ranges:
423            for _ in range(_NONDETERMINISTIC_ITERATIONS):
424                interval_us = random.randrange(
425                    _NONDETERMINISTIC_SERVER_SLEEP_MAX_US
426                )
427                sleep_secs = sleep_range * random.random()
428
429                coro_started = asyncio.Event()
430
431                # Configs the server method to block forever
432                request = messages_pb2.StreamingOutputCallRequest()
433                request.response_parameters.append(
434                    messages_pb2.ResponseParameters(
435                        size=1,
436                        interval_us=interval_us,
437                    )
438                )
439
440                # Invokes the actual RPC
441                call = self._stub.StreamingOutputCall(request)
442
443                unhandled_error = False
444
445                async def another_coro():
446                    nonlocal unhandled_error
447                    coro_started.set()
448                    try:
449                        await call.read()
450                    except asyncio.CancelledError:
451                        pass
452                    except Exception as e:
453                        unhandled_error = True
454                        raise
455
456                task = self.loop.create_task(another_coro())
457                await coro_started.wait()
458                await asyncio.sleep(sleep_secs)
459
460                task.cancel()
461
462                try:
463                    await task
464                except asyncio.CancelledError:
465                    pass
466
467                self.assertFalse(unhandled_error)
468
469    async def test_cancel_unary_stream_in_task_using_read(self):
470        coro_started = asyncio.Event()
471
472        # Configs the server method to block forever
473        request = messages_pb2.StreamingOutputCallRequest()
474        request.response_parameters.append(
475            messages_pb2.ResponseParameters(
476                size=_RESPONSE_PAYLOAD_SIZE,
477                interval_us=_INFINITE_INTERVAL_US,
478            )
479        )
480
481        # Invokes the actual RPC
482        call = self._stub.StreamingOutputCall(request)
483
484        async def another_coro():
485            coro_started.set()
486            await call.read()
487
488        task = self.loop.create_task(another_coro())
489        await coro_started.wait()
490
491        self.assertFalse(task.done())
492        task.cancel()
493
494        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
495
496        with self.assertRaises(asyncio.CancelledError):
497            await task
498
499    async def test_cancel_unary_stream_in_task_using_async_for(self):
500        coro_started = asyncio.Event()
501
502        # Configs the server method to block forever
503        request = messages_pb2.StreamingOutputCallRequest()
504        request.response_parameters.append(
505            messages_pb2.ResponseParameters(
506                size=_RESPONSE_PAYLOAD_SIZE,
507                interval_us=_INFINITE_INTERVAL_US,
508            )
509        )
510
511        # Invokes the actual RPC
512        call = self._stub.StreamingOutputCall(request)
513
514        async def another_coro():
515            coro_started.set()
516            async for _ in call:
517                pass
518
519        task = self.loop.create_task(another_coro())
520        await coro_started.wait()
521
522        self.assertFalse(task.done())
523        task.cancel()
524
525        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
526
527        with self.assertRaises(asyncio.CancelledError):
528            await task
529
530    async def test_time_remaining(self):
531        request = messages_pb2.StreamingOutputCallRequest()
532        # First message comes back immediately
533        request.response_parameters.append(
534            messages_pb2.ResponseParameters(
535                size=_RESPONSE_PAYLOAD_SIZE,
536            )
537        )
538        # Second message comes back after a unit of wait time
539        request.response_parameters.append(
540            messages_pb2.ResponseParameters(
541                size=_RESPONSE_PAYLOAD_SIZE,
542                interval_us=_RESPONSE_INTERVAL_US,
543            )
544        )
545
546        call = self._stub.StreamingOutputCall(
547            request, timeout=_SHORT_TIMEOUT_S * 2
548        )
549
550        response = await call.read()
551        self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
552
553        # Should be around the same as the timeout
554        remained_time = call.time_remaining()
555        self.assertGreater(remained_time, _SHORT_TIMEOUT_S * 3 / 2)
556        self.assertLess(remained_time, _SHORT_TIMEOUT_S * 5 / 2)
557
558        response = await call.read()
559        self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
560
561        # Should be around the timeout minus a unit of wait time
562        remained_time = call.time_remaining()
563        self.assertGreater(remained_time, _SHORT_TIMEOUT_S / 2)
564        self.assertLess(remained_time, _SHORT_TIMEOUT_S * 3 / 2)
565
566        self.assertEqual(grpc.StatusCode.OK, await call.code())
567
568    async def test_empty_responses(self):
569        # Prepares the request
570        request = messages_pb2.StreamingOutputCallRequest()
571        for _ in range(_NUM_STREAM_RESPONSES):
572            request.response_parameters.append(
573                messages_pb2.ResponseParameters()
574            )
575
576        # Invokes the actual RPC
577        call = self._stub.StreamingOutputCall(request)
578
579        for _ in range(_NUM_STREAM_RESPONSES):
580            response = await call.read()
581            self.assertIs(
582                type(response), messages_pb2.StreamingOutputCallResponse
583            )
584            self.assertEqual(b"", response.SerializeToString())
585
586        self.assertEqual(grpc.StatusCode.OK, await call.code())
587
588
589class TestStreamUnaryCall(_MulticallableTestMixin, AioTestBase):
590    async def test_cancel_stream_unary(self):
591        call = self._stub.StreamingInputCall()
592
593        # Prepares the request
594        payload = messages_pb2.Payload(body=b"\0" * _REQUEST_PAYLOAD_SIZE)
595        request = messages_pb2.StreamingInputCallRequest(payload=payload)
596
597        # Sends out requests
598        for _ in range(_NUM_STREAM_RESPONSES):
599            await call.write(request)
600
601        # Cancels the RPC
602        self.assertFalse(call.done())
603        self.assertFalse(call.cancelled())
604        self.assertTrue(call.cancel())
605        self.assertTrue(call.cancelled())
606
607        await call.done_writing()
608
609        with self.assertRaises(asyncio.CancelledError):
610            await call
611
612    async def test_early_cancel_stream_unary(self):
613        call = self._stub.StreamingInputCall()
614
615        # Cancels the RPC
616        self.assertFalse(call.done())
617        self.assertFalse(call.cancelled())
618        self.assertTrue(call.cancel())
619        self.assertTrue(call.cancelled())
620
621        with self.assertRaises(asyncio.InvalidStateError):
622            await call.write(messages_pb2.StreamingInputCallRequest())
623
624        # Should be no-op
625        await call.done_writing()
626
627        with self.assertRaises(asyncio.CancelledError):
628            await call
629
630    async def test_write_after_done_writing(self):
631        call = self._stub.StreamingInputCall()
632
633        # Prepares the request
634        payload = messages_pb2.Payload(body=b"\0" * _REQUEST_PAYLOAD_SIZE)
635        request = messages_pb2.StreamingInputCallRequest(payload=payload)
636
637        # Sends out requests
638        for _ in range(_NUM_STREAM_RESPONSES):
639            await call.write(request)
640
641        # Should be no-op
642        await call.done_writing()
643
644        with self.assertRaises(asyncio.InvalidStateError):
645            await call.write(messages_pb2.StreamingInputCallRequest())
646
647        response = await call
648        self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
649        self.assertEqual(
650            _NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
651            response.aggregated_payload_size,
652        )
653
654        self.assertEqual(await call.code(), grpc.StatusCode.OK)
655
656    async def test_error_in_async_generator(self):
657        # Server will pause between responses
658        request = messages_pb2.StreamingOutputCallRequest()
659        request.response_parameters.append(
660            messages_pb2.ResponseParameters(
661                size=_RESPONSE_PAYLOAD_SIZE,
662                interval_us=_RESPONSE_INTERVAL_US,
663            )
664        )
665
666        # We expect the request iterator to receive the exception
667        request_iterator_received_the_exception = asyncio.Event()
668
669        async def request_iterator():
670            with self.assertRaises(asyncio.CancelledError):
671                for _ in range(_NUM_STREAM_RESPONSES):
672                    yield request
673                    await asyncio.sleep(_SHORT_TIMEOUT_S)
674            request_iterator_received_the_exception.set()
675
676        call = self._stub.StreamingInputCall(request_iterator())
677
678        # Cancel the RPC after at least one response
679        async def cancel_later():
680            await asyncio.sleep(_SHORT_TIMEOUT_S * 2)
681            call.cancel()
682
683        cancel_later_task = self.loop.create_task(cancel_later())
684
685        with self.assertRaises(asyncio.CancelledError):
686            await call
687
688        await request_iterator_received_the_exception.wait()
689
690        # No failures in the cancel later task!
691        await cancel_later_task
692
693    async def test_normal_iterable_requests(self):
694        # Prepares the request
695        payload = messages_pb2.Payload(body=b"\0" * _REQUEST_PAYLOAD_SIZE)
696        request = messages_pb2.StreamingInputCallRequest(payload=payload)
697        requests = [request] * _NUM_STREAM_RESPONSES
698
699        # Sends out requests
700        call = self._stub.StreamingInputCall(requests)
701
702        # RPC should succeed
703        response = await call
704        self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
705        self.assertEqual(
706            _NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
707            response.aggregated_payload_size,
708        )
709
710        self.assertEqual(await call.code(), grpc.StatusCode.OK)
711
712    async def test_call_rpc_error(self):
713        async with aio.insecure_channel(UNREACHABLE_TARGET) as channel:
714            stub = test_pb2_grpc.TestServiceStub(channel)
715
716            # The error should be raised automatically without any traffic.
717            call = stub.StreamingInputCall()
718            with self.assertRaises(aio.AioRpcError) as exception_context:
719                await call
720
721            self.assertEqual(
722                grpc.StatusCode.UNAVAILABLE, exception_context.exception.code()
723            )
724
725            self.assertTrue(call.done())
726            self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code())
727
728    async def test_timeout(self):
729        call = self._stub.StreamingInputCall(timeout=_SHORT_TIMEOUT_S)
730
731        # The error should be raised automatically without any traffic.
732        with self.assertRaises(aio.AioRpcError) as exception_context:
733            await call
734
735        rpc_error = exception_context.exception
736        self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, rpc_error.code())
737        self.assertTrue(call.done())
738        self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, await call.code())
739
740
741# Prepares the request that stream in a ping-pong manner.
742_STREAM_OUTPUT_REQUEST_ONE_RESPONSE = messages_pb2.StreamingOutputCallRequest()
743_STREAM_OUTPUT_REQUEST_ONE_RESPONSE.response_parameters.append(
744    messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)
745)
746_STREAM_OUTPUT_REQUEST_ONE_EMPTY_RESPONSE = (
747    messages_pb2.StreamingOutputCallRequest()
748)
749_STREAM_OUTPUT_REQUEST_ONE_EMPTY_RESPONSE.response_parameters.append(
750    messages_pb2.ResponseParameters()
751)
752
753
754class TestStreamStreamCall(_MulticallableTestMixin, AioTestBase):
755    async def test_cancel(self):
756        # Invokes the actual RPC
757        call = self._stub.FullDuplexCall()
758
759        for _ in range(_NUM_STREAM_RESPONSES):
760            await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
761            response = await call.read()
762            self.assertIsInstance(
763                response, messages_pb2.StreamingOutputCallResponse
764            )
765            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
766
767        # Cancels the RPC
768        self.assertFalse(call.done())
769        self.assertFalse(call.cancelled())
770        self.assertTrue(call.cancel())
771        self.assertTrue(call.cancelled())
772        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
773
774    async def test_cancel_with_pending_read(self):
775        call = self._stub.FullDuplexCall()
776
777        await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
778
779        # Cancels the RPC
780        self.assertFalse(call.done())
781        self.assertFalse(call.cancelled())
782        self.assertTrue(call.cancel())
783        self.assertTrue(call.cancelled())
784        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
785
786    async def test_cancel_with_ongoing_read(self):
787        call = self._stub.FullDuplexCall()
788        coro_started = asyncio.Event()
789
790        async def read_coro():
791            coro_started.set()
792            await call.read()
793
794        read_task = self.loop.create_task(read_coro())
795        await coro_started.wait()
796        self.assertFalse(read_task.done())
797
798        # Cancels the RPC
799        self.assertFalse(call.done())
800        self.assertFalse(call.cancelled())
801        self.assertTrue(call.cancel())
802        self.assertTrue(call.cancelled())
803        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
804
805    async def test_early_cancel(self):
806        call = self._stub.FullDuplexCall()
807
808        # Cancels the RPC
809        self.assertFalse(call.done())
810        self.assertFalse(call.cancelled())
811        self.assertTrue(call.cancel())
812        self.assertTrue(call.cancelled())
813        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
814
815    async def test_cancel_after_done_writing(self):
816        call = self._stub.FullDuplexCall()
817        await call.done_writing()
818
819        # Cancels the RPC
820        self.assertFalse(call.done())
821        self.assertFalse(call.cancelled())
822        self.assertTrue(call.cancel())
823        self.assertTrue(call.cancelled())
824        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
825
826    async def test_late_cancel(self):
827        call = self._stub.FullDuplexCall()
828        await call.done_writing()
829        self.assertEqual(grpc.StatusCode.OK, await call.code())
830
831        # Cancels the RPC
832        self.assertTrue(call.done())
833        self.assertFalse(call.cancelled())
834        self.assertFalse(call.cancel())
835        self.assertFalse(call.cancelled())
836
837        # Status is still OK
838        self.assertEqual(grpc.StatusCode.OK, await call.code())
839
840    async def test_async_generator(self):
841        async def request_generator():
842            yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE
843            yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE
844
845        call = self._stub.FullDuplexCall(request_generator())
846        async for response in call:
847            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
848
849        self.assertEqual(await call.code(), grpc.StatusCode.OK)
850
851    async def test_too_many_reads(self):
852        async def request_generator():
853            for _ in range(_NUM_STREAM_RESPONSES):
854                yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE
855
856        call = self._stub.FullDuplexCall(request_generator())
857        for _ in range(_NUM_STREAM_RESPONSES):
858            response = await call.read()
859            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
860        self.assertIs(await call.read(), aio.EOF)
861
862        self.assertEqual(await call.code(), grpc.StatusCode.OK)
863        # After the RPC finished, the read should also produce EOF
864        self.assertIs(await call.read(), aio.EOF)
865
866    async def test_read_write_after_done_writing(self):
867        call = self._stub.FullDuplexCall()
868
869        # Writes two requests, and pending two requests
870        await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
871        await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
872        await call.done_writing()
873
874        # Further write should fail
875        with self.assertRaises(asyncio.InvalidStateError):
876            await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
877
878        # But read should be unaffected
879        response = await call.read()
880        self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
881        response = await call.read()
882        self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
883
884        self.assertEqual(await call.code(), grpc.StatusCode.OK)
885
886    async def test_error_in_async_generator(self):
887        # Server will pause between responses
888        request = messages_pb2.StreamingOutputCallRequest()
889        request.response_parameters.append(
890            messages_pb2.ResponseParameters(
891                size=_RESPONSE_PAYLOAD_SIZE,
892                interval_us=_RESPONSE_INTERVAL_US,
893            )
894        )
895
896        # We expect the request iterator to receive the exception
897        request_iterator_received_the_exception = asyncio.Event()
898
899        async def request_iterator():
900            with self.assertRaises(asyncio.CancelledError):
901                for _ in range(_NUM_STREAM_RESPONSES):
902                    yield request
903                    await asyncio.sleep(_SHORT_TIMEOUT_S)
904            request_iterator_received_the_exception.set()
905
906        call = self._stub.FullDuplexCall(request_iterator())
907
908        # Cancel the RPC after at least one response
909        async def cancel_later():
910            await asyncio.sleep(_SHORT_TIMEOUT_S * 2)
911            call.cancel()
912
913        cancel_later_task = self.loop.create_task(cancel_later())
914
915        with self.assertRaises(asyncio.CancelledError):
916            async for response in call:
917                self.assertEqual(
918                    _RESPONSE_PAYLOAD_SIZE, len(response.payload.body)
919                )
920
921        await request_iterator_received_the_exception.wait()
922
923        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
924        # No failures in the cancel later task!
925        await cancel_later_task
926
927    async def test_normal_iterable_requests(self):
928        requests = [_STREAM_OUTPUT_REQUEST_ONE_RESPONSE] * _NUM_STREAM_RESPONSES
929
930        call = self._stub.FullDuplexCall(iter(requests))
931        async for response in call:
932            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
933
934        self.assertEqual(await call.code(), grpc.StatusCode.OK)
935
936    async def test_empty_ping_pong(self):
937        call = self._stub.FullDuplexCall()
938        for _ in range(_NUM_STREAM_RESPONSES):
939            await call.write(_STREAM_OUTPUT_REQUEST_ONE_EMPTY_RESPONSE)
940            response = await call.read()
941            self.assertEqual(b"", response.SerializeToString())
942        await call.done_writing()
943        self.assertEqual(await call.code(), grpc.StatusCode.OK)
944
945
946if __name__ == "__main__":
947    logging.basicConfig(level=logging.DEBUG)
948    unittest.main(verbosity=2)
949