• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""This module generates the code for nanopb-based pw_rpc services."""
15
16import os
17from typing import Iterable, NamedTuple, Optional
18
19from pw_protobuf.output_file import OutputFile
20from pw_protobuf.proto_tree import ProtoServiceMethod
21from pw_protobuf.proto_tree import build_node_tree
22from pw_rpc import codegen
23from pw_rpc.codegen import (client_call_type, get_id, CodeGenerator,
24                            RPC_NAMESPACE)
25
26PROTO_H_EXTENSION = '.pb.h'
27PROTO_CC_EXTENSION = '.pb.cc'
28NANOPB_H_EXTENSION = '.pb.h'
29
30
31def _serde(method: ProtoServiceMethod) -> str:
32    """Returns the NanopbMethodSerde for this method."""
33    return (f'{RPC_NAMESPACE}::internal::kNanopbMethodSerde<'
34            f'{method.request_type().nanopb_fields()}, '
35            f'{method.response_type().nanopb_fields()}>')
36
37
38def _proto_filename_to_nanopb_header(proto_file: str) -> str:
39    """Returns the generated nanopb header name for a .proto file."""
40    return os.path.splitext(proto_file)[0] + NANOPB_H_EXTENSION
41
42
43def _proto_filename_to_generated_header(proto_file: str) -> str:
44    """Returns the generated C++ RPC header name for a .proto file."""
45    filename = os.path.splitext(proto_file)[0]
46    return f'{filename}.rpc{PROTO_H_EXTENSION}'
47
48
49def _client_call(method: ProtoServiceMethod) -> str:
50    template_args = []
51
52    if method.client_streaming():
53        template_args.append(method.request_type().nanopb_struct())
54
55    template_args.append(method.response_type().nanopb_struct())
56
57    return f'{client_call_type(method, "Nanopb")}<{", ".join(template_args)}>'
58
59
60def _function(method: ProtoServiceMethod) -> str:
61    return f'{_client_call(method)} {method.name()}'
62
63
64def _user_args(method: ProtoServiceMethod) -> Iterable[str]:
65    if not method.client_streaming():
66        yield f'const {method.request_type().nanopb_struct()}& request'
67
68    response = method.response_type().nanopb_struct()
69
70    if method.server_streaming():
71        yield f'::pw::Function<void(const {response}&)>&& on_next = nullptr'
72        yield '::pw::Function<void(::pw::Status)>&& on_completed = nullptr'
73    else:
74        yield (f'::pw::Function<void(const {response}&, ::pw::Status)>&& '
75               'on_completed = nullptr')
76
77    yield '::pw::Function<void(::pw::Status)>&& on_error = nullptr'
78
79
80class NanopbCodeGenerator(CodeGenerator):
81    """Generates an RPC service and client using the Nanopb API."""
82    def name(self) -> str:
83        return 'nanopb'
84
85    def method_union_name(self) -> str:
86        return 'NanopbMethodUnion'
87
88    def includes(self, proto_file_name: str) -> Iterable[str]:
89        yield '#include "pw_rpc/nanopb/client_reader_writer.h"'
90        yield '#include "pw_rpc/nanopb/internal/method_union.h"'
91        yield '#include "pw_rpc/nanopb/server_reader_writer.h"'
92
93        # Include the corresponding nanopb header file for this proto file, in
94        # which the file's messages and enums are generated. All other files
95        # imported from the .proto file are #included in there.
96        nanopb_header = _proto_filename_to_nanopb_header(proto_file_name)
97        yield f'#include "{nanopb_header}"'
98
99    def service_aliases(self) -> None:
100        self.line('template <typename Response>')
101        self.line('using ServerWriter = '
102                  f'{RPC_NAMESPACE}::NanopbServerWriter<Response>;')
103        self.line('template <typename Request, typename Response>')
104        self.line('using ServerReader = '
105                  f'{RPC_NAMESPACE}::NanopbServerReader<Request, Response>;')
106        self.line('template <typename Request, typename Response>')
107        self.line(
108            'using ServerReaderWriter = '
109            f'{RPC_NAMESPACE}::NanopbServerReaderWriter<Request, Response>;')
110
111    def method_descriptor(self, method: ProtoServiceMethod) -> None:
112        self.line(f'{RPC_NAMESPACE}::internal::'
113                  f'GetNanopbOrRawMethodFor<&Implementation::{method.name()}, '
114                  f'{method.type().cc_enum()}, '
115                  f'{method.request_type().nanopb_struct()}, '
116                  f'{method.response_type().nanopb_struct()}>(')
117        with self.indent(4):
118            self.line(f'{get_id(method)},  // Hash of "{method.name()}"')
119            self.line(f'{_serde(method)}),')
120
121    def client_member_function(self, method: ProtoServiceMethod) -> None:
122        """Outputs client code for a single RPC method."""
123
124        self.line(f'{_function(method)}(')
125        self.indented_list(*_user_args(method), end=') const {')
126
127        with self.indent():
128            client_call = _client_call(method)
129            base = 'Stream' if method.server_streaming() else 'Unary'
130            self.line(f'return {RPC_NAMESPACE}::internal::'
131                      f'Nanopb{base}ResponseClientCall<'
132                      f'{method.response_type().nanopb_struct()}>::'
133                      f'Start<{client_call}>(')
134
135            service_client = RPC_NAMESPACE + '::internal::ServiceClient'
136
137            args = [
138                f'{service_client}::client()',
139                f'{service_client}::channel_id()',
140                'kServiceId',
141                get_id(method),
142                _serde(method),
143            ]
144            if method.server_streaming():
145                args.append('std::move(on_next)')
146
147            args.append('std::move(on_completed)')
148            args.append('std::move(on_error)')
149
150            if not method.client_streaming():
151                args.append('request')
152
153            self.indented_list(*args, end=');')
154
155        self.line('}')
156
157    def client_static_function(self, method: ProtoServiceMethod) -> None:
158        self.line(f'static {_function(method)}(')
159        self.indented_list(f'{RPC_NAMESPACE}::Client& client',
160                           'uint32_t channel_id',
161                           *_user_args(method),
162                           end=') {')
163
164        with self.indent():
165            self.line(f'return Client(client, channel_id).{method.name()}(')
166
167            args = []
168
169            if not method.client_streaming():
170                args.append('request')
171
172            if method.server_streaming():
173                args.append('std::move(on_next)')
174
175            self.indented_list(*args,
176                               'std::move(on_completed)',
177                               'std::move(on_error)',
178                               end=');')
179
180        self.line('}')
181
182    def method_info_specialization(self, method: ProtoServiceMethod) -> None:
183        self.line()
184        self.line(f'using Request = {method.request_type().nanopb_struct()};')
185        self.line(
186            f'using Response = {method.response_type().nanopb_struct()};')
187        self.line()
188        self.line(f'static constexpr const {RPC_NAMESPACE}::internal::'
189                  'NanopbMethodSerde& serde() {')
190        with self.indent():
191            self.line(f'return {_serde(method)};')
192        self.line('}')
193
194
195class _CallbackFunction(NamedTuple):
196    """Represents a callback function parameter in a client RPC call."""
197
198    function_type: str
199    name: str
200    default_value: Optional[str] = None
201
202    def __str__(self):
203        param = f'::pw::Function<{self.function_type}>&& {self.name}'
204        if self.default_value:
205            param += f' = {self.default_value}'
206        return param
207
208
209class StubGenerator(codegen.StubGenerator):
210    """Generates Nanopb RPC stubs."""
211    def unary_signature(self, method: ProtoServiceMethod, prefix: str) -> str:
212        return (f'::pw::Status {prefix}{method.name()}( '
213                f'const {method.request_type().nanopb_struct()}& request, '
214                f'{method.response_type().nanopb_struct()}& response)')
215
216    def unary_stub(self, method: ProtoServiceMethod,
217                   output: OutputFile) -> None:
218        output.write_line(codegen.STUB_REQUEST_TODO)
219        output.write_line('static_cast<void>(request);')
220        output.write_line(codegen.STUB_RESPONSE_TODO)
221        output.write_line('static_cast<void>(response);')
222        output.write_line('return ::pw::Status::Unimplemented();')
223
224    def server_streaming_signature(self, method: ProtoServiceMethod,
225                                   prefix: str) -> str:
226        return (
227            f'void {prefix}{method.name()}( '
228            f'const {method.request_type().nanopb_struct()}& request, '
229            f'ServerWriter<{method.response_type().nanopb_struct()}>& writer)')
230
231    def client_streaming_signature(self, method: ProtoServiceMethod,
232                                   prefix: str) -> str:
233        return (f'void {prefix}{method.name()}( '
234                f'ServerReader<{method.request_type().nanopb_struct()}, '
235                f'{method.response_type().nanopb_struct()}>& reader)')
236
237    def bidirectional_streaming_signature(self, method: ProtoServiceMethod,
238                                          prefix: str) -> str:
239        return (f'void {prefix}{method.name()}( '
240                f'ServerReaderWriter<{method.request_type().nanopb_struct()}, '
241                f'{method.response_type().nanopb_struct()}>& reader_writer)')
242
243
244def process_proto_file(proto_file) -> Iterable[OutputFile]:
245    """Generates code for a single .proto file."""
246
247    _, package_root = build_node_tree(proto_file)
248    output_filename = _proto_filename_to_generated_header(proto_file.name)
249    generator = NanopbCodeGenerator(output_filename)
250    codegen.generate_package(proto_file, package_root, generator)
251
252    codegen.package_stubs(package_root, generator, StubGenerator())
253
254    return [generator.output]
255