• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Common RPC codegen utilities."""
15
16import abc
17from datetime import datetime
18import os
19from typing import cast, Any, Iterable, Union
20
21from pw_protobuf.output_file import OutputFile
22from pw_protobuf.proto_tree import ProtoNode, ProtoService, ProtoServiceMethod
23from pw_rpc import ids
24
25PLUGIN_NAME = 'pw_rpc_codegen'
26PLUGIN_VERSION = '0.3.0'
27
28RPC_NAMESPACE = '::pw::rpc'
29
30# todo-check: disable
31STUB_REQUEST_TODO = (
32    '// TODO: Read the request as appropriate for your application'
33)
34STUB_RESPONSE_TODO = (
35    '// TODO: Fill in the response as appropriate for your application'
36)
37STUB_WRITER_TODO = (
38    '// TODO: Send responses with the writer as appropriate for your '
39    'application'
40)
41STUB_READER_TODO = (
42    '// TODO: Set the client stream callback and send a response as '
43    'appropriate for your application'
44)
45STUB_READER_WRITER_TODO = (
46    '// TODO: Set the client stream callback and send responses as '
47    'appropriate for your application'
48)
49# todo-check: enable
50
51
52def get_id(item: Union[ProtoService, ProtoServiceMethod]) -> str:
53    name = item.proto_path() if isinstance(item, ProtoService) else item.name()
54    return f'0x{ids.calculate(name):08x}'
55
56
57def client_call_type(method: ProtoServiceMethod, prefix: str) -> str:
58    """Returns Client ReaderWriter/Reader/Writer/Recevier for the call."""
59    if method.type() is ProtoServiceMethod.Type.UNARY:
60        call_class = 'UnaryReceiver'
61    elif method.type() is ProtoServiceMethod.Type.SERVER_STREAMING:
62        call_class = 'ClientReader'
63    elif method.type() is ProtoServiceMethod.Type.CLIENT_STREAMING:
64        call_class = 'ClientWriter'
65    elif method.type() is ProtoServiceMethod.Type.BIDIRECTIONAL_STREAMING:
66        call_class = 'ClientReaderWriter'
67    else:
68        raise NotImplementedError(f'Unknown {method.type()}')
69
70    return f'{RPC_NAMESPACE}::{prefix}{call_class}'
71
72
73class CodeGenerator(abc.ABC):
74    """Generates RPC code for services and clients."""
75
76    def __init__(self, output_filename: str) -> None:
77        self.output = OutputFile(output_filename)
78
79    def indent(self, amount: int = OutputFile.INDENT_WIDTH) -> Any:
80        """Indents the output. Use in a with block."""
81        return self.output.indent(amount)
82
83    def line(self, value: str = '') -> None:
84        """Writes a line to the output."""
85        self.output.write_line(value)
86
87    def indented_list(self, *args: str, end: str = ',') -> None:
88        """Outputs each arg one per line; adds end to teh last arg."""
89        with self.indent(4):
90            for arg in args[:-1]:
91                self.line(arg + ',')
92
93            self.line(args[-1] + end)
94
95    @abc.abstractmethod
96    def name(self) -> str:
97        """Name of the pw_rpc implementation."""
98
99    @abc.abstractmethod
100    def method_union_name(self) -> str:
101        """Name of the MethodUnion class to use."""
102
103    @abc.abstractmethod
104    def includes(self, proto_file_name: str) -> Iterable[str]:
105        """Yields #include lines."""
106
107    @abc.abstractmethod
108    def service_aliases(self) -> None:
109        """Generates reader/writer aliases."""
110
111    @abc.abstractmethod
112    def method_descriptor(self, method: ProtoServiceMethod) -> None:
113        """Generates code for a service method."""
114
115    @abc.abstractmethod
116    def client_member_function(self, method: ProtoServiceMethod) -> None:
117        """Generates the client code for the Client member functions."""
118
119    @abc.abstractmethod
120    def client_static_function(self, method: ProtoServiceMethod) -> None:
121        """Generates method static functions that instantiate a Client."""
122
123    def method_info_specialization(self, method: ProtoServiceMethod) -> None:
124        """Generates impl-specific additions to the MethodInfo specialization.
125
126        May be empty if the generator has nothing to add to the MethodInfo.
127        """
128
129    def private_additions(self, service: ProtoService) -> None:
130        """Additions to the private section of the outer generated class."""
131
132
133def generate_package(
134    file_descriptor_proto, proto_package: ProtoNode, gen: CodeGenerator
135) -> None:
136    """Generates service and client code for a package."""
137    assert proto_package.type() == ProtoNode.Type.PACKAGE
138
139    gen.line(
140        f'// {os.path.basename(gen.output.name())} automatically '
141        f'generated by {PLUGIN_NAME} {PLUGIN_VERSION}'
142    )
143    gen.line(f'// on {datetime.now().isoformat()}')
144    gen.line('// clang-format off')
145    gen.line('#pragma once\n')
146
147    gen.line('#include <array>')
148    gen.line('#include <cstdint>')
149    gen.line('#include <type_traits>\n')
150
151    include_lines = [
152        '#include "pw_rpc/internal/method_info.h"',
153        '#include "pw_rpc/internal/method_lookup.h"',
154        '#include "pw_rpc/internal/service_client.h"',
155        '#include "pw_rpc/method_type.h"',
156        '#include "pw_rpc/service.h"',
157        '#include "pw_rpc/service_id.h"',
158    ]
159    include_lines += gen.includes(file_descriptor_proto.name)
160
161    for include_line in sorted(include_lines):
162        gen.line(include_line)
163
164    gen.line()
165
166    if proto_package.cpp_namespace(codegen_subnamespace=None):
167        file_namespace = proto_package.cpp_namespace(codegen_subnamespace=None)
168        if file_namespace.startswith('::'):
169            file_namespace = file_namespace[2:]
170
171        gen.line(f'namespace {file_namespace} {{')
172    else:
173        file_namespace = ''
174
175    gen.line(f'namespace pw_rpc::{gen.name()} {{')
176    gen.line()
177
178    services = [
179        cast(ProtoService, node)
180        for node in proto_package
181        if node.type() == ProtoNode.Type.SERVICE
182    ]
183
184    for service in services:
185        _generate_service_and_client(gen, service)
186
187    gen.line()
188    gen.line(f'}}  // namespace pw_rpc::{gen.name()}\n')
189
190    if file_namespace:
191        gen.line('}  // namespace ' + file_namespace)
192
193    gen.line()
194    gen.line(
195        '// Specialize MethodInfo for each RPC to provide metadata at '
196        'compile time.'
197    )
198    for service in services:
199        _generate_info(gen, file_namespace, service)
200
201
202def _generate_service_and_client(
203    gen: CodeGenerator, service: ProtoService
204) -> None:
205    gen.line(
206        '// Wrapper class that namespaces server and client code for '
207        'this RPC service.'
208    )
209    gen.line(f'class {service.name()} final {{')
210    gen.line(' public:')
211
212    with gen.indent():
213        gen.line(f'{service.name()}() = delete;')
214        gen.line()
215
216        gen.line('static constexpr ::pw::rpc::ServiceId service_id() {')
217        with gen.indent():
218            gen.line('return ::pw::rpc::internal::WrapServiceId(kServiceId);')
219        gen.line('}')
220        gen.line()
221
222        _generate_service(gen, service)
223
224        gen.line()
225
226        _generate_client(gen, service)
227
228    gen.line(' private:')
229
230    with gen.indent():
231        gen.line(f'// Hash of "{service.proto_path()}".')
232        gen.line(f'static constexpr uint32_t kServiceId = {get_id(service)};')
233
234    gen.line('};')
235
236
237def _check_method_name(method: ProtoServiceMethod) -> None:
238    # Methods with the same name as their enclosing service will fail
239    # to compile because the generated method will be indistinguishable
240    # from a constructor.
241    if method.name() == method.service().name():
242        raise ValueError(
243            f'Attempted to compile `pw_rpc` for proto with method '
244            f'`{method.name()}` inside a service of the same name. '
245            '`pw_rpc` does not yet support methods with the same name as their '
246            'enclosing service.'
247        )
248    if method.name() in ('Service', 'ServiceInfo', 'Client'):
249        raise ValueError(
250            f'"{method.service().proto_path()}.{method.name()}" is not a '
251            f'valid method name! The name "{method.name()}" is reserved '
252            'for internal use by pw_rpc.'
253        )
254
255
256def _generate_client(gen: CodeGenerator, service: ProtoService) -> None:
257    gen.line('// The Client is used to invoke RPCs for this service.')
258    gen.line(
259        f'class Client final : public {RPC_NAMESPACE}::internal::'
260        'ServiceClient {'
261    )
262    gen.line(' public:')
263
264    with gen.indent():
265        gen.line(
266            f'constexpr Client({RPC_NAMESPACE}::Client& client,'
267            ' uint32_t channel_id)'
268        )
269        gen.line('    : ServiceClient(client, channel_id) {}')
270        gen.line()
271        gen.line(f'using ServiceInfo = {service.name()};')
272
273        for method in service.methods():
274            gen.line()
275            gen.client_member_function(method)
276
277    gen.line('};')
278    gen.line()
279
280    gen.line(
281        '// Static functions for invoking RPCs on a pw_rpc server. '
282        'These functions are '
283    )
284    gen.line(
285        '// equivalent to instantiating a Client and calling the '
286        'corresponding RPC.'
287    )
288    for method in service.methods():
289        _check_method_name(method)
290        gen.client_static_function(method)
291        gen.line()
292
293
294def _generate_info(
295    gen: CodeGenerator, namespace: str, service: ProtoService
296) -> None:
297    """Generates MethodInfo for each method."""
298    service_id = get_id(service)
299    info = f'struct {RPC_NAMESPACE.lstrip(":")}::internal::MethodInfo'
300
301    for method in service.methods():
302        gen.line('template <>')
303        gen.line(
304            f'{info}<{namespace}::pw_rpc::{gen.name()}::'
305            f'{service.name()}::{method.name()}> {{'
306        )
307
308        with gen.indent():
309            gen.line(f'static constexpr uint32_t kServiceId = {service_id};')
310            gen.line(
311                f'static constexpr uint32_t kMethodId = ' f'{get_id(method)};'
312            )
313            gen.line(
314                f'static constexpr {RPC_NAMESPACE}::MethodType kType = '
315                f'{method.type().cc_enum()};'
316            )
317            gen.line()
318
319            gen.line('template <typename ServiceImpl>')
320            gen.line('static constexpr auto Function() {')
321
322            with gen.indent():
323                gen.line(f'return &ServiceImpl::{method.name()};')
324
325            gen.line('}')
326
327            gen.line(
328                'using GeneratedClient = '
329                f'{"::" + namespace if namespace else ""}'
330                f'::pw_rpc::{gen.name()}::{service.name()}::Client;'
331            )
332
333            gen.method_info_specialization(method)
334
335        gen.line('};')
336        gen.line()
337
338
339def _generate_service(gen: CodeGenerator, service: ProtoService) -> None:
340    """Generates a C++ class for an RPC service."""
341
342    base_class = f'{RPC_NAMESPACE}::Service'
343    gen.line('// The RPC service base class.')
344    gen.line(
345        '// Inherit from this to implement an RPC service for a pw_rpc server.'
346    )
347    gen.line('template <typename Implementation>')
348    gen.line(f'class Service : public {base_class} {{')
349    gen.line(' public:')
350
351    with gen.indent():
352        gen.service_aliases()
353
354        gen.line()
355        gen.line(
356            f'static constexpr const char* name() '
357            f'{{ return "{service.name()}"; }}'
358        )
359        gen.line()
360        gen.line(f'using ServiceInfo = {service.name()};')
361        gen.line()
362
363    gen.line(' protected:')
364
365    with gen.indent():
366        gen.line(
367            'constexpr Service() : '
368            f'{base_class}(kServiceId, kPwRpcMethods) {{}}'
369        )
370
371    gen.line()
372    gen.line(' private:')
373
374    with gen.indent():
375        gen.line('friend class ::pw::rpc::internal::MethodLookup;')
376        gen.line()
377
378        # Generate the method table
379        gen.line(
380            'static constexpr std::array<'
381            f'{RPC_NAMESPACE}::internal::{gen.method_union_name()},'
382            f' {len(service.methods())}> kPwRpcMethods = {{'
383        )
384
385        with gen.indent(4):
386            for method in service.methods():
387                gen.method_descriptor(method)
388
389        gen.line('};\n')
390
391        # Generate the method lookup table
392        _method_lookup_table(gen, service)
393
394    gen.line('};')
395
396
397def _method_lookup_table(gen: CodeGenerator, service: ProtoService) -> None:
398    """Generates array of method IDs for looking up methods at compile time."""
399    gen.line(
400        'static constexpr std::array<uint32_t, '
401        f'{len(service.methods())}> kPwRpcMethodIds = {{'
402    )
403
404    with gen.indent(4):
405        for method in service.methods():
406            gen.line(f'{get_id(method)},  // Hash of "{method.name()}"')
407
408    gen.line('};')
409
410
411class StubGenerator(abc.ABC):
412    """Generates stub method implementations that can be copied-and-pasted."""
413
414    @abc.abstractmethod
415    def unary_signature(self, method: ProtoServiceMethod, prefix: str) -> str:
416        """Returns the signature of this unary method."""
417
418    @abc.abstractmethod
419    def unary_stub(
420        self, method: ProtoServiceMethod, output: OutputFile
421    ) -> None:
422        """Returns the stub for this unary method."""
423
424    @abc.abstractmethod
425    def server_streaming_signature(
426        self, method: ProtoServiceMethod, prefix: str
427    ) -> str:
428        """Returns the signature of this server streaming method."""
429
430    def server_streaming_stub(  # pylint: disable=no-self-use
431        self, unused_method: ProtoServiceMethod, output: OutputFile
432    ) -> None:
433        """Returns the stub for this server streaming method."""
434        output.write_line(STUB_REQUEST_TODO)
435        output.write_line('static_cast<void>(request);')
436        output.write_line(STUB_WRITER_TODO)
437        output.write_line('static_cast<void>(writer);')
438
439    @abc.abstractmethod
440    def client_streaming_signature(
441        self, method: ProtoServiceMethod, prefix: str
442    ) -> str:
443        """Returns the signature of this client streaming method."""
444
445    def client_streaming_stub(  # pylint: disable=no-self-use
446        self, unused_method: ProtoServiceMethod, output: OutputFile
447    ) -> None:
448        """Returns the stub for this client streaming method."""
449        output.write_line(STUB_READER_TODO)
450        output.write_line('static_cast<void>(reader);')
451
452    @abc.abstractmethod
453    def bidirectional_streaming_signature(
454        self, method: ProtoServiceMethod, prefix: str
455    ) -> str:
456        """Returns the signature of this bidirectional streaming method."""
457
458    def bidirectional_streaming_stub(  # pylint: disable=no-self-use
459        self, unused_method: ProtoServiceMethod, output: OutputFile
460    ) -> None:
461        """Returns the stub for this bidirectional streaming method."""
462        output.write_line(STUB_READER_WRITER_TODO)
463        output.write_line('static_cast<void>(reader_writer);')
464
465
466def _select_stub_methods(gen: StubGenerator, method: ProtoServiceMethod):
467    if method.type() is ProtoServiceMethod.Type.UNARY:
468        return gen.unary_signature, gen.unary_stub
469
470    if method.type() is ProtoServiceMethod.Type.SERVER_STREAMING:
471        return gen.server_streaming_signature, gen.server_streaming_stub
472
473    if method.type() is ProtoServiceMethod.Type.CLIENT_STREAMING:
474        return gen.client_streaming_signature, gen.client_streaming_stub
475
476    if method.type() is ProtoServiceMethod.Type.BIDIRECTIONAL_STREAMING:
477        return (
478            gen.bidirectional_streaming_signature,
479            gen.bidirectional_streaming_stub,
480        )
481
482    raise NotImplementedError(f'Unrecognized method type {method.type()}')
483
484
485_STUBS_COMMENT = r'''
486/*
487    ____                __                          __        __  _
488   /  _/___ ___  ____  / /__  ____ ___  ___  ____  / /_____ _/ /_(_)___  ____
489   / // __ `__ \/ __ \/ / _ \/ __ `__ \/ _ \/ __ \/ __/ __ `/ __/ / __ \/ __ \
490 _/ // / / / / / /_/ / /  __/ / / / / /  __/ / / / /_/ /_/ / /_/ / /_/ / / / /
491/___/_/ /_/ /_/ .___/_/\___/_/ /_/ /_/\___/_/ /_/\__/\__,_/\__/_/\____/_/ /_/
492             /_/
493   _____ __        __         __
494  / ___// /___  __/ /_  _____/ /
495  \__ \/ __/ / / / __ \/ ___/ /
496 ___/ / /_/ /_/ / /_/ (__  )_/
497/____/\__/\__,_/_.___/____(_)
498
499*/
500// This section provides stub implementations of the RPC services in this file.
501// The code below may be referenced or copied to serve as a starting point for
502// your RPC service implementations.
503'''
504
505
506def package_stubs(
507    proto_package: ProtoNode, gen: CodeGenerator, stub_generator: StubGenerator
508) -> None:
509    """Generates the RPC stubs for a package."""
510    if proto_package.cpp_namespace(codegen_subnamespace=None):
511        file_ns = proto_package.cpp_namespace(codegen_subnamespace=None)
512        if file_ns.startswith('::'):
513            file_ns = file_ns[2:]
514
515        start_ns = lambda: gen.line(f'namespace {file_ns} {{\n')
516        finish_ns = lambda: gen.line(f'}}  // namespace {file_ns}\n')
517    else:
518        start_ns = finish_ns = lambda: None
519
520    services = [
521        cast(ProtoService, node)
522        for node in proto_package
523        if node.type() == ProtoNode.Type.SERVICE
524    ]
525
526    gen.line('#ifdef _PW_RPC_COMPILE_GENERATED_SERVICE_STUBS')
527    gen.line(_STUBS_COMMENT)
528
529    gen.line(f'#include "{gen.output.name()}"\n')
530
531    start_ns()
532
533    for node in services:
534        _service_declaration_stub(node, gen, stub_generator)
535
536    gen.line()
537
538    finish_ns()
539
540    start_ns()
541
542    for node in services:
543        _service_definition_stub(node, gen, stub_generator)
544        gen.line()
545
546    finish_ns()
547
548    gen.line('#endif  // _PW_RPC_COMPILE_GENERATED_SERVICE_STUBS')
549
550
551def _service_declaration_stub(
552    service: ProtoService, gen: CodeGenerator, stub_generator: StubGenerator
553) -> None:
554    gen.line(f'// Implementation class for {service.proto_path()}.')
555    gen.line(
556        f'class {service.name()} : public pw_rpc::{gen.name()}::'
557        f'{service.name()}::Service<{service.name()}> {{'
558    )
559
560    gen.line(' public:')
561
562    with gen.indent():
563        blank_line = False
564
565        for method in service.methods():
566            if blank_line:
567                gen.line()
568            else:
569                blank_line = True
570
571            signature, _ = _select_stub_methods(stub_generator, method)
572
573            gen.line(signature(method, '') + ';')
574
575    gen.line('};\n')
576
577
578def _service_definition_stub(
579    service: ProtoService, gen: CodeGenerator, stub_generator: StubGenerator
580) -> None:
581    gen.line(f'// Method definitions for {service.proto_path()}.')
582
583    blank_line = False
584
585    for method in service.methods():
586        if blank_line:
587            gen.line()
588        else:
589            blank_line = True
590
591        signature, stub = _select_stub_methods(stub_generator, method)
592
593        gen.line(signature(method, f'{service.name()}::') + ' {')
594        with gen.indent():
595            stub(method, gen.output)
596        gen.line('}')
597