1 // Copyright 2021 The Pigweed Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 // use this file except in compliance with the License. You may obtain a copy of 5 // the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 // License for the specific language governing permissions and limitations under 13 // the License. 14 // 15 // The header provides a set of helper utils for protobuf related operations. 16 // The APIs may not be finalized yet. 17 18 #pragma once 19 20 #include <cstddef> 21 #include <string_view> 22 23 #include "pw_assert/check.h" 24 #include "pw_protobuf/internal/proto_integer_base.h" 25 #include "pw_protobuf/stream_decoder.h" 26 #include "pw_status/status.h" 27 #include "pw_status/try.h" 28 #include "pw_stream/interval_reader.h" 29 #include "pw_stream/stream.h" 30 31 namespace pw::protobuf { 32 33 // The following defines classes that represent various parsed proto integer 34 // types or an error code to indicate parsing failure. 35 // 36 // For normal uses, the class should be created from `class Message`. See 37 // comment for `class Message` for usage. 38 39 class Uint32 : public internal::ProtoIntegerBase<uint32_t> { 40 public: 41 using ProtoIntegerBase<uint32_t>::ProtoIntegerBase; 42 }; 43 44 class Int32 : public internal::ProtoIntegerBase<int32_t> { 45 public: 46 using ProtoIntegerBase<int32_t>::ProtoIntegerBase; 47 }; 48 49 class Sint32 : public internal::ProtoIntegerBase<int32_t> { 50 public: 51 using ProtoIntegerBase<int32_t>::ProtoIntegerBase; 52 }; 53 54 class Fixed32 : public internal::ProtoIntegerBase<uint32_t> { 55 public: 56 using ProtoIntegerBase<uint32_t>::ProtoIntegerBase; 57 }; 58 59 class Sfixed32 : public internal::ProtoIntegerBase<int32_t> { 60 public: 61 using ProtoIntegerBase<int32_t>::ProtoIntegerBase; 62 }; 63 64 class Uint64 : public internal::ProtoIntegerBase<uint64_t> { 65 public: 66 using ProtoIntegerBase<uint64_t>::ProtoIntegerBase; 67 }; 68 69 class Int64 : public internal::ProtoIntegerBase<int64_t> { 70 public: 71 using ProtoIntegerBase<int64_t>::ProtoIntegerBase; 72 }; 73 74 class Sint64 : public internal::ProtoIntegerBase<int64_t> { 75 public: 76 using ProtoIntegerBase<int64_t>::ProtoIntegerBase; 77 }; 78 79 class Fixed64 : public internal::ProtoIntegerBase<uint64_t> { 80 public: 81 using ProtoIntegerBase<uint64_t>::ProtoIntegerBase; 82 }; 83 84 class Sfixed64 : public internal::ProtoIntegerBase<int64_t> { 85 public: 86 using ProtoIntegerBase<int64_t>::ProtoIntegerBase; 87 }; 88 89 class Float : public internal::ProtoIntegerBase<float> { 90 public: 91 using ProtoIntegerBase<float>::ProtoIntegerBase; 92 }; 93 94 class Double : public internal::ProtoIntegerBase<double> { 95 public: 96 using ProtoIntegerBase<double>::ProtoIntegerBase; 97 }; 98 99 class Bool : public internal::ProtoIntegerBase<bool> { 100 public: 101 using ProtoIntegerBase<bool>::ProtoIntegerBase; 102 }; 103 104 // An object that represents a parsed `bytes` field or an error code. The 105 // bytes are available via an stream::IntervalReader by GetBytesReader(). 106 // 107 // For normal uses, the class should be created from `class Message`. See 108 // comment for `class Message` for usage. 109 class Bytes { 110 public: 111 Bytes() = default; Bytes(Status status)112 Bytes(Status status) : reader_(status) {} Bytes(stream::IntervalReader reader)113 Bytes(stream::IntervalReader reader) : reader_(reader) {} GetBytesReader()114 stream::IntervalReader GetBytesReader() { return reader_; } 115 ok()116 bool ok() { return reader_.ok(); } status()117 Status status() { return reader_.status(); } 118 119 // Check whether the bytes value equals the given `bytes`. 120 Result<bool> Equal(ConstByteSpan bytes); 121 122 private: 123 stream::IntervalReader reader_; 124 }; 125 126 // An object that represents a parsed `string` field or an error code. The 127 // string value is available via an stream::IntervalReader by 128 // GetBytesReader(). 129 // 130 // For normal uses, the class should be created from `class Message`. See 131 // comment for `class Message` for usage. 132 class String : public Bytes { 133 public: 134 using Bytes::Bytes; 135 136 // Check whether the string value equals the given `str` 137 Result<bool> Equal(std::string_view str); 138 }; 139 140 // Forward declaration of parser classes. 141 template <typename FieldType> 142 class RepeatedFieldParser; 143 template <typename FieldType> 144 class StringMapEntryParser; 145 template <typename FieldType> 146 class StringMapParser; 147 class Message; 148 149 using RepeatedBytes = RepeatedFieldParser<Bytes>; 150 using RepeatedStrings = RepeatedFieldParser<String>; 151 using RepeatedMessages = RepeatedFieldParser<Message>; 152 using StringToBytesMapEntry = StringMapEntryParser<Bytes>; 153 using StringToStringMapEntry = StringMapEntryParser<String>; 154 using StringToMessageMapEntry = StringMapEntryParser<Message>; 155 using StringToBytesMap = StringMapParser<Bytes>; 156 using StringToStringMap = StringMapParser<String>; 157 using StringToMessageMap = StringMapParser<Message>; 158 159 // Message - A helper class for parsing a proto message. 160 // 161 // Examples: 162 // 163 // message Nested { 164 // string nested_str = 1; 165 // bytes nested_bytes = 2; 166 // } 167 // 168 // message { 169 // string str = 1; 170 // bytes bytes = 2; 171 // uint32 integer = 3 172 // repeated string rep_str = 4; 173 // map<string, bytes> str_to_bytes = 5; 174 // Nested nested = 6; 175 // } 176 // 177 // // Given a seekable `reader` that reads the top-level proto message, and 178 // // a <size> that gives the size of the proto message: 179 // Message message(reader, <size>); 180 // 181 // // Prase simple basic value fields 182 // String str = message.AsString(1); // string 183 // Bytes bytes = message.AsBytes(2); // bytes 184 // Uint32 integer = messasge_parser.AsUint32(3); // uint32 integer 185 // 186 // // Parse repeated field `repeated string rep_str = 4;` 187 // RepeatedStrings rep_str = message.AsRepeatedString(4); 188 // // Iterate through the entries. If proto is malformed when 189 // // iterating, the next element (`str` in this case) will be invalid 190 // // and loop will end in the iteration after. 191 // for (String str : rep_str) { 192 // // Check status 193 // if (!str.ok()) { 194 // // In the case of error, loop will end in the next iteration if 195 // // continues. This is the chance for code to catch the error. 196 // ... 197 // } 198 // ... 199 // } 200 // 201 // // Parse map field `map<string, bytes> str_to_bytes = 5;` 202 // StringToBytesMap str_to_bytes = message.AsStringToBytesMap(5); 203 // 204 // // Access the entry by a given key value 205 // Bytes bytes_for_key = str_to_bytes["key"]; 206 // 207 // // Or iterate through map entries 208 // for (StringToBytesMapEntry entry : str_to_bytes) { 209 // if (!entry.ok()) { 210 // // In the case of error, loop will end in the next iteration if 211 // // continues. This is the chance for code to catch the error. 212 // ... 213 // } 214 // String key = entry.Key(); 215 // Bytes value = entry.Value(); 216 // ... 217 // } 218 // 219 // // Parse nested message `Nested nested = 6;` 220 // Message nested = message.AsMessage(6). 221 // String nested_str = nested.AsString(1); 222 // Bytes nested_bytes = nested.AsBytes(2); 223 // 224 // // The `AsXXX()` methods above internally traverse all the fields to find 225 // // the one with the give field number. This can be expensive if called 226 // // multiple times. Therefore, whenever possible, it is recommended to use 227 // // the following iteration to iterate and process each field directly. 228 // for (Message::Field field : message) { 229 // if (!field.ok()) { 230 // // In the case of error, loop will end in the next iteration if 231 // // continues. This is the chance for code to catch the error. 232 // ... 233 // } 234 // if (field.field_number() == 1) { 235 // String str = field.As<String>(); 236 // ... 237 // } else if (field.field_number() == 2) { 238 // Bytes bytes = field.As<Bytes>(); 239 // ... 240 // } else if (field.field_number() == 6) { 241 // Message nested = field.As<Message>(); 242 // ... 243 // } 244 // } 245 // 246 // All parser objects created above internally hold the same reference 247 // to `reader`. Therefore it needs to maintain valid lifespan throughout the 248 // operations. The parser objects can work independently and without blocking 249 // each other. All method calls and for-iterations above are re-enterable. 250 class Message { 251 public: 252 class Field { 253 public: field_number()254 uint32_t field_number() { return field_number_; } field_reader()255 const stream::IntervalReader& field_reader() { return field_reader_; } ok()256 bool ok() { return field_reader_.ok(); } status()257 Status status() { return field_reader_.status(); } 258 259 // Create a helper parser type of `FieldType` for the field. 260 // The default implementation below assumes the field is a length-delimited 261 // field. Other cases such as primitive integer uint32 will be handled by 262 // template specialization. 263 template <typename FieldType> As()264 FieldType As() { 265 if (!field_reader_.ok()) { 266 return FieldType(field_reader_.status()); 267 } 268 269 StreamDecoder decoder(field_reader_.Reset()); 270 PW_TRY(decoder.Next()); 271 Result<StreamDecoder::Bounds> payload_bounds = 272 decoder.GetLengthDelimitedPayloadBounds(); 273 PW_TRY(payload_bounds.status()); 274 // The bounds is relative to the given stream::IntervalReader. Convert 275 // it to the interval relative to the source_reader. 276 return FieldType(stream::IntervalReader( 277 field_reader_.source_reader(), 278 payload_bounds.value().low + field_reader_.start(), 279 payload_bounds.value().high + field_reader_.start())); 280 } 281 282 private: 283 Field() = default; Field(Status status)284 Field(Status status) : field_reader_(status), field_number_(0) {} Field(stream::IntervalReader reader,uint32_t field_number)285 Field(stream::IntervalReader reader, uint32_t field_number) 286 : field_reader_(reader), field_number_(field_number) {} 287 288 stream::IntervalReader field_reader_; 289 uint32_t field_number_; 290 291 friend class Message; 292 }; 293 294 class iterator { 295 public: 296 iterator& operator++(); 297 298 iterator operator++(int) { 299 iterator iter = *this; 300 this->operator++(); 301 return iter; 302 } 303 ok()304 bool ok() { return status_.ok(); } status()305 Status status() { return status_; } 306 Field operator*() { return current_; } 307 Field* operator->() { return ¤t_; } 308 bool operator!=(const iterator& other) const { return !(*this == other); } 309 310 bool operator==(const iterator& other) const { 311 return eof_ == other.eof_ && reader_ == other.reader_; 312 } 313 314 private: 315 stream::IntervalReader reader_; 316 bool eof_ = false; 317 Field current_; 318 Status status_ = OkStatus(); 319 iterator(stream::IntervalReader reader)320 iterator(stream::IntervalReader reader) : reader_(reader) { 321 this->operator++(); 322 } 323 324 friend class Message; 325 }; 326 327 Message() = default; Message(Status status)328 Message(Status status) : reader_(status) {} Message(stream::IntervalReader reader)329 Message(stream::IntervalReader reader) : reader_(reader) {} Message(stream::SeekableReader & proto_source,size_t size)330 Message(stream::SeekableReader& proto_source, size_t size) 331 : reader_(proto_source, 0, size) {} 332 333 // Parse a sub-field in the message given by `field_number` as bytes. AsBytes(uint32_t field_number)334 Bytes AsBytes(uint32_t field_number) { return As<Bytes>(field_number); } 335 336 // Parse a sub-field in the message given by `field_number` as string. AsString(uint32_t field_number)337 String AsString(uint32_t field_number) { return As<String>(field_number); } 338 339 // Parse a sub-field in the message given by `field_number` as one of the 340 // proto integer type. AsInt32(uint32_t field_number)341 Int32 AsInt32(uint32_t field_number) { return As<Int32>(field_number); } AsSint32(uint32_t field_number)342 Sint32 AsSint32(uint32_t field_number) { return As<Sint32>(field_number); } AsUint32(uint32_t field_number)343 Uint32 AsUint32(uint32_t field_number) { return As<Uint32>(field_number); } AsFixed32(uint32_t field_number)344 Fixed32 AsFixed32(uint32_t field_number) { return As<Fixed32>(field_number); } AsInt64(uint32_t field_number)345 Int64 AsInt64(uint32_t field_number) { return As<Int64>(field_number); } AsSint64(uint32_t field_number)346 Sint64 AsSint64(uint32_t field_number) { return As<Sint64>(field_number); } AsUint64(uint32_t field_number)347 Uint64 AsUint64(uint32_t field_number) { return As<Uint64>(field_number); } AsFixed64(uint32_t field_number)348 Fixed64 AsFixed64(uint32_t field_number) { return As<Fixed64>(field_number); } 349 AsSfixed32(uint32_t field_number)350 Sfixed32 AsSfixed32(uint32_t field_number) { 351 return As<Sfixed32>(field_number); 352 } 353 AsSfixed64(uint32_t field_number)354 Sfixed64 AsSfixed64(uint32_t field_number) { 355 return As<Sfixed64>(field_number); 356 } 357 AsFloat(uint32_t field_number)358 Float AsFloat(uint32_t field_number) { return As<Float>(field_number); } AsDouble(uint32_t field_number)359 Double AsDouble(uint32_t field_number) { return As<Double>(field_number); } 360 AsBool(uint32_t field_number)361 Bool AsBool(uint32_t field_number) { return As<Bool>(field_number); } 362 363 // Parse a sub-field in the message given by `field_number` as another 364 // message. AsMessage(uint32_t field_number)365 Message AsMessage(uint32_t field_number) { return As<Message>(field_number); } 366 367 // Parse a sub-field in the message given by `field_number` as `repeated 368 // string`. 369 RepeatedBytes AsRepeatedBytes(uint32_t field_number); 370 371 // Parse a sub-field in the message given by `field_number` as `repeated 372 // string`. 373 RepeatedStrings AsRepeatedStrings(uint32_t field_number); 374 375 // Parse a sub-field in the message given by `field_number` as `repeated 376 // message`. 377 RepeatedMessages AsRepeatedMessages(uint32_t field_number); 378 379 // Parse a sub-field in the message given by `field_number` as `map<string, 380 // message>`. 381 StringToMessageMap AsStringToMessageMap(uint32_t field_number); 382 383 // Parse a sub-field in the message given by `field_number` as 384 // `map<string, bytes>`. 385 StringToBytesMap AsStringToBytesMap(uint32_t field_number); 386 387 // Parse a sub-field in the message given by `field_number` as 388 // `map<string, string>`. 389 StringToStringMap AsStringToStringMap(uint32_t field_number); 390 391 // Convert the message to a Bytes that represents the raw bytes of this 392 // message. This can be used to obatained the serialized wire-format of the 393 // message. ToBytes()394 Bytes ToBytes() { return Bytes(reader_.Reset()); } 395 ok()396 bool ok() { return reader_.ok(); } status()397 Status status() { return reader_.status(); } 398 399 iterator begin(); 400 iterator end(); 401 402 // Parse a field given by `field_number` as the target parser type 403 // `FieldType`. 404 // 405 // Note: This method assumes that the message has only 1 field with the given 406 // <field_number>. It returns the first matching it find. It does not perform 407 // value overridding or string concatenation for multiple fields with the same 408 // <field_number>. 409 // 410 // Since the method needs to traverse all fields, it can be inefficient if 411 // called multiple times exepcially on slow reader. 412 template <typename FieldType> As(uint32_t field_number)413 FieldType As(uint32_t field_number) { 414 for (Field field : *this) { 415 if (field.field_number() == field_number) { 416 return field.As<FieldType>(); 417 } 418 } 419 420 return FieldType(Status::NotFound()); 421 } 422 423 template <typename FieldType> AsRepeated(uint32_t field_number)424 RepeatedFieldParser<FieldType> AsRepeated(uint32_t field_number) { 425 return RepeatedFieldParser<FieldType>(*this, field_number); 426 } 427 428 template <typename FieldParser> AsStringMap(uint32_t field_number)429 StringMapParser<FieldParser> AsStringMap(uint32_t field_number) { 430 return StringMapParser<FieldParser>(*this, field_number); 431 } 432 433 private: 434 stream::IntervalReader reader_; 435 436 // Consume the current field. If the field has already been processed, i.e. 437 // by calling one of the Read..() method, nothing is done. After calling this 438 // method, the reader will be pointing either to the start of the next 439 // field (i.e. the starting offset of the field key), or the end of the 440 // stream. The method is for use by Message for computing field interval. ConsumeCurrentField(StreamDecoder & decoder)441 static Status ConsumeCurrentField(StreamDecoder& decoder) { 442 return decoder.field_consumed_ ? OkStatus() : decoder.SkipField(); 443 } 444 }; 445 446 // The following are template specialization for proto integer types. 447 template <> 448 Uint32 Message::Field::As<Uint32>(); 449 450 template <> 451 Int32 Message::Field::As<Int32>(); 452 453 template <> 454 Sint32 Message::Field::As<Sint32>(); 455 456 template <> 457 Fixed32 Message::Field::As<Fixed32>(); 458 459 template <> 460 Sfixed32 Message::Field::As<Sfixed32>(); 461 462 template <> 463 Uint64 Message::Field::As<Uint64>(); 464 465 template <> 466 Int64 Message::Field::As<Int64>(); 467 468 template <> 469 Sint64 Message::Field::As<Sint64>(); 470 471 template <> 472 Fixed64 Message::Field::As<Fixed64>(); 473 474 template <> 475 Sfixed64 Message::Field::As<Sfixed64>(); 476 477 template <> 478 Float Message::Field::As<Float>(); 479 480 template <> 481 Double Message::Field::As<Double>(); 482 483 template <> 484 Bool Message::Field::As<Bool>(); 485 486 // A helper for parsing `repeated` field. It implements an iterator interface 487 // that only iterates through the fields of a given `field_number`. 488 // 489 // For normal uses, the class should be created from `class Message`. See 490 // comment for `class Message` for usage. 491 template <typename FieldType> 492 class RepeatedFieldParser { 493 public: 494 class iterator { 495 public: 496 // Precondition: iter_ is not pointing to the end. 497 iterator& operator++() { 498 iter_++; 499 MoveToNext(); 500 return *this; 501 } 502 503 iterator operator++(int) { 504 iterator iter = *this; 505 this->operator++(); 506 return iter; 507 } 508 ok()509 bool ok() { return iter_.ok(); } status()510 Status status() { return iter_.status(); } 511 FieldType operator*() { return current_; } 512 FieldType* operator->() { return ¤t_; } 513 bool operator!=(const iterator& other) const { return !(*this == other); } 514 bool operator==(const iterator& other) const { 515 return &host_ == &other.host_ && iter_ == other.iter_; 516 } 517 518 private: 519 RepeatedFieldParser& host_; 520 Message::iterator iter_; 521 FieldType current_ = FieldType(Status::Unavailable()); 522 iterator(RepeatedFieldParser & host,Message::iterator init_iter)523 iterator(RepeatedFieldParser& host, Message::iterator init_iter) 524 : host_(host), iter_(init_iter), current_(Status::Unavailable()) { 525 // Move to the first element of the target field number. 526 MoveToNext(); 527 } 528 MoveToNext()529 void MoveToNext() { 530 // Move the iterator to the next element with the target field number 531 for (; iter_ != host_.message_.end(); ++iter_) { 532 if (!iter_.ok() || iter_->field_number() == host_.field_number_) { 533 current_ = iter_->As<FieldType>(); 534 break; 535 } 536 } 537 } 538 539 friend class RepeatedFieldParser; 540 }; 541 542 // `message` -- The containing message. 543 // `field_number` -- The field number of the repeated field. RepeatedFieldParser(Message & message,uint32_t field_number)544 RepeatedFieldParser(Message& message, uint32_t field_number) 545 : message_(message), field_number_(field_number) {} 546 RepeatedFieldParser(Status status)547 RepeatedFieldParser(Status status) : message_(status) {} 548 ok()549 bool ok() { return message_.ok(); } status()550 Status status() { return message_.status(); } 551 begin()552 iterator begin() { return iterator(*this, message_.begin()); } end()553 iterator end() { return iterator(*this, message_.end()); } 554 555 private: 556 Message message_; 557 uint32_t field_number_ = 0; 558 }; 559 560 // A helper for pasring the entry type of map<string, <value>>. 561 // An entry for a proto map is essentially a message of a key(k=1) and 562 // value(k=2) field, i.e.: 563 // 564 // message Entry { 565 // string key = 1; 566 // bytes value = 2; 567 // } 568 // 569 // For normal uses, the class should be created from `class Message`. See 570 // comment for `class Message` for usage. 571 template <typename ValueParser> 572 class StringMapEntryParser { 573 public: ok()574 bool ok() { return entry_.ok(); } status()575 Status status() { return entry_.status(); } StringMapEntryParser(Status status)576 StringMapEntryParser(Status status) : entry_(status) {} StringMapEntryParser(stream::IntervalReader reader)577 StringMapEntryParser(stream::IntervalReader reader) : entry_(reader) {} Key()578 String Key() { return entry_.AsString(kMapKeyFieldNumber); } Value()579 ValueParser Value() { return entry_.As<ValueParser>(kMapValueFieldNumber); } 580 581 private: 582 static constexpr uint32_t kMapKeyFieldNumber = 1; 583 static constexpr uint32_t kMapValueFieldNumber = 2; 584 Message entry_; 585 }; 586 587 // A helper class for parsing a string-keyed map field. i.e. map<string, 588 // <value>>. The template argument `ValueParser` indicates the type the value 589 // will be parsed as, i.e. String, Bytes, Uint32, Message etc. 590 // 591 // For normal uses, the class should be created from `class Message`. See 592 // comment for `class Message` for usage. 593 template <typename ValueParser> 594 class StringMapParser 595 : public RepeatedFieldParser<StringMapEntryParser<ValueParser>> { 596 public: 597 using RepeatedFieldParser< 598 StringMapEntryParser<ValueParser>>::RepeatedFieldParser; 599 600 // Operator overload for value access of a given key. 601 ValueParser operator[](std::string_view target) { 602 // Iterate over all entries and find the one whose key matches `target` 603 for (StringMapEntryParser<ValueParser> entry : *this) { 604 String key = entry.Key(); 605 PW_TRY(key.status()); 606 607 // Compare key value with the given string 608 Result<bool> cmp_res = key.Equal(target); 609 PW_TRY(cmp_res.status()); 610 if (cmp_res.value()) { 611 return entry.Value(); 612 } 613 } 614 615 return ValueParser(Status::NotFound()); 616 } 617 }; 618 619 } // namespace pw::protobuf 620