• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2024 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #define PW_LOG_MODULE_NAME "TRN"
16 #define PW_LOG_LEVEL PW_TRANSFER_CONFIG_LOG_LEVEL
17 
18 #include "pw_transfer/transfer_thread.h"
19 
20 #include "pw_assert/check.h"
21 #include "pw_log/log.h"
22 #include "pw_transfer/internal/chunk.h"
23 #include "pw_transfer/internal/client_context.h"
24 #include "pw_transfer/internal/config.h"
25 #include "pw_transfer/internal/event.h"
26 
27 PW_MODIFY_DIAGNOSTICS_PUSH();
28 PW_MODIFY_DIAGNOSTIC(ignored, "-Wmissing-field-initializers");
29 
30 namespace pw::transfer::internal {
31 
Terminate()32 void TransferThread::Terminate() {
33   next_event_ownership_.acquire();
34   next_event_.type = EventType::kTerminate;
35   event_notification_.release();
36 }
37 
SimulateTimeout(EventType type,uint32_t session_id)38 void TransferThread::SimulateTimeout(EventType type, uint32_t session_id) {
39   next_event_ownership_.acquire();
40 
41   next_event_.type = type;
42   next_event_.chunk = {};
43   next_event_.chunk.context_identifier = session_id;
44 
45   event_notification_.release();
46 
47   WaitUntilEventIsProcessed();
48 }
49 
Run()50 void TransferThread::Run() {
51   // Next event starts freed.
52   next_event_ownership_.release();
53 
54   while (true) {
55     if (event_notification_.try_acquire_until(GetNextTransferTimeout())) {
56       HandleEvent(next_event_);
57 
58       // Sample event type before we release ownership of next_event_.
59       bool is_terminating = next_event_.type == EventType::kTerminate;
60 
61       // Finished processing the event. Allow the next_event struct to be
62       // overwritten.
63       next_event_ownership_.release();
64 
65       if (is_terminating) {
66         return;
67       }
68     }
69 
70     // Regardless of whether an event was received or not, check for any
71     // transfers which have timed out and process them if so.
72     for (Context& context : client_transfers_) {
73       if (context.timed_out()) {
74         context.HandleEvent({.type = EventType::kClientTimeout});
75       }
76     }
77     for (Context& context : server_transfers_) {
78       if (context.timed_out()) {
79         context.HandleEvent({.type = EventType::kServerTimeout});
80       }
81     }
82   }
83 }
84 
GetNextTransferTimeout() const85 chrono::SystemClock::time_point TransferThread::GetNextTransferTimeout() const {
86   chrono::SystemClock::time_point timeout =
87       chrono::SystemClock::TimePointAfterAtLeast(kMaxTimeout);
88 
89   for (Context& context : client_transfers_) {
90     auto ctx_timeout = context.timeout();
91     if (ctx_timeout.has_value() && ctx_timeout.value() < timeout) {
92       timeout = ctx_timeout.value();
93     }
94   }
95   for (Context& context : server_transfers_) {
96     auto ctx_timeout = context.timeout();
97     if (ctx_timeout.has_value() && ctx_timeout.value() < timeout) {
98       timeout = ctx_timeout.value();
99     }
100   }
101 
102   return timeout;
103 }
104 
StartTransfer(TransferType type,ProtocolVersion version,uint32_t session_id,uint32_t resource_id,uint32_t handle_id,ConstByteSpan raw_chunk,stream::Stream * stream,const TransferParameters & max_parameters,Function<void (Status)> && on_completion,chrono::SystemClock::duration timeout,chrono::SystemClock::duration initial_timeout,uint8_t max_retries,uint32_t max_lifetime_retries,uint32_t initial_offset)105 void TransferThread::StartTransfer(
106     TransferType type,
107     ProtocolVersion version,
108     uint32_t session_id,
109     uint32_t resource_id,
110     uint32_t handle_id,
111     ConstByteSpan raw_chunk,
112     stream::Stream* stream,
113     const TransferParameters& max_parameters,
114     Function<void(Status)>&& on_completion,
115     chrono::SystemClock::duration timeout,
116     chrono::SystemClock::duration initial_timeout,
117     uint8_t max_retries,
118     uint32_t max_lifetime_retries,
119     uint32_t initial_offset) {
120   // Block until the last event has been processed.
121   next_event_ownership_.acquire();
122 
123   bool is_client_transfer = stream != nullptr;
124 
125   if (is_client_transfer) {
126     if (version == ProtocolVersion::kLegacy) {
127       session_id = resource_id;
128     } else if (session_id == Context::kUnassignedSessionId) {
129       session_id = AssignSessionId();
130     }
131   }
132 
133   next_event_.type = is_client_transfer ? EventType::kNewClientTransfer
134                                         : EventType::kNewServerTransfer;
135 
136   if (!raw_chunk.empty()) {
137     std::memcpy(chunk_buffer_.data(), raw_chunk.data(), raw_chunk.size());
138   }
139 
140   next_event_.new_transfer = {
141       .type = type,
142       .protocol_version = version,
143       .session_id = session_id,
144       .resource_id = resource_id,
145       .handle_id = handle_id,
146       .max_parameters = &max_parameters,
147       .timeout = timeout,
148       .initial_timeout = initial_timeout,
149       .max_retries = max_retries,
150       .max_lifetime_retries = max_lifetime_retries,
151       .transfer_thread = this,
152       .raw_chunk_data = chunk_buffer_.data(),
153       .raw_chunk_size = raw_chunk.size(),
154       .initial_offset = initial_offset,
155   };
156 
157   staged_on_completion_ = std::move(on_completion);
158 
159   // The transfer is initialized with either a stream (client-side) or a handler
160   // (server-side). If no stream is provided, try to find a registered handler
161   // with the specified ID.
162   if (is_client_transfer) {
163     next_event_.new_transfer.stream = stream;
164     next_event_.new_transfer.rpc_writer =
165         &(type == TransferType::kTransmit ? client_write_stream_
166                                           : client_read_stream_)
167              .as_writer();
168   } else {
169     auto handler = std::find_if(handlers_.begin(),
170                                 handlers_.end(),
171                                 [&](auto& h) { return h.id() == resource_id; });
172     if (handler != handlers_.end()) {
173       next_event_.new_transfer.handler = &*handler;
174       next_event_.new_transfer.rpc_writer =
175           &(type == TransferType::kTransmit ? server_read_stream_
176                                             : server_write_stream_)
177                .as_writer();
178     } else {
179       // No handler exists for the transfer: return a NOT_FOUND.
180       next_event_.type = EventType::kSendStatusChunk;
181       next_event_.send_status_chunk = {
182           .session_id = session_id,
183           .protocol_version = version,
184           .status = Status::NotFound().code(),
185           .stream = type == TransferType::kTransmit
186                         ? TransferStream::kServerRead
187                         : TransferStream::kServerWrite,
188       };
189     }
190   }
191 
192   event_notification_.release();
193 }
194 
ProcessChunk(EventType type,ConstByteSpan chunk)195 void TransferThread::ProcessChunk(EventType type, ConstByteSpan chunk) {
196   // If this assert is hit, there is a bug in the transfer implementation.
197   // Contexts' max_chunk_size_bytes fields should be set based on the size of
198   // chunk_buffer_.
199   PW_CHECK(chunk.size() <= chunk_buffer_.size(),
200            "Transfer received a larger chunk than it can handle.");
201 
202   Result<Chunk::Identifier> identifier = Chunk::ExtractIdentifier(chunk);
203   if (!identifier.ok()) {
204     PW_LOG_ERROR("Received a malformed chunk without a context identifier");
205     return;
206   }
207 
208   // Block until the last event has been processed.
209   next_event_ownership_.acquire();
210 
211   std::memcpy(chunk_buffer_.data(), chunk.data(), chunk.size());
212 
213   next_event_.type = type;
214   next_event_.chunk = {
215       .context_identifier = identifier->value(),
216       .match_resource_id = identifier->is_legacy(),
217       .data = chunk_buffer_.data(),
218       .size = chunk.size(),
219   };
220 
221   event_notification_.release();
222 }
223 
SendStatus(TransferStream stream,uint32_t session_id,ProtocolVersion version,Status status)224 void TransferThread::SendStatus(TransferStream stream,
225                                 uint32_t session_id,
226                                 ProtocolVersion version,
227                                 Status status) {
228   // Block until the last event has been processed.
229   next_event_ownership_.acquire();
230 
231   next_event_.type = EventType::kSendStatusChunk;
232   next_event_.send_status_chunk = {
233       .session_id = session_id,
234       .protocol_version = version,
235       .status = status.code(),
236       .stream = stream,
237   };
238 
239   event_notification_.release();
240 }
241 
EndTransfer(EventType type,IdentifierType id_type,uint32_t id,Status status,bool send_status_chunk)242 void TransferThread::EndTransfer(EventType type,
243                                  IdentifierType id_type,
244                                  uint32_t id,
245                                  Status status,
246                                  bool send_status_chunk) {
247   // Block until the last event has been processed.
248   next_event_ownership_.acquire();
249 
250   next_event_.type = type;
251   next_event_.end_transfer = {
252       .id_type = id_type,
253       .id = id,
254       .status = status.code(),
255       .send_status_chunk = send_status_chunk,
256   };
257 
258   event_notification_.release();
259 }
260 
SetStream(TransferStream stream)261 void TransferThread::SetStream(TransferStream stream) {
262   // Block until the last event has been processed.
263   next_event_ownership_.acquire();
264 
265   next_event_.type = EventType::kSetStream;
266   next_event_.set_stream = {
267       .stream = stream,
268   };
269 
270   event_notification_.release();
271 }
272 
UpdateClientTransfer(uint32_t handle_id,size_t transfer_size_bytes)273 void TransferThread::UpdateClientTransfer(uint32_t handle_id,
274                                           size_t transfer_size_bytes) {
275   // Block until the last event has been processed.
276   next_event_ownership_.acquire();
277 
278   next_event_.type = EventType::kUpdateClientTransfer;
279   next_event_.update_transfer.handle_id = handle_id;
280   next_event_.update_transfer.transfer_size_bytes = transfer_size_bytes;
281 
282   event_notification_.release();
283 }
284 
TransferHandlerEvent(EventType type,Handler & handler)285 void TransferThread::TransferHandlerEvent(EventType type, Handler& handler) {
286   // Block until the last event has been processed.
287   next_event_ownership_.acquire();
288 
289   next_event_.type = type;
290   if (type == EventType::kAddTransferHandler) {
291     next_event_.add_transfer_handler = &handler;
292   } else {
293     next_event_.remove_transfer_handler = &handler;
294   }
295 
296   event_notification_.release();
297 }
298 
HandleEvent(const internal::Event & event)299 void TransferThread::HandleEvent(const internal::Event& event) {
300   switch (event.type) {
301     case EventType::kTerminate:
302       // Terminate server contexts.
303       for (ServerContext& server_context : server_transfers_) {
304         server_context.HandleEvent(Event{
305             .type = EventType::kServerEndTransfer,
306             .end_transfer =
307                 EndTransferEvent{
308                     .id_type = IdentifierType::Session,
309                     .id = server_context.session_id(),
310                     .status = Status::Aborted().code(),
311                     .send_status_chunk = false,
312                 },
313         });
314       }
315 
316       // Terminate client contexts.
317       for (ClientContext& client_context : client_transfers_) {
318         client_context.HandleEvent(Event{
319             .type = EventType::kClientEndTransfer,
320             .end_transfer =
321                 EndTransferEvent{
322                     .id_type = IdentifierType::Session,
323                     .id = client_context.session_id(),
324                     .status = Status::Aborted().code(),
325                     .send_status_chunk = false,
326                 },
327         });
328       }
329 
330       // Cancel/Finish streams.
331       client_read_stream_.Cancel().IgnoreError();
332       client_write_stream_.Cancel().IgnoreError();
333       server_read_stream_.Finish(Status::Aborted()).IgnoreError();
334       server_write_stream_.Finish(Status::Aborted()).IgnoreError();
335       return;
336 
337     case EventType::kSendStatusChunk:
338       SendStatusChunk(event.send_status_chunk);
339       break;
340 
341     case EventType::kAddTransferHandler:
342       handlers_.push_front(*event.add_transfer_handler);
343       return;
344 
345     case EventType::kRemoveTransferHandler:
346       for (ServerContext& server_context : server_transfers_) {
347         if (server_context.handler() == event.remove_transfer_handler) {
348           server_context.HandleEvent(Event{
349               .type = EventType::kServerEndTransfer,
350               .end_transfer =
351                   EndTransferEvent{
352                       .id_type = IdentifierType::Session,
353                       .id = server_context.session_id(),
354                       .status = Status::Aborted().code(),
355                       .send_status_chunk = false,
356                   },
357           });
358         }
359       }
360       handlers_.remove(*event.remove_transfer_handler);
361       return;
362 
363     case EventType::kSetStream:
364       HandleSetStreamEvent(event.set_stream.stream);
365       return;
366 
367     case EventType::kGetResourceStatus:
368       GetResourceState(event.resource_status.resource_id);
369       return;
370 
371     case EventType::kNewClientTransfer:
372     case EventType::kNewServerTransfer:
373     case EventType::kClientChunk:
374     case EventType::kServerChunk:
375     case EventType::kClientTimeout:
376     case EventType::kServerTimeout:
377     case EventType::kClientEndTransfer:
378     case EventType::kServerEndTransfer:
379     case EventType::kUpdateClientTransfer:
380     default:
381       // Other events are handled by individual transfer contexts.
382       break;
383   }
384 
385   Context* ctx = FindContextForEvent(event);
386   if (ctx == nullptr) {
387     // No context was found. For new transfer events, report a
388     // RESOURCE_EXHAUSTED error with starting the transfer.
389     if (event.type == EventType::kNewClientTransfer) {
390       // On the client, invoke the completion callback directly.
391       staged_on_completion_(Status::ResourceExhausted());
392     } else if (event.type == EventType::kNewServerTransfer) {
393       // On the server, send a status chunk back to the client.
394       SendStatusChunk(
395           {.session_id = event.new_transfer.session_id,
396            .protocol_version = event.new_transfer.protocol_version,
397            .status = Status::ResourceExhausted().code(),
398            .stream = event.new_transfer.type == TransferType::kTransmit
399                          ? TransferStream::kServerRead
400                          : TransferStream::kServerWrite});
401     }
402     return;
403   }
404 
405   if (event.type == EventType::kNewClientTransfer) {
406     // TODO(frolv): This is terrible.
407     ClientContext* cctx = static_cast<ClientContext*>(ctx);
408     cctx->set_on_completion(std::move(staged_on_completion_));
409     cctx->set_handle_id(event.new_transfer.handle_id);
410   }
411 
412   if (event.type == EventType::kUpdateClientTransfer) {
413     static_cast<ClientContext&>(*ctx).set_transfer_size_bytes(
414         event.update_transfer.transfer_size_bytes);
415     return;
416   }
417 
418   ctx->HandleEvent(event);
419 }
420 
FindContextForEvent(const internal::Event & event) const421 Context* TransferThread::FindContextForEvent(
422     const internal::Event& event) const {
423   switch (event.type) {
424     case EventType::kNewClientTransfer:
425       return FindNewTransfer(client_transfers_, event.new_transfer.session_id);
426     case EventType::kNewServerTransfer:
427       return FindNewTransfer(server_transfers_, event.new_transfer.session_id);
428 
429     case EventType::kClientChunk:
430       if (event.chunk.match_resource_id) {
431         return FindActiveTransferByResourceId(client_transfers_,
432                                               event.chunk.context_identifier);
433       }
434       return FindActiveTransferByLegacyId(client_transfers_,
435                                           event.chunk.context_identifier);
436 
437     case EventType::kServerChunk:
438       if (event.chunk.match_resource_id) {
439         return FindActiveTransferByResourceId(server_transfers_,
440                                               event.chunk.context_identifier);
441       }
442       return FindActiveTransferByLegacyId(server_transfers_,
443                                           event.chunk.context_identifier);
444 
445     case EventType::kClientTimeout:  // Manually triggered client timeout
446       return FindActiveTransferByLegacyId(client_transfers_,
447                                           event.chunk.context_identifier);
448     case EventType::kServerTimeout:  // Manually triggered server timeout
449       return FindActiveTransferByLegacyId(server_transfers_,
450                                           event.chunk.context_identifier);
451 
452     case EventType::kClientEndTransfer:
453       if (event.end_transfer.id_type == IdentifierType::Handle) {
454         return FindClientTransferByHandleId(event.end_transfer.id);
455       }
456       return FindActiveTransferByLegacyId(client_transfers_,
457                                           event.end_transfer.id);
458     case EventType::kServerEndTransfer:
459       PW_DCHECK(event.end_transfer.id_type != IdentifierType::Handle);
460       return FindActiveTransferByLegacyId(server_transfers_,
461                                           event.end_transfer.id);
462 
463     case EventType::kUpdateClientTransfer:
464       return FindClientTransferByHandleId(event.update_transfer.handle_id);
465 
466     case EventType::kSendStatusChunk:
467     case EventType::kAddTransferHandler:
468     case EventType::kRemoveTransferHandler:
469     case EventType::kSetStream:
470     case EventType::kTerminate:
471     case EventType::kGetResourceStatus:
472     default:
473       return nullptr;
474   }
475 }
476 
SendStatusChunk(const internal::SendStatusChunkEvent & event)477 void TransferThread::SendStatusChunk(
478     const internal::SendStatusChunkEvent& event) {
479   rpc::Writer& destination = stream_for(event.stream);
480 
481   Chunk chunk =
482       Chunk::Final(event.protocol_version, event.session_id, event.status);
483 
484   Result<ConstByteSpan> result = chunk.Encode(chunk_buffer_);
485   if (!result.ok()) {
486     PW_LOG_ERROR("Failed to encode final chunk for transfer %u",
487                  static_cast<unsigned>(event.session_id));
488     return;
489   }
490 
491   if (!destination.Write(result.value()).ok()) {
492     PW_LOG_ERROR("Failed to send final chunk for transfer %u",
493                  static_cast<unsigned>(event.session_id));
494     return;
495   }
496 }
497 
498 // Should only be called with the `next_event_ownership_` lock held.
AssignSessionId()499 uint32_t TransferThread::AssignSessionId() {
500   uint32_t session_id = next_session_id_++;
501   if (session_id == 0) {
502     session_id = next_session_id_++;
503   }
504   return session_id;
505 }
506 
507 template <typename T>
TerminateTransfers(span<T> contexts,TransferType type,EventType event_type,Status status)508 void TerminateTransfers(span<T> contexts,
509                         TransferType type,
510                         EventType event_type,
511                         Status status) {
512   for (Context& context : contexts) {
513     if (context.active() && context.type() == type) {
514       context.HandleEvent(Event{
515           .type = event_type,
516           .end_transfer =
517               EndTransferEvent{
518                   .id_type = IdentifierType::Session,
519                   .id = context.session_id(),
520                   .status = status.code(),
521                   .send_status_chunk = false,
522               },
523       });
524     }
525   }
526 }
527 
HandleSetStreamEvent(TransferStream stream)528 void TransferThread::HandleSetStreamEvent(TransferStream stream) {
529   switch (stream) {
530     case TransferStream::kClientRead:
531       TerminateTransfers(client_transfers_,
532                          TransferType::kReceive,
533                          EventType::kClientEndTransfer,
534                          Status::Aborted());
535       client_read_stream_ = std::move(staged_client_stream_);
536       client_read_stream_.set_on_next(std::move(staged_client_on_next_));
537       break;
538     case TransferStream::kClientWrite:
539       TerminateTransfers(client_transfers_,
540                          TransferType::kTransmit,
541                          EventType::kClientEndTransfer,
542                          Status::Aborted());
543       client_write_stream_ = std::move(staged_client_stream_);
544       client_write_stream_.set_on_next(std::move(staged_client_on_next_));
545       break;
546     case TransferStream::kServerRead:
547       TerminateTransfers(server_transfers_,
548                          TransferType::kTransmit,
549                          EventType::kServerEndTransfer,
550                          Status::Aborted());
551       server_read_stream_ = std::move(staged_server_stream_);
552       server_read_stream_.set_on_next(std::move(staged_server_on_next_));
553       break;
554     case TransferStream::kServerWrite:
555       TerminateTransfers(server_transfers_,
556                          TransferType::kReceive,
557                          EventType::kServerEndTransfer,
558                          Status::Aborted());
559       server_write_stream_ = std::move(staged_server_stream_);
560       server_write_stream_.set_on_next(std::move(staged_server_on_next_));
561       break;
562   }
563 }
564 
565 // Adds GetResourceStatusEvent to the queue. Will fail if there is already a
566 // GetResourceStatusEvent in process.
EnqueueResourceEvent(uint32_t resource_id,ResourceStatusCallback && callback)567 void TransferThread::EnqueueResourceEvent(uint32_t resource_id,
568                                           ResourceStatusCallback&& callback) {
569   // Block until the last event has been processed.
570   next_event_ownership_.acquire();
571 
572   next_event_.type = EventType::kGetResourceStatus;
573 
574   resource_status_callback_ = std::move(callback);
575 
576   next_event_.resource_status.resource_id = resource_id;
577 
578   event_notification_.release();
579 }
580 
581 // Should only be called when we got a valid callback and RPC responder from
582 // GetResourceStatus transfer RPC.
GetResourceState(uint32_t resource_id)583 void TransferThread::GetResourceState(uint32_t resource_id) {
584   PW_ASSERT(resource_status_callback_ != nullptr);
585 
586   auto handler = std::find_if(handlers_.begin(), handlers_.end(), [&](auto& h) {
587     return h.id() == resource_id;
588   });
589   internal::ResourceStatus stats;
590   stats.resource_id = resource_id;
591 
592   if (handler != handlers_.end()) {
593     Status status = handler->GetStatus(stats.readable_offset,
594                                        stats.writeable_offset,
595                                        stats.read_checksum,
596                                        stats.write_checksum);
597 
598     resource_status_callback_(status, stats);
599   } else {
600     resource_status_callback_(Status::NotFound(), stats);
601   }
602 }
603 
604 }  // namespace pw::transfer::internal
605 
606 PW_MODIFY_DIAGNOSTICS_POP();
607