• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright 2020 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15r"""Decodes and detokenizes strings from binary or Base64 input.
16
17The main class provided by this module is the Detokenize class. To use it,
18construct it with the path to an ELF or CSV database, a tokens.Database,
19or a file object for an ELF file or CSV. Then, call the detokenize method with
20encoded messages, one at a time. The detokenize method returns a
21DetokenizedString object with the result.
22
23For example,
24
25  from pw_tokenizer import detokenize
26
27  detok = detokenize.Detokenizer('path/to/my/image.elf')
28  print(detok.detokenize(b'\x12\x34\x56\x78\x03hi!'))
29
30This module also provides a command line interface for decoding and detokenizing
31messages from a file or stdin.
32"""
33
34import argparse
35import base64
36import binascii
37import io
38import logging
39import os
40from pathlib import Path
41import re
42import string
43import struct
44import sys
45import time
46from typing import (
47    AnyStr,
48    BinaryIO,
49    Callable,
50    Dict,
51    List,
52    Iterable,
53    Iterator,
54    Match,
55    NamedTuple,
56    Optional,
57    Pattern,
58    Tuple,
59    Union,
60)
61
62try:
63    from pw_tokenizer import database, decode, encode, tokens
64except ImportError:
65    # Append this path to the module search path to allow running this module
66    # without installing the pw_tokenizer package.
67    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
68    from pw_tokenizer import database, decode, encode, tokens
69
70_LOG = logging.getLogger('pw_tokenizer')
71
72ENCODED_TOKEN = struct.Struct('<I')
73BASE64_PREFIX = encode.BASE64_PREFIX.encode()
74DEFAULT_RECURSION = 9
75
76_RawIO = Union[io.RawIOBase, BinaryIO]
77
78
79class DetokenizedString:
80    """A detokenized string, with all results if there are collisions."""
81
82    def __init__(
83        self,
84        token: Optional[int],
85        format_string_entries: Iterable[tuple],
86        encoded_message: bytes,
87        show_errors: bool = False,
88    ):
89        self.token = token
90        self.encoded_message = encoded_message
91        self._show_errors = show_errors
92
93        self.successes: List[decode.FormattedString] = []
94        self.failures: List[decode.FormattedString] = []
95
96        decode_attempts: List[Tuple[Tuple, decode.FormattedString]] = []
97
98        for entry, fmt in format_string_entries:
99            result = fmt.format(
100                encoded_message[ENCODED_TOKEN.size :], show_errors
101            )
102            decode_attempts.append((result.score(entry.date_removed), result))
103
104        # Sort the attempts by the score so the most likely results are first.
105        decode_attempts.sort(key=lambda value: value[0], reverse=True)
106
107        # Split out the successesful decodes from the failures.
108        for score, result in decode_attempts:
109            if score[0]:
110                self.successes.append(result)
111            else:
112                self.failures.append(result)
113
114    def ok(self) -> bool:
115        """True if exactly one string decoded the arguments successfully."""
116        return len(self.successes) == 1
117
118    def matches(self) -> List[decode.FormattedString]:
119        """Returns the strings that matched the token, best matches first."""
120        return self.successes + self.failures
121
122    def best_result(self) -> Optional[decode.FormattedString]:
123        """Returns the string and args for the most likely decoded string."""
124        for string_and_args in self.matches():
125            return string_and_args
126
127        return None
128
129    def error_message(self) -> str:
130        """If detokenization failed, returns a descriptive message."""
131        if self.ok():
132            return ''
133
134        if not self.matches():
135            if self.token is None:
136                return 'missing token'
137
138            return 'unknown token {:08x}'.format(self.token)
139
140        if len(self.matches()) == 1:
141            return 'decoding failed for {!r}'.format(self.matches()[0].value)
142
143        return '{} matches'.format(len(self.matches()))
144
145    def __str__(self) -> str:
146        """Returns the string for the most likely result."""
147        result = self.best_result()
148        if result:
149            return result[0]
150
151        if self._show_errors:
152            return '<[ERROR: {}|{!r}]>'.format(
153                self.error_message(), self.encoded_message
154            )
155
156        # Display the string as prefixed Base64 if it cannot be decoded.
157        return encode.prefixed_base64(self.encoded_message)
158
159    def __repr__(self) -> str:
160        if self.ok():
161            message = repr(str(self))
162        else:
163            message = 'ERROR: {}|{!r}'.format(
164                self.error_message(), self.encoded_message
165            )
166
167        return '{}({})'.format(type(self).__name__, message)
168
169
170class _TokenizedFormatString(NamedTuple):
171    entry: tokens.TokenizedStringEntry
172    format: decode.FormatString
173
174
175class Detokenizer:
176    """Main detokenization class; detokenizes strings and caches results."""
177
178    def __init__(self, *token_database_or_elf, show_errors: bool = False):
179        """Decodes and detokenizes binary messages.
180
181        Args:
182          *token_database_or_elf: a path or file object for an ELF or CSV
183              database, a tokens.Database, or an elf_reader.Elf
184          show_errors: if True, an error message is used in place of the %
185              conversion specifier when an argument fails to decode
186        """
187        self.show_errors = show_errors
188
189        # Cache FormatStrings for faster lookup & formatting.
190        self._cache: Dict[int, List[_TokenizedFormatString]] = {}
191
192        self._initialize_database(token_database_or_elf)
193
194    def _initialize_database(self, token_sources: Iterable) -> None:
195        self.database = database.load_token_database(*token_sources)
196        self._cache.clear()
197
198    def lookup(self, token: int) -> List[_TokenizedFormatString]:
199        """Returns (TokenizedStringEntry, FormatString) list for matches."""
200        try:
201            return self._cache[token]
202        except KeyError:
203            format_strings = [
204                _TokenizedFormatString(entry, decode.FormatString(str(entry)))
205                for entry in self.database.token_to_entries[token]
206            ]
207            self._cache[token] = format_strings
208            return format_strings
209
210    def detokenize(self, encoded_message: bytes) -> DetokenizedString:
211        """Decodes and detokenizes a message as a DetokenizedString."""
212        if not encoded_message:
213            return DetokenizedString(
214                None, (), encoded_message, self.show_errors
215            )
216
217        # Pad messages smaller than ENCODED_TOKEN.size with zeroes to support
218        # tokens smaller than a uint32. Messages with arguments must always use
219        # a full 32-bit token.
220        missing_token_bytes = ENCODED_TOKEN.size - len(encoded_message)
221        if missing_token_bytes > 0:
222            encoded_message += b'\0' * missing_token_bytes
223
224        (token,) = ENCODED_TOKEN.unpack_from(encoded_message)
225        return DetokenizedString(
226            token, self.lookup(token), encoded_message, self.show_errors
227        )
228
229    def detokenize_base64(
230        self,
231        data: AnyStr,
232        prefix: Union[str, bytes] = BASE64_PREFIX,
233        recursion: int = DEFAULT_RECURSION,
234    ) -> AnyStr:
235        """Decodes and replaces prefixed Base64 messages in the provided data.
236
237        Args:
238          data: the binary data to decode
239          prefix: one-character byte string that signals the start of a message
240          recursion: how many levels to recursively decode
241
242        Returns:
243          copy of the data with all recognized tokens decoded
244        """
245        output = io.BytesIO()
246        self.detokenize_base64_to_file(data, output, prefix, recursion)
247        result = output.getvalue()
248        return result.decode() if isinstance(data, str) else result
249
250    def detokenize_base64_to_file(
251        self,
252        data: Union[str, bytes],
253        output: BinaryIO,
254        prefix: Union[str, bytes] = BASE64_PREFIX,
255        recursion: int = DEFAULT_RECURSION,
256    ) -> None:
257        """Decodes prefixed Base64 messages in data; decodes to output file."""
258        data = data.encode() if isinstance(data, str) else data
259        prefix = prefix.encode() if isinstance(prefix, str) else prefix
260
261        output.write(
262            _base64_message_regex(prefix).sub(
263                self._detokenize_prefixed_base64(prefix, recursion), data
264            )
265        )
266
267    def detokenize_base64_live(
268        self,
269        input_file: _RawIO,
270        output: BinaryIO,
271        prefix: Union[str, bytes] = BASE64_PREFIX,
272        recursion: int = DEFAULT_RECURSION,
273    ) -> None:
274        """Reads chars one-at-a-time, decoding messages; SLOW for big files."""
275        prefix_bytes = prefix.encode() if isinstance(prefix, str) else prefix
276
277        base64_message = _base64_message_regex(prefix_bytes)
278
279        def transform(data: bytes) -> bytes:
280            return base64_message.sub(
281                self._detokenize_prefixed_base64(prefix_bytes, recursion), data
282            )
283
284        for message in PrefixedMessageDecoder(
285            prefix, string.ascii_letters + string.digits + '+/-_='
286        ).transform(input_file, transform):
287            output.write(message)
288
289            # Flush each line to prevent delays when piping between processes.
290            if b'\n' in message:
291                output.flush()
292
293    def _detokenize_prefixed_base64(
294        self, prefix: bytes, recursion: int
295    ) -> Callable[[Match[bytes]], bytes]:
296        """Returns a function that decodes prefixed Base64."""
297
298        def decode_and_detokenize(match: Match[bytes]) -> bytes:
299            """Decodes prefixed base64 with this detokenizer."""
300            original = match.group(0)
301
302            try:
303                detokenized_string = self.detokenize(
304                    base64.b64decode(original[1:], validate=True)
305                )
306                if detokenized_string.matches():
307                    result = str(detokenized_string).encode()
308
309                    if recursion > 0 and original != result:
310                        result = self.detokenize_base64(
311                            result, prefix, recursion - 1
312                        )
313
314                    return result
315            except binascii.Error:
316                pass
317
318            return original
319
320        return decode_and_detokenize
321
322
323_PathOrStr = Union[Path, str]
324
325
326# TODO(b/265334753): Reuse this function in database.py:LoadTokenDatabases
327def _parse_domain(path: _PathOrStr) -> Tuple[Path, Optional[Pattern[str]]]:
328    """Extracts an optional domain regex pattern suffix from a path"""
329
330    if isinstance(path, Path):
331        path = str(path)
332
333    delimiters = path.count('#')
334
335    if delimiters == 0:
336        return Path(path), None
337
338    if delimiters == 1:
339        path, domain = path.split('#')
340        return Path(path), re.compile(domain)
341
342    raise ValueError(
343        f'Too many # delimiters. Expected 0 or 1, found {delimiters}'
344    )
345
346
347class AutoUpdatingDetokenizer(Detokenizer):
348    """Loads and updates a detokenizer from database paths."""
349
350    class _DatabasePath:
351        """Tracks the modified time of a path or file object."""
352
353        def __init__(self, path: _PathOrStr) -> None:
354            self.path, self.domain = _parse_domain(path)
355            self._modified_time: Optional[float] = self._last_modified_time()
356
357        def updated(self) -> bool:
358            """True if the path has been updated since the last call."""
359            modified_time = self._last_modified_time()
360            if modified_time is None or modified_time == self._modified_time:
361                return False
362
363            self._modified_time = modified_time
364            return True
365
366        def _last_modified_time(self) -> Optional[float]:
367            try:
368                return os.path.getmtime(self.path)
369            except FileNotFoundError:
370                return None
371
372        def load(self) -> tokens.Database:
373            try:
374                if self.domain is not None:
375                    return database.load_token_database(
376                        self.path, domain=self.domain
377                    )
378                return database.load_token_database(self.path)
379            except FileNotFoundError:
380                return database.load_token_database()
381
382    def __init__(
383        self, *paths_or_files: _PathOrStr, min_poll_period_s: float = 1.0
384    ) -> None:
385        self.paths = tuple(self._DatabasePath(path) for path in paths_or_files)
386        self.min_poll_period_s = min_poll_period_s
387        self._last_checked_time: float = time.time()
388        super().__init__(*(path.load() for path in self.paths))
389
390    def _reload_if_changed(self) -> None:
391        if time.time() - self._last_checked_time >= self.min_poll_period_s:
392            self._last_checked_time = time.time()
393
394            if any(path.updated() for path in self.paths):
395                _LOG.info('Changes detected; reloading token database')
396                self._initialize_database(path.load() for path in self.paths)
397
398    def lookup(self, token: int) -> List[_TokenizedFormatString]:
399        self._reload_if_changed()
400        return super().lookup(token)
401
402
403class PrefixedMessageDecoder:
404    """Parses messages that start with a prefix character from a byte stream."""
405
406    def __init__(self, prefix: Union[str, bytes], chars: Union[str, bytes]):
407        """Parses prefixed messages.
408
409        Args:
410          prefix: one character that signifies the start of a message
411          chars: characters allowed in a message
412        """
413        self._prefix = prefix.encode() if isinstance(prefix, str) else prefix
414
415        if isinstance(chars, str):
416            chars = chars.encode()
417
418        # Store the valid message bytes as a set of binary strings.
419        self._message_bytes = frozenset(
420            chars[i : i + 1] for i in range(len(chars))
421        )
422
423        if len(self._prefix) != 1 or self._prefix in self._message_bytes:
424            raise ValueError(
425                'Invalid prefix {!r}: the prefix must be a single '
426                'character that is not a valid message character.'.format(
427                    prefix
428                )
429            )
430
431        self.data = bytearray()
432
433    def _read_next(self, fd: _RawIO) -> Tuple[bytes, int]:
434        """Returns the next character and its index."""
435        char = fd.read(1) or b''
436        index = len(self.data)
437        self.data += char
438        return char, index
439
440    def read_messages(self, binary_fd: _RawIO) -> Iterator[Tuple[bool, bytes]]:
441        """Parses prefixed messages; yields (is_message, contents) chunks."""
442        message_start = None
443
444        while True:
445            # This reads the file character-by-character. Non-message characters
446            # are yielded right away; message characters are grouped.
447            char, index = self._read_next(binary_fd)
448
449            # If in a message, keep reading until the message completes.
450            if message_start is not None:
451                if char in self._message_bytes:
452                    continue
453
454                yield True, self.data[message_start:index]
455                message_start = None
456
457            # Handle a non-message character.
458            if not char:
459                return
460
461            if char == self._prefix:
462                message_start = index
463            else:
464                yield False, char
465
466    def transform(
467        self, binary_fd: _RawIO, transform: Callable[[bytes], bytes]
468    ) -> Iterator[bytes]:
469        """Yields the file with a transformation applied to the messages."""
470        for is_message, chunk in self.read_messages(binary_fd):
471            yield transform(chunk) if is_message else chunk
472
473
474def _base64_message_regex(prefix: bytes) -> Pattern[bytes]:
475    """Returns a regular expression for prefixed base64 tokenized strings."""
476    return re.compile(
477        # Base64 tokenized strings start with the prefix character ($)
478        re.escape(prefix)
479        + (
480            # Tokenized strings contain 0 or more blocks of four Base64 chars.
481            br'(?:[A-Za-z0-9+/\-_]{4})*'
482            # The last block of 4 chars may have one or two padding chars (=).
483            br'(?:[A-Za-z0-9+/\-_]{3}=|[A-Za-z0-9+/\-_]{2}==)?'
484        )
485    )
486
487
488# TODO(hepler): Remove this unnecessary function.
489def detokenize_base64(
490    detokenizer: Detokenizer,
491    data: bytes,
492    prefix: Union[str, bytes] = BASE64_PREFIX,
493    recursion: int = DEFAULT_RECURSION,
494) -> bytes:
495    """Alias for detokenizer.detokenize_base64 for backwards compatibility."""
496    return detokenizer.detokenize_base64(data, prefix, recursion)
497
498
499def _follow_and_detokenize_file(
500    detokenizer: Detokenizer,
501    file: BinaryIO,
502    output: BinaryIO,
503    prefix: Union[str, bytes],
504    poll_period_s: float = 0.01,
505) -> None:
506    """Polls a file to detokenize it and any appended data."""
507
508    try:
509        while True:
510            data = file.read()
511            if data:
512                detokenizer.detokenize_base64_to_file(data, output, prefix)
513                output.flush()
514            else:
515                time.sleep(poll_period_s)
516    except KeyboardInterrupt:
517        pass
518
519
520def _handle_base64(
521    databases,
522    input_file: BinaryIO,
523    output: BinaryIO,
524    prefix: str,
525    show_errors: bool,
526    follow: bool,
527) -> None:
528    """Handles the base64 command line option."""
529    # argparse.FileType doesn't correctly handle - for binary files.
530    if input_file is sys.stdin:
531        input_file = sys.stdin.buffer
532
533    if output is sys.stdout:
534        output = sys.stdout.buffer
535
536    detokenizer = Detokenizer(
537        tokens.Database.merged(*databases), show_errors=show_errors
538    )
539
540    if follow:
541        _follow_and_detokenize_file(detokenizer, input_file, output, prefix)
542    elif input_file.seekable():
543        # Process seekable files all at once, which is MUCH faster.
544        detokenizer.detokenize_base64_to_file(input_file.read(), output, prefix)
545    else:
546        # For non-seekable inputs (e.g. pipes), read one character at a time.
547        detokenizer.detokenize_base64_live(input_file, output, prefix)
548
549
550def _parse_args() -> argparse.Namespace:
551    """Parses and return command line arguments."""
552
553    parser = argparse.ArgumentParser(
554        description=__doc__,
555        formatter_class=argparse.RawDescriptionHelpFormatter,
556    )
557    parser.set_defaults(handler=lambda **_: parser.print_help())
558
559    subparsers = parser.add_subparsers(help='Encoding of the input.')
560
561    base64_help = 'Detokenize Base64-encoded data from a file or stdin.'
562    subparser = subparsers.add_parser(
563        'base64',
564        description=base64_help,
565        parents=[database.token_databases_parser()],
566        help=base64_help,
567    )
568    subparser.set_defaults(handler=_handle_base64)
569    subparser.add_argument(
570        '-i',
571        '--input',
572        dest='input_file',
573        type=argparse.FileType('rb'),
574        default=sys.stdin.buffer,
575        help='The file from which to read; provide - or omit for stdin.',
576    )
577    subparser.add_argument(
578        '-f',
579        '--follow',
580        action='store_true',
581        help=(
582            'Detokenize data appended to input_file as it grows; similar to '
583            'tail -f.'
584        ),
585    )
586    subparser.add_argument(
587        '-o',
588        '--output',
589        type=argparse.FileType('wb'),
590        default=sys.stdout.buffer,
591        help=(
592            'The file to which to write the output; '
593            'provide - or omit for stdout.'
594        ),
595    )
596    subparser.add_argument(
597        '-p',
598        '--prefix',
599        default=BASE64_PREFIX,
600        help=(
601            'The one-character prefix that signals the start of a '
602            'Base64-encoded message. (default: $)'
603        ),
604    )
605    subparser.add_argument(
606        '-s',
607        '--show_errors',
608        action='store_true',
609        help=(
610            'Show error messages instead of conversion specifiers when '
611            'arguments cannot be decoded.'
612        ),
613    )
614
615    return parser.parse_args()
616
617
618def main() -> int:
619    args = _parse_args()
620
621    handler = args.handler
622    del args.handler
623
624    handler(**vars(args))
625    return 0
626
627
628if __name__ == '__main__':
629    if sys.version_info[0] < 3:
630        sys.exit('ERROR: The detokenizer command line tools require Python 3.')
631    sys.exit(main())
632