1 // Copyright 2021 The Pigweed Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 // use this file except in compliance with the License. You may obtain a copy of 5 // the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 // License for the specific language governing permissions and limitations under 13 // the License. 14 // 15 // The header provides a set of helper utils for protobuf related operations. 16 // The APIs may not be finalized yet. 17 18 #pragma once 19 20 #include <cstddef> 21 #include <string_view> 22 23 #include "pw_assert/check.h" 24 #include "pw_protobuf/internal/proto_integer_base.h" 25 #include "pw_protobuf/stream_decoder.h" 26 #include "pw_status/status.h" 27 #include "pw_status/try.h" 28 #include "pw_stream/interval_reader.h" 29 #include "pw_stream/stream.h" 30 31 namespace pw::protobuf { 32 33 // The following defines classes that represent various parsed proto integer 34 // types or an error code to indicate parsing failure. 35 // 36 // For normal uses, the class should be created from `class Message`. See 37 // comment for `class Message` for usage. 38 39 class Uint32 : public internal::ProtoIntegerBase<uint32_t> { 40 public: 41 using ProtoIntegerBase<uint32_t>::ProtoIntegerBase; 42 }; 43 44 class Int32 : public internal::ProtoIntegerBase<int32_t> { 45 public: 46 using ProtoIntegerBase<int32_t>::ProtoIntegerBase; 47 }; 48 49 class Sint32 : public internal::ProtoIntegerBase<int32_t> { 50 public: 51 using ProtoIntegerBase<int32_t>::ProtoIntegerBase; 52 }; 53 54 class Fixed32 : public internal::ProtoIntegerBase<uint32_t> { 55 public: 56 using ProtoIntegerBase<uint32_t>::ProtoIntegerBase; 57 }; 58 59 class Sfixed32 : public internal::ProtoIntegerBase<int32_t> { 60 public: 61 using ProtoIntegerBase<int32_t>::ProtoIntegerBase; 62 }; 63 64 class Uint64 : public internal::ProtoIntegerBase<uint64_t> { 65 public: 66 using ProtoIntegerBase<uint64_t>::ProtoIntegerBase; 67 }; 68 69 class Int64 : public internal::ProtoIntegerBase<int64_t> { 70 public: 71 using ProtoIntegerBase<int64_t>::ProtoIntegerBase; 72 }; 73 74 class Sint64 : public internal::ProtoIntegerBase<int64_t> { 75 public: 76 using ProtoIntegerBase<int64_t>::ProtoIntegerBase; 77 }; 78 79 class Fixed64 : public internal::ProtoIntegerBase<uint64_t> { 80 public: 81 using ProtoIntegerBase<uint64_t>::ProtoIntegerBase; 82 }; 83 84 class Sfixed64 : public internal::ProtoIntegerBase<int64_t> { 85 public: 86 using ProtoIntegerBase<int64_t>::ProtoIntegerBase; 87 }; 88 89 class Float : public internal::ProtoIntegerBase<float> { 90 public: 91 using ProtoIntegerBase<float>::ProtoIntegerBase; 92 }; 93 94 class Double : public internal::ProtoIntegerBase<double> { 95 public: 96 using ProtoIntegerBase<double>::ProtoIntegerBase; 97 }; 98 99 class Bool : public internal::ProtoIntegerBase<bool> { 100 public: 101 using ProtoIntegerBase<bool>::ProtoIntegerBase; 102 }; 103 104 // An object that represents a parsed `bytes` field or an error code. The 105 // bytes are available via an stream::IntervalReader by GetBytesReader(). 106 // 107 // For normal uses, the class should be created from `class Message`. See 108 // comment for `class Message` for usage. 109 class Bytes { 110 public: 111 Bytes() = default; Bytes(Status status)112 Bytes(Status status) : reader_(status) {} Bytes(stream::IntervalReader reader)113 Bytes(stream::IntervalReader reader) : reader_(reader) {} GetBytesReader()114 stream::IntervalReader GetBytesReader() { return reader_; } 115 116 // TODO(pwbug/363): Migrate this to Result<> once we have StatusOr like 117 // support. ok()118 bool ok() { return reader_.ok(); } status()119 Status status() { return reader_.status(); } 120 121 // Check whether the bytes value equals the given `bytes`. 122 // TODO(pwbug/456): Should this return `bool`? In the case of error, is it ok 123 // to just return false? 124 Result<bool> Equal(ConstByteSpan bytes); 125 126 private: 127 stream::IntervalReader reader_; 128 }; 129 130 // An object that represents a parsed `string` field or an error code. The 131 // string value is available via an stream::IntervalReader by 132 // GetBytesReader(). 133 // 134 // For normal uses, the class should be created from `class Message`. See 135 // comment for `class Message` for usage. 136 class String : public Bytes { 137 public: 138 using Bytes::Bytes; 139 140 // Check whether the string value equals the given `str` 141 Result<bool> Equal(std::string_view str); 142 }; 143 144 // Forward declaration of parser classes. 145 template <typename FieldType> 146 class RepeatedFieldParser; 147 template <typename FieldType> 148 class StringMapEntryParser; 149 template <typename FieldType> 150 class StringMapParser; 151 class Message; 152 153 using RepeatedBytes = RepeatedFieldParser<Bytes>; 154 using RepeatedStrings = RepeatedFieldParser<String>; 155 using RepeatedMessages = RepeatedFieldParser<Message>; 156 using StringToBytesMapEntry = StringMapEntryParser<Bytes>; 157 using StringToStringMapEntry = StringMapEntryParser<String>; 158 using StringToMessageMapEntry = StringMapEntryParser<Message>; 159 using StringToBytesMap = StringMapParser<Bytes>; 160 using StringToStringMap = StringMapParser<String>; 161 using StringToMessageMap = StringMapParser<Message>; 162 163 // Message - A helper class for parsing a proto message. 164 // 165 // Examples: 166 // 167 // message Nested { 168 // string nested_str = 1; 169 // bytes nested_bytes = 2; 170 // } 171 // 172 // message { 173 // string str = 1; 174 // bytes bytes = 2; 175 // uint32 integer = 3 176 // repeated string rep_str = 4; 177 // map<string, bytes> str_to_bytes = 5; 178 // Nested nested = 6; 179 // } 180 // 181 // // Given a seekable `reader` that reads the top-level proto message, and 182 // // a <size> that gives the size of the proto message: 183 // Message message(reader, <size>); 184 // 185 // // Prase simple basic value fields 186 // String str = message.AsString(1); // string 187 // Bytes bytes = message.AsBytes(2); // bytes 188 // Uint32 integer = messasge_parser.AsUint32(3); // uint32 integer 189 // 190 // // Parse repeated field `repeated string rep_str = 4;` 191 // RepeatedStrings rep_str = message.AsRepeatedString(4); 192 // // Iterate through the entries. If proto is malformed when 193 // // iterating, the next element (`str` in this case) will be invalid 194 // // and loop will end in the iteration after. 195 // for (String str : rep_str) { 196 // // Check status 197 // if (!str.ok()) { 198 // // In the case of error, loop will end in the next iteration if 199 // // continues. This is the chance for code to catch the error. 200 // ... 201 // } 202 // ... 203 // } 204 // 205 // // Parse map field `map<string, bytes> str_to_bytes = 5;` 206 // StringToBytesMap str_to_bytes = message.AsStringToBytesMap(5); 207 // 208 // // Access the entry by a given key value 209 // Bytes bytes_for_key = str_to_bytes["key"]; 210 // 211 // // Or iterate through map entries 212 // for (StringToBytesMapEntry entry : str_to_bytes) { 213 // if (!entry.ok()) { 214 // // In the case of error, loop will end in the next iteration if 215 // // continues. This is the chance for code to catch the error. 216 // ... 217 // } 218 // String key = entry.Key(); 219 // Bytes value = entry.Value(); 220 // ... 221 // } 222 // 223 // // Parse nested message `Nested nested = 6;` 224 // Message nested = message.AsMessage(6). 225 // String nested_str = nested.AsString(1); 226 // Bytes nested_bytes = nested.AsBytes(2); 227 // 228 // // The `AsXXX()` methods above internally traverse all the fields to find 229 // // the one with the give field number. This can be expensive if called 230 // // multiple times. Therefore, whenever possible, it is recommended to use 231 // // the following iteration to iterate and process each field directly. 232 // for (Message::Field field : message) { 233 // if (!field.ok()) { 234 // // In the case of error, loop will end in the next iteration if 235 // // continues. This is the chance for code to catch the error. 236 // ... 237 // } 238 // if (field.field_number() == 1) { 239 // String str = field.As<String>(); 240 // ... 241 // } else if (field.field_number() == 2) { 242 // Bytes bytes = field.As<Bytes>(); 243 // ... 244 // } else if (field.field_number() == 6) { 245 // Message nested = field.As<Message>(); 246 // ... 247 // } 248 // } 249 // 250 // All parser objects created above internally hold the same reference 251 // to `reader`. Therefore it needs to maintain valid lifespan throughout the 252 // operations. The parser objects can work independently and without blocking 253 // each other. All method calls and for-iterations above are re-enterable. 254 class Message { 255 public: 256 class Field { 257 public: field_number()258 uint32_t field_number() { return field_number_; } field_reader()259 const stream::IntervalReader& field_reader() { return field_reader_; } ok()260 bool ok() { return field_reader_.ok(); } status()261 Status status() { return field_reader_.status(); } 262 263 // Create a helper parser type of `FieldType` for the field. 264 // The default implementation below assumes the field is a length-delimited 265 // field. Other cases such as primitive integer uint32 will be handled by 266 // template specialization. 267 template <typename FieldType> As()268 FieldType As() { 269 if (!field_reader_.ok()) { 270 return FieldType(field_reader_.status()); 271 } 272 273 StreamDecoder decoder(field_reader_.Reset()); 274 PW_TRY(decoder.Next()); 275 Result<StreamDecoder::Bounds> payload_bounds = 276 decoder.GetLengthDelimitedPayloadBounds(); 277 PW_TRY(payload_bounds.status()); 278 // The bounds is relative to the given stream::IntervalReader. Convert 279 // it to the interval relative to the source_reader. 280 return FieldType(stream::IntervalReader( 281 field_reader_.source_reader(), 282 payload_bounds.value().low + field_reader_.start(), 283 payload_bounds.value().high + field_reader_.start())); 284 } 285 286 private: 287 Field() = default; Field(Status status)288 Field(Status status) : field_reader_(status), field_number_(0) {} Field(stream::IntervalReader reader,uint32_t field_number)289 Field(stream::IntervalReader reader, uint32_t field_number) 290 : field_reader_(reader), field_number_(field_number) {} 291 292 stream::IntervalReader field_reader_; 293 uint32_t field_number_; 294 295 friend class Message; 296 }; 297 298 class iterator { 299 public: 300 iterator& operator++(); 301 302 iterator operator++(int) { 303 iterator iter = *this; 304 this->operator++(); 305 return iter; 306 } 307 ok()308 bool ok() { return status_.ok(); } status()309 Status status() { return status_; } 310 Field operator*() { return current_; } 311 Field* operator->() { return ¤t_; } 312 bool operator!=(const iterator& other) const { return !(*this == other); } 313 314 bool operator==(const iterator& other) const { 315 return eof_ == other.eof_ && reader_ == other.reader_; 316 } 317 318 private: 319 stream::IntervalReader reader_; 320 bool eof_ = false; 321 Field current_; 322 Status status_ = OkStatus(); 323 iterator(stream::IntervalReader reader)324 iterator(stream::IntervalReader reader) : reader_(reader) { 325 this->operator++(); 326 } 327 328 friend class Message; 329 }; 330 331 Message() = default; Message(Status status)332 Message(Status status) : reader_(status) {} Message(stream::IntervalReader reader)333 Message(stream::IntervalReader reader) : reader_(reader) {} Message(stream::SeekableReader & proto_source,size_t size)334 Message(stream::SeekableReader& proto_source, size_t size) 335 : reader_(proto_source, 0, size) {} 336 337 // Parse a sub-field in the message given by `field_number` as bytes. AsBytes(uint32_t field_number)338 Bytes AsBytes(uint32_t field_number) { return As<Bytes>(field_number); } 339 340 // Parse a sub-field in the message given by `field_number` as string. AsString(uint32_t field_number)341 String AsString(uint32_t field_number) { return As<String>(field_number); } 342 343 // Parse a sub-field in the message given by `field_number` as one of the 344 // proto integer type. AsInt32(uint32_t field_number)345 Int32 AsInt32(uint32_t field_number) { return As<Int32>(field_number); } AsSint32(uint32_t field_number)346 Sint32 AsSint32(uint32_t field_number) { return As<Sint32>(field_number); } AsUint32(uint32_t field_number)347 Uint32 AsUint32(uint32_t field_number) { return As<Uint32>(field_number); } AsFixed32(uint32_t field_number)348 Fixed32 AsFixed32(uint32_t field_number) { return As<Fixed32>(field_number); } AsInt64(uint32_t field_number)349 Int64 AsInt64(uint32_t field_number) { return As<Int64>(field_number); } AsSint64(uint32_t field_number)350 Sint64 AsSint64(uint32_t field_number) { return As<Sint64>(field_number); } AsUint64(uint32_t field_number)351 Uint64 AsUint64(uint32_t field_number) { return As<Uint64>(field_number); } AsFixed64(uint32_t field_number)352 Fixed64 AsFixed64(uint32_t field_number) { return As<Fixed64>(field_number); } 353 AsSfixed32(uint32_t field_number)354 Sfixed32 AsSfixed32(uint32_t field_number) { 355 return As<Sfixed32>(field_number); 356 } 357 AsSfixed64(uint32_t field_number)358 Sfixed64 AsSfixed64(uint32_t field_number) { 359 return As<Sfixed64>(field_number); 360 } 361 AsFloat(uint32_t field_number)362 Float AsFloat(uint32_t field_number) { return As<Float>(field_number); } AsDouble(uint32_t field_number)363 Double AsDouble(uint32_t field_number) { return As<Double>(field_number); } 364 AsBool(uint32_t field_number)365 Bool AsBool(uint32_t field_number) { return As<Bool>(field_number); } 366 367 // Parse a sub-field in the message given by `field_number` as another 368 // message. AsMessage(uint32_t field_number)369 Message AsMessage(uint32_t field_number) { return As<Message>(field_number); } 370 371 // Parse a sub-field in the message given by `field_number` as `repeated 372 // string`. 373 RepeatedBytes AsRepeatedBytes(uint32_t field_number); 374 375 // Parse a sub-field in the message given by `field_number` as `repeated 376 // string`. 377 RepeatedStrings AsRepeatedStrings(uint32_t field_number); 378 379 // Parse a sub-field in the message given by `field_number` as `repeated 380 // message`. 381 RepeatedMessages AsRepeatedMessages(uint32_t field_number); 382 383 // Parse a sub-field in the message given by `field_number` as `map<string, 384 // message>`. 385 StringToMessageMap AsStringToMessageMap(uint32_t field_number); 386 387 // Parse a sub-field in the message given by `field_number` as 388 // `map<string, bytes>`. 389 StringToBytesMap AsStringToBytesMap(uint32_t field_number); 390 391 // Parse a sub-field in the message given by `field_number` as 392 // `map<string, string>`. 393 StringToStringMap AsStringToStringMap(uint32_t field_number); 394 395 // Convert the message to a Bytes that represents the raw bytes of this 396 // message. This can be used to obatained the serialized wire-format of the 397 // message. ToBytes()398 Bytes ToBytes() { return Bytes(reader_.Reset()); } 399 400 // TODO(pwbug/363): Migrate this to Result<> once we have StatusOr like 401 // support. ok()402 bool ok() { return reader_.ok(); } status()403 Status status() { return reader_.status(); } 404 405 iterator begin(); 406 iterator end(); 407 408 // Parse a field given by `field_number` as the target parser type 409 // `FieldType`. 410 // 411 // Note: This method assumes that the message has only 1 field with the given 412 // <field_number>. It returns the first matching it find. It does not perform 413 // value overridding or string concatenation for multiple fields with the same 414 // <field_number>. 415 // 416 // Since the method needs to traverse all fields, it can be inefficient if 417 // called multiple times exepcially on slow reader. 418 template <typename FieldType> As(uint32_t field_number)419 FieldType As(uint32_t field_number) { 420 for (Field field : *this) { 421 if (field.field_number() == field_number) { 422 return field.As<FieldType>(); 423 } 424 } 425 426 return FieldType(Status::NotFound()); 427 } 428 429 template <typename FieldType> AsRepeated(uint32_t field_number)430 RepeatedFieldParser<FieldType> AsRepeated(uint32_t field_number) { 431 return RepeatedFieldParser<FieldType>(*this, field_number); 432 } 433 434 template <typename FieldParser> AsStringMap(uint32_t field_number)435 StringMapParser<FieldParser> AsStringMap(uint32_t field_number) { 436 return StringMapParser<FieldParser>(*this, field_number); 437 } 438 439 private: 440 stream::IntervalReader reader_; 441 442 // Consume the current field. If the field has already been processed, i.e. 443 // by calling one of the Read..() method, nothing is done. After calling this 444 // method, the reader will be pointing either to the start of the next 445 // field (i.e. the starting offset of the field key), or the end of the 446 // stream. The method is for use by Message for computing field interval. ConsumeCurrentField(StreamDecoder & decoder)447 static Status ConsumeCurrentField(StreamDecoder& decoder) { 448 return decoder.field_consumed_ ? OkStatus() : decoder.SkipField(); 449 } 450 }; 451 452 // The following are template specialization for proto integer types. 453 template <> 454 Uint32 Message::Field::As<Uint32>(); 455 456 template <> 457 Int32 Message::Field::As<Int32>(); 458 459 template <> 460 Sint32 Message::Field::As<Sint32>(); 461 462 template <> 463 Fixed32 Message::Field::As<Fixed32>(); 464 465 template <> 466 Sfixed32 Message::Field::As<Sfixed32>(); 467 468 template <> 469 Uint64 Message::Field::As<Uint64>(); 470 471 template <> 472 Int64 Message::Field::As<Int64>(); 473 474 template <> 475 Sint64 Message::Field::As<Sint64>(); 476 477 template <> 478 Fixed64 Message::Field::As<Fixed64>(); 479 480 template <> 481 Sfixed64 Message::Field::As<Sfixed64>(); 482 483 template <> 484 Float Message::Field::As<Float>(); 485 486 template <> 487 Double Message::Field::As<Double>(); 488 489 template <> 490 Bool Message::Field::As<Bool>(); 491 492 // A helper for parsing `repeated` field. It implements an iterator interface 493 // that only iterates through the fields of a given `field_number`. 494 // 495 // For normal uses, the class should be created from `class Message`. See 496 // comment for `class Message` for usage. 497 template <typename FieldType> 498 class RepeatedFieldParser { 499 public: 500 class iterator { 501 public: 502 // Precondition: iter_ is not pointing to the end. 503 iterator& operator++() { 504 iter_++; 505 MoveToNext(); 506 return *this; 507 } 508 509 iterator operator++(int) { 510 iterator iter = *this; 511 this->operator++(); 512 return iter; 513 } 514 ok()515 bool ok() { return iter_.ok(); } status()516 Status status() { return iter_.status(); } 517 FieldType operator*() { return current_; } 518 FieldType* operator->() { return ¤t_; } 519 bool operator!=(const iterator& other) const { return !(*this == other); } 520 bool operator==(const iterator& other) const { 521 return &host_ == &other.host_ && iter_ == other.iter_; 522 } 523 524 private: 525 RepeatedFieldParser& host_; 526 Message::iterator iter_; 527 FieldType current_ = FieldType(Status::Unavailable()); 528 iterator(RepeatedFieldParser & host,Message::iterator init_iter)529 iterator(RepeatedFieldParser& host, Message::iterator init_iter) 530 : host_(host), iter_(init_iter), current_(Status::Unavailable()) { 531 // Move to the first element of the target field number. 532 MoveToNext(); 533 } 534 MoveToNext()535 void MoveToNext() { 536 // Move the iterator to the next element with the target field number 537 for (; iter_ != host_.message_.end(); ++iter_) { 538 if (!iter_.ok() || iter_->field_number() == host_.field_number_) { 539 current_ = iter_->As<FieldType>(); 540 break; 541 } 542 } 543 } 544 545 friend class RepeatedFieldParser; 546 }; 547 548 // `message` -- The containing message. 549 // `field_number` -- The field number of the repeated field. RepeatedFieldParser(Message & message,uint32_t field_number)550 RepeatedFieldParser(Message& message, uint32_t field_number) 551 : message_(message), field_number_(field_number) {} 552 RepeatedFieldParser(Status status)553 RepeatedFieldParser(Status status) : message_(status) {} 554 555 // TODO(pwbug/363): Migrate this to Result<> once we have StatusOr like 556 // support. ok()557 bool ok() { return message_.ok(); } status()558 Status status() { return message_.status(); } 559 begin()560 iterator begin() { return iterator(*this, message_.begin()); } end()561 iterator end() { return iterator(*this, message_.end()); } 562 563 private: 564 Message message_; 565 uint32_t field_number_ = 0; 566 }; 567 568 // A helper for pasring the entry type of map<string, <value>>. 569 // An entry for a proto map is essentially a message of a key(k=1) and 570 // value(k=2) field, i.e.: 571 // 572 // message Entry { 573 // string key = 1; 574 // bytes value = 2; 575 // } 576 // 577 // For normal uses, the class should be created from `class Message`. See 578 // comment for `class Message` for usage. 579 template <typename ValueParser> 580 class StringMapEntryParser { 581 public: ok()582 bool ok() { return entry_.ok(); } status()583 Status status() { return entry_.status(); } StringMapEntryParser(Status status)584 StringMapEntryParser(Status status) : entry_(status) {} StringMapEntryParser(stream::IntervalReader reader)585 StringMapEntryParser(stream::IntervalReader reader) : entry_(reader) {} Key()586 String Key() { return entry_.AsString(kMapKeyFieldNumber); } Value()587 ValueParser Value() { return entry_.As<ValueParser>(kMapValueFieldNumber); } 588 589 private: 590 static constexpr uint32_t kMapKeyFieldNumber = 1; 591 static constexpr uint32_t kMapValueFieldNumber = 2; 592 Message entry_; 593 }; 594 595 // A helper class for parsing a string-keyed map field. i.e. map<string, 596 // <value>>. The template argument `ValueParser` indicates the type the value 597 // will be parsed as, i.e. String, Bytes, Uint32, Message etc. 598 // 599 // For normal uses, the class should be created from `class Message`. See 600 // comment for `class Message` for usage. 601 template <typename ValueParser> 602 class StringMapParser 603 : public RepeatedFieldParser<StringMapEntryParser<ValueParser>> { 604 public: 605 using RepeatedFieldParser< 606 StringMapEntryParser<ValueParser>>::RepeatedFieldParser; 607 608 // Operator overload for value access of a given key. 609 ValueParser operator[](std::string_view target) { 610 // Iterate over all entries and find the one whose key matches `target` 611 for (StringMapEntryParser<ValueParser> entry : *this) { 612 String key = entry.Key(); 613 PW_TRY(key.status()); 614 615 // Compare key value with the given string 616 Result<bool> cmp_res = key.Equal(target); 617 PW_TRY(cmp_res.status()); 618 if (cmp_res.value()) { 619 return entry.Value(); 620 } 621 } 622 623 return ValueParser(Status::NotFound()); 624 } 625 }; 626 627 } // namespace pw::protobuf 628