• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""This module generates the code for pw_protobuf pw_rpc services."""
15
16import os
17from typing import Iterable
18
19from pw_protobuf.output_file import OutputFile
20from pw_protobuf.proto_tree import ProtoServiceMethod
21from pw_protobuf.proto_tree import build_node_tree
22from pw_rpc import codegen
23from pw_rpc.codegen import (
24    client_call_type,
25    get_id,
26    CodeGenerator,
27    RPC_NAMESPACE,
28)
29
30PROTO_H_EXTENSION = '.pwpb.h'
31PWPB_H_EXTENSION = '.pwpb.h'
32
33
34def _proto_filename_to_pwpb_header(proto_file: str) -> str:
35    """Returns the generated pwpb header name for a .proto file."""
36    filename = os.path.splitext(proto_file)[0]
37    return f'{filename}{PWPB_H_EXTENSION}'
38
39
40def _proto_filename_to_generated_header(proto_file: str) -> str:
41    """Returns the generated C++ RPC header name for a .proto file."""
42    filename = os.path.splitext(proto_file)[0]
43    return f'{filename}.rpc{PROTO_H_EXTENSION}'
44
45
46def _serde(method: ProtoServiceMethod) -> str:
47    """Returns the PwpbMethodSerde for this method."""
48    return (
49        f'{RPC_NAMESPACE}::internal::kPwpbMethodSerde<'
50        f'&{method.request_type().pwpb_table()}, '
51        f'&{method.response_type().pwpb_table()}>'
52    )
53
54
55def _client_call(method: ProtoServiceMethod) -> str:
56    template_args = []
57
58    if method.client_streaming():
59        template_args.append(method.request_type().pwpb_struct())
60
61    template_args.append(method.response_type().pwpb_struct())
62
63    return f'{client_call_type(method, "Pwpb")}<{", ".join(template_args)}>'
64
65
66def _function(method: ProtoServiceMethod) -> str:
67    return f'{_client_call(method)} {method.name()}'
68
69
70def _user_args(method: ProtoServiceMethod) -> Iterable[str]:
71    if not method.client_streaming():
72        yield f'const {method.request_type().pwpb_struct()}& request'
73
74    response = method.response_type().pwpb_struct()
75
76    if method.server_streaming():
77        yield f'::pw::Function<void(const {response}&)>&& on_next = nullptr'
78        yield '::pw::Function<void(::pw::Status)>&& on_completed = nullptr'
79    else:
80        yield (
81            f'::pw::Function<void(const {response}&, ::pw::Status)>&& '
82            'on_completed = nullptr'
83        )
84
85    yield '::pw::Function<void(::pw::Status)>&& on_error = nullptr'
86
87
88class PwpbCodeGenerator(CodeGenerator):
89    """Generates an RPC service and client using the pw_protobuf API."""
90
91    def name(self) -> str:
92        return 'pwpb'
93
94    def method_union_name(self) -> str:
95        return 'PwpbMethodUnion'
96
97    def includes(self, proto_file_name: str) -> Iterable[str]:
98        yield '#include "pw_rpc/pwpb/client_reader_writer.h"'
99        yield '#include "pw_rpc/pwpb/internal/method_union.h"'
100        yield '#include "pw_rpc/pwpb/server_reader_writer.h"'
101
102        # Include the corresponding pwpb header file for this proto file, in
103        # which the file's messages and enums are generated. All other files
104        # imported from the .proto file are #included in there.
105        pwpb_header = _proto_filename_to_pwpb_header(proto_file_name)
106        yield f'#include "{pwpb_header}"'
107
108    def service_aliases(self) -> None:
109        self.line('template <typename Response>')
110        self.line(
111            'using ServerWriter = '
112            f'{RPC_NAMESPACE}::PwpbServerWriter<Response>;'
113        )
114        self.line('template <typename Request, typename Response>')
115        self.line(
116            'using ServerReader = '
117            f'{RPC_NAMESPACE}::PwpbServerReader<Request, Response>;'
118        )
119        self.line('template <typename Request, typename Response>')
120        self.line(
121            'using ServerReaderWriter = '
122            f'{RPC_NAMESPACE}::PwpbServerReaderWriter<Request, Response>;'
123        )
124
125    def method_descriptor(self, method: ProtoServiceMethod) -> None:
126        impl_method = f'&Implementation::{method.name()}'
127
128        self.line(
129            f'{RPC_NAMESPACE}::internal::GetPwpbOrRawMethodFor<{impl_method}, '
130            f'{method.type().cc_enum()}, '
131            f'{method.request_type().pwpb_struct()}, '
132            f'{method.response_type().pwpb_struct()}>('
133        )
134        with self.indent(4):
135            self.line(f'{get_id(method)},  // Hash of "{method.name()}"')
136            self.line(f'{_serde(method)}),')
137
138    def client_member_function(self, method: ProtoServiceMethod) -> None:
139        """Outputs client code for a single RPC method."""
140
141        self.line(f'{_function(method)}(')
142        self.indented_list(*_user_args(method), end=') const {')
143
144        with self.indent():
145            client_call = _client_call(method)
146            base = 'Stream' if method.server_streaming() else 'Unary'
147            self.line(
148                f'return {RPC_NAMESPACE}::internal::'
149                f'Pwpb{base}ResponseClientCall<'
150                f'{method.response_type().pwpb_struct()}>::'
151                f'Start<{client_call}>('
152            )
153
154            service_client = RPC_NAMESPACE + '::internal::ServiceClient'
155
156            args = [
157                f'{service_client}::client()',
158                f'{service_client}::channel_id()',
159                'kServiceId',
160                get_id(method),
161                _serde(method),
162            ]
163            if method.server_streaming():
164                args.append('std::move(on_next)')
165
166            args.append('std::move(on_completed)')
167            args.append('std::move(on_error)')
168
169            if not method.client_streaming():
170                args.append('request')
171
172            self.indented_list(*args, end=');')
173
174        self.line('}')
175
176    def client_static_function(self, method: ProtoServiceMethod) -> None:
177        self.line(f'static {_function(method)}(')
178        self.indented_list(
179            f'{RPC_NAMESPACE}::Client& client',
180            'uint32_t channel_id',
181            *_user_args(method),
182            end=') {',
183        )
184
185        with self.indent():
186            self.line(f'return Client(client, channel_id).{method.name()}(')
187
188            args = []
189
190            if not method.client_streaming():
191                args.append('request')
192
193            if method.server_streaming():
194                args.append('std::move(on_next)')
195
196            self.indented_list(
197                *args,
198                'std::move(on_completed)',
199                'std::move(on_error)',
200                end=');',
201            )
202
203        self.line('}')
204
205    def method_info_specialization(self, method: ProtoServiceMethod) -> None:
206        self.line()
207        self.line(f'using Request = {method.request_type().pwpb_struct()};')
208        self.line(f'using Response = {method.response_type().pwpb_struct()};')
209        self.line()
210        self.line(
211            f'static constexpr const {RPC_NAMESPACE}::'
212            'PwpbMethodSerde& serde() {'
213        )
214        with self.indent():
215            self.line(f'return {_serde(method)};')
216        self.line('}')
217
218
219class StubGenerator(codegen.StubGenerator):
220    """Generates pw_protobuf RPC stubs."""
221
222    def unary_signature(self, method: ProtoServiceMethod, prefix: str) -> str:
223        return (
224            f'::pw::Status {prefix}{method.name()}( '
225            f'const {method.request_type().pwpb_struct()}& request, '
226            f'{method.response_type().pwpb_struct()}& response)'
227        )
228
229    def unary_stub(
230        self, method: ProtoServiceMethod, output: OutputFile
231    ) -> None:
232        output.write_line(codegen.STUB_REQUEST_TODO)
233        output.write_line('static_cast<void>(request);')
234        output.write_line(codegen.STUB_RESPONSE_TODO)
235        output.write_line('static_cast<void>(response);')
236        output.write_line('return ::pw::Status::Unimplemented();')
237
238    def server_streaming_signature(
239        self, method: ProtoServiceMethod, prefix: str
240    ) -> str:
241        return (
242            f'void {prefix}{method.name()}( '
243            f'const {method.request_type().pwpb_struct()}& request, '
244            f'ServerWriter<{method.response_type().pwpb_struct()}>& writer)'
245        )
246
247    def client_streaming_signature(
248        self, method: ProtoServiceMethod, prefix: str
249    ) -> str:
250        return (
251            f'void {prefix}{method.name()}( '
252            f'ServerReader<{method.request_type().pwpb_struct()}, '
253            f'{method.response_type().pwpb_struct()}>& reader)'
254        )
255
256    def bidirectional_streaming_signature(
257        self, method: ProtoServiceMethod, prefix: str
258    ) -> str:
259        return (
260            f'void {prefix}{method.name()}( '
261            f'ServerReaderWriter<{method.request_type().pwpb_struct()}, '
262            f'{method.response_type().pwpb_struct()}>& reader_writer)'
263        )
264
265
266def process_proto_file(proto_file) -> Iterable[OutputFile]:
267    """Generates code for a single .proto file."""
268
269    _, package_root = build_node_tree(proto_file)
270    output_filename = _proto_filename_to_generated_header(proto_file.name)
271
272    generator = PwpbCodeGenerator(output_filename)
273    codegen.generate_package(proto_file, package_root, generator)
274
275    codegen.package_stubs(package_root, generator, StubGenerator())
276
277    return [generator.output]
278