• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #include "pw_protobuf/encoder.h"
16 
17 #include <algorithm>
18 #include <cstddef>
19 #include <cstring>
20 #include <optional>
21 
22 #include "pw_assert/check.h"
23 #include "pw_bytes/span.h"
24 #include "pw_protobuf/internal/codegen.h"
25 #include "pw_protobuf/serialized_size.h"
26 #include "pw_protobuf/stream_decoder.h"
27 #include "pw_protobuf/wire_format.h"
28 #include "pw_span/span.h"
29 #include "pw_status/status.h"
30 #include "pw_status/try.h"
31 #include "pw_stream/memory_stream.h"
32 #include "pw_stream/stream.h"
33 #include "pw_string/string.h"
34 #include "pw_varint/varint.h"
35 
36 namespace pw::protobuf {
37 
38 using internal::VarintType;
39 
GetNestedEncoder(uint32_t field_number,bool write_when_empty)40 StreamEncoder StreamEncoder::GetNestedEncoder(uint32_t field_number,
41                                               bool write_when_empty) {
42   PW_CHECK(!nested_encoder_open());
43 
44   nested_field_number_ = field_number;
45   if (!ValidFieldNumber(field_number)) {
46     status_.Update(Status::InvalidArgument());
47     return StreamEncoder(*this, ByteSpan(), false);
48   }
49 
50   // Pass the unused space of the scratch buffer to the nested encoder to use
51   // as their scratch buffer.
52   size_t key_size =
53       varint::EncodedSize(FieldKey(field_number, WireType::kDelimited));
54   size_t reserved_size = key_size + config::kMaxVarintSize;
55   size_t max_size = std::min(memory_writer_.ConservativeWriteLimit(),
56                              writer_.ConservativeWriteLimit());
57   // Account for reserved bytes.
58   max_size = max_size > reserved_size ? max_size - reserved_size : 0;
59   // Cap based on max varint size.
60   max_size = std::min(varint::MaxValueInBytes(config::kMaxVarintSize),
61                       static_cast<uint64_t>(max_size));
62 
63   ByteSpan nested_buffer;
64   if (max_size > 0) {
65     nested_buffer = ByteSpan(
66         memory_writer_.data() + reserved_size + memory_writer_.bytes_written(),
67         max_size);
68   } else {
69     nested_buffer = ByteSpan();
70   }
71   return StreamEncoder(*this, nested_buffer, write_when_empty);
72 }
73 
CloseEncoder()74 void StreamEncoder::CloseEncoder() {
75   // If this was an invalidated StreamEncoder which cannot be used, permit the
76   // object to be cleanly destructed by doing nothing.
77   if (nested_field_number_ == kFirstReservedNumber) {
78     return;
79   }
80 
81   PW_CHECK(
82       !nested_encoder_open(),
83       "Tried to destruct a proto encoder with an active submessage encoder");
84 
85   if (parent_ != nullptr) {
86     parent_->CloseNestedMessage(*this);
87   }
88 }
89 
CloseNestedMessage(StreamEncoder & nested)90 void StreamEncoder::CloseNestedMessage(StreamEncoder& nested) {
91   PW_DCHECK_PTR_EQ(nested.parent_,
92                    this,
93                    "CloseNestedMessage() called on the wrong Encoder parent");
94 
95   // Make the nested encoder look like it has an open child to block writes for
96   // the remainder of the object's life.
97   nested.nested_field_number_ = kFirstReservedNumber;
98   nested.parent_ = nullptr;
99   // Temporarily cache the field number of the child so we can re-enable
100   // writing to this encoder.
101   uint32_t temp_field_number = nested_field_number_;
102   nested_field_number_ = 0;
103 
104   // TODO(amontanez): If a submessage fails, we could optionally discard
105   // it and continue happily. For now, we'll always invalidate the entire
106   // encoder if a single submessage fails.
107   status_.Update(nested.status_);
108   if (!status_.ok()) {
109     return;
110   }
111 
112   if (varint::EncodedSize(nested.memory_writer_.bytes_written()) >
113       config::kMaxVarintSize) {
114     status_ = Status::OutOfRange();
115     return;
116   }
117 
118   if (!nested.memory_writer_.bytes_written() && !nested.write_when_empty_) {
119     return;
120   }
121 
122   status_ = WriteLengthDelimitedField(temp_field_number,
123                                       nested.memory_writer_.WrittenData());
124 }
125 
WriteVarintField(uint32_t field_number,uint64_t value)126 Status StreamEncoder::WriteVarintField(uint32_t field_number, uint64_t value) {
127   PW_TRY(UpdateStatusForWrite(
128       field_number, WireType::kVarint, varint::EncodedSize(value)));
129 
130   WriteVarint(FieldKey(field_number, WireType::kVarint))
131       .IgnoreError();  // TODO: b/242598609 - Handle Status properly
132   return WriteVarint(value);
133 }
134 
WriteLengthDelimitedField(uint32_t field_number,ConstByteSpan data)135 Status StreamEncoder::WriteLengthDelimitedField(uint32_t field_number,
136                                                 ConstByteSpan data) {
137   PW_TRY(UpdateStatusForWrite(field_number, WireType::kDelimited, data.size()));
138   status_.Update(WriteLengthDelimitedKeyAndLengthPrefix(
139       field_number, data.size(), writer_));
140   PW_TRY(status_);
141   if (Status status = writer_.Write(data); !status.ok()) {
142     status_ = status;
143   }
144   return status_;
145 }
146 
WriteLengthDelimitedFieldFromStream(uint32_t field_number,stream::Reader & bytes_reader,size_t num_bytes,ByteSpan stream_pipe_buffer)147 Status StreamEncoder::WriteLengthDelimitedFieldFromStream(
148     uint32_t field_number,
149     stream::Reader& bytes_reader,
150     size_t num_bytes,
151     ByteSpan stream_pipe_buffer) {
152   PW_CHECK_UINT_GT(
153       stream_pipe_buffer.size(), 0, "Transfer buffer cannot be 0 size");
154   PW_TRY(UpdateStatusForWrite(field_number, WireType::kDelimited, num_bytes));
155   status_.Update(
156       WriteLengthDelimitedKeyAndLengthPrefix(field_number, num_bytes, writer_));
157   PW_TRY(status_);
158 
159   // Stream data from `bytes_reader` to `writer_`.
160   // TODO(pwbug/468): move the following logic to pw_stream/copy.h at a later
161   // time.
162   for (size_t bytes_written = 0; bytes_written < num_bytes;) {
163     const size_t chunk_size_bytes =
164         std::min(num_bytes - bytes_written, stream_pipe_buffer.size_bytes());
165     const Result<ByteSpan> read_result =
166         bytes_reader.Read(stream_pipe_buffer.data(), chunk_size_bytes);
167     status_.Update(read_result.status());
168     PW_TRY(status_);
169 
170     status_.Update(writer_.Write(read_result.value()));
171     PW_TRY(status_);
172 
173     bytes_written += read_result.value().size();
174   }
175 
176   return OkStatus();
177 }
178 
WriteFixed(uint32_t field_number,ConstByteSpan data)179 Status StreamEncoder::WriteFixed(uint32_t field_number, ConstByteSpan data) {
180   WireType type =
181       data.size() == sizeof(uint32_t) ? WireType::kFixed32 : WireType::kFixed64;
182 
183   PW_TRY(UpdateStatusForWrite(field_number, type, data.size()));
184 
185   WriteVarint(FieldKey(field_number, type))
186       .IgnoreError();  // TODO: b/242598609 - Handle Status properly
187   if (Status status = writer_.Write(data); !status.ok()) {
188     status_ = status;
189   }
190   return status_;
191 }
192 
WritePackedFixed(uint32_t field_number,span<const std::byte> values,size_t elem_size)193 Status StreamEncoder::WritePackedFixed(uint32_t field_number,
194                                        span<const std::byte> values,
195                                        size_t elem_size) {
196   if (values.empty()) {
197     return status_;
198   }
199 
200   PW_CHECK_NOTNULL(values.data());
201   PW_DCHECK(elem_size == sizeof(uint32_t) || elem_size == sizeof(uint64_t));
202 
203   PW_TRY(UpdateStatusForWrite(
204       field_number, WireType::kDelimited, values.size_bytes()));
205   WriteVarint(FieldKey(field_number, WireType::kDelimited))
206       .IgnoreError();  // TODO: b/242598609 - Handle Status properly
207   WriteVarint(values.size_bytes())
208       .IgnoreError();  // TODO: b/242598609 - Handle Status properly
209 
210   for (auto val_start = values.begin(); val_start != values.end();
211        val_start += elem_size) {
212     // Allocates 8 bytes so both 4-byte and 8-byte types can be encoded as
213     // little-endian for serialization.
214     std::array<std::byte, sizeof(uint64_t)> data;
215     if (endian::native == endian::little) {
216       std::copy(val_start, val_start + elem_size, std::begin(data));
217     } else {
218       std::reverse_copy(val_start, val_start + elem_size, std::begin(data));
219     }
220     status_.Update(writer_.Write(span(data).first(elem_size)));
221     PW_TRY(status_);
222   }
223   return status_;
224 }
225 
UpdateStatusForWrite(uint32_t field_number,WireType type,size_t data_size)226 Status StreamEncoder::UpdateStatusForWrite(uint32_t field_number,
227                                            WireType type,
228                                            size_t data_size) {
229   PW_CHECK(!nested_encoder_open());
230   PW_TRY(status_);
231 
232   if (!ValidFieldNumber(field_number)) {
233     return status_ = Status::InvalidArgument();
234   }
235 
236   const Result<size_t> field_size = SizeOfField(field_number, type, data_size);
237   status_.Update(field_size.status());
238   PW_TRY(status_);
239 
240   if (field_size.value() > writer_.ConservativeWriteLimit()) {
241     status_ = Status::ResourceExhausted();
242   }
243 
244   return status_;
245 }
246 
Write(span<const std::byte> message,span<const internal::MessageField> table)247 Status StreamEncoder::Write(span<const std::byte> message,
248                             span<const internal::MessageField> table) {
249   PW_CHECK(!nested_encoder_open());
250   PW_TRY(status_);
251 
252   for (const auto& field : table) {
253     // Calculate the span of bytes corresponding to the structure field to
254     // read from.
255     const auto values =
256         message.subspan(field.field_offset(), field.field_size());
257     PW_CHECK(values.begin() >= message.begin() &&
258              values.end() <= message.end());
259 
260     // If the field is using callbacks, interpret the input field accordingly
261     // and allow the caller to provide custom handling.
262     if (field.use_callback()) {
263       const Callback<StreamEncoder, StreamDecoder>* callback =
264           reinterpret_cast<const Callback<StreamEncoder, StreamDecoder>*>(
265               values.data());
266       PW_TRY(callback->Encode(*this));
267       continue;
268     }
269 
270     switch (field.wire_type()) {
271       case WireType::kFixed64:
272       case WireType::kFixed32: {
273         // Fixed fields call WriteFixed() for singular case and
274         // WritePackedFixed() for repeated fields.
275         PW_CHECK(field.elem_size() == (field.wire_type() == WireType::kFixed32
276                                            ? sizeof(uint32_t)
277                                            : sizeof(uint64_t)),
278                  "Mismatched message field type and size");
279         if (field.is_fixed_size()) {
280           PW_CHECK(field.is_repeated(), "Non-repeated fixed size field");
281           if (static_cast<size_t>(
282                   std::count(values.begin(), values.end(), std::byte{0})) <
283               values.size()) {
284             PW_TRY(WritePackedFixed(
285                 field.field_number(), values, field.elem_size()));
286           }
287         } else if (field.is_repeated()) {
288           // The struct member for this field is a vector of a type
289           // corresponding to the field element size. Cast to the correct
290           // vector type so we're not performing type aliasing (except for
291           // unsigned vs signed which is explicitly allowed).
292           if (field.elem_size() == sizeof(uint64_t)) {
293             const auto* vector =
294                 reinterpret_cast<const pw::Vector<const uint64_t>*>(
295                     values.data());
296             if (!vector->empty()) {
297               PW_TRY(WritePackedFixed(
298                   field.field_number(),
299                   as_bytes(span(vector->data(), vector->size())),
300                   field.elem_size()));
301             }
302           } else if (field.elem_size() == sizeof(uint32_t)) {
303             const auto* vector =
304                 reinterpret_cast<const pw::Vector<const uint32_t>*>(
305                     values.data());
306             if (!vector->empty()) {
307               PW_TRY(WritePackedFixed(
308                   field.field_number(),
309                   as_bytes(span(vector->data(), vector->size())),
310                   field.elem_size()));
311             }
312           }
313         } else if (field.is_optional()) {
314           // The struct member for this field is a std::optional of a type
315           // corresponding to the field element size. Cast to the correct
316           // optional type so we're not performing type aliasing (except for
317           // unsigned vs signed which is explicitly allowed), and write from
318           // a temporary.
319           if (field.elem_size() == sizeof(uint64_t)) {
320             const auto* optional =
321                 reinterpret_cast<const std::optional<uint64_t>*>(values.data());
322             if (optional->has_value()) {
323               uint64_t value = optional->value();
324               PW_TRY(
325                   WriteFixed(field.field_number(), as_bytes(span(&value, 1))));
326             }
327           } else if (field.elem_size() == sizeof(uint32_t)) {
328             const auto* optional =
329                 reinterpret_cast<const std::optional<uint32_t>*>(values.data());
330             if (optional->has_value()) {
331               uint32_t value = optional->value();
332               PW_TRY(
333                   WriteFixed(field.field_number(), as_bytes(span(&value, 1))));
334             }
335           }
336         } else {
337           PW_CHECK(values.size() == field.elem_size(),
338                    "Mismatched message field type and size");
339           if (static_cast<size_t>(
340                   std::count(values.begin(), values.end(), std::byte{0})) <
341               values.size()) {
342             PW_TRY(WriteFixed(field.field_number(), values));
343           }
344         }
345         break;
346       }
347       case WireType::kVarint: {
348         // Varint fields call WriteVarintField() for singular case and
349         // WritePackedVarints() for repeated fields.
350         PW_CHECK(field.elem_size() == sizeof(uint64_t) ||
351                      field.elem_size() == sizeof(uint32_t) ||
352                      field.elem_size() == sizeof(bool),
353                  "Mismatched message field type and size");
354         if (field.is_fixed_size()) {
355           // The struct member for this field is an array of type corresponding
356           // to the field element size. Cast to a span of the correct type over
357           // the array so we're not performing type aliasing (except for
358           // unsigned vs signed which is explicitly allowed).
359           PW_CHECK(field.is_repeated(), "Non-repeated fixed size field");
360           if (static_cast<size_t>(
361                   std::count(values.begin(), values.end(), std::byte{0})) ==
362               values.size()) {
363             continue;
364           }
365           if (field.elem_size() == sizeof(uint64_t)) {
366             PW_TRY(WritePackedVarints(
367                 field.field_number(),
368                 span(reinterpret_cast<const uint64_t*>(values.data()),
369                      values.size() / field.elem_size()),
370                 field.varint_type()));
371           } else if (field.elem_size() == sizeof(uint32_t)) {
372             PW_TRY(WritePackedVarints(
373                 field.field_number(),
374                 span(reinterpret_cast<const uint32_t*>(values.data()),
375                      values.size() / field.elem_size()),
376                 field.varint_type()));
377           } else if (field.elem_size() == sizeof(bool)) {
378             static_assert(sizeof(bool) == sizeof(uint8_t),
379                           "bool must be same size as uint8_t");
380             PW_TRY(WritePackedVarints(
381                 field.field_number(),
382                 span(reinterpret_cast<const uint8_t*>(values.data()),
383                      values.size() / field.elem_size()),
384                 field.varint_type()));
385           }
386         } else if (field.is_repeated()) {
387           // The struct member for this field is a vector of a type
388           // corresponding to the field element size. Cast to the correct
389           // vector type so we're not performing type aliasing (except for
390           // unsigned vs signed which is explicitly allowed).
391           if (field.elem_size() == sizeof(uint64_t)) {
392             const auto* vector =
393                 reinterpret_cast<const pw::Vector<const uint64_t>*>(
394                     values.data());
395             if (!vector->empty()) {
396               PW_TRY(WritePackedVarints(field.field_number(),
397                                         span(vector->data(), vector->size()),
398                                         field.varint_type()));
399             }
400           } else if (field.elem_size() == sizeof(uint32_t)) {
401             const auto* vector =
402                 reinterpret_cast<const pw::Vector<const uint32_t>*>(
403                     values.data());
404             if (!vector->empty()) {
405               PW_TRY(WritePackedVarints(field.field_number(),
406                                         span(vector->data(), vector->size()),
407                                         field.varint_type()));
408             }
409           } else if (field.elem_size() == sizeof(bool)) {
410             static_assert(sizeof(bool) == sizeof(uint8_t),
411                           "bool must be same size as uint8_t");
412             const auto* vector =
413                 reinterpret_cast<const pw::Vector<const uint8_t>*>(
414                     values.data());
415             if (!vector->empty()) {
416               PW_TRY(WritePackedVarints(field.field_number(),
417                                         span(vector->data(), vector->size()),
418                                         field.varint_type()));
419             }
420           }
421         } else if (field.is_optional()) {
422           // The struct member for this field is a std::optional of a type
423           // corresponding to the field element size. Cast to the correct
424           // optional type so we're not performing type aliasing (except for
425           // unsigned vs signed which is explicitly allowed), and write from
426           // a temporary.
427           uint64_t value = 0;
428           if (field.elem_size() == sizeof(uint64_t)) {
429             if (field.varint_type() == VarintType::kUnsigned) {
430               const auto* optional =
431                   reinterpret_cast<const std::optional<uint64_t>*>(
432                       values.data());
433               if (!optional->has_value()) {
434                 continue;
435               }
436               value = optional->value();
437             } else {
438               const auto* optional =
439                   reinterpret_cast<const std::optional<int64_t>*>(
440                       values.data());
441               if (!optional->has_value()) {
442                 continue;
443               }
444               value = field.varint_type() == VarintType::kZigZag
445                           ? varint::ZigZagEncode(optional->value())
446                           : optional->value();
447             }
448           } else if (field.elem_size() == sizeof(uint32_t)) {
449             if (field.varint_type() == VarintType::kUnsigned) {
450               const auto* optional =
451                   reinterpret_cast<const std::optional<uint32_t>*>(
452                       values.data());
453               if (!optional->has_value()) {
454                 continue;
455               }
456               value = optional->value();
457             } else {
458               const auto* optional =
459                   reinterpret_cast<const std::optional<int32_t>*>(
460                       values.data());
461               if (!optional->has_value()) {
462                 continue;
463               }
464               value = field.varint_type() == VarintType::kZigZag
465                           ? varint::ZigZagEncode(optional->value())
466                           : optional->value();
467             }
468           } else if (field.elem_size() == sizeof(bool)) {
469             const auto* optional =
470                 reinterpret_cast<const std::optional<bool>*>(values.data());
471             if (!optional->has_value()) {
472               continue;
473             }
474             value = optional->value();
475           }
476           PW_TRY(WriteVarintField(field.field_number(), value));
477         } else {
478           // The struct member for this field is a scalar of a type
479           // corresponding to the field element size. Cast to the correct
480           // type to retrieve the value before passing to WriteVarintField()
481           // so we're not performing type aliasing (except for unsigned vs
482           // signed which is explicitly allowed).
483           PW_CHECK(values.size() == field.elem_size(),
484                    "Mismatched message field type and size");
485           uint64_t value = 0;
486           if (field.elem_size() == sizeof(uint64_t)) {
487             if (field.varint_type() == VarintType::kZigZag) {
488               value = varint::ZigZagEncode(
489                   *reinterpret_cast<const int64_t*>(values.data()));
490             } else if (field.varint_type() == VarintType::kNormal) {
491               value = *reinterpret_cast<const int64_t*>(values.data());
492             } else {
493               value = *reinterpret_cast<const uint64_t*>(values.data());
494             }
495             if (!value) {
496               continue;
497             }
498           } else if (field.elem_size() == sizeof(uint32_t)) {
499             if (field.varint_type() == VarintType::kZigZag) {
500               value = varint::ZigZagEncode(
501                   *reinterpret_cast<const int32_t*>(values.data()));
502             } else if (field.varint_type() == VarintType::kNormal) {
503               value = *reinterpret_cast<const int32_t*>(values.data());
504             } else {
505               value = *reinterpret_cast<const uint32_t*>(values.data());
506             }
507             if (!value) {
508               continue;
509             }
510           } else if (field.elem_size() == sizeof(bool)) {
511             value = *reinterpret_cast<const bool*>(values.data());
512             if (!value) {
513               continue;
514             }
515           }
516           PW_TRY(WriteVarintField(field.field_number(), value));
517         }
518         break;
519       }
520       case WireType::kDelimited: {
521         // Delimited fields are always a singular case because of the
522         // inability to cast to a generic vector with an element of a certain
523         // size (we always need a type).
524         PW_CHECK(!field.is_repeated(),
525                  "Repeated delimited messages always require a callback");
526         if (field.nested_message_fields()) {
527           // Nested Message. Struct member is an embedded struct for the
528           // nested field. Obtain a nested encoder and recursively call Write()
529           // using the fields table pointer from this field.
530           auto nested_encoder = GetNestedEncoder(field.field_number(),
531                                                  /*write_when_empty=*/false);
532           PW_TRY(nested_encoder.Write(values, *field.nested_message_fields()));
533         } else if (field.is_fixed_size()) {
534           // Fixed-length bytes field. Struct member is a std::array<std::byte>.
535           // Call WriteLengthDelimitedField() to output it to the stream.
536           PW_CHECK(field.elem_size() == sizeof(std::byte),
537                    "Mismatched message field type and size");
538           if (static_cast<size_t>(
539                   std::count(values.begin(), values.end(), std::byte{0})) <
540               values.size()) {
541             PW_TRY(WriteLengthDelimitedField(field.field_number(), values));
542           }
543         } else {
544           // bytes or string field with a maximum size. Struct member is
545           // pw::Vector<std::byte> for bytes or pw::InlineString<> for string.
546           // Use the contents as a span and call WriteLengthDelimitedField() to
547           // output it to the stream.
548           PW_CHECK(field.elem_size() == sizeof(std::byte),
549                    "Mismatched message field type and size");
550           if (field.is_string()) {
551             PW_TRY(WriteStringOrBytes<const InlineString<>>(
552                 field.field_number(), values.data()));
553           } else {
554             PW_TRY(WriteStringOrBytes<const Vector<const std::byte>>(
555                 field.field_number(), values.data()));
556           }
557         }
558         break;
559       }
560     }
561   }
562 
563   return status_;
564 }
565 
566 }  // namespace pw::protobuf
567