• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 #include "net/dcsctp/rx/traditional_reassembly_streams.h"
11 
12 #include <stddef.h>
13 
14 #include <cstdint>
15 #include <functional>
16 #include <iterator>
17 #include <map>
18 #include <numeric>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/algorithm/container.h"
23 #include "absl/types/optional.h"
24 #include "api/array_view.h"
25 #include "net/dcsctp/common/sequence_numbers.h"
26 #include "net/dcsctp/packet/chunk/forward_tsn_common.h"
27 #include "net/dcsctp/packet/data.h"
28 #include "net/dcsctp/public/dcsctp_message.h"
29 #include "rtc_base/logging.h"
30 
31 namespace dcsctp {
32 namespace {
33 
34 // Given a map (`chunks`) and an iterator to within that map (`iter`), this
35 // function will return an iterator to the first chunk in that message, which
36 // has the `is_beginning` flag set. If there are any gaps, or if the beginning
37 // can't be found, `absl::nullopt` is returned.
FindBeginning(const std::map<UnwrappedTSN,Data> & chunks,std::map<UnwrappedTSN,Data>::iterator iter)38 absl::optional<std::map<UnwrappedTSN, Data>::iterator> FindBeginning(
39     const std::map<UnwrappedTSN, Data>& chunks,
40     std::map<UnwrappedTSN, Data>::iterator iter) {
41   UnwrappedTSN prev_tsn = iter->first;
42   for (;;) {
43     if (iter->second.is_beginning) {
44       return iter;
45     }
46     if (iter == chunks.begin()) {
47       return absl::nullopt;
48     }
49     --iter;
50     if (iter->first.next_value() != prev_tsn) {
51       return absl::nullopt;
52     }
53     prev_tsn = iter->first;
54   }
55 }
56 
57 // Given a map (`chunks`) and an iterator to within that map (`iter`), this
58 // function will return an iterator to the chunk after the last chunk in that
59 // message, which has the `is_end` flag set. If there are any gaps, or if the
60 // end can't be found, `absl::nullopt` is returned.
FindEnd(std::map<UnwrappedTSN,Data> & chunks,std::map<UnwrappedTSN,Data>::iterator iter)61 absl::optional<std::map<UnwrappedTSN, Data>::iterator> FindEnd(
62     std::map<UnwrappedTSN, Data>& chunks,
63     std::map<UnwrappedTSN, Data>::iterator iter) {
64   UnwrappedTSN prev_tsn = iter->first;
65   for (;;) {
66     if (iter->second.is_end) {
67       return ++iter;
68     }
69     ++iter;
70     if (iter == chunks.end()) {
71       return absl::nullopt;
72     }
73     if (iter->first != prev_tsn.next_value()) {
74       return absl::nullopt;
75     }
76     prev_tsn = iter->first;
77   }
78 }
79 }  // namespace
80 
TraditionalReassemblyStreams(absl::string_view log_prefix,OnAssembledMessage on_assembled_message)81 TraditionalReassemblyStreams::TraditionalReassemblyStreams(
82     absl::string_view log_prefix,
83     OnAssembledMessage on_assembled_message)
84     : log_prefix_(log_prefix),
85       on_assembled_message_(std::move(on_assembled_message)) {}
86 
Add(UnwrappedTSN tsn,Data data)87 int TraditionalReassemblyStreams::UnorderedStream::Add(UnwrappedTSN tsn,
88                                                        Data data) {
89   int queued_bytes = data.size();
90   auto [it, inserted] = chunks_.emplace(tsn, std::move(data));
91   if (!inserted) {
92     return 0;
93   }
94 
95   queued_bytes -= TryToAssembleMessage(it);
96 
97   return queued_bytes;
98 }
99 
TryToAssembleMessage(ChunkMap::iterator iter)100 size_t TraditionalReassemblyStreams::UnorderedStream::TryToAssembleMessage(
101     ChunkMap::iterator iter) {
102   // TODO(boivie): This method is O(N) with the number of fragments in a
103   // message, which can be inefficient for very large values of N. This could be
104   // optimized by e.g. only trying to assemble a message once _any_ beginning
105   // and _any_ end has been found.
106   absl::optional<ChunkMap::iterator> start = FindBeginning(chunks_, iter);
107   if (!start.has_value()) {
108     return 0;
109   }
110   absl::optional<ChunkMap::iterator> end = FindEnd(chunks_, iter);
111   if (!end.has_value()) {
112     return 0;
113   }
114 
115   size_t bytes_assembled = AssembleMessage(*start, *end);
116   chunks_.erase(*start, *end);
117   return bytes_assembled;
118 }
119 
AssembleMessage(const ChunkMap::iterator start,const ChunkMap::iterator end)120 size_t TraditionalReassemblyStreams::StreamBase::AssembleMessage(
121     const ChunkMap::iterator start,
122     const ChunkMap::iterator end) {
123   size_t count = std::distance(start, end);
124 
125   if (count == 1) {
126     // Fast path - zero-copy
127     const Data& data = start->second;
128     size_t payload_size = start->second.size();
129     UnwrappedTSN tsns[1] = {start->first};
130     DcSctpMessage message(data.stream_id, data.ppid, std::move(data.payload));
131     parent_.on_assembled_message_(tsns, std::move(message));
132     return payload_size;
133   }
134 
135   // Slow path - will need to concatenate the payload.
136   std::vector<UnwrappedTSN> tsns;
137   std::vector<uint8_t> payload;
138 
139   size_t payload_size = std::accumulate(
140       start, end, 0,
141       [](size_t v, const auto& p) { return v + p.second.size(); });
142 
143   tsns.reserve(count);
144   payload.reserve(payload_size);
145   for (auto it = start; it != end; ++it) {
146     const Data& data = it->second;
147     tsns.push_back(it->first);
148     payload.insert(payload.end(), data.payload.begin(), data.payload.end());
149   }
150 
151   DcSctpMessage message(start->second.stream_id, start->second.ppid,
152                         std::move(payload));
153   parent_.on_assembled_message_(tsns, std::move(message));
154 
155   return payload_size;
156 }
157 
EraseTo(UnwrappedTSN tsn)158 size_t TraditionalReassemblyStreams::UnorderedStream::EraseTo(
159     UnwrappedTSN tsn) {
160   auto end_iter = chunks_.upper_bound(tsn);
161   size_t removed_bytes = std::accumulate(
162       chunks_.begin(), end_iter, 0,
163       [](size_t r, const auto& p) { return r + p.second.size(); });
164 
165   chunks_.erase(chunks_.begin(), end_iter);
166   return removed_bytes;
167 }
168 
TryToAssembleMessage()169 size_t TraditionalReassemblyStreams::OrderedStream::TryToAssembleMessage() {
170   if (chunks_by_ssn_.empty() || chunks_by_ssn_.begin()->first != next_ssn_) {
171     return 0;
172   }
173 
174   ChunkMap& chunks = chunks_by_ssn_.begin()->second;
175 
176   if (!chunks.begin()->second.is_beginning || !chunks.rbegin()->second.is_end) {
177     return 0;
178   }
179 
180   uint32_t tsn_diff =
181       UnwrappedTSN::Difference(chunks.rbegin()->first, chunks.begin()->first);
182   if (tsn_diff != chunks.size() - 1) {
183     return 0;
184   }
185 
186   size_t assembled_bytes = AssembleMessage(chunks.begin(), chunks.end());
187   chunks_by_ssn_.erase(chunks_by_ssn_.begin());
188   next_ssn_.Increment();
189   return assembled_bytes;
190 }
191 
TryToAssembleMessages()192 size_t TraditionalReassemblyStreams::OrderedStream::TryToAssembleMessages() {
193   size_t assembled_bytes = 0;
194 
195   for (;;) {
196     size_t assembled_bytes_this_iter = TryToAssembleMessage();
197     if (assembled_bytes_this_iter == 0) {
198       break;
199     }
200     assembled_bytes += assembled_bytes_this_iter;
201   }
202   return assembled_bytes;
203 }
204 
Add(UnwrappedTSN tsn,Data data)205 int TraditionalReassemblyStreams::OrderedStream::Add(UnwrappedTSN tsn,
206                                                      Data data) {
207   int queued_bytes = data.size();
208 
209   UnwrappedSSN ssn = ssn_unwrapper_.Unwrap(data.ssn);
210   auto [unused, inserted] = chunks_by_ssn_[ssn].emplace(tsn, std::move(data));
211   if (!inserted) {
212     return 0;
213   }
214 
215   if (ssn == next_ssn_) {
216     queued_bytes -= TryToAssembleMessages();
217   }
218 
219   return queued_bytes;
220 }
221 
EraseTo(SSN ssn)222 size_t TraditionalReassemblyStreams::OrderedStream::EraseTo(SSN ssn) {
223   UnwrappedSSN unwrapped_ssn = ssn_unwrapper_.Unwrap(ssn);
224 
225   auto end_iter = chunks_by_ssn_.upper_bound(unwrapped_ssn);
226   size_t removed_bytes = std::accumulate(
227       chunks_by_ssn_.begin(), end_iter, 0, [](size_t r1, const auto& p) {
228         return r1 +
229                absl::c_accumulate(p.second, 0, [](size_t r2, const auto& q) {
230                  return r2 + q.second.size();
231                });
232       });
233   chunks_by_ssn_.erase(chunks_by_ssn_.begin(), end_iter);
234 
235   if (unwrapped_ssn >= next_ssn_) {
236     unwrapped_ssn.Increment();
237     next_ssn_ = unwrapped_ssn;
238   }
239 
240   removed_bytes += TryToAssembleMessages();
241   return removed_bytes;
242 }
243 
Add(UnwrappedTSN tsn,Data data)244 int TraditionalReassemblyStreams::Add(UnwrappedTSN tsn, Data data) {
245   if (data.is_unordered) {
246     auto it = unordered_streams_.try_emplace(data.stream_id, this).first;
247     return it->second.Add(tsn, std::move(data));
248   }
249 
250   auto it = ordered_streams_.try_emplace(data.stream_id, this).first;
251   return it->second.Add(tsn, std::move(data));
252 }
253 
HandleForwardTsn(UnwrappedTSN new_cumulative_ack_tsn,rtc::ArrayView<const AnyForwardTsnChunk::SkippedStream> skipped_streams)254 size_t TraditionalReassemblyStreams::HandleForwardTsn(
255     UnwrappedTSN new_cumulative_ack_tsn,
256     rtc::ArrayView<const AnyForwardTsnChunk::SkippedStream> skipped_streams) {
257   size_t bytes_removed = 0;
258   // The `skipped_streams` only cover ordered messages - need to
259   // iterate all unordered streams manually to remove those chunks.
260   for (auto& [unused, stream] : unordered_streams_) {
261     bytes_removed += stream.EraseTo(new_cumulative_ack_tsn);
262   }
263 
264   for (const auto& skipped_stream : skipped_streams) {
265     auto it =
266         ordered_streams_.try_emplace(skipped_stream.stream_id, this).first;
267     bytes_removed += it->second.EraseTo(skipped_stream.ssn);
268   }
269 
270   return bytes_removed;
271 }
272 
ResetStreams(rtc::ArrayView<const StreamID> stream_ids)273 void TraditionalReassemblyStreams::ResetStreams(
274     rtc::ArrayView<const StreamID> stream_ids) {
275   if (stream_ids.empty()) {
276     for (auto& [stream_id, stream] : ordered_streams_) {
277       RTC_DLOG(LS_VERBOSE) << log_prefix_
278                            << "Resetting implicit stream_id=" << *stream_id;
279       stream.Reset();
280     }
281   } else {
282     for (StreamID stream_id : stream_ids) {
283       auto it = ordered_streams_.find(stream_id);
284       if (it != ordered_streams_.end()) {
285         RTC_DLOG(LS_VERBOSE)
286             << log_prefix_ << "Resetting explicit stream_id=" << *stream_id;
287         it->second.Reset();
288       }
289     }
290   }
291 }
292 
GetHandoverReadiness() const293 HandoverReadinessStatus TraditionalReassemblyStreams::GetHandoverReadiness()
294     const {
295   HandoverReadinessStatus status;
296   for (const auto& [unused, stream] : ordered_streams_) {
297     if (stream.has_unassembled_chunks()) {
298       status.Add(HandoverUnreadinessReason::kOrderedStreamHasUnassembledChunks);
299       break;
300     }
301   }
302   for (const auto& [unused, stream] : unordered_streams_) {
303     if (stream.has_unassembled_chunks()) {
304       status.Add(
305           HandoverUnreadinessReason::kUnorderedStreamHasUnassembledChunks);
306       break;
307     }
308   }
309   return status;
310 }
311 
AddHandoverState(DcSctpSocketHandoverState & state)312 void TraditionalReassemblyStreams::AddHandoverState(
313     DcSctpSocketHandoverState& state) {
314   for (const auto& [stream_id, stream] : ordered_streams_) {
315     DcSctpSocketHandoverState::OrderedStream state_stream;
316     state_stream.id = stream_id.value();
317     state_stream.next_ssn = stream.next_ssn().value();
318     state.rx.ordered_streams.push_back(std::move(state_stream));
319   }
320   for (const auto& [stream_id, unused] : unordered_streams_) {
321     DcSctpSocketHandoverState::UnorderedStream state_stream;
322     state_stream.id = stream_id.value();
323     state.rx.unordered_streams.push_back(std::move(state_stream));
324   }
325 }
326 
RestoreFromState(const DcSctpSocketHandoverState & state)327 void TraditionalReassemblyStreams::RestoreFromState(
328     const DcSctpSocketHandoverState& state) {
329   // Validate that the component is in pristine state.
330   RTC_DCHECK(ordered_streams_.empty());
331   RTC_DCHECK(unordered_streams_.empty());
332 
333   for (const DcSctpSocketHandoverState::OrderedStream& state_stream :
334        state.rx.ordered_streams) {
335     ordered_streams_.emplace(
336         std::piecewise_construct,
337         std::forward_as_tuple(StreamID(state_stream.id)),
338         std::forward_as_tuple(this, SSN(state_stream.next_ssn)));
339   }
340   for (const DcSctpSocketHandoverState::UnorderedStream& state_stream :
341        state.rx.unordered_streams) {
342     unordered_streams_.emplace(std::piecewise_construct,
343                                std::forward_as_tuple(StreamID(state_stream.id)),
344                                std::forward_as_tuple(this));
345   }
346 }
347 
348 }  // namespace dcsctp
349