• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 // clang-format off
16 #include "pw_rpc/internal/log_config.h" // PW_LOG_* macros must be first.
17 
18 #include "pw_rpc/server.h"
19 // clang-format on
20 
21 #include <algorithm>
22 
23 #include "pw_log/log.h"
24 #include "pw_rpc/internal/endpoint.h"
25 #include "pw_rpc/internal/packet.h"
26 #include "pw_rpc/service_id.h"
27 
28 namespace pw::rpc {
29 namespace {
30 
31 using internal::Packet;
32 using internal::pwpb::PacketType;
33 
34 }  // namespace
35 
ProcessPacket(ConstByteSpan packet_data)36 Status Server::ProcessPacket(ConstByteSpan packet_data) {
37   PW_TRY_ASSIGN(Packet packet,
38                 Endpoint::ProcessPacket(packet_data, Packet::kServer));
39 
40   internal::rpc_lock().lock();
41 
42   // Verbose log for debugging.
43   // PW_LOG_DEBUG("RPC server received packet type %u for %u:%08x/%08x",
44   //              static_cast<unsigned>(packet.type()),
45   //              static_cast<unsigned>(packet.channel_id()),
46   //              static_cast<unsigned>(packet.service_id()),
47   //              static_cast<unsigned>(packet.method_id()));
48 
49   internal::Channel* channel = GetInternalChannel(packet.channel_id());
50   if (channel == nullptr) {
51     internal::rpc_lock().unlock();
52     PW_LOG_WARN("RPC server received packet for unknown channel %u",
53                 static_cast<unsigned>(packet.channel_id()));
54     return Status::Unavailable();
55   }
56 
57   const auto [service, method] = FindMethod(packet);
58 
59   if (method == nullptr) {
60     // Don't send responses to errors to avoid infinite error cycles.
61     if (packet.type() != PacketType::CLIENT_ERROR) {
62       channel->Send(Packet::ServerError(packet, Status::NotFound()))
63           .IgnoreError();
64     }
65     internal::rpc_lock().unlock();
66     PW_LOG_DEBUG("Received packet on channel %u for unknown RPC %08x/%08x",
67                  static_cast<unsigned>(packet.channel_id()),
68                  static_cast<unsigned>(packet.service_id()),
69                  static_cast<unsigned>(packet.method_id()));
70     return OkStatus();  // OK since the packet was handled.
71   }
72 
73   // Handle request packets separately to avoid an unnecessary call lookup. The
74   // Call constructor looks up and cancels any duplicate calls.
75   if (packet.type() == PacketType::REQUEST) {
76     const internal::CallContext context(
77         *this, packet.channel_id(), *service, *method, packet.call_id());
78     method->Invoke(context, packet);
79     return OkStatus();
80   }
81 
82   IntrusiveList<internal::Call>::iterator call = FindCall(packet);
83 
84   switch (packet.type()) {
85     case PacketType::CLIENT_STREAM:
86       HandleClientStreamPacket(packet, *channel, call);
87       break;
88     case PacketType::CLIENT_ERROR:
89       if (call != calls_end()) {
90         call->HandleError(packet.status());
91       } else {
92         internal::rpc_lock().unlock();
93       }
94       break;
95     case PacketType::CLIENT_STREAM_END:
96       HandleClientStreamPacket(packet, *channel, call);
97       break;
98     case PacketType::REQUEST:  // Handled above
99     case PacketType::RESPONSE:
100     case PacketType::SERVER_ERROR:
101     case PacketType::SERVER_STREAM:
102     default:
103       internal::rpc_lock().unlock();
104       PW_LOG_WARN("pw_rpc server unable to handle packet of type %u",
105                   unsigned(packet.type()));
106   }
107 
108   return OkStatus();  // OK since the packet was handled
109 }
110 
FindMethod(const internal::Packet & packet)111 std::tuple<Service*, const internal::Method*> Server::FindMethod(
112     const internal::Packet& packet) {
113   // Packets always include service and method IDs.
114   auto service = std::find_if(services_.begin(), services_.end(), [&](auto& s) {
115     return internal::UnwrapServiceId(s.service_id()) == packet.service_id();
116   });
117 
118   if (service == services_.end()) {
119     return {};
120   }
121 
122   return {&(*service), service->FindMethod(packet.method_id())};
123 }
124 
HandleClientStreamPacket(const internal::Packet & packet,internal::Channel & channel,IntrusiveList<internal::Call>::iterator call) const125 void Server::HandleClientStreamPacket(
126     const internal::Packet& packet,
127     internal::Channel& channel,
128     IntrusiveList<internal::Call>::iterator call) const {
129   if (call == calls_end()) {
130     channel.Send(Packet::ServerError(packet, Status::FailedPrecondition()))
131         .IgnoreError();  // Errors are logged in Channel::Send.
132     internal::rpc_lock().unlock();
133     PW_LOG_DEBUG(
134         "Received client stream packet for %u:%08x/%08x, which is not pending",
135         static_cast<unsigned>(packet.channel_id()),
136         static_cast<unsigned>(packet.service_id()),
137         static_cast<unsigned>(packet.method_id()));
138     return;
139   }
140 
141   if (!call->has_client_stream()) {
142     channel.Send(Packet::ServerError(packet, Status::InvalidArgument()))
143         .IgnoreError();  // Errors are logged in Channel::Send.
144     internal::rpc_lock().unlock();
145     PW_LOG_DEBUG(
146         "Received client stream packet for %u:%08x/%08x, which doesn't have a "
147         "client stream",
148         static_cast<unsigned>(packet.channel_id()),
149         static_cast<unsigned>(packet.service_id()),
150         static_cast<unsigned>(packet.method_id()));
151     return;
152   }
153 
154   if (!call->client_stream_open()) {
155     channel.Send(Packet::ServerError(packet, Status::FailedPrecondition()))
156         .IgnoreError();  // Errors are logged in Channel::Send.
157     internal::rpc_lock().unlock();
158     PW_LOG_DEBUG(
159         "Received client stream packet for %u:%08x/%08x, but its client stream "
160         "is closed",
161         static_cast<unsigned>(packet.channel_id()),
162         static_cast<unsigned>(packet.service_id()),
163         static_cast<unsigned>(packet.method_id()));
164     return;
165   }
166 
167   if (packet.type() == PacketType::CLIENT_STREAM) {
168     call->HandlePayload(packet.payload());
169   } else {  // Handle PacketType::CLIENT_STREAM_END.
170     static_cast<internal::ServerCall&>(*call).HandleClientStreamEnd();
171   }
172 }
173 
174 }  // namespace pw::rpc
175