• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 ///////////////////////////////////////////////////////////////////////////////
16 
17 #include "tink/aead/cord_aead_wrapper.h"
18 
19 #include <memory>
20 #include <string>
21 #include <utility>
22 
23 #include "absl/status/status.h"
24 #include "absl/strings/cord.h"
25 #include "tink/aead/cord_aead.h"
26 #include "tink/crypto_format.h"
27 #include "tink/primitive_set.h"
28 #include "tink/util/status.h"
29 #include "tink/util/statusor.h"
30 
31 namespace crypto {
32 namespace tink {
33 
34 namespace {
35 
Validate(PrimitiveSet<CordAead> * aead_set)36 util::Status Validate(PrimitiveSet<CordAead>* aead_set) {
37   if (aead_set == nullptr) {
38     return util::Status(absl::StatusCode::kInternal,
39                         "aead_set must be non-NULL");
40   }
41   if (aead_set->get_primary() == nullptr) {
42     return util::Status(absl::StatusCode::kInvalidArgument,
43                         "aead_set has no primary");
44   }
45   return util::OkStatus();
46 }
47 
48 class CordAeadSetWrapper : public CordAead {
49  public:
CordAeadSetWrapper(std::unique_ptr<PrimitiveSet<CordAead>> aead_set)50   explicit CordAeadSetWrapper(std::unique_ptr<PrimitiveSet<CordAead>> aead_set)
51       : aead_set_(std::move(aead_set)) {}
52 
53   crypto::tink::util::StatusOr<absl::Cord> Encrypt(
54       absl::Cord plaintext, absl::Cord associated_data) const override;
55 
56   crypto::tink::util::StatusOr<absl::Cord> Decrypt(
57       absl::Cord ciphertext, absl::Cord associated_data) const override;
58 
59   ~CordAeadSetWrapper() override = default;
60 
61  private:
62   std::unique_ptr<PrimitiveSet<CordAead>> aead_set_;
63 };
64 
Encrypt(absl::Cord plaintext,absl::Cord associated_data) const65 util::StatusOr<absl::Cord> CordAeadSetWrapper::Encrypt(
66     absl::Cord plaintext, absl::Cord associated_data) const {
67   auto encrypt_result = aead_set_->get_primary()->get_primitive().Encrypt(
68       plaintext, associated_data);
69   if (!encrypt_result.ok()) return encrypt_result.status();
70   absl::Cord result;
71   result.Append(aead_set_->get_primary()->get_identifier());
72   result.Append(encrypt_result.value());
73   return result;
74 }
75 
Decrypt(absl::Cord ciphertext,absl::Cord associated_data) const76 util::StatusOr<absl::Cord> CordAeadSetWrapper::Decrypt(
77     absl::Cord ciphertext, absl::Cord associated_data) const {
78   if (ciphertext.size() > CryptoFormat::kNonRawPrefixSize) {
79     std::string key_id =
80         std::string(ciphertext.Subcord(0, CryptoFormat::kNonRawPrefixSize));
81     auto primitives_result = aead_set_->get_primitives(key_id);
82     if (primitives_result.ok()) {
83       auto raw_ciphertext =
84           ciphertext.Subcord(key_id.size(), ciphertext.size());
85       for (auto& aead_entry : *(primitives_result.value())) {
86         CordAead& aead = aead_entry->get_primitive();
87         auto decrypt_result = aead.Decrypt(raw_ciphertext, associated_data);
88         if (decrypt_result.ok()) {
89           return std::move(decrypt_result.value());
90         } else {
91           // LOG that a matching key didn't decrypt the ciphertext.
92         }
93       }
94     }
95   }
96 
97   // No matching key succeeded with decryption, try all RAW keys.
98   auto raw_primitives_result = aead_set_->get_raw_primitives();
99   if (raw_primitives_result.ok()) {
100     for (auto& aead_entry : *(raw_primitives_result.value())) {
101       CordAead& aead = aead_entry->get_primitive();
102       auto decrypt_result = aead.Decrypt(ciphertext, associated_data);
103       if (decrypt_result.ok()) {
104         return std::move(decrypt_result.value());
105       }
106     }
107   }
108   return util::Status(absl::StatusCode::kInvalidArgument, "decryption failed");
109 }
110 }  // anonymous namespace
111 
Wrap(std::unique_ptr<PrimitiveSet<CordAead>> aead_set) const112 util::StatusOr<std::unique_ptr<CordAead>> CordAeadWrapper::Wrap(
113     std::unique_ptr<PrimitiveSet<CordAead>> aead_set) const {
114   util::Status status = Validate(aead_set.get());
115   if (!status.ok()) return status;
116   std::unique_ptr<CordAead> aead(new CordAeadSetWrapper(std::move(aead_set)));
117   return std::move(aead);
118 }
119 
120 }  // namespace tink
121 }  // namespace crypto
122