• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The gRPC Authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests behavior of the Call classes."""
15
16import asyncio
17import datetime
18import logging
19import random
20import unittest
21
22import grpc
23from grpc.experimental import aio
24
25from src.proto.grpc.testing import messages_pb2
26from src.proto.grpc.testing import test_pb2_grpc
27from tests_aio.unit._constants import UNREACHABLE_TARGET
28from tests_aio.unit._test_base import AioTestBase
29from tests_aio.unit._test_server import start_test_server
30
31_SHORT_TIMEOUT_S = datetime.timedelta(seconds=1).total_seconds()
32
33_NUM_STREAM_RESPONSES = 5
34_RESPONSE_PAYLOAD_SIZE = 42
35_REQUEST_PAYLOAD_SIZE = 7
36_LOCAL_CANCEL_DETAILS_EXPECTATION = "Locally cancelled by application!"
37_RESPONSE_INTERVAL_US = int(_SHORT_TIMEOUT_S * 1000 * 1000)
38_INFINITE_INTERVAL_US = 2**31 - 1
39
40_NONDETERMINISTIC_ITERATIONS = 50
41_NONDETERMINISTIC_SERVER_SLEEP_MAX_US = 1000
42
43
44class _MulticallableTestMixin:
45    async def setUp(self):
46        address, self._server = await start_test_server()
47        self._channel = aio.insecure_channel(address)
48        self._stub = test_pb2_grpc.TestServiceStub(self._channel)
49
50    async def tearDown(self):
51        await self._channel.close()
52        await self._server.stop(None)
53
54
55class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase):
56    async def test_call_to_string(self):
57        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
58
59        self.assertTrue(str(call) is not None)
60        self.assertTrue(repr(call) is not None)
61
62        await call
63
64        self.assertTrue(str(call) is not None)
65        self.assertTrue(repr(call) is not None)
66
67    async def test_call_ok(self):
68        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
69
70        self.assertFalse(call.done())
71
72        response = await call
73
74        self.assertTrue(call.done())
75        self.assertIsInstance(response, messages_pb2.SimpleResponse)
76        self.assertEqual(await call.code(), grpc.StatusCode.OK)
77
78        # Response is cached at call object level, reentrance
79        # returns again the same response
80        response_retry = await call
81        self.assertIs(response, response_retry)
82
83    async def test_call_rpc_error(self):
84        async with aio.insecure_channel(UNREACHABLE_TARGET) as channel:
85            stub = test_pb2_grpc.TestServiceStub(channel)
86
87            call = stub.UnaryCall(messages_pb2.SimpleRequest())
88
89            with self.assertRaises(aio.AioRpcError) as exception_context:
90                await call
91
92            self.assertEqual(
93                grpc.StatusCode.UNAVAILABLE, exception_context.exception.code()
94            )
95
96            self.assertTrue(call.done())
97            self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code())
98
99    async def test_call_code_awaitable(self):
100        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
101        self.assertEqual(await call.code(), grpc.StatusCode.OK)
102
103    async def test_call_details_awaitable(self):
104        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
105        self.assertEqual("", await call.details())
106
107    async def test_call_initial_metadata_awaitable(self):
108        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
109        self.assertEqual(aio.Metadata(), await call.initial_metadata())
110
111    async def test_call_trailing_metadata_awaitable(self):
112        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
113        self.assertEqual(aio.Metadata(), await call.trailing_metadata())
114
115    async def test_call_initial_metadata_cancelable(self):
116        coro_started = asyncio.Event()
117        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
118
119        async def coro():
120            coro_started.set()
121            await call.initial_metadata()
122
123        task = self.loop.create_task(coro())
124        await coro_started.wait()
125        task.cancel()
126
127        # Test that initial metadata can still be asked thought
128        # a cancellation happened with the previous task
129        self.assertEqual(aio.Metadata(), await call.initial_metadata())
130
131    async def test_call_initial_metadata_multiple_waiters(self):
132        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
133
134        async def coro():
135            return await call.initial_metadata()
136
137        task1 = self.loop.create_task(coro())
138        task2 = self.loop.create_task(coro())
139
140        await call
141        expected = [aio.Metadata() for _ in range(2)]
142        self.assertEqual(expected, await asyncio.gather(*[task1, task2]))
143
144    async def test_call_code_cancelable(self):
145        coro_started = asyncio.Event()
146        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
147
148        async def coro():
149            coro_started.set()
150            await call.code()
151
152        task = self.loop.create_task(coro())
153        await coro_started.wait()
154        task.cancel()
155
156        # Test that code can still be asked thought
157        # a cancellation happened with the previous task
158        self.assertEqual(grpc.StatusCode.OK, await call.code())
159
160    async def test_call_code_multiple_waiters(self):
161        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
162
163        async def coro():
164            return await call.code()
165
166        task1 = self.loop.create_task(coro())
167        task2 = self.loop.create_task(coro())
168
169        await call
170
171        self.assertEqual(
172            [grpc.StatusCode.OK, grpc.StatusCode.OK],
173            await asyncio.gather(task1, task2),
174        )
175
176    async def test_cancel_unary_unary(self):
177        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
178
179        self.assertFalse(call.cancelled())
180
181        self.assertTrue(call.cancel())
182        self.assertFalse(call.cancel())
183
184        with self.assertRaises(asyncio.CancelledError):
185            await call
186
187        # The info in the RpcError should match the info in Call object.
188        self.assertTrue(call.cancelled())
189        self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
190        self.assertEqual(
191            await call.details(), "Locally cancelled by application!"
192        )
193
194    async def test_cancel_unary_unary_in_task(self):
195        coro_started = asyncio.Event()
196        call = self._stub.EmptyCall(messages_pb2.SimpleRequest())
197
198        async def another_coro():
199            coro_started.set()
200            await call
201
202        task = self.loop.create_task(another_coro())
203        await coro_started.wait()
204
205        self.assertFalse(task.done())
206        task.cancel()
207
208        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
209
210        with self.assertRaises(asyncio.CancelledError):
211            await task
212
213    async def test_passing_credentials_fails_over_insecure_channel(self):
214        call_credentials = grpc.composite_call_credentials(
215            grpc.access_token_call_credentials("abc"),
216            grpc.access_token_call_credentials("def"),
217        )
218        with self.assertRaisesRegex(
219            aio.UsageError, "Call credentials are only valid on secure channels"
220        ):
221            self._stub.UnaryCall(
222                messages_pb2.SimpleRequest(), credentials=call_credentials
223            )
224
225
226class TestUnaryStreamCall(_MulticallableTestMixin, AioTestBase):
227    async def test_call_rpc_error(self):
228        channel = aio.insecure_channel(UNREACHABLE_TARGET)
229        request = messages_pb2.StreamingOutputCallRequest()
230        stub = test_pb2_grpc.TestServiceStub(channel)
231        call = stub.StreamingOutputCall(request)
232
233        with self.assertRaises(aio.AioRpcError) as exception_context:
234            async for response in call:
235                pass
236
237        self.assertEqual(
238            grpc.StatusCode.UNAVAILABLE, exception_context.exception.code()
239        )
240
241        self.assertTrue(call.done())
242        self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code())
243        await channel.close()
244
245    async def test_cancel_unary_stream(self):
246        # Prepares the request
247        request = messages_pb2.StreamingOutputCallRequest()
248        for _ in range(_NUM_STREAM_RESPONSES):
249            request.response_parameters.append(
250                messages_pb2.ResponseParameters(
251                    size=_RESPONSE_PAYLOAD_SIZE,
252                    interval_us=_RESPONSE_INTERVAL_US,
253                )
254            )
255
256        # Invokes the actual RPC
257        call = self._stub.StreamingOutputCall(request)
258        self.assertFalse(call.cancelled())
259
260        response = await call.read()
261        self.assertIs(type(response), messages_pb2.StreamingOutputCallResponse)
262        self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
263
264        self.assertTrue(call.cancel())
265        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
266        self.assertEqual(
267            _LOCAL_CANCEL_DETAILS_EXPECTATION, await call.details()
268        )
269        self.assertFalse(call.cancel())
270
271        with self.assertRaises(asyncio.CancelledError):
272            await call.read()
273        self.assertTrue(call.cancelled())
274
275    async def test_multiple_cancel_unary_stream(self):
276        # Prepares the request
277        request = messages_pb2.StreamingOutputCallRequest()
278        for _ in range(_NUM_STREAM_RESPONSES):
279            request.response_parameters.append(
280                messages_pb2.ResponseParameters(
281                    size=_RESPONSE_PAYLOAD_SIZE,
282                    interval_us=_RESPONSE_INTERVAL_US,
283                )
284            )
285
286        # Invokes the actual RPC
287        call = self._stub.StreamingOutputCall(request)
288        self.assertFalse(call.cancelled())
289
290        response = await call.read()
291        self.assertIs(type(response), messages_pb2.StreamingOutputCallResponse)
292        self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
293
294        self.assertTrue(call.cancel())
295        self.assertFalse(call.cancel())
296        self.assertFalse(call.cancel())
297        self.assertFalse(call.cancel())
298
299        with self.assertRaises(asyncio.CancelledError):
300            await call.read()
301
302    async def test_early_cancel_unary_stream(self):
303        """Test cancellation before receiving messages."""
304        # Prepares the request
305        request = messages_pb2.StreamingOutputCallRequest()
306        for _ in range(_NUM_STREAM_RESPONSES):
307            request.response_parameters.append(
308                messages_pb2.ResponseParameters(
309                    size=_RESPONSE_PAYLOAD_SIZE,
310                    interval_us=_RESPONSE_INTERVAL_US,
311                )
312            )
313
314        # Invokes the actual RPC
315        call = self._stub.StreamingOutputCall(request)
316
317        self.assertFalse(call.cancelled())
318        self.assertTrue(call.cancel())
319        self.assertFalse(call.cancel())
320
321        with self.assertRaises(asyncio.CancelledError):
322            await call.read()
323
324        self.assertTrue(call.cancelled())
325
326        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
327        self.assertEqual(
328            _LOCAL_CANCEL_DETAILS_EXPECTATION, await call.details()
329        )
330
331    async def test_late_cancel_unary_stream(self):
332        """Test cancellation after received all messages."""
333        # Prepares the request
334        request = messages_pb2.StreamingOutputCallRequest()
335        for _ in range(_NUM_STREAM_RESPONSES):
336            request.response_parameters.append(
337                messages_pb2.ResponseParameters(
338                    size=_RESPONSE_PAYLOAD_SIZE,
339                )
340            )
341
342        # Invokes the actual RPC
343        call = self._stub.StreamingOutputCall(request)
344
345        for _ in range(_NUM_STREAM_RESPONSES):
346            response = await call.read()
347            self.assertIs(
348                type(response), messages_pb2.StreamingOutputCallResponse
349            )
350            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
351
352        # After all messages received, it is possible that the final state
353        # is received or on its way. It's basically a data race, so our
354        # expectation here is do not crash :)
355        call.cancel()
356        self.assertIn(
357            await call.code(), [grpc.StatusCode.OK, grpc.StatusCode.CANCELLED]
358        )
359
360    async def test_too_many_reads_unary_stream(self):
361        """Test calling read after received all messages fails."""
362        # Prepares the request
363        request = messages_pb2.StreamingOutputCallRequest()
364        for _ in range(_NUM_STREAM_RESPONSES):
365            request.response_parameters.append(
366                messages_pb2.ResponseParameters(
367                    size=_RESPONSE_PAYLOAD_SIZE,
368                )
369            )
370
371        # Invokes the actual RPC
372        call = self._stub.StreamingOutputCall(request)
373
374        for _ in range(_NUM_STREAM_RESPONSES):
375            response = await call.read()
376            self.assertIs(
377                type(response), messages_pb2.StreamingOutputCallResponse
378            )
379            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
380        self.assertIs(await call.read(), aio.EOF)
381
382        # After the RPC is finished, further reads will lead to exception.
383        self.assertEqual(await call.code(), grpc.StatusCode.OK)
384        self.assertIs(await call.read(), aio.EOF)
385
386    async def test_unary_stream_async_generator(self):
387        """Sunny day test case for unary_stream."""
388        # Prepares the request
389        request = messages_pb2.StreamingOutputCallRequest()
390        for _ in range(_NUM_STREAM_RESPONSES):
391            request.response_parameters.append(
392                messages_pb2.ResponseParameters(
393                    size=_RESPONSE_PAYLOAD_SIZE,
394                )
395            )
396
397        # Invokes the actual RPC
398        call = self._stub.StreamingOutputCall(request)
399        self.assertFalse(call.cancelled())
400
401        async for response in call:
402            self.assertIs(
403                type(response), messages_pb2.StreamingOutputCallResponse
404            )
405            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
406
407        self.assertEqual(await call.code(), grpc.StatusCode.OK)
408
409    async def test_cancel_unary_stream_with_many_interleavings(self):
410        """A cheap alternative to a structured fuzzer.
411
412        Certain classes of error only appear for very specific interleavings of
413        coroutines. Rather than inserting semi-private asyncio.Events throughout
414        the implementation on which to coordinate and explicitly waiting on those
415        in tests, we instead search for bugs over the space of interleavings by
416        stochastically varying the durations of certain events within the test.
417        """
418
419        # We range over several orders of magnitude to ensure that switching platforms
420        # (i.e. to slow CI machines) does not result in this test becoming a no-op.
421        sleep_ranges = (10.0**-i for i in range(1, 4))
422        for sleep_range in sleep_ranges:
423            for _ in range(_NONDETERMINISTIC_ITERATIONS):
424                interval_us = random.randrange(
425                    _NONDETERMINISTIC_SERVER_SLEEP_MAX_US
426                )
427                sleep_secs = sleep_range * random.random()
428
429                coro_started = asyncio.Event()
430
431                # Configs the server method to block forever
432                request = messages_pb2.StreamingOutputCallRequest()
433                request.response_parameters.append(
434                    messages_pb2.ResponseParameters(
435                        size=1,
436                        interval_us=interval_us,
437                    )
438                )
439
440                # Invokes the actual RPC
441                call = self._stub.StreamingOutputCall(request)
442
443                unhandled_error = False
444
445                async def another_coro():
446                    nonlocal unhandled_error
447                    coro_started.set()
448                    try:
449                        await call.read()
450                    except asyncio.CancelledError:
451                        pass
452                    except Exception as e:
453                        unhandled_error = True
454                        raise
455
456                task = self.loop.create_task(another_coro())
457                await coro_started.wait()
458                await asyncio.sleep(sleep_secs)
459
460                task.cancel()
461
462                try:
463                    await task
464                except asyncio.CancelledError:
465                    pass
466
467                self.assertFalse(unhandled_error)
468
469    async def test_cancel_unary_stream_in_task_using_read(self):
470        coro_started = asyncio.Event()
471
472        # Configs the server method to block forever
473        request = messages_pb2.StreamingOutputCallRequest()
474        request.response_parameters.append(
475            messages_pb2.ResponseParameters(
476                size=_RESPONSE_PAYLOAD_SIZE,
477                interval_us=_INFINITE_INTERVAL_US,
478            )
479        )
480
481        # Invokes the actual RPC
482        call = self._stub.StreamingOutputCall(request)
483
484        async def another_coro():
485            coro_started.set()
486            await call.read()
487
488        task = self.loop.create_task(another_coro())
489        await coro_started.wait()
490
491        self.assertFalse(task.done())
492        task.cancel()
493
494        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
495
496        with self.assertRaises(asyncio.CancelledError):
497            await task
498
499    async def test_cancel_unary_stream_in_task_using_async_for(self):
500        coro_started = asyncio.Event()
501
502        # Configs the server method to block forever
503        request = messages_pb2.StreamingOutputCallRequest()
504        request.response_parameters.append(
505            messages_pb2.ResponseParameters(
506                size=_RESPONSE_PAYLOAD_SIZE,
507                interval_us=_INFINITE_INTERVAL_US,
508            )
509        )
510
511        # Invokes the actual RPC
512        call = self._stub.StreamingOutputCall(request)
513
514        async def another_coro():
515            coro_started.set()
516            async for _ in call:
517                pass
518
519        task = self.loop.create_task(another_coro())
520        await coro_started.wait()
521
522        self.assertFalse(task.done())
523        task.cancel()
524
525        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
526
527        with self.assertRaises(asyncio.CancelledError):
528            await task
529
530    async def test_time_remaining(self):
531        request = messages_pb2.StreamingOutputCallRequest()
532        # First message comes back immediately
533        request.response_parameters.append(
534            messages_pb2.ResponseParameters(
535                size=_RESPONSE_PAYLOAD_SIZE,
536            )
537        )
538        # Second message comes back after a unit of wait time
539        request.response_parameters.append(
540            messages_pb2.ResponseParameters(
541                size=_RESPONSE_PAYLOAD_SIZE,
542                interval_us=_RESPONSE_INTERVAL_US,
543            )
544        )
545
546        call = self._stub.StreamingOutputCall(
547            request, timeout=_SHORT_TIMEOUT_S * 2
548        )
549
550        response = await call.read()
551        self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
552
553        # Should be around the same as the timeout
554        remained_time = call.time_remaining()
555        self.assertGreater(remained_time, _SHORT_TIMEOUT_S * 3 / 2)
556        self.assertLess(remained_time, _SHORT_TIMEOUT_S * 5 / 2)
557
558        response = await call.read()
559        self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
560
561        # Should be around the timeout minus a unit of wait time
562        remained_time = call.time_remaining()
563        self.assertGreater(remained_time, _SHORT_TIMEOUT_S / 2)
564        self.assertLess(remained_time, _SHORT_TIMEOUT_S * 3 / 2)
565
566        self.assertEqual(grpc.StatusCode.OK, await call.code())
567
568    async def test_empty_responses(self):
569        # Prepares the request
570        request = messages_pb2.StreamingOutputCallRequest()
571        for _ in range(_NUM_STREAM_RESPONSES):
572            request.response_parameters.append(
573                messages_pb2.ResponseParameters()
574            )
575
576        # Invokes the actual RPC
577        call = self._stub.StreamingOutputCall(request)
578
579        for _ in range(_NUM_STREAM_RESPONSES):
580            response = await call.read()
581            self.assertIs(
582                type(response), messages_pb2.StreamingOutputCallResponse
583            )
584            self.assertEqual(b"", response.SerializeToString())
585
586        self.assertEqual(grpc.StatusCode.OK, await call.code())
587
588
589class TestStreamUnaryCall(_MulticallableTestMixin, AioTestBase):
590    async def test_cancel_stream_unary(self):
591        call = self._stub.StreamingInputCall()
592
593        # Prepares the request
594        payload = messages_pb2.Payload(body=b"\0" * _REQUEST_PAYLOAD_SIZE)
595        request = messages_pb2.StreamingInputCallRequest(payload=payload)
596
597        # Sends out requests
598        for _ in range(_NUM_STREAM_RESPONSES):
599            await call.write(request)
600
601        # Cancels the RPC
602        self.assertFalse(call.done())
603        self.assertFalse(call.cancelled())
604        self.assertTrue(call.cancel())
605        self.assertTrue(call.cancelled())
606
607        await call.done_writing()
608
609        with self.assertRaises(asyncio.CancelledError):
610            await call
611
612    async def test_early_cancel_stream_unary(self):
613        call = self._stub.StreamingInputCall()
614
615        # Cancels the RPC
616        self.assertFalse(call.done())
617        self.assertFalse(call.cancelled())
618        self.assertTrue(call.cancel())
619        self.assertTrue(call.cancelled())
620
621        with self.assertRaises(asyncio.InvalidStateError):
622            await call.write(messages_pb2.StreamingInputCallRequest())
623
624        # Should be no-op
625        await call.done_writing()
626
627        with self.assertRaises(asyncio.CancelledError):
628            await call
629
630    async def test_write_after_done_writing(self):
631        call = self._stub.StreamingInputCall()
632
633        # Prepares the request
634        payload = messages_pb2.Payload(body=b"\0" * _REQUEST_PAYLOAD_SIZE)
635        request = messages_pb2.StreamingInputCallRequest(payload=payload)
636
637        # Sends out requests
638        for _ in range(_NUM_STREAM_RESPONSES):
639            await call.write(request)
640
641        # Should be no-op
642        await call.done_writing()
643
644        with self.assertRaises(asyncio.InvalidStateError):
645            await call.write(messages_pb2.StreamingInputCallRequest())
646
647        response = await call
648        self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
649        self.assertEqual(
650            _NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
651            response.aggregated_payload_size,
652        )
653
654        self.assertEqual(await call.code(), grpc.StatusCode.OK)
655
656    async def test_error_in_async_generator(self):
657        # Server will pause between responses
658        request = messages_pb2.StreamingOutputCallRequest()
659        request.response_parameters.append(
660            messages_pb2.ResponseParameters(
661                size=_RESPONSE_PAYLOAD_SIZE,
662                interval_us=_RESPONSE_INTERVAL_US,
663            )
664        )
665
666        # We expect the request iterator to receive the exception
667        request_iterator_received_the_exception = asyncio.Event()
668
669        async def request_iterator():
670            with self.assertRaises(asyncio.CancelledError):
671                for _ in range(_NUM_STREAM_RESPONSES):
672                    yield request
673                    await asyncio.sleep(_SHORT_TIMEOUT_S)
674            request_iterator_received_the_exception.set()
675
676        call = self._stub.StreamingInputCall(request_iterator())
677
678        # Cancel the RPC after at least one response
679        async def cancel_later():
680            await asyncio.sleep(_SHORT_TIMEOUT_S * 2)
681            call.cancel()
682
683        cancel_later_task = self.loop.create_task(cancel_later())
684
685        with self.assertRaises(asyncio.CancelledError):
686            await call
687
688        await request_iterator_received_the_exception.wait()
689
690        # No failures in the cancel later task!
691        await cancel_later_task
692
693    async def test_normal_iterable_requests(self):
694        # Prepares the request
695        payload = messages_pb2.Payload(body=b"\0" * _REQUEST_PAYLOAD_SIZE)
696        request = messages_pb2.StreamingInputCallRequest(payload=payload)
697        requests = [request] * _NUM_STREAM_RESPONSES
698
699        # Sends out requests
700        call = self._stub.StreamingInputCall(requests)
701
702        # RPC should succeed
703        response = await call
704        self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
705        self.assertEqual(
706            _NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
707            response.aggregated_payload_size,
708        )
709
710        self.assertEqual(await call.code(), grpc.StatusCode.OK)
711
712    async def test_call_rpc_error(self):
713        async with aio.insecure_channel(UNREACHABLE_TARGET) as channel:
714            stub = test_pb2_grpc.TestServiceStub(channel)
715
716            # The error should be raised automatically without any traffic.
717            call = stub.StreamingInputCall()
718            with self.assertRaises(aio.AioRpcError) as exception_context:
719                await call
720
721            self.assertEqual(
722                grpc.StatusCode.UNAVAILABLE, exception_context.exception.code()
723            )
724
725            self.assertTrue(call.done())
726            self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code())
727
728    async def test_timeout(self):
729        call = self._stub.StreamingInputCall(timeout=_SHORT_TIMEOUT_S)
730
731        # The error should be raised automatically without any traffic.
732        with self.assertRaises(aio.AioRpcError) as exception_context:
733            await call
734
735        rpc_error = exception_context.exception
736        self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, rpc_error.code())
737        self.assertTrue(call.done())
738        self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, await call.code())
739
740
741# Prepares the request that stream in a ping-pong manner.
742_STREAM_OUTPUT_REQUEST_ONE_RESPONSE = messages_pb2.StreamingOutputCallRequest()
743_STREAM_OUTPUT_REQUEST_ONE_RESPONSE.response_parameters.append(
744    messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)
745)
746_STREAM_OUTPUT_REQUEST_ONE_EMPTY_RESPONSE = (
747    messages_pb2.StreamingOutputCallRequest()
748)
749_STREAM_OUTPUT_REQUEST_ONE_EMPTY_RESPONSE.response_parameters.append(
750    messages_pb2.ResponseParameters()
751)
752
753
754class TestStreamStreamCall(_MulticallableTestMixin, AioTestBase):
755    async def test_cancel(self):
756        # Invokes the actual RPC
757        call = self._stub.FullDuplexCall()
758
759        for _ in range(_NUM_STREAM_RESPONSES):
760            await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
761            response = await call.read()
762            self.assertIsInstance(
763                response, messages_pb2.StreamingOutputCallResponse
764            )
765            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
766
767        # Cancels the RPC
768        self.assertFalse(call.done())
769        self.assertFalse(call.cancelled())
770        self.assertTrue(call.cancel())
771        self.assertTrue(call.cancelled())
772        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
773
774    async def test_cancel_with_pending_read(self):
775        call = self._stub.FullDuplexCall()
776
777        await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
778
779        # Cancels the RPC
780        self.assertFalse(call.done())
781        self.assertFalse(call.cancelled())
782        self.assertTrue(call.cancel())
783        self.assertTrue(call.cancelled())
784        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
785
786    async def test_cancel_with_ongoing_read(self):
787        call = self._stub.FullDuplexCall()
788        coro_started = asyncio.Event()
789
790        async def read_coro():
791            coro_started.set()
792            await call.read()
793
794        read_task = self.loop.create_task(read_coro())
795        await coro_started.wait()
796        self.assertFalse(read_task.done())
797
798        # Cancels the RPC
799        self.assertFalse(call.done())
800        self.assertFalse(call.cancelled())
801        self.assertTrue(call.cancel())
802        self.assertTrue(call.cancelled())
803        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
804
805    async def test_early_cancel(self):
806        call = self._stub.FullDuplexCall()
807
808        # Cancels the RPC
809        self.assertFalse(call.done())
810        self.assertFalse(call.cancelled())
811        self.assertTrue(call.cancel())
812        self.assertTrue(call.cancelled())
813        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
814
815    async def test_cancel_after_done_writing(self):
816        call = self._stub.FullDuplexCall()
817        request_with_delay = messages_pb2.StreamingOutputCallRequest()
818        request_with_delay.response_parameters.append(
819            messages_pb2.ResponseParameters(interval_us=10000)
820        )
821        await call.write(request_with_delay)
822        await call.write(request_with_delay)
823        await call.done_writing()
824
825        # Cancels the RPC
826        self.assertFalse(call.cancelled())
827        self.assertTrue(call.cancel())
828        self.assertTrue(call.cancelled())
829        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
830
831    async def test_late_cancel(self):
832        call = self._stub.FullDuplexCall()
833        await call.done_writing()
834        self.assertEqual(grpc.StatusCode.OK, await call.code())
835
836        # Cancels the RPC
837        self.assertTrue(call.done())
838        self.assertFalse(call.cancelled())
839        self.assertFalse(call.cancel())
840        self.assertFalse(call.cancelled())
841
842        # Status is still OK
843        self.assertEqual(grpc.StatusCode.OK, await call.code())
844
845    async def test_async_generator(self):
846        async def request_generator():
847            yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE
848            yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE
849
850        call = self._stub.FullDuplexCall(request_generator())
851        async for response in call:
852            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
853
854        self.assertEqual(await call.code(), grpc.StatusCode.OK)
855
856    async def test_too_many_reads(self):
857        async def request_generator():
858            for _ in range(_NUM_STREAM_RESPONSES):
859                yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE
860
861        call = self._stub.FullDuplexCall(request_generator())
862        for _ in range(_NUM_STREAM_RESPONSES):
863            response = await call.read()
864            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
865        self.assertIs(await call.read(), aio.EOF)
866
867        self.assertEqual(await call.code(), grpc.StatusCode.OK)
868        # After the RPC finished, the read should also produce EOF
869        self.assertIs(await call.read(), aio.EOF)
870
871    async def test_read_write_after_done_writing(self):
872        call = self._stub.FullDuplexCall()
873
874        # Writes two requests, and pending two requests
875        await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
876        await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
877        await call.done_writing()
878
879        # Further write should fail
880        with self.assertRaises(asyncio.InvalidStateError):
881            await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
882
883        # But read should be unaffected
884        response = await call.read()
885        self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
886        response = await call.read()
887        self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
888
889        self.assertEqual(await call.code(), grpc.StatusCode.OK)
890
891    async def test_error_in_async_generator(self):
892        # Server will pause between responses
893        request = messages_pb2.StreamingOutputCallRequest()
894        request.response_parameters.append(
895            messages_pb2.ResponseParameters(
896                size=_RESPONSE_PAYLOAD_SIZE,
897                interval_us=_RESPONSE_INTERVAL_US,
898            )
899        )
900
901        # We expect the request iterator to receive the exception
902        request_iterator_received_the_exception = asyncio.Event()
903
904        async def request_iterator():
905            with self.assertRaises(asyncio.CancelledError):
906                for _ in range(_NUM_STREAM_RESPONSES):
907                    yield request
908                    await asyncio.sleep(_SHORT_TIMEOUT_S)
909            request_iterator_received_the_exception.set()
910
911        call = self._stub.FullDuplexCall(request_iterator())
912
913        # Cancel the RPC after at least one response
914        async def cancel_later():
915            await asyncio.sleep(_SHORT_TIMEOUT_S * 2)
916            call.cancel()
917
918        cancel_later_task = self.loop.create_task(cancel_later())
919
920        with self.assertRaises(asyncio.CancelledError):
921            async for response in call:
922                self.assertEqual(
923                    _RESPONSE_PAYLOAD_SIZE, len(response.payload.body)
924                )
925
926        await request_iterator_received_the_exception.wait()
927
928        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
929        # No failures in the cancel later task!
930        await cancel_later_task
931
932    async def test_normal_iterable_requests(self):
933        requests = [_STREAM_OUTPUT_REQUEST_ONE_RESPONSE] * _NUM_STREAM_RESPONSES
934
935        call = self._stub.FullDuplexCall(iter(requests))
936        async for response in call:
937            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
938
939        self.assertEqual(await call.code(), grpc.StatusCode.OK)
940
941    async def test_empty_ping_pong(self):
942        call = self._stub.FullDuplexCall()
943        for _ in range(_NUM_STREAM_RESPONSES):
944            await call.write(_STREAM_OUTPUT_REQUEST_ONE_EMPTY_RESPONSE)
945            response = await call.read()
946            self.assertEqual(b"", response.SerializeToString())
947        await call.done_writing()
948        self.assertEqual(await call.code(), grpc.StatusCode.OK)
949
950
951if __name__ == "__main__":
952    logging.basicConfig(level=logging.DEBUG)
953    unittest.main(verbosity=2)
954