1#!/usr/bin/env python3 2# Copyright 2021 The Pigweed Authors 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); you may not 5# use this file except in compliance with the License. You may obtain a copy of 6# the License at 7# 8# https://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13# License for the specific language governing permissions and limitations under 14# the License. 15"""Tests decoding a proto with tokenized fields.""" 16 17import base64 18import unittest 19 20from pw_tokenizer_tests.detokenize_proto_test_pb2 import TheMessage 21 22from pw_tokenizer import detokenize, encode, tokens 23from pw_tokenizer.proto import detokenize_fields, decode_optionally_tokenized 24 25_DATABASE = tokens.Database( 26 [ 27 tokens.TokenizedStringEntry(0xAABBCCDD, "Luke, we're gonna have %s"), 28 tokens.TokenizedStringEntry(0x12345678, "This string has a $oeQAAA=="), 29 tokens.TokenizedStringEntry(0x0000E4A1, "recursive token"), 30 ] 31) 32_DETOKENIZER = detokenize.Detokenizer(_DATABASE) 33 34 35class TestDetokenizeProtoFields(unittest.TestCase): 36 """Tests detokenizing optionally tokenized proto fields.""" 37 38 def test_plain_text(self) -> None: 39 proto = TheMessage(message=b'boring conversation anyway!') 40 detokenize_fields(_DETOKENIZER, proto) 41 self.assertEqual(proto.message, b'boring conversation anyway!') 42 43 def test_binary(self) -> None: 44 proto = TheMessage(message=b'\xDD\xCC\xBB\xAA\x07company') 45 detokenize_fields(_DETOKENIZER, proto) 46 self.assertEqual(proto.message, b"Luke, we're gonna have company") 47 48 def test_binary_missing_arguments(self) -> None: 49 proto = TheMessage(message=b'\xDD\xCC\xBB\xAA') 50 detokenize_fields(_DETOKENIZER, proto) 51 self.assertEqual(proto.message, b"Luke, we're gonna have %s") 52 53 def test_recursive_binary(self) -> None: 54 proto = TheMessage(message=b'\x78\x56\x34\x12') 55 detokenize_fields(_DETOKENIZER, proto) 56 self.assertEqual(proto.message, b"This string has a recursive token") 57 58 def test_base64(self) -> None: 59 base64_msg = encode.prefixed_base64(b'\xDD\xCC\xBB\xAA\x07company') 60 proto = TheMessage(message=base64_msg.encode()) 61 detokenize_fields(_DETOKENIZER, proto) 62 self.assertEqual(proto.message, b"Luke, we're gonna have company") 63 64 def test_recursive_base64(self) -> None: 65 base64_msg = encode.prefixed_base64(b'\x78\x56\x34\x12') 66 proto = TheMessage(message=base64_msg.encode()) 67 detokenize_fields(_DETOKENIZER, proto) 68 self.assertEqual(proto.message, b"This string has a recursive token") 69 70 def test_plain_text_with_prefixed_base64(self) -> None: 71 base64_msg = encode.prefixed_base64(b'\xDD\xCC\xBB\xAA\x09pancakes!') 72 proto = TheMessage(message=f'Good morning, {base64_msg}'.encode()) 73 detokenize_fields(_DETOKENIZER, proto) 74 self.assertEqual( 75 proto.message, b"Good morning, Luke, we're gonna have pancakes!" 76 ) 77 78 def test_unknown_token_not_utf8(self) -> None: 79 proto = TheMessage(message=b'\xFE\xED\xF0\x0D') 80 detokenize_fields(_DETOKENIZER, proto) 81 self.assertEqual( 82 proto.message.decode(), encode.prefixed_base64(b'\xFE\xED\xF0\x0D') 83 ) 84 85 def test_only_control_characters(self) -> None: 86 proto = TheMessage(message=b'\1\2\3\4') 87 detokenize_fields(_DETOKENIZER, proto) 88 self.assertEqual( 89 proto.message.decode(), encode.prefixed_base64(b'\1\2\3\4') 90 ) 91 92 93class TestDecodeOptionallyTokenized(unittest.TestCase): 94 """Tests optional detokenization directly.""" 95 96 def setUp(self): 97 self.detok = detokenize.Detokenizer( 98 tokens.Database( 99 [ 100 tokens.TokenizedStringEntry(0, 'cheese'), 101 tokens.TokenizedStringEntry(1, 'on pizza'), 102 tokens.TokenizedStringEntry(2, 'is quite good'), 103 tokens.TokenizedStringEntry(3, 'they say'), 104 ] 105 ) 106 ) 107 108 def test_found_binary_token(self): 109 self.assertEqual( 110 'on pizza', 111 decode_optionally_tokenized(self.detok, b'\x01\x00\x00\x00'), 112 ) 113 114 def test_missing_binary_token(self): 115 self.assertEqual( 116 '$' + base64.b64encode(b'\xD5\x8A\xF9\x2A\x8A').decode(), 117 decode_optionally_tokenized(self.detok, b'\xD5\x8A\xF9\x2A\x8A'), 118 ) 119 120 def test_found_b64_token(self): 121 b64_bytes = b'$' + base64.b64encode(b'\x03\x00\x00\x00') 122 self.assertEqual( 123 'they say', decode_optionally_tokenized(self.detok, b64_bytes) 124 ) 125 126 def test_missing_b64_token(self): 127 b64_bytes = b'$' + base64.b64encode(b'\xD5\x8A\xF9\x2A\x8A') 128 self.assertEqual( 129 b64_bytes.decode(), 130 decode_optionally_tokenized(self.detok, b64_bytes), 131 ) 132 133 def test_found_alternate_prefix(self): 134 b64_bytes = b'~' + base64.b64encode(b'\x00\x00\x00\x00') 135 self.assertEqual( 136 'cheese', decode_optionally_tokenized(self.detok, b64_bytes, b'~') 137 ) 138 139 def test_missing_alternate_prefix(self): 140 b64_bytes = b'~' + base64.b64encode(b'\x02\x00\x00\x00') 141 self.assertEqual( 142 b64_bytes.decode(), 143 decode_optionally_tokenized(self.detok, b64_bytes, b'^'), 144 ) 145 146 147if __name__ == '__main__': 148 unittest.main() 149