• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 // clang-format off
16 #include "pw_rpc/internal/log_config.h" // PW_LOG_* macros must be first.
17 
18 #include "pw_rpc/server.h"
19 // clang-format on
20 
21 #include <algorithm>
22 
23 #include "pw_log/log.h"
24 #include "pw_rpc/internal/endpoint.h"
25 #include "pw_rpc/internal/packet.h"
26 #include "pw_rpc/service_id.h"
27 
28 namespace pw::rpc {
29 namespace {
30 
31 using internal::Packet;
32 using internal::pwpb::PacketType;
33 
34 }  // namespace
35 
ProcessPacket(ConstByteSpan packet_data)36 Status Server::ProcessPacket(ConstByteSpan packet_data) {
37   PW_TRY_ASSIGN(Packet packet,
38                 Endpoint::ProcessPacket(packet_data, Packet::kServer));
39   return ProcessPacket(packet);
40 }
41 
ProcessPacket(internal::Packet packet)42 Status Server::ProcessPacket(internal::Packet packet) {
43   internal::rpc_lock().lock();
44 
45   // Verbose log for debugging.
46   // PW_LOG_DEBUG("RPC server received packet type %u for %u:%08x/%08x",
47   //              static_cast<unsigned>(packet.type()),
48   //              static_cast<unsigned>(packet.channel_id()),
49   //              static_cast<unsigned>(packet.service_id()),
50   //              static_cast<unsigned>(packet.method_id()));
51 
52   internal::Channel* channel = GetInternalChannel(packet.channel_id());
53   if (channel == nullptr) {
54     internal::rpc_lock().unlock();
55     PW_LOG_WARN("RPC server received packet for unknown channel %u",
56                 static_cast<unsigned>(packet.channel_id()));
57     return Status::Unavailable();
58   }
59 
60   const auto [service, method] = FindMethodLocked(packet);
61 
62   if (method == nullptr) {
63     // Don't send responses to errors to avoid infinite error cycles.
64     if (packet.type() != PacketType::CLIENT_ERROR) {
65       channel->Send(Packet::ServerError(packet, Status::NotFound()))
66           .IgnoreError();
67     }
68     internal::rpc_lock().unlock();
69     PW_LOG_DEBUG("Received packet on channel %u for unknown RPC %08x/%08x",
70                  static_cast<unsigned>(packet.channel_id()),
71                  static_cast<unsigned>(packet.service_id()),
72                  static_cast<unsigned>(packet.method_id()));
73     return OkStatus();  // OK since the packet was handled.
74   }
75 
76   // Handle request packets separately to avoid an unnecessary call lookup. The
77   // Call constructor looks up and cancels any duplicate calls.
78   if (packet.type() == PacketType::REQUEST) {
79     const internal::CallContext context(
80         *this, packet.channel_id(), *service, *method, packet.call_id());
81     method->Invoke(context, packet);
82     return OkStatus();
83   }
84 
85   IntrusiveList<internal::Call>::iterator call = FindCall(packet);
86 
87   switch (packet.type()) {
88     case PacketType::CLIENT_STREAM:
89       HandleClientStreamPacket(packet, *channel, call);
90       break;
91     case PacketType::CLIENT_ERROR:
92       if (call != calls_end()) {
93         call->HandleError(packet.status());
94       } else {
95         internal::rpc_lock().unlock();
96       }
97       break;
98     case PacketType::CLIENT_REQUEST_COMPLETION:
99       HandleCompletionRequest(packet, *channel, call);
100       break;
101     case PacketType::REQUEST:  // Handled above
102     case PacketType::RESPONSE:
103     case PacketType::SERVER_ERROR:
104     case PacketType::SERVER_STREAM:
105     default:
106       internal::rpc_lock().unlock();
107       PW_LOG_WARN("pw_rpc server unable to handle packet of type %u",
108                   unsigned(packet.type()));
109   }
110 
111   return OkStatus();  // OK since the packet was handled
112 }
113 
FindMethod(uint32_t service_id,uint32_t method_id)114 std::tuple<Service*, const internal::Method*> Server::FindMethod(
115     uint32_t service_id, uint32_t method_id) {
116   internal::RpcLockGuard lock;
117   return FindMethodLocked(service_id, method_id);
118 }
119 
FindMethodLocked(uint32_t service_id,uint32_t method_id)120 std::tuple<Service*, const internal::Method*> Server::FindMethodLocked(
121     uint32_t service_id, uint32_t method_id) {
122   auto service = std::find_if(services_.begin(), services_.end(), [&](auto& s) {
123     return internal::UnwrapServiceId(s.service_id()) == service_id;
124   });
125 
126   if (service == services_.end()) {
127     return {};
128   }
129 
130   return {&(*service), service->FindMethod(method_id)};
131 }
132 
HandleCompletionRequest(const internal::Packet & packet,internal::Channel & channel,IntrusiveList<internal::Call>::iterator call) const133 void Server::HandleCompletionRequest(
134     const internal::Packet& packet,
135     internal::Channel& channel,
136     IntrusiveList<internal::Call>::iterator call) const {
137   if (call == calls_end()) {
138     channel.Send(Packet::ServerError(packet, Status::FailedPrecondition()))
139         .IgnoreError();  // Errors are logged in Channel::Send.
140     internal::rpc_lock().unlock();
141     PW_LOG_DEBUG(
142         "Received a request completion packet for %u:%08x/%08x, which is not a"
143         "pending call",
144         static_cast<unsigned>(packet.channel_id()),
145         static_cast<unsigned>(packet.service_id()),
146         static_cast<unsigned>(packet.method_id()));
147     return;
148   }
149 
150   if (call->client_requested_completion()) {
151     internal::rpc_lock().unlock();
152     PW_LOG_DEBUG("Received multiple completion requests for %u:%08x/%08x",
153                  static_cast<unsigned>(packet.channel_id()),
154                  static_cast<unsigned>(packet.service_id()),
155                  static_cast<unsigned>(packet.method_id()));
156     return;
157   }
158 
159   static_cast<internal::ServerCall&>(*call).HandleClientRequestedCompletion();
160 }
161 
HandleClientStreamPacket(const internal::Packet & packet,internal::Channel & channel,IntrusiveList<internal::Call>::iterator call) const162 void Server::HandleClientStreamPacket(
163     const internal::Packet& packet,
164     internal::Channel& channel,
165     IntrusiveList<internal::Call>::iterator call) const {
166   if (call == calls_end()) {
167     channel.Send(Packet::ServerError(packet, Status::FailedPrecondition()))
168         .IgnoreError();  // Errors are logged in Channel::Send.
169     internal::rpc_lock().unlock();
170     PW_LOG_DEBUG(
171         "Received client stream packet for %u:%08x/%08x, which is not pending",
172         static_cast<unsigned>(packet.channel_id()),
173         static_cast<unsigned>(packet.service_id()),
174         static_cast<unsigned>(packet.method_id()));
175     return;
176   }
177 
178   if (!call->has_client_stream()) {
179     channel.Send(Packet::ServerError(packet, Status::InvalidArgument()))
180         .IgnoreError();  // Errors are logged in Channel::Send.
181     internal::rpc_lock().unlock();
182     PW_LOG_DEBUG(
183         "Received client stream packet for %u:%08x/%08x, which doesn't have a "
184         "client stream",
185         static_cast<unsigned>(packet.channel_id()),
186         static_cast<unsigned>(packet.service_id()),
187         static_cast<unsigned>(packet.method_id()));
188     return;
189   }
190 
191   if (call->client_requested_completion()) {
192     channel.Send(Packet::ServerError(packet, Status::FailedPrecondition()))
193         .IgnoreError();  // Errors are logged in Channel::Send.
194     internal::rpc_lock().unlock();
195     PW_LOG_DEBUG(
196         "Received client stream packet for %u:%08x/%08x, but its client stream "
197         "is closed",
198         static_cast<unsigned>(packet.channel_id()),
199         static_cast<unsigned>(packet.service_id()),
200         static_cast<unsigned>(packet.method_id()));
201     return;
202   }
203 
204   call->HandlePayload(packet.payload());
205 }
206 
207 }  // namespace pw::rpc
208