1 /*
2 * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10 #include "net/dcsctp/rx/traditional_reassembly_streams.h"
11
12 #include <stddef.h>
13
14 #include <cstdint>
15 #include <functional>
16 #include <iterator>
17 #include <map>
18 #include <numeric>
19 #include <utility>
20 #include <vector>
21
22 #include "absl/algorithm/container.h"
23 #include "absl/types/optional.h"
24 #include "api/array_view.h"
25 #include "net/dcsctp/common/sequence_numbers.h"
26 #include "net/dcsctp/packet/chunk/forward_tsn_common.h"
27 #include "net/dcsctp/packet/data.h"
28 #include "net/dcsctp/public/dcsctp_message.h"
29 #include "rtc_base/logging.h"
30
31 namespace dcsctp {
32 namespace {
33
34 // Given a map (`chunks`) and an iterator to within that map (`iter`), this
35 // function will return an iterator to the first chunk in that message, which
36 // has the `is_beginning` flag set. If there are any gaps, or if the beginning
37 // can't be found, `absl::nullopt` is returned.
FindBeginning(const std::map<UnwrappedTSN,Data> & chunks,std::map<UnwrappedTSN,Data>::iterator iter)38 absl::optional<std::map<UnwrappedTSN, Data>::iterator> FindBeginning(
39 const std::map<UnwrappedTSN, Data>& chunks,
40 std::map<UnwrappedTSN, Data>::iterator iter) {
41 UnwrappedTSN prev_tsn = iter->first;
42 for (;;) {
43 if (iter->second.is_beginning) {
44 return iter;
45 }
46 if (iter == chunks.begin()) {
47 return absl::nullopt;
48 }
49 --iter;
50 if (iter->first.next_value() != prev_tsn) {
51 return absl::nullopt;
52 }
53 prev_tsn = iter->first;
54 }
55 }
56
57 // Given a map (`chunks`) and an iterator to within that map (`iter`), this
58 // function will return an iterator to the chunk after the last chunk in that
59 // message, which has the `is_end` flag set. If there are any gaps, or if the
60 // end can't be found, `absl::nullopt` is returned.
FindEnd(std::map<UnwrappedTSN,Data> & chunks,std::map<UnwrappedTSN,Data>::iterator iter)61 absl::optional<std::map<UnwrappedTSN, Data>::iterator> FindEnd(
62 std::map<UnwrappedTSN, Data>& chunks,
63 std::map<UnwrappedTSN, Data>::iterator iter) {
64 UnwrappedTSN prev_tsn = iter->first;
65 for (;;) {
66 if (iter->second.is_end) {
67 return ++iter;
68 }
69 ++iter;
70 if (iter == chunks.end()) {
71 return absl::nullopt;
72 }
73 if (iter->first != prev_tsn.next_value()) {
74 return absl::nullopt;
75 }
76 prev_tsn = iter->first;
77 }
78 }
79 } // namespace
80
TraditionalReassemblyStreams(absl::string_view log_prefix,OnAssembledMessage on_assembled_message)81 TraditionalReassemblyStreams::TraditionalReassemblyStreams(
82 absl::string_view log_prefix,
83 OnAssembledMessage on_assembled_message)
84 : log_prefix_(log_prefix),
85 on_assembled_message_(std::move(on_assembled_message)) {}
86
Add(UnwrappedTSN tsn,Data data)87 int TraditionalReassemblyStreams::UnorderedStream::Add(UnwrappedTSN tsn,
88 Data data) {
89 int queued_bytes = data.size();
90 auto [it, inserted] = chunks_.emplace(tsn, std::move(data));
91 if (!inserted) {
92 return 0;
93 }
94
95 queued_bytes -= TryToAssembleMessage(it);
96
97 return queued_bytes;
98 }
99
TryToAssembleMessage(ChunkMap::iterator iter)100 size_t TraditionalReassemblyStreams::UnorderedStream::TryToAssembleMessage(
101 ChunkMap::iterator iter) {
102 // TODO(boivie): This method is O(N) with the number of fragments in a
103 // message, which can be inefficient for very large values of N. This could be
104 // optimized by e.g. only trying to assemble a message once _any_ beginning
105 // and _any_ end has been found.
106 absl::optional<ChunkMap::iterator> start = FindBeginning(chunks_, iter);
107 if (!start.has_value()) {
108 return 0;
109 }
110 absl::optional<ChunkMap::iterator> end = FindEnd(chunks_, iter);
111 if (!end.has_value()) {
112 return 0;
113 }
114
115 size_t bytes_assembled = AssembleMessage(*start, *end);
116 chunks_.erase(*start, *end);
117 return bytes_assembled;
118 }
119
AssembleMessage(const ChunkMap::iterator start,const ChunkMap::iterator end)120 size_t TraditionalReassemblyStreams::StreamBase::AssembleMessage(
121 const ChunkMap::iterator start,
122 const ChunkMap::iterator end) {
123 size_t count = std::distance(start, end);
124
125 if (count == 1) {
126 // Fast path - zero-copy
127 const Data& data = start->second;
128 size_t payload_size = start->second.size();
129 UnwrappedTSN tsns[1] = {start->first};
130 DcSctpMessage message(data.stream_id, data.ppid, std::move(data.payload));
131 parent_.on_assembled_message_(tsns, std::move(message));
132 return payload_size;
133 }
134
135 // Slow path - will need to concatenate the payload.
136 std::vector<UnwrappedTSN> tsns;
137 std::vector<uint8_t> payload;
138
139 size_t payload_size = std::accumulate(
140 start, end, 0,
141 [](size_t v, const auto& p) { return v + p.second.size(); });
142
143 tsns.reserve(count);
144 payload.reserve(payload_size);
145 for (auto it = start; it != end; ++it) {
146 const Data& data = it->second;
147 tsns.push_back(it->first);
148 payload.insert(payload.end(), data.payload.begin(), data.payload.end());
149 }
150
151 DcSctpMessage message(start->second.stream_id, start->second.ppid,
152 std::move(payload));
153 parent_.on_assembled_message_(tsns, std::move(message));
154
155 return payload_size;
156 }
157
EraseTo(UnwrappedTSN tsn)158 size_t TraditionalReassemblyStreams::UnorderedStream::EraseTo(
159 UnwrappedTSN tsn) {
160 auto end_iter = chunks_.upper_bound(tsn);
161 size_t removed_bytes = std::accumulate(
162 chunks_.begin(), end_iter, 0,
163 [](size_t r, const auto& p) { return r + p.second.size(); });
164
165 chunks_.erase(chunks_.begin(), end_iter);
166 return removed_bytes;
167 }
168
TryToAssembleMessage()169 size_t TraditionalReassemblyStreams::OrderedStream::TryToAssembleMessage() {
170 if (chunks_by_ssn_.empty() || chunks_by_ssn_.begin()->first != next_ssn_) {
171 return 0;
172 }
173
174 ChunkMap& chunks = chunks_by_ssn_.begin()->second;
175
176 if (!chunks.begin()->second.is_beginning || !chunks.rbegin()->second.is_end) {
177 return 0;
178 }
179
180 uint32_t tsn_diff =
181 UnwrappedTSN::Difference(chunks.rbegin()->first, chunks.begin()->first);
182 if (tsn_diff != chunks.size() - 1) {
183 return 0;
184 }
185
186 size_t assembled_bytes = AssembleMessage(chunks.begin(), chunks.end());
187 chunks_by_ssn_.erase(chunks_by_ssn_.begin());
188 next_ssn_.Increment();
189 return assembled_bytes;
190 }
191
TryToAssembleMessages()192 size_t TraditionalReassemblyStreams::OrderedStream::TryToAssembleMessages() {
193 size_t assembled_bytes = 0;
194
195 for (;;) {
196 size_t assembled_bytes_this_iter = TryToAssembleMessage();
197 if (assembled_bytes_this_iter == 0) {
198 break;
199 }
200 assembled_bytes += assembled_bytes_this_iter;
201 }
202 return assembled_bytes;
203 }
204
Add(UnwrappedTSN tsn,Data data)205 int TraditionalReassemblyStreams::OrderedStream::Add(UnwrappedTSN tsn,
206 Data data) {
207 int queued_bytes = data.size();
208
209 UnwrappedSSN ssn = ssn_unwrapper_.Unwrap(data.ssn);
210 auto [unused, inserted] = chunks_by_ssn_[ssn].emplace(tsn, std::move(data));
211 if (!inserted) {
212 return 0;
213 }
214
215 if (ssn == next_ssn_) {
216 queued_bytes -= TryToAssembleMessages();
217 }
218
219 return queued_bytes;
220 }
221
EraseTo(SSN ssn)222 size_t TraditionalReassemblyStreams::OrderedStream::EraseTo(SSN ssn) {
223 UnwrappedSSN unwrapped_ssn = ssn_unwrapper_.Unwrap(ssn);
224
225 auto end_iter = chunks_by_ssn_.upper_bound(unwrapped_ssn);
226 size_t removed_bytes = std::accumulate(
227 chunks_by_ssn_.begin(), end_iter, 0, [](size_t r1, const auto& p) {
228 return r1 +
229 absl::c_accumulate(p.second, 0, [](size_t r2, const auto& q) {
230 return r2 + q.second.size();
231 });
232 });
233 chunks_by_ssn_.erase(chunks_by_ssn_.begin(), end_iter);
234
235 if (unwrapped_ssn >= next_ssn_) {
236 unwrapped_ssn.Increment();
237 next_ssn_ = unwrapped_ssn;
238 }
239
240 removed_bytes += TryToAssembleMessages();
241 return removed_bytes;
242 }
243
Add(UnwrappedTSN tsn,Data data)244 int TraditionalReassemblyStreams::Add(UnwrappedTSN tsn, Data data) {
245 if (data.is_unordered) {
246 auto it = unordered_streams_.try_emplace(data.stream_id, this).first;
247 return it->second.Add(tsn, std::move(data));
248 }
249
250 auto it = ordered_streams_.try_emplace(data.stream_id, this).first;
251 return it->second.Add(tsn, std::move(data));
252 }
253
HandleForwardTsn(UnwrappedTSN new_cumulative_ack_tsn,rtc::ArrayView<const AnyForwardTsnChunk::SkippedStream> skipped_streams)254 size_t TraditionalReassemblyStreams::HandleForwardTsn(
255 UnwrappedTSN new_cumulative_ack_tsn,
256 rtc::ArrayView<const AnyForwardTsnChunk::SkippedStream> skipped_streams) {
257 size_t bytes_removed = 0;
258 // The `skipped_streams` only cover ordered messages - need to
259 // iterate all unordered streams manually to remove those chunks.
260 for (auto& [unused, stream] : unordered_streams_) {
261 bytes_removed += stream.EraseTo(new_cumulative_ack_tsn);
262 }
263
264 for (const auto& skipped_stream : skipped_streams) {
265 auto it =
266 ordered_streams_.try_emplace(skipped_stream.stream_id, this).first;
267 bytes_removed += it->second.EraseTo(skipped_stream.ssn);
268 }
269
270 return bytes_removed;
271 }
272
ResetStreams(rtc::ArrayView<const StreamID> stream_ids)273 void TraditionalReassemblyStreams::ResetStreams(
274 rtc::ArrayView<const StreamID> stream_ids) {
275 if (stream_ids.empty()) {
276 for (auto& [stream_id, stream] : ordered_streams_) {
277 RTC_DLOG(LS_VERBOSE) << log_prefix_
278 << "Resetting implicit stream_id=" << *stream_id;
279 stream.Reset();
280 }
281 } else {
282 for (StreamID stream_id : stream_ids) {
283 auto it = ordered_streams_.find(stream_id);
284 if (it != ordered_streams_.end()) {
285 RTC_DLOG(LS_VERBOSE)
286 << log_prefix_ << "Resetting explicit stream_id=" << *stream_id;
287 it->second.Reset();
288 }
289 }
290 }
291 }
292
GetHandoverReadiness() const293 HandoverReadinessStatus TraditionalReassemblyStreams::GetHandoverReadiness()
294 const {
295 HandoverReadinessStatus status;
296 for (const auto& [unused, stream] : ordered_streams_) {
297 if (stream.has_unassembled_chunks()) {
298 status.Add(HandoverUnreadinessReason::kOrderedStreamHasUnassembledChunks);
299 break;
300 }
301 }
302 for (const auto& [unused, stream] : unordered_streams_) {
303 if (stream.has_unassembled_chunks()) {
304 status.Add(
305 HandoverUnreadinessReason::kUnorderedStreamHasUnassembledChunks);
306 break;
307 }
308 }
309 return status;
310 }
311
AddHandoverState(DcSctpSocketHandoverState & state)312 void TraditionalReassemblyStreams::AddHandoverState(
313 DcSctpSocketHandoverState& state) {
314 for (const auto& [stream_id, stream] : ordered_streams_) {
315 DcSctpSocketHandoverState::OrderedStream state_stream;
316 state_stream.id = stream_id.value();
317 state_stream.next_ssn = stream.next_ssn().value();
318 state.rx.ordered_streams.push_back(std::move(state_stream));
319 }
320 for (const auto& [stream_id, unused] : unordered_streams_) {
321 DcSctpSocketHandoverState::UnorderedStream state_stream;
322 state_stream.id = stream_id.value();
323 state.rx.unordered_streams.push_back(std::move(state_stream));
324 }
325 }
326
RestoreFromState(const DcSctpSocketHandoverState & state)327 void TraditionalReassemblyStreams::RestoreFromState(
328 const DcSctpSocketHandoverState& state) {
329 // Validate that the component is in pristine state.
330 RTC_DCHECK(ordered_streams_.empty());
331 RTC_DCHECK(unordered_streams_.empty());
332
333 for (const DcSctpSocketHandoverState::OrderedStream& state_stream :
334 state.rx.ordered_streams) {
335 ordered_streams_.emplace(
336 std::piecewise_construct,
337 std::forward_as_tuple(StreamID(state_stream.id)),
338 std::forward_as_tuple(this, SSN(state_stream.next_ssn)));
339 }
340 for (const DcSctpSocketHandoverState::UnorderedStream& state_stream :
341 state.rx.unordered_streams) {
342 unordered_streams_.emplace(std::piecewise_construct,
343 std::forward_as_tuple(StreamID(state_stream.id)),
344 std::forward_as_tuple(this));
345 }
346 }
347
348 } // namespace dcsctp
349