1from __future__ import print_function 2 3import base64 4import contextlib 5import copy 6import email.utils 7import functools 8import gzip 9import hashlib 10import httplib2 11import os 12import random 13import re 14import shutil 15import six 16import socket 17import struct 18import sys 19import threading 20import time 21import traceback 22import zlib 23from six.moves import http_client, queue 24 25 26@contextlib.contextmanager 27def assert_raises(exc_type): 28 def _name(t): 29 return getattr(t, "__name__", None) or str(t) 30 31 if not isinstance(exc_type, tuple): 32 exc_type = (exc_type,) 33 names = ", ".join(map(_name, exc_type)) 34 35 try: 36 yield 37 except exc_type: 38 pass 39 else: 40 assert False, "Expected exception(s) {0}".format(names) 41 42 43class BufferedReader(object): 44 """io.BufferedReader with \r\n support 45 """ 46 47 def __init__(self, sock): 48 self._buf = b"" 49 self._end = False 50 self._newline = b"\r\n" 51 self._sock = sock 52 if isinstance(sock, bytes): 53 self._sock = None 54 self._buf = sock 55 56 def _fill(self, target=1, more=None, untilend=False): 57 if more: 58 target = len(self._buf) + more 59 while untilend or (len(self._buf) < target): 60 # crutch to enable HttpRequest.from_bytes 61 if self._sock is None: 62 chunk = b"" 63 else: 64 chunk = self._sock.recv(8 << 10) 65 # print('!!! recv', chunk) 66 if not chunk: 67 self._end = True 68 if untilend: 69 return 70 else: 71 raise EOFError 72 self._buf += chunk 73 74 def peek(self, size): 75 self._fill(target=size) 76 return self._buf[:size] 77 78 def read(self, size): 79 self._fill(target=size) 80 chunk, self._buf = self._buf[:size], self._buf[size:] 81 return chunk 82 83 def readall(self): 84 self._fill(untilend=True) 85 chunk, self._buf = self._buf, b"" 86 return chunk 87 88 def readline(self): 89 while True: 90 i = self._buf.find(self._newline) 91 if i >= 0: 92 break 93 self._fill(more=1) 94 inext = i + len(self._newline) 95 line, self._buf = self._buf[:inext], self._buf[inext:] 96 return line 97 98 99def parse_http_message(kind, buf): 100 if buf._end: 101 return None 102 try: 103 start_line = buf.readline() 104 except EOFError: 105 return None 106 msg = kind() 107 msg.raw = start_line 108 if kind is HttpRequest: 109 assert re.match( 110 br".+ HTTP/\d\.\d\r\n$", start_line 111 ), "Start line does not look like HTTP request: " + repr(start_line) 112 msg.method, msg.uri, msg.proto = start_line.rstrip().decode().split(" ", 2) 113 assert msg.proto.startswith("HTTP/"), repr(start_line) 114 elif kind is HttpResponse: 115 assert re.match( 116 br"^HTTP/\d\.\d \d+ .+\r\n$", start_line 117 ), "Start line does not look like HTTP response: " + repr(start_line) 118 msg.proto, msg.status, msg.reason = start_line.rstrip().decode().split(" ", 2) 119 msg.status = int(msg.status) 120 assert msg.proto.startswith("HTTP/"), repr(start_line) 121 else: 122 raise Exception("Use HttpRequest or HttpResponse .from_{bytes,buffered}") 123 msg.version = msg.proto[5:] 124 125 while True: 126 line = buf.readline() 127 msg.raw += line 128 line = line.rstrip() 129 if not line: 130 break 131 t = line.decode().split(":", 1) 132 msg.headers[t[0].lower()] = t[1].lstrip() 133 134 content_length_string = msg.headers.get("content-length", "") 135 if content_length_string.isdigit(): 136 content_length = int(content_length_string) 137 msg.body = msg.body_raw = buf.read(content_length) 138 elif msg.headers.get("transfer-encoding") == "chunked": 139 raise NotImplemented 140 elif msg.version == "1.0": 141 msg.body = msg.body_raw = buf.readall() 142 else: 143 msg.body = msg.body_raw = b"" 144 145 msg.raw += msg.body_raw 146 return msg 147 148 149class HttpMessage(object): 150 def __init__(self): 151 self.headers = {} 152 153 @classmethod 154 def from_bytes(cls, bs): 155 buf = BufferedReader(bs) 156 return parse_http_message(cls, buf) 157 158 @classmethod 159 def from_buffered(cls, buf): 160 return parse_http_message(cls, buf) 161 162 def __repr__(self): 163 return "{} {}".format(self.__class__, repr(vars(self))) 164 165 166class HttpRequest(HttpMessage): 167 pass 168 169 170class HttpResponse(HttpMessage): 171 pass 172 173 174class MockResponse(six.BytesIO): 175 def __init__(self, body, **kwargs): 176 six.BytesIO.__init__(self, body) 177 self.headers = kwargs 178 179 def items(self): 180 return self.headers.items() 181 182 def iteritems(self): 183 return six.iteritems(self.headers) 184 185 186class MockHTTPConnection(object): 187 """This class is just a mock of httplib.HTTPConnection used for testing 188 """ 189 190 def __init__( 191 self, 192 host, 193 port=None, 194 key_file=None, 195 cert_file=None, 196 strict=None, 197 timeout=None, 198 proxy_info=None, 199 ): 200 self.host = host 201 self.port = port 202 self.timeout = timeout 203 self.log = "" 204 self.sock = None 205 206 def set_debuglevel(self, level): 207 pass 208 209 def connect(self): 210 "Connect to a host on a given port." 211 pass 212 213 def close(self): 214 pass 215 216 def request(self, method, request_uri, body, headers): 217 pass 218 219 def getresponse(self): 220 return MockResponse(b"the body", status="200") 221 222 223class MockHTTPBadStatusConnection(object): 224 """Mock of httplib.HTTPConnection that raises BadStatusLine. 225 """ 226 227 num_calls = 0 228 229 def __init__( 230 self, 231 host, 232 port=None, 233 key_file=None, 234 cert_file=None, 235 strict=None, 236 timeout=None, 237 proxy_info=None, 238 ): 239 self.host = host 240 self.port = port 241 self.timeout = timeout 242 self.log = "" 243 self.sock = None 244 MockHTTPBadStatusConnection.num_calls = 0 245 246 def set_debuglevel(self, level): 247 pass 248 249 def connect(self): 250 pass 251 252 def close(self): 253 pass 254 255 def request(self, method, request_uri, body, headers): 256 pass 257 258 def getresponse(self): 259 MockHTTPBadStatusConnection.num_calls += 1 260 raise http_client.BadStatusLine("") 261 262 263@contextlib.contextmanager 264def server_socket(fun, request_count=1, timeout=5): 265 gresult = [None] 266 gcounter = [0] 267 268 def tick(request): 269 gcounter[0] += 1 270 keep = True 271 keep &= gcounter[0] < request_count 272 keep &= request.headers.get("connection", "").lower() != "close" 273 return keep 274 275 def server_socket_thread(srv): 276 try: 277 while gcounter[0] < request_count: 278 client, _ = srv.accept() 279 try: 280 client.settimeout(timeout) 281 fun(client, tick) 282 finally: 283 try: 284 client.shutdown(socket.SHUT_RDWR) 285 except (IOError, socket.error): 286 pass 287 # FIXME: client.close() introduces connection reset by peer 288 # at least in other/connection_close test 289 # should not be a problem since socket would close upon garbage collection 290 if gcounter[0] > request_count: 291 gresult[0] = Exception( 292 "Request count expected={0} actual={1}".format( 293 request_count, gcounter[0] 294 ) 295 ) 296 except Exception as e: 297 # traceback.print_exc caused IOError: concurrent operation on sys.stderr.close() under setup.py test 298 sys.stderr.write(traceback.format_exc().encode()) 299 gresult[0] = e 300 301 server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 302 server.bind(("localhost", 0)) 303 try: 304 server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 305 except socket.error as ex: 306 print("non critical error on SO_REUSEADDR", ex) 307 server.listen(10) 308 server.settimeout(timeout) 309 t = threading.Thread(target=server_socket_thread, args=(server,)) 310 t.daemon = True 311 t.start() 312 yield u"http://{0}:{1}/".format(*server.getsockname()) 313 server.close() 314 t.join() 315 if gresult[0] is not None: 316 raise gresult[0] 317 318 319def server_yield(fun, **kwargs): 320 q = queue.Queue(1) 321 g = fun(q.get) 322 323 def server_yield_socket_handler(sock, tick): 324 buf = BufferedReader(sock) 325 i = 0 326 while True: 327 request = HttpRequest.from_buffered(buf) 328 if request is None: 329 break 330 i += 1 331 request.client_addr = sock.getsockname() 332 request.number = i 333 q.put(request) 334 response = six.next(g) 335 sock.sendall(response) 336 if not tick(request): 337 break 338 339 return server_socket(server_yield_socket_handler, **kwargs) 340 341 342def server_request(request_handler, **kwargs): 343 def server_request_socket_handler(sock, tick): 344 buf = BufferedReader(sock) 345 i = 0 346 while True: 347 request = HttpRequest.from_buffered(buf) 348 if request is None: 349 break 350 i += 1 351 request.client_addr = sock.getsockname() 352 request.number = i 353 response = request_handler(request=request) 354 sock.sendall(response) 355 if not tick(request): 356 break 357 358 return server_socket(server_request_socket_handler, **kwargs) 359 360 361def server_const_bytes(response_content, **kwargs): 362 return server_request(lambda request: response_content, **kwargs) 363 364 365_http_kwargs = ( 366 "proto", 367 "status", 368 "headers", 369 "body", 370 "add_content_length", 371 "add_date", 372 "add_etag", 373 "undefined_body_length", 374) 375 376 377def http_response_bytes( 378 proto="HTTP/1.1", 379 status="200 OK", 380 headers=None, 381 body=b"", 382 add_content_length=True, 383 add_date=False, 384 add_etag=False, 385 undefined_body_length=False, 386 **kwargs 387): 388 if undefined_body_length: 389 add_content_length = False 390 if headers is None: 391 headers = {} 392 if add_content_length: 393 headers.setdefault("content-length", str(len(body))) 394 if add_date: 395 headers.setdefault("date", email.utils.formatdate()) 396 if add_etag: 397 headers.setdefault("etag", '"{0}"'.format(hashlib.md5(body).hexdigest())) 398 header_string = "".join("{0}: {1}\r\n".format(k, v) for k, v in headers.items()) 399 if ( 400 not undefined_body_length 401 and proto != "HTTP/1.0" 402 and "content-length" not in headers 403 ): 404 raise Exception( 405 "httplib2.tests.http_response_bytes: client could not figure response body length" 406 ) 407 if str(status).isdigit(): 408 status = "{} {}".format(status, http_client.responses[status]) 409 response = ( 410 "{proto} {status}\r\n{headers}\r\n".format( 411 proto=proto, status=status, headers=header_string 412 ).encode() 413 + body 414 ) 415 return response 416 417 418def make_http_reflect(**kwargs): 419 assert "body" not in kwargs, "make_http_reflect will overwrite response " "body" 420 421 def fun(request): 422 kw = copy.deepcopy(kwargs) 423 kw["body"] = request.raw 424 response = http_response_bytes(**kw) 425 return response 426 427 return fun 428 429 430def server_route(routes, **kwargs): 431 response_404 = http_response_bytes(status="404 Not Found") 432 response_wildcard = routes.get("") 433 434 def handler(request): 435 target = routes.get(request.uri, response_wildcard) or response_404 436 if callable(target): 437 response = target(request=request) 438 else: 439 response = target 440 return response 441 442 return server_request(handler, **kwargs) 443 444 445def server_const_http(**kwargs): 446 response_kwargs = {k: kwargs.pop(k) for k in dict(kwargs) if k in _http_kwargs} 447 response = http_response_bytes(**response_kwargs) 448 return server_const_bytes(response, **kwargs) 449 450 451def server_list_http(responses, **kwargs): 452 i = iter(responses) 453 454 def handler(request): 455 return next(i) 456 457 kwargs.setdefault("request_count", len(responses)) 458 return server_request(handler, **kwargs) 459 460 461def server_reflect(**kwargs): 462 response_kwargs = {k: kwargs.pop(k) for k in dict(kwargs) if k in _http_kwargs} 463 http_handler = make_http_reflect(**response_kwargs) 464 return server_request(http_handler, **kwargs) 465 466 467def http_parse_auth(s): 468 """https://tools.ietf.org/html/rfc7235#section-2.1 469 """ 470 scheme, rest = s.split(" ", 1) 471 result = {} 472 while True: 473 m = httplib2.WWW_AUTH_RELAXED.search(rest) 474 if not m: 475 break 476 if len(m.groups()) == 3: 477 key, value, rest = m.groups() 478 result[key.lower()] = httplib2.UNQUOTE_PAIRS.sub(r"\1", value) 479 return result 480 481 482def store_request_response(out): 483 def wrapper(fun): 484 @functools.wraps(fun) 485 def wrapped(request, *a, **kw): 486 response_bytes = fun(request, *a, **kw) 487 if out is not None: 488 response = HttpResponse.from_bytes(response_bytes) 489 out.append((request, response)) 490 return response_bytes 491 492 return wrapped 493 494 return wrapper 495 496 497def http_reflect_with_auth( 498 allow_scheme, allow_credentials, out_renew_nonce=None, out_requests=None 499): 500 """allow_scheme - 'basic', 'digest', etc allow_credentials - sequence of ('name', 'password') out_renew_nonce - None | [function] 501 502 Way to return nonce renew function to caller. 503 Kind of `out` parameter in some programming languages. 504 Allows to keep same signature for all handler builder functions. 505 out_requests - None | [] 506 If set to list, every parsed request will be appended here. 507 """ 508 glastnc = [None] 509 gnextnonce = [None] 510 gserver_nonce = [gen_digest_nonce(salt=b"n")] 511 realm = "httplib2 test" 512 server_opaque = gen_digest_nonce(salt=b"o") 513 514 def renew_nonce(): 515 if gnextnonce[0]: 516 assert False, ( 517 "previous nextnonce was not used, probably bug in " "test code" 518 ) 519 gnextnonce[0] = gen_digest_nonce() 520 return gserver_nonce[0], gnextnonce[0] 521 522 if out_renew_nonce: 523 out_renew_nonce[0] = renew_nonce 524 525 def deny(**kwargs): 526 nonce_stale = kwargs.pop("nonce_stale", False) 527 if nonce_stale: 528 kwargs.setdefault("body", b"nonce stale") 529 if allow_scheme == "basic": 530 authenticate = 'basic realm="{realm}"'.format(realm=realm) 531 elif allow_scheme == "digest": 532 authenticate = ( 533 'digest realm="{realm}", qop="auth"' 534 + ', nonce="{nonce}", opaque="{opaque}"' 535 + (", stale=true" if nonce_stale else "") 536 ).format(realm=realm, nonce=gserver_nonce[0], opaque=server_opaque) 537 else: 538 raise Exception("unknown allow_scheme={0}".format(allow_scheme)) 539 deny_headers = {"www-authenticate": authenticate} 540 kwargs.setdefault("status", 401) 541 # supplied headers may overwrite generated ones 542 deny_headers.update(kwargs.get("headers", {})) 543 kwargs["headers"] = deny_headers 544 kwargs.setdefault("body", b"HTTP authorization required") 545 return http_response_bytes(**kwargs) 546 547 @store_request_response(out_requests) 548 def http_reflect_with_auth_handler(request): 549 auth_header = request.headers.get("authorization", "") 550 if not auth_header: 551 return deny() 552 if " " not in auth_header: 553 return http_response_bytes( 554 status=400, body=b"authorization header syntax error" 555 ) 556 scheme, data = auth_header.split(" ", 1) 557 scheme = scheme.lower() 558 if scheme != allow_scheme: 559 return deny(body=b"must use different auth scheme") 560 if scheme == "basic": 561 decoded = base64.b64decode(data).decode() 562 username, password = decoded.split(":", 1) 563 if (username, password) in allow_credentials: 564 return make_http_reflect()(request) 565 else: 566 return deny(body=b"supplied credentials are not allowed") 567 elif scheme == "digest": 568 server_nonce_old = gserver_nonce[0] 569 nextnonce = gnextnonce[0] 570 if nextnonce: 571 # server decided to change nonce, in this case, guided by caller test code 572 gserver_nonce[0] = nextnonce 573 gnextnonce[0] = None 574 server_nonce_current = gserver_nonce[0] 575 auth_info = http_parse_auth(data) 576 client_cnonce = auth_info.get("cnonce", "") 577 client_nc = auth_info.get("nc", "") 578 client_nonce = auth_info.get("nonce", "") 579 client_opaque = auth_info.get("opaque", "") 580 client_qop = auth_info.get("qop", "auth").strip('"') 581 582 # TODO: auth_info.get('algorithm', 'md5') 583 hasher = hashlib.md5 584 585 # TODO: client_qop auth-int 586 ha2 = hasher(":".join((request.method, request.uri)).encode()).hexdigest() 587 588 if client_nonce != server_nonce_current: 589 if client_nonce == server_nonce_old: 590 return deny(nonce_stale=True) 591 return deny(body=b"invalid nonce") 592 if not client_nc: 593 return deny(body=b"auth-info nc missing") 594 if client_opaque != server_opaque: 595 return deny( 596 body="auth-info opaque mismatch expected={} actual={}".format( 597 server_opaque, client_opaque 598 ).encode() 599 ) 600 for allow_username, allow_password in allow_credentials: 601 ha1 = hasher( 602 ":".join((allow_username, realm, allow_password)).encode() 603 ).hexdigest() 604 allow_response = hasher( 605 ":".join( 606 (ha1, client_nonce, client_nc, client_cnonce, client_qop, ha2) 607 ).encode() 608 ).hexdigest() 609 rspauth_ha2 = hasher(":{}".format(request.uri).encode()).hexdigest() 610 rspauth = hasher( 611 ":".join( 612 ( 613 ha1, 614 client_nonce, 615 client_nc, 616 client_cnonce, 617 client_qop, 618 rspauth_ha2, 619 ) 620 ).encode() 621 ).hexdigest() 622 if auth_info.get("response", "") == allow_response: 623 # TODO: fix or remove doubtful comment 624 # do we need to save nc only on success? 625 glastnc[0] = client_nc 626 allow_headers = { 627 "authentication-info": " ".join( 628 ( 629 'nextnonce="{}"'.format(nextnonce) if nextnonce else "", 630 "qop={}".format(client_qop), 631 'rspauth="{}"'.format(rspauth), 632 'cnonce="{}"'.format(client_cnonce), 633 "nc={}".format(client_nc), 634 ) 635 ).strip() 636 } 637 return make_http_reflect(headers=allow_headers)(request) 638 return deny(body=b"supplied credentials are not allowed") 639 else: 640 return http_response_bytes( 641 status=400, 642 body="unknown authorization scheme={0}".format(scheme).encode(), 643 ) 644 645 return http_reflect_with_auth_handler 646 647 648def get_cache_path(): 649 default = "./_httplib2_test_cache" 650 path = os.environ.get("httplib2_test_cache_path") or default 651 if os.path.exists(path): 652 shutil.rmtree(path) 653 return path 654 655 656def gen_digest_nonce(salt=b""): 657 t = struct.pack(">Q", int(time.time() * 1e9)) 658 return base64.b64encode(t + b":" + hashlib.sha1(t + salt).digest()).decode() 659 660 661def gen_password(): 662 length = random.randint(8, 64) 663 return "".join(six.unichr(random.randint(0, 127)) for _ in range(length)) 664 665 666def gzip_compress(bs): 667 # gzipobj = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16) 668 # result = gzipobj.compress(text) + gzipobj.flush() 669 buf = six.BytesIO() 670 gf = gzip.GzipFile(fileobj=buf, mode="wb", compresslevel=6) 671 gf.write(bs) 672 gf.close() 673 return buf.getvalue() 674 675 676def gzip_decompress(bs): 677 return zlib.decompress(bs, zlib.MAX_WBITS | 16) 678 679 680def deflate_compress(bs): 681 do = zlib.compressobj(9, zlib.DEFLATED, -zlib.MAX_WBITS) 682 return do.compress(bs) + do.flush() 683 684 685def deflate_decompress(bs): 686 return zlib.decompress(bs, -zlib.MAX_WBITS) 687