• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Translates gRPC's client-side API into gRPC's client-side Beta API."""
15
16import grpc
17from grpc import _common
18from grpc.beta import _metadata
19from grpc.beta import interfaces
20from grpc.framework.common import cardinality
21from grpc.framework.foundation import future
22from grpc.framework.interfaces.face import face
23
24# pylint: disable=too-many-arguments,too-many-locals,unused-argument
25
26_STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS = {
27    grpc.StatusCode.CANCELLED:
28        (face.Abortion.Kind.CANCELLED, face.CancellationError),
29    grpc.StatusCode.UNKNOWN:
30        (face.Abortion.Kind.REMOTE_FAILURE, face.RemoteError),
31    grpc.StatusCode.DEADLINE_EXCEEDED:
32        (face.Abortion.Kind.EXPIRED, face.ExpirationError),
33    grpc.StatusCode.UNIMPLEMENTED:
34        (face.Abortion.Kind.LOCAL_FAILURE, face.LocalError),
35}
36
37
38def _effective_metadata(metadata, metadata_transformer):
39    non_none_metadata = () if metadata is None else metadata
40    if metadata_transformer is None:
41        return non_none_metadata
42    else:
43        return metadata_transformer(non_none_metadata)
44
45
46def _credentials(grpc_call_options):
47    return None if grpc_call_options is None else grpc_call_options.credentials
48
49
50def _abortion(rpc_error_call):
51    code = rpc_error_call.code()
52    pair = _STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS.get(code)
53    error_kind = face.Abortion.Kind.LOCAL_FAILURE if pair is None else pair[0]
54    return face.Abortion(error_kind, rpc_error_call.initial_metadata(),
55                         rpc_error_call.trailing_metadata(), code,
56                         rpc_error_call.details())
57
58
59def _abortion_error(rpc_error_call):
60    code = rpc_error_call.code()
61    pair = _STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS.get(code)
62    exception_class = face.AbortionError if pair is None else pair[1]
63    return exception_class(rpc_error_call.initial_metadata(),
64                           rpc_error_call.trailing_metadata(), code,
65                           rpc_error_call.details())
66
67
68class _InvocationProtocolContext(interfaces.GRPCInvocationContext):
69
70    def disable_next_request_compression(self):
71        pass  # TODO(https://github.com/grpc/grpc/issues/4078): design, implement.
72
73
74class _Rendezvous(future.Future, face.Call):
75
76    def __init__(self, response_future, response_iterator, call):
77        self._future = response_future
78        self._iterator = response_iterator
79        self._call = call
80
81    def cancel(self):
82        return self._call.cancel()
83
84    def cancelled(self):
85        return self._future.cancelled()
86
87    def running(self):
88        return self._future.running()
89
90    def done(self):
91        return self._future.done()
92
93    def result(self, timeout=None):
94        try:
95            return self._future.result(timeout=timeout)
96        except grpc.RpcError as rpc_error_call:
97            raise _abortion_error(rpc_error_call)
98        except grpc.FutureTimeoutError:
99            raise future.TimeoutError()
100        except grpc.FutureCancelledError:
101            raise future.CancelledError()
102
103    def exception(self, timeout=None):
104        try:
105            rpc_error_call = self._future.exception(timeout=timeout)
106            if rpc_error_call is None:
107                return None
108            else:
109                return _abortion_error(rpc_error_call)
110        except grpc.FutureTimeoutError:
111            raise future.TimeoutError()
112        except grpc.FutureCancelledError:
113            raise future.CancelledError()
114
115    def traceback(self, timeout=None):
116        try:
117            return self._future.traceback(timeout=timeout)
118        except grpc.FutureTimeoutError:
119            raise future.TimeoutError()
120        except grpc.FutureCancelledError:
121            raise future.CancelledError()
122
123    def add_done_callback(self, fn):
124        self._future.add_done_callback(lambda ignored_callback: fn(self))
125
126    def __iter__(self):
127        return self
128
129    def _next(self):
130        try:
131            return next(self._iterator)
132        except grpc.RpcError as rpc_error_call:
133            raise _abortion_error(rpc_error_call)
134
135    def __next__(self):
136        return self._next()
137
138    def next(self):
139        return self._next()
140
141    def is_active(self):
142        return self._call.is_active()
143
144    def time_remaining(self):
145        return self._call.time_remaining()
146
147    def add_abortion_callback(self, abortion_callback):
148
149        def done_callback():
150            if self.code() is not grpc.StatusCode.OK:
151                abortion_callback(_abortion(self._call))
152
153        registered = self._call.add_callback(done_callback)
154        return None if registered else done_callback()
155
156    def protocol_context(self):
157        return _InvocationProtocolContext()
158
159    def initial_metadata(self):
160        return _metadata.beta(self._call.initial_metadata())
161
162    def terminal_metadata(self):
163        return _metadata.beta(self._call.terminal_metadata())
164
165    def code(self):
166        return self._call.code()
167
168    def details(self):
169        return self._call.details()
170
171
172def _blocking_unary_unary(channel, group, method, timeout, with_call,
173                          protocol_options, metadata, metadata_transformer,
174                          request, request_serializer, response_deserializer):
175    try:
176        multi_callable = channel.unary_unary(
177            _common.fully_qualified_method(group, method),
178            request_serializer=request_serializer,
179            response_deserializer=response_deserializer)
180        effective_metadata = _effective_metadata(metadata, metadata_transformer)
181        if with_call:
182            response, call = multi_callable.with_call(
183                request,
184                timeout=timeout,
185                metadata=_metadata.unbeta(effective_metadata),
186                credentials=_credentials(protocol_options))
187            return response, _Rendezvous(None, None, call)
188        else:
189            return multi_callable(request,
190                                  timeout=timeout,
191                                  metadata=_metadata.unbeta(effective_metadata),
192                                  credentials=_credentials(protocol_options))
193    except grpc.RpcError as rpc_error_call:
194        raise _abortion_error(rpc_error_call)
195
196
197def _future_unary_unary(channel, group, method, timeout, protocol_options,
198                        metadata, metadata_transformer, request,
199                        request_serializer, response_deserializer):
200    multi_callable = channel.unary_unary(
201        _common.fully_qualified_method(group, method),
202        request_serializer=request_serializer,
203        response_deserializer=response_deserializer)
204    effective_metadata = _effective_metadata(metadata, metadata_transformer)
205    response_future = multi_callable.future(
206        request,
207        timeout=timeout,
208        metadata=_metadata.unbeta(effective_metadata),
209        credentials=_credentials(protocol_options))
210    return _Rendezvous(response_future, None, response_future)
211
212
213def _unary_stream(channel, group, method, timeout, protocol_options, metadata,
214                  metadata_transformer, request, request_serializer,
215                  response_deserializer):
216    multi_callable = channel.unary_stream(
217        _common.fully_qualified_method(group, method),
218        request_serializer=request_serializer,
219        response_deserializer=response_deserializer)
220    effective_metadata = _effective_metadata(metadata, metadata_transformer)
221    response_iterator = multi_callable(
222        request,
223        timeout=timeout,
224        metadata=_metadata.unbeta(effective_metadata),
225        credentials=_credentials(protocol_options))
226    return _Rendezvous(None, response_iterator, response_iterator)
227
228
229def _blocking_stream_unary(channel, group, method, timeout, with_call,
230                           protocol_options, metadata, metadata_transformer,
231                           request_iterator, request_serializer,
232                           response_deserializer):
233    try:
234        multi_callable = channel.stream_unary(
235            _common.fully_qualified_method(group, method),
236            request_serializer=request_serializer,
237            response_deserializer=response_deserializer)
238        effective_metadata = _effective_metadata(metadata, metadata_transformer)
239        if with_call:
240            response, call = multi_callable.with_call(
241                request_iterator,
242                timeout=timeout,
243                metadata=_metadata.unbeta(effective_metadata),
244                credentials=_credentials(protocol_options))
245            return response, _Rendezvous(None, None, call)
246        else:
247            return multi_callable(request_iterator,
248                                  timeout=timeout,
249                                  metadata=_metadata.unbeta(effective_metadata),
250                                  credentials=_credentials(protocol_options))
251    except grpc.RpcError as rpc_error_call:
252        raise _abortion_error(rpc_error_call)
253
254
255def _future_stream_unary(channel, group, method, timeout, protocol_options,
256                         metadata, metadata_transformer, request_iterator,
257                         request_serializer, response_deserializer):
258    multi_callable = channel.stream_unary(
259        _common.fully_qualified_method(group, method),
260        request_serializer=request_serializer,
261        response_deserializer=response_deserializer)
262    effective_metadata = _effective_metadata(metadata, metadata_transformer)
263    response_future = multi_callable.future(
264        request_iterator,
265        timeout=timeout,
266        metadata=_metadata.unbeta(effective_metadata),
267        credentials=_credentials(protocol_options))
268    return _Rendezvous(response_future, None, response_future)
269
270
271def _stream_stream(channel, group, method, timeout, protocol_options, metadata,
272                   metadata_transformer, request_iterator, request_serializer,
273                   response_deserializer):
274    multi_callable = channel.stream_stream(
275        _common.fully_qualified_method(group, method),
276        request_serializer=request_serializer,
277        response_deserializer=response_deserializer)
278    effective_metadata = _effective_metadata(metadata, metadata_transformer)
279    response_iterator = multi_callable(
280        request_iterator,
281        timeout=timeout,
282        metadata=_metadata.unbeta(effective_metadata),
283        credentials=_credentials(protocol_options))
284    return _Rendezvous(None, response_iterator, response_iterator)
285
286
287class _UnaryUnaryMultiCallable(face.UnaryUnaryMultiCallable):
288
289    def __init__(self, channel, group, method, metadata_transformer,
290                 request_serializer, response_deserializer):
291        self._channel = channel
292        self._group = group
293        self._method = method
294        self._metadata_transformer = metadata_transformer
295        self._request_serializer = request_serializer
296        self._response_deserializer = response_deserializer
297
298    def __call__(self,
299                 request,
300                 timeout,
301                 metadata=None,
302                 with_call=False,
303                 protocol_options=None):
304        return _blocking_unary_unary(self._channel, self._group, self._method,
305                                     timeout, with_call, protocol_options,
306                                     metadata, self._metadata_transformer,
307                                     request, self._request_serializer,
308                                     self._response_deserializer)
309
310    def future(self, request, timeout, metadata=None, protocol_options=None):
311        return _future_unary_unary(self._channel, self._group, self._method,
312                                   timeout, protocol_options, metadata,
313                                   self._metadata_transformer, request,
314                                   self._request_serializer,
315                                   self._response_deserializer)
316
317    def event(self,
318              request,
319              receiver,
320              abortion_callback,
321              timeout,
322              metadata=None,
323              protocol_options=None):
324        raise NotImplementedError()
325
326
327class _UnaryStreamMultiCallable(face.UnaryStreamMultiCallable):
328
329    def __init__(self, channel, group, method, metadata_transformer,
330                 request_serializer, response_deserializer):
331        self._channel = channel
332        self._group = group
333        self._method = method
334        self._metadata_transformer = metadata_transformer
335        self._request_serializer = request_serializer
336        self._response_deserializer = response_deserializer
337
338    def __call__(self, request, timeout, metadata=None, protocol_options=None):
339        return _unary_stream(self._channel, self._group, self._method, timeout,
340                             protocol_options, metadata,
341                             self._metadata_transformer, request,
342                             self._request_serializer,
343                             self._response_deserializer)
344
345    def event(self,
346              request,
347              receiver,
348              abortion_callback,
349              timeout,
350              metadata=None,
351              protocol_options=None):
352        raise NotImplementedError()
353
354
355class _StreamUnaryMultiCallable(face.StreamUnaryMultiCallable):
356
357    def __init__(self, channel, group, method, metadata_transformer,
358                 request_serializer, response_deserializer):
359        self._channel = channel
360        self._group = group
361        self._method = method
362        self._metadata_transformer = metadata_transformer
363        self._request_serializer = request_serializer
364        self._response_deserializer = response_deserializer
365
366    def __call__(self,
367                 request_iterator,
368                 timeout,
369                 metadata=None,
370                 with_call=False,
371                 protocol_options=None):
372        return _blocking_stream_unary(self._channel, self._group, self._method,
373                                      timeout, with_call, protocol_options,
374                                      metadata, self._metadata_transformer,
375                                      request_iterator,
376                                      self._request_serializer,
377                                      self._response_deserializer)
378
379    def future(self,
380               request_iterator,
381               timeout,
382               metadata=None,
383               protocol_options=None):
384        return _future_stream_unary(self._channel, self._group, self._method,
385                                    timeout, protocol_options, metadata,
386                                    self._metadata_transformer,
387                                    request_iterator, self._request_serializer,
388                                    self._response_deserializer)
389
390    def event(self,
391              receiver,
392              abortion_callback,
393              timeout,
394              metadata=None,
395              protocol_options=None):
396        raise NotImplementedError()
397
398
399class _StreamStreamMultiCallable(face.StreamStreamMultiCallable):
400
401    def __init__(self, channel, group, method, metadata_transformer,
402                 request_serializer, response_deserializer):
403        self._channel = channel
404        self._group = group
405        self._method = method
406        self._metadata_transformer = metadata_transformer
407        self._request_serializer = request_serializer
408        self._response_deserializer = response_deserializer
409
410    def __call__(self,
411                 request_iterator,
412                 timeout,
413                 metadata=None,
414                 protocol_options=None):
415        return _stream_stream(self._channel, self._group, self._method, timeout,
416                              protocol_options, metadata,
417                              self._metadata_transformer, request_iterator,
418                              self._request_serializer,
419                              self._response_deserializer)
420
421    def event(self,
422              receiver,
423              abortion_callback,
424              timeout,
425              metadata=None,
426              protocol_options=None):
427        raise NotImplementedError()
428
429
430class _GenericStub(face.GenericStub):
431
432    def __init__(self, channel, metadata_transformer, request_serializers,
433                 response_deserializers):
434        self._channel = channel
435        self._metadata_transformer = metadata_transformer
436        self._request_serializers = request_serializers or {}
437        self._response_deserializers = response_deserializers or {}
438
439    def blocking_unary_unary(self,
440                             group,
441                             method,
442                             request,
443                             timeout,
444                             metadata=None,
445                             with_call=None,
446                             protocol_options=None):
447        request_serializer = self._request_serializers.get((
448            group,
449            method,
450        ))
451        response_deserializer = self._response_deserializers.get((
452            group,
453            method,
454        ))
455        return _blocking_unary_unary(self._channel, group, method, timeout,
456                                     with_call, protocol_options, metadata,
457                                     self._metadata_transformer, request,
458                                     request_serializer, response_deserializer)
459
460    def future_unary_unary(self,
461                           group,
462                           method,
463                           request,
464                           timeout,
465                           metadata=None,
466                           protocol_options=None):
467        request_serializer = self._request_serializers.get((
468            group,
469            method,
470        ))
471        response_deserializer = self._response_deserializers.get((
472            group,
473            method,
474        ))
475        return _future_unary_unary(self._channel, group, method, timeout,
476                                   protocol_options, metadata,
477                                   self._metadata_transformer, request,
478                                   request_serializer, response_deserializer)
479
480    def inline_unary_stream(self,
481                            group,
482                            method,
483                            request,
484                            timeout,
485                            metadata=None,
486                            protocol_options=None):
487        request_serializer = self._request_serializers.get((
488            group,
489            method,
490        ))
491        response_deserializer = self._response_deserializers.get((
492            group,
493            method,
494        ))
495        return _unary_stream(self._channel, group, method, timeout,
496                             protocol_options, metadata,
497                             self._metadata_transformer, request,
498                             request_serializer, response_deserializer)
499
500    def blocking_stream_unary(self,
501                              group,
502                              method,
503                              request_iterator,
504                              timeout,
505                              metadata=None,
506                              with_call=None,
507                              protocol_options=None):
508        request_serializer = self._request_serializers.get((
509            group,
510            method,
511        ))
512        response_deserializer = self._response_deserializers.get((
513            group,
514            method,
515        ))
516        return _blocking_stream_unary(self._channel, group, method, timeout,
517                                      with_call, protocol_options, metadata,
518                                      self._metadata_transformer,
519                                      request_iterator, request_serializer,
520                                      response_deserializer)
521
522    def future_stream_unary(self,
523                            group,
524                            method,
525                            request_iterator,
526                            timeout,
527                            metadata=None,
528                            protocol_options=None):
529        request_serializer = self._request_serializers.get((
530            group,
531            method,
532        ))
533        response_deserializer = self._response_deserializers.get((
534            group,
535            method,
536        ))
537        return _future_stream_unary(self._channel, group, method, timeout,
538                                    protocol_options, metadata,
539                                    self._metadata_transformer,
540                                    request_iterator, request_serializer,
541                                    response_deserializer)
542
543    def inline_stream_stream(self,
544                             group,
545                             method,
546                             request_iterator,
547                             timeout,
548                             metadata=None,
549                             protocol_options=None):
550        request_serializer = self._request_serializers.get((
551            group,
552            method,
553        ))
554        response_deserializer = self._response_deserializers.get((
555            group,
556            method,
557        ))
558        return _stream_stream(self._channel, group, method, timeout,
559                              protocol_options, metadata,
560                              self._metadata_transformer, request_iterator,
561                              request_serializer, response_deserializer)
562
563    def event_unary_unary(self,
564                          group,
565                          method,
566                          request,
567                          receiver,
568                          abortion_callback,
569                          timeout,
570                          metadata=None,
571                          protocol_options=None):
572        raise NotImplementedError()
573
574    def event_unary_stream(self,
575                           group,
576                           method,
577                           request,
578                           receiver,
579                           abortion_callback,
580                           timeout,
581                           metadata=None,
582                           protocol_options=None):
583        raise NotImplementedError()
584
585    def event_stream_unary(self,
586                           group,
587                           method,
588                           receiver,
589                           abortion_callback,
590                           timeout,
591                           metadata=None,
592                           protocol_options=None):
593        raise NotImplementedError()
594
595    def event_stream_stream(self,
596                            group,
597                            method,
598                            receiver,
599                            abortion_callback,
600                            timeout,
601                            metadata=None,
602                            protocol_options=None):
603        raise NotImplementedError()
604
605    def unary_unary(self, group, method):
606        request_serializer = self._request_serializers.get((
607            group,
608            method,
609        ))
610        response_deserializer = self._response_deserializers.get((
611            group,
612            method,
613        ))
614        return _UnaryUnaryMultiCallable(self._channel, group, method,
615                                        self._metadata_transformer,
616                                        request_serializer,
617                                        response_deserializer)
618
619    def unary_stream(self, group, method):
620        request_serializer = self._request_serializers.get((
621            group,
622            method,
623        ))
624        response_deserializer = self._response_deserializers.get((
625            group,
626            method,
627        ))
628        return _UnaryStreamMultiCallable(self._channel, group, method,
629                                         self._metadata_transformer,
630                                         request_serializer,
631                                         response_deserializer)
632
633    def stream_unary(self, group, method):
634        request_serializer = self._request_serializers.get((
635            group,
636            method,
637        ))
638        response_deserializer = self._response_deserializers.get((
639            group,
640            method,
641        ))
642        return _StreamUnaryMultiCallable(self._channel, group, method,
643                                         self._metadata_transformer,
644                                         request_serializer,
645                                         response_deserializer)
646
647    def stream_stream(self, group, method):
648        request_serializer = self._request_serializers.get((
649            group,
650            method,
651        ))
652        response_deserializer = self._response_deserializers.get((
653            group,
654            method,
655        ))
656        return _StreamStreamMultiCallable(self._channel, group, method,
657                                          self._metadata_transformer,
658                                          request_serializer,
659                                          response_deserializer)
660
661    def __enter__(self):
662        return self
663
664    def __exit__(self, exc_type, exc_val, exc_tb):
665        return False
666
667
668class _DynamicStub(face.DynamicStub):
669
670    def __init__(self, backing_generic_stub, group, cardinalities):
671        self._generic_stub = backing_generic_stub
672        self._group = group
673        self._cardinalities = cardinalities
674
675    def __getattr__(self, attr):
676        method_cardinality = self._cardinalities.get(attr)
677        if method_cardinality is cardinality.Cardinality.UNARY_UNARY:
678            return self._generic_stub.unary_unary(self._group, attr)
679        elif method_cardinality is cardinality.Cardinality.UNARY_STREAM:
680            return self._generic_stub.unary_stream(self._group, attr)
681        elif method_cardinality is cardinality.Cardinality.STREAM_UNARY:
682            return self._generic_stub.stream_unary(self._group, attr)
683        elif method_cardinality is cardinality.Cardinality.STREAM_STREAM:
684            return self._generic_stub.stream_stream(self._group, attr)
685        else:
686            raise AttributeError('_DynamicStub object has no attribute "%s"!' %
687                                 attr)
688
689    def __enter__(self):
690        return self
691
692    def __exit__(self, exc_type, exc_val, exc_tb):
693        return False
694
695
696def generic_stub(channel, host, metadata_transformer, request_serializers,
697                 response_deserializers):
698    return _GenericStub(channel, metadata_transformer, request_serializers,
699                        response_deserializers)
700
701
702def dynamic_stub(channel, service, cardinalities, host, metadata_transformer,
703                 request_serializers, response_deserializers):
704    return _DynamicStub(
705        _GenericStub(channel, metadata_transformer, request_serializers,
706                     response_deserializers), service, cardinalities)
707