• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3#***************************************************************************
4#                                  _   _ ____  _
5#  Project                     ___| | | |  _ \| |
6#                             / __| | | | |_) | |
7#                            | (__| |_| |  _ <| |___
8#                             \___|\___/|_| \_\_____|
9#
10# Copyright (C) Daniel Stenberg, <daniel@haxx.se>, et al.
11#
12# This software is licensed as described in the file COPYING, which
13# you should have received as part of this distribution. The terms
14# are also available at https://curl.se/docs/copyright.html.
15#
16# You may opt to use, copy, modify, merge, publish, distribute and/or sell
17# copies of the Software, and permit persons to whom the Software is
18# furnished to do so, under the terms of the COPYING file.
19#
20# This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY
21# KIND, either express or implied.
22#
23# SPDX-License-Identifier: curl
24#
25###########################################################################
26#
27import json
28import logging
29import os
30import re
31import shutil
32import subprocess
33from datetime import timedelta, datetime
34from typing import List, Optional, Dict
35from urllib.parse import urlparse
36
37from .env import Env
38
39
40log = logging.getLogger(__name__)
41
42
43class ExecResult:
44
45    def __init__(self, args: List[str], exit_code: int,
46                 stdout: List[str], stderr: List[str],
47                 duration: Optional[timedelta] = None,
48                 with_stats: bool = False):
49        self._args = args
50        self._exit_code = exit_code
51        self._stdout = stdout
52        self._stderr = stderr
53        self._duration = duration if duration is not None else timedelta()
54        self._response = None
55        self._responses = []
56        self._results = {}
57        self._assets = []
58        self._stats = []
59        self._json_out = None
60        self._with_stats = with_stats
61        if with_stats:
62            self._parse_stats()
63        else:
64            # noinspection PyBroadException
65            try:
66                out = ''.join(self._stdout)
67                self._json_out = json.loads(out)
68            except:
69                pass
70
71    def __repr__(self):
72        return f"ExecResult[code={self.exit_code}, args={self._args}, stdout={self._stdout}, stderr={self._stderr}]"
73
74    def _parse_stats(self):
75        self._stats = []
76        for l in self._stdout:
77            try:
78                self._stats.append(json.loads(l))
79            except:
80                log.error(f'not a JSON stat: {l}')
81                log.error(f'stdout is: {"".join(self._stdout)}')
82                break
83
84    @property
85    def exit_code(self) -> int:
86        return self._exit_code
87
88    @property
89    def args(self) -> List[str]:
90        return self._args
91
92    @property
93    def outraw(self) -> bytes:
94        return ''.join(self._stdout).encode()
95
96    @property
97    def stdout(self) -> str:
98        return ''.join(self._stdout)
99
100    @property
101    def json(self) -> Optional[Dict]:
102        """Output as JSON dictionary or None if not parseable."""
103        return self._json_out
104
105    @property
106    def stderr(self) -> str:
107        return ''.join(self._stderr)
108
109    @property
110    def duration(self) -> timedelta:
111        return self._duration
112
113    @property
114    def response(self) -> Optional[Dict]:
115        return self._response
116
117    @property
118    def responses(self) -> List[Dict]:
119        return self._responses
120
121    @property
122    def results(self) -> Dict:
123        return self._results
124
125    @property
126    def assets(self) -> List:
127        return self._assets
128
129    @property
130    def with_stats(self) -> bool:
131        return self._with_stats
132
133    @property
134    def stats(self) -> List:
135        return self._stats
136
137    @property
138    def total_connects(self) -> Optional[int]:
139        if len(self.stats):
140            n = 0
141            for stat in self.stats:
142                n += stat['num_connects']
143            return n
144        return None
145
146    def add_response(self, resp: Dict):
147        self._response = resp
148        self._responses.append(resp)
149
150    def add_results(self, results: Dict):
151        self._results.update(results)
152        if 'response' in results:
153            self.add_response(results['response'])
154
155    def add_assets(self, assets: List):
156        self._assets.extend(assets)
157
158    def check_responses(self, count: int, exp_status: Optional[int] = None,
159                        exp_exitcode: Optional[int] = None):
160        assert len(self.responses) == count, \
161            f'response count: expected {count}, got {len(self.responses)}'
162        if exp_status is not None:
163            for idx, x in enumerate(self.responses):
164                assert x['status'] == exp_status, \
165                    f'response #{idx} unexpectedstatus: {x["status"]}'
166        if exp_exitcode is not None:
167            for idx, x in enumerate(self.responses):
168                if 'exitcode' in x:
169                    assert x['exitcode'] == 0, f'response #{idx} exitcode: {x["exitcode"]}'
170        if self.with_stats:
171            assert len(self.stats) == count, f'{self}'
172
173    def check_stats(self, count: int, exp_status: Optional[int] = None,
174                        exp_exitcode: Optional[int] = None):
175        assert len(self.stats) == count, \
176            f'stats count: expected {count}, got {len(self.stats)}'
177        if exp_status is not None:
178            for idx, x in enumerate(self.stats):
179                assert 'http_code' in x, \
180                    f'status #{idx} reports no http_code'
181                assert x['http_code'] == exp_status, \
182                    f'status #{idx} unexpected http_code: {x["http_code"]}'
183        if exp_exitcode is not None:
184            for idx, x in enumerate(self.stats):
185                if 'exitcode' in x:
186                    assert x['exitcode'] == 0, f'status #{idx} exitcode: {x["exitcode"]}'
187
188
189class CurlClient:
190
191    ALPN_ARG = {
192        'http/0.9': '--http0.9',
193        'http/1.0': '--http1.0',
194        'http/1.1': '--http1.1',
195        'h2': '--http2',
196        'h2c': '--http2',
197        'h3': '--http3-only',
198    }
199
200    def __init__(self, env: Env, run_dir: Optional[str] = None):
201        self.env = env
202        self._curl = os.environ['CURL'] if 'CURL' in os.environ else env.curl
203        self._run_dir = run_dir if run_dir else os.path.join(env.gen_dir, 'curl')
204        self._stdoutfile = f'{self._run_dir}/curl.stdout'
205        self._stderrfile = f'{self._run_dir}/curl.stderr'
206        self._headerfile = f'{self._run_dir}/curl.headers'
207        self._tracefile = f'{self._run_dir}/curl.trace'
208        self._log_path = f'{self._run_dir}/curl.log'
209        self._rmrf(self._run_dir)
210        self._mkpath(self._run_dir)
211
212    @property
213    def run_dir(self) -> str:
214        return self._run_dir
215
216    def download_file(self, i: int) -> str:
217        return os.path.join(self.run_dir, f'download_{i}.data')
218
219    def _rmf(self, path):
220        if os.path.exists(path):
221            return os.remove(path)
222
223    def _rmrf(self, path):
224        if os.path.exists(path):
225            return shutil.rmtree(path)
226
227    def _mkpath(self, path):
228        if not os.path.exists(path):
229            return os.makedirs(path)
230
231    def http_get(self, url: str, extra_args: Optional[List[str]] = None):
232        return self._raw(url, options=extra_args, with_stats=False)
233
234    def http_download(self, urls: List[str],
235                      alpn_proto: Optional[str] = None,
236                      with_stats: bool = True,
237                      with_headers: bool = False,
238                      no_save: bool = False,
239                      extra_args: List[str] = None):
240        if extra_args is None:
241            extra_args = []
242        if no_save:
243            extra_args.extend([
244                '-o', '/dev/null',
245            ])
246        else:
247            extra_args.extend([
248                '-o', 'download_#1.data',
249            ])
250        # remove any existing ones
251        for i in range(100):
252            self._rmf(self.download_file(i))
253        if with_stats:
254            extra_args.extend([
255                '-w', '%{json}\\n'
256            ])
257        return self._raw(urls, alpn_proto=alpn_proto, options=extra_args,
258                         with_stats=with_stats,
259                         with_headers=with_headers)
260
261    def http_upload(self, urls: List[str], data: str,
262                    alpn_proto: Optional[str] = None,
263                    with_stats: bool = True,
264                    with_headers: bool = False,
265                    extra_args: Optional[List[str]] = None):
266        if extra_args is None:
267            extra_args = []
268        extra_args.extend([
269            '--data-binary', data, '-o', 'download_#1.data',
270        ])
271        if with_stats:
272            extra_args.extend([
273                '-w', '%{json}\\n'
274            ])
275        return self._raw(urls, alpn_proto=alpn_proto, options=extra_args,
276                         with_stats=with_stats,
277                         with_headers=with_headers)
278
279    def http_put(self, urls: List[str], data=None, fdata=None,
280                 alpn_proto: Optional[str] = None,
281                 with_stats: bool = True,
282                 with_headers: bool = False,
283                 extra_args: Optional[List[str]] = None):
284        if extra_args is None:
285            extra_args = []
286        if fdata is not None:
287            extra_args.extend(['-T', fdata])
288        elif data is not None:
289            extra_args.extend(['-T', '-'])
290        extra_args.extend([
291            '-o', 'download_#1.data',
292        ])
293        if with_stats:
294            extra_args.extend([
295                '-w', '%{json}\\n'
296            ])
297        return self._raw(urls, intext=data,
298                         alpn_proto=alpn_proto, options=extra_args,
299                         with_stats=with_stats,
300                         with_headers=with_headers)
301
302    def response_file(self, idx: int):
303        return os.path.join(self._run_dir, f'download_{idx}.data')
304
305    def run_direct(self, args, with_stats: bool = False):
306        my_args = [self._curl]
307        if with_stats:
308            my_args.extend([
309                '-w', '%{json}\\n'
310            ])
311        my_args.extend([
312            '-o', 'download.data',
313        ])
314        my_args.extend(args)
315        return self._run(args=my_args, with_stats=with_stats)
316
317    def _run(self, args, intext='', with_stats: bool = False):
318        self._rmf(self._stdoutfile)
319        self._rmf(self._stderrfile)
320        self._rmf(self._headerfile)
321        self._rmf(self._tracefile)
322        start = datetime.now()
323        with open(self._stdoutfile, 'w') as cout:
324            with open(self._stderrfile, 'w') as cerr:
325                p = subprocess.run(args, stderr=cerr, stdout=cout,
326                                   cwd=self._run_dir, shell=False,
327                                   input=intext.encode() if intext else None)
328        coutput = open(self._stdoutfile).readlines()
329        cerrput = open(self._stderrfile).readlines()
330        return ExecResult(args=args, exit_code=p.returncode,
331                          stdout=coutput, stderr=cerrput,
332                          duration=datetime.now() - start,
333                          with_stats=with_stats)
334
335    def _raw(self, urls, intext='', timeout=10, options=None, insecure=False,
336             alpn_proto: Optional[str] = None,
337             force_resolve=True,
338             with_stats=False,
339             with_headers=True):
340        args = self._complete_args(
341            urls=urls, timeout=timeout, options=options, insecure=insecure,
342            alpn_proto=alpn_proto, force_resolve=force_resolve,
343            with_headers=with_headers)
344        r = self._run(args, intext=intext, with_stats=with_stats)
345        if r.exit_code == 0 and with_headers:
346            self._parse_headerfile(self._headerfile, r=r)
347            if r.json:
348                r.response["json"] = r.json
349        return r
350
351    def _complete_args(self, urls, timeout=None, options=None,
352                       insecure=False, force_resolve=True,
353                       alpn_proto: Optional[str] = None,
354                       with_headers: bool = True):
355        if not isinstance(urls, list):
356            urls = [urls]
357
358        args = [self._curl, "-s", "--path-as-is"]
359        if with_headers:
360            args.extend(["-D", self._headerfile])
361        if self.env.verbose > 1:
362            args.extend(['--trace', self._tracefile])
363        if self.env.verbose > 2:
364            args.extend(['--trace', self._tracefile, '--trace-time'])
365
366        for url in urls:
367            u = urlparse(urls[0])
368            if alpn_proto is not None:
369                if alpn_proto not in self.ALPN_ARG:
370                    raise Exception(f'unknown ALPN protocol: "{alpn_proto}"')
371                args.append(self.ALPN_ARG[alpn_proto])
372
373            if u.scheme == 'http':
374                pass
375            elif insecure:
376                args.append('--insecure')
377            elif options and "--cacert" in options:
378                pass
379            elif u.hostname:
380                args.extend(["--cacert", self.env.ca.cert_file])
381
382            if force_resolve and u.hostname and u.hostname != 'localhost' \
383                    and not re.match(r'^(\d+|\[|:).*', u.hostname):
384                port = u.port if u.port else 443
385                args.extend(["--resolve", f"{u.hostname}:{port}:127.0.0.1"])
386            if timeout is not None and int(timeout) > 0:
387                args.extend(["--connect-timeout", str(int(timeout))])
388            if options:
389                args.extend(options)
390            args.append(url)
391        return args
392
393    def _parse_headerfile(self, headerfile: str, r: ExecResult = None) -> ExecResult:
394        lines = open(headerfile).readlines()
395        if r is None:
396            r = ExecResult(args=[], exit_code=0, stdout=[], stderr=[])
397
398        response = None
399
400        def fin_response(resp):
401            if resp:
402                r.add_response(resp)
403
404        expected = ['status']
405        for line in lines:
406            line = line.strip()
407            if re.match(r'^$', line):
408                if 'trailer' in expected:
409                    # end of trailers
410                    fin_response(response)
411                    response = None
412                    expected = ['status']
413                elif 'header' in expected:
414                    # end of header, another status or trailers might follow
415                    expected = ['status', 'trailer']
416                else:
417                    assert False, f"unexpected line: '{line}'"
418                continue
419            if 'status' in expected:
420                # log.debug("reading 1st response line: %s", line)
421                m = re.match(r'^(\S+) (\d+)( .*)?$', line)
422                if m:
423                    fin_response(response)
424                    response = {
425                        "protocol": m.group(1),
426                        "status": int(m.group(2)),
427                        "description": m.group(3),
428                        "header": {},
429                        "trailer": {},
430                        "body": r.outraw
431                    }
432                    expected = ['header']
433                    continue
434            if 'trailer' in expected:
435                m = re.match(r'^([^:]+):\s*(.*)$', line)
436                if m:
437                    response['trailer'][m.group(1).lower()] = m.group(2)
438                    continue
439            if 'header' in expected:
440                m = re.match(r'^([^:]+):\s*(.*)$', line)
441                if m:
442                    response['header'][m.group(1).lower()] = m.group(2)
443                    continue
444            assert False, f"unexpected line: '{line}, expected: {expected}'"
445
446        fin_response(response)
447        return r
448
449