1# Copyright 2016 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Translates gRPC's client-side API into gRPC's client-side Beta API.""" 15 16import grpc 17from grpc import _common 18from grpc.beta import _metadata 19from grpc.beta import interfaces 20from grpc.framework.common import cardinality 21from grpc.framework.foundation import future 22from grpc.framework.interfaces.face import face 23 24# pylint: disable=too-many-arguments,too-many-locals,unused-argument 25 26_STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS = { 27 grpc.StatusCode.CANCELLED: (face.Abortion.Kind.CANCELLED, 28 face.CancellationError), 29 grpc.StatusCode.UNKNOWN: (face.Abortion.Kind.REMOTE_FAILURE, 30 face.RemoteError), 31 grpc.StatusCode.DEADLINE_EXCEEDED: (face.Abortion.Kind.EXPIRED, 32 face.ExpirationError), 33 grpc.StatusCode.UNIMPLEMENTED: (face.Abortion.Kind.LOCAL_FAILURE, 34 face.LocalError), 35} 36 37 38def _effective_metadata(metadata, metadata_transformer): 39 non_none_metadata = () if metadata is None else metadata 40 if metadata_transformer is None: 41 return non_none_metadata 42 else: 43 return metadata_transformer(non_none_metadata) 44 45 46def _credentials(grpc_call_options): 47 return None if grpc_call_options is None else grpc_call_options.credentials 48 49 50def _abortion(rpc_error_call): 51 code = rpc_error_call.code() 52 pair = _STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS.get(code) 53 error_kind = face.Abortion.Kind.LOCAL_FAILURE if pair is None else pair[0] 54 return face.Abortion(error_kind, rpc_error_call.initial_metadata(), 55 rpc_error_call.trailing_metadata(), code, 56 rpc_error_call.details()) 57 58 59def _abortion_error(rpc_error_call): 60 code = rpc_error_call.code() 61 pair = _STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS.get(code) 62 exception_class = face.AbortionError if pair is None else pair[1] 63 return exception_class(rpc_error_call.initial_metadata(), 64 rpc_error_call.trailing_metadata(), code, 65 rpc_error_call.details()) 66 67 68class _InvocationProtocolContext(interfaces.GRPCInvocationContext): 69 70 def disable_next_request_compression(self): 71 pass # TODO(https://github.com/grpc/grpc/issues/4078): design, implement. 72 73 74class _Rendezvous(future.Future, face.Call): 75 76 def __init__(self, response_future, response_iterator, call): 77 self._future = response_future 78 self._iterator = response_iterator 79 self._call = call 80 81 def cancel(self): 82 return self._call.cancel() 83 84 def cancelled(self): 85 return self._future.cancelled() 86 87 def running(self): 88 return self._future.running() 89 90 def done(self): 91 return self._future.done() 92 93 def result(self, timeout=None): 94 try: 95 return self._future.result(timeout=timeout) 96 except grpc.RpcError as rpc_error_call: 97 raise _abortion_error(rpc_error_call) 98 except grpc.FutureTimeoutError: 99 raise future.TimeoutError() 100 except grpc.FutureCancelledError: 101 raise future.CancelledError() 102 103 def exception(self, timeout=None): 104 try: 105 rpc_error_call = self._future.exception(timeout=timeout) 106 if rpc_error_call is None: 107 return None 108 else: 109 return _abortion_error(rpc_error_call) 110 except grpc.FutureTimeoutError: 111 raise future.TimeoutError() 112 except grpc.FutureCancelledError: 113 raise future.CancelledError() 114 115 def traceback(self, timeout=None): 116 try: 117 return self._future.traceback(timeout=timeout) 118 except grpc.FutureTimeoutError: 119 raise future.TimeoutError() 120 except grpc.FutureCancelledError: 121 raise future.CancelledError() 122 123 def add_done_callback(self, fn): 124 self._future.add_done_callback(lambda ignored_callback: fn(self)) 125 126 def __iter__(self): 127 return self 128 129 def _next(self): 130 try: 131 return next(self._iterator) 132 except grpc.RpcError as rpc_error_call: 133 raise _abortion_error(rpc_error_call) 134 135 def __next__(self): 136 return self._next() 137 138 def next(self): 139 return self._next() 140 141 def is_active(self): 142 return self._call.is_active() 143 144 def time_remaining(self): 145 return self._call.time_remaining() 146 147 def add_abortion_callback(self, abortion_callback): 148 149 def done_callback(): 150 if self.code() is not grpc.StatusCode.OK: 151 abortion_callback(_abortion(self._call)) 152 153 registered = self._call.add_callback(done_callback) 154 return None if registered else done_callback() 155 156 def protocol_context(self): 157 return _InvocationProtocolContext() 158 159 def initial_metadata(self): 160 return _metadata.beta(self._call.initial_metadata()) 161 162 def terminal_metadata(self): 163 return _metadata.beta(self._call.terminal_metadata()) 164 165 def code(self): 166 return self._call.code() 167 168 def details(self): 169 return self._call.details() 170 171 172def _blocking_unary_unary(channel, group, method, timeout, with_call, 173 protocol_options, metadata, metadata_transformer, 174 request, request_serializer, response_deserializer): 175 try: 176 multi_callable = channel.unary_unary( 177 _common.fully_qualified_method(group, method), 178 request_serializer=request_serializer, 179 response_deserializer=response_deserializer) 180 effective_metadata = _effective_metadata(metadata, metadata_transformer) 181 if with_call: 182 response, call = multi_callable.with_call( 183 request, 184 timeout=timeout, 185 metadata=_metadata.unbeta(effective_metadata), 186 credentials=_credentials(protocol_options)) 187 return response, _Rendezvous(None, None, call) 188 else: 189 return multi_callable( 190 request, 191 timeout=timeout, 192 metadata=_metadata.unbeta(effective_metadata), 193 credentials=_credentials(protocol_options)) 194 except grpc.RpcError as rpc_error_call: 195 raise _abortion_error(rpc_error_call) 196 197 198def _future_unary_unary(channel, group, method, timeout, protocol_options, 199 metadata, metadata_transformer, request, 200 request_serializer, response_deserializer): 201 multi_callable = channel.unary_unary( 202 _common.fully_qualified_method(group, method), 203 request_serializer=request_serializer, 204 response_deserializer=response_deserializer) 205 effective_metadata = _effective_metadata(metadata, metadata_transformer) 206 response_future = multi_callable.future( 207 request, 208 timeout=timeout, 209 metadata=_metadata.unbeta(effective_metadata), 210 credentials=_credentials(protocol_options)) 211 return _Rendezvous(response_future, None, response_future) 212 213 214def _unary_stream(channel, group, method, timeout, protocol_options, metadata, 215 metadata_transformer, request, request_serializer, 216 response_deserializer): 217 multi_callable = channel.unary_stream( 218 _common.fully_qualified_method(group, method), 219 request_serializer=request_serializer, 220 response_deserializer=response_deserializer) 221 effective_metadata = _effective_metadata(metadata, metadata_transformer) 222 response_iterator = multi_callable( 223 request, 224 timeout=timeout, 225 metadata=_metadata.unbeta(effective_metadata), 226 credentials=_credentials(protocol_options)) 227 return _Rendezvous(None, response_iterator, response_iterator) 228 229 230def _blocking_stream_unary(channel, group, method, timeout, with_call, 231 protocol_options, metadata, metadata_transformer, 232 request_iterator, request_serializer, 233 response_deserializer): 234 try: 235 multi_callable = channel.stream_unary( 236 _common.fully_qualified_method(group, method), 237 request_serializer=request_serializer, 238 response_deserializer=response_deserializer) 239 effective_metadata = _effective_metadata(metadata, metadata_transformer) 240 if with_call: 241 response, call = multi_callable.with_call( 242 request_iterator, 243 timeout=timeout, 244 metadata=_metadata.unbeta(effective_metadata), 245 credentials=_credentials(protocol_options)) 246 return response, _Rendezvous(None, None, call) 247 else: 248 return multi_callable( 249 request_iterator, 250 timeout=timeout, 251 metadata=_metadata.unbeta(effective_metadata), 252 credentials=_credentials(protocol_options)) 253 except grpc.RpcError as rpc_error_call: 254 raise _abortion_error(rpc_error_call) 255 256 257def _future_stream_unary(channel, group, method, timeout, protocol_options, 258 metadata, metadata_transformer, request_iterator, 259 request_serializer, response_deserializer): 260 multi_callable = channel.stream_unary( 261 _common.fully_qualified_method(group, method), 262 request_serializer=request_serializer, 263 response_deserializer=response_deserializer) 264 effective_metadata = _effective_metadata(metadata, metadata_transformer) 265 response_future = multi_callable.future( 266 request_iterator, 267 timeout=timeout, 268 metadata=_metadata.unbeta(effective_metadata), 269 credentials=_credentials(protocol_options)) 270 return _Rendezvous(response_future, None, response_future) 271 272 273def _stream_stream(channel, group, method, timeout, protocol_options, metadata, 274 metadata_transformer, request_iterator, request_serializer, 275 response_deserializer): 276 multi_callable = channel.stream_stream( 277 _common.fully_qualified_method(group, method), 278 request_serializer=request_serializer, 279 response_deserializer=response_deserializer) 280 effective_metadata = _effective_metadata(metadata, metadata_transformer) 281 response_iterator = multi_callable( 282 request_iterator, 283 timeout=timeout, 284 metadata=_metadata.unbeta(effective_metadata), 285 credentials=_credentials(protocol_options)) 286 return _Rendezvous(None, response_iterator, response_iterator) 287 288 289class _UnaryUnaryMultiCallable(face.UnaryUnaryMultiCallable): 290 291 def __init__(self, channel, group, method, metadata_transformer, 292 request_serializer, response_deserializer): 293 self._channel = channel 294 self._group = group 295 self._method = method 296 self._metadata_transformer = metadata_transformer 297 self._request_serializer = request_serializer 298 self._response_deserializer = response_deserializer 299 300 def __call__(self, 301 request, 302 timeout, 303 metadata=None, 304 with_call=False, 305 protocol_options=None): 306 return _blocking_unary_unary( 307 self._channel, self._group, self._method, timeout, with_call, 308 protocol_options, metadata, self._metadata_transformer, request, 309 self._request_serializer, self._response_deserializer) 310 311 def future(self, request, timeout, metadata=None, protocol_options=None): 312 return _future_unary_unary( 313 self._channel, self._group, self._method, timeout, protocol_options, 314 metadata, self._metadata_transformer, request, 315 self._request_serializer, self._response_deserializer) 316 317 def event(self, 318 request, 319 receiver, 320 abortion_callback, 321 timeout, 322 metadata=None, 323 protocol_options=None): 324 raise NotImplementedError() 325 326 327class _UnaryStreamMultiCallable(face.UnaryStreamMultiCallable): 328 329 def __init__(self, channel, group, method, metadata_transformer, 330 request_serializer, response_deserializer): 331 self._channel = channel 332 self._group = group 333 self._method = method 334 self._metadata_transformer = metadata_transformer 335 self._request_serializer = request_serializer 336 self._response_deserializer = response_deserializer 337 338 def __call__(self, request, timeout, metadata=None, protocol_options=None): 339 return _unary_stream( 340 self._channel, self._group, self._method, timeout, protocol_options, 341 metadata, self._metadata_transformer, request, 342 self._request_serializer, self._response_deserializer) 343 344 def event(self, 345 request, 346 receiver, 347 abortion_callback, 348 timeout, 349 metadata=None, 350 protocol_options=None): 351 raise NotImplementedError() 352 353 354class _StreamUnaryMultiCallable(face.StreamUnaryMultiCallable): 355 356 def __init__(self, channel, group, method, metadata_transformer, 357 request_serializer, response_deserializer): 358 self._channel = channel 359 self._group = group 360 self._method = method 361 self._metadata_transformer = metadata_transformer 362 self._request_serializer = request_serializer 363 self._response_deserializer = response_deserializer 364 365 def __call__(self, 366 request_iterator, 367 timeout, 368 metadata=None, 369 with_call=False, 370 protocol_options=None): 371 return _blocking_stream_unary( 372 self._channel, self._group, self._method, timeout, with_call, 373 protocol_options, metadata, self._metadata_transformer, 374 request_iterator, self._request_serializer, 375 self._response_deserializer) 376 377 def future(self, 378 request_iterator, 379 timeout, 380 metadata=None, 381 protocol_options=None): 382 return _future_stream_unary( 383 self._channel, self._group, self._method, timeout, protocol_options, 384 metadata, self._metadata_transformer, request_iterator, 385 self._request_serializer, self._response_deserializer) 386 387 def event(self, 388 receiver, 389 abortion_callback, 390 timeout, 391 metadata=None, 392 protocol_options=None): 393 raise NotImplementedError() 394 395 396class _StreamStreamMultiCallable(face.StreamStreamMultiCallable): 397 398 def __init__(self, channel, group, method, metadata_transformer, 399 request_serializer, response_deserializer): 400 self._channel = channel 401 self._group = group 402 self._method = method 403 self._metadata_transformer = metadata_transformer 404 self._request_serializer = request_serializer 405 self._response_deserializer = response_deserializer 406 407 def __call__(self, 408 request_iterator, 409 timeout, 410 metadata=None, 411 protocol_options=None): 412 return _stream_stream( 413 self._channel, self._group, self._method, timeout, protocol_options, 414 metadata, self._metadata_transformer, request_iterator, 415 self._request_serializer, self._response_deserializer) 416 417 def event(self, 418 receiver, 419 abortion_callback, 420 timeout, 421 metadata=None, 422 protocol_options=None): 423 raise NotImplementedError() 424 425 426class _GenericStub(face.GenericStub): 427 428 def __init__(self, channel, metadata_transformer, request_serializers, 429 response_deserializers): 430 self._channel = channel 431 self._metadata_transformer = metadata_transformer 432 self._request_serializers = request_serializers or {} 433 self._response_deserializers = response_deserializers or {} 434 435 def blocking_unary_unary(self, 436 group, 437 method, 438 request, 439 timeout, 440 metadata=None, 441 with_call=None, 442 protocol_options=None): 443 request_serializer = self._request_serializers.get(( 444 group, 445 method, 446 )) 447 response_deserializer = self._response_deserializers.get(( 448 group, 449 method, 450 )) 451 return _blocking_unary_unary(self._channel, group, method, timeout, 452 with_call, protocol_options, metadata, 453 self._metadata_transformer, request, 454 request_serializer, response_deserializer) 455 456 def future_unary_unary(self, 457 group, 458 method, 459 request, 460 timeout, 461 metadata=None, 462 protocol_options=None): 463 request_serializer = self._request_serializers.get(( 464 group, 465 method, 466 )) 467 response_deserializer = self._response_deserializers.get(( 468 group, 469 method, 470 )) 471 return _future_unary_unary(self._channel, group, method, timeout, 472 protocol_options, metadata, 473 self._metadata_transformer, request, 474 request_serializer, response_deserializer) 475 476 def inline_unary_stream(self, 477 group, 478 method, 479 request, 480 timeout, 481 metadata=None, 482 protocol_options=None): 483 request_serializer = self._request_serializers.get(( 484 group, 485 method, 486 )) 487 response_deserializer = self._response_deserializers.get(( 488 group, 489 method, 490 )) 491 return _unary_stream(self._channel, group, method, timeout, 492 protocol_options, metadata, 493 self._metadata_transformer, request, 494 request_serializer, response_deserializer) 495 496 def blocking_stream_unary(self, 497 group, 498 method, 499 request_iterator, 500 timeout, 501 metadata=None, 502 with_call=None, 503 protocol_options=None): 504 request_serializer = self._request_serializers.get(( 505 group, 506 method, 507 )) 508 response_deserializer = self._response_deserializers.get(( 509 group, 510 method, 511 )) 512 return _blocking_stream_unary( 513 self._channel, group, method, timeout, with_call, protocol_options, 514 metadata, self._metadata_transformer, request_iterator, 515 request_serializer, response_deserializer) 516 517 def future_stream_unary(self, 518 group, 519 method, 520 request_iterator, 521 timeout, 522 metadata=None, 523 protocol_options=None): 524 request_serializer = self._request_serializers.get(( 525 group, 526 method, 527 )) 528 response_deserializer = self._response_deserializers.get(( 529 group, 530 method, 531 )) 532 return _future_stream_unary( 533 self._channel, group, method, timeout, protocol_options, metadata, 534 self._metadata_transformer, request_iterator, request_serializer, 535 response_deserializer) 536 537 def inline_stream_stream(self, 538 group, 539 method, 540 request_iterator, 541 timeout, 542 metadata=None, 543 protocol_options=None): 544 request_serializer = self._request_serializers.get(( 545 group, 546 method, 547 )) 548 response_deserializer = self._response_deserializers.get(( 549 group, 550 method, 551 )) 552 return _stream_stream(self._channel, group, method, timeout, 553 protocol_options, metadata, 554 self._metadata_transformer, request_iterator, 555 request_serializer, response_deserializer) 556 557 def event_unary_unary(self, 558 group, 559 method, 560 request, 561 receiver, 562 abortion_callback, 563 timeout, 564 metadata=None, 565 protocol_options=None): 566 raise NotImplementedError() 567 568 def event_unary_stream(self, 569 group, 570 method, 571 request, 572 receiver, 573 abortion_callback, 574 timeout, 575 metadata=None, 576 protocol_options=None): 577 raise NotImplementedError() 578 579 def event_stream_unary(self, 580 group, 581 method, 582 receiver, 583 abortion_callback, 584 timeout, 585 metadata=None, 586 protocol_options=None): 587 raise NotImplementedError() 588 589 def event_stream_stream(self, 590 group, 591 method, 592 receiver, 593 abortion_callback, 594 timeout, 595 metadata=None, 596 protocol_options=None): 597 raise NotImplementedError() 598 599 def unary_unary(self, group, method): 600 request_serializer = self._request_serializers.get(( 601 group, 602 method, 603 )) 604 response_deserializer = self._response_deserializers.get(( 605 group, 606 method, 607 )) 608 return _UnaryUnaryMultiCallable( 609 self._channel, group, method, self._metadata_transformer, 610 request_serializer, response_deserializer) 611 612 def unary_stream(self, group, method): 613 request_serializer = self._request_serializers.get(( 614 group, 615 method, 616 )) 617 response_deserializer = self._response_deserializers.get(( 618 group, 619 method, 620 )) 621 return _UnaryStreamMultiCallable( 622 self._channel, group, method, self._metadata_transformer, 623 request_serializer, response_deserializer) 624 625 def stream_unary(self, group, method): 626 request_serializer = self._request_serializers.get(( 627 group, 628 method, 629 )) 630 response_deserializer = self._response_deserializers.get(( 631 group, 632 method, 633 )) 634 return _StreamUnaryMultiCallable( 635 self._channel, group, method, self._metadata_transformer, 636 request_serializer, response_deserializer) 637 638 def stream_stream(self, group, method): 639 request_serializer = self._request_serializers.get(( 640 group, 641 method, 642 )) 643 response_deserializer = self._response_deserializers.get(( 644 group, 645 method, 646 )) 647 return _StreamStreamMultiCallable( 648 self._channel, group, method, self._metadata_transformer, 649 request_serializer, response_deserializer) 650 651 def __enter__(self): 652 return self 653 654 def __exit__(self, exc_type, exc_val, exc_tb): 655 return False 656 657 658class _DynamicStub(face.DynamicStub): 659 660 def __init__(self, backing_generic_stub, group, cardinalities): 661 self._generic_stub = backing_generic_stub 662 self._group = group 663 self._cardinalities = cardinalities 664 665 def __getattr__(self, attr): 666 method_cardinality = self._cardinalities.get(attr) 667 if method_cardinality is cardinality.Cardinality.UNARY_UNARY: 668 return self._generic_stub.unary_unary(self._group, attr) 669 elif method_cardinality is cardinality.Cardinality.UNARY_STREAM: 670 return self._generic_stub.unary_stream(self._group, attr) 671 elif method_cardinality is cardinality.Cardinality.STREAM_UNARY: 672 return self._generic_stub.stream_unary(self._group, attr) 673 elif method_cardinality is cardinality.Cardinality.STREAM_STREAM: 674 return self._generic_stub.stream_stream(self._group, attr) 675 else: 676 raise AttributeError( 677 '_DynamicStub object has no attribute "%s"!' % attr) 678 679 def __enter__(self): 680 return self 681 682 def __exit__(self, exc_type, exc_val, exc_tb): 683 return False 684 685 686def generic_stub(channel, host, metadata_transformer, request_serializers, 687 response_deserializers): 688 return _GenericStub(channel, metadata_transformer, request_serializers, 689 response_deserializers) 690 691 692def dynamic_stub(channel, service, cardinalities, host, metadata_transformer, 693 request_serializers, response_deserializers): 694 return _DynamicStub( 695 _GenericStub(channel, metadata_transformer, request_serializers, 696 response_deserializers), service, cardinalities) 697