• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #include "pw_protobuf/stream_decoder.h"
16 
17 #include <algorithm>
18 #include <bit>
19 #include <cstdint>
20 #include <cstring>
21 #include <limits>
22 
23 #include "pw_assert/check.h"
24 #include "pw_status/status.h"
25 #include "pw_status/status_with_size.h"
26 #include "pw_status/try.h"
27 #include "pw_varint/stream.h"
28 #include "pw_varint/varint.h"
29 
30 namespace pw::protobuf {
31 
DoSeek(ptrdiff_t offset,Whence origin)32 Status StreamDecoder::BytesReader::DoSeek(ptrdiff_t offset, Whence origin) {
33   PW_TRY(status_);
34   if (!decoder_.reader_.seekable()) {
35     return Status::Unimplemented();
36   }
37 
38   ptrdiff_t absolute_position = std::numeric_limits<ptrdiff_t>::min();
39 
40   // Convert from the position within the bytes field to the position within the
41   // proto stream.
42   switch (origin) {
43     case Whence::kBeginning:
44       absolute_position = start_offset_ + offset;
45       break;
46 
47     case Whence::kCurrent:
48       absolute_position = decoder_.position_ + offset;
49       break;
50 
51     case Whence::kEnd:
52       absolute_position = end_offset_ + offset;
53       break;
54   }
55 
56   if (absolute_position < 0) {
57     return Status::InvalidArgument();
58   }
59 
60   if (static_cast<size_t>(absolute_position) < start_offset_ ||
61       static_cast<size_t>(absolute_position) >= end_offset_) {
62     return Status::OutOfRange();
63   }
64 
65   PW_TRY(decoder_.reader_.Seek(absolute_position, Whence::kBeginning));
66   decoder_.position_ = absolute_position;
67   return OkStatus();
68 }
69 
DoRead(ByteSpan destination)70 StatusWithSize StreamDecoder::BytesReader::DoRead(ByteSpan destination) {
71   if (!status_.ok()) {
72     return StatusWithSize(status_, 0);
73   }
74 
75   // Bound the read buffer to the size of the bytes field.
76   size_t max_length = end_offset_ - decoder_.position_;
77   if (destination.size() > max_length) {
78     destination = destination.first(max_length);
79   }
80 
81   Result<ByteSpan> result = decoder_.reader_.Read(destination);
82   if (!result.ok()) {
83     return StatusWithSize(result.status(), 0);
84   }
85 
86   decoder_.position_ += result.value().size();
87   return StatusWithSize(result.value().size());
88 }
89 
~StreamDecoder()90 StreamDecoder::~StreamDecoder() {
91   if (parent_ != nullptr) {
92     parent_->CloseNestedDecoder(*this);
93   } else if (stream_bounds_.high < std::numeric_limits<size_t>::max()) {
94     if (status_.ok()) {
95       // Advance the stream to the end of the bounds.
96       PW_CHECK(Advance(stream_bounds_.high).ok());
97     }
98   }
99 }
100 
Next()101 Status StreamDecoder::Next() {
102   PW_CHECK(!nested_reader_open_,
103            "Cannot use parent decoder while a nested one is open");
104 
105   PW_TRY(status_);
106 
107   if (!field_consumed_) {
108     PW_TRY(SkipField());
109   }
110 
111   if (position_ >= stream_bounds_.high) {
112     return Status::OutOfRange();
113   }
114 
115   status_ = ReadFieldKey();
116   return status_;
117 }
118 
GetBytesReader()119 StreamDecoder::BytesReader StreamDecoder::GetBytesReader() {
120   Status status = CheckOkToRead(WireType::kDelimited);
121 
122   if (reader_.ConservativeReadLimit() < delimited_field_size_) {
123     status.Update(Status::DataLoss());
124   }
125 
126   nested_reader_open_ = true;
127 
128   if (!status.ok()) {
129     return BytesReader(*this, status);
130   }
131 
132   size_t low = position_;
133   size_t high = low + delimited_field_size_;
134 
135   return BytesReader(*this, low, high);
136 }
137 
GetNestedDecoder()138 StreamDecoder StreamDecoder::GetNestedDecoder() {
139   Status status = CheckOkToRead(WireType::kDelimited);
140 
141   if (reader_.ConservativeReadLimit() < delimited_field_size_) {
142     status.Update(Status::DataLoss());
143   }
144 
145   nested_reader_open_ = true;
146 
147   if (!status.ok()) {
148     return StreamDecoder(reader_, this, status);
149   }
150 
151   size_t low = position_;
152   size_t high = low + delimited_field_size_;
153 
154   return StreamDecoder(reader_, this, low, high);
155 }
156 
Advance(size_t end_position)157 Status StreamDecoder::Advance(size_t end_position) {
158   if (reader_.seekable()) {
159     PW_TRY(reader_.Seek(end_position - position_, stream::Stream::kCurrent));
160     position_ = end_position;
161     return OkStatus();
162   }
163 
164   while (position_ < end_position) {
165     std::byte b;
166     PW_TRY(reader_.Read(std::span(&b, 1)));
167     position_++;
168   }
169   return OkStatus();
170 }
171 
CloseBytesReader(BytesReader & reader)172 void StreamDecoder::CloseBytesReader(BytesReader& reader) {
173   status_ = reader.status_;
174   if (status_.ok()) {
175     // Advance the stream to the end of the bytes field.
176     // The BytesReader already updated our position_ field as bytes were read.
177     PW_CHECK(Advance(reader.end_offset_).ok());
178   }
179 
180   field_consumed_ = true;
181   nested_reader_open_ = false;
182 }
183 
CloseNestedDecoder(StreamDecoder & nested)184 void StreamDecoder::CloseNestedDecoder(StreamDecoder& nested) {
185   PW_CHECK_PTR_EQ(nested.parent_, this);
186 
187   nested.nested_reader_open_ = true;
188   nested.parent_ = nullptr;
189 
190   status_ = nested.status_;
191   position_ = nested.position_;
192   if (status_.ok()) {
193     // Advance the stream to the end of the nested message field.
194     PW_CHECK(Advance(nested.stream_bounds_.high).ok());
195   }
196 
197   field_consumed_ = true;
198   nested_reader_open_ = false;
199 }
200 
ReadFieldKey()201 Status StreamDecoder::ReadFieldKey() {
202   PW_DCHECK(field_consumed_);
203 
204   uint64_t varint = 0;
205   PW_TRY_ASSIGN(size_t bytes_read, varint::Read(reader_, &varint));
206   position_ += bytes_read;
207 
208   if (!FieldKey::IsValidKey(varint)) {
209     return Status::DataLoss();
210   }
211 
212   current_field_ = FieldKey(varint);
213 
214   if (current_field_.wire_type() == WireType::kDelimited) {
215     // Read the length varint of length-delimited fields immediately to simplify
216     // later processing of the field.
217     PW_TRY_ASSIGN(bytes_read, varint::Read(reader_, &varint));
218     position_ += bytes_read;
219 
220     if (varint > std::numeric_limits<uint32_t>::max()) {
221       return Status::DataLoss();
222     }
223 
224     delimited_field_size_ = varint;
225     delimited_field_offset_ = position_;
226   }
227 
228   field_consumed_ = false;
229   return OkStatus();
230 }
231 
GetLengthDelimitedPayloadBounds()232 Result<StreamDecoder::Bounds> StreamDecoder::GetLengthDelimitedPayloadBounds() {
233   PW_TRY(CheckOkToRead(WireType::kDelimited));
234   return StreamDecoder::Bounds{delimited_field_offset_,
235                                delimited_field_size_ + delimited_field_offset_};
236 }
237 
238 // Consumes the current protobuf field, advancing the stream to the key of the
239 // next field (if one exists).
SkipField()240 Status StreamDecoder::SkipField() {
241   PW_DCHECK(!field_consumed_);
242 
243   size_t bytes_to_skip = 0;
244   uint64_t value = 0;
245 
246   switch (current_field_.wire_type()) {
247     case WireType::kVarint: {
248       // Consume the varint field; nothing more to skip afterward.
249       PW_TRY_ASSIGN(size_t bytes_read, varint::Read(reader_, &value));
250       position_ += bytes_read;
251       break;
252     }
253     case WireType::kDelimited:
254       bytes_to_skip = delimited_field_size_;
255       break;
256 
257     case WireType::kFixed32:
258       bytes_to_skip = sizeof(uint32_t);
259       break;
260 
261     case WireType::kFixed64:
262       bytes_to_skip = sizeof(uint64_t);
263       break;
264   }
265 
266   if (bytes_to_skip > 0) {
267     // Check if the stream has the field available. If not, report it as a
268     // DATA_LOSS since the proto is invalid (as opposed to OUT_OF_BOUNDS if we
269     // just tried to seek beyond the end).
270     if (reader_.ConservativeReadLimit() < bytes_to_skip) {
271       status_ = Status::DataLoss();
272       return status_;
273     }
274 
275     PW_TRY(Advance(position_ + bytes_to_skip));
276   }
277 
278   field_consumed_ = true;
279   return OkStatus();
280 }
281 
ReadVarintField(std::span<std::byte> out,VarintDecodeType decode_type)282 Status StreamDecoder::ReadVarintField(std::span<std::byte> out,
283                                       VarintDecodeType decode_type) {
284   PW_CHECK(out.size() == sizeof(bool) || out.size() == sizeof(uint32_t) ||
285                out.size() == sizeof(uint64_t),
286            "Protobuf varints must only be used with bool, int32_t, uint32_t, "
287            "int64_t, or uint64_t");
288   PW_TRY(CheckOkToRead(WireType::kVarint));
289 
290   const StatusWithSize sws = ReadOneVarint(out, decode_type);
291   if (sws.status() != Status::DataLoss())
292     field_consumed_ = true;
293   return sws.status();
294 }
295 
ReadOneVarint(std::span<std::byte> out,VarintDecodeType decode_type)296 StatusWithSize StreamDecoder::ReadOneVarint(std::span<std::byte> out,
297                                             VarintDecodeType decode_type) {
298   uint64_t value;
299   StatusWithSize sws = varint::Read(reader_, &value);
300   if (sws.IsOutOfRange()) {
301     // Out of range indicates the end of the stream. As a value is expected
302     // here, report it as a data loss and terminate the decode operation.
303     status_ = Status::DataLoss();
304     return StatusWithSize(status_, sws.size());
305   }
306   if (!sws.ok()) {
307     return sws;
308   }
309 
310   position_ += sws.size();
311 
312   if (out.size() == sizeof(uint64_t)) {
313     if (decode_type == VarintDecodeType::kUnsigned) {
314       std::memcpy(out.data(), &value, out.size());
315     } else {
316       const int64_t signed_value = decode_type == VarintDecodeType::kZigZag
317                                        ? varint::ZigZagDecode(value)
318                                        : static_cast<int64_t>(value);
319       std::memcpy(out.data(), &signed_value, out.size());
320     }
321   } else if (out.size() == sizeof(uint32_t)) {
322     if (decode_type == VarintDecodeType::kUnsigned) {
323       if (value > std::numeric_limits<uint32_t>::max()) {
324         return StatusWithSize(Status::OutOfRange(), sws.size());
325       }
326       std::memcpy(out.data(), &value, out.size());
327     } else {
328       const int64_t signed_value = decode_type == VarintDecodeType::kZigZag
329                                        ? varint::ZigZagDecode(value)
330                                        : static_cast<int64_t>(value);
331       if (signed_value > std::numeric_limits<int32_t>::max() ||
332           signed_value < std::numeric_limits<int32_t>::min()) {
333         return StatusWithSize(Status::OutOfRange(), sws.size());
334       }
335       std::memcpy(out.data(), &signed_value, out.size());
336     }
337   } else if (out.size() == sizeof(bool)) {
338     PW_CHECK(decode_type == VarintDecodeType::kUnsigned,
339              "Protobuf bool can never be signed");
340     std::memcpy(out.data(), &value, out.size());
341   }
342 
343   return sws;
344 }
345 
ReadFixedField(std::span<std::byte> out)346 Status StreamDecoder::ReadFixedField(std::span<std::byte> out) {
347   WireType expected_wire_type =
348       out.size() == sizeof(uint32_t) ? WireType::kFixed32 : WireType::kFixed64;
349   PW_TRY(CheckOkToRead(expected_wire_type));
350 
351   if (reader_.ConservativeReadLimit() < out.size()) {
352     status_ = Status::DataLoss();
353     return status_;
354   }
355 
356   PW_TRY(reader_.Read(out));
357   position_ += out.size();
358   field_consumed_ = true;
359 
360   if (std::endian::native != std::endian::little) {
361     std::reverse(out.begin(), out.end());
362   }
363 
364   return OkStatus();
365 }
366 
ReadDelimitedField(std::span<std::byte> out)367 StatusWithSize StreamDecoder::ReadDelimitedField(std::span<std::byte> out) {
368   if (Status status = CheckOkToRead(WireType::kDelimited); !status.ok()) {
369     return StatusWithSize(status, 0);
370   }
371 
372   if (reader_.ConservativeReadLimit() < delimited_field_size_) {
373     status_ = Status::DataLoss();
374     return StatusWithSize(status_, 0);
375   }
376 
377   if (out.size() < delimited_field_size_) {
378     // Value can't fit into the provided buffer. Don't advance the cursor so
379     // that the field can be re-read with a larger buffer or through the stream
380     // API.
381     return StatusWithSize::ResourceExhausted();
382   }
383 
384   Result<ByteSpan> result = reader_.Read(out.first(delimited_field_size_));
385   if (!result.ok()) {
386     return StatusWithSize(result.status(), 0);
387   }
388 
389   position_ += result.value().size();
390   field_consumed_ = true;
391   return StatusWithSize(result.value().size());
392 }
393 
ReadPackedFixedField(std::span<std::byte> out,size_t elem_size)394 StatusWithSize StreamDecoder::ReadPackedFixedField(std::span<std::byte> out,
395                                                    size_t elem_size) {
396   if (Status status = CheckOkToRead(WireType::kDelimited); !status.ok()) {
397     return StatusWithSize(status, 0);
398   }
399 
400   if (reader_.ConservativeReadLimit() < delimited_field_size_) {
401     status_ = Status::DataLoss();
402     return StatusWithSize(status_, 0);
403   }
404 
405   if (out.size() < delimited_field_size_) {
406     // Value can't fit into the provided buffer. Don't advance the cursor so
407     // that the field can be re-read with a larger buffer or through the stream
408     // API.
409     return StatusWithSize::ResourceExhausted();
410   }
411 
412   Result<ByteSpan> result = reader_.Read(out.first(delimited_field_size_));
413   if (!result.ok()) {
414     return StatusWithSize(result.status(), 0);
415   }
416 
417   position_ += result.value().size();
418   field_consumed_ = true;
419 
420   // Decode little-endian serialized packed fields.
421   if (std::endian::native != std::endian::little) {
422     for (auto out_start = out.begin(); out_start != out.end();
423          out_start += elem_size) {
424       std::reverse(out_start, out_start + elem_size);
425     }
426   }
427 
428   return StatusWithSize(result.value().size() / elem_size);
429 }
430 
ReadPackedVarintField(std::span<std::byte> out,size_t elem_size,VarintDecodeType decode_type)431 StatusWithSize StreamDecoder::ReadPackedVarintField(
432     std::span<std::byte> out, size_t elem_size, VarintDecodeType decode_type) {
433   PW_CHECK(elem_size == sizeof(bool) || elem_size == sizeof(uint32_t) ||
434                elem_size == sizeof(uint64_t),
435            "Protobuf varints must only be used with bool, int32_t, uint32_t, "
436            "int64_t, or uint64_t");
437 
438   if (Status status = CheckOkToRead(WireType::kDelimited); !status.ok()) {
439     return StatusWithSize(status, 0);
440   }
441 
442   if (reader_.ConservativeReadLimit() < delimited_field_size_) {
443     status_ = Status::DataLoss();
444     return StatusWithSize(status_, 0);
445   }
446 
447   size_t bytes_read = 0;
448   size_t number_out = 0;
449   while (bytes_read < delimited_field_size_ && !out.empty()) {
450     const StatusWithSize sws = ReadOneVarint(out.first(elem_size), decode_type);
451     if (!sws.ok()) {
452       return StatusWithSize(sws.status(), number_out);
453     }
454 
455     bytes_read += sws.size();
456     out = out.subspan(elem_size);
457     ++number_out;
458   }
459 
460   if (bytes_read < delimited_field_size_) {
461     return StatusWithSize(Status::ResourceExhausted(), number_out);
462   }
463 
464   field_consumed_ = true;
465   return StatusWithSize(OkStatus(), number_out);
466 }
467 
CheckOkToRead(WireType type)468 Status StreamDecoder::CheckOkToRead(WireType type) {
469   PW_CHECK(!nested_reader_open_,
470            "Cannot read from a decoder while a nested decoder is open");
471   PW_CHECK(!field_consumed_,
472            "Attempting to read from protobuf decoder without first calling "
473            "Next()");
474 
475   // Attempting to read the wrong type is typically a programmer error;
476   // however, it could also occur due to data corruption. As we don't want to
477   // crash on bad data, return NOT_FOUND here to distinguish it from other
478   // corruption cases.
479   if (current_field_.wire_type() != type) {
480     status_ = Status::NotFound();
481   }
482 
483   return status_;
484 }
485 
486 }  // namespace pw::protobuf
487