• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #include "pw_protobuf/encoder.h"
16 
17 #include <algorithm>
18 #include <cstddef>
19 #include <cstring>
20 #include <optional>
21 
22 #include "pw_assert/check.h"
23 #include "pw_bytes/span.h"
24 #include "pw_protobuf/internal/codegen.h"
25 #include "pw_protobuf/serialized_size.h"
26 #include "pw_protobuf/stream_decoder.h"
27 #include "pw_protobuf/wire_format.h"
28 #include "pw_span/span.h"
29 #include "pw_status/status.h"
30 #include "pw_status/try.h"
31 #include "pw_stream/memory_stream.h"
32 #include "pw_stream/stream.h"
33 #include "pw_string/string.h"
34 #include "pw_varint/varint.h"
35 
36 namespace pw::protobuf {
37 
38 using internal::VarintType;
39 
GetNestedEncoder(uint32_t field_number,bool write_when_empty)40 StreamEncoder StreamEncoder::GetNestedEncoder(uint32_t field_number,
41                                               bool write_when_empty) {
42   PW_CHECK(!nested_encoder_open());
43   PW_CHECK(ValidFieldNumber(field_number));
44 
45   nested_field_number_ = field_number;
46 
47   // Pass the unused space of the scratch buffer to the nested encoder to use
48   // as their scratch buffer.
49   size_t key_size =
50       varint::EncodedSize(FieldKey(field_number, WireType::kDelimited));
51   size_t reserved_size = key_size + config::kMaxVarintSize;
52   size_t max_size = std::min(memory_writer_.ConservativeWriteLimit(),
53                              writer_.ConservativeWriteLimit());
54   // Account for reserved bytes.
55   max_size = max_size > reserved_size ? max_size - reserved_size : 0;
56   // Cap based on max varint size.
57   max_size = std::min(varint::MaxValueInBytes(config::kMaxVarintSize),
58                       static_cast<uint64_t>(max_size));
59 
60   ByteSpan nested_buffer;
61   if (max_size > 0) {
62     nested_buffer = ByteSpan(
63         memory_writer_.data() + reserved_size + memory_writer_.bytes_written(),
64         max_size);
65   } else {
66     nested_buffer = ByteSpan();
67   }
68   return StreamEncoder(*this, nested_buffer, write_when_empty);
69 }
70 
CloseEncoder()71 void StreamEncoder::CloseEncoder() {
72   // If this was an invalidated StreamEncoder which cannot be used, permit the
73   // object to be cleanly destructed by doing nothing.
74   if (nested_field_number_ == kFirstReservedNumber) {
75     return;
76   }
77 
78   PW_CHECK(
79       !nested_encoder_open(),
80       "Tried to destruct a proto encoder with an active submessage encoder");
81 
82   if (parent_ != nullptr) {
83     parent_->CloseNestedMessage(*this);
84   }
85 }
86 
CloseNestedMessage(StreamEncoder & nested)87 void StreamEncoder::CloseNestedMessage(StreamEncoder& nested) {
88   PW_DCHECK_PTR_EQ(nested.parent_,
89                    this,
90                    "CloseNestedMessage() called on the wrong Encoder parent");
91 
92   // Make the nested encoder look like it has an open child to block writes for
93   // the remainder of the object's life.
94   nested.nested_field_number_ = kFirstReservedNumber;
95   nested.parent_ = nullptr;
96   // Temporarily cache the field number of the child so we can re-enable
97   // writing to this encoder.
98   uint32_t temp_field_number = nested_field_number_;
99   nested_field_number_ = 0;
100 
101   // TODO(amontanez): If a submessage fails, we could optionally discard
102   // it and continue happily. For now, we'll always invalidate the entire
103   // encoder if a single submessage fails.
104   status_.Update(nested.status_);
105   if (!status_.ok()) {
106     return;
107   }
108 
109   if (varint::EncodedSize(nested.memory_writer_.bytes_written()) >
110       config::kMaxVarintSize) {
111     status_ = Status::OutOfRange();
112     return;
113   }
114 
115   if (!nested.memory_writer_.bytes_written() && !nested.write_when_empty_) {
116     return;
117   }
118 
119   status_ = WriteLengthDelimitedField(temp_field_number,
120                                       nested.memory_writer_.WrittenData());
121 }
122 
WriteVarintField(uint32_t field_number,uint64_t value)123 Status StreamEncoder::WriteVarintField(uint32_t field_number, uint64_t value) {
124   PW_TRY(UpdateStatusForWrite(
125       field_number, WireType::kVarint, varint::EncodedSize(value)));
126 
127   WriteVarint(FieldKey(field_number, WireType::kVarint))
128       .IgnoreError();  // TODO(b/242598609): Handle Status properly
129   return WriteVarint(value);
130 }
131 
WriteLengthDelimitedField(uint32_t field_number,ConstByteSpan data)132 Status StreamEncoder::WriteLengthDelimitedField(uint32_t field_number,
133                                                 ConstByteSpan data) {
134   PW_TRY(UpdateStatusForWrite(field_number, WireType::kDelimited, data.size()));
135   status_.Update(WriteLengthDelimitedKeyAndLengthPrefix(
136       field_number, data.size(), writer_));
137   PW_TRY(status_);
138   if (Status status = writer_.Write(data); !status.ok()) {
139     status_ = status;
140   }
141   return status_;
142 }
143 
WriteLengthDelimitedFieldFromStream(uint32_t field_number,stream::Reader & bytes_reader,size_t num_bytes,ByteSpan stream_pipe_buffer)144 Status StreamEncoder::WriteLengthDelimitedFieldFromStream(
145     uint32_t field_number,
146     stream::Reader& bytes_reader,
147     size_t num_bytes,
148     ByteSpan stream_pipe_buffer) {
149   PW_CHECK_UINT_GT(
150       stream_pipe_buffer.size(), 0, "Transfer buffer cannot be 0 size");
151   PW_TRY(UpdateStatusForWrite(field_number, WireType::kDelimited, num_bytes));
152   status_.Update(
153       WriteLengthDelimitedKeyAndLengthPrefix(field_number, num_bytes, writer_));
154   PW_TRY(status_);
155 
156   // Stream data from `bytes_reader` to `writer_`.
157   // TODO(pwbug/468): move the following logic to pw_stream/copy.h at a later
158   // time.
159   for (size_t bytes_written = 0; bytes_written < num_bytes;) {
160     const size_t chunk_size_bytes =
161         std::min(num_bytes - bytes_written, stream_pipe_buffer.size_bytes());
162     const Result<ByteSpan> read_result =
163         bytes_reader.Read(stream_pipe_buffer.data(), chunk_size_bytes);
164     status_.Update(read_result.status());
165     PW_TRY(status_);
166 
167     status_.Update(writer_.Write(read_result.value()));
168     PW_TRY(status_);
169 
170     bytes_written += read_result.value().size();
171   }
172 
173   return OkStatus();
174 }
175 
WriteFixed(uint32_t field_number,ConstByteSpan data)176 Status StreamEncoder::WriteFixed(uint32_t field_number, ConstByteSpan data) {
177   WireType type =
178       data.size() == sizeof(uint32_t) ? WireType::kFixed32 : WireType::kFixed64;
179 
180   PW_TRY(UpdateStatusForWrite(field_number, type, data.size()));
181 
182   WriteVarint(FieldKey(field_number, type))
183       .IgnoreError();  // TODO(b/242598609): Handle Status properly
184   if (Status status = writer_.Write(data); !status.ok()) {
185     status_ = status;
186   }
187   return status_;
188 }
189 
WritePackedFixed(uint32_t field_number,span<const std::byte> values,size_t elem_size)190 Status StreamEncoder::WritePackedFixed(uint32_t field_number,
191                                        span<const std::byte> values,
192                                        size_t elem_size) {
193   if (values.empty()) {
194     return status_;
195   }
196 
197   PW_CHECK_NOTNULL(values.data());
198   PW_DCHECK(elem_size == sizeof(uint32_t) || elem_size == sizeof(uint64_t));
199 
200   PW_TRY(UpdateStatusForWrite(
201       field_number, WireType::kDelimited, values.size_bytes()));
202   WriteVarint(FieldKey(field_number, WireType::kDelimited))
203       .IgnoreError();  // TODO(b/242598609): Handle Status properly
204   WriteVarint(values.size_bytes())
205       .IgnoreError();  // TODO(b/242598609): Handle Status properly
206 
207   for (auto val_start = values.begin(); val_start != values.end();
208        val_start += elem_size) {
209     // Allocates 8 bytes so both 4-byte and 8-byte types can be encoded as
210     // little-endian for serialization.
211     std::array<std::byte, sizeof(uint64_t)> data;
212     if (endian::native == endian::little) {
213       std::copy(val_start, val_start + elem_size, std::begin(data));
214     } else {
215       std::reverse_copy(val_start, val_start + elem_size, std::begin(data));
216     }
217     status_.Update(writer_.Write(span(data).first(elem_size)));
218     PW_TRY(status_);
219   }
220   return status_;
221 }
222 
UpdateStatusForWrite(uint32_t field_number,WireType type,size_t data_size)223 Status StreamEncoder::UpdateStatusForWrite(uint32_t field_number,
224                                            WireType type,
225                                            size_t data_size) {
226   PW_CHECK(!nested_encoder_open());
227   PW_TRY(status_);
228 
229   if (!ValidFieldNumber(field_number)) {
230     return status_ = Status::InvalidArgument();
231   }
232 
233   const Result<size_t> field_size = SizeOfField(field_number, type, data_size);
234   status_.Update(field_size.status());
235   PW_TRY(status_);
236 
237   if (field_size.value() > writer_.ConservativeWriteLimit()) {
238     status_ = Status::ResourceExhausted();
239   }
240 
241   return status_;
242 }
243 
Write(span<const std::byte> message,span<const internal::MessageField> table)244 Status StreamEncoder::Write(span<const std::byte> message,
245                             span<const internal::MessageField> table) {
246   PW_CHECK(!nested_encoder_open());
247   PW_TRY(status_);
248 
249   for (const auto& field : table) {
250     // Calculate the span of bytes corresponding to the structure field to
251     // read from.
252     const auto values =
253         message.subspan(field.field_offset(), field.field_size());
254     PW_CHECK(values.begin() >= message.begin() &&
255              values.end() <= message.end());
256 
257     // If the field is using callbacks, interpret the input field accordingly
258     // and allow the caller to provide custom handling.
259     if (field.use_callback()) {
260       const Callback<StreamEncoder, StreamDecoder>* callback =
261           reinterpret_cast<const Callback<StreamEncoder, StreamDecoder>*>(
262               values.data());
263       PW_TRY(callback->Encode(*this));
264       continue;
265     }
266 
267     switch (field.wire_type()) {
268       case WireType::kFixed64:
269       case WireType::kFixed32: {
270         // Fixed fields call WriteFixed() for singular case and
271         // WritePackedFixed() for repeated fields.
272         PW_CHECK(field.elem_size() == (field.wire_type() == WireType::kFixed32
273                                            ? sizeof(uint32_t)
274                                            : sizeof(uint64_t)),
275                  "Mismatched message field type and size");
276         if (field.is_fixed_size()) {
277           PW_CHECK(field.is_repeated(), "Non-repeated fixed size field");
278           if (static_cast<size_t>(
279                   std::count(values.begin(), values.end(), std::byte{0})) <
280               values.size()) {
281             PW_TRY(WritePackedFixed(
282                 field.field_number(), values, field.elem_size()));
283           }
284         } else if (field.is_repeated()) {
285           // The struct member for this field is a vector of a type
286           // corresponding to the field element size. Cast to the correct
287           // vector type so we're not performing type aliasing (except for
288           // unsigned vs signed which is explicitly allowed).
289           if (field.elem_size() == sizeof(uint64_t)) {
290             const auto* vector =
291                 reinterpret_cast<const pw::Vector<const uint64_t>*>(
292                     values.data());
293             if (!vector->empty()) {
294               PW_TRY(WritePackedFixed(
295                   field.field_number(),
296                   as_bytes(span(vector->data(), vector->size())),
297                   field.elem_size()));
298             }
299           } else if (field.elem_size() == sizeof(uint32_t)) {
300             const auto* vector =
301                 reinterpret_cast<const pw::Vector<const uint32_t>*>(
302                     values.data());
303             if (!vector->empty()) {
304               PW_TRY(WritePackedFixed(
305                   field.field_number(),
306                   as_bytes(span(vector->data(), vector->size())),
307                   field.elem_size()));
308             }
309           }
310         } else if (field.is_optional()) {
311           // The struct member for this field is a std::optional of a type
312           // corresponding to the field element size. Cast to the correct
313           // optional type so we're not performing type aliasing (except for
314           // unsigned vs signed which is explicitly allowed), and write from
315           // a temporary.
316           if (field.elem_size() == sizeof(uint64_t)) {
317             const auto* optional =
318                 reinterpret_cast<const std::optional<uint64_t>*>(values.data());
319             if (optional->has_value()) {
320               uint64_t value = optional->value();
321               PW_TRY(
322                   WriteFixed(field.field_number(), as_bytes(span(&value, 1))));
323             }
324           } else if (field.elem_size() == sizeof(uint32_t)) {
325             const auto* optional =
326                 reinterpret_cast<const std::optional<uint32_t>*>(values.data());
327             if (optional->has_value()) {
328               uint32_t value = optional->value();
329               PW_TRY(
330                   WriteFixed(field.field_number(), as_bytes(span(&value, 1))));
331             }
332           }
333         } else {
334           PW_CHECK(values.size() == field.elem_size(),
335                    "Mismatched message field type and size");
336           if (static_cast<size_t>(
337                   std::count(values.begin(), values.end(), std::byte{0})) <
338               values.size()) {
339             PW_TRY(WriteFixed(field.field_number(), values));
340           }
341         }
342         break;
343       }
344       case WireType::kVarint: {
345         // Varint fields call WriteVarintField() for singular case and
346         // WritePackedVarints() for repeated fields.
347         PW_CHECK(field.elem_size() == sizeof(uint64_t) ||
348                      field.elem_size() == sizeof(uint32_t) ||
349                      field.elem_size() == sizeof(bool),
350                  "Mismatched message field type and size");
351         if (field.is_fixed_size()) {
352           // The struct member for this field is an array of type corresponding
353           // to the field element size. Cast to a span of the correct type over
354           // the array so we're not performing type aliasing (except for
355           // unsigned vs signed which is explicitly allowed).
356           PW_CHECK(field.is_repeated(), "Non-repeated fixed size field");
357           if (static_cast<size_t>(
358                   std::count(values.begin(), values.end(), std::byte{0})) ==
359               values.size()) {
360             continue;
361           }
362           if (field.elem_size() == sizeof(uint64_t)) {
363             PW_TRY(WritePackedVarints(
364                 field.field_number(),
365                 span(reinterpret_cast<const uint64_t*>(values.data()),
366                      values.size() / field.elem_size()),
367                 field.varint_type()));
368           } else if (field.elem_size() == sizeof(uint32_t)) {
369             PW_TRY(WritePackedVarints(
370                 field.field_number(),
371                 span(reinterpret_cast<const uint32_t*>(values.data()),
372                      values.size() / field.elem_size()),
373                 field.varint_type()));
374           } else if (field.elem_size() == sizeof(bool)) {
375             static_assert(sizeof(bool) == sizeof(uint8_t),
376                           "bool must be same size as uint8_t");
377             PW_TRY(WritePackedVarints(
378                 field.field_number(),
379                 span(reinterpret_cast<const uint8_t*>(values.data()),
380                      values.size() / field.elem_size()),
381                 field.varint_type()));
382           }
383         } else if (field.is_repeated()) {
384           // The struct member for this field is a vector of a type
385           // corresponding to the field element size. Cast to the correct
386           // vector type so we're not performing type aliasing (except for
387           // unsigned vs signed which is explicitly allowed).
388           if (field.elem_size() == sizeof(uint64_t)) {
389             const auto* vector =
390                 reinterpret_cast<const pw::Vector<const uint64_t>*>(
391                     values.data());
392             if (!vector->empty()) {
393               PW_TRY(WritePackedVarints(field.field_number(),
394                                         span(vector->data(), vector->size()),
395                                         field.varint_type()));
396             }
397           } else if (field.elem_size() == sizeof(uint32_t)) {
398             const auto* vector =
399                 reinterpret_cast<const pw::Vector<const uint32_t>*>(
400                     values.data());
401             if (!vector->empty()) {
402               PW_TRY(WritePackedVarints(field.field_number(),
403                                         span(vector->data(), vector->size()),
404                                         field.varint_type()));
405             }
406           } else if (field.elem_size() == sizeof(bool)) {
407             static_assert(sizeof(bool) == sizeof(uint8_t),
408                           "bool must be same size as uint8_t");
409             const auto* vector =
410                 reinterpret_cast<const pw::Vector<const uint8_t>*>(
411                     values.data());
412             if (!vector->empty()) {
413               PW_TRY(WritePackedVarints(field.field_number(),
414                                         span(vector->data(), vector->size()),
415                                         field.varint_type()));
416             }
417           }
418         } else if (field.is_optional()) {
419           // The struct member for this field is a std::optional of a type
420           // corresponding to the field element size. Cast to the correct
421           // optional type so we're not performing type aliasing (except for
422           // unsigned vs signed which is explicitly allowed), and write from
423           // a temporary.
424           uint64_t value = 0;
425           if (field.elem_size() == sizeof(uint64_t)) {
426             if (field.varint_type() == VarintType::kUnsigned) {
427               const auto* optional =
428                   reinterpret_cast<const std::optional<uint64_t>*>(
429                       values.data());
430               if (!optional->has_value()) {
431                 continue;
432               }
433               value = optional->value();
434             } else {
435               const auto* optional =
436                   reinterpret_cast<const std::optional<int64_t>*>(
437                       values.data());
438               if (!optional->has_value()) {
439                 continue;
440               }
441               value = field.varint_type() == VarintType::kZigZag
442                           ? varint::ZigZagEncode(optional->value())
443                           : optional->value();
444             }
445           } else if (field.elem_size() == sizeof(uint32_t)) {
446             if (field.varint_type() == VarintType::kUnsigned) {
447               const auto* optional =
448                   reinterpret_cast<const std::optional<uint32_t>*>(
449                       values.data());
450               if (!optional->has_value()) {
451                 continue;
452               }
453               value = optional->value();
454             } else {
455               const auto* optional =
456                   reinterpret_cast<const std::optional<int32_t>*>(
457                       values.data());
458               if (!optional->has_value()) {
459                 continue;
460               }
461               value = field.varint_type() == VarintType::kZigZag
462                           ? varint::ZigZagEncode(optional->value())
463                           : optional->value();
464             }
465           } else if (field.elem_size() == sizeof(bool)) {
466             const auto* optional =
467                 reinterpret_cast<const std::optional<bool>*>(values.data());
468             if (!optional->has_value()) {
469               continue;
470             }
471             value = optional->value();
472           }
473           PW_TRY(WriteVarintField(field.field_number(), value));
474         } else {
475           // The struct member for this field is a scalar of a type
476           // corresponding to the field element size. Cast to the correct
477           // type to retrieve the value before passing to WriteVarintField()
478           // so we're not performing type aliasing (except for unsigned vs
479           // signed which is explicitly allowed).
480           PW_CHECK(values.size() == field.elem_size(),
481                    "Mismatched message field type and size");
482           uint64_t value = 0;
483           if (field.elem_size() == sizeof(uint64_t)) {
484             if (field.varint_type() == VarintType::kZigZag) {
485               value = varint::ZigZagEncode(
486                   *reinterpret_cast<const int64_t*>(values.data()));
487             } else if (field.varint_type() == VarintType::kNormal) {
488               value = *reinterpret_cast<const int64_t*>(values.data());
489             } else {
490               value = *reinterpret_cast<const uint64_t*>(values.data());
491             }
492             if (!value) {
493               continue;
494             }
495           } else if (field.elem_size() == sizeof(uint32_t)) {
496             if (field.varint_type() == VarintType::kZigZag) {
497               value = varint::ZigZagEncode(
498                   *reinterpret_cast<const int32_t*>(values.data()));
499             } else if (field.varint_type() == VarintType::kNormal) {
500               value = *reinterpret_cast<const int32_t*>(values.data());
501             } else {
502               value = *reinterpret_cast<const uint32_t*>(values.data());
503             }
504             if (!value) {
505               continue;
506             }
507           } else if (field.elem_size() == sizeof(bool)) {
508             value = *reinterpret_cast<const bool*>(values.data());
509             if (!value) {
510               continue;
511             }
512           }
513           PW_TRY(WriteVarintField(field.field_number(), value));
514         }
515         break;
516       }
517       case WireType::kDelimited: {
518         // Delimited fields are always a singular case because of the
519         // inability to cast to a generic vector with an element of a certain
520         // size (we always need a type).
521         PW_CHECK(!field.is_repeated(),
522                  "Repeated delimited messages always require a callback");
523         if (field.nested_message_fields()) {
524           // Nested Message. Struct member is an embedded struct for the
525           // nested field. Obtain a nested encoder and recursively call Write()
526           // using the fields table pointer from this field.
527           auto nested_encoder = GetNestedEncoder(field.field_number(),
528                                                  /*write_when_empty=*/false);
529           PW_TRY(nested_encoder.Write(values, *field.nested_message_fields()));
530         } else if (field.is_fixed_size()) {
531           // Fixed-length bytes field. Struct member is a std::array<std::byte>.
532           // Call WriteLengthDelimitedField() to output it to the stream.
533           PW_CHECK(field.elem_size() == sizeof(std::byte),
534                    "Mismatched message field type and size");
535           if (static_cast<size_t>(
536                   std::count(values.begin(), values.end(), std::byte{0})) <
537               values.size()) {
538             PW_TRY(WriteLengthDelimitedField(field.field_number(), values));
539           }
540         } else {
541           // bytes or string field with a maximum size. Struct member is
542           // pw::Vector<std::byte> for bytes or pw::InlineString<> for string.
543           // Use the contents as a span and call WriteLengthDelimitedField() to
544           // output it to the stream.
545           PW_CHECK(field.elem_size() == sizeof(std::byte),
546                    "Mismatched message field type and size");
547           if (field.is_string()) {
548             PW_TRY(WriteStringOrBytes<const InlineString<>>(
549                 field.field_number(), values.data()));
550           } else {
551             PW_TRY(WriteStringOrBytes<const Vector<const std::byte>>(
552                 field.field_number(), values.data()));
553           }
554         }
555         break;
556       }
557     }
558   }
559 
560   return status_;
561 }
562 
563 }  // namespace pw::protobuf
564