""" Functions to manipulate conninfo strings """ # Copyright (C) 2020 The Psycopg Team from __future__ import annotations import os import re import socket import asyncio from typing import Any, Iterator, AsyncIterator from random import shuffle from pathlib import Path from datetime import tzinfo from functools import lru_cache from ipaddress import ip_address from typing_extensions import TypeAlias from . import pq from . import errors as e from ._tz import get_tzinfo from ._compat import cache from ._encodings import pgconn_encoding ConnDict: TypeAlias = "dict[str, Any]" def make_conninfo(conninfo: str = "", **kwargs: Any) -> str: """ Merge a string and keyword params into a single conninfo string. :param conninfo: A `connection string`__ as accepted by PostgreSQL. :param kwargs: Parameters overriding the ones specified in `!conninfo`. :return: A connection string valid for PostgreSQL, with the `!kwargs` parameters merged. Raise `~psycopg.ProgrammingError` if the input doesn't make a valid conninfo string. .. __: https://www.postgresql.org/docs/current/libpq-connect.html #LIBPQ-CONNSTRING """ if not conninfo and not kwargs: return "" # If no kwarg specified don't mung the conninfo but check if it's correct. # Make sure to return a string, not a subtype, to avoid making Liskov sad. if not kwargs: _parse_conninfo(conninfo) return str(conninfo) # Override the conninfo with the parameters # Drop the None arguments kwargs = {k: v for (k, v) in kwargs.items() if v is not None} if conninfo: tmp = conninfo_to_dict(conninfo) tmp.update(kwargs) kwargs = tmp conninfo = " ".join(f"{k}={_param_escape(str(v))}" for (k, v) in kwargs.items()) # Verify the result is valid _parse_conninfo(conninfo) return conninfo def conninfo_to_dict(conninfo: str = "", **kwargs: Any) -> ConnDict: """ Convert the `!conninfo` string into a dictionary of parameters. :param conninfo: A `connection string`__ as accepted by PostgreSQL. :param kwargs: Parameters overriding the ones specified in `!conninfo`. :return: Dictionary with the parameters parsed from `!conninfo` and `!kwargs`. Raise `~psycopg.ProgrammingError` if `!conninfo` is not a a valid connection string. .. __: https://www.postgresql.org/docs/current/libpq-connect.html #LIBPQ-CONNSTRING """ opts = _parse_conninfo(conninfo) rv = {opt.keyword.decode(): opt.val.decode() for opt in opts if opt.val is not None} for k, v in kwargs.items(): if v is not None: rv[k] = v return rv def _parse_conninfo(conninfo: str) -> list[pq.ConninfoOption]: """ Verify that `!conninfo` is a valid connection string. Raise ProgrammingError if the string is not valid. Return the result of pq.Conninfo.parse() on success. """ try: return pq.Conninfo.parse(conninfo.encode()) except e.OperationalError as ex: raise e.ProgrammingError(str(ex)) from None re_escape = re.compile(r"([\\'])") re_space = re.compile(r"\s") def _param_escape(s: str) -> str: """ Apply the escaping rule required by PQconnectdb """ if not s: return "''" s = re_escape.sub(r"\\\1", s) if re_space.search(s): s = "'" + s + "'" return s class ConnectionInfo: """Allow access to information about the connection.""" __module__ = "psycopg" def __init__(self, pgconn: pq.abc.PGconn): self.pgconn = pgconn @property def vendor(self) -> str: """A string representing the database vendor connected to.""" return "PostgreSQL" @property def host(self) -> str: """The server host name of the active connection. See :pq:`PQhost()`.""" return self._get_pgconn_attr("host") @property def hostaddr(self) -> str: """The server IP address of the connection. See :pq:`PQhostaddr()`.""" return self._get_pgconn_attr("hostaddr") @property def port(self) -> int: """The port of the active connection. See :pq:`PQport()`.""" return int(self._get_pgconn_attr("port")) @property def dbname(self) -> str: """The database name of the connection. See :pq:`PQdb()`.""" return self._get_pgconn_attr("db") @property def user(self) -> str: """The user name of the connection. See :pq:`PQuser()`.""" return self._get_pgconn_attr("user") @property def password(self) -> str: """The password of the connection. See :pq:`PQpass()`.""" return self._get_pgconn_attr("password") @property def options(self) -> str: """ The command-line options passed in the connection request. See :pq:`PQoptions`. """ return self._get_pgconn_attr("options") def get_parameters(self) -> dict[str, str]: """Return the connection parameters values. Return all the parameters set to a non-default value, which might come either from the connection string and parameters passed to `~Connection.connect()` or from environment variables. The password is never returned (you can read it using the `password` attribute). """ pyenc = self.encoding # Get the known defaults to avoid reporting them defaults = { i.keyword: i.compiled for i in pq.Conninfo.get_defaults() if i.compiled is not None } # Not returned by the libq. Bug? Bet we're using SSH. defaults.setdefault(b"channel_binding", b"prefer") defaults[b"passfile"] = str(Path.home() / ".pgpass").encode() return { i.keyword.decode(pyenc): i.val.decode(pyenc) for i in self.pgconn.info if i.val is not None and i.keyword != b"password" and i.val != defaults.get(i.keyword) } @property def dsn(self) -> str: """Return the connection string to connect to the database. The string contains all the parameters set to a non-default value, which might come either from the connection string and parameters passed to `~Connection.connect()` or from environment variables. The password is never returned (you can read it using the `password` attribute). """ return make_conninfo(**self.get_parameters()) @property def status(self) -> pq.ConnStatus: """The status of the connection. See :pq:`PQstatus()`.""" return pq.ConnStatus(self.pgconn.status) @property def transaction_status(self) -> pq.TransactionStatus: """ The current in-transaction status of the session. See :pq:`PQtransactionStatus()`. """ return pq.TransactionStatus(self.pgconn.transaction_status) @property def pipeline_status(self) -> pq.PipelineStatus: """ The current pipeline status of the client. See :pq:`PQpipelineStatus()`. """ return pq.PipelineStatus(self.pgconn.pipeline_status) def parameter_status(self, param_name: str) -> str | None: """ Return a parameter setting of the connection. Return `None` is the parameter is unknown. """ res = self.pgconn.parameter_status(param_name.encode(self.encoding)) return res.decode(self.encoding) if res is not None else None @property def server_version(self) -> int: """ An integer representing the server version. See :pq:`PQserverVersion()`. """ return self.pgconn.server_version @property def backend_pid(self) -> int: """ The process ID (PID) of the backend process handling this connection. See :pq:`PQbackendPID()`. """ return self.pgconn.backend_pid @property def error_message(self) -> str: """ The error message most recently generated by an operation on the connection. See :pq:`PQerrorMessage()`. """ return self._get_pgconn_attr("error_message") @property def timezone(self) -> tzinfo: """The Python timezone info of the connection's timezone.""" return get_tzinfo(self.pgconn) @property def encoding(self) -> str: """The Python codec name of the connection's client encoding.""" return pgconn_encoding(self.pgconn) def _get_pgconn_attr(self, name: str) -> str: value: bytes = getattr(self.pgconn, name) return value.decode(self.encoding) def conninfo_attempts(params: ConnDict) -> Iterator[ConnDict]: """Split a set of connection params on the single attempts to perforn. A connection param can perform more than one attempt more than one ``host`` is provided. Because the libpq async function doesn't honour the timeout, we need to reimplement the repeated attempts. """ if params.get("load_balance_hosts", "disable") == "random": attempts = list(_split_attempts(_inject_defaults(params))) shuffle(attempts) for attempt in attempts: yield attempt else: for attempt in _split_attempts(_inject_defaults(params)): yield attempt async def conninfo_attempts_async(params: ConnDict) -> AsyncIterator[ConnDict]: """Split a set of connection params on the single attempts to perforn. A connection param can perform more than one attempt more than one ``host`` is provided. Also perform async resolution of the hostname into hostaddr in order to avoid blocking. Because a host can resolve to more than one address, this can lead to yield more attempts too. Raise `OperationalError` if no host could be resolved. Because the libpq async function doesn't honour the timeout, we need to reimplement the repeated attempts. """ yielded = False last_exc = None for attempt in _split_attempts(_inject_defaults(params)): try: async for a2 in _split_attempts_and_resolve(attempt): yielded = True yield a2 except OSError as ex: last_exc = ex if not yielded: assert last_exc # We couldn't resolve anything raise e.OperationalError(str(last_exc)) def _inject_defaults(params: ConnDict) -> ConnDict: """ Add defaults to a dictionary of parameters. This avoids the need to look up for env vars at various stages during processing. Note that a port is always specified. 5432 likely comes from here. The `host`, `hostaddr`, `port` will be always set to a string. """ defaults = _conn_defaults() out = params.copy() def inject(name: str, envvar: str) -> None: value = out.get(name) if not value: out[name] = os.environ.get(envvar, defaults[name]) else: out[name] = str(value) inject("host", "PGHOST") inject("hostaddr", "PGHOSTADDR") inject("port", "PGPORT") return out def _split_attempts(params: ConnDict) -> Iterator[ConnDict]: """ Split connection parameters with a sequence of hosts into separate attempts. Assume that `host`, `hostaddr`, `port` are always present and a string (as emitted from `_inject_defaults()`). """ def split_val(key: str) -> list[str]: # Assume all keys are present and strings. val: str = params[key] return val.split(",") if val else [] hosts = split_val("host") hostaddrs = split_val("hostaddr") ports = split_val("port") if hosts and hostaddrs and len(hosts) != len(hostaddrs): raise e.OperationalError( f"could not match {len(hosts)} host names" f" with {len(hostaddrs)} hostaddr values" ) nhosts = max(len(hosts), len(hostaddrs)) if 1 < len(ports) != nhosts: raise e.OperationalError( f"could not match {len(ports)} port numbers to {len(hosts)} hosts" ) elif len(ports) == 1: ports *= nhosts # A single attempt to make if nhosts <= 1: yield params return # Now all lists are either empty or have the same length for i in range(nhosts): attempt = params.copy() if hosts: attempt["host"] = hosts[i] if hostaddrs: attempt["hostaddr"] = hostaddrs[i] if ports: attempt["port"] = ports[i] yield attempt async def _split_attempts_and_resolve(params: ConnDict) -> AsyncIterator[ConnDict]: """ Perform async DNS lookup of the hosts and return a new params dict. :param params: The input parameters, for instance as returned by `~psycopg.conninfo.conninfo_to_dict()`. The function expects at most a single entry for host, hostaddr, port and doesn't check for env vars because it is designed to further process the input of _split_attempts() If a ``host`` param is present but not ``hostname``, resolve the host addresses dynamically. The function may change the input ``host``, ``hostname``, ``port`` to allow connecting without further DNS lookups. Raise `~psycopg.OperationalError` if resolution fails. """ host = params["host"] if not host or host.startswith("/") or host[1:2] == ":": # Local path, or no host to resolve yield params return hostaddr = params["hostaddr"] if hostaddr: # Already resolved yield params return if is_ip_address(host): # If the host is already an ip address don't try to resolve it params["hostaddr"] = host yield params return loop = asyncio.get_running_loop() port = params["port"] ans = await loop.getaddrinfo( host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM ) attempt = params.copy() for item in ans: attempt["hostaddr"] = item[4][0] yield attempt @cache def _conn_defaults() -> dict[str, str]: """ Return a dictionary of defaults for connection strings parameters. """ defs = pq.Conninfo.get_defaults() return { d.keyword.decode(): d.compiled.decode() if d.compiled is not None else "" for d in defs } @lru_cache() def is_ip_address(s: str) -> bool: """Return True if the string represent a valid ip address.""" try: ip_address(s) except ValueError: return False return True