1 // Copyright 2021 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include "pw_protobuf/stream_decoder.h"
16
17 #include <algorithm>
18 #include <bit>
19 #include <cstdint>
20 #include <cstring>
21 #include <limits>
22
23 #include "pw_assert/check.h"
24 #include "pw_status/status.h"
25 #include "pw_status/status_with_size.h"
26 #include "pw_status/try.h"
27 #include "pw_varint/stream.h"
28 #include "pw_varint/varint.h"
29
30 namespace pw::protobuf {
31
DoSeek(ptrdiff_t offset,Whence origin)32 Status StreamDecoder::BytesReader::DoSeek(ptrdiff_t offset, Whence origin) {
33 PW_TRY(status_);
34 if (!decoder_.reader_.seekable()) {
35 return Status::Unimplemented();
36 }
37
38 ptrdiff_t absolute_position = std::numeric_limits<ptrdiff_t>::min();
39
40 // Convert from the position within the bytes field to the position within the
41 // proto stream.
42 switch (origin) {
43 case Whence::kBeginning:
44 absolute_position = start_offset_ + offset;
45 break;
46
47 case Whence::kCurrent:
48 absolute_position = decoder_.position_ + offset;
49 break;
50
51 case Whence::kEnd:
52 absolute_position = end_offset_ + offset;
53 break;
54 }
55
56 if (absolute_position < 0) {
57 return Status::InvalidArgument();
58 }
59
60 if (static_cast<size_t>(absolute_position) < start_offset_ ||
61 static_cast<size_t>(absolute_position) >= end_offset_) {
62 return Status::OutOfRange();
63 }
64
65 PW_TRY(decoder_.reader_.Seek(absolute_position, Whence::kBeginning));
66 decoder_.position_ = absolute_position;
67 return OkStatus();
68 }
69
DoRead(ByteSpan destination)70 StatusWithSize StreamDecoder::BytesReader::DoRead(ByteSpan destination) {
71 if (!status_.ok()) {
72 return StatusWithSize(status_, 0);
73 }
74
75 // Bound the read buffer to the size of the bytes field.
76 size_t max_length = end_offset_ - decoder_.position_;
77 if (destination.size() > max_length) {
78 destination = destination.first(max_length);
79 }
80
81 Result<ByteSpan> result = decoder_.reader_.Read(destination);
82 if (!result.ok()) {
83 return StatusWithSize(result.status(), 0);
84 }
85
86 decoder_.position_ += result.value().size();
87 return StatusWithSize(result.value().size());
88 }
89
~StreamDecoder()90 StreamDecoder::~StreamDecoder() {
91 if (parent_ != nullptr) {
92 parent_->CloseNestedDecoder(*this);
93 } else if (stream_bounds_.high < std::numeric_limits<size_t>::max()) {
94 if (status_.ok()) {
95 // Advance the stream to the end of the bounds.
96 PW_CHECK(Advance(stream_bounds_.high).ok());
97 }
98 }
99 }
100
Next()101 Status StreamDecoder::Next() {
102 PW_CHECK(!nested_reader_open_,
103 "Cannot use parent decoder while a nested one is open");
104
105 PW_TRY(status_);
106
107 if (!field_consumed_) {
108 PW_TRY(SkipField());
109 }
110
111 if (position_ >= stream_bounds_.high) {
112 return Status::OutOfRange();
113 }
114
115 status_ = ReadFieldKey();
116 return status_;
117 }
118
GetBytesReader()119 StreamDecoder::BytesReader StreamDecoder::GetBytesReader() {
120 Status status = CheckOkToRead(WireType::kDelimited);
121
122 if (reader_.ConservativeReadLimit() < delimited_field_size_) {
123 status.Update(Status::DataLoss());
124 }
125
126 nested_reader_open_ = true;
127
128 if (!status.ok()) {
129 return BytesReader(*this, status);
130 }
131
132 size_t low = position_;
133 size_t high = low + delimited_field_size_;
134
135 return BytesReader(*this, low, high);
136 }
137
GetNestedDecoder()138 StreamDecoder StreamDecoder::GetNestedDecoder() {
139 Status status = CheckOkToRead(WireType::kDelimited);
140
141 if (reader_.ConservativeReadLimit() < delimited_field_size_) {
142 status.Update(Status::DataLoss());
143 }
144
145 nested_reader_open_ = true;
146
147 if (!status.ok()) {
148 return StreamDecoder(reader_, this, status);
149 }
150
151 size_t low = position_;
152 size_t high = low + delimited_field_size_;
153
154 return StreamDecoder(reader_, this, low, high);
155 }
156
Advance(size_t end_position)157 Status StreamDecoder::Advance(size_t end_position) {
158 if (reader_.seekable()) {
159 PW_TRY(reader_.Seek(end_position - position_, stream::Stream::kCurrent));
160 position_ = end_position;
161 return OkStatus();
162 }
163
164 while (position_ < end_position) {
165 std::byte b;
166 PW_TRY(reader_.Read(std::span(&b, 1)));
167 position_++;
168 }
169 return OkStatus();
170 }
171
CloseBytesReader(BytesReader & reader)172 void StreamDecoder::CloseBytesReader(BytesReader& reader) {
173 status_ = reader.status_;
174 if (status_.ok()) {
175 // Advance the stream to the end of the bytes field.
176 // The BytesReader already updated our position_ field as bytes were read.
177 PW_CHECK(Advance(reader.end_offset_).ok());
178 }
179
180 field_consumed_ = true;
181 nested_reader_open_ = false;
182 }
183
CloseNestedDecoder(StreamDecoder & nested)184 void StreamDecoder::CloseNestedDecoder(StreamDecoder& nested) {
185 PW_CHECK_PTR_EQ(nested.parent_, this);
186
187 nested.nested_reader_open_ = true;
188 nested.parent_ = nullptr;
189
190 status_ = nested.status_;
191 position_ = nested.position_;
192 if (status_.ok()) {
193 // Advance the stream to the end of the nested message field.
194 PW_CHECK(Advance(nested.stream_bounds_.high).ok());
195 }
196
197 field_consumed_ = true;
198 nested_reader_open_ = false;
199 }
200
ReadFieldKey()201 Status StreamDecoder::ReadFieldKey() {
202 PW_DCHECK(field_consumed_);
203
204 uint64_t varint = 0;
205 PW_TRY_ASSIGN(size_t bytes_read, varint::Read(reader_, &varint));
206 position_ += bytes_read;
207
208 if (!FieldKey::IsValidKey(varint)) {
209 return Status::DataLoss();
210 }
211
212 current_field_ = FieldKey(varint);
213
214 if (current_field_.wire_type() == WireType::kDelimited) {
215 // Read the length varint of length-delimited fields immediately to simplify
216 // later processing of the field.
217 PW_TRY_ASSIGN(bytes_read, varint::Read(reader_, &varint));
218 position_ += bytes_read;
219
220 if (varint > std::numeric_limits<uint32_t>::max()) {
221 return Status::DataLoss();
222 }
223
224 delimited_field_size_ = varint;
225 delimited_field_offset_ = position_;
226 }
227
228 field_consumed_ = false;
229 return OkStatus();
230 }
231
GetLengthDelimitedPayloadBounds()232 Result<StreamDecoder::Bounds> StreamDecoder::GetLengthDelimitedPayloadBounds() {
233 PW_TRY(CheckOkToRead(WireType::kDelimited));
234 return StreamDecoder::Bounds{delimited_field_offset_,
235 delimited_field_size_ + delimited_field_offset_};
236 }
237
238 // Consumes the current protobuf field, advancing the stream to the key of the
239 // next field (if one exists).
SkipField()240 Status StreamDecoder::SkipField() {
241 PW_DCHECK(!field_consumed_);
242
243 size_t bytes_to_skip = 0;
244 uint64_t value = 0;
245
246 switch (current_field_.wire_type()) {
247 case WireType::kVarint: {
248 // Consume the varint field; nothing more to skip afterward.
249 PW_TRY_ASSIGN(size_t bytes_read, varint::Read(reader_, &value));
250 position_ += bytes_read;
251 break;
252 }
253 case WireType::kDelimited:
254 bytes_to_skip = delimited_field_size_;
255 break;
256
257 case WireType::kFixed32:
258 bytes_to_skip = sizeof(uint32_t);
259 break;
260
261 case WireType::kFixed64:
262 bytes_to_skip = sizeof(uint64_t);
263 break;
264 }
265
266 if (bytes_to_skip > 0) {
267 // Check if the stream has the field available. If not, report it as a
268 // DATA_LOSS since the proto is invalid (as opposed to OUT_OF_BOUNDS if we
269 // just tried to seek beyond the end).
270 if (reader_.ConservativeReadLimit() < bytes_to_skip) {
271 status_ = Status::DataLoss();
272 return status_;
273 }
274
275 PW_TRY(Advance(position_ + bytes_to_skip));
276 }
277
278 field_consumed_ = true;
279 return OkStatus();
280 }
281
ReadVarintField(std::span<std::byte> out,VarintDecodeType decode_type)282 Status StreamDecoder::ReadVarintField(std::span<std::byte> out,
283 VarintDecodeType decode_type) {
284 PW_CHECK(out.size() == sizeof(bool) || out.size() == sizeof(uint32_t) ||
285 out.size() == sizeof(uint64_t),
286 "Protobuf varints must only be used with bool, int32_t, uint32_t, "
287 "int64_t, or uint64_t");
288 PW_TRY(CheckOkToRead(WireType::kVarint));
289
290 const StatusWithSize sws = ReadOneVarint(out, decode_type);
291 if (sws.status() != Status::DataLoss())
292 field_consumed_ = true;
293 return sws.status();
294 }
295
ReadOneVarint(std::span<std::byte> out,VarintDecodeType decode_type)296 StatusWithSize StreamDecoder::ReadOneVarint(std::span<std::byte> out,
297 VarintDecodeType decode_type) {
298 uint64_t value;
299 StatusWithSize sws = varint::Read(reader_, &value);
300 if (sws.IsOutOfRange()) {
301 // Out of range indicates the end of the stream. As a value is expected
302 // here, report it as a data loss and terminate the decode operation.
303 status_ = Status::DataLoss();
304 return StatusWithSize(status_, sws.size());
305 }
306 if (!sws.ok()) {
307 return sws;
308 }
309
310 position_ += sws.size();
311
312 if (out.size() == sizeof(uint64_t)) {
313 if (decode_type == VarintDecodeType::kUnsigned) {
314 std::memcpy(out.data(), &value, out.size());
315 } else {
316 const int64_t signed_value = decode_type == VarintDecodeType::kZigZag
317 ? varint::ZigZagDecode(value)
318 : static_cast<int64_t>(value);
319 std::memcpy(out.data(), &signed_value, out.size());
320 }
321 } else if (out.size() == sizeof(uint32_t)) {
322 if (decode_type == VarintDecodeType::kUnsigned) {
323 if (value > std::numeric_limits<uint32_t>::max()) {
324 return StatusWithSize(Status::OutOfRange(), sws.size());
325 }
326 std::memcpy(out.data(), &value, out.size());
327 } else {
328 const int64_t signed_value = decode_type == VarintDecodeType::kZigZag
329 ? varint::ZigZagDecode(value)
330 : static_cast<int64_t>(value);
331 if (signed_value > std::numeric_limits<int32_t>::max() ||
332 signed_value < std::numeric_limits<int32_t>::min()) {
333 return StatusWithSize(Status::OutOfRange(), sws.size());
334 }
335 std::memcpy(out.data(), &signed_value, out.size());
336 }
337 } else if (out.size() == sizeof(bool)) {
338 PW_CHECK(decode_type == VarintDecodeType::kUnsigned,
339 "Protobuf bool can never be signed");
340 std::memcpy(out.data(), &value, out.size());
341 }
342
343 return sws;
344 }
345
ReadFixedField(std::span<std::byte> out)346 Status StreamDecoder::ReadFixedField(std::span<std::byte> out) {
347 WireType expected_wire_type =
348 out.size() == sizeof(uint32_t) ? WireType::kFixed32 : WireType::kFixed64;
349 PW_TRY(CheckOkToRead(expected_wire_type));
350
351 if (reader_.ConservativeReadLimit() < out.size()) {
352 status_ = Status::DataLoss();
353 return status_;
354 }
355
356 PW_TRY(reader_.Read(out));
357 position_ += out.size();
358 field_consumed_ = true;
359
360 if (std::endian::native != std::endian::little) {
361 std::reverse(out.begin(), out.end());
362 }
363
364 return OkStatus();
365 }
366
ReadDelimitedField(std::span<std::byte> out)367 StatusWithSize StreamDecoder::ReadDelimitedField(std::span<std::byte> out) {
368 if (Status status = CheckOkToRead(WireType::kDelimited); !status.ok()) {
369 return StatusWithSize(status, 0);
370 }
371
372 if (reader_.ConservativeReadLimit() < delimited_field_size_) {
373 status_ = Status::DataLoss();
374 return StatusWithSize(status_, 0);
375 }
376
377 if (out.size() < delimited_field_size_) {
378 // Value can't fit into the provided buffer. Don't advance the cursor so
379 // that the field can be re-read with a larger buffer or through the stream
380 // API.
381 return StatusWithSize::ResourceExhausted();
382 }
383
384 Result<ByteSpan> result = reader_.Read(out.first(delimited_field_size_));
385 if (!result.ok()) {
386 return StatusWithSize(result.status(), 0);
387 }
388
389 position_ += result.value().size();
390 field_consumed_ = true;
391 return StatusWithSize(result.value().size());
392 }
393
ReadPackedFixedField(std::span<std::byte> out,size_t elem_size)394 StatusWithSize StreamDecoder::ReadPackedFixedField(std::span<std::byte> out,
395 size_t elem_size) {
396 if (Status status = CheckOkToRead(WireType::kDelimited); !status.ok()) {
397 return StatusWithSize(status, 0);
398 }
399
400 if (reader_.ConservativeReadLimit() < delimited_field_size_) {
401 status_ = Status::DataLoss();
402 return StatusWithSize(status_, 0);
403 }
404
405 if (out.size() < delimited_field_size_) {
406 // Value can't fit into the provided buffer. Don't advance the cursor so
407 // that the field can be re-read with a larger buffer or through the stream
408 // API.
409 return StatusWithSize::ResourceExhausted();
410 }
411
412 Result<ByteSpan> result = reader_.Read(out.first(delimited_field_size_));
413 if (!result.ok()) {
414 return StatusWithSize(result.status(), 0);
415 }
416
417 position_ += result.value().size();
418 field_consumed_ = true;
419
420 // Decode little-endian serialized packed fields.
421 if (std::endian::native != std::endian::little) {
422 for (auto out_start = out.begin(); out_start != out.end();
423 out_start += elem_size) {
424 std::reverse(out_start, out_start + elem_size);
425 }
426 }
427
428 return StatusWithSize(result.value().size() / elem_size);
429 }
430
ReadPackedVarintField(std::span<std::byte> out,size_t elem_size,VarintDecodeType decode_type)431 StatusWithSize StreamDecoder::ReadPackedVarintField(
432 std::span<std::byte> out, size_t elem_size, VarintDecodeType decode_type) {
433 PW_CHECK(elem_size == sizeof(bool) || elem_size == sizeof(uint32_t) ||
434 elem_size == sizeof(uint64_t),
435 "Protobuf varints must only be used with bool, int32_t, uint32_t, "
436 "int64_t, or uint64_t");
437
438 if (Status status = CheckOkToRead(WireType::kDelimited); !status.ok()) {
439 return StatusWithSize(status, 0);
440 }
441
442 if (reader_.ConservativeReadLimit() < delimited_field_size_) {
443 status_ = Status::DataLoss();
444 return StatusWithSize(status_, 0);
445 }
446
447 size_t bytes_read = 0;
448 size_t number_out = 0;
449 while (bytes_read < delimited_field_size_ && !out.empty()) {
450 const StatusWithSize sws = ReadOneVarint(out.first(elem_size), decode_type);
451 if (!sws.ok()) {
452 return StatusWithSize(sws.status(), number_out);
453 }
454
455 bytes_read += sws.size();
456 out = out.subspan(elem_size);
457 ++number_out;
458 }
459
460 if (bytes_read < delimited_field_size_) {
461 return StatusWithSize(Status::ResourceExhausted(), number_out);
462 }
463
464 field_consumed_ = true;
465 return StatusWithSize(OkStatus(), number_out);
466 }
467
CheckOkToRead(WireType type)468 Status StreamDecoder::CheckOkToRead(WireType type) {
469 PW_CHECK(!nested_reader_open_,
470 "Cannot read from a decoder while a nested decoder is open");
471 PW_CHECK(!field_consumed_,
472 "Attempting to read from protobuf decoder without first calling "
473 "Next()");
474
475 // Attempting to read the wrong type is typically a programmer error;
476 // however, it could also occur due to data corruption. As we don't want to
477 // crash on bad data, return NOT_FOUND here to distinguish it from other
478 // corruption cases.
479 if (current_field_.wire_type() != type) {
480 status_ = Status::NotFound();
481 }
482
483 return status_;
484 }
485
486 } // namespace pw::protobuf
487