1# Copyright 2019 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Invocation-side implementation of gRPC Asyncio Python.""" 15 16import asyncio 17import enum 18import inspect 19import logging 20from functools import partial 21from typing import AsyncIterable, Optional, Tuple 22 23import grpc 24from grpc import _common 25from grpc._cython import cygrpc 26 27from . import _base_call 28from ._metadata import Metadata 29from ._typing import (DeserializingFunction, DoneCallbackType, MetadatumType, 30 RequestIterableType, RequestType, ResponseType, 31 SerializingFunction) 32 33__all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall' 34 35_LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!' 36_GC_CANCELLATION_DETAILS = 'Cancelled upon garbage collection!' 37_RPC_ALREADY_FINISHED_DETAILS = 'RPC already finished.' 38_RPC_HALF_CLOSED_DETAILS = 'RPC is half closed after calling "done_writing".' 39_API_STYLE_ERROR = 'The iterator and read/write APIs may not be mixed on a single RPC.' 40 41_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n' 42 '\tstatus = {}\n' 43 '\tdetails = "{}"\n' 44 '>') 45 46_NON_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n' 47 '\tstatus = {}\n' 48 '\tdetails = "{}"\n' 49 '\tdebug_error_string = "{}"\n' 50 '>') 51 52_LOGGER = logging.getLogger(__name__) 53 54 55class AioRpcError(grpc.RpcError): 56 """An implementation of RpcError to be used by the asynchronous API. 57 58 Raised RpcError is a snapshot of the final status of the RPC, values are 59 determined. Hence, its methods no longer needs to be coroutines. 60 """ 61 62 _code: grpc.StatusCode 63 _details: Optional[str] 64 _initial_metadata: Optional[Metadata] 65 _trailing_metadata: Optional[Metadata] 66 _debug_error_string: Optional[str] 67 68 def __init__(self, 69 code: grpc.StatusCode, 70 initial_metadata: Metadata, 71 trailing_metadata: Metadata, 72 details: Optional[str] = None, 73 debug_error_string: Optional[str] = None) -> None: 74 """Constructor. 75 76 Args: 77 code: The status code with which the RPC has been finalized. 78 details: Optional details explaining the reason of the error. 79 initial_metadata: Optional initial metadata that could be sent by the 80 Server. 81 trailing_metadata: Optional metadata that could be sent by the Server. 82 """ 83 84 super().__init__(self) 85 self._code = code 86 self._details = details 87 self._initial_metadata = initial_metadata 88 self._trailing_metadata = trailing_metadata 89 self._debug_error_string = debug_error_string 90 91 def code(self) -> grpc.StatusCode: 92 """Accesses the status code sent by the server. 93 94 Returns: 95 The `grpc.StatusCode` status code. 96 """ 97 return self._code 98 99 def details(self) -> Optional[str]: 100 """Accesses the details sent by the server. 101 102 Returns: 103 The description of the error. 104 """ 105 return self._details 106 107 def initial_metadata(self) -> Metadata: 108 """Accesses the initial metadata sent by the server. 109 110 Returns: 111 The initial metadata received. 112 """ 113 return self._initial_metadata 114 115 def trailing_metadata(self) -> Metadata: 116 """Accesses the trailing metadata sent by the server. 117 118 Returns: 119 The trailing metadata received. 120 """ 121 return self._trailing_metadata 122 123 def debug_error_string(self) -> str: 124 """Accesses the debug error string sent by the server. 125 126 Returns: 127 The debug error string received. 128 """ 129 return self._debug_error_string 130 131 def _repr(self) -> str: 132 """Assembles the error string for the RPC error.""" 133 return _NON_OK_CALL_REPRESENTATION.format(self.__class__.__name__, 134 self._code, self._details, 135 self._debug_error_string) 136 137 def __repr__(self) -> str: 138 return self._repr() 139 140 def __str__(self) -> str: 141 return self._repr() 142 143 144def _create_rpc_error(initial_metadata: Metadata, 145 status: cygrpc.AioRpcStatus) -> AioRpcError: 146 return AioRpcError( 147 _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()], 148 Metadata.from_tuple(initial_metadata), 149 Metadata.from_tuple(status.trailing_metadata()), 150 details=status.details(), 151 debug_error_string=status.debug_error_string(), 152 ) 153 154 155class Call: 156 """Base implementation of client RPC Call object. 157 158 Implements logic around final status, metadata and cancellation. 159 """ 160 _loop: asyncio.AbstractEventLoop 161 _code: grpc.StatusCode 162 _cython_call: cygrpc._AioCall 163 _metadata: Tuple[MetadatumType] 164 _request_serializer: SerializingFunction 165 _response_deserializer: DeserializingFunction 166 167 def __init__(self, cython_call: cygrpc._AioCall, metadata: Metadata, 168 request_serializer: SerializingFunction, 169 response_deserializer: DeserializingFunction, 170 loop: asyncio.AbstractEventLoop) -> None: 171 self._loop = loop 172 self._cython_call = cython_call 173 self._metadata = tuple(metadata) 174 self._request_serializer = request_serializer 175 self._response_deserializer = response_deserializer 176 177 def __del__(self) -> None: 178 # The '_cython_call' object might be destructed before Call object 179 if hasattr(self, '_cython_call'): 180 if not self._cython_call.done(): 181 self._cancel(_GC_CANCELLATION_DETAILS) 182 183 def cancelled(self) -> bool: 184 return self._cython_call.cancelled() 185 186 def _cancel(self, details: str) -> bool: 187 """Forwards the application cancellation reasoning.""" 188 if not self._cython_call.done(): 189 self._cython_call.cancel(details) 190 return True 191 else: 192 return False 193 194 def cancel(self) -> bool: 195 return self._cancel(_LOCAL_CANCELLATION_DETAILS) 196 197 def done(self) -> bool: 198 return self._cython_call.done() 199 200 def add_done_callback(self, callback: DoneCallbackType) -> None: 201 cb = partial(callback, self) 202 self._cython_call.add_done_callback(cb) 203 204 def time_remaining(self) -> Optional[float]: 205 return self._cython_call.time_remaining() 206 207 async def initial_metadata(self) -> Metadata: 208 raw_metadata_tuple = await self._cython_call.initial_metadata() 209 return Metadata.from_tuple(raw_metadata_tuple) 210 211 async def trailing_metadata(self) -> Metadata: 212 raw_metadata_tuple = (await 213 self._cython_call.status()).trailing_metadata() 214 return Metadata.from_tuple(raw_metadata_tuple) 215 216 async def code(self) -> grpc.StatusCode: 217 cygrpc_code = (await self._cython_call.status()).code() 218 return _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[cygrpc_code] 219 220 async def details(self) -> str: 221 return (await self._cython_call.status()).details() 222 223 async def debug_error_string(self) -> str: 224 return (await self._cython_call.status()).debug_error_string() 225 226 async def _raise_for_status(self) -> None: 227 if self._cython_call.is_locally_cancelled(): 228 raise asyncio.CancelledError() 229 code = await self.code() 230 if code != grpc.StatusCode.OK: 231 raise _create_rpc_error(await self.initial_metadata(), await 232 self._cython_call.status()) 233 234 def _repr(self) -> str: 235 return repr(self._cython_call) 236 237 def __repr__(self) -> str: 238 return self._repr() 239 240 def __str__(self) -> str: 241 return self._repr() 242 243 244class _APIStyle(enum.IntEnum): 245 UNKNOWN = 0 246 ASYNC_GENERATOR = 1 247 READER_WRITER = 2 248 249 250class _UnaryResponseMixin(Call): 251 _call_response: asyncio.Task 252 253 def _init_unary_response_mixin(self, response_task: asyncio.Task): 254 self._call_response = response_task 255 256 def cancel(self) -> bool: 257 if super().cancel(): 258 self._call_response.cancel() 259 return True 260 else: 261 return False 262 263 def __await__(self) -> ResponseType: 264 """Wait till the ongoing RPC request finishes.""" 265 try: 266 response = yield from self._call_response 267 except asyncio.CancelledError: 268 # Even if we caught all other CancelledError, there is still 269 # this corner case. If the application cancels immediately after 270 # the Call object is created, we will observe this 271 # `CancelledError`. 272 if not self.cancelled(): 273 self.cancel() 274 raise 275 276 # NOTE(lidiz) If we raise RpcError in the task, and users doesn't 277 # 'await' on it. AsyncIO will log 'Task exception was never retrieved'. 278 # Instead, if we move the exception raising here, the spam stops. 279 # Unfortunately, there can only be one 'yield from' in '__await__'. So, 280 # we need to access the private instance variable. 281 if response is cygrpc.EOF: 282 if self._cython_call.is_locally_cancelled(): 283 raise asyncio.CancelledError() 284 else: 285 raise _create_rpc_error(self._cython_call._initial_metadata, 286 self._cython_call._status) 287 else: 288 return response 289 290 291class _StreamResponseMixin(Call): 292 _message_aiter: AsyncIterable[ResponseType] 293 _preparation: asyncio.Task 294 _response_style: _APIStyle 295 296 def _init_stream_response_mixin(self, preparation: asyncio.Task): 297 self._message_aiter = None 298 self._preparation = preparation 299 self._response_style = _APIStyle.UNKNOWN 300 301 def _update_response_style(self, style: _APIStyle): 302 if self._response_style is _APIStyle.UNKNOWN: 303 self._response_style = style 304 elif self._response_style is not style: 305 raise cygrpc.UsageError(_API_STYLE_ERROR) 306 307 def cancel(self) -> bool: 308 if super().cancel(): 309 self._preparation.cancel() 310 return True 311 else: 312 return False 313 314 async def _fetch_stream_responses(self) -> ResponseType: 315 message = await self._read() 316 while message is not cygrpc.EOF: 317 yield message 318 message = await self._read() 319 320 # If the read operation failed, Core should explain why. 321 await self._raise_for_status() 322 323 def __aiter__(self) -> AsyncIterable[ResponseType]: 324 self._update_response_style(_APIStyle.ASYNC_GENERATOR) 325 if self._message_aiter is None: 326 self._message_aiter = self._fetch_stream_responses() 327 return self._message_aiter 328 329 async def _read(self) -> ResponseType: 330 # Wait for the request being sent 331 await self._preparation 332 333 # Reads response message from Core 334 try: 335 raw_response = await self._cython_call.receive_serialized_message() 336 except asyncio.CancelledError: 337 if not self.cancelled(): 338 self.cancel() 339 await self._raise_for_status() 340 341 if raw_response is cygrpc.EOF: 342 return cygrpc.EOF 343 else: 344 return _common.deserialize(raw_response, 345 self._response_deserializer) 346 347 async def read(self) -> ResponseType: 348 if self.done(): 349 await self._raise_for_status() 350 return cygrpc.EOF 351 self._update_response_style(_APIStyle.READER_WRITER) 352 353 response_message = await self._read() 354 355 if response_message is cygrpc.EOF: 356 # If the read operation failed, Core should explain why. 357 await self._raise_for_status() 358 return response_message 359 360 361class _StreamRequestMixin(Call): 362 _metadata_sent: asyncio.Event 363 _done_writing_flag: bool 364 _async_request_poller: Optional[asyncio.Task] 365 _request_style: _APIStyle 366 367 def _init_stream_request_mixin( 368 self, request_iterator: Optional[RequestIterableType]): 369 self._metadata_sent = asyncio.Event(loop=self._loop) 370 self._done_writing_flag = False 371 372 # If user passes in an async iterator, create a consumer Task. 373 if request_iterator is not None: 374 self._async_request_poller = self._loop.create_task( 375 self._consume_request_iterator(request_iterator)) 376 self._request_style = _APIStyle.ASYNC_GENERATOR 377 else: 378 self._async_request_poller = None 379 self._request_style = _APIStyle.READER_WRITER 380 381 def _raise_for_different_style(self, style: _APIStyle): 382 if self._request_style is not style: 383 raise cygrpc.UsageError(_API_STYLE_ERROR) 384 385 def cancel(self) -> bool: 386 if super().cancel(): 387 if self._async_request_poller is not None: 388 self._async_request_poller.cancel() 389 return True 390 else: 391 return False 392 393 def _metadata_sent_observer(self): 394 self._metadata_sent.set() 395 396 async def _consume_request_iterator(self, 397 request_iterator: RequestIterableType 398 ) -> None: 399 try: 400 if inspect.isasyncgen(request_iterator) or hasattr( 401 request_iterator, '__aiter__'): 402 async for request in request_iterator: 403 await self._write(request) 404 else: 405 for request in request_iterator: 406 await self._write(request) 407 408 await self._done_writing() 409 except AioRpcError as rpc_error: 410 # Rpc status should be exposed through other API. Exceptions raised 411 # within this Task won't be retrieved by another coroutine. It's 412 # better to suppress the error than spamming users' screen. 413 _LOGGER.debug('Exception while consuming the request_iterator: %s', 414 rpc_error) 415 416 async def _write(self, request: RequestType) -> None: 417 if self.done(): 418 raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) 419 if self._done_writing_flag: 420 raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS) 421 if not self._metadata_sent.is_set(): 422 await self._metadata_sent.wait() 423 if self.done(): 424 await self._raise_for_status() 425 426 serialized_request = _common.serialize(request, 427 self._request_serializer) 428 try: 429 await self._cython_call.send_serialized_message(serialized_request) 430 except asyncio.CancelledError: 431 if not self.cancelled(): 432 self.cancel() 433 await self._raise_for_status() 434 435 async def _done_writing(self) -> None: 436 if self.done(): 437 # If the RPC is finished, do nothing. 438 return 439 if not self._done_writing_flag: 440 # If the done writing is not sent before, try to send it. 441 self._done_writing_flag = True 442 try: 443 await self._cython_call.send_receive_close() 444 except asyncio.CancelledError: 445 if not self.cancelled(): 446 self.cancel() 447 await self._raise_for_status() 448 449 async def write(self, request: RequestType) -> None: 450 self._raise_for_different_style(_APIStyle.READER_WRITER) 451 await self._write(request) 452 453 async def done_writing(self) -> None: 454 """Signal peer that client is done writing. 455 456 This method is idempotent. 457 """ 458 self._raise_for_different_style(_APIStyle.READER_WRITER) 459 await self._done_writing() 460 461 async def wait_for_connection(self) -> None: 462 await self._metadata_sent.wait() 463 if self.done(): 464 await self._raise_for_status() 465 466 467class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall): 468 """Object for managing unary-unary RPC calls. 469 470 Returned when an instance of `UnaryUnaryMultiCallable` object is called. 471 """ 472 _request: RequestType 473 _invocation_task: asyncio.Task 474 475 # pylint: disable=too-many-arguments 476 def __init__(self, request: RequestType, deadline: Optional[float], 477 metadata: Metadata, 478 credentials: Optional[grpc.CallCredentials], 479 wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, 480 method: bytes, request_serializer: SerializingFunction, 481 response_deserializer: DeserializingFunction, 482 loop: asyncio.AbstractEventLoop) -> None: 483 super().__init__( 484 channel.call(method, deadline, credentials, wait_for_ready), 485 metadata, request_serializer, response_deserializer, loop) 486 self._request = request 487 self._invocation_task = loop.create_task(self._invoke()) 488 self._init_unary_response_mixin(self._invocation_task) 489 490 async def _invoke(self) -> ResponseType: 491 serialized_request = _common.serialize(self._request, 492 self._request_serializer) 493 494 # NOTE(lidiz) asyncio.CancelledError is not a good transport for status, 495 # because the asyncio.Task class do not cache the exception object. 496 # https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785 497 try: 498 serialized_response = await self._cython_call.unary_unary( 499 serialized_request, self._metadata) 500 except asyncio.CancelledError: 501 if not self.cancelled(): 502 self.cancel() 503 504 if self._cython_call.is_ok(): 505 return _common.deserialize(serialized_response, 506 self._response_deserializer) 507 else: 508 return cygrpc.EOF 509 510 async def wait_for_connection(self) -> None: 511 await self._invocation_task 512 if self.done(): 513 await self._raise_for_status() 514 515 516class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall): 517 """Object for managing unary-stream RPC calls. 518 519 Returned when an instance of `UnaryStreamMultiCallable` object is called. 520 """ 521 _request: RequestType 522 _send_unary_request_task: asyncio.Task 523 524 # pylint: disable=too-many-arguments 525 def __init__(self, request: RequestType, deadline: Optional[float], 526 metadata: Metadata, 527 credentials: Optional[grpc.CallCredentials], 528 wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, 529 method: bytes, request_serializer: SerializingFunction, 530 response_deserializer: DeserializingFunction, 531 loop: asyncio.AbstractEventLoop) -> None: 532 super().__init__( 533 channel.call(method, deadline, credentials, wait_for_ready), 534 metadata, request_serializer, response_deserializer, loop) 535 self._request = request 536 self._send_unary_request_task = loop.create_task( 537 self._send_unary_request()) 538 self._init_stream_response_mixin(self._send_unary_request_task) 539 540 async def _send_unary_request(self) -> ResponseType: 541 serialized_request = _common.serialize(self._request, 542 self._request_serializer) 543 try: 544 await self._cython_call.initiate_unary_stream( 545 serialized_request, self._metadata) 546 except asyncio.CancelledError: 547 if not self.cancelled(): 548 self.cancel() 549 raise 550 551 async def wait_for_connection(self) -> None: 552 await self._send_unary_request_task 553 if self.done(): 554 await self._raise_for_status() 555 556 557class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call, 558 _base_call.StreamUnaryCall): 559 """Object for managing stream-unary RPC calls. 560 561 Returned when an instance of `StreamUnaryMultiCallable` object is called. 562 """ 563 564 # pylint: disable=too-many-arguments 565 def __init__(self, request_iterator: Optional[RequestIterableType], 566 deadline: Optional[float], metadata: Metadata, 567 credentials: Optional[grpc.CallCredentials], 568 wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, 569 method: bytes, request_serializer: SerializingFunction, 570 response_deserializer: DeserializingFunction, 571 loop: asyncio.AbstractEventLoop) -> None: 572 super().__init__( 573 channel.call(method, deadline, credentials, wait_for_ready), 574 metadata, request_serializer, response_deserializer, loop) 575 576 self._init_stream_request_mixin(request_iterator) 577 self._init_unary_response_mixin(loop.create_task(self._conduct_rpc())) 578 579 async def _conduct_rpc(self) -> ResponseType: 580 try: 581 serialized_response = await self._cython_call.stream_unary( 582 self._metadata, self._metadata_sent_observer) 583 except asyncio.CancelledError: 584 if not self.cancelled(): 585 self.cancel() 586 587 if self._cython_call.is_ok(): 588 return _common.deserialize(serialized_response, 589 self._response_deserializer) 590 else: 591 return cygrpc.EOF 592 593 594class StreamStreamCall(_StreamRequestMixin, _StreamResponseMixin, Call, 595 _base_call.StreamStreamCall): 596 """Object for managing stream-stream RPC calls. 597 598 Returned when an instance of `StreamStreamMultiCallable` object is called. 599 """ 600 _initializer: asyncio.Task 601 602 # pylint: disable=too-many-arguments 603 def __init__(self, request_iterator: Optional[RequestIterableType], 604 deadline: Optional[float], metadata: Metadata, 605 credentials: Optional[grpc.CallCredentials], 606 wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, 607 method: bytes, request_serializer: SerializingFunction, 608 response_deserializer: DeserializingFunction, 609 loop: asyncio.AbstractEventLoop) -> None: 610 super().__init__( 611 channel.call(method, deadline, credentials, wait_for_ready), 612 metadata, request_serializer, response_deserializer, loop) 613 self._initializer = self._loop.create_task(self._prepare_rpc()) 614 self._init_stream_request_mixin(request_iterator) 615 self._init_stream_response_mixin(self._initializer) 616 617 async def _prepare_rpc(self): 618 """This method prepares the RPC for receiving/sending messages. 619 620 All other operations around the stream should only happen after the 621 completion of this method. 622 """ 623 try: 624 await self._cython_call.initiate_stream_stream( 625 self._metadata, self._metadata_sent_observer) 626 except asyncio.CancelledError: 627 if not self.cancelled(): 628 self.cancel() 629 # No need to raise RpcError here, because no one will `await` this task. 630