1 // Copyright 2020 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include "pw_rpc/server.h"
16
17 #include <algorithm>
18
19 #include "pw_log/log.h"
20 #include "pw_rpc/internal/packet.h"
21 #include "pw_rpc/internal/server.h"
22 #include "pw_rpc/server_context.h"
23
24 namespace pw::rpc {
25 namespace {
26
27 using std::byte;
28
29 using internal::Packet;
30 using internal::PacketType;
31
DecodePacket(ChannelOutput & interface,std::span<const byte> data,Packet & packet)32 bool DecodePacket(ChannelOutput& interface,
33 std::span<const byte> data,
34 Packet& packet) {
35 Result<Packet> result = Packet::FromBuffer(data);
36 if (!result.ok()) {
37 PW_LOG_WARN("Failed to decode packet on interface %s", interface.name());
38 return false;
39 }
40
41 packet = result.value();
42
43 // If the packet is malformed, don't try to process it.
44 if (packet.channel_id() == Channel::kUnassignedChannelId ||
45 packet.service_id() == 0 || packet.method_id() == 0) {
46 PW_LOG_WARN("Received incomplete packet on interface %s", interface.name());
47
48 // Only send an ERROR response if a valid channel ID was provided.
49 if (packet.channel_id() != Channel::kUnassignedChannelId) {
50 internal::Channel temp_channel(packet.channel_id(), &interface);
51 temp_channel.Send(Packet::ServerError(packet, Status::DataLoss()));
52 }
53 return false;
54 }
55
56 return true;
57 }
58
59 } // namespace
60
~Server()61 Server::~Server() {
62 // Since the writers remove themselves from the server in Finish(), remove the
63 // first writer until no writers remain.
64 while (!writers_.empty()) {
65 writers_.front().Finish();
66 }
67 }
68
ProcessPacket(std::span<const byte> data,ChannelOutput & interface)69 Status Server::ProcessPacket(std::span<const byte> data,
70 ChannelOutput& interface) {
71 Packet packet;
72 if (!DecodePacket(interface, data, packet)) {
73 return Status::DataLoss();
74 }
75
76 if (packet.destination() != Packet::kServer) {
77 return Status::InvalidArgument();
78 }
79
80 internal::Channel* channel = FindChannel(packet.channel_id());
81 if (channel == nullptr) {
82 // If the requested channel doesn't exist, try to dynamically assign one.
83 channel = AssignChannel(packet.channel_id(), interface);
84 if (channel == nullptr) {
85 // If a channel can't be assigned, send a RESOURCE_EXHAUSTED error.
86 internal::Channel temp_channel(packet.channel_id(), &interface);
87 temp_channel.Send(
88 Packet::ServerError(packet, Status::ResourceExhausted()));
89 return OkStatus(); // OK since the packet was handled
90 }
91 }
92
93 const auto [service, method] = FindMethod(packet);
94
95 if (method == nullptr) {
96 channel->Send(Packet::ServerError(packet, Status::NotFound()));
97 return OkStatus();
98 }
99
100 switch (packet.type()) {
101 case PacketType::REQUEST: {
102 internal::ServerCall call(
103 static_cast<internal::Server&>(*this), *channel, *service, *method);
104 method->Invoke(call, packet);
105 break;
106 }
107 case PacketType::CLIENT_STREAM_END:
108 // TODO(hepler): Support client streaming RPCs.
109 break;
110 case PacketType::CLIENT_ERROR:
111 HandleClientError(packet);
112 break;
113 case PacketType::CANCEL_SERVER_STREAM:
114 HandleCancelPacket(packet, *channel);
115 break;
116 default:
117 channel->Send(Packet::ServerError(packet, Status::Unimplemented()));
118 PW_LOG_WARN("Unable to handle packet of type %u",
119 unsigned(packet.type()));
120 }
121 return OkStatus();
122 }
123
FindMethod(const internal::Packet & packet)124 std::tuple<Service*, const internal::Method*> Server::FindMethod(
125 const internal::Packet& packet) {
126 // Packets always include service and method IDs.
127 auto service = std::find_if(services_.begin(), services_.end(), [&](auto& s) {
128 return s.id() == packet.service_id();
129 });
130
131 if (service == services_.end()) {
132 return {};
133 }
134
135 return {&(*service), service->FindMethod(packet.method_id())};
136 }
137
HandleCancelPacket(const Packet & packet,internal::Channel & channel)138 void Server::HandleCancelPacket(const Packet& packet,
139 internal::Channel& channel) {
140 auto writer = std::find_if(writers_.begin(), writers_.end(), [&](auto& w) {
141 return w.channel_id() == packet.channel_id() &&
142 w.service_id() == packet.service_id() &&
143 w.method_id() == packet.method_id();
144 });
145
146 if (writer == writers_.end()) {
147 channel.Send(Packet::ServerError(packet, Status::FailedPrecondition()));
148 PW_LOG_WARN("Received CANCEL packet for method that is not pending");
149 } else {
150 writer->Finish(Status::Cancelled());
151 }
152 }
153
HandleClientError(const Packet & packet)154 void Server::HandleClientError(const Packet& packet) {
155 // A client error indicates that the client received a packet that it did not
156 // expect. If the packet belongs to a streaming RPC, cancel the stream without
157 // sending a final SERVER_STREAM_END packet.
158 auto writer = std::find_if(writers_.begin(), writers_.end(), [&](auto& w) {
159 return w.channel_id() == packet.channel_id() &&
160 w.service_id() == packet.service_id() &&
161 w.method_id() == packet.method_id();
162 });
163
164 if (writer != writers_.end()) {
165 writer->Close();
166 }
167 }
168
FindChannel(uint32_t id) const169 internal::Channel* Server::FindChannel(uint32_t id) const {
170 for (internal::Channel& c : channels_) {
171 if (c.id() == id) {
172 return &c;
173 }
174 }
175 return nullptr;
176 }
177
AssignChannel(uint32_t id,ChannelOutput & interface)178 internal::Channel* Server::AssignChannel(uint32_t id,
179 ChannelOutput& interface) {
180 internal::Channel* channel = FindChannel(Channel::kUnassignedChannelId);
181 if (channel == nullptr) {
182 return nullptr;
183 }
184
185 *channel = internal::Channel(id, &interface);
186 return channel;
187 }
188
189 } // namespace pw::rpc
190