1 // Copyright 2024 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #define PW_LOG_MODULE_NAME "TRN"
16 #define PW_LOG_LEVEL PW_TRANSFER_CONFIG_LOG_LEVEL
17
18 #include "pw_transfer/transfer_thread.h"
19
20 #include "pw_assert/check.h"
21 #include "pw_log/log.h"
22 #include "pw_transfer/internal/chunk.h"
23 #include "pw_transfer/internal/client_context.h"
24 #include "pw_transfer/internal/config.h"
25 #include "pw_transfer/internal/event.h"
26
27 PW_MODIFY_DIAGNOSTICS_PUSH();
28 PW_MODIFY_DIAGNOSTIC(ignored, "-Wmissing-field-initializers");
29
30 namespace pw::transfer::internal {
31
Terminate()32 void TransferThread::Terminate() {
33 next_event_ownership_.acquire();
34 next_event_.type = EventType::kTerminate;
35 event_notification_.release();
36 }
37
SimulateTimeout(EventType type,uint32_t session_id)38 void TransferThread::SimulateTimeout(EventType type, uint32_t session_id) {
39 next_event_ownership_.acquire();
40
41 next_event_.type = type;
42 next_event_.chunk = {};
43 next_event_.chunk.context_identifier = session_id;
44
45 event_notification_.release();
46
47 WaitUntilEventIsProcessed();
48 }
49
Run()50 void TransferThread::Run() {
51 // Next event starts freed.
52 next_event_ownership_.release();
53
54 while (true) {
55 if (event_notification_.try_acquire_until(GetNextTransferTimeout())) {
56 HandleEvent(next_event_);
57
58 // Sample event type before we release ownership of next_event_.
59 bool is_terminating = next_event_.type == EventType::kTerminate;
60
61 // Finished processing the event. Allow the next_event struct to be
62 // overwritten.
63 next_event_ownership_.release();
64
65 if (is_terminating) {
66 return;
67 }
68 }
69
70 // Regardless of whether an event was received or not, check for any
71 // transfers which have timed out and process them if so.
72 for (Context& context : client_transfers_) {
73 if (context.timed_out()) {
74 context.HandleEvent({.type = EventType::kClientTimeout});
75 }
76 }
77 for (Context& context : server_transfers_) {
78 if (context.timed_out()) {
79 context.HandleEvent({.type = EventType::kServerTimeout});
80 }
81 }
82 }
83 }
84
GetNextTransferTimeout() const85 chrono::SystemClock::time_point TransferThread::GetNextTransferTimeout() const {
86 chrono::SystemClock::time_point timeout =
87 chrono::SystemClock::TimePointAfterAtLeast(kMaxTimeout);
88
89 for (Context& context : client_transfers_) {
90 auto ctx_timeout = context.timeout();
91 if (ctx_timeout.has_value() && ctx_timeout.value() < timeout) {
92 timeout = ctx_timeout.value();
93 }
94 }
95 for (Context& context : server_transfers_) {
96 auto ctx_timeout = context.timeout();
97 if (ctx_timeout.has_value() && ctx_timeout.value() < timeout) {
98 timeout = ctx_timeout.value();
99 }
100 }
101
102 return timeout;
103 }
104
StartTransfer(TransferType type,ProtocolVersion version,uint32_t session_id,uint32_t resource_id,uint32_t handle_id,ConstByteSpan raw_chunk,stream::Stream * stream,const TransferParameters & max_parameters,Function<void (Status)> && on_completion,chrono::SystemClock::duration timeout,chrono::SystemClock::duration initial_timeout,uint8_t max_retries,uint32_t max_lifetime_retries,uint32_t initial_offset)105 void TransferThread::StartTransfer(
106 TransferType type,
107 ProtocolVersion version,
108 uint32_t session_id,
109 uint32_t resource_id,
110 uint32_t handle_id,
111 ConstByteSpan raw_chunk,
112 stream::Stream* stream,
113 const TransferParameters& max_parameters,
114 Function<void(Status)>&& on_completion,
115 chrono::SystemClock::duration timeout,
116 chrono::SystemClock::duration initial_timeout,
117 uint8_t max_retries,
118 uint32_t max_lifetime_retries,
119 uint32_t initial_offset) {
120 // Block until the last event has been processed.
121 next_event_ownership_.acquire();
122
123 bool is_client_transfer = stream != nullptr;
124
125 if (is_client_transfer) {
126 if (version == ProtocolVersion::kLegacy) {
127 session_id = resource_id;
128 } else if (session_id == Context::kUnassignedSessionId) {
129 session_id = AssignSessionId();
130 }
131 }
132
133 next_event_.type = is_client_transfer ? EventType::kNewClientTransfer
134 : EventType::kNewServerTransfer;
135
136 if (!raw_chunk.empty()) {
137 std::memcpy(chunk_buffer_.data(), raw_chunk.data(), raw_chunk.size());
138 }
139
140 next_event_.new_transfer = {
141 .type = type,
142 .protocol_version = version,
143 .session_id = session_id,
144 .resource_id = resource_id,
145 .handle_id = handle_id,
146 .max_parameters = &max_parameters,
147 .timeout = timeout,
148 .initial_timeout = initial_timeout,
149 .max_retries = max_retries,
150 .max_lifetime_retries = max_lifetime_retries,
151 .transfer_thread = this,
152 .raw_chunk_data = chunk_buffer_.data(),
153 .raw_chunk_size = raw_chunk.size(),
154 .initial_offset = initial_offset,
155 };
156
157 staged_on_completion_ = std::move(on_completion);
158
159 // The transfer is initialized with either a stream (client-side) or a handler
160 // (server-side). If no stream is provided, try to find a registered handler
161 // with the specified ID.
162 if (is_client_transfer) {
163 next_event_.new_transfer.stream = stream;
164 next_event_.new_transfer.rpc_writer =
165 &(type == TransferType::kTransmit ? client_write_stream_
166 : client_read_stream_)
167 .as_writer();
168 } else {
169 auto handler = std::find_if(handlers_.begin(),
170 handlers_.end(),
171 [&](auto& h) { return h.id() == resource_id; });
172 if (handler != handlers_.end()) {
173 next_event_.new_transfer.handler = &*handler;
174 next_event_.new_transfer.rpc_writer =
175 &(type == TransferType::kTransmit ? server_read_stream_
176 : server_write_stream_)
177 .as_writer();
178 } else {
179 // No handler exists for the transfer: return a NOT_FOUND.
180 next_event_.type = EventType::kSendStatusChunk;
181 next_event_.send_status_chunk = {
182 .session_id = session_id,
183 .protocol_version = version,
184 .status = Status::NotFound().code(),
185 .stream = type == TransferType::kTransmit
186 ? TransferStream::kServerRead
187 : TransferStream::kServerWrite,
188 };
189 }
190 }
191
192 event_notification_.release();
193 }
194
ProcessChunk(EventType type,ConstByteSpan chunk)195 void TransferThread::ProcessChunk(EventType type, ConstByteSpan chunk) {
196 // If this assert is hit, there is a bug in the transfer implementation.
197 // Contexts' max_chunk_size_bytes fields should be set based on the size of
198 // chunk_buffer_.
199 PW_CHECK(chunk.size() <= chunk_buffer_.size(),
200 "Transfer received a larger chunk than it can handle.");
201
202 Result<Chunk::Identifier> identifier = Chunk::ExtractIdentifier(chunk);
203 if (!identifier.ok()) {
204 PW_LOG_ERROR("Received a malformed chunk without a context identifier");
205 return;
206 }
207
208 // Block until the last event has been processed.
209 next_event_ownership_.acquire();
210
211 std::memcpy(chunk_buffer_.data(), chunk.data(), chunk.size());
212
213 next_event_.type = type;
214 next_event_.chunk = {
215 .context_identifier = identifier->value(),
216 .match_resource_id = identifier->is_legacy(),
217 .data = chunk_buffer_.data(),
218 .size = chunk.size(),
219 };
220
221 event_notification_.release();
222 }
223
SendStatus(TransferStream stream,uint32_t session_id,ProtocolVersion version,Status status)224 void TransferThread::SendStatus(TransferStream stream,
225 uint32_t session_id,
226 ProtocolVersion version,
227 Status status) {
228 // Block until the last event has been processed.
229 next_event_ownership_.acquire();
230
231 next_event_.type = EventType::kSendStatusChunk;
232 next_event_.send_status_chunk = {
233 .session_id = session_id,
234 .protocol_version = version,
235 .status = status.code(),
236 .stream = stream,
237 };
238
239 event_notification_.release();
240 }
241
EndTransfer(EventType type,IdentifierType id_type,uint32_t id,Status status,bool send_status_chunk)242 void TransferThread::EndTransfer(EventType type,
243 IdentifierType id_type,
244 uint32_t id,
245 Status status,
246 bool send_status_chunk) {
247 // Block until the last event has been processed.
248 next_event_ownership_.acquire();
249
250 next_event_.type = type;
251 next_event_.end_transfer = {
252 .id_type = id_type,
253 .id = id,
254 .status = status.code(),
255 .send_status_chunk = send_status_chunk,
256 };
257
258 event_notification_.release();
259 }
260
SetStream(TransferStream stream)261 void TransferThread::SetStream(TransferStream stream) {
262 // Block until the last event has been processed.
263 next_event_ownership_.acquire();
264
265 next_event_.type = EventType::kSetStream;
266 next_event_.set_stream = {
267 .stream = stream,
268 };
269
270 event_notification_.release();
271 }
272
UpdateClientTransfer(uint32_t handle_id,size_t transfer_size_bytes)273 void TransferThread::UpdateClientTransfer(uint32_t handle_id,
274 size_t transfer_size_bytes) {
275 // Block until the last event has been processed.
276 next_event_ownership_.acquire();
277
278 next_event_.type = EventType::kUpdateClientTransfer;
279 next_event_.update_transfer.handle_id = handle_id;
280 next_event_.update_transfer.transfer_size_bytes = transfer_size_bytes;
281
282 event_notification_.release();
283 }
284
TransferHandlerEvent(EventType type,Handler & handler)285 void TransferThread::TransferHandlerEvent(EventType type, Handler& handler) {
286 // Block until the last event has been processed.
287 next_event_ownership_.acquire();
288
289 next_event_.type = type;
290 if (type == EventType::kAddTransferHandler) {
291 next_event_.add_transfer_handler = &handler;
292 } else {
293 next_event_.remove_transfer_handler = &handler;
294 }
295
296 event_notification_.release();
297 }
298
HandleEvent(const internal::Event & event)299 void TransferThread::HandleEvent(const internal::Event& event) {
300 switch (event.type) {
301 case EventType::kTerminate:
302 // Terminate server contexts.
303 for (ServerContext& server_context : server_transfers_) {
304 server_context.HandleEvent(Event{
305 .type = EventType::kServerEndTransfer,
306 .end_transfer =
307 EndTransferEvent{
308 .id_type = IdentifierType::Session,
309 .id = server_context.session_id(),
310 .status = Status::Aborted().code(),
311 .send_status_chunk = false,
312 },
313 });
314 }
315
316 // Terminate client contexts.
317 for (ClientContext& client_context : client_transfers_) {
318 client_context.HandleEvent(Event{
319 .type = EventType::kClientEndTransfer,
320 .end_transfer =
321 EndTransferEvent{
322 .id_type = IdentifierType::Session,
323 .id = client_context.session_id(),
324 .status = Status::Aborted().code(),
325 .send_status_chunk = false,
326 },
327 });
328 }
329
330 // Cancel/Finish streams.
331 client_read_stream_.Cancel().IgnoreError();
332 client_write_stream_.Cancel().IgnoreError();
333 server_read_stream_.Finish(Status::Aborted()).IgnoreError();
334 server_write_stream_.Finish(Status::Aborted()).IgnoreError();
335 return;
336
337 case EventType::kSendStatusChunk:
338 SendStatusChunk(event.send_status_chunk);
339 break;
340
341 case EventType::kAddTransferHandler:
342 handlers_.push_front(*event.add_transfer_handler);
343 return;
344
345 case EventType::kRemoveTransferHandler:
346 for (ServerContext& server_context : server_transfers_) {
347 if (server_context.handler() == event.remove_transfer_handler) {
348 server_context.HandleEvent(Event{
349 .type = EventType::kServerEndTransfer,
350 .end_transfer =
351 EndTransferEvent{
352 .id_type = IdentifierType::Session,
353 .id = server_context.session_id(),
354 .status = Status::Aborted().code(),
355 .send_status_chunk = false,
356 },
357 });
358 }
359 }
360 handlers_.remove(*event.remove_transfer_handler);
361 return;
362
363 case EventType::kSetStream:
364 HandleSetStreamEvent(event.set_stream.stream);
365 return;
366
367 case EventType::kGetResourceStatus:
368 GetResourceState(event.resource_status.resource_id);
369 return;
370
371 case EventType::kNewClientTransfer:
372 case EventType::kNewServerTransfer:
373 case EventType::kClientChunk:
374 case EventType::kServerChunk:
375 case EventType::kClientTimeout:
376 case EventType::kServerTimeout:
377 case EventType::kClientEndTransfer:
378 case EventType::kServerEndTransfer:
379 case EventType::kUpdateClientTransfer:
380 default:
381 // Other events are handled by individual transfer contexts.
382 break;
383 }
384
385 Context* ctx = FindContextForEvent(event);
386 if (ctx == nullptr) {
387 // No context was found. For new transfer events, report a
388 // RESOURCE_EXHAUSTED error with starting the transfer.
389 if (event.type == EventType::kNewClientTransfer) {
390 // On the client, invoke the completion callback directly.
391 staged_on_completion_(Status::ResourceExhausted());
392 } else if (event.type == EventType::kNewServerTransfer) {
393 // On the server, send a status chunk back to the client.
394 SendStatusChunk(
395 {.session_id = event.new_transfer.session_id,
396 .protocol_version = event.new_transfer.protocol_version,
397 .status = Status::ResourceExhausted().code(),
398 .stream = event.new_transfer.type == TransferType::kTransmit
399 ? TransferStream::kServerRead
400 : TransferStream::kServerWrite});
401 }
402 return;
403 }
404
405 if (event.type == EventType::kNewClientTransfer) {
406 // TODO(frolv): This is terrible.
407 ClientContext* cctx = static_cast<ClientContext*>(ctx);
408 cctx->set_on_completion(std::move(staged_on_completion_));
409 cctx->set_handle_id(event.new_transfer.handle_id);
410 }
411
412 if (event.type == EventType::kUpdateClientTransfer) {
413 static_cast<ClientContext&>(*ctx).set_transfer_size_bytes(
414 event.update_transfer.transfer_size_bytes);
415 return;
416 }
417
418 ctx->HandleEvent(event);
419 }
420
FindContextForEvent(const internal::Event & event) const421 Context* TransferThread::FindContextForEvent(
422 const internal::Event& event) const {
423 switch (event.type) {
424 case EventType::kNewClientTransfer:
425 return FindNewTransfer(client_transfers_, event.new_transfer.session_id);
426 case EventType::kNewServerTransfer:
427 return FindNewTransfer(server_transfers_, event.new_transfer.session_id);
428
429 case EventType::kClientChunk:
430 if (event.chunk.match_resource_id) {
431 return FindActiveTransferByResourceId(client_transfers_,
432 event.chunk.context_identifier);
433 }
434 return FindActiveTransferByLegacyId(client_transfers_,
435 event.chunk.context_identifier);
436
437 case EventType::kServerChunk:
438 if (event.chunk.match_resource_id) {
439 return FindActiveTransferByResourceId(server_transfers_,
440 event.chunk.context_identifier);
441 }
442 return FindActiveTransferByLegacyId(server_transfers_,
443 event.chunk.context_identifier);
444
445 case EventType::kClientTimeout: // Manually triggered client timeout
446 return FindActiveTransferByLegacyId(client_transfers_,
447 event.chunk.context_identifier);
448 case EventType::kServerTimeout: // Manually triggered server timeout
449 return FindActiveTransferByLegacyId(server_transfers_,
450 event.chunk.context_identifier);
451
452 case EventType::kClientEndTransfer:
453 if (event.end_transfer.id_type == IdentifierType::Handle) {
454 return FindClientTransferByHandleId(event.end_transfer.id);
455 }
456 return FindActiveTransferByLegacyId(client_transfers_,
457 event.end_transfer.id);
458 case EventType::kServerEndTransfer:
459 PW_DCHECK(event.end_transfer.id_type != IdentifierType::Handle);
460 return FindActiveTransferByLegacyId(server_transfers_,
461 event.end_transfer.id);
462
463 case EventType::kUpdateClientTransfer:
464 return FindClientTransferByHandleId(event.update_transfer.handle_id);
465
466 case EventType::kSendStatusChunk:
467 case EventType::kAddTransferHandler:
468 case EventType::kRemoveTransferHandler:
469 case EventType::kSetStream:
470 case EventType::kTerminate:
471 case EventType::kGetResourceStatus:
472 default:
473 return nullptr;
474 }
475 }
476
SendStatusChunk(const internal::SendStatusChunkEvent & event)477 void TransferThread::SendStatusChunk(
478 const internal::SendStatusChunkEvent& event) {
479 rpc::Writer& destination = stream_for(event.stream);
480
481 Chunk chunk =
482 Chunk::Final(event.protocol_version, event.session_id, event.status);
483
484 Result<ConstByteSpan> result = chunk.Encode(chunk_buffer_);
485 if (!result.ok()) {
486 PW_LOG_ERROR("Failed to encode final chunk for transfer %u",
487 static_cast<unsigned>(event.session_id));
488 return;
489 }
490
491 if (!destination.Write(result.value()).ok()) {
492 PW_LOG_ERROR("Failed to send final chunk for transfer %u",
493 static_cast<unsigned>(event.session_id));
494 return;
495 }
496 }
497
498 // Should only be called with the `next_event_ownership_` lock held.
AssignSessionId()499 uint32_t TransferThread::AssignSessionId() {
500 uint32_t session_id = next_session_id_++;
501 if (session_id == 0) {
502 session_id = next_session_id_++;
503 }
504 return session_id;
505 }
506
507 template <typename T>
TerminateTransfers(span<T> contexts,TransferType type,EventType event_type,Status status)508 void TerminateTransfers(span<T> contexts,
509 TransferType type,
510 EventType event_type,
511 Status status) {
512 for (Context& context : contexts) {
513 if (context.active() && context.type() == type) {
514 context.HandleEvent(Event{
515 .type = event_type,
516 .end_transfer =
517 EndTransferEvent{
518 .id_type = IdentifierType::Session,
519 .id = context.session_id(),
520 .status = status.code(),
521 .send_status_chunk = false,
522 },
523 });
524 }
525 }
526 }
527
HandleSetStreamEvent(TransferStream stream)528 void TransferThread::HandleSetStreamEvent(TransferStream stream) {
529 switch (stream) {
530 case TransferStream::kClientRead:
531 TerminateTransfers(client_transfers_,
532 TransferType::kReceive,
533 EventType::kClientEndTransfer,
534 Status::Aborted());
535 client_read_stream_ = std::move(staged_client_stream_);
536 client_read_stream_.set_on_next(std::move(staged_client_on_next_));
537 break;
538 case TransferStream::kClientWrite:
539 TerminateTransfers(client_transfers_,
540 TransferType::kTransmit,
541 EventType::kClientEndTransfer,
542 Status::Aborted());
543 client_write_stream_ = std::move(staged_client_stream_);
544 client_write_stream_.set_on_next(std::move(staged_client_on_next_));
545 break;
546 case TransferStream::kServerRead:
547 TerminateTransfers(server_transfers_,
548 TransferType::kTransmit,
549 EventType::kServerEndTransfer,
550 Status::Aborted());
551 server_read_stream_ = std::move(staged_server_stream_);
552 server_read_stream_.set_on_next(std::move(staged_server_on_next_));
553 break;
554 case TransferStream::kServerWrite:
555 TerminateTransfers(server_transfers_,
556 TransferType::kReceive,
557 EventType::kServerEndTransfer,
558 Status::Aborted());
559 server_write_stream_ = std::move(staged_server_stream_);
560 server_write_stream_.set_on_next(std::move(staged_server_on_next_));
561 break;
562 }
563 }
564
565 // Adds GetResourceStatusEvent to the queue. Will fail if there is already a
566 // GetResourceStatusEvent in process.
EnqueueResourceEvent(uint32_t resource_id,ResourceStatusCallback && callback)567 void TransferThread::EnqueueResourceEvent(uint32_t resource_id,
568 ResourceStatusCallback&& callback) {
569 // Block until the last event has been processed.
570 next_event_ownership_.acquire();
571
572 next_event_.type = EventType::kGetResourceStatus;
573
574 resource_status_callback_ = std::move(callback);
575
576 next_event_.resource_status.resource_id = resource_id;
577
578 event_notification_.release();
579 }
580
581 // Should only be called when we got a valid callback and RPC responder from
582 // GetResourceStatus transfer RPC.
GetResourceState(uint32_t resource_id)583 void TransferThread::GetResourceState(uint32_t resource_id) {
584 PW_ASSERT(resource_status_callback_ != nullptr);
585
586 auto handler = std::find_if(handlers_.begin(), handlers_.end(), [&](auto& h) {
587 return h.id() == resource_id;
588 });
589 internal::ResourceStatus stats;
590 stats.resource_id = resource_id;
591
592 if (handler != handlers_.end()) {
593 Status status = handler->GetStatus(stats.readable_offset,
594 stats.writeable_offset,
595 stats.read_checksum,
596 stats.write_checksum);
597
598 resource_status_callback_(status, stats);
599 } else {
600 resource_status_callback_(Status::NotFound(), stats);
601 }
602 }
603
604 } // namespace pw::transfer::internal
605
606 PW_MODIFY_DIAGNOSTICS_POP();
607