• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/platform/base64.h"
17 
18 #include <cstring>
19 #include <memory>
20 #include "tensorflow/core/platform/errors.h"
21 
22 namespace tensorflow {
23 namespace {
24 // This array must have signed type.
25 // clang-format off
26 constexpr int8 kBase64Bytes[128] = {
27      -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
28      -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
29      -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
30      -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,  0x3E,  -1,   -1,
31     0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D,  -1,   -1,
32      -1,   -1,   -1,   -1,   -1,  0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06,
33     0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12,
34     0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19,  -1,   -1,   -1,   -1,  0x3F,
35      -1,  0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, 0x23, 0x24,
36     0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, 0x30,
37     0x31, 0x32, 0x33,  -1,   -1,   -1,   -1,   -1};
38 // clang-format on
39 
40 constexpr char kBase64UrlSafeChars[65] =
41     "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
42 
43 constexpr char kPadChar = '=';
44 
45 // Converts a char (8 bits) into a 6-bit value for decoding. If the input char
46 // is invalid for base64 encoding, the return value has at least its upper 25
47 // bits set.
Convert(char x)48 inline uint32 Convert(char x) {
49   // If x < 128, then we look up x in the table. If x is valid, then the table
50   // will have a value <= 0x3F, otherwise the table will have -1. If x >= 128,
51   // we still do some table lookup, but the value is ignored since we explicitly
52   // set the high bit of y to 1. Either way, y is negative (high bit set) in
53   // case of error.
54   const int8_t y = kBase64Bytes[x & 0x7F] | (x & 0x80);
55   // Casting from int8 to int32 preserves sign by sign extension. If y was
56   // negative, at least its 25 high bits of the return value are set.
57   const int32_t z = static_cast<int32>(y);
58   return static_cast<uint32>(z);
59 }
60 
DecodeThreeChars(const char * codes,char * result)61 Status DecodeThreeChars(const char* codes, char* result) {
62   const uint32 packed = (Convert(codes[0]) << 18) | (Convert(codes[1]) << 12) |
63                         (Convert(codes[2]) << 6) | (Convert(codes[3]));
64   // Convert() return value has upper 25 bits set if input is invalid.
65   // Therefore `packed` has high bits set iff at least one of code is invalid.
66   if (TF_PREDICT_FALSE((packed & 0xFF000000) != 0)) {
67     return errors::InvalidArgument("Invalid character found in base64.");
68   }
69   result[0] = static_cast<char>(packed >> 16);
70   result[1] = static_cast<char>(packed >> 8);
71   result[2] = static_cast<char>(packed);
72   return OkStatus();
73 }
74 }  // namespace
75 
76 template <typename T>
Base64Decode(StringPiece data,T * decoded)77 Status Base64Decode(StringPiece data, T* decoded) {
78   if (decoded == nullptr) {
79     return errors::Internal("'decoded' cannot be nullptr.");
80   }
81 
82   if (data.empty()) {
83     decoded->clear();
84     return OkStatus();
85   }
86 
87   // This decoding procedure will write 3 * ceil(data.size() / 4) bytes to be
88   // output buffer, then truncate if necessary. Therefore we must overestimate
89   // and allocate sufficient amount. Currently max_decoded_size may overestimate
90   // by up to 3 bytes.
91   const size_t max_decoded_size = 3 * (data.size() / 4) + 3;
92   std::unique_ptr<char[]> buffer(new char[max_decoded_size]);
93   char* current = buffer.get();
94   if (current == nullptr) {
95     return errors::ResourceExhausted(
96         "Failed to allocate buffer for decoded string.");
97   }
98 
99   const char* b64 = data.data();
100   const char* end = data.data() + data.size();
101 
102   while (end - b64 > 4) {
103     TF_RETURN_IF_ERROR(DecodeThreeChars(b64, current));
104     b64 += 4;
105     current += 3;
106   }
107 
108   if (end - b64 == 4) {
109     // The data length is a multiple of 4. Check for padding.
110     // Base64 cannot have more than 2 paddings.
111     if (b64[2] == kPadChar && b64[3] == kPadChar) {
112       end -= 2;
113     }
114     if (b64[2] != kPadChar && b64[3] == kPadChar) {
115       end -= 1;
116     }
117   }
118 
119   const int remain = static_cast<int>(end - b64);
120   if (TF_PREDICT_FALSE(remain == 1)) {
121     // We may check this condition early by checking data.size() % 4 == 1.
122     return errors::InvalidArgument(
123         "Base64 string length cannot be 1 modulo 4.");
124   }
125 
126   // A valid base64 character will replace paddings, if any.
127   char tail[4] = {kBase64UrlSafeChars[0], kBase64UrlSafeChars[0],
128                   kBase64UrlSafeChars[0], kBase64UrlSafeChars[0]};
129   // Copy tail of the input into the array, then decode.
130   std::memcpy(tail, b64, remain * sizeof(*b64));
131   TF_RETURN_IF_ERROR(DecodeThreeChars(tail, current));
132   // We know how many parsed characters are valid.
133   current += remain - 1;
134 
135   decoded->assign(buffer.get(), current - buffer.get());
136   return OkStatus();
137 }
138 
139 template <typename T>
Base64Encode(StringPiece source,T * encoded)140 Status Base64Encode(StringPiece source, T* encoded) {
141   return Base64Encode(source, false, encoded);
142 }
143 
144 template <typename T>
Base64Encode(StringPiece source,bool with_padding,T * encoded)145 Status Base64Encode(StringPiece source, bool with_padding, T* encoded) {
146   const char* const base64_chars = kBase64UrlSafeChars;
147   if (encoded == nullptr) {
148     return errors::Internal("'encoded' cannot be nullptr.");
149   }
150 
151   // max_encoded_size may overestimate by up to 4 bytes.
152   const size_t max_encoded_size = 4 * (source.size() / 3) + 4;
153   std::unique_ptr<char[]> buffer(new char[max_encoded_size]);
154   char* current = buffer.get();
155   if (current == nullptr) {
156     return errors::ResourceExhausted(
157         "Failed to allocate buffer for encoded string.");
158   }
159 
160   const char* data = source.data();
161   const char* const end = source.data() + source.size();
162 
163   // Encode each block.
164   while (end - data >= 3) {
165     *current++ = base64_chars[(data[0] >> 2) & 0x3F];
166     *current++ =
167         base64_chars[((data[0] & 0x03) << 4) | ((data[1] >> 4) & 0x0F)];
168     *current++ =
169         base64_chars[((data[1] & 0x0F) << 2) | ((data[2] >> 6) & 0x03)];
170     *current++ = base64_chars[data[2] & 0x3F];
171 
172     data += 3;
173   }
174 
175   // Take care of the tail.
176   if (end - data == 2) {
177     *current++ = base64_chars[(data[0] >> 2) & 0x3F];
178     *current++ =
179         base64_chars[((data[0] & 0x03) << 4) | ((data[1] >> 4) & 0x0F)];
180     *current++ = base64_chars[(data[1] & 0x0F) << 2];
181     if (with_padding) {
182       *current++ = kPadChar;
183     }
184   } else if (end - data == 1) {
185     *current++ = base64_chars[(data[0] >> 2) & 0x3F];
186     *current++ = base64_chars[(data[0] & 0x03) << 4];
187     if (with_padding) {
188       *current++ = kPadChar;
189       *current++ = kPadChar;
190     }
191   }
192 
193   encoded->assign(buffer.get(), current - buffer.get());
194   return OkStatus();
195 }
196 
197 template Status Base64Decode<std::string>(StringPiece data,
198                                           std::string* decoded);
199 template Status Base64Encode<std::string>(StringPiece source,
200                                           std::string* encoded);
201 template Status Base64Encode<std::string>(StringPiece source, bool with_padding,
202                                           std::string* encoded);
203 
204 template Status Base64Decode<tstring>(StringPiece data, tstring* decoded);
205 template Status Base64Encode<tstring>(StringPiece source, tstring* encoded);
206 template Status Base64Encode<tstring>(StringPiece source, bool with_padding,
207                                       tstring* encoded);
208 
209 }  // namespace tensorflow
210