1 // Copyright 2020 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 // clang-format off
16 #include "pw_rpc/internal/log_config.h" // PW_LOG_* macros must be first.
17
18 #include "pw_rpc/server.h"
19 // clang-format on
20
21 #include <algorithm>
22
23 #include "pw_log/log.h"
24 #include "pw_rpc/internal/endpoint.h"
25 #include "pw_rpc/internal/packet.h"
26 #include "pw_rpc/service_id.h"
27
28 namespace pw::rpc {
29 namespace {
30
31 using internal::Packet;
32 using internal::pwpb::PacketType;
33
34 } // namespace
35
ProcessPacket(ConstByteSpan packet_data)36 Status Server::ProcessPacket(ConstByteSpan packet_data) {
37 PW_TRY_ASSIGN(Packet packet,
38 Endpoint::ProcessPacket(packet_data, Packet::kServer));
39 return ProcessPacket(packet);
40 }
41
ProcessPacket(internal::Packet packet)42 Status Server::ProcessPacket(internal::Packet packet) {
43 internal::rpc_lock().lock();
44
45 // Verbose log for debugging.
46 // PW_LOG_DEBUG("RPC server received packet type %u for %u:%08x/%08x",
47 // static_cast<unsigned>(packet.type()),
48 // static_cast<unsigned>(packet.channel_id()),
49 // static_cast<unsigned>(packet.service_id()),
50 // static_cast<unsigned>(packet.method_id()));
51
52 internal::Channel* channel = GetInternalChannel(packet.channel_id());
53 if (channel == nullptr) {
54 internal::rpc_lock().unlock();
55 PW_LOG_WARN("RPC server received packet for unknown channel %u",
56 static_cast<unsigned>(packet.channel_id()));
57 return Status::Unavailable();
58 }
59
60 const auto [service, method] = FindMethodLocked(packet);
61
62 if (method == nullptr) {
63 // Don't send responses to errors to avoid infinite error cycles.
64 if (packet.type() != PacketType::CLIENT_ERROR) {
65 channel->Send(Packet::ServerError(packet, Status::NotFound()))
66 .IgnoreError();
67 }
68 internal::rpc_lock().unlock();
69 PW_LOG_DEBUG("Received packet on channel %u for unknown RPC %08x/%08x",
70 static_cast<unsigned>(packet.channel_id()),
71 static_cast<unsigned>(packet.service_id()),
72 static_cast<unsigned>(packet.method_id()));
73 return OkStatus(); // OK since the packet was handled.
74 }
75
76 // Handle request packets separately to avoid an unnecessary call lookup. The
77 // Call constructor looks up and cancels any duplicate calls.
78 if (packet.type() == PacketType::REQUEST) {
79 const internal::CallContext context(
80 *this, packet.channel_id(), *service, *method, packet.call_id());
81 method->Invoke(context, packet);
82 return OkStatus();
83 }
84
85 IntrusiveList<internal::Call>::iterator call = FindCall(packet);
86
87 switch (packet.type()) {
88 case PacketType::CLIENT_STREAM:
89 HandleClientStreamPacket(packet, *channel, call);
90 break;
91 case PacketType::CLIENT_ERROR:
92 if (call != calls_end()) {
93 call->HandleError(packet.status());
94 } else {
95 internal::rpc_lock().unlock();
96 }
97 break;
98 case PacketType::CLIENT_REQUEST_COMPLETION:
99 HandleCompletionRequest(packet, *channel, call);
100 break;
101 case PacketType::REQUEST: // Handled above
102 case PacketType::RESPONSE:
103 case PacketType::SERVER_ERROR:
104 case PacketType::SERVER_STREAM:
105 default:
106 internal::rpc_lock().unlock();
107 PW_LOG_WARN("pw_rpc server unable to handle packet of type %u",
108 unsigned(packet.type()));
109 }
110
111 return OkStatus(); // OK since the packet was handled
112 }
113
FindMethod(uint32_t service_id,uint32_t method_id)114 std::tuple<Service*, const internal::Method*> Server::FindMethod(
115 uint32_t service_id, uint32_t method_id) {
116 internal::RpcLockGuard lock;
117 return FindMethodLocked(service_id, method_id);
118 }
119
FindMethodLocked(uint32_t service_id,uint32_t method_id)120 std::tuple<Service*, const internal::Method*> Server::FindMethodLocked(
121 uint32_t service_id, uint32_t method_id) {
122 auto service = std::find_if(services_.begin(), services_.end(), [&](auto& s) {
123 return internal::UnwrapServiceId(s.service_id()) == service_id;
124 });
125
126 if (service == services_.end()) {
127 return {};
128 }
129
130 return {&(*service), service->FindMethod(method_id)};
131 }
132
HandleCompletionRequest(const internal::Packet & packet,internal::Channel & channel,IntrusiveList<internal::Call>::iterator call) const133 void Server::HandleCompletionRequest(
134 const internal::Packet& packet,
135 internal::Channel& channel,
136 IntrusiveList<internal::Call>::iterator call) const {
137 if (call == calls_end()) {
138 channel.Send(Packet::ServerError(packet, Status::FailedPrecondition()))
139 .IgnoreError(); // Errors are logged in Channel::Send.
140 internal::rpc_lock().unlock();
141 PW_LOG_DEBUG(
142 "Received a request completion packet for %u:%08x/%08x, which is not a"
143 "pending call",
144 static_cast<unsigned>(packet.channel_id()),
145 static_cast<unsigned>(packet.service_id()),
146 static_cast<unsigned>(packet.method_id()));
147 return;
148 }
149
150 if (call->client_requested_completion()) {
151 internal::rpc_lock().unlock();
152 PW_LOG_DEBUG("Received multiple completion requests for %u:%08x/%08x",
153 static_cast<unsigned>(packet.channel_id()),
154 static_cast<unsigned>(packet.service_id()),
155 static_cast<unsigned>(packet.method_id()));
156 return;
157 }
158
159 static_cast<internal::ServerCall&>(*call).HandleClientRequestedCompletion();
160 }
161
HandleClientStreamPacket(const internal::Packet & packet,internal::Channel & channel,IntrusiveList<internal::Call>::iterator call) const162 void Server::HandleClientStreamPacket(
163 const internal::Packet& packet,
164 internal::Channel& channel,
165 IntrusiveList<internal::Call>::iterator call) const {
166 if (call == calls_end()) {
167 channel.Send(Packet::ServerError(packet, Status::FailedPrecondition()))
168 .IgnoreError(); // Errors are logged in Channel::Send.
169 internal::rpc_lock().unlock();
170 PW_LOG_DEBUG(
171 "Received client stream packet for %u:%08x/%08x, which is not pending",
172 static_cast<unsigned>(packet.channel_id()),
173 static_cast<unsigned>(packet.service_id()),
174 static_cast<unsigned>(packet.method_id()));
175 return;
176 }
177
178 if (!call->has_client_stream()) {
179 channel.Send(Packet::ServerError(packet, Status::InvalidArgument()))
180 .IgnoreError(); // Errors are logged in Channel::Send.
181 internal::rpc_lock().unlock();
182 PW_LOG_DEBUG(
183 "Received client stream packet for %u:%08x/%08x, which doesn't have a "
184 "client stream",
185 static_cast<unsigned>(packet.channel_id()),
186 static_cast<unsigned>(packet.service_id()),
187 static_cast<unsigned>(packet.method_id()));
188 return;
189 }
190
191 if (call->client_requested_completion()) {
192 channel.Send(Packet::ServerError(packet, Status::FailedPrecondition()))
193 .IgnoreError(); // Errors are logged in Channel::Send.
194 internal::rpc_lock().unlock();
195 PW_LOG_DEBUG(
196 "Received client stream packet for %u:%08x/%08x, but its client stream "
197 "is closed",
198 static_cast<unsigned>(packet.channel_id()),
199 static_cast<unsigned>(packet.service_id()),
200 static_cast<unsigned>(packet.method_id()));
201 return;
202 }
203
204 call->HandlePayload(packet.payload());
205 }
206
207 } // namespace pw::rpc
208