1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3# 4# Copyright (c) 2024 Huawei Device Co., Ltd. 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# 17 18from dataclasses import dataclass, field 19from functools import _CacheInfo as CacheInfo 20from functools import cache as func_cach 21from typing import Callable, Iterator, List, Literal, Protocol, Tuple 22 23from ..logs import logger 24 25LOG = logger(__name__) 26 27 28class CacheWrapper(Protocol): 29 def cache_info(self) -> CacheInfo: 30 pass 31 32 def cache_clear(self) -> None: 33 pass 34 35 36@dataclass 37class HitStat: 38 hits: int = 0 39 misses: int = 0 40 count: int = 0 41 42 def add(self, info: CacheInfo): 43 self.hits += info.hits 44 self.misses += info.misses 45 self.count += info.currsize 46 47 48@dataclass 49class Cache: 50 cache: CacheWrapper 51 scope: str 52 name: str = "" 53 stats: HitStat = field(default_factory=HitStat) 54 55 def clear(self) -> None: 56 self._accum_info() 57 self.cache.cache_clear() 58 59 def info(self) -> CacheInfo: 60 return self.cache.cache_info() 61 62 def empty(self) -> bool: 63 return self.cache.cache_info().hits == 0 64 65 def _accum_info(self): 66 self.stats.add(self.cache.cache_info()) 67 68 69CacheScope = Literal["test", "session"] 70 71 72class MirrorTypeCaches: 73 _caches: List[Cache] 74 75 def __init__(self) -> None: 76 self._caches = [] 77 78 def add(self, cache: Cache): 79 self._caches.append(cache) 80 81 def find(self, *scopes: CacheScope) -> Iterator[Cache]: 82 for c in self._caches: 83 if len(scopes) == 0 or c.scope in scopes: 84 yield c 85 86 def stats(self) -> Iterator[Tuple[str, HitStat]]: 87 for c in self._caches: 88 yield c.name, c.stats 89 90 91_CACHE_WRAPPERS = MirrorTypeCaches() 92 93 94def type_cache(*, scope: CacheScope): 95 def wrapper(user_function: Callable, /): 96 w = func_cach(user_function) 97 _CACHE_WRAPPERS.add(Cache(cache=w, scope=scope, name=user_function.__name__)) 98 return w 99 100 return wrapper 101 102 103def clear_cache(*scopes: CacheScope) -> None: 104 for c in _CACHE_WRAPPERS.find(*scopes): 105 if not c.empty(): 106 LOG.debug("Clear cache '%s': %s", c.name, c.info()) 107 c.clear() 108 109 110def log_stats(): 111 lines = [f"'{name}': {stat}" for name, stat in _CACHE_WRAPPERS.stats() if stat.hits] 112 LOG.info("Cache statisticts:\n%s", "\n".join(lines)) 113