• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"Module extensions for using rules_rust with bzlmod"
2
3load("//rust:defs.bzl", "rust_common")
4load("//rust:repositories.bzl", "rust_register_toolchains", "rust_toolchain_tools_repository")
5load("//rust/platform:triple.bzl", "get_host_triple")
6load(
7    "//rust/private:repository_utils.bzl",
8    "DEFAULT_EXTRA_TARGET_TRIPLES",
9    "DEFAULT_NIGHTLY_VERSION",
10    "DEFAULT_STATIC_RUST_URL_TEMPLATES",
11)
12
13_HOST_TOOL_ERR = """When %s, host tools must be explicitly defined. For example:
14
15rust = use_extension("@rules_rust//rust:extensions.bzl", "rust")
16rust.host_tools(
17    edition = "2021",
18    version = "1.70.2",
19)
20"""
21
22_EXAMPLE_TOOLCHAIN = """
23rust = use_extension("@rules_rust//rust:extensions.bzl", "rust")
24rust.toolchain(
25    edition = "2021",
26    versions = ["1.70.2"],
27)
28use_repo(rust, "rust_toolchains")
29register_toolchains("@rust_toolchains//:all")"""
30
31_TRANSITIVE_DEP_ERR = """
32Your transitive dependency %s is using rules_rust, so you need to define a rust toolchain.
33To do so, you will need to add the following to your root MODULE.bazel. For example:
34
35bazel_dep(name = "rules_rust", version = "<rules_rust version>")
36""" + _EXAMPLE_TOOLCHAIN
37
38_TOOLCHAIN_ERR = """
39Please add at least one toolchain to your root MODULE.bazel. For example:
40""" + _EXAMPLE_TOOLCHAIN
41
42def _rust_impl(module_ctx):
43    # Toolchain configuration is only allowed in the root module.
44    # It would be very confusing (and a security concern) if I was using the
45    # default rust toolchains, then when I added a module built on rust, I was
46    # suddenly using a custom rustc.
47    root = None
48    for mod in module_ctx.modules:
49        if mod.is_root:
50            root = mod
51    if not root:
52        fail(_TRANSITIVE_DEP_ERR % module_ctx.modules[0].name)
53
54    toolchains = root.tags.toolchain
55    if not toolchains:
56        fail(_TOOLCHAIN_ERR)
57
58    if len(root.tags.host_tools) == 1:
59        host_tools = root.tags.host_tools[0]
60    elif not root.tags.host_tools:
61        if len(toolchains) != 1:
62            fail(_HOST_TOOL_ERR % "multiple toolchains are provided")
63        toolchain = toolchains[0]
64        if len(toolchain.versions) == 1:
65            version = toolchain.versions[0]
66        elif not toolchain.versions:
67            version = None
68        else:
69            fail(_HOST_TOOL_ERR % "multiple toolchain versions are provided")
70        host_tools = struct(
71            allocator_library = toolchain.allocator_library,
72            dev_components = toolchain.dev_components,
73            edition = toolchain.edition,
74            rustfmt_version = toolchain.rustfmt_version,
75            sha256s = toolchain.sha256s,
76            urls = toolchain.urls,
77            version = version,
78        )
79    else:
80        fail("Multiple host_tools were defined in your root MODULE.bazel")
81
82    host_triple = get_host_triple(module_ctx)
83    rust_toolchain_tools_repository(
84        name = "rust_host_tools",
85        exec_triple = host_triple.str,
86        target_triple = host_triple.str,
87        allocator_library = host_tools.allocator_library,
88        dev_components = host_tools.dev_components,
89        edition = host_tools.edition,
90        rustfmt_version = host_tools.rustfmt_version,
91        sha256s = host_tools.sha256s,
92        urls = host_tools.urls,
93        version = host_tools.version or rust_common.default_version,
94    )
95
96    for toolchain in toolchains:
97        rust_register_toolchains(
98            dev_components = toolchain.dev_components,
99            edition = toolchain.edition,
100            allocator_library = toolchain.allocator_library,
101            rustfmt_version = toolchain.rustfmt_version,
102            rust_analyzer_version = toolchain.rust_analyzer_version,
103            sha256s = toolchain.sha256s,
104            extra_target_triples = toolchain.extra_target_triples,
105            urls = toolchain.urls,
106            versions = toolchain.versions,
107            register_toolchains = False,
108        )
109
110_COMMON_TAG_KWARGS = dict(
111    allocator_library = attr.string(default = "@rules_rust//ffi/cc/allocator_library"),
112    dev_components = attr.bool(default = False),
113    edition = attr.string(),
114    rustfmt_version = attr.string(default = DEFAULT_NIGHTLY_VERSION),
115    sha256s = attr.string_dict(),
116    urls = attr.string_list(default = DEFAULT_STATIC_RUST_URL_TEMPLATES),
117)
118
119_RUST_TOOLCHAIN_TAG = tag_class(attrs = dict(
120    extra_target_triples = attr.string_list(default = DEFAULT_EXTRA_TARGET_TRIPLES),
121    rust_analyzer_version = attr.string(),
122    versions = attr.string_list(default = []),
123    **_COMMON_TAG_KWARGS
124))
125
126_RUST_HOST_TOOLS_TAG = tag_class(attrs = dict(
127    version = attr.string(),
128    **_COMMON_TAG_KWARGS
129))
130
131rust = module_extension(
132    implementation = _rust_impl,
133    tag_classes = {
134        "host_tools": _RUST_HOST_TOOLS_TAG,
135        "toolchain": _RUST_TOOLCHAIN_TAG,
136    },
137)
138