1"""Utility functions for generating protobuf code.""" 2 3load("@rules_proto//proto:defs.bzl", "ProtoInfo") 4 5_PROTO_EXTENSION = ".proto" 6_VIRTUAL_IMPORTS = "/_virtual_imports/" 7 8def well_known_proto_libs(): 9 return [ 10 "@com_google_protobuf//:any_proto", 11 "@com_google_protobuf//:api_proto", 12 "@com_google_protobuf//:compiler_plugin_proto", 13 "@com_google_protobuf//:descriptor_proto", 14 "@com_google_protobuf//:duration_proto", 15 "@com_google_protobuf//:empty_proto", 16 "@com_google_protobuf//:field_mask_proto", 17 "@com_google_protobuf//:source_context_proto", 18 "@com_google_protobuf//:struct_proto", 19 "@com_google_protobuf//:timestamp_proto", 20 "@com_google_protobuf//:type_proto", 21 "@com_google_protobuf//:wrappers_proto", 22 ] 23 24def get_proto_root(workspace_root): 25 """Gets the root protobuf directory. 26 27 Args: 28 workspace_root: context.label.workspace_root 29 30 Returns: 31 The directory relative to which generated include paths should be. 32 """ 33 if workspace_root: 34 return "/{}".format(workspace_root) 35 else: 36 return "" 37 38def _strip_proto_extension(proto_filename): 39 if not proto_filename.endswith(_PROTO_EXTENSION): 40 fail('"{}" does not end with "{}"'.format( 41 proto_filename, 42 _PROTO_EXTENSION, 43 )) 44 return proto_filename[:-len(_PROTO_EXTENSION)] 45 46def proto_path_to_generated_filename(proto_path, fmt_str): 47 """Calculates the name of a generated file for a protobuf path. 48 49 For example, "examples/protos/helloworld.proto" might map to 50 "helloworld.pb.h". 51 52 Args: 53 proto_path: The path to the .proto file. 54 fmt_str: A format string used to calculate the generated filename. For 55 example, "{}.pb.h" might be used to calculate a C++ header filename. 56 57 Returns: 58 The generated filename. 59 """ 60 return fmt_str.format(_strip_proto_extension(proto_path)) 61 62def get_include_directory(source_file): 63 """Returns the include directory path for the source_file. I.e. all of the 64 include statements within the given source_file are calculated relative to 65 the directory returned by this method. 66 67 The returned directory path can be used as the "--proto_path=" argument 68 value. 69 70 Args: 71 source_file: A proto file. 72 73 Returns: 74 The include directory path for the source_file. 75 """ 76 directory = source_file.path 77 prefix_len = 0 78 79 if is_in_virtual_imports(source_file): 80 root, relative = source_file.path.split(_VIRTUAL_IMPORTS, 2) 81 result = root + _VIRTUAL_IMPORTS + relative.split("/", 1)[0] 82 return result 83 84 if not source_file.is_source and directory.startswith(source_file.root.path): 85 prefix_len = len(source_file.root.path) + 1 86 87 if directory.startswith("external", prefix_len): 88 external_separator = directory.find("/", prefix_len) 89 repository_separator = directory.find("/", external_separator + 1) 90 return directory[:repository_separator] 91 else: 92 return source_file.root.path if source_file.root.path else "." 93 94def get_plugin_args( 95 plugin, 96 flags, 97 dir_out, 98 generate_mocks, 99 plugin_name = "PLUGIN"): 100 """Returns arguments configuring protoc to use a plugin for a language. 101 102 Args: 103 plugin: An executable file to run as the protoc plugin. 104 flags: The plugin flags to be passed to protoc. 105 dir_out: The output directory for the plugin. 106 generate_mocks: A bool indicating whether to generate mocks. 107 plugin_name: A name of the plugin, it is required to be unique when there 108 are more than one plugin used in a single protoc command. 109 Returns: 110 A list of protoc arguments configuring the plugin. 111 """ 112 augmented_flags = list(flags) 113 if generate_mocks: 114 augmented_flags.append("generate_mock_code=true") 115 116 augmented_dir_out = dir_out 117 if augmented_flags: 118 augmented_dir_out = ",".join(augmented_flags) + ":" + dir_out 119 120 return [ 121 "--plugin=protoc-gen-{plugin_name}={plugin_path}".format( 122 plugin_name = plugin_name, 123 plugin_path = plugin.path, 124 ), 125 "--{plugin_name}_out={dir_out}".format( 126 plugin_name = plugin_name, 127 dir_out = augmented_dir_out, 128 ), 129 ] 130 131def _get_staged_proto_file(context, source_file): 132 if source_file.dirname == context.label.package or \ 133 is_in_virtual_imports(source_file): 134 # Current target and source_file are in same package 135 return source_file 136 else: 137 # Current target and source_file are in different packages (most 138 # probably even in different repositories) 139 copied_proto = context.actions.declare_file(source_file.basename) 140 context.actions.run_shell( 141 inputs = [source_file], 142 outputs = [copied_proto], 143 command = "cp {} {}".format(source_file.path, copied_proto.path), 144 mnemonic = "CopySourceProto", 145 ) 146 return copied_proto 147 148def protos_from_context(context): 149 """Copies proto files to the appropriate location. 150 151 Args: 152 context: The ctx object for the rule. 153 154 Returns: 155 A list of the protos. 156 """ 157 protos = [] 158 for src in context.attr.deps: 159 for file in src[ProtoInfo].direct_sources: 160 protos.append(_get_staged_proto_file(context, file)) 161 return protos 162 163def includes_from_deps(deps): 164 """Get includes from rule dependencies.""" 165 return [ 166 file 167 for src in deps 168 for file in src[ProtoInfo].transitive_imports.to_list() 169 ] 170 171def get_proto_arguments(protos, genfiles_dir_path): 172 """Get the protoc arguments specifying which protos to compile.""" 173 arguments = [] 174 for proto in protos: 175 strip_prefix_len = 0 176 if is_in_virtual_imports(proto): 177 incl_directory = get_include_directory(proto) 178 if proto.path.startswith(incl_directory): 179 strip_prefix_len = len(incl_directory) + 1 180 elif proto.path.startswith(genfiles_dir_path): 181 strip_prefix_len = len(genfiles_dir_path) + 1 182 183 arguments.append(proto.path[strip_prefix_len:]) 184 185 return arguments 186 187def declare_out_files(protos, context, generated_file_format): 188 """Declares and returns the files to be generated.""" 189 190 out_file_paths = [] 191 for proto in protos: 192 if not is_in_virtual_imports(proto): 193 out_file_paths.append(proto.basename) 194 else: 195 path = proto.path[proto.path.index(_VIRTUAL_IMPORTS) + 1:] 196 out_file_paths.append(path) 197 198 return [ 199 context.actions.declare_file( 200 proto_path_to_generated_filename( 201 out_file_path, 202 generated_file_format, 203 ), 204 ) 205 for out_file_path in out_file_paths 206 ] 207 208def get_out_dir(protos, context): 209 """ Returns the calculated value for --<lang>_out= protoc argument based on 210 the input source proto files and current context. 211 212 Args: 213 protos: A list of protos to be used as source files in protoc command 214 context: A ctx object for the rule. 215 Returns: 216 The value of --<lang>_out= argument. 217 """ 218 at_least_one_virtual = 0 219 for proto in protos: 220 if is_in_virtual_imports(proto): 221 at_least_one_virtual = True 222 elif at_least_one_virtual: 223 fail("Proto sources must be either all virtual imports or all real") 224 if at_least_one_virtual: 225 out_dir = get_include_directory(protos[0]) 226 ws_root = protos[0].owner.workspace_root 227 if ws_root and out_dir.find(ws_root) >= 0: 228 out_dir = "".join(out_dir.rsplit(ws_root, 1)) 229 return struct( 230 path = out_dir, 231 import_path = out_dir[out_dir.find(_VIRTUAL_IMPORTS) + 1:], 232 ) 233 return struct(path = context.genfiles_dir.path, import_path = None) 234 235def is_in_virtual_imports(source_file, virtual_folder = _VIRTUAL_IMPORTS): 236 """Determines if source_file is virtual (is placed in _virtual_imports 237 subdirectory). The output of all proto_library targets which use 238 import_prefix and/or strip_import_prefix arguments is placed under 239 _virtual_imports directory. 240 241 Args: 242 source_file: A proto file. 243 virtual_folder: The virtual folder name (is set to "_virtual_imports" 244 by default) 245 Returns: 246 True if source_file is located under _virtual_imports, False otherwise. 247 """ 248 return not source_file.is_source and virtual_folder in source_file.path 249