• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright 2020 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15r"""Decodes and detokenizes strings from binary or Base64 input.
16
17The main class provided by this module is the Detokenize class. To use it,
18construct it with the path to an ELF or CSV database, a tokens.Database,
19or a file object for an ELF file or CSV. Then, call the detokenize method with
20encoded messages, one at a time. The detokenize method returns a
21DetokenizedString object with the result.
22
23For example,
24
25  from pw_tokenizer import detokenize
26
27  detok = detokenize.Detokenizer('path/to/my/image.elf')
28  print(detok.detokenize(b'\x12\x34\x56\x78\x03hi!'))
29
30This module also provides a command line interface for decoding and detokenizing
31messages from a file or stdin.
32"""
33
34import argparse
35import base64
36import binascii
37import io
38import logging
39import os
40from pathlib import Path
41import re
42import string
43import struct
44import sys
45import time
46from typing import (AnyStr, BinaryIO, Callable, Dict, List, Iterable, IO,
47                    Iterator, Match, NamedTuple, Optional, Pattern, Tuple,
48                    Union)
49
50try:
51    from pw_tokenizer import database, decode, encode, tokens
52except ImportError:
53    # Append this path to the module search path to allow running this module
54    # without installing the pw_tokenizer package.
55    sys.path.append(os.path.dirname(os.path.dirname(
56        os.path.abspath(__file__))))
57    from pw_tokenizer import database, decode, encode, tokens
58
59_LOG = logging.getLogger('pw_tokenizer')
60
61ENCODED_TOKEN = struct.Struct('<I')
62BASE64_PREFIX = encode.BASE64_PREFIX.encode()
63DEFAULT_RECURSION = 9
64
65
66class DetokenizedString:
67    """A detokenized string, with all results if there are collisions."""
68    def __init__(self,
69                 token: Optional[int],
70                 format_string_entries: Iterable[tuple],
71                 encoded_message: bytes,
72                 show_errors: bool = False):
73        self.token = token
74        self.encoded_message = encoded_message
75        self._show_errors = show_errors
76
77        self.successes: List[decode.FormattedString] = []
78        self.failures: List[decode.FormattedString] = []
79
80        decode_attempts: List[Tuple[Tuple, decode.FormattedString]] = []
81
82        for entry, fmt in format_string_entries:
83            result = fmt.format(encoded_message[ENCODED_TOKEN.size:],
84                                show_errors)
85            decode_attempts.append((result.score(entry.date_removed), result))
86
87        # Sort the attempts by the score so the most likely results are first.
88        decode_attempts.sort(key=lambda value: value[0], reverse=True)
89
90        # Split out the successesful decodes from the failures.
91        for score, result in decode_attempts:
92            if score[0]:
93                self.successes.append(result)
94            else:
95                self.failures.append(result)
96
97    def ok(self) -> bool:
98        """True if exactly one string decoded the arguments successfully."""
99        return len(self.successes) == 1
100
101    def matches(self) -> List[decode.FormattedString]:
102        """Returns the strings that matched the token, best matches first."""
103        return self.successes + self.failures
104
105    def best_result(self) -> Optional[decode.FormattedString]:
106        """Returns the string and args for the most likely decoded string."""
107        for string_and_args in self.matches():
108            return string_and_args
109
110        return None
111
112    def error_message(self) -> str:
113        """If detokenization failed, returns a descriptive message."""
114        if self.ok():
115            return ''
116
117        if not self.matches():
118            if self.token is None:
119                return 'missing token'
120
121            return 'unknown token {:08x}'.format(self.token)
122
123        if len(self.matches()) == 1:
124            return 'decoding failed for {!r}'.format(self.matches()[0].value)
125
126        return '{} matches'.format(len(self.matches()))
127
128    def __str__(self) -> str:
129        """Returns the string for the most likely result."""
130        result = self.best_result()
131        if result:
132            return result[0]
133
134        if self._show_errors:
135            return '<[ERROR: {}|{!r}]>'.format(self.error_message(),
136                                               self.encoded_message)
137
138        # Display the string as prefixed Base64 if it cannot be decoded.
139        return encode.prefixed_base64(self.encoded_message)
140
141    def __repr__(self) -> str:
142        if self.ok():
143            message = repr(str(self))
144        else:
145            message = 'ERROR: {}|{!r}'.format(self.error_message(),
146                                              self.encoded_message)
147
148        return '{}({})'.format(type(self).__name__, message)
149
150
151class _TokenizedFormatString(NamedTuple):
152    entry: tokens.TokenizedStringEntry
153    format: decode.FormatString
154
155
156class Detokenizer:
157    """Main detokenization class; detokenizes strings and caches results."""
158    def __init__(self, *token_database_or_elf, show_errors: bool = False):
159        """Decodes and detokenizes binary messages.
160
161        Args:
162          *token_database_or_elf: a path or file object for an ELF or CSV
163              database, a tokens.Database, or an elf_reader.Elf
164          show_errors: if True, an error message is used in place of the %
165              conversion specifier when an argument fails to decode
166        """
167        self.show_errors = show_errors
168
169        # Cache FormatStrings for faster lookup & formatting.
170        self._cache: Dict[int, List[_TokenizedFormatString]] = {}
171
172        self._initialize_database(token_database_or_elf)
173
174    def _initialize_database(self, token_sources: Iterable) -> None:
175        self.database = database.load_token_database(*token_sources)
176        self._cache.clear()
177
178    def lookup(self, token: int) -> List[_TokenizedFormatString]:
179        """Returns (TokenizedStringEntry, FormatString) list for matches."""
180        try:
181            return self._cache[token]
182        except KeyError:
183            format_strings = [
184                _TokenizedFormatString(entry, decode.FormatString(str(entry)))
185                for entry in self.database.token_to_entries[token]
186            ]
187            self._cache[token] = format_strings
188            return format_strings
189
190    def detokenize(self, encoded_message: bytes) -> DetokenizedString:
191        """Decodes and detokenizes a message as a DetokenizedString."""
192        if len(encoded_message) < ENCODED_TOKEN.size:
193            return DetokenizedString(None, (), encoded_message,
194                                     self.show_errors)
195
196        token, = ENCODED_TOKEN.unpack_from(encoded_message)
197        return DetokenizedString(token, self.lookup(token), encoded_message,
198                                 self.show_errors)
199
200    def detokenize_base64(self,
201                          data: AnyStr,
202                          prefix: Union[str, bytes] = BASE64_PREFIX,
203                          recursion: int = DEFAULT_RECURSION) -> AnyStr:
204        """Decodes and replaces prefixed Base64 messages in the provided data.
205
206        Args:
207          data: the binary data to decode
208          prefix: one-character byte string that signals the start of a message
209          recursion: how many levels to recursively decode
210
211        Returns:
212          copy of the data with all recognized tokens decoded
213        """
214        output = io.BytesIO()
215        self.detokenize_base64_to_file(data, output, prefix, recursion)
216        result = output.getvalue()
217        return result.decode() if isinstance(data, str) else result
218
219    def detokenize_base64_to_file(self,
220                                  data: Union[str, bytes],
221                                  output: BinaryIO,
222                                  prefix: Union[str, bytes] = BASE64_PREFIX,
223                                  recursion: int = DEFAULT_RECURSION) -> None:
224        """Decodes prefixed Base64 messages in data; decodes to output file."""
225        data = data.encode() if isinstance(data, str) else data
226        prefix = prefix.encode() if isinstance(prefix, str) else prefix
227
228        output.write(
229            _base64_message_regex(prefix).sub(
230                self._detokenize_prefixed_base64(prefix, recursion), data))
231
232    def detokenize_base64_live(self,
233                               input_file: BinaryIO,
234                               output: BinaryIO,
235                               prefix: Union[str, bytes] = BASE64_PREFIX,
236                               recursion: int = DEFAULT_RECURSION) -> None:
237        """Reads chars one-at-a-time, decoding messages; SLOW for big files."""
238        prefix_bytes = prefix.encode() if isinstance(prefix, str) else prefix
239
240        base64_message = _base64_message_regex(prefix_bytes)
241
242        def transform(data: bytes) -> bytes:
243            return base64_message.sub(
244                self._detokenize_prefixed_base64(prefix_bytes, recursion),
245                data)
246
247        for message in PrefixedMessageDecoder(
248                prefix,
249                string.ascii_letters + string.digits + '+/-_=').transform(
250                    input_file, transform):
251            output.write(message)
252
253            # Flush each line to prevent delays when piping between processes.
254            if b'\n' in message:
255                output.flush()
256
257    def _detokenize_prefixed_base64(
258            self, prefix: bytes,
259            recursion: int) -> Callable[[Match[bytes]], bytes]:
260        """Returns a function that decodes prefixed Base64."""
261        def decode_and_detokenize(match: Match[bytes]) -> bytes:
262            """Decodes prefixed base64 with this detokenizer."""
263            original = match.group(0)
264
265            try:
266                detokenized_string = self.detokenize(
267                    base64.b64decode(original[1:], validate=True))
268                if detokenized_string.matches():
269                    result = str(detokenized_string).encode()
270
271                    if recursion > 0 and original != result:
272                        result = self.detokenize_base64(
273                            result, prefix, recursion - 1)
274
275                    return result
276            except binascii.Error:
277                pass
278
279            return original
280
281        return decode_and_detokenize
282
283
284_PathOrFile = Union[IO, str, Path]
285
286
287class AutoUpdatingDetokenizer(Detokenizer):
288    """Loads and updates a detokenizer from database paths."""
289    class _DatabasePath:
290        """Tracks the modified time of a path or file object."""
291        def __init__(self, path: _PathOrFile) -> None:
292            self.path = path if isinstance(path, (str, Path)) else path.name
293            self._modified_time: Optional[float] = self._last_modified_time()
294
295        def updated(self) -> bool:
296            """True if the path has been updated since the last call."""
297            modified_time = self._last_modified_time()
298            if modified_time is None or modified_time == self._modified_time:
299                return False
300
301            self._modified_time = modified_time
302            return True
303
304        def _last_modified_time(self) -> Optional[float]:
305            try:
306                return os.path.getmtime(self.path)
307            except FileNotFoundError:
308                return None
309
310        def load(self) -> tokens.Database:
311            try:
312                return database.load_token_database(self.path)
313            except FileNotFoundError:
314                return database.load_token_database()
315
316    def __init__(self,
317                 *paths_or_files: _PathOrFile,
318                 min_poll_period_s: float = 1.0) -> None:
319        self.paths = tuple(self._DatabasePath(path) for path in paths_or_files)
320        self.min_poll_period_s = min_poll_period_s
321        self._last_checked_time: float = time.time()
322        super().__init__(*(path.load() for path in self.paths))
323
324    def _reload_if_changed(self) -> None:
325        if time.time() - self._last_checked_time >= self.min_poll_period_s:
326            self._last_checked_time = time.time()
327
328            if any(path.updated() for path in self.paths):
329                _LOG.info('Changes detected; reloading token database')
330                self._initialize_database(path.load() for path in self.paths)
331
332    def lookup(self, token: int) -> List[_TokenizedFormatString]:
333        self._reload_if_changed()
334        return super().lookup(token)
335
336
337class PrefixedMessageDecoder:
338    """Parses messages that start with a prefix character from a byte stream."""
339    def __init__(self, prefix: Union[str, bytes], chars: Union[str, bytes]):
340        """Parses prefixed messages.
341
342        Args:
343          prefix: one character that signifies the start of a message
344          chars: characters allowed in a message
345        """
346        self._prefix = prefix.encode() if isinstance(prefix, str) else prefix
347
348        if isinstance(chars, str):
349            chars = chars.encode()
350
351        # Store the valid message bytes as a set of binary strings.
352        self._message_bytes = frozenset(chars[i:i + 1]
353                                        for i in range(len(chars)))
354
355        if len(self._prefix) != 1 or self._prefix in self._message_bytes:
356            raise ValueError(
357                'Invalid prefix {!r}: the prefix must be a single '
358                'character that is not a valid message character.'.format(
359                    prefix))
360
361        self.data = bytearray()
362
363    def _read_next(self, fd: BinaryIO) -> Tuple[bytes, int]:
364        """Returns the next character and its index."""
365        char = fd.read(1)
366        index = len(self.data)
367        self.data += char
368        return char, index
369
370    def read_messages(self,
371                      binary_fd: BinaryIO) -> Iterator[Tuple[bool, bytes]]:
372        """Parses prefixed messages; yields (is_message, contents) chunks."""
373        message_start = None
374
375        while True:
376            # This reads the file character-by-character. Non-message characters
377            # are yielded right away; message characters are grouped.
378            char, index = self._read_next(binary_fd)
379
380            # If in a message, keep reading until the message completes.
381            if message_start is not None:
382                if char in self._message_bytes:
383                    continue
384
385                yield True, self.data[message_start:index]
386                message_start = None
387
388            # Handle a non-message character.
389            if not char:
390                return
391
392            if char == self._prefix:
393                message_start = index
394            else:
395                yield False, char
396
397    def transform(self, binary_fd: BinaryIO,
398                  transform: Callable[[bytes], bytes]) -> Iterator[bytes]:
399        """Yields the file with a transformation applied to the messages."""
400        for is_message, chunk in self.read_messages(binary_fd):
401            yield transform(chunk) if is_message else chunk
402
403
404def _base64_message_regex(prefix: bytes) -> Pattern[bytes]:
405    """Returns a regular expression for prefixed base64 tokenized strings."""
406    return re.compile(
407        # Base64 tokenized strings start with the prefix character ($)
408        re.escape(prefix) + (
409            # Tokenized strings contain 0 or more blocks of four Base64 chars.
410            br'(?:[A-Za-z0-9+/\-_]{4})*'
411            # The last block of 4 chars may have one or two padding chars (=).
412            br'(?:[A-Za-z0-9+/\-_]{3}=|[A-Za-z0-9+/\-_]{2}==)?'))
413
414
415# TODO(hepler): Remove this unnecessary function.
416def detokenize_base64(detokenizer: Detokenizer,
417                      data: bytes,
418                      prefix: Union[str, bytes] = BASE64_PREFIX,
419                      recursion: int = DEFAULT_RECURSION) -> bytes:
420    """Alias for detokenizer.detokenize_base64 for backwards compatibility."""
421    return detokenizer.detokenize_base64(data, prefix, recursion)
422
423
424def _follow_and_detokenize_file(detokenizer: Detokenizer,
425                                file: BinaryIO,
426                                output: BinaryIO,
427                                prefix: Union[str, bytes],
428                                poll_period_s: float = 0.01) -> None:
429    """Polls a file to detokenize it and any appended data."""
430
431    try:
432        while True:
433            data = file.read()
434            if data:
435                detokenizer.detokenize_base64_to_file(data, output, prefix)
436                output.flush()
437            else:
438                time.sleep(poll_period_s)
439    except KeyboardInterrupt:
440        pass
441
442
443def _handle_base64(databases, input_file: BinaryIO, output: BinaryIO,
444                   prefix: str, show_errors: bool, follow: bool) -> None:
445    """Handles the base64 command line option."""
446    # argparse.FileType doesn't correctly handle - for binary files.
447    if input_file is sys.stdin:
448        input_file = sys.stdin.buffer
449
450    if output is sys.stdout:
451        output = sys.stdout.buffer
452
453    detokenizer = Detokenizer(tokens.Database.merged(*databases),
454                              show_errors=show_errors)
455
456    if follow:
457        _follow_and_detokenize_file(detokenizer, input_file, output, prefix)
458    elif input_file.seekable():
459        # Process seekable files all at once, which is MUCH faster.
460        detokenizer.detokenize_base64_to_file(input_file.read(), output,
461                                              prefix)
462    else:
463        # For non-seekable inputs (e.g. pipes), read one character at a time.
464        detokenizer.detokenize_base64_live(input_file, output, prefix)
465
466
467def _parse_args() -> argparse.Namespace:
468    """Parses and return command line arguments."""
469
470    parser = argparse.ArgumentParser(
471        description=__doc__,
472        formatter_class=argparse.RawDescriptionHelpFormatter)
473    parser.set_defaults(handler=lambda **_: parser.print_help())
474
475    subparsers = parser.add_subparsers(help='Encoding of the input.')
476
477    base64_help = 'Detokenize Base64-encoded data from a file or stdin.'
478    subparser = subparsers.add_parser(
479        'base64',
480        description=base64_help,
481        parents=[database.token_databases_parser()],
482        help=base64_help)
483    subparser.set_defaults(handler=_handle_base64)
484    subparser.add_argument(
485        '-i',
486        '--input',
487        dest='input_file',
488        type=argparse.FileType('rb'),
489        default=sys.stdin.buffer,
490        help='The file from which to read; provide - or omit for stdin.')
491    subparser.add_argument(
492        '-f',
493        '--follow',
494        action='store_true',
495        help=('Detokenize data appended to input_file as it grows; similar to '
496              'tail -f.'))
497    subparser.add_argument('-o',
498                           '--output',
499                           type=argparse.FileType('wb'),
500                           default=sys.stdout.buffer,
501                           help=('The file to which to write the output; '
502                                 'provide - or omit for stdout.'))
503    subparser.add_argument(
504        '-p',
505        '--prefix',
506        default=BASE64_PREFIX,
507        help=('The one-character prefix that signals the start of a '
508              'Base64-encoded message. (default: $)'))
509    subparser.add_argument(
510        '-s',
511        '--show_errors',
512        action='store_true',
513        help=('Show error messages instead of conversion specifiers when '
514              'arguments cannot be decoded.'))
515
516    return parser.parse_args()
517
518
519def main() -> int:
520    args = _parse_args()
521
522    handler = args.handler
523    del args.handler
524
525    handler(**vars(args))
526    return 0
527
528
529if __name__ == '__main__':
530    if sys.version_info[0] < 3:
531        sys.exit('ERROR: The detokenizer command line tools require Python 3.')
532    sys.exit(main())
533