1#!/usr/bin/env python3 2# Copyright 2020 The Pigweed Authors 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); you may not 5# use this file except in compliance with the License. You may obtain a copy of 6# the License at 7# 8# https://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13# License for the specific language governing permissions and limitations under 14# the License. 15"""File Helper Functions.""" 16 17import glob 18import hashlib 19import json 20import logging 21import os 22import shutil 23import sys 24import subprocess 25import tarfile 26import urllib.request 27import zipfile 28from pathlib import Path 29from typing import List 30 31_LOG = logging.getLogger(__name__) 32 33 34class InvalidChecksumError(Exception): 35 pass 36 37 38def find_files( 39 starting_dir: str, patterns: List[str], directories_only=False 40) -> List[str]: 41 original_working_dir = os.getcwd() 42 if not (os.path.exists(starting_dir) and os.path.isdir(starting_dir)): 43 raise FileNotFoundError( 44 "Directory '{}' does not exist.".format(starting_dir) 45 ) 46 47 os.chdir(starting_dir) 48 files = [] 49 for pattern in patterns: 50 for file_path in glob.glob(pattern, recursive=True): 51 if not directories_only or ( 52 directories_only and os.path.isdir(file_path) 53 ): 54 files.append(file_path) 55 os.chdir(original_working_dir) 56 return sorted(files) 57 58 59def sha256_sum(file_name): 60 hash_sha256 = hashlib.sha256() 61 with open(file_name, "rb") as file_handle: 62 for chunk in iter(lambda: file_handle.read(4096), b""): 63 hash_sha256.update(chunk) 64 return hash_sha256.hexdigest() 65 66 67def md5_sum(file_name): 68 hash_md5 = hashlib.md5() 69 with open(file_name, "rb") as file_handle: 70 for chunk in iter(lambda: file_handle.read(4096), b""): 71 hash_md5.update(chunk) 72 return hash_md5.hexdigest() 73 74 75def verify_file_checksum(file_path, expected_checksum, sum_function=sha256_sum): 76 downloaded_checksum = sum_function(file_path) 77 if downloaded_checksum != expected_checksum: 78 raise InvalidChecksumError( 79 f"Invalid {sum_function.__name__}\n" 80 f"{downloaded_checksum} {os.path.basename(file_path)}\n" 81 f"{expected_checksum} (expected)\n\n" 82 "Please delete this file and try again:\n" 83 f"{file_path}" 84 ) 85 86 _LOG.debug(" %s:", sum_function.__name__) 87 _LOG.debug(" %s %s", downloaded_checksum, os.path.basename(file_path)) 88 return True 89 90 91def relative_or_absolute_path(file_string: str): 92 """Return a Path relative to os.getcwd(), else an absolute path.""" 93 file_path = Path(file_string) 94 try: 95 return file_path.relative_to(os.getcwd()) 96 except ValueError: 97 return file_path.resolve() 98 99 100def download_to_cache( 101 url: str, 102 expected_md5sum=None, 103 expected_sha256sum=None, 104 cache_directory=".cache", 105 downloaded_file_name=None, 106) -> str: 107 """TODO(tonymd) Add docstring.""" 108 109 cache_dir = os.path.realpath( 110 os.path.expanduser(os.path.expandvars(cache_directory)) 111 ) 112 if not downloaded_file_name: 113 # Use the last part of the URL as the file name. 114 downloaded_file_name = url.split("/")[-1] 115 downloaded_file = os.path.join(cache_dir, downloaded_file_name) 116 117 if not os.path.exists(downloaded_file): 118 _LOG.info("Downloading: %s", url) 119 _LOG.info("Please wait...") 120 urllib.request.urlretrieve(url, filename=downloaded_file) 121 122 if os.path.exists(downloaded_file): 123 _LOG.info("Downloaded: %s", relative_or_absolute_path(downloaded_file)) 124 if expected_sha256sum: 125 verify_file_checksum( 126 downloaded_file, expected_sha256sum, sum_function=sha256_sum 127 ) 128 elif expected_md5sum: 129 verify_file_checksum( 130 downloaded_file, expected_md5sum, sum_function=md5_sum 131 ) 132 133 return downloaded_file 134 135 136def extract_zipfile(archive_file: str, dest_dir: str): 137 """Extract a zipfile preseving permissions.""" 138 destination_path = Path(dest_dir) 139 with zipfile.ZipFile(archive_file) as archive: 140 for info in archive.infolist(): 141 archive.extract(info.filename, path=dest_dir) 142 permissions = info.external_attr >> 16 143 out_path = destination_path / info.filename 144 out_path.chmod(permissions) 145 146 147def extract_tarfile(archive_file: str, dest_dir: str): 148 with tarfile.open(archive_file, 'r') as archive: 149 archive.extractall(path=dest_dir) 150 151 152def extract_archive( 153 archive_file: str, 154 dest_dir: str, 155 cache_dir: str, 156 remove_single_toplevel_folder=True, 157): 158 """Extract a tar or zip file. 159 160 Args: 161 archive_file (str): Absolute path to the archive file. 162 dest_dir (str): Extraction destination directory. 163 cache_dir (str): Directory where temp files can be created. 164 remove_single_toplevel_folder (bool): If the archive contains only a 165 single folder move the contents of that into the destination 166 directory. 167 """ 168 # Make a temporary directory to extract files into 169 temp_extract_dir = os.path.join( 170 cache_dir, "." + os.path.basename(archive_file) 171 ) 172 os.makedirs(temp_extract_dir, exist_ok=True) 173 174 _LOG.info("Extracting: %s", relative_or_absolute_path(archive_file)) 175 if zipfile.is_zipfile(archive_file): 176 extract_zipfile(archive_file, temp_extract_dir) 177 elif tarfile.is_tarfile(archive_file): 178 extract_tarfile(archive_file, temp_extract_dir) 179 else: 180 _LOG.error("Unknown archive format: %s", archive_file) 181 return sys.exit(1) 182 183 _LOG.info("Installing into: %s", relative_or_absolute_path(dest_dir)) 184 path_to_extracted_files = temp_extract_dir 185 186 extracted_top_level_files = os.listdir(temp_extract_dir) 187 # Check if tarfile has only one folder 188 # If yes, make that the new path_to_extracted_files 189 if remove_single_toplevel_folder and len(extracted_top_level_files) == 1: 190 path_to_extracted_files = os.path.join( 191 temp_extract_dir, extracted_top_level_files[0] 192 ) 193 194 # Move extracted files to dest_dir 195 extracted_files = os.listdir(path_to_extracted_files) 196 for file_name in extracted_files: 197 source_file = os.path.join(path_to_extracted_files, file_name) 198 dest_file = os.path.join(dest_dir, file_name) 199 shutil.move(source_file, dest_file) 200 201 # rm -rf temp_extract_dir 202 shutil.rmtree(temp_extract_dir, ignore_errors=True) 203 204 # Return List of extracted files 205 return list(Path(dest_dir).rglob("*")) 206 207 208def remove_empty_directories(directory): 209 """Recursively remove empty directories.""" 210 211 for path in sorted(Path(directory).rglob("*"), reverse=True): 212 # If broken symlink 213 if path.is_symlink() and not path.exists(): 214 path.unlink() 215 # if empty directory 216 elif path.is_dir() and len(os.listdir(path)) == 0: 217 path.rmdir() 218 219 220def decode_file_json(file_name): 221 """Decode JSON values from a file. 222 223 Does not raise an error if the file cannot be decoded.""" 224 225 # Get absolute path to the file. 226 file_path = os.path.realpath( 227 os.path.expanduser(os.path.expandvars(file_name)) 228 ) 229 230 json_file_options = {} 231 try: 232 with open(file_path, "r") as jfile: 233 json_file_options = json.loads(jfile.read()) 234 except (FileNotFoundError, json.JSONDecodeError): 235 _LOG.warning("Unable to read file '%s'", file_path) 236 237 return json_file_options, file_path 238 239 240def git_apply_patch( 241 root_directory, patch_file, ignore_whitespace=True, unsafe_paths=False 242): 243 """Use `git apply` to apply a diff file.""" 244 245 _LOG.info("Applying Patch: %s", patch_file) 246 git_apply_command = ["git", "apply"] 247 if ignore_whitespace: 248 git_apply_command.append("--ignore-whitespace") 249 if unsafe_paths: 250 git_apply_command.append("--unsafe-paths") 251 git_apply_command += ["--directory", root_directory, patch_file] 252 subprocess.run(git_apply_command) 253