1from __future__ import print_function 2 3import base64 4import contextlib 5import copy 6import email.utils 7import functools 8import gzip 9import hashlib 10import httplib2 11import os 12import random 13import re 14import shutil 15import six 16import socket 17import ssl 18import struct 19import sys 20import threading 21import time 22import traceback 23import zlib 24from six.moves import http_client, queue 25 26 27DUMMY_URL = "http://127.0.0.1:1" 28DUMMY_HTTPS_URL = "https://127.0.0.1:2" 29 30tls_dir = os.path.join(os.path.dirname(__file__), "tls") 31CA_CERTS = os.path.join(tls_dir, "ca.pem") 32CA_UNUSED_CERTS = os.path.join(tls_dir, "ca_unused.pem") 33CLIENT_PEM = os.path.join(tls_dir, "client.pem") 34CLIENT_ENCRYPTED_PEM = os.path.join(tls_dir, "client_encrypted.pem") 35SERVER_PEM = os.path.join(tls_dir, "server.pem") 36SERVER_CHAIN = os.path.join(tls_dir, "server_chain.pem") 37 38 39@contextlib.contextmanager 40def assert_raises(exc_type): 41 def _name(t): 42 return getattr(t, "__name__", None) or str(t) 43 44 if not isinstance(exc_type, tuple): 45 exc_type = (exc_type,) 46 names = ", ".join(map(_name, exc_type)) 47 48 try: 49 yield 50 except exc_type: 51 pass 52 else: 53 assert False, "Expected exception(s) {0}".format(names) 54 55 56class BufferedReader(object): 57 """io.BufferedReader with \r\n support 58 """ 59 60 def __init__(self, sock): 61 self._buf = b"" 62 self._end = False 63 self._newline = b"\r\n" 64 self._sock = sock 65 if isinstance(sock, bytes): 66 self._sock = None 67 self._buf = sock 68 69 def _fill(self, target=1, more=None, untilend=False): 70 if more: 71 target = len(self._buf) + more 72 while untilend or (len(self._buf) < target): 73 # crutch to enable HttpRequest.from_bytes 74 if self._sock is None: 75 chunk = b"" 76 else: 77 chunk = self._sock.recv(8 << 10) 78 # print("!!! recv", chunk) 79 if not chunk: 80 self._end = True 81 if untilend: 82 return 83 else: 84 raise EOFError 85 self._buf += chunk 86 87 def peek(self, size): 88 self._fill(target=size) 89 return self._buf[:size] 90 91 def read(self, size): 92 self._fill(target=size) 93 chunk, self._buf = self._buf[:size], self._buf[size:] 94 return chunk 95 96 def readall(self): 97 self._fill(untilend=True) 98 chunk, self._buf = self._buf, b"" 99 return chunk 100 101 def readline(self): 102 while True: 103 i = self._buf.find(self._newline) 104 if i >= 0: 105 break 106 self._fill(more=1) 107 inext = i + len(self._newline) 108 line, self._buf = self._buf[:inext], self._buf[inext:] 109 return line 110 111 112def parse_http_message(kind, buf): 113 if buf._end: 114 return None 115 try: 116 start_line = buf.readline() 117 except EOFError: 118 return None 119 msg = kind() 120 msg.raw = start_line 121 if kind is HttpRequest: 122 assert re.match( 123 br".+ HTTP/\d\.\d\r\n$", start_line 124 ), "Start line does not look like HTTP request: " + repr(start_line) 125 msg.method, msg.uri, msg.proto = start_line.rstrip().decode().split(" ", 2) 126 assert msg.proto.startswith("HTTP/"), repr(start_line) 127 elif kind is HttpResponse: 128 assert re.match( 129 br"^HTTP/\d\.\d \d+ .+\r\n$", start_line 130 ), "Start line does not look like HTTP response: " + repr(start_line) 131 msg.proto, msg.status, msg.reason = start_line.rstrip().decode().split(" ", 2) 132 msg.status = int(msg.status) 133 assert msg.proto.startswith("HTTP/"), repr(start_line) 134 else: 135 raise Exception("Use HttpRequest or HttpResponse .from_{bytes,buffered}") 136 msg.version = msg.proto[5:] 137 138 while True: 139 line = buf.readline() 140 msg.raw += line 141 line = line.rstrip() 142 if not line: 143 break 144 t = line.decode().split(":", 1) 145 msg.headers[t[0].lower()] = t[1].lstrip() 146 147 content_length_string = msg.headers.get("content-length", "") 148 if content_length_string.isdigit(): 149 content_length = int(content_length_string) 150 msg.body = msg.body_raw = buf.read(content_length) 151 elif msg.headers.get("transfer-encoding") == "chunked": 152 raise NotImplemented 153 elif msg.version == "1.0": 154 msg.body = msg.body_raw = buf.readall() 155 else: 156 msg.body = msg.body_raw = b"" 157 158 msg.raw += msg.body_raw 159 return msg 160 161 162class HttpMessage(object): 163 def __init__(self): 164 self.headers = {} 165 166 @classmethod 167 def from_bytes(cls, bs): 168 buf = BufferedReader(bs) 169 return parse_http_message(cls, buf) 170 171 @classmethod 172 def from_buffered(cls, buf): 173 return parse_http_message(cls, buf) 174 175 def __repr__(self): 176 return "{} {}".format(self.__class__, repr(vars(self))) 177 178 179class HttpRequest(HttpMessage): 180 pass 181 182 183class HttpResponse(HttpMessage): 184 pass 185 186 187class MockResponse(six.BytesIO): 188 def __init__(self, body, **kwargs): 189 six.BytesIO.__init__(self, body) 190 self.headers = kwargs 191 192 def items(self): 193 return self.headers.items() 194 195 def iteritems(self): 196 return six.iteritems(self.headers) 197 198 199class MockHTTPConnection(object): 200 """This class is just a mock of httplib.HTTPConnection used for testing 201 """ 202 203 def __init__( 204 self, 205 host, 206 port=None, 207 key_file=None, 208 cert_file=None, 209 strict=None, 210 timeout=None, 211 proxy_info=None, 212 ): 213 self.host = host 214 self.port = port 215 self.timeout = timeout 216 self.log = "" 217 self.sock = None 218 219 def set_debuglevel(self, level): 220 pass 221 222 def connect(self): 223 "Connect to a host on a given port." 224 pass 225 226 def close(self): 227 pass 228 229 def request(self, method, request_uri, body, headers): 230 pass 231 232 def getresponse(self): 233 return MockResponse(b"the body", status="200") 234 235 236class MockHTTPBadStatusConnection(object): 237 """Mock of httplib.HTTPConnection that raises BadStatusLine. 238 """ 239 240 num_calls = 0 241 242 def __init__( 243 self, 244 host, 245 port=None, 246 key_file=None, 247 cert_file=None, 248 strict=None, 249 timeout=None, 250 proxy_info=None, 251 ): 252 self.host = host 253 self.port = port 254 self.timeout = timeout 255 self.log = "" 256 self.sock = None 257 MockHTTPBadStatusConnection.num_calls = 0 258 259 def set_debuglevel(self, level): 260 pass 261 262 def connect(self): 263 pass 264 265 def close(self): 266 pass 267 268 def request(self, method, request_uri, body, headers): 269 pass 270 271 def getresponse(self): 272 MockHTTPBadStatusConnection.num_calls += 1 273 raise http_client.BadStatusLine("") 274 275 276@contextlib.contextmanager 277def server_socket(fun, request_count=1, timeout=5, scheme="", tls=None): 278 """Base socket server for tests. 279 Likely you want to use server_request or other higher level helpers. 280 All arguments except fun can be passed to other server_* helpers. 281 282 :param fun: fun(client_sock, tick) called after successful accept(). 283 :param request_count: test succeeds after exactly this number of requests, triggered by tick(request) 284 :param timeout: seconds. 285 :param scheme: affects yielded value 286 "" - build normal http/https URI. 287 string - build normal URI using supplied scheme. 288 None - yield (addr, port) tuple. 289 :param tls: 290 None (default) - plain HTTP. 291 True - HTTPS with reasonable defaults. Likely you want httplib2.Http(ca_certs=tests.CA_CERTS) 292 string - path to custom server cert+key PEM file. 293 callable - function(context, listener, skip_errors) -> ssl_wrapped_listener 294 """ 295 gresult = [None] 296 gcounter = [0] 297 tls_skip_errors = [ 298 "TLSV1_ALERT_UNKNOWN_CA", 299 ] 300 301 def tick(request): 302 gcounter[0] += 1 303 keep = True 304 keep &= gcounter[0] < request_count 305 if request is not None: 306 keep &= request.headers.get("connection", "").lower() != "close" 307 return keep 308 309 def server_socket_thread(srv): 310 try: 311 while gcounter[0] < request_count: 312 try: 313 client, _ = srv.accept() 314 except ssl.SSLError as e: 315 if e.reason in tls_skip_errors: 316 return 317 raise 318 319 try: 320 client.settimeout(timeout) 321 fun(client, tick) 322 finally: 323 try: 324 client.shutdown(socket.SHUT_RDWR) 325 except (IOError, socket.error): 326 pass 327 # FIXME: client.close() introduces connection reset by peer 328 # at least in other/connection_close test 329 # should not be a problem since socket would close upon garbage collection 330 if gcounter[0] > request_count: 331 gresult[0] = Exception( 332 "Request count expected={0} actual={1}".format( 333 request_count, gcounter[0] 334 ) 335 ) 336 except Exception as e: 337 # traceback.print_exc caused IOError: concurrent operation on sys.stderr.close() under setup.py test 338 print(traceback.format_exc(), file=sys.stderr) 339 gresult[0] = e 340 341 bind_hostname = "localhost" 342 server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 343 server.bind((bind_hostname, 0)) 344 try: 345 server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 346 except socket.error as ex: 347 print("non critical error on SO_REUSEADDR", ex) 348 server.listen(10) 349 server.settimeout(timeout) 350 server_port = server.getsockname()[1] 351 if tls is True: 352 tls = SERVER_CHAIN 353 if tls: 354 context = ssl_context() 355 if callable(tls): 356 context.load_cert_chain(SERVER_CHAIN) 357 server = tls(context, server, tls_skip_errors) 358 else: 359 context.load_cert_chain(tls) 360 server = context.wrap_socket(server, server_side=True) 361 if scheme == "": 362 scheme = "https" if tls else "http" 363 364 t = threading.Thread(target=server_socket_thread, args=(server,)) 365 t.daemon = True 366 t.start() 367 if scheme is None: 368 yield (bind_hostname, server_port) 369 else: 370 yield u"{scheme}://{host}:{port}/".format(scheme=scheme, host=bind_hostname, port=server_port) 371 server.close() 372 t.join() 373 if gresult[0] is not None: 374 raise gresult[0] 375 376 377def server_yield(fun, **kwargs): 378 q = queue.Queue(1) 379 g = fun(q.get) 380 381 def server_yield_socket_handler(sock, tick): 382 buf = BufferedReader(sock) 383 i = 0 384 while True: 385 request = HttpRequest.from_buffered(buf) 386 if request is None: 387 break 388 i += 1 389 request.client_sock = sock 390 request.number = i 391 q.put(request) 392 response = six.next(g) 393 sock.sendall(response) 394 request.client_sock = None 395 if not tick(request): 396 break 397 398 return server_socket(server_yield_socket_handler, **kwargs) 399 400 401def server_request(request_handler, **kwargs): 402 def server_request_socket_handler(sock, tick): 403 buf = BufferedReader(sock) 404 i = 0 405 while True: 406 request = HttpRequest.from_buffered(buf) 407 if request is None: 408 break 409 # print("--- debug request\n" + request.raw.decode("ascii", "replace")) 410 i += 1 411 request.client_sock = sock 412 request.number = i 413 response = request_handler(request=request) 414 # print("--- debug response\n" + response.decode("ascii", "replace")) 415 sock.sendall(response) 416 request.client_sock = None 417 if not tick(request): 418 break 419 420 return server_socket(server_request_socket_handler, **kwargs) 421 422 423def server_const_bytes(response_content, **kwargs): 424 return server_request(lambda request: response_content, **kwargs) 425 426 427_http_kwargs = ( 428 "proto", 429 "status", 430 "headers", 431 "body", 432 "add_content_length", 433 "add_date", 434 "add_etag", 435 "undefined_body_length", 436) 437 438 439def http_response_bytes( 440 proto="HTTP/1.1", 441 status="200 OK", 442 headers=None, 443 body=b"", 444 add_content_length=True, 445 add_date=False, 446 add_etag=False, 447 undefined_body_length=False, 448 **kwargs 449): 450 if undefined_body_length: 451 add_content_length = False 452 if headers is None: 453 headers = {} 454 if add_content_length: 455 headers.setdefault("content-length", str(len(body))) 456 if add_date: 457 headers.setdefault("date", email.utils.formatdate()) 458 if add_etag: 459 headers.setdefault("etag", '"{0}"'.format(hashlib.md5(body).hexdigest())) 460 header_string = "".join("{0}: {1}\r\n".format(k, v) for k, v in headers.items()) 461 if ( 462 not undefined_body_length 463 and proto != "HTTP/1.0" 464 and "content-length" not in headers 465 ): 466 raise Exception( 467 "httplib2.tests.http_response_bytes: client could not figure response body length" 468 ) 469 if str(status).isdigit(): 470 status = "{} {}".format(status, http_client.responses[status]) 471 response = ( 472 "{proto} {status}\r\n{headers}\r\n".format( 473 proto=proto, status=status, headers=header_string 474 ).encode() 475 + body 476 ) 477 return response 478 479 480def make_http_reflect(**kwargs): 481 assert "body" not in kwargs, "make_http_reflect will overwrite response " "body" 482 483 def fun(request): 484 kw = copy.deepcopy(kwargs) 485 kw["body"] = request.raw 486 response = http_response_bytes(**kw) 487 return response 488 489 return fun 490 491 492def server_route(routes, **kwargs): 493 response_404 = http_response_bytes(status="404 Not Found") 494 response_wildcard = routes.get("") 495 496 def handler(request): 497 target = routes.get(request.uri, response_wildcard) or response_404 498 if callable(target): 499 response = target(request=request) 500 else: 501 response = target 502 return response 503 504 return server_request(handler, **kwargs) 505 506 507def server_const_http(**kwargs): 508 response_kwargs = {k: kwargs.pop(k) for k in dict(kwargs) if k in _http_kwargs} 509 response = http_response_bytes(**response_kwargs) 510 return server_const_bytes(response, **kwargs) 511 512 513def server_list_http(responses, **kwargs): 514 i = iter(responses) 515 516 def handler(request): 517 return next(i) 518 519 kwargs.setdefault("request_count", len(responses)) 520 return server_request(handler, **kwargs) 521 522 523def server_reflect(**kwargs): 524 response_kwargs = {k: kwargs.pop(k) for k in dict(kwargs) if k in _http_kwargs} 525 http_handler = make_http_reflect(**response_kwargs) 526 return server_request(http_handler, **kwargs) 527 528 529def http_parse_auth(s): 530 """https://tools.ietf.org/html/rfc7235#section-2.1 531 """ 532 scheme, rest = s.split(" ", 1) 533 result = {} 534 while True: 535 m = httplib2.WWW_AUTH_RELAXED.search(rest) 536 if not m: 537 break 538 if len(m.groups()) == 3: 539 key, value, rest = m.groups() 540 result[key.lower()] = httplib2.UNQUOTE_PAIRS.sub(r"\1", value) 541 return result 542 543 544def store_request_response(out): 545 def wrapper(fun): 546 @functools.wraps(fun) 547 def wrapped(request, *a, **kw): 548 response_bytes = fun(request, *a, **kw) 549 if out is not None: 550 response = HttpResponse.from_bytes(response_bytes) 551 out.append((request, response)) 552 return response_bytes 553 554 return wrapped 555 556 return wrapper 557 558 559def http_reflect_with_auth( 560 allow_scheme, allow_credentials, out_renew_nonce=None, out_requests=None 561): 562 """allow_scheme - 'basic', 'digest', etc allow_credentials - sequence of ('name', 'password') out_renew_nonce - None | [function] 563 564 Way to return nonce renew function to caller. 565 Kind of `out` parameter in some programming languages. 566 Allows to keep same signature for all handler builder functions. 567 out_requests - None | [] 568 If set to list, every parsed request will be appended here. 569 """ 570 glastnc = [None] 571 gnextnonce = [None] 572 gserver_nonce = [gen_digest_nonce(salt=b"n")] 573 realm = "httplib2 test" 574 server_opaque = gen_digest_nonce(salt=b"o") 575 576 def renew_nonce(): 577 if gnextnonce[0]: 578 assert False, ( 579 "previous nextnonce was not used, probably bug in " "test code" 580 ) 581 gnextnonce[0] = gen_digest_nonce() 582 return gserver_nonce[0], gnextnonce[0] 583 584 if out_renew_nonce: 585 out_renew_nonce[0] = renew_nonce 586 587 def deny(**kwargs): 588 nonce_stale = kwargs.pop("nonce_stale", False) 589 if nonce_stale: 590 kwargs.setdefault("body", b"nonce stale") 591 if allow_scheme == "basic": 592 authenticate = 'basic realm="{realm}"'.format(realm=realm) 593 elif allow_scheme == "digest": 594 authenticate = ( 595 'digest realm="{realm}", qop="auth"' 596 + ', nonce="{nonce}", opaque="{opaque}"' 597 + (", stale=true" if nonce_stale else "") 598 ).format(realm=realm, nonce=gserver_nonce[0], opaque=server_opaque) 599 else: 600 raise Exception("unknown allow_scheme={0}".format(allow_scheme)) 601 deny_headers = {"www-authenticate": authenticate} 602 kwargs.setdefault("status", 401) 603 # supplied headers may overwrite generated ones 604 deny_headers.update(kwargs.get("headers", {})) 605 kwargs["headers"] = deny_headers 606 kwargs.setdefault("body", b"HTTP authorization required") 607 return http_response_bytes(**kwargs) 608 609 @store_request_response(out_requests) 610 def http_reflect_with_auth_handler(request): 611 auth_header = request.headers.get("authorization", "") 612 if not auth_header: 613 return deny() 614 if " " not in auth_header: 615 return http_response_bytes( 616 status=400, body=b"authorization header syntax error" 617 ) 618 scheme, data = auth_header.split(" ", 1) 619 scheme = scheme.lower() 620 if scheme != allow_scheme: 621 return deny(body=b"must use different auth scheme") 622 if scheme == "basic": 623 decoded = base64.b64decode(data).decode() 624 username, password = decoded.split(":", 1) 625 if (username, password) in allow_credentials: 626 return make_http_reflect()(request) 627 else: 628 return deny(body=b"supplied credentials are not allowed") 629 elif scheme == "digest": 630 server_nonce_old = gserver_nonce[0] 631 nextnonce = gnextnonce[0] 632 if nextnonce: 633 # server decided to change nonce, in this case, guided by caller test code 634 gserver_nonce[0] = nextnonce 635 gnextnonce[0] = None 636 server_nonce_current = gserver_nonce[0] 637 auth_info = http_parse_auth(data) 638 client_cnonce = auth_info.get("cnonce", "") 639 client_nc = auth_info.get("nc", "") 640 client_nonce = auth_info.get("nonce", "") 641 client_opaque = auth_info.get("opaque", "") 642 client_qop = auth_info.get("qop", "auth").strip('"') 643 644 # TODO: auth_info.get('algorithm', 'md5') 645 hasher = hashlib.md5 646 647 # TODO: client_qop auth-int 648 ha2 = hasher(":".join((request.method, request.uri)).encode()).hexdigest() 649 650 if client_nonce != server_nonce_current: 651 if client_nonce == server_nonce_old: 652 return deny(nonce_stale=True) 653 return deny(body=b"invalid nonce") 654 if not client_nc: 655 return deny(body=b"auth-info nc missing") 656 if client_opaque != server_opaque: 657 return deny( 658 body="auth-info opaque mismatch expected={} actual={}".format( 659 server_opaque, client_opaque 660 ).encode() 661 ) 662 for allow_username, allow_password in allow_credentials: 663 ha1 = hasher( 664 ":".join((allow_username, realm, allow_password)).encode() 665 ).hexdigest() 666 allow_response = hasher( 667 ":".join( 668 (ha1, client_nonce, client_nc, client_cnonce, client_qop, ha2) 669 ).encode() 670 ).hexdigest() 671 rspauth_ha2 = hasher(":{}".format(request.uri).encode()).hexdigest() 672 rspauth = hasher( 673 ":".join( 674 ( 675 ha1, 676 client_nonce, 677 client_nc, 678 client_cnonce, 679 client_qop, 680 rspauth_ha2, 681 ) 682 ).encode() 683 ).hexdigest() 684 if auth_info.get("response", "") == allow_response: 685 # TODO: fix or remove doubtful comment 686 # do we need to save nc only on success? 687 glastnc[0] = client_nc 688 allow_headers = { 689 "authentication-info": " ".join( 690 ( 691 'nextnonce="{}"'.format(nextnonce) if nextnonce else "", 692 "qop={}".format(client_qop), 693 'rspauth="{}"'.format(rspauth), 694 'cnonce="{}"'.format(client_cnonce), 695 "nc={}".format(client_nc), 696 ) 697 ).strip() 698 } 699 return make_http_reflect(headers=allow_headers)(request) 700 return deny(body=b"supplied credentials are not allowed") 701 else: 702 return http_response_bytes( 703 status=400, 704 body="unknown authorization scheme={0}".format(scheme).encode(), 705 ) 706 707 return http_reflect_with_auth_handler 708 709 710def get_cache_path(): 711 default = "./_httplib2_test_cache" 712 path = os.environ.get("httplib2_test_cache_path") or default 713 if os.path.exists(path): 714 shutil.rmtree(path) 715 return path 716 717 718def gen_digest_nonce(salt=b""): 719 t = struct.pack(">Q", int(time.time() * 1e9)) 720 return base64.b64encode(t + b":" + hashlib.sha1(t + salt).digest()).decode() 721 722 723def gen_password(): 724 length = random.randint(8, 64) 725 return "".join(six.unichr(random.randint(0, 127)) for _ in range(length)) 726 727 728def gzip_compress(bs): 729 # gzipobj = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16) 730 # result = gzipobj.compress(text) + gzipobj.flush() 731 buf = six.BytesIO() 732 gf = gzip.GzipFile(fileobj=buf, mode="wb", compresslevel=6) 733 gf.write(bs) 734 gf.close() 735 return buf.getvalue() 736 737 738def gzip_decompress(bs): 739 return zlib.decompress(bs, zlib.MAX_WBITS | 16) 740 741 742def deflate_compress(bs): 743 do = zlib.compressobj(9, zlib.DEFLATED, -zlib.MAX_WBITS) 744 return do.compress(bs) + do.flush() 745 746 747def deflate_decompress(bs): 748 return zlib.decompress(bs, -zlib.MAX_WBITS) 749 750 751def ssl_context(protocol=None): 752 """Workaround for old SSLContext() required protocol argument. 753 """ 754 if sys.version_info < (3, 5, 3): 755 return ssl.SSLContext(ssl.PROTOCOL_SSLv23) 756 return ssl.SSLContext() 757