• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright 2020 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""File Helper Functions."""
16
17import glob
18import hashlib
19import json
20import logging
21import os
22import shutil
23import sys
24import subprocess
25import tarfile
26import urllib.request
27import zipfile
28from pathlib import Path
29from typing import List
30
31_LOG = logging.getLogger(__name__)
32
33
34class InvalidChecksumError(Exception):
35    pass
36
37
38def find_files(
39    starting_dir: str, patterns: List[str], directories_only=False
40) -> List[str]:
41    original_working_dir = os.getcwd()
42    if not (os.path.exists(starting_dir) and os.path.isdir(starting_dir)):
43        raise FileNotFoundError(
44            "Directory '{}' does not exist.".format(starting_dir)
45        )
46
47    os.chdir(starting_dir)
48    files = []
49    for pattern in patterns:
50        for file_path in glob.glob(pattern, recursive=True):
51            if not directories_only or (
52                directories_only and os.path.isdir(file_path)
53            ):
54                files.append(file_path)
55    os.chdir(original_working_dir)
56    return sorted(files)
57
58
59def sha256_sum(file_name):
60    hash_sha256 = hashlib.sha256()
61    with open(file_name, "rb") as file_handle:
62        for chunk in iter(lambda: file_handle.read(4096), b""):
63            hash_sha256.update(chunk)
64    return hash_sha256.hexdigest()
65
66
67def md5_sum(file_name):
68    hash_md5 = hashlib.md5()
69    with open(file_name, "rb") as file_handle:
70        for chunk in iter(lambda: file_handle.read(4096), b""):
71            hash_md5.update(chunk)
72    return hash_md5.hexdigest()
73
74
75def verify_file_checksum(file_path, expected_checksum, sum_function=sha256_sum):
76    downloaded_checksum = sum_function(file_path)
77    if downloaded_checksum != expected_checksum:
78        raise InvalidChecksumError(
79            f"Invalid {sum_function.__name__}\n"
80            f"{downloaded_checksum} {os.path.basename(file_path)}\n"
81            f"{expected_checksum} (expected)\n\n"
82            "Please delete this file and try again:\n"
83            f"{file_path}"
84        )
85
86    _LOG.debug("  %s:", sum_function.__name__)
87    _LOG.debug("  %s %s", downloaded_checksum, os.path.basename(file_path))
88    return True
89
90
91def relative_or_absolute_path(file_string: str):
92    """Return a Path relative to os.getcwd(), else an absolute path."""
93    file_path = Path(file_string)
94    try:
95        return file_path.relative_to(os.getcwd())
96    except ValueError:
97        return file_path.resolve()
98
99
100def download_to_cache(
101    url: str,
102    expected_md5sum=None,
103    expected_sha256sum=None,
104    cache_directory=".cache",
105    downloaded_file_name=None,
106) -> str:
107    """TODO(tonymd) Add docstring."""
108
109    cache_dir = os.path.realpath(
110        os.path.expanduser(os.path.expandvars(cache_directory))
111    )
112    if not downloaded_file_name:
113        # Use the last part of the URL as the file name.
114        downloaded_file_name = url.split("/")[-1]
115    downloaded_file = os.path.join(cache_dir, downloaded_file_name)
116
117    if not os.path.exists(downloaded_file):
118        _LOG.info("Downloading: %s", url)
119        _LOG.info("Please wait...")
120        urllib.request.urlretrieve(url, filename=downloaded_file)
121
122    if os.path.exists(downloaded_file):
123        _LOG.info("Downloaded: %s", relative_or_absolute_path(downloaded_file))
124        if expected_sha256sum:
125            verify_file_checksum(
126                downloaded_file, expected_sha256sum, sum_function=sha256_sum
127            )
128        elif expected_md5sum:
129            verify_file_checksum(
130                downloaded_file, expected_md5sum, sum_function=md5_sum
131            )
132
133    return downloaded_file
134
135
136def extract_zipfile(archive_file: str, dest_dir: str):
137    """Extract a zipfile preseving permissions."""
138    destination_path = Path(dest_dir)
139    with zipfile.ZipFile(archive_file) as archive:
140        for info in archive.infolist():
141            archive.extract(info.filename, path=dest_dir)
142            permissions = info.external_attr >> 16
143            out_path = destination_path / info.filename
144            out_path.chmod(permissions)
145
146
147def extract_tarfile(archive_file: str, dest_dir: str):
148    with tarfile.open(archive_file, 'r') as archive:
149        archive.extractall(path=dest_dir)
150
151
152def extract_archive(
153    archive_file: str,
154    dest_dir: str,
155    cache_dir: str,
156    remove_single_toplevel_folder=True,
157):
158    """Extract a tar or zip file.
159
160    Args:
161        archive_file (str): Absolute path to the archive file.
162        dest_dir (str): Extraction destination directory.
163        cache_dir (str): Directory where temp files can be created.
164        remove_single_toplevel_folder (bool): If the archive contains only a
165            single folder move the contents of that into the destination
166            directory.
167    """
168    # Make a temporary directory to extract files into
169    temp_extract_dir = os.path.join(
170        cache_dir, "." + os.path.basename(archive_file)
171    )
172    os.makedirs(temp_extract_dir, exist_ok=True)
173
174    _LOG.info("Extracting: %s", relative_or_absolute_path(archive_file))
175    if zipfile.is_zipfile(archive_file):
176        extract_zipfile(archive_file, temp_extract_dir)
177    elif tarfile.is_tarfile(archive_file):
178        extract_tarfile(archive_file, temp_extract_dir)
179    else:
180        _LOG.error("Unknown archive format: %s", archive_file)
181        return sys.exit(1)
182
183    _LOG.info("Installing into: %s", relative_or_absolute_path(dest_dir))
184    path_to_extracted_files = temp_extract_dir
185
186    extracted_top_level_files = os.listdir(temp_extract_dir)
187    # Check if tarfile has only one folder
188    # If yes, make that the new path_to_extracted_files
189    if remove_single_toplevel_folder and len(extracted_top_level_files) == 1:
190        path_to_extracted_files = os.path.join(
191            temp_extract_dir, extracted_top_level_files[0]
192        )
193
194    # Move extracted files to dest_dir
195    extracted_files = os.listdir(path_to_extracted_files)
196    for file_name in extracted_files:
197        source_file = os.path.join(path_to_extracted_files, file_name)
198        dest_file = os.path.join(dest_dir, file_name)
199        shutil.move(source_file, dest_file)
200
201    # rm -rf temp_extract_dir
202    shutil.rmtree(temp_extract_dir, ignore_errors=True)
203
204    # Return List of extracted files
205    return list(Path(dest_dir).rglob("*"))
206
207
208def remove_empty_directories(directory):
209    """Recursively remove empty directories."""
210
211    for path in sorted(Path(directory).rglob("*"), reverse=True):
212        # If broken symlink
213        if path.is_symlink() and not path.exists():
214            path.unlink()
215        # if empty directory
216        elif path.is_dir() and len(os.listdir(path)) == 0:
217            path.rmdir()
218
219
220def decode_file_json(file_name):
221    """Decode JSON values from a file.
222
223    Does not raise an error if the file cannot be decoded."""
224
225    # Get absolute path to the file.
226    file_path = os.path.realpath(
227        os.path.expanduser(os.path.expandvars(file_name))
228    )
229
230    json_file_options = {}
231    try:
232        with open(file_path, "r") as jfile:
233            json_file_options = json.loads(jfile.read())
234    except (FileNotFoundError, json.JSONDecodeError):
235        _LOG.warning("Unable to read file '%s'", file_path)
236
237    return json_file_options, file_path
238
239
240def git_apply_patch(
241    root_directory, patch_file, ignore_whitespace=True, unsafe_paths=False
242):
243    """Use `git apply` to apply a diff file."""
244
245    _LOG.info("Applying Patch: %s", patch_file)
246    git_apply_command = ["git", "apply"]
247    if ignore_whitespace:
248        git_apply_command.append("--ignore-whitespace")
249    if unsafe_paths:
250        git_apply_command.append("--unsafe-paths")
251    git_apply_command += ["--directory", root_directory, patch_file]
252    subprocess.run(git_apply_command)
253