• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1from __future__ import print_function
2
3import base64
4import contextlib
5import copy
6import email.utils
7import functools
8import gzip
9import hashlib
10import httplib2
11import os
12import random
13import re
14import shutil
15import six
16import socket
17import struct
18import sys
19import threading
20import time
21import traceback
22import zlib
23from six.moves import http_client, queue
24
25
26@contextlib.contextmanager
27def assert_raises(exc_type):
28    def _name(t):
29        return getattr(t, "__name__", None) or str(t)
30
31    if not isinstance(exc_type, tuple):
32        exc_type = (exc_type,)
33    names = ", ".join(map(_name, exc_type))
34
35    try:
36        yield
37    except exc_type:
38        pass
39    else:
40        assert False, "Expected exception(s) {0}".format(names)
41
42
43class BufferedReader(object):
44    """io.BufferedReader with \r\n support
45    """
46
47    def __init__(self, sock):
48        self._buf = b""
49        self._end = False
50        self._newline = b"\r\n"
51        self._sock = sock
52        if isinstance(sock, bytes):
53            self._sock = None
54            self._buf = sock
55
56    def _fill(self, target=1, more=None, untilend=False):
57        if more:
58            target = len(self._buf) + more
59        while untilend or (len(self._buf) < target):
60            # crutch to enable HttpRequest.from_bytes
61            if self._sock is None:
62                chunk = b""
63            else:
64                chunk = self._sock.recv(8 << 10)
65            # print('!!! recv', chunk)
66            if not chunk:
67                self._end = True
68                if untilend:
69                    return
70                else:
71                    raise EOFError
72            self._buf += chunk
73
74    def peek(self, size):
75        self._fill(target=size)
76        return self._buf[:size]
77
78    def read(self, size):
79        self._fill(target=size)
80        chunk, self._buf = self._buf[:size], self._buf[size:]
81        return chunk
82
83    def readall(self):
84        self._fill(untilend=True)
85        chunk, self._buf = self._buf, b""
86        return chunk
87
88    def readline(self):
89        while True:
90            i = self._buf.find(self._newline)
91            if i >= 0:
92                break
93            self._fill(more=1)
94        inext = i + len(self._newline)
95        line, self._buf = self._buf[:inext], self._buf[inext:]
96        return line
97
98
99def parse_http_message(kind, buf):
100    if buf._end:
101        return None
102    try:
103        start_line = buf.readline()
104    except EOFError:
105        return None
106    msg = kind()
107    msg.raw = start_line
108    if kind is HttpRequest:
109        assert re.match(
110            br".+ HTTP/\d\.\d\r\n$", start_line
111        ), "Start line does not look like HTTP request: " + repr(start_line)
112        msg.method, msg.uri, msg.proto = start_line.rstrip().decode().split(" ", 2)
113        assert msg.proto.startswith("HTTP/"), repr(start_line)
114    elif kind is HttpResponse:
115        assert re.match(
116            br"^HTTP/\d\.\d \d+ .+\r\n$", start_line
117        ), "Start line does not look like HTTP response: " + repr(start_line)
118        msg.proto, msg.status, msg.reason = start_line.rstrip().decode().split(" ", 2)
119        msg.status = int(msg.status)
120        assert msg.proto.startswith("HTTP/"), repr(start_line)
121    else:
122        raise Exception("Use HttpRequest or HttpResponse .from_{bytes,buffered}")
123    msg.version = msg.proto[5:]
124
125    while True:
126        line = buf.readline()
127        msg.raw += line
128        line = line.rstrip()
129        if not line:
130            break
131        t = line.decode().split(":", 1)
132        msg.headers[t[0].lower()] = t[1].lstrip()
133
134    content_length_string = msg.headers.get("content-length", "")
135    if content_length_string.isdigit():
136        content_length = int(content_length_string)
137        msg.body = msg.body_raw = buf.read(content_length)
138    elif msg.headers.get("transfer-encoding") == "chunked":
139        raise NotImplemented
140    elif msg.version == "1.0":
141        msg.body = msg.body_raw = buf.readall()
142    else:
143        msg.body = msg.body_raw = b""
144
145    msg.raw += msg.body_raw
146    return msg
147
148
149class HttpMessage(object):
150    def __init__(self):
151        self.headers = {}
152
153    @classmethod
154    def from_bytes(cls, bs):
155        buf = BufferedReader(bs)
156        return parse_http_message(cls, buf)
157
158    @classmethod
159    def from_buffered(cls, buf):
160        return parse_http_message(cls, buf)
161
162    def __repr__(self):
163        return "{} {}".format(self.__class__, repr(vars(self)))
164
165
166class HttpRequest(HttpMessage):
167    pass
168
169
170class HttpResponse(HttpMessage):
171    pass
172
173
174class MockResponse(six.BytesIO):
175    def __init__(self, body, **kwargs):
176        six.BytesIO.__init__(self, body)
177        self.headers = kwargs
178
179    def items(self):
180        return self.headers.items()
181
182    def iteritems(self):
183        return six.iteritems(self.headers)
184
185
186class MockHTTPConnection(object):
187    """This class is just a mock of httplib.HTTPConnection used for testing
188    """
189
190    def __init__(
191        self,
192        host,
193        port=None,
194        key_file=None,
195        cert_file=None,
196        strict=None,
197        timeout=None,
198        proxy_info=None,
199    ):
200        self.host = host
201        self.port = port
202        self.timeout = timeout
203        self.log = ""
204        self.sock = None
205
206    def set_debuglevel(self, level):
207        pass
208
209    def connect(self):
210        "Connect to a host on a given port."
211        pass
212
213    def close(self):
214        pass
215
216    def request(self, method, request_uri, body, headers):
217        pass
218
219    def getresponse(self):
220        return MockResponse(b"the body", status="200")
221
222
223class MockHTTPBadStatusConnection(object):
224    """Mock of httplib.HTTPConnection that raises BadStatusLine.
225    """
226
227    num_calls = 0
228
229    def __init__(
230        self,
231        host,
232        port=None,
233        key_file=None,
234        cert_file=None,
235        strict=None,
236        timeout=None,
237        proxy_info=None,
238    ):
239        self.host = host
240        self.port = port
241        self.timeout = timeout
242        self.log = ""
243        self.sock = None
244        MockHTTPBadStatusConnection.num_calls = 0
245
246    def set_debuglevel(self, level):
247        pass
248
249    def connect(self):
250        pass
251
252    def close(self):
253        pass
254
255    def request(self, method, request_uri, body, headers):
256        pass
257
258    def getresponse(self):
259        MockHTTPBadStatusConnection.num_calls += 1
260        raise http_client.BadStatusLine("")
261
262
263@contextlib.contextmanager
264def server_socket(fun, request_count=1, timeout=5):
265    gresult = [None]
266    gcounter = [0]
267
268    def tick(request):
269        gcounter[0] += 1
270        keep = True
271        keep &= gcounter[0] < request_count
272        keep &= request.headers.get("connection", "").lower() != "close"
273        return keep
274
275    def server_socket_thread(srv):
276        try:
277            while gcounter[0] < request_count:
278                client, _ = srv.accept()
279                try:
280                    client.settimeout(timeout)
281                    fun(client, tick)
282                finally:
283                    try:
284                        client.shutdown(socket.SHUT_RDWR)
285                    except (IOError, socket.error):
286                        pass
287                    # FIXME: client.close() introduces connection reset by peer
288                    # at least in other/connection_close test
289                    # should not be a problem since socket would close upon garbage collection
290            if gcounter[0] > request_count:
291                gresult[0] = Exception(
292                    "Request count expected={0} actual={1}".format(
293                        request_count, gcounter[0]
294                    )
295                )
296        except Exception as e:
297            # traceback.print_exc caused IOError: concurrent operation on sys.stderr.close() under setup.py test
298            sys.stderr.write(traceback.format_exc().encode())
299            gresult[0] = e
300
301    server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
302    server.bind(("localhost", 0))
303    try:
304        server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
305    except socket.error as ex:
306        print("non critical error on SO_REUSEADDR", ex)
307    server.listen(10)
308    server.settimeout(timeout)
309    t = threading.Thread(target=server_socket_thread, args=(server,))
310    t.daemon = True
311    t.start()
312    yield u"http://{0}:{1}/".format(*server.getsockname())
313    server.close()
314    t.join()
315    if gresult[0] is not None:
316        raise gresult[0]
317
318
319def server_yield(fun, **kwargs):
320    q = queue.Queue(1)
321    g = fun(q.get)
322
323    def server_yield_socket_handler(sock, tick):
324        buf = BufferedReader(sock)
325        i = 0
326        while True:
327            request = HttpRequest.from_buffered(buf)
328            if request is None:
329                break
330            i += 1
331            request.client_addr = sock.getsockname()
332            request.number = i
333            q.put(request)
334            response = six.next(g)
335            sock.sendall(response)
336            if not tick(request):
337                break
338
339    return server_socket(server_yield_socket_handler, **kwargs)
340
341
342def server_request(request_handler, **kwargs):
343    def server_request_socket_handler(sock, tick):
344        buf = BufferedReader(sock)
345        i = 0
346        while True:
347            request = HttpRequest.from_buffered(buf)
348            if request is None:
349                break
350            i += 1
351            request.client_addr = sock.getsockname()
352            request.number = i
353            response = request_handler(request=request)
354            sock.sendall(response)
355            if not tick(request):
356                break
357
358    return server_socket(server_request_socket_handler, **kwargs)
359
360
361def server_const_bytes(response_content, **kwargs):
362    return server_request(lambda request: response_content, **kwargs)
363
364
365_http_kwargs = (
366    "proto",
367    "status",
368    "headers",
369    "body",
370    "add_content_length",
371    "add_date",
372    "add_etag",
373    "undefined_body_length",
374)
375
376
377def http_response_bytes(
378    proto="HTTP/1.1",
379    status="200 OK",
380    headers=None,
381    body=b"",
382    add_content_length=True,
383    add_date=False,
384    add_etag=False,
385    undefined_body_length=False,
386    **kwargs
387):
388    if undefined_body_length:
389        add_content_length = False
390    if headers is None:
391        headers = {}
392    if add_content_length:
393        headers.setdefault("content-length", str(len(body)))
394    if add_date:
395        headers.setdefault("date", email.utils.formatdate())
396    if add_etag:
397        headers.setdefault("etag", '"{0}"'.format(hashlib.md5(body).hexdigest()))
398    header_string = "".join("{0}: {1}\r\n".format(k, v) for k, v in headers.items())
399    if (
400        not undefined_body_length
401        and proto != "HTTP/1.0"
402        and "content-length" not in headers
403    ):
404        raise Exception(
405            "httplib2.tests.http_response_bytes: client could not figure response body length"
406        )
407    if str(status).isdigit():
408        status = "{} {}".format(status, http_client.responses[status])
409    response = (
410        "{proto} {status}\r\n{headers}\r\n".format(
411            proto=proto, status=status, headers=header_string
412        ).encode()
413        + body
414    )
415    return response
416
417
418def make_http_reflect(**kwargs):
419    assert "body" not in kwargs, "make_http_reflect will overwrite response " "body"
420
421    def fun(request):
422        kw = copy.deepcopy(kwargs)
423        kw["body"] = request.raw
424        response = http_response_bytes(**kw)
425        return response
426
427    return fun
428
429
430def server_route(routes, **kwargs):
431    response_404 = http_response_bytes(status="404 Not Found")
432    response_wildcard = routes.get("")
433
434    def handler(request):
435        target = routes.get(request.uri, response_wildcard) or response_404
436        if callable(target):
437            response = target(request=request)
438        else:
439            response = target
440        return response
441
442    return server_request(handler, **kwargs)
443
444
445def server_const_http(**kwargs):
446    response_kwargs = {k: kwargs.pop(k) for k in dict(kwargs) if k in _http_kwargs}
447    response = http_response_bytes(**response_kwargs)
448    return server_const_bytes(response, **kwargs)
449
450
451def server_list_http(responses, **kwargs):
452    i = iter(responses)
453
454    def handler(request):
455        return next(i)
456
457    kwargs.setdefault("request_count", len(responses))
458    return server_request(handler, **kwargs)
459
460
461def server_reflect(**kwargs):
462    response_kwargs = {k: kwargs.pop(k) for k in dict(kwargs) if k in _http_kwargs}
463    http_handler = make_http_reflect(**response_kwargs)
464    return server_request(http_handler, **kwargs)
465
466
467def http_parse_auth(s):
468    """https://tools.ietf.org/html/rfc7235#section-2.1
469    """
470    scheme, rest = s.split(" ", 1)
471    result = {}
472    while True:
473        m = httplib2.WWW_AUTH_RELAXED.search(rest)
474        if not m:
475            break
476        if len(m.groups()) == 3:
477            key, value, rest = m.groups()
478            result[key.lower()] = httplib2.UNQUOTE_PAIRS.sub(r"\1", value)
479    return result
480
481
482def store_request_response(out):
483    def wrapper(fun):
484        @functools.wraps(fun)
485        def wrapped(request, *a, **kw):
486            response_bytes = fun(request, *a, **kw)
487            if out is not None:
488                response = HttpResponse.from_bytes(response_bytes)
489                out.append((request, response))
490            return response_bytes
491
492        return wrapped
493
494    return wrapper
495
496
497def http_reflect_with_auth(
498    allow_scheme, allow_credentials, out_renew_nonce=None, out_requests=None
499):
500    """allow_scheme - 'basic', 'digest', etc allow_credentials - sequence of ('name', 'password') out_renew_nonce - None | [function]
501
502        Way to return nonce renew function to caller.
503        Kind of `out` parameter in some programming languages.
504        Allows to keep same signature for all handler builder functions.
505    out_requests - None | []
506        If set to list, every parsed request will be appended here.
507    """
508    glastnc = [None]
509    gnextnonce = [None]
510    gserver_nonce = [gen_digest_nonce(salt=b"n")]
511    realm = "httplib2 test"
512    server_opaque = gen_digest_nonce(salt=b"o")
513
514    def renew_nonce():
515        if gnextnonce[0]:
516            assert False, (
517                "previous nextnonce was not used, probably bug in " "test code"
518            )
519        gnextnonce[0] = gen_digest_nonce()
520        return gserver_nonce[0], gnextnonce[0]
521
522    if out_renew_nonce:
523        out_renew_nonce[0] = renew_nonce
524
525    def deny(**kwargs):
526        nonce_stale = kwargs.pop("nonce_stale", False)
527        if nonce_stale:
528            kwargs.setdefault("body", b"nonce stale")
529        if allow_scheme == "basic":
530            authenticate = 'basic realm="{realm}"'.format(realm=realm)
531        elif allow_scheme == "digest":
532            authenticate = (
533                'digest realm="{realm}", qop="auth"'
534                + ', nonce="{nonce}", opaque="{opaque}"'
535                + (", stale=true" if nonce_stale else "")
536            ).format(realm=realm, nonce=gserver_nonce[0], opaque=server_opaque)
537        else:
538            raise Exception("unknown allow_scheme={0}".format(allow_scheme))
539        deny_headers = {"www-authenticate": authenticate}
540        kwargs.setdefault("status", 401)
541        # supplied headers may overwrite generated ones
542        deny_headers.update(kwargs.get("headers", {}))
543        kwargs["headers"] = deny_headers
544        kwargs.setdefault("body", b"HTTP authorization required")
545        return http_response_bytes(**kwargs)
546
547    @store_request_response(out_requests)
548    def http_reflect_with_auth_handler(request):
549        auth_header = request.headers.get("authorization", "")
550        if not auth_header:
551            return deny()
552        if " " not in auth_header:
553            return http_response_bytes(
554                status=400, body=b"authorization header syntax error"
555            )
556        scheme, data = auth_header.split(" ", 1)
557        scheme = scheme.lower()
558        if scheme != allow_scheme:
559            return deny(body=b"must use different auth scheme")
560        if scheme == "basic":
561            decoded = base64.b64decode(data).decode()
562            username, password = decoded.split(":", 1)
563            if (username, password) in allow_credentials:
564                return make_http_reflect()(request)
565            else:
566                return deny(body=b"supplied credentials are not allowed")
567        elif scheme == "digest":
568            server_nonce_old = gserver_nonce[0]
569            nextnonce = gnextnonce[0]
570            if nextnonce:
571                # server decided to change nonce, in this case, guided by caller test code
572                gserver_nonce[0] = nextnonce
573                gnextnonce[0] = None
574            server_nonce_current = gserver_nonce[0]
575            auth_info = http_parse_auth(data)
576            client_cnonce = auth_info.get("cnonce", "")
577            client_nc = auth_info.get("nc", "")
578            client_nonce = auth_info.get("nonce", "")
579            client_opaque = auth_info.get("opaque", "")
580            client_qop = auth_info.get("qop", "auth").strip('"')
581
582            # TODO: auth_info.get('algorithm', 'md5')
583            hasher = hashlib.md5
584
585            # TODO: client_qop auth-int
586            ha2 = hasher(":".join((request.method, request.uri)).encode()).hexdigest()
587
588            if client_nonce != server_nonce_current:
589                if client_nonce == server_nonce_old:
590                    return deny(nonce_stale=True)
591                return deny(body=b"invalid nonce")
592            if not client_nc:
593                return deny(body=b"auth-info nc missing")
594            if client_opaque != server_opaque:
595                return deny(
596                    body="auth-info opaque mismatch expected={} actual={}".format(
597                        server_opaque, client_opaque
598                    ).encode()
599                )
600            for allow_username, allow_password in allow_credentials:
601                ha1 = hasher(
602                    ":".join((allow_username, realm, allow_password)).encode()
603                ).hexdigest()
604                allow_response = hasher(
605                    ":".join(
606                        (ha1, client_nonce, client_nc, client_cnonce, client_qop, ha2)
607                    ).encode()
608                ).hexdigest()
609                rspauth_ha2 = hasher(":{}".format(request.uri).encode()).hexdigest()
610                rspauth = hasher(
611                    ":".join(
612                        (
613                            ha1,
614                            client_nonce,
615                            client_nc,
616                            client_cnonce,
617                            client_qop,
618                            rspauth_ha2,
619                        )
620                    ).encode()
621                ).hexdigest()
622                if auth_info.get("response", "") == allow_response:
623                    # TODO: fix or remove doubtful comment
624                    # do we need to save nc only on success?
625                    glastnc[0] = client_nc
626                    allow_headers = {
627                        "authentication-info": " ".join(
628                            (
629                                'nextnonce="{}"'.format(nextnonce) if nextnonce else "",
630                                "qop={}".format(client_qop),
631                                'rspauth="{}"'.format(rspauth),
632                                'cnonce="{}"'.format(client_cnonce),
633                                "nc={}".format(client_nc),
634                            )
635                        ).strip()
636                    }
637                    return make_http_reflect(headers=allow_headers)(request)
638            return deny(body=b"supplied credentials are not allowed")
639        else:
640            return http_response_bytes(
641                status=400,
642                body="unknown authorization scheme={0}".format(scheme).encode(),
643            )
644
645    return http_reflect_with_auth_handler
646
647
648def get_cache_path():
649    default = "./_httplib2_test_cache"
650    path = os.environ.get("httplib2_test_cache_path") or default
651    if os.path.exists(path):
652        shutil.rmtree(path)
653    return path
654
655
656def gen_digest_nonce(salt=b""):
657    t = struct.pack(">Q", int(time.time() * 1e9))
658    return base64.b64encode(t + b":" + hashlib.sha1(t + salt).digest()).decode()
659
660
661def gen_password():
662    length = random.randint(8, 64)
663    return "".join(six.unichr(random.randint(0, 127)) for _ in range(length))
664
665
666def gzip_compress(bs):
667    # gzipobj = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16)
668    # result = gzipobj.compress(text) + gzipobj.flush()
669    buf = six.BytesIO()
670    gf = gzip.GzipFile(fileobj=buf, mode="wb", compresslevel=6)
671    gf.write(bs)
672    gf.close()
673    return buf.getvalue()
674
675
676def gzip_decompress(bs):
677    return zlib.decompress(bs, zlib.MAX_WBITS | 16)
678
679
680def deflate_compress(bs):
681    do = zlib.compressobj(9, zlib.DEFLATED, -zlib.MAX_WBITS)
682    return do.compress(bs) + do.flush()
683
684
685def deflate_decompress(bs):
686    return zlib.decompress(bs, -zlib.MAX_WBITS)
687