1# Copyright 2017 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Implementation of gRPC Python interceptors.""" 15 16import collections 17import sys 18 19import grpc 20 21 22class _ServicePipeline(object): 23 24 def __init__(self, interceptors): 25 self.interceptors = tuple(interceptors) 26 27 def _continuation(self, thunk, index): 28 return lambda context: self._intercept_at(thunk, index, context) 29 30 def _intercept_at(self, thunk, index, context): 31 if index < len(self.interceptors): 32 interceptor = self.interceptors[index] 33 thunk = self._continuation(thunk, index + 1) 34 return interceptor.intercept_service(thunk, context) 35 else: 36 return thunk(context) 37 38 def execute(self, thunk, context): 39 return self._intercept_at(thunk, 0, context) 40 41 42def service_pipeline(interceptors): 43 return _ServicePipeline(interceptors) if interceptors else None 44 45 46class _ClientCallDetails( 47 collections.namedtuple( 48 '_ClientCallDetails', 49 ('method', 'timeout', 'metadata', 'credentials')), 50 grpc.ClientCallDetails): 51 pass 52 53 54def _unwrap_client_call_details(call_details, default_details): 55 try: 56 method = call_details.method 57 except AttributeError: 58 method = default_details.method 59 60 try: 61 timeout = call_details.timeout 62 except AttributeError: 63 timeout = default_details.timeout 64 65 try: 66 metadata = call_details.metadata 67 except AttributeError: 68 metadata = default_details.metadata 69 70 try: 71 credentials = call_details.credentials 72 except AttributeError: 73 credentials = default_details.credentials 74 75 return method, timeout, metadata, credentials 76 77 78class _FailureOutcome(grpc.RpcError, grpc.Future, grpc.Call): 79 80 def __init__(self, exception, traceback): 81 super(_FailureOutcome, self).__init__() 82 self._exception = exception 83 self._traceback = traceback 84 85 def initial_metadata(self): 86 return None 87 88 def trailing_metadata(self): 89 return None 90 91 def code(self): 92 return grpc.StatusCode.INTERNAL 93 94 def details(self): 95 return 'Exception raised while intercepting the RPC' 96 97 def cancel(self): 98 return False 99 100 def cancelled(self): 101 return False 102 103 def is_active(self): 104 return False 105 106 def time_remaining(self): 107 return None 108 109 def running(self): 110 return False 111 112 def done(self): 113 return True 114 115 def result(self, ignored_timeout=None): 116 raise self._exception 117 118 def exception(self, ignored_timeout=None): 119 return self._exception 120 121 def traceback(self, ignored_timeout=None): 122 return self._traceback 123 124 def add_callback(self, callback): 125 return False 126 127 def add_done_callback(self, fn): 128 fn(self) 129 130 def __iter__(self): 131 return self 132 133 def next(self): 134 raise self._exception 135 136 137class _UnaryOutcome(grpc.Call, grpc.Future): 138 139 def __init__(self, response, call): 140 self._response = response 141 self._call = call 142 143 def initial_metadata(self): 144 return self._call.initial_metadata() 145 146 def trailing_metadata(self): 147 return self._call.trailing_metadata() 148 149 def code(self): 150 return self._call.code() 151 152 def details(self): 153 return self._call.details() 154 155 def is_active(self): 156 return self._call.is_active() 157 158 def time_remaining(self): 159 return self._call.time_remaining() 160 161 def cancel(self): 162 return self._call.cancel() 163 164 def add_callback(self, callback): 165 return self._call.add_callback(callback) 166 167 def cancelled(self): 168 return False 169 170 def running(self): 171 return False 172 173 def done(self): 174 return True 175 176 def result(self, ignored_timeout=None): 177 return self._response 178 179 def exception(self, ignored_timeout=None): 180 return None 181 182 def traceback(self, ignored_timeout=None): 183 return None 184 185 def add_done_callback(self, fn): 186 fn(self) 187 188 189class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): 190 191 def __init__(self, thunk, method, interceptor): 192 self._thunk = thunk 193 self._method = method 194 self._interceptor = interceptor 195 196 def __call__(self, request, timeout=None, metadata=None, credentials=None): 197 response, ignored_call = self._with_call( 198 request, 199 timeout=timeout, 200 metadata=metadata, 201 credentials=credentials) 202 return response 203 204 def _with_call(self, request, timeout=None, metadata=None, 205 credentials=None): 206 client_call_details = _ClientCallDetails(self._method, timeout, 207 metadata, credentials) 208 209 def continuation(new_details, request): 210 new_method, new_timeout, new_metadata, new_credentials = ( 211 _unwrap_client_call_details(new_details, client_call_details)) 212 try: 213 response, call = self._thunk(new_method).with_call( 214 request, 215 timeout=new_timeout, 216 metadata=new_metadata, 217 credentials=new_credentials) 218 return _UnaryOutcome(response, call) 219 except grpc.RpcError: 220 raise 221 except Exception as exception: # pylint:disable=broad-except 222 return _FailureOutcome(exception, sys.exc_info()[2]) 223 224 call = self._interceptor.intercept_unary_unary( 225 continuation, client_call_details, request) 226 return call.result(), call 227 228 def with_call(self, request, timeout=None, metadata=None, credentials=None): 229 return self._with_call( 230 request, 231 timeout=timeout, 232 metadata=metadata, 233 credentials=credentials) 234 235 def future(self, request, timeout=None, metadata=None, credentials=None): 236 client_call_details = _ClientCallDetails(self._method, timeout, 237 metadata, credentials) 238 239 def continuation(new_details, request): 240 new_method, new_timeout, new_metadata, new_credentials = ( 241 _unwrap_client_call_details(new_details, client_call_details)) 242 return self._thunk(new_method).future( 243 request, 244 timeout=new_timeout, 245 metadata=new_metadata, 246 credentials=new_credentials) 247 248 try: 249 return self._interceptor.intercept_unary_unary( 250 continuation, client_call_details, request) 251 except Exception as exception: # pylint:disable=broad-except 252 return _FailureOutcome(exception, sys.exc_info()[2]) 253 254 255class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): 256 257 def __init__(self, thunk, method, interceptor): 258 self._thunk = thunk 259 self._method = method 260 self._interceptor = interceptor 261 262 def __call__(self, request, timeout=None, metadata=None, credentials=None): 263 client_call_details = _ClientCallDetails(self._method, timeout, 264 metadata, credentials) 265 266 def continuation(new_details, request): 267 new_method, new_timeout, new_metadata, new_credentials = ( 268 _unwrap_client_call_details(new_details, client_call_details)) 269 return self._thunk(new_method)( 270 request, 271 timeout=new_timeout, 272 metadata=new_metadata, 273 credentials=new_credentials) 274 275 try: 276 return self._interceptor.intercept_unary_stream( 277 continuation, client_call_details, request) 278 except Exception as exception: # pylint:disable=broad-except 279 return _FailureOutcome(exception, sys.exc_info()[2]) 280 281 282class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): 283 284 def __init__(self, thunk, method, interceptor): 285 self._thunk = thunk 286 self._method = method 287 self._interceptor = interceptor 288 289 def __call__(self, 290 request_iterator, 291 timeout=None, 292 metadata=None, 293 credentials=None): 294 response, ignored_call = self._with_call( 295 request_iterator, 296 timeout=timeout, 297 metadata=metadata, 298 credentials=credentials) 299 return response 300 301 def _with_call(self, 302 request_iterator, 303 timeout=None, 304 metadata=None, 305 credentials=None): 306 client_call_details = _ClientCallDetails(self._method, timeout, 307 metadata, credentials) 308 309 def continuation(new_details, request_iterator): 310 new_method, new_timeout, new_metadata, new_credentials = ( 311 _unwrap_client_call_details(new_details, client_call_details)) 312 try: 313 response, call = self._thunk(new_method).with_call( 314 request_iterator, 315 timeout=new_timeout, 316 metadata=new_metadata, 317 credentials=new_credentials) 318 return _UnaryOutcome(response, call) 319 except grpc.RpcError: 320 raise 321 except Exception as exception: # pylint:disable=broad-except 322 return _FailureOutcome(exception, sys.exc_info()[2]) 323 324 call = self._interceptor.intercept_stream_unary( 325 continuation, client_call_details, request_iterator) 326 return call.result(), call 327 328 def with_call(self, 329 request_iterator, 330 timeout=None, 331 metadata=None, 332 credentials=None): 333 return self._with_call( 334 request_iterator, 335 timeout=timeout, 336 metadata=metadata, 337 credentials=credentials) 338 339 def future(self, 340 request_iterator, 341 timeout=None, 342 metadata=None, 343 credentials=None): 344 client_call_details = _ClientCallDetails(self._method, timeout, 345 metadata, credentials) 346 347 def continuation(new_details, request_iterator): 348 new_method, new_timeout, new_metadata, new_credentials = ( 349 _unwrap_client_call_details(new_details, client_call_details)) 350 return self._thunk(new_method).future( 351 request_iterator, 352 timeout=new_timeout, 353 metadata=new_metadata, 354 credentials=new_credentials) 355 356 try: 357 return self._interceptor.intercept_stream_unary( 358 continuation, client_call_details, request_iterator) 359 except Exception as exception: # pylint:disable=broad-except 360 return _FailureOutcome(exception, sys.exc_info()[2]) 361 362 363class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): 364 365 def __init__(self, thunk, method, interceptor): 366 self._thunk = thunk 367 self._method = method 368 self._interceptor = interceptor 369 370 def __call__(self, 371 request_iterator, 372 timeout=None, 373 metadata=None, 374 credentials=None): 375 client_call_details = _ClientCallDetails(self._method, timeout, 376 metadata, credentials) 377 378 def continuation(new_details, request_iterator): 379 new_method, new_timeout, new_metadata, new_credentials = ( 380 _unwrap_client_call_details(new_details, client_call_details)) 381 return self._thunk(new_method)( 382 request_iterator, 383 timeout=new_timeout, 384 metadata=new_metadata, 385 credentials=new_credentials) 386 387 try: 388 return self._interceptor.intercept_stream_stream( 389 continuation, client_call_details, request_iterator) 390 except Exception as exception: # pylint:disable=broad-except 391 return _FailureOutcome(exception, sys.exc_info()[2]) 392 393 394class _Channel(grpc.Channel): 395 396 def __init__(self, channel, interceptor): 397 self._channel = channel 398 self._interceptor = interceptor 399 400 def subscribe(self, callback, try_to_connect=False): 401 self._channel.subscribe(callback, try_to_connect=try_to_connect) 402 403 def unsubscribe(self, callback): 404 self._channel.unsubscribe(callback) 405 406 def unary_unary(self, 407 method, 408 request_serializer=None, 409 response_deserializer=None): 410 thunk = lambda m: self._channel.unary_unary(m, request_serializer, response_deserializer) 411 if isinstance(self._interceptor, grpc.UnaryUnaryClientInterceptor): 412 return _UnaryUnaryMultiCallable(thunk, method, self._interceptor) 413 else: 414 return thunk(method) 415 416 def unary_stream(self, 417 method, 418 request_serializer=None, 419 response_deserializer=None): 420 thunk = lambda m: self._channel.unary_stream(m, request_serializer, response_deserializer) 421 if isinstance(self._interceptor, grpc.UnaryStreamClientInterceptor): 422 return _UnaryStreamMultiCallable(thunk, method, self._interceptor) 423 else: 424 return thunk(method) 425 426 def stream_unary(self, 427 method, 428 request_serializer=None, 429 response_deserializer=None): 430 thunk = lambda m: self._channel.stream_unary(m, request_serializer, response_deserializer) 431 if isinstance(self._interceptor, grpc.StreamUnaryClientInterceptor): 432 return _StreamUnaryMultiCallable(thunk, method, self._interceptor) 433 else: 434 return thunk(method) 435 436 def stream_stream(self, 437 method, 438 request_serializer=None, 439 response_deserializer=None): 440 thunk = lambda m: self._channel.stream_stream(m, request_serializer, response_deserializer) 441 if isinstance(self._interceptor, grpc.StreamStreamClientInterceptor): 442 return _StreamStreamMultiCallable(thunk, method, self._interceptor) 443 else: 444 return thunk(method) 445 446 def _close(self): 447 self._channel.close() 448 449 def __enter__(self): 450 return self 451 452 def __exit__(self, exc_type, exc_val, exc_tb): 453 self._close() 454 return False 455 456 def close(self): 457 self._channel.close() 458 459 460def intercept_channel(channel, *interceptors): 461 for interceptor in reversed(list(interceptors)): 462 if not isinstance(interceptor, grpc.UnaryUnaryClientInterceptor) and \ 463 not isinstance(interceptor, grpc.UnaryStreamClientInterceptor) and \ 464 not isinstance(interceptor, grpc.StreamUnaryClientInterceptor) and \ 465 not isinstance(interceptor, grpc.StreamStreamClientInterceptor): 466 raise TypeError('interceptor must be ' 467 'grpc.UnaryUnaryClientInterceptor or ' 468 'grpc.UnaryStreamClientInterceptor or ' 469 'grpc.StreamUnaryClientInterceptor or ' 470 'grpc.StreamStreamClientInterceptor or ') 471 channel = _Channel(channel, interceptor) 472 return channel 473