• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Invocation-side implementation of gRPC Asyncio Python."""
15
16import asyncio
17import enum
18import inspect
19import logging
20from functools import partial
21from typing import AsyncIterable, Optional, Tuple
22
23import grpc
24from grpc import _common
25from grpc._cython import cygrpc
26
27from . import _base_call
28from ._metadata import Metadata
29from ._typing import (DeserializingFunction, DoneCallbackType, MetadatumType,
30                      RequestIterableType, RequestType, ResponseType,
31                      SerializingFunction)
32
33__all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall'
34
35_LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!'
36_GC_CANCELLATION_DETAILS = 'Cancelled upon garbage collection!'
37_RPC_ALREADY_FINISHED_DETAILS = 'RPC already finished.'
38_RPC_HALF_CLOSED_DETAILS = 'RPC is half closed after calling "done_writing".'
39_API_STYLE_ERROR = 'The iterator and read/write APIs may not be mixed on a single RPC.'
40
41_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n'
42                           '\tstatus = {}\n'
43                           '\tdetails = "{}"\n'
44                           '>')
45
46_NON_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n'
47                               '\tstatus = {}\n'
48                               '\tdetails = "{}"\n'
49                               '\tdebug_error_string = "{}"\n'
50                               '>')
51
52_LOGGER = logging.getLogger(__name__)
53
54
55class AioRpcError(grpc.RpcError):
56    """An implementation of RpcError to be used by the asynchronous API.
57
58    Raised RpcError is a snapshot of the final status of the RPC, values are
59    determined. Hence, its methods no longer needs to be coroutines.
60    """
61
62    _code: grpc.StatusCode
63    _details: Optional[str]
64    _initial_metadata: Optional[Metadata]
65    _trailing_metadata: Optional[Metadata]
66    _debug_error_string: Optional[str]
67
68    def __init__(self,
69                 code: grpc.StatusCode,
70                 initial_metadata: Metadata,
71                 trailing_metadata: Metadata,
72                 details: Optional[str] = None,
73                 debug_error_string: Optional[str] = None) -> None:
74        """Constructor.
75
76        Args:
77          code: The status code with which the RPC has been finalized.
78          details: Optional details explaining the reason of the error.
79          initial_metadata: Optional initial metadata that could be sent by the
80            Server.
81          trailing_metadata: Optional metadata that could be sent by the Server.
82        """
83
84        super().__init__(self)
85        self._code = code
86        self._details = details
87        self._initial_metadata = initial_metadata
88        self._trailing_metadata = trailing_metadata
89        self._debug_error_string = debug_error_string
90
91    def code(self) -> grpc.StatusCode:
92        """Accesses the status code sent by the server.
93
94        Returns:
95          The `grpc.StatusCode` status code.
96        """
97        return self._code
98
99    def details(self) -> Optional[str]:
100        """Accesses the details sent by the server.
101
102        Returns:
103          The description of the error.
104        """
105        return self._details
106
107    def initial_metadata(self) -> Metadata:
108        """Accesses the initial metadata sent by the server.
109
110        Returns:
111          The initial metadata received.
112        """
113        return self._initial_metadata
114
115    def trailing_metadata(self) -> Metadata:
116        """Accesses the trailing metadata sent by the server.
117
118        Returns:
119          The trailing metadata received.
120        """
121        return self._trailing_metadata
122
123    def debug_error_string(self) -> str:
124        """Accesses the debug error string sent by the server.
125
126        Returns:
127          The debug error string received.
128        """
129        return self._debug_error_string
130
131    def _repr(self) -> str:
132        """Assembles the error string for the RPC error."""
133        return _NON_OK_CALL_REPRESENTATION.format(self.__class__.__name__,
134                                                  self._code, self._details,
135                                                  self._debug_error_string)
136
137    def __repr__(self) -> str:
138        return self._repr()
139
140    def __str__(self) -> str:
141        return self._repr()
142
143
144def _create_rpc_error(initial_metadata: Metadata,
145                      status: cygrpc.AioRpcStatus) -> AioRpcError:
146    return AioRpcError(
147        _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()],
148        Metadata.from_tuple(initial_metadata),
149        Metadata.from_tuple(status.trailing_metadata()),
150        details=status.details(),
151        debug_error_string=status.debug_error_string(),
152    )
153
154
155class Call:
156    """Base implementation of client RPC Call object.
157
158    Implements logic around final status, metadata and cancellation.
159    """
160    _loop: asyncio.AbstractEventLoop
161    _code: grpc.StatusCode
162    _cython_call: cygrpc._AioCall
163    _metadata: Tuple[MetadatumType]
164    _request_serializer: SerializingFunction
165    _response_deserializer: DeserializingFunction
166
167    def __init__(self, cython_call: cygrpc._AioCall, metadata: Metadata,
168                 request_serializer: SerializingFunction,
169                 response_deserializer: DeserializingFunction,
170                 loop: asyncio.AbstractEventLoop) -> None:
171        self._loop = loop
172        self._cython_call = cython_call
173        self._metadata = tuple(metadata)
174        self._request_serializer = request_serializer
175        self._response_deserializer = response_deserializer
176
177    def __del__(self) -> None:
178        # The '_cython_call' object might be destructed before Call object
179        if hasattr(self, '_cython_call'):
180            if not self._cython_call.done():
181                self._cancel(_GC_CANCELLATION_DETAILS)
182
183    def cancelled(self) -> bool:
184        return self._cython_call.cancelled()
185
186    def _cancel(self, details: str) -> bool:
187        """Forwards the application cancellation reasoning."""
188        if not self._cython_call.done():
189            self._cython_call.cancel(details)
190            return True
191        else:
192            return False
193
194    def cancel(self) -> bool:
195        return self._cancel(_LOCAL_CANCELLATION_DETAILS)
196
197    def done(self) -> bool:
198        return self._cython_call.done()
199
200    def add_done_callback(self, callback: DoneCallbackType) -> None:
201        cb = partial(callback, self)
202        self._cython_call.add_done_callback(cb)
203
204    def time_remaining(self) -> Optional[float]:
205        return self._cython_call.time_remaining()
206
207    async def initial_metadata(self) -> Metadata:
208        raw_metadata_tuple = await self._cython_call.initial_metadata()
209        return Metadata.from_tuple(raw_metadata_tuple)
210
211    async def trailing_metadata(self) -> Metadata:
212        raw_metadata_tuple = (await
213                              self._cython_call.status()).trailing_metadata()
214        return Metadata.from_tuple(raw_metadata_tuple)
215
216    async def code(self) -> grpc.StatusCode:
217        cygrpc_code = (await self._cython_call.status()).code()
218        return _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[cygrpc_code]
219
220    async def details(self) -> str:
221        return (await self._cython_call.status()).details()
222
223    async def debug_error_string(self) -> str:
224        return (await self._cython_call.status()).debug_error_string()
225
226    async def _raise_for_status(self) -> None:
227        if self._cython_call.is_locally_cancelled():
228            raise asyncio.CancelledError()
229        code = await self.code()
230        if code != grpc.StatusCode.OK:
231            raise _create_rpc_error(await self.initial_metadata(), await
232                                    self._cython_call.status())
233
234    def _repr(self) -> str:
235        return repr(self._cython_call)
236
237    def __repr__(self) -> str:
238        return self._repr()
239
240    def __str__(self) -> str:
241        return self._repr()
242
243
244class _APIStyle(enum.IntEnum):
245    UNKNOWN = 0
246    ASYNC_GENERATOR = 1
247    READER_WRITER = 2
248
249
250class _UnaryResponseMixin(Call):
251    _call_response: asyncio.Task
252
253    def _init_unary_response_mixin(self, response_task: asyncio.Task):
254        self._call_response = response_task
255
256    def cancel(self) -> bool:
257        if super().cancel():
258            self._call_response.cancel()
259            return True
260        else:
261            return False
262
263    def __await__(self) -> ResponseType:
264        """Wait till the ongoing RPC request finishes."""
265        try:
266            response = yield from self._call_response
267        except asyncio.CancelledError:
268            # Even if we caught all other CancelledError, there is still
269            # this corner case. If the application cancels immediately after
270            # the Call object is created, we will observe this
271            # `CancelledError`.
272            if not self.cancelled():
273                self.cancel()
274            raise
275
276        # NOTE(lidiz) If we raise RpcError in the task, and users doesn't
277        # 'await' on it. AsyncIO will log 'Task exception was never retrieved'.
278        # Instead, if we move the exception raising here, the spam stops.
279        # Unfortunately, there can only be one 'yield from' in '__await__'. So,
280        # we need to access the private instance variable.
281        if response is cygrpc.EOF:
282            if self._cython_call.is_locally_cancelled():
283                raise asyncio.CancelledError()
284            else:
285                raise _create_rpc_error(self._cython_call._initial_metadata,
286                                        self._cython_call._status)
287        else:
288            return response
289
290
291class _StreamResponseMixin(Call):
292    _message_aiter: AsyncIterable[ResponseType]
293    _preparation: asyncio.Task
294    _response_style: _APIStyle
295
296    def _init_stream_response_mixin(self, preparation: asyncio.Task):
297        self._message_aiter = None
298        self._preparation = preparation
299        self._response_style = _APIStyle.UNKNOWN
300
301    def _update_response_style(self, style: _APIStyle):
302        if self._response_style is _APIStyle.UNKNOWN:
303            self._response_style = style
304        elif self._response_style is not style:
305            raise cygrpc.UsageError(_API_STYLE_ERROR)
306
307    def cancel(self) -> bool:
308        if super().cancel():
309            self._preparation.cancel()
310            return True
311        else:
312            return False
313
314    async def _fetch_stream_responses(self) -> ResponseType:
315        message = await self._read()
316        while message is not cygrpc.EOF:
317            yield message
318            message = await self._read()
319
320        # If the read operation failed, Core should explain why.
321        await self._raise_for_status()
322
323    def __aiter__(self) -> AsyncIterable[ResponseType]:
324        self._update_response_style(_APIStyle.ASYNC_GENERATOR)
325        if self._message_aiter is None:
326            self._message_aiter = self._fetch_stream_responses()
327        return self._message_aiter
328
329    async def _read(self) -> ResponseType:
330        # Wait for the request being sent
331        await self._preparation
332
333        # Reads response message from Core
334        try:
335            raw_response = await self._cython_call.receive_serialized_message()
336        except asyncio.CancelledError:
337            if not self.cancelled():
338                self.cancel()
339            await self._raise_for_status()
340
341        if raw_response is cygrpc.EOF:
342            return cygrpc.EOF
343        else:
344            return _common.deserialize(raw_response,
345                                       self._response_deserializer)
346
347    async def read(self) -> ResponseType:
348        if self.done():
349            await self._raise_for_status()
350            return cygrpc.EOF
351        self._update_response_style(_APIStyle.READER_WRITER)
352
353        response_message = await self._read()
354
355        if response_message is cygrpc.EOF:
356            # If the read operation failed, Core should explain why.
357            await self._raise_for_status()
358        return response_message
359
360
361class _StreamRequestMixin(Call):
362    _metadata_sent: asyncio.Event
363    _done_writing_flag: bool
364    _async_request_poller: Optional[asyncio.Task]
365    _request_style: _APIStyle
366
367    def _init_stream_request_mixin(
368            self, request_iterator: Optional[RequestIterableType]):
369        self._metadata_sent = asyncio.Event(loop=self._loop)
370        self._done_writing_flag = False
371
372        # If user passes in an async iterator, create a consumer Task.
373        if request_iterator is not None:
374            self._async_request_poller = self._loop.create_task(
375                self._consume_request_iterator(request_iterator))
376            self._request_style = _APIStyle.ASYNC_GENERATOR
377        else:
378            self._async_request_poller = None
379            self._request_style = _APIStyle.READER_WRITER
380
381    def _raise_for_different_style(self, style: _APIStyle):
382        if self._request_style is not style:
383            raise cygrpc.UsageError(_API_STYLE_ERROR)
384
385    def cancel(self) -> bool:
386        if super().cancel():
387            if self._async_request_poller is not None:
388                self._async_request_poller.cancel()
389            return True
390        else:
391            return False
392
393    def _metadata_sent_observer(self):
394        self._metadata_sent.set()
395
396    async def _consume_request_iterator(self,
397                                        request_iterator: RequestIterableType
398                                       ) -> None:
399        try:
400            if inspect.isasyncgen(request_iterator) or hasattr(
401                    request_iterator, '__aiter__'):
402                async for request in request_iterator:
403                    await self._write(request)
404            else:
405                for request in request_iterator:
406                    await self._write(request)
407
408            await self._done_writing()
409        except AioRpcError as rpc_error:
410            # Rpc status should be exposed through other API. Exceptions raised
411            # within this Task won't be retrieved by another coroutine. It's
412            # better to suppress the error than spamming users' screen.
413            _LOGGER.debug('Exception while consuming the request_iterator: %s',
414                          rpc_error)
415
416    async def _write(self, request: RequestType) -> None:
417        if self.done():
418            raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
419        if self._done_writing_flag:
420            raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
421        if not self._metadata_sent.is_set():
422            await self._metadata_sent.wait()
423            if self.done():
424                await self._raise_for_status()
425
426        serialized_request = _common.serialize(request,
427                                               self._request_serializer)
428        try:
429            await self._cython_call.send_serialized_message(serialized_request)
430        except asyncio.CancelledError:
431            if not self.cancelled():
432                self.cancel()
433            await self._raise_for_status()
434
435    async def _done_writing(self) -> None:
436        if self.done():
437            # If the RPC is finished, do nothing.
438            return
439        if not self._done_writing_flag:
440            # If the done writing is not sent before, try to send it.
441            self._done_writing_flag = True
442            try:
443                await self._cython_call.send_receive_close()
444            except asyncio.CancelledError:
445                if not self.cancelled():
446                    self.cancel()
447                await self._raise_for_status()
448
449    async def write(self, request: RequestType) -> None:
450        self._raise_for_different_style(_APIStyle.READER_WRITER)
451        await self._write(request)
452
453    async def done_writing(self) -> None:
454        """Signal peer that client is done writing.
455
456        This method is idempotent.
457        """
458        self._raise_for_different_style(_APIStyle.READER_WRITER)
459        await self._done_writing()
460
461    async def wait_for_connection(self) -> None:
462        await self._metadata_sent.wait()
463        if self.done():
464            await self._raise_for_status()
465
466
467class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall):
468    """Object for managing unary-unary RPC calls.
469
470    Returned when an instance of `UnaryUnaryMultiCallable` object is called.
471    """
472    _request: RequestType
473    _invocation_task: asyncio.Task
474
475    # pylint: disable=too-many-arguments
476    def __init__(self, request: RequestType, deadline: Optional[float],
477                 metadata: Metadata,
478                 credentials: Optional[grpc.CallCredentials],
479                 wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
480                 method: bytes, request_serializer: SerializingFunction,
481                 response_deserializer: DeserializingFunction,
482                 loop: asyncio.AbstractEventLoop) -> None:
483        super().__init__(
484            channel.call(method, deadline, credentials, wait_for_ready),
485            metadata, request_serializer, response_deserializer, loop)
486        self._request = request
487        self._invocation_task = loop.create_task(self._invoke())
488        self._init_unary_response_mixin(self._invocation_task)
489
490    async def _invoke(self) -> ResponseType:
491        serialized_request = _common.serialize(self._request,
492                                               self._request_serializer)
493
494        # NOTE(lidiz) asyncio.CancelledError is not a good transport for status,
495        # because the asyncio.Task class do not cache the exception object.
496        # https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785
497        try:
498            serialized_response = await self._cython_call.unary_unary(
499                serialized_request, self._metadata)
500        except asyncio.CancelledError:
501            if not self.cancelled():
502                self.cancel()
503
504        if self._cython_call.is_ok():
505            return _common.deserialize(serialized_response,
506                                       self._response_deserializer)
507        else:
508            return cygrpc.EOF
509
510    async def wait_for_connection(self) -> None:
511        await self._invocation_task
512        if self.done():
513            await self._raise_for_status()
514
515
516class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall):
517    """Object for managing unary-stream RPC calls.
518
519    Returned when an instance of `UnaryStreamMultiCallable` object is called.
520    """
521    _request: RequestType
522    _send_unary_request_task: asyncio.Task
523
524    # pylint: disable=too-many-arguments
525    def __init__(self, request: RequestType, deadline: Optional[float],
526                 metadata: Metadata,
527                 credentials: Optional[grpc.CallCredentials],
528                 wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
529                 method: bytes, request_serializer: SerializingFunction,
530                 response_deserializer: DeserializingFunction,
531                 loop: asyncio.AbstractEventLoop) -> None:
532        super().__init__(
533            channel.call(method, deadline, credentials, wait_for_ready),
534            metadata, request_serializer, response_deserializer, loop)
535        self._request = request
536        self._send_unary_request_task = loop.create_task(
537            self._send_unary_request())
538        self._init_stream_response_mixin(self._send_unary_request_task)
539
540    async def _send_unary_request(self) -> ResponseType:
541        serialized_request = _common.serialize(self._request,
542                                               self._request_serializer)
543        try:
544            await self._cython_call.initiate_unary_stream(
545                serialized_request, self._metadata)
546        except asyncio.CancelledError:
547            if not self.cancelled():
548                self.cancel()
549            raise
550
551    async def wait_for_connection(self) -> None:
552        await self._send_unary_request_task
553        if self.done():
554            await self._raise_for_status()
555
556
557class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call,
558                      _base_call.StreamUnaryCall):
559    """Object for managing stream-unary RPC calls.
560
561    Returned when an instance of `StreamUnaryMultiCallable` object is called.
562    """
563
564    # pylint: disable=too-many-arguments
565    def __init__(self, request_iterator: Optional[RequestIterableType],
566                 deadline: Optional[float], metadata: Metadata,
567                 credentials: Optional[grpc.CallCredentials],
568                 wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
569                 method: bytes, request_serializer: SerializingFunction,
570                 response_deserializer: DeserializingFunction,
571                 loop: asyncio.AbstractEventLoop) -> None:
572        super().__init__(
573            channel.call(method, deadline, credentials, wait_for_ready),
574            metadata, request_serializer, response_deserializer, loop)
575
576        self._init_stream_request_mixin(request_iterator)
577        self._init_unary_response_mixin(loop.create_task(self._conduct_rpc()))
578
579    async def _conduct_rpc(self) -> ResponseType:
580        try:
581            serialized_response = await self._cython_call.stream_unary(
582                self._metadata, self._metadata_sent_observer)
583        except asyncio.CancelledError:
584            if not self.cancelled():
585                self.cancel()
586
587        if self._cython_call.is_ok():
588            return _common.deserialize(serialized_response,
589                                       self._response_deserializer)
590        else:
591            return cygrpc.EOF
592
593
594class StreamStreamCall(_StreamRequestMixin, _StreamResponseMixin, Call,
595                       _base_call.StreamStreamCall):
596    """Object for managing stream-stream RPC calls.
597
598    Returned when an instance of `StreamStreamMultiCallable` object is called.
599    """
600    _initializer: asyncio.Task
601
602    # pylint: disable=too-many-arguments
603    def __init__(self, request_iterator: Optional[RequestIterableType],
604                 deadline: Optional[float], metadata: Metadata,
605                 credentials: Optional[grpc.CallCredentials],
606                 wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
607                 method: bytes, request_serializer: SerializingFunction,
608                 response_deserializer: DeserializingFunction,
609                 loop: asyncio.AbstractEventLoop) -> None:
610        super().__init__(
611            channel.call(method, deadline, credentials, wait_for_ready),
612            metadata, request_serializer, response_deserializer, loop)
613        self._initializer = self._loop.create_task(self._prepare_rpc())
614        self._init_stream_request_mixin(request_iterator)
615        self._init_stream_response_mixin(self._initializer)
616
617    async def _prepare_rpc(self):
618        """This method prepares the RPC for receiving/sending messages.
619
620        All other operations around the stream should only happen after the
621        completion of this method.
622        """
623        try:
624            await self._cython_call.initiate_stream_stream(
625                self._metadata, self._metadata_sent_observer)
626        except asyncio.CancelledError:
627            if not self.cancelled():
628                self.cancel()
629            # No need to raise RpcError here, because no one will `await` this task.
630