1 // Copyright 2020 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include "pw_rpc/internal/call.h"
16
17 #include "pw_assert/check.h"
18 #include "pw_bytes/span.h"
19 #include "pw_log/log.h"
20 #include "pw_preprocessor/util.h"
21 #include "pw_rpc/channel.h"
22 #include "pw_rpc/internal/encoding_buffer.h"
23 #include "pw_rpc/internal/endpoint.h"
24 #include "pw_rpc/internal/method.h"
25 #include "pw_rpc/internal/packet.pwpb.h"
26 #include "pw_status/status_with_size.h"
27 #include "pw_status/try.h"
28
29 // If the callback timeout is enabled, count the number of iterations of the
30 // waiting loop and crash if it exceeds PW_RPC_CALLBACK_TIMEOUT_TICKS.
31 #if PW_RPC_CALLBACK_TIMEOUT_TICKS > 0
32 #define PW_RPC_CHECK_FOR_DEADLOCK(timeout_source, call) \
33 iterations += 1; \
34 PW_CHECK( \
35 iterations < PW_RPC_CALLBACK_TIMEOUT_TICKS, \
36 "A callback for RPC %u:%08x/%08x has not finished after " \
37 PW_STRINGIFY(PW_RPC_CALLBACK_TIMEOUT_TICKS) \
38 " ticks. This may indicate that an RPC callback attempted to " \
39 timeout_source \
40 " its own call object, which is not permitted. Fix this condition or " \
41 "change the value of PW_RPC_CALLBACK_TIMEOUT_TICKS to avoid this " \
42 "crash. See https://pigweed.dev/pw_rpc" \
43 "#destructors-moves-wait-for-callbacks-to-complete for details.", \
44 static_cast<unsigned>((call).channel_id_), \
45 static_cast<unsigned>((call).service_id_), \
46 static_cast<unsigned>((call).method_id_))
47 #else
48 #define PW_RPC_CHECK_FOR_DEADLOCK(timeout_source, call) \
49 static_cast<void>(iterations)
50 #endif // PW_RPC_CALLBACK_TIMEOUT_TICKS > 0
51
52 namespace pw::rpc::internal {
53
54 using pwpb::PacketType;
55
EncodeCallbackToPayloadBuffer(const Function<StatusWithSize (ByteSpan)> & callback)56 Result<ConstByteSpan> EncodeCallbackToPayloadBuffer(
57 const Function<StatusWithSize(ByteSpan)>& callback)
58 PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock()) {
59 if (callback == nullptr) {
60 return Status::InvalidArgument();
61 }
62
63 ByteSpan payload_buffer =
64 encoding_buffer.AllocatePayloadBuffer(MaxSafePayloadSize());
65 PW_TRY_ASSIGN(const size_t payload_size, callback(payload_buffer));
66
67 return payload_buffer.first(payload_size);
68 }
69
70 // Creates an active server-side Call.
Call(const LockedCallContext & context,CallProperties properties)71 Call::Call(const LockedCallContext& context, CallProperties properties)
72 : Call(context.server().ClaimLocked(),
73 context.call_id(),
74 context.channel_id(),
75 UnwrapServiceId(context.service().service_id()),
76 context.method().id(),
77 properties) {}
78
79 // Creates an active client-side call, assigning it a new ID.
Call(LockedEndpoint & client,uint32_t channel_id,uint32_t service_id,uint32_t method_id,CallProperties properties)80 Call::Call(LockedEndpoint& client,
81 uint32_t channel_id,
82 uint32_t service_id,
83 uint32_t method_id,
84 CallProperties properties)
85 : Call(client,
86 client.NewCallId(),
87 channel_id,
88 service_id,
89 method_id,
90 properties) {}
91
Call(LockedEndpoint & endpoint_ref,uint32_t call_id,uint32_t channel_id,uint32_t service_id,uint32_t method_id,CallProperties properties)92 Call::Call(LockedEndpoint& endpoint_ref,
93 uint32_t call_id,
94 uint32_t channel_id,
95 uint32_t service_id,
96 uint32_t method_id,
97 CallProperties properties)
98 : endpoint_(&endpoint_ref),
99 channel_id_(channel_id),
100 id_(call_id),
101 service_id_(service_id),
102 method_id_(method_id),
103 // Note: Bit kActive set to 1 and kClientRequestedCompletion is set to 0.
104 state_(kActive),
105 awaiting_cleanup_(OkStatus().code()),
106 callbacks_executing_(0),
107 properties_(properties) {
108 PW_CHECK_UINT_NE(channel_id,
109 Channel::kUnassignedChannelId,
110 "Calls cannot be created with channel ID 0 "
111 "(Channel::kUnassignedChannelId)");
112 endpoint().RegisterCall(*this);
113 }
114
DestroyServerCall()115 void Call::DestroyServerCall() {
116 RpcLockGuard lock;
117 // Any errors are logged in Channel::Send.
118 CloseAndSendResponseLocked(OkStatus()).IgnoreError();
119 WaitForCallbacksToComplete();
120 state_ |= kHasBeenDestroyed;
121 }
122
DestroyClientCall()123 void Call::DestroyClientCall() {
124 RpcLockGuard lock;
125 CloseClientCall();
126 WaitForCallbacksToComplete();
127 state_ |= kHasBeenDestroyed;
128 }
129
WaitForCallbacksToComplete()130 void Call::WaitForCallbacksToComplete() {
131 do {
132 int iterations = 0;
133 while (CallbacksAreRunning()) {
134 PW_RPC_CHECK_FOR_DEADLOCK("destroy", *this);
135 YieldRpcLock();
136 }
137
138 } while (CleanUpIfRequired());
139 }
140
MoveFrom(Call & other)141 void Call::MoveFrom(Call& other) {
142 PW_DCHECK(!active_locked());
143 PW_DCHECK(!awaiting_cleanup() && !other.awaiting_cleanup());
144
145 // An active call with an executing callback cannot be moved. Derived call
146 // classes must wait for callbacks to finish before calling MoveFrom.
147 PW_DCHECK(!other.active_locked() || !other.CallbacksAreRunning());
148
149 // Copy all members from the other call.
150 endpoint_ = other.endpoint_;
151 channel_id_ = other.channel_id_;
152 id_ = other.id_;
153 service_id_ = other.service_id_;
154 method_id_ = other.method_id_;
155
156 state_ = other.state_;
157
158 // No need to move awaiting_cleanup_, since it is 0 in both calls here.
159
160 properties_ = other.properties_;
161
162 // callbacks_executing_ is not moved since it is associated with the object in
163 // memory, not the call.
164
165 on_error_ = std::move(other.on_error_);
166 on_next_ = std::move(other.on_next_);
167
168 if (other.active_locked()) {
169 // Mark the other call inactive, unregister it, and register this one.
170 other.MarkClosed();
171 endpoint().UnregisterCall(other);
172 endpoint().RegisterUniqueCall(*this);
173 }
174 }
175
WaitUntilReadyForMove(Call & destination,Call & source)176 void Call::WaitUntilReadyForMove(Call& destination, Call& source) {
177 do {
178 // Wait for the source's callbacks to finish if it is active.
179 int iterations = 0;
180 while (source.active_locked() && source.CallbacksAreRunning()) {
181 PW_RPC_CHECK_FOR_DEADLOCK("move", source);
182 YieldRpcLock();
183 }
184
185 // At this point, no callbacks are running in the source call. If cleanup
186 // is required for the destination call, perform it and retry since
187 // cleanup releases and reacquires the RPC lock.
188 } while (source.CleanUpIfRequired() || destination.CleanUpIfRequired());
189 }
190
CallOnError(Status error)191 void Call::CallOnError(Status error) {
192 auto on_error_local = std::move(on_error_);
193
194 CallbackStarted();
195
196 rpc_lock().unlock();
197 if (on_error_local) {
198 on_error_local(error);
199 }
200
201 // This mutex lock could be avoided by making callbacks_executing_ atomic.
202 RpcLockGuard lock;
203 CallbackFinished();
204 }
205
CleanUpIfRequired()206 bool Call::CleanUpIfRequired() PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock()) {
207 if (!awaiting_cleanup()) {
208 return false;
209 }
210 endpoint_->CleanUpCall(*this);
211 rpc_lock().lock();
212 return true;
213 }
214
SendPacket(PacketType type,ConstByteSpan payload,Status status)215 Status Call::SendPacket(PacketType type, ConstByteSpan payload, Status status) {
216 if (!active_locked()) {
217 encoding_buffer.ReleaseIfAllocated();
218 return Status::FailedPrecondition();
219 }
220
221 ChannelBase* channel = endpoint_->GetInternalChannel(channel_id_);
222 if (channel == nullptr) {
223 encoding_buffer.ReleaseIfAllocated();
224 return Status::Unavailable();
225 }
226 return channel->Send(MakePacket(type, payload, status));
227 }
228
CloseAndSendResponseCallbackLocked(const Function<StatusWithSize (ByteSpan)> & callback,Status status)229 Status Call::CloseAndSendResponseCallbackLocked(
230 const Function<StatusWithSize(ByteSpan)>& callback, Status status) {
231 PW_TRY_ASSIGN(ConstByteSpan payload, EncodeCallbackToPayloadBuffer(callback));
232 return CloseAndSendFinalPacketLocked(
233 pwpb::PacketType::RESPONSE, payload, status);
234 }
235
TryCloseAndSendResponseCallbackLocked(const Function<StatusWithSize (ByteSpan)> & callback,Status status)236 Status Call::TryCloseAndSendResponseCallbackLocked(
237 const Function<StatusWithSize(ByteSpan)>& callback, Status status) {
238 PW_TRY_ASSIGN(ConstByteSpan payload, EncodeCallbackToPayloadBuffer(callback));
239 return TryCloseAndSendFinalPacketLocked(
240 pwpb::PacketType::RESPONSE, payload, status);
241 }
242
CloseAndSendFinalPacketLocked(PacketType type,ConstByteSpan response,Status status)243 Status Call::CloseAndSendFinalPacketLocked(PacketType type,
244 ConstByteSpan response,
245 Status status) {
246 const Status send_status = SendPacket(type, response, status);
247 UnregisterAndMarkClosed();
248 return send_status;
249 }
250
TryCloseAndSendFinalPacketLocked(PacketType type,ConstByteSpan response,Status status)251 Status Call::TryCloseAndSendFinalPacketLocked(PacketType type,
252 ConstByteSpan response,
253 Status status) {
254 const Status send_status = SendPacket(type, response, status);
255 // Only close the call if the final packet gets sent out successfully.
256 if (send_status.ok()) {
257 UnregisterAndMarkClosed();
258 }
259 return send_status;
260 }
261
WriteLocked(ConstByteSpan payload)262 Status Call::WriteLocked(ConstByteSpan payload) {
263 return SendPacket(properties_.call_type() == kServerCall
264 ? PacketType::SERVER_STREAM
265 : PacketType::CLIENT_STREAM,
266 payload);
267 }
268
WriteCallbackLocked(const Function<StatusWithSize (ByteSpan)> & callback)269 Status Call::WriteCallbackLocked(
270 const Function<StatusWithSize(ByteSpan)>& callback) {
271 PW_TRY_ASSIGN(ConstByteSpan payload, EncodeCallbackToPayloadBuffer(callback));
272 return SendPacket(properties_.call_type() == kServerCall
273 ? PacketType::SERVER_STREAM
274 : PacketType::CLIENT_STREAM,
275 payload);
276 }
277
278 // This definition is in the .cc file because the Endpoint class is not defined
279 // in the Call header, due to circular dependencies between the two.
CloseAndMarkForCleanup(Status error)280 void Call::CloseAndMarkForCleanup(Status error) {
281 endpoint_->CloseCallAndMarkForCleanup(*this, error);
282 }
283
HandlePayload(ConstByteSpan payload)284 void Call::HandlePayload(ConstByteSpan payload) {
285 // pw_rpc only supports handling packets for a particular RPC one at a time.
286 // Check if any callbacks are running and drop the packet if they are.
287 //
288 // The on_next callback cannot support multiple packets at once since it is
289 // moved before it is invoked. on_error and on_completed are only called
290 // after the call is closed.
291 if (CallbacksAreRunning()) {
292 PW_LOG_WARN(
293 "Received stream packet for %u:%08x/%08x before the callback for a "
294 "previous packet completed! This packet will be dropped. This can be "
295 "avoided by handling packets for a particular RPC on only one thread.",
296 static_cast<unsigned>(channel_id_),
297 static_cast<unsigned>(service_id_),
298 static_cast<unsigned>(method_id_));
299 rpc_lock().unlock();
300 return;
301 }
302
303 if (on_next_ == nullptr) {
304 rpc_lock().unlock();
305 return;
306 }
307
308 const uint32_t original_id = id();
309 auto on_next_local = std::move(on_next_);
310 CallbackStarted();
311
312 if (hold_lock_while_invoking_callback_with_payload()) {
313 on_next_local(payload);
314 } else {
315 rpc_lock().unlock();
316 on_next_local(payload);
317 rpc_lock().lock();
318 }
319
320 CallbackFinished();
321
322 // Restore the original callback if the original call is still active and
323 // the callback has not been replaced.
324 // NOLINTNEXTLINE(bugprone-use-after-move)
325 if (active_locked() && id() == original_id && on_next_ == nullptr) {
326 on_next_ = std::move(on_next_local);
327 }
328
329 // The call could have been reinitialized and cleaned up already by another
330 // thread that acquired the rpc_lock() while on_next_local was executing
331 // without lock held.
332 if (endpoint_ != nullptr) {
333 // Clean up calls in case decoding failed.
334 endpoint_->CleanUpCalls();
335 } else {
336 rpc_lock().unlock();
337 }
338 }
339
CloseClientCall()340 void Call::CloseClientCall() {
341 // When a client call is closed, for bidirectional and client streaming RPCs,
342 // the server may be waiting for client stream messages, so we need to notify
343 // the server that the client has requested for completion and no further
344 // requests should be expected from the client. For unary and server streaming
345 // RPCs, since the client is not sending messages, server does not need to be
346 // notified.
347 if (has_client_stream() && !client_requested_completion()) {
348 RequestCompletionLocked().IgnoreError();
349 }
350 UnregisterAndMarkClosed();
351 }
352
UnregisterAndMarkClosed()353 void Call::UnregisterAndMarkClosed() {
354 if (active_locked()) {
355 endpoint().UnregisterCall(*this);
356 MarkClosed();
357 }
358 }
359
DebugLog() const360 void Call::DebugLog() const PW_NO_LOCK_SAFETY_ANALYSIS {
361 PW_LOG_INFO(
362 "Call %p\n"
363 "\tEndpoint: %p\n"
364 "\tCall ID: %8u\n"
365 "\tChannel: %8u\n"
366 "\tService: %08x\n"
367 "\tMethod: %08x\n"
368 "\tState: %8x\n"
369 "\tCleanup: %8s\n"
370 "\tBusy CBs: %8x\n"
371 "\tType: %8d\n"
372 "\tClient: %8d\n"
373 "\tWrapped: %8d\n"
374 "\ton_error: %8d\n"
375 "\ton_next: %8d\n",
376 static_cast<const void*>(this),
377 static_cast<const void*>(endpoint_),
378 static_cast<unsigned>(id_),
379 static_cast<unsigned>(channel_id_),
380 static_cast<unsigned>(service_id_),
381 static_cast<unsigned>(method_id_),
382 static_cast<int>(state_),
383 Status(static_cast<Status::Code>(awaiting_cleanup_)).str(),
384 static_cast<int>(callbacks_executing_),
385 static_cast<int>(properties_.method_type()),
386 static_cast<int>(properties_.call_type()),
387 static_cast<int>(hold_lock_while_invoking_callback_with_payload()),
388 static_cast<int>(on_error_ == nullptr),
389 static_cast<int>(on_next_ == nullptr));
390 }
391
392 } // namespace pw::rpc::internal
393