• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Rules for building protos in Rust with Prost and Tonic."""
2
3load("@rules_proto//proto:defs.bzl", "ProtoInfo", "proto_common")
4load("//proto/prost:providers.bzl", "ProstProtoInfo")
5load("//rust:defs.bzl", "rust_common")
6
7# buildifier: disable=bzl-visibility
8load("//rust/private:providers.bzl", "RustAnalyzerGroupInfo", "RustAnalyzerInfo")
9
10# buildifier: disable=bzl-visibility
11load("//rust/private:rust.bzl", "RUSTC_ATTRS")
12
13# buildifier: disable=bzl-visibility
14load("//rust/private:rust_analyzer.bzl", "write_rust_analyzer_spec_file")
15
16# buildifier: disable=bzl-visibility
17load("//rust/private:rustc.bzl", "rustc_compile_action")
18
19# buildifier: disable=bzl-visibility
20load("//rust/private:utils.bzl", "can_build_metadata")
21
22RUST_EDITION = "2021"
23
24TOOLCHAIN_TYPE = "@rules_rust//proto/prost:toolchain_type"
25
26def _create_proto_lang_toolchain(ctx, prost_toolchain):
27    proto_lang_toolchain = proto_common.ProtoLangToolchainInfo(
28        out_replacement_format_flag = "--prost_out=%s",
29        plugin_format_flag = prost_toolchain.prost_plugin_flag,
30        plugin = prost_toolchain.prost_plugin[DefaultInfo].files_to_run,
31        runtime = prost_toolchain.prost_runtime,
32        provided_proto_sources = depset(),
33        proto_compiler = ctx.attr._prost_process_wrapper[DefaultInfo].files_to_run,
34        protoc_opts = prost_toolchain.protoc_opts,
35        progress_message = "ProstGenProto %{label}",
36        mnemonic = "ProstGenProto",
37    )
38
39    return proto_lang_toolchain
40
41def _compile_proto(ctx, crate_name, proto_info, deps, prost_toolchain, rustfmt_toolchain = None):
42    deps_info_file = ctx.actions.declare_file(ctx.label.name + ".prost_deps_info")
43    dep_package_infos = [dep[ProstProtoInfo].package_info for dep in deps]
44    ctx.actions.write(
45        output = deps_info_file,
46        content = "\n".join([file.path for file in dep_package_infos]),
47    )
48
49    package_info_file = ctx.actions.declare_file(ctx.label.name + ".prost_package_info")
50    lib_rs = ctx.actions.declare_file("{}.lib.rs".format(ctx.label.name))
51
52    proto_compiler = prost_toolchain.proto_compiler[DefaultInfo].files_to_run
53    tools = depset([proto_compiler.executable])
54
55    additional_args = ctx.actions.args()
56
57    # Prost process wrapper specific args
58    additional_args.add("--protoc={}".format(proto_compiler.executable.path))
59    additional_args.add("--label={}".format(ctx.label))
60    additional_args.add("--out_librs={}".format(lib_rs.path))
61    additional_args.add("--package_info_output={}".format("{}={}".format(crate_name, package_info_file.path)))
62    additional_args.add("--deps_info={}".format(deps_info_file.path))
63    additional_args.add("--prost_opt=compile_well_known_types")
64    additional_args.add("--descriptor_set={}".format(proto_info.direct_descriptor_set.path))
65    additional_args.add_all(prost_toolchain.prost_opts, format_each = "--prost_opt=%s")
66
67    if prost_toolchain.tonic_plugin:
68        tonic_plugin = prost_toolchain.tonic_plugin[DefaultInfo].files_to_run
69        additional_args.add(prost_toolchain.tonic_plugin_flag % tonic_plugin.executable.path)
70        additional_args.add("--tonic_opt=no_include")
71        additional_args.add("--tonic_opt=compile_well_known_types")
72        additional_args.add("--is_tonic")
73        additional_args.add_all(prost_toolchain.tonic_opts, format_each = "--tonic_opt=%s")
74        tools = depset([tonic_plugin.executable], transitive = [tools])
75
76    if rustfmt_toolchain:
77        additional_args.add("--rustfmt={}".format(rustfmt_toolchain.rustfmt.path))
78        tools = depset(transitive = [tools, rustfmt_toolchain.all_files])
79
80    additional_inputs = depset([deps_info_file, proto_info.direct_descriptor_set] + [dep[ProstProtoInfo].package_info for dep in deps])
81
82    proto_common.compile(
83        actions = ctx.actions,
84        proto_info = proto_info,
85        additional_tools = tools.to_list(),
86        additional_inputs = additional_inputs,
87        additional_args = additional_args,
88        generated_files = [lib_rs, package_info_file],
89        proto_lang_toolchain_info = _create_proto_lang_toolchain(ctx, prost_toolchain),
90        plugin_output = ctx.bin_dir.path,
91    )
92
93    return lib_rs, package_info_file
94
95def _get_crate_info(providers):
96    """Finds the CrateInfo provider in the list of providers."""
97    for provider in providers:
98        if hasattr(provider, "name"):
99            return provider
100    fail("Couldn't find a CrateInfo in the list of providers")
101
102def _get_dep_info(providers):
103    """Finds the DepInfo provider in the list of providers."""
104    for provider in providers:
105        if hasattr(provider, "direct_crates"):
106            return provider
107    fail("Couldn't find a DepInfo in the list of providers")
108
109def _get_cc_info(providers):
110    """Finds the CcInfo provider in the list of providers."""
111    for provider in providers:
112        if hasattr(provider, "linking_context"):
113            return provider
114    fail("Couldn't find a CcInfo in the list of providers")
115
116def _compile_rust(ctx, attr, crate_name, src, deps, edition):
117    """Compiles a Rust source file.
118
119    Args:
120      ctx (RuleContext): The rule context.
121      attr (Attrs): The current rule's attributes (`ctx.attr` for rules, `ctx.rule.attr` for aspects)
122      crate_name (str): The crate module name to use.
123      src (File): The crate root source file to be compiled.
124      deps (List of DepVariantInfo): A list of dependencies needed.
125      edition (str): The Rust edition to use.
126
127    Returns:
128      A DepVariantInfo provider.
129    """
130    toolchain = ctx.toolchains["@rules_rust//rust:toolchain_type"]
131    output_hash = repr(hash(src.path + ".prost"))
132
133    lib_name = "{prefix}{name}-{lib_hash}{extension}".format(
134        prefix = "lib",
135        name = crate_name,
136        lib_hash = output_hash,
137        extension = ".rlib",
138    )
139
140    rmeta_name = "{prefix}{name}-{lib_hash}{extension}".format(
141        prefix = "lib",
142        name = crate_name,
143        lib_hash = output_hash,
144        extension = ".rmeta",
145    )
146
147    lib = ctx.actions.declare_file(lib_name)
148    rmeta = None
149
150    if can_build_metadata(toolchain, ctx, "rlib"):
151        rmeta_name = "{prefix}{name}-{lib_hash}{extension}".format(
152            prefix = "lib",
153            name = crate_name,
154            lib_hash = output_hash,
155            extension = ".rmeta",
156        )
157        rmeta = ctx.actions.declare_file(rmeta_name)
158
159    providers = rustc_compile_action(
160        ctx = ctx,
161        attr = attr,
162        toolchain = toolchain,
163        crate_info_dict = dict(
164            name = crate_name,
165            type = "rlib",
166            root = src,
167            srcs = depset([src]),
168            deps = depset(deps),
169            proc_macro_deps = depset([]),
170            aliases = {},
171            output = lib,
172            metadata = rmeta,
173            edition = edition,
174            is_test = False,
175            rustc_env = {},
176            compile_data = depset([]),
177            compile_data_targets = depset([]),
178            owner = ctx.label,
179        ),
180        output_hash = output_hash,
181    )
182
183    crate_info = _get_crate_info(providers)
184    dep_info = _get_dep_info(providers)
185    cc_info = _get_cc_info(providers)
186
187    return rust_common.dep_variant_info(
188        crate_info = crate_info,
189        dep_info = dep_info,
190        cc_info = cc_info,
191        build_info = None,
192    )
193
194def _rust_prost_aspect_impl(target, ctx):
195    if ProstProtoInfo in target:
196        return []
197
198    runtime_deps = []
199
200    rustfmt_toolchain = ctx.toolchains["@rules_rust//rust/rustfmt:toolchain_type"]
201    prost_toolchain = ctx.toolchains["@rules_rust//proto/prost:toolchain_type"]
202    for prost_runtime in [prost_toolchain.prost_runtime, prost_toolchain.tonic_runtime]:
203        if not prost_runtime:
204            continue
205        if rust_common.crate_group_info in prost_runtime:
206            crate_group_info = prost_runtime[rust_common.crate_group_info]
207            runtime_deps.extend(crate_group_info.dep_variant_infos.to_list())
208        else:
209            runtime_deps.append(rust_common.dep_variant_info(
210                crate_info = prost_runtime[rust_common.crate_info] if rust_common.crate_info in prost_runtime else None,
211                dep_info = prost_runtime[rust_common.dep_info] if rust_common.dep_info in prost_runtime else None,
212                cc_info = prost_runtime[CcInfo] if CcInfo in prost_runtime else None,
213                build_info = None,
214            ))
215
216    proto_deps = getattr(ctx.rule.attr, "deps", [])
217
218    direct_deps = []
219    transitive_deps = [depset(runtime_deps)]
220    rust_analyzer_deps = []
221    for proto_dep in proto_deps:
222        proto_info = proto_dep[ProstProtoInfo]
223
224        direct_deps.append(proto_info.dep_variant_info)
225        transitive_deps.append(depset(
226            [proto_info.dep_variant_info],
227            transitive = [proto_info.transitive_dep_infos],
228        ))
229
230        if RustAnalyzerInfo in proto_dep:
231            rust_analyzer_deps.append(proto_dep[RustAnalyzerInfo])
232
233    deps = runtime_deps + direct_deps
234
235    crate_name = ctx.label.name.replace("-", "_").replace("/", "_")
236
237    proto_info = target[ProtoInfo]
238
239    lib_rs, package_info_file = _compile_proto(
240        ctx = ctx,
241        crate_name = crate_name,
242        proto_info = proto_info,
243        deps = proto_deps,
244        prost_toolchain = prost_toolchain,
245        rustfmt_toolchain = rustfmt_toolchain,
246    )
247
248    dep_variant_info = _compile_rust(
249        ctx = ctx,
250        attr = ctx.rule.attr,
251        crate_name = crate_name,
252        src = lib_rs,
253        deps = deps,
254        edition = RUST_EDITION,
255    )
256
257    # Always add `test` & `debug_assertions`. See rust-analyzer source code:
258    # https://github.com/rust-analyzer/rust-analyzer/blob/2021-11-15/crates/project_model/src/workspace.rs#L529-L531
259    cfgs = ["test", "debug_assertions"]
260
261    rust_analyzer_info = write_rust_analyzer_spec_file(ctx, ctx.rule.attr, ctx.label, RustAnalyzerInfo(
262        crate = dep_variant_info.crate_info,
263        cfgs = cfgs,
264        env = dep_variant_info.crate_info.rustc_env,
265        deps = rust_analyzer_deps,
266        crate_specs = depset(transitive = [dep.crate_specs for dep in rust_analyzer_deps]),
267        proc_macro_dylib_path = None,
268        build_info = dep_variant_info.build_info,
269    ))
270
271    return [
272        ProstProtoInfo(
273            dep_variant_info = dep_variant_info,
274            transitive_dep_infos = depset(transitive = transitive_deps),
275            package_info = package_info_file,
276        ),
277        rust_analyzer_info,
278    ]
279
280rust_prost_aspect = aspect(
281    doc = "An aspect used to generate and compile proto files with Prost.",
282    implementation = _rust_prost_aspect_impl,
283    attr_aspects = ["deps"],
284    attrs = {
285        "_collect_cc_coverage": attr.label(
286            default = Label("//util:collect_coverage"),
287            executable = True,
288            cfg = "exec",
289        ),
290        "_grep_includes": attr.label(
291            allow_single_file = True,
292            default = Label("@bazel_tools//tools/cpp:grep-includes"),
293            cfg = "exec",
294        ),
295        "_prost_process_wrapper": attr.label(
296            doc = "The wrapper script for the Prost protoc plugin.",
297            cfg = "exec",
298            executable = True,
299            default = Label("//proto/prost/private:protoc_wrapper"),
300        ),
301    } | RUSTC_ATTRS,
302    fragments = ["cpp"],
303    toolchains = [
304        TOOLCHAIN_TYPE,
305        "@bazel_tools//tools/cpp:toolchain_type",
306        "@rules_rust//rust:toolchain_type",
307        "@rules_rust//rust/rustfmt:toolchain_type",
308    ],
309)
310
311def _rust_prost_library_impl(ctx):
312    proto_dep = ctx.attr.proto
313    rust_proto_info = proto_dep[ProstProtoInfo]
314    dep_variant_info = rust_proto_info.dep_variant_info
315
316    return [
317        DefaultInfo(files = depset([dep_variant_info.crate_info.output])),
318        rust_common.crate_group_info(
319            dep_variant_infos = depset(
320                [dep_variant_info],
321                transitive = [rust_proto_info.transitive_dep_infos],
322            ),
323        ),
324        RustAnalyzerGroupInfo(deps = [proto_dep[RustAnalyzerInfo]]),
325    ]
326
327rust_prost_library = rule(
328    doc = "A rule for generating a Rust library using Prost.",
329    implementation = _rust_prost_library_impl,
330    attrs = {
331        "proto": attr.label(
332            doc = "A `proto_library` target for which to generate Rust gencode.",
333            providers = [ProtoInfo],
334            aspects = [rust_prost_aspect],
335            mandatory = True,
336        ),
337        "_collect_cc_coverage": attr.label(
338            default = Label("@rules_rust//util:collect_coverage"),
339            executable = True,
340            cfg = "exec",
341        ),
342    },
343)
344
345def _rust_prost_toolchain_impl(ctx):
346    tonic_attrs = [ctx.attr.tonic_plugin_flag, ctx.attr.tonic_plugin, ctx.attr.tonic_runtime]
347    if any(tonic_attrs) and not all(tonic_attrs):
348        fail("When one tonic attribute is added, all must be added")
349
350    return [platform_common.ToolchainInfo(
351        prost_opts = ctx.attr.prost_opts,
352        prost_plugin = ctx.attr.prost_plugin,
353        prost_plugin_flag = ctx.attr.prost_plugin_flag,
354        prost_runtime = ctx.attr.prost_runtime,
355        prost_types = ctx.attr.prost_types,
356        proto_compiler = ctx.attr.proto_compiler,
357        protoc_opts = ctx.fragments.proto.experimental_protoc_opts,
358        tonic_opts = ctx.attr.tonic_opts,
359        tonic_plugin = ctx.attr.tonic_plugin,
360        tonic_plugin_flag = ctx.attr.tonic_plugin_flag,
361        tonic_runtime = ctx.attr.tonic_runtime,
362    )]
363
364rust_prost_toolchain = rule(
365    implementation = _rust_prost_toolchain_impl,
366    doc = "Rust Prost toolchain rule.",
367    fragments = ["proto"],
368    attrs = {
369        "prost_opts": attr.string_list(
370            doc = "Additional options to add to Prost.",
371        ),
372        "prost_plugin": attr.label(
373            doc = "Additional plugins to add to Prost.",
374            cfg = "exec",
375            executable = True,
376            mandatory = True,
377        ),
378        "prost_plugin_flag": attr.string(
379            doc = "Prost plugin flag format. (e.g. `--plugin=protoc-gen-prost=%s`)",
380            default = "--plugin=protoc-gen-prost=%s",
381        ),
382        "prost_runtime": attr.label(
383            doc = "The Prost runtime crates to use.",
384            providers = [[rust_common.crate_info], [rust_common.crate_group_info]],
385            mandatory = True,
386        ),
387        "prost_types": attr.label(
388            doc = "The Prost types crates to use.",
389            providers = [[rust_common.crate_info], [rust_common.crate_group_info]],
390            mandatory = True,
391        ),
392        "proto_compiler": attr.label(
393            doc = "The protoc compiler to use.",
394            cfg = "exec",
395            executable = True,
396            mandatory = True,
397        ),
398        "tonic_opts": attr.string_list(
399            doc = "Additional options to add to Tonic.",
400        ),
401        "tonic_plugin": attr.label(
402            doc = "Additional plugins to add to Tonic.",
403            cfg = "exec",
404            executable = True,
405        ),
406        "tonic_plugin_flag": attr.string(
407            doc = "Tonic plugin flag format. (e.g. `--plugin=protoc-gen-tonic=%s`))",
408            default = "--plugin=protoc-gen-tonic=%s",
409        ),
410        "tonic_runtime": attr.label(
411            doc = "The Tonic runtime crates to use.",
412            providers = [[rust_common.crate_info], [rust_common.crate_group_info]],
413        ),
414    },
415)
416
417def _current_prost_runtime_impl(ctx):
418    toolchain = ctx.toolchains[TOOLCHAIN_TYPE]
419
420    runtime_deps = []
421
422    for target in [toolchain.prost_runtime, toolchain.prost_types]:
423        if rust_common.crate_group_info in target:
424            crate_group_info = target[rust_common.crate_group_info]
425            runtime_deps.extend(crate_group_info.dep_variant_infos.to_list())
426        else:
427            runtime_deps.append(rust_common.dep_variant_info(
428                crate_info = target[rust_common.crate_info] if rust_common.crate_info in target else None,
429                dep_info = target[rust_common.dep_info] if rust_common.dep_info in target else None,
430                cc_info = target[CcInfo] if CcInfo in target else None,
431                build_info = None,
432            ))
433
434    return [rust_common.crate_group_info(
435        dep_variant_infos = depset(runtime_deps),
436    )]
437
438current_prost_runtime = rule(
439    doc = "A rule for accessing the current Prost toolchain components needed by the process wrapper",
440    provides = [rust_common.crate_group_info],
441    implementation = _current_prost_runtime_impl,
442    toolchains = [TOOLCHAIN_TYPE],
443)
444