1"""Rules for building protos in Rust with Prost and Tonic.""" 2 3load("@rules_proto//proto:defs.bzl", "ProtoInfo", "proto_common") 4load("//proto/prost:providers.bzl", "ProstProtoInfo") 5load("//rust:defs.bzl", "rust_common") 6 7# buildifier: disable=bzl-visibility 8load("//rust/private:providers.bzl", "RustAnalyzerGroupInfo", "RustAnalyzerInfo") 9 10# buildifier: disable=bzl-visibility 11load("//rust/private:rust.bzl", "RUSTC_ATTRS") 12 13# buildifier: disable=bzl-visibility 14load("//rust/private:rust_analyzer.bzl", "write_rust_analyzer_spec_file") 15 16# buildifier: disable=bzl-visibility 17load("//rust/private:rustc.bzl", "rustc_compile_action") 18 19# buildifier: disable=bzl-visibility 20load("//rust/private:utils.bzl", "can_build_metadata") 21 22RUST_EDITION = "2021" 23 24TOOLCHAIN_TYPE = "@rules_rust//proto/prost:toolchain_type" 25 26def _create_proto_lang_toolchain(ctx, prost_toolchain): 27 proto_lang_toolchain = proto_common.ProtoLangToolchainInfo( 28 out_replacement_format_flag = "--prost_out=%s", 29 plugin_format_flag = prost_toolchain.prost_plugin_flag, 30 plugin = prost_toolchain.prost_plugin[DefaultInfo].files_to_run, 31 runtime = prost_toolchain.prost_runtime, 32 provided_proto_sources = depset(), 33 proto_compiler = ctx.attr._prost_process_wrapper[DefaultInfo].files_to_run, 34 protoc_opts = prost_toolchain.protoc_opts, 35 progress_message = "ProstGenProto %{label}", 36 mnemonic = "ProstGenProto", 37 ) 38 39 return proto_lang_toolchain 40 41def _compile_proto(ctx, crate_name, proto_info, deps, prost_toolchain, rustfmt_toolchain = None): 42 deps_info_file = ctx.actions.declare_file(ctx.label.name + ".prost_deps_info") 43 dep_package_infos = [dep[ProstProtoInfo].package_info for dep in deps] 44 ctx.actions.write( 45 output = deps_info_file, 46 content = "\n".join([file.path for file in dep_package_infos]), 47 ) 48 49 package_info_file = ctx.actions.declare_file(ctx.label.name + ".prost_package_info") 50 lib_rs = ctx.actions.declare_file("{}.lib.rs".format(ctx.label.name)) 51 52 proto_compiler = prost_toolchain.proto_compiler[DefaultInfo].files_to_run 53 tools = depset([proto_compiler.executable]) 54 55 additional_args = ctx.actions.args() 56 57 # Prost process wrapper specific args 58 additional_args.add("--protoc={}".format(proto_compiler.executable.path)) 59 additional_args.add("--label={}".format(ctx.label)) 60 additional_args.add("--out_librs={}".format(lib_rs.path)) 61 additional_args.add("--package_info_output={}".format("{}={}".format(crate_name, package_info_file.path))) 62 additional_args.add("--deps_info={}".format(deps_info_file.path)) 63 additional_args.add("--prost_opt=compile_well_known_types") 64 additional_args.add("--descriptor_set={}".format(proto_info.direct_descriptor_set.path)) 65 additional_args.add_all(prost_toolchain.prost_opts, format_each = "--prost_opt=%s") 66 67 if prost_toolchain.tonic_plugin: 68 tonic_plugin = prost_toolchain.tonic_plugin[DefaultInfo].files_to_run 69 additional_args.add(prost_toolchain.tonic_plugin_flag % tonic_plugin.executable.path) 70 additional_args.add("--tonic_opt=no_include") 71 additional_args.add("--tonic_opt=compile_well_known_types") 72 additional_args.add("--is_tonic") 73 additional_args.add_all(prost_toolchain.tonic_opts, format_each = "--tonic_opt=%s") 74 tools = depset([tonic_plugin.executable], transitive = [tools]) 75 76 if rustfmt_toolchain: 77 additional_args.add("--rustfmt={}".format(rustfmt_toolchain.rustfmt.path)) 78 tools = depset(transitive = [tools, rustfmt_toolchain.all_files]) 79 80 additional_inputs = depset([deps_info_file, proto_info.direct_descriptor_set] + [dep[ProstProtoInfo].package_info for dep in deps]) 81 82 proto_common.compile( 83 actions = ctx.actions, 84 proto_info = proto_info, 85 additional_tools = tools.to_list(), 86 additional_inputs = additional_inputs, 87 additional_args = additional_args, 88 generated_files = [lib_rs, package_info_file], 89 proto_lang_toolchain_info = _create_proto_lang_toolchain(ctx, prost_toolchain), 90 plugin_output = ctx.bin_dir.path, 91 ) 92 93 return lib_rs, package_info_file 94 95def _get_crate_info(providers): 96 """Finds the CrateInfo provider in the list of providers.""" 97 for provider in providers: 98 if hasattr(provider, "name"): 99 return provider 100 fail("Couldn't find a CrateInfo in the list of providers") 101 102def _get_dep_info(providers): 103 """Finds the DepInfo provider in the list of providers.""" 104 for provider in providers: 105 if hasattr(provider, "direct_crates"): 106 return provider 107 fail("Couldn't find a DepInfo in the list of providers") 108 109def _get_cc_info(providers): 110 """Finds the CcInfo provider in the list of providers.""" 111 for provider in providers: 112 if hasattr(provider, "linking_context"): 113 return provider 114 fail("Couldn't find a CcInfo in the list of providers") 115 116def _compile_rust(ctx, attr, crate_name, src, deps, edition): 117 """Compiles a Rust source file. 118 119 Args: 120 ctx (RuleContext): The rule context. 121 attr (Attrs): The current rule's attributes (`ctx.attr` for rules, `ctx.rule.attr` for aspects) 122 crate_name (str): The crate module name to use. 123 src (File): The crate root source file to be compiled. 124 deps (List of DepVariantInfo): A list of dependencies needed. 125 edition (str): The Rust edition to use. 126 127 Returns: 128 A DepVariantInfo provider. 129 """ 130 toolchain = ctx.toolchains["@rules_rust//rust:toolchain_type"] 131 output_hash = repr(hash(src.path + ".prost")) 132 133 lib_name = "{prefix}{name}-{lib_hash}{extension}".format( 134 prefix = "lib", 135 name = crate_name, 136 lib_hash = output_hash, 137 extension = ".rlib", 138 ) 139 140 rmeta_name = "{prefix}{name}-{lib_hash}{extension}".format( 141 prefix = "lib", 142 name = crate_name, 143 lib_hash = output_hash, 144 extension = ".rmeta", 145 ) 146 147 lib = ctx.actions.declare_file(lib_name) 148 rmeta = None 149 150 if can_build_metadata(toolchain, ctx, "rlib"): 151 rmeta_name = "{prefix}{name}-{lib_hash}{extension}".format( 152 prefix = "lib", 153 name = crate_name, 154 lib_hash = output_hash, 155 extension = ".rmeta", 156 ) 157 rmeta = ctx.actions.declare_file(rmeta_name) 158 159 providers = rustc_compile_action( 160 ctx = ctx, 161 attr = attr, 162 toolchain = toolchain, 163 crate_info_dict = dict( 164 name = crate_name, 165 type = "rlib", 166 root = src, 167 srcs = depset([src]), 168 deps = depset(deps), 169 proc_macro_deps = depset([]), 170 aliases = {}, 171 output = lib, 172 metadata = rmeta, 173 edition = edition, 174 is_test = False, 175 rustc_env = {}, 176 compile_data = depset([]), 177 compile_data_targets = depset([]), 178 owner = ctx.label, 179 ), 180 output_hash = output_hash, 181 ) 182 183 crate_info = _get_crate_info(providers) 184 dep_info = _get_dep_info(providers) 185 cc_info = _get_cc_info(providers) 186 187 return rust_common.dep_variant_info( 188 crate_info = crate_info, 189 dep_info = dep_info, 190 cc_info = cc_info, 191 build_info = None, 192 ) 193 194def _rust_prost_aspect_impl(target, ctx): 195 if ProstProtoInfo in target: 196 return [] 197 198 runtime_deps = [] 199 200 rustfmt_toolchain = ctx.toolchains["@rules_rust//rust/rustfmt:toolchain_type"] 201 prost_toolchain = ctx.toolchains["@rules_rust//proto/prost:toolchain_type"] 202 for prost_runtime in [prost_toolchain.prost_runtime, prost_toolchain.tonic_runtime]: 203 if not prost_runtime: 204 continue 205 if rust_common.crate_group_info in prost_runtime: 206 crate_group_info = prost_runtime[rust_common.crate_group_info] 207 runtime_deps.extend(crate_group_info.dep_variant_infos.to_list()) 208 else: 209 runtime_deps.append(rust_common.dep_variant_info( 210 crate_info = prost_runtime[rust_common.crate_info] if rust_common.crate_info in prost_runtime else None, 211 dep_info = prost_runtime[rust_common.dep_info] if rust_common.dep_info in prost_runtime else None, 212 cc_info = prost_runtime[CcInfo] if CcInfo in prost_runtime else None, 213 build_info = None, 214 )) 215 216 proto_deps = getattr(ctx.rule.attr, "deps", []) 217 218 direct_deps = [] 219 transitive_deps = [depset(runtime_deps)] 220 rust_analyzer_deps = [] 221 for proto_dep in proto_deps: 222 proto_info = proto_dep[ProstProtoInfo] 223 224 direct_deps.append(proto_info.dep_variant_info) 225 transitive_deps.append(depset( 226 [proto_info.dep_variant_info], 227 transitive = [proto_info.transitive_dep_infos], 228 )) 229 230 if RustAnalyzerInfo in proto_dep: 231 rust_analyzer_deps.append(proto_dep[RustAnalyzerInfo]) 232 233 deps = runtime_deps + direct_deps 234 235 crate_name = ctx.label.name.replace("-", "_").replace("/", "_") 236 237 proto_info = target[ProtoInfo] 238 239 lib_rs, package_info_file = _compile_proto( 240 ctx = ctx, 241 crate_name = crate_name, 242 proto_info = proto_info, 243 deps = proto_deps, 244 prost_toolchain = prost_toolchain, 245 rustfmt_toolchain = rustfmt_toolchain, 246 ) 247 248 dep_variant_info = _compile_rust( 249 ctx = ctx, 250 attr = ctx.rule.attr, 251 crate_name = crate_name, 252 src = lib_rs, 253 deps = deps, 254 edition = RUST_EDITION, 255 ) 256 257 # Always add `test` & `debug_assertions`. See rust-analyzer source code: 258 # https://github.com/rust-analyzer/rust-analyzer/blob/2021-11-15/crates/project_model/src/workspace.rs#L529-L531 259 cfgs = ["test", "debug_assertions"] 260 261 rust_analyzer_info = write_rust_analyzer_spec_file(ctx, ctx.rule.attr, ctx.label, RustAnalyzerInfo( 262 crate = dep_variant_info.crate_info, 263 cfgs = cfgs, 264 env = dep_variant_info.crate_info.rustc_env, 265 deps = rust_analyzer_deps, 266 crate_specs = depset(transitive = [dep.crate_specs for dep in rust_analyzer_deps]), 267 proc_macro_dylib_path = None, 268 build_info = dep_variant_info.build_info, 269 )) 270 271 return [ 272 ProstProtoInfo( 273 dep_variant_info = dep_variant_info, 274 transitive_dep_infos = depset(transitive = transitive_deps), 275 package_info = package_info_file, 276 ), 277 rust_analyzer_info, 278 ] 279 280rust_prost_aspect = aspect( 281 doc = "An aspect used to generate and compile proto files with Prost.", 282 implementation = _rust_prost_aspect_impl, 283 attr_aspects = ["deps"], 284 attrs = { 285 "_collect_cc_coverage": attr.label( 286 default = Label("//util:collect_coverage"), 287 executable = True, 288 cfg = "exec", 289 ), 290 "_grep_includes": attr.label( 291 allow_single_file = True, 292 default = Label("@bazel_tools//tools/cpp:grep-includes"), 293 cfg = "exec", 294 ), 295 "_prost_process_wrapper": attr.label( 296 doc = "The wrapper script for the Prost protoc plugin.", 297 cfg = "exec", 298 executable = True, 299 default = Label("//proto/prost/private:protoc_wrapper"), 300 ), 301 } | RUSTC_ATTRS, 302 fragments = ["cpp"], 303 toolchains = [ 304 TOOLCHAIN_TYPE, 305 "@bazel_tools//tools/cpp:toolchain_type", 306 "@rules_rust//rust:toolchain_type", 307 "@rules_rust//rust/rustfmt:toolchain_type", 308 ], 309) 310 311def _rust_prost_library_impl(ctx): 312 proto_dep = ctx.attr.proto 313 rust_proto_info = proto_dep[ProstProtoInfo] 314 dep_variant_info = rust_proto_info.dep_variant_info 315 316 return [ 317 DefaultInfo(files = depset([dep_variant_info.crate_info.output])), 318 rust_common.crate_group_info( 319 dep_variant_infos = depset( 320 [dep_variant_info], 321 transitive = [rust_proto_info.transitive_dep_infos], 322 ), 323 ), 324 RustAnalyzerGroupInfo(deps = [proto_dep[RustAnalyzerInfo]]), 325 ] 326 327rust_prost_library = rule( 328 doc = "A rule for generating a Rust library using Prost.", 329 implementation = _rust_prost_library_impl, 330 attrs = { 331 "proto": attr.label( 332 doc = "A `proto_library` target for which to generate Rust gencode.", 333 providers = [ProtoInfo], 334 aspects = [rust_prost_aspect], 335 mandatory = True, 336 ), 337 "_collect_cc_coverage": attr.label( 338 default = Label("@rules_rust//util:collect_coverage"), 339 executable = True, 340 cfg = "exec", 341 ), 342 }, 343) 344 345def _rust_prost_toolchain_impl(ctx): 346 tonic_attrs = [ctx.attr.tonic_plugin_flag, ctx.attr.tonic_plugin, ctx.attr.tonic_runtime] 347 if any(tonic_attrs) and not all(tonic_attrs): 348 fail("When one tonic attribute is added, all must be added") 349 350 return [platform_common.ToolchainInfo( 351 prost_opts = ctx.attr.prost_opts, 352 prost_plugin = ctx.attr.prost_plugin, 353 prost_plugin_flag = ctx.attr.prost_plugin_flag, 354 prost_runtime = ctx.attr.prost_runtime, 355 prost_types = ctx.attr.prost_types, 356 proto_compiler = ctx.attr.proto_compiler, 357 protoc_opts = ctx.fragments.proto.experimental_protoc_opts, 358 tonic_opts = ctx.attr.tonic_opts, 359 tonic_plugin = ctx.attr.tonic_plugin, 360 tonic_plugin_flag = ctx.attr.tonic_plugin_flag, 361 tonic_runtime = ctx.attr.tonic_runtime, 362 )] 363 364rust_prost_toolchain = rule( 365 implementation = _rust_prost_toolchain_impl, 366 doc = "Rust Prost toolchain rule.", 367 fragments = ["proto"], 368 attrs = { 369 "prost_opts": attr.string_list( 370 doc = "Additional options to add to Prost.", 371 ), 372 "prost_plugin": attr.label( 373 doc = "Additional plugins to add to Prost.", 374 cfg = "exec", 375 executable = True, 376 mandatory = True, 377 ), 378 "prost_plugin_flag": attr.string( 379 doc = "Prost plugin flag format. (e.g. `--plugin=protoc-gen-prost=%s`)", 380 default = "--plugin=protoc-gen-prost=%s", 381 ), 382 "prost_runtime": attr.label( 383 doc = "The Prost runtime crates to use.", 384 providers = [[rust_common.crate_info], [rust_common.crate_group_info]], 385 mandatory = True, 386 ), 387 "prost_types": attr.label( 388 doc = "The Prost types crates to use.", 389 providers = [[rust_common.crate_info], [rust_common.crate_group_info]], 390 mandatory = True, 391 ), 392 "proto_compiler": attr.label( 393 doc = "The protoc compiler to use.", 394 cfg = "exec", 395 executable = True, 396 mandatory = True, 397 ), 398 "tonic_opts": attr.string_list( 399 doc = "Additional options to add to Tonic.", 400 ), 401 "tonic_plugin": attr.label( 402 doc = "Additional plugins to add to Tonic.", 403 cfg = "exec", 404 executable = True, 405 ), 406 "tonic_plugin_flag": attr.string( 407 doc = "Tonic plugin flag format. (e.g. `--plugin=protoc-gen-tonic=%s`))", 408 default = "--plugin=protoc-gen-tonic=%s", 409 ), 410 "tonic_runtime": attr.label( 411 doc = "The Tonic runtime crates to use.", 412 providers = [[rust_common.crate_info], [rust_common.crate_group_info]], 413 ), 414 }, 415) 416 417def _current_prost_runtime_impl(ctx): 418 toolchain = ctx.toolchains[TOOLCHAIN_TYPE] 419 420 runtime_deps = [] 421 422 for target in [toolchain.prost_runtime, toolchain.prost_types]: 423 if rust_common.crate_group_info in target: 424 crate_group_info = target[rust_common.crate_group_info] 425 runtime_deps.extend(crate_group_info.dep_variant_infos.to_list()) 426 else: 427 runtime_deps.append(rust_common.dep_variant_info( 428 crate_info = target[rust_common.crate_info] if rust_common.crate_info in target else None, 429 dep_info = target[rust_common.dep_info] if rust_common.dep_info in target else None, 430 cc_info = target[CcInfo] if CcInfo in target else None, 431 build_info = None, 432 )) 433 434 return [rust_common.crate_group_info( 435 dep_variant_infos = depset(runtime_deps), 436 )] 437 438current_prost_runtime = rule( 439 doc = "A rule for accessing the current Prost toolchain components needed by the process wrapper", 440 provides = [rust_common.crate_group_info], 441 implementation = _current_prost_runtime_impl, 442 toolchains = [TOOLCHAIN_TYPE], 443) 444