• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2024 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #include "pw_grpc/pw_rpc_handler.h"
16 
17 #include <cinttypes>
18 
19 namespace pw::grpc {
20 
21 using pw::rpc::internal::pwpb::PacketType;
22 
OnClose(StreamId id)23 void PwRpcHandler::OnClose(StreamId id) { ResetStream(id); }
24 
OnNewConnection()25 void PwRpcHandler::OnNewConnection() { ResetAllStreams(); }
26 
OnNew(StreamId id,InlineString<kMaxMethodNameSize> full_method_name)27 Status PwRpcHandler::OnNew(StreamId id,
28                            InlineString<kMaxMethodNameSize> full_method_name) {
29   // Parse out service and method from `/grpc.examples.echo.Echo/UnaryEcho`
30   // formatted name.
31   std::string_view view = std::string_view(full_method_name);
32   auto split_pos = view.find_last_of('/');
33   if (view.empty() || view[0] != '/' || split_pos == std::string_view::npos) {
34     PW_LOG_WARN("Can't determine service/method name id=%" PRIu32 " name=%s",
35                 id,
36                 full_method_name.c_str());
37     return Status::NotFound();
38   }
39 
40   auto service = view.substr(1, split_pos - 1);
41   auto method = view.substr(split_pos + 1);
42   return CreateStream(
43       id, rpc::internal::Hash(service), rpc::internal::Hash(method));
44 }
45 
OnMessage(StreamId id,ByteSpan message)46 Status PwRpcHandler::OnMessage(StreamId id, ByteSpan message) {
47   auto stream = LookupStream(id);
48   if (!stream.ok()) {
49     PW_LOG_INFO("Handler.OnMessage id=%" PRIu32 " size=%zu: unknown stream",
50                 id,
51                 message.size());
52     return Status::NotFound();
53   }
54 
55   const auto [service, method] =
56       server_.FindMethod(stream->service_id, stream->method_id);
57   if (service == nullptr || method == nullptr) {
58     PW_LOG_WARN("Could not find method type");
59     return Status::NotFound();
60   }
61 
62   switch (method->type()) {
63     case pw::rpc::MethodType::kUnary:
64     case pw::rpc::MethodType::kServerStreaming: {
65       auto packet = pw::rpc::internal::Packet(PacketType::kRequest,
66                                               channel_id_,
67                                               stream->service_id,
68                                               stream->method_id,
69                                               id,
70                                               message,
71                                               pw::OkStatus());
72       PW_TRY(server_.ProcessPacket(packet));
73       break;
74     }
75     case pw::rpc::MethodType::kClientStreaming:
76     case pw::rpc::MethodType::kBidirectionalStreaming: {
77       if (!stream->sent_request) {
78         auto packet = pw::rpc::internal::Packet(PacketType::kRequest,
79                                                 channel_id_,
80                                                 stream->service_id,
81                                                 stream->method_id,
82                                                 id,
83                                                 {},
84                                                 pw::OkStatus());
85         PW_TRY(server_.ProcessPacket(packet));
86         MarkSentRequest(id);
87       }
88 
89       auto packet = pw::rpc::internal::Packet(PacketType::kClientStream,
90                                               channel_id_,
91                                               stream->service_id,
92                                               stream->method_id,
93                                               id,
94                                               message,
95                                               pw::OkStatus());
96       PW_TRY(server_.ProcessPacket(packet));
97       break;
98     }
99     default:
100       PW_LOG_WARN("Unexpected method type");
101       return Status::Internal();
102   }
103 
104   return OkStatus();
105 }
106 
OnHalfClose(StreamId id)107 void PwRpcHandler::OnHalfClose(StreamId id) {
108   auto stream = LookupStream(id);
109   if (!stream.ok()) {
110     PW_LOG_INFO("OnHalfClose unknown stream");
111     return;
112   }
113 
114   const auto [service, method] =
115       server_.FindMethod(stream->service_id, stream->method_id);
116   if (service == nullptr || method == nullptr) {
117     PW_LOG_WARN("Could not find method type");
118     return;
119   }
120 
121   if (method->type() == pw::rpc::MethodType::kClientStreaming ||
122       method->type() == pw::rpc::MethodType::kBidirectionalStreaming) {
123     auto packet =
124         pw::rpc::internal::Packet(PacketType::kClientRequestCompletion,
125                                   channel_id_,
126                                   stream->service_id,
127                                   stream->method_id,
128                                   id,
129                                   {},
130                                   pw::OkStatus());
131     ResetStream(id);
132 
133     server_.ProcessPacket(packet).IgnoreError();
134   }
135 }
136 
OnCancel(StreamId id)137 void PwRpcHandler::OnCancel(StreamId id) {
138   auto stream = LookupStream(id);
139   if (!stream.ok()) {
140     PW_LOG_INFO("OnCancel unknown stream");
141     return;
142   }
143 
144   auto packet = pw::rpc::internal::Packet(PacketType::kClientError,
145                                           channel_id_,
146                                           stream->service_id,
147                                           stream->method_id,
148                                           id,
149                                           {},
150                                           pw::Status::Cancelled());
151   ResetStream(id);
152 
153   server_.ProcessPacket(packet).IgnoreError();
154 }
155 
LookupStream(StreamId id)156 Result<PwRpcHandler::Stream> PwRpcHandler::LookupStream(StreamId id) {
157   auto streams_locked = streams_.acquire();
158   for (size_t i = 0; i < streams_locked->size(); ++i) {
159     auto& stream = (*streams_locked)[i];
160     if (stream.id == id) {
161       return stream;
162     }
163   }
164   return Status::NotFound();
165 }
166 
ResetAllStreams()167 void PwRpcHandler::ResetAllStreams() {
168   auto streams_locked = streams_.acquire();
169   for (size_t i = 0; i < streams_locked->size(); ++i) {
170     auto& stream = (*streams_locked)[i];
171     stream.id = 0;
172   }
173 }
174 
ResetStream(StreamId id)175 void PwRpcHandler::ResetStream(StreamId id) {
176   auto streams_locked = streams_.acquire();
177   for (size_t i = 0; i < streams_locked->size(); ++i) {
178     auto& stream = (*streams_locked)[i];
179     if (stream.id == id) {
180       stream.id = 0;
181       break;
182     }
183   }
184 }
185 
MarkSentRequest(StreamId id)186 void PwRpcHandler::MarkSentRequest(StreamId id) {
187   auto streams_locked = streams_.acquire();
188   for (size_t i = 0; i < streams_locked->size(); ++i) {
189     auto& stream = (*streams_locked)[i];
190     if (stream.id == id) {
191       stream.sent_request = true;
192       break;
193     }
194   }
195 }
196 
CreateStream(StreamId id,uint32_t service_id,uint32_t method_id)197 Status PwRpcHandler::CreateStream(StreamId id,
198                                   uint32_t service_id,
199                                   uint32_t method_id) {
200   auto streams_locked = streams_.acquire();
201 
202   for (size_t i = 0; i < streams_locked->size(); ++i) {
203     auto& stream = (*streams_locked)[i];
204     if (!stream.id) {
205       stream.id = id;
206       stream.service_id = service_id;
207       stream.method_id = method_id;
208       stream.sent_request = false;
209       return OkStatus();
210     }
211   }
212   return Status::ResourceExhausted();
213 }
214 
215 }  // namespace pw::grpc
216