• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Tests for EncodeBase64 and DecodeBase64."""
17
18import base64
19
20import numpy as np
21
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import errors
24from tensorflow.python.framework import test_util
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import string_ops
27from tensorflow.python.platform import test
28
29
30@test_util.run_deprecated_v1
31class Base64OpsTest(test_util.TensorFlowTestCase):
32
33  def setUp(self):
34    self._msg = array_ops.placeholder(dtype=dtypes.string)
35    self._encoded_f = string_ops.encode_base64(self._msg, pad=False)
36    self._decoded_f = string_ops.decode_base64(self._encoded_f)
37    self._encoded_t = string_ops.encode_base64(self._msg, pad=True)
38    self._decoded_t = string_ops.decode_base64(self._encoded_t)
39
40  def _RemovePad(self, msg, base64_msg):
41    if len(msg) % 3 == 1:
42      return base64_msg[:-2]
43    if len(msg) % 3 == 2:
44      return base64_msg[:-1]
45    return base64_msg
46
47  def _RunTest(self, msg, pad):
48    with self.cached_session() as sess:
49      if pad:
50        encoded, decoded = sess.run([self._encoded_t, self._decoded_t],
51                                    feed_dict={self._msg: msg})
52      else:
53        encoded, decoded = sess.run([self._encoded_f, self._decoded_f],
54                                    feed_dict={self._msg: msg})
55
56    if not isinstance(msg, (list, tuple)):
57      msg = [msg]
58      encoded = [encoded]
59      decoded = [decoded]
60
61    base64_msg = [base64.urlsafe_b64encode(m) for m in msg]
62    if not pad:
63      base64_msg = [self._RemovePad(m, b) for m, b in zip(msg, base64_msg)]
64
65    for i in range(len(msg)):
66      self.assertEqual(base64_msg[i], encoded[i])
67      self.assertEqual(msg[i], decoded[i])
68
69  def testWithPythonBase64(self):
70    for pad in (False, True):
71      self._RunTest(b"", pad=pad)
72
73      for _ in range(100):
74        length = np.random.randint(1024 * 1024)
75        msg = np.random.bytes(length)
76        self._RunTest(msg, pad=pad)
77
78  def testShape(self):
79    for pad in (False, True):
80      for _ in range(10):
81        msg = [np.random.bytes(np.random.randint(20))
82               for _ in range(np.random.randint(10))]
83        self._RunTest(msg, pad=pad)
84
85      # Zero-element, non-trivial shapes.
86      for _ in range(10):
87        k = np.random.randint(10)
88        msg = np.empty((0, k), dtype=bytes)
89        encoded = string_ops.encode_base64(msg, pad=pad)
90        decoded = string_ops.decode_base64(encoded)
91
92        with self.cached_session() as sess:
93          encoded_value, decoded_value = self.evaluate([encoded, decoded])
94
95        self.assertEqual(encoded_value.shape, msg.shape)
96        self.assertEqual(decoded_value.shape, msg.shape)
97
98  def testInvalidInput(self):
99    def try_decode(enc):
100      self._decoded_f.eval(feed_dict={self._encoded_f: enc})
101
102    with self.cached_session():
103      # Invalid length.
104      msg = np.random.bytes(99)
105      enc = base64.urlsafe_b64encode(msg)
106      with self.assertRaisesRegex(errors.InvalidArgumentError, "1 modulo 4"):
107        try_decode(enc + b"a")
108
109      # Invalid char used in encoding.
110      msg = np.random.bytes(34)
111      enc = base64.urlsafe_b64encode(msg)
112      for i in range(len(msg)):
113        with self.assertRaises(errors.InvalidArgumentError):
114          try_decode(enc[:i] + b"?" + enc[(i + 1):])
115        with self.assertRaises(errors.InvalidArgumentError):
116          try_decode(enc[:i] + b"\x80" + enc[(i + 1):])  # outside ascii range.
117        with self.assertRaises(errors.InvalidArgumentError):
118          try_decode(enc[:i] + b"+" + enc[(i + 1):])  # not url-safe.
119        with self.assertRaises(errors.InvalidArgumentError):
120          try_decode(enc[:i] + b"/" + enc[(i + 1):])  # not url-safe.
121
122      # Partial padding.
123      msg = np.random.bytes(34)
124      enc = base64.urlsafe_b64encode(msg)
125      with self.assertRaises(errors.InvalidArgumentError):
126        # enc contains == at the end. Partial padding is not allowed.
127        try_decode(enc[:-1])
128
129      # Unnecessary padding.
130      msg = np.random.bytes(33)
131      enc = base64.urlsafe_b64encode(msg)
132      with self.assertRaises(errors.InvalidArgumentError):
133        try_decode(enc + b"==")
134      with self.assertRaises(errors.InvalidArgumentError):
135        try_decode(enc + b"===")
136      with self.assertRaises(errors.InvalidArgumentError):
137        try_decode(enc + b"====")
138
139      # Padding in the middle. (Previous implementation was ok with this as long
140      # as padding char location was 2 or 3 (mod 4).
141      msg = np.random.bytes(33)
142      enc = base64.urlsafe_b64encode(msg)
143      for i in range(len(msg) - 1):
144        with self.assertRaises(errors.InvalidArgumentError):
145          try_decode(enc[:i] + b"=" + enc[(i + 1):])
146      for i in range(len(msg) - 2):
147        with self.assertRaises(errors.InvalidArgumentError):
148          try_decode(enc[:i] + b"==" + enc[(i + 2):])
149
150
151if __name__ == "__main__":
152  test.main()
153