• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Implementation of gRPC Python interceptors."""
15
16import collections
17import sys
18
19import grpc
20
21
22class _ServicePipeline(object):
23
24    def __init__(self, interceptors):
25        self.interceptors = tuple(interceptors)
26
27    def _continuation(self, thunk, index):
28        return lambda context: self._intercept_at(thunk, index, context)
29
30    def _intercept_at(self, thunk, index, context):
31        if index < len(self.interceptors):
32            interceptor = self.interceptors[index]
33            thunk = self._continuation(thunk, index + 1)
34            return interceptor.intercept_service(thunk, context)
35        else:
36            return thunk(context)
37
38    def execute(self, thunk, context):
39        return self._intercept_at(thunk, 0, context)
40
41
42def service_pipeline(interceptors):
43    return _ServicePipeline(interceptors) if interceptors else None
44
45
46class _ClientCallDetails(
47        collections.namedtuple(
48            '_ClientCallDetails',
49            ('method', 'timeout', 'metadata', 'credentials')),
50        grpc.ClientCallDetails):
51    pass
52
53
54def _unwrap_client_call_details(call_details, default_details):
55    try:
56        method = call_details.method
57    except AttributeError:
58        method = default_details.method
59
60    try:
61        timeout = call_details.timeout
62    except AttributeError:
63        timeout = default_details.timeout
64
65    try:
66        metadata = call_details.metadata
67    except AttributeError:
68        metadata = default_details.metadata
69
70    try:
71        credentials = call_details.credentials
72    except AttributeError:
73        credentials = default_details.credentials
74
75    return method, timeout, metadata, credentials
76
77
78class _FailureOutcome(grpc.RpcError, grpc.Future, grpc.Call):
79
80    def __init__(self, exception, traceback):
81        super(_FailureOutcome, self).__init__()
82        self._exception = exception
83        self._traceback = traceback
84
85    def initial_metadata(self):
86        return None
87
88    def trailing_metadata(self):
89        return None
90
91    def code(self):
92        return grpc.StatusCode.INTERNAL
93
94    def details(self):
95        return 'Exception raised while intercepting the RPC'
96
97    def cancel(self):
98        return False
99
100    def cancelled(self):
101        return False
102
103    def is_active(self):
104        return False
105
106    def time_remaining(self):
107        return None
108
109    def running(self):
110        return False
111
112    def done(self):
113        return True
114
115    def result(self, ignored_timeout=None):
116        raise self._exception
117
118    def exception(self, ignored_timeout=None):
119        return self._exception
120
121    def traceback(self, ignored_timeout=None):
122        return self._traceback
123
124    def add_callback(self, callback):
125        return False
126
127    def add_done_callback(self, fn):
128        fn(self)
129
130    def __iter__(self):
131        return self
132
133    def next(self):
134        raise self._exception
135
136
137class _UnaryOutcome(grpc.Call, grpc.Future):
138
139    def __init__(self, response, call):
140        self._response = response
141        self._call = call
142
143    def initial_metadata(self):
144        return self._call.initial_metadata()
145
146    def trailing_metadata(self):
147        return self._call.trailing_metadata()
148
149    def code(self):
150        return self._call.code()
151
152    def details(self):
153        return self._call.details()
154
155    def is_active(self):
156        return self._call.is_active()
157
158    def time_remaining(self):
159        return self._call.time_remaining()
160
161    def cancel(self):
162        return self._call.cancel()
163
164    def add_callback(self, callback):
165        return self._call.add_callback(callback)
166
167    def cancelled(self):
168        return False
169
170    def running(self):
171        return False
172
173    def done(self):
174        return True
175
176    def result(self, ignored_timeout=None):
177        return self._response
178
179    def exception(self, ignored_timeout=None):
180        return None
181
182    def traceback(self, ignored_timeout=None):
183        return None
184
185    def add_done_callback(self, fn):
186        fn(self)
187
188
189class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
190
191    def __init__(self, thunk, method, interceptor):
192        self._thunk = thunk
193        self._method = method
194        self._interceptor = interceptor
195
196    def __call__(self, request, timeout=None, metadata=None, credentials=None):
197        response, ignored_call = self._with_call(
198            request,
199            timeout=timeout,
200            metadata=metadata,
201            credentials=credentials)
202        return response
203
204    def _with_call(self, request, timeout=None, metadata=None,
205                   credentials=None):
206        client_call_details = _ClientCallDetails(self._method, timeout,
207                                                 metadata, credentials)
208
209        def continuation(new_details, request):
210            new_method, new_timeout, new_metadata, new_credentials = (
211                _unwrap_client_call_details(new_details, client_call_details))
212            try:
213                response, call = self._thunk(new_method).with_call(
214                    request,
215                    timeout=new_timeout,
216                    metadata=new_metadata,
217                    credentials=new_credentials)
218                return _UnaryOutcome(response, call)
219            except grpc.RpcError:
220                raise
221            except Exception as exception:  # pylint:disable=broad-except
222                return _FailureOutcome(exception, sys.exc_info()[2])
223
224        call = self._interceptor.intercept_unary_unary(
225            continuation, client_call_details, request)
226        return call.result(), call
227
228    def with_call(self, request, timeout=None, metadata=None, credentials=None):
229        return self._with_call(
230            request,
231            timeout=timeout,
232            metadata=metadata,
233            credentials=credentials)
234
235    def future(self, request, timeout=None, metadata=None, credentials=None):
236        client_call_details = _ClientCallDetails(self._method, timeout,
237                                                 metadata, credentials)
238
239        def continuation(new_details, request):
240            new_method, new_timeout, new_metadata, new_credentials = (
241                _unwrap_client_call_details(new_details, client_call_details))
242            return self._thunk(new_method).future(
243                request,
244                timeout=new_timeout,
245                metadata=new_metadata,
246                credentials=new_credentials)
247
248        try:
249            return self._interceptor.intercept_unary_unary(
250                continuation, client_call_details, request)
251        except Exception as exception:  # pylint:disable=broad-except
252            return _FailureOutcome(exception, sys.exc_info()[2])
253
254
255class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
256
257    def __init__(self, thunk, method, interceptor):
258        self._thunk = thunk
259        self._method = method
260        self._interceptor = interceptor
261
262    def __call__(self, request, timeout=None, metadata=None, credentials=None):
263        client_call_details = _ClientCallDetails(self._method, timeout,
264                                                 metadata, credentials)
265
266        def continuation(new_details, request):
267            new_method, new_timeout, new_metadata, new_credentials = (
268                _unwrap_client_call_details(new_details, client_call_details))
269            return self._thunk(new_method)(
270                request,
271                timeout=new_timeout,
272                metadata=new_metadata,
273                credentials=new_credentials)
274
275        try:
276            return self._interceptor.intercept_unary_stream(
277                continuation, client_call_details, request)
278        except Exception as exception:  # pylint:disable=broad-except
279            return _FailureOutcome(exception, sys.exc_info()[2])
280
281
282class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
283
284    def __init__(self, thunk, method, interceptor):
285        self._thunk = thunk
286        self._method = method
287        self._interceptor = interceptor
288
289    def __call__(self,
290                 request_iterator,
291                 timeout=None,
292                 metadata=None,
293                 credentials=None):
294        response, ignored_call = self._with_call(
295            request_iterator,
296            timeout=timeout,
297            metadata=metadata,
298            credentials=credentials)
299        return response
300
301    def _with_call(self,
302                   request_iterator,
303                   timeout=None,
304                   metadata=None,
305                   credentials=None):
306        client_call_details = _ClientCallDetails(self._method, timeout,
307                                                 metadata, credentials)
308
309        def continuation(new_details, request_iterator):
310            new_method, new_timeout, new_metadata, new_credentials = (
311                _unwrap_client_call_details(new_details, client_call_details))
312            try:
313                response, call = self._thunk(new_method).with_call(
314                    request_iterator,
315                    timeout=new_timeout,
316                    metadata=new_metadata,
317                    credentials=new_credentials)
318                return _UnaryOutcome(response, call)
319            except grpc.RpcError:
320                raise
321            except Exception as exception:  # pylint:disable=broad-except
322                return _FailureOutcome(exception, sys.exc_info()[2])
323
324        call = self._interceptor.intercept_stream_unary(
325            continuation, client_call_details, request_iterator)
326        return call.result(), call
327
328    def with_call(self,
329                  request_iterator,
330                  timeout=None,
331                  metadata=None,
332                  credentials=None):
333        return self._with_call(
334            request_iterator,
335            timeout=timeout,
336            metadata=metadata,
337            credentials=credentials)
338
339    def future(self,
340               request_iterator,
341               timeout=None,
342               metadata=None,
343               credentials=None):
344        client_call_details = _ClientCallDetails(self._method, timeout,
345                                                 metadata, credentials)
346
347        def continuation(new_details, request_iterator):
348            new_method, new_timeout, new_metadata, new_credentials = (
349                _unwrap_client_call_details(new_details, client_call_details))
350            return self._thunk(new_method).future(
351                request_iterator,
352                timeout=new_timeout,
353                metadata=new_metadata,
354                credentials=new_credentials)
355
356        try:
357            return self._interceptor.intercept_stream_unary(
358                continuation, client_call_details, request_iterator)
359        except Exception as exception:  # pylint:disable=broad-except
360            return _FailureOutcome(exception, sys.exc_info()[2])
361
362
363class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
364
365    def __init__(self, thunk, method, interceptor):
366        self._thunk = thunk
367        self._method = method
368        self._interceptor = interceptor
369
370    def __call__(self,
371                 request_iterator,
372                 timeout=None,
373                 metadata=None,
374                 credentials=None):
375        client_call_details = _ClientCallDetails(self._method, timeout,
376                                                 metadata, credentials)
377
378        def continuation(new_details, request_iterator):
379            new_method, new_timeout, new_metadata, new_credentials = (
380                _unwrap_client_call_details(new_details, client_call_details))
381            return self._thunk(new_method)(
382                request_iterator,
383                timeout=new_timeout,
384                metadata=new_metadata,
385                credentials=new_credentials)
386
387        try:
388            return self._interceptor.intercept_stream_stream(
389                continuation, client_call_details, request_iterator)
390        except Exception as exception:  # pylint:disable=broad-except
391            return _FailureOutcome(exception, sys.exc_info()[2])
392
393
394class _Channel(grpc.Channel):
395
396    def __init__(self, channel, interceptor):
397        self._channel = channel
398        self._interceptor = interceptor
399
400    def subscribe(self, callback, try_to_connect=False):
401        self._channel.subscribe(callback, try_to_connect=try_to_connect)
402
403    def unsubscribe(self, callback):
404        self._channel.unsubscribe(callback)
405
406    def unary_unary(self,
407                    method,
408                    request_serializer=None,
409                    response_deserializer=None):
410        thunk = lambda m: self._channel.unary_unary(m, request_serializer, response_deserializer)
411        if isinstance(self._interceptor, grpc.UnaryUnaryClientInterceptor):
412            return _UnaryUnaryMultiCallable(thunk, method, self._interceptor)
413        else:
414            return thunk(method)
415
416    def unary_stream(self,
417                     method,
418                     request_serializer=None,
419                     response_deserializer=None):
420        thunk = lambda m: self._channel.unary_stream(m, request_serializer, response_deserializer)
421        if isinstance(self._interceptor, grpc.UnaryStreamClientInterceptor):
422            return _UnaryStreamMultiCallable(thunk, method, self._interceptor)
423        else:
424            return thunk(method)
425
426    def stream_unary(self,
427                     method,
428                     request_serializer=None,
429                     response_deserializer=None):
430        thunk = lambda m: self._channel.stream_unary(m, request_serializer, response_deserializer)
431        if isinstance(self._interceptor, grpc.StreamUnaryClientInterceptor):
432            return _StreamUnaryMultiCallable(thunk, method, self._interceptor)
433        else:
434            return thunk(method)
435
436    def stream_stream(self,
437                      method,
438                      request_serializer=None,
439                      response_deserializer=None):
440        thunk = lambda m: self._channel.stream_stream(m, request_serializer, response_deserializer)
441        if isinstance(self._interceptor, grpc.StreamStreamClientInterceptor):
442            return _StreamStreamMultiCallable(thunk, method, self._interceptor)
443        else:
444            return thunk(method)
445
446    def _close(self):
447        self._channel.close()
448
449    def __enter__(self):
450        return self
451
452    def __exit__(self, exc_type, exc_val, exc_tb):
453        self._close()
454        return False
455
456    def close(self):
457        self._channel.close()
458
459
460def intercept_channel(channel, *interceptors):
461    for interceptor in reversed(list(interceptors)):
462        if not isinstance(interceptor, grpc.UnaryUnaryClientInterceptor) and \
463           not isinstance(interceptor, grpc.UnaryStreamClientInterceptor) and \
464           not isinstance(interceptor, grpc.StreamUnaryClientInterceptor) and \
465           not isinstance(interceptor, grpc.StreamStreamClientInterceptor):
466            raise TypeError('interceptor must be '
467                            'grpc.UnaryUnaryClientInterceptor or '
468                            'grpc.UnaryStreamClientInterceptor or '
469                            'grpc.StreamUnaryClientInterceptor or '
470                            'grpc.StreamStreamClientInterceptor or ')
471        channel = _Channel(channel, interceptor)
472    return channel
473