• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""This module generates the code for nanopb-based pw_rpc services."""
15
16import os
17from typing import Iterable, Iterator
18
19from pw_protobuf.output_file import OutputFile
20from pw_protobuf.proto_tree import ProtoNode, ProtoService, ProtoServiceMethod
21from pw_protobuf.proto_tree import build_node_tree
22from pw_rpc import codegen
23from pw_rpc.codegen import RPC_NAMESPACE
24import pw_rpc.ids
25
26PROTO_H_EXTENSION = '.pb.h'
27PROTO_CC_EXTENSION = '.pb.cc'
28NANOPB_H_EXTENSION = '.pb.h'
29
30
31def _proto_filename_to_nanopb_header(proto_file: str) -> str:
32    """Returns the generated nanopb header name for a .proto file."""
33    return os.path.splitext(proto_file)[0] + NANOPB_H_EXTENSION
34
35
36def _proto_filename_to_generated_header(proto_file: str) -> str:
37    """Returns the generated C++ RPC header name for a .proto file."""
38    filename = os.path.splitext(proto_file)[0]
39    return f'{filename}.rpc{PROTO_H_EXTENSION}'
40
41
42def _generate_method_descriptor(method: ProtoServiceMethod, method_id: int,
43                                output: OutputFile) -> None:
44    """Generates a nanopb method descriptor for an RPC method."""
45
46    req_fields = f'{method.request_type().nanopb_name()}_fields'
47    res_fields = f'{method.response_type().nanopb_name()}_fields'
48    impl_method = f'&Implementation::{method.name()}'
49
50    output.write_line(
51        f'{RPC_NAMESPACE}::internal::GetNanopbOrRawMethodFor<{impl_method}, '
52        f'{method.type().cc_enum()}, '
53        f'{method.request_type().nanopb_name()}, '
54        f'{method.response_type().nanopb_name()}>(')
55    with output.indent(4):
56        output.write_line(f'0x{method_id:08x},  // Hash of "{method.name()}"')
57        output.write_line(f'{req_fields},')
58        output.write_line(f'{res_fields}),')
59
60
61def _generate_server_writer_alias(output: OutputFile) -> None:
62    output.write_line('template <typename T>')
63    output.write_line('using ServerWriter = ::pw::rpc::ServerWriter<T>;')
64
65
66def _generate_code_for_service(service: ProtoService, root: ProtoNode,
67                               output: OutputFile) -> None:
68    """Generates a C++ derived class for a nanopb RPC service."""
69    codegen.service_class(service, root, output, _generate_server_writer_alias,
70                          'NanopbMethodUnion', _generate_method_descriptor)
71
72
73def _generate_code_for_client_method(method: ProtoServiceMethod,
74                                     output: OutputFile) -> None:
75    """Outputs client code for a single RPC method."""
76
77    req = method.request_type().nanopb_name()
78    res = method.response_type().nanopb_name()
79    method_id = pw_rpc.ids.calculate(method.name())
80
81    if method.type() == ProtoServiceMethod.Type.UNARY:
82        callback = f'{RPC_NAMESPACE}::UnaryResponseHandler<{res}>'
83    elif method.type() == ProtoServiceMethod.Type.SERVER_STREAMING:
84        callback = f'{RPC_NAMESPACE}::ServerStreamingResponseHandler<{res}>'
85    else:
86        raise NotImplementedError(
87            'Only unary and server streaming RPCs are currently supported')
88
89    output.write_line()
90    output.write_line(f'static NanopbClientCall<\n    {callback}>')
91    output.write_line(f'{method.name()}({RPC_NAMESPACE}::Channel& channel,')
92    with output.indent(len(method.name()) + 1):
93        output.write_line(f'const {req}& request,')
94        output.write_line(f'{callback}& callback) {{')
95
96    with output.indent():
97        output.write_line(f'NanopbClientCall<{callback}>')
98        output.write_line('    call(&channel,')
99        with output.indent(9):
100            output.write_line('kServiceId,')
101            output.write_line(
102                f'0x{method_id:08x},  // Hash of "{method.name()}"')
103            output.write_line('callback,')
104            output.write_line(f'{req}_fields,')
105            output.write_line(f'{res}_fields);')
106        output.write_line('call.SendRequest(&request);')
107        output.write_line('return call;')
108
109    output.write_line('}')
110
111
112def _generate_code_for_client(service: ProtoService, root: ProtoNode,
113                              output: OutputFile) -> None:
114    """Outputs client code for an RPC service."""
115
116    output.write_line('namespace nanopb {')
117
118    class_name = f'{service.cpp_namespace(root)}Client'
119    output.write_line(f'\nclass {class_name} {{')
120    output.write_line(' public:')
121
122    with output.indent():
123        output.write_line('template <typename T>')
124        output.write_line(
125            f'using NanopbClientCall = {RPC_NAMESPACE}::NanopbClientCall<T>;')
126
127        output.write_line('')
128        output.write_line(f'{class_name}() = delete;')
129
130        for method in service.methods():
131            _generate_code_for_client_method(method, output)
132
133    service_name_hash = pw_rpc.ids.calculate(service.proto_path())
134    output.write_line('\n private:')
135
136    with output.indent():
137        output.write_line(f'// Hash of "{service.proto_path()}".')
138        output.write_line(
139            f'static constexpr uint32_t kServiceId = 0x{service_name_hash:08x};'
140        )
141
142    output.write_line('};')
143
144    output.write_line('\n}  // namespace nanopb\n')
145
146
147def includes(proto_file, unused_package: ProtoNode) -> Iterator[str]:
148    yield '#include "pw_rpc/internal/nanopb_method_union.h"'
149    yield '#include "pw_rpc/nanopb_client_call.h"'
150
151    # Include the corresponding nanopb header file for this proto file, in which
152    # the file's messages and enums are generated. All other files imported from
153    # the .proto file are #included in there.
154    nanopb_header = _proto_filename_to_nanopb_header(proto_file.name)
155    yield f'#include "{nanopb_header}"'
156
157
158def _generate_code_for_package(proto_file, package: ProtoNode,
159                               output: OutputFile) -> None:
160    """Generates code for a header file corresponding to a .proto file."""
161
162    codegen.package(proto_file, package, output, includes,
163                    _generate_code_for_service, _generate_code_for_client)
164
165
166class StubGenerator(codegen.StubGenerator):
167    def unary_signature(self, method: ProtoServiceMethod, prefix: str) -> str:
168        return (f'pw::Status {prefix}{method.name()}(ServerContext&, '
169                f'const {method.request_type().nanopb_name()}& request, '
170                f'{method.response_type().nanopb_name()}& response)')
171
172    def unary_stub(self, method: ProtoServiceMethod,
173                   output: OutputFile) -> None:
174        output.write_line(codegen.STUB_REQUEST_TODO)
175        output.write_line('static_cast<void>(request);')
176        output.write_line(codegen.STUB_RESPONSE_TODO)
177        output.write_line('static_cast<void>(response);')
178        output.write_line('return pw::Status::Unimplemented();')
179
180    def server_streaming_signature(self, method: ProtoServiceMethod,
181                                   prefix: str) -> str:
182        return (
183            f'void {prefix}{method.name()}(ServerContext&, '
184            f'const {method.request_type().nanopb_name()}& request, '
185            f'ServerWriter<{method.response_type().nanopb_name()}>& writer)')
186
187
188def process_proto_file(proto_file) -> Iterable[OutputFile]:
189    """Generates code for a single .proto file."""
190
191    _, package_root = build_node_tree(proto_file)
192    output_filename = _proto_filename_to_generated_header(proto_file.name)
193    output_file = OutputFile(output_filename)
194    _generate_code_for_package(proto_file, package_root, output_file)
195
196    output_file.write_line()
197    codegen.package_stubs(package_root, output_file, StubGenerator())
198
199    return [output_file]
200