• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #include "pw_rpc/server.h"
16 
17 #include <algorithm>
18 
19 #include "pw_log/log.h"
20 #include "pw_rpc/internal/packet.h"
21 #include "pw_rpc/internal/server.h"
22 #include "pw_rpc/server_context.h"
23 
24 namespace pw::rpc {
25 namespace {
26 
27 using std::byte;
28 
29 using internal::Packet;
30 using internal::PacketType;
31 
DecodePacket(ChannelOutput & interface,std::span<const byte> data,Packet & packet)32 bool DecodePacket(ChannelOutput& interface,
33                   std::span<const byte> data,
34                   Packet& packet) {
35   Result<Packet> result = Packet::FromBuffer(data);
36   if (!result.ok()) {
37     PW_LOG_WARN("Failed to decode packet on interface %s", interface.name());
38     return false;
39   }
40 
41   packet = result.value();
42 
43   // If the packet is malformed, don't try to process it.
44   if (packet.channel_id() == Channel::kUnassignedChannelId ||
45       packet.service_id() == 0 || packet.method_id() == 0) {
46     PW_LOG_WARN("Received incomplete packet on interface %s", interface.name());
47 
48     // Only send an ERROR response if a valid channel ID was provided.
49     if (packet.channel_id() != Channel::kUnassignedChannelId) {
50       internal::Channel temp_channel(packet.channel_id(), &interface);
51       temp_channel.Send(Packet::ServerError(packet, Status::DataLoss()));
52     }
53     return false;
54   }
55 
56   return true;
57 }
58 
59 }  // namespace
60 
~Server()61 Server::~Server() {
62   // Since the writers remove themselves from the server in Finish(), remove the
63   // first writer until no writers remain.
64   while (!writers_.empty()) {
65     writers_.front().Finish();
66   }
67 }
68 
ProcessPacket(std::span<const byte> data,ChannelOutput & interface)69 Status Server::ProcessPacket(std::span<const byte> data,
70                              ChannelOutput& interface) {
71   Packet packet;
72   if (!DecodePacket(interface, data, packet)) {
73     return Status::DataLoss();
74   }
75 
76   if (packet.destination() != Packet::kServer) {
77     return Status::InvalidArgument();
78   }
79 
80   internal::Channel* channel = FindChannel(packet.channel_id());
81   if (channel == nullptr) {
82     // If the requested channel doesn't exist, try to dynamically assign one.
83     channel = AssignChannel(packet.channel_id(), interface);
84     if (channel == nullptr) {
85       // If a channel can't be assigned, send a RESOURCE_EXHAUSTED error.
86       internal::Channel temp_channel(packet.channel_id(), &interface);
87       temp_channel.Send(
88           Packet::ServerError(packet, Status::ResourceExhausted()));
89       return OkStatus();  // OK since the packet was handled
90     }
91   }
92 
93   const auto [service, method] = FindMethod(packet);
94 
95   if (method == nullptr) {
96     channel->Send(Packet::ServerError(packet, Status::NotFound()));
97     return OkStatus();
98   }
99 
100   switch (packet.type()) {
101     case PacketType::REQUEST: {
102       internal::ServerCall call(
103           static_cast<internal::Server&>(*this), *channel, *service, *method);
104       method->Invoke(call, packet);
105       break;
106     }
107     case PacketType::CLIENT_STREAM_END:
108       // TODO(hepler): Support client streaming RPCs.
109       break;
110     case PacketType::CLIENT_ERROR:
111       HandleClientError(packet);
112       break;
113     case PacketType::CANCEL_SERVER_STREAM:
114       HandleCancelPacket(packet, *channel);
115       break;
116     default:
117       channel->Send(Packet::ServerError(packet, Status::Unimplemented()));
118       PW_LOG_WARN("Unable to handle packet of type %u",
119                   unsigned(packet.type()));
120   }
121   return OkStatus();
122 }
123 
FindMethod(const internal::Packet & packet)124 std::tuple<Service*, const internal::Method*> Server::FindMethod(
125     const internal::Packet& packet) {
126   // Packets always include service and method IDs.
127   auto service = std::find_if(services_.begin(), services_.end(), [&](auto& s) {
128     return s.id() == packet.service_id();
129   });
130 
131   if (service == services_.end()) {
132     return {};
133   }
134 
135   return {&(*service), service->FindMethod(packet.method_id())};
136 }
137 
HandleCancelPacket(const Packet & packet,internal::Channel & channel)138 void Server::HandleCancelPacket(const Packet& packet,
139                                 internal::Channel& channel) {
140   auto writer = std::find_if(writers_.begin(), writers_.end(), [&](auto& w) {
141     return w.channel_id() == packet.channel_id() &&
142            w.service_id() == packet.service_id() &&
143            w.method_id() == packet.method_id();
144   });
145 
146   if (writer == writers_.end()) {
147     channel.Send(Packet::ServerError(packet, Status::FailedPrecondition()));
148     PW_LOG_WARN("Received CANCEL packet for method that is not pending");
149   } else {
150     writer->Finish(Status::Cancelled());
151   }
152 }
153 
HandleClientError(const Packet & packet)154 void Server::HandleClientError(const Packet& packet) {
155   // A client error indicates that the client received a packet that it did not
156   // expect. If the packet belongs to a streaming RPC, cancel the stream without
157   // sending a final SERVER_STREAM_END packet.
158   auto writer = std::find_if(writers_.begin(), writers_.end(), [&](auto& w) {
159     return w.channel_id() == packet.channel_id() &&
160            w.service_id() == packet.service_id() &&
161            w.method_id() == packet.method_id();
162   });
163 
164   if (writer != writers_.end()) {
165     writer->Close();
166   }
167 }
168 
FindChannel(uint32_t id) const169 internal::Channel* Server::FindChannel(uint32_t id) const {
170   for (internal::Channel& c : channels_) {
171     if (c.id() == id) {
172       return &c;
173     }
174   }
175   return nullptr;
176 }
177 
AssignChannel(uint32_t id,ChannelOutput & interface)178 internal::Channel* Server::AssignChannel(uint32_t id,
179                                          ChannelOutput& interface) {
180   internal::Channel* channel = FindChannel(Channel::kUnassignedChannelId);
181   if (channel == nullptr) {
182     return nullptr;
183   }
184 
185   *channel = internal::Channel(id, &interface);
186   return channel;
187 }
188 
189 }  // namespace pw::rpc
190