1 /*
2 * Copyright 2019 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "cppbor.h"
18
19 #include <inttypes.h>
20 #include <openssl/sha.h>
21 #include <cstdint>
22
23 #include "cppbor_parse.h"
24
25 using std::string;
26 using std::vector;
27
28 #ifndef __TRUSTY__
29 #include <android-base/logging.h>
30 #define LOG_TAG "CppBor"
31 #else
32 #define CHECK(x) (void)(x)
33 #endif
34
35 namespace cppbor {
36
37 namespace {
38
39 template <typename T, typename Iterator, typename = std::enable_if<std::is_unsigned<T>::value>>
40 Iterator writeBigEndian(T value, Iterator pos) {
41 for (unsigned i = 0; i < sizeof(value); ++i) {
42 *pos++ = static_cast<uint8_t>(value >> (8 * (sizeof(value) - 1)));
43 value = static_cast<T>(value << 8);
44 }
45 return pos;
46 }
47
48 template <typename T, typename = std::enable_if<std::is_unsigned<T>::value>>
writeBigEndian(T value,std::function<void (uint8_t)> & cb)49 void writeBigEndian(T value, std::function<void(uint8_t)>& cb) {
50 for (unsigned i = 0; i < sizeof(value); ++i) {
51 cb(static_cast<uint8_t>(value >> (8 * (sizeof(value) - 1))));
52 value = static_cast<T>(value << 8);
53 }
54 }
55
cborAreAllElementsNonCompound(const Item * compoundItem)56 bool cborAreAllElementsNonCompound(const Item* compoundItem) {
57 if (compoundItem->type() == ARRAY) {
58 const Array* array = compoundItem->asArray();
59 for (size_t n = 0; n < array->size(); n++) {
60 const Item* entry = (*array)[n].get();
61 switch (entry->type()) {
62 case ARRAY:
63 case MAP:
64 return false;
65 default:
66 break;
67 }
68 }
69 } else {
70 const Map* map = compoundItem->asMap();
71 for (auto& [keyEntry, valueEntry] : *map) {
72 switch (keyEntry->type()) {
73 case ARRAY:
74 case MAP:
75 return false;
76 default:
77 break;
78 }
79 switch (valueEntry->type()) {
80 case ARRAY:
81 case MAP:
82 return false;
83 default:
84 break;
85 }
86 }
87 }
88 return true;
89 }
90
prettyPrintInternal(const Item * item,string & out,size_t indent,size_t maxBStrSize,const vector<string> & mapKeysToNotPrint)91 bool prettyPrintInternal(const Item* item, string& out, size_t indent, size_t maxBStrSize,
92 const vector<string>& mapKeysToNotPrint) {
93 if (!item) {
94 out.append("<NULL>");
95 return false;
96 }
97
98 char buf[80];
99
100 string indentString(indent, ' ');
101
102 size_t tagCount = item->semanticTagCount();
103 while (tagCount > 0) {
104 --tagCount;
105 snprintf(buf, sizeof(buf), "tag %" PRIu64 " ", item->semanticTag(tagCount));
106 out.append(buf);
107 }
108
109 switch (item->type()) {
110 case SEMANTIC:
111 // Handled above.
112 break;
113
114 case UINT:
115 snprintf(buf, sizeof(buf), "%" PRIu64, item->asUint()->unsignedValue());
116 out.append(buf);
117 break;
118
119 case NINT:
120 snprintf(buf, sizeof(buf), "%" PRId64, item->asNint()->value());
121 out.append(buf);
122 break;
123
124 case BSTR: {
125 const uint8_t* valueData;
126 size_t valueSize;
127 const Bstr* bstr = item->asBstr();
128 if (bstr != nullptr) {
129 const vector<uint8_t>& value = bstr->value();
130 valueData = value.data();
131 valueSize = value.size();
132 } else {
133 const ViewBstr* viewBstr = item->asViewBstr();
134 assert(viewBstr != nullptr);
135
136 valueData = viewBstr->view().data();
137 valueSize = viewBstr->view().size();
138 }
139
140 if (valueSize > maxBStrSize) {
141 unsigned char digest[SHA_DIGEST_LENGTH];
142 SHA_CTX ctx;
143 SHA1_Init(&ctx);
144 SHA1_Update(&ctx, valueData, valueSize);
145 SHA1_Final(digest, &ctx);
146 char buf2[SHA_DIGEST_LENGTH * 2 + 1];
147 for (size_t n = 0; n < SHA_DIGEST_LENGTH; n++) {
148 snprintf(buf2 + n * 2, 3, "%02x", digest[n]);
149 }
150 snprintf(buf, sizeof(buf), "<bstr size=%zd sha1=%s>", valueSize, buf2);
151 out.append(buf);
152 } else {
153 out.append("{");
154 for (size_t n = 0; n < valueSize; n++) {
155 if (n > 0) {
156 out.append(", ");
157 }
158 snprintf(buf, sizeof(buf), "0x%02x", valueData[n]);
159 out.append(buf);
160 }
161 out.append("}");
162 }
163 } break;
164
165 case TSTR:
166 out.append("'");
167 {
168 // TODO: escape "'" characters
169 if (item->asTstr() != nullptr) {
170 out.append(item->asTstr()->value().c_str());
171 } else {
172 const ViewTstr* viewTstr = item->asViewTstr();
173 assert(viewTstr != nullptr);
174 out.append(viewTstr->view());
175 }
176 }
177 out.append("'");
178 break;
179
180 case ARRAY: {
181 const Array* array = item->asArray();
182 if (array->size() == 0) {
183 out.append("[]");
184 } else if (cborAreAllElementsNonCompound(array)) {
185 out.append("[");
186 for (size_t n = 0; n < array->size(); n++) {
187 if (!prettyPrintInternal((*array)[n].get(), out, indent + 2, maxBStrSize,
188 mapKeysToNotPrint)) {
189 return false;
190 }
191 out.append(", ");
192 }
193 out.append("]");
194 } else {
195 out.append("[\n" + indentString);
196 for (size_t n = 0; n < array->size(); n++) {
197 out.append(" ");
198 if (!prettyPrintInternal((*array)[n].get(), out, indent + 2, maxBStrSize,
199 mapKeysToNotPrint)) {
200 return false;
201 }
202 out.append(",\n" + indentString);
203 }
204 out.append("]");
205 }
206 } break;
207
208 case MAP: {
209 const Map* map = item->asMap();
210
211 if (map->size() == 0) {
212 out.append("{}");
213 } else {
214 out.append("{\n" + indentString);
215 for (auto& [map_key, map_value] : *map) {
216 out.append(" ");
217
218 if (!prettyPrintInternal(map_key.get(), out, indent + 2, maxBStrSize,
219 mapKeysToNotPrint)) {
220 return false;
221 }
222 out.append(" : ");
223 if (map_key->type() == TSTR &&
224 std::find(mapKeysToNotPrint.begin(), mapKeysToNotPrint.end(),
225 map_key->asTstr()->value()) != mapKeysToNotPrint.end()) {
226 out.append("<not printed>");
227 } else {
228 if (!prettyPrintInternal(map_value.get(), out, indent + 2, maxBStrSize,
229 mapKeysToNotPrint)) {
230 return false;
231 }
232 }
233 out.append(",\n" + indentString);
234 }
235 out.append("}");
236 }
237 } break;
238
239 case SIMPLE:
240 const Bool* asBool = item->asSimple()->asBool();
241 const Null* asNull = item->asSimple()->asNull();
242 if (asBool != nullptr) {
243 out.append(asBool->value() ? "true" : "false");
244 } else if (asNull != nullptr) {
245 out.append("null");
246 } else {
247 #ifndef __TRUSTY__
248 LOG(ERROR) << "Only boolean/null is implemented for SIMPLE";
249 #endif // __TRUSTY__
250 return false;
251 }
252 break;
253 }
254
255 return true;
256 }
257
258 } // namespace
259
headerSize(uint64_t addlInfo)260 size_t headerSize(uint64_t addlInfo) {
261 if (addlInfo < ONE_BYTE_LENGTH) return 1;
262 if (addlInfo <= std::numeric_limits<uint8_t>::max()) return 2;
263 if (addlInfo <= std::numeric_limits<uint16_t>::max()) return 3;
264 if (addlInfo <= std::numeric_limits<uint32_t>::max()) return 5;
265 return 9;
266 }
267
encodeHeader(MajorType type,uint64_t addlInfo,uint8_t * pos,const uint8_t * end)268 uint8_t* encodeHeader(MajorType type, uint64_t addlInfo, uint8_t* pos, const uint8_t* end) {
269 size_t sz = headerSize(addlInfo);
270 if (end - pos < static_cast<ssize_t>(sz)) return nullptr;
271 switch (sz) {
272 case 1:
273 *pos++ = type | static_cast<uint8_t>(addlInfo);
274 return pos;
275 case 2:
276 *pos++ = type | static_cast<MajorType>(ONE_BYTE_LENGTH);
277 *pos++ = static_cast<uint8_t>(addlInfo);
278 return pos;
279 case 3:
280 *pos++ = type | static_cast<MajorType>(TWO_BYTE_LENGTH);
281 return writeBigEndian(static_cast<uint16_t>(addlInfo), pos);
282 case 5:
283 *pos++ = type | static_cast<MajorType>(FOUR_BYTE_LENGTH);
284 return writeBigEndian(static_cast<uint32_t>(addlInfo), pos);
285 case 9:
286 *pos++ = type | static_cast<MajorType>(EIGHT_BYTE_LENGTH);
287 return writeBigEndian(addlInfo, pos);
288 default:
289 CHECK(false); // Impossible to get here.
290 return nullptr;
291 }
292 }
293
encodeHeader(MajorType type,uint64_t addlInfo,EncodeCallback encodeCallback)294 void encodeHeader(MajorType type, uint64_t addlInfo, EncodeCallback encodeCallback) {
295 size_t sz = headerSize(addlInfo);
296 switch (sz) {
297 case 1:
298 encodeCallback(type | static_cast<uint8_t>(addlInfo));
299 break;
300 case 2:
301 encodeCallback(type | static_cast<MajorType>(ONE_BYTE_LENGTH));
302 encodeCallback(static_cast<uint8_t>(addlInfo));
303 break;
304 case 3:
305 encodeCallback(type | static_cast<MajorType>(TWO_BYTE_LENGTH));
306 writeBigEndian(static_cast<uint16_t>(addlInfo), encodeCallback);
307 break;
308 case 5:
309 encodeCallback(type | static_cast<MajorType>(FOUR_BYTE_LENGTH));
310 writeBigEndian(static_cast<uint32_t>(addlInfo), encodeCallback);
311 break;
312 case 9:
313 encodeCallback(type | static_cast<MajorType>(EIGHT_BYTE_LENGTH));
314 writeBigEndian(addlInfo, encodeCallback);
315 break;
316 default:
317 CHECK(false); // Impossible to get here.
318 }
319 }
320
operator ==(const Item & other) const321 bool Item::operator==(const Item& other) const& {
322 if (type() != other.type()) return false;
323 switch (type()) {
324 case UINT:
325 return *asUint() == *(other.asUint());
326 case NINT:
327 return *asNint() == *(other.asNint());
328 case BSTR:
329 if (asBstr() != nullptr && other.asBstr() != nullptr) {
330 return *asBstr() == *(other.asBstr());
331 }
332 if (asViewBstr() != nullptr && other.asViewBstr() != nullptr) {
333 return *asViewBstr() == *(other.asViewBstr());
334 }
335 // Interesting corner case: comparing a Bstr and ViewBstr with
336 // identical contents. The function currently returns false for
337 // this case.
338 // TODO: if it should return true, this needs a deep comparison
339 return false;
340 case TSTR:
341 if (asTstr() != nullptr && other.asTstr() != nullptr) {
342 return *asTstr() == *(other.asTstr());
343 }
344 if (asViewTstr() != nullptr && other.asViewTstr() != nullptr) {
345 return *asViewTstr() == *(other.asViewTstr());
346 }
347 // Same corner case as Bstr
348 return false;
349 case ARRAY:
350 return *asArray() == *(other.asArray());
351 case MAP:
352 return *asMap() == *(other.asMap());
353 case SIMPLE:
354 return *asSimple() == *(other.asSimple());
355 case SEMANTIC:
356 return *asSemanticTag() == *(other.asSemanticTag());
357 default:
358 CHECK(false); // Impossible to get here.
359 return false;
360 }
361 }
362
Nint(int64_t v)363 Nint::Nint(int64_t v) : mValue(v) {
364 CHECK(v < 0);
365 }
366
operator ==(const Simple & other) const367 bool Simple::operator==(const Simple& other) const& {
368 if (simpleType() != other.simpleType()) return false;
369
370 switch (simpleType()) {
371 case BOOLEAN:
372 return *asBool() == *(other.asBool());
373 case NULL_T:
374 return true;
375 default:
376 CHECK(false); // Impossible to get here.
377 return false;
378 }
379 }
380
encode(uint8_t * pos,const uint8_t * end) const381 uint8_t* Bstr::encode(uint8_t* pos, const uint8_t* end) const {
382 pos = encodeHeader(mValue.size(), pos, end);
383 if (!pos || end - pos < static_cast<ptrdiff_t>(mValue.size())) return nullptr;
384 return std::copy(mValue.begin(), mValue.end(), pos);
385 }
386
encodeValue(EncodeCallback encodeCallback) const387 void Bstr::encodeValue(EncodeCallback encodeCallback) const {
388 for (auto c : mValue) {
389 encodeCallback(c);
390 }
391 }
392
encode(uint8_t * pos,const uint8_t * end) const393 uint8_t* ViewBstr::encode(uint8_t* pos, const uint8_t* end) const {
394 pos = encodeHeader(mView.size(), pos, end);
395 if (!pos || end - pos < static_cast<ptrdiff_t>(mView.size())) return nullptr;
396 return std::copy(mView.begin(), mView.end(), pos);
397 }
398
encodeValue(EncodeCallback encodeCallback) const399 void ViewBstr::encodeValue(EncodeCallback encodeCallback) const {
400 for (auto c : mView) {
401 encodeCallback(static_cast<uint8_t>(c));
402 }
403 }
404
encode(uint8_t * pos,const uint8_t * end) const405 uint8_t* Tstr::encode(uint8_t* pos, const uint8_t* end) const {
406 pos = encodeHeader(mValue.size(), pos, end);
407 if (!pos || end - pos < static_cast<ptrdiff_t>(mValue.size())) return nullptr;
408 return std::copy(mValue.begin(), mValue.end(), pos);
409 }
410
encodeValue(EncodeCallback encodeCallback) const411 void Tstr::encodeValue(EncodeCallback encodeCallback) const {
412 for (auto c : mValue) {
413 encodeCallback(static_cast<uint8_t>(c));
414 }
415 }
416
encode(uint8_t * pos,const uint8_t * end) const417 uint8_t* ViewTstr::encode(uint8_t* pos, const uint8_t* end) const {
418 pos = encodeHeader(mView.size(), pos, end);
419 if (!pos || end - pos < static_cast<ptrdiff_t>(mView.size())) return nullptr;
420 return std::copy(mView.begin(), mView.end(), pos);
421 }
422
encodeValue(EncodeCallback encodeCallback) const423 void ViewTstr::encodeValue(EncodeCallback encodeCallback) const {
424 for (auto c : mView) {
425 encodeCallback(static_cast<uint8_t>(c));
426 }
427 }
428
operator ==(const Array & other) const429 bool Array::operator==(const Array& other) const& {
430 return size() == other.size()
431 // Can't use vector::operator== because the contents are pointers. std::equal lets us
432 // provide a predicate that does the dereferencing.
433 && std::equal(mEntries.begin(), mEntries.end(), other.mEntries.begin(),
434 [](auto& a, auto& b) -> bool { return *a == *b; });
435 }
436
encode(uint8_t * pos,const uint8_t * end) const437 uint8_t* Array::encode(uint8_t* pos, const uint8_t* end) const {
438 pos = encodeHeader(size(), pos, end);
439 if (!pos) return nullptr;
440 for (auto& entry : mEntries) {
441 pos = entry->encode(pos, end);
442 if (!pos) return nullptr;
443 }
444 return pos;
445 }
446
encode(EncodeCallback encodeCallback) const447 void Array::encode(EncodeCallback encodeCallback) const {
448 encodeHeader(size(), encodeCallback);
449 for (auto& entry : mEntries) {
450 entry->encode(encodeCallback);
451 }
452 }
453
clone() const454 std::unique_ptr<Item> Array::clone() const {
455 auto res = std::make_unique<Array>();
456 for (size_t i = 0; i < mEntries.size(); i++) {
457 res->add(mEntries[i]->clone());
458 }
459 return res;
460 }
461
operator ==(const Map & other) const462 bool Map::operator==(const Map& other) const& {
463 return size() == other.size()
464 // Can't use vector::operator== because the contents are pairs of pointers. std::equal
465 // lets us provide a predicate that does the dereferencing.
466 && std::equal(begin(), end(), other.begin(), [](auto& a, auto& b) {
467 return *a.first == *b.first && *a.second == *b.second;
468 });
469 }
470
encode(uint8_t * pos,const uint8_t * end) const471 uint8_t* Map::encode(uint8_t* pos, const uint8_t* end) const {
472 pos = encodeHeader(size(), pos, end);
473 if (!pos) return nullptr;
474 for (auto& entry : mEntries) {
475 pos = entry.first->encode(pos, end);
476 if (!pos) return nullptr;
477 pos = entry.second->encode(pos, end);
478 if (!pos) return nullptr;
479 }
480 return pos;
481 }
482
encode(EncodeCallback encodeCallback) const483 void Map::encode(EncodeCallback encodeCallback) const {
484 encodeHeader(size(), encodeCallback);
485 for (auto& entry : mEntries) {
486 entry.first->encode(encodeCallback);
487 entry.second->encode(encodeCallback);
488 }
489 }
490
keyLess(const Item * a,const Item * b)491 bool Map::keyLess(const Item* a, const Item* b) {
492 // CBOR map canonicalization rules are:
493
494 // 1. If two keys have different lengths, the shorter one sorts earlier.
495 if (a->encodedSize() < b->encodedSize()) return true;
496 if (a->encodedSize() > b->encodedSize()) return false;
497
498 // 2. If two keys have the same length, the one with the lower value in (byte-wise) lexical
499 // order sorts earlier. This requires encoding both items.
500 auto encodedA = a->encode();
501 auto encodedB = b->encode();
502
503 return std::lexicographical_compare(encodedA.begin(), encodedA.end(), //
504 encodedB.begin(), encodedB.end());
505 }
506
recursivelyCanonicalize(std::unique_ptr<Item> & item)507 void recursivelyCanonicalize(std::unique_ptr<Item>& item) {
508 switch (item->type()) {
509 case UINT:
510 case NINT:
511 case BSTR:
512 case TSTR:
513 case SIMPLE:
514 return;
515
516 case ARRAY:
517 std::for_each(item->asArray()->begin(), item->asArray()->end(),
518 recursivelyCanonicalize);
519 return;
520
521 case MAP:
522 item->asMap()->canonicalize(true /* recurse */);
523 return;
524
525 case SEMANTIC:
526 // This can't happen. SemanticTags delegate their type() method to the contained Item's
527 // type.
528 assert(false);
529 return;
530 }
531 }
532
canonicalize(bool recurse)533 Map& Map::canonicalize(bool recurse) & {
534 if (recurse) {
535 for (auto& entry : mEntries) {
536 recursivelyCanonicalize(entry.first);
537 recursivelyCanonicalize(entry.second);
538 }
539 }
540
541 if (size() < 2 || mCanonicalized) {
542 // Trivially or already canonical; do nothing.
543 return *this;
544 }
545
546 std::sort(begin(), end(),
547 [](auto& a, auto& b) { return keyLess(a.first.get(), b.first.get()); });
548 mCanonicalized = true;
549 return *this;
550 }
551
clone() const552 std::unique_ptr<Item> Map::clone() const {
553 auto res = std::make_unique<Map>();
554 for (auto& [key, value] : *this) {
555 res->add(key->clone(), value->clone());
556 }
557 res->mCanonicalized = mCanonicalized;
558 return res;
559 }
560
clone() const561 std::unique_ptr<Item> SemanticTag::clone() const {
562 return std::make_unique<SemanticTag>(mValue, mTaggedItem->clone());
563 }
564
encode(uint8_t * pos,const uint8_t * end) const565 uint8_t* SemanticTag::encode(uint8_t* pos, const uint8_t* end) const {
566 // Can't use the encodeHeader() method that calls type() to get the major type, since that will
567 // return the tagged Item's type.
568 pos = ::cppbor::encodeHeader(kMajorType, mValue, pos, end);
569 if (!pos) return nullptr;
570 return mTaggedItem->encode(pos, end);
571 }
572
encode(EncodeCallback encodeCallback) const573 void SemanticTag::encode(EncodeCallback encodeCallback) const {
574 // Can't use the encodeHeader() method that calls type() to get the major type, since that will
575 // return the tagged Item's type.
576 ::cppbor::encodeHeader(kMajorType, mValue, encodeCallback);
577 mTaggedItem->encode(encodeCallback);
578 }
579
semanticTagCount() const580 size_t SemanticTag::semanticTagCount() const {
581 size_t levelCount = 1; // Count this level.
582 const SemanticTag* cur = this;
583 while (cur->mTaggedItem && (cur = cur->mTaggedItem->asSemanticTag()) != nullptr) ++levelCount;
584 return levelCount;
585 }
586
semanticTag(size_t nesting) const587 uint64_t SemanticTag::semanticTag(size_t nesting) const {
588 // Getting the value of a specific nested tag is a bit tricky, because we start with the outer
589 // tag and don't know how many are inside. We count the number of nesting levels to find out
590 // how many there are in total, then to get the one we want we have to walk down levelCount -
591 // nesting steps.
592 size_t levelCount = semanticTagCount();
593 if (nesting >= levelCount) return 0;
594
595 levelCount -= nesting;
596 const SemanticTag* cur = this;
597 while (--levelCount > 0) cur = cur->mTaggedItem->asSemanticTag();
598
599 return cur->mValue;
600 }
601
prettyPrint(const Item * item,size_t maxBStrSize,const vector<string> & mapKeysToNotPrint)602 string prettyPrint(const Item* item, size_t maxBStrSize, const vector<string>& mapKeysToNotPrint) {
603 string out;
604 prettyPrintInternal(item, out, 0, maxBStrSize, mapKeysToNotPrint);
605 return out;
606 }
prettyPrint(const vector<uint8_t> & encodedCbor,size_t maxBStrSize,const vector<string> & mapKeysToNotPrint)607 string prettyPrint(const vector<uint8_t>& encodedCbor, size_t maxBStrSize,
608 const vector<string>& mapKeysToNotPrint) {
609 auto [item, _, message] = parse(encodedCbor);
610 if (item == nullptr) {
611 #ifndef __TRUSTY__
612 LOG(ERROR) << "Data to pretty print is not valid CBOR: " << message;
613 #endif // __TRUSTY__
614 return "";
615 }
616
617 return prettyPrint(item.get(), maxBStrSize, mapKeysToNotPrint);
618 }
619
620 } // namespace cppbor
621