• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""This module generates the code for nanopb-based pw_rpc services."""
15
16import os
17from typing import Iterable, NamedTuple, Optional
18
19from pw_protobuf.output_file import OutputFile
20from pw_protobuf.proto_tree import ProtoServiceMethod
21from pw_protobuf.proto_tree import build_node_tree
22from pw_rpc import codegen
23from pw_rpc.codegen import (
24    client_call_type,
25    get_id,
26    CodeGenerator,
27    RPC_NAMESPACE,
28)
29
30PROTO_H_EXTENSION = '.pb.h'
31PROTO_CC_EXTENSION = '.pb.cc'
32NANOPB_H_EXTENSION = '.pb.h'
33
34
35def _serde(method: ProtoServiceMethod) -> str:
36    """Returns the NanopbMethodSerde for this method."""
37    return (
38        f'{RPC_NAMESPACE}::internal::kNanopbMethodSerde<'
39        f'{method.request_type().nanopb_fields()}, '
40        f'{method.response_type().nanopb_fields()}>'
41    )
42
43
44def _proto_filename_to_nanopb_header(proto_file: str) -> str:
45    """Returns the generated nanopb header name for a .proto file."""
46    return os.path.splitext(proto_file)[0] + NANOPB_H_EXTENSION
47
48
49def _proto_filename_to_generated_header(proto_file: str) -> str:
50    """Returns the generated C++ RPC header name for a .proto file."""
51    filename = os.path.splitext(proto_file)[0]
52    return f'{filename}.rpc{PROTO_H_EXTENSION}'
53
54
55def _client_call(method: ProtoServiceMethod) -> str:
56    template_args = []
57
58    if method.client_streaming():
59        template_args.append(method.request_type().nanopb_struct())
60
61    template_args.append(method.response_type().nanopb_struct())
62
63    return f'{client_call_type(method, "Nanopb")}<{", ".join(template_args)}>'
64
65
66def _function(method: ProtoServiceMethod) -> str:
67    return f'{_client_call(method)} {method.name()}'
68
69
70def _user_args(method: ProtoServiceMethod) -> Iterable[str]:
71    if not method.client_streaming():
72        yield f'const {method.request_type().nanopb_struct()}& request'
73
74    response = method.response_type().nanopb_struct()
75
76    if method.server_streaming():
77        yield f'::pw::Function<void(const {response}&)>&& on_next = nullptr'
78        yield '::pw::Function<void(::pw::Status)>&& on_completed = nullptr'
79    else:
80        yield (
81            f'::pw::Function<void(const {response}&, ::pw::Status)>&& '
82            'on_completed = nullptr'
83        )
84
85    yield '::pw::Function<void(::pw::Status)>&& on_error = nullptr'
86
87
88class NanopbCodeGenerator(CodeGenerator):
89    """Generates an RPC service and client using the Nanopb API."""
90
91    def name(self) -> str:
92        return 'nanopb'
93
94    def method_union_name(self) -> str:
95        return 'NanopbMethodUnion'
96
97    def includes(self, proto_file_name: str) -> Iterable[str]:
98        yield '#include "pw_rpc/nanopb/client_reader_writer.h"'
99        yield '#include "pw_rpc/nanopb/internal/method_union.h"'
100        yield '#include "pw_rpc/nanopb/server_reader_writer.h"'
101
102        # Include the corresponding nanopb header file for this proto file, in
103        # which the file's messages and enums are generated. All other files
104        # imported from the .proto file are #included in there.
105        nanopb_header = _proto_filename_to_nanopb_header(proto_file_name)
106        yield f'#include "{nanopb_header}"'
107
108    def service_aliases(self) -> None:
109        self.line('template <typename Response>')
110        self.line(
111            'using ServerWriter = '
112            f'{RPC_NAMESPACE}::NanopbServerWriter<Response>;'
113        )
114        self.line('template <typename Request, typename Response>')
115        self.line(
116            'using ServerReader = '
117            f'{RPC_NAMESPACE}::NanopbServerReader<Request, Response>;'
118        )
119        self.line('template <typename Request, typename Response>')
120        self.line(
121            'using ServerReaderWriter = '
122            f'{RPC_NAMESPACE}::NanopbServerReaderWriter<Request, Response>;'
123        )
124
125    def method_descriptor(self, method: ProtoServiceMethod) -> None:
126        self.line(
127            f'{RPC_NAMESPACE}::internal::'
128            f'GetNanopbOrRawMethodFor<&Implementation::{method.name()}, '
129            f'{method.type().cc_enum()}, '
130            f'{method.request_type().nanopb_struct()}, '
131            f'{method.response_type().nanopb_struct()}>('
132        )
133        with self.indent(4):
134            self.line(f'{get_id(method)},  // Hash of "{method.name()}"')
135            self.line(f'{_serde(method)}),')
136
137    def client_member_function(self, method: ProtoServiceMethod) -> None:
138        """Outputs client code for a single RPC method."""
139
140        self.line(f'{_function(method)}(')
141        self.indented_list(*_user_args(method), end=') const {')
142
143        with self.indent():
144            client_call = _client_call(method)
145            base = 'Stream' if method.server_streaming() else 'Unary'
146            self.line(
147                f'return {RPC_NAMESPACE}::internal::'
148                f'Nanopb{base}ResponseClientCall<'
149                f'{method.response_type().nanopb_struct()}>::'
150                f'Start<{client_call}>('
151            )
152
153            service_client = RPC_NAMESPACE + '::internal::ServiceClient'
154
155            args = [
156                f'{service_client}::client()',
157                f'{service_client}::channel_id()',
158                'kServiceId',
159                get_id(method),
160                _serde(method),
161            ]
162            if method.server_streaming():
163                args.append('std::move(on_next)')
164
165            args.append('std::move(on_completed)')
166            args.append('std::move(on_error)')
167
168            if not method.client_streaming():
169                args.append('request')
170
171            self.indented_list(*args, end=');')
172
173        self.line('}')
174
175    def client_static_function(self, method: ProtoServiceMethod) -> None:
176        self.line(f'static {_function(method)}(')
177        self.indented_list(
178            f'{RPC_NAMESPACE}::Client& client',
179            'uint32_t channel_id',
180            *_user_args(method),
181            end=') {',
182        )
183
184        with self.indent():
185            self.line(f'return Client(client, channel_id).{method.name()}(')
186
187            args = []
188
189            if not method.client_streaming():
190                args.append('request')
191
192            if method.server_streaming():
193                args.append('std::move(on_next)')
194
195            self.indented_list(
196                *args,
197                'std::move(on_completed)',
198                'std::move(on_error)',
199                end=');',
200            )
201
202        self.line('}')
203
204    def method_info_specialization(self, method: ProtoServiceMethod) -> None:
205        self.line()
206        self.line(f'using Request = {method.request_type().nanopb_struct()};')
207        self.line(f'using Response = {method.response_type().nanopb_struct()};')
208        self.line()
209        self.line(
210            f'static constexpr const {RPC_NAMESPACE}::internal::'
211            'NanopbMethodSerde& serde() {'
212        )
213        with self.indent():
214            self.line(f'return {_serde(method)};')
215        self.line('}')
216
217
218class _CallbackFunction(NamedTuple):
219    """Represents a callback function parameter in a client RPC call."""
220
221    function_type: str
222    name: str
223    default_value: Optional[str] = None
224
225    def __str__(self):
226        param = f'::pw::Function<{self.function_type}>&& {self.name}'
227        if self.default_value:
228            param += f' = {self.default_value}'
229        return param
230
231
232class StubGenerator(codegen.StubGenerator):
233    """Generates Nanopb RPC stubs."""
234
235    def unary_signature(self, method: ProtoServiceMethod, prefix: str) -> str:
236        return (
237            f'::pw::Status {prefix}{method.name()}( '
238            f'const {method.request_type().nanopb_struct()}& request, '
239            f'{method.response_type().nanopb_struct()}& response)'
240        )
241
242    def unary_stub(
243        self, method: ProtoServiceMethod, output: OutputFile
244    ) -> None:
245        output.write_line(codegen.STUB_REQUEST_TODO)
246        output.write_line('static_cast<void>(request);')
247        output.write_line(codegen.STUB_RESPONSE_TODO)
248        output.write_line('static_cast<void>(response);')
249        output.write_line('return ::pw::Status::Unimplemented();')
250
251    def server_streaming_signature(
252        self, method: ProtoServiceMethod, prefix: str
253    ) -> str:
254        return (
255            f'void {prefix}{method.name()}( '
256            f'const {method.request_type().nanopb_struct()}& request, '
257            f'ServerWriter<{method.response_type().nanopb_struct()}>& writer)'
258        )
259
260    def client_streaming_signature(
261        self, method: ProtoServiceMethod, prefix: str
262    ) -> str:
263        return (
264            f'void {prefix}{method.name()}( '
265            f'ServerReader<{method.request_type().nanopb_struct()}, '
266            f'{method.response_type().nanopb_struct()}>& reader)'
267        )
268
269    def bidirectional_streaming_signature(
270        self, method: ProtoServiceMethod, prefix: str
271    ) -> str:
272        return (
273            f'void {prefix}{method.name()}( '
274            f'ServerReaderWriter<{method.request_type().nanopb_struct()}, '
275            f'{method.response_type().nanopb_struct()}>& reader_writer)'
276        )
277
278
279def process_proto_file(proto_file) -> Iterable[OutputFile]:
280    """Generates code for a single .proto file."""
281
282    _, package_root = build_node_tree(proto_file)
283    output_filename = _proto_filename_to_generated_header(proto_file.name)
284    generator = NanopbCodeGenerator(output_filename)
285    codegen.generate_package(proto_file, package_root, generator)
286
287    codegen.package_stubs(package_root, generator, StubGenerator())
288
289    return [generator.output]
290