• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Module for extracting object files from a compiled archive (.a) file.
16
17This module provides functionality almost identical to the 'ar -x' command,
18which extracts out all object files from a given archive file. This module
19assumes the archive is in the BSD variant format used in Apple platforms.
20
21See: https://en.wikipedia.org/wiki/Ar_(Unix)#BSD_variant
22
23This extractor has two important differences compared to the 'ar -x' command
24shipped with Xcode.
25
261.  When there are multiple object files with the same name in a given archive,
27    each file is renamed so that they are all correctly extracted without
28    overwriting each other.
29
302.  This module takes the destination directory as an additional parameter.
31
32    Example Usage:
33
34    archive_path = ...
35    dest_dir = ...
36    extract_object_files(archive_path, dest_dir)
37"""
38
39import hashlib
40import io
41import itertools
42import os
43import struct
44from typing import Iterator, Tuple
45
46
47def extract_object_files(archive_file: io.BufferedIOBase,
48                         dest_dir: str) -> None:
49  """Extracts object files from the archive path to the destination directory.
50
51  Extracts object files from the given BSD variant archive file. The extracted
52  files are written to the destination directory, which will be created if the
53  directory does not exist.
54
55  Colliding object file names are automatically renamed upon extraction in order
56  to avoid unintended overwriting.
57
58  Args:
59    archive_file: The archive file object pointing at its beginning.
60    dest_dir: The destination directory path in which the extracted object files
61      will be written. The directory will be created if it does not exist.
62  """
63  if not os.path.exists(dest_dir):
64    os.makedirs(dest_dir)
65
66  _check_archive_signature(archive_file)
67
68  # Keep the extracted file names and their content hash values, in order to
69  # handle duplicate names correctly.
70  extracted_files = dict()
71
72  for name, file_content in _extract_next_file(archive_file):
73    digest = hashlib.md5(file_content).digest()
74
75    # Check if the name is already used. If so, come up with a different name by
76    # incrementing the number suffix until it finds an unused one.
77    # For example, if 'foo.o' is used, try 'foo_1.o', 'foo_2.o', and so on.
78    for final_name in _generate_modified_filenames(name):
79      if final_name not in extracted_files:
80        extracted_files[final_name] = digest
81
82        # Write the file content to the desired final path.
83        with open(os.path.join(dest_dir, final_name), 'wb') as object_file:
84          object_file.write(file_content)
85        break
86
87      # Skip writing this file if the same file was already extracted.
88      elif extracted_files[final_name] == digest:
89        break
90
91
92def _generate_modified_filenames(filename: str) -> Iterator[str]:
93  """Generates the modified filenames with incremental name suffix added.
94
95  This helper function first yields the given filename itself, and subsequently
96  yields modified filenames by incrementing number suffix to the basename.
97
98  Args:
99    filename: The original filename to be modified.
100
101  Yields:
102    The original filename and then modified filenames with incremental suffix.
103  """
104  yield filename
105
106  base, ext = os.path.splitext(filename)
107  for name_suffix in itertools.count(1, 1):
108    yield '{}_{}{}'.format(base, name_suffix, ext)
109
110
111def _check_archive_signature(archive_file: io.BufferedIOBase) -> None:
112  """Checks if the file has the correct archive header signature.
113
114  The cursor is moved to the first available file header section after
115  successfully checking the signature.
116
117  Args:
118    archive_file: The archive file object pointing at its beginning.
119
120  Raises:
121    RuntimeError: The archive signature is invalid.
122  """
123  signature = archive_file.read(8)
124  if signature != b'!<arch>\n':
125    raise RuntimeError('Invalid archive file format.')
126
127
128def _extract_next_file(
129    archive_file: io.BufferedIOBase) -> Iterator[Tuple[str, bytes]]:
130  """Extracts the next available file from the archive.
131
132  Reads the next available file header section and yields its filename and
133  content in bytes as a tuple. Stops when there are no more available files in
134  the provided archive_file.
135
136  Args:
137    archive_file: The archive file object, of which cursor is pointing to the
138      next available file header section.
139
140  Yields:
141    The name and content of the next available file in the given archive file.
142
143  Raises:
144    RuntimeError: The archive_file is in an unknown format.
145  """
146  while True:
147    header = archive_file.read(60)
148    if not header:
149      return
150    elif len(header) < 60:
151      raise RuntimeError('Invalid file header format.')
152
153    # For the details of the file header format, see:
154    # https://en.wikipedia.org/wiki/Ar_(Unix)#File_header
155    # We only need the file name and the size values.
156    name, _, _, _, _, size, end = struct.unpack('=16s12s6s6s8s10s2s', header)
157    if end != b'`\n':
158      raise RuntimeError('Invalid file header format.')
159
160    # Convert the bytes into more natural types.
161    name = name.decode('ascii').strip()
162    size = int(size, base=10)
163    odd_size = size % 2 == 1
164
165    # Handle the extended filename scheme.
166    if name.startswith('#1/'):
167      filename_size = int(name[3:])
168      name = archive_file.read(filename_size).decode('utf-8').strip(' \x00')
169      size -= filename_size
170
171    file_content = archive_file.read(size)
172    # The file contents are always 2 byte aligned, and 1 byte is padded at the
173    # end in case the size is odd.
174    if odd_size:
175      archive_file.read(1)
176
177    yield (name, file_content)
178