1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/platform/base64.h"
17
18 #include <cstring>
19 #include <memory>
20 #include "tensorflow/core/platform/errors.h"
21
22 namespace tensorflow {
23 namespace {
24 // This array must have signed type.
25 // clang-format off
26 constexpr int8 kBase64Bytes[128] = {
27 -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
28 -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
29 -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
30 -1, -1, -1, -1, -1, -1, -1, -1, -1, 0x3E, -1, -1,
31 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, -1, -1,
32 -1, -1, -1, -1, -1, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06,
33 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12,
34 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, -1, -1, -1, -1, 0x3F,
35 -1, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, 0x23, 0x24,
36 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, 0x30,
37 0x31, 0x32, 0x33, -1, -1, -1, -1, -1};
38 // clang-format on
39
40 constexpr char kBase64UrlSafeChars[65] =
41 "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
42
43 constexpr char kPadChar = '=';
44
45 // Converts a char (8 bits) into a 6-bit value for decoding. If the input char
46 // is invalid for base64 encoding, the return value has at least its upper 25
47 // bits set.
Convert(char x)48 inline uint32 Convert(char x) {
49 // If x < 128, then we look up x in the table. If x is valid, then the table
50 // will have a value <= 0x3F, otherwise the table will have -1. If x >= 128,
51 // we still do some table lookup, but the value is ignored since we explicitly
52 // set the high bit of y to 1. Either way, y is negative (high bit set) in
53 // case of error.
54 const int8_t y = kBase64Bytes[x & 0x7F] | (x & 0x80);
55 // Casting from int8 to int32 preserves sign by sign extension. If y was
56 // negative, at least its 25 high bits of the return value are set.
57 const int32_t z = static_cast<int32>(y);
58 return static_cast<uint32>(z);
59 }
60
DecodeThreeChars(const char * codes,char * result)61 Status DecodeThreeChars(const char* codes, char* result) {
62 const uint32 packed = (Convert(codes[0]) << 18) | (Convert(codes[1]) << 12) |
63 (Convert(codes[2]) << 6) | (Convert(codes[3]));
64 // Convert() return value has upper 25 bits set if input is invalid.
65 // Therefore `packed` has high bits set iff at least one of code is invalid.
66 if (TF_PREDICT_FALSE((packed & 0xFF000000) != 0)) {
67 return errors::InvalidArgument("Invalid character found in base64.");
68 }
69 result[0] = static_cast<char>(packed >> 16);
70 result[1] = static_cast<char>(packed >> 8);
71 result[2] = static_cast<char>(packed);
72 return OkStatus();
73 }
74 } // namespace
75
76 template <typename T>
Base64Decode(StringPiece data,T * decoded)77 Status Base64Decode(StringPiece data, T* decoded) {
78 if (decoded == nullptr) {
79 return errors::Internal("'decoded' cannot be nullptr.");
80 }
81
82 if (data.empty()) {
83 decoded->clear();
84 return OkStatus();
85 }
86
87 // This decoding procedure will write 3 * ceil(data.size() / 4) bytes to be
88 // output buffer, then truncate if necessary. Therefore we must overestimate
89 // and allocate sufficient amount. Currently max_decoded_size may overestimate
90 // by up to 3 bytes.
91 const size_t max_decoded_size = 3 * (data.size() / 4) + 3;
92 std::unique_ptr<char[]> buffer(new char[max_decoded_size]);
93 char* current = buffer.get();
94 if (current == nullptr) {
95 return errors::ResourceExhausted(
96 "Failed to allocate buffer for decoded string.");
97 }
98
99 const char* b64 = data.data();
100 const char* end = data.data() + data.size();
101
102 while (end - b64 > 4) {
103 TF_RETURN_IF_ERROR(DecodeThreeChars(b64, current));
104 b64 += 4;
105 current += 3;
106 }
107
108 if (end - b64 == 4) {
109 // The data length is a multiple of 4. Check for padding.
110 // Base64 cannot have more than 2 paddings.
111 if (b64[2] == kPadChar && b64[3] == kPadChar) {
112 end -= 2;
113 }
114 if (b64[2] != kPadChar && b64[3] == kPadChar) {
115 end -= 1;
116 }
117 }
118
119 const int remain = static_cast<int>(end - b64);
120 if (TF_PREDICT_FALSE(remain == 1)) {
121 // We may check this condition early by checking data.size() % 4 == 1.
122 return errors::InvalidArgument(
123 "Base64 string length cannot be 1 modulo 4.");
124 }
125
126 // A valid base64 character will replace paddings, if any.
127 char tail[4] = {kBase64UrlSafeChars[0], kBase64UrlSafeChars[0],
128 kBase64UrlSafeChars[0], kBase64UrlSafeChars[0]};
129 // Copy tail of the input into the array, then decode.
130 std::memcpy(tail, b64, remain * sizeof(*b64));
131 TF_RETURN_IF_ERROR(DecodeThreeChars(tail, current));
132 // We know how many parsed characters are valid.
133 current += remain - 1;
134
135 decoded->assign(buffer.get(), current - buffer.get());
136 return OkStatus();
137 }
138
139 template <typename T>
Base64Encode(StringPiece source,T * encoded)140 Status Base64Encode(StringPiece source, T* encoded) {
141 return Base64Encode(source, false, encoded);
142 }
143
144 template <typename T>
Base64Encode(StringPiece source,bool with_padding,T * encoded)145 Status Base64Encode(StringPiece source, bool with_padding, T* encoded) {
146 const char* const base64_chars = kBase64UrlSafeChars;
147 if (encoded == nullptr) {
148 return errors::Internal("'encoded' cannot be nullptr.");
149 }
150
151 // max_encoded_size may overestimate by up to 4 bytes.
152 const size_t max_encoded_size = 4 * (source.size() / 3) + 4;
153 std::unique_ptr<char[]> buffer(new char[max_encoded_size]);
154 char* current = buffer.get();
155 if (current == nullptr) {
156 return errors::ResourceExhausted(
157 "Failed to allocate buffer for encoded string.");
158 }
159
160 const char* data = source.data();
161 const char* const end = source.data() + source.size();
162
163 // Encode each block.
164 while (end - data >= 3) {
165 *current++ = base64_chars[(data[0] >> 2) & 0x3F];
166 *current++ =
167 base64_chars[((data[0] & 0x03) << 4) | ((data[1] >> 4) & 0x0F)];
168 *current++ =
169 base64_chars[((data[1] & 0x0F) << 2) | ((data[2] >> 6) & 0x03)];
170 *current++ = base64_chars[data[2] & 0x3F];
171
172 data += 3;
173 }
174
175 // Take care of the tail.
176 if (end - data == 2) {
177 *current++ = base64_chars[(data[0] >> 2) & 0x3F];
178 *current++ =
179 base64_chars[((data[0] & 0x03) << 4) | ((data[1] >> 4) & 0x0F)];
180 *current++ = base64_chars[(data[1] & 0x0F) << 2];
181 if (with_padding) {
182 *current++ = kPadChar;
183 }
184 } else if (end - data == 1) {
185 *current++ = base64_chars[(data[0] >> 2) & 0x3F];
186 *current++ = base64_chars[(data[0] & 0x03) << 4];
187 if (with_padding) {
188 *current++ = kPadChar;
189 *current++ = kPadChar;
190 }
191 }
192
193 encoded->assign(buffer.get(), current - buffer.get());
194 return OkStatus();
195 }
196
197 template Status Base64Decode<std::string>(StringPiece data,
198 std::string* decoded);
199 template Status Base64Encode<std::string>(StringPiece source,
200 std::string* encoded);
201 template Status Base64Encode<std::string>(StringPiece source, bool with_padding,
202 std::string* encoded);
203
204 template Status Base64Decode<tstring>(StringPiece data, tstring* decoded);
205 template Status Base64Encode<tstring>(StringPiece source, tstring* encoded);
206 template Status Base64Encode<tstring>(StringPiece source, bool with_padding,
207 tstring* encoded);
208
209 } // namespace tensorflow
210