1 // Copyright 2021 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include "pw_protobuf/encoder.h"
16
17 #include <algorithm>
18 #include <cstddef>
19 #include <cstring>
20 #include <optional>
21
22 #include "pw_assert/check.h"
23 #include "pw_bytes/span.h"
24 #include "pw_protobuf/internal/codegen.h"
25 #include "pw_protobuf/serialized_size.h"
26 #include "pw_protobuf/stream_decoder.h"
27 #include "pw_protobuf/wire_format.h"
28 #include "pw_span/span.h"
29 #include "pw_status/status.h"
30 #include "pw_status/try.h"
31 #include "pw_stream/memory_stream.h"
32 #include "pw_stream/stream.h"
33 #include "pw_string/string.h"
34 #include "pw_varint/varint.h"
35
36 namespace pw::protobuf {
37
38 using internal::VarintType;
39
GetNestedEncoder(uint32_t field_number,bool write_when_empty)40 StreamEncoder StreamEncoder::GetNestedEncoder(uint32_t field_number,
41 bool write_when_empty) {
42 PW_CHECK(!nested_encoder_open());
43
44 nested_field_number_ = field_number;
45 if (!ValidFieldNumber(field_number)) {
46 status_.Update(Status::InvalidArgument());
47 return StreamEncoder(*this, ByteSpan(), false);
48 }
49
50 // Pass the unused space of the scratch buffer to the nested encoder to use
51 // as their scratch buffer.
52 size_t key_size =
53 varint::EncodedSize(FieldKey(field_number, WireType::kDelimited));
54 size_t reserved_size = key_size + config::kMaxVarintSize;
55 size_t max_size = std::min(memory_writer_.ConservativeWriteLimit(),
56 writer_.ConservativeWriteLimit());
57 // Account for reserved bytes.
58 max_size = max_size > reserved_size ? max_size - reserved_size : 0;
59 // Cap based on max varint size.
60 max_size = std::min(varint::MaxValueInBytes(config::kMaxVarintSize),
61 static_cast<uint64_t>(max_size));
62
63 ByteSpan nested_buffer;
64 if (max_size > 0) {
65 nested_buffer = ByteSpan(
66 memory_writer_.data() + reserved_size + memory_writer_.bytes_written(),
67 max_size);
68 } else {
69 nested_buffer = ByteSpan();
70 }
71 return StreamEncoder(*this, nested_buffer, write_when_empty);
72 }
73
CloseEncoder()74 void StreamEncoder::CloseEncoder() {
75 // If this was an invalidated StreamEncoder which cannot be used, permit the
76 // object to be cleanly destructed by doing nothing.
77 if (nested_field_number_ == kFirstReservedNumber) {
78 return;
79 }
80
81 PW_CHECK(
82 !nested_encoder_open(),
83 "Tried to destruct a proto encoder with an active submessage encoder");
84
85 if (parent_ != nullptr) {
86 parent_->CloseNestedMessage(*this);
87 }
88 }
89
CloseNestedMessage(StreamEncoder & nested)90 void StreamEncoder::CloseNestedMessage(StreamEncoder& nested) {
91 PW_DCHECK_PTR_EQ(nested.parent_,
92 this,
93 "CloseNestedMessage() called on the wrong Encoder parent");
94
95 // Make the nested encoder look like it has an open child to block writes for
96 // the remainder of the object's life.
97 nested.nested_field_number_ = kFirstReservedNumber;
98 nested.parent_ = nullptr;
99 // Temporarily cache the field number of the child so we can re-enable
100 // writing to this encoder.
101 uint32_t temp_field_number = nested_field_number_;
102 nested_field_number_ = 0;
103
104 // TODO(amontanez): If a submessage fails, we could optionally discard
105 // it and continue happily. For now, we'll always invalidate the entire
106 // encoder if a single submessage fails.
107 status_.Update(nested.status_);
108 if (!status_.ok()) {
109 return;
110 }
111
112 if (varint::EncodedSize(nested.memory_writer_.bytes_written()) >
113 config::kMaxVarintSize) {
114 status_ = Status::OutOfRange();
115 return;
116 }
117
118 if (!nested.memory_writer_.bytes_written() && !nested.write_when_empty_) {
119 return;
120 }
121
122 status_ = WriteLengthDelimitedField(temp_field_number,
123 nested.memory_writer_.WrittenData());
124 }
125
WriteVarintField(uint32_t field_number,uint64_t value)126 Status StreamEncoder::WriteVarintField(uint32_t field_number, uint64_t value) {
127 PW_TRY(UpdateStatusForWrite(
128 field_number, WireType::kVarint, varint::EncodedSize(value)));
129
130 WriteVarint(FieldKey(field_number, WireType::kVarint))
131 .IgnoreError(); // TODO: b/242598609 - Handle Status properly
132 return WriteVarint(value);
133 }
134
WriteLengthDelimitedField(uint32_t field_number,ConstByteSpan data)135 Status StreamEncoder::WriteLengthDelimitedField(uint32_t field_number,
136 ConstByteSpan data) {
137 PW_TRY(UpdateStatusForWrite(field_number, WireType::kDelimited, data.size()));
138 status_.Update(WriteLengthDelimitedKeyAndLengthPrefix(
139 field_number, data.size(), writer_));
140 PW_TRY(status_);
141 if (Status status = writer_.Write(data); !status.ok()) {
142 status_ = status;
143 }
144 return status_;
145 }
146
WriteLengthDelimitedFieldFromStream(uint32_t field_number,stream::Reader & bytes_reader,size_t num_bytes,ByteSpan stream_pipe_buffer)147 Status StreamEncoder::WriteLengthDelimitedFieldFromStream(
148 uint32_t field_number,
149 stream::Reader& bytes_reader,
150 size_t num_bytes,
151 ByteSpan stream_pipe_buffer) {
152 PW_CHECK_UINT_GT(
153 stream_pipe_buffer.size(), 0, "Transfer buffer cannot be 0 size");
154 PW_TRY(UpdateStatusForWrite(field_number, WireType::kDelimited, num_bytes));
155 status_.Update(
156 WriteLengthDelimitedKeyAndLengthPrefix(field_number, num_bytes, writer_));
157 PW_TRY(status_);
158
159 // Stream data from `bytes_reader` to `writer_`.
160 // TODO(pwbug/468): move the following logic to pw_stream/copy.h at a later
161 // time.
162 for (size_t bytes_written = 0; bytes_written < num_bytes;) {
163 const size_t chunk_size_bytes =
164 std::min(num_bytes - bytes_written, stream_pipe_buffer.size_bytes());
165 const Result<ByteSpan> read_result =
166 bytes_reader.Read(stream_pipe_buffer.data(), chunk_size_bytes);
167 status_.Update(read_result.status());
168 PW_TRY(status_);
169
170 status_.Update(writer_.Write(read_result.value()));
171 PW_TRY(status_);
172
173 bytes_written += read_result.value().size();
174 }
175
176 return OkStatus();
177 }
178
WriteFixed(uint32_t field_number,ConstByteSpan data)179 Status StreamEncoder::WriteFixed(uint32_t field_number, ConstByteSpan data) {
180 WireType type =
181 data.size() == sizeof(uint32_t) ? WireType::kFixed32 : WireType::kFixed64;
182
183 PW_TRY(UpdateStatusForWrite(field_number, type, data.size()));
184
185 WriteVarint(FieldKey(field_number, type))
186 .IgnoreError(); // TODO: b/242598609 - Handle Status properly
187 if (Status status = writer_.Write(data); !status.ok()) {
188 status_ = status;
189 }
190 return status_;
191 }
192
WritePackedFixed(uint32_t field_number,span<const std::byte> values,size_t elem_size)193 Status StreamEncoder::WritePackedFixed(uint32_t field_number,
194 span<const std::byte> values,
195 size_t elem_size) {
196 if (values.empty()) {
197 return status_;
198 }
199
200 PW_CHECK_NOTNULL(values.data());
201 PW_DCHECK(elem_size == sizeof(uint32_t) || elem_size == sizeof(uint64_t));
202
203 PW_TRY(UpdateStatusForWrite(
204 field_number, WireType::kDelimited, values.size_bytes()));
205 WriteVarint(FieldKey(field_number, WireType::kDelimited))
206 .IgnoreError(); // TODO: b/242598609 - Handle Status properly
207 WriteVarint(values.size_bytes())
208 .IgnoreError(); // TODO: b/242598609 - Handle Status properly
209
210 for (auto val_start = values.begin(); val_start != values.end();
211 val_start += elem_size) {
212 // Allocates 8 bytes so both 4-byte and 8-byte types can be encoded as
213 // little-endian for serialization.
214 std::array<std::byte, sizeof(uint64_t)> data;
215 if (endian::native == endian::little) {
216 std::copy(val_start, val_start + elem_size, std::begin(data));
217 } else {
218 std::reverse_copy(val_start, val_start + elem_size, std::begin(data));
219 }
220 status_.Update(writer_.Write(span(data).first(elem_size)));
221 PW_TRY(status_);
222 }
223 return status_;
224 }
225
UpdateStatusForWrite(uint32_t field_number,WireType type,size_t data_size)226 Status StreamEncoder::UpdateStatusForWrite(uint32_t field_number,
227 WireType type,
228 size_t data_size) {
229 PW_CHECK(!nested_encoder_open());
230 PW_TRY(status_);
231
232 if (!ValidFieldNumber(field_number)) {
233 return status_ = Status::InvalidArgument();
234 }
235
236 const Result<size_t> field_size = SizeOfField(field_number, type, data_size);
237 status_.Update(field_size.status());
238 PW_TRY(status_);
239
240 if (field_size.value() > writer_.ConservativeWriteLimit()) {
241 status_ = Status::ResourceExhausted();
242 }
243
244 return status_;
245 }
246
Write(span<const std::byte> message,span<const internal::MessageField> table)247 Status StreamEncoder::Write(span<const std::byte> message,
248 span<const internal::MessageField> table) {
249 PW_CHECK(!nested_encoder_open());
250 PW_TRY(status_);
251
252 for (const auto& field : table) {
253 // Calculate the span of bytes corresponding to the structure field to
254 // read from.
255 const auto values =
256 message.subspan(field.field_offset(), field.field_size());
257 PW_CHECK(values.begin() >= message.begin() &&
258 values.end() <= message.end());
259
260 // If the field is using callbacks, interpret the input field accordingly
261 // and allow the caller to provide custom handling.
262 if (field.use_callback()) {
263 const Callback<StreamEncoder, StreamDecoder>* callback =
264 reinterpret_cast<const Callback<StreamEncoder, StreamDecoder>*>(
265 values.data());
266 PW_TRY(callback->Encode(*this));
267 continue;
268 }
269
270 switch (field.wire_type()) {
271 case WireType::kFixed64:
272 case WireType::kFixed32: {
273 // Fixed fields call WriteFixed() for singular case and
274 // WritePackedFixed() for repeated fields.
275 PW_CHECK(field.elem_size() == (field.wire_type() == WireType::kFixed32
276 ? sizeof(uint32_t)
277 : sizeof(uint64_t)),
278 "Mismatched message field type and size");
279 if (field.is_fixed_size()) {
280 PW_CHECK(field.is_repeated(), "Non-repeated fixed size field");
281 if (static_cast<size_t>(
282 std::count(values.begin(), values.end(), std::byte{0})) <
283 values.size()) {
284 PW_TRY(WritePackedFixed(
285 field.field_number(), values, field.elem_size()));
286 }
287 } else if (field.is_repeated()) {
288 // The struct member for this field is a vector of a type
289 // corresponding to the field element size. Cast to the correct
290 // vector type so we're not performing type aliasing (except for
291 // unsigned vs signed which is explicitly allowed).
292 if (field.elem_size() == sizeof(uint64_t)) {
293 const auto* vector =
294 reinterpret_cast<const pw::Vector<const uint64_t>*>(
295 values.data());
296 if (!vector->empty()) {
297 PW_TRY(WritePackedFixed(
298 field.field_number(),
299 as_bytes(span(vector->data(), vector->size())),
300 field.elem_size()));
301 }
302 } else if (field.elem_size() == sizeof(uint32_t)) {
303 const auto* vector =
304 reinterpret_cast<const pw::Vector<const uint32_t>*>(
305 values.data());
306 if (!vector->empty()) {
307 PW_TRY(WritePackedFixed(
308 field.field_number(),
309 as_bytes(span(vector->data(), vector->size())),
310 field.elem_size()));
311 }
312 }
313 } else if (field.is_optional()) {
314 // The struct member for this field is a std::optional of a type
315 // corresponding to the field element size. Cast to the correct
316 // optional type so we're not performing type aliasing (except for
317 // unsigned vs signed which is explicitly allowed), and write from
318 // a temporary.
319 if (field.elem_size() == sizeof(uint64_t)) {
320 const auto* optional =
321 reinterpret_cast<const std::optional<uint64_t>*>(values.data());
322 if (optional->has_value()) {
323 uint64_t value = optional->value();
324 PW_TRY(
325 WriteFixed(field.field_number(), as_bytes(span(&value, 1))));
326 }
327 } else if (field.elem_size() == sizeof(uint32_t)) {
328 const auto* optional =
329 reinterpret_cast<const std::optional<uint32_t>*>(values.data());
330 if (optional->has_value()) {
331 uint32_t value = optional->value();
332 PW_TRY(
333 WriteFixed(field.field_number(), as_bytes(span(&value, 1))));
334 }
335 }
336 } else {
337 PW_CHECK(values.size() == field.elem_size(),
338 "Mismatched message field type and size");
339 if (static_cast<size_t>(
340 std::count(values.begin(), values.end(), std::byte{0})) <
341 values.size()) {
342 PW_TRY(WriteFixed(field.field_number(), values));
343 }
344 }
345 break;
346 }
347 case WireType::kVarint: {
348 // Varint fields call WriteVarintField() for singular case and
349 // WritePackedVarints() for repeated fields.
350 PW_CHECK(field.elem_size() == sizeof(uint64_t) ||
351 field.elem_size() == sizeof(uint32_t) ||
352 field.elem_size() == sizeof(bool),
353 "Mismatched message field type and size");
354 if (field.is_fixed_size()) {
355 // The struct member for this field is an array of type corresponding
356 // to the field element size. Cast to a span of the correct type over
357 // the array so we're not performing type aliasing (except for
358 // unsigned vs signed which is explicitly allowed).
359 PW_CHECK(field.is_repeated(), "Non-repeated fixed size field");
360 if (static_cast<size_t>(
361 std::count(values.begin(), values.end(), std::byte{0})) ==
362 values.size()) {
363 continue;
364 }
365 if (field.elem_size() == sizeof(uint64_t)) {
366 PW_TRY(WritePackedVarints(
367 field.field_number(),
368 span(reinterpret_cast<const uint64_t*>(values.data()),
369 values.size() / field.elem_size()),
370 field.varint_type()));
371 } else if (field.elem_size() == sizeof(uint32_t)) {
372 PW_TRY(WritePackedVarints(
373 field.field_number(),
374 span(reinterpret_cast<const uint32_t*>(values.data()),
375 values.size() / field.elem_size()),
376 field.varint_type()));
377 } else if (field.elem_size() == sizeof(bool)) {
378 static_assert(sizeof(bool) == sizeof(uint8_t),
379 "bool must be same size as uint8_t");
380 PW_TRY(WritePackedVarints(
381 field.field_number(),
382 span(reinterpret_cast<const uint8_t*>(values.data()),
383 values.size() / field.elem_size()),
384 field.varint_type()));
385 }
386 } else if (field.is_repeated()) {
387 // The struct member for this field is a vector of a type
388 // corresponding to the field element size. Cast to the correct
389 // vector type so we're not performing type aliasing (except for
390 // unsigned vs signed which is explicitly allowed).
391 if (field.elem_size() == sizeof(uint64_t)) {
392 const auto* vector =
393 reinterpret_cast<const pw::Vector<const uint64_t>*>(
394 values.data());
395 if (!vector->empty()) {
396 PW_TRY(WritePackedVarints(field.field_number(),
397 span(vector->data(), vector->size()),
398 field.varint_type()));
399 }
400 } else if (field.elem_size() == sizeof(uint32_t)) {
401 const auto* vector =
402 reinterpret_cast<const pw::Vector<const uint32_t>*>(
403 values.data());
404 if (!vector->empty()) {
405 PW_TRY(WritePackedVarints(field.field_number(),
406 span(vector->data(), vector->size()),
407 field.varint_type()));
408 }
409 } else if (field.elem_size() == sizeof(bool)) {
410 static_assert(sizeof(bool) == sizeof(uint8_t),
411 "bool must be same size as uint8_t");
412 const auto* vector =
413 reinterpret_cast<const pw::Vector<const uint8_t>*>(
414 values.data());
415 if (!vector->empty()) {
416 PW_TRY(WritePackedVarints(field.field_number(),
417 span(vector->data(), vector->size()),
418 field.varint_type()));
419 }
420 }
421 } else if (field.is_optional()) {
422 // The struct member for this field is a std::optional of a type
423 // corresponding to the field element size. Cast to the correct
424 // optional type so we're not performing type aliasing (except for
425 // unsigned vs signed which is explicitly allowed), and write from
426 // a temporary.
427 uint64_t value = 0;
428 if (field.elem_size() == sizeof(uint64_t)) {
429 if (field.varint_type() == VarintType::kUnsigned) {
430 const auto* optional =
431 reinterpret_cast<const std::optional<uint64_t>*>(
432 values.data());
433 if (!optional->has_value()) {
434 continue;
435 }
436 value = optional->value();
437 } else {
438 const auto* optional =
439 reinterpret_cast<const std::optional<int64_t>*>(
440 values.data());
441 if (!optional->has_value()) {
442 continue;
443 }
444 value = field.varint_type() == VarintType::kZigZag
445 ? varint::ZigZagEncode(optional->value())
446 : optional->value();
447 }
448 } else if (field.elem_size() == sizeof(uint32_t)) {
449 if (field.varint_type() == VarintType::kUnsigned) {
450 const auto* optional =
451 reinterpret_cast<const std::optional<uint32_t>*>(
452 values.data());
453 if (!optional->has_value()) {
454 continue;
455 }
456 value = optional->value();
457 } else {
458 const auto* optional =
459 reinterpret_cast<const std::optional<int32_t>*>(
460 values.data());
461 if (!optional->has_value()) {
462 continue;
463 }
464 value = field.varint_type() == VarintType::kZigZag
465 ? varint::ZigZagEncode(optional->value())
466 : optional->value();
467 }
468 } else if (field.elem_size() == sizeof(bool)) {
469 const auto* optional =
470 reinterpret_cast<const std::optional<bool>*>(values.data());
471 if (!optional->has_value()) {
472 continue;
473 }
474 value = optional->value();
475 }
476 PW_TRY(WriteVarintField(field.field_number(), value));
477 } else {
478 // The struct member for this field is a scalar of a type
479 // corresponding to the field element size. Cast to the correct
480 // type to retrieve the value before passing to WriteVarintField()
481 // so we're not performing type aliasing (except for unsigned vs
482 // signed which is explicitly allowed).
483 PW_CHECK(values.size() == field.elem_size(),
484 "Mismatched message field type and size");
485 uint64_t value = 0;
486 if (field.elem_size() == sizeof(uint64_t)) {
487 if (field.varint_type() == VarintType::kZigZag) {
488 value = varint::ZigZagEncode(
489 *reinterpret_cast<const int64_t*>(values.data()));
490 } else if (field.varint_type() == VarintType::kNormal) {
491 value = *reinterpret_cast<const int64_t*>(values.data());
492 } else {
493 value = *reinterpret_cast<const uint64_t*>(values.data());
494 }
495 if (!value) {
496 continue;
497 }
498 } else if (field.elem_size() == sizeof(uint32_t)) {
499 if (field.varint_type() == VarintType::kZigZag) {
500 value = varint::ZigZagEncode(
501 *reinterpret_cast<const int32_t*>(values.data()));
502 } else if (field.varint_type() == VarintType::kNormal) {
503 value = *reinterpret_cast<const int32_t*>(values.data());
504 } else {
505 value = *reinterpret_cast<const uint32_t*>(values.data());
506 }
507 if (!value) {
508 continue;
509 }
510 } else if (field.elem_size() == sizeof(bool)) {
511 value = *reinterpret_cast<const bool*>(values.data());
512 if (!value) {
513 continue;
514 }
515 }
516 PW_TRY(WriteVarintField(field.field_number(), value));
517 }
518 break;
519 }
520 case WireType::kDelimited: {
521 // Delimited fields are always a singular case because of the
522 // inability to cast to a generic vector with an element of a certain
523 // size (we always need a type).
524 PW_CHECK(!field.is_repeated(),
525 "Repeated delimited messages always require a callback");
526 if (field.nested_message_fields()) {
527 // Nested Message. Struct member is an embedded struct for the
528 // nested field. Obtain a nested encoder and recursively call Write()
529 // using the fields table pointer from this field.
530 auto nested_encoder = GetNestedEncoder(field.field_number(),
531 /*write_when_empty=*/false);
532 PW_TRY(nested_encoder.Write(values, *field.nested_message_fields()));
533 } else if (field.is_fixed_size()) {
534 // Fixed-length bytes field. Struct member is a std::array<std::byte>.
535 // Call WriteLengthDelimitedField() to output it to the stream.
536 PW_CHECK(field.elem_size() == sizeof(std::byte),
537 "Mismatched message field type and size");
538 if (static_cast<size_t>(
539 std::count(values.begin(), values.end(), std::byte{0})) <
540 values.size()) {
541 PW_TRY(WriteLengthDelimitedField(field.field_number(), values));
542 }
543 } else {
544 // bytes or string field with a maximum size. Struct member is
545 // pw::Vector<std::byte> for bytes or pw::InlineString<> for string.
546 // Use the contents as a span and call WriteLengthDelimitedField() to
547 // output it to the stream.
548 PW_CHECK(field.elem_size() == sizeof(std::byte),
549 "Mismatched message field type and size");
550 if (field.is_string()) {
551 PW_TRY(WriteStringOrBytes<const InlineString<>>(
552 field.field_number(), values.data()));
553 } else {
554 PW_TRY(WriteStringOrBytes<const Vector<const std::byte>>(
555 field.field_number(), values.data()));
556 }
557 }
558 break;
559 }
560 }
561 }
562
563 return status_;
564 }
565
566 } // namespace pw::protobuf
567