1 // Copyright 2024 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include "pw_grpc/pw_rpc_handler.h"
16
17 #include <cinttypes>
18
19 namespace pw::grpc {
20
21 using pw::rpc::internal::pwpb::PacketType;
22
OnClose(StreamId id)23 void PwRpcHandler::OnClose(StreamId id) { ResetStream(id); }
24
OnNewConnection()25 void PwRpcHandler::OnNewConnection() { ResetAllStreams(); }
26
OnNew(StreamId id,InlineString<kMaxMethodNameSize> full_method_name)27 Status PwRpcHandler::OnNew(StreamId id,
28 InlineString<kMaxMethodNameSize> full_method_name) {
29 // Parse out service and method from `/grpc.examples.echo.Echo/UnaryEcho`
30 // formatted name.
31 std::string_view view = std::string_view(full_method_name);
32 auto split_pos = view.find_last_of('/');
33 if (view.empty() || view[0] != '/' || split_pos == std::string_view::npos) {
34 PW_LOG_WARN("Can't determine service/method name id=%" PRIu32 " name=%s",
35 id,
36 full_method_name.c_str());
37 return Status::NotFound();
38 }
39
40 auto service = view.substr(1, split_pos - 1);
41 auto method = view.substr(split_pos + 1);
42 return CreateStream(
43 id, rpc::internal::Hash(service), rpc::internal::Hash(method));
44 }
45
OnMessage(StreamId id,ByteSpan message)46 Status PwRpcHandler::OnMessage(StreamId id, ByteSpan message) {
47 auto stream = LookupStream(id);
48 if (!stream.ok()) {
49 PW_LOG_INFO("Handler.OnMessage id=%" PRIu32 " size=%zu: unknown stream",
50 id,
51 message.size());
52 return Status::NotFound();
53 }
54
55 const auto [service, method] =
56 server_.FindMethod(stream->service_id, stream->method_id);
57 if (service == nullptr || method == nullptr) {
58 PW_LOG_WARN("Could not find method type");
59 return Status::NotFound();
60 }
61
62 switch (method->type()) {
63 case pw::rpc::MethodType::kUnary:
64 case pw::rpc::MethodType::kServerStreaming: {
65 auto packet = pw::rpc::internal::Packet(PacketType::kRequest,
66 channel_id_,
67 stream->service_id,
68 stream->method_id,
69 id,
70 message,
71 pw::OkStatus());
72 PW_TRY(server_.ProcessPacket(packet));
73 break;
74 }
75 case pw::rpc::MethodType::kClientStreaming:
76 case pw::rpc::MethodType::kBidirectionalStreaming: {
77 if (!stream->sent_request) {
78 auto packet = pw::rpc::internal::Packet(PacketType::kRequest,
79 channel_id_,
80 stream->service_id,
81 stream->method_id,
82 id,
83 {},
84 pw::OkStatus());
85 PW_TRY(server_.ProcessPacket(packet));
86 MarkSentRequest(id);
87 }
88
89 auto packet = pw::rpc::internal::Packet(PacketType::kClientStream,
90 channel_id_,
91 stream->service_id,
92 stream->method_id,
93 id,
94 message,
95 pw::OkStatus());
96 PW_TRY(server_.ProcessPacket(packet));
97 break;
98 }
99 default:
100 PW_LOG_WARN("Unexpected method type");
101 return Status::Internal();
102 }
103
104 return OkStatus();
105 }
106
OnHalfClose(StreamId id)107 void PwRpcHandler::OnHalfClose(StreamId id) {
108 auto stream = LookupStream(id);
109 if (!stream.ok()) {
110 PW_LOG_INFO("OnHalfClose unknown stream");
111 return;
112 }
113
114 const auto [service, method] =
115 server_.FindMethod(stream->service_id, stream->method_id);
116 if (service == nullptr || method == nullptr) {
117 PW_LOG_WARN("Could not find method type");
118 return;
119 }
120
121 if (method->type() == pw::rpc::MethodType::kClientStreaming ||
122 method->type() == pw::rpc::MethodType::kBidirectionalStreaming) {
123 auto packet =
124 pw::rpc::internal::Packet(PacketType::kClientRequestCompletion,
125 channel_id_,
126 stream->service_id,
127 stream->method_id,
128 id,
129 {},
130 pw::OkStatus());
131 ResetStream(id);
132
133 server_.ProcessPacket(packet).IgnoreError();
134 }
135 }
136
OnCancel(StreamId id)137 void PwRpcHandler::OnCancel(StreamId id) {
138 auto stream = LookupStream(id);
139 if (!stream.ok()) {
140 PW_LOG_INFO("OnCancel unknown stream");
141 return;
142 }
143
144 auto packet = pw::rpc::internal::Packet(PacketType::kClientError,
145 channel_id_,
146 stream->service_id,
147 stream->method_id,
148 id,
149 {},
150 pw::Status::Cancelled());
151 ResetStream(id);
152
153 server_.ProcessPacket(packet).IgnoreError();
154 }
155
LookupStream(StreamId id)156 Result<PwRpcHandler::Stream> PwRpcHandler::LookupStream(StreamId id) {
157 auto streams_locked = streams_.acquire();
158 for (size_t i = 0; i < streams_locked->size(); ++i) {
159 auto& stream = (*streams_locked)[i];
160 if (stream.id == id) {
161 return stream;
162 }
163 }
164 return Status::NotFound();
165 }
166
ResetAllStreams()167 void PwRpcHandler::ResetAllStreams() {
168 auto streams_locked = streams_.acquire();
169 for (size_t i = 0; i < streams_locked->size(); ++i) {
170 auto& stream = (*streams_locked)[i];
171 stream.id = 0;
172 }
173 }
174
ResetStream(StreamId id)175 void PwRpcHandler::ResetStream(StreamId id) {
176 auto streams_locked = streams_.acquire();
177 for (size_t i = 0; i < streams_locked->size(); ++i) {
178 auto& stream = (*streams_locked)[i];
179 if (stream.id == id) {
180 stream.id = 0;
181 break;
182 }
183 }
184 }
185
MarkSentRequest(StreamId id)186 void PwRpcHandler::MarkSentRequest(StreamId id) {
187 auto streams_locked = streams_.acquire();
188 for (size_t i = 0; i < streams_locked->size(); ++i) {
189 auto& stream = (*streams_locked)[i];
190 if (stream.id == id) {
191 stream.sent_request = true;
192 break;
193 }
194 }
195 }
196
CreateStream(StreamId id,uint32_t service_id,uint32_t method_id)197 Status PwRpcHandler::CreateStream(StreamId id,
198 uint32_t service_id,
199 uint32_t method_id) {
200 auto streams_locked = streams_.acquire();
201
202 for (size_t i = 0; i < streams_locked->size(); ++i) {
203 auto& stream = (*streams_locked)[i];
204 if (!stream.id) {
205 stream.id = id;
206 stream.service_id = service_id;
207 stream.method_id = method_id;
208 stream.sent_request = false;
209 return OkStatus();
210 }
211 }
212 return Status::ResourceExhausted();
213 }
214
215 } // namespace pw::grpc
216