• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #include "pw_protobuf/encoder.h"
16 
17 #include <algorithm>
18 #include <cstddef>
19 #include <cstring>
20 #include <optional>
21 
22 #include "pw_assert/check.h"
23 #include "pw_bytes/span.h"
24 #include "pw_protobuf/internal/codegen.h"
25 #include "pw_protobuf/serialized_size.h"
26 #include "pw_protobuf/stream_decoder.h"
27 #include "pw_protobuf/wire_format.h"
28 #include "pw_span/span.h"
29 #include "pw_status/status.h"
30 #include "pw_status/try.h"
31 #include "pw_stream/memory_stream.h"
32 #include "pw_stream/stream.h"
33 #include "pw_string/string.h"
34 #include "pw_varint/varint.h"
35 
36 namespace pw::protobuf {
37 
38 using internal::VarintType;
39 
GetNestedEncoder(uint32_t field_number,bool write_when_empty)40 StreamEncoder StreamEncoder::GetNestedEncoder(uint32_t field_number,
41                                               bool write_when_empty) {
42   PW_CHECK(!nested_encoder_open());
43 
44   nested_field_number_ = field_number;
45   if (!ValidFieldNumber(field_number)) {
46     status_.Update(Status::InvalidArgument());
47     return StreamEncoder(*this, ByteSpan(), false);
48   }
49 
50   // Pass the unused space of the scratch buffer to the nested encoder to use
51   // as their scratch buffer.
52   size_t key_size =
53       varint::EncodedSize(FieldKey(field_number, WireType::kDelimited));
54   size_t reserved_size = key_size + config::kMaxVarintSize;
55   size_t max_size = std::min(memory_writer_.ConservativeWriteLimit(),
56                              writer_.ConservativeWriteLimit());
57   // Cap based on max varint size.
58   max_size = std::min(varint::MaxValueInBytes(config::kMaxVarintSize),
59                       static_cast<uint64_t>(max_size));
60 
61   // Account for reserved bytes.
62   max_size = max_size > reserved_size ? max_size - reserved_size : 0;
63 
64   ByteSpan nested_buffer;
65   if (max_size > 0) {
66     nested_buffer = ByteSpan(
67         memory_writer_.data() + reserved_size + memory_writer_.bytes_written(),
68         max_size);
69   } else {
70     nested_buffer = ByteSpan();
71   }
72   return StreamEncoder(*this, nested_buffer, write_when_empty);
73 }
74 
CloseEncoder()75 void StreamEncoder::CloseEncoder() {
76   // If this was an invalidated StreamEncoder which cannot be used, permit the
77   // object to be cleanly destructed by doing nothing.
78   if (nested_field_number_ == kFirstReservedNumber) {
79     return;
80   }
81 
82   PW_CHECK(
83       !nested_encoder_open(),
84       "Tried to destruct a proto encoder with an active submessage encoder");
85 
86   if (parent_ != nullptr) {
87     parent_->CloseNestedMessage(*this);
88   }
89 }
90 
CloseNestedMessage(StreamEncoder & nested)91 void StreamEncoder::CloseNestedMessage(StreamEncoder& nested) {
92   PW_DCHECK_PTR_EQ(nested.parent_,
93                    this,
94                    "CloseNestedMessage() called on the wrong Encoder parent");
95 
96   // Make the nested encoder look like it has an open child to block writes for
97   // the remainder of the object's life.
98   nested.nested_field_number_ = kFirstReservedNumber;
99   nested.parent_ = nullptr;
100   // Temporarily cache the field number of the child so we can re-enable
101   // writing to this encoder.
102   uint32_t temp_field_number = nested_field_number_;
103   nested_field_number_ = 0;
104 
105   // TODO(amontanez): If a submessage fails, we could optionally discard
106   // it and continue happily. For now, we'll always invalidate the entire
107   // encoder if a single submessage fails.
108   status_.Update(nested.status_);
109   if (!status_.ok()) {
110     return;
111   }
112 
113   if (varint::EncodedSize(nested.memory_writer_.bytes_written()) >
114       config::kMaxVarintSize) {
115     status_ = Status::OutOfRange();
116     return;
117   }
118 
119   if (!nested.memory_writer_.bytes_written() && !nested.write_when_empty_) {
120     return;
121   }
122 
123   status_ = WriteLengthDelimitedField(temp_field_number,
124                                       nested.memory_writer_.WrittenData());
125 }
126 
WriteVarintField(uint32_t field_number,uint64_t value)127 Status StreamEncoder::WriteVarintField(uint32_t field_number, uint64_t value) {
128   PW_TRY(UpdateStatusForWrite(
129       field_number, WireType::kVarint, varint::EncodedSize(value)));
130 
131   WriteVarint(FieldKey(field_number, WireType::kVarint))
132       .IgnoreError();  // TODO: b/242598609 - Handle Status properly
133   return WriteVarint(value);
134 }
135 
WriteLengthDelimitedField(uint32_t field_number,ConstByteSpan data)136 Status StreamEncoder::WriteLengthDelimitedField(uint32_t field_number,
137                                                 ConstByteSpan data) {
138   PW_TRY(UpdateStatusForWrite(field_number, WireType::kDelimited, data.size()));
139   status_.Update(WriteLengthDelimitedKeyAndLengthPrefix(
140       field_number, data.size(), writer_));
141   PW_TRY(status_);
142   if (Status status = writer_.Write(data); !status.ok()) {
143     status_ = status;
144   }
145   return status_;
146 }
147 
WriteLengthDelimitedFieldFromStream(uint32_t field_number,stream::Reader & bytes_reader,size_t num_bytes,ByteSpan stream_pipe_buffer)148 Status StreamEncoder::WriteLengthDelimitedFieldFromStream(
149     uint32_t field_number,
150     stream::Reader& bytes_reader,
151     size_t num_bytes,
152     ByteSpan stream_pipe_buffer) {
153   PW_CHECK_UINT_GT(
154       stream_pipe_buffer.size(), 0, "Transfer buffer cannot be 0 size");
155   PW_TRY(UpdateStatusForWrite(field_number, WireType::kDelimited, num_bytes));
156   status_.Update(
157       WriteLengthDelimitedKeyAndLengthPrefix(field_number, num_bytes, writer_));
158   PW_TRY(status_);
159 
160   // Stream data from `bytes_reader` to `writer_`.
161   // TODO(pwbug/468): move the following logic to pw_stream/copy.h at a later
162   // time.
163   for (size_t bytes_written = 0; bytes_written < num_bytes;) {
164     const size_t chunk_size_bytes =
165         std::min(num_bytes - bytes_written, stream_pipe_buffer.size_bytes());
166     const Result<ByteSpan> read_result =
167         bytes_reader.Read(stream_pipe_buffer.data(), chunk_size_bytes);
168     status_.Update(read_result.status());
169     PW_TRY(status_);
170 
171     status_.Update(writer_.Write(read_result.value()));
172     PW_TRY(status_);
173 
174     bytes_written += read_result.value().size();
175   }
176 
177   return OkStatus();
178 }
179 
WriteFixed(uint32_t field_number,ConstByteSpan data)180 Status StreamEncoder::WriteFixed(uint32_t field_number, ConstByteSpan data) {
181   WireType type =
182       data.size() == sizeof(uint32_t) ? WireType::kFixed32 : WireType::kFixed64;
183 
184   PW_TRY(UpdateStatusForWrite(field_number, type, data.size()));
185 
186   WriteVarint(FieldKey(field_number, type))
187       .IgnoreError();  // TODO: b/242598609 - Handle Status properly
188   if (Status status = writer_.Write(data); !status.ok()) {
189     status_ = status;
190   }
191   return status_;
192 }
193 
WritePackedFixed(uint32_t field_number,span<const std::byte> values,size_t elem_size)194 Status StreamEncoder::WritePackedFixed(uint32_t field_number,
195                                        span<const std::byte> values,
196                                        size_t elem_size) {
197   if (values.empty()) {
198     return status_;
199   }
200 
201   PW_CHECK_NOTNULL(values.data());
202   PW_DCHECK(elem_size == sizeof(uint32_t) || elem_size == sizeof(uint64_t));
203 
204   PW_TRY(UpdateStatusForWrite(
205       field_number, WireType::kDelimited, values.size_bytes()));
206   WriteVarint(FieldKey(field_number, WireType::kDelimited))
207       .IgnoreError();  // TODO: b/242598609 - Handle Status properly
208   WriteVarint(values.size_bytes())
209       .IgnoreError();  // TODO: b/242598609 - Handle Status properly
210 
211   for (auto val_start = values.begin(); val_start != values.end();
212        val_start += elem_size) {
213     // Allocates 8 bytes so both 4-byte and 8-byte types can be encoded as
214     // little-endian for serialization.
215     std::array<std::byte, sizeof(uint64_t)> data;
216     if (endian::native == endian::little) {
217       std::copy(val_start, val_start + elem_size, std::begin(data));
218     } else {
219       std::reverse_copy(val_start, val_start + elem_size, std::begin(data));
220     }
221     status_.Update(writer_.Write(span(data).first(elem_size)));
222     PW_TRY(status_);
223   }
224   return status_;
225 }
226 
UpdateStatusForWrite(uint32_t field_number,WireType type,size_t data_size)227 Status StreamEncoder::UpdateStatusForWrite(uint32_t field_number,
228                                            WireType type,
229                                            size_t data_size) {
230   PW_CHECK(!nested_encoder_open());
231   PW_TRY(status_);
232 
233   if (!ValidFieldNumber(field_number)) {
234     return status_ = Status::InvalidArgument();
235   }
236 
237   const Result<size_t> field_size = SizeOfField(field_number, type, data_size);
238   status_.Update(field_size.status());
239   PW_TRY(status_);
240 
241   if (field_size.value() > writer_.ConservativeWriteLimit()) {
242     status_ = Status::ResourceExhausted();
243   }
244 
245   return status_;
246 }
247 
Write(span<const std::byte> message,span<const internal::MessageField> table)248 Status StreamEncoder::Write(span<const std::byte> message,
249                             span<const internal::MessageField> table) {
250   PW_CHECK(!nested_encoder_open());
251   PW_TRY(status_);
252 
253   for (const auto& field : table) {
254     // Calculate the span of bytes corresponding to the structure field to
255     // read from.
256     ConstByteSpan values =
257         message.subspan(field.field_offset(), field.field_size());
258     PW_CHECK(values.begin() >= message.begin() &&
259              values.end() <= message.end());
260 
261     // If the field is using callbacks, interpret the input field accordingly
262     // and allow the caller to provide custom handling.
263     if (field.callback_type() == internal::CallbackType::kSingleField) {
264       const Callback<StreamEncoder, StreamDecoder>* callback =
265           reinterpret_cast<const Callback<StreamEncoder, StreamDecoder>*>(
266               values.data());
267       PW_TRY(callback->Encode(*this));
268       continue;
269     } else if (field.callback_type() == internal::CallbackType::kOneOfGroup) {
270       const OneOf<StreamEncoder, StreamDecoder>* callback =
271           reinterpret_cast<const OneOf<StreamEncoder, StreamDecoder>*>(
272               values.data());
273       PW_TRY(callback->Encode(*this));
274       continue;
275     }
276 
277     switch (field.wire_type()) {
278       case WireType::kFixed64:
279       case WireType::kFixed32: {
280         // Fixed fields call WriteFixed() for singular case and
281         // WritePackedFixed() for repeated fields.
282         PW_CHECK(field.elem_size() == (field.wire_type() == WireType::kFixed32
283                                            ? sizeof(uint32_t)
284                                            : sizeof(uint64_t)),
285                  "Mismatched message field type and size");
286         if (field.is_fixed_size()) {
287           PW_CHECK(field.is_repeated(), "Non-repeated fixed size field");
288           if (static_cast<size_t>(
289                   std::count(values.begin(), values.end(), std::byte{0})) <
290               values.size()) {
291             PW_TRY(WritePackedFixed(
292                 field.field_number(), values, field.elem_size()));
293           }
294         } else if (field.is_repeated()) {
295           // The struct member for this field is a vector of a type
296           // corresponding to the field element size. Cast to the correct
297           // vector type so we're not performing type aliasing (except for
298           // unsigned vs signed which is explicitly allowed).
299           if (field.elem_size() == sizeof(uint64_t)) {
300             const auto* vector =
301                 reinterpret_cast<const pw::Vector<const uint64_t>*>(
302                     values.data());
303             if (!vector->empty()) {
304               PW_TRY(WritePackedFixed(
305                   field.field_number(),
306                   as_bytes(span(vector->data(), vector->size())),
307                   field.elem_size()));
308             }
309           } else if (field.elem_size() == sizeof(uint32_t)) {
310             const auto* vector =
311                 reinterpret_cast<const pw::Vector<const uint32_t>*>(
312                     values.data());
313             if (!vector->empty()) {
314               PW_TRY(WritePackedFixed(
315                   field.field_number(),
316                   as_bytes(span(vector->data(), vector->size())),
317                   field.elem_size()));
318             }
319           }
320         } else if (field.is_optional()) {
321           // The struct member for this field is a std::optional of a type
322           // corresponding to the field element size. Cast to the correct
323           // optional type so we're not performing type aliasing (except for
324           // unsigned vs signed which is explicitly allowed), and write from
325           // a temporary.
326           if (field.elem_size() == sizeof(uint64_t)) {
327             const auto* optional =
328                 reinterpret_cast<const std::optional<uint64_t>*>(values.data());
329             if (optional->has_value()) {
330               uint64_t value = optional->value();
331               PW_TRY(
332                   WriteFixed(field.field_number(), as_bytes(span(&value, 1))));
333             }
334           } else if (field.elem_size() == sizeof(uint32_t)) {
335             const auto* optional =
336                 reinterpret_cast<const std::optional<uint32_t>*>(values.data());
337             if (optional->has_value()) {
338               uint32_t value = optional->value();
339               PW_TRY(
340                   WriteFixed(field.field_number(), as_bytes(span(&value, 1))));
341             }
342           }
343         } else {
344           PW_CHECK(values.size() == field.elem_size(),
345                    "Mismatched message field type and size");
346           if (static_cast<size_t>(
347                   std::count(values.begin(), values.end(), std::byte{0})) <
348               values.size()) {
349             PW_TRY(WriteFixed(field.field_number(), values));
350           }
351         }
352         break;
353       }
354       case WireType::kVarint: {
355         // Varint fields call WriteVarintField() for singular case and
356         // WritePackedVarints() for repeated fields.
357         PW_CHECK(field.elem_size() == sizeof(uint64_t) ||
358                      field.elem_size() == sizeof(uint32_t) ||
359                      field.elem_size() == sizeof(bool),
360                  "Mismatched message field type and size");
361         if (field.is_fixed_size()) {
362           // The struct member for this field is an array of type corresponding
363           // to the field element size. Cast to a span of the correct type over
364           // the array so we're not performing type aliasing (except for
365           // unsigned vs signed which is explicitly allowed).
366           PW_CHECK(field.is_repeated(), "Non-repeated fixed size field");
367           if (static_cast<size_t>(
368                   std::count(values.begin(), values.end(), std::byte{0})) ==
369               values.size()) {
370             continue;
371           }
372           if (field.elem_size() == sizeof(uint64_t)) {
373             PW_TRY(WritePackedVarints(
374                 field.field_number(),
375                 span(reinterpret_cast<const uint64_t*>(values.data()),
376                      values.size() / field.elem_size()),
377                 field.varint_type()));
378           } else if (field.elem_size() == sizeof(uint32_t)) {
379             PW_TRY(WritePackedVarints(
380                 field.field_number(),
381                 span(reinterpret_cast<const uint32_t*>(values.data()),
382                      values.size() / field.elem_size()),
383                 field.varint_type()));
384           } else if (field.elem_size() == sizeof(bool)) {
385             static_assert(sizeof(bool) == sizeof(uint8_t),
386                           "bool must be same size as uint8_t");
387             PW_TRY(WritePackedVarints(
388                 field.field_number(),
389                 span(reinterpret_cast<const uint8_t*>(values.data()),
390                      values.size() / field.elem_size()),
391                 field.varint_type()));
392           }
393         } else if (field.is_repeated()) {
394           // The struct member for this field is a vector of a type
395           // corresponding to the field element size. Cast to the correct
396           // vector type so we're not performing type aliasing (except for
397           // unsigned vs signed which is explicitly allowed).
398           if (field.elem_size() == sizeof(uint64_t)) {
399             const auto* vector =
400                 reinterpret_cast<const pw::Vector<const uint64_t>*>(
401                     values.data());
402             if (!vector->empty()) {
403               PW_TRY(WritePackedVarints(field.field_number(),
404                                         span(vector->data(), vector->size()),
405                                         field.varint_type()));
406             }
407           } else if (field.elem_size() == sizeof(uint32_t)) {
408             const auto* vector =
409                 reinterpret_cast<const pw::Vector<const uint32_t>*>(
410                     values.data());
411             if (!vector->empty()) {
412               PW_TRY(WritePackedVarints(field.field_number(),
413                                         span(vector->data(), vector->size()),
414                                         field.varint_type()));
415             }
416           } else if (field.elem_size() == sizeof(bool)) {
417             static_assert(sizeof(bool) == sizeof(uint8_t),
418                           "bool must be same size as uint8_t");
419             const auto* vector =
420                 reinterpret_cast<const pw::Vector<const uint8_t>*>(
421                     values.data());
422             if (!vector->empty()) {
423               PW_TRY(WritePackedVarints(field.field_number(),
424                                         span(vector->data(), vector->size()),
425                                         field.varint_type()));
426             }
427           }
428         } else if (field.is_optional()) {
429           // The struct member for this field is a std::optional of a type
430           // corresponding to the field element size. Cast to the correct
431           // optional type so we're not performing type aliasing (except for
432           // unsigned vs signed which is explicitly allowed), and write from
433           // a temporary.
434           uint64_t value = 0;
435           if (field.elem_size() == sizeof(uint64_t)) {
436             if (field.varint_type() == VarintType::kUnsigned) {
437               const auto* optional =
438                   reinterpret_cast<const std::optional<uint64_t>*>(
439                       values.data());
440               if (!optional->has_value()) {
441                 continue;
442               }
443               value = optional->value();
444             } else {
445               const auto* optional =
446                   reinterpret_cast<const std::optional<int64_t>*>(
447                       values.data());
448               if (!optional->has_value()) {
449                 continue;
450               }
451               value = field.varint_type() == VarintType::kZigZag
452                           ? varint::ZigZagEncode(optional->value())
453                           : optional->value();
454             }
455           } else if (field.elem_size() == sizeof(uint32_t)) {
456             if (field.varint_type() == VarintType::kUnsigned) {
457               const auto* optional =
458                   reinterpret_cast<const std::optional<uint32_t>*>(
459                       values.data());
460               if (!optional->has_value()) {
461                 continue;
462               }
463               value = optional->value();
464             } else {
465               const auto* optional =
466                   reinterpret_cast<const std::optional<int32_t>*>(
467                       values.data());
468               if (!optional->has_value()) {
469                 continue;
470               }
471               value = field.varint_type() == VarintType::kZigZag
472                           ? varint::ZigZagEncode(optional->value())
473                           : optional->value();
474             }
475           } else if (field.elem_size() == sizeof(bool)) {
476             const auto* optional =
477                 reinterpret_cast<const std::optional<bool>*>(values.data());
478             if (!optional->has_value()) {
479               continue;
480             }
481             value = optional->value();
482           }
483           PW_TRY(WriteVarintField(field.field_number(), value));
484         } else {
485           // The struct member for this field is a scalar of a type
486           // corresponding to the field element size. Cast to the correct
487           // type to retrieve the value before passing to WriteVarintField()
488           // so we're not performing type aliasing (except for unsigned vs
489           // signed which is explicitly allowed).
490           PW_CHECK(values.size() == field.elem_size(),
491                    "Mismatched message field type and size");
492           uint64_t value = 0;
493           if (field.elem_size() == sizeof(uint64_t)) {
494             if (field.varint_type() == VarintType::kZigZag) {
495               value = varint::ZigZagEncode(
496                   *reinterpret_cast<const int64_t*>(values.data()));
497             } else if (field.varint_type() == VarintType::kNormal) {
498               value = *reinterpret_cast<const int64_t*>(values.data());
499             } else {
500               value = *reinterpret_cast<const uint64_t*>(values.data());
501             }
502             if (!value) {
503               continue;
504             }
505           } else if (field.elem_size() == sizeof(uint32_t)) {
506             if (field.varint_type() == VarintType::kZigZag) {
507               value = varint::ZigZagEncode(
508                   *reinterpret_cast<const int32_t*>(values.data()));
509             } else if (field.varint_type() == VarintType::kNormal) {
510               value = *reinterpret_cast<const int32_t*>(values.data());
511             } else {
512               value = *reinterpret_cast<const uint32_t*>(values.data());
513             }
514             if (!value) {
515               continue;
516             }
517           } else if (field.elem_size() == sizeof(bool)) {
518             value = *reinterpret_cast<const bool*>(values.data());
519             if (!value) {
520               continue;
521             }
522           }
523           PW_TRY(WriteVarintField(field.field_number(), value));
524         }
525         break;
526       }
527       case WireType::kDelimited: {
528         // Delimited fields are always a singular case because of the
529         // inability to cast to a generic vector with an element of a certain
530         // size (we always need a type).
531         PW_CHECK(!field.is_repeated(),
532                  "Repeated delimited messages always require a callback");
533         if (field.nested_message_fields()) {
534           // Nested Message. Struct member is an embedded struct for the
535           // nested field. Obtain a nested encoder and recursively call Write()
536           // using the fields table pointer from this field.
537           auto nested_encoder = GetNestedEncoder(field.field_number(),
538                                                  /*write_when_empty=*/false);
539           PW_TRY(nested_encoder.Write(values, *field.nested_message_fields()));
540         } else if (field.is_fixed_size()) {
541           // Fixed-length bytes field. Struct member is a std::array<std::byte>.
542           // Call WriteLengthDelimitedField() to output it to the stream.
543           PW_CHECK(field.elem_size() == sizeof(std::byte),
544                    "Mismatched message field type and size");
545           if (static_cast<size_t>(
546                   std::count(values.begin(), values.end(), std::byte{0})) <
547               values.size()) {
548             PW_TRY(WriteLengthDelimitedField(field.field_number(), values));
549           }
550         } else {
551           // bytes or string field with a maximum size. Struct member is
552           // pw::Vector<std::byte> for bytes or pw::InlineString<> for string.
553           // Use the contents as a span and call WriteLengthDelimitedField() to
554           // output it to the stream.
555           PW_CHECK(field.elem_size() == sizeof(std::byte),
556                    "Mismatched message field type and size");
557           if (field.is_string()) {
558             PW_TRY(WriteStringOrBytes<const InlineString<>>(
559                 field.field_number(), values.data()));
560           } else {
561             PW_TRY(WriteStringOrBytes<const Vector<const std::byte>>(
562                 field.field_number(), values.data()));
563           }
564         }
565         break;
566       }
567     }
568   }
569 
570   ResetOneOfCallbacks(message, table);
571 
572   return status_;
573 }
574 
ResetOneOfCallbacks(ConstByteSpan message,span<const internal::MessageField> table)575 void StreamEncoder::ResetOneOfCallbacks(
576     ConstByteSpan message, span<const internal::MessageField> table) {
577   for (const auto& field : table) {
578     // Calculate the span of bytes corresponding to the structure field to
579     // read from.
580     ConstByteSpan values =
581         message.subspan(field.field_offset(), field.field_size());
582     PW_CHECK(values.begin() >= message.begin() &&
583              values.end() <= message.end());
584 
585     if (field.callback_type() == internal::CallbackType::kOneOfGroup) {
586       const OneOf<StreamEncoder, StreamDecoder>* callback =
587           reinterpret_cast<const OneOf<StreamEncoder, StreamDecoder>*>(
588               values.data());
589       callback->invoked_ = false;
590     }
591   }
592 }
593 
594 }  // namespace pw::protobuf
595