• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Common RPC codegen utilities."""
15
16import abc
17from datetime import datetime
18import os
19from typing import cast, Any, Iterable, Union
20
21from pw_protobuf.output_file import OutputFile
22from pw_protobuf.proto_tree import ProtoNode, ProtoService, ProtoServiceMethod
23from pw_rpc import ids
24
25PLUGIN_NAME = 'pw_rpc_codegen'
26PLUGIN_VERSION = '0.3.0'
27
28RPC_NAMESPACE = '::pw::rpc'
29
30STUB_REQUEST_TODO = (
31    '// TODO: Read the request as appropriate for your application')
32STUB_RESPONSE_TODO = (
33    '// TODO: Fill in the response as appropriate for your application')
34STUB_WRITER_TODO = (
35    '// TODO: Send responses with the writer as appropriate for your '
36    'application')
37STUB_READER_TODO = (
38    '// TODO: Set the client stream callback and send a response as '
39    'appropriate for your application')
40STUB_READER_WRITER_TODO = (
41    '// TODO: Set the client stream callback and send responses as '
42    'appropriate for your application')
43
44
45def get_id(item: Union[ProtoService, ProtoServiceMethod]) -> str:
46    name = item.proto_path() if isinstance(item, ProtoService) else item.name()
47    return f'0x{ids.calculate(name):08x}'
48
49
50def client_call_type(method: ProtoServiceMethod, prefix: str) -> str:
51    """Returns Client ReaderWriter/Reader/Writer/Recevier for the call."""
52    if method.type() is ProtoServiceMethod.Type.UNARY:
53        call_class = 'UnaryReceiver'
54    elif method.type() is ProtoServiceMethod.Type.SERVER_STREAMING:
55        call_class = 'ClientReader'
56    elif method.type() is ProtoServiceMethod.Type.CLIENT_STREAMING:
57        call_class = 'ClientWriter'
58    elif method.type() is ProtoServiceMethod.Type.BIDIRECTIONAL_STREAMING:
59        call_class = 'ClientReaderWriter'
60    else:
61        raise NotImplementedError(f'Unknown {method.type()}')
62
63    return f'{RPC_NAMESPACE}::{prefix}{call_class}'
64
65
66class CodeGenerator(abc.ABC):
67    """Generates RPC code for services and clients."""
68    def __init__(self, output_filename: str) -> None:
69        self.output = OutputFile(output_filename)
70
71    def indent(self, amount: int = OutputFile.INDENT_WIDTH) -> Any:
72        """Indents the output. Use in a with block."""
73        return self.output.indent(amount)
74
75    def line(self, value: str = '') -> None:
76        """Writes a line to the output."""
77        self.output.write_line(value)
78
79    def indented_list(self, *args: str, end: str = ',') -> None:
80        """Outputs each arg one per line; adds end to teh last arg."""
81        with self.indent(4):
82            for arg in args[:-1]:
83                self.line(arg + ',')
84
85            self.line(args[-1] + end)
86
87    @abc.abstractmethod
88    def name(self) -> str:
89        """Name of the pw_rpc implementation."""
90
91    @abc.abstractmethod
92    def method_union_name(self) -> str:
93        """Name of the MethodUnion class to use."""
94
95    @abc.abstractmethod
96    def includes(self, proto_file_name: str) -> Iterable[str]:
97        """Yields #include lines."""
98
99    @abc.abstractmethod
100    def service_aliases(self) -> None:
101        """Generates reader/writer aliases."""
102
103    @abc.abstractmethod
104    def method_descriptor(self, method: ProtoServiceMethod) -> None:
105        """Generates code for a service method."""
106
107    @abc.abstractmethod
108    def client_member_function(self, method: ProtoServiceMethod) -> None:
109        """Generates the client code for the Client member functions."""
110
111    @abc.abstractmethod
112    def client_static_function(self, method: ProtoServiceMethod) -> None:
113        """Generates method static functions that instantiate a Client."""
114
115    def method_info_specialization(self, method: ProtoServiceMethod) -> None:
116        """Generates impl-specific additions to the MethodInfo specialization.
117
118        May be empty if the generator has nothing to add to the MethodInfo.
119        """
120
121    def private_additions(self, service: ProtoService) -> None:
122        """Additions to the private section of the outer generated class."""
123
124
125def generate_package(file_descriptor_proto, proto_package: ProtoNode,
126                     gen: CodeGenerator) -> None:
127    """Generates service and client code for a package."""
128    assert proto_package.type() == ProtoNode.Type.PACKAGE
129
130    gen.line(f'// {os.path.basename(gen.output.name())} automatically '
131             f'generated by {PLUGIN_NAME} {PLUGIN_VERSION}')
132    gen.line(f'// on {datetime.now().isoformat()}')
133    gen.line('// clang-format off')
134    gen.line('#pragma once\n')
135
136    gen.line('#include <array>')
137    gen.line('#include <cstdint>')
138    gen.line('#include <type_traits>\n')
139
140    include_lines = [
141        '#include "pw_rpc/internal/method_info.h"',
142        '#include "pw_rpc/internal/method_lookup.h"',
143        '#include "pw_rpc/internal/service_client.h"',
144        '#include "pw_rpc/method_type.h"',
145        '#include "pw_rpc/service.h"',
146    ]
147    include_lines += gen.includes(file_descriptor_proto.name)
148
149    for include_line in sorted(include_lines):
150        gen.line(include_line)
151
152    gen.line()
153
154    if proto_package.cpp_namespace():
155        file_namespace = proto_package.cpp_namespace()
156        if file_namespace.startswith('::'):
157            file_namespace = file_namespace[2:]
158
159        gen.line(f'namespace {file_namespace} {{')
160    else:
161        file_namespace = ''
162
163    gen.line(f'namespace pw_rpc::{gen.name()} {{')
164    gen.line()
165
166    services = [
167        cast(ProtoService, node) for node in proto_package
168        if node.type() == ProtoNode.Type.SERVICE
169    ]
170
171    for service in services:
172        _generate_service_and_client(gen, service)
173
174    gen.line()
175    gen.line(f'}}  // namespace pw_rpc::{gen.name()}\n')
176
177    if file_namespace:
178        gen.line('}  // namespace ' + file_namespace)
179
180    gen.line()
181    gen.line('// Specialize MethodInfo for each RPC to provide metadata at '
182             'compile time.')
183    for service in services:
184        _generate_info(gen, file_namespace, service)
185
186
187def _generate_service_and_client(gen: CodeGenerator,
188                                 service: ProtoService) -> None:
189    gen.line('// Wrapper class that namespaces server and client code for '
190             'this RPC service.')
191    gen.line(f'class {service.name()} final {{')
192    gen.line(' public:')
193
194    with gen.indent():
195        gen.line(f'{service.name()}() = delete;')
196        gen.line()
197
198        _generate_service(gen, service)
199
200        gen.line()
201
202        _generate_client(gen, service)
203
204    gen.line(' private:')
205
206    with gen.indent():
207        gen.line(f'// Hash of "{service.proto_path()}".')
208        gen.line(f'static constexpr uint32_t kServiceId = {get_id(service)};')
209
210    gen.line('};')
211
212
213def _check_method_name(method: ProtoServiceMethod) -> None:
214    if method.name() in ('Service', 'Client'):
215        raise ValueError(
216            f'"{method.service().proto_path()}.{method.name()}" is not a '
217            f'valid method name! The name "{method.name()}" is reserved '
218            'for internal use by pw_rpc.')
219
220
221def _generate_client(gen: CodeGenerator, service: ProtoService) -> None:
222    gen.line('// The Client is used to invoke RPCs for this service.')
223    gen.line(f'class Client final : public {RPC_NAMESPACE}::internal::'
224             'ServiceClient {')
225    gen.line(' public:')
226
227    with gen.indent():
228        gen.line(f'constexpr Client({RPC_NAMESPACE}::Client& client,'
229                 ' uint32_t channel_id)')
230        gen.line('    : ServiceClient(client, channel_id) {}')
231
232        for method in service.methods():
233            gen.line()
234            gen.client_member_function(method)
235
236    gen.line('};')
237    gen.line()
238
239    gen.line('// Static functions for invoking RPCs on a pw_rpc server. '
240             'These functions are ')
241    gen.line('// equivalent to instantiating a Client and calling the '
242             'corresponding RPC.')
243    for method in service.methods():
244        _check_method_name(method)
245        gen.client_static_function(method)
246        gen.line()
247
248
249def _generate_info(gen: CodeGenerator, namespace: str,
250                   service: ProtoService) -> None:
251    """Generates MethodInfo for each method."""
252    service_id = get_id(service)
253    info = f'struct {RPC_NAMESPACE.lstrip(":")}::internal::MethodInfo'
254
255    for method in service.methods():
256        gen.line('template <>')
257        gen.line(f'{info}<{namespace}::pw_rpc::{gen.name()}::'
258                 f'{service.name()}::{method.name()}> {{')
259
260        with gen.indent():
261            gen.line(f'static constexpr uint32_t kServiceId = {service_id};')
262            gen.line(f'static constexpr uint32_t kMethodId = '
263                     f'{get_id(method)};')
264            gen.line(f'static constexpr {RPC_NAMESPACE}::MethodType kType = '
265                     f'{method.type().cc_enum()};')
266            gen.line()
267
268            gen.line('template <typename ServiceImpl>')
269            gen.line('static constexpr auto Function() {')
270
271            with gen.indent():
272                gen.line(f'return &ServiceImpl::{method.name()};')
273
274            gen.line('}')
275
276            gen.method_info_specialization(method)
277
278        gen.line('};')
279        gen.line()
280
281
282def _generate_service(gen: CodeGenerator, service: ProtoService) -> None:
283    """Generates a C++ class for an RPC service."""
284
285    base_class = f'{RPC_NAMESPACE}::Service'
286    gen.line('// The RPC service base class.')
287    gen.line(
288        '// Inherit from this to implement an RPC service for a pw_rpc server.'
289    )
290    gen.line('template <typename Implementation>')
291    gen.line(f'class Service : public {base_class} {{')
292    gen.line(' public:')
293
294    with gen.indent():
295        gen.service_aliases()
296
297        gen.line()
298        gen.line(f'static constexpr const char* name() '
299                 f'{{ return "{service.name()}"; }}')
300
301        gen.line()
302
303    gen.line(' protected:')
304
305    with gen.indent():
306        gen.line('constexpr Service() : '
307                 f'{base_class}(kServiceId, kPwRpcMethods) {{}}')
308
309    gen.line()
310    gen.line(' private:')
311
312    with gen.indent():
313        gen.line('friend class ::pw::rpc::internal::MethodLookup;')
314        gen.line()
315
316        # Generate the method table
317        gen.line('static constexpr std::array<'
318                 f'{RPC_NAMESPACE}::internal::{gen.method_union_name()},'
319                 f' {len(service.methods())}> kPwRpcMethods = {{')
320
321        with gen.indent(4):
322            for method in service.methods():
323                gen.method_descriptor(method)
324
325        gen.line('};\n')
326
327        # Generate the method lookup table
328        _method_lookup_table(gen, service)
329
330    gen.line('};')
331
332
333def _method_lookup_table(gen: CodeGenerator, service: ProtoService) -> None:
334    """Generates array of method IDs for looking up methods at compile time."""
335    gen.line('static constexpr std::array<uint32_t, '
336             f'{len(service.methods())}> kPwRpcMethodIds = {{')
337
338    with gen.indent(4):
339        for method in service.methods():
340            gen.line(f'{get_id(method)},  // Hash of "{method.name()}"')
341
342    gen.line('};')
343
344
345class StubGenerator(abc.ABC):
346    """Generates stub method implementations that can be copied-and-pasted."""
347    @abc.abstractmethod
348    def unary_signature(self, method: ProtoServiceMethod, prefix: str) -> str:
349        """Returns the signature of this unary method."""
350
351    @abc.abstractmethod
352    def unary_stub(self, method: ProtoServiceMethod,
353                   output: OutputFile) -> None:
354        """Returns the stub for this unary method."""
355
356    @abc.abstractmethod
357    def server_streaming_signature(self, method: ProtoServiceMethod,
358                                   prefix: str) -> str:
359        """Returns the signature of this server streaming method."""
360
361    def server_streaming_stub(  # pylint: disable=no-self-use
362            self, unused_method: ProtoServiceMethod,
363            output: OutputFile) -> None:
364        """Returns the stub for this server streaming method."""
365        output.write_line(STUB_REQUEST_TODO)
366        output.write_line('static_cast<void>(request);')
367        output.write_line(STUB_WRITER_TODO)
368        output.write_line('static_cast<void>(writer);')
369
370    @abc.abstractmethod
371    def client_streaming_signature(self, method: ProtoServiceMethod,
372                                   prefix: str) -> str:
373        """Returns the signature of this client streaming method."""
374
375    def client_streaming_stub(  # pylint: disable=no-self-use
376            self, unused_method: ProtoServiceMethod,
377            output: OutputFile) -> None:
378        """Returns the stub for this client streaming method."""
379        output.write_line(STUB_READER_TODO)
380        output.write_line('static_cast<void>(reader);')
381
382    @abc.abstractmethod
383    def bidirectional_streaming_signature(self, method: ProtoServiceMethod,
384                                          prefix: str) -> str:
385        """Returns the signature of this bidirectional streaming method."""
386
387    def bidirectional_streaming_stub(  # pylint: disable=no-self-use
388            self, unused_method: ProtoServiceMethod,
389            output: OutputFile) -> None:
390        """Returns the stub for this bidirectional streaming method."""
391        output.write_line(STUB_READER_WRITER_TODO)
392        output.write_line('static_cast<void>(reader_writer);')
393
394
395def _select_stub_methods(gen: StubGenerator, method: ProtoServiceMethod):
396    if method.type() is ProtoServiceMethod.Type.UNARY:
397        return gen.unary_signature, gen.unary_stub
398
399    if method.type() is ProtoServiceMethod.Type.SERVER_STREAMING:
400        return gen.server_streaming_signature, gen.server_streaming_stub
401
402    if method.type() is ProtoServiceMethod.Type.CLIENT_STREAMING:
403        return gen.client_streaming_signature, gen.client_streaming_stub
404
405    if method.type() is ProtoServiceMethod.Type.BIDIRECTIONAL_STREAMING:
406        return (gen.bidirectional_streaming_signature,
407                gen.bidirectional_streaming_stub)
408
409    raise NotImplementedError(f'Unrecognized method type {method.type()}')
410
411
412_STUBS_COMMENT = r'''
413/*
414    ____                __                          __        __  _
415   /  _/___ ___  ____  / /__  ____ ___  ___  ____  / /_____ _/ /_(_)___  ____
416   / // __ `__ \/ __ \/ / _ \/ __ `__ \/ _ \/ __ \/ __/ __ `/ __/ / __ \/ __ \
417 _/ // / / / / / /_/ / /  __/ / / / / /  __/ / / / /_/ /_/ / /_/ / /_/ / / / /
418/___/_/ /_/ /_/ .___/_/\___/_/ /_/ /_/\___/_/ /_/\__/\__,_/\__/_/\____/_/ /_/
419             /_/
420   _____ __        __         __
421  / ___// /___  __/ /_  _____/ /
422  \__ \/ __/ / / / __ \/ ___/ /
423 ___/ / /_/ /_/ / /_/ (__  )_/
424/____/\__/\__,_/_.___/____(_)
425
426*/
427// This section provides stub implementations of the RPC services in this file.
428// The code below may be referenced or copied to serve as a starting point for
429// your RPC service implementations.
430'''
431
432
433def package_stubs(proto_package: ProtoNode, gen: CodeGenerator,
434                  stub_generator: StubGenerator) -> None:
435    """Generates the RPC stubs for a package."""
436    if proto_package.cpp_namespace():
437        file_ns = proto_package.cpp_namespace()
438        if file_ns.startswith('::'):
439            file_ns = file_ns[2:]
440
441        start_ns = lambda: gen.line(f'namespace {file_ns} {{\n')
442        finish_ns = lambda: gen.line(f'}}  // namespace {file_ns}\n')
443    else:
444        start_ns = finish_ns = lambda: None
445
446    services = [
447        cast(ProtoService, node) for node in proto_package
448        if node.type() == ProtoNode.Type.SERVICE
449    ]
450
451    gen.line('#ifdef _PW_RPC_COMPILE_GENERATED_SERVICE_STUBS')
452    gen.line(_STUBS_COMMENT)
453
454    gen.line(f'#include "{gen.output.name()}"\n')
455
456    start_ns()
457
458    for node in services:
459        _service_declaration_stub(node, gen, stub_generator)
460
461    gen.line()
462
463    finish_ns()
464
465    start_ns()
466
467    for node in services:
468        _service_definition_stub(node, gen, stub_generator)
469        gen.line()
470
471    finish_ns()
472
473    gen.line('#endif  // _PW_RPC_COMPILE_GENERATED_SERVICE_STUBS')
474
475
476def _service_declaration_stub(service: ProtoService, gen: CodeGenerator,
477                              stub_generator: StubGenerator) -> None:
478    gen.line(f'// Implementation class for {service.proto_path()}.')
479    gen.line(f'class {service.name()} : public pw_rpc::{gen.name()}::'
480             f'{service.name()}::Service<{service.name()}> {{')
481
482    gen.line(' public:')
483
484    with gen.indent():
485        blank_line = False
486
487        for method in service.methods():
488            if blank_line:
489                gen.line()
490            else:
491                blank_line = True
492
493            signature, _ = _select_stub_methods(stub_generator, method)
494
495            gen.line(signature(method, '') + ';')
496
497    gen.line('};\n')
498
499
500def _service_definition_stub(service: ProtoService, gen: CodeGenerator,
501                             stub_generator: StubGenerator) -> None:
502    gen.line(f'// Method definitions for {service.proto_path()}.')
503
504    blank_line = False
505
506    for method in service.methods():
507        if blank_line:
508            gen.line()
509        else:
510            blank_line = True
511
512        signature, stub = _select_stub_methods(stub_generator, method)
513
514        gen.line(signature(method, f'{service.name()}::') + ' {')
515        with gen.indent():
516            stub(method, gen.output)
517        gen.line('}')
518