• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3# Copyright (c) 2025 Huawei Device Co., Ltd.
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16from download_util import (
17    check_sha256,
18    check_sha256_by_mark,
19    extract_compress_files_and_gen_mark,
20    get_local_path,
21    run_cmd,
22    import_rich_module,
23)
24import os
25import sys
26import re
27import glob
28import traceback
29import threading
30from concurrent.futures import ThreadPoolExecutor, as_completed
31from multiprocessing import cpu_count
32import requests
33
34
35class PoolDownloader:
36    def __init__(self, download_configs: list, global_args: object = None):
37        if not global_args.disable_rich:
38            self.progress = import_rich_module()
39        else:
40            self.progress = None
41        self.global_args = global_args
42        self.download_configs = download_configs
43        self.lock = threading.Lock()
44        self.unchanged_tool_list = []
45
46    def start(self) -> list:
47        if self.progress:
48            with self.progress:
49                self._run_download_in_thread_pool()
50        else:
51            self._run_download_in_thread_pool()
52        return self.unchanged_tool_list
53
54    def _run_download_in_thread_pool(self):
55        try:
56            cnt = cpu_count()
57        except Exception as e:
58            cnt = 1
59        with ThreadPoolExecutor(max_workers=cnt) as pool:
60            tasks = dict()
61            for config_item in self.download_configs:
62                task = pool.submit(self._process, config_item)
63                tasks[task] = os.path.basename(config_item.get("remote_url"))
64            self._wait_for_download_tasks_complete(tasks)
65
66    def _wait_for_download_tasks_complete(self, tasks: dict):
67        for task in as_completed(tasks):
68            try:
69                _ = task.result()
70            except Exception as e:
71                self._adaptive_print(f"Task {task} generated an exception: {e}", style="red")
72                self._adaptive_print(traceback.format_exc())
73            else:
74                self._adaptive_print(
75                    "{}, download and decompress completed".format(tasks.get(task)),
76                    style="green",
77                )
78
79    def _adaptive_print(self, msg: str, **kwargs):
80        if self.progress:
81            self.progress.console.log(msg, **kwargs)
82        else:
83            print(msg)
84
85    def _process(self, operate: dict):
86        global_args = self.global_args
87        remote_url = operate.get("remote_url")
88        if "python" in remote_url and global_args.glibc_version is not None:
89            remote_url = re.sub(r"GLIBC[0-9]\.[0-9]{2}", global_args.glibc_version, remote_url)
90        remote_url = global_args.tool_repo + remote_url
91
92        download_root = operate.get("download_dir")
93        unzip_dir = operate.get("unzip_dir")
94        unzip_filename = operate.get("unzip_filename")
95        local_path = get_local_path(download_root, remote_url)
96        self._adaptive_print(f"start deal {remote_url}")
97        mark_file_exist, mark_file_path = check_sha256_by_mark(remote_url, unzip_dir, unzip_filename)
98        # 检查解压的文件是否和远程一致
99        if mark_file_exist:
100            self._adaptive_print(
101                "{}, Sha256 markword check OK.".format(remote_url), style="green"
102            )
103            with self.lock:
104                self.unchanged_tool_list.append(operate.get("name") + "_" + os.path.basename(remote_url))
105        else:
106            # 不一致则先删除产物
107            run_cmd(["rm", "-rf"] + glob.glob(f"{unzip_dir}/*.{unzip_filename}.mark", recursive=False))
108            run_cmd(["rm", "-rf", '{}/{}'.format(unzip_dir, unzip_filename)])
109            # 校验压缩包
110            if os.path.exists(local_path):
111                check_result = check_sha256(remote_url, local_path)
112                if check_result:
113                    self._adaptive_print(
114                        "{}, Sha256 check download OK.".format(local_path),
115                        style="green",
116                    )
117                else:
118                    # 压缩包不一致则删除压缩包,重新下载
119                    os.remove(local_path)
120                    self._try_download(remote_url, local_path)
121            else:
122                # 压缩包不存在则下载
123                self._try_download(remote_url, local_path)
124
125            # 解压缩包
126            self._adaptive_print("Start decompression {}".format(local_path))
127            extract_compress_files_and_gen_mark(local_path, unzip_dir, mark_file_path)
128            self._adaptive_print(f"{local_path} extracted to {unzip_dir}")
129
130
131    def _try_download(self, remote_url: str, local_path: str):
132        max_retry_times = 3
133        # 创建下载目录
134        download_dir = os.path.dirname(local_path)
135        os.makedirs(download_dir, exist_ok=True)
136
137        # 获取进度条和任务 ID
138        progress = self.progress
139        progress_task_id = progress.add_task(
140            "download", filename=os.path.basename(remote_url), start=False
141        ) if progress else None
142        self._adaptive_print(f"Downloading {remote_url}")
143        for retry_times in range(max_retry_times):
144            try:
145                self._download_remote_file(remote_url, local_path, progress_task_id)
146                return
147            except Exception as e:
148                error_message = getattr(e, 'code', str(e))
149                self._adaptive_print(
150                    f"Failed to open {remote_url}, Error: {error_message}",
151                    style="red"
152                )
153
154        # 重试次数达到上限,下载失败
155        self._adaptive_print(
156            f"{local_path}, download failed after {max_retry_times} retries, "
157            "please check network status. Prebuilts download exit."
158        )
159        sys.exit(1)
160
161    def _download_remote_file(self, remote_url: str, local_path: str, progress_task_id):
162        buffer_size = 32768
163        progress = self.progress
164        # 使用requests库进行下载
165        with requests.get(remote_url, stream=True, timeout=(30, 300)) as response:
166            response.raise_for_status()  # 检查HTTP错误
167
168            total_size = int(response.headers.get("Content-Length", 0))
169            if progress:
170                progress.update(progress_task_id, total=total_size)
171                progress.start_task(progress_task_id)
172            self._save_to_local(response, local_path, buffer_size, progress_task_id)
173        self._adaptive_print(f"Downloaded {local_path}")
174
175    def _save_to_local(self, response: requests.Response, local_path: str, buffer_size: int, progress_task_id):
176        with os.fdopen(os.open(local_path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, mode=0o640), 'wb') as dest_file:
177            for chunk in response.iter_content(chunk_size=buffer_size):
178                if chunk:  # 过滤掉保持连接的chunk
179                    dest_file.write(chunk)
180                    self._update_progress(progress_task_id, len(chunk))
181
182    def _update_progress(self, task_id, advance):
183        if self.progress:
184            self.progress.update(task_id, advance=advance)
185