1# Copyright 2024 The Bazel Authors. All rights reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Get the requirement files by platform.""" 16 17load(":whl_target_platforms.bzl", "whl_target_platforms") 18 19# TODO @aignas 2024-05-13: consider using the same platform tags as are used in 20# the //python:versions.bzl 21DEFAULT_PLATFORMS = [ 22 "linux_aarch64", 23 "linux_arm", 24 "linux_ppc", 25 "linux_s390x", 26 "linux_x86_64", 27 "osx_aarch64", 28 "osx_x86_64", 29 "windows_x86_64", 30] 31 32def _default_platforms(*, filter): 33 if not filter: 34 fail("Must specific a filter string, got: {}".format(filter)) 35 36 if filter.startswith("cp3"): 37 # TODO @aignas 2024-05-23: properly handle python versions in the filter. 38 # For now we are just dropping it to ensure that we don't fail. 39 _, _, filter = filter.partition("_") 40 41 sanitized = filter.replace("*", "").replace("_", "") 42 if sanitized and not sanitized.isalnum(): 43 fail("The platform filter can only contain '*', '_' and alphanumerics") 44 45 if "*" in filter: 46 prefix = filter.rstrip("*") 47 if "*" in prefix: 48 fail("The filter can only contain '*' at the end of it") 49 50 if not prefix: 51 return DEFAULT_PLATFORMS 52 53 return [p for p in DEFAULT_PLATFORMS if p.startswith(prefix)] 54 else: 55 return [p for p in DEFAULT_PLATFORMS if filter in p] 56 57def _platforms_from_args(extra_pip_args): 58 platform_values = [] 59 60 if not extra_pip_args: 61 return platform_values 62 63 for arg in extra_pip_args: 64 if platform_values and platform_values[-1] == "": 65 platform_values[-1] = arg 66 continue 67 68 if arg == "--platform": 69 platform_values.append("") 70 continue 71 72 if not arg.startswith("--platform"): 73 continue 74 75 _, _, plat = arg.partition("=") 76 if not plat: 77 _, _, plat = arg.partition(" ") 78 if plat: 79 platform_values.append(plat) 80 else: 81 platform_values.append("") 82 83 if not platform_values: 84 return [] 85 86 platforms = { 87 p.target_platform: None 88 for arg in platform_values 89 for p in whl_target_platforms(arg) 90 } 91 return list(platforms.keys()) 92 93def _platform(platform_string, python_version = None): 94 if not python_version or platform_string.startswith("cp3"): 95 return platform_string 96 97 _, _, tail = python_version.partition(".") 98 minor, _, _ = tail.partition(".") 99 100 return "cp3{}_{}".format(minor, platform_string) 101 102def requirements_files_by_platform( 103 *, 104 requirements_by_platform = {}, 105 requirements_osx = None, 106 requirements_linux = None, 107 requirements_lock = None, 108 requirements_windows = None, 109 extra_pip_args = None, 110 python_version = None, 111 logger = None, 112 fail_fn = fail): 113 """Resolve the requirement files by target platform. 114 115 Args: 116 requirements_by_platform (label_keyed_string_dict): a way to have 117 different package versions (or different packages) for different 118 os, arch combinations. 119 requirements_osx (label): The requirements file for the osx OS. 120 requirements_linux (label): The requirements file for the linux OS. 121 requirements_lock (label): The requirements file for all OSes, or used as a fallback. 122 requirements_windows (label): The requirements file for windows OS. 123 extra_pip_args (string list): Extra pip arguments to perform extra validations and to 124 be joined with args fined in files. 125 python_version: str or None. This is needed when the get_index_urls is 126 specified. It should be of the form "3.x.x", 127 logger: repo_utils.logger or None, a simple struct to log diagnostic messages. 128 fail_fn (Callable[[str], None]): A failure function used in testing failure cases. 129 130 Returns: 131 A dict with keys as the labels to the files and values as lists of 132 platforms that the files support. 133 """ 134 if not ( 135 requirements_lock or 136 requirements_linux or 137 requirements_osx or 138 requirements_windows or 139 requirements_by_platform 140 ): 141 fail_fn( 142 "A 'requirements_lock' attribute must be specified, a platform-specific lockfiles " + 143 "via 'requirements_by_platform' or an os-specific lockfiles must be specified " + 144 "via 'requirements_*' attributes", 145 ) 146 return None 147 148 platforms = _platforms_from_args(extra_pip_args) 149 if logger: 150 logger.debug(lambda: "Platforms from pip args: {}".format(platforms)) 151 152 if platforms: 153 lock_files = [ 154 f 155 for f in [ 156 requirements_lock, 157 requirements_linux, 158 requirements_osx, 159 requirements_windows, 160 ] + list(requirements_by_platform.keys()) 161 if f 162 ] 163 164 if len(lock_files) > 1: 165 # If the --platform argument is used, check that we are using 166 # a single `requirements_lock` file instead of the OS specific ones as that is 167 # the only correct way to use the API. 168 fail_fn("only a single 'requirements_lock' file can be used when using '--platform' pip argument, consider specifying it via 'requirements_lock' attribute") 169 return None 170 171 files_by_platform = [ 172 (lock_files[0], platforms), 173 ] 174 if logger: 175 logger.debug(lambda: "Files by platform with the platform set in the args: {}".format(files_by_platform)) 176 else: 177 files_by_platform = { 178 file: [ 179 platform 180 for filter_or_platform in specifier.split(",") 181 for platform in (_default_platforms(filter = filter_or_platform) if filter_or_platform.endswith("*") else [filter_or_platform]) 182 ] 183 for file, specifier in requirements_by_platform.items() 184 }.items() 185 186 if logger: 187 logger.debug(lambda: "Files by platform with the platform set in the attrs: {}".format(files_by_platform)) 188 189 for f in [ 190 # If the users need a greater span of the platforms, they should consider 191 # using the 'requirements_by_platform' attribute. 192 (requirements_linux, _default_platforms(filter = "linux_*")), 193 (requirements_osx, _default_platforms(filter = "osx_*")), 194 (requirements_windows, _default_platforms(filter = "windows_*")), 195 (requirements_lock, None), 196 ]: 197 if f[0]: 198 if logger: 199 logger.debug(lambda: "Adding an extra item to files_by_platform: {}".format(f)) 200 files_by_platform.append(f) 201 202 configured_platforms = {} 203 requirements = {} 204 for file, plats in files_by_platform: 205 if plats: 206 plats = [_platform(p, python_version) for p in plats] 207 for p in plats: 208 if p in configured_platforms: 209 fail_fn( 210 "Expected the platform '{}' to be map only to a single requirements file, but got multiple: '{}', '{}'".format( 211 p, 212 configured_platforms[p], 213 file, 214 ), 215 ) 216 return None 217 218 configured_platforms[p] = file 219 else: 220 default_platforms = [_platform(p, python_version) for p in DEFAULT_PLATFORMS] 221 plats = [ 222 p 223 for p in default_platforms 224 if p not in configured_platforms 225 ] 226 if logger: 227 logger.debug(lambda: "File {} will be used for the remaining platforms {} that are not in configured_platforms: {}".format( 228 file, 229 plats, 230 default_platforms, 231 )) 232 for p in plats: 233 configured_platforms[p] = file 234 235 if logger: 236 logger.debug(lambda: "Configured platforms for file {} are {}".format(file, plats)) 237 238 for p in plats: 239 if p in requirements: 240 # This should never happen because in the code above we should 241 # have unambiguous selection of the requirements files. 242 fail_fn("Attempting to override a requirements file '{}' with '{}' for platform '{}'".format( 243 requirements[p], 244 file, 245 p, 246 )) 247 return None 248 requirements[p] = file 249 250 # Now return a dict that is similar to requirements_by_platform - where we 251 # have labels/files as keys in the dict to minimize the number of times we 252 # may parse the same file. 253 254 ret = {} 255 for plat, file in requirements.items(): 256 ret.setdefault(file, []).append(plat) 257 258 return ret 259