1"Module extensions for using rules_rust with bzlmod" 2 3load("//rust:defs.bzl", "rust_common") 4load("//rust:repositories.bzl", "rust_register_toolchains", "rust_toolchain_tools_repository") 5load("//rust/platform:triple.bzl", "get_host_triple") 6load( 7 "//rust/private:repository_utils.bzl", 8 "DEFAULT_EXTRA_TARGET_TRIPLES", 9 "DEFAULT_NIGHTLY_VERSION", 10 "DEFAULT_STATIC_RUST_URL_TEMPLATES", 11) 12 13_HOST_TOOL_ERR = """When %s, host tools must be explicitly defined. For example: 14 15rust = use_extension("@rules_rust//rust:extensions.bzl", "rust") 16rust.host_tools( 17 edition = "2021", 18 version = "1.70.2", 19) 20""" 21 22_EXAMPLE_TOOLCHAIN = """ 23rust = use_extension("@rules_rust//rust:extensions.bzl", "rust") 24rust.toolchain( 25 edition = "2021", 26 versions = ["1.70.2"], 27) 28use_repo(rust, "rust_toolchains") 29register_toolchains("@rust_toolchains//:all")""" 30 31_TRANSITIVE_DEP_ERR = """ 32Your transitive dependency %s is using rules_rust, so you need to define a rust toolchain. 33To do so, you will need to add the following to your root MODULE.bazel. For example: 34 35bazel_dep(name = "rules_rust", version = "<rules_rust version>") 36""" + _EXAMPLE_TOOLCHAIN 37 38_TOOLCHAIN_ERR = """ 39Please add at least one toolchain to your root MODULE.bazel. For example: 40""" + _EXAMPLE_TOOLCHAIN 41 42def _rust_impl(module_ctx): 43 # Toolchain configuration is only allowed in the root module. 44 # It would be very confusing (and a security concern) if I was using the 45 # default rust toolchains, then when I added a module built on rust, I was 46 # suddenly using a custom rustc. 47 root = None 48 for mod in module_ctx.modules: 49 if mod.is_root: 50 root = mod 51 if not root: 52 fail(_TRANSITIVE_DEP_ERR % module_ctx.modules[0].name) 53 54 toolchains = root.tags.toolchain 55 if not toolchains: 56 fail(_TOOLCHAIN_ERR) 57 58 if len(root.tags.host_tools) == 1: 59 host_tools = root.tags.host_tools[0] 60 elif not root.tags.host_tools: 61 if len(toolchains) != 1: 62 fail(_HOST_TOOL_ERR % "multiple toolchains are provided") 63 toolchain = toolchains[0] 64 if len(toolchain.versions) == 1: 65 version = toolchain.versions[0] 66 elif not toolchain.versions: 67 version = None 68 else: 69 fail(_HOST_TOOL_ERR % "multiple toolchain versions are provided") 70 host_tools = struct( 71 allocator_library = toolchain.allocator_library, 72 dev_components = toolchain.dev_components, 73 edition = toolchain.edition, 74 rustfmt_version = toolchain.rustfmt_version, 75 sha256s = toolchain.sha256s, 76 urls = toolchain.urls, 77 version = version, 78 ) 79 else: 80 fail("Multiple host_tools were defined in your root MODULE.bazel") 81 82 host_triple = get_host_triple(module_ctx) 83 rust_toolchain_tools_repository( 84 name = "rust_host_tools", 85 exec_triple = host_triple.str, 86 target_triple = host_triple.str, 87 allocator_library = host_tools.allocator_library, 88 dev_components = host_tools.dev_components, 89 edition = host_tools.edition, 90 rustfmt_version = host_tools.rustfmt_version, 91 sha256s = host_tools.sha256s, 92 urls = host_tools.urls, 93 version = host_tools.version or rust_common.default_version, 94 ) 95 96 for toolchain in toolchains: 97 rust_register_toolchains( 98 dev_components = toolchain.dev_components, 99 edition = toolchain.edition, 100 allocator_library = toolchain.allocator_library, 101 rustfmt_version = toolchain.rustfmt_version, 102 rust_analyzer_version = toolchain.rust_analyzer_version, 103 sha256s = toolchain.sha256s, 104 extra_target_triples = toolchain.extra_target_triples, 105 urls = toolchain.urls, 106 versions = toolchain.versions, 107 register_toolchains = False, 108 ) 109 110_COMMON_TAG_KWARGS = dict( 111 allocator_library = attr.string(default = "@rules_rust//ffi/cc/allocator_library"), 112 dev_components = attr.bool(default = False), 113 edition = attr.string(), 114 rustfmt_version = attr.string(default = DEFAULT_NIGHTLY_VERSION), 115 sha256s = attr.string_dict(), 116 urls = attr.string_list(default = DEFAULT_STATIC_RUST_URL_TEMPLATES), 117) 118 119_RUST_TOOLCHAIN_TAG = tag_class(attrs = dict( 120 extra_target_triples = attr.string_list(default = DEFAULT_EXTRA_TARGET_TRIPLES), 121 rust_analyzer_version = attr.string(), 122 versions = attr.string_list(default = []), 123 **_COMMON_TAG_KWARGS 124)) 125 126_RUST_HOST_TOOLS_TAG = tag_class(attrs = dict( 127 version = attr.string(), 128 **_COMMON_TAG_KWARGS 129)) 130 131rust = module_extension( 132 implementation = _rust_impl, 133 tag_classes = { 134 "host_tools": _RUST_HOST_TOOLS_TAG, 135 "toolchain": _RUST_TOOLCHAIN_TAG, 136 }, 137) 138