1# Copyright 2020 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""This module generates the code for nanopb-based pw_rpc services.""" 15 16import os 17from typing import Iterable, Iterator 18 19from pw_protobuf.output_file import OutputFile 20from pw_protobuf.proto_tree import ProtoNode, ProtoService, ProtoServiceMethod 21from pw_protobuf.proto_tree import build_node_tree 22from pw_rpc import codegen 23from pw_rpc.codegen import RPC_NAMESPACE 24import pw_rpc.ids 25 26PROTO_H_EXTENSION = '.pb.h' 27PROTO_CC_EXTENSION = '.pb.cc' 28NANOPB_H_EXTENSION = '.pb.h' 29 30 31def _proto_filename_to_nanopb_header(proto_file: str) -> str: 32 """Returns the generated nanopb header name for a .proto file.""" 33 return os.path.splitext(proto_file)[0] + NANOPB_H_EXTENSION 34 35 36def _proto_filename_to_generated_header(proto_file: str) -> str: 37 """Returns the generated C++ RPC header name for a .proto file.""" 38 filename = os.path.splitext(proto_file)[0] 39 return f'{filename}.rpc{PROTO_H_EXTENSION}' 40 41 42def _generate_method_descriptor(method: ProtoServiceMethod, method_id: int, 43 output: OutputFile) -> None: 44 """Generates a nanopb method descriptor for an RPC method.""" 45 46 req_fields = f'{method.request_type().nanopb_name()}_fields' 47 res_fields = f'{method.response_type().nanopb_name()}_fields' 48 impl_method = f'&Implementation::{method.name()}' 49 50 output.write_line( 51 f'{RPC_NAMESPACE}::internal::GetNanopbOrRawMethodFor<{impl_method}, ' 52 f'{method.type().cc_enum()}, ' 53 f'{method.request_type().nanopb_name()}, ' 54 f'{method.response_type().nanopb_name()}>(') 55 with output.indent(4): 56 output.write_line(f'0x{method_id:08x}, // Hash of "{method.name()}"') 57 output.write_line(f'{req_fields},') 58 output.write_line(f'{res_fields}),') 59 60 61def _generate_server_writer_alias(output: OutputFile) -> None: 62 output.write_line('template <typename T>') 63 output.write_line('using ServerWriter = ::pw::rpc::ServerWriter<T>;') 64 65 66def _generate_code_for_service(service: ProtoService, root: ProtoNode, 67 output: OutputFile) -> None: 68 """Generates a C++ derived class for a nanopb RPC service.""" 69 codegen.service_class(service, root, output, _generate_server_writer_alias, 70 'NanopbMethodUnion', _generate_method_descriptor) 71 72 73def _generate_code_for_client_method(method: ProtoServiceMethod, 74 output: OutputFile) -> None: 75 """Outputs client code for a single RPC method.""" 76 77 req = method.request_type().nanopb_name() 78 res = method.response_type().nanopb_name() 79 method_id = pw_rpc.ids.calculate(method.name()) 80 81 if method.type() == ProtoServiceMethod.Type.UNARY: 82 callback = f'{RPC_NAMESPACE}::UnaryResponseHandler<{res}>' 83 elif method.type() == ProtoServiceMethod.Type.SERVER_STREAMING: 84 callback = f'{RPC_NAMESPACE}::ServerStreamingResponseHandler<{res}>' 85 else: 86 raise NotImplementedError( 87 'Only unary and server streaming RPCs are currently supported') 88 89 output.write_line() 90 output.write_line(f'static NanopbClientCall<\n {callback}>') 91 output.write_line(f'{method.name()}({RPC_NAMESPACE}::Channel& channel,') 92 with output.indent(len(method.name()) + 1): 93 output.write_line(f'const {req}& request,') 94 output.write_line(f'{callback}& callback) {{') 95 96 with output.indent(): 97 output.write_line(f'NanopbClientCall<{callback}>') 98 output.write_line(' call(&channel,') 99 with output.indent(9): 100 output.write_line('kServiceId,') 101 output.write_line( 102 f'0x{method_id:08x}, // Hash of "{method.name()}"') 103 output.write_line('callback,') 104 output.write_line(f'{req}_fields,') 105 output.write_line(f'{res}_fields);') 106 output.write_line('call.SendRequest(&request);') 107 output.write_line('return call;') 108 109 output.write_line('}') 110 111 112def _generate_code_for_client(service: ProtoService, root: ProtoNode, 113 output: OutputFile) -> None: 114 """Outputs client code for an RPC service.""" 115 116 output.write_line('namespace nanopb {') 117 118 class_name = f'{service.cpp_namespace(root)}Client' 119 output.write_line(f'\nclass {class_name} {{') 120 output.write_line(' public:') 121 122 with output.indent(): 123 output.write_line('template <typename T>') 124 output.write_line( 125 f'using NanopbClientCall = {RPC_NAMESPACE}::NanopbClientCall<T>;') 126 127 output.write_line('') 128 output.write_line(f'{class_name}() = delete;') 129 130 for method in service.methods(): 131 _generate_code_for_client_method(method, output) 132 133 service_name_hash = pw_rpc.ids.calculate(service.proto_path()) 134 output.write_line('\n private:') 135 136 with output.indent(): 137 output.write_line(f'// Hash of "{service.proto_path()}".') 138 output.write_line( 139 f'static constexpr uint32_t kServiceId = 0x{service_name_hash:08x};' 140 ) 141 142 output.write_line('};') 143 144 output.write_line('\n} // namespace nanopb\n') 145 146 147def includes(proto_file, unused_package: ProtoNode) -> Iterator[str]: 148 yield '#include "pw_rpc/internal/nanopb_method_union.h"' 149 yield '#include "pw_rpc/nanopb_client_call.h"' 150 151 # Include the corresponding nanopb header file for this proto file, in which 152 # the file's messages and enums are generated. All other files imported from 153 # the .proto file are #included in there. 154 nanopb_header = _proto_filename_to_nanopb_header(proto_file.name) 155 yield f'#include "{nanopb_header}"' 156 157 158def _generate_code_for_package(proto_file, package: ProtoNode, 159 output: OutputFile) -> None: 160 """Generates code for a header file corresponding to a .proto file.""" 161 162 codegen.package(proto_file, package, output, includes, 163 _generate_code_for_service, _generate_code_for_client) 164 165 166class StubGenerator(codegen.StubGenerator): 167 def unary_signature(self, method: ProtoServiceMethod, prefix: str) -> str: 168 return (f'pw::Status {prefix}{method.name()}(ServerContext&, ' 169 f'const {method.request_type().nanopb_name()}& request, ' 170 f'{method.response_type().nanopb_name()}& response)') 171 172 def unary_stub(self, method: ProtoServiceMethod, 173 output: OutputFile) -> None: 174 output.write_line(codegen.STUB_REQUEST_TODO) 175 output.write_line('static_cast<void>(request);') 176 output.write_line(codegen.STUB_RESPONSE_TODO) 177 output.write_line('static_cast<void>(response);') 178 output.write_line('return pw::Status::Unimplemented();') 179 180 def server_streaming_signature(self, method: ProtoServiceMethod, 181 prefix: str) -> str: 182 return ( 183 f'void {prefix}{method.name()}(ServerContext&, ' 184 f'const {method.request_type().nanopb_name()}& request, ' 185 f'ServerWriter<{method.response_type().nanopb_name()}>& writer)') 186 187 188def process_proto_file(proto_file) -> Iterable[OutputFile]: 189 """Generates code for a single .proto file.""" 190 191 _, package_root = build_node_tree(proto_file) 192 output_filename = _proto_filename_to_generated_header(proto_file.name) 193 output_file = OutputFile(output_filename) 194 _generate_code_for_package(proto_file, package_root, output_file) 195 196 output_file.write_line() 197 codegen.package_stubs(package_root, output_file, StubGenerator()) 198 199 return [output_file] 200