• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright 2021 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Tests decoding a proto with tokenized fields."""
16
17import base64
18import unittest
19
20from pw_tokenizer_tests.detokenize_proto_test_pb2 import TheMessage
21
22from pw_tokenizer import detokenize, encode, tokens
23from pw_tokenizer.proto import detokenize_fields, decode_optionally_tokenized
24
25_DATABASE = tokens.Database(
26    [
27        tokens.TokenizedStringEntry(0xAABBCCDD, "Luke, we're gonna have %s"),
28        tokens.TokenizedStringEntry(0x12345678, "This string has a $oeQAAA=="),
29        tokens.TokenizedStringEntry(0x0000E4A1, "recursive token"),
30    ]
31)
32_DETOKENIZER = detokenize.Detokenizer(_DATABASE)
33
34
35class TestDetokenizeProtoFields(unittest.TestCase):
36    """Tests detokenizing optionally tokenized proto fields."""
37
38    def test_plain_text(self) -> None:
39        proto = TheMessage(message=b'boring conversation anyway!')
40        detokenize_fields(_DETOKENIZER, proto)
41        self.assertEqual(proto.message, b'boring conversation anyway!')
42
43    def test_binary(self) -> None:
44        proto = TheMessage(message=b'\xDD\xCC\xBB\xAA\x07company')
45        detokenize_fields(_DETOKENIZER, proto)
46        self.assertEqual(proto.message, b"Luke, we're gonna have company")
47
48    def test_binary_missing_arguments(self) -> None:
49        proto = TheMessage(message=b'\xDD\xCC\xBB\xAA')
50        detokenize_fields(_DETOKENIZER, proto)
51        self.assertEqual(proto.message, b"Luke, we're gonna have %s")
52
53    def test_recursive_binary(self) -> None:
54        proto = TheMessage(message=b'\x78\x56\x34\x12')
55        detokenize_fields(_DETOKENIZER, proto)
56        self.assertEqual(proto.message, b"This string has a recursive token")
57
58    def test_base64(self) -> None:
59        base64_msg = encode.prefixed_base64(b'\xDD\xCC\xBB\xAA\x07company')
60        proto = TheMessage(message=base64_msg.encode())
61        detokenize_fields(_DETOKENIZER, proto)
62        self.assertEqual(proto.message, b"Luke, we're gonna have company")
63
64    def test_recursive_base64(self) -> None:
65        base64_msg = encode.prefixed_base64(b'\x78\x56\x34\x12')
66        proto = TheMessage(message=base64_msg.encode())
67        detokenize_fields(_DETOKENIZER, proto)
68        self.assertEqual(proto.message, b"This string has a recursive token")
69
70    def test_plain_text_with_prefixed_base64(self) -> None:
71        base64_msg = encode.prefixed_base64(b'\xDD\xCC\xBB\xAA\x09pancakes!')
72        proto = TheMessage(message=f'Good morning, {base64_msg}'.encode())
73        detokenize_fields(_DETOKENIZER, proto)
74        self.assertEqual(
75            proto.message, b"Good morning, Luke, we're gonna have pancakes!"
76        )
77
78    def test_unknown_token_not_utf8(self) -> None:
79        proto = TheMessage(message=b'\xFE\xED\xF0\x0D')
80        detokenize_fields(_DETOKENIZER, proto)
81        self.assertEqual(
82            proto.message.decode(), encode.prefixed_base64(b'\xFE\xED\xF0\x0D')
83        )
84
85    def test_only_control_characters(self) -> None:
86        proto = TheMessage(message=b'\1\2\3\4')
87        detokenize_fields(_DETOKENIZER, proto)
88        self.assertEqual(
89            proto.message.decode(), encode.prefixed_base64(b'\1\2\3\4')
90        )
91
92
93class TestDecodeOptionallyTokenized(unittest.TestCase):
94    """Tests optional detokenization directly."""
95
96    def setUp(self):
97        self.detok = detokenize.Detokenizer(
98            tokens.Database(
99                [
100                    tokens.TokenizedStringEntry(0, 'cheese'),
101                    tokens.TokenizedStringEntry(1, 'on pizza'),
102                    tokens.TokenizedStringEntry(2, 'is quite good'),
103                    tokens.TokenizedStringEntry(3, 'they say'),
104                ]
105            )
106        )
107
108    def test_found_binary_token(self):
109        self.assertEqual(
110            'on pizza',
111            decode_optionally_tokenized(self.detok, b'\x01\x00\x00\x00'),
112        )
113
114    def test_missing_binary_token(self):
115        self.assertEqual(
116            '$' + base64.b64encode(b'\xD5\x8A\xF9\x2A\x8A').decode(),
117            decode_optionally_tokenized(self.detok, b'\xD5\x8A\xF9\x2A\x8A'),
118        )
119
120    def test_found_b64_token(self):
121        b64_bytes = b'$' + base64.b64encode(b'\x03\x00\x00\x00')
122        self.assertEqual(
123            'they say', decode_optionally_tokenized(self.detok, b64_bytes)
124        )
125
126    def test_missing_b64_token(self):
127        b64_bytes = b'$' + base64.b64encode(b'\xD5\x8A\xF9\x2A\x8A')
128        self.assertEqual(
129            b64_bytes.decode(),
130            decode_optionally_tokenized(self.detok, b64_bytes),
131        )
132
133    def test_found_alternate_prefix(self):
134        b64_bytes = b'~' + base64.b64encode(b'\x00\x00\x00\x00')
135        self.assertEqual(
136            'cheese', decode_optionally_tokenized(self.detok, b64_bytes, b'~')
137        )
138
139    def test_missing_alternate_prefix(self):
140        b64_bytes = b'~' + base64.b64encode(b'\x02\x00\x00\x00')
141        self.assertEqual(
142            b64_bytes.decode(),
143            decode_optionally_tokenized(self.detok, b64_bytes, b'^'),
144        )
145
146
147if __name__ == '__main__':
148    unittest.main()
149