#!/usr/bin/env python3 # # Copyright 2023, The Android Open Source Project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Preserves and restores the state of a repository. This module includes a `Snapshot` class that provides methods to: - Take snapshots of a directory, including or excluding specified paths. - Preserve environment variables. - Restore the directory state from previously taken snapshots, managing file and directory deletions and replacements. """ import functools import glob import hashlib import json import logging import os import pathlib import threading from typing import Any, Optional def _synchronized(func): """Ensures thread-safe execution of the wrapped function.""" lock = threading.Lock() @functools.wraps(func) def _synchronized_func(*args, **kwargs): with lock: return func(*args, **kwargs) return _synchronized_func class Snapshot: """Provides functionality to take and restore snapshots of a directory.""" def __init__(self, storage_dir: pathlib.Path): """Initializes a Snapshot object. Args: storage_dir: The directory where snapshots will be stored. """ self._dir_snapshot = _DirSnapshot(storage_dir) self._env_snapshot = _EnvSnapshot(storage_dir) self._obj_snapshot = _ObjectSnapshot(storage_dir) self._lock = self._get_threading_lock(storage_dir) @_synchronized def _get_threading_lock( self, name: str, ): """Gets a threading lock for the snapshot directory.""" locks_dict_attr_name = 'threading_locks' current_function = self._get_threading_lock.__func__ if not hasattr(current_function, locks_dict_attr_name): setattr(current_function, locks_dict_attr_name, {}) if name not in getattr(current_function, locks_dict_attr_name): getattr(current_function, locks_dict_attr_name)[name] = threading.Lock() return getattr(current_function, locks_dict_attr_name)[name] # pylint: disable=too-many-arguments def take_snapshot( self, name: str, root_path: str, include_paths: list[str], exclude_paths: Optional[list[str]] = None, env_keys: Optional[list[str]] = None, env: Optional[dict[str, str]] = None, objs: Optional[dict[str, Any]] = None, ) -> None: """Takes a snapshot of the directory at the given path. Args: name: The name of the snapshot. root_path: The path to the directory to snapshot. include_paths: A list of relative paths to include in the snapshot. exclude_paths: A list of relative paths to exclude from the snapshot. env_keys: A list of environment variable keys to save. env: Environment variables to use while restoring. objs: A dictionary of objects to save. The current implementation limits the type of objects to the types that can be serialized by the json module. """ with self._lock: self._dir_snapshot.take_snapshot( name, root_path, include_paths, exclude_paths, env ) self._env_snapshot.take_snapshot(name, env_keys) self._obj_snapshot.take_snapshot(name, objs) def restore_snapshot( self, name: str, root_path: str, exclude_paths: Optional[list[str]] = None, ) -> tuple[dict[str, str], dict[str, Any]]: """Restores directory at given path to a snapshot with the given name. Args: name: The name of the snapshot. root_path: The path to the target directory. exclude_paths: A list of paths to ignore during restore. Returns: A tuple of restored environment variables and object dictionary. """ env = self._env_snapshot.restore_snapshot(name, root_path) objs = self._obj_snapshot.restore_snapshot(name) with self._lock: self._dir_snapshot.restore_snapshot(name, root_path, exclude_paths, env) return env, objs class _ObjectSnapshot: """Save and restore a dictionary of objects through json.""" def __init__(self, storage_path: pathlib.Path): self._storage_path = storage_path def take_snapshot( self, name: str, objs: Optional[dict[str, Any]] = None, ) -> None: """Save a dictionary of objects in snapshot. Args: name: The name of the snapshot objs: A dictionary of objects to snapshot. Note: The current implementation limits the type of objects to the types that can be serialized by the json module. """ if objs is None: objs = {} with open( self._storage_path.joinpath('%s.objs.json' % name), 'w', encoding='utf-8', ) as f: json.dump(objs, f) def restore_snapshot(self, name: str) -> dict[str, Any]: """Restore saved objects from snapshot.""" with open( self._storage_path.joinpath('%s.objs.json' % name), 'r', encoding='utf-8', ) as f: return json.load(f) class _EnvSnapshot: """Save and restore environment variables.""" _repo_root_placeholder = '' def __init__(self, storage_path: pathlib.Path): self._storage_path = storage_path def take_snapshot( self, name: str, env_keys: Optional[list[str]] = None, ) -> None: """Save a subset of environment variables.""" if env_keys is None: env_keys = [] original_env = os.environ.copy() subset_env = { key: os.environ[key] for key in env_keys if key in original_env } modified_env = { key: value.replace( os.environ['ANDROID_BUILD_TOP'], self._repo_root_placeholder ) for key, value in subset_env.items() } with open(self._get_env_file_path(name), 'w', encoding='utf-8') as f: json.dump(modified_env, f) def restore_snapshot(self, name: str, root_path: str) -> dict[str, str]: """Load saved environment variables.""" with self._get_env_file_path(name).open('r') as f: loaded_env = json.load(f) restored_env = { key: value.replace( self._repo_root_placeholder, root_path, ) for key, value in loaded_env.items() } if 'PATH' in os.environ: if 'PATH' in restored_env: restored_env['PATH'] = restored_env['PATH'] + ':' + os.environ['PATH'] else: restored_env['PATH'] = os.environ['PATH'] return restored_env def _get_env_file_path(self, name: str) -> pathlib.Path: """Get environment file path.""" return self._storage_path / (name + '_env.json') class _FileInfo: """An object to save file information.""" # pylint: disable=too-many-arguments def __init__( self, path: str, timestamp: float, content_hash: str, permissions: int, symlink_target: str, is_directory: bool, is_target_in_workspace: bool = False, ): self.path = path self.timestamp = timestamp self.content_hash = content_hash self.permissions = permissions self.symlink_target = symlink_target self.is_directory = is_directory self.is_target_in_workspace = is_target_in_workspace class _BlobStore: """Class to save and load file content.""" def __init__(self, path: str): self.path = pathlib.Path(path) self.cache = self._load_cache() def add(self, path: pathlib.Path, timestamp: float) -> str: """Add a file path to the store.""" cache_key = path.as_posix() + str(timestamp) if cache_key in self.cache: return self.cache[cache_key] content = path.read_bytes() content_hash = hashlib.sha256(content).hexdigest() content_path = self.path.joinpath(content_hash[:2], content_hash[2:]) if not content_path.exists(): content_path.parent.mkdir(parents=True, exist_ok=True) content_path.write_bytes(content) self.cache[cache_key] = content_hash return content_hash def get(self, content_hash: str) -> bytes: """Read file content from a content hash.""" file_path = self.path.joinpath(content_hash[:2], content_hash[2:]) if file_path.exists(): return file_path.read_bytes() return None def dump_cache(self) -> None: """Dump the saved file path cache to speed up next run.""" self._get_cache_path().parent.mkdir(parents=True, exist_ok=True) with self._get_cache_path().open('w', encoding='utf-8') as f: json.dump(self.cache, f) def _load_cache(self) -> dict[str, str]: if not self._get_cache_path().exists(): return {} with self._get_cache_path().open('r', encoding='utf-8') as f: return json.load(f) def _get_cache_path(self) -> pathlib.Path: return self.path.joinpath('cache.json') class _DirSnapshot: """Class to take and restore snapshot for a directory path.""" def __init__(self, storage_path: pathlib.Path): self._storage_path = storage_path self._blob_store = _BlobStore(self._storage_path.joinpath('blobs')) def _expand_vars_paths( self, paths: list[str], variables: dict[str, str] ) -> list[str]: """Expand variables in paths with the given environment variables. This function is similar to os.path.expandvars(path) which relies on os.environ. Args: paths: A list of paths that might contains variables to expand. variables: A dictionary of variable names and values. Returns: A list containing paths whose variables have been expanded if known. """ if not variables: return paths path_result = paths.copy() for idx, _ in enumerate(path_result): for key, val in sorted( variables.items(), key=lambda item: len(item[0]), reverse=True ): path_result[idx] = path_result[idx].replace(f'${key}', val) return path_result def _expand_wildcard_paths( self, root_path: str, paths: list[str], env: Optional[dict[str, str]] = None, ) -> list[str]: """Expand wildcard paths.""" compose = lambda inner, outer: lambda path: outer(inner(path)) get_abs_path = ( lambda path: path if os.path.isabs(path) else os.path.join(root_path, path) ) glob_path = functools.partial(glob.glob, recursive=True) return sum( map( compose(get_abs_path, glob_path), self._expand_vars_paths(paths, env), ), [], ) def _is_excluded(self, path: str, exclude_paths: list[pathlib.Path]) -> bool: """Check whether a path should be excluded.""" return exclude_paths and any( path.startswith(exclude_path) for exclude_path in exclude_paths ) def _filter_excluded_paths( self, root: pathlib.Path, paths: list[pathlib.Path], exclude_paths: list[pathlib.Path], ) -> None: """Filter a list of paths with a list of exclude paths.""" new_paths = [ path for path in paths if not self._is_excluded(os.path.join(root, path), exclude_paths) ] if len(new_paths) == len(paths): return paths.clear() paths.extend(new_paths) def take_snapshot( self, name: str, root_path: str, include_paths: list[str], exclude_paths: Optional[list[str]] = None, env: Optional[dict[str, str]] = None, ) -> tuple[dict[str, _FileInfo], list[str]]: """Creates a snapshot of the directory at the given path. Args: name: The name of the snapshot. root_path: The path to the root directory. include_paths: A list of relative paths to include in the snapshot. exclude_paths: A list of relative paths to exclude from the snapshot. env: Environment variables to use while restoring. Returns: A tuple containing: - A dictionary of _FileInfo objects keyed by their relative path within the directory. """ include_paths = ( self._expand_wildcard_paths(root_path, include_paths, env) if include_paths else [] ) exclude_paths = ( self._expand_wildcard_paths(root_path, exclude_paths, env) if exclude_paths else [] ) file_infos = {} def process_directory(path: pathlib.Path) -> None: if path.is_symlink(): process_link(path) return relative_path = path.relative_to(root_path).as_posix() if relative_path == '.': return file_infos[relative_path] = _FileInfo( relative_path, timestamp=None, content_hash=None, permissions=path.stat().st_mode, symlink_target=None, is_directory=True, ) def process_file(path: pathlib.Path) -> None: if path.is_symlink(): process_link(path) return relative_path = path.relative_to(root_path).as_posix() timestamp = path.stat().st_mtime file_infos[relative_path] = _FileInfo( relative_path, timestamp=timestamp, content_hash=self._blob_store.add(path, timestamp) if path.stat().st_size else None, permissions=path.stat().st_mode, symlink_target=None, is_directory=False, ) def process_link(path: pathlib.Path) -> None: relative_path = path.relative_to(root_path).as_posix() symlink_target = path.readlink() is_target_in_workspace = False if symlink_target.is_relative_to(root_path): symlink_target = symlink_target.relative_to(root_path) is_target_in_workspace = True file_infos[relative_path] = _FileInfo( relative_path, timestamp=None, content_hash=None, permissions=None, symlink_target=symlink_target.as_posix(), is_target_in_workspace=is_target_in_workspace, is_directory=False, ) def process_path(path: pathlib.Path) -> None: if self._is_excluded(path.as_posix(), exclude_paths): return if path.is_symlink(): process_link(path) elif path.is_file(): process_file(path) elif path.is_dir(): process_directory(path) for root, directories, files in os.walk(path): self._filter_excluded_paths(root, directories, exclude_paths) self._filter_excluded_paths(root, files, exclude_paths) for directory in directories: process_directory(pathlib.Path(root).joinpath(directory)) for file in files: process_file(pathlib.Path(root).joinpath(file)) else: # We are not throwing error here because it might be just a # corner case which likely doesn't affect the test process. logging.error('Unexpected path type: %s', path.as_posix()) for path in include_paths: process_path(pathlib.Path(path)) snapshot_path = self._storage_path.joinpath(name + '_metadata.json') snapshot_path.parent.mkdir(parents=True, exist_ok=True) with snapshot_path.open('w') as f: json.dump(file_infos, f, default=lambda o: o.__dict__) self._blob_store.dump_cache() return file_infos def restore_snapshot( self, name: str, root_path: str, exclude_paths: Optional[list[str]] = None, env: Optional[dict[str, str]] = None, ) -> tuple[list[str], list[str], list[str]]: """Restores directory at given path to snapshot with given name. Args: name: The name of the snapshot. root_path: The path to the root directory. exclude_paths: A list of relative paths to ignore during restoring. env: Environment variables to use while restoring. Returns: A tuple containing 3 lists: - Files and directories that were deleted. - Files that were replaced. """ with self._storage_path.joinpath(name + '_metadata.json').open('r') as f: file_infos_dict = { key: _FileInfo(**val) for key, val in json.load(f).items() } exclude_paths = ( self._expand_wildcard_paths(root_path, exclude_paths, env) if exclude_paths else [] ) deleted = self._remove_extra_files( file_infos_dict, root_path, exclude_paths ) self._restore_directories(file_infos_dict, root_path, exclude_paths) replaced = self._restore_files(file_infos_dict, root_path, exclude_paths) return deleted, replaced def _remove_extra_files( self, file_infos_dict: dict[str, _FileInfo], root_path: str, exclude_paths: list[str], ): """Internal method to remove extra files during snapshot restore.""" deleted = [] for root, directories, files in os.walk(root_path): self._filter_excluded_paths(root, directories, exclude_paths) self._filter_excluded_paths(root, files, exclude_paths) for directory in directories: dir_path = pathlib.Path(root).joinpath(directory) # Ignore non link directories because complicated to deal # with file paths in include filters and unnecessary if dir_path.is_symlink(): dir_path.unlink() for file in files: file_path = pathlib.Path(root).joinpath(file) if file_path.is_symlink(): file_path.unlink() elif file_path.relative_to(root_path).as_posix() not in file_infos_dict: file_path.unlink() deleted.append(file_path.as_posix()) return deleted def _restore_directories( self, file_infos_dict: dict[str, _FileInfo], root_path: str, exclude_paths: list[str], ): """Internal method to restore directories during snapshot restore.""" for relative_path, file_info in file_infos_dict.items(): if not file_info.is_directory: continue dir_path = pathlib.Path(root_path).joinpath(relative_path) if self._is_excluded(dir_path.as_posix(), exclude_paths): continue dir_path.mkdir(parents=True, exist_ok=True) os.chmod(dir_path, file_info.permissions) def _restore_files( self, file_infos_dict: dict[str, _FileInfo], root_path: str, exclude_paths: list[str], ): """Internal method to restore files during snapshot restore.""" replaced = [] for relative_path, file_info in file_infos_dict.items(): file_path = pathlib.Path(root_path).joinpath(relative_path) if self._is_excluded(file_path.as_posix(), exclude_paths): continue if file_info.symlink_target: file_path.parent.mkdir(parents=True, exist_ok=True) target = file_info.symlink_target if bool(file_info.is_target_in_workspace): target = pathlib.Path(root_path).joinpath(target) file_path.parent.mkdir(parents=True, exist_ok=True) file_path.symlink_to(target) continue if file_info.is_directory: continue if ( file_path.exists() and file_path.stat().st_mtime == file_info.timestamp ): continue file_path.parent.mkdir(parents=True, exist_ok=True) file_path.unlink(missing_ok=True) if not file_info.content_hash: file_path.touch() else: file_path.write_bytes(self._blob_store.get(file_info.content_hash)) os.utime(file_path, (file_info.timestamp, file_info.timestamp)) os.chmod(file_path, file_info.permissions) replaced.append(file_path.as_posix()) return replaced