1# Copyright 2016 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Translates gRPC's client-side API into gRPC's client-side Beta API.""" 15 16import grpc 17from grpc import _common 18from grpc.beta import _metadata 19from grpc.beta import interfaces 20from grpc.framework.common import cardinality 21from grpc.framework.foundation import future 22from grpc.framework.interfaces.face import face 23 24# pylint: disable=too-many-arguments,too-many-locals,unused-argument 25 26_STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS = { 27 grpc.StatusCode.CANCELLED: 28 (face.Abortion.Kind.CANCELLED, face.CancellationError), 29 grpc.StatusCode.UNKNOWN: 30 (face.Abortion.Kind.REMOTE_FAILURE, face.RemoteError), 31 grpc.StatusCode.DEADLINE_EXCEEDED: 32 (face.Abortion.Kind.EXPIRED, face.ExpirationError), 33 grpc.StatusCode.UNIMPLEMENTED: 34 (face.Abortion.Kind.LOCAL_FAILURE, face.LocalError), 35} 36 37 38def _effective_metadata(metadata, metadata_transformer): 39 non_none_metadata = () if metadata is None else metadata 40 if metadata_transformer is None: 41 return non_none_metadata 42 else: 43 return metadata_transformer(non_none_metadata) 44 45 46def _credentials(grpc_call_options): 47 return None if grpc_call_options is None else grpc_call_options.credentials 48 49 50def _abortion(rpc_error_call): 51 code = rpc_error_call.code() 52 pair = _STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS.get(code) 53 error_kind = face.Abortion.Kind.LOCAL_FAILURE if pair is None else pair[0] 54 return face.Abortion(error_kind, rpc_error_call.initial_metadata(), 55 rpc_error_call.trailing_metadata(), code, 56 rpc_error_call.details()) 57 58 59def _abortion_error(rpc_error_call): 60 code = rpc_error_call.code() 61 pair = _STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS.get(code) 62 exception_class = face.AbortionError if pair is None else pair[1] 63 return exception_class(rpc_error_call.initial_metadata(), 64 rpc_error_call.trailing_metadata(), code, 65 rpc_error_call.details()) 66 67 68class _InvocationProtocolContext(interfaces.GRPCInvocationContext): 69 70 def disable_next_request_compression(self): 71 pass # TODO(https://github.com/grpc/grpc/issues/4078): design, implement. 72 73 74class _Rendezvous(future.Future, face.Call): 75 76 def __init__(self, response_future, response_iterator, call): 77 self._future = response_future 78 self._iterator = response_iterator 79 self._call = call 80 81 def cancel(self): 82 return self._call.cancel() 83 84 def cancelled(self): 85 return self._future.cancelled() 86 87 def running(self): 88 return self._future.running() 89 90 def done(self): 91 return self._future.done() 92 93 def result(self, timeout=None): 94 try: 95 return self._future.result(timeout=timeout) 96 except grpc.RpcError as rpc_error_call: 97 raise _abortion_error(rpc_error_call) 98 except grpc.FutureTimeoutError: 99 raise future.TimeoutError() 100 except grpc.FutureCancelledError: 101 raise future.CancelledError() 102 103 def exception(self, timeout=None): 104 try: 105 rpc_error_call = self._future.exception(timeout=timeout) 106 if rpc_error_call is None: 107 return None 108 else: 109 return _abortion_error(rpc_error_call) 110 except grpc.FutureTimeoutError: 111 raise future.TimeoutError() 112 except grpc.FutureCancelledError: 113 raise future.CancelledError() 114 115 def traceback(self, timeout=None): 116 try: 117 return self._future.traceback(timeout=timeout) 118 except grpc.FutureTimeoutError: 119 raise future.TimeoutError() 120 except grpc.FutureCancelledError: 121 raise future.CancelledError() 122 123 def add_done_callback(self, fn): 124 self._future.add_done_callback(lambda ignored_callback: fn(self)) 125 126 def __iter__(self): 127 return self 128 129 def _next(self): 130 try: 131 return next(self._iterator) 132 except grpc.RpcError as rpc_error_call: 133 raise _abortion_error(rpc_error_call) 134 135 def __next__(self): 136 return self._next() 137 138 def next(self): 139 return self._next() 140 141 def is_active(self): 142 return self._call.is_active() 143 144 def time_remaining(self): 145 return self._call.time_remaining() 146 147 def add_abortion_callback(self, abortion_callback): 148 149 def done_callback(): 150 if self.code() is not grpc.StatusCode.OK: 151 abortion_callback(_abortion(self._call)) 152 153 registered = self._call.add_callback(done_callback) 154 return None if registered else done_callback() 155 156 def protocol_context(self): 157 return _InvocationProtocolContext() 158 159 def initial_metadata(self): 160 return _metadata.beta(self._call.initial_metadata()) 161 162 def terminal_metadata(self): 163 return _metadata.beta(self._call.terminal_metadata()) 164 165 def code(self): 166 return self._call.code() 167 168 def details(self): 169 return self._call.details() 170 171 172def _blocking_unary_unary(channel, group, method, timeout, with_call, 173 protocol_options, metadata, metadata_transformer, 174 request, request_serializer, response_deserializer): 175 try: 176 multi_callable = channel.unary_unary( 177 _common.fully_qualified_method(group, method), 178 request_serializer=request_serializer, 179 response_deserializer=response_deserializer) 180 effective_metadata = _effective_metadata(metadata, metadata_transformer) 181 if with_call: 182 response, call = multi_callable.with_call( 183 request, 184 timeout=timeout, 185 metadata=_metadata.unbeta(effective_metadata), 186 credentials=_credentials(protocol_options)) 187 return response, _Rendezvous(None, None, call) 188 else: 189 return multi_callable(request, 190 timeout=timeout, 191 metadata=_metadata.unbeta(effective_metadata), 192 credentials=_credentials(protocol_options)) 193 except grpc.RpcError as rpc_error_call: 194 raise _abortion_error(rpc_error_call) 195 196 197def _future_unary_unary(channel, group, method, timeout, protocol_options, 198 metadata, metadata_transformer, request, 199 request_serializer, response_deserializer): 200 multi_callable = channel.unary_unary( 201 _common.fully_qualified_method(group, method), 202 request_serializer=request_serializer, 203 response_deserializer=response_deserializer) 204 effective_metadata = _effective_metadata(metadata, metadata_transformer) 205 response_future = multi_callable.future( 206 request, 207 timeout=timeout, 208 metadata=_metadata.unbeta(effective_metadata), 209 credentials=_credentials(protocol_options)) 210 return _Rendezvous(response_future, None, response_future) 211 212 213def _unary_stream(channel, group, method, timeout, protocol_options, metadata, 214 metadata_transformer, request, request_serializer, 215 response_deserializer): 216 multi_callable = channel.unary_stream( 217 _common.fully_qualified_method(group, method), 218 request_serializer=request_serializer, 219 response_deserializer=response_deserializer) 220 effective_metadata = _effective_metadata(metadata, metadata_transformer) 221 response_iterator = multi_callable( 222 request, 223 timeout=timeout, 224 metadata=_metadata.unbeta(effective_metadata), 225 credentials=_credentials(protocol_options)) 226 return _Rendezvous(None, response_iterator, response_iterator) 227 228 229def _blocking_stream_unary(channel, group, method, timeout, with_call, 230 protocol_options, metadata, metadata_transformer, 231 request_iterator, request_serializer, 232 response_deserializer): 233 try: 234 multi_callable = channel.stream_unary( 235 _common.fully_qualified_method(group, method), 236 request_serializer=request_serializer, 237 response_deserializer=response_deserializer) 238 effective_metadata = _effective_metadata(metadata, metadata_transformer) 239 if with_call: 240 response, call = multi_callable.with_call( 241 request_iterator, 242 timeout=timeout, 243 metadata=_metadata.unbeta(effective_metadata), 244 credentials=_credentials(protocol_options)) 245 return response, _Rendezvous(None, None, call) 246 else: 247 return multi_callable(request_iterator, 248 timeout=timeout, 249 metadata=_metadata.unbeta(effective_metadata), 250 credentials=_credentials(protocol_options)) 251 except grpc.RpcError as rpc_error_call: 252 raise _abortion_error(rpc_error_call) 253 254 255def _future_stream_unary(channel, group, method, timeout, protocol_options, 256 metadata, metadata_transformer, request_iterator, 257 request_serializer, response_deserializer): 258 multi_callable = channel.stream_unary( 259 _common.fully_qualified_method(group, method), 260 request_serializer=request_serializer, 261 response_deserializer=response_deserializer) 262 effective_metadata = _effective_metadata(metadata, metadata_transformer) 263 response_future = multi_callable.future( 264 request_iterator, 265 timeout=timeout, 266 metadata=_metadata.unbeta(effective_metadata), 267 credentials=_credentials(protocol_options)) 268 return _Rendezvous(response_future, None, response_future) 269 270 271def _stream_stream(channel, group, method, timeout, protocol_options, metadata, 272 metadata_transformer, request_iterator, request_serializer, 273 response_deserializer): 274 multi_callable = channel.stream_stream( 275 _common.fully_qualified_method(group, method), 276 request_serializer=request_serializer, 277 response_deserializer=response_deserializer) 278 effective_metadata = _effective_metadata(metadata, metadata_transformer) 279 response_iterator = multi_callable( 280 request_iterator, 281 timeout=timeout, 282 metadata=_metadata.unbeta(effective_metadata), 283 credentials=_credentials(protocol_options)) 284 return _Rendezvous(None, response_iterator, response_iterator) 285 286 287class _UnaryUnaryMultiCallable(face.UnaryUnaryMultiCallable): 288 289 def __init__(self, channel, group, method, metadata_transformer, 290 request_serializer, response_deserializer): 291 self._channel = channel 292 self._group = group 293 self._method = method 294 self._metadata_transformer = metadata_transformer 295 self._request_serializer = request_serializer 296 self._response_deserializer = response_deserializer 297 298 def __call__(self, 299 request, 300 timeout, 301 metadata=None, 302 with_call=False, 303 protocol_options=None): 304 return _blocking_unary_unary(self._channel, self._group, self._method, 305 timeout, with_call, protocol_options, 306 metadata, self._metadata_transformer, 307 request, self._request_serializer, 308 self._response_deserializer) 309 310 def future(self, request, timeout, metadata=None, protocol_options=None): 311 return _future_unary_unary(self._channel, self._group, self._method, 312 timeout, protocol_options, metadata, 313 self._metadata_transformer, request, 314 self._request_serializer, 315 self._response_deserializer) 316 317 def event(self, 318 request, 319 receiver, 320 abortion_callback, 321 timeout, 322 metadata=None, 323 protocol_options=None): 324 raise NotImplementedError() 325 326 327class _UnaryStreamMultiCallable(face.UnaryStreamMultiCallable): 328 329 def __init__(self, channel, group, method, metadata_transformer, 330 request_serializer, response_deserializer): 331 self._channel = channel 332 self._group = group 333 self._method = method 334 self._metadata_transformer = metadata_transformer 335 self._request_serializer = request_serializer 336 self._response_deserializer = response_deserializer 337 338 def __call__(self, request, timeout, metadata=None, protocol_options=None): 339 return _unary_stream(self._channel, self._group, self._method, timeout, 340 protocol_options, metadata, 341 self._metadata_transformer, request, 342 self._request_serializer, 343 self._response_deserializer) 344 345 def event(self, 346 request, 347 receiver, 348 abortion_callback, 349 timeout, 350 metadata=None, 351 protocol_options=None): 352 raise NotImplementedError() 353 354 355class _StreamUnaryMultiCallable(face.StreamUnaryMultiCallable): 356 357 def __init__(self, channel, group, method, metadata_transformer, 358 request_serializer, response_deserializer): 359 self._channel = channel 360 self._group = group 361 self._method = method 362 self._metadata_transformer = metadata_transformer 363 self._request_serializer = request_serializer 364 self._response_deserializer = response_deserializer 365 366 def __call__(self, 367 request_iterator, 368 timeout, 369 metadata=None, 370 with_call=False, 371 protocol_options=None): 372 return _blocking_stream_unary(self._channel, self._group, self._method, 373 timeout, with_call, protocol_options, 374 metadata, self._metadata_transformer, 375 request_iterator, 376 self._request_serializer, 377 self._response_deserializer) 378 379 def future(self, 380 request_iterator, 381 timeout, 382 metadata=None, 383 protocol_options=None): 384 return _future_stream_unary(self._channel, self._group, self._method, 385 timeout, protocol_options, metadata, 386 self._metadata_transformer, 387 request_iterator, self._request_serializer, 388 self._response_deserializer) 389 390 def event(self, 391 receiver, 392 abortion_callback, 393 timeout, 394 metadata=None, 395 protocol_options=None): 396 raise NotImplementedError() 397 398 399class _StreamStreamMultiCallable(face.StreamStreamMultiCallable): 400 401 def __init__(self, channel, group, method, metadata_transformer, 402 request_serializer, response_deserializer): 403 self._channel = channel 404 self._group = group 405 self._method = method 406 self._metadata_transformer = metadata_transformer 407 self._request_serializer = request_serializer 408 self._response_deserializer = response_deserializer 409 410 def __call__(self, 411 request_iterator, 412 timeout, 413 metadata=None, 414 protocol_options=None): 415 return _stream_stream(self._channel, self._group, self._method, timeout, 416 protocol_options, metadata, 417 self._metadata_transformer, request_iterator, 418 self._request_serializer, 419 self._response_deserializer) 420 421 def event(self, 422 receiver, 423 abortion_callback, 424 timeout, 425 metadata=None, 426 protocol_options=None): 427 raise NotImplementedError() 428 429 430class _GenericStub(face.GenericStub): 431 432 def __init__(self, channel, metadata_transformer, request_serializers, 433 response_deserializers): 434 self._channel = channel 435 self._metadata_transformer = metadata_transformer 436 self._request_serializers = request_serializers or {} 437 self._response_deserializers = response_deserializers or {} 438 439 def blocking_unary_unary(self, 440 group, 441 method, 442 request, 443 timeout, 444 metadata=None, 445 with_call=None, 446 protocol_options=None): 447 request_serializer = self._request_serializers.get(( 448 group, 449 method, 450 )) 451 response_deserializer = self._response_deserializers.get(( 452 group, 453 method, 454 )) 455 return _blocking_unary_unary(self._channel, group, method, timeout, 456 with_call, protocol_options, metadata, 457 self._metadata_transformer, request, 458 request_serializer, response_deserializer) 459 460 def future_unary_unary(self, 461 group, 462 method, 463 request, 464 timeout, 465 metadata=None, 466 protocol_options=None): 467 request_serializer = self._request_serializers.get(( 468 group, 469 method, 470 )) 471 response_deserializer = self._response_deserializers.get(( 472 group, 473 method, 474 )) 475 return _future_unary_unary(self._channel, group, method, timeout, 476 protocol_options, metadata, 477 self._metadata_transformer, request, 478 request_serializer, response_deserializer) 479 480 def inline_unary_stream(self, 481 group, 482 method, 483 request, 484 timeout, 485 metadata=None, 486 protocol_options=None): 487 request_serializer = self._request_serializers.get(( 488 group, 489 method, 490 )) 491 response_deserializer = self._response_deserializers.get(( 492 group, 493 method, 494 )) 495 return _unary_stream(self._channel, group, method, timeout, 496 protocol_options, metadata, 497 self._metadata_transformer, request, 498 request_serializer, response_deserializer) 499 500 def blocking_stream_unary(self, 501 group, 502 method, 503 request_iterator, 504 timeout, 505 metadata=None, 506 with_call=None, 507 protocol_options=None): 508 request_serializer = self._request_serializers.get(( 509 group, 510 method, 511 )) 512 response_deserializer = self._response_deserializers.get(( 513 group, 514 method, 515 )) 516 return _blocking_stream_unary(self._channel, group, method, timeout, 517 with_call, protocol_options, metadata, 518 self._metadata_transformer, 519 request_iterator, request_serializer, 520 response_deserializer) 521 522 def future_stream_unary(self, 523 group, 524 method, 525 request_iterator, 526 timeout, 527 metadata=None, 528 protocol_options=None): 529 request_serializer = self._request_serializers.get(( 530 group, 531 method, 532 )) 533 response_deserializer = self._response_deserializers.get(( 534 group, 535 method, 536 )) 537 return _future_stream_unary(self._channel, group, method, timeout, 538 protocol_options, metadata, 539 self._metadata_transformer, 540 request_iterator, request_serializer, 541 response_deserializer) 542 543 def inline_stream_stream(self, 544 group, 545 method, 546 request_iterator, 547 timeout, 548 metadata=None, 549 protocol_options=None): 550 request_serializer = self._request_serializers.get(( 551 group, 552 method, 553 )) 554 response_deserializer = self._response_deserializers.get(( 555 group, 556 method, 557 )) 558 return _stream_stream(self._channel, group, method, timeout, 559 protocol_options, metadata, 560 self._metadata_transformer, request_iterator, 561 request_serializer, response_deserializer) 562 563 def event_unary_unary(self, 564 group, 565 method, 566 request, 567 receiver, 568 abortion_callback, 569 timeout, 570 metadata=None, 571 protocol_options=None): 572 raise NotImplementedError() 573 574 def event_unary_stream(self, 575 group, 576 method, 577 request, 578 receiver, 579 abortion_callback, 580 timeout, 581 metadata=None, 582 protocol_options=None): 583 raise NotImplementedError() 584 585 def event_stream_unary(self, 586 group, 587 method, 588 receiver, 589 abortion_callback, 590 timeout, 591 metadata=None, 592 protocol_options=None): 593 raise NotImplementedError() 594 595 def event_stream_stream(self, 596 group, 597 method, 598 receiver, 599 abortion_callback, 600 timeout, 601 metadata=None, 602 protocol_options=None): 603 raise NotImplementedError() 604 605 def unary_unary(self, group, method): 606 request_serializer = self._request_serializers.get(( 607 group, 608 method, 609 )) 610 response_deserializer = self._response_deserializers.get(( 611 group, 612 method, 613 )) 614 return _UnaryUnaryMultiCallable(self._channel, group, method, 615 self._metadata_transformer, 616 request_serializer, 617 response_deserializer) 618 619 def unary_stream(self, group, method): 620 request_serializer = self._request_serializers.get(( 621 group, 622 method, 623 )) 624 response_deserializer = self._response_deserializers.get(( 625 group, 626 method, 627 )) 628 return _UnaryStreamMultiCallable(self._channel, group, method, 629 self._metadata_transformer, 630 request_serializer, 631 response_deserializer) 632 633 def stream_unary(self, group, method): 634 request_serializer = self._request_serializers.get(( 635 group, 636 method, 637 )) 638 response_deserializer = self._response_deserializers.get(( 639 group, 640 method, 641 )) 642 return _StreamUnaryMultiCallable(self._channel, group, method, 643 self._metadata_transformer, 644 request_serializer, 645 response_deserializer) 646 647 def stream_stream(self, group, method): 648 request_serializer = self._request_serializers.get(( 649 group, 650 method, 651 )) 652 response_deserializer = self._response_deserializers.get(( 653 group, 654 method, 655 )) 656 return _StreamStreamMultiCallable(self._channel, group, method, 657 self._metadata_transformer, 658 request_serializer, 659 response_deserializer) 660 661 def __enter__(self): 662 return self 663 664 def __exit__(self, exc_type, exc_val, exc_tb): 665 return False 666 667 668class _DynamicStub(face.DynamicStub): 669 670 def __init__(self, backing_generic_stub, group, cardinalities): 671 self._generic_stub = backing_generic_stub 672 self._group = group 673 self._cardinalities = cardinalities 674 675 def __getattr__(self, attr): 676 method_cardinality = self._cardinalities.get(attr) 677 if method_cardinality is cardinality.Cardinality.UNARY_UNARY: 678 return self._generic_stub.unary_unary(self._group, attr) 679 elif method_cardinality is cardinality.Cardinality.UNARY_STREAM: 680 return self._generic_stub.unary_stream(self._group, attr) 681 elif method_cardinality is cardinality.Cardinality.STREAM_UNARY: 682 return self._generic_stub.stream_unary(self._group, attr) 683 elif method_cardinality is cardinality.Cardinality.STREAM_STREAM: 684 return self._generic_stub.stream_stream(self._group, attr) 685 else: 686 raise AttributeError('_DynamicStub object has no attribute "%s"!' % 687 attr) 688 689 def __enter__(self): 690 return self 691 692 def __exit__(self, exc_type, exc_val, exc_tb): 693 return False 694 695 696def generic_stub(channel, host, metadata_transformer, request_serializers, 697 response_deserializers): 698 return _GenericStub(channel, metadata_transformer, request_serializers, 699 response_deserializers) 700 701 702def dynamic_stub(channel, service, cardinalities, host, metadata_transformer, 703 request_serializers, response_deserializers): 704 return _DynamicStub( 705 _GenericStub(channel, metadata_transformer, request_serializers, 706 response_deserializers), service, cardinalities) 707