• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Implementation of gRPC Python interceptors."""
15
16import collections
17import sys
18
19import grpc
20
21
22class _ServicePipeline(object):
23
24    def __init__(self, interceptors):
25        self.interceptors = tuple(interceptors)
26
27    def _continuation(self, thunk, index):
28        return lambda context: self._intercept_at(thunk, index, context)
29
30    def _intercept_at(self, thunk, index, context):
31        if index < len(self.interceptors):
32            interceptor = self.interceptors[index]
33            thunk = self._continuation(thunk, index + 1)
34            return interceptor.intercept_service(thunk, context)
35        else:
36            return thunk(context)
37
38    def execute(self, thunk, context):
39        return self._intercept_at(thunk, 0, context)
40
41
42def service_pipeline(interceptors):
43    return _ServicePipeline(interceptors) if interceptors else None
44
45
46class _ClientCallDetails(
47        collections.namedtuple('_ClientCallDetails',
48                               ('method', 'timeout', 'metadata', 'credentials',
49                                'wait_for_ready', 'compression')),
50        grpc.ClientCallDetails):
51    pass
52
53
54def _unwrap_client_call_details(call_details, default_details):
55    try:
56        method = call_details.method
57    except AttributeError:
58        method = default_details.method
59
60    try:
61        timeout = call_details.timeout
62    except AttributeError:
63        timeout = default_details.timeout
64
65    try:
66        metadata = call_details.metadata
67    except AttributeError:
68        metadata = default_details.metadata
69
70    try:
71        credentials = call_details.credentials
72    except AttributeError:
73        credentials = default_details.credentials
74
75    try:
76        wait_for_ready = call_details.wait_for_ready
77    except AttributeError:
78        wait_for_ready = default_details.wait_for_ready
79
80    try:
81        compression = call_details.compression
82    except AttributeError:
83        compression = default_details.compression
84
85    return method, timeout, metadata, credentials, wait_for_ready, compression
86
87
88class _FailureOutcome(grpc.RpcError, grpc.Future, grpc.Call):  # pylint: disable=too-many-ancestors
89
90    def __init__(self, exception, traceback):
91        super(_FailureOutcome, self).__init__()
92        self._exception = exception
93        self._traceback = traceback
94
95    def initial_metadata(self):
96        return None
97
98    def trailing_metadata(self):
99        return None
100
101    def code(self):
102        return grpc.StatusCode.INTERNAL
103
104    def details(self):
105        return 'Exception raised while intercepting the RPC'
106
107    def cancel(self):
108        return False
109
110    def cancelled(self):
111        return False
112
113    def is_active(self):
114        return False
115
116    def time_remaining(self):
117        return None
118
119    def running(self):
120        return False
121
122    def done(self):
123        return True
124
125    def result(self, ignored_timeout=None):
126        raise self._exception
127
128    def exception(self, ignored_timeout=None):
129        return self._exception
130
131    def traceback(self, ignored_timeout=None):
132        return self._traceback
133
134    def add_callback(self, unused_callback):
135        return False
136
137    def add_done_callback(self, fn):
138        fn(self)
139
140    def __iter__(self):
141        return self
142
143    def __next__(self):
144        raise self._exception
145
146    def next(self):
147        return self.__next__()
148
149
150class _UnaryOutcome(grpc.Call, grpc.Future):
151
152    def __init__(self, response, call):
153        self._response = response
154        self._call = call
155
156    def initial_metadata(self):
157        return self._call.initial_metadata()
158
159    def trailing_metadata(self):
160        return self._call.trailing_metadata()
161
162    def code(self):
163        return self._call.code()
164
165    def details(self):
166        return self._call.details()
167
168    def is_active(self):
169        return self._call.is_active()
170
171    def time_remaining(self):
172        return self._call.time_remaining()
173
174    def cancel(self):
175        return self._call.cancel()
176
177    def add_callback(self, callback):
178        return self._call.add_callback(callback)
179
180    def cancelled(self):
181        return False
182
183    def running(self):
184        return False
185
186    def done(self):
187        return True
188
189    def result(self, ignored_timeout=None):
190        return self._response
191
192    def exception(self, ignored_timeout=None):
193        return None
194
195    def traceback(self, ignored_timeout=None):
196        return None
197
198    def add_done_callback(self, fn):
199        fn(self)
200
201
202class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
203
204    def __init__(self, thunk, method, interceptor):
205        self._thunk = thunk
206        self._method = method
207        self._interceptor = interceptor
208
209    def __call__(self,
210                 request,
211                 timeout=None,
212                 metadata=None,
213                 credentials=None,
214                 wait_for_ready=None,
215                 compression=None):
216        response, ignored_call = self._with_call(request,
217                                                 timeout=timeout,
218                                                 metadata=metadata,
219                                                 credentials=credentials,
220                                                 wait_for_ready=wait_for_ready,
221                                                 compression=compression)
222        return response
223
224    def _with_call(self,
225                   request,
226                   timeout=None,
227                   metadata=None,
228                   credentials=None,
229                   wait_for_ready=None,
230                   compression=None):
231        client_call_details = _ClientCallDetails(self._method, timeout,
232                                                 metadata, credentials,
233                                                 wait_for_ready, compression)
234
235        def continuation(new_details, request):
236            (new_method, new_timeout, new_metadata, new_credentials,
237             new_wait_for_ready,
238             new_compression) = (_unwrap_client_call_details(
239                 new_details, client_call_details))
240            try:
241                response, call = self._thunk(new_method).with_call(
242                    request,
243                    timeout=new_timeout,
244                    metadata=new_metadata,
245                    credentials=new_credentials,
246                    wait_for_ready=new_wait_for_ready,
247                    compression=new_compression)
248                return _UnaryOutcome(response, call)
249            except grpc.RpcError as rpc_error:
250                return rpc_error
251            except Exception as exception:  # pylint:disable=broad-except
252                return _FailureOutcome(exception, sys.exc_info()[2])
253
254        call = self._interceptor.intercept_unary_unary(continuation,
255                                                       client_call_details,
256                                                       request)
257        return call.result(), call
258
259    def with_call(self,
260                  request,
261                  timeout=None,
262                  metadata=None,
263                  credentials=None,
264                  wait_for_ready=None,
265                  compression=None):
266        return self._with_call(request,
267                               timeout=timeout,
268                               metadata=metadata,
269                               credentials=credentials,
270                               wait_for_ready=wait_for_ready,
271                               compression=compression)
272
273    def future(self,
274               request,
275               timeout=None,
276               metadata=None,
277               credentials=None,
278               wait_for_ready=None,
279               compression=None):
280        client_call_details = _ClientCallDetails(self._method, timeout,
281                                                 metadata, credentials,
282                                                 wait_for_ready, compression)
283
284        def continuation(new_details, request):
285            (new_method, new_timeout, new_metadata, new_credentials,
286             new_wait_for_ready,
287             new_compression) = (_unwrap_client_call_details(
288                 new_details, client_call_details))
289            return self._thunk(new_method).future(
290                request,
291                timeout=new_timeout,
292                metadata=new_metadata,
293                credentials=new_credentials,
294                wait_for_ready=new_wait_for_ready,
295                compression=new_compression)
296
297        try:
298            return self._interceptor.intercept_unary_unary(
299                continuation, client_call_details, request)
300        except Exception as exception:  # pylint:disable=broad-except
301            return _FailureOutcome(exception, sys.exc_info()[2])
302
303
304class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
305
306    def __init__(self, thunk, method, interceptor):
307        self._thunk = thunk
308        self._method = method
309        self._interceptor = interceptor
310
311    def __call__(self,
312                 request,
313                 timeout=None,
314                 metadata=None,
315                 credentials=None,
316                 wait_for_ready=None,
317                 compression=None):
318        client_call_details = _ClientCallDetails(self._method, timeout,
319                                                 metadata, credentials,
320                                                 wait_for_ready, compression)
321
322        def continuation(new_details, request):
323            (new_method, new_timeout, new_metadata, new_credentials,
324             new_wait_for_ready,
325             new_compression) = (_unwrap_client_call_details(
326                 new_details, client_call_details))
327            return self._thunk(new_method)(request,
328                                           timeout=new_timeout,
329                                           metadata=new_metadata,
330                                           credentials=new_credentials,
331                                           wait_for_ready=new_wait_for_ready,
332                                           compression=new_compression)
333
334        try:
335            return self._interceptor.intercept_unary_stream(
336                continuation, client_call_details, request)
337        except Exception as exception:  # pylint:disable=broad-except
338            return _FailureOutcome(exception, sys.exc_info()[2])
339
340
341class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
342
343    def __init__(self, thunk, method, interceptor):
344        self._thunk = thunk
345        self._method = method
346        self._interceptor = interceptor
347
348    def __call__(self,
349                 request_iterator,
350                 timeout=None,
351                 metadata=None,
352                 credentials=None,
353                 wait_for_ready=None,
354                 compression=None):
355        response, ignored_call = self._with_call(request_iterator,
356                                                 timeout=timeout,
357                                                 metadata=metadata,
358                                                 credentials=credentials,
359                                                 wait_for_ready=wait_for_ready,
360                                                 compression=compression)
361        return response
362
363    def _with_call(self,
364                   request_iterator,
365                   timeout=None,
366                   metadata=None,
367                   credentials=None,
368                   wait_for_ready=None,
369                   compression=None):
370        client_call_details = _ClientCallDetails(self._method, timeout,
371                                                 metadata, credentials,
372                                                 wait_for_ready, compression)
373
374        def continuation(new_details, request_iterator):
375            (new_method, new_timeout, new_metadata, new_credentials,
376             new_wait_for_ready,
377             new_compression) = (_unwrap_client_call_details(
378                 new_details, client_call_details))
379            try:
380                response, call = self._thunk(new_method).with_call(
381                    request_iterator,
382                    timeout=new_timeout,
383                    metadata=new_metadata,
384                    credentials=new_credentials,
385                    wait_for_ready=new_wait_for_ready,
386                    compression=new_compression)
387                return _UnaryOutcome(response, call)
388            except grpc.RpcError as rpc_error:
389                return rpc_error
390            except Exception as exception:  # pylint:disable=broad-except
391                return _FailureOutcome(exception, sys.exc_info()[2])
392
393        call = self._interceptor.intercept_stream_unary(continuation,
394                                                        client_call_details,
395                                                        request_iterator)
396        return call.result(), call
397
398    def with_call(self,
399                  request_iterator,
400                  timeout=None,
401                  metadata=None,
402                  credentials=None,
403                  wait_for_ready=None,
404                  compression=None):
405        return self._with_call(request_iterator,
406                               timeout=timeout,
407                               metadata=metadata,
408                               credentials=credentials,
409                               wait_for_ready=wait_for_ready,
410                               compression=compression)
411
412    def future(self,
413               request_iterator,
414               timeout=None,
415               metadata=None,
416               credentials=None,
417               wait_for_ready=None,
418               compression=None):
419        client_call_details = _ClientCallDetails(self._method, timeout,
420                                                 metadata, credentials,
421                                                 wait_for_ready, compression)
422
423        def continuation(new_details, request_iterator):
424            (new_method, new_timeout, new_metadata, new_credentials,
425             new_wait_for_ready,
426             new_compression) = (_unwrap_client_call_details(
427                 new_details, client_call_details))
428            return self._thunk(new_method).future(
429                request_iterator,
430                timeout=new_timeout,
431                metadata=new_metadata,
432                credentials=new_credentials,
433                wait_for_ready=new_wait_for_ready,
434                compression=new_compression)
435
436        try:
437            return self._interceptor.intercept_stream_unary(
438                continuation, client_call_details, request_iterator)
439        except Exception as exception:  # pylint:disable=broad-except
440            return _FailureOutcome(exception, sys.exc_info()[2])
441
442
443class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
444
445    def __init__(self, thunk, method, interceptor):
446        self._thunk = thunk
447        self._method = method
448        self._interceptor = interceptor
449
450    def __call__(self,
451                 request_iterator,
452                 timeout=None,
453                 metadata=None,
454                 credentials=None,
455                 wait_for_ready=None,
456                 compression=None):
457        client_call_details = _ClientCallDetails(self._method, timeout,
458                                                 metadata, credentials,
459                                                 wait_for_ready, compression)
460
461        def continuation(new_details, request_iterator):
462            (new_method, new_timeout, new_metadata, new_credentials,
463             new_wait_for_ready,
464             new_compression) = (_unwrap_client_call_details(
465                 new_details, client_call_details))
466            return self._thunk(new_method)(request_iterator,
467                                           timeout=new_timeout,
468                                           metadata=new_metadata,
469                                           credentials=new_credentials,
470                                           wait_for_ready=new_wait_for_ready,
471                                           compression=new_compression)
472
473        try:
474            return self._interceptor.intercept_stream_stream(
475                continuation, client_call_details, request_iterator)
476        except Exception as exception:  # pylint:disable=broad-except
477            return _FailureOutcome(exception, sys.exc_info()[2])
478
479
480class _Channel(grpc.Channel):
481
482    def __init__(self, channel, interceptor):
483        self._channel = channel
484        self._interceptor = interceptor
485
486    def subscribe(self, callback, try_to_connect=False):
487        self._channel.subscribe(callback, try_to_connect=try_to_connect)
488
489    def unsubscribe(self, callback):
490        self._channel.unsubscribe(callback)
491
492    def unary_unary(self,
493                    method,
494                    request_serializer=None,
495                    response_deserializer=None):
496        thunk = lambda m: self._channel.unary_unary(m, request_serializer,
497                                                    response_deserializer)
498        if isinstance(self._interceptor, grpc.UnaryUnaryClientInterceptor):
499            return _UnaryUnaryMultiCallable(thunk, method, self._interceptor)
500        else:
501            return thunk(method)
502
503    def unary_stream(self,
504                     method,
505                     request_serializer=None,
506                     response_deserializer=None):
507        thunk = lambda m: self._channel.unary_stream(m, request_serializer,
508                                                     response_deserializer)
509        if isinstance(self._interceptor, grpc.UnaryStreamClientInterceptor):
510            return _UnaryStreamMultiCallable(thunk, method, self._interceptor)
511        else:
512            return thunk(method)
513
514    def stream_unary(self,
515                     method,
516                     request_serializer=None,
517                     response_deserializer=None):
518        thunk = lambda m: self._channel.stream_unary(m, request_serializer,
519                                                     response_deserializer)
520        if isinstance(self._interceptor, grpc.StreamUnaryClientInterceptor):
521            return _StreamUnaryMultiCallable(thunk, method, self._interceptor)
522        else:
523            return thunk(method)
524
525    def stream_stream(self,
526                      method,
527                      request_serializer=None,
528                      response_deserializer=None):
529        thunk = lambda m: self._channel.stream_stream(m, request_serializer,
530                                                      response_deserializer)
531        if isinstance(self._interceptor, grpc.StreamStreamClientInterceptor):
532            return _StreamStreamMultiCallable(thunk, method, self._interceptor)
533        else:
534            return thunk(method)
535
536    def _close(self):
537        self._channel.close()
538
539    def __enter__(self):
540        return self
541
542    def __exit__(self, exc_type, exc_val, exc_tb):
543        self._close()
544        return False
545
546    def close(self):
547        self._channel.close()
548
549
550def intercept_channel(channel, *interceptors):
551    for interceptor in reversed(list(interceptors)):
552        if not isinstance(interceptor, grpc.UnaryUnaryClientInterceptor) and \
553           not isinstance(interceptor, grpc.UnaryStreamClientInterceptor) and \
554           not isinstance(interceptor, grpc.StreamUnaryClientInterceptor) and \
555           not isinstance(interceptor, grpc.StreamStreamClientInterceptor):
556            raise TypeError('interceptor must be '
557                            'grpc.UnaryUnaryClientInterceptor or '
558                            'grpc.UnaryStreamClientInterceptor or '
559                            'grpc.StreamUnaryClientInterceptor or '
560                            'grpc.StreamStreamClientInterceptor or ')
561        channel = _Channel(channel, interceptor)
562    return channel
563