• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Extension for declaring Pigweed Rust toolchains."""
15
16load("//pw_env_setup/bazel/cipd_setup:cipd_rules.bzl", "cipd_repository")
17load(":templates.bzl", "rust_analyzer_toolchain_template", "rust_toolchain_no_prebuilt_template", "rust_toolchain_template", "rustfmt_toolchain_template", "toolchain_template")
18load(":toolchains.bzl", "CHANNELS", "EXTRA_TARGETS", "HOSTS")
19
20def _module_cipd_tag(module):
21    """\
22    Returns the `cipd_tag` tag for the given module.
23
24    Latter delcations will take precedence of ealier ones.
25    """
26    cipd_tag = None
27    for toolchain in module.tags.toolchain:
28        cipd_tag = toolchain.cipd_tag
29
30    return cipd_tag
31
32def _find_cipd_tag(ctx):
33    """\
34    Returns the CIPD tag specified in either the root or pigweed modules.
35
36    The tag from the root module will take priority over the tag from the
37    pigweed module.
38    """
39
40    pigweed_module = None
41    root_tag = None
42
43    for module in ctx.modules:
44        if module.is_root:
45            root_tag = _module_cipd_tag(module)
46        if module.name == "pigweed":
47            pigweed_module = module
48
49    if pigweed_module == None:
50        fail("Unable to find pigweed module")
51
52    return root_tag or _module_cipd_tag(pigweed_module)
53
54def _normalize_os_to_cipd(os):
55    """\
56    Translate a bazel OS name to one used by CIPD.
57    """
58    if os == "macos":
59        return "mac"
60
61    return os
62
63def _pw_rust_impl(ctx):
64    cipd_tag = _find_cipd_tag(ctx)
65
66    # Register CIPD repositories for toolchain binaries
67    for host in HOSTS:
68        cipd_os = _normalize_os_to_cipd(host["os"])
69
70        cipd_repository(
71            name = "rust_toolchain_host_{}_{}".format(host["os"], host["cpu"]),
72            build_file = "//pw_toolchain/rust:rust_toolchain.BUILD",
73            path = "fuchsia/third_party/rust/host/{}-{}".format(cipd_os, host["cipd_arch"]),
74            tag = cipd_tag,
75        )
76
77        cipd_repository(
78            name = "rust_toolchain_target_{}_{}".format(host["triple"], host["cpu"]),
79            build_file = "//pw_toolchain/rust:rust_stdlib.BUILD",
80            path = "fuchsia/third_party/rust/target/{}".format(host["triple"]),
81            tag = cipd_tag,
82        )
83
84    for target in EXTRA_TARGETS:
85        build_std = target.get("build_std", False)
86        if not build_std:
87            cipd_repository(
88                name = "rust_toolchain_target_{}_{}".format(target["triple"], target["cpu"]),
89                build_file = "//pw_toolchain/rust:rust_stdlib.BUILD",
90                path = "fuchsia/third_party/rust/target/{}".format(target["triple"]),
91                tag = cipd_tag,
92            )
93
94    _toolchain_repository_hub(name = "pw_rust_toolchains")
95
96_RUST_TOOLCHAIN_TAG = tag_class(
97    attrs = dict(
98        cipd_tag = attr.string(
99            doc = "The CIPD tag to use when fetching the Rust toolchain.",
100        ),
101    ),
102)
103
104pw_rust = module_extension(
105    implementation = _pw_rust_impl,
106    tag_classes = {
107        "toolchain": _RUST_TOOLCHAIN_TAG,
108    },
109    doc = """Generate a repository for all Pigweed Rust toolchains.
110
111        Declares a suite of Rust toolchains that may be registered in a
112        MODULE.bazel file. If you would like to use the Toolchains provided
113        by Pigweed, add these lines to your MOUDLE.bazel:
114        ```
115        pw_rust = use_extension("@pigweed//pw_toolchain/rust:extensions.bzl", "pw_rust")
116        use_repo(pw_rust, "pw_rust_toolchains")
117        register_toolchains(
118            "@pw_rust_toolchains//:all",
119            dev_dependency = True,
120        )
121        ```
122
123        If you would like to override the rust compiler version, you can specify a
124        CIPD version for an alternative toolchain to use in your project. Note that
125        only the root module's specification of this tag is applied, and that if no
126        version tag is specified Pigweed's value will be used as a fallback.
127        ```
128        pw_rust = use_extension("@pigweed//pw_toolchain/rust:extensions.bzl", "pw_rust")
129        pw_rust.toolchain(cipd_tag = "rust_revision:bf9c7a64ad222b85397573668b39e6d1ab9f4a72")
130        use_repo(pw_rust, "pw_rust_toolchains")
131        register_toolchains(
132            "@pw_rust_toolchains//:all",
133            dev_dependency = True,
134        )
135        ```
136    """,
137)
138
139def _pw_rust_toolchain(
140        name,
141        exec_triple,
142        target_triple,
143        toolchain_repo,
144        target_repo,
145        dylib_ext,
146        exec_compatible_with,
147        target_compatible_with,
148        target_settings,
149        extra_rustc_flags,
150        analyzer_toolchain_name = None,
151        rustfmt_toolchain_name = None,
152        build_std = False):
153    if build_std:
154        build_file = rust_toolchain_no_prebuilt_template(
155            name = name,
156            exec_compatible_with = exec_compatible_with,
157            target_compatible_with = target_compatible_with,
158            dylib_ext = dylib_ext,
159            toolchain_repo = toolchain_repo,
160            exec_triple = exec_triple,
161            target_triple = target_triple,
162            extra_rustc_flags = extra_rustc_flags,
163        )
164    else:
165        build_file = rust_toolchain_template(
166            name = name,
167            exec_compatible_with = exec_compatible_with,
168            target_compatible_with = target_compatible_with,
169            dylib_ext = dylib_ext,
170            target_repo = target_repo,
171            toolchain_repo = toolchain_repo,
172            exec_triple = exec_triple,
173            target_triple = target_triple,
174            extra_rustc_flags = extra_rustc_flags,
175        )
176
177    build_file += toolchain_template(
178        name = name,
179        exec_compatible_with = exec_compatible_with,
180        target_compatible_with = target_compatible_with,
181        target_settings = target_settings,
182    )
183
184    if analyzer_toolchain_name:
185        build_file += rust_analyzer_toolchain_template(
186            name = analyzer_toolchain_name,
187            toolchain_repo = toolchain_repo,
188            exec_compatible_with = exec_compatible_with,
189            target_compatible_with = target_compatible_with,
190            target_settings = target_settings,
191        )
192
193    if rustfmt_toolchain_name:
194        build_file += rustfmt_toolchain_template(
195            name = rustfmt_toolchain_name,
196            toolchain_repo = toolchain_repo,
197            exec_compatible_with = exec_compatible_with,
198            target_compatible_with = target_compatible_with,
199            target_settings = target_settings,
200        )
201
202    return build_file
203
204def _BUILD_for_toolchain_repo():
205    # Declare rust toolchains
206    build_file = """load("@rules_rust//rust:toolchain.bzl", "rust_analyzer_toolchain", "rustfmt_toolchain", "rust_toolchain")\n"""
207    for channel in CHANNELS:
208        for host in HOSTS:
209            build_file += _pw_rust_toolchain(
210                name = "host_rust_toolchain_{}_{}_{}".format(host["os"], host["cpu"], channel["name"]),
211                analyzer_toolchain_name = "host_rust_analyzer_toolchain_{}_{}_{}".format(host["os"], host["cpu"], channel["name"]),
212                rustfmt_toolchain_name = "host_rustfmt_toolchain_{}_{}_{}".format(host["os"], host["cpu"], channel["name"]),
213                exec_compatible_with = [
214                    "@platforms//cpu:{}".format(host["cpu"]),
215                    "@platforms//os:{}".format(host["os"]),
216                ],
217                target_compatible_with = [
218                    "@platforms//cpu:{}".format(host["cpu"]),
219                    "@platforms//os:{}".format(host["os"]),
220                ],
221                target_settings = channel["target_settings"],
222                dylib_ext = host["dylib_ext"],
223                target_repo = "@rust_toolchain_target_{}_{}".format(host["triple"], host["cpu"]),
224                toolchain_repo = "@rust_toolchain_host_{}_{}".format(host["os"], host["cpu"]),
225                exec_triple = host["triple"],
226                target_triple = host["triple"],
227                extra_rustc_flags = channel["extra_rustc_flags"],
228            )
229
230            for target in EXTRA_TARGETS:
231                build_file += _pw_rust_toolchain(
232                    name = "{}_{}_rust_toolchain_{}_{}_{}".format(host["os"], host["cpu"], target["triple"], target["cpu"], channel["name"]),
233                    exec_triple = host["triple"],
234                    target_triple = target["triple"],
235                    target_repo = "@rust_toolchain_target_{}_{}".format(target["triple"], target["cpu"]),
236                    toolchain_repo = "@rust_toolchain_host_{}_{}".format(host["os"], host["cpu"]),
237                    dylib_ext = "*.so",
238                    exec_compatible_with = [
239                        "@platforms//cpu:{}".format(host["cpu"]),
240                        "@platforms//os:{}".format(host["os"]),
241                    ],
242                    target_compatible_with = [
243                        "@platforms//cpu:{}".format(target["cpu"]),
244                    ],
245                    target_settings = channel["target_settings"],
246                    extra_rustc_flags = channel["extra_rustc_flags"],
247                    build_std = target.get("build_std", False),
248                )
249    return build_file
250
251def _toolchain_repository_hub_impl(repository_ctx):
252    repository_ctx.file("WORKSPACE.bazel", """workspace(name = "{}")""".format(
253        repository_ctx.name,
254    ))
255
256    repository_ctx.file("BUILD.bazel", _BUILD_for_toolchain_repo())
257
258_toolchain_repository_hub = repository_rule(
259    doc = "A repository of Pigweed Rust toolchains",
260    implementation = _toolchain_repository_hub_impl,
261)
262