1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3# Copyright (c) 2025 Huawei Device Co., Ltd. 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15 16from download_util import ( 17 check_sha256, 18 check_sha256_by_mark, 19 extract_compress_files_and_gen_mark, 20 get_local_path, 21 run_cmd, 22 import_rich_module, 23) 24import os 25import sys 26import re 27import glob 28import traceback 29import threading 30from concurrent.futures import ThreadPoolExecutor, as_completed 31from multiprocessing import cpu_count 32import requests 33 34 35class PoolDownloader: 36 def __init__(self, download_configs: list, global_args: object = None): 37 if not global_args.disable_rich: 38 self.progress = import_rich_module() 39 else: 40 self.progress = None 41 self.global_args = global_args 42 self.download_configs = download_configs 43 self.lock = threading.Lock() 44 self.unchanged_tool_list = [] 45 46 def start(self) -> list: 47 if self.progress: 48 with self.progress: 49 self._run_download_in_thread_pool() 50 else: 51 self._run_download_in_thread_pool() 52 return self.unchanged_tool_list 53 54 def _run_download_in_thread_pool(self): 55 try: 56 cnt = cpu_count() 57 except Exception as e: 58 cnt = 1 59 with ThreadPoolExecutor(max_workers=cnt) as pool: 60 tasks = dict() 61 for config_item in self.download_configs: 62 task = pool.submit(self._process, config_item) 63 tasks[task] = os.path.basename(config_item.get("remote_url")) 64 self._wait_for_download_tasks_complete(tasks) 65 66 def _wait_for_download_tasks_complete(self, tasks: dict): 67 for task in as_completed(tasks): 68 try: 69 _ = task.result() 70 except Exception as e: 71 self._adaptive_print(f"Task {task} generated an exception: {e}", style="red") 72 self._adaptive_print(traceback.format_exc()) 73 else: 74 self._adaptive_print( 75 "{}, download and decompress completed".format(tasks.get(task)), 76 style="green", 77 ) 78 79 def _adaptive_print(self, msg: str, **kwargs): 80 if self.progress: 81 self.progress.console.log(msg, **kwargs) 82 else: 83 print(msg) 84 85 def _process(self, operate: dict): 86 global_args = self.global_args 87 remote_url = operate.get("remote_url") 88 if "python" in remote_url and global_args.glibc_version is not None: 89 remote_url = re.sub(r"GLIBC[0-9]\.[0-9]{2}", global_args.glibc_version, remote_url) 90 remote_url = global_args.tool_repo + remote_url 91 92 download_root = operate.get("download_dir") 93 unzip_dir = operate.get("unzip_dir") 94 unzip_filename = operate.get("unzip_filename") 95 local_path = get_local_path(download_root, remote_url) 96 self._adaptive_print(f"start deal {remote_url}") 97 mark_file_exist, mark_file_path = check_sha256_by_mark(remote_url, unzip_dir, unzip_filename) 98 # 检查解压的文件是否和远程一致 99 if mark_file_exist: 100 self._adaptive_print( 101 "{}, Sha256 markword check OK.".format(remote_url), style="green" 102 ) 103 with self.lock: 104 self.unchanged_tool_list.append(operate.get("name") + "_" + os.path.basename(remote_url)) 105 else: 106 # 不一致则先删除产物 107 run_cmd(["rm", "-rf"] + glob.glob(f"{unzip_dir}/*.{unzip_filename}.mark", recursive=False)) 108 run_cmd(["rm", "-rf", '{}/{}'.format(unzip_dir, unzip_filename)]) 109 # 校验压缩包 110 if os.path.exists(local_path): 111 check_result = check_sha256(remote_url, local_path) 112 if check_result: 113 self._adaptive_print( 114 "{}, Sha256 check download OK.".format(local_path), 115 style="green", 116 ) 117 else: 118 # 压缩包不一致则删除压缩包,重新下载 119 os.remove(local_path) 120 self._try_download(remote_url, local_path) 121 else: 122 # 压缩包不存在则下载 123 self._try_download(remote_url, local_path) 124 125 # 解压缩包 126 self._adaptive_print("Start decompression {}".format(local_path)) 127 extract_compress_files_and_gen_mark(local_path, unzip_dir, mark_file_path) 128 self._adaptive_print(f"{local_path} extracted to {unzip_dir}") 129 130 131 def _try_download(self, remote_url: str, local_path: str): 132 max_retry_times = 3 133 # 创建下载目录 134 download_dir = os.path.dirname(local_path) 135 os.makedirs(download_dir, exist_ok=True) 136 137 # 获取进度条和任务 ID 138 progress = self.progress 139 progress_task_id = progress.add_task( 140 "download", filename=os.path.basename(remote_url), start=False 141 ) if progress else None 142 self._adaptive_print(f"Downloading {remote_url}") 143 for retry_times in range(max_retry_times): 144 try: 145 self._download_remote_file(remote_url, local_path, progress_task_id) 146 return 147 except Exception as e: 148 error_message = getattr(e, 'code', str(e)) 149 self._adaptive_print( 150 f"Failed to open {remote_url}, Error: {error_message}", 151 style="red" 152 ) 153 154 # 重试次数达到上限,下载失败 155 self._adaptive_print( 156 f"{local_path}, download failed after {max_retry_times} retries, " 157 "please check network status. Prebuilts download exit." 158 ) 159 sys.exit(1) 160 161 def _download_remote_file(self, remote_url: str, local_path: str, progress_task_id): 162 buffer_size = 32768 163 progress = self.progress 164 # 使用requests库进行下载 165 with requests.get(remote_url, stream=True, timeout=(30, 300)) as response: 166 response.raise_for_status() # 检查HTTP错误 167 168 total_size = int(response.headers.get("Content-Length", 0)) 169 if progress: 170 progress.update(progress_task_id, total=total_size) 171 progress.start_task(progress_task_id) 172 self._save_to_local(response, local_path, buffer_size, progress_task_id) 173 self._adaptive_print(f"Downloaded {local_path}") 174 175 def _save_to_local(self, response: requests.Response, local_path: str, buffer_size: int, progress_task_id): 176 with os.fdopen(os.open(local_path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, mode=0o640), 'wb') as dest_file: 177 for chunk in response.iter_content(chunk_size=buffer_size): 178 if chunk: # 过滤掉保持连接的chunk 179 dest_file.write(chunk) 180 self._update_progress(progress_task_id, len(chunk)) 181 182 def _update_progress(self, task_id, advance): 183 if self.progress: 184 self.progress.update(task_id, advance=advance) 185