1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Tests for EncodeBase64 and DecodeBase64.""" 17 18import base64 19 20import numpy as np 21 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import errors 24from tensorflow.python.framework import test_util 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import string_ops 27from tensorflow.python.platform import test 28 29 30@test_util.run_deprecated_v1 31class Base64OpsTest(test_util.TensorFlowTestCase): 32 33 def setUp(self): 34 self._msg = array_ops.placeholder(dtype=dtypes.string) 35 self._encoded_f = string_ops.encode_base64(self._msg, pad=False) 36 self._decoded_f = string_ops.decode_base64(self._encoded_f) 37 self._encoded_t = string_ops.encode_base64(self._msg, pad=True) 38 self._decoded_t = string_ops.decode_base64(self._encoded_t) 39 40 def _RemovePad(self, msg, base64_msg): 41 if len(msg) % 3 == 1: 42 return base64_msg[:-2] 43 if len(msg) % 3 == 2: 44 return base64_msg[:-1] 45 return base64_msg 46 47 def _RunTest(self, msg, pad): 48 with self.cached_session() as sess: 49 if pad: 50 encoded, decoded = sess.run([self._encoded_t, self._decoded_t], 51 feed_dict={self._msg: msg}) 52 else: 53 encoded, decoded = sess.run([self._encoded_f, self._decoded_f], 54 feed_dict={self._msg: msg}) 55 56 if not isinstance(msg, (list, tuple)): 57 msg = [msg] 58 encoded = [encoded] 59 decoded = [decoded] 60 61 base64_msg = [base64.urlsafe_b64encode(m) for m in msg] 62 if not pad: 63 base64_msg = [self._RemovePad(m, b) for m, b in zip(msg, base64_msg)] 64 65 for i in range(len(msg)): 66 self.assertEqual(base64_msg[i], encoded[i]) 67 self.assertEqual(msg[i], decoded[i]) 68 69 def testWithPythonBase64(self): 70 for pad in (False, True): 71 self._RunTest(b"", pad=pad) 72 73 for _ in range(100): 74 length = np.random.randint(1024 * 1024) 75 msg = np.random.bytes(length) 76 self._RunTest(msg, pad=pad) 77 78 def testShape(self): 79 for pad in (False, True): 80 for _ in range(10): 81 msg = [np.random.bytes(np.random.randint(20)) 82 for _ in range(np.random.randint(10))] 83 self._RunTest(msg, pad=pad) 84 85 # Zero-element, non-trivial shapes. 86 for _ in range(10): 87 k = np.random.randint(10) 88 msg = np.empty((0, k), dtype=bytes) 89 encoded = string_ops.encode_base64(msg, pad=pad) 90 decoded = string_ops.decode_base64(encoded) 91 92 with self.cached_session() as sess: 93 encoded_value, decoded_value = self.evaluate([encoded, decoded]) 94 95 self.assertEqual(encoded_value.shape, msg.shape) 96 self.assertEqual(decoded_value.shape, msg.shape) 97 98 def testInvalidInput(self): 99 def try_decode(enc): 100 self._decoded_f.eval(feed_dict={self._encoded_f: enc}) 101 102 with self.cached_session(): 103 # Invalid length. 104 msg = np.random.bytes(99) 105 enc = base64.urlsafe_b64encode(msg) 106 with self.assertRaisesRegex(errors.InvalidArgumentError, "1 modulo 4"): 107 try_decode(enc + b"a") 108 109 # Invalid char used in encoding. 110 msg = np.random.bytes(34) 111 enc = base64.urlsafe_b64encode(msg) 112 for i in range(len(msg)): 113 with self.assertRaises(errors.InvalidArgumentError): 114 try_decode(enc[:i] + b"?" + enc[(i + 1):]) 115 with self.assertRaises(errors.InvalidArgumentError): 116 try_decode(enc[:i] + b"\x80" + enc[(i + 1):]) # outside ascii range. 117 with self.assertRaises(errors.InvalidArgumentError): 118 try_decode(enc[:i] + b"+" + enc[(i + 1):]) # not url-safe. 119 with self.assertRaises(errors.InvalidArgumentError): 120 try_decode(enc[:i] + b"/" + enc[(i + 1):]) # not url-safe. 121 122 # Partial padding. 123 msg = np.random.bytes(34) 124 enc = base64.urlsafe_b64encode(msg) 125 with self.assertRaises(errors.InvalidArgumentError): 126 # enc contains == at the end. Partial padding is not allowed. 127 try_decode(enc[:-1]) 128 129 # Unnecessary padding. 130 msg = np.random.bytes(33) 131 enc = base64.urlsafe_b64encode(msg) 132 with self.assertRaises(errors.InvalidArgumentError): 133 try_decode(enc + b"==") 134 with self.assertRaises(errors.InvalidArgumentError): 135 try_decode(enc + b"===") 136 with self.assertRaises(errors.InvalidArgumentError): 137 try_decode(enc + b"====") 138 139 # Padding in the middle. (Previous implementation was ok with this as long 140 # as padding char location was 2 or 3 (mod 4). 141 msg = np.random.bytes(33) 142 enc = base64.urlsafe_b64encode(msg) 143 for i in range(len(msg) - 1): 144 with self.assertRaises(errors.InvalidArgumentError): 145 try_decode(enc[:i] + b"=" + enc[(i + 1):]) 146 for i in range(len(msg) - 2): 147 with self.assertRaises(errors.InvalidArgumentError): 148 try_decode(enc[:i] + b"==" + enc[(i + 2):]) 149 150 151if __name__ == "__main__": 152 test.main() 153