• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The gRPC Authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests behavior of the Call classes."""
15
16import asyncio
17import logging
18import unittest
19import datetime
20
21import grpc
22from grpc.experimental import aio
23
24from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
25from tests_aio.unit._test_base import AioTestBase
26from tests_aio.unit._test_server import start_test_server
27from tests_aio.unit._constants import UNREACHABLE_TARGET
28
29_SHORT_TIMEOUT_S = datetime.timedelta(seconds=1).total_seconds()
30
31_NUM_STREAM_RESPONSES = 5
32_RESPONSE_PAYLOAD_SIZE = 42
33_REQUEST_PAYLOAD_SIZE = 7
34_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!'
35_RESPONSE_INTERVAL_US = int(_SHORT_TIMEOUT_S * 1000 * 1000)
36_INFINITE_INTERVAL_US = 2**31 - 1
37
38
39class _MulticallableTestMixin():
40
41    async def setUp(self):
42        address, self._server = await start_test_server()
43        self._channel = aio.insecure_channel(address)
44        self._stub = test_pb2_grpc.TestServiceStub(self._channel)
45
46    async def tearDown(self):
47        await self._channel.close()
48        await self._server.stop(None)
49
50
51class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase):
52
53    async def test_call_to_string(self):
54        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
55
56        self.assertTrue(str(call) is not None)
57        self.assertTrue(repr(call) is not None)
58
59        await call
60
61        self.assertTrue(str(call) is not None)
62        self.assertTrue(repr(call) is not None)
63
64    async def test_call_ok(self):
65        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
66
67        self.assertFalse(call.done())
68
69        response = await call
70
71        self.assertTrue(call.done())
72        self.assertIsInstance(response, messages_pb2.SimpleResponse)
73        self.assertEqual(await call.code(), grpc.StatusCode.OK)
74
75        # Response is cached at call object level, reentrance
76        # returns again the same response
77        response_retry = await call
78        self.assertIs(response, response_retry)
79
80    async def test_call_rpc_error(self):
81        async with aio.insecure_channel(UNREACHABLE_TARGET) as channel:
82            stub = test_pb2_grpc.TestServiceStub(channel)
83
84            call = stub.UnaryCall(messages_pb2.SimpleRequest())
85
86            with self.assertRaises(aio.AioRpcError) as exception_context:
87                await call
88
89            self.assertEqual(grpc.StatusCode.UNAVAILABLE,
90                             exception_context.exception.code())
91
92            self.assertTrue(call.done())
93            self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code())
94
95    async def test_call_code_awaitable(self):
96        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
97        self.assertEqual(await call.code(), grpc.StatusCode.OK)
98
99    async def test_call_details_awaitable(self):
100        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
101        self.assertEqual('', await call.details())
102
103    async def test_call_initial_metadata_awaitable(self):
104        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
105        self.assertEqual(aio.Metadata(), await call.initial_metadata())
106
107    async def test_call_trailing_metadata_awaitable(self):
108        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
109        self.assertEqual(aio.Metadata(), await call.trailing_metadata())
110
111    async def test_call_initial_metadata_cancelable(self):
112        coro_started = asyncio.Event()
113        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
114
115        async def coro():
116            coro_started.set()
117            await call.initial_metadata()
118
119        task = self.loop.create_task(coro())
120        await coro_started.wait()
121        task.cancel()
122
123        # Test that initial metadata can still be asked thought
124        # a cancellation happened with the previous task
125        self.assertEqual(aio.Metadata(), await call.initial_metadata())
126
127    async def test_call_initial_metadata_multiple_waiters(self):
128        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
129
130        async def coro():
131            return await call.initial_metadata()
132
133        task1 = self.loop.create_task(coro())
134        task2 = self.loop.create_task(coro())
135
136        await call
137        expected = [aio.Metadata() for _ in range(2)]
138        self.assertEqual(expected, await asyncio.gather(*[task1, task2]))
139
140    async def test_call_code_cancelable(self):
141        coro_started = asyncio.Event()
142        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
143
144        async def coro():
145            coro_started.set()
146            await call.code()
147
148        task = self.loop.create_task(coro())
149        await coro_started.wait()
150        task.cancel()
151
152        # Test that code can still be asked thought
153        # a cancellation happened with the previous task
154        self.assertEqual(grpc.StatusCode.OK, await call.code())
155
156    async def test_call_code_multiple_waiters(self):
157        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
158
159        async def coro():
160            return await call.code()
161
162        task1 = self.loop.create_task(coro())
163        task2 = self.loop.create_task(coro())
164
165        await call
166
167        self.assertEqual([grpc.StatusCode.OK, grpc.StatusCode.OK], await
168                         asyncio.gather(task1, task2))
169
170    async def test_cancel_unary_unary(self):
171        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
172
173        self.assertFalse(call.cancelled())
174
175        self.assertTrue(call.cancel())
176        self.assertFalse(call.cancel())
177
178        with self.assertRaises(asyncio.CancelledError):
179            await call
180
181        # The info in the RpcError should match the info in Call object.
182        self.assertTrue(call.cancelled())
183        self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
184        self.assertEqual(await call.details(),
185                         'Locally cancelled by application!')
186
187    async def test_cancel_unary_unary_in_task(self):
188        coro_started = asyncio.Event()
189        call = self._stub.EmptyCall(messages_pb2.SimpleRequest())
190
191        async def another_coro():
192            coro_started.set()
193            await call
194
195        task = self.loop.create_task(another_coro())
196        await coro_started.wait()
197
198        self.assertFalse(task.done())
199        task.cancel()
200
201        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
202
203        with self.assertRaises(asyncio.CancelledError):
204            await task
205
206    async def test_passing_credentials_fails_over_insecure_channel(self):
207        call_credentials = grpc.composite_call_credentials(
208            grpc.access_token_call_credentials("abc"),
209            grpc.access_token_call_credentials("def"),
210        )
211        with self.assertRaisesRegex(
212                aio.UsageError,
213                "Call credentials are only valid on secure channels"):
214            self._stub.UnaryCall(messages_pb2.SimpleRequest(),
215                                 credentials=call_credentials)
216
217
218class TestUnaryStreamCall(_MulticallableTestMixin, AioTestBase):
219
220    async def test_call_rpc_error(self):
221        channel = aio.insecure_channel(UNREACHABLE_TARGET)
222        request = messages_pb2.StreamingOutputCallRequest()
223        stub = test_pb2_grpc.TestServiceStub(channel)
224        call = stub.StreamingOutputCall(request)
225
226        with self.assertRaises(aio.AioRpcError) as exception_context:
227            async for response in call:
228                pass
229
230        self.assertEqual(grpc.StatusCode.UNAVAILABLE,
231                         exception_context.exception.code())
232
233        self.assertTrue(call.done())
234        self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code())
235        await channel.close()
236
237    async def test_cancel_unary_stream(self):
238        # Prepares the request
239        request = messages_pb2.StreamingOutputCallRequest()
240        for _ in range(_NUM_STREAM_RESPONSES):
241            request.response_parameters.append(
242                messages_pb2.ResponseParameters(
243                    size=_RESPONSE_PAYLOAD_SIZE,
244                    interval_us=_RESPONSE_INTERVAL_US,
245                ))
246
247        # Invokes the actual RPC
248        call = self._stub.StreamingOutputCall(request)
249        self.assertFalse(call.cancelled())
250
251        response = await call.read()
252        self.assertIs(type(response), messages_pb2.StreamingOutputCallResponse)
253        self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
254
255        self.assertTrue(call.cancel())
256        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
257        self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await
258                         call.details())
259        self.assertFalse(call.cancel())
260
261        with self.assertRaises(asyncio.CancelledError):
262            await call.read()
263        self.assertTrue(call.cancelled())
264
265    async def test_multiple_cancel_unary_stream(self):
266        # Prepares the request
267        request = messages_pb2.StreamingOutputCallRequest()
268        for _ in range(_NUM_STREAM_RESPONSES):
269            request.response_parameters.append(
270                messages_pb2.ResponseParameters(
271                    size=_RESPONSE_PAYLOAD_SIZE,
272                    interval_us=_RESPONSE_INTERVAL_US,
273                ))
274
275        # Invokes the actual RPC
276        call = self._stub.StreamingOutputCall(request)
277        self.assertFalse(call.cancelled())
278
279        response = await call.read()
280        self.assertIs(type(response), messages_pb2.StreamingOutputCallResponse)
281        self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
282
283        self.assertTrue(call.cancel())
284        self.assertFalse(call.cancel())
285        self.assertFalse(call.cancel())
286        self.assertFalse(call.cancel())
287
288        with self.assertRaises(asyncio.CancelledError):
289            await call.read()
290
291    async def test_early_cancel_unary_stream(self):
292        """Test cancellation before receiving messages."""
293        # Prepares the request
294        request = messages_pb2.StreamingOutputCallRequest()
295        for _ in range(_NUM_STREAM_RESPONSES):
296            request.response_parameters.append(
297                messages_pb2.ResponseParameters(
298                    size=_RESPONSE_PAYLOAD_SIZE,
299                    interval_us=_RESPONSE_INTERVAL_US,
300                ))
301
302        # Invokes the actual RPC
303        call = self._stub.StreamingOutputCall(request)
304
305        self.assertFalse(call.cancelled())
306        self.assertTrue(call.cancel())
307        self.assertFalse(call.cancel())
308
309        with self.assertRaises(asyncio.CancelledError):
310            await call.read()
311
312        self.assertTrue(call.cancelled())
313
314        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
315        self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await
316                         call.details())
317
318    async def test_late_cancel_unary_stream(self):
319        """Test cancellation after received all messages."""
320        # Prepares the request
321        request = messages_pb2.StreamingOutputCallRequest()
322        for _ in range(_NUM_STREAM_RESPONSES):
323            request.response_parameters.append(
324                messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,))
325
326        # Invokes the actual RPC
327        call = self._stub.StreamingOutputCall(request)
328
329        for _ in range(_NUM_STREAM_RESPONSES):
330            response = await call.read()
331            self.assertIs(type(response),
332                          messages_pb2.StreamingOutputCallResponse)
333            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
334
335        # After all messages received, it is possible that the final state
336        # is received or on its way. It's basically a data race, so our
337        # expectation here is do not crash :)
338        call.cancel()
339        self.assertIn(await call.code(),
340                      [grpc.StatusCode.OK, grpc.StatusCode.CANCELLED])
341
342    async def test_too_many_reads_unary_stream(self):
343        """Test calling read after received all messages fails."""
344        # Prepares the request
345        request = messages_pb2.StreamingOutputCallRequest()
346        for _ in range(_NUM_STREAM_RESPONSES):
347            request.response_parameters.append(
348                messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,))
349
350        # Invokes the actual RPC
351        call = self._stub.StreamingOutputCall(request)
352
353        for _ in range(_NUM_STREAM_RESPONSES):
354            response = await call.read()
355            self.assertIs(type(response),
356                          messages_pb2.StreamingOutputCallResponse)
357            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
358        self.assertIs(await call.read(), aio.EOF)
359
360        # After the RPC is finished, further reads will lead to exception.
361        self.assertEqual(await call.code(), grpc.StatusCode.OK)
362        self.assertIs(await call.read(), aio.EOF)
363
364    async def test_unary_stream_async_generator(self):
365        """Sunny day test case for unary_stream."""
366        # Prepares the request
367        request = messages_pb2.StreamingOutputCallRequest()
368        for _ in range(_NUM_STREAM_RESPONSES):
369            request.response_parameters.append(
370                messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,))
371
372        # Invokes the actual RPC
373        call = self._stub.StreamingOutputCall(request)
374        self.assertFalse(call.cancelled())
375
376        async for response in call:
377            self.assertIs(type(response),
378                          messages_pb2.StreamingOutputCallResponse)
379            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
380
381        self.assertEqual(await call.code(), grpc.StatusCode.OK)
382
383    async def test_cancel_unary_stream_in_task_using_read(self):
384        coro_started = asyncio.Event()
385
386        # Configs the server method to block forever
387        request = messages_pb2.StreamingOutputCallRequest()
388        request.response_parameters.append(
389            messages_pb2.ResponseParameters(
390                size=_RESPONSE_PAYLOAD_SIZE,
391                interval_us=_INFINITE_INTERVAL_US,
392            ))
393
394        # Invokes the actual RPC
395        call = self._stub.StreamingOutputCall(request)
396
397        async def another_coro():
398            coro_started.set()
399            await call.read()
400
401        task = self.loop.create_task(another_coro())
402        await coro_started.wait()
403
404        self.assertFalse(task.done())
405        task.cancel()
406
407        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
408
409        with self.assertRaises(asyncio.CancelledError):
410            await task
411
412    async def test_cancel_unary_stream_in_task_using_async_for(self):
413        coro_started = asyncio.Event()
414
415        # Configs the server method to block forever
416        request = messages_pb2.StreamingOutputCallRequest()
417        request.response_parameters.append(
418            messages_pb2.ResponseParameters(
419                size=_RESPONSE_PAYLOAD_SIZE,
420                interval_us=_INFINITE_INTERVAL_US,
421            ))
422
423        # Invokes the actual RPC
424        call = self._stub.StreamingOutputCall(request)
425
426        async def another_coro():
427            coro_started.set()
428            async for _ in call:
429                pass
430
431        task = self.loop.create_task(another_coro())
432        await coro_started.wait()
433
434        self.assertFalse(task.done())
435        task.cancel()
436
437        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
438
439        with self.assertRaises(asyncio.CancelledError):
440            await task
441
442    async def test_time_remaining(self):
443        request = messages_pb2.StreamingOutputCallRequest()
444        # First message comes back immediately
445        request.response_parameters.append(
446            messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,))
447        # Second message comes back after a unit of wait time
448        request.response_parameters.append(
449            messages_pb2.ResponseParameters(
450                size=_RESPONSE_PAYLOAD_SIZE,
451                interval_us=_RESPONSE_INTERVAL_US,
452            ))
453
454        call = self._stub.StreamingOutputCall(request,
455                                              timeout=_SHORT_TIMEOUT_S * 2)
456
457        response = await call.read()
458        self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
459
460        # Should be around the same as the timeout
461        remained_time = call.time_remaining()
462        self.assertGreater(remained_time, _SHORT_TIMEOUT_S * 3 / 2)
463        self.assertLess(remained_time, _SHORT_TIMEOUT_S * 5 / 2)
464
465        response = await call.read()
466        self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
467
468        # Should be around the timeout minus a unit of wait time
469        remained_time = call.time_remaining()
470        self.assertGreater(remained_time, _SHORT_TIMEOUT_S / 2)
471        self.assertLess(remained_time, _SHORT_TIMEOUT_S * 3 / 2)
472
473        self.assertEqual(grpc.StatusCode.OK, await call.code())
474
475    async def test_empty_responses(self):
476        # Prepares the request
477        request = messages_pb2.StreamingOutputCallRequest()
478        for _ in range(_NUM_STREAM_RESPONSES):
479            request.response_parameters.append(
480                messages_pb2.ResponseParameters())
481
482        # Invokes the actual RPC
483        call = self._stub.StreamingOutputCall(request)
484
485        for _ in range(_NUM_STREAM_RESPONSES):
486            response = await call.read()
487            self.assertIs(type(response),
488                          messages_pb2.StreamingOutputCallResponse)
489            self.assertEqual(b'', response.SerializeToString())
490
491        self.assertEqual(grpc.StatusCode.OK, await call.code())
492
493
494class TestStreamUnaryCall(_MulticallableTestMixin, AioTestBase):
495
496    async def test_cancel_stream_unary(self):
497        call = self._stub.StreamingInputCall()
498
499        # Prepares the request
500        payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
501        request = messages_pb2.StreamingInputCallRequest(payload=payload)
502
503        # Sends out requests
504        for _ in range(_NUM_STREAM_RESPONSES):
505            await call.write(request)
506
507        # Cancels the RPC
508        self.assertFalse(call.done())
509        self.assertFalse(call.cancelled())
510        self.assertTrue(call.cancel())
511        self.assertTrue(call.cancelled())
512
513        await call.done_writing()
514
515        with self.assertRaises(asyncio.CancelledError):
516            await call
517
518    async def test_early_cancel_stream_unary(self):
519        call = self._stub.StreamingInputCall()
520
521        # Cancels the RPC
522        self.assertFalse(call.done())
523        self.assertFalse(call.cancelled())
524        self.assertTrue(call.cancel())
525        self.assertTrue(call.cancelled())
526
527        with self.assertRaises(asyncio.InvalidStateError):
528            await call.write(messages_pb2.StreamingInputCallRequest())
529
530        # Should be no-op
531        await call.done_writing()
532
533        with self.assertRaises(asyncio.CancelledError):
534            await call
535
536    async def test_write_after_done_writing(self):
537        call = self._stub.StreamingInputCall()
538
539        # Prepares the request
540        payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
541        request = messages_pb2.StreamingInputCallRequest(payload=payload)
542
543        # Sends out requests
544        for _ in range(_NUM_STREAM_RESPONSES):
545            await call.write(request)
546
547        # Should be no-op
548        await call.done_writing()
549
550        with self.assertRaises(asyncio.InvalidStateError):
551            await call.write(messages_pb2.StreamingInputCallRequest())
552
553        response = await call
554        self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
555        self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
556                         response.aggregated_payload_size)
557
558        self.assertEqual(await call.code(), grpc.StatusCode.OK)
559
560    async def test_error_in_async_generator(self):
561        # Server will pause between responses
562        request = messages_pb2.StreamingOutputCallRequest()
563        request.response_parameters.append(
564            messages_pb2.ResponseParameters(
565                size=_RESPONSE_PAYLOAD_SIZE,
566                interval_us=_RESPONSE_INTERVAL_US,
567            ))
568
569        # We expect the request iterator to receive the exception
570        request_iterator_received_the_exception = asyncio.Event()
571
572        async def request_iterator():
573            with self.assertRaises(asyncio.CancelledError):
574                for _ in range(_NUM_STREAM_RESPONSES):
575                    yield request
576                    await asyncio.sleep(_SHORT_TIMEOUT_S)
577            request_iterator_received_the_exception.set()
578
579        call = self._stub.StreamingInputCall(request_iterator())
580
581        # Cancel the RPC after at least one response
582        async def cancel_later():
583            await asyncio.sleep(_SHORT_TIMEOUT_S * 2)
584            call.cancel()
585
586        cancel_later_task = self.loop.create_task(cancel_later())
587
588        with self.assertRaises(asyncio.CancelledError):
589            await call
590
591        await request_iterator_received_the_exception.wait()
592
593        # No failures in the cancel later task!
594        await cancel_later_task
595
596    async def test_normal_iterable_requests(self):
597        # Prepares the request
598        payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
599        request = messages_pb2.StreamingInputCallRequest(payload=payload)
600        requests = [request] * _NUM_STREAM_RESPONSES
601
602        # Sends out requests
603        call = self._stub.StreamingInputCall(requests)
604
605        # RPC should succeed
606        response = await call
607        self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
608        self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
609                         response.aggregated_payload_size)
610
611        self.assertEqual(await call.code(), grpc.StatusCode.OK)
612
613    async def test_call_rpc_error(self):
614        async with aio.insecure_channel(UNREACHABLE_TARGET) as channel:
615            stub = test_pb2_grpc.TestServiceStub(channel)
616
617            # The error should be raised automatically without any traffic.
618            call = stub.StreamingInputCall()
619            with self.assertRaises(aio.AioRpcError) as exception_context:
620                await call
621
622            self.assertEqual(grpc.StatusCode.UNAVAILABLE,
623                             exception_context.exception.code())
624
625            self.assertTrue(call.done())
626            self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code())
627
628    async def test_timeout(self):
629        call = self._stub.StreamingInputCall(timeout=_SHORT_TIMEOUT_S)
630
631        # The error should be raised automatically without any traffic.
632        with self.assertRaises(aio.AioRpcError) as exception_context:
633            await call
634
635        rpc_error = exception_context.exception
636        self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, rpc_error.code())
637        self.assertTrue(call.done())
638        self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, await call.code())
639
640
641# Prepares the request that stream in a ping-pong manner.
642_STREAM_OUTPUT_REQUEST_ONE_RESPONSE = messages_pb2.StreamingOutputCallRequest()
643_STREAM_OUTPUT_REQUEST_ONE_RESPONSE.response_parameters.append(
644    messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))
645_STREAM_OUTPUT_REQUEST_ONE_EMPTY_RESPONSE = messages_pb2.StreamingOutputCallRequest(
646)
647_STREAM_OUTPUT_REQUEST_ONE_EMPTY_RESPONSE.response_parameters.append(
648    messages_pb2.ResponseParameters())
649
650
651class TestStreamStreamCall(_MulticallableTestMixin, AioTestBase):
652
653    async def test_cancel(self):
654        # Invokes the actual RPC
655        call = self._stub.FullDuplexCall()
656
657        for _ in range(_NUM_STREAM_RESPONSES):
658            await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
659            response = await call.read()
660            self.assertIsInstance(response,
661                                  messages_pb2.StreamingOutputCallResponse)
662            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
663
664        # Cancels the RPC
665        self.assertFalse(call.done())
666        self.assertFalse(call.cancelled())
667        self.assertTrue(call.cancel())
668        self.assertTrue(call.cancelled())
669        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
670
671    async def test_cancel_with_pending_read(self):
672        call = self._stub.FullDuplexCall()
673
674        await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
675
676        # Cancels the RPC
677        self.assertFalse(call.done())
678        self.assertFalse(call.cancelled())
679        self.assertTrue(call.cancel())
680        self.assertTrue(call.cancelled())
681        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
682
683    async def test_cancel_with_ongoing_read(self):
684        call = self._stub.FullDuplexCall()
685        coro_started = asyncio.Event()
686
687        async def read_coro():
688            coro_started.set()
689            await call.read()
690
691        read_task = self.loop.create_task(read_coro())
692        await coro_started.wait()
693        self.assertFalse(read_task.done())
694
695        # Cancels the RPC
696        self.assertFalse(call.done())
697        self.assertFalse(call.cancelled())
698        self.assertTrue(call.cancel())
699        self.assertTrue(call.cancelled())
700        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
701
702    async def test_early_cancel(self):
703        call = self._stub.FullDuplexCall()
704
705        # Cancels the RPC
706        self.assertFalse(call.done())
707        self.assertFalse(call.cancelled())
708        self.assertTrue(call.cancel())
709        self.assertTrue(call.cancelled())
710        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
711
712    async def test_cancel_after_done_writing(self):
713        call = self._stub.FullDuplexCall()
714        await call.done_writing()
715
716        # Cancels the RPC
717        self.assertFalse(call.done())
718        self.assertFalse(call.cancelled())
719        self.assertTrue(call.cancel())
720        self.assertTrue(call.cancelled())
721        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
722
723    async def test_late_cancel(self):
724        call = self._stub.FullDuplexCall()
725        await call.done_writing()
726        self.assertEqual(grpc.StatusCode.OK, await call.code())
727
728        # Cancels the RPC
729        self.assertTrue(call.done())
730        self.assertFalse(call.cancelled())
731        self.assertFalse(call.cancel())
732        self.assertFalse(call.cancelled())
733
734        # Status is still OK
735        self.assertEqual(grpc.StatusCode.OK, await call.code())
736
737    async def test_async_generator(self):
738
739        async def request_generator():
740            yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE
741            yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE
742
743        call = self._stub.FullDuplexCall(request_generator())
744        async for response in call:
745            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
746
747        self.assertEqual(await call.code(), grpc.StatusCode.OK)
748
749    async def test_too_many_reads(self):
750
751        async def request_generator():
752            for _ in range(_NUM_STREAM_RESPONSES):
753                yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE
754
755        call = self._stub.FullDuplexCall(request_generator())
756        for _ in range(_NUM_STREAM_RESPONSES):
757            response = await call.read()
758            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
759        self.assertIs(await call.read(), aio.EOF)
760
761        self.assertEqual(await call.code(), grpc.StatusCode.OK)
762        # After the RPC finished, the read should also produce EOF
763        self.assertIs(await call.read(), aio.EOF)
764
765    async def test_read_write_after_done_writing(self):
766        call = self._stub.FullDuplexCall()
767
768        # Writes two requests, and pending two requests
769        await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
770        await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
771        await call.done_writing()
772
773        # Further write should fail
774        with self.assertRaises(asyncio.InvalidStateError):
775            await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
776
777        # But read should be unaffected
778        response = await call.read()
779        self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
780        response = await call.read()
781        self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
782
783        self.assertEqual(await call.code(), grpc.StatusCode.OK)
784
785    async def test_error_in_async_generator(self):
786        # Server will pause between responses
787        request = messages_pb2.StreamingOutputCallRequest()
788        request.response_parameters.append(
789            messages_pb2.ResponseParameters(
790                size=_RESPONSE_PAYLOAD_SIZE,
791                interval_us=_RESPONSE_INTERVAL_US,
792            ))
793
794        # We expect the request iterator to receive the exception
795        request_iterator_received_the_exception = asyncio.Event()
796
797        async def request_iterator():
798            with self.assertRaises(asyncio.CancelledError):
799                for _ in range(_NUM_STREAM_RESPONSES):
800                    yield request
801                    await asyncio.sleep(_SHORT_TIMEOUT_S)
802            request_iterator_received_the_exception.set()
803
804        call = self._stub.FullDuplexCall(request_iterator())
805
806        # Cancel the RPC after at least one response
807        async def cancel_later():
808            await asyncio.sleep(_SHORT_TIMEOUT_S * 2)
809            call.cancel()
810
811        cancel_later_task = self.loop.create_task(cancel_later())
812
813        with self.assertRaises(asyncio.CancelledError):
814            async for response in call:
815                self.assertEqual(_RESPONSE_PAYLOAD_SIZE,
816                                 len(response.payload.body))
817
818        await request_iterator_received_the_exception.wait()
819
820        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
821        # No failures in the cancel later task!
822        await cancel_later_task
823
824    async def test_normal_iterable_requests(self):
825        requests = [_STREAM_OUTPUT_REQUEST_ONE_RESPONSE] * _NUM_STREAM_RESPONSES
826
827        call = self._stub.FullDuplexCall(iter(requests))
828        async for response in call:
829            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
830
831        self.assertEqual(await call.code(), grpc.StatusCode.OK)
832
833    async def test_empty_ping_pong(self):
834        call = self._stub.FullDuplexCall()
835        for _ in range(_NUM_STREAM_RESPONSES):
836            await call.write(_STREAM_OUTPUT_REQUEST_ONE_EMPTY_RESPONSE)
837            response = await call.read()
838            self.assertEqual(b'', response.SerializeToString())
839        await call.done_writing()
840        self.assertEqual(await call.code(), grpc.StatusCode.OK)
841
842
843if __name__ == '__main__':
844    logging.basicConfig(level=logging.DEBUG)
845    unittest.main(verbosity=2)
846