1# Copyright 2017 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Implementation of gRPC Python interceptors.""" 15 16import collections 17import sys 18 19import grpc 20 21 22class _ServicePipeline(object): 23 24 def __init__(self, interceptors): 25 self.interceptors = tuple(interceptors) 26 27 def _continuation(self, thunk, index): 28 return lambda context: self._intercept_at(thunk, index, context) 29 30 def _intercept_at(self, thunk, index, context): 31 if index < len(self.interceptors): 32 interceptor = self.interceptors[index] 33 thunk = self._continuation(thunk, index + 1) 34 return interceptor.intercept_service(thunk, context) 35 else: 36 return thunk(context) 37 38 def execute(self, thunk, context): 39 return self._intercept_at(thunk, 0, context) 40 41 42def service_pipeline(interceptors): 43 return _ServicePipeline(interceptors) if interceptors else None 44 45 46class _ClientCallDetails( 47 collections.namedtuple('_ClientCallDetails', 48 ('method', 'timeout', 'metadata', 'credentials', 49 'wait_for_ready', 'compression')), 50 grpc.ClientCallDetails): 51 pass 52 53 54def _unwrap_client_call_details(call_details, default_details): 55 try: 56 method = call_details.method 57 except AttributeError: 58 method = default_details.method 59 60 try: 61 timeout = call_details.timeout 62 except AttributeError: 63 timeout = default_details.timeout 64 65 try: 66 metadata = call_details.metadata 67 except AttributeError: 68 metadata = default_details.metadata 69 70 try: 71 credentials = call_details.credentials 72 except AttributeError: 73 credentials = default_details.credentials 74 75 try: 76 wait_for_ready = call_details.wait_for_ready 77 except AttributeError: 78 wait_for_ready = default_details.wait_for_ready 79 80 try: 81 compression = call_details.compression 82 except AttributeError: 83 compression = default_details.compression 84 85 return method, timeout, metadata, credentials, wait_for_ready, compression 86 87 88class _FailureOutcome(grpc.RpcError, grpc.Future, grpc.Call): # pylint: disable=too-many-ancestors 89 90 def __init__(self, exception, traceback): 91 super(_FailureOutcome, self).__init__() 92 self._exception = exception 93 self._traceback = traceback 94 95 def initial_metadata(self): 96 return None 97 98 def trailing_metadata(self): 99 return None 100 101 def code(self): 102 return grpc.StatusCode.INTERNAL 103 104 def details(self): 105 return 'Exception raised while intercepting the RPC' 106 107 def cancel(self): 108 return False 109 110 def cancelled(self): 111 return False 112 113 def is_active(self): 114 return False 115 116 def time_remaining(self): 117 return None 118 119 def running(self): 120 return False 121 122 def done(self): 123 return True 124 125 def result(self, ignored_timeout=None): 126 raise self._exception 127 128 def exception(self, ignored_timeout=None): 129 return self._exception 130 131 def traceback(self, ignored_timeout=None): 132 return self._traceback 133 134 def add_callback(self, unused_callback): 135 return False 136 137 def add_done_callback(self, fn): 138 fn(self) 139 140 def __iter__(self): 141 return self 142 143 def __next__(self): 144 raise self._exception 145 146 def next(self): 147 return self.__next__() 148 149 150class _UnaryOutcome(grpc.Call, grpc.Future): 151 152 def __init__(self, response, call): 153 self._response = response 154 self._call = call 155 156 def initial_metadata(self): 157 return self._call.initial_metadata() 158 159 def trailing_metadata(self): 160 return self._call.trailing_metadata() 161 162 def code(self): 163 return self._call.code() 164 165 def details(self): 166 return self._call.details() 167 168 def is_active(self): 169 return self._call.is_active() 170 171 def time_remaining(self): 172 return self._call.time_remaining() 173 174 def cancel(self): 175 return self._call.cancel() 176 177 def add_callback(self, callback): 178 return self._call.add_callback(callback) 179 180 def cancelled(self): 181 return False 182 183 def running(self): 184 return False 185 186 def done(self): 187 return True 188 189 def result(self, ignored_timeout=None): 190 return self._response 191 192 def exception(self, ignored_timeout=None): 193 return None 194 195 def traceback(self, ignored_timeout=None): 196 return None 197 198 def add_done_callback(self, fn): 199 fn(self) 200 201 202class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): 203 204 def __init__(self, thunk, method, interceptor): 205 self._thunk = thunk 206 self._method = method 207 self._interceptor = interceptor 208 209 def __call__(self, 210 request, 211 timeout=None, 212 metadata=None, 213 credentials=None, 214 wait_for_ready=None, 215 compression=None): 216 response, ignored_call = self._with_call(request, 217 timeout=timeout, 218 metadata=metadata, 219 credentials=credentials, 220 wait_for_ready=wait_for_ready, 221 compression=compression) 222 return response 223 224 def _with_call(self, 225 request, 226 timeout=None, 227 metadata=None, 228 credentials=None, 229 wait_for_ready=None, 230 compression=None): 231 client_call_details = _ClientCallDetails(self._method, timeout, 232 metadata, credentials, 233 wait_for_ready, compression) 234 235 def continuation(new_details, request): 236 (new_method, new_timeout, new_metadata, new_credentials, 237 new_wait_for_ready, 238 new_compression) = (_unwrap_client_call_details( 239 new_details, client_call_details)) 240 try: 241 response, call = self._thunk(new_method).with_call( 242 request, 243 timeout=new_timeout, 244 metadata=new_metadata, 245 credentials=new_credentials, 246 wait_for_ready=new_wait_for_ready, 247 compression=new_compression) 248 return _UnaryOutcome(response, call) 249 except grpc.RpcError as rpc_error: 250 return rpc_error 251 except Exception as exception: # pylint:disable=broad-except 252 return _FailureOutcome(exception, sys.exc_info()[2]) 253 254 call = self._interceptor.intercept_unary_unary(continuation, 255 client_call_details, 256 request) 257 return call.result(), call 258 259 def with_call(self, 260 request, 261 timeout=None, 262 metadata=None, 263 credentials=None, 264 wait_for_ready=None, 265 compression=None): 266 return self._with_call(request, 267 timeout=timeout, 268 metadata=metadata, 269 credentials=credentials, 270 wait_for_ready=wait_for_ready, 271 compression=compression) 272 273 def future(self, 274 request, 275 timeout=None, 276 metadata=None, 277 credentials=None, 278 wait_for_ready=None, 279 compression=None): 280 client_call_details = _ClientCallDetails(self._method, timeout, 281 metadata, credentials, 282 wait_for_ready, compression) 283 284 def continuation(new_details, request): 285 (new_method, new_timeout, new_metadata, new_credentials, 286 new_wait_for_ready, 287 new_compression) = (_unwrap_client_call_details( 288 new_details, client_call_details)) 289 return self._thunk(new_method).future( 290 request, 291 timeout=new_timeout, 292 metadata=new_metadata, 293 credentials=new_credentials, 294 wait_for_ready=new_wait_for_ready, 295 compression=new_compression) 296 297 try: 298 return self._interceptor.intercept_unary_unary( 299 continuation, client_call_details, request) 300 except Exception as exception: # pylint:disable=broad-except 301 return _FailureOutcome(exception, sys.exc_info()[2]) 302 303 304class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): 305 306 def __init__(self, thunk, method, interceptor): 307 self._thunk = thunk 308 self._method = method 309 self._interceptor = interceptor 310 311 def __call__(self, 312 request, 313 timeout=None, 314 metadata=None, 315 credentials=None, 316 wait_for_ready=None, 317 compression=None): 318 client_call_details = _ClientCallDetails(self._method, timeout, 319 metadata, credentials, 320 wait_for_ready, compression) 321 322 def continuation(new_details, request): 323 (new_method, new_timeout, new_metadata, new_credentials, 324 new_wait_for_ready, 325 new_compression) = (_unwrap_client_call_details( 326 new_details, client_call_details)) 327 return self._thunk(new_method)(request, 328 timeout=new_timeout, 329 metadata=new_metadata, 330 credentials=new_credentials, 331 wait_for_ready=new_wait_for_ready, 332 compression=new_compression) 333 334 try: 335 return self._interceptor.intercept_unary_stream( 336 continuation, client_call_details, request) 337 except Exception as exception: # pylint:disable=broad-except 338 return _FailureOutcome(exception, sys.exc_info()[2]) 339 340 341class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): 342 343 def __init__(self, thunk, method, interceptor): 344 self._thunk = thunk 345 self._method = method 346 self._interceptor = interceptor 347 348 def __call__(self, 349 request_iterator, 350 timeout=None, 351 metadata=None, 352 credentials=None, 353 wait_for_ready=None, 354 compression=None): 355 response, ignored_call = self._with_call(request_iterator, 356 timeout=timeout, 357 metadata=metadata, 358 credentials=credentials, 359 wait_for_ready=wait_for_ready, 360 compression=compression) 361 return response 362 363 def _with_call(self, 364 request_iterator, 365 timeout=None, 366 metadata=None, 367 credentials=None, 368 wait_for_ready=None, 369 compression=None): 370 client_call_details = _ClientCallDetails(self._method, timeout, 371 metadata, credentials, 372 wait_for_ready, compression) 373 374 def continuation(new_details, request_iterator): 375 (new_method, new_timeout, new_metadata, new_credentials, 376 new_wait_for_ready, 377 new_compression) = (_unwrap_client_call_details( 378 new_details, client_call_details)) 379 try: 380 response, call = self._thunk(new_method).with_call( 381 request_iterator, 382 timeout=new_timeout, 383 metadata=new_metadata, 384 credentials=new_credentials, 385 wait_for_ready=new_wait_for_ready, 386 compression=new_compression) 387 return _UnaryOutcome(response, call) 388 except grpc.RpcError as rpc_error: 389 return rpc_error 390 except Exception as exception: # pylint:disable=broad-except 391 return _FailureOutcome(exception, sys.exc_info()[2]) 392 393 call = self._interceptor.intercept_stream_unary(continuation, 394 client_call_details, 395 request_iterator) 396 return call.result(), call 397 398 def with_call(self, 399 request_iterator, 400 timeout=None, 401 metadata=None, 402 credentials=None, 403 wait_for_ready=None, 404 compression=None): 405 return self._with_call(request_iterator, 406 timeout=timeout, 407 metadata=metadata, 408 credentials=credentials, 409 wait_for_ready=wait_for_ready, 410 compression=compression) 411 412 def future(self, 413 request_iterator, 414 timeout=None, 415 metadata=None, 416 credentials=None, 417 wait_for_ready=None, 418 compression=None): 419 client_call_details = _ClientCallDetails(self._method, timeout, 420 metadata, credentials, 421 wait_for_ready, compression) 422 423 def continuation(new_details, request_iterator): 424 (new_method, new_timeout, new_metadata, new_credentials, 425 new_wait_for_ready, 426 new_compression) = (_unwrap_client_call_details( 427 new_details, client_call_details)) 428 return self._thunk(new_method).future( 429 request_iterator, 430 timeout=new_timeout, 431 metadata=new_metadata, 432 credentials=new_credentials, 433 wait_for_ready=new_wait_for_ready, 434 compression=new_compression) 435 436 try: 437 return self._interceptor.intercept_stream_unary( 438 continuation, client_call_details, request_iterator) 439 except Exception as exception: # pylint:disable=broad-except 440 return _FailureOutcome(exception, sys.exc_info()[2]) 441 442 443class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): 444 445 def __init__(self, thunk, method, interceptor): 446 self._thunk = thunk 447 self._method = method 448 self._interceptor = interceptor 449 450 def __call__(self, 451 request_iterator, 452 timeout=None, 453 metadata=None, 454 credentials=None, 455 wait_for_ready=None, 456 compression=None): 457 client_call_details = _ClientCallDetails(self._method, timeout, 458 metadata, credentials, 459 wait_for_ready, compression) 460 461 def continuation(new_details, request_iterator): 462 (new_method, new_timeout, new_metadata, new_credentials, 463 new_wait_for_ready, 464 new_compression) = (_unwrap_client_call_details( 465 new_details, client_call_details)) 466 return self._thunk(new_method)(request_iterator, 467 timeout=new_timeout, 468 metadata=new_metadata, 469 credentials=new_credentials, 470 wait_for_ready=new_wait_for_ready, 471 compression=new_compression) 472 473 try: 474 return self._interceptor.intercept_stream_stream( 475 continuation, client_call_details, request_iterator) 476 except Exception as exception: # pylint:disable=broad-except 477 return _FailureOutcome(exception, sys.exc_info()[2]) 478 479 480class _Channel(grpc.Channel): 481 482 def __init__(self, channel, interceptor): 483 self._channel = channel 484 self._interceptor = interceptor 485 486 def subscribe(self, callback, try_to_connect=False): 487 self._channel.subscribe(callback, try_to_connect=try_to_connect) 488 489 def unsubscribe(self, callback): 490 self._channel.unsubscribe(callback) 491 492 def unary_unary(self, 493 method, 494 request_serializer=None, 495 response_deserializer=None): 496 thunk = lambda m: self._channel.unary_unary(m, request_serializer, 497 response_deserializer) 498 if isinstance(self._interceptor, grpc.UnaryUnaryClientInterceptor): 499 return _UnaryUnaryMultiCallable(thunk, method, self._interceptor) 500 else: 501 return thunk(method) 502 503 def unary_stream(self, 504 method, 505 request_serializer=None, 506 response_deserializer=None): 507 thunk = lambda m: self._channel.unary_stream(m, request_serializer, 508 response_deserializer) 509 if isinstance(self._interceptor, grpc.UnaryStreamClientInterceptor): 510 return _UnaryStreamMultiCallable(thunk, method, self._interceptor) 511 else: 512 return thunk(method) 513 514 def stream_unary(self, 515 method, 516 request_serializer=None, 517 response_deserializer=None): 518 thunk = lambda m: self._channel.stream_unary(m, request_serializer, 519 response_deserializer) 520 if isinstance(self._interceptor, grpc.StreamUnaryClientInterceptor): 521 return _StreamUnaryMultiCallable(thunk, method, self._interceptor) 522 else: 523 return thunk(method) 524 525 def stream_stream(self, 526 method, 527 request_serializer=None, 528 response_deserializer=None): 529 thunk = lambda m: self._channel.stream_stream(m, request_serializer, 530 response_deserializer) 531 if isinstance(self._interceptor, grpc.StreamStreamClientInterceptor): 532 return _StreamStreamMultiCallable(thunk, method, self._interceptor) 533 else: 534 return thunk(method) 535 536 def _close(self): 537 self._channel.close() 538 539 def __enter__(self): 540 return self 541 542 def __exit__(self, exc_type, exc_val, exc_tb): 543 self._close() 544 return False 545 546 def close(self): 547 self._channel.close() 548 549 550def intercept_channel(channel, *interceptors): 551 for interceptor in reversed(list(interceptors)): 552 if not isinstance(interceptor, grpc.UnaryUnaryClientInterceptor) and \ 553 not isinstance(interceptor, grpc.UnaryStreamClientInterceptor) and \ 554 not isinstance(interceptor, grpc.StreamUnaryClientInterceptor) and \ 555 not isinstance(interceptor, grpc.StreamStreamClientInterceptor): 556 raise TypeError('interceptor must be ' 557 'grpc.UnaryUnaryClientInterceptor or ' 558 'grpc.UnaryStreamClientInterceptor or ' 559 'grpc.StreamUnaryClientInterceptor or ' 560 'grpc.StreamStreamClientInterceptor or ') 561 channel = _Channel(channel, interceptor) 562 return channel 563