• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright 2020 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""File Helper Functions."""
16
17import glob
18import hashlib
19import json
20import logging
21import os
22import shutil
23import sys
24import subprocess
25import tarfile
26import urllib.request
27import zipfile
28from pathlib import Path
29from typing import List
30
31_LOG = logging.getLogger(__name__)
32
33
34class InvalidChecksumError(Exception):
35    pass
36
37
38def find_files(starting_dir: str,
39               patterns: List[str],
40               directories_only=False) -> List[str]:
41    original_working_dir = os.getcwd()
42    if not (os.path.exists(starting_dir) and os.path.isdir(starting_dir)):
43        raise FileNotFoundError(
44            "Directory '{}' does not exist.".format(starting_dir))
45
46    os.chdir(starting_dir)
47    files = []
48    for pattern in patterns:
49        for file_path in glob.glob(pattern, recursive=True):
50            if not directories_only or (directories_only
51                                        and os.path.isdir(file_path)):
52                files.append(file_path)
53    os.chdir(original_working_dir)
54    return sorted(files)
55
56
57def sha256_sum(file_name):
58    hash_sha256 = hashlib.sha256()
59    with open(file_name, "rb") as file_handle:
60        for chunk in iter(lambda: file_handle.read(4096), b""):
61            hash_sha256.update(chunk)
62    return hash_sha256.hexdigest()
63
64
65def md5_sum(file_name):
66    hash_md5 = hashlib.md5()
67    with open(file_name, "rb") as file_handle:
68        for chunk in iter(lambda: file_handle.read(4096), b""):
69            hash_md5.update(chunk)
70    return hash_md5.hexdigest()
71
72
73def verify_file_checksum(file_path,
74                         expected_checksum,
75                         sum_function=sha256_sum):
76    downloaded_checksum = sum_function(file_path)
77    if downloaded_checksum != expected_checksum:
78        raise InvalidChecksumError(
79            f"Invalid {sum_function.__name__}\n"
80            f"{downloaded_checksum} {os.path.basename(file_path)}\n"
81            f"{expected_checksum} (expected)\n\n"
82            "Please delete this file and try again:\n"
83            f"{file_path}")
84
85    _LOG.debug("  %s:", sum_function.__name__)
86    _LOG.debug("  %s %s", downloaded_checksum, os.path.basename(file_path))
87    return True
88
89
90def relative_or_absolute_path(file_string: str):
91    """Return a Path relative to os.getcwd(), else an absolute path."""
92    file_path = Path(file_string)
93    try:
94        return file_path.relative_to(os.getcwd())
95    except ValueError:
96        return file_path.resolve()
97
98
99def download_to_cache(url: str,
100                      expected_md5sum=None,
101                      expected_sha256sum=None,
102                      cache_directory=".cache",
103                      downloaded_file_name=None) -> str:
104
105    cache_dir = os.path.realpath(
106        os.path.expanduser(os.path.expandvars(cache_directory)))
107    if not downloaded_file_name:
108        # Use the last part of the URL as the file name.
109        downloaded_file_name = url.split("/")[-1]
110    downloaded_file = os.path.join(cache_dir, downloaded_file_name)
111
112    if not os.path.exists(downloaded_file):
113        _LOG.info("Downloading: %s", url)
114        _LOG.info("Please wait...")
115        urllib.request.urlretrieve(url, filename=downloaded_file)
116
117    if os.path.exists(downloaded_file):
118        _LOG.info("Downloaded: %s", relative_or_absolute_path(downloaded_file))
119        if expected_sha256sum:
120            verify_file_checksum(downloaded_file,
121                                 expected_sha256sum,
122                                 sum_function=sha256_sum)
123        elif expected_md5sum:
124            verify_file_checksum(downloaded_file,
125                                 expected_md5sum,
126                                 sum_function=md5_sum)
127
128    return downloaded_file
129
130
131def extract_zipfile(archive_file: str, dest_dir: str):
132    """Extract a zipfile preseving permissions."""
133    destination_path = Path(dest_dir)
134    with zipfile.ZipFile(archive_file) as archive:
135        for info in archive.infolist():
136            archive.extract(info.filename, path=dest_dir)
137            permissions = info.external_attr >> 16
138            out_path = destination_path / info.filename
139            out_path.chmod(permissions)
140
141
142def extract_tarfile(archive_file: str, dest_dir: str):
143    with tarfile.open(archive_file, 'r') as archive:
144        archive.extractall(path=dest_dir)
145
146
147def extract_archive(archive_file: str,
148                    dest_dir: str,
149                    cache_dir: str,
150                    remove_single_toplevel_folder=True):
151    """Extract a tar or zip file.
152
153    Args:
154        archive_file (str): Absolute path to the archive file.
155        dest_dir (str): Extraction destination directory.
156        cache_dir (str): Directory where temp files can be created.
157        remove_single_toplevel_folder (bool): If the archive contains only a
158            single folder move the contents of that into the destination
159            directory.
160    """
161    # Make a temporary directory to extract files into
162    temp_extract_dir = os.path.join(cache_dir,
163                                    "." + os.path.basename(archive_file))
164    os.makedirs(temp_extract_dir, exist_ok=True)
165
166    _LOG.info("Extracting: %s", relative_or_absolute_path(archive_file))
167    if zipfile.is_zipfile(archive_file):
168        extract_zipfile(archive_file, temp_extract_dir)
169    elif tarfile.is_tarfile(archive_file):
170        extract_tarfile(archive_file, temp_extract_dir)
171    else:
172        _LOG.error("Unknown archive format: %s", archive_file)
173        return sys.exit(1)
174
175    _LOG.info("Installing into: %s", relative_or_absolute_path(dest_dir))
176    path_to_extracted_files = temp_extract_dir
177
178    extracted_top_level_files = os.listdir(temp_extract_dir)
179    # Check if tarfile has only one folder
180    # If yes, make that the new path_to_extracted_files
181    if remove_single_toplevel_folder and len(extracted_top_level_files) == 1:
182        path_to_extracted_files = os.path.join(temp_extract_dir,
183                                               extracted_top_level_files[0])
184
185    # Move extracted files to dest_dir
186    extracted_files = os.listdir(path_to_extracted_files)
187    for file_name in extracted_files:
188        source_file = os.path.join(path_to_extracted_files, file_name)
189        dest_file = os.path.join(dest_dir, file_name)
190        shutil.move(source_file, dest_file)
191
192    # rm -rf temp_extract_dir
193    shutil.rmtree(temp_extract_dir, ignore_errors=True)
194
195    # Return List of extracted files
196    return list(Path(dest_dir).rglob("*"))
197
198
199def remove_empty_directories(directory):
200    """Recursively remove empty directories."""
201
202    for path in sorted(Path(directory).rglob("*"), reverse=True):
203        # If broken symlink
204        if path.is_symlink() and not path.exists():
205            path.unlink()
206        # if empty directory
207        elif path.is_dir() and len(os.listdir(path)) == 0:
208            path.rmdir()
209
210
211def decode_file_json(file_name):
212    """Decode JSON values from a file.
213
214    Does not raise an error if the file cannot be decoded."""
215
216    # Get absolute path to the file.
217    file_path = os.path.realpath(
218        os.path.expanduser(os.path.expandvars(file_name)))
219
220    json_file_options = {}
221    try:
222        with open(file_path, "r") as jfile:
223            json_file_options = json.loads(jfile.read())
224    except (FileNotFoundError, json.JSONDecodeError):
225        _LOG.warning("Unable to read file '%s'", file_path)
226
227    return json_file_options, file_path
228
229
230def git_apply_patch(root_directory,
231                    patch_file,
232                    ignore_whitespace=True,
233                    unsafe_paths=False):
234    """Use `git apply` to apply a diff file."""
235
236    _LOG.info("Applying Patch: %s", patch_file)
237    git_apply_command = ["git", "apply"]
238    if ignore_whitespace:
239        git_apply_command.append("--ignore-whitespace")
240    if unsafe_paths:
241        git_apply_command.append("--unsafe-paths")
242    git_apply_command += ["--directory", root_directory, patch_file]
243    subprocess.run(git_apply_command)
244