1# Copyright 2022 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""This module generates the code for pw_protobuf pw_rpc services.""" 15 16import os 17from typing import Iterable 18 19from pw_protobuf.output_file import OutputFile 20from pw_protobuf.proto_tree import ProtoServiceMethod 21from pw_protobuf.proto_tree import build_node_tree 22from pw_rpc import codegen 23from pw_rpc.codegen import ( 24 client_call_type, 25 get_id, 26 CodeGenerator, 27 RPC_NAMESPACE, 28) 29 30PROTO_H_EXTENSION = '.pwpb.h' 31PWPB_H_EXTENSION = '.pwpb.h' 32 33 34def _proto_filename_to_pwpb_header(proto_file: str) -> str: 35 """Returns the generated pwpb header name for a .proto file.""" 36 filename = os.path.splitext(proto_file)[0] 37 return f'{filename}{PWPB_H_EXTENSION}' 38 39 40def _proto_filename_to_generated_header(proto_file: str) -> str: 41 """Returns the generated C++ RPC header name for a .proto file.""" 42 filename = os.path.splitext(proto_file)[0] 43 return f'{filename}.rpc{PROTO_H_EXTENSION}' 44 45 46def _serde(method: ProtoServiceMethod) -> str: 47 """Returns the PwpbMethodSerde for this method.""" 48 return ( 49 f'{RPC_NAMESPACE}::internal::kPwpbMethodSerde<' 50 f'&{method.request_type().pwpb_table()}, ' 51 f'&{method.response_type().pwpb_table()}>' 52 ) 53 54 55def _client_call(method: ProtoServiceMethod) -> str: 56 template_args = [] 57 58 if method.client_streaming(): 59 template_args.append(method.request_type().pwpb_struct()) 60 61 template_args.append(method.response_type().pwpb_struct()) 62 63 return f'{client_call_type(method, "Pwpb")}<{", ".join(template_args)}>' 64 65 66def _function(method: ProtoServiceMethod) -> str: 67 return f'{_client_call(method)} {method.name()}' 68 69 70def _user_args(method: ProtoServiceMethod) -> Iterable[str]: 71 if not method.client_streaming(): 72 yield f'const {method.request_type().pwpb_struct()}& request' 73 74 response = method.response_type().pwpb_struct() 75 76 if method.server_streaming(): 77 yield f'::pw::Function<void(const {response}&)>&& on_next = nullptr' 78 yield '::pw::Function<void(::pw::Status)>&& on_completed = nullptr' 79 else: 80 yield ( 81 f'::pw::Function<void(const {response}&, ::pw::Status)>&& ' 82 'on_completed = nullptr' 83 ) 84 85 yield '::pw::Function<void(::pw::Status)>&& on_error = nullptr' 86 87 88class PwpbCodeGenerator(CodeGenerator): 89 """Generates an RPC service and client using the pw_protobuf API.""" 90 91 def name(self) -> str: 92 return 'pwpb' 93 94 def method_union_name(self) -> str: 95 return 'PwpbMethodUnion' 96 97 def includes(self, proto_file_name: str) -> Iterable[str]: 98 yield '#include "pw_rpc/pwpb/client_reader_writer.h"' 99 yield '#include "pw_rpc/pwpb/internal/method_union.h"' 100 yield '#include "pw_rpc/pwpb/server_reader_writer.h"' 101 102 # Include the corresponding pwpb header file for this proto file, in 103 # which the file's messages and enums are generated. All other files 104 # imported from the .proto file are #included in there. 105 pwpb_header = _proto_filename_to_pwpb_header(proto_file_name) 106 yield f'#include "{pwpb_header}"' 107 108 def service_aliases(self) -> None: 109 self.line('template <typename Response>') 110 self.line( 111 'using ServerWriter = ' 112 f'{RPC_NAMESPACE}::PwpbServerWriter<Response>;' 113 ) 114 self.line('template <typename Request, typename Response>') 115 self.line( 116 'using ServerReader = ' 117 f'{RPC_NAMESPACE}::PwpbServerReader<Request, Response>;' 118 ) 119 self.line('template <typename Request, typename Response>') 120 self.line( 121 'using ServerReaderWriter = ' 122 f'{RPC_NAMESPACE}::PwpbServerReaderWriter<Request, Response>;' 123 ) 124 125 def method_descriptor(self, method: ProtoServiceMethod) -> None: 126 impl_method = f'&Implementation::{method.name()}' 127 128 self.line( 129 f'{RPC_NAMESPACE}::internal::GetPwpbOrRawMethodFor<{impl_method}, ' 130 f'{method.type().cc_enum()}, ' 131 f'{method.request_type().pwpb_struct()}, ' 132 f'{method.response_type().pwpb_struct()}>(' 133 ) 134 with self.indent(4): 135 self.line(f'{get_id(method)}, // Hash of "{method.name()}"') 136 self.line(f'{_serde(method)}),') 137 138 def client_member_function(self, method: ProtoServiceMethod) -> None: 139 """Outputs client code for a single RPC method.""" 140 141 self.line(f'{_function(method)}(') 142 self.indented_list(*_user_args(method), end=') const {') 143 144 with self.indent(): 145 client_call = _client_call(method) 146 base = 'Stream' if method.server_streaming() else 'Unary' 147 self.line( 148 f'return {RPC_NAMESPACE}::internal::' 149 f'Pwpb{base}ResponseClientCall<' 150 f'{method.response_type().pwpb_struct()}>::' 151 f'Start<{client_call}>(' 152 ) 153 154 service_client = RPC_NAMESPACE + '::internal::ServiceClient' 155 156 args = [ 157 f'{service_client}::client()', 158 f'{service_client}::channel_id()', 159 'kServiceId', 160 get_id(method), 161 _serde(method), 162 ] 163 if method.server_streaming(): 164 args.append('std::move(on_next)') 165 166 args.append('std::move(on_completed)') 167 args.append('std::move(on_error)') 168 169 if not method.client_streaming(): 170 args.append('request') 171 172 self.indented_list(*args, end=');') 173 174 self.line('}') 175 176 def client_static_function(self, method: ProtoServiceMethod) -> None: 177 self.line(f'static {_function(method)}(') 178 self.indented_list( 179 f'{RPC_NAMESPACE}::Client& client', 180 'uint32_t channel_id', 181 *_user_args(method), 182 end=') {', 183 ) 184 185 with self.indent(): 186 self.line(f'return Client(client, channel_id).{method.name()}(') 187 188 args = [] 189 190 if not method.client_streaming(): 191 args.append('request') 192 193 if method.server_streaming(): 194 args.append('std::move(on_next)') 195 196 self.indented_list( 197 *args, 198 'std::move(on_completed)', 199 'std::move(on_error)', 200 end=');', 201 ) 202 203 self.line('}') 204 205 def method_info_specialization(self, method: ProtoServiceMethod) -> None: 206 self.line() 207 self.line(f'using Request = {method.request_type().pwpb_struct()};') 208 self.line(f'using Response = {method.response_type().pwpb_struct()};') 209 self.line() 210 self.line( 211 f'static constexpr const {RPC_NAMESPACE}::' 212 'PwpbMethodSerde& serde() {' 213 ) 214 with self.indent(): 215 self.line(f'return {_serde(method)};') 216 self.line('}') 217 218 219class StubGenerator(codegen.StubGenerator): 220 """Generates pw_protobuf RPC stubs.""" 221 222 def unary_signature(self, method: ProtoServiceMethod, prefix: str) -> str: 223 return ( 224 f'::pw::Status {prefix}{method.name()}( ' 225 f'const {method.request_type().pwpb_struct()}& request, ' 226 f'{method.response_type().pwpb_struct()}& response)' 227 ) 228 229 def unary_stub( 230 self, method: ProtoServiceMethod, output: OutputFile 231 ) -> None: 232 output.write_line(codegen.STUB_REQUEST_TODO) 233 output.write_line('static_cast<void>(request);') 234 output.write_line(codegen.STUB_RESPONSE_TODO) 235 output.write_line('static_cast<void>(response);') 236 output.write_line('return ::pw::Status::Unimplemented();') 237 238 def server_streaming_signature( 239 self, method: ProtoServiceMethod, prefix: str 240 ) -> str: 241 return ( 242 f'void {prefix}{method.name()}( ' 243 f'const {method.request_type().pwpb_struct()}& request, ' 244 f'ServerWriter<{method.response_type().pwpb_struct()}>& writer)' 245 ) 246 247 def client_streaming_signature( 248 self, method: ProtoServiceMethod, prefix: str 249 ) -> str: 250 return ( 251 f'void {prefix}{method.name()}( ' 252 f'ServerReader<{method.request_type().pwpb_struct()}, ' 253 f'{method.response_type().pwpb_struct()}>& reader)' 254 ) 255 256 def bidirectional_streaming_signature( 257 self, method: ProtoServiceMethod, prefix: str 258 ) -> str: 259 return ( 260 f'void {prefix}{method.name()}( ' 261 f'ServerReaderWriter<{method.request_type().pwpb_struct()}, ' 262 f'{method.response_type().pwpb_struct()}>& reader_writer)' 263 ) 264 265 266def process_proto_file(proto_file) -> Iterable[OutputFile]: 267 """Generates code for a single .proto file.""" 268 269 _, package_root = build_node_tree(proto_file) 270 output_filename = _proto_filename_to_generated_header(proto_file.name) 271 272 generator = PwpbCodeGenerator(output_filename) 273 codegen.generate_package(proto_file, package_root, generator) 274 275 codegen.package_stubs(package_root, generator, StubGenerator()) 276 277 return [generator.output] 278