import errno import logging import os import platform import stat import struct import tempfile import typing as _t from contextlib import contextmanager from hashlib import md5 from pathlib import Path from time import sleep from time import time from cachelib.base import BaseCache from cachelib.serializers import FileSystemSerializer class FileSystemCache(BaseCache): """A cache that stores the items on the file system. This cache depends on being the only user of the `cache_dir`. Make absolutely sure that nobody but this cache stores files there or otherwise the cache will randomly delete files therein. :param cache_dir: the directory where cache files are stored. :param threshold: the maximum number of items the cache stores before it starts deleting some. A threshold value of 0 indicates no threshold. :param default_timeout: the default timeout that is used if no timeout is specified on :meth:`~BaseCache.set`. A timeout of 0 indicates that the cache never expires. :param mode: the file mode wanted for the cache files, default 0600 :param hash_method: Default hashlib.md5. The hash method used to generate the filename for cached results. """ #: used for temporary files by the FileSystemCache _fs_transaction_suffix = ".__wz_cache" #: keep amount of files in a cache element _fs_count_file = "__wz_cache_count" serializer = FileSystemSerializer() def __init__( self, cache_dir: str, threshold: int = 500, default_timeout: int = 300, mode: _t.Optional[int] = None, hash_method: _t.Any = md5, ): BaseCache.__init__(self, default_timeout) self._path = cache_dir self._threshold = threshold self._hash_method = hash_method # Mode set by user takes precedence. If no mode has # been given, we need to set the correct default based # on user platform. self._mode = mode if self._mode is None: self._mode = self._get_compatible_platform_mode() try: os.makedirs(self._path) except OSError as ex: if ex.errno != errno.EEXIST: raise # If there are many files and a zero threshold, # the list_dir can slow initialisation massively if self._threshold != 0: self._update_count(value=len(list(self._list_dir()))) def _get_compatible_platform_mode(self) -> int: mode = 0o600 # nix systems if platform.system() == "Windows": mode = stat.S_IWRITE return mode @property def _file_count(self) -> int: return self.get(self._fs_count_file) or 0 def _update_count( self, delta: _t.Optional[int] = None, value: _t.Optional[int] = None ) -> None: # If we have no threshold, don't count files if self._threshold == 0: return if delta: new_count = self._file_count + delta else: new_count = value or 0 self.set(self._fs_count_file, new_count, mgmt_element=True) def _normalize_timeout(self, timeout: _t.Optional[int]) -> int: timeout = BaseCache._normalize_timeout(self, timeout) if timeout != 0: timeout = int(time()) + timeout return int(timeout) def _is_mgmt(self, name: str) -> bool: fshash = self._get_filename(self._fs_count_file).split(os.sep)[-1] return name == fshash or name.endswith(self._fs_transaction_suffix) def _list_dir(self) -> _t.Generator[str, None, None]: """return a list of (fully qualified) cache filenames""" return ( os.path.join(self._path, fn) for fn in os.listdir(self._path) if not self._is_mgmt(fn) ) def _over_threshold(self) -> bool: return self._threshold != 0 and self._file_count > self._threshold def _remove_expired(self, now: float) -> None: for fname in self._list_dir(): try: with self._safe_stream_open(fname, "rb") as f: expires = struct.unpack("I", f.read(4))[0] if expires != 0 and expires < now: os.remove(fname) self._update_count(delta=-1) except FileNotFoundError: pass except (OSError, EOFError, struct.error): logging.warning( "Exception raised while handling cache file '%s'", fname, exc_info=True, ) def _remove_older(self) -> bool: exp_fname_tuples = [] for fname in self._list_dir(): try: with self._safe_stream_open(fname, "rb") as f: timestamp = struct.unpack("I", f.read(4))[0] exp_fname_tuples.append((timestamp, fname)) except FileNotFoundError: pass except (OSError, EOFError, struct.error): logging.warning( "Exception raised while handling cache file '%s'", fname, exc_info=True, ) fname_sorted = ( fname for _, fname in sorted( exp_fname_tuples, key=lambda item: item[0] # type: ignore ) ) for fname in fname_sorted: try: os.remove(fname) self._update_count(delta=-1) except FileNotFoundError: pass except OSError: logging.warning( "Exception raised while handling cache file '%s'", fname, exc_info=True, ) return False if not self._over_threshold(): break return True def _prune(self) -> None: if self._over_threshold(): now = time() self._remove_expired(now) # if still over threshold if self._over_threshold(): self._remove_older() def clear(self) -> bool: for i, fname in enumerate(self._list_dir()): try: os.remove(fname) except FileNotFoundError: pass except OSError: logging.warning( "Exception raised while handling cache file '%s'", fname, exc_info=True, ) self._update_count(delta=-i) return False self._update_count(value=0) return True def _get_filename(self, key: str) -> str: if isinstance(key, str): bkey = key.encode("utf-8") # XXX unicode review bkey_hash = self._hash_method(bkey).hexdigest() else: raise TypeError(f"Key must be a string, received type {type(key)}") return os.path.join(self._path, bkey_hash) def get(self, key: str) -> _t.Any: filename = self._get_filename(key) try: with self._safe_stream_open(filename, "rb") as f: pickle_time = struct.unpack("I", f.read(4))[0] if pickle_time == 0 or pickle_time >= time(): return self.serializer.load(f) except FileNotFoundError: pass except (OSError, EOFError, struct.error): logging.warning( "Exception raised while handling cache file '%s'", filename, exc_info=True, ) return None def add(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> bool: filename = self._get_filename(key) if not os.path.exists(filename): return self.set(key, value, timeout) return False def set( self, key: str, value: _t.Any, timeout: _t.Optional[int] = None, mgmt_element: bool = False, ) -> bool: # Management elements have no timeout if mgmt_element: timeout = 0 # Don't prune on management element update, to avoid loop else: self._prune() timeout = self._normalize_timeout(timeout) filename = self._get_filename(key) overwrite = os.path.isfile(filename) try: fd, tmp = tempfile.mkstemp( suffix=self._fs_transaction_suffix, dir=self._path ) with os.fdopen(fd, "wb") as f: f.write(struct.pack("I", timeout)) self.serializer.dump(value, f) self._run_safely(os.replace, tmp, filename) self._run_safely(os.chmod, filename, self._mode) fsize = Path(filename).stat().st_size except OSError: logging.warning( "Exception raised while handling cache file '%s'", filename, exc_info=True, ) return False else: # Management elements should not count towards threshold if not overwrite and not mgmt_element: self._update_count(delta=1) return fsize > 0 # function should fail if file is empty def delete(self, key: str, mgmt_element: bool = False) -> bool: try: os.remove(self._get_filename(key)) except FileNotFoundError: # if file doesn't exist we consider it deleted return True except OSError: logging.warning("Exception raised while handling cache file", exc_info=True) return False else: # Management elements should not count towards threshold if not mgmt_element: self._update_count(delta=-1) return True def has(self, key: str) -> bool: filename = self._get_filename(key) try: with self._safe_stream_open(filename, "rb") as f: pickle_time = struct.unpack("I", f.read(4))[0] if pickle_time == 0 or pickle_time >= time(): return True else: return False except FileNotFoundError: # if there is no file there is no key return False except (OSError, EOFError, struct.error): logging.warning( "Exception raised while handling cache file '%s'", filename, exc_info=True, ) return False def _run_safely(self, fn: _t.Callable, *args: _t.Any, **kwargs: _t.Any) -> _t.Any: """On Windows os.replace, os.chmod and open can yield permission errors if executed by two different processes.""" if platform.system() == "Windows": output = None wait_step = 0.001 max_sleep_time = 10.0 total_sleep_time = 0.0 while total_sleep_time < max_sleep_time: try: output = fn(*args, **kwargs) except PermissionError: sleep(wait_step) total_sleep_time += wait_step wait_step *= 2 else: break else: output = fn(*args, **kwargs) return output @contextmanager def _safe_stream_open(self, path: str, mode: str) -> _t.Generator: fs = self._run_safely(open, path, mode) if fs is None: raise OSError try: yield fs finally: fs.close()