1 // Copyright 2021 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include "pw_protobuf/encoder.h"
16
17 #include <algorithm>
18 #include <cstddef>
19 #include <cstring>
20 #include <optional>
21
22 #include "pw_assert/check.h"
23 #include "pw_bytes/span.h"
24 #include "pw_protobuf/internal/codegen.h"
25 #include "pw_protobuf/serialized_size.h"
26 #include "pw_protobuf/stream_decoder.h"
27 #include "pw_protobuf/wire_format.h"
28 #include "pw_span/span.h"
29 #include "pw_status/status.h"
30 #include "pw_status/try.h"
31 #include "pw_stream/memory_stream.h"
32 #include "pw_stream/stream.h"
33 #include "pw_string/string.h"
34 #include "pw_varint/varint.h"
35
36 namespace pw::protobuf {
37
38 using internal::VarintType;
39
GetNestedEncoder(uint32_t field_number,bool write_when_empty)40 StreamEncoder StreamEncoder::GetNestedEncoder(uint32_t field_number,
41 bool write_when_empty) {
42 PW_CHECK(!nested_encoder_open());
43 PW_CHECK(ValidFieldNumber(field_number));
44
45 nested_field_number_ = field_number;
46
47 // Pass the unused space of the scratch buffer to the nested encoder to use
48 // as their scratch buffer.
49 size_t key_size =
50 varint::EncodedSize(FieldKey(field_number, WireType::kDelimited));
51 size_t reserved_size = key_size + config::kMaxVarintSize;
52 size_t max_size = std::min(memory_writer_.ConservativeWriteLimit(),
53 writer_.ConservativeWriteLimit());
54 // Account for reserved bytes.
55 max_size = max_size > reserved_size ? max_size - reserved_size : 0;
56 // Cap based on max varint size.
57 max_size = std::min(varint::MaxValueInBytes(config::kMaxVarintSize),
58 static_cast<uint64_t>(max_size));
59
60 ByteSpan nested_buffer;
61 if (max_size > 0) {
62 nested_buffer = ByteSpan(
63 memory_writer_.data() + reserved_size + memory_writer_.bytes_written(),
64 max_size);
65 } else {
66 nested_buffer = ByteSpan();
67 }
68 return StreamEncoder(*this, nested_buffer, write_when_empty);
69 }
70
CloseEncoder()71 void StreamEncoder::CloseEncoder() {
72 // If this was an invalidated StreamEncoder which cannot be used, permit the
73 // object to be cleanly destructed by doing nothing.
74 if (nested_field_number_ == kFirstReservedNumber) {
75 return;
76 }
77
78 PW_CHECK(
79 !nested_encoder_open(),
80 "Tried to destruct a proto encoder with an active submessage encoder");
81
82 if (parent_ != nullptr) {
83 parent_->CloseNestedMessage(*this);
84 }
85 }
86
CloseNestedMessage(StreamEncoder & nested)87 void StreamEncoder::CloseNestedMessage(StreamEncoder& nested) {
88 PW_DCHECK_PTR_EQ(nested.parent_,
89 this,
90 "CloseNestedMessage() called on the wrong Encoder parent");
91
92 // Make the nested encoder look like it has an open child to block writes for
93 // the remainder of the object's life.
94 nested.nested_field_number_ = kFirstReservedNumber;
95 nested.parent_ = nullptr;
96 // Temporarily cache the field number of the child so we can re-enable
97 // writing to this encoder.
98 uint32_t temp_field_number = nested_field_number_;
99 nested_field_number_ = 0;
100
101 // TODO(amontanez): If a submessage fails, we could optionally discard
102 // it and continue happily. For now, we'll always invalidate the entire
103 // encoder if a single submessage fails.
104 status_.Update(nested.status_);
105 if (!status_.ok()) {
106 return;
107 }
108
109 if (varint::EncodedSize(nested.memory_writer_.bytes_written()) >
110 config::kMaxVarintSize) {
111 status_ = Status::OutOfRange();
112 return;
113 }
114
115 if (!nested.memory_writer_.bytes_written() && !nested.write_when_empty_) {
116 return;
117 }
118
119 status_ = WriteLengthDelimitedField(temp_field_number,
120 nested.memory_writer_.WrittenData());
121 }
122
WriteVarintField(uint32_t field_number,uint64_t value)123 Status StreamEncoder::WriteVarintField(uint32_t field_number, uint64_t value) {
124 PW_TRY(UpdateStatusForWrite(
125 field_number, WireType::kVarint, varint::EncodedSize(value)));
126
127 WriteVarint(FieldKey(field_number, WireType::kVarint))
128 .IgnoreError(); // TODO(b/242598609): Handle Status properly
129 return WriteVarint(value);
130 }
131
WriteLengthDelimitedField(uint32_t field_number,ConstByteSpan data)132 Status StreamEncoder::WriteLengthDelimitedField(uint32_t field_number,
133 ConstByteSpan data) {
134 PW_TRY(UpdateStatusForWrite(field_number, WireType::kDelimited, data.size()));
135 status_.Update(WriteLengthDelimitedKeyAndLengthPrefix(
136 field_number, data.size(), writer_));
137 PW_TRY(status_);
138 if (Status status = writer_.Write(data); !status.ok()) {
139 status_ = status;
140 }
141 return status_;
142 }
143
WriteLengthDelimitedFieldFromStream(uint32_t field_number,stream::Reader & bytes_reader,size_t num_bytes,ByteSpan stream_pipe_buffer)144 Status StreamEncoder::WriteLengthDelimitedFieldFromStream(
145 uint32_t field_number,
146 stream::Reader& bytes_reader,
147 size_t num_bytes,
148 ByteSpan stream_pipe_buffer) {
149 PW_CHECK_UINT_GT(
150 stream_pipe_buffer.size(), 0, "Transfer buffer cannot be 0 size");
151 PW_TRY(UpdateStatusForWrite(field_number, WireType::kDelimited, num_bytes));
152 status_.Update(
153 WriteLengthDelimitedKeyAndLengthPrefix(field_number, num_bytes, writer_));
154 PW_TRY(status_);
155
156 // Stream data from `bytes_reader` to `writer_`.
157 // TODO(pwbug/468): move the following logic to pw_stream/copy.h at a later
158 // time.
159 for (size_t bytes_written = 0; bytes_written < num_bytes;) {
160 const size_t chunk_size_bytes =
161 std::min(num_bytes - bytes_written, stream_pipe_buffer.size_bytes());
162 const Result<ByteSpan> read_result =
163 bytes_reader.Read(stream_pipe_buffer.data(), chunk_size_bytes);
164 status_.Update(read_result.status());
165 PW_TRY(status_);
166
167 status_.Update(writer_.Write(read_result.value()));
168 PW_TRY(status_);
169
170 bytes_written += read_result.value().size();
171 }
172
173 return OkStatus();
174 }
175
WriteFixed(uint32_t field_number,ConstByteSpan data)176 Status StreamEncoder::WriteFixed(uint32_t field_number, ConstByteSpan data) {
177 WireType type =
178 data.size() == sizeof(uint32_t) ? WireType::kFixed32 : WireType::kFixed64;
179
180 PW_TRY(UpdateStatusForWrite(field_number, type, data.size()));
181
182 WriteVarint(FieldKey(field_number, type))
183 .IgnoreError(); // TODO(b/242598609): Handle Status properly
184 if (Status status = writer_.Write(data); !status.ok()) {
185 status_ = status;
186 }
187 return status_;
188 }
189
WritePackedFixed(uint32_t field_number,span<const std::byte> values,size_t elem_size)190 Status StreamEncoder::WritePackedFixed(uint32_t field_number,
191 span<const std::byte> values,
192 size_t elem_size) {
193 if (values.empty()) {
194 return status_;
195 }
196
197 PW_CHECK_NOTNULL(values.data());
198 PW_DCHECK(elem_size == sizeof(uint32_t) || elem_size == sizeof(uint64_t));
199
200 PW_TRY(UpdateStatusForWrite(
201 field_number, WireType::kDelimited, values.size_bytes()));
202 WriteVarint(FieldKey(field_number, WireType::kDelimited))
203 .IgnoreError(); // TODO(b/242598609): Handle Status properly
204 WriteVarint(values.size_bytes())
205 .IgnoreError(); // TODO(b/242598609): Handle Status properly
206
207 for (auto val_start = values.begin(); val_start != values.end();
208 val_start += elem_size) {
209 // Allocates 8 bytes so both 4-byte and 8-byte types can be encoded as
210 // little-endian for serialization.
211 std::array<std::byte, sizeof(uint64_t)> data;
212 if (endian::native == endian::little) {
213 std::copy(val_start, val_start + elem_size, std::begin(data));
214 } else {
215 std::reverse_copy(val_start, val_start + elem_size, std::begin(data));
216 }
217 status_.Update(writer_.Write(span(data).first(elem_size)));
218 PW_TRY(status_);
219 }
220 return status_;
221 }
222
UpdateStatusForWrite(uint32_t field_number,WireType type,size_t data_size)223 Status StreamEncoder::UpdateStatusForWrite(uint32_t field_number,
224 WireType type,
225 size_t data_size) {
226 PW_CHECK(!nested_encoder_open());
227 PW_TRY(status_);
228
229 if (!ValidFieldNumber(field_number)) {
230 return status_ = Status::InvalidArgument();
231 }
232
233 const Result<size_t> field_size = SizeOfField(field_number, type, data_size);
234 status_.Update(field_size.status());
235 PW_TRY(status_);
236
237 if (field_size.value() > writer_.ConservativeWriteLimit()) {
238 status_ = Status::ResourceExhausted();
239 }
240
241 return status_;
242 }
243
Write(span<const std::byte> message,span<const internal::MessageField> table)244 Status StreamEncoder::Write(span<const std::byte> message,
245 span<const internal::MessageField> table) {
246 PW_CHECK(!nested_encoder_open());
247 PW_TRY(status_);
248
249 for (const auto& field : table) {
250 // Calculate the span of bytes corresponding to the structure field to
251 // read from.
252 const auto values =
253 message.subspan(field.field_offset(), field.field_size());
254 PW_CHECK(values.begin() >= message.begin() &&
255 values.end() <= message.end());
256
257 // If the field is using callbacks, interpret the input field accordingly
258 // and allow the caller to provide custom handling.
259 if (field.use_callback()) {
260 const Callback<StreamEncoder, StreamDecoder>* callback =
261 reinterpret_cast<const Callback<StreamEncoder, StreamDecoder>*>(
262 values.data());
263 PW_TRY(callback->Encode(*this));
264 continue;
265 }
266
267 switch (field.wire_type()) {
268 case WireType::kFixed64:
269 case WireType::kFixed32: {
270 // Fixed fields call WriteFixed() for singular case and
271 // WritePackedFixed() for repeated fields.
272 PW_CHECK(field.elem_size() == (field.wire_type() == WireType::kFixed32
273 ? sizeof(uint32_t)
274 : sizeof(uint64_t)),
275 "Mismatched message field type and size");
276 if (field.is_fixed_size()) {
277 PW_CHECK(field.is_repeated(), "Non-repeated fixed size field");
278 if (static_cast<size_t>(
279 std::count(values.begin(), values.end(), std::byte{0})) <
280 values.size()) {
281 PW_TRY(WritePackedFixed(
282 field.field_number(), values, field.elem_size()));
283 }
284 } else if (field.is_repeated()) {
285 // The struct member for this field is a vector of a type
286 // corresponding to the field element size. Cast to the correct
287 // vector type so we're not performing type aliasing (except for
288 // unsigned vs signed which is explicitly allowed).
289 if (field.elem_size() == sizeof(uint64_t)) {
290 const auto* vector =
291 reinterpret_cast<const pw::Vector<const uint64_t>*>(
292 values.data());
293 if (!vector->empty()) {
294 PW_TRY(WritePackedFixed(
295 field.field_number(),
296 as_bytes(span(vector->data(), vector->size())),
297 field.elem_size()));
298 }
299 } else if (field.elem_size() == sizeof(uint32_t)) {
300 const auto* vector =
301 reinterpret_cast<const pw::Vector<const uint32_t>*>(
302 values.data());
303 if (!vector->empty()) {
304 PW_TRY(WritePackedFixed(
305 field.field_number(),
306 as_bytes(span(vector->data(), vector->size())),
307 field.elem_size()));
308 }
309 }
310 } else if (field.is_optional()) {
311 // The struct member for this field is a std::optional of a type
312 // corresponding to the field element size. Cast to the correct
313 // optional type so we're not performing type aliasing (except for
314 // unsigned vs signed which is explicitly allowed), and write from
315 // a temporary.
316 if (field.elem_size() == sizeof(uint64_t)) {
317 const auto* optional =
318 reinterpret_cast<const std::optional<uint64_t>*>(values.data());
319 if (optional->has_value()) {
320 uint64_t value = optional->value();
321 PW_TRY(
322 WriteFixed(field.field_number(), as_bytes(span(&value, 1))));
323 }
324 } else if (field.elem_size() == sizeof(uint32_t)) {
325 const auto* optional =
326 reinterpret_cast<const std::optional<uint32_t>*>(values.data());
327 if (optional->has_value()) {
328 uint32_t value = optional->value();
329 PW_TRY(
330 WriteFixed(field.field_number(), as_bytes(span(&value, 1))));
331 }
332 }
333 } else {
334 PW_CHECK(values.size() == field.elem_size(),
335 "Mismatched message field type and size");
336 if (static_cast<size_t>(
337 std::count(values.begin(), values.end(), std::byte{0})) <
338 values.size()) {
339 PW_TRY(WriteFixed(field.field_number(), values));
340 }
341 }
342 break;
343 }
344 case WireType::kVarint: {
345 // Varint fields call WriteVarintField() for singular case and
346 // WritePackedVarints() for repeated fields.
347 PW_CHECK(field.elem_size() == sizeof(uint64_t) ||
348 field.elem_size() == sizeof(uint32_t) ||
349 field.elem_size() == sizeof(bool),
350 "Mismatched message field type and size");
351 if (field.is_fixed_size()) {
352 // The struct member for this field is an array of type corresponding
353 // to the field element size. Cast to a span of the correct type over
354 // the array so we're not performing type aliasing (except for
355 // unsigned vs signed which is explicitly allowed).
356 PW_CHECK(field.is_repeated(), "Non-repeated fixed size field");
357 if (static_cast<size_t>(
358 std::count(values.begin(), values.end(), std::byte{0})) ==
359 values.size()) {
360 continue;
361 }
362 if (field.elem_size() == sizeof(uint64_t)) {
363 PW_TRY(WritePackedVarints(
364 field.field_number(),
365 span(reinterpret_cast<const uint64_t*>(values.data()),
366 values.size() / field.elem_size()),
367 field.varint_type()));
368 } else if (field.elem_size() == sizeof(uint32_t)) {
369 PW_TRY(WritePackedVarints(
370 field.field_number(),
371 span(reinterpret_cast<const uint32_t*>(values.data()),
372 values.size() / field.elem_size()),
373 field.varint_type()));
374 } else if (field.elem_size() == sizeof(bool)) {
375 static_assert(sizeof(bool) == sizeof(uint8_t),
376 "bool must be same size as uint8_t");
377 PW_TRY(WritePackedVarints(
378 field.field_number(),
379 span(reinterpret_cast<const uint8_t*>(values.data()),
380 values.size() / field.elem_size()),
381 field.varint_type()));
382 }
383 } else if (field.is_repeated()) {
384 // The struct member for this field is a vector of a type
385 // corresponding to the field element size. Cast to the correct
386 // vector type so we're not performing type aliasing (except for
387 // unsigned vs signed which is explicitly allowed).
388 if (field.elem_size() == sizeof(uint64_t)) {
389 const auto* vector =
390 reinterpret_cast<const pw::Vector<const uint64_t>*>(
391 values.data());
392 if (!vector->empty()) {
393 PW_TRY(WritePackedVarints(field.field_number(),
394 span(vector->data(), vector->size()),
395 field.varint_type()));
396 }
397 } else if (field.elem_size() == sizeof(uint32_t)) {
398 const auto* vector =
399 reinterpret_cast<const pw::Vector<const uint32_t>*>(
400 values.data());
401 if (!vector->empty()) {
402 PW_TRY(WritePackedVarints(field.field_number(),
403 span(vector->data(), vector->size()),
404 field.varint_type()));
405 }
406 } else if (field.elem_size() == sizeof(bool)) {
407 static_assert(sizeof(bool) == sizeof(uint8_t),
408 "bool must be same size as uint8_t");
409 const auto* vector =
410 reinterpret_cast<const pw::Vector<const uint8_t>*>(
411 values.data());
412 if (!vector->empty()) {
413 PW_TRY(WritePackedVarints(field.field_number(),
414 span(vector->data(), vector->size()),
415 field.varint_type()));
416 }
417 }
418 } else if (field.is_optional()) {
419 // The struct member for this field is a std::optional of a type
420 // corresponding to the field element size. Cast to the correct
421 // optional type so we're not performing type aliasing (except for
422 // unsigned vs signed which is explicitly allowed), and write from
423 // a temporary.
424 uint64_t value = 0;
425 if (field.elem_size() == sizeof(uint64_t)) {
426 if (field.varint_type() == VarintType::kUnsigned) {
427 const auto* optional =
428 reinterpret_cast<const std::optional<uint64_t>*>(
429 values.data());
430 if (!optional->has_value()) {
431 continue;
432 }
433 value = optional->value();
434 } else {
435 const auto* optional =
436 reinterpret_cast<const std::optional<int64_t>*>(
437 values.data());
438 if (!optional->has_value()) {
439 continue;
440 }
441 value = field.varint_type() == VarintType::kZigZag
442 ? varint::ZigZagEncode(optional->value())
443 : optional->value();
444 }
445 } else if (field.elem_size() == sizeof(uint32_t)) {
446 if (field.varint_type() == VarintType::kUnsigned) {
447 const auto* optional =
448 reinterpret_cast<const std::optional<uint32_t>*>(
449 values.data());
450 if (!optional->has_value()) {
451 continue;
452 }
453 value = optional->value();
454 } else {
455 const auto* optional =
456 reinterpret_cast<const std::optional<int32_t>*>(
457 values.data());
458 if (!optional->has_value()) {
459 continue;
460 }
461 value = field.varint_type() == VarintType::kZigZag
462 ? varint::ZigZagEncode(optional->value())
463 : optional->value();
464 }
465 } else if (field.elem_size() == sizeof(bool)) {
466 const auto* optional =
467 reinterpret_cast<const std::optional<bool>*>(values.data());
468 if (!optional->has_value()) {
469 continue;
470 }
471 value = optional->value();
472 }
473 PW_TRY(WriteVarintField(field.field_number(), value));
474 } else {
475 // The struct member for this field is a scalar of a type
476 // corresponding to the field element size. Cast to the correct
477 // type to retrieve the value before passing to WriteVarintField()
478 // so we're not performing type aliasing (except for unsigned vs
479 // signed which is explicitly allowed).
480 PW_CHECK(values.size() == field.elem_size(),
481 "Mismatched message field type and size");
482 uint64_t value = 0;
483 if (field.elem_size() == sizeof(uint64_t)) {
484 if (field.varint_type() == VarintType::kZigZag) {
485 value = varint::ZigZagEncode(
486 *reinterpret_cast<const int64_t*>(values.data()));
487 } else if (field.varint_type() == VarintType::kNormal) {
488 value = *reinterpret_cast<const int64_t*>(values.data());
489 } else {
490 value = *reinterpret_cast<const uint64_t*>(values.data());
491 }
492 if (!value) {
493 continue;
494 }
495 } else if (field.elem_size() == sizeof(uint32_t)) {
496 if (field.varint_type() == VarintType::kZigZag) {
497 value = varint::ZigZagEncode(
498 *reinterpret_cast<const int32_t*>(values.data()));
499 } else if (field.varint_type() == VarintType::kNormal) {
500 value = *reinterpret_cast<const int32_t*>(values.data());
501 } else {
502 value = *reinterpret_cast<const uint32_t*>(values.data());
503 }
504 if (!value) {
505 continue;
506 }
507 } else if (field.elem_size() == sizeof(bool)) {
508 value = *reinterpret_cast<const bool*>(values.data());
509 if (!value) {
510 continue;
511 }
512 }
513 PW_TRY(WriteVarintField(field.field_number(), value));
514 }
515 break;
516 }
517 case WireType::kDelimited: {
518 // Delimited fields are always a singular case because of the
519 // inability to cast to a generic vector with an element of a certain
520 // size (we always need a type).
521 PW_CHECK(!field.is_repeated(),
522 "Repeated delimited messages always require a callback");
523 if (field.nested_message_fields()) {
524 // Nested Message. Struct member is an embedded struct for the
525 // nested field. Obtain a nested encoder and recursively call Write()
526 // using the fields table pointer from this field.
527 auto nested_encoder = GetNestedEncoder(field.field_number(),
528 /*write_when_empty=*/false);
529 PW_TRY(nested_encoder.Write(values, *field.nested_message_fields()));
530 } else if (field.is_fixed_size()) {
531 // Fixed-length bytes field. Struct member is a std::array<std::byte>.
532 // Call WriteLengthDelimitedField() to output it to the stream.
533 PW_CHECK(field.elem_size() == sizeof(std::byte),
534 "Mismatched message field type and size");
535 if (static_cast<size_t>(
536 std::count(values.begin(), values.end(), std::byte{0})) <
537 values.size()) {
538 PW_TRY(WriteLengthDelimitedField(field.field_number(), values));
539 }
540 } else {
541 // bytes or string field with a maximum size. Struct member is
542 // pw::Vector<std::byte> for bytes or pw::InlineString<> for string.
543 // Use the contents as a span and call WriteLengthDelimitedField() to
544 // output it to the stream.
545 PW_CHECK(field.elem_size() == sizeof(std::byte),
546 "Mismatched message field type and size");
547 if (field.is_string()) {
548 PW_TRY(WriteStringOrBytes<const InlineString<>>(
549 field.field_number(), values.data()));
550 } else {
551 PW_TRY(WriteStringOrBytes<const Vector<const std::byte>>(
552 field.field_number(), values.data()));
553 }
554 }
555 break;
556 }
557 }
558 }
559
560 return status_;
561 }
562
563 } // namespace pw::protobuf
564