• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Common RPC codegen utilities."""
15
16import abc
17from datetime import datetime
18import os
19from typing import cast, Any, Callable, Iterable
20
21from pw_protobuf.output_file import OutputFile
22from pw_protobuf.proto_tree import ProtoNode, ProtoService, ProtoServiceMethod
23
24import pw_rpc.ids
25
26PLUGIN_NAME = 'pw_rpc_codegen'
27PLUGIN_VERSION = '0.2.0'
28
29RPC_NAMESPACE = '::pw::rpc'
30
31STUB_REQUEST_TODO = (
32    '// TODO: Read the request as appropriate for your application')
33STUB_RESPONSE_TODO = (
34    '// TODO: Fill in the response as appropriate for your application')
35STUB_WRITER_TODO = (
36    '// TODO: Send responses with the writer as appropriate for your '
37    'application')
38
39ServerWriterGenerator = Callable[[OutputFile], None]
40MethodGenerator = Callable[[ProtoServiceMethod, int, OutputFile], None]
41ServiceGenerator = Callable[[ProtoService, ProtoNode, OutputFile], None]
42IncludesGenerator = Callable[[Any, ProtoNode], Iterable[str]]
43
44
45def package(file_descriptor_proto, proto_package: ProtoNode,
46            output: OutputFile, includes: IncludesGenerator,
47            service: ServiceGenerator, client: ServiceGenerator) -> None:
48    """Generates service and client code for a package."""
49    assert proto_package.type() == ProtoNode.Type.PACKAGE
50
51    output.write_line(f'// {os.path.basename(output.name())} automatically '
52                      f'generated by {PLUGIN_NAME} {PLUGIN_VERSION}')
53    output.write_line(f'// on {datetime.now().isoformat()}')
54    output.write_line('// clang-format off')
55    output.write_line('#pragma once\n')
56
57    output.write_line('#include <array>')
58    output.write_line('#include <cstdint>')
59    output.write_line('#include <type_traits>\n')
60
61    include_lines = [
62        '#include "pw_rpc/internal/method_lookup.h"',
63        '#include "pw_rpc/server_context.h"',
64        '#include "pw_rpc/service.h"',
65    ]
66    include_lines += includes(file_descriptor_proto, proto_package)
67
68    for include_line in sorted(include_lines):
69        output.write_line(include_line)
70
71    output.write_line()
72
73    if proto_package.cpp_namespace():
74        file_namespace = proto_package.cpp_namespace()
75        if file_namespace.startswith('::'):
76            file_namespace = file_namespace[2:]
77
78        output.write_line(f'namespace {file_namespace} {{')
79
80    for node in proto_package:
81        if node.type() == ProtoNode.Type.SERVICE:
82            service(cast(ProtoService, node), proto_package, output)
83            client(cast(ProtoService, node), proto_package, output)
84
85    if proto_package.cpp_namespace():
86        output.write_line(f'}}  // namespace {file_namespace}')
87
88
89def service_class(service: ProtoService, root: ProtoNode, output: OutputFile,
90                  server_writer_alias: ServerWriterGenerator,
91                  method_union: str,
92                  method_descriptor: MethodGenerator) -> None:
93    """Generates a C++ derived class for a nanopb RPC service."""
94
95    output.write_line('namespace generated {')
96
97    base_class = f'{RPC_NAMESPACE}::Service'
98    output.write_line('\ntemplate <typename Implementation>')
99    output.write_line(
100        f'class {service.cpp_namespace(root)} : public {base_class} {{')
101    output.write_line(' public:')
102
103    with output.indent():
104        output.write_line(
105            f'using ServerContext = {RPC_NAMESPACE}::ServerContext;')
106        server_writer_alias(output)
107        output.write_line()
108
109        output.write_line(f'constexpr {service.name()}()')
110        output.write_line(f'    : {base_class}(kServiceId, kMethods) {{}}')
111
112        output.write_line()
113        output.write_line(
114            f'{service.name()}(const {service.name()}&) = delete;')
115        output.write_line(f'{service.name()}& operator='
116                          f'(const {service.name()}&) = delete;')
117
118        output.write_line()
119        output.write_line(f'static constexpr const char* name() '
120                          f'{{ return "{service.name()}"; }}')
121
122        output.write_line()
123        output.write_line(
124            '// Used by MethodLookup to identify the generated service base.')
125        output.write_line(
126            'constexpr void _PwRpcInternalGeneratedBase() const {}')
127
128    service_name_hash = pw_rpc.ids.calculate(service.proto_path())
129    output.write_line('\n private:')
130
131    with output.indent():
132        output.write_line('friend class ::pw::rpc::internal::MethodLookup;\n')
133        output.write_line(f'// Hash of "{service.proto_path()}".')
134        output.write_line(
135            f'static constexpr uint32_t kServiceId = 0x{service_name_hash:08x};'
136        )
137
138        output.write_line()
139
140        # Generate the method table
141        output.write_line('static constexpr std::array<'
142                          f'{RPC_NAMESPACE}::internal::{method_union},'
143                          f' {len(service.methods())}> kMethods = {{')
144
145        with output.indent(4):
146            for method in service.methods():
147                method_descriptor(method, pw_rpc.ids.calculate(method.name()),
148                                  output)
149
150        output.write_line('};\n')
151
152        # Generate the method lookup table
153        _method_lookup_table(service, output)
154
155    output.write_line('};')
156
157    output.write_line('\n}  // namespace generated\n')
158
159
160def _method_lookup_table(service: ProtoService, output: OutputFile) -> None:
161    """Generates array of method IDs for looking up methods at compile time."""
162    output.write_line('static constexpr std::array<uint32_t, '
163                      f'{len(service.methods())}> kMethodIds = {{')
164
165    with output.indent(4):
166        for method in service.methods():
167            method_id = pw_rpc.ids.calculate(method.name())
168            output.write_line(
169                f'0x{method_id:08x},  // Hash of "{method.name()}"')
170
171    output.write_line('};\n')
172
173
174class StubGenerator(abc.ABC):
175    @abc.abstractmethod
176    def unary_signature(self, method: ProtoServiceMethod, prefix: str) -> str:
177        """Returns the signature of this unary method."""
178
179    @abc.abstractmethod
180    def unary_stub(self, method: ProtoServiceMethod,
181                   output: OutputFile) -> None:
182        """Returns the stub for this unary method."""
183
184    @abc.abstractmethod
185    def server_streaming_signature(self, method: ProtoServiceMethod,
186                                   prefix: str) -> str:
187        """Returns the signature of this server streaming method."""
188
189    def server_streaming_stub(  # pylint: disable=no-self-use
190            self, unused_method: ProtoServiceMethod,
191            output: OutputFile) -> None:
192        """Returns the stub for this server streaming method."""
193        output.write_line(STUB_REQUEST_TODO)
194        output.write_line('static_cast<void>(request);')
195        output.write_line(STUB_WRITER_TODO)
196        output.write_line('static_cast<void>(writer);')
197
198
199def _select_stub_methods(generator: StubGenerator, method: ProtoServiceMethod):
200    if method.type() is ProtoServiceMethod.Type.UNARY:
201        return generator.unary_signature, generator.unary_stub
202
203    if method.type() is ProtoServiceMethod.Type.SERVER_STREAMING:
204        return (generator.server_streaming_signature,
205                generator.server_streaming_stub)
206
207    raise NotImplementedError(
208        'Client and bidirectional streaming not yet implemented')
209
210
211_STUBS_COMMENT = r'''
212/*
213    ____                __                          __        __  _
214   /  _/___ ___  ____  / /__  ____ ___  ___  ____  / /_____ _/ /_(_)___  ____
215   / // __ `__ \/ __ \/ / _ \/ __ `__ \/ _ \/ __ \/ __/ __ `/ __/ / __ \/ __ \
216 _/ // / / / / / /_/ / /  __/ / / / / /  __/ / / / /_/ /_/ / /_/ / /_/ / / / /
217/___/_/ /_/ /_/ .___/_/\___/_/ /_/ /_/\___/_/ /_/\__/\__,_/\__/_/\____/_/ /_/
218             /_/
219   _____ __        __         __
220  / ___// /___  __/ /_  _____/ /
221  \__ \/ __/ / / / __ \/ ___/ /
222 ___/ / /_/ /_/ / /_/ (__  )_/
223/____/\__/\__,_/_.___/____(_)
224
225*/
226// This section provides stub implementations of the RPC services in this file.
227// The code below may be referenced or copied to serve as a starting point for
228// your RPC service implementations.
229'''
230
231
232def package_stubs(proto_package: ProtoNode, output: OutputFile,
233                  stub_generator: StubGenerator) -> None:
234    """Generates the RPC stubs for a package."""
235    if proto_package.cpp_namespace():
236        file_ns = proto_package.cpp_namespace()
237        if file_ns.startswith('::'):
238            file_ns = file_ns[2:]
239
240        start_ns = lambda: output.write_line(f'namespace {file_ns} {{\n')
241        finish_ns = lambda: output.write_line(f'}}  // namespace {file_ns}\n')
242    else:
243        start_ns = finish_ns = lambda: None
244
245    services = [
246        cast(ProtoService, node) for node in proto_package
247        if node.type() == ProtoNode.Type.SERVICE
248    ]
249
250    output.write_line('#ifdef _PW_RPC_COMPILE_GENERATED_SERVICE_STUBS')
251    output.write_line(_STUBS_COMMENT)
252
253    output.write_line(f'#include "{output.name()}"\n')
254
255    start_ns()
256
257    for node in services:
258        _generate_service_class(node, output, stub_generator)
259
260    output.write_line()
261
262    finish_ns()
263
264    start_ns()
265
266    for node in services:
267        _generate_service_stubs(node, output, stub_generator)
268        output.write_line()
269
270    finish_ns()
271
272    output.write_line('#endif  // _PW_RPC_COMPILE_GENERATED_SERVICE_STUBS')
273
274
275def _generate_service_class(service: ProtoService, output: OutputFile,
276                            stub_generator: StubGenerator) -> None:
277    output.write_line(f'// Implementation class for {service.proto_path()}.')
278    output.write_line(
279        f'class {service.name()} '
280        f': public generated::{service.name()}<{service.name()}> {{')
281
282    output.write_line(' public:')
283
284    with output.indent():
285        blank_line = False
286
287        for method in service.methods():
288            if blank_line:
289                output.write_line()
290            else:
291                blank_line = True
292
293            signature, _ = _select_stub_methods(stub_generator, method)
294
295            output.write_line(signature(method, '') + ';')
296
297    output.write_line('};\n')
298
299
300def _generate_service_stubs(service: ProtoService, output: OutputFile,
301                            stub_generator: StubGenerator) -> None:
302    output.write_line(f'// Method definitions for {service.proto_path()}.')
303
304    blank_line = False
305
306    for method in service.methods():
307        if blank_line:
308            output.write_line()
309        else:
310            blank_line = True
311
312        signature, stub = _select_stub_methods(stub_generator, method)
313
314        output.write_line(signature(method, f'{service.name()}::') + ' {')
315        with output.indent():
316            stub(method, output)
317        output.write_line('}')
318