1 // Copyright 2020 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include "pw_protobuf/decoder.h"
16
17 #include <cstring>
18
19 #include "pw_varint/varint.h"
20
21 namespace pw::protobuf {
22
Next()23 Status Decoder::Next() {
24 if (!previous_field_consumed_) {
25 if (Status status = SkipField(); !status.ok()) {
26 return status;
27 }
28 }
29 if (proto_.empty()) {
30 return Status::OutOfRange();
31 }
32 previous_field_consumed_ = false;
33 return FieldSize() == 0 ? Status::DataLoss() : OkStatus();
34 }
35
SkipField()36 Status Decoder::SkipField() {
37 if (proto_.empty()) {
38 return Status::OutOfRange();
39 }
40
41 size_t bytes_to_skip = FieldSize();
42 if (bytes_to_skip == 0) {
43 return Status::DataLoss();
44 }
45
46 proto_ = proto_.subspan(bytes_to_skip);
47 return proto_.empty() ? Status::OutOfRange() : OkStatus();
48 }
49
FieldNumber() const50 uint32_t Decoder::FieldNumber() const {
51 uint64_t key;
52 varint::Decode(proto_, &key);
53 if (!FieldKey::IsValidKey(key)) {
54 return 0;
55 }
56 return FieldKey(key).field_number();
57 }
58
ReadUint32(uint32_t * out)59 Status Decoder::ReadUint32(uint32_t* out) {
60 uint64_t value = 0;
61 Status status = ReadUint64(&value);
62 if (!status.ok()) {
63 return status;
64 }
65 if (value > std::numeric_limits<uint32_t>::max()) {
66 return Status::OutOfRange();
67 }
68 *out = value;
69 return OkStatus();
70 }
71
ReadSint32(int32_t * out)72 Status Decoder::ReadSint32(int32_t* out) {
73 int64_t value = 0;
74 Status status = ReadSint64(&value);
75 if (!status.ok()) {
76 return status;
77 }
78 if (value > std::numeric_limits<int32_t>::max()) {
79 return Status::OutOfRange();
80 }
81 *out = value;
82 return OkStatus();
83 }
84
ReadSint64(int64_t * out)85 Status Decoder::ReadSint64(int64_t* out) {
86 uint64_t value = 0;
87 Status status = ReadUint64(&value);
88 if (!status.ok()) {
89 return status;
90 }
91 *out = varint::ZigZagDecode(value);
92 return OkStatus();
93 }
94
ReadBool(bool * out)95 Status Decoder::ReadBool(bool* out) {
96 uint64_t value = 0;
97 Status status = ReadUint64(&value);
98 if (!status.ok()) {
99 return status;
100 }
101 *out = value;
102 return OkStatus();
103 }
104
ReadString(std::string_view * out)105 Status Decoder::ReadString(std::string_view* out) {
106 span<const std::byte> bytes;
107 Status status = ReadDelimited(&bytes);
108 if (!status.ok()) {
109 return status;
110 }
111 *out = std::string_view(reinterpret_cast<const char*>(bytes.data()),
112 bytes.size());
113 return OkStatus();
114 }
115
FieldSize() const116 size_t Decoder::FieldSize() const {
117 uint64_t key;
118 size_t key_size = varint::Decode(proto_, &key);
119 if (key_size == 0 || !FieldKey::IsValidKey(key)) {
120 return 0;
121 }
122
123 span<const std::byte> remainder = proto_.subspan(key_size);
124 uint64_t value = 0;
125 size_t expected_size = 0;
126
127 switch (FieldKey(key).wire_type()) {
128 case WireType::kVarint:
129 expected_size = varint::Decode(remainder, &value);
130 if (expected_size == 0) {
131 return 0;
132 }
133 break;
134
135 case WireType::kDelimited:
136 // Varint at cursor indicates size of the field.
137 expected_size = varint::Decode(remainder, &value);
138 if (expected_size == 0) {
139 return 0;
140 }
141 expected_size += value;
142 break;
143
144 case WireType::kFixed32:
145 expected_size = sizeof(uint32_t);
146 break;
147
148 case WireType::kFixed64:
149 expected_size = sizeof(uint64_t);
150 break;
151 }
152
153 if (remainder.size() < expected_size) {
154 return 0;
155 }
156
157 return key_size + expected_size;
158 }
159
ConsumeKey(WireType expected_type)160 Status Decoder::ConsumeKey(WireType expected_type) {
161 uint64_t key;
162 size_t bytes_read = varint::Decode(proto_, &key);
163 if (bytes_read == 0) {
164 return Status::FailedPrecondition();
165 }
166
167 if (!FieldKey::IsValidKey(key)) {
168 return Status::DataLoss();
169 }
170
171 if (FieldKey(key).wire_type() != expected_type) {
172 return Status::FailedPrecondition();
173 }
174
175 // Advance past the key.
176 proto_ = proto_.subspan(bytes_read);
177 return OkStatus();
178 }
179
ReadVarint(uint64_t * out)180 Status Decoder::ReadVarint(uint64_t* out) {
181 if (Status status = ConsumeKey(WireType::kVarint); !status.ok()) {
182 return status;
183 }
184
185 size_t bytes_read = varint::Decode(proto_, out);
186 if (bytes_read == 0) {
187 return Status::DataLoss();
188 }
189
190 // Advance to the next field.
191 proto_ = proto_.subspan(bytes_read);
192 previous_field_consumed_ = true;
193 return OkStatus();
194 }
195
ReadFixed(std::byte * out,size_t size)196 Status Decoder::ReadFixed(std::byte* out, size_t size) {
197 WireType expected_wire_type =
198 size == sizeof(uint32_t) ? WireType::kFixed32 : WireType::kFixed64;
199 Status status = ConsumeKey(expected_wire_type);
200 if (!status.ok()) {
201 return status;
202 }
203
204 if (proto_.size() < size) {
205 return Status::DataLoss();
206 }
207
208 std::memcpy(out, proto_.data(), size);
209 proto_ = proto_.subspan(size);
210 previous_field_consumed_ = true;
211
212 return OkStatus();
213 }
214
ReadDelimited(span<const std::byte> * out)215 Status Decoder::ReadDelimited(span<const std::byte>* out) {
216 Status status = ConsumeKey(WireType::kDelimited);
217 if (!status.ok()) {
218 return status;
219 }
220
221 uint64_t length;
222 size_t bytes_read = varint::Decode(proto_, &length);
223 if (bytes_read == 0) {
224 return Status::DataLoss();
225 }
226
227 proto_ = proto_.subspan(bytes_read);
228 if (proto_.size() < length) {
229 return Status::DataLoss();
230 }
231
232 *out = proto_.first(length);
233 proto_ = proto_.subspan(length);
234 previous_field_consumed_ = true;
235
236 return OkStatus();
237 }
238
Decode(span<const std::byte> proto)239 Status CallbackDecoder::Decode(span<const std::byte> proto) {
240 if (handler_ == nullptr || state_ != kReady) {
241 return Status::FailedPrecondition();
242 }
243
244 state_ = kDecodeInProgress;
245 decoder_.Reset(proto);
246
247 // Iterate the proto, calling the handler with each field number.
248 while (state_ == kDecodeInProgress) {
249 if (Status status = decoder_.Next(); !status.ok()) {
250 if (status.IsOutOfRange()) {
251 // Reached the end of the proto.
252 break;
253 }
254
255 // Proto data is malformed.
256 return status;
257 }
258
259 Status status = handler_->ProcessField(*this, decoder_.FieldNumber());
260 if (!status.ok()) {
261 state_ = status.IsCancelled() ? kDecodeCancelled : kDecodeFailed;
262 return status;
263 }
264
265 // The callback function can modify the decoder's state; check that
266 // everything is still okay.
267 if (state_ == kDecodeFailed) {
268 break;
269 }
270 }
271
272 if (state_ != kDecodeInProgress) {
273 return Status::DataLoss();
274 }
275
276 state_ = kReady;
277 return OkStatus();
278 }
279
280 } // namespace pw::protobuf
281