1 // Copyright 2021 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include "pw_protobuf/encoder.h"
16
17 #include <algorithm>
18 #include <cstddef>
19 #include <cstring>
20 #include <optional>
21
22 #include "pw_assert/check.h"
23 #include "pw_bytes/span.h"
24 #include "pw_protobuf/internal/codegen.h"
25 #include "pw_protobuf/serialized_size.h"
26 #include "pw_protobuf/stream_decoder.h"
27 #include "pw_protobuf/wire_format.h"
28 #include "pw_span/span.h"
29 #include "pw_status/status.h"
30 #include "pw_status/try.h"
31 #include "pw_stream/memory_stream.h"
32 #include "pw_stream/stream.h"
33 #include "pw_string/string.h"
34 #include "pw_varint/varint.h"
35
36 namespace pw::protobuf {
37
38 using internal::VarintType;
39
GetNestedEncoder(uint32_t field_number,bool write_when_empty)40 StreamEncoder StreamEncoder::GetNestedEncoder(uint32_t field_number,
41 bool write_when_empty) {
42 PW_CHECK(!nested_encoder_open());
43
44 nested_field_number_ = field_number;
45 if (!ValidFieldNumber(field_number)) {
46 status_.Update(Status::InvalidArgument());
47 return StreamEncoder(*this, ByteSpan(), false);
48 }
49
50 // Pass the unused space of the scratch buffer to the nested encoder to use
51 // as their scratch buffer.
52 size_t key_size =
53 varint::EncodedSize(FieldKey(field_number, WireType::kDelimited));
54 size_t reserved_size = key_size + config::kMaxVarintSize;
55 size_t max_size = std::min(memory_writer_.ConservativeWriteLimit(),
56 writer_.ConservativeWriteLimit());
57 // Cap based on max varint size.
58 max_size = std::min(varint::MaxValueInBytes(config::kMaxVarintSize),
59 static_cast<uint64_t>(max_size));
60
61 // Account for reserved bytes.
62 max_size = max_size > reserved_size ? max_size - reserved_size : 0;
63
64 ByteSpan nested_buffer;
65 if (max_size > 0) {
66 nested_buffer = ByteSpan(
67 memory_writer_.data() + reserved_size + memory_writer_.bytes_written(),
68 max_size);
69 } else {
70 nested_buffer = ByteSpan();
71 }
72 return StreamEncoder(*this, nested_buffer, write_when_empty);
73 }
74
CloseEncoder()75 void StreamEncoder::CloseEncoder() {
76 // If this was an invalidated StreamEncoder which cannot be used, permit the
77 // object to be cleanly destructed by doing nothing.
78 if (nested_field_number_ == kFirstReservedNumber) {
79 return;
80 }
81
82 PW_CHECK(
83 !nested_encoder_open(),
84 "Tried to destruct a proto encoder with an active submessage encoder");
85
86 if (parent_ != nullptr) {
87 parent_->CloseNestedMessage(*this);
88 }
89 }
90
CloseNestedMessage(StreamEncoder & nested)91 void StreamEncoder::CloseNestedMessage(StreamEncoder& nested) {
92 PW_DCHECK_PTR_EQ(nested.parent_,
93 this,
94 "CloseNestedMessage() called on the wrong Encoder parent");
95
96 // Make the nested encoder look like it has an open child to block writes for
97 // the remainder of the object's life.
98 nested.nested_field_number_ = kFirstReservedNumber;
99 nested.parent_ = nullptr;
100 // Temporarily cache the field number of the child so we can re-enable
101 // writing to this encoder.
102 uint32_t temp_field_number = nested_field_number_;
103 nested_field_number_ = 0;
104
105 // TODO(amontanez): If a submessage fails, we could optionally discard
106 // it and continue happily. For now, we'll always invalidate the entire
107 // encoder if a single submessage fails.
108 status_.Update(nested.status_);
109 if (!status_.ok()) {
110 return;
111 }
112
113 if (varint::EncodedSize(nested.memory_writer_.bytes_written()) >
114 config::kMaxVarintSize) {
115 status_ = Status::OutOfRange();
116 return;
117 }
118
119 if (!nested.memory_writer_.bytes_written() && !nested.write_when_empty_) {
120 return;
121 }
122
123 status_ = WriteLengthDelimitedField(temp_field_number,
124 nested.memory_writer_.WrittenData());
125 }
126
WriteVarintField(uint32_t field_number,uint64_t value)127 Status StreamEncoder::WriteVarintField(uint32_t field_number, uint64_t value) {
128 PW_TRY(UpdateStatusForWrite(
129 field_number, WireType::kVarint, varint::EncodedSize(value)));
130
131 WriteVarint(FieldKey(field_number, WireType::kVarint))
132 .IgnoreError(); // TODO: b/242598609 - Handle Status properly
133 return WriteVarint(value);
134 }
135
WriteLengthDelimitedField(uint32_t field_number,ConstByteSpan data)136 Status StreamEncoder::WriteLengthDelimitedField(uint32_t field_number,
137 ConstByteSpan data) {
138 PW_TRY(UpdateStatusForWrite(field_number, WireType::kDelimited, data.size()));
139 status_.Update(WriteLengthDelimitedKeyAndLengthPrefix(
140 field_number, data.size(), writer_));
141 PW_TRY(status_);
142 if (Status status = writer_.Write(data); !status.ok()) {
143 status_ = status;
144 }
145 return status_;
146 }
147
WriteLengthDelimitedFieldFromStream(uint32_t field_number,stream::Reader & bytes_reader,size_t num_bytes,ByteSpan stream_pipe_buffer)148 Status StreamEncoder::WriteLengthDelimitedFieldFromStream(
149 uint32_t field_number,
150 stream::Reader& bytes_reader,
151 size_t num_bytes,
152 ByteSpan stream_pipe_buffer) {
153 PW_CHECK_UINT_GT(
154 stream_pipe_buffer.size(), 0, "Transfer buffer cannot be 0 size");
155 PW_TRY(UpdateStatusForWrite(field_number, WireType::kDelimited, num_bytes));
156 status_.Update(
157 WriteLengthDelimitedKeyAndLengthPrefix(field_number, num_bytes, writer_));
158 PW_TRY(status_);
159
160 // Stream data from `bytes_reader` to `writer_`.
161 // TODO(pwbug/468): move the following logic to pw_stream/copy.h at a later
162 // time.
163 for (size_t bytes_written = 0; bytes_written < num_bytes;) {
164 const size_t chunk_size_bytes =
165 std::min(num_bytes - bytes_written, stream_pipe_buffer.size_bytes());
166 const Result<ByteSpan> read_result =
167 bytes_reader.Read(stream_pipe_buffer.data(), chunk_size_bytes);
168 status_.Update(read_result.status());
169 PW_TRY(status_);
170
171 status_.Update(writer_.Write(read_result.value()));
172 PW_TRY(status_);
173
174 bytes_written += read_result.value().size();
175 }
176
177 return OkStatus();
178 }
179
WriteFixed(uint32_t field_number,ConstByteSpan data)180 Status StreamEncoder::WriteFixed(uint32_t field_number, ConstByteSpan data) {
181 WireType type =
182 data.size() == sizeof(uint32_t) ? WireType::kFixed32 : WireType::kFixed64;
183
184 PW_TRY(UpdateStatusForWrite(field_number, type, data.size()));
185
186 WriteVarint(FieldKey(field_number, type))
187 .IgnoreError(); // TODO: b/242598609 - Handle Status properly
188 if (Status status = writer_.Write(data); !status.ok()) {
189 status_ = status;
190 }
191 return status_;
192 }
193
WritePackedFixed(uint32_t field_number,span<const std::byte> values,size_t elem_size)194 Status StreamEncoder::WritePackedFixed(uint32_t field_number,
195 span<const std::byte> values,
196 size_t elem_size) {
197 if (values.empty()) {
198 return status_;
199 }
200
201 PW_CHECK_NOTNULL(values.data());
202 PW_DCHECK(elem_size == sizeof(uint32_t) || elem_size == sizeof(uint64_t));
203
204 PW_TRY(UpdateStatusForWrite(
205 field_number, WireType::kDelimited, values.size_bytes()));
206 WriteVarint(FieldKey(field_number, WireType::kDelimited))
207 .IgnoreError(); // TODO: b/242598609 - Handle Status properly
208 WriteVarint(values.size_bytes())
209 .IgnoreError(); // TODO: b/242598609 - Handle Status properly
210
211 for (auto val_start = values.begin(); val_start != values.end();
212 val_start += elem_size) {
213 // Allocates 8 bytes so both 4-byte and 8-byte types can be encoded as
214 // little-endian for serialization.
215 std::array<std::byte, sizeof(uint64_t)> data;
216 if (endian::native == endian::little) {
217 std::copy(val_start, val_start + elem_size, std::begin(data));
218 } else {
219 std::reverse_copy(val_start, val_start + elem_size, std::begin(data));
220 }
221 status_.Update(writer_.Write(span(data).first(elem_size)));
222 PW_TRY(status_);
223 }
224 return status_;
225 }
226
UpdateStatusForWrite(uint32_t field_number,WireType type,size_t data_size)227 Status StreamEncoder::UpdateStatusForWrite(uint32_t field_number,
228 WireType type,
229 size_t data_size) {
230 PW_CHECK(!nested_encoder_open());
231 PW_TRY(status_);
232
233 if (!ValidFieldNumber(field_number)) {
234 return status_ = Status::InvalidArgument();
235 }
236
237 const Result<size_t> field_size = SizeOfField(field_number, type, data_size);
238 status_.Update(field_size.status());
239 PW_TRY(status_);
240
241 if (field_size.value() > writer_.ConservativeWriteLimit()) {
242 status_ = Status::ResourceExhausted();
243 }
244
245 return status_;
246 }
247
Write(span<const std::byte> message,span<const internal::MessageField> table)248 Status StreamEncoder::Write(span<const std::byte> message,
249 span<const internal::MessageField> table) {
250 PW_CHECK(!nested_encoder_open());
251 PW_TRY(status_);
252
253 for (const auto& field : table) {
254 // Calculate the span of bytes corresponding to the structure field to
255 // read from.
256 ConstByteSpan values =
257 message.subspan(field.field_offset(), field.field_size());
258 PW_CHECK(values.begin() >= message.begin() &&
259 values.end() <= message.end());
260
261 // If the field is using callbacks, interpret the input field accordingly
262 // and allow the caller to provide custom handling.
263 if (field.callback_type() == internal::CallbackType::kSingleField) {
264 const Callback<StreamEncoder, StreamDecoder>* callback =
265 reinterpret_cast<const Callback<StreamEncoder, StreamDecoder>*>(
266 values.data());
267 PW_TRY(callback->Encode(*this));
268 continue;
269 } else if (field.callback_type() == internal::CallbackType::kOneOfGroup) {
270 const OneOf<StreamEncoder, StreamDecoder>* callback =
271 reinterpret_cast<const OneOf<StreamEncoder, StreamDecoder>*>(
272 values.data());
273 PW_TRY(callback->Encode(*this));
274 continue;
275 }
276
277 switch (field.wire_type()) {
278 case WireType::kFixed64:
279 case WireType::kFixed32: {
280 // Fixed fields call WriteFixed() for singular case and
281 // WritePackedFixed() for repeated fields.
282 PW_CHECK(field.elem_size() == (field.wire_type() == WireType::kFixed32
283 ? sizeof(uint32_t)
284 : sizeof(uint64_t)),
285 "Mismatched message field type and size");
286 if (field.is_fixed_size()) {
287 PW_CHECK(field.is_repeated(), "Non-repeated fixed size field");
288 if (static_cast<size_t>(
289 std::count(values.begin(), values.end(), std::byte{0})) <
290 values.size()) {
291 PW_TRY(WritePackedFixed(
292 field.field_number(), values, field.elem_size()));
293 }
294 } else if (field.is_repeated()) {
295 // The struct member for this field is a vector of a type
296 // corresponding to the field element size. Cast to the correct
297 // vector type so we're not performing type aliasing (except for
298 // unsigned vs signed which is explicitly allowed).
299 if (field.elem_size() == sizeof(uint64_t)) {
300 const auto* vector =
301 reinterpret_cast<const pw::Vector<const uint64_t>*>(
302 values.data());
303 if (!vector->empty()) {
304 PW_TRY(WritePackedFixed(
305 field.field_number(),
306 as_bytes(span(vector->data(), vector->size())),
307 field.elem_size()));
308 }
309 } else if (field.elem_size() == sizeof(uint32_t)) {
310 const auto* vector =
311 reinterpret_cast<const pw::Vector<const uint32_t>*>(
312 values.data());
313 if (!vector->empty()) {
314 PW_TRY(WritePackedFixed(
315 field.field_number(),
316 as_bytes(span(vector->data(), vector->size())),
317 field.elem_size()));
318 }
319 }
320 } else if (field.is_optional()) {
321 // The struct member for this field is a std::optional of a type
322 // corresponding to the field element size. Cast to the correct
323 // optional type so we're not performing type aliasing (except for
324 // unsigned vs signed which is explicitly allowed), and write from
325 // a temporary.
326 if (field.elem_size() == sizeof(uint64_t)) {
327 const auto* optional =
328 reinterpret_cast<const std::optional<uint64_t>*>(values.data());
329 if (optional->has_value()) {
330 uint64_t value = optional->value();
331 PW_TRY(
332 WriteFixed(field.field_number(), as_bytes(span(&value, 1))));
333 }
334 } else if (field.elem_size() == sizeof(uint32_t)) {
335 const auto* optional =
336 reinterpret_cast<const std::optional<uint32_t>*>(values.data());
337 if (optional->has_value()) {
338 uint32_t value = optional->value();
339 PW_TRY(
340 WriteFixed(field.field_number(), as_bytes(span(&value, 1))));
341 }
342 }
343 } else {
344 PW_CHECK(values.size() == field.elem_size(),
345 "Mismatched message field type and size");
346 if (static_cast<size_t>(
347 std::count(values.begin(), values.end(), std::byte{0})) <
348 values.size()) {
349 PW_TRY(WriteFixed(field.field_number(), values));
350 }
351 }
352 break;
353 }
354 case WireType::kVarint: {
355 // Varint fields call WriteVarintField() for singular case and
356 // WritePackedVarints() for repeated fields.
357 PW_CHECK(field.elem_size() == sizeof(uint64_t) ||
358 field.elem_size() == sizeof(uint32_t) ||
359 field.elem_size() == sizeof(bool),
360 "Mismatched message field type and size");
361 if (field.is_fixed_size()) {
362 // The struct member for this field is an array of type corresponding
363 // to the field element size. Cast to a span of the correct type over
364 // the array so we're not performing type aliasing (except for
365 // unsigned vs signed which is explicitly allowed).
366 PW_CHECK(field.is_repeated(), "Non-repeated fixed size field");
367 if (static_cast<size_t>(
368 std::count(values.begin(), values.end(), std::byte{0})) ==
369 values.size()) {
370 continue;
371 }
372 if (field.elem_size() == sizeof(uint64_t)) {
373 PW_TRY(WritePackedVarints(
374 field.field_number(),
375 span(reinterpret_cast<const uint64_t*>(values.data()),
376 values.size() / field.elem_size()),
377 field.varint_type()));
378 } else if (field.elem_size() == sizeof(uint32_t)) {
379 PW_TRY(WritePackedVarints(
380 field.field_number(),
381 span(reinterpret_cast<const uint32_t*>(values.data()),
382 values.size() / field.elem_size()),
383 field.varint_type()));
384 } else if (field.elem_size() == sizeof(bool)) {
385 static_assert(sizeof(bool) == sizeof(uint8_t),
386 "bool must be same size as uint8_t");
387 PW_TRY(WritePackedVarints(
388 field.field_number(),
389 span(reinterpret_cast<const uint8_t*>(values.data()),
390 values.size() / field.elem_size()),
391 field.varint_type()));
392 }
393 } else if (field.is_repeated()) {
394 // The struct member for this field is a vector of a type
395 // corresponding to the field element size. Cast to the correct
396 // vector type so we're not performing type aliasing (except for
397 // unsigned vs signed which is explicitly allowed).
398 if (field.elem_size() == sizeof(uint64_t)) {
399 const auto* vector =
400 reinterpret_cast<const pw::Vector<const uint64_t>*>(
401 values.data());
402 if (!vector->empty()) {
403 PW_TRY(WritePackedVarints(field.field_number(),
404 span(vector->data(), vector->size()),
405 field.varint_type()));
406 }
407 } else if (field.elem_size() == sizeof(uint32_t)) {
408 const auto* vector =
409 reinterpret_cast<const pw::Vector<const uint32_t>*>(
410 values.data());
411 if (!vector->empty()) {
412 PW_TRY(WritePackedVarints(field.field_number(),
413 span(vector->data(), vector->size()),
414 field.varint_type()));
415 }
416 } else if (field.elem_size() == sizeof(bool)) {
417 static_assert(sizeof(bool) == sizeof(uint8_t),
418 "bool must be same size as uint8_t");
419 const auto* vector =
420 reinterpret_cast<const pw::Vector<const uint8_t>*>(
421 values.data());
422 if (!vector->empty()) {
423 PW_TRY(WritePackedVarints(field.field_number(),
424 span(vector->data(), vector->size()),
425 field.varint_type()));
426 }
427 }
428 } else if (field.is_optional()) {
429 // The struct member for this field is a std::optional of a type
430 // corresponding to the field element size. Cast to the correct
431 // optional type so we're not performing type aliasing (except for
432 // unsigned vs signed which is explicitly allowed), and write from
433 // a temporary.
434 uint64_t value = 0;
435 if (field.elem_size() == sizeof(uint64_t)) {
436 if (field.varint_type() == VarintType::kUnsigned) {
437 const auto* optional =
438 reinterpret_cast<const std::optional<uint64_t>*>(
439 values.data());
440 if (!optional->has_value()) {
441 continue;
442 }
443 value = optional->value();
444 } else {
445 const auto* optional =
446 reinterpret_cast<const std::optional<int64_t>*>(
447 values.data());
448 if (!optional->has_value()) {
449 continue;
450 }
451 value = field.varint_type() == VarintType::kZigZag
452 ? varint::ZigZagEncode(optional->value())
453 : optional->value();
454 }
455 } else if (field.elem_size() == sizeof(uint32_t)) {
456 if (field.varint_type() == VarintType::kUnsigned) {
457 const auto* optional =
458 reinterpret_cast<const std::optional<uint32_t>*>(
459 values.data());
460 if (!optional->has_value()) {
461 continue;
462 }
463 value = optional->value();
464 } else {
465 const auto* optional =
466 reinterpret_cast<const std::optional<int32_t>*>(
467 values.data());
468 if (!optional->has_value()) {
469 continue;
470 }
471 value = field.varint_type() == VarintType::kZigZag
472 ? varint::ZigZagEncode(optional->value())
473 : optional->value();
474 }
475 } else if (field.elem_size() == sizeof(bool)) {
476 const auto* optional =
477 reinterpret_cast<const std::optional<bool>*>(values.data());
478 if (!optional->has_value()) {
479 continue;
480 }
481 value = optional->value();
482 }
483 PW_TRY(WriteVarintField(field.field_number(), value));
484 } else {
485 // The struct member for this field is a scalar of a type
486 // corresponding to the field element size. Cast to the correct
487 // type to retrieve the value before passing to WriteVarintField()
488 // so we're not performing type aliasing (except for unsigned vs
489 // signed which is explicitly allowed).
490 PW_CHECK(values.size() == field.elem_size(),
491 "Mismatched message field type and size");
492 uint64_t value = 0;
493 if (field.elem_size() == sizeof(uint64_t)) {
494 if (field.varint_type() == VarintType::kZigZag) {
495 value = varint::ZigZagEncode(
496 *reinterpret_cast<const int64_t*>(values.data()));
497 } else if (field.varint_type() == VarintType::kNormal) {
498 value = *reinterpret_cast<const int64_t*>(values.data());
499 } else {
500 value = *reinterpret_cast<const uint64_t*>(values.data());
501 }
502 if (!value) {
503 continue;
504 }
505 } else if (field.elem_size() == sizeof(uint32_t)) {
506 if (field.varint_type() == VarintType::kZigZag) {
507 value = varint::ZigZagEncode(
508 *reinterpret_cast<const int32_t*>(values.data()));
509 } else if (field.varint_type() == VarintType::kNormal) {
510 value = *reinterpret_cast<const int32_t*>(values.data());
511 } else {
512 value = *reinterpret_cast<const uint32_t*>(values.data());
513 }
514 if (!value) {
515 continue;
516 }
517 } else if (field.elem_size() == sizeof(bool)) {
518 value = *reinterpret_cast<const bool*>(values.data());
519 if (!value) {
520 continue;
521 }
522 }
523 PW_TRY(WriteVarintField(field.field_number(), value));
524 }
525 break;
526 }
527 case WireType::kDelimited: {
528 // Delimited fields are always a singular case because of the
529 // inability to cast to a generic vector with an element of a certain
530 // size (we always need a type).
531 PW_CHECK(!field.is_repeated(),
532 "Repeated delimited messages always require a callback");
533 if (field.nested_message_fields()) {
534 // Nested Message. Struct member is an embedded struct for the
535 // nested field. Obtain a nested encoder and recursively call Write()
536 // using the fields table pointer from this field.
537 auto nested_encoder = GetNestedEncoder(field.field_number(),
538 /*write_when_empty=*/false);
539 PW_TRY(nested_encoder.Write(values, *field.nested_message_fields()));
540 } else if (field.is_fixed_size()) {
541 // Fixed-length bytes field. Struct member is a std::array<std::byte>.
542 // Call WriteLengthDelimitedField() to output it to the stream.
543 PW_CHECK(field.elem_size() == sizeof(std::byte),
544 "Mismatched message field type and size");
545 if (static_cast<size_t>(
546 std::count(values.begin(), values.end(), std::byte{0})) <
547 values.size()) {
548 PW_TRY(WriteLengthDelimitedField(field.field_number(), values));
549 }
550 } else {
551 // bytes or string field with a maximum size. Struct member is
552 // pw::Vector<std::byte> for bytes or pw::InlineString<> for string.
553 // Use the contents as a span and call WriteLengthDelimitedField() to
554 // output it to the stream.
555 PW_CHECK(field.elem_size() == sizeof(std::byte),
556 "Mismatched message field type and size");
557 if (field.is_string()) {
558 PW_TRY(WriteStringOrBytes<const InlineString<>>(
559 field.field_number(), values.data()));
560 } else {
561 PW_TRY(WriteStringOrBytes<const Vector<const std::byte>>(
562 field.field_number(), values.data()));
563 }
564 }
565 break;
566 }
567 }
568 }
569
570 ResetOneOfCallbacks(message, table);
571
572 return status_;
573 }
574
ResetOneOfCallbacks(ConstByteSpan message,span<const internal::MessageField> table)575 void StreamEncoder::ResetOneOfCallbacks(
576 ConstByteSpan message, span<const internal::MessageField> table) {
577 for (const auto& field : table) {
578 // Calculate the span of bytes corresponding to the structure field to
579 // read from.
580 ConstByteSpan values =
581 message.subspan(field.field_offset(), field.field_size());
582 PW_CHECK(values.begin() >= message.begin() &&
583 values.end() <= message.end());
584
585 if (field.callback_type() == internal::CallbackType::kOneOfGroup) {
586 const OneOf<StreamEncoder, StreamDecoder>* callback =
587 reinterpret_cast<const OneOf<StreamEncoder, StreamDecoder>*>(
588 values.data());
589 callback->invoked_ = false;
590 }
591 }
592 }
593
594 } // namespace pw::protobuf
595