• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Utility functions for generating protobuf code."""
2
3load("@rules_proto//proto:defs.bzl", "ProtoInfo")
4
5_PROTO_EXTENSION = ".proto"
6_VIRTUAL_IMPORTS = "/_virtual_imports/"
7
8def well_known_proto_libs():
9    return [
10        "@com_google_protobuf//:any_proto",
11        "@com_google_protobuf//:api_proto",
12        "@com_google_protobuf//:compiler_plugin_proto",
13        "@com_google_protobuf//:descriptor_proto",
14        "@com_google_protobuf//:duration_proto",
15        "@com_google_protobuf//:empty_proto",
16        "@com_google_protobuf//:field_mask_proto",
17        "@com_google_protobuf//:source_context_proto",
18        "@com_google_protobuf//:struct_proto",
19        "@com_google_protobuf//:timestamp_proto",
20        "@com_google_protobuf//:type_proto",
21        "@com_google_protobuf//:wrappers_proto",
22    ]
23
24def get_proto_root(workspace_root):
25    """Gets the root protobuf directory.
26
27    Args:
28      workspace_root: context.label.workspace_root
29
30    Returns:
31      The directory relative to which generated include paths should be.
32    """
33    if workspace_root:
34        return "/{}".format(workspace_root)
35    else:
36        return ""
37
38def _strip_proto_extension(proto_filename):
39    if not proto_filename.endswith(_PROTO_EXTENSION):
40        fail('"{}" does not end with "{}"'.format(
41            proto_filename,
42            _PROTO_EXTENSION,
43        ))
44    return proto_filename[:-len(_PROTO_EXTENSION)]
45
46def proto_path_to_generated_filename(proto_path, fmt_str):
47    """Calculates the name of a generated file for a protobuf path.
48
49    For example, "examples/protos/helloworld.proto" might map to
50      "helloworld.pb.h".
51
52    Args:
53      proto_path: The path to the .proto file.
54      fmt_str: A format string used to calculate the generated filename. For
55        example, "{}.pb.h" might be used to calculate a C++ header filename.
56
57    Returns:
58      The generated filename.
59    """
60    return fmt_str.format(_strip_proto_extension(proto_path))
61
62def get_include_directory(source_file):
63    """Returns the include directory path for the source_file. I.e. all of the
64    include statements within the given source_file are calculated relative to
65    the directory returned by this method.
66
67    The returned directory path can be used as the "--proto_path=" argument
68    value.
69
70    Args:
71      source_file: A proto file.
72
73    Returns:
74      The include directory path for the source_file.
75    """
76    directory = source_file.path
77    prefix_len = 0
78
79    if is_in_virtual_imports(source_file):
80        root, relative = source_file.path.split(_VIRTUAL_IMPORTS, 2)
81        result = root + _VIRTUAL_IMPORTS + relative.split("/", 1)[0]
82        return result
83
84    if not source_file.is_source and directory.startswith(source_file.root.path):
85        prefix_len = len(source_file.root.path) + 1
86
87    if directory.startswith("external", prefix_len):
88        external_separator = directory.find("/", prefix_len)
89        repository_separator = directory.find("/", external_separator + 1)
90        return directory[:repository_separator]
91    else:
92        return source_file.root.path if source_file.root.path else "."
93
94def get_plugin_args(
95        plugin,
96        flags,
97        dir_out,
98        generate_mocks,
99        plugin_name = "PLUGIN"):
100    """Returns arguments configuring protoc to use a plugin for a language.
101
102    Args:
103      plugin: An executable file to run as the protoc plugin.
104      flags: The plugin flags to be passed to protoc.
105      dir_out: The output directory for the plugin.
106      generate_mocks: A bool indicating whether to generate mocks.
107      plugin_name: A name of the plugin, it is required to be unique when there
108      are more than one plugin used in a single protoc command.
109    Returns:
110      A list of protoc arguments configuring the plugin.
111    """
112    augmented_flags = list(flags)
113    if generate_mocks:
114        augmented_flags.append("generate_mock_code=true")
115
116    augmented_dir_out = dir_out
117    if augmented_flags:
118        augmented_dir_out = ",".join(augmented_flags) + ":" + dir_out
119
120    return [
121        "--plugin=protoc-gen-{plugin_name}={plugin_path}".format(
122            plugin_name = plugin_name,
123            plugin_path = plugin.path,
124        ),
125        "--{plugin_name}_out={dir_out}".format(
126            plugin_name = plugin_name,
127            dir_out = augmented_dir_out,
128        ),
129    ]
130
131def _get_staged_proto_file(context, source_file):
132    if source_file.dirname == context.label.package or \
133       is_in_virtual_imports(source_file):
134        # Current target and source_file are in same package
135        return source_file
136    else:
137        # Current target and source_file are in different packages (most
138        # probably even in different repositories)
139        copied_proto = context.actions.declare_file(source_file.basename)
140        context.actions.run_shell(
141            inputs = [source_file],
142            outputs = [copied_proto],
143            command = "cp {} {}".format(source_file.path, copied_proto.path),
144            mnemonic = "CopySourceProto",
145        )
146        return copied_proto
147
148def protos_from_context(context):
149    """Copies proto files to the appropriate location.
150
151    Args:
152      context: The ctx object for the rule.
153
154    Returns:
155      A list of the protos.
156    """
157    protos = []
158    for src in context.attr.deps:
159        for file in src[ProtoInfo].direct_sources:
160            protos.append(_get_staged_proto_file(context, file))
161    return protos
162
163def includes_from_deps(deps):
164    """Get includes from rule dependencies."""
165    return [
166        file
167        for src in deps
168        for file in src[ProtoInfo].transitive_imports.to_list()
169    ]
170
171def get_proto_arguments(protos, genfiles_dir_path):
172    """Get the protoc arguments specifying which protos to compile."""
173    arguments = []
174    for proto in protos:
175        strip_prefix_len = 0
176        if is_in_virtual_imports(proto):
177            incl_directory = get_include_directory(proto)
178            if proto.path.startswith(incl_directory):
179                strip_prefix_len = len(incl_directory) + 1
180        elif proto.path.startswith(genfiles_dir_path):
181            strip_prefix_len = len(genfiles_dir_path) + 1
182
183        arguments.append(proto.path[strip_prefix_len:])
184
185    return arguments
186
187def declare_out_files(protos, context, generated_file_format):
188    """Declares and returns the files to be generated."""
189
190    out_file_paths = []
191    for proto in protos:
192        if not is_in_virtual_imports(proto):
193            out_file_paths.append(proto.basename)
194        else:
195            path = proto.path[proto.path.index(_VIRTUAL_IMPORTS) + 1:]
196            out_file_paths.append(path)
197
198    return [
199        context.actions.declare_file(
200            proto_path_to_generated_filename(
201                out_file_path,
202                generated_file_format,
203            ),
204        )
205        for out_file_path in out_file_paths
206    ]
207
208def get_out_dir(protos, context):
209    """ Returns the calculated value for --<lang>_out= protoc argument based on
210    the input source proto files and current context.
211
212    Args:
213        protos: A list of protos to be used as source files in protoc command
214        context: A ctx object for the rule.
215    Returns:
216        The value of --<lang>_out= argument.
217    """
218    at_least_one_virtual = 0
219    for proto in protos:
220        if is_in_virtual_imports(proto):
221            at_least_one_virtual = True
222        elif at_least_one_virtual:
223            fail("Proto sources must be either all virtual imports or all real")
224    if at_least_one_virtual:
225        out_dir = get_include_directory(protos[0])
226        ws_root = protos[0].owner.workspace_root
227        if ws_root and out_dir.find(ws_root) >= 0:
228            out_dir = "".join(out_dir.rsplit(ws_root, 1))
229        return struct(
230            path = out_dir,
231            import_path = out_dir[out_dir.find(_VIRTUAL_IMPORTS) + 1:],
232        )
233    return struct(path = context.genfiles_dir.path, import_path = None)
234
235def is_in_virtual_imports(source_file, virtual_folder = _VIRTUAL_IMPORTS):
236    """Determines if source_file is virtual (is placed in _virtual_imports
237    subdirectory). The output of all proto_library targets which use
238    import_prefix  and/or strip_import_prefix arguments is placed under
239    _virtual_imports directory.
240
241    Args:
242        source_file: A proto file.
243        virtual_folder: The virtual folder name (is set to "_virtual_imports"
244            by default)
245    Returns:
246        True if source_file is located under _virtual_imports, False otherwise.
247    """
248    return not source_file.is_source and virtual_folder in source_file.path
249