docker setup

This commit is contained in:
AdrienLSH
2023-11-23 16:43:30 +01:00
parent fd19180e1d
commit f29003c66a
5410 changed files with 869440 additions and 0 deletions

View File

@ -0,0 +1,110 @@
"""
psycopg -- PostgreSQL database adapter for Python
"""
# Copyright (C) 2020 The Psycopg Team
import logging
from . import pq # noqa: F401 import early to stabilize side effects
from . import types
from . import postgres
from ._tpc import Xid
from .copy import Copy, AsyncCopy
from ._enums import IsolationLevel
from .cursor import Cursor
from .errors import Warning, Error, InterfaceError, DatabaseError
from .errors import DataError, OperationalError, IntegrityError
from .errors import InternalError, ProgrammingError, NotSupportedError
from ._column import Column
from .conninfo import ConnectionInfo
from ._pipeline import Pipeline, AsyncPipeline
from .connection import BaseConnection, Connection, Notify
from .transaction import Rollback, Transaction, AsyncTransaction
from .cursor_async import AsyncCursor
from .server_cursor import AsyncServerCursor, ServerCursor
from .client_cursor import AsyncClientCursor, ClientCursor
from .connection_async import AsyncConnection
from . import dbapi20
from .dbapi20 import BINARY, DATETIME, NUMBER, ROWID, STRING
from .dbapi20 import Binary, Date, DateFromTicks, Time, TimeFromTicks
from .dbapi20 import Timestamp, TimestampFromTicks
from .version import __version__ as __version__ # noqa: F401
# Set the logger to a quiet default, can be enabled if needed
logger = logging.getLogger("psycopg")
if logger.level == logging.NOTSET:
logger.setLevel(logging.WARNING)
# DBAPI compliance
connect = Connection.connect
apilevel = "2.0"
threadsafety = 2
paramstyle = "pyformat"
# register default adapters for PostgreSQL
adapters = postgres.adapters # exposed by the package
postgres.register_default_adapters(adapters)
# After the default ones, because these can deal with the bytea oid better
dbapi20.register_dbapi20_adapters(adapters)
# Must come after all the types have been registered
types.array.register_all_arrays(adapters)
# Note: defining the exported methods helps both Sphynx in documenting that
# this is the canonical place to obtain them and should be used by MyPy too,
# so that function signatures are consistent with the documentation.
__all__ = [
"AsyncClientCursor",
"AsyncConnection",
"AsyncCopy",
"AsyncCursor",
"AsyncPipeline",
"AsyncServerCursor",
"AsyncTransaction",
"BaseConnection",
"ClientCursor",
"Column",
"Connection",
"ConnectionInfo",
"Copy",
"Cursor",
"IsolationLevel",
"Notify",
"Pipeline",
"Rollback",
"ServerCursor",
"Transaction",
"Xid",
# DBAPI exports
"connect",
"apilevel",
"threadsafety",
"paramstyle",
"Warning",
"Error",
"InterfaceError",
"DatabaseError",
"DataError",
"OperationalError",
"IntegrityError",
"InternalError",
"ProgrammingError",
"NotSupportedError",
# DBAPI type constructors and singletons
"Binary",
"Date",
"DateFromTicks",
"Time",
"TimeFromTicks",
"Timestamp",
"TimestampFromTicks",
"BINARY",
"DATETIME",
"NUMBER",
"ROWID",
"STRING",
]

View File

@ -0,0 +1,296 @@
"""
Mapping from types/oids to Dumpers/Loaders
"""
# Copyright (C) 2020 The Psycopg Team
from typing import Any, Dict, List, Optional, Type, TypeVar, Union
from typing import cast, TYPE_CHECKING
from . import pq
from . import errors as e
from .abc import Dumper, Loader
from ._enums import PyFormat as PyFormat
from ._cmodule import _psycopg
from ._typeinfo import TypesRegistry
if TYPE_CHECKING:
from .connection import BaseConnection
RV = TypeVar("RV")
class AdaptersMap:
r"""
Establish how types should be converted between Python and PostgreSQL in
an `~psycopg.abc.AdaptContext`.
`!AdaptersMap` maps Python types to `~psycopg.adapt.Dumper` classes to
define how Python types are converted to PostgreSQL, and maps OIDs to
`~psycopg.adapt.Loader` classes to establish how query results are
converted to Python.
Every `!AdaptContext` object has an underlying `!AdaptersMap` defining how
types are converted in that context, exposed as the
`~psycopg.abc.AdaptContext.adapters` attribute: changing such map allows
to customise adaptation in a context without changing separated contexts.
When a context is created from another context (for instance when a
`~psycopg.Cursor` is created from a `~psycopg.Connection`), the parent's
`!adapters` are used as template for the child's `!adapters`, so that every
cursor created from the same connection use the connection's types
configuration, but separate connections have independent mappings.
Once created, `!AdaptersMap` are independent. This means that objects
already created are not affected if a wider scope (e.g. the global one) is
changed.
The connections adapters are initialised using a global `!AdptersMap`
template, exposed as `psycopg.adapters`: changing such mapping allows to
customise the type mapping for every connections created afterwards.
The object can start empty or copy from another object of the same class.
Copies are copy-on-write: if the maps are updated make a copy. This way
extending e.g. global map by a connection or a connection map from a cursor
is cheap: a copy is only made on customisation.
"""
__module__ = "psycopg.adapt"
types: TypesRegistry
_dumpers: Dict[PyFormat, Dict[Union[type, str], Type[Dumper]]]
_dumpers_by_oid: List[Dict[int, Type[Dumper]]]
_loaders: List[Dict[int, Type[Loader]]]
# Record if a dumper or loader has an optimised version.
_optimised: Dict[type, type] = {}
def __init__(
self,
template: Optional["AdaptersMap"] = None,
types: Optional[TypesRegistry] = None,
):
if template:
self._dumpers = template._dumpers.copy()
self._own_dumpers = _dumpers_shared.copy()
template._own_dumpers = _dumpers_shared.copy()
self._dumpers_by_oid = template._dumpers_by_oid[:]
self._own_dumpers_by_oid = [False, False]
template._own_dumpers_by_oid = [False, False]
self._loaders = template._loaders[:]
self._own_loaders = [False, False]
template._own_loaders = [False, False]
self.types = TypesRegistry(template.types)
else:
self._dumpers = {fmt: {} for fmt in PyFormat}
self._own_dumpers = _dumpers_owned.copy()
self._dumpers_by_oid = [{}, {}]
self._own_dumpers_by_oid = [True, True]
self._loaders = [{}, {}]
self._own_loaders = [True, True]
self.types = types or TypesRegistry()
# implement the AdaptContext protocol too
@property
def adapters(self) -> "AdaptersMap":
return self
@property
def connection(self) -> Optional["BaseConnection[Any]"]:
return None
def register_dumper(
self, cls: Union[type, str, None], dumper: Type[Dumper]
) -> None:
"""
Configure the context to use `!dumper` to convert objects of type `!cls`.
If two dumpers with different `~Dumper.format` are registered for the
same type, the last one registered will be chosen when the query
doesn't specify a format (i.e. when the value is used with a ``%s``
"`~PyFormat.AUTO`" placeholder).
:param cls: The type to manage.
:param dumper: The dumper to register for `!cls`.
If `!cls` is specified as string it will be lazy-loaded, so that it
will be possible to register it without importing it before. In this
case it should be the fully qualified name of the object (e.g.
``"uuid.UUID"``).
If `!cls` is None, only use the dumper when looking up using
`get_dumper_by_oid()`, which happens when we know the Postgres type to
adapt to, but not the Python type that will be adapted (e.g. in COPY
after using `~psycopg.Copy.set_types()`).
"""
if not (cls is None or isinstance(cls, (str, type))):
raise TypeError(
f"dumpers should be registered on classes, got {cls} instead"
)
if _psycopg:
dumper = self._get_optimised(dumper)
# Register the dumper both as its format and as auto
# so that the last dumper registered is used in auto (%s) format
if cls:
for fmt in (PyFormat.from_pq(dumper.format), PyFormat.AUTO):
if not self._own_dumpers[fmt]:
self._dumpers[fmt] = self._dumpers[fmt].copy()
self._own_dumpers[fmt] = True
self._dumpers[fmt][cls] = dumper
# Register the dumper by oid, if the oid of the dumper is fixed
if dumper.oid:
if not self._own_dumpers_by_oid[dumper.format]:
self._dumpers_by_oid[dumper.format] = self._dumpers_by_oid[
dumper.format
].copy()
self._own_dumpers_by_oid[dumper.format] = True
self._dumpers_by_oid[dumper.format][dumper.oid] = dumper
def register_loader(self, oid: Union[int, str], loader: Type["Loader"]) -> None:
"""
Configure the context to use `!loader` to convert data of oid `!oid`.
:param oid: The PostgreSQL OID or type name to manage.
:param loader: The loar to register for `!oid`.
If `oid` is specified as string, it refers to a type name, which is
looked up in the `types` registry. `
"""
if isinstance(oid, str):
oid = self.types[oid].oid
if not isinstance(oid, int):
raise TypeError(f"loaders should be registered on oid, got {oid} instead")
if _psycopg:
loader = self._get_optimised(loader)
fmt = loader.format
if not self._own_loaders[fmt]:
self._loaders[fmt] = self._loaders[fmt].copy()
self._own_loaders[fmt] = True
self._loaders[fmt][oid] = loader
def get_dumper(self, cls: type, format: PyFormat) -> Type["Dumper"]:
"""
Return the dumper class for the given type and format.
Raise `~psycopg.ProgrammingError` if a class is not available.
:param cls: The class to adapt.
:param format: The format to dump to. If `~psycopg.adapt.PyFormat.AUTO`,
use the last one of the dumpers registered on `!cls`.
"""
try:
# Fast path: the class has a known dumper.
return self._dumpers[format][cls]
except KeyError:
if format not in self._dumpers:
raise ValueError(f"bad dumper format: {format}")
# If the KeyError was caused by cls missing from dmap, let's
# look for different cases.
dmap = self._dumpers[format]
# Look for the right class, including looking at superclasses
for scls in cls.__mro__:
if scls in dmap:
return dmap[scls]
# If the adapter is not found, look for its name as a string
fqn = scls.__module__ + "." + scls.__qualname__
if fqn in dmap:
# Replace the class name with the class itself
d = dmap[scls] = dmap.pop(fqn)
return d
format = PyFormat(format)
raise e.ProgrammingError(
f"cannot adapt type {cls.__name__!r} using placeholder '%{format.value}'"
f" (format: {format.name})"
)
def get_dumper_by_oid(self, oid: int, format: pq.Format) -> Type["Dumper"]:
"""
Return the dumper class for the given oid and format.
Raise `~psycopg.ProgrammingError` if a class is not available.
:param oid: The oid of the type to dump to.
:param format: The format to dump to.
"""
try:
dmap = self._dumpers_by_oid[format]
except KeyError:
raise ValueError(f"bad dumper format: {format}")
try:
return dmap[oid]
except KeyError:
info = self.types.get(oid)
if info:
msg = (
f"cannot find a dumper for type {info.name} (oid {oid})"
f" format {pq.Format(format).name}"
)
else:
msg = (
f"cannot find a dumper for unknown type with oid {oid}"
f" format {pq.Format(format).name}"
)
raise e.ProgrammingError(msg)
def get_loader(self, oid: int, format: pq.Format) -> Optional[Type["Loader"]]:
"""
Return the loader class for the given oid and format.
Return `!None` if not found.
:param oid: The oid of the type to load.
:param format: The format to load from.
"""
return self._loaders[format].get(oid)
@classmethod
def _get_optimised(self, cls: Type[RV]) -> Type[RV]:
"""Return the optimised version of a Dumper or Loader class.
Return the input class itself if there is no optimised version.
"""
try:
return self._optimised[cls]
except KeyError:
pass
# Check if the class comes from psycopg.types and there is a class
# with the same name in psycopg_c._psycopg.
from psycopg import types
if cls.__module__.startswith(types.__name__):
new = cast(Type[RV], getattr(_psycopg, cls.__name__, None))
if new:
self._optimised[cls] = new
return new
self._optimised[cls] = cls
return cls
# Micro-optimization: copying these objects is faster than creating new dicts
_dumpers_owned = dict.fromkeys(PyFormat, True)
_dumpers_shared = dict.fromkeys(PyFormat, False)

View File

@ -0,0 +1,24 @@
"""
Simplify access to the _psycopg module
"""
# Copyright (C) 2021 The Psycopg Team
from typing import Optional
from . import pq
__version__: Optional[str] = None
# Note: "c" must the first attempt so that mypy associates the variable the
# right module interface. It will not result Optional, but hey.
if pq.__impl__ == "c":
from psycopg_c import _psycopg as _psycopg
from psycopg_c import __version__ as __version__ # noqa: F401
elif pq.__impl__ == "binary":
from psycopg_binary import _psycopg as _psycopg # type: ignore
from psycopg_binary import __version__ as __version__ # type: ignore # noqa: F401
elif pq.__impl__ == "python":
_psycopg = None # type: ignore
else:
raise ImportError(f"can't find _psycopg optimised module in {pq.__impl__!r}")

View File

@ -0,0 +1,142 @@
"""
The Column object in Cursor.description
"""
# Copyright (C) 2020 The Psycopg Team
from typing import Any, NamedTuple, Optional, Sequence, TYPE_CHECKING
from operator import attrgetter
if TYPE_CHECKING:
from .cursor import BaseCursor
class ColumnData(NamedTuple):
ftype: int
fmod: int
fsize: int
class Column(Sequence[Any]):
__module__ = "psycopg"
def __init__(self, cursor: "BaseCursor[Any, Any]", index: int):
res = cursor.pgresult
assert res
fname = res.fname(index)
if fname:
self._name = fname.decode(cursor._encoding)
else:
# COPY_OUT results have columns but no name
self._name = f"column_{index + 1}"
self._data = ColumnData(
ftype=res.ftype(index),
fmod=res.fmod(index),
fsize=res.fsize(index),
)
self._type = cursor.adapters.types.get(self._data.ftype)
_attrs = tuple(
attrgetter(attr)
for attr in """
name type_code display_size internal_size precision scale null_ok
""".split()
)
def __repr__(self) -> str:
return (
f"<Column {self.name!r},"
f" type: {self._type_display()} (oid: {self.type_code})>"
)
def __len__(self) -> int:
return 7
def _type_display(self) -> str:
parts = []
parts.append(self._type.name if self._type else str(self.type_code))
mod1 = self.precision
if mod1 is None:
mod1 = self.display_size
if mod1:
parts.append(f"({mod1}")
if self.scale:
parts.append(f", {self.scale}")
parts.append(")")
if self._type and self.type_code == self._type.array_oid:
parts.append("[]")
return "".join(parts)
def __getitem__(self, index: Any) -> Any:
if isinstance(index, slice):
return tuple(getter(self) for getter in self._attrs[index])
else:
return self._attrs[index](self)
@property
def name(self) -> str:
"""The name of the column."""
return self._name
@property
def type_code(self) -> int:
"""The numeric OID of the column."""
return self._data.ftype
@property
def display_size(self) -> Optional[int]:
"""The field size, for :sql:`varchar(n)`, None otherwise."""
if not self._type:
return None
if self._type.name in ("varchar", "char"):
fmod = self._data.fmod
if fmod >= 0:
return fmod - 4
return None
@property
def internal_size(self) -> Optional[int]:
"""The internal field size for fixed-size types, None otherwise."""
fsize = self._data.fsize
return fsize if fsize >= 0 else None
@property
def precision(self) -> Optional[int]:
"""The number of digits for fixed precision types."""
if not self._type:
return None
dttypes = ("time", "timetz", "timestamp", "timestamptz", "interval")
if self._type.name == "numeric":
fmod = self._data.fmod
if fmod >= 0:
return fmod >> 16
elif self._type.name in dttypes:
fmod = self._data.fmod
if fmod >= 0:
return fmod & 0xFFFF
return None
@property
def scale(self) -> Optional[int]:
"""The number of digits after the decimal point if available."""
if self._type and self._type.name == "numeric":
fmod = self._data.fmod - 4
if fmod >= 0:
return fmod & 0xFFFF
return None
@property
def null_ok(self) -> Optional[bool]:
"""Always `!None`"""
return None

View File

@ -0,0 +1,72 @@
"""
compatibility functions for different Python versions
"""
# Copyright (C) 2021 The Psycopg Team
import sys
import asyncio
from typing import Any, Awaitable, Generator, Optional, Sequence, Union, TypeVar
# NOTE: TypeAlias cannot be exported by this module, as pyright special-cases it.
# For this raisin it must be imported directly from typing_extension where used.
# See https://github.com/microsoft/pyright/issues/4197
from typing_extensions import TypeAlias
if sys.version_info >= (3, 8):
from typing import Protocol
else:
from typing_extensions import Protocol
T = TypeVar("T")
FutureT: TypeAlias = Union["asyncio.Future[T]", Generator[Any, None, T], Awaitable[T]]
if sys.version_info >= (3, 8):
create_task = asyncio.create_task
from math import prod
else:
def create_task(
coro: FutureT[T], name: Optional[str] = None
) -> "asyncio.Future[T]":
return asyncio.create_task(coro)
from functools import reduce
def prod(seq: Sequence[int]) -> int:
return reduce(int.__mul__, seq, 1)
if sys.version_info >= (3, 9):
from zoneinfo import ZoneInfo
from functools import cache
from collections import Counter, deque as Deque
else:
from typing import Counter, Deque
from functools import lru_cache
from backports.zoneinfo import ZoneInfo
cache = lru_cache(maxsize=None)
if sys.version_info >= (3, 10):
from typing import TypeGuard
else:
from typing_extensions import TypeGuard
if sys.version_info >= (3, 11):
from typing import LiteralString
else:
from typing_extensions import LiteralString
__all__ = [
"Counter",
"Deque",
"LiteralString",
"Protocol",
"TypeGuard",
"ZoneInfo",
"cache",
"create_task",
"prod",
]

View File

@ -0,0 +1,252 @@
# type: ignore # dnspython is currently optional and mypy fails if missing
"""
DNS query support
"""
# Copyright (C) 2021 The Psycopg Team
import os
import re
import warnings
from random import randint
from typing import Any, DefaultDict, Dict, List, NamedTuple, Optional, Sequence
from typing import TYPE_CHECKING
from collections import defaultdict
try:
from dns.resolver import Resolver, Cache
from dns.asyncresolver import Resolver as AsyncResolver
from dns.exception import DNSException
except ImportError:
raise ImportError(
"the module psycopg._dns requires the package 'dnspython' installed"
)
from . import errors as e
from . import conninfo
if TYPE_CHECKING:
from dns.rdtypes.IN.SRV import SRV
resolver = Resolver()
resolver.cache = Cache()
async_resolver = AsyncResolver()
async_resolver.cache = Cache()
async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]:
"""
Perform async DNS lookup of the hosts and return a new params dict.
.. deprecated:: 3.1
The use of this function is not necessary anymore, because
`psycopg.AsyncConnection.connect()` performs non-blocking name
resolution automatically.
"""
warnings.warn(
"from psycopg 3.1, resolve_hostaddr_async() is not needed anymore",
DeprecationWarning,
)
hosts: list[str] = []
hostaddrs: list[str] = []
ports: list[str] = []
for attempt in conninfo._split_attempts(conninfo._inject_defaults(params)):
try:
async for a2 in conninfo._split_attempts_and_resolve(attempt):
hosts.append(a2["host"])
hostaddrs.append(a2["hostaddr"])
if "port" in params:
ports.append(a2["port"])
except OSError as ex:
last_exc = ex
if params.get("host") and not hosts:
# We couldn't resolve anything
raise e.OperationalError(str(last_exc))
out = params.copy()
shosts = ",".join(hosts)
if shosts:
out["host"] = shosts
shostaddrs = ",".join(hostaddrs)
if shostaddrs:
out["hostaddr"] = shostaddrs
sports = ",".join(ports)
if ports:
out["port"] = sports
return out
def resolve_srv(params: Dict[str, Any]) -> Dict[str, Any]:
"""Apply SRV DNS lookup as defined in :RFC:`2782`."""
return Rfc2782Resolver().resolve(params)
async def resolve_srv_async(params: Dict[str, Any]) -> Dict[str, Any]:
"""Async equivalent of `resolve_srv()`."""
return await Rfc2782Resolver().resolve_async(params)
class HostPort(NamedTuple):
host: str
port: str
totry: bool = False
target: Optional[str] = None
class Rfc2782Resolver:
"""Implement SRV RR Resolution as per RFC 2782
The class is organised to minimise code duplication between the sync and
the async paths.
"""
re_srv_rr = re.compile(r"^(?P<service>_[^\.]+)\.(?P<proto>_[^\.]+)\.(?P<target>.+)")
def resolve(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""Update the parameters host and port after SRV lookup."""
attempts = self._get_attempts(params)
if not attempts:
return params
hps = []
for hp in attempts:
if hp.totry:
hps.extend(self._resolve_srv(hp))
else:
hps.append(hp)
return self._return_params(params, hps)
async def resolve_async(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""Update the parameters host and port after SRV lookup."""
attempts = self._get_attempts(params)
if not attempts:
return params
hps = []
for hp in attempts:
if hp.totry:
hps.extend(await self._resolve_srv_async(hp))
else:
hps.append(hp)
return self._return_params(params, hps)
def _get_attempts(self, params: Dict[str, Any]) -> List[HostPort]:
"""
Return the list of host, and for each host if SRV lookup must be tried.
Return an empty list if no lookup is requested.
"""
# If hostaddr is defined don't do any resolution.
if params.get("hostaddr", os.environ.get("PGHOSTADDR", "")):
return []
host_arg: str = params.get("host", os.environ.get("PGHOST", ""))
hosts_in = host_arg.split(",")
port_arg: str = str(params.get("port", os.environ.get("PGPORT", "")))
ports_in = port_arg.split(",")
if len(ports_in) == 1:
# If only one port is specified, it applies to all the hosts.
ports_in *= len(hosts_in)
if len(ports_in) != len(hosts_in):
# ProgrammingError would have been more appropriate, but this is
# what the raise if the libpq fails connect in the same case.
raise e.OperationalError(
f"cannot match {len(hosts_in)} hosts with {len(ports_in)} port numbers"
)
out = []
srv_found = False
for host, port in zip(hosts_in, ports_in):
m = self.re_srv_rr.match(host)
if m or port.lower() == "srv":
srv_found = True
target = m.group("target") if m else None
hp = HostPort(host=host, port=port, totry=True, target=target)
else:
hp = HostPort(host=host, port=port)
out.append(hp)
return out if srv_found else []
def _resolve_srv(self, hp: HostPort) -> List[HostPort]:
try:
ans = resolver.resolve(hp.host, "SRV")
except DNSException:
ans = ()
return self._get_solved_entries(hp, ans)
async def _resolve_srv_async(self, hp: HostPort) -> List[HostPort]:
try:
ans = await async_resolver.resolve(hp.host, "SRV")
except DNSException:
ans = ()
return self._get_solved_entries(hp, ans)
def _get_solved_entries(
self, hp: HostPort, entries: "Sequence[SRV]"
) -> List[HostPort]:
if not entries:
# No SRV entry found. Delegate the libpq a QNAME=target lookup
if hp.target and hp.port.lower() != "srv":
return [HostPort(host=hp.target, port=hp.port)]
else:
return []
# If there is precisely one SRV RR, and its Target is "." (the root
# domain), abort.
if len(entries) == 1 and str(entries[0].target) == ".":
return []
return [
HostPort(host=str(entry.target).rstrip("."), port=str(entry.port))
for entry in self.sort_rfc2782(entries)
]
def _return_params(
self, params: Dict[str, Any], hps: List[HostPort]
) -> Dict[str, Any]:
if not hps:
# Nothing found, we ended up with an empty list
raise e.OperationalError("no host found after SRV RR lookup")
out = params.copy()
out["host"] = ",".join(hp.host for hp in hps)
out["port"] = ",".join(str(hp.port) for hp in hps)
return out
def sort_rfc2782(self, ans: "Sequence[SRV]") -> "List[SRV]":
"""
Implement the priority/weight ordering defined in RFC 2782.
"""
# Divide the entries by priority:
priorities: DefaultDict[int, "List[SRV]"] = defaultdict(list)
out: "List[SRV]" = []
for entry in ans:
priorities[entry.priority].append(entry)
for pri, entries in sorted(priorities.items()):
if len(entries) == 1:
out.append(entries[0])
continue
entries.sort(key=lambda ent: ent.weight)
total_weight = sum(ent.weight for ent in entries)
while entries:
r = randint(0, total_weight)
csum = 0
for i, ent in enumerate(entries):
csum += ent.weight
if csum >= r:
break
out.append(ent)
total_weight -= ent.weight
del entries[i]
return out

View File

@ -0,0 +1,170 @@
"""
Mappings between PostgreSQL and Python encodings.
"""
# Copyright (C) 2020 The Psycopg Team
import re
import string
import codecs
from typing import Any, Dict, Optional, TYPE_CHECKING
from .pq._enums import ConnStatus
from .errors import NotSupportedError
from ._compat import cache
if TYPE_CHECKING:
from .pq.abc import PGconn
from .connection import BaseConnection
OK = ConnStatus.OK
_py_codecs = {
"BIG5": "big5",
"EUC_CN": "gb2312",
"EUC_JIS_2004": "euc_jis_2004",
"EUC_JP": "euc_jp",
"EUC_KR": "euc_kr",
# "EUC_TW": not available in Python
"GB18030": "gb18030",
"GBK": "gbk",
"ISO_8859_5": "iso8859-5",
"ISO_8859_6": "iso8859-6",
"ISO_8859_7": "iso8859-7",
"ISO_8859_8": "iso8859-8",
"JOHAB": "johab",
"KOI8R": "koi8-r",
"KOI8U": "koi8-u",
"LATIN1": "iso8859-1",
"LATIN10": "iso8859-16",
"LATIN2": "iso8859-2",
"LATIN3": "iso8859-3",
"LATIN4": "iso8859-4",
"LATIN5": "iso8859-9",
"LATIN6": "iso8859-10",
"LATIN7": "iso8859-13",
"LATIN8": "iso8859-14",
"LATIN9": "iso8859-15",
# "MULE_INTERNAL": not available in Python
"SHIFT_JIS_2004": "shift_jis_2004",
"SJIS": "shift_jis",
# this actually means no encoding, see PostgreSQL docs
# it is special-cased by the text loader.
"SQL_ASCII": "ascii",
"UHC": "cp949",
"UTF8": "utf-8",
"WIN1250": "cp1250",
"WIN1251": "cp1251",
"WIN1252": "cp1252",
"WIN1253": "cp1253",
"WIN1254": "cp1254",
"WIN1255": "cp1255",
"WIN1256": "cp1256",
"WIN1257": "cp1257",
"WIN1258": "cp1258",
"WIN866": "cp866",
"WIN874": "cp874",
}
py_codecs: Dict[bytes, str] = {}
py_codecs.update((k.encode(), v) for k, v in _py_codecs.items())
# Add an alias without underscore, for lenient lookups
py_codecs.update(
(k.replace("_", "").encode(), v) for k, v in _py_codecs.items() if "_" in k
)
pg_codecs = {v: k.encode() for k, v in _py_codecs.items()}
def conn_encoding(conn: "Optional[BaseConnection[Any]]") -> str:
"""
Return the Python encoding name of a psycopg connection.
Default to utf8 if the connection has no encoding info.
"""
if not conn or conn.closed:
return "utf-8"
pgenc = conn.pgconn.parameter_status(b"client_encoding") or b"UTF8"
return pg2pyenc(pgenc)
def pgconn_encoding(pgconn: "PGconn") -> str:
"""
Return the Python encoding name of a libpq connection.
Default to utf8 if the connection has no encoding info.
"""
if pgconn.status != OK:
return "utf-8"
pgenc = pgconn.parameter_status(b"client_encoding") or b"UTF8"
return pg2pyenc(pgenc)
def conninfo_encoding(conninfo: str) -> str:
"""
Return the Python encoding name passed in a conninfo string. Default to utf8.
Because the input is likely to come from the user and not normalised by the
server, be somewhat lenient (non-case-sensitive lookup, ignore noise chars).
"""
from .conninfo import conninfo_to_dict
params = conninfo_to_dict(conninfo)
pgenc = params.get("client_encoding")
if pgenc:
try:
return pg2pyenc(pgenc.encode())
except NotSupportedError:
pass
return "utf-8"
@cache
def py2pgenc(name: str) -> bytes:
"""Convert a Python encoding name to PostgreSQL encoding name.
Raise LookupError if the Python encoding is unknown.
"""
return pg_codecs[codecs.lookup(name).name]
@cache
def pg2pyenc(name: bytes) -> str:
"""Convert a PostgreSQL encoding name to Python encoding name.
Raise NotSupportedError if the PostgreSQL encoding is not supported by
Python.
"""
try:
return py_codecs[name.replace(b"-", b"").replace(b"_", b"").upper()]
except KeyError:
sname = name.decode("utf8", "replace")
raise NotSupportedError(f"codec not available in Python: {sname!r}")
def _as_python_identifier(s: str, prefix: str = "f") -> str:
"""
Reduce a string to a valid Python identifier.
Replace all non-valid chars with '_' and prefix the value with `!prefix` if
the first letter is an '_'.
"""
if not s.isidentifier():
if s[0] in "1234567890":
s = prefix + s
if not s.isidentifier():
s = _re_clean.sub("_", s)
# namedtuple fields cannot start with underscore. So...
if s[0] == "_":
s = prefix + s
return s
_re_clean = re.compile(
f"[^{string.ascii_lowercase}{string.ascii_uppercase}{string.digits}_]"
)

View File

@ -0,0 +1,79 @@
"""
Enum values for psycopg
These values are defined by us and are not necessarily dependent on
libpq-defined enums.
"""
# Copyright (C) 2020 The Psycopg Team
from enum import Enum, IntEnum
from selectors import EVENT_READ, EVENT_WRITE
from . import pq
class Wait(IntEnum):
R = EVENT_READ
W = EVENT_WRITE
RW = EVENT_READ | EVENT_WRITE
class Ready(IntEnum):
R = EVENT_READ
W = EVENT_WRITE
RW = EVENT_READ | EVENT_WRITE
class PyFormat(str, Enum):
"""
Enum representing the format wanted for a query argument.
The value `AUTO` allows psycopg to choose the best format for a certain
parameter.
"""
__module__ = "psycopg.adapt"
AUTO = "s"
"""Automatically chosen (``%s`` placeholder)."""
TEXT = "t"
"""Text parameter (``%t`` placeholder)."""
BINARY = "b"
"""Binary parameter (``%b`` placeholder)."""
@classmethod
def from_pq(cls, fmt: pq.Format) -> "PyFormat":
return _pg2py[fmt]
@classmethod
def as_pq(cls, fmt: "PyFormat") -> pq.Format:
return _py2pg[fmt]
class IsolationLevel(IntEnum):
"""
Enum representing the isolation level for a transaction.
"""
__module__ = "psycopg"
READ_UNCOMMITTED = 1
""":sql:`READ UNCOMMITTED` isolation level."""
READ_COMMITTED = 2
""":sql:`READ COMMITTED` isolation level."""
REPEATABLE_READ = 3
""":sql:`REPEATABLE READ` isolation level."""
SERIALIZABLE = 4
""":sql:`SERIALIZABLE` isolation level."""
_py2pg = {
PyFormat.TEXT: pq.Format.TEXT,
PyFormat.BINARY: pq.Format.BINARY,
}
_pg2py = {
pq.Format.TEXT: PyFormat.TEXT,
pq.Format.BINARY: PyFormat.BINARY,
}

View File

@ -0,0 +1,298 @@
"""
commands pipeline management
"""
# Copyright (C) 2021 The Psycopg Team
import logging
from types import TracebackType
from typing import Any, List, Optional, Union, Tuple, Type, TypeVar, TYPE_CHECKING
from typing_extensions import TypeAlias
from . import pq
from . import errors as e
from .abc import PipelineCommand, PQGen
from ._compat import Deque
from .pq.misc import connection_summary
from ._encodings import pgconn_encoding
from ._preparing import Key, Prepare
from .generators import pipeline_communicate, fetch_many, send
if TYPE_CHECKING:
from .pq.abc import PGresult
from .cursor import BaseCursor
from .connection import BaseConnection, Connection
from .connection_async import AsyncConnection
PendingResult: TypeAlias = Union[
None, Tuple["BaseCursor[Any, Any]", Optional[Tuple[Key, Prepare, bytes]]]
]
FATAL_ERROR = pq.ExecStatus.FATAL_ERROR
PIPELINE_ABORTED = pq.ExecStatus.PIPELINE_ABORTED
BAD = pq.ConnStatus.BAD
ACTIVE = pq.TransactionStatus.ACTIVE
logger = logging.getLogger("psycopg")
class BasePipeline:
command_queue: Deque[PipelineCommand]
result_queue: Deque[PendingResult]
_is_supported: Optional[bool] = None
def __init__(self, conn: "BaseConnection[Any]") -> None:
self._conn = conn
self.pgconn = conn.pgconn
self.command_queue = Deque[PipelineCommand]()
self.result_queue = Deque[PendingResult]()
self.level = 0
def __repr__(self) -> str:
cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
info = connection_summary(self._conn.pgconn)
return f"<{cls} {info} at 0x{id(self):x}>"
@property
def status(self) -> pq.PipelineStatus:
return pq.PipelineStatus(self.pgconn.pipeline_status)
@classmethod
def is_supported(cls) -> bool:
"""Return `!True` if the psycopg libpq wrapper supports pipeline mode."""
if BasePipeline._is_supported is None:
BasePipeline._is_supported = not cls._not_supported_reason()
return BasePipeline._is_supported
@classmethod
def _not_supported_reason(cls) -> str:
"""Return the reason why the pipeline mode is not supported.
Return an empty string if pipeline mode is supported.
"""
# Support only depends on the libpq functions available in the pq
# wrapper, not on the database version.
if pq.version() < 140000:
return (
f"libpq too old {pq.version()};"
" v14 or greater required for pipeline mode"
)
if pq.__build_version__ < 140000:
return (
f"libpq too old: module built for {pq.__build_version__};"
" v14 or greater required for pipeline mode"
)
return ""
def _enter_gen(self) -> PQGen[None]:
if not self.is_supported():
raise e.NotSupportedError(
f"pipeline mode not supported: {self._not_supported_reason()}"
)
if self.level == 0:
self.pgconn.enter_pipeline_mode()
elif self.command_queue or self.pgconn.transaction_status == ACTIVE:
# Nested pipeline case.
# Transaction might be ACTIVE when the pipeline uses an "implicit
# transaction", typically in autocommit mode. But when entering a
# Psycopg transaction(), we expect the IDLE state. By sync()-ing,
# we make sure all previous commands are completed and the
# transaction gets back to IDLE.
yield from self._sync_gen()
self.level += 1
def _exit(self, exc: Optional[BaseException]) -> None:
self.level -= 1
if self.level == 0 and self.pgconn.status != BAD:
try:
self.pgconn.exit_pipeline_mode()
except e.OperationalError as exc2:
# Notice that this error might be pretty irrecoverable. It
# happens on COPY, for instance: even if sync succeeds, exiting
# fails with "cannot exit pipeline mode with uncollected results"
if exc:
logger.warning("error ignored exiting %r: %s", self, exc2)
else:
raise exc2.with_traceback(None)
def _sync_gen(self) -> PQGen[None]:
self._enqueue_sync()
yield from self._communicate_gen()
yield from self._fetch_gen(flush=False)
def _exit_gen(self) -> PQGen[None]:
"""
Exit current pipeline by sending a Sync and fetch back all remaining results.
"""
try:
self._enqueue_sync()
yield from self._communicate_gen()
finally:
# No need to force flush since we emitted a sync just before.
yield from self._fetch_gen(flush=False)
def _communicate_gen(self) -> PQGen[None]:
"""Communicate with pipeline to send commands and possibly fetch
results, which are then processed.
"""
fetched = yield from pipeline_communicate(self.pgconn, self.command_queue)
exception = None
for results in fetched:
queued = self.result_queue.popleft()
try:
self._process_results(queued, results)
except e.Error as exc:
if exception is None:
exception = exc
if exception is not None:
raise exception
def _fetch_gen(self, *, flush: bool) -> PQGen[None]:
"""Fetch available results from the connection and process them with
pipeline queued items.
If 'flush' is True, a PQsendFlushRequest() is issued in order to make
sure results can be fetched. Otherwise, the caller may emit a
PQpipelineSync() call to ensure the output buffer gets flushed before
fetching.
"""
if not self.result_queue:
return
if flush:
self.pgconn.send_flush_request()
yield from send(self.pgconn)
exception = None
while self.result_queue:
results = yield from fetch_many(self.pgconn)
if not results:
# No more results to fetch, but there may still be pending
# commands.
break
queued = self.result_queue.popleft()
try:
self._process_results(queued, results)
except e.Error as exc:
if exception is None:
exception = exc
if exception is not None:
raise exception
def _process_results(
self, queued: PendingResult, results: List["PGresult"]
) -> None:
"""Process a results set fetched from the current pipeline.
This matches 'results' with its respective element in the pipeline
queue. For commands (None value in the pipeline queue), results are
checked directly. For prepare statement creation requests, update the
cache. Otherwise, results are attached to their respective cursor.
"""
if queued is None:
(result,) = results
if result.status == FATAL_ERROR:
raise e.error_from_result(result, encoding=pgconn_encoding(self.pgconn))
elif result.status == PIPELINE_ABORTED:
raise e.PipelineAborted("pipeline aborted")
else:
cursor, prepinfo = queued
if prepinfo:
key, prep, name = prepinfo
# Update the prepare state of the query.
cursor._conn._prepared.validate(key, prep, name, results)
cursor._set_results_from_pipeline(results)
def _enqueue_sync(self) -> None:
"""Enqueue a PQpipelineSync() command."""
self.command_queue.append(self.pgconn.pipeline_sync)
self.result_queue.append(None)
class Pipeline(BasePipeline):
"""Handler for connection in pipeline mode."""
__module__ = "psycopg"
_conn: "Connection[Any]"
_Self = TypeVar("_Self", bound="Pipeline")
def __init__(self, conn: "Connection[Any]") -> None:
super().__init__(conn)
def sync(self) -> None:
"""Sync the pipeline, send any pending command and receive and process
all available results.
"""
try:
with self._conn.lock:
self._conn.wait(self._sync_gen())
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
def __enter__(self: _Self) -> _Self:
with self._conn.lock:
self._conn.wait(self._enter_gen())
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
try:
with self._conn.lock:
self._conn.wait(self._exit_gen())
except Exception as exc2:
# Don't clobber an exception raised in the block with this one
if exc_val:
logger.warning("error ignored terminating %r: %s", self, exc2)
else:
raise exc2.with_traceback(None)
finally:
self._exit(exc_val)
class AsyncPipeline(BasePipeline):
"""Handler for async connection in pipeline mode."""
__module__ = "psycopg"
_conn: "AsyncConnection[Any]"
_Self = TypeVar("_Self", bound="AsyncPipeline")
def __init__(self, conn: "AsyncConnection[Any]") -> None:
super().__init__(conn)
async def sync(self) -> None:
try:
async with self._conn.lock:
await self._conn.wait(self._sync_gen())
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
async def __aenter__(self: _Self) -> _Self:
async with self._conn.lock:
await self._conn.wait(self._enter_gen())
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
try:
async with self._conn.lock:
await self._conn.wait(self._exit_gen())
except Exception as exc2:
# Don't clobber an exception raised in the block with this one
if exc_val:
logger.warning("error ignored terminating %r: %s", self, exc2)
else:
raise exc2.with_traceback(None)
finally:
self._exit(exc_val)

View File

@ -0,0 +1,194 @@
"""
Support for prepared statements
"""
# Copyright (C) 2020 The Psycopg Team
from enum import IntEnum, auto
from typing import Iterator, Optional, Sequence, Tuple, TYPE_CHECKING
from collections import OrderedDict
from typing_extensions import TypeAlias
from . import pq
from ._compat import Deque
from ._queries import PostgresQuery
if TYPE_CHECKING:
from .pq.abc import PGresult
Key: TypeAlias = Tuple[bytes, Tuple[int, ...]]
COMMAND_OK = pq.ExecStatus.COMMAND_OK
TUPLES_OK = pq.ExecStatus.TUPLES_OK
class Prepare(IntEnum):
NO = auto()
YES = auto()
SHOULD = auto()
class PrepareManager:
# Number of times a query is executed before it is prepared.
prepare_threshold: Optional[int] = 5
# Maximum number of prepared statements on the connection.
prepared_max: int = 100
def __init__(self) -> None:
# Map (query, types) to the number of times the query was seen.
self._counts: OrderedDict[Key, int] = OrderedDict()
# Map (query, types) to the name of the statement if prepared.
self._names: OrderedDict[Key, bytes] = OrderedDict()
# Counter to generate prepared statements names
self._prepared_idx = 0
self._maint_commands = Deque[bytes]()
@staticmethod
def key(query: PostgresQuery) -> Key:
return (query.query, query.types)
def get(
self, query: PostgresQuery, prepare: Optional[bool] = None
) -> Tuple[Prepare, bytes]:
"""
Check if a query is prepared, tell back whether to prepare it.
"""
if prepare is False or self.prepare_threshold is None:
# The user doesn't want this query to be prepared
return Prepare.NO, b""
key = self.key(query)
name = self._names.get(key)
if name:
# The query was already prepared in this session
return Prepare.YES, name
count = self._counts.get(key, 0)
if count >= self.prepare_threshold or prepare:
# The query has been executed enough times and needs to be prepared
name = f"_pg3_{self._prepared_idx}".encode()
self._prepared_idx += 1
return Prepare.SHOULD, name
else:
# The query is not to be prepared yet
return Prepare.NO, b""
def _should_discard(self, prep: Prepare, results: Sequence["PGresult"]) -> bool:
"""Check if we need to discard our entire state: it should happen on
rollback or on dropping objects, because the same object may get
recreated and postgres would fail internal lookups.
"""
if self._names or prep == Prepare.SHOULD:
for result in results:
if result.status != COMMAND_OK:
continue
cmdstat = result.command_status
if cmdstat and (cmdstat.startswith(b"DROP ") or cmdstat == b"ROLLBACK"):
return self.clear()
return False
@staticmethod
def _check_results(results: Sequence["PGresult"]) -> bool:
"""Return False if 'results' are invalid for prepared statement cache."""
if len(results) != 1:
# We cannot prepare a multiple statement
return False
status = results[0].status
if COMMAND_OK != status != TUPLES_OK:
# We don't prepare failed queries or other weird results
return False
return True
def _rotate(self) -> None:
"""Evict an old value from the cache.
If it was prepared, deallocate it. Do it only once: if the cache was
resized, deallocate gradually.
"""
if len(self._counts) > self.prepared_max:
self._counts.popitem(last=False)
if len(self._names) > self.prepared_max:
name = self._names.popitem(last=False)[1]
self._maint_commands.append(b"DEALLOCATE " + name)
def maybe_add_to_cache(
self, query: PostgresQuery, prep: Prepare, name: bytes
) -> Optional[Key]:
"""Handle 'query' for possible addition to the cache.
If a new entry has been added, return its key. Return None otherwise
(meaning the query is already in cache or cache is not enabled).
"""
# don't do anything if prepared statements are disabled
if self.prepare_threshold is None:
return None
key = self.key(query)
if key in self._counts:
if prep is Prepare.SHOULD:
del self._counts[key]
self._names[key] = name
else:
self._counts[key] += 1
self._counts.move_to_end(key)
return None
elif key in self._names:
self._names.move_to_end(key)
return None
else:
if prep is Prepare.SHOULD:
self._names[key] = name
else:
self._counts[key] = 1
return key
def validate(
self,
key: Key,
prep: Prepare,
name: bytes,
results: Sequence["PGresult"],
) -> None:
"""Validate cached entry with 'key' by checking query 'results'.
Possibly record a command to perform maintenance on database side.
"""
if self._should_discard(prep, results):
return
if not self._check_results(results):
self._names.pop(key, None)
self._counts.pop(key, None)
else:
self._rotate()
def clear(self) -> bool:
"""Clear the cache of the maintenance commands.
Clear the internal state and prepare a command to clear the state of
the server.
"""
self._counts.clear()
if self._names:
self._names.clear()
self._maint_commands.clear()
self._maint_commands.append(b"DEALLOCATE ALL")
return True
else:
return False
def get_maintenance_commands(self) -> Iterator[bytes]:
"""
Iterate over the commands needed to align the server state to our state
"""
while self._maint_commands:
yield self._maint_commands.popleft()

View File

@ -0,0 +1,415 @@
"""
Utility module to manipulate queries
"""
# Copyright (C) 2020 The Psycopg Team
import re
from typing import Any, Callable, Dict, List, Mapping, Match, NamedTuple, Optional
from typing import Sequence, Tuple, Union, TYPE_CHECKING
from functools import lru_cache
from typing_extensions import TypeAlias
from . import pq
from . import errors as e
from .sql import Composable
from .abc import Buffer, Query, Params
from ._enums import PyFormat
from ._encodings import conn_encoding
if TYPE_CHECKING:
from .abc import Transformer
MAX_CACHED_STATEMENT_LENGTH = 4096
MAX_CACHED_STATEMENT_PARAMS = 50
class QueryPart(NamedTuple):
pre: bytes
item: Union[int, str]
format: PyFormat
class PostgresQuery:
"""
Helper to convert a Python query and parameters into Postgres format.
"""
__slots__ = """
query params types formats
_tx _want_formats _parts _encoding _order
""".split()
def __init__(self, transformer: "Transformer"):
self._tx = transformer
self.params: Optional[Sequence[Optional[Buffer]]] = None
# these are tuples so they can be used as keys e.g. in prepared stmts
self.types: Tuple[int, ...] = ()
# The format requested by the user and the ones to really pass Postgres
self._want_formats: Optional[List[PyFormat]] = None
self.formats: Optional[Sequence[pq.Format]] = None
self._encoding = conn_encoding(transformer.connection)
self._parts: List[QueryPart]
self.query = b""
self._order: Optional[List[str]] = None
def convert(self, query: Query, vars: Optional[Params]) -> None:
"""
Set up the query and parameters to convert.
The results of this function can be obtained accessing the object
attributes (`query`, `params`, `types`, `formats`).
"""
if isinstance(query, str):
bquery = query.encode(self._encoding)
elif isinstance(query, Composable):
bquery = query.as_bytes(self._tx)
else:
bquery = query
if vars is not None:
# Avoid caching queries extremely long or with a huge number of
# parameters. They are usually generated by ORMs and have poor
# cacheablility (e.g. INSERT ... VALUES (...), (...) with varying
# numbers of tuples.
# see https://github.com/psycopg/psycopg/discussions/628
if (
len(bquery) <= MAX_CACHED_STATEMENT_LENGTH
and len(vars) <= MAX_CACHED_STATEMENT_PARAMS
):
f: _Query2Pg = _query2pg
else:
f = _query2pg_nocache
(self.query, self._want_formats, self._order, self._parts) = f(
bquery, self._encoding
)
else:
self.query = bquery
self._want_formats = self._order = None
self.dump(vars)
def dump(self, vars: Optional[Params]) -> None:
"""
Process a new set of variables on the query processed by `convert()`.
This method updates `params` and `types`.
"""
if vars is not None:
params = _validate_and_reorder_params(self._parts, vars, self._order)
assert self._want_formats is not None
self.params = self._tx.dump_sequence(params, self._want_formats)
self.types = self._tx.types or ()
self.formats = self._tx.formats
else:
self.params = None
self.types = ()
self.formats = None
# The type of the _query2pg() and _query2pg_nocache() methods
_Query2Pg: TypeAlias = Callable[
[bytes, str], Tuple[bytes, List[PyFormat], Optional[List[str]], List[QueryPart]]
]
def _query2pg_nocache(
query: bytes, encoding: str
) -> Tuple[bytes, List[PyFormat], Optional[List[str]], List[QueryPart]]:
"""
Convert Python query and params into something Postgres understands.
- Convert Python placeholders (``%s``, ``%(name)s``) into Postgres
format (``$1``, ``$2``)
- placeholders can be %s, %t, or %b (auto, text or binary)
- return ``query`` (bytes), ``formats`` (list of formats) ``order``
(sequence of names used in the query, in the position they appear)
``parts`` (splits of queries and placeholders).
"""
parts = _split_query(query, encoding)
order: Optional[List[str]] = None
chunks: List[bytes] = []
formats = []
if isinstance(parts[0].item, int):
for part in parts[:-1]:
assert isinstance(part.item, int)
chunks.append(part.pre)
chunks.append(b"$%d" % (part.item + 1))
formats.append(part.format)
elif isinstance(parts[0].item, str):
seen: Dict[str, Tuple[bytes, PyFormat]] = {}
order = []
for part in parts[:-1]:
assert isinstance(part.item, str)
chunks.append(part.pre)
if part.item not in seen:
ph = b"$%d" % (len(seen) + 1)
seen[part.item] = (ph, part.format)
order.append(part.item)
chunks.append(ph)
formats.append(part.format)
else:
if seen[part.item][1] != part.format:
raise e.ProgrammingError(
f"placeholder '{part.item}' cannot have different formats"
)
chunks.append(seen[part.item][0])
# last part
chunks.append(parts[-1].pre)
return b"".join(chunks), formats, order, parts
# Note: the cache size is 128 items, but someone has reported throwing ~12k
# queries (of type `INSERT ... VALUES (...), (...)` with a varying amount of
# records), and the resulting cache size is >100Mb. So, we will avoid to cache
# large queries or queries with a large number of params. See
# https://github.com/sqlalchemy/sqlalchemy/discussions/10270
_query2pg = lru_cache()(_query2pg_nocache)
class PostgresClientQuery(PostgresQuery):
"""
PostgresQuery subclass merging query and arguments client-side.
"""
__slots__ = ("template",)
def convert(self, query: Query, vars: Optional[Params]) -> None:
"""
Set up the query and parameters to convert.
The results of this function can be obtained accessing the object
attributes (`query`, `params`, `types`, `formats`).
"""
if isinstance(query, str):
bquery = query.encode(self._encoding)
elif isinstance(query, Composable):
bquery = query.as_bytes(self._tx)
else:
bquery = query
if vars is not None:
if (
len(bquery) <= MAX_CACHED_STATEMENT_LENGTH
and len(vars) <= MAX_CACHED_STATEMENT_PARAMS
):
f: _Query2PgClient = _query2pg_client
else:
f = _query2pg_client_nocache
(self.template, self._order, self._parts) = f(bquery, self._encoding)
else:
self.query = bquery
self._order = None
self.dump(vars)
def dump(self, vars: Optional[Params]) -> None:
"""
Process a new set of variables on the query processed by `convert()`.
This method updates `params` and `types`.
"""
if vars is not None:
params = _validate_and_reorder_params(self._parts, vars, self._order)
self.params = tuple(
self._tx.as_literal(p) if p is not None else b"NULL" for p in params
)
self.query = self.template % self.params
else:
self.params = None
_Query2PgClient: TypeAlias = Callable[
[bytes, str], Tuple[bytes, Optional[List[str]], List[QueryPart]]
]
def _query2pg_client_nocache(
query: bytes, encoding: str
) -> Tuple[bytes, Optional[List[str]], List[QueryPart]]:
"""
Convert Python query and params into a template to perform client-side binding
"""
parts = _split_query(query, encoding, collapse_double_percent=False)
order: Optional[List[str]] = None
chunks: List[bytes] = []
if isinstance(parts[0].item, int):
for part in parts[:-1]:
assert isinstance(part.item, int)
chunks.append(part.pre)
chunks.append(b"%s")
elif isinstance(parts[0].item, str):
seen: Dict[str, Tuple[bytes, PyFormat]] = {}
order = []
for part in parts[:-1]:
assert isinstance(part.item, str)
chunks.append(part.pre)
if part.item not in seen:
ph = b"%s"
seen[part.item] = (ph, part.format)
order.append(part.item)
chunks.append(ph)
else:
chunks.append(seen[part.item][0])
order.append(part.item)
# last part
chunks.append(parts[-1].pre)
return b"".join(chunks), order, parts
_query2pg_client = lru_cache()(_query2pg_client_nocache)
def _validate_and_reorder_params(
parts: List[QueryPart], vars: Params, order: Optional[List[str]]
) -> Sequence[Any]:
"""
Verify the compatibility between a query and a set of params.
"""
# Try concrete types, then abstract types
t = type(vars)
if t is list or t is tuple:
sequence = True
elif t is dict:
sequence = False
elif isinstance(vars, Sequence) and not isinstance(vars, (bytes, str)):
sequence = True
elif isinstance(vars, Mapping):
sequence = False
else:
raise TypeError(
"query parameters should be a sequence or a mapping,"
f" got {type(vars).__name__}"
)
if sequence:
if len(vars) != len(parts) - 1:
raise e.ProgrammingError(
f"the query has {len(parts) - 1} placeholders but"
f" {len(vars)} parameters were passed"
)
if vars and not isinstance(parts[0].item, int):
raise TypeError("named placeholders require a mapping of parameters")
return vars # type: ignore[return-value]
else:
if vars and len(parts) > 1 and not isinstance(parts[0][1], str):
raise TypeError(
"positional placeholders (%s) require a sequence of parameters"
)
try:
return [vars[item] for item in order or ()] # type: ignore[call-overload]
except KeyError:
raise e.ProgrammingError(
"query parameter missing:"
f" {', '.join(sorted(i for i in order or () if i not in vars))}"
)
_re_placeholder = re.compile(
rb"""(?x)
% # a literal %
(?:
(?:
\( ([^)]+) \) # or a name in (braces)
. # followed by a format
)
|
(?:.) # or any char, really
)
"""
)
def _split_query(
query: bytes, encoding: str = "ascii", collapse_double_percent: bool = True
) -> List[QueryPart]:
parts: List[Tuple[bytes, Optional[Match[bytes]]]] = []
cur = 0
# pairs [(fragment, match], with the last match None
m = None
for m in _re_placeholder.finditer(query):
pre = query[cur : m.span(0)[0]]
parts.append((pre, m))
cur = m.span(0)[1]
if m:
parts.append((query[cur:], None))
else:
parts.append((query, None))
rv = []
# drop the "%%", validate
i = 0
phtype = None
while i < len(parts):
pre, m = parts[i]
if m is None:
# last part
rv.append(QueryPart(pre, 0, PyFormat.AUTO))
break
ph = m.group(0)
if ph == b"%%":
# unescape '%%' to '%' if necessary, then merge the parts
if collapse_double_percent:
ph = b"%"
pre1, m1 = parts[i + 1]
parts[i + 1] = (pre + ph + pre1, m1)
del parts[i]
continue
if ph == b"%(":
raise e.ProgrammingError(
"incomplete placeholder:"
f" '{query[m.span(0)[0]:].split()[0].decode(encoding)}'"
)
elif ph == b"% ":
# explicit messasge for a typical error
raise e.ProgrammingError(
"incomplete placeholder: '%'; if you want to use '%' as an"
" operator you can double it up, i.e. use '%%'"
)
elif ph[-1:] not in b"sbt":
raise e.ProgrammingError(
"only '%s', '%b', '%t' are allowed as placeholders, got"
f" '{m.group(0).decode(encoding)}'"
)
# Index or name
item: Union[int, str]
item = m.group(1).decode(encoding) if m.group(1) else i
if not phtype:
phtype = type(item)
elif phtype is not type(item):
raise e.ProgrammingError(
"positional and named placeholders cannot be mixed"
)
format = _ph_to_fmt[ph[-1:]]
rv.append(QueryPart(pre, item, format))
i += 1
return rv
_ph_to_fmt = {
b"s": PyFormat.AUTO,
b"t": PyFormat.TEXT,
b"b": PyFormat.BINARY,
}

View File

@ -0,0 +1,57 @@
"""
Utility functions to deal with binary structs.
"""
# Copyright (C) 2020 The Psycopg Team
import struct
from typing import Callable, cast, Optional, Tuple
from typing_extensions import TypeAlias
from .abc import Buffer
from . import errors as e
from ._compat import Protocol
PackInt: TypeAlias = Callable[[int], bytes]
UnpackInt: TypeAlias = Callable[[Buffer], Tuple[int]]
PackFloat: TypeAlias = Callable[[float], bytes]
UnpackFloat: TypeAlias = Callable[[Buffer], Tuple[float]]
class UnpackLen(Protocol):
def __call__(self, data: Buffer, start: Optional[int]) -> Tuple[int]:
...
pack_int2 = cast(PackInt, struct.Struct("!h").pack)
pack_uint2 = cast(PackInt, struct.Struct("!H").pack)
pack_int4 = cast(PackInt, struct.Struct("!i").pack)
pack_uint4 = cast(PackInt, struct.Struct("!I").pack)
pack_int8 = cast(PackInt, struct.Struct("!q").pack)
pack_float4 = cast(PackFloat, struct.Struct("!f").pack)
pack_float8 = cast(PackFloat, struct.Struct("!d").pack)
unpack_int2 = cast(UnpackInt, struct.Struct("!h").unpack)
unpack_uint2 = cast(UnpackInt, struct.Struct("!H").unpack)
unpack_int4 = cast(UnpackInt, struct.Struct("!i").unpack)
unpack_uint4 = cast(UnpackInt, struct.Struct("!I").unpack)
unpack_int8 = cast(UnpackInt, struct.Struct("!q").unpack)
unpack_float4 = cast(UnpackFloat, struct.Struct("!f").unpack)
unpack_float8 = cast(UnpackFloat, struct.Struct("!d").unpack)
_struct_len = struct.Struct("!i")
pack_len = cast(Callable[[int], bytes], _struct_len.pack)
unpack_len = cast(UnpackLen, _struct_len.unpack_from)
def pack_float4_bug_304(x: float) -> bytes:
raise e.InterfaceError(
"cannot dump Float4: Python affected by bug #304. Note that the psycopg-c"
" and psycopg-binary packages are not affected by this issue."
" See https://github.com/psycopg/psycopg/issues/304"
)
# If issue #304 is detected, raise an error instead of dumping wrong data.
if struct.Struct("!f").pack(1.0) != bytes.fromhex("3f800000"):
pack_float4 = pack_float4_bug_304

View File

@ -0,0 +1,116 @@
"""
psycopg two-phase commit support
"""
# Copyright (C) 2021 The Psycopg Team
import re
import datetime as dt
from base64 import b64encode, b64decode
from typing import Optional, Union
from dataclasses import dataclass, replace
_re_xid = re.compile(r"^(\d+)_([^_]*)_([^_]*)$")
@dataclass(frozen=True)
class Xid:
"""A two-phase commit transaction identifier.
The object can also be unpacked as a 3-item tuple (`format_id`, `gtrid`,
`bqual`).
"""
format_id: Optional[int]
gtrid: str
bqual: Optional[str]
prepared: Optional[dt.datetime] = None
owner: Optional[str] = None
database: Optional[str] = None
@classmethod
def from_string(cls, s: str) -> "Xid":
"""Try to parse an XA triple from the string.
This may fail for several reasons. In such case return an unparsed Xid.
"""
try:
return cls._parse_string(s)
except Exception:
return Xid(None, s, None)
def __str__(self) -> str:
return self._as_tid()
def __len__(self) -> int:
return 3
def __getitem__(self, index: int) -> Union[int, str, None]:
return (self.format_id, self.gtrid, self.bqual)[index]
@classmethod
def _parse_string(cls, s: str) -> "Xid":
m = _re_xid.match(s)
if not m:
raise ValueError("bad Xid format")
format_id = int(m.group(1))
gtrid = b64decode(m.group(2)).decode()
bqual = b64decode(m.group(3)).decode()
return cls.from_parts(format_id, gtrid, bqual)
@classmethod
def from_parts(
cls, format_id: Optional[int], gtrid: str, bqual: Optional[str]
) -> "Xid":
if format_id is not None:
if bqual is None:
raise TypeError("if format_id is specified, bqual must be too")
if not 0 <= format_id < 0x80000000:
raise ValueError("format_id must be a non-negative 32-bit integer")
if len(bqual) > 64:
raise ValueError("bqual must be not longer than 64 chars")
if len(gtrid) > 64:
raise ValueError("gtrid must be not longer than 64 chars")
elif bqual is None:
raise TypeError("if format_id is None, bqual must be None too")
return Xid(format_id, gtrid, bqual)
def _as_tid(self) -> str:
"""
Return the PostgreSQL transaction_id for this XA xid.
PostgreSQL wants just a string, while the DBAPI supports the XA
standard and thus a triple. We use the same conversion algorithm
implemented by JDBC in order to allow some form of interoperation.
see also: the pgjdbc implementation
http://cvs.pgfoundry.org/cgi-bin/cvsweb.cgi/jdbc/pgjdbc/org/
postgresql/xa/RecoveredXid.java?rev=1.2
"""
if self.format_id is None or self.bqual is None:
# Unparsed xid: return the gtrid.
return self.gtrid
# XA xid: mash together the components.
egtrid = b64encode(self.gtrid.encode()).decode()
ebqual = b64encode(self.bqual.encode()).decode()
return f"{self.format_id}_{egtrid}_{ebqual}"
@classmethod
def _get_recover_query(cls) -> str:
return "SELECT gid, prepared, owner, database FROM pg_prepared_xacts"
@classmethod
def _from_record(
cls, gid: str, prepared: dt.datetime, owner: str, database: str
) -> "Xid":
xid = Xid.from_string(gid)
return replace(xid, prepared=prepared, owner=owner, database=database)
Xid.__module__ = "psycopg"

View File

@ -0,0 +1,354 @@
"""
Helper object to transform values between Python and PostgreSQL
"""
# Copyright (C) 2020 The Psycopg Team
from typing import Any, Dict, List, Optional, Sequence, Tuple
from typing import DefaultDict, TYPE_CHECKING
from collections import defaultdict
from typing_extensions import TypeAlias
from . import pq
from . import postgres
from . import errors as e
from .abc import Buffer, LoadFunc, AdaptContext, PyFormat, DumperKey, NoneType
from .rows import Row, RowMaker
from .postgres import INVALID_OID, TEXT_OID
from ._encodings import pgconn_encoding
if TYPE_CHECKING:
from .abc import Dumper, Loader
from .adapt import AdaptersMap
from .pq.abc import PGresult
from .connection import BaseConnection
DumperCache: TypeAlias = Dict[DumperKey, "Dumper"]
OidDumperCache: TypeAlias = Dict[int, "Dumper"]
LoaderCache: TypeAlias = Dict[int, "Loader"]
TEXT = pq.Format.TEXT
PY_TEXT = PyFormat.TEXT
class Transformer(AdaptContext):
"""
An object that can adapt efficiently between Python and PostgreSQL.
The life cycle of the object is the query, so it is assumed that attributes
such as the server version or the connection encoding will not change. The
object have its state so adapting several values of the same type can be
optimised.
"""
__module__ = "psycopg.adapt"
__slots__ = """
types formats
_conn _adapters _pgresult _dumpers _loaders _encoding _none_oid
_oid_dumpers _oid_types _row_dumpers _row_loaders
""".split()
types: Optional[Tuple[int, ...]]
formats: Optional[List[pq.Format]]
_adapters: "AdaptersMap"
_pgresult: Optional["PGresult"]
_none_oid: int
def __init__(self, context: Optional[AdaptContext] = None):
self._pgresult = self.types = self.formats = None
# WARNING: don't store context, or you'll create a loop with the Cursor
if context:
self._adapters = context.adapters
self._conn = context.connection
else:
self._adapters = postgres.adapters
self._conn = None
# mapping fmt, class -> Dumper instance
self._dumpers: DefaultDict[PyFormat, DumperCache]
self._dumpers = defaultdict(dict)
# mapping fmt, oid -> Dumper instance
# Not often used, so create it only if needed.
self._oid_dumpers: Optional[Tuple[OidDumperCache, OidDumperCache]]
self._oid_dumpers = None
# mapping fmt, oid -> Loader instance
self._loaders: Tuple[LoaderCache, LoaderCache] = ({}, {})
self._row_dumpers: Optional[List["Dumper"]] = None
# sequence of load functions from value to python
# the length of the result columns
self._row_loaders: List[LoadFunc] = []
# mapping oid -> type sql representation
self._oid_types: Dict[int, bytes] = {}
self._encoding = ""
@classmethod
def from_context(cls, context: Optional[AdaptContext]) -> "Transformer":
"""
Return a Transformer from an AdaptContext.
If the context is a Transformer instance, just return it.
"""
if isinstance(context, Transformer):
return context
else:
return cls(context)
@property
def connection(self) -> Optional["BaseConnection[Any]"]:
return self._conn
@property
def encoding(self) -> str:
if not self._encoding:
conn = self.connection
self._encoding = pgconn_encoding(conn.pgconn) if conn else "utf-8"
return self._encoding
@property
def adapters(self) -> "AdaptersMap":
return self._adapters
@property
def pgresult(self) -> Optional["PGresult"]:
return self._pgresult
def set_pgresult(
self,
result: Optional["PGresult"],
*,
set_loaders: bool = True,
format: Optional[pq.Format] = None,
) -> None:
self._pgresult = result
if not result:
self._nfields = self._ntuples = 0
if set_loaders:
self._row_loaders = []
return
self._ntuples = result.ntuples
nf = self._nfields = result.nfields
if not set_loaders:
return
if not nf:
self._row_loaders = []
return
fmt: pq.Format
fmt = result.fformat(0) if format is None else format # type: ignore
self._row_loaders = [
self.get_loader(result.ftype(i), fmt).load for i in range(nf)
]
def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None:
self._row_dumpers = [self.get_dumper_by_oid(oid, format) for oid in types]
self.types = tuple(types)
self.formats = [format] * len(types)
def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None:
self._row_loaders = [self.get_loader(oid, format).load for oid in types]
def dump_sequence(
self, params: Sequence[Any], formats: Sequence[PyFormat]
) -> Sequence[Optional[Buffer]]:
nparams = len(params)
out: List[Optional[Buffer]] = [None] * nparams
# If we have dumpers, it means set_dumper_types had been called, in
# which case self.types and self.formats are set to sequences of the
# right size.
if self._row_dumpers:
for i in range(nparams):
param = params[i]
if param is not None:
out[i] = self._row_dumpers[i].dump(param)
return out
types = [self._get_none_oid()] * nparams
pqformats = [TEXT] * nparams
for i in range(nparams):
param = params[i]
if param is None:
continue
dumper = self.get_dumper(param, formats[i])
out[i] = dumper.dump(param)
types[i] = dumper.oid
pqformats[i] = dumper.format
self.types = tuple(types)
self.formats = pqformats
return out
def as_literal(self, obj: Any) -> bytes:
dumper = self.get_dumper(obj, PY_TEXT)
rv = dumper.quote(obj)
# If the result is quoted, and the oid not unknown or text,
# add an explicit type cast.
# Check the last char because the first one might be 'E'.
oid = dumper.oid
if oid and rv and rv[-1] == b"'"[0] and oid != TEXT_OID:
try:
type_sql = self._oid_types[oid]
except KeyError:
ti = self.adapters.types.get(oid)
if ti:
if oid < 8192:
# builtin: prefer "timestamptz" to "timestamp with time zone"
type_sql = ti.name.encode(self.encoding)
else:
type_sql = ti.regtype.encode(self.encoding)
if oid == ti.array_oid:
type_sql += b"[]"
else:
type_sql = b""
self._oid_types[oid] = type_sql
if type_sql:
rv = b"%s::%s" % (rv, type_sql)
if not isinstance(rv, bytes):
rv = bytes(rv)
return rv
def get_dumper(self, obj: Any, format: PyFormat) -> "Dumper":
"""
Return a Dumper instance to dump `!obj`.
"""
# Normally, the type of the object dictates how to dump it
key = type(obj)
# Reuse an existing Dumper class for objects of the same type
cache = self._dumpers[format]
try:
dumper = cache[key]
except KeyError:
# If it's the first time we see this type, look for a dumper
# configured for it.
try:
dcls = self.adapters.get_dumper(key, format)
except e.ProgrammingError as ex:
raise ex from None
else:
cache[key] = dumper = dcls(key, self)
# Check if the dumper requires an upgrade to handle this specific value
key1 = dumper.get_key(obj, format)
if key1 is key:
return dumper
# If it does, ask the dumper to create its own upgraded version
try:
return cache[key1]
except KeyError:
dumper = cache[key1] = dumper.upgrade(obj, format)
return dumper
def _get_none_oid(self) -> int:
try:
return self._none_oid
except AttributeError:
pass
try:
rv = self._none_oid = self._adapters.get_dumper(NoneType, PY_TEXT).oid
except KeyError:
raise e.InterfaceError("None dumper not found")
return rv
def get_dumper_by_oid(self, oid: int, format: pq.Format) -> "Dumper":
"""
Return a Dumper to dump an object to the type with given oid.
"""
if not self._oid_dumpers:
self._oid_dumpers = ({}, {})
# Reuse an existing Dumper class for objects of the same type
cache = self._oid_dumpers[format]
try:
return cache[oid]
except KeyError:
# If it's the first time we see this type, look for a dumper
# configured for it.
dcls = self.adapters.get_dumper_by_oid(oid, format)
cache[oid] = dumper = dcls(NoneType, self)
return dumper
def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> List[Row]:
res = self._pgresult
if not res:
raise e.InterfaceError("result not set")
if not (0 <= row0 <= self._ntuples and 0 <= row1 <= self._ntuples):
raise e.InterfaceError(
f"rows must be included between 0 and {self._ntuples}"
)
records = []
for row in range(row0, row1):
record: List[Any] = [None] * self._nfields
for col in range(self._nfields):
val = res.get_value(row, col)
if val is not None:
record[col] = self._row_loaders[col](val)
records.append(make_row(record))
return records
def load_row(self, row: int, make_row: RowMaker[Row]) -> Optional[Row]:
res = self._pgresult
if not res:
return None
if not 0 <= row < self._ntuples:
return None
record: List[Any] = [None] * self._nfields
for col in range(self._nfields):
val = res.get_value(row, col)
if val is not None:
record[col] = self._row_loaders[col](val)
return make_row(record)
def load_sequence(self, record: Sequence[Optional[Buffer]]) -> Tuple[Any, ...]:
if len(self._row_loaders) != len(record):
raise e.ProgrammingError(
f"cannot load sequence of {len(record)} items:"
f" {len(self._row_loaders)} loaders registered"
)
return tuple(
(self._row_loaders[i](val) if val is not None else None)
for i, val in enumerate(record)
)
def get_loader(self, oid: int, format: pq.Format) -> "Loader":
try:
return self._loaders[format][oid]
except KeyError:
pass
loader_cls = self._adapters.get_loader(oid, format)
if not loader_cls:
loader_cls = self._adapters.get_loader(INVALID_OID, format)
if not loader_cls:
raise e.InterfaceError("unknown oid loader not found")
loader = self._loaders[format][oid] = loader_cls(oid, self)
return loader

View File

@ -0,0 +1,500 @@
"""
Information about PostgreSQL types
These types allow to read information from the system catalog and provide
information to the adapters if needed.
"""
# Copyright (C) 2020 The Psycopg Team
from enum import Enum
from typing import Any, Dict, Iterator, Optional, overload
from typing import Sequence, Tuple, Type, TypeVar, Union, TYPE_CHECKING
from typing_extensions import TypeAlias
from . import errors as e
from .abc import AdaptContext, Query
from .rows import dict_row
from ._encodings import conn_encoding
if TYPE_CHECKING:
from .connection import BaseConnection, Connection
from .connection_async import AsyncConnection
from .sql import Identifier, SQL
T = TypeVar("T", bound="TypeInfo")
RegistryKey: TypeAlias = Union[str, int, Tuple[type, int]]
class TypeInfo:
"""
Hold information about a PostgreSQL base type.
"""
__module__ = "psycopg.types"
def __init__(
self,
name: str,
oid: int,
array_oid: int,
*,
regtype: str = "",
delimiter: str = ",",
):
self.name = name
self.oid = oid
self.array_oid = array_oid
self.regtype = regtype or name
self.delimiter = delimiter
def __repr__(self) -> str:
return (
f"<{self.__class__.__qualname__}:"
f" {self.name} (oid: {self.oid}, array oid: {self.array_oid})>"
)
@overload
@classmethod
def fetch(
cls: Type[T], conn: "Connection[Any]", name: Union[str, "Identifier"]
) -> Optional[T]:
...
@overload
@classmethod
async def fetch(
cls: Type[T], conn: "AsyncConnection[Any]", name: Union[str, "Identifier"]
) -> Optional[T]:
...
@classmethod
def fetch(
cls: Type[T], conn: "BaseConnection[Any]", name: Union[str, "Identifier"]
) -> Any:
"""Query a system catalog to read information about a type."""
from .sql import Composable
from .connection import Connection
from .connection_async import AsyncConnection
if isinstance(name, Composable):
name = name.as_string(conn)
if isinstance(conn, Connection):
return cls._fetch(conn, name)
elif isinstance(conn, AsyncConnection):
return cls._fetch_async(conn, name)
else:
raise TypeError(
f"expected Connection or AsyncConnection, got {type(conn).__name__}"
)
@classmethod
def _fetch(cls: Type[T], conn: "Connection[Any]", name: str) -> Optional[T]:
# This might result in a nested transaction. What we want is to leave
# the function with the connection in the state we found (either idle
# or intrans)
try:
with conn.transaction():
if conn_encoding(conn) == "ascii":
conn.execute("set local client_encoding to utf8")
with conn.cursor(row_factory=dict_row) as cur:
cur.execute(cls._get_info_query(conn), {"name": name})
recs = cur.fetchall()
except e.UndefinedObject:
return None
return cls._from_records(name, recs)
@classmethod
async def _fetch_async(
cls: Type[T], conn: "AsyncConnection[Any]", name: str
) -> Optional[T]:
try:
async with conn.transaction():
if conn_encoding(conn) == "ascii":
await conn.execute("set local client_encoding to utf8")
async with conn.cursor(row_factory=dict_row) as cur:
await cur.execute(cls._get_info_query(conn), {"name": name})
recs = await cur.fetchall()
except e.UndefinedObject:
return None
return cls._from_records(name, recs)
@classmethod
def _from_records(
cls: Type[T], name: str, recs: Sequence[Dict[str, Any]]
) -> Optional[T]:
if len(recs) == 1:
return cls(**recs[0])
elif not recs:
return None
else:
raise e.ProgrammingError(f"found {len(recs)} different types named {name}")
def register(self, context: Optional[AdaptContext] = None) -> None:
"""
Register the type information, globally or in the specified `!context`.
"""
if context:
types = context.adapters.types
else:
from . import postgres
types = postgres.types
types.add(self)
if self.array_oid:
from .types.array import register_array
register_array(self, context)
@classmethod
def _get_info_query(cls, conn: "BaseConnection[Any]") -> Query:
from .sql import SQL
return SQL(
"""\
SELECT
typname AS name, oid, typarray AS array_oid,
oid::regtype::text AS regtype, typdelim AS delimiter
FROM pg_type t
WHERE t.oid = {regtype}
ORDER BY t.oid
"""
).format(regtype=cls._to_regtype(conn))
@classmethod
def _has_to_regtype_function(cls, conn: "BaseConnection[Any]") -> bool:
# to_regtype() introduced in PostgreSQL 9.4 and CockroachDB 22.2
info = conn.info
if info.vendor == "PostgreSQL":
return info.server_version >= 90400
elif info.vendor == "CockroachDB":
return info.server_version >= 220200
else:
return False
@classmethod
def _to_regtype(cls, conn: "BaseConnection[Any]") -> "SQL":
# `to_regtype()` returns the type oid or NULL, unlike the :: operator,
# which returns the type or raises an exception, which requires
# a transaction rollback and leaves traces in the server logs.
from .sql import SQL
if cls._has_to_regtype_function(conn):
return SQL("to_regtype(%(name)s)")
else:
return SQL("%(name)s::regtype")
def _added(self, registry: "TypesRegistry") -> None:
"""Method called by the `!registry` when the object is added there."""
pass
class RangeInfo(TypeInfo):
"""Manage information about a range type."""
__module__ = "psycopg.types.range"
def __init__(
self,
name: str,
oid: int,
array_oid: int,
*,
regtype: str = "",
subtype_oid: int,
):
super().__init__(name, oid, array_oid, regtype=regtype)
self.subtype_oid = subtype_oid
@classmethod
def _get_info_query(cls, conn: "BaseConnection[Any]") -> Query:
from .sql import SQL
return SQL(
"""\
SELECT t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
t.oid::regtype::text AS regtype,
r.rngsubtype AS subtype_oid
FROM pg_type t
JOIN pg_range r ON t.oid = r.rngtypid
WHERE t.oid = {regtype}
"""
).format(regtype=cls._to_regtype(conn))
def _added(self, registry: "TypesRegistry") -> None:
# Map ranges subtypes to info
registry._registry[RangeInfo, self.subtype_oid] = self
class MultirangeInfo(TypeInfo):
"""Manage information about a multirange type."""
__module__ = "psycopg.types.multirange"
def __init__(
self,
name: str,
oid: int,
array_oid: int,
*,
regtype: str = "",
range_oid: int,
subtype_oid: int,
):
super().__init__(name, oid, array_oid, regtype=regtype)
self.range_oid = range_oid
self.subtype_oid = subtype_oid
@classmethod
def _get_info_query(cls, conn: "BaseConnection[Any]") -> Query:
from .sql import SQL
if conn.info.server_version < 140000:
raise e.NotSupportedError(
"multirange types are only available from PostgreSQL 14"
)
return SQL(
"""\
SELECT t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
t.oid::regtype::text AS regtype,
r.rngtypid AS range_oid, r.rngsubtype AS subtype_oid
FROM pg_type t
JOIN pg_range r ON t.oid = r.rngmultitypid
WHERE t.oid = {regtype}
"""
).format(regtype=cls._to_regtype(conn))
def _added(self, registry: "TypesRegistry") -> None:
# Map multiranges ranges and subtypes to info
registry._registry[MultirangeInfo, self.range_oid] = self
registry._registry[MultirangeInfo, self.subtype_oid] = self
class CompositeInfo(TypeInfo):
"""Manage information about a composite type."""
__module__ = "psycopg.types.composite"
def __init__(
self,
name: str,
oid: int,
array_oid: int,
*,
regtype: str = "",
field_names: Sequence[str],
field_types: Sequence[int],
):
super().__init__(name, oid, array_oid, regtype=regtype)
self.field_names = field_names
self.field_types = field_types
# Will be set by register() if the `factory` is a type
self.python_type: Optional[type] = None
@classmethod
def _get_info_query(cls, conn: "BaseConnection[Any]") -> Query:
from .sql import SQL
return SQL(
"""\
SELECT
t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
t.oid::regtype::text AS regtype,
coalesce(a.fnames, '{{}}') AS field_names,
coalesce(a.ftypes, '{{}}') AS field_types
FROM pg_type t
LEFT JOIN (
SELECT
attrelid,
array_agg(attname) AS fnames,
array_agg(atttypid) AS ftypes
FROM (
SELECT a.attrelid, a.attname, a.atttypid
FROM pg_attribute a
JOIN pg_type t ON t.typrelid = a.attrelid
WHERE t.oid = {regtype}
AND a.attnum > 0
AND NOT a.attisdropped
ORDER BY a.attnum
) x
GROUP BY attrelid
) a ON a.attrelid = t.typrelid
WHERE t.oid = {regtype}
"""
).format(regtype=cls._to_regtype(conn))
class EnumInfo(TypeInfo):
"""Manage information about an enum type."""
__module__ = "psycopg.types.enum"
def __init__(
self,
name: str,
oid: int,
array_oid: int,
labels: Sequence[str],
):
super().__init__(name, oid, array_oid)
self.labels = labels
# Will be set by register_enum()
self.enum: Optional[Type[Enum]] = None
@classmethod
def _get_info_query(cls, conn: "BaseConnection[Any]") -> Query:
from .sql import SQL
return SQL(
"""\
SELECT name, oid, array_oid, array_agg(label) AS labels
FROM (
SELECT
t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
e.enumlabel AS label
FROM pg_type t
LEFT JOIN pg_enum e
ON e.enumtypid = t.oid
WHERE t.oid = {regtype}
ORDER BY e.enumsortorder
) x
GROUP BY name, oid, array_oid
"""
).format(regtype=cls._to_regtype(conn))
class TypesRegistry:
"""
Container for the information about types in a database.
"""
__module__ = "psycopg.types"
def __init__(self, template: Optional["TypesRegistry"] = None):
self._registry: Dict[RegistryKey, TypeInfo]
# Make a shallow copy: it will become a proper copy if the registry
# is edited.
if template:
self._registry = template._registry
self._own_state = False
template._own_state = False
else:
self.clear()
def clear(self) -> None:
self._registry = {}
self._own_state = True
def add(self, info: TypeInfo) -> None:
self._ensure_own_state()
if info.oid:
self._registry[info.oid] = info
if info.array_oid:
self._registry[info.array_oid] = info
self._registry[info.name] = info
if info.regtype and info.regtype not in self._registry:
self._registry[info.regtype] = info
# Allow info to customise further their relation with the registry
info._added(self)
def __iter__(self) -> Iterator[TypeInfo]:
seen = set()
for t in self._registry.values():
if id(t) not in seen:
seen.add(id(t))
yield t
@overload
def __getitem__(self, key: Union[str, int]) -> TypeInfo:
...
@overload
def __getitem__(self, key: Tuple[Type[T], int]) -> T:
...
def __getitem__(self, key: RegistryKey) -> TypeInfo:
"""
Return info about a type, specified by name or oid
:param key: the name or oid of the type to look for.
Raise KeyError if not found.
"""
if isinstance(key, str):
if key.endswith("[]"):
key = key[:-2]
elif not isinstance(key, (int, tuple)):
raise TypeError(f"the key must be an oid or a name, got {type(key)}")
try:
return self._registry[key]
except KeyError:
raise KeyError(f"couldn't find the type {key!r} in the types registry")
@overload
def get(self, key: Union[str, int]) -> Optional[TypeInfo]:
...
@overload
def get(self, key: Tuple[Type[T], int]) -> Optional[T]:
...
def get(self, key: RegistryKey) -> Optional[TypeInfo]:
"""
Return info about a type, specified by name or oid
:param key: the name or oid of the type to look for.
Unlike `__getitem__`, return None if not found.
"""
try:
return self[key]
except KeyError:
return None
def get_oid(self, name: str) -> int:
"""
Return the oid of a PostgreSQL type by name.
:param key: the name of the type to look for.
Return the array oid if the type ends with "``[]``"
Raise KeyError if the name is unknown.
"""
t = self[name]
if name.endswith("[]"):
return t.array_oid
else:
return t.oid
def get_by_subtype(self, cls: Type[T], subtype: Union[int, str]) -> Optional[T]:
"""
Return info about a `TypeInfo` subclass by its element name or oid.
:param cls: the subtype of `!TypeInfo` to look for. Currently
supported are `~psycopg.types.range.RangeInfo` and
`~psycopg.types.multirange.MultirangeInfo`.
:param subtype: The name or OID of the subtype of the element to look for.
:return: The `!TypeInfo` object of class `!cls` whose subtype is
`!subtype`. `!None` if the element or its range are not found.
"""
try:
info = self[subtype]
except KeyError:
return None
return self.get((cls, info.oid))
def _ensure_own_state(self) -> None:
# Time to write! so, copy.
if not self._own_state:
self._registry = self._registry.copy()
self._own_state = True

View File

@ -0,0 +1,44 @@
"""
Timezone utility functions.
"""
# Copyright (C) 2020 The Psycopg Team
import logging
from typing import Dict, Optional, Union
from datetime import timezone, tzinfo
from .pq.abc import PGconn
from ._compat import ZoneInfo
logger = logging.getLogger("psycopg")
_timezones: Dict[Union[None, bytes], tzinfo] = {
None: timezone.utc,
b"UTC": timezone.utc,
}
def get_tzinfo(pgconn: Optional[PGconn]) -> tzinfo:
"""Return the Python timezone info of the connection's timezone."""
tzname = pgconn.parameter_status(b"TimeZone") if pgconn else None
try:
return _timezones[tzname]
except KeyError:
sname = tzname.decode() if tzname else "UTC"
try:
zi: tzinfo = ZoneInfo(sname)
except (KeyError, OSError):
logger.warning("unknown PostgreSQL timezone: %r; will use UTC", sname)
zi = timezone.utc
except Exception as ex:
logger.warning(
"error handling PostgreSQL timezone: %r; will use UTC (%s - %s)",
sname,
type(ex).__name__,
ex,
)
zi = timezone.utc
_timezones[tzname] = zi
return zi

View File

@ -0,0 +1,137 @@
"""
Wrappers for numeric types.
"""
# Copyright (C) 2020 The Psycopg Team
# Wrappers to force numbers to be cast as specific PostgreSQL types
# These types are implemented here but exposed by `psycopg.types.numeric`.
# They are defined here to avoid a circular import.
_MODULE = "psycopg.types.numeric"
class Int2(int):
"""
Force dumping a Python `!int` as a PostgreSQL :sql:`smallint/int2`.
"""
__module__ = _MODULE
__slots__ = ()
def __new__(cls, arg: int) -> "Int2":
return super().__new__(cls, arg)
def __str__(self) -> str:
return super().__repr__()
def __repr__(self) -> str:
return f"{self.__class__.__name__}({super().__repr__()})"
class Int4(int):
"""
Force dumping a Python `!int` as a PostgreSQL :sql:`integer/int4`.
"""
__module__ = _MODULE
__slots__ = ()
def __new__(cls, arg: int) -> "Int4":
return super().__new__(cls, arg)
def __str__(self) -> str:
return super().__repr__()
def __repr__(self) -> str:
return f"{self.__class__.__name__}({super().__repr__()})"
class Int8(int):
"""
Force dumping a Python `!int` as a PostgreSQL :sql:`bigint/int8`.
"""
__module__ = _MODULE
__slots__ = ()
def __new__(cls, arg: int) -> "Int8":
return super().__new__(cls, arg)
def __str__(self) -> str:
return super().__repr__()
def __repr__(self) -> str:
return f"{self.__class__.__name__}({super().__repr__()})"
class IntNumeric(int):
"""
Force dumping a Python `!int` as a PostgreSQL :sql:`numeric/decimal`.
"""
__module__ = _MODULE
__slots__ = ()
def __new__(cls, arg: int) -> "IntNumeric":
return super().__new__(cls, arg)
def __str__(self) -> str:
return super().__repr__()
def __repr__(self) -> str:
return f"{self.__class__.__name__}({super().__repr__()})"
class Float4(float):
"""
Force dumping a Python `!float` as a PostgreSQL :sql:`float4/real`.
"""
__module__ = _MODULE
__slots__ = ()
def __new__(cls, arg: float) -> "Float4":
return super().__new__(cls, arg)
def __str__(self) -> str:
return super().__repr__()
def __repr__(self) -> str:
return f"{self.__class__.__name__}({super().__repr__()})"
class Float8(float):
"""
Force dumping a Python `!float` as a PostgreSQL :sql:`float8/double precision`.
"""
__module__ = _MODULE
__slots__ = ()
def __new__(cls, arg: float) -> "Float8":
return super().__new__(cls, arg)
def __str__(self) -> str:
return super().__repr__()
def __repr__(self) -> str:
return f"{self.__class__.__name__}({super().__repr__()})"
class Oid(int):
"""
Force dumping a Python `!int` as a PostgreSQL :sql:`oid`.
"""
__module__ = _MODULE
__slots__ = ()
def __new__(cls, arg: int) -> "Oid":
return super().__new__(cls, arg)
def __str__(self) -> str:
return super().__repr__()
def __repr__(self) -> str:
return f"{self.__class__.__name__}({super().__repr__()})"

View File

@ -0,0 +1,265 @@
"""
Protocol objects representing different implementations of the same classes.
"""
# Copyright (C) 2020 The Psycopg Team
from typing import Any, Callable, Generator, Mapping
from typing import List, Optional, Sequence, Tuple, TypeVar, Union
from typing import TYPE_CHECKING
from typing_extensions import TypeAlias
from . import pq
from ._enums import PyFormat as PyFormat
from ._compat import Protocol, LiteralString
if TYPE_CHECKING:
from . import sql
from .rows import Row, RowMaker
from .pq.abc import PGresult
from .waiting import Wait, Ready
from .connection import BaseConnection
from ._adapters_map import AdaptersMap
NoneType: type = type(None)
# An object implementing the buffer protocol
Buffer: TypeAlias = Union[bytes, bytearray, memoryview]
Query: TypeAlias = Union[LiteralString, bytes, "sql.SQL", "sql.Composed"]
Params: TypeAlias = Union[Sequence[Any], Mapping[str, Any]]
ConnectionType = TypeVar("ConnectionType", bound="BaseConnection[Any]")
PipelineCommand: TypeAlias = Callable[[], None]
DumperKey: TypeAlias = Union[type, Tuple["DumperKey", ...]]
# Waiting protocol types
RV = TypeVar("RV")
PQGenConn: TypeAlias = Generator[Tuple[int, "Wait"], "Ready", RV]
"""Generator for processes where the connection file number can change.
This can happen in connection and reset, but not in normal querying.
"""
PQGen: TypeAlias = Generator["Wait", "Ready", RV]
"""Generator for processes where the connection file number won't change.
"""
class WaitFunc(Protocol):
"""
Wait on the connection which generated `PQgen` and return its final result.
"""
def __call__(
self, gen: PQGen[RV], fileno: int, timeout: Optional[float] = None
) -> RV:
...
# Adaptation types
DumpFunc: TypeAlias = Callable[[Any], Buffer]
LoadFunc: TypeAlias = Callable[[Buffer], Any]
class AdaptContext(Protocol):
"""
A context describing how types are adapted.
Example of `~AdaptContext` are `~psycopg.Connection`, `~psycopg.Cursor`,
`~psycopg.adapt.Transformer`, `~psycopg.adapt.AdaptersMap`.
Note that this is a `~typing.Protocol`, so objects implementing
`!AdaptContext` don't need to explicitly inherit from this class.
"""
@property
def adapters(self) -> "AdaptersMap":
"""The adapters configuration that this object uses."""
...
@property
def connection(self) -> Optional["BaseConnection[Any]"]:
"""The connection used by this object, if available.
:rtype: `~psycopg.Connection` or `~psycopg.AsyncConnection` or `!None`
"""
...
class Dumper(Protocol):
"""
Convert Python objects of type `!cls` to PostgreSQL representation.
"""
format: pq.Format
"""
The format that this class `dump()` method produces,
`~psycopg.pq.Format.TEXT` or `~psycopg.pq.Format.BINARY`.
This is a class attribute.
"""
oid: int
"""The oid to pass to the server, if known; 0 otherwise (class attribute)."""
def __init__(self, cls: type, context: Optional[AdaptContext] = None):
...
def dump(self, obj: Any) -> Buffer:
"""Convert the object `!obj` to PostgreSQL representation.
:param obj: the object to convert.
"""
...
def quote(self, obj: Any) -> Buffer:
"""Convert the object `!obj` to escaped representation.
:param obj: the object to convert.
"""
...
def get_key(self, obj: Any, format: PyFormat) -> DumperKey:
"""Return an alternative key to upgrade the dumper to represent `!obj`.
:param obj: The object to convert
:param format: The format to convert to
Normally the type of the object is all it takes to define how to dump
the object to the database. For instance, a Python `~datetime.date` can
be simply converted into a PostgreSQL :sql:`date`.
In a few cases, just the type is not enough. For example:
- A Python `~datetime.datetime` could be represented as a
:sql:`timestamptz` or a :sql:`timestamp`, according to whether it
specifies a `!tzinfo` or not.
- A Python int could be stored as several Postgres types: int2, int4,
int8, numeric. If a type too small is used, it may result in an
overflow. If a type too large is used, PostgreSQL may not want to
cast it to a smaller type.
- Python lists should be dumped according to the type they contain to
convert them to e.g. array of strings, array of ints (and which
size of int?...)
In these cases, a dumper can implement `!get_key()` and return a new
class, or sequence of classes, that can be used to identify the same
dumper again. If the mechanism is not needed, the method should return
the same `!cls` object passed in the constructor.
If a dumper implements `get_key()` it should also implement
`upgrade()`.
"""
...
def upgrade(self, obj: Any, format: PyFormat) -> "Dumper":
"""Return a new dumper to manage `!obj`.
:param obj: The object to convert
:param format: The format to convert to
Once `Transformer.get_dumper()` has been notified by `get_key()` that
this Dumper class cannot handle `!obj` itself, it will invoke
`!upgrade()`, which should return a new `Dumper` instance, which will
be reused for every objects for which `!get_key()` returns the same
result.
"""
...
class Loader(Protocol):
"""
Convert PostgreSQL values with type OID `!oid` to Python objects.
"""
format: pq.Format
"""
The format that this class `load()` method can convert,
`~psycopg.pq.Format.TEXT` or `~psycopg.pq.Format.BINARY`.
This is a class attribute.
"""
def __init__(self, oid: int, context: Optional[AdaptContext] = None):
...
def load(self, data: Buffer) -> Any:
"""
Convert the data returned by the database into a Python object.
:param data: the data to convert.
"""
...
class Transformer(Protocol):
types: Optional[Tuple[int, ...]]
formats: Optional[List[pq.Format]]
def __init__(self, context: Optional[AdaptContext] = None):
...
@classmethod
def from_context(cls, context: Optional[AdaptContext]) -> "Transformer":
...
@property
def connection(self) -> Optional["BaseConnection[Any]"]:
...
@property
def encoding(self) -> str:
...
@property
def adapters(self) -> "AdaptersMap":
...
@property
def pgresult(self) -> Optional["PGresult"]:
...
def set_pgresult(
self,
result: Optional["PGresult"],
*,
set_loaders: bool = True,
format: Optional[pq.Format] = None
) -> None:
...
def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None:
...
def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None:
...
def dump_sequence(
self, params: Sequence[Any], formats: Sequence[PyFormat]
) -> Sequence[Optional[Buffer]]:
...
def as_literal(self, obj: Any) -> bytes:
...
def get_dumper(self, obj: Any, format: PyFormat) -> Dumper:
...
def load_rows(self, row0: int, row1: int, make_row: "RowMaker[Row]") -> List["Row"]:
...
def load_row(self, row: int, make_row: "RowMaker[Row]") -> Optional["Row"]:
...
def load_sequence(self, record: Sequence[Optional[Buffer]]) -> Tuple[Any, ...]:
...
def get_loader(self, oid: int, format: pq.Format) -> Loader:
...

View File

@ -0,0 +1,162 @@
"""
Entry point into the adaptation system.
"""
# Copyright (C) 2020 The Psycopg Team
from abc import ABC, abstractmethod
from typing import Any, Optional, Type, TYPE_CHECKING
from . import pq, abc
from . import _adapters_map
from ._enums import PyFormat as PyFormat
from ._cmodule import _psycopg
if TYPE_CHECKING:
from .connection import BaseConnection
AdaptersMap = _adapters_map.AdaptersMap
Buffer = abc.Buffer
ORD_BS = ord("\\")
class Dumper(abc.Dumper, ABC):
"""
Convert Python object of the type `!cls` to PostgreSQL representation.
"""
oid: int = 0
"""The oid to pass to the server, if known."""
format: pq.Format = pq.Format.TEXT
"""The format of the data dumped."""
def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None):
self.cls = cls
self.connection: Optional["BaseConnection[Any]"] = (
context.connection if context else None
)
def __repr__(self) -> str:
return (
f"<{type(self).__module__}.{type(self).__qualname__}"
f" (oid={self.oid}) at 0x{id(self):x}>"
)
@abstractmethod
def dump(self, obj: Any) -> Buffer:
...
def quote(self, obj: Any) -> Buffer:
"""
By default return the `dump()` value quoted and sanitised, so
that the result can be used to build a SQL string. This works well
for most types and you won't likely have to implement this method in a
subclass.
"""
value = self.dump(obj)
if self.connection:
esc = pq.Escaping(self.connection.pgconn)
# escaping and quoting
return esc.escape_literal(value)
# This path is taken when quote is asked without a connection,
# usually it means by psycopg.sql.quote() or by
# 'Composible.as_string(None)'. Most often than not this is done by
# someone generating a SQL file to consume elsewhere.
# No quoting, only quote escaping, random bs escaping. See further.
esc = pq.Escaping()
out = esc.escape_string(value)
# b"\\" in memoryview doesn't work so search for the ascii value
if ORD_BS not in out:
# If the string has no backslash, the result is correct and we
# don't need to bother with standard_conforming_strings.
return b"'" + out + b"'"
# The libpq has a crazy behaviour: PQescapeString uses the last
# standard_conforming_strings setting seen on a connection. This
# means that backslashes might be escaped or might not.
#
# A syntax E'\\' works everywhere, whereas E'\' is an error. OTOH,
# if scs is off, '\\' raises a warning and '\' is an error.
#
# Check what the libpq does, and if it doesn't escape the backslash
# let's do it on our own. Never mind the race condition.
rv: bytes = b" E'" + out + b"'"
if esc.escape_string(b"\\") == b"\\":
rv = rv.replace(b"\\", b"\\\\")
return rv
def get_key(self, obj: Any, format: PyFormat) -> abc.DumperKey:
"""
Implementation of the `~psycopg.abc.Dumper.get_key()` member of the
`~psycopg.abc.Dumper` protocol. Look at its definition for details.
This implementation returns the `!cls` passed in the constructor.
Subclasses needing to specialise the PostgreSQL type according to the
*value* of the object dumped (not only according to to its type)
should override this class.
"""
return self.cls
def upgrade(self, obj: Any, format: PyFormat) -> "Dumper":
"""
Implementation of the `~psycopg.abc.Dumper.upgrade()` member of the
`~psycopg.abc.Dumper` protocol. Look at its definition for details.
This implementation just returns `!self`. If a subclass implements
`get_key()` it should probably override `!upgrade()` too.
"""
return self
class Loader(abc.Loader, ABC):
"""
Convert PostgreSQL values with type OID `!oid` to Python objects.
"""
format: pq.Format = pq.Format.TEXT
"""The format of the data loaded."""
def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None):
self.oid = oid
self.connection: Optional["BaseConnection[Any]"] = (
context.connection if context else None
)
@abstractmethod
def load(self, data: Buffer) -> Any:
"""Convert a PostgreSQL value to a Python object."""
...
Transformer: Type["abc.Transformer"]
# Override it with fast object if available
if _psycopg:
Transformer = _psycopg.Transformer
else:
from . import _transform
Transformer = _transform.Transformer
class RecursiveDumper(Dumper):
"""Dumper with a transformer to help dumping recursive types."""
def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None):
super().__init__(cls, context)
self._tx = Transformer.from_context(context)
class RecursiveLoader(Loader):
"""Loader with a transformer to help loading recursive types."""
def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None):
super().__init__(oid, context)
self._tx = Transformer.from_context(context)

View File

@ -0,0 +1,95 @@
"""
psycopg client-side binding cursors
"""
# Copyright (C) 2022 The Psycopg Team
from typing import Optional, Tuple, TYPE_CHECKING
from functools import partial
from ._queries import PostgresQuery, PostgresClientQuery
from . import pq
from . import adapt
from . import errors as e
from .abc import ConnectionType, Query, Params
from .rows import Row
from .cursor import BaseCursor, Cursor
from ._preparing import Prepare
from .cursor_async import AsyncCursor
if TYPE_CHECKING:
from typing import Any # noqa: F401
from .connection import Connection # noqa: F401
from .connection_async import AsyncConnection # noqa: F401
TEXT = pq.Format.TEXT
BINARY = pq.Format.BINARY
class ClientCursorMixin(BaseCursor[ConnectionType, Row]):
def mogrify(self, query: Query, params: Optional[Params] = None) -> str:
"""
Return the query and parameters merged.
Parameters are adapted and merged to the query the same way that
`!execute()` would do.
"""
self._tx = adapt.Transformer(self)
pgq = self._convert_query(query, params)
return pgq.query.decode(self._tx.encoding)
def _execute_send(
self,
query: PostgresQuery,
*,
force_extended: bool = False,
binary: Optional[bool] = None,
) -> None:
if binary is None:
fmt = self.format
else:
fmt = BINARY if binary else TEXT
if fmt == BINARY:
raise e.NotSupportedError(
"client-side cursors don't support binary results"
)
self._query = query
if self._conn._pipeline:
# In pipeline mode always use PQsendQueryParams - see #314
# Multiple statements in the same query are not allowed anyway.
self._conn._pipeline.command_queue.append(
partial(self._pgconn.send_query_params, query.query, None)
)
elif force_extended:
self._pgconn.send_query_params(query.query, None)
else:
# If we can, let's use simple query protocol,
# as it can execute more than one statement in a single query.
self._pgconn.send_query(query.query)
def _convert_query(
self, query: Query, params: Optional[Params] = None
) -> PostgresQuery:
pgq = PostgresClientQuery(self._tx)
pgq.convert(query, params)
return pgq
def _get_prepared(
self, pgq: PostgresQuery, prepare: Optional[bool] = None
) -> Tuple[Prepare, bytes]:
return (Prepare.NO, b"")
class ClientCursor(ClientCursorMixin["Connection[Any]", Row], Cursor[Row]):
__module__ = "psycopg"
class AsyncClientCursor(
ClientCursorMixin["AsyncConnection[Any]", Row], AsyncCursor[Row]
):
__module__ = "psycopg"

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,428 @@
"""
psycopg async connection objects
"""
# Copyright (C) 2020 The Psycopg Team
import sys
import asyncio
import logging
from types import TracebackType
from typing import Any, AsyncGenerator, AsyncIterator, List, Optional
from typing import Type, TypeVar, Union, cast, overload, TYPE_CHECKING
from contextlib import asynccontextmanager
from . import pq
from . import errors as e
from . import waiting
from .abc import AdaptContext, Params, PQGen, PQGenConn, Query, RV
from ._tpc import Xid
from .rows import Row, AsyncRowFactory, tuple_row, TupleRow, args_row
from .adapt import AdaptersMap
from ._enums import IsolationLevel
from .conninfo import ConnDict, make_conninfo, conninfo_to_dict, conninfo_attempts_async
from ._pipeline import AsyncPipeline
from ._encodings import pgconn_encoding
from .connection import BaseConnection, CursorRow, Notify
from .generators import notifies
from .transaction import AsyncTransaction
from .cursor_async import AsyncCursor
from .server_cursor import AsyncServerCursor
if TYPE_CHECKING:
from .pq.abc import PGconn
TEXT = pq.Format.TEXT
BINARY = pq.Format.BINARY
IDLE = pq.TransactionStatus.IDLE
INTRANS = pq.TransactionStatus.INTRANS
logger = logging.getLogger("psycopg")
class AsyncConnection(BaseConnection[Row]):
"""
Asynchronous wrapper for a connection to the database.
"""
__module__ = "psycopg"
cursor_factory: Type[AsyncCursor[Row]]
server_cursor_factory: Type[AsyncServerCursor[Row]]
row_factory: AsyncRowFactory[Row]
_pipeline: Optional[AsyncPipeline]
_Self = TypeVar("_Self", bound="AsyncConnection[Any]")
def __init__(
self,
pgconn: "PGconn",
row_factory: AsyncRowFactory[Row] = cast(AsyncRowFactory[Row], tuple_row),
):
super().__init__(pgconn)
self.row_factory = row_factory
self.lock = asyncio.Lock()
self.cursor_factory = AsyncCursor
self.server_cursor_factory = AsyncServerCursor
@overload
@classmethod
async def connect(
cls,
conninfo: str = "",
*,
autocommit: bool = False,
prepare_threshold: Optional[int] = 5,
row_factory: AsyncRowFactory[Row],
cursor_factory: Optional[Type[AsyncCursor[Row]]] = None,
context: Optional[AdaptContext] = None,
**kwargs: Union[None, int, str],
) -> "AsyncConnection[Row]":
# TODO: returned type should be _Self. See #308.
...
@overload
@classmethod
async def connect(
cls,
conninfo: str = "",
*,
autocommit: bool = False,
prepare_threshold: Optional[int] = 5,
cursor_factory: Optional[Type[AsyncCursor[Any]]] = None,
context: Optional[AdaptContext] = None,
**kwargs: Union[None, int, str],
) -> "AsyncConnection[TupleRow]":
...
@classmethod # type: ignore[misc] # https://github.com/python/mypy/issues/11004
async def connect(
cls,
conninfo: str = "",
*,
autocommit: bool = False,
prepare_threshold: Optional[int] = 5,
context: Optional[AdaptContext] = None,
row_factory: Optional[AsyncRowFactory[Row]] = None,
cursor_factory: Optional[Type[AsyncCursor[Row]]] = None,
**kwargs: Any,
) -> "AsyncConnection[Any]":
if sys.platform == "win32":
loop = asyncio.get_running_loop()
if isinstance(loop, asyncio.ProactorEventLoop):
raise e.InterfaceError(
"Psycopg cannot use the 'ProactorEventLoop' to run in async"
" mode. Please use a compatible event loop, for instance by"
" setting 'asyncio.set_event_loop_policy"
"(WindowsSelectorEventLoopPolicy())'"
)
params = await cls._get_connection_params(conninfo, **kwargs)
timeout = int(params["connect_timeout"])
rv = None
async for attempt in conninfo_attempts_async(params):
try:
conninfo = make_conninfo(**attempt)
rv = await cls._wait_conn(cls._connect_gen(conninfo), timeout=timeout)
break
except e._NO_TRACEBACK as ex:
last_ex = ex
if not rv:
assert last_ex
raise last_ex.with_traceback(None)
rv._autocommit = bool(autocommit)
if row_factory:
rv.row_factory = row_factory
if cursor_factory:
rv.cursor_factory = cursor_factory
if context:
rv._adapters = AdaptersMap(context.adapters)
rv.prepare_threshold = prepare_threshold
return rv
async def __aenter__(self: _Self) -> _Self:
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
if self.closed:
return
if exc_type:
# try to rollback, but if there are problems (connection in a bad
# state) just warn without clobbering the exception bubbling up.
try:
await self.rollback()
except Exception as exc2:
logger.warning(
"error ignored in rollback on %s: %s",
self,
exc2,
)
else:
await self.commit()
# Close the connection only if it doesn't belong to a pool.
if not getattr(self, "_pool", None):
await self.close()
@classmethod
async def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> ConnDict:
"""Manipulate connection parameters before connecting."""
params = conninfo_to_dict(conninfo, **kwargs)
# Make sure there is an usable connect_timeout
if "connect_timeout" in params:
params["connect_timeout"] = int(params["connect_timeout"])
else:
# The sync connect function will stop on the default socket timeout
# Because in async connection mode we need to enforce the timeout
# ourselves, we need a finite value.
params["connect_timeout"] = cls._DEFAULT_CONNECT_TIMEOUT
return params
async def close(self) -> None:
if self.closed:
return
self._closed = True
# TODO: maybe send a cancel on close, if the connection is ACTIVE?
self.pgconn.finish()
@overload
def cursor(self, *, binary: bool = False) -> AsyncCursor[Row]:
...
@overload
def cursor(
self, *, binary: bool = False, row_factory: AsyncRowFactory[CursorRow]
) -> AsyncCursor[CursorRow]:
...
@overload
def cursor(
self,
name: str,
*,
binary: bool = False,
scrollable: Optional[bool] = None,
withhold: bool = False,
) -> AsyncServerCursor[Row]:
...
@overload
def cursor(
self,
name: str,
*,
binary: bool = False,
row_factory: AsyncRowFactory[CursorRow],
scrollable: Optional[bool] = None,
withhold: bool = False,
) -> AsyncServerCursor[CursorRow]:
...
def cursor(
self,
name: str = "",
*,
binary: bool = False,
row_factory: Optional[AsyncRowFactory[Any]] = None,
scrollable: Optional[bool] = None,
withhold: bool = False,
) -> Union[AsyncCursor[Any], AsyncServerCursor[Any]]:
"""
Return a new `AsyncCursor` to send commands and queries to the connection.
"""
self._check_connection_ok()
if not row_factory:
row_factory = self.row_factory
cur: Union[AsyncCursor[Any], AsyncServerCursor[Any]]
if name:
cur = self.server_cursor_factory(
self,
name=name,
row_factory=row_factory,
scrollable=scrollable,
withhold=withhold,
)
else:
cur = self.cursor_factory(self, row_factory=row_factory)
if binary:
cur.format = BINARY
return cur
async def execute(
self,
query: Query,
params: Optional[Params] = None,
*,
prepare: Optional[bool] = None,
binary: bool = False,
) -> AsyncCursor[Row]:
try:
cur = self.cursor()
if binary:
cur.format = BINARY
return await cur.execute(query, params, prepare=prepare)
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
async def commit(self) -> None:
async with self.lock:
await self.wait(self._commit_gen())
async def rollback(self) -> None:
async with self.lock:
await self.wait(self._rollback_gen())
@asynccontextmanager
async def transaction(
self,
savepoint_name: Optional[str] = None,
force_rollback: bool = False,
) -> AsyncIterator[AsyncTransaction]:
"""
Start a context block with a new transaction or nested transaction.
:rtype: AsyncTransaction
"""
tx = AsyncTransaction(self, savepoint_name, force_rollback)
if self._pipeline:
async with self.pipeline(), tx, self.pipeline():
yield tx
else:
async with tx:
yield tx
async def notifies(self) -> AsyncGenerator[Notify, None]:
while True:
async with self.lock:
try:
ns = await self.wait(notifies(self.pgconn))
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
enc = pgconn_encoding(self.pgconn)
for pgn in ns:
n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid)
yield n
@asynccontextmanager
async def pipeline(self) -> AsyncIterator[AsyncPipeline]:
"""Context manager to switch the connection into pipeline mode."""
async with self.lock:
self._check_connection_ok()
pipeline = self._pipeline
if pipeline is None:
# WARNING: reference loop, broken ahead.
pipeline = self._pipeline = AsyncPipeline(self)
try:
async with pipeline:
yield pipeline
finally:
if pipeline.level == 0:
async with self.lock:
assert pipeline is self._pipeline
self._pipeline = None
async def wait(self, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV:
try:
return await waiting.wait_async(gen, self.pgconn.socket, timeout=timeout)
except (asyncio.CancelledError, KeyboardInterrupt):
# On Ctrl-C, try to cancel the query in the server, otherwise
# the connection will remain stuck in ACTIVE state.
self._try_cancel(self.pgconn)
try:
await waiting.wait_async(gen, self.pgconn.socket, timeout=timeout)
except e.QueryCanceled:
pass # as expected
raise
@classmethod
async def _wait_conn(cls, gen: PQGenConn[RV], timeout: Optional[int]) -> RV:
return await waiting.wait_conn_async(gen, timeout)
def _set_autocommit(self, value: bool) -> None:
self._no_set_async("autocommit")
async def set_autocommit(self, value: bool) -> None:
"""Async version of the `~Connection.autocommit` setter."""
async with self.lock:
await self.wait(self._set_autocommit_gen(value))
def _set_isolation_level(self, value: Optional[IsolationLevel]) -> None:
self._no_set_async("isolation_level")
async def set_isolation_level(self, value: Optional[IsolationLevel]) -> None:
"""Async version of the `~Connection.isolation_level` setter."""
async with self.lock:
await self.wait(self._set_isolation_level_gen(value))
def _set_read_only(self, value: Optional[bool]) -> None:
self._no_set_async("read_only")
async def set_read_only(self, value: Optional[bool]) -> None:
"""Async version of the `~Connection.read_only` setter."""
async with self.lock:
await self.wait(self._set_read_only_gen(value))
def _set_deferrable(self, value: Optional[bool]) -> None:
self._no_set_async("deferrable")
async def set_deferrable(self, value: Optional[bool]) -> None:
"""Async version of the `~Connection.deferrable` setter."""
async with self.lock:
await self.wait(self._set_deferrable_gen(value))
def _no_set_async(self, attribute: str) -> None:
raise AttributeError(
f"'the {attribute!r} property is read-only on async connections:"
f" please use 'await .set_{attribute}()' instead."
)
async def tpc_begin(self, xid: Union[Xid, str]) -> None:
async with self.lock:
await self.wait(self._tpc_begin_gen(xid))
async def tpc_prepare(self) -> None:
try:
async with self.lock:
await self.wait(self._tpc_prepare_gen())
except e.ObjectNotInPrerequisiteState as ex:
raise e.NotSupportedError(str(ex)) from None
async def tpc_commit(self, xid: Union[Xid, str, None] = None) -> None:
async with self.lock:
await self.wait(self._tpc_finish_gen("commit", xid))
async def tpc_rollback(self, xid: Union[Xid, str, None] = None) -> None:
async with self.lock:
await self.wait(self._tpc_finish_gen("rollback", xid))
async def tpc_recover(self) -> List[Xid]:
self._check_tpc()
status = self.info.transaction_status
async with self.cursor(row_factory=args_row(Xid._from_record)) as cur:
await cur.execute(Xid._get_recover_query())
res = await cur.fetchall()
if status == IDLE and self.info.transaction_status == INTRANS:
await self.rollback()
return res

View File

@ -0,0 +1,478 @@
"""
Functions to manipulate conninfo strings
"""
# Copyright (C) 2020 The Psycopg Team
from __future__ import annotations
import os
import re
import socket
import asyncio
from typing import Any, Iterator, AsyncIterator
from random import shuffle
from pathlib import Path
from datetime import tzinfo
from functools import lru_cache
from ipaddress import ip_address
from typing_extensions import TypeAlias
from . import pq
from . import errors as e
from ._tz import get_tzinfo
from ._compat import cache
from ._encodings import pgconn_encoding
ConnDict: TypeAlias = "dict[str, Any]"
def make_conninfo(conninfo: str = "", **kwargs: Any) -> str:
"""
Merge a string and keyword params into a single conninfo string.
:param conninfo: A `connection string`__ as accepted by PostgreSQL.
:param kwargs: Parameters overriding the ones specified in `!conninfo`.
:return: A connection string valid for PostgreSQL, with the `!kwargs`
parameters merged.
Raise `~psycopg.ProgrammingError` if the input doesn't make a valid
conninfo string.
.. __: https://www.postgresql.org/docs/current/libpq-connect.html
#LIBPQ-CONNSTRING
"""
if not conninfo and not kwargs:
return ""
# If no kwarg specified don't mung the conninfo but check if it's correct.
# Make sure to return a string, not a subtype, to avoid making Liskov sad.
if not kwargs:
_parse_conninfo(conninfo)
return str(conninfo)
# Override the conninfo with the parameters
# Drop the None arguments
kwargs = {k: v for (k, v) in kwargs.items() if v is not None}
if conninfo:
tmp = conninfo_to_dict(conninfo)
tmp.update(kwargs)
kwargs = tmp
conninfo = " ".join(f"{k}={_param_escape(str(v))}" for (k, v) in kwargs.items())
# Verify the result is valid
_parse_conninfo(conninfo)
return conninfo
def conninfo_to_dict(conninfo: str = "", **kwargs: Any) -> ConnDict:
"""
Convert the `!conninfo` string into a dictionary of parameters.
:param conninfo: A `connection string`__ as accepted by PostgreSQL.
:param kwargs: Parameters overriding the ones specified in `!conninfo`.
:return: Dictionary with the parameters parsed from `!conninfo` and
`!kwargs`.
Raise `~psycopg.ProgrammingError` if `!conninfo` is not a a valid connection
string.
.. __: https://www.postgresql.org/docs/current/libpq-connect.html
#LIBPQ-CONNSTRING
"""
opts = _parse_conninfo(conninfo)
rv = {opt.keyword.decode(): opt.val.decode() for opt in opts if opt.val is not None}
for k, v in kwargs.items():
if v is not None:
rv[k] = v
return rv
def _parse_conninfo(conninfo: str) -> list[pq.ConninfoOption]:
"""
Verify that `!conninfo` is a valid connection string.
Raise ProgrammingError if the string is not valid.
Return the result of pq.Conninfo.parse() on success.
"""
try:
return pq.Conninfo.parse(conninfo.encode())
except e.OperationalError as ex:
raise e.ProgrammingError(str(ex)) from None
re_escape = re.compile(r"([\\'])")
re_space = re.compile(r"\s")
def _param_escape(s: str) -> str:
"""
Apply the escaping rule required by PQconnectdb
"""
if not s:
return "''"
s = re_escape.sub(r"\\\1", s)
if re_space.search(s):
s = "'" + s + "'"
return s
class ConnectionInfo:
"""Allow access to information about the connection."""
__module__ = "psycopg"
def __init__(self, pgconn: pq.abc.PGconn):
self.pgconn = pgconn
@property
def vendor(self) -> str:
"""A string representing the database vendor connected to."""
return "PostgreSQL"
@property
def host(self) -> str:
"""The server host name of the active connection. See :pq:`PQhost()`."""
return self._get_pgconn_attr("host")
@property
def hostaddr(self) -> str:
"""The server IP address of the connection. See :pq:`PQhostaddr()`."""
return self._get_pgconn_attr("hostaddr")
@property
def port(self) -> int:
"""The port of the active connection. See :pq:`PQport()`."""
return int(self._get_pgconn_attr("port"))
@property
def dbname(self) -> str:
"""The database name of the connection. See :pq:`PQdb()`."""
return self._get_pgconn_attr("db")
@property
def user(self) -> str:
"""The user name of the connection. See :pq:`PQuser()`."""
return self._get_pgconn_attr("user")
@property
def password(self) -> str:
"""The password of the connection. See :pq:`PQpass()`."""
return self._get_pgconn_attr("password")
@property
def options(self) -> str:
"""
The command-line options passed in the connection request.
See :pq:`PQoptions`.
"""
return self._get_pgconn_attr("options")
def get_parameters(self) -> dict[str, str]:
"""Return the connection parameters values.
Return all the parameters set to a non-default value, which might come
either from the connection string and parameters passed to
`~Connection.connect()` or from environment variables. The password
is never returned (you can read it using the `password` attribute).
"""
pyenc = self.encoding
# Get the known defaults to avoid reporting them
defaults = {
i.keyword: i.compiled
for i in pq.Conninfo.get_defaults()
if i.compiled is not None
}
# Not returned by the libq. Bug? Bet we're using SSH.
defaults.setdefault(b"channel_binding", b"prefer")
defaults[b"passfile"] = str(Path.home() / ".pgpass").encode()
return {
i.keyword.decode(pyenc): i.val.decode(pyenc)
for i in self.pgconn.info
if i.val is not None
and i.keyword != b"password"
and i.val != defaults.get(i.keyword)
}
@property
def dsn(self) -> str:
"""Return the connection string to connect to the database.
The string contains all the parameters set to a non-default value,
which might come either from the connection string and parameters
passed to `~Connection.connect()` or from environment variables. The
password is never returned (you can read it using the `password`
attribute).
"""
return make_conninfo(**self.get_parameters())
@property
def status(self) -> pq.ConnStatus:
"""The status of the connection. See :pq:`PQstatus()`."""
return pq.ConnStatus(self.pgconn.status)
@property
def transaction_status(self) -> pq.TransactionStatus:
"""
The current in-transaction status of the session.
See :pq:`PQtransactionStatus()`.
"""
return pq.TransactionStatus(self.pgconn.transaction_status)
@property
def pipeline_status(self) -> pq.PipelineStatus:
"""
The current pipeline status of the client.
See :pq:`PQpipelineStatus()`.
"""
return pq.PipelineStatus(self.pgconn.pipeline_status)
def parameter_status(self, param_name: str) -> str | None:
"""
Return a parameter setting of the connection.
Return `None` is the parameter is unknown.
"""
res = self.pgconn.parameter_status(param_name.encode(self.encoding))
return res.decode(self.encoding) if res is not None else None
@property
def server_version(self) -> int:
"""
An integer representing the server version. See :pq:`PQserverVersion()`.
"""
return self.pgconn.server_version
@property
def backend_pid(self) -> int:
"""
The process ID (PID) of the backend process handling this connection.
See :pq:`PQbackendPID()`.
"""
return self.pgconn.backend_pid
@property
def error_message(self) -> str:
"""
The error message most recently generated by an operation on the connection.
See :pq:`PQerrorMessage()`.
"""
return self._get_pgconn_attr("error_message")
@property
def timezone(self) -> tzinfo:
"""The Python timezone info of the connection's timezone."""
return get_tzinfo(self.pgconn)
@property
def encoding(self) -> str:
"""The Python codec name of the connection's client encoding."""
return pgconn_encoding(self.pgconn)
def _get_pgconn_attr(self, name: str) -> str:
value: bytes = getattr(self.pgconn, name)
return value.decode(self.encoding)
def conninfo_attempts(params: ConnDict) -> Iterator[ConnDict]:
"""Split a set of connection params on the single attempts to perforn.
A connection param can perform more than one attempt more than one ``host``
is provided.
Because the libpq async function doesn't honour the timeout, we need to
reimplement the repeated attempts.
"""
if params.get("load_balance_hosts", "disable") == "random":
attempts = list(_split_attempts(_inject_defaults(params)))
shuffle(attempts)
for attempt in attempts:
yield attempt
else:
for attempt in _split_attempts(_inject_defaults(params)):
yield attempt
async def conninfo_attempts_async(params: ConnDict) -> AsyncIterator[ConnDict]:
"""Split a set of connection params on the single attempts to perforn.
A connection param can perform more than one attempt more than one ``host``
is provided.
Also perform async resolution of the hostname into hostaddr in order to
avoid blocking. Because a host can resolve to more than one address, this
can lead to yield more attempts too. Raise `OperationalError` if no host
could be resolved.
Because the libpq async function doesn't honour the timeout, we need to
reimplement the repeated attempts.
"""
yielded = False
last_exc = None
for attempt in _split_attempts(_inject_defaults(params)):
try:
async for a2 in _split_attempts_and_resolve(attempt):
yielded = True
yield a2
except OSError as ex:
last_exc = ex
if not yielded:
assert last_exc
# We couldn't resolve anything
raise e.OperationalError(str(last_exc))
def _inject_defaults(params: ConnDict) -> ConnDict:
"""
Add defaults to a dictionary of parameters.
This avoids the need to look up for env vars at various stages during
processing.
Note that a port is always specified. 5432 likely comes from here.
The `host`, `hostaddr`, `port` will be always set to a string.
"""
defaults = _conn_defaults()
out = params.copy()
def inject(name: str, envvar: str) -> None:
value = out.get(name)
if not value:
out[name] = os.environ.get(envvar, defaults[name])
else:
out[name] = str(value)
inject("host", "PGHOST")
inject("hostaddr", "PGHOSTADDR")
inject("port", "PGPORT")
return out
def _split_attempts(params: ConnDict) -> Iterator[ConnDict]:
"""
Split connection parameters with a sequence of hosts into separate attempts.
Assume that `host`, `hostaddr`, `port` are always present and a string (as
emitted from `_inject_defaults()`).
"""
def split_val(key: str) -> list[str]:
# Assume all keys are present and strings.
val: str = params[key]
return val.split(",") if val else []
hosts = split_val("host")
hostaddrs = split_val("hostaddr")
ports = split_val("port")
if hosts and hostaddrs and len(hosts) != len(hostaddrs):
raise e.OperationalError(
f"could not match {len(hosts)} host names"
f" with {len(hostaddrs)} hostaddr values"
)
nhosts = max(len(hosts), len(hostaddrs))
if 1 < len(ports) != nhosts:
raise e.OperationalError(
f"could not match {len(ports)} port numbers to {len(hosts)} hosts"
)
elif len(ports) == 1:
ports *= nhosts
# A single attempt to make
if nhosts <= 1:
yield params
return
# Now all lists are either empty or have the same length
for i in range(nhosts):
attempt = params.copy()
if hosts:
attempt["host"] = hosts[i]
if hostaddrs:
attempt["hostaddr"] = hostaddrs[i]
if ports:
attempt["port"] = ports[i]
yield attempt
async def _split_attempts_and_resolve(params: ConnDict) -> AsyncIterator[ConnDict]:
"""
Perform async DNS lookup of the hosts and return a new params dict.
:param params: The input parameters, for instance as returned by
`~psycopg.conninfo.conninfo_to_dict()`. The function expects at most
a single entry for host, hostaddr, port and doesn't check for env vars
because it is designed to further process the input of _split_attempts()
If a ``host`` param is present but not ``hostname``, resolve the host
addresses dynamically.
The function may change the input ``host``, ``hostname``, ``port`` to allow
connecting without further DNS lookups.
Raise `~psycopg.OperationalError` if resolution fails.
"""
host = params["host"]
if not host or host.startswith("/") or host[1:2] == ":":
# Local path, or no host to resolve
yield params
return
hostaddr = params["hostaddr"]
if hostaddr:
# Already resolved
yield params
return
if is_ip_address(host):
# If the host is already an ip address don't try to resolve it
params["hostaddr"] = host
yield params
return
loop = asyncio.get_running_loop()
port = params["port"]
ans = await loop.getaddrinfo(
host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
)
attempt = params.copy()
for item in ans:
attempt["hostaddr"] = item[4][0]
yield attempt
@cache
def _conn_defaults() -> dict[str, str]:
"""
Return a dictionary of defaults for connection strings parameters.
"""
defs = pq.Conninfo.get_defaults()
return {
d.keyword.decode(): d.compiled.decode() if d.compiled is not None else ""
for d in defs
}
@lru_cache()
def is_ip_address(s: str) -> bool:
"""Return True if the string represent a valid ip address."""
try:
ip_address(s)
except ValueError:
return False
return True

View File

@ -0,0 +1,919 @@
"""
psycopg copy support
"""
# Copyright (C) 2020 The Psycopg Team
import re
import queue
import struct
import asyncio
import threading
from abc import ABC, abstractmethod
from types import TracebackType
from typing import Any, AsyncIterator, Dict, Generic, Iterator, List, Match, IO
from typing import Optional, Sequence, Tuple, Type, TypeVar, Union, TYPE_CHECKING
from . import pq
from . import adapt
from . import errors as e
from .abc import Buffer, ConnectionType, PQGen, Transformer
from ._compat import create_task
from .pq.misc import connection_summary
from ._cmodule import _psycopg
from ._encodings import pgconn_encoding
from .generators import copy_from, copy_to, copy_end
if TYPE_CHECKING:
from .cursor import BaseCursor, Cursor
from .cursor_async import AsyncCursor
from .connection import Connection # noqa: F401
from .connection_async import AsyncConnection # noqa: F401
PY_TEXT = adapt.PyFormat.TEXT
PY_BINARY = adapt.PyFormat.BINARY
TEXT = pq.Format.TEXT
BINARY = pq.Format.BINARY
COPY_IN = pq.ExecStatus.COPY_IN
COPY_OUT = pq.ExecStatus.COPY_OUT
ACTIVE = pq.TransactionStatus.ACTIVE
# Size of data to accumulate before sending it down the network. We fill a
# buffer this size field by field, and when it passes the threshold size
# we ship it, so it may end up being bigger than this.
BUFFER_SIZE = 32 * 1024
# Maximum data size we want to queue to send to the libpq copy. Sending a
# buffer too big to be handled can cause an infinite loop in the libpq
# (#255) so we want to split it in more digestable chunks.
MAX_BUFFER_SIZE = 4 * BUFFER_SIZE
# Note: making this buffer too large, e.g.
# MAX_BUFFER_SIZE = 1024 * 1024
# makes operations *way* slower! Probably triggering some quadraticity
# in the libpq memory management and data sending.
# Max size of the write queue of buffers. More than that copy will block
# Each buffer should be around BUFFER_SIZE size.
QUEUE_SIZE = 1024
class BaseCopy(Generic[ConnectionType]):
"""
Base implementation for the copy user interface.
Two subclasses expose real methods with the sync/async differences.
The difference between the text and binary format is managed by two
different `Formatter` subclasses.
Writing (the I/O part) is implemented in the subclasses by a `Writer` or
`AsyncWriter` instance. Normally writing implies sending copy data to a
database, but a different writer might be chosen, e.g. to stream data into
a file for later use.
"""
_Self = TypeVar("_Self", bound="BaseCopy[Any]")
formatter: "Formatter"
def __init__(
self,
cursor: "BaseCursor[ConnectionType, Any]",
*,
binary: Optional[bool] = None,
):
self.cursor = cursor
self.connection = cursor.connection
self._pgconn = self.connection.pgconn
result = cursor.pgresult
if result:
self._direction = result.status
if self._direction != COPY_IN and self._direction != COPY_OUT:
raise e.ProgrammingError(
"the cursor should have performed a COPY operation;"
f" its status is {pq.ExecStatus(self._direction).name} instead"
)
else:
self._direction = COPY_IN
if binary is None:
binary = bool(result and result.binary_tuples)
tx: Transformer = getattr(cursor, "_tx", None) or adapt.Transformer(cursor)
if binary:
self.formatter = BinaryFormatter(tx)
else:
self.formatter = TextFormatter(tx, encoding=pgconn_encoding(self._pgconn))
self._finished = False
def __repr__(self) -> str:
cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
info = connection_summary(self._pgconn)
return f"<{cls} {info} at 0x{id(self):x}>"
def _enter(self) -> None:
if self._finished:
raise TypeError("copy blocks can be used only once")
def set_types(self, types: Sequence[Union[int, str]]) -> None:
"""
Set the types expected in a COPY operation.
The types must be specified as a sequence of oid or PostgreSQL type
names (e.g. ``int4``, ``timestamptz[]``).
This operation overcomes the lack of metadata returned by PostgreSQL
when a COPY operation begins:
- On :sql:`COPY TO`, `!set_types()` allows to specify what types the
operation returns. If `!set_types()` is not used, the data will be
returned as unparsed strings or bytes instead of Python objects.
- On :sql:`COPY FROM`, `!set_types()` allows to choose what type the
database expects. This is especially useful in binary copy, because
PostgreSQL will apply no cast rule.
"""
registry = self.cursor.adapters.types
oids = [t if isinstance(t, int) else registry.get_oid(t) for t in types]
if self._direction == COPY_IN:
self.formatter.transformer.set_dumper_types(oids, self.formatter.format)
else:
self.formatter.transformer.set_loader_types(oids, self.formatter.format)
# High level copy protocol generators (state change of the Copy object)
def _read_gen(self) -> PQGen[Buffer]:
if self._finished:
return memoryview(b"")
res = yield from copy_from(self._pgconn)
if isinstance(res, memoryview):
return res
# res is the final PGresult
self._finished = True
# This result is a COMMAND_OK which has info about the number of rows
# returned, but not about the columns, which is instead an information
# that was received on the COPY_OUT result at the beginning of COPY.
# So, don't replace the results in the cursor, just update the rowcount.
nrows = res.command_tuples
self.cursor._rowcount = nrows if nrows is not None else -1
return memoryview(b"")
def _read_row_gen(self) -> PQGen[Optional[Tuple[Any, ...]]]:
data = yield from self._read_gen()
if not data:
return None
row = self.formatter.parse_row(data)
if row is None:
# Get the final result to finish the copy operation
yield from self._read_gen()
self._finished = True
return None
return row
def _end_copy_out_gen(self, exc: Optional[BaseException]) -> PQGen[None]:
if not exc:
return
if self._pgconn.transaction_status != ACTIVE:
# The server has already finished to send copy data. The connection
# is already in a good state.
return
# Throw a cancel to the server, then consume the rest of the copy data
# (which might or might not have been already transferred entirely to
# the client, so we won't necessary see the exception associated with
# canceling).
self.connection.cancel()
try:
while (yield from self._read_gen()):
pass
except e.QueryCanceled:
pass
class Copy(BaseCopy["Connection[Any]"]):
"""Manage a :sql:`COPY` operation.
:param cursor: the cursor where the operation is performed.
:param binary: if `!True`, write binary format.
:param writer: the object to write to destination. If not specified, write
to the `!cursor` connection.
Choosing `!binary` is not necessary if the cursor has executed a
:sql:`COPY` operation, because the operation result describes the format
too. The parameter is useful when a `!Copy` object is created manually and
no operation is performed on the cursor, such as when using ``writer=``\\
`~psycopg.copy.FileWriter`.
"""
__module__ = "psycopg"
writer: "Writer"
def __init__(
self,
cursor: "Cursor[Any]",
*,
binary: Optional[bool] = None,
writer: Optional["Writer"] = None,
):
super().__init__(cursor, binary=binary)
if not writer:
writer = LibpqWriter(cursor)
self.writer = writer
self._write = writer.write
def __enter__(self: BaseCopy._Self) -> BaseCopy._Self:
self._enter()
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self.finish(exc_val)
# End user sync interface
def __iter__(self) -> Iterator[Buffer]:
"""Implement block-by-block iteration on :sql:`COPY TO`."""
while True:
data = self.read()
if not data:
break
yield data
def read(self) -> Buffer:
"""
Read an unparsed row after a :sql:`COPY TO` operation.
Return an empty string when the data is finished.
"""
return self.connection.wait(self._read_gen())
def rows(self) -> Iterator[Tuple[Any, ...]]:
"""
Iterate on the result of a :sql:`COPY TO` operation record by record.
Note that the records returned will be tuples of unparsed strings or
bytes, unless data types are specified using `set_types()`.
"""
while True:
record = self.read_row()
if record is None:
break
yield record
def read_row(self) -> Optional[Tuple[Any, ...]]:
"""
Read a parsed row of data from a table after a :sql:`COPY TO` operation.
Return `!None` when the data is finished.
Note that the records returned will be tuples of unparsed strings or
bytes, unless data types are specified using `set_types()`.
"""
return self.connection.wait(self._read_row_gen())
def write(self, buffer: Union[Buffer, str]) -> None:
"""
Write a block of data to a table after a :sql:`COPY FROM` operation.
If the :sql:`COPY` is in binary format `!buffer` must be `!bytes`. In
text mode it can be either `!bytes` or `!str`.
"""
data = self.formatter.write(buffer)
if data:
self._write(data)
def write_row(self, row: Sequence[Any]) -> None:
"""Write a record to a table after a :sql:`COPY FROM` operation."""
data = self.formatter.write_row(row)
if data:
self._write(data)
def finish(self, exc: Optional[BaseException]) -> None:
"""Terminate the copy operation and free the resources allocated.
You shouldn't need to call this function yourself: it is usually called
by exit. It is available if, despite what is documented, you end up
using the `Copy` object outside a block.
"""
if self._direction == COPY_IN:
data = self.formatter.end()
if data:
self._write(data)
self.writer.finish(exc)
self._finished = True
else:
self.connection.wait(self._end_copy_out_gen(exc))
class Writer(ABC):
"""
A class to write copy data somewhere.
"""
@abstractmethod
def write(self, data: Buffer) -> None:
"""
Write some data to destination.
"""
...
def finish(self, exc: Optional[BaseException] = None) -> None:
"""
Called when write operations are finished.
If operations finished with an error, it will be passed to ``exc``.
"""
pass
class LibpqWriter(Writer):
"""
A `Writer` to write copy data to a Postgres database.
"""
def __init__(self, cursor: "Cursor[Any]"):
self.cursor = cursor
self.connection = cursor.connection
self._pgconn = self.connection.pgconn
def write(self, data: Buffer) -> None:
if len(data) <= MAX_BUFFER_SIZE:
# Most used path: we don't need to split the buffer in smaller
# bits, so don't make a copy.
self.connection.wait(copy_to(self._pgconn, data))
else:
# Copy a buffer too large in chunks to avoid causing a memory
# error in the libpq, which may cause an infinite loop (#255).
for i in range(0, len(data), MAX_BUFFER_SIZE):
self.connection.wait(
copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE])
)
def finish(self, exc: Optional[BaseException] = None) -> None:
bmsg: Optional[bytes]
if exc:
msg = f"error from Python: {type(exc).__qualname__} - {exc}"
bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace")
else:
bmsg = None
try:
res = self.connection.wait(copy_end(self._pgconn, bmsg))
# The QueryCanceled is expected if we sent an exception message to
# pgconn.put_copy_end(). The Python exception that generated that
# cancelling is more important, so don't clobber it.
except e.QueryCanceled:
if not bmsg:
raise
else:
self.cursor._results = [res]
class QueuedLibpqWriter(LibpqWriter):
"""
A writer using a buffer to queue data to write to a Postgres database.
`write()` returns immediately, so that the main thread can be CPU-bound
formatting messages, while a worker thread can be IO-bound waiting to write
on the connection.
"""
def __init__(self, cursor: "Cursor[Any]"):
super().__init__(cursor)
self._queue: queue.Queue[Buffer] = queue.Queue(maxsize=QUEUE_SIZE)
self._worker: Optional[threading.Thread] = None
self._worker_error: Optional[BaseException] = None
def worker(self) -> None:
"""Push data to the server when available from the copy queue.
Terminate reading when the queue receives a false-y value, or in case
of error.
The function is designed to be run in a separate thread.
"""
try:
while True:
data = self._queue.get(block=True, timeout=24 * 60 * 60)
if not data:
break
self.connection.wait(copy_to(self._pgconn, data))
except BaseException as ex:
# Propagate the error to the main thread.
self._worker_error = ex
def write(self, data: Buffer) -> None:
if not self._worker:
# warning: reference loop, broken by _write_end
self._worker = threading.Thread(target=self.worker)
self._worker.daemon = True
self._worker.start()
# If the worker thread raies an exception, re-raise it to the caller.
if self._worker_error:
raise self._worker_error
if len(data) <= MAX_BUFFER_SIZE:
# Most used path: we don't need to split the buffer in smaller
# bits, so don't make a copy.
self._queue.put(data)
else:
# Copy a buffer too large in chunks to avoid causing a memory
# error in the libpq, which may cause an infinite loop (#255).
for i in range(0, len(data), MAX_BUFFER_SIZE):
self._queue.put(data[i : i + MAX_BUFFER_SIZE])
def finish(self, exc: Optional[BaseException] = None) -> None:
self._queue.put(b"")
if self._worker:
self._worker.join()
self._worker = None # break the loop
# Check if the worker thread raised any exception before terminating.
if self._worker_error:
raise self._worker_error
super().finish(exc)
class FileWriter(Writer):
"""
A `Writer` to write copy data to a file-like object.
:param file: the file where to write copy data. It must be open for writing
in binary mode.
"""
def __init__(self, file: IO[bytes]):
self.file = file
def write(self, data: Buffer) -> None:
self.file.write(data)
class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
"""Manage an asynchronous :sql:`COPY` operation."""
__module__ = "psycopg"
writer: "AsyncWriter"
def __init__(
self,
cursor: "AsyncCursor[Any]",
*,
binary: Optional[bool] = None,
writer: Optional["AsyncWriter"] = None,
):
super().__init__(cursor, binary=binary)
if not writer:
writer = AsyncLibpqWriter(cursor)
self.writer = writer
self._write = writer.write
async def __aenter__(self: BaseCopy._Self) -> BaseCopy._Self:
self._enter()
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
await self.finish(exc_val)
async def __aiter__(self) -> AsyncIterator[Buffer]:
while True:
data = await self.read()
if not data:
break
yield data
async def read(self) -> Buffer:
return await self.connection.wait(self._read_gen())
async def rows(self) -> AsyncIterator[Tuple[Any, ...]]:
while True:
record = await self.read_row()
if record is None:
break
yield record
async def read_row(self) -> Optional[Tuple[Any, ...]]:
return await self.connection.wait(self._read_row_gen())
async def write(self, buffer: Union[Buffer, str]) -> None:
data = self.formatter.write(buffer)
if data:
await self._write(data)
async def write_row(self, row: Sequence[Any]) -> None:
data = self.formatter.write_row(row)
if data:
await self._write(data)
async def finish(self, exc: Optional[BaseException]) -> None:
if self._direction == COPY_IN:
data = self.formatter.end()
if data:
await self._write(data)
await self.writer.finish(exc)
self._finished = True
else:
await self.connection.wait(self._end_copy_out_gen(exc))
class AsyncWriter(ABC):
"""
A class to write copy data somewhere (for async connections).
"""
@abstractmethod
async def write(self, data: Buffer) -> None:
...
async def finish(self, exc: Optional[BaseException] = None) -> None:
pass
class AsyncLibpqWriter(AsyncWriter):
"""
An `AsyncWriter` to write copy data to a Postgres database.
"""
def __init__(self, cursor: "AsyncCursor[Any]"):
self.cursor = cursor
self.connection = cursor.connection
self._pgconn = self.connection.pgconn
async def write(self, data: Buffer) -> None:
if len(data) <= MAX_BUFFER_SIZE:
# Most used path: we don't need to split the buffer in smaller
# bits, so don't make a copy.
await self.connection.wait(copy_to(self._pgconn, data))
else:
# Copy a buffer too large in chunks to avoid causing a memory
# error in the libpq, which may cause an infinite loop (#255).
for i in range(0, len(data), MAX_BUFFER_SIZE):
await self.connection.wait(
copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE])
)
async def finish(self, exc: Optional[BaseException] = None) -> None:
bmsg: Optional[bytes]
if exc:
msg = f"error from Python: {type(exc).__qualname__} - {exc}"
bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace")
else:
bmsg = None
try:
res = await self.connection.wait(copy_end(self._pgconn, bmsg))
# The QueryCanceled is expected if we sent an exception message to
# pgconn.put_copy_end(). The Python exception that generated that
# cancelling is more important, so don't clobber it.
except e.QueryCanceled:
if not bmsg:
raise
else:
self.cursor._results = [res]
class AsyncQueuedLibpqWriter(AsyncLibpqWriter):
"""
An `AsyncWriter` using a buffer to queue data to write.
`write()` returns immediately, so that the main thread can be CPU-bound
formatting messages, while a worker thread can be IO-bound waiting to write
on the connection.
"""
def __init__(self, cursor: "AsyncCursor[Any]"):
super().__init__(cursor)
self._queue: asyncio.Queue[Buffer] = asyncio.Queue(maxsize=QUEUE_SIZE)
self._worker: Optional[asyncio.Future[None]] = None
async def worker(self) -> None:
"""Push data to the server when available from the copy queue.
Terminate reading when the queue receives a false-y value.
The function is designed to be run in a separate task.
"""
while True:
data = await self._queue.get()
if not data:
break
await self.connection.wait(copy_to(self._pgconn, data))
async def write(self, data: Buffer) -> None:
if not self._worker:
self._worker = create_task(self.worker())
if len(data) <= MAX_BUFFER_SIZE:
# Most used path: we don't need to split the buffer in smaller
# bits, so don't make a copy.
await self._queue.put(data)
else:
# Copy a buffer too large in chunks to avoid causing a memory
# error in the libpq, which may cause an infinite loop (#255).
for i in range(0, len(data), MAX_BUFFER_SIZE):
await self._queue.put(data[i : i + MAX_BUFFER_SIZE])
async def finish(self, exc: Optional[BaseException] = None) -> None:
await self._queue.put(b"")
if self._worker:
await asyncio.gather(self._worker)
self._worker = None # break reference loops if any
await super().finish(exc)
class Formatter(ABC):
"""
A class which understand a copy format (text, binary).
"""
format: pq.Format
def __init__(self, transformer: Transformer):
self.transformer = transformer
self._write_buffer = bytearray()
self._row_mode = False # true if the user is using write_row()
@abstractmethod
def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]:
...
@abstractmethod
def write(self, buffer: Union[Buffer, str]) -> Buffer:
...
@abstractmethod
def write_row(self, row: Sequence[Any]) -> Buffer:
...
@abstractmethod
def end(self) -> Buffer:
...
class TextFormatter(Formatter):
format = TEXT
def __init__(self, transformer: Transformer, encoding: str = "utf-8"):
super().__init__(transformer)
self._encoding = encoding
def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]:
if data:
return parse_row_text(data, self.transformer)
else:
return None
def write(self, buffer: Union[Buffer, str]) -> Buffer:
data = self._ensure_bytes(buffer)
self._signature_sent = True
return data
def write_row(self, row: Sequence[Any]) -> Buffer:
# Note down that we are writing in row mode: it means we will have
# to take care of the end-of-copy marker too
self._row_mode = True
format_row_text(row, self.transformer, self._write_buffer)
if len(self._write_buffer) > BUFFER_SIZE:
buffer, self._write_buffer = self._write_buffer, bytearray()
return buffer
else:
return b""
def end(self) -> Buffer:
buffer, self._write_buffer = self._write_buffer, bytearray()
return buffer
def _ensure_bytes(self, data: Union[Buffer, str]) -> Buffer:
if isinstance(data, str):
return data.encode(self._encoding)
else:
# Assume, for simplicity, that the user is not passing stupid
# things to the write function. If that's the case, things
# will fail downstream.
return data
class BinaryFormatter(Formatter):
format = BINARY
def __init__(self, transformer: Transformer):
super().__init__(transformer)
self._signature_sent = False
def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]:
if not self._signature_sent:
if data[: len(_binary_signature)] != _binary_signature:
raise e.DataError(
"binary copy doesn't start with the expected signature"
)
self._signature_sent = True
data = data[len(_binary_signature) :]
elif data == _binary_trailer:
return None
return parse_row_binary(data, self.transformer)
def write(self, buffer: Union[Buffer, str]) -> Buffer:
data = self._ensure_bytes(buffer)
self._signature_sent = True
return data
def write_row(self, row: Sequence[Any]) -> Buffer:
# Note down that we are writing in row mode: it means we will have
# to take care of the end-of-copy marker too
self._row_mode = True
if not self._signature_sent:
self._write_buffer += _binary_signature
self._signature_sent = True
format_row_binary(row, self.transformer, self._write_buffer)
if len(self._write_buffer) > BUFFER_SIZE:
buffer, self._write_buffer = self._write_buffer, bytearray()
return buffer
else:
return b""
def end(self) -> Buffer:
# If we have sent no data we need to send the signature
# and the trailer
if not self._signature_sent:
self._write_buffer += _binary_signature
self._write_buffer += _binary_trailer
elif self._row_mode:
# if we have sent data already, we have sent the signature
# too (either with the first row, or we assume that in
# block mode the signature is included).
# Write the trailer only if we are sending rows (with the
# assumption that who is copying binary data is sending the
# whole format).
self._write_buffer += _binary_trailer
buffer, self._write_buffer = self._write_buffer, bytearray()
return buffer
def _ensure_bytes(self, data: Union[Buffer, str]) -> Buffer:
if isinstance(data, str):
raise TypeError("cannot copy str data in binary mode: use bytes instead")
else:
# Assume, for simplicity, that the user is not passing stupid
# things to the write function. If that's the case, things
# will fail downstream.
return data
def _format_row_text(
row: Sequence[Any], tx: Transformer, out: Optional[bytearray] = None
) -> bytearray:
"""Convert a row of objects to the data to send for copy."""
if out is None:
out = bytearray()
if not row:
out += b"\n"
return out
for item in row:
if item is not None:
dumper = tx.get_dumper(item, PY_TEXT)
b = dumper.dump(item)
out += _dump_re.sub(_dump_sub, b)
else:
out += rb"\N"
out += b"\t"
out[-1:] = b"\n"
return out
def _format_row_binary(
row: Sequence[Any], tx: Transformer, out: Optional[bytearray] = None
) -> bytearray:
"""Convert a row of objects to the data to send for binary copy."""
if out is None:
out = bytearray()
out += _pack_int2(len(row))
adapted = tx.dump_sequence(row, [PY_BINARY] * len(row))
for b in adapted:
if b is not None:
out += _pack_int4(len(b))
out += b
else:
out += _binary_null
return out
def _parse_row_text(data: Buffer, tx: Transformer) -> Tuple[Any, ...]:
if not isinstance(data, bytes):
data = bytes(data)
fields = data.split(b"\t")
fields[-1] = fields[-1][:-1] # drop \n
row = [None if f == b"\\N" else _load_re.sub(_load_sub, f) for f in fields]
return tx.load_sequence(row)
def _parse_row_binary(data: Buffer, tx: Transformer) -> Tuple[Any, ...]:
row: List[Optional[Buffer]] = []
nfields = _unpack_int2(data, 0)[0]
pos = 2
for i in range(nfields):
length = _unpack_int4(data, pos)[0]
pos += 4
if length >= 0:
row.append(data[pos : pos + length])
pos += length
else:
row.append(None)
return tx.load_sequence(row)
_pack_int2 = struct.Struct("!h").pack
_pack_int4 = struct.Struct("!i").pack
_unpack_int2 = struct.Struct("!h").unpack_from
_unpack_int4 = struct.Struct("!i").unpack_from
_binary_signature = (
b"PGCOPY\n\xff\r\n\0" # Signature
b"\x00\x00\x00\x00" # flags
b"\x00\x00\x00\x00" # extra length
)
_binary_trailer = b"\xff\xff"
_binary_null = b"\xff\xff\xff\xff"
_dump_re = re.compile(b"[\b\t\n\v\f\r\\\\]")
_dump_repl = {
b"\b": b"\\b",
b"\t": b"\\t",
b"\n": b"\\n",
b"\v": b"\\v",
b"\f": b"\\f",
b"\r": b"\\r",
b"\\": b"\\\\",
}
def _dump_sub(m: Match[bytes], __map: Dict[bytes, bytes] = _dump_repl) -> bytes:
return __map[m.group(0)]
_load_re = re.compile(b"\\\\[btnvfr\\\\]")
_load_repl = {v: k for k, v in _dump_repl.items()}
def _load_sub(m: Match[bytes], __map: Dict[bytes, bytes] = _load_repl) -> bytes:
return __map[m.group(0)]
# Override functions with fast versions if available
if _psycopg:
format_row_text = _psycopg.format_row_text
format_row_binary = _psycopg.format_row_binary
parse_row_text = _psycopg.parse_row_text
parse_row_binary = _psycopg.parse_row_binary
else:
format_row_text = _format_row_text
format_row_binary = _format_row_binary
parse_row_text = _parse_row_text
parse_row_binary = _parse_row_binary

View File

@ -0,0 +1,19 @@
"""
CockroachDB support package.
"""
# Copyright (C) 2022 The Psycopg Team
from . import _types
from .connection import CrdbConnection, AsyncCrdbConnection, CrdbConnectionInfo
adapters = _types.adapters # exposed by the package
connect = CrdbConnection.connect
_types.register_crdb_adapters(adapters)
__all__ = [
"AsyncCrdbConnection",
"CrdbConnection",
"CrdbConnectionInfo",
]

View File

@ -0,0 +1,163 @@
"""
Types configuration specific for CockroachDB.
"""
# Copyright (C) 2022 The Psycopg Team
from enum import Enum
from .._typeinfo import TypeInfo, TypesRegistry
from ..abc import AdaptContext, NoneType
from ..postgres import TEXT_OID
from .._adapters_map import AdaptersMap
from ..types.enum import EnumDumper, EnumBinaryDumper
from ..types.none import NoneDumper
types = TypesRegistry()
# Global adapter maps with PostgreSQL types configuration
adapters = AdaptersMap(types=types)
class CrdbEnumDumper(EnumDumper):
oid = TEXT_OID
class CrdbEnumBinaryDumper(EnumBinaryDumper):
oid = TEXT_OID
class CrdbNoneDumper(NoneDumper):
oid = TEXT_OID
def register_postgres_adapters(context: AdaptContext) -> None:
# Same adapters used by PostgreSQL, or a good starting point for customization
from ..types import array, bool, composite, datetime
from ..types import numeric, string, uuid
array.register_default_adapters(context)
bool.register_default_adapters(context)
composite.register_default_adapters(context)
datetime.register_default_adapters(context)
numeric.register_default_adapters(context)
string.register_default_adapters(context)
uuid.register_default_adapters(context)
def register_crdb_adapters(context: AdaptContext) -> None:
from .. import dbapi20
from ..types import array
register_postgres_adapters(context)
# String must come after enum to map text oid -> string dumper
register_crdb_enum_adapters(context)
register_crdb_string_adapters(context)
register_crdb_json_adapters(context)
register_crdb_net_adapters(context)
register_crdb_none_adapters(context)
dbapi20.register_dbapi20_adapters(adapters)
array.register_all_arrays(adapters)
def register_crdb_string_adapters(context: AdaptContext) -> None:
from ..types import string
# Dump strings with text oid instead of unknown.
# Unlike PostgreSQL, CRDB seems able to cast text to most types.
context.adapters.register_dumper(str, string.StrDumper)
context.adapters.register_dumper(str, string.StrBinaryDumper)
def register_crdb_enum_adapters(context: AdaptContext) -> None:
context.adapters.register_dumper(Enum, CrdbEnumBinaryDumper)
context.adapters.register_dumper(Enum, CrdbEnumDumper)
def register_crdb_json_adapters(context: AdaptContext) -> None:
from ..types import json
adapters = context.adapters
# CRDB doesn't have json/jsonb: both names map to the jsonb oid
adapters.register_dumper(json.Json, json.JsonbBinaryDumper)
adapters.register_dumper(json.Json, json.JsonbDumper)
adapters.register_dumper(json.Jsonb, json.JsonbBinaryDumper)
adapters.register_dumper(json.Jsonb, json.JsonbDumper)
adapters.register_loader("json", json.JsonLoader)
adapters.register_loader("jsonb", json.JsonbLoader)
adapters.register_loader("json", json.JsonBinaryLoader)
adapters.register_loader("jsonb", json.JsonbBinaryLoader)
def register_crdb_net_adapters(context: AdaptContext) -> None:
from ..types import net
adapters = context.adapters
adapters.register_dumper("ipaddress.IPv4Address", net.InterfaceDumper)
adapters.register_dumper("ipaddress.IPv6Address", net.InterfaceDumper)
adapters.register_dumper("ipaddress.IPv4Interface", net.InterfaceDumper)
adapters.register_dumper("ipaddress.IPv6Interface", net.InterfaceDumper)
adapters.register_dumper("ipaddress.IPv4Address", net.AddressBinaryDumper)
adapters.register_dumper("ipaddress.IPv6Address", net.AddressBinaryDumper)
adapters.register_dumper("ipaddress.IPv4Interface", net.InterfaceBinaryDumper)
adapters.register_dumper("ipaddress.IPv6Interface", net.InterfaceBinaryDumper)
adapters.register_dumper(None, net.InetBinaryDumper)
adapters.register_loader("inet", net.InetLoader)
adapters.register_loader("inet", net.InetBinaryLoader)
def register_crdb_none_adapters(context: AdaptContext) -> None:
context.adapters.register_dumper(NoneType, CrdbNoneDumper)
for t in [
TypeInfo("json", 3802, 3807, regtype="jsonb"), # Alias json -> jsonb.
TypeInfo("int8", 20, 1016, regtype="integer"), # Alias integer -> int8
TypeInfo('"char"', 18, 1002), # special case, not generated
# autogenerated: start
# Generated from CockroachDB 22.1.0
TypeInfo("bit", 1560, 1561),
TypeInfo("bool", 16, 1000, regtype="boolean"),
TypeInfo("bpchar", 1042, 1014, regtype="character"),
TypeInfo("bytea", 17, 1001),
TypeInfo("date", 1082, 1182),
TypeInfo("float4", 700, 1021, regtype="real"),
TypeInfo("float8", 701, 1022, regtype="double precision"),
TypeInfo("inet", 869, 1041),
TypeInfo("int2", 21, 1005, regtype="smallint"),
TypeInfo("int2vector", 22, 1006),
TypeInfo("int4", 23, 1007),
TypeInfo("int8", 20, 1016, regtype="bigint"),
TypeInfo("interval", 1186, 1187),
TypeInfo("jsonb", 3802, 3807),
TypeInfo("name", 19, 1003),
TypeInfo("numeric", 1700, 1231),
TypeInfo("oid", 26, 1028),
TypeInfo("oidvector", 30, 1013),
TypeInfo("record", 2249, 2287),
TypeInfo("regclass", 2205, 2210),
TypeInfo("regnamespace", 4089, 4090),
TypeInfo("regproc", 24, 1008),
TypeInfo("regprocedure", 2202, 2207),
TypeInfo("regrole", 4096, 4097),
TypeInfo("regtype", 2206, 2211),
TypeInfo("text", 25, 1009),
TypeInfo("time", 1083, 1183, regtype="time without time zone"),
TypeInfo("timestamp", 1114, 1115, regtype="timestamp without time zone"),
TypeInfo("timestamptz", 1184, 1185, regtype="timestamp with time zone"),
TypeInfo("timetz", 1266, 1270, regtype="time with time zone"),
TypeInfo("unknown", 705, 0),
TypeInfo("uuid", 2950, 2951),
TypeInfo("varbit", 1562, 1563, regtype="bit varying"),
TypeInfo("varchar", 1043, 1015, regtype="character varying"),
# autogenerated: end
]:
types.add(t)

View File

@ -0,0 +1,185 @@
"""
CockroachDB-specific connections.
"""
# Copyright (C) 2022 The Psycopg Team
import re
from typing import Any, Optional, Type, Union, overload, TYPE_CHECKING
from .. import errors as e
from ..abc import AdaptContext
from ..rows import Row, RowFactory, AsyncRowFactory, TupleRow
from ..conninfo import ConnectionInfo
from ..connection import Connection
from .._adapters_map import AdaptersMap
from ..connection_async import AsyncConnection
from ._types import adapters
if TYPE_CHECKING:
from ..pq.abc import PGconn
from ..cursor import Cursor
from ..cursor_async import AsyncCursor
class _CrdbConnectionMixin:
_adapters: Optional[AdaptersMap]
pgconn: "PGconn"
@classmethod
def is_crdb(
cls, conn: Union[Connection[Any], AsyncConnection[Any], "PGconn"]
) -> bool:
"""
Return `!True` if the server connected to `!conn` is CockroachDB.
"""
if isinstance(conn, (Connection, AsyncConnection)):
conn = conn.pgconn
return bool(conn.parameter_status(b"crdb_version"))
@property
def adapters(self) -> AdaptersMap:
if not self._adapters:
# By default, use CockroachDB adapters map
self._adapters = AdaptersMap(adapters)
return self._adapters
@property
def info(self) -> "CrdbConnectionInfo":
return CrdbConnectionInfo(self.pgconn)
def _check_tpc(self) -> None:
if self.is_crdb(self.pgconn):
raise e.NotSupportedError("CockroachDB doesn't support prepared statements")
class CrdbConnection(_CrdbConnectionMixin, Connection[Row]):
"""
Wrapper for a connection to a CockroachDB database.
"""
__module__ = "psycopg.crdb"
# TODO: this method shouldn't require re-definition if the base class
# implements a generic self.
# https://github.com/psycopg/psycopg/issues/308
@overload
@classmethod
def connect(
cls,
conninfo: str = "",
*,
autocommit: bool = False,
row_factory: RowFactory[Row],
prepare_threshold: Optional[int] = 5,
cursor_factory: "Optional[Type[Cursor[Row]]]" = None,
context: Optional[AdaptContext] = None,
**kwargs: Union[None, int, str],
) -> "CrdbConnection[Row]":
...
@overload
@classmethod
def connect(
cls,
conninfo: str = "",
*,
autocommit: bool = False,
prepare_threshold: Optional[int] = 5,
cursor_factory: "Optional[Type[Cursor[Any]]]" = None,
context: Optional[AdaptContext] = None,
**kwargs: Union[None, int, str],
) -> "CrdbConnection[TupleRow]":
...
@classmethod
def connect(cls, conninfo: str = "", **kwargs: Any) -> "CrdbConnection[Any]":
"""
Connect to a database server and return a new `CrdbConnection` instance.
"""
return super().connect(conninfo, **kwargs) # type: ignore[return-value]
class AsyncCrdbConnection(_CrdbConnectionMixin, AsyncConnection[Row]):
"""
Wrapper for an async connection to a CockroachDB database.
"""
__module__ = "psycopg.crdb"
# TODO: this method shouldn't require re-definition if the base class
# implements a generic self.
# https://github.com/psycopg/psycopg/issues/308
@overload
@classmethod
async def connect(
cls,
conninfo: str = "",
*,
autocommit: bool = False,
prepare_threshold: Optional[int] = 5,
row_factory: AsyncRowFactory[Row],
cursor_factory: "Optional[Type[AsyncCursor[Row]]]" = None,
context: Optional[AdaptContext] = None,
**kwargs: Union[None, int, str],
) -> "AsyncCrdbConnection[Row]":
...
@overload
@classmethod
async def connect(
cls,
conninfo: str = "",
*,
autocommit: bool = False,
prepare_threshold: Optional[int] = 5,
cursor_factory: "Optional[Type[AsyncCursor[Any]]]" = None,
context: Optional[AdaptContext] = None,
**kwargs: Union[None, int, str],
) -> "AsyncCrdbConnection[TupleRow]":
...
@classmethod
async def connect(
cls, conninfo: str = "", **kwargs: Any
) -> "AsyncCrdbConnection[Any]":
return await super().connect(conninfo, **kwargs) # type: ignore [no-any-return]
class CrdbConnectionInfo(ConnectionInfo):
"""
`~psycopg.ConnectionInfo` subclass to get info about a CockroachDB database.
"""
__module__ = "psycopg.crdb"
@property
def vendor(self) -> str:
return "CockroachDB"
@property
def server_version(self) -> int:
"""
Return the CockroachDB server version connected.
Return a number in the PostgreSQL format (e.g. 21.2.10 -> 210210).
"""
sver = self.parameter_status("crdb_version")
if not sver:
raise e.InternalError("'crdb_version' parameter status not set")
ver = self.parse_crdb_version(sver)
if ver is None:
raise e.InterfaceError(f"couldn't parse CockroachDB version from: {sver!r}")
return ver
@classmethod
def parse_crdb_version(self, sver: str) -> Optional[int]:
m = re.search(r"\bv(\d+)\.(\d+)\.(\d+)", sver)
if not m:
return None
return int(m.group(1)) * 10000 + int(m.group(2)) * 100 + int(m.group(3))

View File

@ -0,0 +1,929 @@
"""
psycopg cursor objects
"""
# Copyright (C) 2020 The Psycopg Team
from functools import partial
from types import TracebackType
from typing import Any, Generic, Iterable, Iterator, List
from typing import Optional, NoReturn, Sequence, Tuple, Type, TypeVar
from typing import overload, TYPE_CHECKING
from warnings import warn
from contextlib import contextmanager
from . import pq
from . import adapt
from . import errors as e
from .abc import ConnectionType, Query, Params, PQGen
from .copy import Copy, Writer as CopyWriter
from .rows import Row, RowMaker, RowFactory
from ._column import Column
from .pq.misc import connection_summary
from ._queries import PostgresQuery, PostgresClientQuery
from ._pipeline import Pipeline
from ._encodings import pgconn_encoding
from ._preparing import Prepare
from .generators import execute, fetch, send
if TYPE_CHECKING:
from .abc import Transformer
from .pq.abc import PGconn, PGresult
from .connection import Connection
TEXT = pq.Format.TEXT
BINARY = pq.Format.BINARY
EMPTY_QUERY = pq.ExecStatus.EMPTY_QUERY
COMMAND_OK = pq.ExecStatus.COMMAND_OK
TUPLES_OK = pq.ExecStatus.TUPLES_OK
COPY_OUT = pq.ExecStatus.COPY_OUT
COPY_IN = pq.ExecStatus.COPY_IN
COPY_BOTH = pq.ExecStatus.COPY_BOTH
FATAL_ERROR = pq.ExecStatus.FATAL_ERROR
SINGLE_TUPLE = pq.ExecStatus.SINGLE_TUPLE
PIPELINE_ABORTED = pq.ExecStatus.PIPELINE_ABORTED
ACTIVE = pq.TransactionStatus.ACTIVE
class BaseCursor(Generic[ConnectionType, Row]):
__slots__ = """
_conn format _adapters arraysize _closed _results pgresult _pos
_iresult _rowcount _query _tx _last_query _row_factory _make_row
_pgconn _execmany_returning
__weakref__
""".split()
ExecStatus = pq.ExecStatus
_tx: "Transformer"
_make_row: RowMaker[Row]
_pgconn: "PGconn"
def __init__(self, connection: ConnectionType):
self._conn = connection
self.format = TEXT
self._pgconn = connection.pgconn
self._adapters = adapt.AdaptersMap(connection.adapters)
self.arraysize = 1
self._closed = False
self._last_query: Optional[Query] = None
self._reset()
def _reset(self, reset_query: bool = True) -> None:
self._results: List["PGresult"] = []
self.pgresult: Optional["PGresult"] = None
self._pos = 0
self._iresult = 0
self._rowcount = -1
self._query: Optional[PostgresQuery]
# None if executemany() not executing, True/False according to returning state
self._execmany_returning: Optional[bool] = None
if reset_query:
self._query = None
def __repr__(self) -> str:
cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
info = connection_summary(self._pgconn)
if self._closed:
status = "closed"
elif self.pgresult:
status = pq.ExecStatus(self.pgresult.status).name
else:
status = "no result"
return f"<{cls} [{status}] {info} at 0x{id(self):x}>"
@property
def connection(self) -> ConnectionType:
"""The connection this cursor is using."""
return self._conn
@property
def adapters(self) -> adapt.AdaptersMap:
return self._adapters
@property
def closed(self) -> bool:
"""`True` if the cursor is closed."""
return self._closed
@property
def description(self) -> Optional[List[Column]]:
"""
A list of `Column` objects describing the current resultset.
`!None` if the current resultset didn't return tuples.
"""
res = self.pgresult
# We return columns if we have nfields, but also if we don't but
# the query said we got tuples (mostly to handle the super useful
# query "SELECT ;"
if res and (
res.nfields or res.status == TUPLES_OK or res.status == SINGLE_TUPLE
):
return [Column(self, i) for i in range(res.nfields)]
else:
return None
@property
def rowcount(self) -> int:
"""Number of records affected by the precedent operation."""
return self._rowcount
@property
def rownumber(self) -> Optional[int]:
"""Index of the next row to fetch in the current result.
`!None` if there is no result to fetch.
"""
tuples = self.pgresult and self.pgresult.status == TUPLES_OK
return self._pos if tuples else None
def setinputsizes(self, sizes: Sequence[Any]) -> None:
# no-op
pass
def setoutputsize(self, size: Any, column: Optional[int] = None) -> None:
# no-op
pass
def nextset(self) -> Optional[bool]:
"""
Move to the result set of the next query executed through `executemany()`
or to the next result set if `execute()` returned more than one.
Return `!True` if a new result is available, which will be the one
methods `!fetch*()` will operate on.
"""
# Raise a warning if people is calling nextset() in pipeline mode
# after a sequence of execute() in pipeline mode. Pipeline accumulating
# execute() results in the cursor is an unintended difference w.r.t.
# non-pipeline mode.
if self._execmany_returning is None and self._conn._pipeline:
warn(
"using nextset() in pipeline mode for several execute() is"
" deprecated and will be dropped in 3.2; please use different"
" cursors to receive more than one result",
DeprecationWarning,
)
if self._iresult < len(self._results) - 1:
self._select_current_result(self._iresult + 1)
return True
else:
return None
@property
def statusmessage(self) -> Optional[str]:
"""
The command status tag from the last SQL command executed.
`!None` if the cursor doesn't have a result available.
"""
msg = self.pgresult.command_status if self.pgresult else None
return msg.decode() if msg else None
def _make_row_maker(self) -> RowMaker[Row]:
raise NotImplementedError
#
# Generators for the high level operations on the cursor
#
# Like for sync/async connections, these are implemented as generators
# so that different concurrency strategies (threads,asyncio) can use their
# own way of waiting (or better, `connection.wait()`).
#
def _execute_gen(
self,
query: Query,
params: Optional[Params] = None,
*,
prepare: Optional[bool] = None,
binary: Optional[bool] = None,
) -> PQGen[None]:
"""Generator implementing `Cursor.execute()`."""
yield from self._start_query(query)
pgq = self._convert_query(query, params)
results = yield from self._maybe_prepare_gen(
pgq, prepare=prepare, binary=binary
)
if self._conn._pipeline:
yield from self._conn._pipeline._communicate_gen()
else:
assert results is not None
self._check_results(results)
self._results = results
self._select_current_result(0)
self._last_query = query
for cmd in self._conn._prepared.get_maintenance_commands():
yield from self._conn._exec_command(cmd)
def _executemany_gen_pipeline(
self, query: Query, params_seq: Iterable[Params], returning: bool
) -> PQGen[None]:
"""
Generator implementing `Cursor.executemany()` with pipelines available.
"""
pipeline = self._conn._pipeline
assert pipeline
yield from self._start_query(query)
if not returning:
self._rowcount = 0
assert self._execmany_returning is None
self._execmany_returning = returning
first = True
for params in params_seq:
if first:
pgq = self._convert_query(query, params)
self._query = pgq
first = False
else:
pgq.dump(params)
yield from self._maybe_prepare_gen(pgq, prepare=True)
yield from pipeline._communicate_gen()
self._last_query = query
if returning:
yield from pipeline._fetch_gen(flush=True)
for cmd in self._conn._prepared.get_maintenance_commands():
yield from self._conn._exec_command(cmd)
def _executemany_gen_no_pipeline(
self, query: Query, params_seq: Iterable[Params], returning: bool
) -> PQGen[None]:
"""
Generator implementing `Cursor.executemany()` with pipelines not available.
"""
yield from self._start_query(query)
if not returning:
self._rowcount = 0
first = True
for params in params_seq:
if first:
pgq = self._convert_query(query, params)
self._query = pgq
first = False
else:
pgq.dump(params)
results = yield from self._maybe_prepare_gen(pgq, prepare=True)
assert results is not None
self._check_results(results)
if returning:
self._results.extend(results)
else:
# In non-returning case, set rowcount to the cumulated number
# of rows of executed queries.
for res in results:
self._rowcount += res.command_tuples or 0
if self._results:
self._select_current_result(0)
self._last_query = query
for cmd in self._conn._prepared.get_maintenance_commands():
yield from self._conn._exec_command(cmd)
def _maybe_prepare_gen(
self,
pgq: PostgresQuery,
*,
prepare: Optional[bool] = None,
binary: Optional[bool] = None,
) -> PQGen[Optional[List["PGresult"]]]:
# Check if the query is prepared or needs preparing
prep, name = self._get_prepared(pgq, prepare)
if prep is Prepare.NO:
# The query must be executed without preparing
self._execute_send(pgq, binary=binary)
else:
# If the query is not already prepared, prepare it.
if prep is Prepare.SHOULD:
self._send_prepare(name, pgq)
if not self._conn._pipeline:
(result,) = yield from execute(self._pgconn)
if result.status == FATAL_ERROR:
raise e.error_from_result(result, encoding=self._encoding)
# Then execute it.
self._send_query_prepared(name, pgq, binary=binary)
# Update the prepare state of the query.
# If an operation requires to flush our prepared statements cache,
# it will be added to the maintenance commands to execute later.
key = self._conn._prepared.maybe_add_to_cache(pgq, prep, name)
if self._conn._pipeline:
queued = None
if key is not None:
queued = (key, prep, name)
self._conn._pipeline.result_queue.append((self, queued))
return None
# run the query
results = yield from execute(self._pgconn)
if key is not None:
self._conn._prepared.validate(key, prep, name, results)
return results
def _get_prepared(
self, pgq: PostgresQuery, prepare: Optional[bool] = None
) -> Tuple[Prepare, bytes]:
return self._conn._prepared.get(pgq, prepare)
def _stream_send_gen(
self,
query: Query,
params: Optional[Params] = None,
*,
binary: Optional[bool] = None,
) -> PQGen[None]:
"""Generator to send the query for `Cursor.stream()`."""
yield from self._start_query(query)
pgq = self._convert_query(query, params)
self._execute_send(pgq, binary=binary, force_extended=True)
self._pgconn.set_single_row_mode()
self._last_query = query
yield from send(self._pgconn)
def _stream_fetchone_gen(self, first: bool) -> PQGen[Optional["PGresult"]]:
res = yield from fetch(self._pgconn)
if res is None:
return None
status = res.status
if status == SINGLE_TUPLE:
self.pgresult = res
self._tx.set_pgresult(res, set_loaders=first)
if first:
self._make_row = self._make_row_maker()
return res
elif status == TUPLES_OK or status == COMMAND_OK:
# End of single row results
while res:
res = yield from fetch(self._pgconn)
if status != TUPLES_OK:
raise e.ProgrammingError(
"the operation in stream() didn't produce a result"
)
return None
else:
# Errors, unexpected values
return self._raise_for_result(res)
def _start_query(self, query: Optional[Query] = None) -> PQGen[None]:
"""Generator to start the processing of a query.
It is implemented as generator because it may send additional queries,
such as `begin`.
"""
if self.closed:
raise e.InterfaceError("the cursor is closed")
self._reset()
if not self._last_query or (self._last_query is not query):
self._last_query = None
self._tx = adapt.Transformer(self)
yield from self._conn._start_query()
def _start_copy_gen(
self, statement: Query, params: Optional[Params] = None
) -> PQGen[None]:
"""Generator implementing sending a command for `Cursor.copy()."""
# The connection gets in an unrecoverable state if we attempt COPY in
# pipeline mode. Forbid it explicitly.
if self._conn._pipeline:
raise e.NotSupportedError("COPY cannot be used in pipeline mode")
yield from self._start_query()
# Merge the params client-side
if params:
pgq = PostgresClientQuery(self._tx)
pgq.convert(statement, params)
statement = pgq.query
query = self._convert_query(statement)
self._execute_send(query, binary=False)
results = yield from execute(self._pgconn)
if len(results) != 1:
raise e.ProgrammingError("COPY cannot be mixed with other operations")
self._check_copy_result(results[0])
self._results = results
self._select_current_result(0)
def _execute_send(
self,
query: PostgresQuery,
*,
force_extended: bool = False,
binary: Optional[bool] = None,
) -> None:
"""
Implement part of execute() before waiting common to sync and async.
This is not a generator, but a normal non-blocking function.
"""
if binary is None:
fmt = self.format
else:
fmt = BINARY if binary else TEXT
self._query = query
if self._conn._pipeline:
# In pipeline mode always use PQsendQueryParams - see #314
# Multiple statements in the same query are not allowed anyway.
self._conn._pipeline.command_queue.append(
partial(
self._pgconn.send_query_params,
query.query,
query.params,
param_formats=query.formats,
param_types=query.types,
result_format=fmt,
)
)
elif force_extended or query.params or fmt == BINARY:
self._pgconn.send_query_params(
query.query,
query.params,
param_formats=query.formats,
param_types=query.types,
result_format=fmt,
)
else:
# If we can, let's use simple query protocol,
# as it can execute more than one statement in a single query.
self._pgconn.send_query(query.query)
def _convert_query(
self, query: Query, params: Optional[Params] = None
) -> PostgresQuery:
pgq = PostgresQuery(self._tx)
pgq.convert(query, params)
return pgq
def _check_results(self, results: List["PGresult"]) -> None:
"""
Verify that the results of a query are valid.
Verify that the query returned at least one result and that they all
represent a valid result from the database.
"""
if not results:
raise e.InternalError("got no result from the query")
for res in results:
status = res.status
if status != TUPLES_OK and status != COMMAND_OK and status != EMPTY_QUERY:
self._raise_for_result(res)
def _raise_for_result(self, result: "PGresult") -> NoReturn:
"""
Raise an appropriate error message for an unexpected database result
"""
status = result.status
if status == FATAL_ERROR:
raise e.error_from_result(result, encoding=self._encoding)
elif status == PIPELINE_ABORTED:
raise e.PipelineAborted("pipeline aborted")
elif status == COPY_IN or status == COPY_OUT or status == COPY_BOTH:
raise e.ProgrammingError(
"COPY cannot be used with this method; use copy() instead"
)
else:
raise e.InternalError(
"unexpected result status from query:" f" {pq.ExecStatus(status).name}"
)
def _select_current_result(
self, i: int, format: Optional[pq.Format] = None
) -> None:
"""
Select one of the results in the cursor as the active one.
"""
self._iresult = i
res = self.pgresult = self._results[i]
# Note: the only reason to override format is to correctly set
# binary loaders on server-side cursors, because send_describe_portal
# only returns a text result.
self._tx.set_pgresult(res, format=format)
self._pos = 0
if res.status == TUPLES_OK:
self._rowcount = self.pgresult.ntuples
# COPY_OUT has never info about nrows. We need such result for the
# columns in order to return a `description`, but not overwrite the
# cursor rowcount (which was set by the Copy object).
elif res.status != COPY_OUT:
nrows = self.pgresult.command_tuples
self._rowcount = nrows if nrows is not None else -1
self._make_row = self._make_row_maker()
def _set_results_from_pipeline(self, results: List["PGresult"]) -> None:
self._check_results(results)
first_batch = not self._results
if self._execmany_returning is None:
# Received from execute()
self._results.extend(results)
if first_batch:
self._select_current_result(0)
else:
# Received from executemany()
if self._execmany_returning:
self._results.extend(results)
if first_batch:
self._select_current_result(0)
else:
# In non-returning case, set rowcount to the cumulated number of
# rows of executed queries.
for res in results:
self._rowcount += res.command_tuples or 0
def _send_prepare(self, name: bytes, query: PostgresQuery) -> None:
if self._conn._pipeline:
self._conn._pipeline.command_queue.append(
partial(
self._pgconn.send_prepare,
name,
query.query,
param_types=query.types,
)
)
self._conn._pipeline.result_queue.append(None)
else:
self._pgconn.send_prepare(name, query.query, param_types=query.types)
def _send_query_prepared(
self, name: bytes, pgq: PostgresQuery, *, binary: Optional[bool] = None
) -> None:
if binary is None:
fmt = self.format
else:
fmt = BINARY if binary else TEXT
if self._conn._pipeline:
self._conn._pipeline.command_queue.append(
partial(
self._pgconn.send_query_prepared,
name,
pgq.params,
param_formats=pgq.formats,
result_format=fmt,
)
)
else:
self._pgconn.send_query_prepared(
name, pgq.params, param_formats=pgq.formats, result_format=fmt
)
def _check_result_for_fetch(self) -> None:
if self.closed:
raise e.InterfaceError("the cursor is closed")
res = self.pgresult
if not res:
raise e.ProgrammingError("no result available")
status = res.status
if status == TUPLES_OK:
return
elif status == FATAL_ERROR:
raise e.error_from_result(res, encoding=self._encoding)
elif status == PIPELINE_ABORTED:
raise e.PipelineAborted("pipeline aborted")
else:
raise e.ProgrammingError("the last operation didn't produce a result")
def _check_copy_result(self, result: "PGresult") -> None:
"""
Check that the value returned in a copy() operation is a legit COPY.
"""
status = result.status
if status == COPY_IN or status == COPY_OUT:
return
elif status == FATAL_ERROR:
raise e.error_from_result(result, encoding=self._encoding)
else:
raise e.ProgrammingError(
"copy() should be used only with COPY ... TO STDOUT or COPY ..."
f" FROM STDIN statements, got {pq.ExecStatus(status).name}"
)
def _scroll(self, value: int, mode: str) -> None:
self._check_result_for_fetch()
assert self.pgresult
if mode == "relative":
newpos = self._pos + value
elif mode == "absolute":
newpos = value
else:
raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'")
if not 0 <= newpos < self.pgresult.ntuples:
raise IndexError("position out of bound")
self._pos = newpos
def _close(self) -> None:
"""Non-blocking part of closing. Common to sync/async."""
# Don't reset the query because it may be useful to investigate after
# an error.
self._reset(reset_query=False)
self._closed = True
@property
def _encoding(self) -> str:
return pgconn_encoding(self._pgconn)
class Cursor(BaseCursor["Connection[Any]", Row]):
__module__ = "psycopg"
__slots__ = ()
_Self = TypeVar("_Self", bound="Cursor[Any]")
@overload
def __init__(self: "Cursor[Row]", connection: "Connection[Row]"):
...
@overload
def __init__(
self: "Cursor[Row]",
connection: "Connection[Any]",
*,
row_factory: RowFactory[Row],
):
...
def __init__(
self,
connection: "Connection[Any]",
*,
row_factory: Optional[RowFactory[Row]] = None,
):
super().__init__(connection)
self._row_factory = row_factory or connection.row_factory
def __enter__(self: _Self) -> _Self:
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self.close()
def close(self) -> None:
"""
Close the current cursor and free associated resources.
"""
self._close()
@property
def row_factory(self) -> RowFactory[Row]:
"""Writable attribute to control how result rows are formed."""
return self._row_factory
@row_factory.setter
def row_factory(self, row_factory: RowFactory[Row]) -> None:
self._row_factory = row_factory
if self.pgresult:
self._make_row = row_factory(self)
def _make_row_maker(self) -> RowMaker[Row]:
return self._row_factory(self)
def execute(
self: _Self,
query: Query,
params: Optional[Params] = None,
*,
prepare: Optional[bool] = None,
binary: Optional[bool] = None,
) -> _Self:
"""
Execute a query or command to the database.
"""
try:
with self._conn.lock:
self._conn.wait(
self._execute_gen(query, params, prepare=prepare, binary=binary)
)
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
return self
def executemany(
self,
query: Query,
params_seq: Iterable[Params],
*,
returning: bool = False,
) -> None:
"""
Execute the same command with a sequence of input data.
"""
try:
if Pipeline.is_supported():
# If there is already a pipeline, ride it, in order to avoid
# sending unnecessary Sync.
with self._conn.lock:
p = self._conn._pipeline
if p:
self._conn.wait(
self._executemany_gen_pipeline(query, params_seq, returning)
)
# Otherwise, make a new one
if not p:
with self._conn.pipeline(), self._conn.lock:
self._conn.wait(
self._executemany_gen_pipeline(query, params_seq, returning)
)
else:
with self._conn.lock:
self._conn.wait(
self._executemany_gen_no_pipeline(query, params_seq, returning)
)
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
def stream(
self,
query: Query,
params: Optional[Params] = None,
*,
binary: Optional[bool] = None,
) -> Iterator[Row]:
"""
Iterate row-by-row on a result from the database.
"""
if self._pgconn.pipeline_status:
raise e.ProgrammingError("stream() cannot be used in pipeline mode")
with self._conn.lock:
try:
self._conn.wait(self._stream_send_gen(query, params, binary=binary))
first = True
while self._conn.wait(self._stream_fetchone_gen(first)):
# We know that, if we got a result, it has a single row.
rec: Row = self._tx.load_row(0, self._make_row) # type: ignore
yield rec
first = False
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
finally:
if self._pgconn.transaction_status == ACTIVE:
# Try to cancel the query, then consume the results
# already received.
self._conn.cancel()
try:
while self._conn.wait(self._stream_fetchone_gen(first=False)):
pass
except Exception:
pass
# Try to get out of ACTIVE state. Just do a single attempt, which
# should work to recover from an error or query cancelled.
try:
self._conn.wait(self._stream_fetchone_gen(first=False))
except Exception:
pass
def fetchone(self) -> Optional[Row]:
"""
Return the next record from the current recordset.
Return `!None` the recordset is finished.
:rtype: Optional[Row], with Row defined by `row_factory`
"""
self._fetch_pipeline()
self._check_result_for_fetch()
record = self._tx.load_row(self._pos, self._make_row)
if record is not None:
self._pos += 1
return record
def fetchmany(self, size: int = 0) -> List[Row]:
"""
Return the next `!size` records from the current recordset.
`!size` default to `!self.arraysize` if not specified.
:rtype: Sequence[Row], with Row defined by `row_factory`
"""
self._fetch_pipeline()
self._check_result_for_fetch()
assert self.pgresult
if not size:
size = self.arraysize
records = self._tx.load_rows(
self._pos,
min(self._pos + size, self.pgresult.ntuples),
self._make_row,
)
self._pos += len(records)
return records
def fetchall(self) -> List[Row]:
"""
Return all the remaining records from the current recordset.
:rtype: Sequence[Row], with Row defined by `row_factory`
"""
self._fetch_pipeline()
self._check_result_for_fetch()
assert self.pgresult
records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row)
self._pos = self.pgresult.ntuples
return records
def __iter__(self) -> Iterator[Row]:
self._fetch_pipeline()
self._check_result_for_fetch()
def load(pos: int) -> Optional[Row]:
return self._tx.load_row(pos, self._make_row)
while True:
row = load(self._pos)
if row is None:
break
self._pos += 1
yield row
def scroll(self, value: int, mode: str = "relative") -> None:
"""
Move the cursor in the result set to a new position according to mode.
If `!mode` is ``'relative'`` (default), `!value` is taken as offset to
the current position in the result set; if set to ``'absolute'``,
`!value` states an absolute target position.
Raise `!IndexError` in case a scroll operation would leave the result
set. In this case the position will not change.
"""
self._fetch_pipeline()
self._scroll(value, mode)
@contextmanager
def copy(
self,
statement: Query,
params: Optional[Params] = None,
*,
writer: Optional[CopyWriter] = None,
) -> Iterator[Copy]:
"""
Initiate a :sql:`COPY` operation and return an object to manage it.
:rtype: Copy
"""
try:
with self._conn.lock:
self._conn.wait(self._start_copy_gen(statement, params))
with Copy(self, writer=writer) as copy:
yield copy
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
# If a fresher result has been set on the cursor by the Copy object,
# read its properties (especially rowcount).
self._select_current_result(0)
def _fetch_pipeline(self) -> None:
if (
self._execmany_returning is not False
and not self.pgresult
and self._conn._pipeline
):
with self._conn.lock:
self._conn.wait(self._conn._pipeline._fetch_gen(flush=True))

View File

@ -0,0 +1,251 @@
"""
psycopg async cursor objects
"""
# Copyright (C) 2020 The Psycopg Team
from types import TracebackType
from typing import Any, AsyncIterator, Iterable, List
from typing import Optional, Type, TypeVar, TYPE_CHECKING, overload
from contextlib import asynccontextmanager
from . import pq
from . import errors as e
from .abc import Query, Params
from .copy import AsyncCopy, AsyncWriter as AsyncCopyWriter
from .rows import Row, RowMaker, AsyncRowFactory
from .cursor import BaseCursor
from ._pipeline import Pipeline
if TYPE_CHECKING:
from .connection_async import AsyncConnection
ACTIVE = pq.TransactionStatus.ACTIVE
class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
__module__ = "psycopg"
__slots__ = ()
_Self = TypeVar("_Self", bound="AsyncCursor[Any]")
@overload
def __init__(self: "AsyncCursor[Row]", connection: "AsyncConnection[Row]"):
...
@overload
def __init__(
self: "AsyncCursor[Row]",
connection: "AsyncConnection[Any]",
*,
row_factory: AsyncRowFactory[Row],
):
...
def __init__(
self,
connection: "AsyncConnection[Any]",
*,
row_factory: Optional[AsyncRowFactory[Row]] = None,
):
super().__init__(connection)
self._row_factory = row_factory or connection.row_factory
async def __aenter__(self: _Self) -> _Self:
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
await self.close()
async def close(self) -> None:
self._close()
@property
def row_factory(self) -> AsyncRowFactory[Row]:
return self._row_factory
@row_factory.setter
def row_factory(self, row_factory: AsyncRowFactory[Row]) -> None:
self._row_factory = row_factory
if self.pgresult:
self._make_row = row_factory(self)
def _make_row_maker(self) -> RowMaker[Row]:
return self._row_factory(self)
async def execute(
self: _Self,
query: Query,
params: Optional[Params] = None,
*,
prepare: Optional[bool] = None,
binary: Optional[bool] = None,
) -> _Self:
try:
async with self._conn.lock:
await self._conn.wait(
self._execute_gen(query, params, prepare=prepare, binary=binary)
)
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
return self
async def executemany(
self,
query: Query,
params_seq: Iterable[Params],
*,
returning: bool = False,
) -> None:
try:
if Pipeline.is_supported():
# If there is already a pipeline, ride it, in order to avoid
# sending unnecessary Sync.
async with self._conn.lock:
p = self._conn._pipeline
if p:
await self._conn.wait(
self._executemany_gen_pipeline(query, params_seq, returning)
)
# Otherwise, make a new one
if not p:
async with self._conn.pipeline(), self._conn.lock:
await self._conn.wait(
self._executemany_gen_pipeline(query, params_seq, returning)
)
else:
async with self._conn.lock:
await self._conn.wait(
self._executemany_gen_no_pipeline(query, params_seq, returning)
)
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
async def stream(
self,
query: Query,
params: Optional[Params] = None,
*,
binary: Optional[bool] = None,
) -> AsyncIterator[Row]:
if self._pgconn.pipeline_status:
raise e.ProgrammingError("stream() cannot be used in pipeline mode")
async with self._conn.lock:
try:
await self._conn.wait(
self._stream_send_gen(query, params, binary=binary)
)
first = True
while await self._conn.wait(self._stream_fetchone_gen(first)):
# We know that, if we got a result, it has a single row.
rec: Row = self._tx.load_row(0, self._make_row) # type: ignore
yield rec
first = False
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
finally:
if self._pgconn.transaction_status == ACTIVE:
# Try to cancel the query, then consume the results
# already received.
self._conn.cancel()
try:
while await self._conn.wait(
self._stream_fetchone_gen(first=False)
):
pass
except Exception:
pass
# Try to get out of ACTIVE state. Just do a single attempt, which
# should work to recover from an error or query cancelled.
try:
await self._conn.wait(self._stream_fetchone_gen(first=False))
except Exception:
pass
async def fetchone(self) -> Optional[Row]:
await self._fetch_pipeline()
self._check_result_for_fetch()
record = self._tx.load_row(self._pos, self._make_row)
if record is not None:
self._pos += 1
return record
async def fetchmany(self, size: int = 0) -> List[Row]:
await self._fetch_pipeline()
self._check_result_for_fetch()
assert self.pgresult
if not size:
size = self.arraysize
records = self._tx.load_rows(
self._pos,
min(self._pos + size, self.pgresult.ntuples),
self._make_row,
)
self._pos += len(records)
return records
async def fetchall(self) -> List[Row]:
await self._fetch_pipeline()
self._check_result_for_fetch()
assert self.pgresult
records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row)
self._pos = self.pgresult.ntuples
return records
async def __aiter__(self) -> AsyncIterator[Row]:
await self._fetch_pipeline()
self._check_result_for_fetch()
def load(pos: int) -> Optional[Row]:
return self._tx.load_row(pos, self._make_row)
while True:
row = load(self._pos)
if row is None:
break
self._pos += 1
yield row
async def scroll(self, value: int, mode: str = "relative") -> None:
await self._fetch_pipeline()
self._scroll(value, mode)
@asynccontextmanager
async def copy(
self,
statement: Query,
params: Optional[Params] = None,
*,
writer: Optional[AsyncCopyWriter] = None,
) -> AsyncIterator[AsyncCopy]:
"""
:rtype: AsyncCopy
"""
try:
async with self._conn.lock:
await self._conn.wait(self._start_copy_gen(statement, params))
async with AsyncCopy(self, writer=writer) as copy:
yield copy
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
self._select_current_result(0)
async def _fetch_pipeline(self) -> None:
if (
self._execmany_returning is not False
and not self.pgresult
and self._conn._pipeline
):
async with self._conn.lock:
await self._conn.wait(self._conn._pipeline._fetch_gen(flush=True))

View File

@ -0,0 +1,112 @@
"""
Compatibility objects with DBAPI 2.0
"""
# Copyright (C) 2020 The Psycopg Team
import time
import datetime as dt
from math import floor
from typing import Any, Sequence, Union
from . import postgres
from .abc import AdaptContext, Buffer
from .types.string import BytesDumper, BytesBinaryDumper
class DBAPITypeObject:
def __init__(self, name: str, type_names: Sequence[str]):
self.name = name
self.values = tuple(postgres.types[n].oid for n in type_names)
def __repr__(self) -> str:
return f"psycopg.{self.name}"
def __eq__(self, other: Any) -> bool:
if isinstance(other, int):
return other in self.values
else:
return NotImplemented
def __ne__(self, other: Any) -> bool:
if isinstance(other, int):
return other not in self.values
else:
return NotImplemented
BINARY = DBAPITypeObject("BINARY", ("bytea",))
DATETIME = DBAPITypeObject(
"DATETIME", "timestamp timestamptz date time timetz interval".split()
)
NUMBER = DBAPITypeObject("NUMBER", "int2 int4 int8 float4 float8 numeric".split())
ROWID = DBAPITypeObject("ROWID", ("oid",))
STRING = DBAPITypeObject("STRING", "text varchar bpchar".split())
class Binary:
def __init__(self, obj: Any):
self.obj = obj
def __repr__(self) -> str:
sobj = repr(self.obj)
if len(sobj) > 40:
sobj = f"{sobj[:35]} ... ({len(sobj)} byteschars)"
return f"{self.__class__.__name__}({sobj})"
class BinaryBinaryDumper(BytesBinaryDumper):
def dump(self, obj: Union[Buffer, Binary]) -> Buffer:
if isinstance(obj, Binary):
return super().dump(obj.obj)
else:
return super().dump(obj)
class BinaryTextDumper(BytesDumper):
def dump(self, obj: Union[Buffer, Binary]) -> Buffer:
if isinstance(obj, Binary):
return super().dump(obj.obj)
else:
return super().dump(obj)
def Date(year: int, month: int, day: int) -> dt.date:
return dt.date(year, month, day)
def DateFromTicks(ticks: float) -> dt.date:
return TimestampFromTicks(ticks).date()
def Time(hour: int, minute: int, second: int) -> dt.time:
return dt.time(hour, minute, second)
def TimeFromTicks(ticks: float) -> dt.time:
return TimestampFromTicks(ticks).time()
def Timestamp(
year: int, month: int, day: int, hour: int, minute: int, second: int
) -> dt.datetime:
return dt.datetime(year, month, day, hour, minute, second)
def TimestampFromTicks(ticks: float) -> dt.datetime:
secs = floor(ticks)
frac = ticks - secs
t = time.localtime(ticks)
tzinfo = dt.timezone(dt.timedelta(seconds=t.tm_gmtoff))
rv = dt.datetime(*t[:6], round(frac * 1_000_000), tzinfo=tzinfo)
return rv
def register_dbapi20_adapters(context: AdaptContext) -> None:
adapters = context.adapters
adapters.register_dumper(Binary, BinaryTextDumper)
adapters.register_dumper(Binary, BinaryBinaryDumper)
# Make them also the default dumpers when dumping by bytea oid
adapters.register_dumper(None, BinaryTextDumper)
adapters.register_dumper(None, BinaryBinaryDumper)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,333 @@
"""
Generators implementing communication protocols with the libpq
Certain operations (connection, querying) are an interleave of libpq calls and
waiting for the socket to be ready. This module contains the code to execute
the operations, yielding a polling state whenever there is to wait. The
functions in the `waiting` module are the ones who wait more or less
cooperatively for the socket to be ready and make these generators continue.
All these generators yield pairs (fileno, `Wait`) whenever an operation would
block. The generator can be restarted sending the appropriate `Ready` state
when the file descriptor is ready.
"""
# Copyright (C) 2020 The Psycopg Team
import logging
from typing import List, Optional, Union
from . import pq
from . import errors as e
from .abc import Buffer, PipelineCommand, PQGen, PQGenConn
from .pq.abc import PGconn, PGresult
from .waiting import Wait, Ready
from ._compat import Deque
from ._cmodule import _psycopg
from ._encodings import pgconn_encoding, conninfo_encoding
OK = pq.ConnStatus.OK
BAD = pq.ConnStatus.BAD
POLL_OK = pq.PollingStatus.OK
POLL_READING = pq.PollingStatus.READING
POLL_WRITING = pq.PollingStatus.WRITING
POLL_FAILED = pq.PollingStatus.FAILED
COMMAND_OK = pq.ExecStatus.COMMAND_OK
COPY_OUT = pq.ExecStatus.COPY_OUT
COPY_IN = pq.ExecStatus.COPY_IN
COPY_BOTH = pq.ExecStatus.COPY_BOTH
PIPELINE_SYNC = pq.ExecStatus.PIPELINE_SYNC
WAIT_R = Wait.R
WAIT_W = Wait.W
WAIT_RW = Wait.RW
READY_R = Ready.R
READY_W = Ready.W
READY_RW = Ready.RW
logger = logging.getLogger(__name__)
def _connect(conninfo: str) -> PQGenConn[PGconn]:
"""
Generator to create a database connection without blocking.
"""
conn = pq.PGconn.connect_start(conninfo.encode())
while True:
if conn.status == BAD:
encoding = conninfo_encoding(conninfo)
raise e.OperationalError(
f"connection is bad: {pq.error_message(conn, encoding=encoding)}",
pgconn=conn,
)
status = conn.connect_poll()
if status == POLL_OK:
break
elif status == POLL_READING:
yield conn.socket, WAIT_R
elif status == POLL_WRITING:
yield conn.socket, WAIT_W
elif status == POLL_FAILED:
encoding = conninfo_encoding(conninfo)
raise e.OperationalError(
f"connection failed: {pq.error_message(conn, encoding=encoding)}",
pgconn=e.finish_pgconn(conn),
)
else:
raise e.InternalError(
f"unexpected poll status: {status}", pgconn=e.finish_pgconn(conn)
)
conn.nonblocking = 1
return conn
def _execute(pgconn: PGconn) -> PQGen[List[PGresult]]:
"""
Generator sending a query and returning results without blocking.
The query must have already been sent using `pgconn.send_query()` or
similar. Flush the query and then return the result using nonblocking
functions.
Return the list of results returned by the database (whether success
or error).
"""
yield from _send(pgconn)
rv = yield from _fetch_many(pgconn)
return rv
def _send(pgconn: PGconn) -> PQGen[None]:
"""
Generator to send a query to the server without blocking.
The query must have already been sent using `pgconn.send_query()` or
similar. Flush the query and then return the result using nonblocking
functions.
After this generator has finished you may want to cycle using `fetch()`
to retrieve the results available.
"""
while True:
f = pgconn.flush()
if f == 0:
break
ready = yield WAIT_RW
if ready & READY_R:
# This call may read notifies: they will be saved in the
# PGconn buffer and passed to Python later, in `fetch()`.
pgconn.consume_input()
def _fetch_many(pgconn: PGconn) -> PQGen[List[PGresult]]:
"""
Generator retrieving results from the database without blocking.
The query must have already been sent to the server, so pgconn.flush() has
already returned 0.
Return the list of results returned by the database (whether success
or error).
"""
results: List[PGresult] = []
while True:
res = yield from _fetch(pgconn)
if not res:
break
results.append(res)
status = res.status
if status == COPY_IN or status == COPY_OUT or status == COPY_BOTH:
# After entering copy mode the libpq will create a phony result
# for every request so let's break the endless loop.
break
if status == PIPELINE_SYNC:
# PIPELINE_SYNC is not followed by a NULL, but we return it alone
# similarly to other result sets.
assert len(results) == 1, results
break
return results
def _fetch(pgconn: PGconn) -> PQGen[Optional[PGresult]]:
"""
Generator retrieving a single result from the database without blocking.
The query must have already been sent to the server, so pgconn.flush() has
already returned 0.
Return a result from the database (whether success or error).
"""
if pgconn.is_busy():
yield WAIT_R
while True:
pgconn.consume_input()
if not pgconn.is_busy():
break
yield WAIT_R
_consume_notifies(pgconn)
return pgconn.get_result()
def _pipeline_communicate(
pgconn: PGconn, commands: Deque[PipelineCommand]
) -> PQGen[List[List[PGresult]]]:
"""Generator to send queries from a connection in pipeline mode while also
receiving results.
Return a list results, including single PIPELINE_SYNC elements.
"""
results = []
while True:
ready = yield WAIT_RW
if ready & READY_R:
pgconn.consume_input()
_consume_notifies(pgconn)
res: List[PGresult] = []
while not pgconn.is_busy():
r = pgconn.get_result()
if r is None:
if not res:
break
results.append(res)
res = []
else:
status = r.status
if status == PIPELINE_SYNC:
assert not res
results.append([r])
elif status == COPY_IN or status == COPY_OUT or status == COPY_BOTH:
# This shouldn't happen, but insisting hard enough, it will.
# For instance, in test_executemany_badquery(), with the COPY
# statement and the AsyncClientCursor, which disables
# prepared statements).
# Bail out from the resulting infinite loop.
raise e.NotSupportedError(
"COPY cannot be used in pipeline mode"
)
else:
res.append(r)
if ready & READY_W:
pgconn.flush()
if not commands:
break
commands.popleft()()
return results
def _consume_notifies(pgconn: PGconn) -> None:
# Consume notifies
while True:
n = pgconn.notifies()
if not n:
break
if pgconn.notify_handler:
pgconn.notify_handler(n)
def notifies(pgconn: PGconn) -> PQGen[List[pq.PGnotify]]:
yield WAIT_R
pgconn.consume_input()
ns = []
while True:
n = pgconn.notifies()
if n:
ns.append(n)
else:
break
return ns
def copy_from(pgconn: PGconn) -> PQGen[Union[memoryview, PGresult]]:
while True:
nbytes, data = pgconn.get_copy_data(1)
if nbytes != 0:
break
# would block
yield WAIT_R
pgconn.consume_input()
if nbytes > 0:
# some data
return data
# Retrieve the final result of copy
results = yield from _fetch_many(pgconn)
if len(results) > 1:
# TODO: too brutal? Copy worked.
raise e.ProgrammingError("you cannot mix COPY with other operations")
result = results[0]
if result.status != COMMAND_OK:
encoding = pgconn_encoding(pgconn)
raise e.error_from_result(result, encoding=encoding)
return result
def copy_to(pgconn: PGconn, buffer: Buffer) -> PQGen[None]:
# Retry enqueuing data until successful.
#
# WARNING! This can cause an infinite loop if the buffer is too large. (see
# ticket #255). We avoid it in the Copy object by splitting a large buffer
# into smaller ones. We prefer to do it there instead of here in order to
# do it upstream the queue decoupling the writer task from the producer one.
while pgconn.put_copy_data(buffer) == 0:
yield WAIT_W
def copy_end(pgconn: PGconn, error: Optional[bytes]) -> PQGen[PGresult]:
# Retry enqueuing end copy message until successful
while pgconn.put_copy_end(error) == 0:
yield WAIT_W
# Repeat until it the message is flushed to the server
while True:
yield WAIT_W
f = pgconn.flush()
if f == 0:
break
# Retrieve the final result of copy
(result,) = yield from _fetch_many(pgconn)
if result.status != COMMAND_OK:
encoding = pgconn_encoding(pgconn)
raise e.error_from_result(result, encoding=encoding)
return result
# Override functions with fast versions if available
if _psycopg:
connect = _psycopg.connect
execute = _psycopg.execute
send = _psycopg.send
fetch_many = _psycopg.fetch_many
fetch = _psycopg.fetch
pipeline_communicate = _psycopg.pipeline_communicate
else:
connect = _connect
execute = _execute
send = _send
fetch_many = _fetch_many
fetch = _fetch
pipeline_communicate = _pipeline_communicate

View File

@ -0,0 +1,124 @@
"""
Types configuration specific to PostgreSQL.
"""
# Copyright (C) 2020 The Psycopg Team
from ._typeinfo import TypeInfo, RangeInfo, MultirangeInfo, TypesRegistry
from .abc import AdaptContext
from ._adapters_map import AdaptersMap
# Global objects with PostgreSQL builtins and globally registered user types.
types = TypesRegistry()
# Global adapter maps with PostgreSQL types configuration
adapters = AdaptersMap(types=types)
# Use tools/update_oids.py to update this data.
for t in [
TypeInfo('"char"', 18, 1002),
# autogenerated: start
# Generated from PostgreSQL 16.0
TypeInfo("aclitem", 1033, 1034),
TypeInfo("bit", 1560, 1561),
TypeInfo("bool", 16, 1000, regtype="boolean"),
TypeInfo("box", 603, 1020, delimiter=";"),
TypeInfo("bpchar", 1042, 1014, regtype="character"),
TypeInfo("bytea", 17, 1001),
TypeInfo("cid", 29, 1012),
TypeInfo("cidr", 650, 651),
TypeInfo("circle", 718, 719),
TypeInfo("date", 1082, 1182),
TypeInfo("float4", 700, 1021, regtype="real"),
TypeInfo("float8", 701, 1022, regtype="double precision"),
TypeInfo("gtsvector", 3642, 3644),
TypeInfo("inet", 869, 1041),
TypeInfo("int2", 21, 1005, regtype="smallint"),
TypeInfo("int2vector", 22, 1006),
TypeInfo("int4", 23, 1007, regtype="integer"),
TypeInfo("int8", 20, 1016, regtype="bigint"),
TypeInfo("interval", 1186, 1187),
TypeInfo("json", 114, 199),
TypeInfo("jsonb", 3802, 3807),
TypeInfo("jsonpath", 4072, 4073),
TypeInfo("line", 628, 629),
TypeInfo("lseg", 601, 1018),
TypeInfo("macaddr", 829, 1040),
TypeInfo("macaddr8", 774, 775),
TypeInfo("money", 790, 791),
TypeInfo("name", 19, 1003),
TypeInfo("numeric", 1700, 1231),
TypeInfo("oid", 26, 1028),
TypeInfo("oidvector", 30, 1013),
TypeInfo("path", 602, 1019),
TypeInfo("pg_lsn", 3220, 3221),
TypeInfo("point", 600, 1017),
TypeInfo("polygon", 604, 1027),
TypeInfo("record", 2249, 2287),
TypeInfo("refcursor", 1790, 2201),
TypeInfo("regclass", 2205, 2210),
TypeInfo("regcollation", 4191, 4192),
TypeInfo("regconfig", 3734, 3735),
TypeInfo("regdictionary", 3769, 3770),
TypeInfo("regnamespace", 4089, 4090),
TypeInfo("regoper", 2203, 2208),
TypeInfo("regoperator", 2204, 2209),
TypeInfo("regproc", 24, 1008),
TypeInfo("regprocedure", 2202, 2207),
TypeInfo("regrole", 4096, 4097),
TypeInfo("regtype", 2206, 2211),
TypeInfo("text", 25, 1009),
TypeInfo("tid", 27, 1010),
TypeInfo("time", 1083, 1183, regtype="time without time zone"),
TypeInfo("timestamp", 1114, 1115, regtype="timestamp without time zone"),
TypeInfo("timestamptz", 1184, 1185, regtype="timestamp with time zone"),
TypeInfo("timetz", 1266, 1270, regtype="time with time zone"),
TypeInfo("tsquery", 3615, 3645),
TypeInfo("tsvector", 3614, 3643),
TypeInfo("txid_snapshot", 2970, 2949),
TypeInfo("uuid", 2950, 2951),
TypeInfo("varbit", 1562, 1563, regtype="bit varying"),
TypeInfo("varchar", 1043, 1015, regtype="character varying"),
TypeInfo("xid", 28, 1011),
TypeInfo("xid8", 5069, 271),
TypeInfo("xml", 142, 143),
RangeInfo("daterange", 3912, 3913, subtype_oid=1082),
RangeInfo("int4range", 3904, 3905, subtype_oid=23),
RangeInfo("int8range", 3926, 3927, subtype_oid=20),
RangeInfo("numrange", 3906, 3907, subtype_oid=1700),
RangeInfo("tsrange", 3908, 3909, subtype_oid=1114),
RangeInfo("tstzrange", 3910, 3911, subtype_oid=1184),
MultirangeInfo("datemultirange", 4535, 6155, range_oid=3912, subtype_oid=1082),
MultirangeInfo("int4multirange", 4451, 6150, range_oid=3904, subtype_oid=23),
MultirangeInfo("int8multirange", 4536, 6157, range_oid=3926, subtype_oid=20),
MultirangeInfo("nummultirange", 4532, 6151, range_oid=3906, subtype_oid=1700),
MultirangeInfo("tsmultirange", 4533, 6152, range_oid=3908, subtype_oid=1114),
MultirangeInfo("tstzmultirange", 4534, 6153, range_oid=3910, subtype_oid=1184),
# autogenerated: end
]:
types.add(t)
# A few oids used a bit everywhere
INVALID_OID = 0
TEXT_OID = types["text"].oid
TEXT_ARRAY_OID = types["text"].array_oid
def register_default_adapters(context: AdaptContext) -> None:
from .types import array, bool, composite, datetime, enum, json, multirange
from .types import net, none, numeric, range, string, uuid
array.register_default_adapters(context)
bool.register_default_adapters(context)
composite.register_default_adapters(context)
datetime.register_default_adapters(context)
enum.register_default_adapters(context)
json.register_default_adapters(context)
multirange.register_default_adapters(context)
net.register_default_adapters(context)
none.register_default_adapters(context)
numeric.register_default_adapters(context)
range.register_default_adapters(context)
string.register_default_adapters(context)
uuid.register_default_adapters(context)

View File

@ -0,0 +1,133 @@
"""
psycopg libpq wrapper
This package exposes the libpq functionalities as Python objects and functions.
The real implementation (the binding to the C library) is
implementation-dependant but all the implementations share the same interface.
"""
# Copyright (C) 2020 The Psycopg Team
import os
import logging
from typing import Callable, List, Type
from . import abc
from .misc import ConninfoOption, PGnotify, PGresAttDesc
from .misc import error_message
from ._enums import ConnStatus, DiagnosticField, ExecStatus, Format, Trace
from ._enums import Ping, PipelineStatus, PollingStatus, TransactionStatus
logger = logging.getLogger(__name__)
__impl__: str
"""The currently loaded implementation of the `!psycopg.pq` package.
Possible values include ``python``, ``c``, ``binary``.
"""
__build_version__: int
"""The libpq version the C package was built with.
A number in the same format of `~psycopg.ConnectionInfo.server_version`
representing the libpq used to build the speedup module (``c``, ``binary``) if
available.
Certain features might not be available if the built version is too old.
"""
version: Callable[[], int]
PGconn: Type[abc.PGconn]
PGresult: Type[abc.PGresult]
Conninfo: Type[abc.Conninfo]
Escaping: Type[abc.Escaping]
PGcancel: Type[abc.PGcancel]
def import_from_libpq() -> None:
"""
Import pq objects implementation from the best libpq wrapper available.
If an implementation is requested try to import only it, otherwise
try to import the best implementation available.
"""
# import these names into the module on success as side effect
global __impl__, version, __build_version__
global PGconn, PGresult, Conninfo, Escaping, PGcancel
impl = os.environ.get("PSYCOPG_IMPL", "").lower()
module = None
attempts: List[str] = []
def handle_error(name: str, e: Exception) -> None:
if not impl:
msg = f"couldn't import psycopg '{name}' implementation: {e}"
logger.debug(msg)
attempts.append(msg)
else:
msg = f"couldn't import requested psycopg '{name}' implementation: {e}"
raise ImportError(msg) from e
# The best implementation: fast but requires the system libpq installed
if not impl or impl == "c":
try:
from psycopg_c import pq as module # type: ignore
except Exception as e:
handle_error("c", e)
# Second best implementation: fast and stand-alone
if not module and (not impl or impl == "binary"):
try:
from psycopg_binary import pq as module # type: ignore
except Exception as e:
handle_error("binary", e)
# Pure Python implementation, slow and requires the system libpq installed.
if not module and (not impl or impl == "python"):
try:
from . import pq_ctypes as module # type: ignore[assignment]
except Exception as e:
handle_error("python", e)
if module:
__impl__ = module.__impl__
version = module.version
PGconn = module.PGconn
PGresult = module.PGresult
Conninfo = module.Conninfo
Escaping = module.Escaping
PGcancel = module.PGcancel
__build_version__ = module.__build_version__
elif impl:
raise ImportError(f"requested psycopg implementation '{impl}' unknown")
else:
sattempts = "\n".join(f"- {attempt}" for attempt in attempts)
raise ImportError(
f"""\
no pq wrapper available.
Attempts made:
{sattempts}"""
)
import_from_libpq()
__all__ = (
"ConnStatus",
"PipelineStatus",
"PollingStatus",
"TransactionStatus",
"ExecStatus",
"Ping",
"DiagnosticField",
"Format",
"Trace",
"PGconn",
"PGnotify",
"Conninfo",
"PGresAttDesc",
"error_message",
"ConninfoOption",
"version",
)

View File

@ -0,0 +1,106 @@
"""
libpq debugging tools
These functionalities are exposed here for convenience, but are not part of
the public interface and are subject to change at any moment.
Suggested usage::
import logging
import psycopg
from psycopg import pq
from psycopg.pq._debug import PGconnDebug
logging.basicConfig(level=logging.INFO, format="%(message)s")
logger = logging.getLogger("psycopg.debug")
logger.setLevel(logging.INFO)
assert pq.__impl__ == "python"
pq.PGconn = PGconnDebug
with psycopg.connect("") as conn:
conn.pgconn.trace(2)
conn.pgconn.set_trace_flags(
pq.Trace.SUPPRESS_TIMESTAMPS | pq.Trace.REGRESS_MODE)
...
"""
# Copyright (C) 2022 The Psycopg Team
import inspect
import logging
from typing import Any, Callable, Type, TypeVar, TYPE_CHECKING
from functools import wraps
from . import PGconn
from .misc import connection_summary
if TYPE_CHECKING:
from . import abc
Func = TypeVar("Func", bound=Callable[..., Any])
logger = logging.getLogger("psycopg.debug")
class PGconnDebug:
"""Wrapper for a PQconn logging all its access."""
_Self = TypeVar("_Self", bound="PGconnDebug")
_pgconn: "abc.PGconn"
def __init__(self, pgconn: "abc.PGconn"):
super().__setattr__("_pgconn", pgconn)
def __repr__(self) -> str:
cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
info = connection_summary(self._pgconn)
return f"<{cls} {info} at 0x{id(self):x}>"
def __getattr__(self, attr: str) -> Any:
value = getattr(self._pgconn, attr)
if callable(value):
return debugging(value)
else:
logger.info("PGconn.%s -> %s", attr, value)
return value
def __setattr__(self, attr: str, value: Any) -> None:
setattr(self._pgconn, attr, value)
logger.info("PGconn.%s <- %s", attr, value)
@classmethod
def connect(cls: Type[_Self], conninfo: bytes) -> _Self:
return cls(debugging(PGconn.connect)(conninfo))
@classmethod
def connect_start(cls: Type[_Self], conninfo: bytes) -> _Self:
return cls(debugging(PGconn.connect_start)(conninfo))
@classmethod
def ping(self, conninfo: bytes) -> int:
return debugging(PGconn.ping)(conninfo)
def debugging(f: Func) -> Func:
"""Wrap a function in order to log its arguments and return value on call."""
@wraps(f)
def debugging_(*args: Any, **kwargs: Any) -> Any:
reprs = []
for arg in args:
reprs.append(f"{arg!r}")
for k, v in kwargs.items():
reprs.append(f"{k}={v!r}")
logger.info("PGconn.%s(%s)", f.__name__, ", ".join(reprs))
rv = f(*args, **kwargs)
# Display the return value only if the function is declared to return
# something else than None.
ra = inspect.signature(f).return_annotation
if ra is not None or rv is not None:
logger.info(" <- %r", rv)
return rv
return debugging_ # type: ignore

View File

@ -0,0 +1,249 @@
"""
libpq enum definitions for psycopg
"""
# Copyright (C) 2020 The Psycopg Team
from enum import IntEnum, IntFlag, auto
class ConnStatus(IntEnum):
"""
Current status of the connection.
"""
__module__ = "psycopg.pq"
OK = 0
"""The connection is in a working state."""
BAD = auto()
"""The connection is closed."""
STARTED = auto()
MADE = auto()
AWAITING_RESPONSE = auto()
AUTH_OK = auto()
SETENV = auto()
SSL_STARTUP = auto()
NEEDED = auto()
CHECK_WRITABLE = auto()
CONSUME = auto()
GSS_STARTUP = auto()
CHECK_TARGET = auto()
CHECK_STANDBY = auto()
class PollingStatus(IntEnum):
"""
The status of the socket during a connection.
If ``READING`` or ``WRITING`` you may select before polling again.
"""
__module__ = "psycopg.pq"
FAILED = 0
"""Connection attempt failed."""
READING = auto()
"""Will have to wait before reading new data."""
WRITING = auto()
"""Will have to wait before writing new data."""
OK = auto()
"""Connection completed."""
ACTIVE = auto()
class ExecStatus(IntEnum):
"""
The status of a command.
"""
__module__ = "psycopg.pq"
EMPTY_QUERY = 0
"""The string sent to the server was empty."""
COMMAND_OK = auto()
"""Successful completion of a command returning no data."""
TUPLES_OK = auto()
"""
Successful completion of a command returning data (such as a SELECT or SHOW).
"""
COPY_OUT = auto()
"""Copy Out (from server) data transfer started."""
COPY_IN = auto()
"""Copy In (to server) data transfer started."""
BAD_RESPONSE = auto()
"""The server's response was not understood."""
NONFATAL_ERROR = auto()
"""A nonfatal error (a notice or warning) occurred."""
FATAL_ERROR = auto()
"""A fatal error occurred."""
COPY_BOTH = auto()
"""
Copy In/Out (to and from server) data transfer started.
This feature is currently used only for streaming replication, so this
status should not occur in ordinary applications.
"""
SINGLE_TUPLE = auto()
"""
The PGresult contains a single result tuple from the current command.
This status occurs only when single-row mode has been selected for the
query.
"""
PIPELINE_SYNC = auto()
"""
The PGresult represents a synchronization point in pipeline mode,
requested by PQpipelineSync.
This status occurs only when pipeline mode has been selected.
"""
PIPELINE_ABORTED = auto()
"""
The PGresult represents a pipeline that has received an error from the server.
PQgetResult must be called repeatedly, and each time it will return this
status code until the end of the current pipeline, at which point it will
return PGRES_PIPELINE_SYNC and normal processing can resume.
"""
class TransactionStatus(IntEnum):
"""
The transaction status of a connection.
"""
__module__ = "psycopg.pq"
IDLE = 0
"""Connection ready, no transaction active."""
ACTIVE = auto()
"""A command is in progress."""
INTRANS = auto()
"""Connection idle in an open transaction."""
INERROR = auto()
"""An error happened in the current transaction."""
UNKNOWN = auto()
"""Unknown connection state, broken connection."""
class Ping(IntEnum):
"""Response from a ping attempt."""
__module__ = "psycopg.pq"
OK = 0
"""
The server is running and appears to be accepting connections.
"""
REJECT = auto()
"""
The server is running but is in a state that disallows connections.
"""
NO_RESPONSE = auto()
"""
The server could not be contacted.
"""
NO_ATTEMPT = auto()
"""
No attempt was made to contact the server.
"""
class PipelineStatus(IntEnum):
"""Pipeline mode status of the libpq connection."""
__module__ = "psycopg.pq"
OFF = 0
"""
The libpq connection is *not* in pipeline mode.
"""
ON = auto()
"""
The libpq connection is in pipeline mode.
"""
ABORTED = auto()
"""
The libpq connection is in pipeline mode and an error occurred while
processing the current pipeline. The aborted flag is cleared when
PQgetResult returns a result of type PGRES_PIPELINE_SYNC.
"""
class DiagnosticField(IntEnum):
"""
Fields in an error report.
"""
__module__ = "psycopg.pq"
# from postgres_ext.h
SEVERITY = ord("S")
SEVERITY_NONLOCALIZED = ord("V")
SQLSTATE = ord("C")
MESSAGE_PRIMARY = ord("M")
MESSAGE_DETAIL = ord("D")
MESSAGE_HINT = ord("H")
STATEMENT_POSITION = ord("P")
INTERNAL_POSITION = ord("p")
INTERNAL_QUERY = ord("q")
CONTEXT = ord("W")
SCHEMA_NAME = ord("s")
TABLE_NAME = ord("t")
COLUMN_NAME = ord("c")
DATATYPE_NAME = ord("d")
CONSTRAINT_NAME = ord("n")
SOURCE_FILE = ord("F")
SOURCE_LINE = ord("L")
SOURCE_FUNCTION = ord("R")
class Format(IntEnum):
"""
Enum representing the format of a query argument or return value.
These values are only the ones managed by the libpq. `~psycopg` may also
support automatically-chosen values: see `psycopg.adapt.PyFormat`.
"""
__module__ = "psycopg.pq"
TEXT = 0
"""Text parameter."""
BINARY = 1
"""Binary parameter."""
class Trace(IntFlag):
"""
Enum to control tracing of the client/server communication.
"""
__module__ = "psycopg.pq"
SUPPRESS_TIMESTAMPS = 1
"""Do not include timestamps in messages."""
REGRESS_MODE = 2
"""Redact some fields, e.g. OIDs, from messages."""

View File

@ -0,0 +1,804 @@
"""
libpq access using ctypes
"""
# Copyright (C) 2020 The Psycopg Team
import sys
import ctypes
import ctypes.util
from ctypes import Structure, CFUNCTYPE, POINTER
from ctypes import c_char, c_char_p, c_int, c_size_t, c_ubyte, c_uint, c_void_p
from typing import List, Optional, Tuple
from .misc import find_libpq_full_path
from ..errors import NotSupportedError
libname = find_libpq_full_path()
if not libname:
raise ImportError("libpq library not found")
pq = ctypes.cdll.LoadLibrary(libname)
class FILE(Structure):
pass
FILE_ptr = POINTER(FILE)
if sys.platform == "linux":
libcname = ctypes.util.find_library("c")
assert libcname
libc = ctypes.cdll.LoadLibrary(libcname)
fdopen = libc.fdopen
fdopen.argtypes = (c_int, c_char_p)
fdopen.restype = FILE_ptr
# Get the libpq version to define what functions are available.
PQlibVersion = pq.PQlibVersion
PQlibVersion.argtypes = []
PQlibVersion.restype = c_int
libpq_version = PQlibVersion()
# libpq data types
Oid = c_uint
class PGconn_struct(Structure):
_fields_: List[Tuple[str, type]] = []
class PGresult_struct(Structure):
_fields_: List[Tuple[str, type]] = []
class PQconninfoOption_struct(Structure):
_fields_ = [
("keyword", c_char_p),
("envvar", c_char_p),
("compiled", c_char_p),
("val", c_char_p),
("label", c_char_p),
("dispchar", c_char_p),
("dispsize", c_int),
]
class PGnotify_struct(Structure):
_fields_ = [
("relname", c_char_p),
("be_pid", c_int),
("extra", c_char_p),
]
class PGcancel_struct(Structure):
_fields_: List[Tuple[str, type]] = []
class PGresAttDesc_struct(Structure):
_fields_ = [
("name", c_char_p),
("tableid", Oid),
("columnid", c_int),
("format", c_int),
("typid", Oid),
("typlen", c_int),
("atttypmod", c_int),
]
PGconn_ptr = POINTER(PGconn_struct)
PGresult_ptr = POINTER(PGresult_struct)
PQconninfoOption_ptr = POINTER(PQconninfoOption_struct)
PGnotify_ptr = POINTER(PGnotify_struct)
PGcancel_ptr = POINTER(PGcancel_struct)
PGresAttDesc_ptr = POINTER(PGresAttDesc_struct)
# Function definitions as explained in PostgreSQL 12 documentation
# 33.1. Database Connection Control Functions
# PQconnectdbParams: doesn't seem useful, won't wrap for now
PQconnectdb = pq.PQconnectdb
PQconnectdb.argtypes = [c_char_p]
PQconnectdb.restype = PGconn_ptr
# PQsetdbLogin: not useful
# PQsetdb: not useful
# PQconnectStartParams: not useful
PQconnectStart = pq.PQconnectStart
PQconnectStart.argtypes = [c_char_p]
PQconnectStart.restype = PGconn_ptr
PQconnectPoll = pq.PQconnectPoll
PQconnectPoll.argtypes = [PGconn_ptr]
PQconnectPoll.restype = c_int
PQconndefaults = pq.PQconndefaults
PQconndefaults.argtypes = []
PQconndefaults.restype = PQconninfoOption_ptr
PQconninfoFree = pq.PQconninfoFree
PQconninfoFree.argtypes = [PQconninfoOption_ptr]
PQconninfoFree.restype = None
PQconninfo = pq.PQconninfo
PQconninfo.argtypes = [PGconn_ptr]
PQconninfo.restype = PQconninfoOption_ptr
PQconninfoParse = pq.PQconninfoParse
PQconninfoParse.argtypes = [c_char_p, POINTER(c_char_p)]
PQconninfoParse.restype = PQconninfoOption_ptr
PQfinish = pq.PQfinish
PQfinish.argtypes = [PGconn_ptr]
PQfinish.restype = None
PQreset = pq.PQreset
PQreset.argtypes = [PGconn_ptr]
PQreset.restype = None
PQresetStart = pq.PQresetStart
PQresetStart.argtypes = [PGconn_ptr]
PQresetStart.restype = c_int
PQresetPoll = pq.PQresetPoll
PQresetPoll.argtypes = [PGconn_ptr]
PQresetPoll.restype = c_int
PQping = pq.PQping
PQping.argtypes = [c_char_p]
PQping.restype = c_int
# 33.2. Connection Status Functions
PQdb = pq.PQdb
PQdb.argtypes = [PGconn_ptr]
PQdb.restype = c_char_p
PQuser = pq.PQuser
PQuser.argtypes = [PGconn_ptr]
PQuser.restype = c_char_p
PQpass = pq.PQpass
PQpass.argtypes = [PGconn_ptr]
PQpass.restype = c_char_p
PQhost = pq.PQhost
PQhost.argtypes = [PGconn_ptr]
PQhost.restype = c_char_p
_PQhostaddr = None
if libpq_version >= 120000:
_PQhostaddr = pq.PQhostaddr
_PQhostaddr.argtypes = [PGconn_ptr]
_PQhostaddr.restype = c_char_p
def PQhostaddr(pgconn: PGconn_struct) -> bytes:
if not _PQhostaddr:
raise NotSupportedError(
"PQhostaddr requires libpq from PostgreSQL 12,"
f" {libpq_version} available instead"
)
return _PQhostaddr(pgconn)
PQport = pq.PQport
PQport.argtypes = [PGconn_ptr]
PQport.restype = c_char_p
PQtty = pq.PQtty
PQtty.argtypes = [PGconn_ptr]
PQtty.restype = c_char_p
PQoptions = pq.PQoptions
PQoptions.argtypes = [PGconn_ptr]
PQoptions.restype = c_char_p
PQstatus = pq.PQstatus
PQstatus.argtypes = [PGconn_ptr]
PQstatus.restype = c_int
PQtransactionStatus = pq.PQtransactionStatus
PQtransactionStatus.argtypes = [PGconn_ptr]
PQtransactionStatus.restype = c_int
PQparameterStatus = pq.PQparameterStatus
PQparameterStatus.argtypes = [PGconn_ptr, c_char_p]
PQparameterStatus.restype = c_char_p
PQprotocolVersion = pq.PQprotocolVersion
PQprotocolVersion.argtypes = [PGconn_ptr]
PQprotocolVersion.restype = c_int
PQserverVersion = pq.PQserverVersion
PQserverVersion.argtypes = [PGconn_ptr]
PQserverVersion.restype = c_int
PQerrorMessage = pq.PQerrorMessage
PQerrorMessage.argtypes = [PGconn_ptr]
PQerrorMessage.restype = c_char_p
PQsocket = pq.PQsocket
PQsocket.argtypes = [PGconn_ptr]
PQsocket.restype = c_int
PQbackendPID = pq.PQbackendPID
PQbackendPID.argtypes = [PGconn_ptr]
PQbackendPID.restype = c_int
PQconnectionNeedsPassword = pq.PQconnectionNeedsPassword
PQconnectionNeedsPassword.argtypes = [PGconn_ptr]
PQconnectionNeedsPassword.restype = c_int
PQconnectionUsedPassword = pq.PQconnectionUsedPassword
PQconnectionUsedPassword.argtypes = [PGconn_ptr]
PQconnectionUsedPassword.restype = c_int
PQsslInUse = pq.PQsslInUse
PQsslInUse.argtypes = [PGconn_ptr]
PQsslInUse.restype = c_int
# TODO: PQsslAttribute, PQsslAttributeNames, PQsslStruct, PQgetssl
# 33.3. Command Execution Functions
PQexec = pq.PQexec
PQexec.argtypes = [PGconn_ptr, c_char_p]
PQexec.restype = PGresult_ptr
PQexecParams = pq.PQexecParams
PQexecParams.argtypes = [
PGconn_ptr,
c_char_p,
c_int,
POINTER(Oid),
POINTER(c_char_p),
POINTER(c_int),
POINTER(c_int),
c_int,
]
PQexecParams.restype = PGresult_ptr
PQprepare = pq.PQprepare
PQprepare.argtypes = [PGconn_ptr, c_char_p, c_char_p, c_int, POINTER(Oid)]
PQprepare.restype = PGresult_ptr
PQexecPrepared = pq.PQexecPrepared
PQexecPrepared.argtypes = [
PGconn_ptr,
c_char_p,
c_int,
POINTER(c_char_p),
POINTER(c_int),
POINTER(c_int),
c_int,
]
PQexecPrepared.restype = PGresult_ptr
PQdescribePrepared = pq.PQdescribePrepared
PQdescribePrepared.argtypes = [PGconn_ptr, c_char_p]
PQdescribePrepared.restype = PGresult_ptr
PQdescribePortal = pq.PQdescribePortal
PQdescribePortal.argtypes = [PGconn_ptr, c_char_p]
PQdescribePortal.restype = PGresult_ptr
PQresultStatus = pq.PQresultStatus
PQresultStatus.argtypes = [PGresult_ptr]
PQresultStatus.restype = c_int
# PQresStatus: not needed, we have pretty enums
PQresultErrorMessage = pq.PQresultErrorMessage
PQresultErrorMessage.argtypes = [PGresult_ptr]
PQresultErrorMessage.restype = c_char_p
# TODO: PQresultVerboseErrorMessage
PQresultErrorField = pq.PQresultErrorField
PQresultErrorField.argtypes = [PGresult_ptr, c_int]
PQresultErrorField.restype = c_char_p
PQclear = pq.PQclear
PQclear.argtypes = [PGresult_ptr]
PQclear.restype = None
# 33.3.2. Retrieving Query Result Information
PQntuples = pq.PQntuples
PQntuples.argtypes = [PGresult_ptr]
PQntuples.restype = c_int
PQnfields = pq.PQnfields
PQnfields.argtypes = [PGresult_ptr]
PQnfields.restype = c_int
PQfname = pq.PQfname
PQfname.argtypes = [PGresult_ptr, c_int]
PQfname.restype = c_char_p
# PQfnumber: useless and hard to use
PQftable = pq.PQftable
PQftable.argtypes = [PGresult_ptr, c_int]
PQftable.restype = Oid
PQftablecol = pq.PQftablecol
PQftablecol.argtypes = [PGresult_ptr, c_int]
PQftablecol.restype = c_int
PQfformat = pq.PQfformat
PQfformat.argtypes = [PGresult_ptr, c_int]
PQfformat.restype = c_int
PQftype = pq.PQftype
PQftype.argtypes = [PGresult_ptr, c_int]
PQftype.restype = Oid
PQfmod = pq.PQfmod
PQfmod.argtypes = [PGresult_ptr, c_int]
PQfmod.restype = c_int
PQfsize = pq.PQfsize
PQfsize.argtypes = [PGresult_ptr, c_int]
PQfsize.restype = c_int
PQbinaryTuples = pq.PQbinaryTuples
PQbinaryTuples.argtypes = [PGresult_ptr]
PQbinaryTuples.restype = c_int
PQgetvalue = pq.PQgetvalue
PQgetvalue.argtypes = [PGresult_ptr, c_int, c_int]
PQgetvalue.restype = POINTER(c_char) # not a null-terminated string
PQgetisnull = pq.PQgetisnull
PQgetisnull.argtypes = [PGresult_ptr, c_int, c_int]
PQgetisnull.restype = c_int
PQgetlength = pq.PQgetlength
PQgetlength.argtypes = [PGresult_ptr, c_int, c_int]
PQgetlength.restype = c_int
PQnparams = pq.PQnparams
PQnparams.argtypes = [PGresult_ptr]
PQnparams.restype = c_int
PQparamtype = pq.PQparamtype
PQparamtype.argtypes = [PGresult_ptr, c_int]
PQparamtype.restype = Oid
# PQprint: pretty useless
# 33.3.3. Retrieving Other Result Information
PQcmdStatus = pq.PQcmdStatus
PQcmdStatus.argtypes = [PGresult_ptr]
PQcmdStatus.restype = c_char_p
PQcmdTuples = pq.PQcmdTuples
PQcmdTuples.argtypes = [PGresult_ptr]
PQcmdTuples.restype = c_char_p
PQoidValue = pq.PQoidValue
PQoidValue.argtypes = [PGresult_ptr]
PQoidValue.restype = Oid
# 33.3.4. Escaping Strings for Inclusion in SQL Commands
PQescapeLiteral = pq.PQescapeLiteral
PQescapeLiteral.argtypes = [PGconn_ptr, c_char_p, c_size_t]
PQescapeLiteral.restype = POINTER(c_char)
PQescapeIdentifier = pq.PQescapeIdentifier
PQescapeIdentifier.argtypes = [PGconn_ptr, c_char_p, c_size_t]
PQescapeIdentifier.restype = POINTER(c_char)
PQescapeStringConn = pq.PQescapeStringConn
# TODO: raises "wrong type" error
# PQescapeStringConn.argtypes = [
# PGconn_ptr, c_char_p, c_char_p, c_size_t, POINTER(c_int)
# ]
PQescapeStringConn.restype = c_size_t
PQescapeString = pq.PQescapeString
# TODO: raises "wrong type" error
# PQescapeString.argtypes = [c_char_p, c_char_p, c_size_t]
PQescapeString.restype = c_size_t
PQescapeByteaConn = pq.PQescapeByteaConn
PQescapeByteaConn.argtypes = [
PGconn_ptr,
POINTER(c_char), # actually POINTER(c_ubyte) but this is easier
c_size_t,
POINTER(c_size_t),
]
PQescapeByteaConn.restype = POINTER(c_ubyte)
PQescapeBytea = pq.PQescapeBytea
PQescapeBytea.argtypes = [
POINTER(c_char), # actually POINTER(c_ubyte) but this is easier
c_size_t,
POINTER(c_size_t),
]
PQescapeBytea.restype = POINTER(c_ubyte)
PQunescapeBytea = pq.PQunescapeBytea
PQunescapeBytea.argtypes = [
POINTER(c_char), # actually POINTER(c_ubyte) but this is easier
POINTER(c_size_t),
]
PQunescapeBytea.restype = POINTER(c_ubyte)
# 33.4. Asynchronous Command Processing
PQsendQuery = pq.PQsendQuery
PQsendQuery.argtypes = [PGconn_ptr, c_char_p]
PQsendQuery.restype = c_int
PQsendQueryParams = pq.PQsendQueryParams
PQsendQueryParams.argtypes = [
PGconn_ptr,
c_char_p,
c_int,
POINTER(Oid),
POINTER(c_char_p),
POINTER(c_int),
POINTER(c_int),
c_int,
]
PQsendQueryParams.restype = c_int
PQsendPrepare = pq.PQsendPrepare
PQsendPrepare.argtypes = [PGconn_ptr, c_char_p, c_char_p, c_int, POINTER(Oid)]
PQsendPrepare.restype = c_int
PQsendQueryPrepared = pq.PQsendQueryPrepared
PQsendQueryPrepared.argtypes = [
PGconn_ptr,
c_char_p,
c_int,
POINTER(c_char_p),
POINTER(c_int),
POINTER(c_int),
c_int,
]
PQsendQueryPrepared.restype = c_int
PQsendDescribePrepared = pq.PQsendDescribePrepared
PQsendDescribePrepared.argtypes = [PGconn_ptr, c_char_p]
PQsendDescribePrepared.restype = c_int
PQsendDescribePortal = pq.PQsendDescribePortal
PQsendDescribePortal.argtypes = [PGconn_ptr, c_char_p]
PQsendDescribePortal.restype = c_int
PQgetResult = pq.PQgetResult
PQgetResult.argtypes = [PGconn_ptr]
PQgetResult.restype = PGresult_ptr
PQconsumeInput = pq.PQconsumeInput
PQconsumeInput.argtypes = [PGconn_ptr]
PQconsumeInput.restype = c_int
PQisBusy = pq.PQisBusy
PQisBusy.argtypes = [PGconn_ptr]
PQisBusy.restype = c_int
PQsetnonblocking = pq.PQsetnonblocking
PQsetnonblocking.argtypes = [PGconn_ptr, c_int]
PQsetnonblocking.restype = c_int
PQisnonblocking = pq.PQisnonblocking
PQisnonblocking.argtypes = [PGconn_ptr]
PQisnonblocking.restype = c_int
PQflush = pq.PQflush
PQflush.argtypes = [PGconn_ptr]
PQflush.restype = c_int
# 33.5. Retrieving Query Results Row-by-Row
PQsetSingleRowMode = pq.PQsetSingleRowMode
PQsetSingleRowMode.argtypes = [PGconn_ptr]
PQsetSingleRowMode.restype = c_int
# 33.6. Canceling Queries in Progress
PQgetCancel = pq.PQgetCancel
PQgetCancel.argtypes = [PGconn_ptr]
PQgetCancel.restype = PGcancel_ptr
PQfreeCancel = pq.PQfreeCancel
PQfreeCancel.argtypes = [PGcancel_ptr]
PQfreeCancel.restype = None
PQcancel = pq.PQcancel
# TODO: raises "wrong type" error
# PQcancel.argtypes = [PGcancel_ptr, POINTER(c_char), c_int]
PQcancel.restype = c_int
# 33.8. Asynchronous Notification
PQnotifies = pq.PQnotifies
PQnotifies.argtypes = [PGconn_ptr]
PQnotifies.restype = PGnotify_ptr
# 33.9. Functions Associated with the COPY Command
PQputCopyData = pq.PQputCopyData
PQputCopyData.argtypes = [PGconn_ptr, c_char_p, c_int]
PQputCopyData.restype = c_int
PQputCopyEnd = pq.PQputCopyEnd
PQputCopyEnd.argtypes = [PGconn_ptr, c_char_p]
PQputCopyEnd.restype = c_int
PQgetCopyData = pq.PQgetCopyData
PQgetCopyData.argtypes = [PGconn_ptr, POINTER(c_char_p), c_int]
PQgetCopyData.restype = c_int
# 33.10. Control Functions
PQtrace = pq.PQtrace
PQtrace.argtypes = [PGconn_ptr, FILE_ptr]
PQtrace.restype = None
_PQsetTraceFlags = None
if libpq_version >= 140000:
_PQsetTraceFlags = pq.PQsetTraceFlags
_PQsetTraceFlags.argtypes = [PGconn_ptr, c_int]
_PQsetTraceFlags.restype = None
def PQsetTraceFlags(pgconn: PGconn_struct, flags: int) -> None:
if not _PQsetTraceFlags:
raise NotSupportedError(
"PQsetTraceFlags requires libpq from PostgreSQL 14,"
f" {libpq_version} available instead"
)
_PQsetTraceFlags(pgconn, flags)
PQuntrace = pq.PQuntrace
PQuntrace.argtypes = [PGconn_ptr]
PQuntrace.restype = None
# 33.11. Miscellaneous Functions
PQfreemem = pq.PQfreemem
PQfreemem.argtypes = [c_void_p]
PQfreemem.restype = None
if libpq_version >= 100000:
_PQencryptPasswordConn = pq.PQencryptPasswordConn
_PQencryptPasswordConn.argtypes = [
PGconn_ptr,
c_char_p,
c_char_p,
c_char_p,
]
_PQencryptPasswordConn.restype = POINTER(c_char)
def PQencryptPasswordConn(
pgconn: PGconn_struct, passwd: bytes, user: bytes, algorithm: bytes
) -> Optional[bytes]:
if not _PQencryptPasswordConn:
raise NotSupportedError(
"PQencryptPasswordConn requires libpq from PostgreSQL 10,"
f" {libpq_version} available instead"
)
return _PQencryptPasswordConn(pgconn, passwd, user, algorithm)
PQmakeEmptyPGresult = pq.PQmakeEmptyPGresult
PQmakeEmptyPGresult.argtypes = [PGconn_ptr, c_int]
PQmakeEmptyPGresult.restype = PGresult_ptr
PQsetResultAttrs = pq.PQsetResultAttrs
PQsetResultAttrs.argtypes = [PGresult_ptr, c_int, PGresAttDesc_ptr]
PQsetResultAttrs.restype = c_int
# 33.12. Notice Processing
PQnoticeReceiver = CFUNCTYPE(None, c_void_p, PGresult_ptr)
PQsetNoticeReceiver = pq.PQsetNoticeReceiver
PQsetNoticeReceiver.argtypes = [PGconn_ptr, PQnoticeReceiver, c_void_p]
PQsetNoticeReceiver.restype = PQnoticeReceiver
# 34.5 Pipeline Mode
_PQpipelineStatus = None
_PQenterPipelineMode = None
_PQexitPipelineMode = None
_PQpipelineSync = None
_PQsendFlushRequest = None
if libpq_version >= 140000:
_PQpipelineStatus = pq.PQpipelineStatus
_PQpipelineStatus.argtypes = [PGconn_ptr]
_PQpipelineStatus.restype = c_int
_PQenterPipelineMode = pq.PQenterPipelineMode
_PQenterPipelineMode.argtypes = [PGconn_ptr]
_PQenterPipelineMode.restype = c_int
_PQexitPipelineMode = pq.PQexitPipelineMode
_PQexitPipelineMode.argtypes = [PGconn_ptr]
_PQexitPipelineMode.restype = c_int
_PQpipelineSync = pq.PQpipelineSync
_PQpipelineSync.argtypes = [PGconn_ptr]
_PQpipelineSync.restype = c_int
_PQsendFlushRequest = pq.PQsendFlushRequest
_PQsendFlushRequest.argtypes = [PGconn_ptr]
_PQsendFlushRequest.restype = c_int
def PQpipelineStatus(pgconn: PGconn_struct) -> int:
if not _PQpipelineStatus:
raise NotSupportedError(
"PQpipelineStatus requires libpq from PostgreSQL 14,"
f" {libpq_version} available instead"
)
return _PQpipelineStatus(pgconn)
def PQenterPipelineMode(pgconn: PGconn_struct) -> int:
if not _PQenterPipelineMode:
raise NotSupportedError(
"PQenterPipelineMode requires libpq from PostgreSQL 14,"
f" {libpq_version} available instead"
)
return _PQenterPipelineMode(pgconn)
def PQexitPipelineMode(pgconn: PGconn_struct) -> int:
if not _PQexitPipelineMode:
raise NotSupportedError(
"PQexitPipelineMode requires libpq from PostgreSQL 14,"
f" {libpq_version} available instead"
)
return _PQexitPipelineMode(pgconn)
def PQpipelineSync(pgconn: PGconn_struct) -> int:
if not _PQpipelineSync:
raise NotSupportedError(
"PQpipelineSync requires libpq from PostgreSQL 14,"
f" {libpq_version} available instead"
)
return _PQpipelineSync(pgconn)
def PQsendFlushRequest(pgconn: PGconn_struct) -> int:
if not _PQsendFlushRequest:
raise NotSupportedError(
"PQsendFlushRequest requires libpq from PostgreSQL 14,"
f" {libpq_version} available instead"
)
return _PQsendFlushRequest(pgconn)
# 33.18. SSL Support
PQinitOpenSSL = pq.PQinitOpenSSL
PQinitOpenSSL.argtypes = [c_int, c_int]
PQinitOpenSSL.restype = None
def generate_stub() -> None:
import re
from ctypes import _CFuncPtr # type: ignore
def type2str(fname, narg, t):
if t is None:
return "None"
elif t is c_void_p:
return "Any"
elif t is c_int or t is c_uint or t is c_size_t:
return "int"
elif t is c_char_p or t.__name__ == "LP_c_char":
if narg is not None:
return "bytes"
else:
return "Optional[bytes]"
elif t.__name__ in (
"LP_PGconn_struct",
"LP_PGresult_struct",
"LP_PGcancel_struct",
):
if narg is not None:
return f"Optional[{t.__name__[3:]}]"
else:
return t.__name__[3:]
elif t.__name__ in ("LP_PQconninfoOption_struct",):
return f"Sequence[{t.__name__[3:]}]"
elif t.__name__ in (
"LP_c_ubyte",
"LP_c_char_p",
"LP_c_int",
"LP_c_uint",
"LP_c_ulong",
"LP_FILE",
):
return f"_Pointer[{t.__name__[3:]}]"
else:
assert False, f"can't deal with {t} in {fname}"
fn = __file__ + "i"
with open(fn) as f:
lines = f.read().splitlines()
istart, iend = (
i
for i, line in enumerate(lines)
if re.match(r"\s*#\s*autogenerated:\s+(start|end)", line)
)
known = {
line[4:].split("(", 1)[0] for line in lines[:istart] if line.startswith("def ")
}
signatures = []
for name, obj in globals().items():
if name in known:
continue
if not isinstance(obj, _CFuncPtr):
continue
params = []
for i, t in enumerate(obj.argtypes):
params.append(f"arg{i + 1}: {type2str(name, i, t)}")
resname = type2str(name, None, obj.restype)
signatures.append(f"def {name}({', '.join(params)}) -> {resname}: ...")
lines[istart + 1 : iend] = signatures
with open(fn, "w") as f:
f.write("\n".join(lines))
f.write("\n")
if __name__ == "__main__":
generate_stub()

View File

@ -0,0 +1,384 @@
"""
Protocol objects to represent objects exposed by different pq implementations.
"""
# Copyright (C) 2020 The Psycopg Team
from typing import Any, Callable, List, Optional, Sequence, Tuple
from typing import Union, TYPE_CHECKING
from typing_extensions import TypeAlias
from ._enums import Format, Trace
from .._compat import Protocol
if TYPE_CHECKING:
from .misc import PGnotify, ConninfoOption, PGresAttDesc
# An object implementing the buffer protocol (ish)
Buffer: TypeAlias = Union[bytes, bytearray, memoryview]
class PGconn(Protocol):
notice_handler: Optional[Callable[["PGresult"], None]]
notify_handler: Optional[Callable[["PGnotify"], None]]
@classmethod
def connect(cls, conninfo: bytes) -> "PGconn":
...
@classmethod
def connect_start(cls, conninfo: bytes) -> "PGconn":
...
def connect_poll(self) -> int:
...
def finish(self) -> None:
...
@property
def info(self) -> List["ConninfoOption"]:
...
def reset(self) -> None:
...
def reset_start(self) -> None:
...
def reset_poll(self) -> int:
...
@classmethod
def ping(self, conninfo: bytes) -> int:
...
@property
def db(self) -> bytes:
...
@property
def user(self) -> bytes:
...
@property
def password(self) -> bytes:
...
@property
def host(self) -> bytes:
...
@property
def hostaddr(self) -> bytes:
...
@property
def port(self) -> bytes:
...
@property
def tty(self) -> bytes:
...
@property
def options(self) -> bytes:
...
@property
def status(self) -> int:
...
@property
def transaction_status(self) -> int:
...
def parameter_status(self, name: bytes) -> Optional[bytes]:
...
@property
def error_message(self) -> bytes:
...
@property
def server_version(self) -> int:
...
@property
def socket(self) -> int:
...
@property
def backend_pid(self) -> int:
...
@property
def needs_password(self) -> bool:
...
@property
def used_password(self) -> bool:
...
@property
def ssl_in_use(self) -> bool:
...
def exec_(self, command: bytes) -> "PGresult":
...
def send_query(self, command: bytes) -> None:
...
def exec_params(
self,
command: bytes,
param_values: Optional[Sequence[Optional[Buffer]]],
param_types: Optional[Sequence[int]] = None,
param_formats: Optional[Sequence[int]] = None,
result_format: int = Format.TEXT,
) -> "PGresult":
...
def send_query_params(
self,
command: bytes,
param_values: Optional[Sequence[Optional[Buffer]]],
param_types: Optional[Sequence[int]] = None,
param_formats: Optional[Sequence[int]] = None,
result_format: int = Format.TEXT,
) -> None:
...
def send_prepare(
self,
name: bytes,
command: bytes,
param_types: Optional[Sequence[int]] = None,
) -> None:
...
def send_query_prepared(
self,
name: bytes,
param_values: Optional[Sequence[Optional[Buffer]]],
param_formats: Optional[Sequence[int]] = None,
result_format: int = Format.TEXT,
) -> None:
...
def prepare(
self,
name: bytes,
command: bytes,
param_types: Optional[Sequence[int]] = None,
) -> "PGresult":
...
def exec_prepared(
self,
name: bytes,
param_values: Optional[Sequence[Buffer]],
param_formats: Optional[Sequence[int]] = None,
result_format: int = 0,
) -> "PGresult":
...
def describe_prepared(self, name: bytes) -> "PGresult":
...
def send_describe_prepared(self, name: bytes) -> None:
...
def describe_portal(self, name: bytes) -> "PGresult":
...
def send_describe_portal(self, name: bytes) -> None:
...
def get_result(self) -> Optional["PGresult"]:
...
def consume_input(self) -> None:
...
def is_busy(self) -> int:
...
@property
def nonblocking(self) -> int:
...
@nonblocking.setter
def nonblocking(self, arg: int) -> None:
...
def flush(self) -> int:
...
def set_single_row_mode(self) -> None:
...
def get_cancel(self) -> "PGcancel":
...
def notifies(self) -> Optional["PGnotify"]:
...
def put_copy_data(self, buffer: Buffer) -> int:
...
def put_copy_end(self, error: Optional[bytes] = None) -> int:
...
def get_copy_data(self, async_: int) -> Tuple[int, memoryview]:
...
def trace(self, fileno: int) -> None:
...
def set_trace_flags(self, flags: Trace) -> None:
...
def untrace(self) -> None:
...
def encrypt_password(
self, passwd: bytes, user: bytes, algorithm: Optional[bytes] = None
) -> bytes:
...
def make_empty_result(self, exec_status: int) -> "PGresult":
...
@property
def pipeline_status(self) -> int:
...
def enter_pipeline_mode(self) -> None:
...
def exit_pipeline_mode(self) -> None:
...
def pipeline_sync(self) -> None:
...
def send_flush_request(self) -> None:
...
class PGresult(Protocol):
def clear(self) -> None:
...
@property
def status(self) -> int:
...
@property
def error_message(self) -> bytes:
...
def error_field(self, fieldcode: int) -> Optional[bytes]:
...
@property
def ntuples(self) -> int:
...
@property
def nfields(self) -> int:
...
def fname(self, column_number: int) -> Optional[bytes]:
...
def ftable(self, column_number: int) -> int:
...
def ftablecol(self, column_number: int) -> int:
...
def fformat(self, column_number: int) -> int:
...
def ftype(self, column_number: int) -> int:
...
def fmod(self, column_number: int) -> int:
...
def fsize(self, column_number: int) -> int:
...
@property
def binary_tuples(self) -> int:
...
def get_value(self, row_number: int, column_number: int) -> Optional[bytes]:
...
@property
def nparams(self) -> int:
...
def param_type(self, param_number: int) -> int:
...
@property
def command_status(self) -> Optional[bytes]:
...
@property
def command_tuples(self) -> Optional[int]:
...
@property
def oid_value(self) -> int:
...
def set_attributes(self, descriptions: List["PGresAttDesc"]) -> None:
...
class PGcancel(Protocol):
def free(self) -> None:
...
def cancel(self) -> None:
...
class Conninfo(Protocol):
@classmethod
def get_defaults(cls) -> List["ConninfoOption"]:
...
@classmethod
def parse(cls, conninfo: bytes) -> List["ConninfoOption"]:
...
@classmethod
def _options_from_array(cls, opts: Sequence[Any]) -> List["ConninfoOption"]:
...
class Escaping(Protocol):
def __init__(self, conn: Optional[PGconn] = None):
...
def escape_literal(self, data: Buffer) -> bytes:
...
def escape_identifier(self, data: Buffer) -> bytes:
...
def escape_string(self, data: Buffer) -> bytes:
...
def escape_bytea(self, data: Buffer) -> bytes:
...
def unescape_bytea(self, data: Buffer) -> bytes:
...

View File

@ -0,0 +1,146 @@
"""
Various functionalities to make easier to work with the libpq.
"""
# Copyright (C) 2020 The Psycopg Team
import os
import sys
import logging
import ctypes.util
from typing import cast, NamedTuple, Optional, Union
from .abc import PGconn, PGresult
from ._enums import ConnStatus, TransactionStatus, PipelineStatus
from .._compat import cache
from .._encodings import pgconn_encoding
logger = logging.getLogger("psycopg.pq")
OK = ConnStatus.OK
class PGnotify(NamedTuple):
relname: bytes
be_pid: int
extra: bytes
class ConninfoOption(NamedTuple):
keyword: bytes
envvar: Optional[bytes]
compiled: Optional[bytes]
val: Optional[bytes]
label: bytes
dispchar: bytes
dispsize: int
class PGresAttDesc(NamedTuple):
name: bytes
tableid: int
columnid: int
format: int
typid: int
typlen: int
atttypmod: int
@cache
def find_libpq_full_path() -> Optional[str]:
if sys.platform == "win32":
libname = ctypes.util.find_library("libpq.dll")
elif sys.platform == "darwin":
libname = ctypes.util.find_library("libpq.dylib")
# (hopefully) temporary hack: libpq not in a standard place
# https://github.com/orgs/Homebrew/discussions/3595
# If pg_config is available and agrees, let's use its indications.
if not libname:
try:
import subprocess as sp
libdir = sp.check_output(["pg_config", "--libdir"]).strip().decode()
libname = os.path.join(libdir, "libpq.dylib")
if not os.path.exists(libname):
libname = None
except Exception as ex:
logger.debug("couldn't use pg_config to find libpq: %s", ex)
else:
libname = ctypes.util.find_library("pq")
return libname
def error_message(obj: Union[PGconn, PGresult], encoding: str = "utf8") -> str:
"""
Return an error message from a `PGconn` or `PGresult`.
The return value is a `!str` (unlike pq data which is usually `!bytes`):
use the connection encoding if available, otherwise the `!encoding`
parameter as a fallback for decoding. Don't raise exceptions on decoding
errors.
"""
bmsg: bytes
if hasattr(obj, "error_field"):
# obj is a PGresult
obj = cast(PGresult, obj)
bmsg = obj.error_message
# strip severity and whitespaces
if bmsg:
bmsg = bmsg.split(b":", 1)[-1].strip()
elif hasattr(obj, "error_message"):
# obj is a PGconn
if obj.status == OK:
encoding = pgconn_encoding(obj)
bmsg = obj.error_message
# strip severity and whitespaces
if bmsg:
bmsg = bmsg.split(b":", 1)[-1].strip()
else:
raise TypeError(f"PGconn or PGresult expected, got {type(obj).__name__}")
if bmsg:
msg = bmsg.decode(encoding, "replace")
else:
msg = "no details available"
return msg
def connection_summary(pgconn: PGconn) -> str:
"""
Return summary information on a connection.
Useful for __repr__
"""
parts = []
if pgconn.status == OK:
# Put together the [STATUS]
status = TransactionStatus(pgconn.transaction_status).name
if pgconn.pipeline_status:
status += f", pipeline={PipelineStatus(pgconn.pipeline_status).name}"
# Put together the (CONNECTION)
if not pgconn.host.startswith(b"/"):
parts.append(("host", pgconn.host.decode()))
if pgconn.port != b"5432":
parts.append(("port", pgconn.port.decode()))
if pgconn.user != pgconn.db:
parts.append(("user", pgconn.user.decode()))
parts.append(("database", pgconn.db.decode()))
else:
status = ConnStatus(pgconn.status).name
sparts = " ".join("%s=%s" % part for part in parts)
if sparts:
sparts = f" ({sparts})"
return f"[{status}]{sparts}"

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,255 @@
"""
psycopg row factories
"""
# Copyright (C) 2021 The Psycopg Team
import functools
from typing import Any, Callable, Dict, List, Optional, NamedTuple, NoReturn
from typing import TYPE_CHECKING, Sequence, Tuple, Type, TypeVar
from collections import namedtuple
from typing_extensions import TypeAlias
from . import pq
from . import errors as e
from ._compat import Protocol
from ._encodings import _as_python_identifier
if TYPE_CHECKING:
from .cursor import BaseCursor, Cursor
from .cursor_async import AsyncCursor
from psycopg.pq.abc import PGresult
COMMAND_OK = pq.ExecStatus.COMMAND_OK
TUPLES_OK = pq.ExecStatus.TUPLES_OK
SINGLE_TUPLE = pq.ExecStatus.SINGLE_TUPLE
T = TypeVar("T", covariant=True)
# Row factories
Row = TypeVar("Row", covariant=True)
class RowMaker(Protocol[Row]):
"""
Callable protocol taking a sequence of value and returning an object.
The sequence of value is what is returned from a database query, already
adapted to the right Python types. The return value is the object that your
program would like to receive: by default (`tuple_row()`) it is a simple
tuple, but it may be any type of object.
Typically, `!RowMaker` functions are returned by `RowFactory`.
"""
def __call__(self, __values: Sequence[Any]) -> Row:
...
class RowFactory(Protocol[Row]):
"""
Callable protocol taking a `~psycopg.Cursor` and returning a `RowMaker`.
A `!RowFactory` is typically called when a `!Cursor` receives a result.
This way it can inspect the cursor state (for instance the
`~psycopg.Cursor.description` attribute) and help a `!RowMaker` to create
a complete object.
For instance the `dict_row()` `!RowFactory` uses the names of the column to
define the dictionary key and returns a `!RowMaker` function which would
use the values to create a dictionary for each record.
"""
def __call__(self, __cursor: "Cursor[Any]") -> RowMaker[Row]:
...
class AsyncRowFactory(Protocol[Row]):
"""
Like `RowFactory`, taking an async cursor as argument.
"""
def __call__(self, __cursor: "AsyncCursor[Any]") -> RowMaker[Row]:
...
class BaseRowFactory(Protocol[Row]):
"""
Like `RowFactory`, taking either type of cursor as argument.
"""
def __call__(self, __cursor: "BaseCursor[Any, Any]") -> RowMaker[Row]:
...
TupleRow: TypeAlias = Tuple[Any, ...]
"""
An alias for the type returned by `tuple_row()` (i.e. a tuple of any content).
"""
DictRow: TypeAlias = Dict[str, Any]
"""
An alias for the type returned by `dict_row()`
A `!DictRow` is a dictionary with keys as string and any value returned by the
database.
"""
def tuple_row(cursor: "BaseCursor[Any, Any]") -> "RowMaker[TupleRow]":
r"""Row factory to represent rows as simple tuples.
This is the default factory, used when `~psycopg.Connection.connect()` or
`~psycopg.Connection.cursor()` are called without a `!row_factory`
parameter.
"""
# Implementation detail: make sure this is the tuple type itself, not an
# equivalent function, because the C code fast-paths on it.
return tuple
def dict_row(cursor: "BaseCursor[Any, Any]") -> "RowMaker[DictRow]":
"""Row factory to represent rows as dictionaries.
The dictionary keys are taken from the column names of the returned columns.
"""
names = _get_names(cursor)
if names is None:
return no_result
def dict_row_(values: Sequence[Any]) -> Dict[str, Any]:
return dict(zip(names, values))
return dict_row_
def namedtuple_row(
cursor: "BaseCursor[Any, Any]",
) -> "RowMaker[NamedTuple]":
"""Row factory to represent rows as `~collections.namedtuple`.
The field names are taken from the column names of the returned columns,
with some mangling to deal with invalid names.
"""
res = cursor.pgresult
if not res:
return no_result
nfields = _get_nfields(res)
if nfields is None:
return no_result
nt = _make_nt(cursor._encoding, *(res.fname(i) for i in range(nfields)))
return nt._make
@functools.lru_cache(512)
def _make_nt(enc: str, *names: bytes) -> Type[NamedTuple]:
snames = tuple(_as_python_identifier(n.decode(enc)) for n in names)
return namedtuple("Row", snames) # type: ignore[return-value]
def class_row(cls: Type[T]) -> BaseRowFactory[T]:
r"""Generate a row factory to represent rows as instances of the class `!cls`.
The class must support every output column name as a keyword parameter.
:param cls: The class to return for each row. It must support the fields
returned by the query as keyword arguments.
:rtype: `!Callable[[Cursor],` `RowMaker`\[~T]]
"""
def class_row_(cursor: "BaseCursor[Any, Any]") -> "RowMaker[T]":
names = _get_names(cursor)
if names is None:
return no_result
def class_row__(values: Sequence[Any]) -> T:
return cls(**dict(zip(names, values)))
return class_row__
return class_row_
def args_row(func: Callable[..., T]) -> BaseRowFactory[T]:
"""Generate a row factory calling `!func` with positional parameters for every row.
:param func: The function to call for each row. It must support the fields
returned by the query as positional arguments.
"""
def args_row_(cur: "BaseCursor[Any, T]") -> "RowMaker[T]":
def args_row__(values: Sequence[Any]) -> T:
return func(*values)
return args_row__
return args_row_
def kwargs_row(func: Callable[..., T]) -> BaseRowFactory[T]:
"""Generate a row factory calling `!func` with keyword parameters for every row.
:param func: The function to call for each row. It must support the fields
returned by the query as keyword arguments.
"""
def kwargs_row_(cursor: "BaseCursor[Any, T]") -> "RowMaker[T]":
names = _get_names(cursor)
if names is None:
return no_result
def kwargs_row__(values: Sequence[Any]) -> T:
return func(**dict(zip(names, values)))
return kwargs_row__
return kwargs_row_
def no_result(values: Sequence[Any]) -> NoReturn:
"""A `RowMaker` that always fail.
It can be used as return value for a `RowFactory` called with no result.
Note that the `!RowFactory` *will* be called with no result, but the
resulting `!RowMaker` never should.
"""
raise e.InterfaceError("the cursor doesn't have a result")
def _get_names(cursor: "BaseCursor[Any, Any]") -> Optional[List[str]]:
res = cursor.pgresult
if not res:
return None
nfields = _get_nfields(res)
if nfields is None:
return None
enc = cursor._encoding
return [
res.fname(i).decode(enc) for i in range(nfields) # type: ignore[union-attr]
]
def _get_nfields(res: "PGresult") -> Optional[int]:
"""
Return the number of columns in a result, if it returns tuples else None
Take into account the special case of results with zero columns.
"""
nfields = res.nfields
if (
res.status == TUPLES_OK
or res.status == SINGLE_TUPLE
# "describe" in named cursors
or (res.status == COMMAND_OK and nfields)
):
return nfields
else:
return None

View File

@ -0,0 +1,478 @@
"""
psycopg server-side cursor objects.
"""
# Copyright (C) 2020 The Psycopg Team
from typing import Any, AsyncIterator, List, Iterable, Iterator
from typing import Optional, TypeVar, TYPE_CHECKING, overload
from warnings import warn
from . import pq
from . import sql
from . import errors as e
from .abc import ConnectionType, Query, Params, PQGen
from .rows import Row, RowFactory, AsyncRowFactory
from .cursor import BaseCursor, Cursor
from .generators import execute
from .cursor_async import AsyncCursor
if TYPE_CHECKING:
from .connection import Connection
from .connection_async import AsyncConnection
DEFAULT_ITERSIZE = 100
TEXT = pq.Format.TEXT
BINARY = pq.Format.BINARY
COMMAND_OK = pq.ExecStatus.COMMAND_OK
TUPLES_OK = pq.ExecStatus.TUPLES_OK
IDLE = pq.TransactionStatus.IDLE
INTRANS = pq.TransactionStatus.INTRANS
class ServerCursorMixin(BaseCursor[ConnectionType, Row]):
"""Mixin to add ServerCursor behaviour and implementation a BaseCursor."""
__slots__ = "_name _scrollable _withhold _described itersize _format".split()
def __init__(
self,
name: str,
scrollable: Optional[bool],
withhold: bool,
):
self._name = name
self._scrollable = scrollable
self._withhold = withhold
self._described = False
self.itersize: int = DEFAULT_ITERSIZE
self._format = TEXT
def __repr__(self) -> str:
# Insert the name as the second word
parts = super().__repr__().split(None, 1)
parts.insert(1, f"{self._name!r}")
return " ".join(parts)
@property
def name(self) -> str:
"""The name of the cursor."""
return self._name
@property
def scrollable(self) -> Optional[bool]:
"""
Whether the cursor is scrollable or not.
If `!None` leave the choice to the server. Use `!True` if you want to
use `scroll()` on the cursor.
"""
return self._scrollable
@property
def withhold(self) -> bool:
"""
If the cursor can be used after the creating transaction has committed.
"""
return self._withhold
@property
def rownumber(self) -> Optional[int]:
"""Index of the next row to fetch in the current result.
`!None` if there is no result to fetch.
"""
res = self.pgresult
# command_status is empty if the result comes from
# describe_portal, which means that we have just executed the DECLARE,
# so we can assume we are at the first row.
tuples = res and (res.status == TUPLES_OK or res.command_status == b"")
return self._pos if tuples else None
def _declare_gen(
self,
query: Query,
params: Optional[Params] = None,
binary: Optional[bool] = None,
) -> PQGen[None]:
"""Generator implementing `ServerCursor.execute()`."""
query = self._make_declare_statement(query)
# If the cursor is being reused, the previous one must be closed.
if self._described:
yield from self._close_gen()
self._described = False
yield from self._start_query(query)
pgq = self._convert_query(query, params)
self._execute_send(pgq, force_extended=True)
results = yield from execute(self._conn.pgconn)
if results[-1].status != COMMAND_OK:
self._raise_for_result(results[-1])
# Set the format, which will be used by describe and fetch operations
if binary is None:
self._format = self.format
else:
self._format = BINARY if binary else TEXT
# The above result only returned COMMAND_OK. Get the cursor shape
yield from self._describe_gen()
def _describe_gen(self) -> PQGen[None]:
self._pgconn.send_describe_portal(self._name.encode(self._encoding))
results = yield from execute(self._pgconn)
self._check_results(results)
self._results = results
self._select_current_result(0, format=self._format)
self._described = True
def _close_gen(self) -> PQGen[None]:
ts = self._conn.pgconn.transaction_status
# if the connection is not in a sane state, don't even try
if ts != IDLE and ts != INTRANS:
return
# If we are IDLE, a WITHOUT HOLD cursor will surely have gone already.
if not self._withhold and ts == IDLE:
return
# if we didn't declare the cursor ourselves we still have to close it
# but we must make sure it exists.
if not self._described:
query = sql.SQL(
"SELECT 1 FROM pg_catalog.pg_cursors WHERE name = {}"
).format(sql.Literal(self._name))
res = yield from self._conn._exec_command(query)
# pipeline mode otherwise, unsupported here.
assert res is not None
if res.ntuples == 0:
return
query = sql.SQL("CLOSE {}").format(sql.Identifier(self._name))
yield from self._conn._exec_command(query)
def _fetch_gen(self, num: Optional[int]) -> PQGen[List[Row]]:
if self.closed:
raise e.InterfaceError("the cursor is closed")
# If we are stealing the cursor, make sure we know its shape
if not self._described:
yield from self._start_query()
yield from self._describe_gen()
query = sql.SQL("FETCH FORWARD {} FROM {}").format(
sql.SQL("ALL") if num is None else sql.Literal(num),
sql.Identifier(self._name),
)
res = yield from self._conn._exec_command(query, result_format=self._format)
# pipeline mode otherwise, unsupported here.
assert res is not None
self.pgresult = res
self._tx.set_pgresult(res, set_loaders=False)
return self._tx.load_rows(0, res.ntuples, self._make_row)
def _scroll_gen(self, value: int, mode: str) -> PQGen[None]:
if mode not in ("relative", "absolute"):
raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'")
query = sql.SQL("MOVE{} {} FROM {}").format(
sql.SQL(" ABSOLUTE" if mode == "absolute" else ""),
sql.Literal(value),
sql.Identifier(self._name),
)
yield from self._conn._exec_command(query)
def _make_declare_statement(self, query: Query) -> sql.Composed:
if isinstance(query, bytes):
query = query.decode(self._encoding)
if not isinstance(query, sql.Composable):
query = sql.SQL(query)
parts = [
sql.SQL("DECLARE"),
sql.Identifier(self._name),
]
if self._scrollable is not None:
parts.append(sql.SQL("SCROLL" if self._scrollable else "NO SCROLL"))
parts.append(sql.SQL("CURSOR"))
if self._withhold:
parts.append(sql.SQL("WITH HOLD"))
parts.append(sql.SQL("FOR"))
parts.append(query)
return sql.SQL(" ").join(parts)
class ServerCursor(ServerCursorMixin["Connection[Any]", Row], Cursor[Row]):
__module__ = "psycopg"
__slots__ = ()
_Self = TypeVar("_Self", bound="ServerCursor[Any]")
@overload
def __init__(
self: "ServerCursor[Row]",
connection: "Connection[Row]",
name: str,
*,
scrollable: Optional[bool] = None,
withhold: bool = False,
):
...
@overload
def __init__(
self: "ServerCursor[Row]",
connection: "Connection[Any]",
name: str,
*,
row_factory: RowFactory[Row],
scrollable: Optional[bool] = None,
withhold: bool = False,
):
...
def __init__(
self,
connection: "Connection[Any]",
name: str,
*,
row_factory: Optional[RowFactory[Row]] = None,
scrollable: Optional[bool] = None,
withhold: bool = False,
):
Cursor.__init__(
self, connection, row_factory=row_factory or connection.row_factory
)
ServerCursorMixin.__init__(self, name, scrollable, withhold)
def __del__(self) -> None:
if not self.closed:
warn(
f"the server-side cursor {self} was deleted while still open."
" Please use 'with' or '.close()' to close the cursor properly",
ResourceWarning,
)
def close(self) -> None:
"""
Close the current cursor and free associated resources.
"""
with self._conn.lock:
if self.closed:
return
if not self._conn.closed:
self._conn.wait(self._close_gen())
super().close()
def execute(
self: _Self,
query: Query,
params: Optional[Params] = None,
*,
binary: Optional[bool] = None,
**kwargs: Any,
) -> _Self:
"""
Open a cursor to execute a query to the database.
"""
if kwargs:
raise TypeError(f"keyword not supported: {list(kwargs)[0]}")
if self._pgconn.pipeline_status:
raise e.NotSupportedError(
"server-side cursors not supported in pipeline mode"
)
try:
with self._conn.lock:
self._conn.wait(self._declare_gen(query, params, binary))
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
return self
def executemany(
self,
query: Query,
params_seq: Iterable[Params],
*,
returning: bool = True,
) -> None:
"""Method not implemented for server-side cursors."""
raise e.NotSupportedError("executemany not supported on server-side cursors")
def fetchone(self) -> Optional[Row]:
with self._conn.lock:
recs = self._conn.wait(self._fetch_gen(1))
if recs:
self._pos += 1
return recs[0]
else:
return None
def fetchmany(self, size: int = 0) -> List[Row]:
if not size:
size = self.arraysize
with self._conn.lock:
recs = self._conn.wait(self._fetch_gen(size))
self._pos += len(recs)
return recs
def fetchall(self) -> List[Row]:
with self._conn.lock:
recs = self._conn.wait(self._fetch_gen(None))
self._pos += len(recs)
return recs
def __iter__(self) -> Iterator[Row]:
while True:
with self._conn.lock:
recs = self._conn.wait(self._fetch_gen(self.itersize))
for rec in recs:
self._pos += 1
yield rec
if len(recs) < self.itersize:
break
def scroll(self, value: int, mode: str = "relative") -> None:
with self._conn.lock:
self._conn.wait(self._scroll_gen(value, mode))
# Postgres doesn't have a reliable way to report a cursor out of bound
if mode == "relative":
self._pos += value
else:
self._pos = value
class AsyncServerCursor(
ServerCursorMixin["AsyncConnection[Any]", Row], AsyncCursor[Row]
):
__module__ = "psycopg"
__slots__ = ()
_Self = TypeVar("_Self", bound="AsyncServerCursor[Any]")
@overload
def __init__(
self: "AsyncServerCursor[Row]",
connection: "AsyncConnection[Row]",
name: str,
*,
scrollable: Optional[bool] = None,
withhold: bool = False,
):
...
@overload
def __init__(
self: "AsyncServerCursor[Row]",
connection: "AsyncConnection[Any]",
name: str,
*,
row_factory: AsyncRowFactory[Row],
scrollable: Optional[bool] = None,
withhold: bool = False,
):
...
def __init__(
self,
connection: "AsyncConnection[Any]",
name: str,
*,
row_factory: Optional[AsyncRowFactory[Row]] = None,
scrollable: Optional[bool] = None,
withhold: bool = False,
):
AsyncCursor.__init__(
self, connection, row_factory=row_factory or connection.row_factory
)
ServerCursorMixin.__init__(self, name, scrollable, withhold)
def __del__(self) -> None:
if not self.closed:
warn(
f"the server-side cursor {self} was deleted while still open."
" Please use 'with' or '.close()' to close the cursor properly",
ResourceWarning,
)
async def close(self) -> None:
async with self._conn.lock:
if self.closed:
return
if not self._conn.closed:
await self._conn.wait(self._close_gen())
await super().close()
async def execute(
self: _Self,
query: Query,
params: Optional[Params] = None,
*,
binary: Optional[bool] = None,
**kwargs: Any,
) -> _Self:
if kwargs:
raise TypeError(f"keyword not supported: {list(kwargs)[0]}")
if self._pgconn.pipeline_status:
raise e.NotSupportedError(
"server-side cursors not supported in pipeline mode"
)
try:
async with self._conn.lock:
await self._conn.wait(self._declare_gen(query, params, binary))
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
return self
async def executemany(
self,
query: Query,
params_seq: Iterable[Params],
*,
returning: bool = True,
) -> None:
raise e.NotSupportedError("executemany not supported on server-side cursors")
async def fetchone(self) -> Optional[Row]:
async with self._conn.lock:
recs = await self._conn.wait(self._fetch_gen(1))
if recs:
self._pos += 1
return recs[0]
else:
return None
async def fetchmany(self, size: int = 0) -> List[Row]:
if not size:
size = self.arraysize
async with self._conn.lock:
recs = await self._conn.wait(self._fetch_gen(size))
self._pos += len(recs)
return recs
async def fetchall(self) -> List[Row]:
async with self._conn.lock:
recs = await self._conn.wait(self._fetch_gen(None))
self._pos += len(recs)
return recs
async def __aiter__(self) -> AsyncIterator[Row]:
while True:
async with self._conn.lock:
recs = await self._conn.wait(self._fetch_gen(self.itersize))
for rec in recs:
self._pos += 1
yield rec
if len(recs) < self.itersize:
break
async def scroll(self, value: int, mode: str = "relative") -> None:
async with self._conn.lock:
await self._conn.wait(self._scroll_gen(value, mode))

View File

@ -0,0 +1,467 @@
"""
SQL composition utility module
"""
# Copyright (C) 2020 The Psycopg Team
import codecs
import string
from abc import ABC, abstractmethod
from typing import Any, Iterator, Iterable, List, Optional, Sequence, Union
from .pq import Escaping
from .abc import AdaptContext
from .adapt import Transformer, PyFormat
from ._compat import LiteralString
from ._encodings import conn_encoding
def quote(obj: Any, context: Optional[AdaptContext] = None) -> str:
"""
Adapt a Python object to a quoted SQL string.
Use this function only if you absolutely want to convert a Python string to
an SQL quoted literal to use e.g. to generate batch SQL and you won't have
a connection available when you will need to use it.
This function is relatively inefficient, because it doesn't cache the
adaptation rules. If you pass a `!context` you can adapt the adaptation
rules used, otherwise only global rules are used.
"""
return Literal(obj).as_string(context)
class Composable(ABC):
"""
Abstract base class for objects that can be used to compose an SQL string.
`!Composable` objects can be passed directly to
`~psycopg.Cursor.execute()`, `~psycopg.Cursor.executemany()`,
`~psycopg.Cursor.copy()` in place of the query string.
`!Composable` objects can be joined using the ``+`` operator: the result
will be a `Composed` instance containing the objects joined. The operator
``*`` is also supported with an integer argument: the result is a
`!Composed` instance containing the left argument repeated as many times as
requested.
"""
def __init__(self, obj: Any):
self._obj = obj
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self._obj!r})"
@abstractmethod
def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
"""
Return the value of the object as bytes.
:param context: the context to evaluate the object into.
:type context: `connection` or `cursor`
The method is automatically invoked by `~psycopg.Cursor.execute()`,
`~psycopg.Cursor.executemany()`, `~psycopg.Cursor.copy()` if a
`!Composable` is passed instead of the query string.
"""
raise NotImplementedError
def as_string(self, context: Optional[AdaptContext]) -> str:
"""
Return the value of the object as string.
:param context: the context to evaluate the string into.
:type context: `connection` or `cursor`
"""
conn = context.connection if context else None
enc = conn_encoding(conn)
b = self.as_bytes(context)
if isinstance(b, bytes):
return b.decode(enc)
else:
# buffer object
return codecs.lookup(enc).decode(b)[0]
def __add__(self, other: "Composable") -> "Composed":
if isinstance(other, Composed):
return Composed([self]) + other
if isinstance(other, Composable):
return Composed([self]) + Composed([other])
else:
return NotImplemented
def __mul__(self, n: int) -> "Composed":
return Composed([self] * n)
def __eq__(self, other: Any) -> bool:
return type(self) is type(other) and self._obj == other._obj
def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)
class Composed(Composable):
"""
A `Composable` object made of a sequence of `!Composable`.
The object is usually created using `!Composable` operators and methods.
However it is possible to create a `!Composed` directly specifying a
sequence of objects as arguments: if they are not `!Composable` they will
be wrapped in a `Literal`.
Example::
>>> comp = sql.Composed(
... [sql.SQL("INSERT INTO "), sql.Identifier("table")])
>>> print(comp.as_string(conn))
INSERT INTO "table"
`!Composed` objects are iterable (so they can be used in `SQL.join` for
instance).
"""
_obj: List[Composable]
def __init__(self, seq: Sequence[Any]):
seq = [obj if isinstance(obj, Composable) else Literal(obj) for obj in seq]
super().__init__(seq)
def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
return b"".join(obj.as_bytes(context) for obj in self._obj)
def __iter__(self) -> Iterator[Composable]:
return iter(self._obj)
def __add__(self, other: Composable) -> "Composed":
if isinstance(other, Composed):
return Composed(self._obj + other._obj)
if isinstance(other, Composable):
return Composed(self._obj + [other])
else:
return NotImplemented
def join(self, joiner: Union["SQL", LiteralString]) -> "Composed":
"""
Return a new `!Composed` interposing the `!joiner` with the `!Composed` items.
The `!joiner` must be a `SQL` or a string which will be interpreted as
an `SQL`.
Example::
>>> fields = sql.Identifier('foo') + sql.Identifier('bar') # a Composed
>>> print(fields.join(', ').as_string(conn))
"foo", "bar"
"""
if isinstance(joiner, str):
joiner = SQL(joiner)
elif not isinstance(joiner, SQL):
raise TypeError(
"Composed.join() argument must be strings or SQL,"
f" got {joiner!r} instead"
)
return joiner.join(self._obj)
class SQL(Composable):
"""
A `Composable` representing a snippet of SQL statement.
`!SQL` exposes `join()` and `format()` methods useful to create a template
where to merge variable parts of a query (for instance field or table
names).
The `!obj` string doesn't undergo any form of escaping, so it is not
suitable to represent variable identifiers or values: you should only use
it to pass constant strings representing templates or snippets of SQL
statements; use other objects such as `Identifier` or `Literal` to
represent variable parts.
Example::
>>> query = sql.SQL("SELECT {0} FROM {1}").format(
... sql.SQL(', ').join([sql.Identifier('foo'), sql.Identifier('bar')]),
... sql.Identifier('table'))
>>> print(query.as_string(conn))
SELECT "foo", "bar" FROM "table"
"""
_obj: LiteralString
_formatter = string.Formatter()
def __init__(self, obj: LiteralString):
super().__init__(obj)
if not isinstance(obj, str):
raise TypeError(f"SQL values must be strings, got {obj!r} instead")
def as_string(self, context: Optional[AdaptContext]) -> str:
return self._obj
def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
enc = "utf-8"
if context:
enc = conn_encoding(context.connection)
return self._obj.encode(enc)
def format(self, *args: Any, **kwargs: Any) -> Composed:
"""
Merge `Composable` objects into a template.
:param args: parameters to replace to numbered (``{0}``, ``{1}``) or
auto-numbered (``{}``) placeholders
:param kwargs: parameters to replace to named (``{name}``) placeholders
:return: the union of the `!SQL` string with placeholders replaced
:rtype: `Composed`
The method is similar to the Python `str.format()` method: the string
template supports auto-numbered (``{}``), numbered (``{0}``,
``{1}``...), and named placeholders (``{name}``), with positional
arguments replacing the numbered placeholders and keywords replacing
the named ones. However placeholder modifiers (``{0!r}``, ``{0:<10}``)
are not supported.
If a `!Composable` objects is passed to the template it will be merged
according to its `as_string()` method. If any other Python object is
passed, it will be wrapped in a `Literal` object and so escaped
according to SQL rules.
Example::
>>> print(sql.SQL("SELECT * FROM {} WHERE {} = %s")
... .format(sql.Identifier('people'), sql.Identifier('id'))
... .as_string(conn))
SELECT * FROM "people" WHERE "id" = %s
>>> print(sql.SQL("SELECT * FROM {tbl} WHERE name = {name}")
... .format(tbl=sql.Identifier('people'), name="O'Rourke"))
... .as_string(conn))
SELECT * FROM "people" WHERE name = 'O''Rourke'
"""
rv: List[Composable] = []
autonum: Optional[int] = 0
# TODO: this is probably not the right way to whitelist pre
# pyre complains. Will wait for mypy to complain too to fix.
pre: LiteralString
for pre, name, spec, conv in self._formatter.parse(self._obj):
if spec:
raise ValueError("no format specification supported by SQL")
if conv:
raise ValueError("no format conversion supported by SQL")
if pre:
rv.append(SQL(pre))
if name is None:
continue
if name.isdigit():
if autonum:
raise ValueError(
"cannot switch from automatic field numbering to manual"
)
rv.append(args[int(name)])
autonum = None
elif not name:
if autonum is None:
raise ValueError(
"cannot switch from manual field numbering to automatic"
)
rv.append(args[autonum])
autonum += 1
else:
rv.append(kwargs[name])
return Composed(rv)
def join(self, seq: Iterable[Composable]) -> Composed:
"""
Join a sequence of `Composable`.
:param seq: the elements to join.
:type seq: iterable of `!Composable`
Use the `!SQL` object's string to separate the elements in `!seq`.
Note that `Composed` objects are iterable too, so they can be used as
argument for this method.
Example::
>>> snip = sql.SQL(', ').join(
... sql.Identifier(n) for n in ['foo', 'bar', 'baz'])
>>> print(snip.as_string(conn))
"foo", "bar", "baz"
"""
rv = []
it = iter(seq)
try:
rv.append(next(it))
except StopIteration:
pass
else:
for i in it:
rv.append(self)
rv.append(i)
return Composed(rv)
class Identifier(Composable):
"""
A `Composable` representing an SQL identifier or a dot-separated sequence.
Identifiers usually represent names of database objects, such as tables or
fields. PostgreSQL identifiers follow `different rules`__ than SQL string
literals for escaping (e.g. they use double quotes instead of single).
.. __: https://www.postgresql.org/docs/current/sql-syntax-lexical.html# \
SQL-SYNTAX-IDENTIFIERS
Example::
>>> t1 = sql.Identifier("foo")
>>> t2 = sql.Identifier("ba'r")
>>> t3 = sql.Identifier('ba"z')
>>> print(sql.SQL(', ').join([t1, t2, t3]).as_string(conn))
"foo", "ba'r", "ba""z"
Multiple strings can be passed to the object to represent a qualified name,
i.e. a dot-separated sequence of identifiers.
Example::
>>> query = sql.SQL("SELECT {} FROM {}").format(
... sql.Identifier("table", "field"),
... sql.Identifier("schema", "table"))
>>> print(query.as_string(conn))
SELECT "table"."field" FROM "schema"."table"
"""
_obj: Sequence[str]
def __init__(self, *strings: str):
# init super() now to make the __repr__ not explode in case of error
super().__init__(strings)
if not strings:
raise TypeError("Identifier cannot be empty")
for s in strings:
if not isinstance(s, str):
raise TypeError(
f"SQL identifier parts must be strings, got {s!r} instead"
)
def __repr__(self) -> str:
return f"{self.__class__.__name__}({', '.join(map(repr, self._obj))})"
def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
conn = context.connection if context else None
if not conn:
raise ValueError("a connection is necessary for Identifier")
esc = Escaping(conn.pgconn)
enc = conn_encoding(conn)
escs = [esc.escape_identifier(s.encode(enc)) for s in self._obj]
return b".".join(escs)
class Literal(Composable):
"""
A `Composable` representing an SQL value to include in a query.
Usually you will want to include placeholders in the query and pass values
as `~cursor.execute()` arguments. If however you really really need to
include a literal value in the query you can use this object.
The string returned by `!as_string()` follows the normal :ref:`adaptation
rules <types-adaptation>` for Python objects.
Example::
>>> s1 = sql.Literal("fo'o")
>>> s2 = sql.Literal(42)
>>> s3 = sql.Literal(date(2000, 1, 1))
>>> print(sql.SQL(', ').join([s1, s2, s3]).as_string(conn))
'fo''o', 42, '2000-01-01'::date
"""
def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
tx = Transformer.from_context(context)
return tx.as_literal(self._obj)
class Placeholder(Composable):
"""A `Composable` representing a placeholder for query parameters.
If the name is specified, generate a named placeholder (e.g. ``%(name)s``,
``%(name)b``), otherwise generate a positional placeholder (e.g. ``%s``,
``%b``).
The object is useful to generate SQL queries with a variable number of
arguments.
Examples::
>>> names = ['foo', 'bar', 'baz']
>>> q1 = sql.SQL("INSERT INTO my_table ({}) VALUES ({})").format(
... sql.SQL(', ').join(map(sql.Identifier, names)),
... sql.SQL(', ').join(sql.Placeholder() * len(names)))
>>> print(q1.as_string(conn))
INSERT INTO my_table ("foo", "bar", "baz") VALUES (%s, %s, %s)
>>> q2 = sql.SQL("INSERT INTO my_table ({}) VALUES ({})").format(
... sql.SQL(', ').join(map(sql.Identifier, names)),
... sql.SQL(', ').join(map(sql.Placeholder, names)))
>>> print(q2.as_string(conn))
INSERT INTO my_table ("foo", "bar", "baz") VALUES (%(foo)s, %(bar)s, %(baz)s)
"""
def __init__(self, name: str = "", format: Union[str, PyFormat] = PyFormat.AUTO):
super().__init__(name)
if not isinstance(name, str):
raise TypeError(f"expected string as name, got {name!r}")
if ")" in name:
raise ValueError(f"invalid name: {name!r}")
if type(format) is str:
format = PyFormat(format)
if not isinstance(format, PyFormat):
raise TypeError(
f"expected PyFormat as format, got {type(format).__name__!r}"
)
self._format: PyFormat = format
def __repr__(self) -> str:
parts = []
if self._obj:
parts.append(repr(self._obj))
if self._format is not PyFormat.AUTO:
parts.append(f"format={self._format.name}")
return f"{self.__class__.__name__}({', '.join(parts)})"
def as_string(self, context: Optional[AdaptContext]) -> str:
code = self._format.value
return f"%({self._obj}){code}" if self._obj else f"%{code}"
def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
conn = context.connection if context else None
enc = conn_encoding(conn)
return self.as_string(context).encode(enc)
# Literals
NULL = SQL("NULL")
DEFAULT = SQL("DEFAULT")

View File

@ -0,0 +1,291 @@
"""
Transaction context managers returned by Connection.transaction()
"""
# Copyright (C) 2020 The Psycopg Team
import logging
from types import TracebackType
from typing import Generic, Iterator, Optional, Type, Union, TypeVar, TYPE_CHECKING
from . import pq
from . import sql
from . import errors as e
from .abc import ConnectionType, PQGen
from .pq.misc import connection_summary
if TYPE_CHECKING:
from typing import Any
from .connection import Connection
from .connection_async import AsyncConnection
IDLE = pq.TransactionStatus.IDLE
OK = pq.ConnStatus.OK
logger = logging.getLogger(__name__)
class Rollback(Exception):
"""
Exit the current `Transaction` context immediately and rollback any changes
made within this context.
If a transaction context is specified in the constructor, rollback
enclosing transactions contexts up to and including the one specified.
"""
__module__ = "psycopg"
def __init__(
self,
transaction: Union["Transaction", "AsyncTransaction", None] = None,
):
self.transaction = transaction
def __repr__(self) -> str:
return f"{self.__class__.__qualname__}({self.transaction!r})"
class OutOfOrderTransactionNesting(e.ProgrammingError):
"""Out-of-order transaction nesting detected"""
class BaseTransaction(Generic[ConnectionType]):
def __init__(
self,
connection: ConnectionType,
savepoint_name: Optional[str] = None,
force_rollback: bool = False,
):
self._conn = connection
self.pgconn = self._conn.pgconn
self._savepoint_name = savepoint_name or ""
self.force_rollback = force_rollback
self._entered = self._exited = False
self._outer_transaction = False
self._stack_index = -1
@property
def savepoint_name(self) -> Optional[str]:
"""
The name of the savepoint; `!None` if handling the main transaction.
"""
# Yes, it may change on __enter__. No, I don't care, because the
# un-entered state is outside the public interface.
return self._savepoint_name
def __repr__(self) -> str:
cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
info = connection_summary(self.pgconn)
if not self._entered:
status = "inactive"
elif not self._exited:
status = "active"
else:
status = "terminated"
sp = f"{self.savepoint_name!r} " if self.savepoint_name else ""
return f"<{cls} {sp}({status}) {info} at 0x{id(self):x}>"
def _enter_gen(self) -> PQGen[None]:
if self._entered:
raise TypeError("transaction blocks can be used only once")
self._entered = True
self._push_savepoint()
for command in self._get_enter_commands():
yield from self._conn._exec_command(command)
def _exit_gen(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> PQGen[bool]:
if not exc_val and not self.force_rollback:
yield from self._commit_gen()
return False
else:
# try to rollback, but if there are problems (connection in a bad
# state) just warn without clobbering the exception bubbling up.
try:
return (yield from self._rollback_gen(exc_val))
except OutOfOrderTransactionNesting:
# Clobber an exception happened in the block with the exception
# caused by out-of-order transaction detected, so make the
# behaviour consistent with _commit_gen and to make sure the
# user fixes this condition, which is unrelated from
# operational error that might arise in the block.
raise
except Exception as exc2:
logger.warning("error ignored in rollback of %s: %s", self, exc2)
return False
def _commit_gen(self) -> PQGen[None]:
ex = self._pop_savepoint("commit")
self._exited = True
if ex:
raise ex
for command in self._get_commit_commands():
yield from self._conn._exec_command(command)
def _rollback_gen(self, exc_val: Optional[BaseException]) -> PQGen[bool]:
if isinstance(exc_val, Rollback):
logger.debug(f"{self._conn}: Explicit rollback from: ", exc_info=True)
ex = self._pop_savepoint("rollback")
self._exited = True
if ex:
raise ex
for command in self._get_rollback_commands():
yield from self._conn._exec_command(command)
if isinstance(exc_val, Rollback):
if not exc_val.transaction or exc_val.transaction is self:
return True # Swallow the exception
return False
def _get_enter_commands(self) -> Iterator[bytes]:
if self._outer_transaction:
yield self._conn._get_tx_start_command()
if self._savepoint_name:
yield (
sql.SQL("SAVEPOINT {}")
.format(sql.Identifier(self._savepoint_name))
.as_bytes(self._conn)
)
def _get_commit_commands(self) -> Iterator[bytes]:
if self._savepoint_name and not self._outer_transaction:
yield (
sql.SQL("RELEASE {}")
.format(sql.Identifier(self._savepoint_name))
.as_bytes(self._conn)
)
if self._outer_transaction:
assert not self._conn._num_transactions
yield b"COMMIT"
def _get_rollback_commands(self) -> Iterator[bytes]:
if self._savepoint_name and not self._outer_transaction:
yield (
sql.SQL("ROLLBACK TO {n}")
.format(n=sql.Identifier(self._savepoint_name))
.as_bytes(self._conn)
)
yield (
sql.SQL("RELEASE {n}")
.format(n=sql.Identifier(self._savepoint_name))
.as_bytes(self._conn)
)
if self._outer_transaction:
assert not self._conn._num_transactions
yield b"ROLLBACK"
# Also clear the prepared statements cache.
if self._conn._prepared.clear():
yield from self._conn._prepared.get_maintenance_commands()
def _push_savepoint(self) -> None:
"""
Push the transaction on the connection transactions stack.
Also set the internal state of the object and verify consistency.
"""
self._outer_transaction = self.pgconn.transaction_status == IDLE
if self._outer_transaction:
# outer transaction: if no name it's only a begin, else
# there will be an additional savepoint
assert not self._conn._num_transactions
else:
# inner transaction: it always has a name
if not self._savepoint_name:
self._savepoint_name = f"_pg3_{self._conn._num_transactions + 1}"
self._stack_index = self._conn._num_transactions
self._conn._num_transactions += 1
def _pop_savepoint(self, action: str) -> Optional[Exception]:
"""
Pop the transaction from the connection transactions stack.
Also verify the state consistency.
"""
self._conn._num_transactions -= 1
if self._conn._num_transactions == self._stack_index:
return None
return OutOfOrderTransactionNesting(
f"transaction {action} at the wrong nesting level: {self}"
)
class Transaction(BaseTransaction["Connection[Any]"]):
"""
Returned by `Connection.transaction()` to handle a transaction block.
"""
__module__ = "psycopg"
_Self = TypeVar("_Self", bound="Transaction")
@property
def connection(self) -> "Connection[Any]":
"""The connection the object is managing."""
return self._conn
def __enter__(self: _Self) -> _Self:
with self._conn.lock:
self._conn.wait(self._enter_gen())
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
if self.pgconn.status == OK:
with self._conn.lock:
return self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb))
else:
return False
class AsyncTransaction(BaseTransaction["AsyncConnection[Any]"]):
"""
Returned by `AsyncConnection.transaction()` to handle a transaction block.
"""
__module__ = "psycopg"
_Self = TypeVar("_Self", bound="AsyncTransaction")
@property
def connection(self) -> "AsyncConnection[Any]":
return self._conn
async def __aenter__(self: _Self) -> _Self:
async with self._conn.lock:
await self._conn.wait(self._enter_gen())
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
if self.pgconn.status == OK:
async with self._conn.lock:
return await self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb))
else:
return False

View File

@ -0,0 +1,11 @@
"""
psycopg types package
"""
# Copyright (C) 2020 The Psycopg Team
from .. import _typeinfo
# Exposed here
TypeInfo = _typeinfo.TypeInfo
TypesRegistry = _typeinfo.TypesRegistry

View File

@ -0,0 +1,469 @@
"""
Adapters for arrays
"""
# Copyright (C) 2020 The Psycopg Team
import re
import struct
from typing import Any, cast, Callable, List, Optional, Pattern, Set, Tuple, Type
from .. import pq
from .. import errors as e
from .. import postgres
from ..abc import AdaptContext, Buffer, Dumper, DumperKey, NoneType, Loader, Transformer
from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat
from .._compat import cache, prod
from .._struct import pack_len, unpack_len
from .._cmodule import _psycopg
from ..postgres import TEXT_OID, INVALID_OID
from .._typeinfo import TypeInfo
_struct_head = struct.Struct("!III") # ndims, hasnull, elem oid
_pack_head = cast(Callable[[int, int, int], bytes], _struct_head.pack)
_unpack_head = cast(Callable[[Buffer], Tuple[int, int, int]], _struct_head.unpack_from)
_struct_dim = struct.Struct("!II") # dim, lower bound
_pack_dim = cast(Callable[[int, int], bytes], _struct_dim.pack)
_unpack_dim = cast(Callable[[Buffer, int], Tuple[int, int]], _struct_dim.unpack_from)
TEXT_ARRAY_OID = postgres.types["text"].array_oid
PY_TEXT = PyFormat.TEXT
PQ_BINARY = pq.Format.BINARY
class BaseListDumper(RecursiveDumper):
element_oid = 0
def __init__(self, cls: type, context: Optional[AdaptContext] = None):
if cls is NoneType:
cls = list
super().__init__(cls, context)
self.sub_dumper: Optional[Dumper] = None
if self.element_oid and context:
sdclass = context.adapters.get_dumper_by_oid(self.element_oid, self.format)
self.sub_dumper = sdclass(NoneType, context)
def _find_list_element(self, L: List[Any], format: PyFormat) -> Any:
"""
Find the first non-null element of an eventually nested list
"""
items = list(self._flatiter(L, set()))
types = {type(item): item for item in items}
if not types:
return None
if len(types) == 1:
t, v = types.popitem()
else:
# More than one type in the list. It might be still good, as long
# as they dump with the same oid (e.g. IPv4Network, IPv6Network).
dumpers = [self._tx.get_dumper(item, format) for item in types.values()]
oids = set(d.oid for d in dumpers)
if len(oids) == 1:
t, v = types.popitem()
else:
raise e.DataError(
"cannot dump lists of mixed types;"
f" got: {', '.join(sorted(t.__name__ for t in types))}"
)
# Checking for precise type. If the type is a subclass (e.g. Int4)
# we assume the user knows what type they are passing.
if t is not int:
return v
# If we got an int, let's see what is the biggest one in order to
# choose the smallest OID and allow Postgres to do the right cast.
imax: int = max(items)
imin: int = min(items)
if imin >= 0:
return imax
else:
return max(imax, -imin - 1)
def _flatiter(self, L: List[Any], seen: Set[int]) -> Any:
if id(L) in seen:
raise e.DataError("cannot dump a recursive list")
seen.add(id(L))
for item in L:
if type(item) is list:
yield from self._flatiter(item, seen)
elif item is not None:
yield item
return None
def _get_base_type_info(self, base_oid: int) -> TypeInfo:
"""
Return info about the base type.
Return text info as fallback.
"""
if base_oid:
info = self._tx.adapters.types.get(base_oid)
if info:
return info
return self._tx.adapters.types["text"]
class ListDumper(BaseListDumper):
delimiter = b","
def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey:
if self.oid:
return self.cls
item = self._find_list_element(obj, format)
if item is None:
return self.cls
sd = self._tx.get_dumper(item, format)
return (self.cls, sd.get_key(item, format))
def upgrade(self, obj: List[Any], format: PyFormat) -> "BaseListDumper":
# If we have an oid we don't need to upgrade
if self.oid:
return self
item = self._find_list_element(obj, format)
if item is None:
# Empty lists can only be dumped as text if the type is unknown.
return self
sd = self._tx.get_dumper(item, PyFormat.from_pq(self.format))
dumper = type(self)(self.cls, self._tx)
dumper.sub_dumper = sd
# We consider an array of unknowns as unknown, so we can dump empty
# lists or lists containing only None elements.
if sd.oid != INVALID_OID:
info = self._get_base_type_info(sd.oid)
dumper.oid = info.array_oid or TEXT_ARRAY_OID
dumper.delimiter = info.delimiter.encode()
else:
dumper.oid = INVALID_OID
return dumper
# Double quotes and backslashes embedded in element values will be
# backslash-escaped.
_re_esc = re.compile(rb'(["\\])')
def dump(self, obj: List[Any]) -> bytes:
tokens: List[Buffer] = []
needs_quotes = _get_needs_quotes_regexp(self.delimiter).search
def dump_list(obj: List[Any]) -> None:
if not obj:
tokens.append(b"{}")
return
tokens.append(b"{")
for item in obj:
if isinstance(item, list):
dump_list(item)
elif item is not None:
ad = self._dump_item(item)
if needs_quotes(ad):
if not isinstance(ad, bytes):
ad = bytes(ad)
ad = b'"' + self._re_esc.sub(rb"\\\1", ad) + b'"'
tokens.append(ad)
else:
tokens.append(b"NULL")
tokens.append(self.delimiter)
tokens[-1] = b"}"
dump_list(obj)
return b"".join(tokens)
def _dump_item(self, item: Any) -> Buffer:
if self.sub_dumper:
return self.sub_dumper.dump(item)
else:
return self._tx.get_dumper(item, PY_TEXT).dump(item)
@cache
def _get_needs_quotes_regexp(delimiter: bytes) -> Pattern[bytes]:
"""Return a regexp to recognise when a value needs quotes
from https://www.postgresql.org/docs/current/arrays.html#ARRAYS-IO
The array output routine will put double quotes around element values if
they are empty strings, contain curly braces, delimiter characters,
double quotes, backslashes, or white space, or match the word NULL.
"""
return re.compile(
rb"""(?xi)
^$ # the empty string
| ["{}%s\\\s] # or a char to escape
| ^null$ # or the word NULL
"""
% delimiter
)
class ListBinaryDumper(BaseListDumper):
format = pq.Format.BINARY
def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey:
if self.oid:
return self.cls
item = self._find_list_element(obj, format)
if item is None:
return (self.cls,)
sd = self._tx.get_dumper(item, format)
return (self.cls, sd.get_key(item, format))
def upgrade(self, obj: List[Any], format: PyFormat) -> "BaseListDumper":
# If we have an oid we don't need to upgrade
if self.oid:
return self
item = self._find_list_element(obj, format)
if item is None:
return ListDumper(self.cls, self._tx)
sd = self._tx.get_dumper(item, format.from_pq(self.format))
dumper = type(self)(self.cls, self._tx)
dumper.sub_dumper = sd
info = self._get_base_type_info(sd.oid)
dumper.oid = info.array_oid or TEXT_ARRAY_OID
return dumper
def dump(self, obj: List[Any]) -> bytes:
# Postgres won't take unknown for element oid: fall back on text
sub_oid = self.sub_dumper and self.sub_dumper.oid or TEXT_OID
if not obj:
return _pack_head(0, 0, sub_oid)
data: List[Buffer] = [b"", b""] # placeholders to avoid a resize
dims: List[int] = []
hasnull = 0
def calc_dims(L: List[Any]) -> None:
if isinstance(L, self.cls):
if not L:
raise e.DataError("lists cannot contain empty lists")
dims.append(len(L))
calc_dims(L[0])
calc_dims(obj)
def dump_list(L: List[Any], dim: int) -> None:
nonlocal hasnull
if len(L) != dims[dim]:
raise e.DataError("nested lists have inconsistent lengths")
if dim == len(dims) - 1:
for item in L:
if item is not None:
# If we get here, the sub_dumper must have been set
ad = self.sub_dumper.dump(item) # type: ignore[union-attr]
data.append(pack_len(len(ad)))
data.append(ad)
else:
hasnull = 1
data.append(b"\xff\xff\xff\xff")
else:
for item in L:
if not isinstance(item, self.cls):
raise e.DataError("nested lists have inconsistent depths")
dump_list(item, dim + 1) # type: ignore
dump_list(obj, 0)
data[0] = _pack_head(len(dims), hasnull, sub_oid)
data[1] = b"".join(_pack_dim(dim, 1) for dim in dims)
return b"".join(data)
class ArrayLoader(RecursiveLoader):
delimiter = b","
base_oid: int
def load(self, data: Buffer) -> List[Any]:
loader = self._tx.get_loader(self.base_oid, self.format)
return _load_text(data, loader, self.delimiter)
class ArrayBinaryLoader(RecursiveLoader):
format = pq.Format.BINARY
def load(self, data: Buffer) -> List[Any]:
return _load_binary(data, self._tx)
def register_array(info: TypeInfo, context: Optional[AdaptContext] = None) -> None:
if not info.array_oid:
raise ValueError(f"the type info {info} doesn't describe an array")
adapters = context.adapters if context else postgres.adapters
loader = _make_loader(info.name, info.oid, info.delimiter)
adapters.register_loader(info.array_oid, loader)
# No need to make a new loader because the binary datum has all the info.
loader = getattr(_psycopg, "ArrayBinaryLoader", ArrayBinaryLoader)
adapters.register_loader(info.array_oid, loader)
dumper = _make_dumper(info.name, info.oid, info.array_oid, info.delimiter)
adapters.register_dumper(None, dumper)
dumper = _make_binary_dumper(info.name, info.oid, info.array_oid)
adapters.register_dumper(None, dumper)
# Cache all dynamically-generated types to avoid leaks in case the types
# cannot be GC'd.
@cache
def _make_loader(name: str, oid: int, delimiter: str) -> Type[Loader]:
# Note: caching this function is really needed because, if the C extension
# is available, the resulting type cannot be GC'd, so calling
# register_array() in a loop results in a leak. See #647.
base = getattr(_psycopg, "ArrayLoader", ArrayLoader)
attribs = {"base_oid": oid, "delimiter": delimiter.encode()}
return type(f"{name.title()}{base.__name__}", (base,), attribs)
@cache
def _make_dumper(
name: str, oid: int, array_oid: int, delimiter: str
) -> Type[BaseListDumper]:
attribs = {"oid": array_oid, "element_oid": oid, "delimiter": delimiter.encode()}
return type(f"{name.title()}ListDumper", (ListDumper,), attribs)
@cache
def _make_binary_dumper(name: str, oid: int, array_oid: int) -> Type[BaseListDumper]:
attribs = {"oid": array_oid, "element_oid": oid}
return type(f"{name.title()}ListBinaryDumper", (ListBinaryDumper,), attribs)
def register_default_adapters(context: AdaptContext) -> None:
# The text dumper is more flexible as it can handle lists of mixed type,
# so register it later.
context.adapters.register_dumper(list, ListBinaryDumper)
context.adapters.register_dumper(list, ListDumper)
def register_all_arrays(context: AdaptContext) -> None:
"""
Associate the array oid of all the types in Loader.globals.
This function is designed to be called once at import time, after having
registered all the base loaders.
"""
for t in context.adapters.types:
if t.array_oid:
t.register(context)
def _load_text(
data: Buffer,
loader: Loader,
delimiter: bytes = b",",
__re_unescape: Pattern[bytes] = re.compile(rb"\\(.)"),
) -> List[Any]:
rv = None
stack: List[Any] = []
a: List[Any] = []
rv = a
load = loader.load
# Remove the dimensions information prefix (``[...]=``)
if data and data[0] == b"["[0]:
if isinstance(data, memoryview):
data = bytes(data)
idx = data.find(b"=")
if idx == -1:
raise e.DataError("malformed array: no '=' after dimension information")
data = data[idx + 1 :]
re_parse = _get_array_parse_regexp(delimiter)
for m in re_parse.finditer(data):
t = m.group(1)
if t == b"{":
if stack:
stack[-1].append(a)
stack.append(a)
a = []
elif t == b"}":
if not stack:
raise e.DataError("malformed array: unexpected '}'")
rv = stack.pop()
else:
if not stack:
wat = t[:10].decode("utf8", "replace") + "..." if len(t) > 10 else ""
raise e.DataError(f"malformed array: unexpected '{wat}'")
if t == b"NULL":
v = None
else:
if t.startswith(b'"'):
t = __re_unescape.sub(rb"\1", t[1:-1])
v = load(t)
stack[-1].append(v)
assert rv is not None
return rv
@cache
def _get_array_parse_regexp(delimiter: bytes) -> Pattern[bytes]:
"""
Return a regexp to tokenize an array representation into item and brackets
"""
return re.compile(
rb"""(?xi)
( [{}] # open or closed bracket
| " (?: [^"\\] | \\. )* " # or a quoted string
| [^"{}%s\\]+ # or an unquoted non-empty string
) ,?
"""
% delimiter
)
def _load_binary(data: Buffer, tx: Transformer) -> List[Any]:
ndims, hasnull, oid = _unpack_head(data)
load = tx.get_loader(oid, PQ_BINARY).load
if not ndims:
return []
p = 12 + 8 * ndims
dims = [_unpack_dim(data, i)[0] for i in range(12, p, 8)]
nelems = prod(dims)
out: List[Any] = [None] * nelems
for i in range(nelems):
size = unpack_len(data, p)[0]
p += 4
if size == -1:
continue
out[i] = load(data[p : p + size])
p += size
# fon ndims > 1 we have to aggregate the array into sub-arrays
for dim in dims[-1:0:-1]:
out = [out[i : i + dim] for i in range(0, len(out), dim)]
return out

View File

@ -0,0 +1,48 @@
"""
Adapters for booleans.
"""
# Copyright (C) 2020 The Psycopg Team
from .. import postgres
from ..pq import Format
from ..abc import AdaptContext
from ..adapt import Buffer, Dumper, Loader
class BoolDumper(Dumper):
oid = postgres.types["bool"].oid
def dump(self, obj: bool) -> bytes:
return b"t" if obj else b"f"
def quote(self, obj: bool) -> bytes:
return b"true" if obj else b"false"
class BoolBinaryDumper(Dumper):
format = Format.BINARY
oid = postgres.types["bool"].oid
def dump(self, obj: bool) -> bytes:
return b"\x01" if obj else b"\x00"
class BoolLoader(Loader):
def load(self, data: Buffer) -> bool:
return data == b"t"
class BoolBinaryLoader(Loader):
format = Format.BINARY
def load(self, data: Buffer) -> bool:
return data != b"\x00"
def register_default_adapters(context: AdaptContext) -> None:
adapters = context.adapters
adapters.register_dumper(bool, BoolDumper)
adapters.register_dumper(bool, BoolBinaryDumper)
adapters.register_loader("bool", BoolLoader)
adapters.register_loader("bool", BoolBinaryLoader)

View File

@ -0,0 +1,328 @@
"""
Support for composite types adaptation.
"""
# Copyright (C) 2020 The Psycopg Team
import re
import struct
from collections import namedtuple
from typing import Any, Callable, cast, Dict, Iterator, List, Optional
from typing import NamedTuple, Sequence, Tuple, Type
from .. import pq
from .. import abc
from .. import postgres
from ..adapt import Transformer, PyFormat, RecursiveDumper, Loader, Dumper
from .._compat import cache
from .._struct import pack_len, unpack_len
from ..postgres import TEXT_OID
from .._typeinfo import CompositeInfo as CompositeInfo # exported here
from .._encodings import _as_python_identifier
_struct_oidlen = struct.Struct("!Ii")
_pack_oidlen = cast(Callable[[int, int], bytes], _struct_oidlen.pack)
_unpack_oidlen = cast(
Callable[[abc.Buffer, int], Tuple[int, int]], _struct_oidlen.unpack_from
)
class SequenceDumper(RecursiveDumper):
def _dump_sequence(
self, obj: Sequence[Any], start: bytes, end: bytes, sep: bytes
) -> bytes:
if not obj:
return start + end
parts: List[abc.Buffer] = [start]
for item in obj:
if item is None:
parts.append(sep)
continue
dumper = self._tx.get_dumper(item, PyFormat.from_pq(self.format))
ad = dumper.dump(item)
if not ad:
ad = b'""'
elif self._re_needs_quotes.search(ad):
ad = b'"' + self._re_esc.sub(rb"\1\1", ad) + b'"'
parts.append(ad)
parts.append(sep)
parts[-1] = end
return b"".join(parts)
_re_needs_quotes = re.compile(rb'[",\\\s()]')
_re_esc = re.compile(rb"([\\\"])")
class TupleDumper(SequenceDumper):
# Should be this, but it doesn't work
# oid = postgres_types["record"].oid
def dump(self, obj: Tuple[Any, ...]) -> bytes:
return self._dump_sequence(obj, b"(", b")", b",")
class TupleBinaryDumper(Dumper):
format = pq.Format.BINARY
# Subclasses must set this info
_field_types: Tuple[int, ...]
def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None):
super().__init__(cls, context)
# Note: this class is not a RecursiveDumper because it would use the
# same Transformer of the context, which would confuse dump_sequence()
# in case the composite contains another composite. Make sure to use
# a separate Transformer instance instead.
self._tx = Transformer(context)
self._tx.set_dumper_types(self._field_types, self.format)
nfields = len(self._field_types)
self._formats = (PyFormat.from_pq(self.format),) * nfields
def dump(self, obj: Tuple[Any, ...]) -> bytearray:
out = bytearray(pack_len(len(obj)))
adapted = self._tx.dump_sequence(obj, self._formats)
for i in range(len(obj)):
b = adapted[i]
oid = self._field_types[i]
if b is not None:
out += _pack_oidlen(oid, len(b))
out += b
else:
out += _pack_oidlen(oid, -1)
return out
class BaseCompositeLoader(Loader):
def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None):
super().__init__(oid, context)
self._tx = Transformer(context)
def _parse_record(self, data: abc.Buffer) -> Iterator[Optional[bytes]]:
"""
Split a non-empty representation of a composite type into components.
Terminators shouldn't be used in `!data` (so that both record and range
representations can be parsed).
"""
for m in self._re_tokenize.finditer(data):
if m.group(1):
yield None
elif m.group(2) is not None:
yield self._re_undouble.sub(rb"\1", m.group(2))
else:
yield m.group(3)
# If the final group ended in `,` there is a final NULL in the record
# that the regexp couldn't parse.
if m and m.group().endswith(b","):
yield None
_re_tokenize = re.compile(
rb"""(?x)
(,) # an empty token, representing NULL
| " ((?: [^"] | "")*) " ,? # or a quoted string
| ([^",)]+) ,? # or an unquoted string
"""
)
_re_undouble = re.compile(rb'(["\\])\1')
class RecordLoader(BaseCompositeLoader):
def load(self, data: abc.Buffer) -> Tuple[Any, ...]:
if data == b"()":
return ()
cast = self._tx.get_loader(TEXT_OID, self.format).load
return tuple(
cast(token) if token is not None else None
for token in self._parse_record(data[1:-1])
)
class RecordBinaryLoader(Loader):
format = pq.Format.BINARY
def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None):
super().__init__(oid, context)
self._ctx = context
# Cache a transformer for each sequence of oid found.
# Usually there will be only one, but if there is more than one
# row in the same query (in different columns, or even in different
# records), oids might differ and we'd need separate transformers.
self._txs: Dict[Tuple[int, ...], abc.Transformer] = {}
def load(self, data: abc.Buffer) -> Tuple[Any, ...]:
nfields = unpack_len(data, 0)[0]
offset = 4
oids = []
record = []
for _ in range(nfields):
oid, length = _unpack_oidlen(data, offset)
offset += 8
record.append(data[offset : offset + length] if length != -1 else None)
oids.append(oid)
if length >= 0:
offset += length
key = tuple(oids)
try:
tx = self._txs[key]
except KeyError:
tx = self._txs[key] = Transformer(self._ctx)
tx.set_loader_types(oids, self.format)
return tx.load_sequence(tuple(record))
class CompositeLoader(RecordLoader):
factory: Callable[..., Any]
fields_types: List[int]
_types_set = False
def load(self, data: abc.Buffer) -> Any:
if not self._types_set:
self._config_types(data)
self._types_set = True
if data == b"()":
return type(self).factory()
return type(self).factory(
*self._tx.load_sequence(tuple(self._parse_record(data[1:-1])))
)
def _config_types(self, data: abc.Buffer) -> None:
self._tx.set_loader_types(self.fields_types, self.format)
class CompositeBinaryLoader(RecordBinaryLoader):
format = pq.Format.BINARY
factory: Callable[..., Any]
def load(self, data: abc.Buffer) -> Any:
r = super().load(data)
return type(self).factory(*r)
def register_composite(
info: CompositeInfo,
context: Optional[abc.AdaptContext] = None,
factory: Optional[Callable[..., Any]] = None,
) -> None:
"""Register the adapters to load and dump a composite type.
:param info: The object with the information about the composite to register.
:param context: The context where to register the adapters. If `!None`,
register it globally.
:param factory: Callable to convert the sequence of attributes read from
the composite into a Python object.
.. note::
Registering the adapters doesn't affect objects already created, even
if they are children of the registered context. For instance,
registering the adapter globally doesn't affect already existing
connections.
"""
# A friendly error warning instead of an AttributeError in case fetch()
# failed and it wasn't noticed.
if not info:
raise TypeError("no info passed. Is the requested composite available?")
# Register arrays and type info
info.register(context)
if not factory:
factory = _nt_from_info(info)
adapters = context.adapters if context else postgres.adapters
# generate and register a customized text loader
loader: Type[BaseCompositeLoader]
loader = _make_loader(info.name, tuple(info.field_types), factory)
adapters.register_loader(info.oid, loader)
# generate and register a customized binary loader
loader = _make_binary_loader(info.name, factory)
adapters.register_loader(info.oid, loader)
# If the factory is a type, create and register dumpers for it
if isinstance(factory, type):
dumper: Type[Dumper]
dumper = _make_binary_dumper(info.name, info.oid, tuple(info.field_types))
adapters.register_dumper(factory, dumper)
# Default to the text dumper because it is more flexible
dumper = _make_dumper(info.name, info.oid)
adapters.register_dumper(factory, dumper)
info.python_type = factory
def register_default_adapters(context: abc.AdaptContext) -> None:
adapters = context.adapters
adapters.register_dumper(tuple, TupleDumper)
adapters.register_loader("record", RecordLoader)
adapters.register_loader("record", RecordBinaryLoader)
def _nt_from_info(info: CompositeInfo) -> Type[NamedTuple]:
name = _as_python_identifier(info.name)
fields = tuple(_as_python_identifier(n) for n in info.field_names)
return _make_nt(name, fields)
# Cache all dynamically-generated types to avoid leaks in case the types
# cannot be GC'd.
@cache
def _make_nt(name: str, fields: Tuple[str, ...]) -> Type[NamedTuple]:
return namedtuple(name, fields) # type: ignore[return-value]
@cache
def _make_loader(
name: str, types: Tuple[int, ...], factory: Callable[..., Any]
) -> Type[BaseCompositeLoader]:
return type(
f"{name.title()}Loader",
(CompositeLoader,),
{"factory": factory, "fields_types": list(types)},
)
@cache
def _make_binary_loader(
name: str, factory: Callable[..., Any]
) -> Type[BaseCompositeLoader]:
return type(
f"{name.title()}BinaryLoader", (CompositeBinaryLoader,), {"factory": factory}
)
@cache
def _make_dumper(name: str, oid: int) -> Type[TupleDumper]:
return type(f"{name.title()}Dumper", (TupleDumper,), {"oid": oid})
@cache
def _make_binary_dumper(
name: str, oid: int, field_types: Tuple[int, ...]
) -> Type[TupleBinaryDumper]:
return type(
f"{name.title()}BinaryDumper",
(TupleBinaryDumper,),
{"oid": oid, "_field_types": field_types},
)

View File

@ -0,0 +1,741 @@
"""
Adapters for date/time types.
"""
# Copyright (C) 2020 The Psycopg Team
import re
import struct
from datetime import date, datetime, time, timedelta, timezone
from typing import Any, Callable, cast, Optional, Tuple, TYPE_CHECKING
from .. import postgres
from ..pq import Format
from .._tz import get_tzinfo
from ..abc import AdaptContext, DumperKey
from ..adapt import Buffer, Dumper, Loader, PyFormat
from ..errors import InterfaceError, DataError
from .._struct import pack_int4, pack_int8, unpack_int4, unpack_int8
if TYPE_CHECKING:
from ..connection import BaseConnection
_struct_timetz = struct.Struct("!qi") # microseconds, sec tz offset
_pack_timetz = cast(Callable[[int, int], bytes], _struct_timetz.pack)
_unpack_timetz = cast(Callable[[Buffer], Tuple[int, int]], _struct_timetz.unpack)
_struct_interval = struct.Struct("!qii") # microseconds, days, months
_pack_interval = cast(Callable[[int, int, int], bytes], _struct_interval.pack)
_unpack_interval = cast(
Callable[[Buffer], Tuple[int, int, int]], _struct_interval.unpack
)
utc = timezone.utc
_pg_date_epoch_days = date(2000, 1, 1).toordinal()
_pg_datetime_epoch = datetime(2000, 1, 1)
_pg_datetimetz_epoch = datetime(2000, 1, 1, tzinfo=utc)
_py_date_min_days = date.min.toordinal()
class DateDumper(Dumper):
oid = postgres.types["date"].oid
def dump(self, obj: date) -> bytes:
# NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD)
# the YYYY-MM-DD is always understood correctly.
return str(obj).encode()
class DateBinaryDumper(Dumper):
format = Format.BINARY
oid = postgres.types["date"].oid
def dump(self, obj: date) -> bytes:
days = obj.toordinal() - _pg_date_epoch_days
return pack_int4(days)
class _BaseTimeDumper(Dumper):
def get_key(self, obj: time, format: PyFormat) -> DumperKey:
# Use (cls,) to report the need to upgrade to a dumper for timetz (the
# Frankenstein of the data types).
if not obj.tzinfo:
return self.cls
else:
return (self.cls,)
def upgrade(self, obj: time, format: PyFormat) -> Dumper:
raise NotImplementedError
def _get_offset(self, obj: time) -> timedelta:
offset = obj.utcoffset()
if offset is None:
raise DataError(
f"cannot calculate the offset of tzinfo '{obj.tzinfo}' without a date"
)
return offset
class _BaseTimeTextDumper(_BaseTimeDumper):
def dump(self, obj: time) -> bytes:
return str(obj).encode()
class TimeDumper(_BaseTimeTextDumper):
oid = postgres.types["time"].oid
def upgrade(self, obj: time, format: PyFormat) -> Dumper:
if not obj.tzinfo:
return self
else:
return TimeTzDumper(self.cls)
class TimeTzDumper(_BaseTimeTextDumper):
oid = postgres.types["timetz"].oid
def dump(self, obj: time) -> bytes:
self._get_offset(obj)
return super().dump(obj)
class TimeBinaryDumper(_BaseTimeDumper):
format = Format.BINARY
oid = postgres.types["time"].oid
def dump(self, obj: time) -> bytes:
us = obj.microsecond + 1_000_000 * (
obj.second + 60 * (obj.minute + 60 * obj.hour)
)
return pack_int8(us)
def upgrade(self, obj: time, format: PyFormat) -> Dumper:
if not obj.tzinfo:
return self
else:
return TimeTzBinaryDumper(self.cls)
class TimeTzBinaryDumper(_BaseTimeDumper):
format = Format.BINARY
oid = postgres.types["timetz"].oid
def dump(self, obj: time) -> bytes:
us = obj.microsecond + 1_000_000 * (
obj.second + 60 * (obj.minute + 60 * obj.hour)
)
off = self._get_offset(obj)
return _pack_timetz(us, -int(off.total_seconds()))
class _BaseDatetimeDumper(Dumper):
def get_key(self, obj: datetime, format: PyFormat) -> DumperKey:
# Use (cls,) to report the need to upgrade (downgrade, actually) to a
# dumper for naive timestamp.
if obj.tzinfo:
return self.cls
else:
return (self.cls,)
def upgrade(self, obj: datetime, format: PyFormat) -> Dumper:
raise NotImplementedError
class _BaseDatetimeTextDumper(_BaseDatetimeDumper):
def dump(self, obj: datetime) -> bytes:
# NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD)
# the YYYY-MM-DD is always understood correctly.
return str(obj).encode()
class DatetimeDumper(_BaseDatetimeTextDumper):
oid = postgres.types["timestamptz"].oid
def upgrade(self, obj: datetime, format: PyFormat) -> Dumper:
if obj.tzinfo:
return self
else:
return DatetimeNoTzDumper(self.cls)
class DatetimeNoTzDumper(_BaseDatetimeTextDumper):
oid = postgres.types["timestamp"].oid
class DatetimeBinaryDumper(_BaseDatetimeDumper):
format = Format.BINARY
oid = postgres.types["timestamptz"].oid
def dump(self, obj: datetime) -> bytes:
delta = obj - _pg_datetimetz_epoch
micros = delta.microseconds + 1_000_000 * (86_400 * delta.days + delta.seconds)
return pack_int8(micros)
def upgrade(self, obj: datetime, format: PyFormat) -> Dumper:
if obj.tzinfo:
return self
else:
return DatetimeNoTzBinaryDumper(self.cls)
class DatetimeNoTzBinaryDumper(_BaseDatetimeDumper):
format = Format.BINARY
oid = postgres.types["timestamp"].oid
def dump(self, obj: datetime) -> bytes:
delta = obj - _pg_datetime_epoch
micros = delta.microseconds + 1_000_000 * (86_400 * delta.days + delta.seconds)
return pack_int8(micros)
class TimedeltaDumper(Dumper):
oid = postgres.types["interval"].oid
def __init__(self, cls: type, context: Optional[AdaptContext] = None):
super().__init__(cls, context)
if self.connection:
if (
self.connection.pgconn.parameter_status(b"IntervalStyle")
== b"sql_standard"
):
setattr(self, "dump", self._dump_sql)
def dump(self, obj: timedelta) -> bytes:
# The comma is parsed ok by PostgreSQL but it's not documented
# and it seems brittle to rely on it. CRDB doesn't consume it well.
return str(obj).encode().replace(b",", b"")
def _dump_sql(self, obj: timedelta) -> bytes:
# sql_standard format needs explicit signs
# otherwise -1 day 1 sec will mean -1 sec
return b"%+d day %+d second %+d microsecond" % (
obj.days,
obj.seconds,
obj.microseconds,
)
class TimedeltaBinaryDumper(Dumper):
format = Format.BINARY
oid = postgres.types["interval"].oid
def dump(self, obj: timedelta) -> bytes:
micros = 1_000_000 * obj.seconds + obj.microseconds
return _pack_interval(micros, obj.days, 0)
class DateLoader(Loader):
_ORDER_YMD = 0
_ORDER_DMY = 1
_ORDER_MDY = 2
def __init__(self, oid: int, context: Optional[AdaptContext] = None):
super().__init__(oid, context)
ds = _get_datestyle(self.connection)
if ds.startswith(b"I"): # ISO
self._order = self._ORDER_YMD
elif ds.startswith(b"G"): # German
self._order = self._ORDER_DMY
elif ds.startswith(b"S") or ds.startswith(b"P"): # SQL or Postgres
self._order = self._ORDER_DMY if ds.endswith(b"DMY") else self._ORDER_MDY
else:
raise InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}")
def load(self, data: Buffer) -> date:
if self._order == self._ORDER_YMD:
ye = data[:4]
mo = data[5:7]
da = data[8:]
elif self._order == self._ORDER_DMY:
da = data[:2]
mo = data[3:5]
ye = data[6:]
else:
mo = data[:2]
da = data[3:5]
ye = data[6:]
try:
return date(int(ye), int(mo), int(da))
except ValueError as ex:
s = bytes(data).decode("utf8", "replace")
if s == "infinity" or (s and len(s.split()[0]) > 10):
raise DataError(f"date too large (after year 10K): {s!r}") from None
elif s == "-infinity" or "BC" in s:
raise DataError(f"date too small (before year 1): {s!r}") from None
else:
raise DataError(f"can't parse date {s!r}: {ex}") from None
class DateBinaryLoader(Loader):
format = Format.BINARY
def load(self, data: Buffer) -> date:
days = unpack_int4(data)[0] + _pg_date_epoch_days
try:
return date.fromordinal(days)
except (ValueError, OverflowError):
if days < _py_date_min_days:
raise DataError("date too small (before year 1)") from None
else:
raise DataError("date too large (after year 10K)") from None
class TimeLoader(Loader):
_re_format = re.compile(rb"^(\d+):(\d+):(\d+)(?:\.(\d+))?")
def load(self, data: Buffer) -> time:
m = self._re_format.match(data)
if not m:
s = bytes(data).decode("utf8", "replace")
raise DataError(f"can't parse time {s!r}")
ho, mi, se, fr = m.groups()
# Pad the fraction of second to get micros
if fr:
us = int(fr)
if len(fr) < 6:
us *= _uspad[len(fr)]
else:
us = 0
try:
return time(int(ho), int(mi), int(se), us)
except ValueError as e:
s = bytes(data).decode("utf8", "replace")
raise DataError(f"can't parse time {s!r}: {e}") from None
class TimeBinaryLoader(Loader):
format = Format.BINARY
def load(self, data: Buffer) -> time:
val = unpack_int8(data)[0]
val, us = divmod(val, 1_000_000)
val, s = divmod(val, 60)
h, m = divmod(val, 60)
try:
return time(h, m, s, us)
except ValueError:
raise DataError(f"time not supported by Python: hour={h}") from None
class TimetzLoader(Loader):
_re_format = re.compile(
rb"""(?ix)
^
(\d+) : (\d+) : (\d+) (?: \. (\d+) )? # Time and micros
([-+]) (\d+) (?: : (\d+) )? (?: : (\d+) )? # Timezone
$
"""
)
def load(self, data: Buffer) -> time:
m = self._re_format.match(data)
if not m:
s = bytes(data).decode("utf8", "replace")
raise DataError(f"can't parse timetz {s!r}")
ho, mi, se, fr, sgn, oh, om, os = m.groups()
# Pad the fraction of second to get the micros
if fr:
us = int(fr)
if len(fr) < 6:
us *= _uspad[len(fr)]
else:
us = 0
# Calculate timezone
off = 60 * 60 * int(oh)
if om:
off += 60 * int(om)
if os:
off += int(os)
tz = timezone(timedelta(0, off if sgn == b"+" else -off))
try:
return time(int(ho), int(mi), int(se), us, tz)
except ValueError as e:
s = bytes(data).decode("utf8", "replace")
raise DataError(f"can't parse timetz {s!r}: {e}") from None
class TimetzBinaryLoader(Loader):
format = Format.BINARY
def load(self, data: Buffer) -> time:
val, off = _unpack_timetz(data)
val, us = divmod(val, 1_000_000)
val, s = divmod(val, 60)
h, m = divmod(val, 60)
try:
return time(h, m, s, us, timezone(timedelta(seconds=-off)))
except ValueError:
raise DataError(f"time not supported by Python: hour={h}") from None
class TimestampLoader(Loader):
_re_format = re.compile(
rb"""(?ix)
^
(\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Date
(?: T | [^a-z0-9] ) # Separator, including T
(\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Time
(?: \.(\d+) )? # Micros
$
"""
)
_re_format_pg = re.compile(
rb"""(?ix)
^
[a-z]+ [^a-z0-9] # DoW, separator
(\d+|[a-z]+) [^a-z0-9] # Month or day
(\d+|[a-z]+) [^a-z0-9] # Month or day
(\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Time
(?: \.(\d+) )? # Micros
[^a-z0-9] (\d+) # Year
$
"""
)
_ORDER_YMD = 0
_ORDER_DMY = 1
_ORDER_MDY = 2
_ORDER_PGDM = 3
_ORDER_PGMD = 4
def __init__(self, oid: int, context: Optional[AdaptContext] = None):
super().__init__(oid, context)
ds = _get_datestyle(self.connection)
if ds.startswith(b"I"): # ISO
self._order = self._ORDER_YMD
elif ds.startswith(b"G"): # German
self._order = self._ORDER_DMY
elif ds.startswith(b"S"): # SQL
self._order = self._ORDER_DMY if ds.endswith(b"DMY") else self._ORDER_MDY
elif ds.startswith(b"P"): # Postgres
self._order = self._ORDER_PGDM if ds.endswith(b"DMY") else self._ORDER_PGMD
self._re_format = self._re_format_pg
else:
raise InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}")
def load(self, data: Buffer) -> datetime:
m = self._re_format.match(data)
if not m:
raise _get_timestamp_load_error(self.connection, data) from None
if self._order == self._ORDER_YMD:
ye, mo, da, ho, mi, se, fr = m.groups()
imo = int(mo)
elif self._order == self._ORDER_DMY:
da, mo, ye, ho, mi, se, fr = m.groups()
imo = int(mo)
elif self._order == self._ORDER_MDY:
mo, da, ye, ho, mi, se, fr = m.groups()
imo = int(mo)
else:
if self._order == self._ORDER_PGDM:
da, mo, ho, mi, se, fr, ye = m.groups()
else:
mo, da, ho, mi, se, fr, ye = m.groups()
try:
imo = _month_abbr[mo]
except KeyError:
s = mo.decode("utf8", "replace")
raise DataError(f"can't parse month: {s!r}") from None
# Pad the fraction of second to get the micros
if fr:
us = int(fr)
if len(fr) < 6:
us *= _uspad[len(fr)]
else:
us = 0
try:
return datetime(int(ye), imo, int(da), int(ho), int(mi), int(se), us)
except ValueError as ex:
raise _get_timestamp_load_error(self.connection, data, ex) from None
class TimestampBinaryLoader(Loader):
format = Format.BINARY
def load(self, data: Buffer) -> datetime:
micros = unpack_int8(data)[0]
try:
return _pg_datetime_epoch + timedelta(microseconds=micros)
except OverflowError:
if micros <= 0:
raise DataError("timestamp too small (before year 1)") from None
else:
raise DataError("timestamp too large (after year 10K)") from None
class TimestamptzLoader(Loader):
_re_format = re.compile(
rb"""(?ix)
^
(\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Date
(?: T | [^a-z0-9] ) # Separator, including T
(\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Time
(?: \.(\d+) )? # Micros
([-+]) (\d+) (?: : (\d+) )? (?: : (\d+) )? # Timezone
$
"""
)
def __init__(self, oid: int, context: Optional[AdaptContext] = None):
super().__init__(oid, context)
self._timezone = get_tzinfo(self.connection.pgconn if self.connection else None)
ds = _get_datestyle(self.connection)
if not ds.startswith(b"I"): # not ISO
setattr(self, "load", self._load_notimpl)
def load(self, data: Buffer) -> datetime:
m = self._re_format.match(data)
if not m:
raise _get_timestamp_load_error(self.connection, data) from None
ye, mo, da, ho, mi, se, fr, sgn, oh, om, os = m.groups()
# Pad the fraction of second to get the micros
if fr:
us = int(fr)
if len(fr) < 6:
us *= _uspad[len(fr)]
else:
us = 0
# Calculate timezone offset
soff = 60 * 60 * int(oh)
if om:
soff += 60 * int(om)
if os:
soff += int(os)
tzoff = timedelta(0, soff if sgn == b"+" else -soff)
# The return value is a datetime with the timezone of the connection
# (in order to be consistent with the binary loader, which is the only
# thing it can return). So create a temporary datetime object, in utc,
# shift it by the offset parsed from the timestamp, and then move it to
# the connection timezone.
dt = None
ex: Exception
try:
dt = datetime(int(ye), int(mo), int(da), int(ho), int(mi), int(se), us, utc)
return (dt - tzoff).astimezone(self._timezone)
except OverflowError as e:
# If we have created the temporary 'dt' it means that we have a
# datetime close to max, the shift pushed it past max, overflowing.
# In this case return the datetime in a fixed offset timezone.
if dt is not None:
return dt.replace(tzinfo=timezone(tzoff))
else:
ex = e
except ValueError as e:
ex = e
raise _get_timestamp_load_error(self.connection, data, ex) from None
def _load_notimpl(self, data: Buffer) -> datetime:
s = bytes(data).decode("utf8", "replace")
ds = _get_datestyle(self.connection).decode("ascii")
raise NotImplementedError(
f"can't parse timestamptz with DateStyle {ds!r}: {s!r}"
)
class TimestamptzBinaryLoader(Loader):
format = Format.BINARY
def __init__(self, oid: int, context: Optional[AdaptContext] = None):
super().__init__(oid, context)
self._timezone = get_tzinfo(self.connection.pgconn if self.connection else None)
def load(self, data: Buffer) -> datetime:
micros = unpack_int8(data)[0]
try:
ts = _pg_datetimetz_epoch + timedelta(microseconds=micros)
return ts.astimezone(self._timezone)
except OverflowError:
# If we were asked about a timestamp which would overflow in UTC,
# but not in the desired timezone (e.g. datetime.max at Chicago
# timezone) we can still save the day by shifting the value by the
# timezone offset and then replacing the timezone.
if self._timezone:
utcoff = self._timezone.utcoffset(
datetime.min if micros < 0 else datetime.max
)
if utcoff:
usoff = 1_000_000 * int(utcoff.total_seconds())
try:
ts = _pg_datetime_epoch + timedelta(microseconds=micros + usoff)
except OverflowError:
pass # will raise downstream
else:
return ts.replace(tzinfo=self._timezone)
if micros <= 0:
raise DataError("timestamp too small (before year 1)") from None
else:
raise DataError("timestamp too large (after year 10K)") from None
class IntervalLoader(Loader):
_re_interval = re.compile(
rb"""
(?: ([-+]?\d+) \s+ years? \s* )? # Years
(?: ([-+]?\d+) \s+ mons? \s* )? # Months
(?: ([-+]?\d+) \s+ days? \s* )? # Days
(?: ([-+])? (\d+) : (\d+) : (\d+ (?:\.\d+)?) # Time
)?
""",
re.VERBOSE,
)
def __init__(self, oid: int, context: Optional[AdaptContext] = None):
super().__init__(oid, context)
if self.connection:
ints = self.connection.pgconn.parameter_status(b"IntervalStyle")
if ints != b"postgres":
setattr(self, "load", self._load_notimpl)
def load(self, data: Buffer) -> timedelta:
m = self._re_interval.match(data)
if not m:
s = bytes(data).decode("utf8", "replace")
raise DataError(f"can't parse interval {s!r}")
ye, mo, da, sgn, ho, mi, se = m.groups()
days = 0
seconds = 0.0
if ye:
days += 365 * int(ye)
if mo:
days += 30 * int(mo)
if da:
days += int(da)
if ho:
seconds = 3600 * int(ho) + 60 * int(mi) + float(se)
if sgn == b"-":
seconds = -seconds
try:
return timedelta(days=days, seconds=seconds)
except OverflowError as e:
s = bytes(data).decode("utf8", "replace")
raise DataError(f"can't parse interval {s!r}: {e}") from None
def _load_notimpl(self, data: Buffer) -> timedelta:
s = bytes(data).decode("utf8", "replace")
ints = (
self.connection
and self.connection.pgconn.parameter_status(b"IntervalStyle")
or b"unknown"
).decode("utf8", "replace")
raise NotImplementedError(
f"can't parse interval with IntervalStyle {ints}: {s!r}"
)
class IntervalBinaryLoader(Loader):
format = Format.BINARY
def load(self, data: Buffer) -> timedelta:
micros, days, months = _unpack_interval(data)
if months > 0:
years, months = divmod(months, 12)
days = days + 30 * months + 365 * years
elif months < 0:
years, months = divmod(-months, 12)
days = days - 30 * months - 365 * years
try:
return timedelta(days=days, microseconds=micros)
except OverflowError as e:
raise DataError(f"can't parse interval: {e}") from None
def _get_datestyle(conn: Optional["BaseConnection[Any]"]) -> bytes:
if conn:
ds = conn.pgconn.parameter_status(b"DateStyle")
if ds:
return ds
return b"ISO, DMY"
def _get_timestamp_load_error(
conn: Optional["BaseConnection[Any]"], data: Buffer, ex: Optional[Exception] = None
) -> Exception:
s = bytes(data).decode("utf8", "replace")
def is_overflow(s: str) -> bool:
if not s:
return False
ds = _get_datestyle(conn)
if not ds.startswith(b"P"): # Postgres
return len(s.split()[0]) > 10 # date is first token
else:
return len(s.split()[-1]) > 4 # year is last token
if s == "-infinity" or s.endswith("BC"):
return DataError("timestamp too small (before year 1): {s!r}")
elif s == "infinity" or is_overflow(s):
return DataError(f"timestamp too large (after year 10K): {s!r}")
else:
return DataError(f"can't parse timestamp {s!r}: {ex or '(unknown)'}")
_month_abbr = {
n: i
for i, n in enumerate(b"Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec".split(), 1)
}
# Pad to get microseconds from a fraction of seconds
_uspad = [0, 100_000, 10_000, 1_000, 100, 10, 1]
def register_default_adapters(context: AdaptContext) -> None:
adapters = context.adapters
adapters.register_dumper("datetime.date", DateDumper)
adapters.register_dumper("datetime.date", DateBinaryDumper)
# first register dumpers for 'timetz' oid, then the proper ones on time type.
adapters.register_dumper("datetime.time", TimeTzDumper)
adapters.register_dumper("datetime.time", TimeTzBinaryDumper)
adapters.register_dumper("datetime.time", TimeDumper)
adapters.register_dumper("datetime.time", TimeBinaryDumper)
# first register dumpers for 'timestamp' oid, then the proper ones
# on the datetime type.
adapters.register_dumper("datetime.datetime", DatetimeNoTzDumper)
adapters.register_dumper("datetime.datetime", DatetimeNoTzBinaryDumper)
adapters.register_dumper("datetime.datetime", DatetimeDumper)
adapters.register_dumper("datetime.datetime", DatetimeBinaryDumper)
adapters.register_dumper("datetime.timedelta", TimedeltaDumper)
adapters.register_dumper("datetime.timedelta", TimedeltaBinaryDumper)
adapters.register_loader("date", DateLoader)
adapters.register_loader("date", DateBinaryLoader)
adapters.register_loader("time", TimeLoader)
adapters.register_loader("time", TimeBinaryLoader)
adapters.register_loader("timetz", TimetzLoader)
adapters.register_loader("timetz", TimetzBinaryLoader)
adapters.register_loader("timestamp", TimestampLoader)
adapters.register_loader("timestamp", TimestampBinaryLoader)
adapters.register_loader("timestamptz", TimestamptzLoader)
adapters.register_loader("timestamptz", TimestamptzBinaryLoader)
adapters.register_loader("interval", IntervalLoader)
adapters.register_loader("interval", IntervalBinaryLoader)

View File

@ -0,0 +1,220 @@
"""
Adapters for the enum type.
"""
from enum import Enum
from typing import Dict, Generic, Optional, Mapping, Sequence
from typing import Tuple, Type, TypeVar, Union, cast
from typing_extensions import TypeAlias
from .. import postgres
from .. import errors as e
from ..pq import Format
from ..abc import AdaptContext
from ..adapt import Buffer, Dumper, Loader
from .._compat import cache
from .._encodings import conn_encoding
from .._typeinfo import EnumInfo as EnumInfo # exported here
E = TypeVar("E", bound=Enum)
EnumDumpMap: TypeAlias = Dict[E, bytes]
EnumLoadMap: TypeAlias = Dict[bytes, E]
EnumMapping: TypeAlias = Union[Mapping[E, str], Sequence[Tuple[E, str]], None]
# Hashable versions
_HEnumDumpMap: TypeAlias = Tuple[Tuple[E, bytes], ...]
_HEnumLoadMap: TypeAlias = Tuple[Tuple[bytes, E], ...]
TEXT = Format.TEXT
BINARY = Format.BINARY
class _BaseEnumLoader(Loader, Generic[E]):
"""
Loader for a specific Enum class
"""
enum: Type[E]
_load_map: EnumLoadMap[E]
def load(self, data: Buffer) -> E:
if not isinstance(data, bytes):
data = bytes(data)
try:
return self._load_map[data]
except KeyError:
enc = conn_encoding(self.connection)
label = data.decode(enc, "replace")
raise e.DataError(
f"bad member for enum {self.enum.__qualname__}: {label!r}"
)
class _BaseEnumDumper(Dumper, Generic[E]):
"""
Dumper for a specific Enum class
"""
enum: Type[E]
_dump_map: EnumDumpMap[E]
def dump(self, value: E) -> Buffer:
return self._dump_map[value]
class EnumDumper(Dumper):
"""
Dumper for a generic Enum class
"""
def __init__(self, cls: type, context: Optional[AdaptContext] = None):
super().__init__(cls, context)
self._encoding = conn_encoding(self.connection)
def dump(self, value: E) -> Buffer:
return value.name.encode(self._encoding)
class EnumBinaryDumper(EnumDumper):
format = BINARY
def register_enum(
info: EnumInfo,
context: Optional[AdaptContext] = None,
enum: Optional[Type[E]] = None,
*,
mapping: EnumMapping[E] = None,
) -> None:
"""Register the adapters to load and dump a enum type.
:param info: The object with the information about the enum to register.
:param context: The context where to register the adapters. If `!None`,
register it globally.
:param enum: Python enum type matching to the PostgreSQL one. If `!None`,
a new enum will be generated and exposed as `EnumInfo.enum`.
:param mapping: Override the mapping between `!enum` members and `!info`
labels.
"""
if not info:
raise TypeError("no info passed. Is the requested enum available?")
if enum is None:
enum = cast(Type[E], _make_enum(info.name, tuple(info.labels)))
info.enum = enum
adapters = context.adapters if context else postgres.adapters
info.register(context)
load_map = _make_load_map(info, enum, mapping, context)
loader = _make_loader(info.name, info.enum, load_map)
adapters.register_loader(info.oid, loader)
loader = _make_binary_loader(info.name, info.enum, load_map)
adapters.register_loader(info.oid, loader)
dump_map = _make_dump_map(info, enum, mapping, context)
dumper = _make_dumper(info.enum, info.oid, dump_map)
adapters.register_dumper(info.enum, dumper)
dumper = _make_binary_dumper(info.enum, info.oid, dump_map)
adapters.register_dumper(info.enum, dumper)
# Cache all dynamically-generated types to avoid leaks in case the types
# cannot be GC'd.
@cache
def _make_enum(name: str, labels: Tuple[str, ...]) -> Enum:
return Enum(name.title(), labels, module=__name__)
@cache
def _make_loader(
name: str, enum: Type[Enum], load_map: _HEnumLoadMap[E]
) -> Type[_BaseEnumLoader[E]]:
attribs = {"enum": enum, "_load_map": dict(load_map)}
return type(f"{name.title()}Loader", (_BaseEnumLoader,), attribs)
@cache
def _make_binary_loader(
name: str, enum: Type[Enum], load_map: _HEnumLoadMap[E]
) -> Type[_BaseEnumLoader[E]]:
attribs = {"enum": enum, "_load_map": dict(load_map), "format": BINARY}
return type(f"{name.title()}BinaryLoader", (_BaseEnumLoader,), attribs)
@cache
def _make_dumper(
enum: Type[Enum], oid: int, dump_map: _HEnumDumpMap[E]
) -> Type[_BaseEnumDumper[E]]:
attribs = {"enum": enum, "oid": oid, "_dump_map": dict(dump_map)}
return type(f"{enum.__name__}Dumper", (_BaseEnumDumper,), attribs)
@cache
def _make_binary_dumper(
enum: Type[Enum], oid: int, dump_map: _HEnumDumpMap[E]
) -> Type[_BaseEnumDumper[E]]:
attribs = {"enum": enum, "oid": oid, "_dump_map": dict(dump_map), "format": BINARY}
return type(f"{enum.__name__}BinaryDumper", (_BaseEnumDumper,), attribs)
def _make_load_map(
info: EnumInfo,
enum: Type[E],
mapping: EnumMapping[E],
context: Optional[AdaptContext],
) -> _HEnumLoadMap[E]:
enc = conn_encoding(context.connection if context else None)
rv = []
for label in info.labels:
try:
member = enum[label]
except KeyError:
# tolerate a missing enum, assuming it won't be used. If it is we
# will get a DataError on fetch.
pass
else:
rv.append((label.encode(enc), member))
if mapping:
if isinstance(mapping, Mapping):
mapping = list(mapping.items())
for member, label in mapping:
rv.append((label.encode(enc), member))
return tuple(rv)
def _make_dump_map(
info: EnumInfo,
enum: Type[E],
mapping: EnumMapping[E],
context: Optional[AdaptContext],
) -> _HEnumDumpMap[E]:
enc = conn_encoding(context.connection if context else None)
rv = []
for member in enum:
rv.append((member, member.name.encode(enc)))
if mapping:
if isinstance(mapping, Mapping):
mapping = list(mapping.items())
for member, label in mapping:
rv.append((member, label.encode(enc)))
return tuple(rv)
def register_default_adapters(context: AdaptContext) -> None:
context.adapters.register_dumper(Enum, EnumBinaryDumper)
context.adapters.register_dumper(Enum, EnumDumper)

View File

@ -0,0 +1,146 @@
"""
Dict to hstore adaptation
"""
# Copyright (C) 2021 The Psycopg Team
import re
from typing import Dict, List, Optional, Type
from typing_extensions import TypeAlias
from .. import errors as e
from .. import postgres
from ..abc import Buffer, AdaptContext
from ..adapt import PyFormat, RecursiveDumper, RecursiveLoader
from .._compat import cache
from ..postgres import TEXT_OID
from .._typeinfo import TypeInfo
_re_escape = re.compile(r'(["\\])')
_re_unescape = re.compile(r"\\(.)")
_re_hstore = re.compile(
r"""
# hstore key:
# a string of normal or escaped chars
"((?: [^"\\] | \\. )*)"
\s*=>\s* # hstore value
(?:
NULL # the value can be null - not caught
# or a quoted string like the key
| "((?: [^"\\] | \\. )*)"
)
(?:\s*,\s*|$) # pairs separated by comma or end of string.
""",
re.VERBOSE,
)
Hstore: TypeAlias = Dict[str, Optional[str]]
class BaseHstoreDumper(RecursiveDumper):
def dump(self, obj: Hstore) -> Buffer:
if not obj:
return b""
tokens: List[str] = []
def add_token(s: str) -> None:
tokens.append('"')
tokens.append(_re_escape.sub(r"\\\1", s))
tokens.append('"')
for k, v in obj.items():
if not isinstance(k, str):
raise e.DataError("hstore keys can only be strings")
add_token(k)
tokens.append("=>")
if v is None:
tokens.append("NULL")
elif not isinstance(v, str):
raise e.DataError("hstore keys can only be strings")
else:
add_token(v)
tokens.append(",")
del tokens[-1]
data = "".join(tokens)
dumper = self._tx.get_dumper(data, PyFormat.TEXT)
return dumper.dump(data)
class HstoreLoader(RecursiveLoader):
def load(self, data: Buffer) -> Hstore:
loader = self._tx.get_loader(TEXT_OID, self.format)
s: str = loader.load(data)
rv: Hstore = {}
start = 0
for m in _re_hstore.finditer(s):
if m is None or m.start() != start:
raise e.DataError(f"error parsing hstore pair at char {start}")
k = _re_unescape.sub(r"\1", m.group(1))
v = m.group(2)
if v is not None:
v = _re_unescape.sub(r"\1", v)
rv[k] = v
start = m.end()
if start < len(s):
raise e.DataError(f"error parsing hstore: unparsed data after char {start}")
return rv
def register_hstore(info: TypeInfo, context: Optional[AdaptContext] = None) -> None:
"""Register the adapters to load and dump hstore.
:param info: The object with the information about the hstore type.
:param context: The context where to register the adapters. If `!None`,
register it globally.
.. note::
Registering the adapters doesn't affect objects already created, even
if they are children of the registered context. For instance,
registering the adapter globally doesn't affect already existing
connections.
"""
# A friendly error warning instead of an AttributeError in case fetch()
# failed and it wasn't noticed.
if not info:
raise TypeError("no info passed. Is the 'hstore' extension loaded?")
# Register arrays and type info
info.register(context)
adapters = context.adapters if context else postgres.adapters
# Generate and register a customized text dumper
adapters.register_dumper(dict, _make_hstore_dumper(info.oid))
# register the text loader on the oid
adapters.register_loader(info.oid, HstoreLoader)
# Cache all dynamically-generated types to avoid leaks in case the types
# cannot be GC'd.
@cache
def _make_hstore_dumper(oid_in: int) -> Type[BaseHstoreDumper]:
"""
Return an hstore dumper class configured using `oid_in`.
Avoid to create new classes if the oid configured is the same.
"""
class HstoreDumper(BaseHstoreDumper):
oid = oid_in
return HstoreDumper

View File

@ -0,0 +1,247 @@
"""
Adapters for JSON types.
"""
# Copyright (C) 2020 The Psycopg Team
import json
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
from .. import abc
from .. import errors as e
from .. import postgres
from ..pq import Format
from ..adapt import Buffer, Dumper, Loader, PyFormat, AdaptersMap
from ..errors import DataError
from .._compat import cache
JsonDumpsFunction = Callable[[Any], Union[str, bytes]]
JsonLoadsFunction = Callable[[Union[str, bytes]], Any]
def set_json_dumps(
dumps: JsonDumpsFunction, context: Optional[abc.AdaptContext] = None
) -> None:
"""
Set the JSON serialisation function to store JSON objects in the database.
:param dumps: The dump function to use.
:type dumps: `!Callable[[Any], str]`
:param context: Where to use the `!dumps` function. If not specified, use it
globally.
:type context: `~psycopg.Connection` or `~psycopg.Cursor`
By default dumping JSON uses the builtin `json.dumps`. You can override
it to use a different JSON library or to use customised arguments.
If the `Json` wrapper specified a `!dumps` function, use it in precedence
of the one set by this function.
"""
if context is None:
# If changing load function globally, just change the default on the
# global class
_JsonDumper._dumps = dumps
else:
adapters = context.adapters
# If the scope is smaller than global, create subclassess and register
# them in the appropriate scope.
grid = [
(Json, PyFormat.BINARY),
(Json, PyFormat.TEXT),
(Jsonb, PyFormat.BINARY),
(Jsonb, PyFormat.TEXT),
]
for wrapper, format in grid:
base = _get_current_dumper(adapters, wrapper, format)
dumper = _make_dumper(base, dumps)
adapters.register_dumper(wrapper, dumper)
def set_json_loads(
loads: JsonLoadsFunction, context: Optional[abc.AdaptContext] = None
) -> None:
"""
Set the JSON parsing function to fetch JSON objects from the database.
:param loads: The load function to use.
:type loads: `!Callable[[bytes], Any]`
:param context: Where to use the `!loads` function. If not specified, use
it globally.
:type context: `~psycopg.Connection` or `~psycopg.Cursor`
By default loading JSON uses the builtin `json.loads`. You can override
it to use a different JSON library or to use customised arguments.
"""
if context is None:
# If changing load function globally, just change the default on the
# global class
_JsonLoader._loads = loads
else:
# If the scope is smaller than global, create subclassess and register
# them in the appropriate scope.
grid = [
("json", JsonLoader),
("json", JsonBinaryLoader),
("jsonb", JsonbLoader),
("jsonb", JsonbBinaryLoader),
]
for tname, base in grid:
loader = _make_loader(base, loads)
context.adapters.register_loader(tname, loader)
# Cache all dynamically-generated types to avoid leaks in case the types
# cannot be GC'd.
@cache
def _make_dumper(base: Type[abc.Dumper], dumps: JsonDumpsFunction) -> Type[abc.Dumper]:
name = base.__name__
if not name.startswith("Custom"):
name = f"Custom{name}"
return type(name, (base,), {"_dumps": dumps})
@cache
def _make_loader(base: Type[Loader], loads: JsonLoadsFunction) -> Type[Loader]:
name = base.__name__
if not name.startswith("Custom"):
name = f"Custom{name}"
return type(name, (base,), {"_loads": loads})
class _JsonWrapper:
__slots__ = ("obj", "dumps")
def __init__(self, obj: Any, dumps: Optional[JsonDumpsFunction] = None):
self.obj = obj
self.dumps = dumps
def __repr__(self) -> str:
sobj = repr(self.obj)
if len(sobj) > 40:
sobj = f"{sobj[:35]} ... ({len(sobj)} chars)"
return f"{self.__class__.__name__}({sobj})"
class Json(_JsonWrapper):
__slots__ = ()
class Jsonb(_JsonWrapper):
__slots__ = ()
class _JsonDumper(Dumper):
# The globally used JSON dumps() function. It can be changed globally (by
# set_json_dumps) or by a subclass.
_dumps: JsonDumpsFunction = json.dumps
def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None):
super().__init__(cls, context)
self.dumps = self.__class__._dumps
def dump(self, obj: Any) -> bytes:
if isinstance(obj, _JsonWrapper):
dumps = obj.dumps or self.dumps
obj = obj.obj
else:
dumps = self.dumps
data = dumps(obj)
if isinstance(data, str):
return data.encode()
return data
class JsonDumper(_JsonDumper):
oid = postgres.types["json"].oid
class JsonBinaryDumper(_JsonDumper):
format = Format.BINARY
oid = postgres.types["json"].oid
class JsonbDumper(_JsonDumper):
oid = postgres.types["jsonb"].oid
class JsonbBinaryDumper(_JsonDumper):
format = Format.BINARY
oid = postgres.types["jsonb"].oid
def dump(self, obj: Any) -> bytes:
return b"\x01" + super().dump(obj)
class _JsonLoader(Loader):
# The globally used JSON loads() function. It can be changed globally (by
# set_json_loads) or by a subclass.
_loads: JsonLoadsFunction = json.loads
def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None):
super().__init__(oid, context)
self.loads = self.__class__._loads
def load(self, data: Buffer) -> Any:
# json.loads() cannot work on memoryview.
if not isinstance(data, bytes):
data = bytes(data)
return self.loads(data)
class JsonLoader(_JsonLoader):
pass
class JsonbLoader(_JsonLoader):
pass
class JsonBinaryLoader(_JsonLoader):
format = Format.BINARY
class JsonbBinaryLoader(_JsonLoader):
format = Format.BINARY
def load(self, data: Buffer) -> Any:
if data and data[0] != 1:
raise DataError("unknown jsonb binary format: {data[0]}")
data = data[1:]
if not isinstance(data, bytes):
data = bytes(data)
return self.loads(data)
def _get_current_dumper(
adapters: AdaptersMap, cls: type, format: PyFormat
) -> Type[abc.Dumper]:
try:
return adapters.get_dumper(cls, format)
except e.ProgrammingError:
return _default_dumpers[cls, format]
_default_dumpers: Dict[Tuple[Type[_JsonWrapper], PyFormat], Type[Dumper]] = {
(Json, PyFormat.BINARY): JsonBinaryDumper,
(Json, PyFormat.TEXT): JsonDumper,
(Jsonb, PyFormat.BINARY): JsonbBinaryDumper,
(Jsonb, PyFormat.TEXT): JsonDumper,
}
def register_default_adapters(context: abc.AdaptContext) -> None:
adapters = context.adapters
# Currently json binary format is nothing different than text, maybe with
# an extra memcopy we can avoid.
adapters.register_dumper(Json, JsonBinaryDumper)
adapters.register_dumper(Json, JsonDumper)
adapters.register_dumper(Jsonb, JsonbBinaryDumper)
adapters.register_dumper(Jsonb, JsonbDumper)
adapters.register_loader("json", JsonLoader)
adapters.register_loader("jsonb", JsonbLoader)
adapters.register_loader("json", JsonBinaryLoader)
adapters.register_loader("jsonb", JsonbBinaryLoader)

View File

@ -0,0 +1,521 @@
"""
Support for multirange types adaptation.
"""
# Copyright (C) 2021 The Psycopg Team
from decimal import Decimal
from typing import Any, Generic, List, Iterable
from typing import MutableSequence, Optional, Type, Union, overload
from datetime import date, datetime
from .. import errors as e
from .. import postgres
from ..pq import Format
from ..abc import AdaptContext, Buffer, Dumper, DumperKey
from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat
from .._compat import cache
from .._struct import pack_len, unpack_len
from ..postgres import INVALID_OID, TEXT_OID
from .._typeinfo import MultirangeInfo as MultirangeInfo # exported here
from .range import Range, T, load_range_text, load_range_binary
from .range import dump_range_text, dump_range_binary, fail_dump
class Multirange(MutableSequence[Range[T]]):
"""Python representation for a PostgreSQL multirange type.
:param items: Sequence of ranges to initialise the object.
"""
def __init__(self, items: Iterable[Range[T]] = ()):
self._ranges: List[Range[T]] = list(map(self._check_type, items))
def _check_type(self, item: Any) -> Range[Any]:
if not isinstance(item, Range):
raise TypeError(
f"Multirange is a sequence of Range, got {type(item).__name__}"
)
return item
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self._ranges!r})"
def __str__(self) -> str:
return f"{{{', '.join(map(str, self._ranges))}}}"
@overload
def __getitem__(self, index: int) -> Range[T]:
...
@overload
def __getitem__(self, index: slice) -> "Multirange[T]":
...
def __getitem__(self, index: Union[int, slice]) -> "Union[Range[T],Multirange[T]]":
if isinstance(index, int):
return self._ranges[index]
else:
return Multirange(self._ranges[index])
def __len__(self) -> int:
return len(self._ranges)
@overload
def __setitem__(self, index: int, value: Range[T]) -> None:
...
@overload
def __setitem__(self, index: slice, value: Iterable[Range[T]]) -> None:
...
def __setitem__(
self,
index: Union[int, slice],
value: Union[Range[T], Iterable[Range[T]]],
) -> None:
if isinstance(index, int):
self._check_type(value)
self._ranges[index] = self._check_type(value)
elif not isinstance(value, Iterable):
raise TypeError("can only assign an iterable")
else:
value = map(self._check_type, value)
self._ranges[index] = value
def __delitem__(self, index: Union[int, slice]) -> None:
del self._ranges[index]
def insert(self, index: int, value: Range[T]) -> None:
self._ranges.insert(index, self._check_type(value))
def __eq__(self, other: Any) -> bool:
if not isinstance(other, Multirange):
return False
return self._ranges == other._ranges
# Order is arbitrary but consistent
def __lt__(self, other: Any) -> bool:
if not isinstance(other, Multirange):
return NotImplemented
return self._ranges < other._ranges
def __le__(self, other: Any) -> bool:
return self == other or self < other # type: ignore
def __gt__(self, other: Any) -> bool:
if not isinstance(other, Multirange):
return NotImplemented
return self._ranges > other._ranges
def __ge__(self, other: Any) -> bool:
return self == other or self > other # type: ignore
# Subclasses to specify a specific subtype. Usually not needed
class Int4Multirange(Multirange[int]):
pass
class Int8Multirange(Multirange[int]):
pass
class NumericMultirange(Multirange[Decimal]):
pass
class DateMultirange(Multirange[date]):
pass
class TimestampMultirange(Multirange[datetime]):
pass
class TimestamptzMultirange(Multirange[datetime]):
pass
class BaseMultirangeDumper(RecursiveDumper):
def __init__(self, cls: type, context: Optional[AdaptContext] = None):
super().__init__(cls, context)
self.sub_dumper: Optional[Dumper] = None
self._adapt_format = PyFormat.from_pq(self.format)
def get_key(self, obj: Multirange[Any], format: PyFormat) -> DumperKey:
# If we are a subclass whose oid is specified we don't need upgrade
if self.cls is not Multirange:
return self.cls
item = self._get_item(obj)
if item is not None:
sd = self._tx.get_dumper(item, self._adapt_format)
return (self.cls, sd.get_key(item, format))
else:
return (self.cls,)
def upgrade(self, obj: Multirange[Any], format: PyFormat) -> "BaseMultirangeDumper":
# If we are a subclass whose oid is specified we don't need upgrade
if self.cls is not Multirange:
return self
item = self._get_item(obj)
if item is None:
return MultirangeDumper(self.cls)
dumper: BaseMultirangeDumper
if type(item) is int:
# postgres won't cast int4range -> int8range so we must use
# text format and unknown oid here
sd = self._tx.get_dumper(item, PyFormat.TEXT)
dumper = MultirangeDumper(self.cls, self._tx)
dumper.sub_dumper = sd
dumper.oid = INVALID_OID
return dumper
sd = self._tx.get_dumper(item, format)
dumper = type(self)(self.cls, self._tx)
dumper.sub_dumper = sd
if sd.oid == INVALID_OID and isinstance(item, str):
# Work around the normal mapping where text is dumped as unknown
dumper.oid = self._get_multirange_oid(TEXT_OID)
else:
dumper.oid = self._get_multirange_oid(sd.oid)
return dumper
def _get_item(self, obj: Multirange[Any]) -> Any:
"""
Return a member representative of the multirange
"""
for r in obj:
if r.lower is not None:
return r.lower
if r.upper is not None:
return r.upper
return None
def _get_multirange_oid(self, sub_oid: int) -> int:
"""
Return the oid of the range from the oid of its elements.
"""
info = self._tx.adapters.types.get_by_subtype(MultirangeInfo, sub_oid)
return info.oid if info else INVALID_OID
class MultirangeDumper(BaseMultirangeDumper):
"""
Dumper for multirange types.
The dumper can upgrade to one specific for a different range type.
"""
def dump(self, obj: Multirange[Any]) -> Buffer:
if not obj:
return b"{}"
item = self._get_item(obj)
if item is not None:
dump = self._tx.get_dumper(item, self._adapt_format).dump
else:
dump = fail_dump
out: List[Buffer] = [b"{"]
for r in obj:
out.append(dump_range_text(r, dump))
out.append(b",")
out[-1] = b"}"
return b"".join(out)
class MultirangeBinaryDumper(BaseMultirangeDumper):
format = Format.BINARY
def dump(self, obj: Multirange[Any]) -> Buffer:
item = self._get_item(obj)
if item is not None:
dump = self._tx.get_dumper(item, self._adapt_format).dump
else:
dump = fail_dump
out: List[Buffer] = [pack_len(len(obj))]
for r in obj:
data = dump_range_binary(r, dump)
out.append(pack_len(len(data)))
out.append(data)
return b"".join(out)
class BaseMultirangeLoader(RecursiveLoader, Generic[T]):
subtype_oid: int
def __init__(self, oid: int, context: Optional[AdaptContext] = None):
super().__init__(oid, context)
self._load = self._tx.get_loader(self.subtype_oid, format=self.format).load
class MultirangeLoader(BaseMultirangeLoader[T]):
def load(self, data: Buffer) -> Multirange[T]:
if not data or data[0] != _START_INT:
raise e.DataError(
"malformed multirange starting with"
f" {bytes(data[:1]).decode('utf8', 'replace')}"
)
out = Multirange[T]()
if data == b"{}":
return out
pos = 1
data = data[pos:]
try:
while True:
r, pos = load_range_text(data, self._load)
out.append(r)
sep = data[pos] # can raise IndexError
if sep == _SEP_INT:
data = data[pos + 1 :]
continue
elif sep == _END_INT:
if len(data) == pos + 1:
return out
else:
raise e.DataError(
"malformed multirange: data after closing brace"
)
else:
raise e.DataError(
f"malformed multirange: found unexpected {chr(sep)}"
)
except IndexError:
raise e.DataError("malformed multirange: separator missing")
return out
_SEP_INT = ord(",")
_START_INT = ord("{")
_END_INT = ord("}")
class MultirangeBinaryLoader(BaseMultirangeLoader[T]):
format = Format.BINARY
def load(self, data: Buffer) -> Multirange[T]:
nelems = unpack_len(data, 0)[0]
pos = 4
out = Multirange[T]()
for i in range(nelems):
length = unpack_len(data, pos)[0]
pos += 4
out.append(load_range_binary(data[pos : pos + length], self._load))
pos += length
if pos != len(data):
raise e.DataError("unexpected trailing data in multirange")
return out
def register_multirange(
info: MultirangeInfo, context: Optional[AdaptContext] = None
) -> None:
"""Register the adapters to load and dump a multirange type.
:param info: The object with the information about the range to register.
:param context: The context where to register the adapters. If `!None`,
register it globally.
Register loaders so that loading data of this type will result in a `Range`
with bounds parsed as the right subtype.
.. note::
Registering the adapters doesn't affect objects already created, even
if they are children of the registered context. For instance,
registering the adapter globally doesn't affect already existing
connections.
"""
# A friendly error warning instead of an AttributeError in case fetch()
# failed and it wasn't noticed.
if not info:
raise TypeError("no info passed. Is the requested multirange available?")
# Register arrays and type info
info.register(context)
adapters = context.adapters if context else postgres.adapters
# generate and register a customized text loader
loader: Type[BaseMultirangeLoader[Any]]
loader = _make_loader(info.name, info.subtype_oid)
adapters.register_loader(info.oid, loader)
# generate and register a customized binary loader
loader = _make_binary_loader(info.name, info.subtype_oid)
adapters.register_loader(info.oid, loader)
# Cache all dynamically-generated types to avoid leaks in case the types
# cannot be GC'd.
@cache
def _make_loader(name: str, oid: int) -> Type[MultirangeLoader[Any]]:
return type(f"{name.title()}Loader", (MultirangeLoader,), {"subtype_oid": oid})
@cache
def _make_binary_loader(name: str, oid: int) -> Type[MultirangeBinaryLoader[Any]]:
return type(
f"{name.title()}BinaryLoader", (MultirangeBinaryLoader,), {"subtype_oid": oid}
)
# Text dumpers for builtin multirange types wrappers
# These are registered on specific subtypes so that the upgrade mechanism
# doesn't kick in.
class Int4MultirangeDumper(MultirangeDumper):
oid = postgres.types["int4multirange"].oid
class Int8MultirangeDumper(MultirangeDumper):
oid = postgres.types["int8multirange"].oid
class NumericMultirangeDumper(MultirangeDumper):
oid = postgres.types["nummultirange"].oid
class DateMultirangeDumper(MultirangeDumper):
oid = postgres.types["datemultirange"].oid
class TimestampMultirangeDumper(MultirangeDumper):
oid = postgres.types["tsmultirange"].oid
class TimestamptzMultirangeDumper(MultirangeDumper):
oid = postgres.types["tstzmultirange"].oid
# Binary dumpers for builtin multirange types wrappers
# These are registered on specific subtypes so that the upgrade mechanism
# doesn't kick in.
class Int4MultirangeBinaryDumper(MultirangeBinaryDumper):
oid = postgres.types["int4multirange"].oid
class Int8MultirangeBinaryDumper(MultirangeBinaryDumper):
oid = postgres.types["int8multirange"].oid
class NumericMultirangeBinaryDumper(MultirangeBinaryDumper):
oid = postgres.types["nummultirange"].oid
class DateMultirangeBinaryDumper(MultirangeBinaryDumper):
oid = postgres.types["datemultirange"].oid
class TimestampMultirangeBinaryDumper(MultirangeBinaryDumper):
oid = postgres.types["tsmultirange"].oid
class TimestamptzMultirangeBinaryDumper(MultirangeBinaryDumper):
oid = postgres.types["tstzmultirange"].oid
# Text loaders for builtin multirange types
class Int4MultirangeLoader(MultirangeLoader[int]):
subtype_oid = postgres.types["int4"].oid
class Int8MultirangeLoader(MultirangeLoader[int]):
subtype_oid = postgres.types["int8"].oid
class NumericMultirangeLoader(MultirangeLoader[Decimal]):
subtype_oid = postgres.types["numeric"].oid
class DateMultirangeLoader(MultirangeLoader[date]):
subtype_oid = postgres.types["date"].oid
class TimestampMultirangeLoader(MultirangeLoader[datetime]):
subtype_oid = postgres.types["timestamp"].oid
class TimestampTZMultirangeLoader(MultirangeLoader[datetime]):
subtype_oid = postgres.types["timestamptz"].oid
# Binary loaders for builtin multirange types
class Int4MultirangeBinaryLoader(MultirangeBinaryLoader[int]):
subtype_oid = postgres.types["int4"].oid
class Int8MultirangeBinaryLoader(MultirangeBinaryLoader[int]):
subtype_oid = postgres.types["int8"].oid
class NumericMultirangeBinaryLoader(MultirangeBinaryLoader[Decimal]):
subtype_oid = postgres.types["numeric"].oid
class DateMultirangeBinaryLoader(MultirangeBinaryLoader[date]):
subtype_oid = postgres.types["date"].oid
class TimestampMultirangeBinaryLoader(MultirangeBinaryLoader[datetime]):
subtype_oid = postgres.types["timestamp"].oid
class TimestampTZMultirangeBinaryLoader(MultirangeBinaryLoader[datetime]):
subtype_oid = postgres.types["timestamptz"].oid
def register_default_adapters(context: AdaptContext) -> None:
adapters = context.adapters
adapters.register_dumper(Multirange, MultirangeBinaryDumper)
adapters.register_dumper(Multirange, MultirangeDumper)
adapters.register_dumper(Int4Multirange, Int4MultirangeDumper)
adapters.register_dumper(Int8Multirange, Int8MultirangeDumper)
adapters.register_dumper(NumericMultirange, NumericMultirangeDumper)
adapters.register_dumper(DateMultirange, DateMultirangeDumper)
adapters.register_dumper(TimestampMultirange, TimestampMultirangeDumper)
adapters.register_dumper(TimestamptzMultirange, TimestamptzMultirangeDumper)
adapters.register_dumper(Int4Multirange, Int4MultirangeBinaryDumper)
adapters.register_dumper(Int8Multirange, Int8MultirangeBinaryDumper)
adapters.register_dumper(NumericMultirange, NumericMultirangeBinaryDumper)
adapters.register_dumper(DateMultirange, DateMultirangeBinaryDumper)
adapters.register_dumper(TimestampMultirange, TimestampMultirangeBinaryDumper)
adapters.register_dumper(TimestamptzMultirange, TimestamptzMultirangeBinaryDumper)
adapters.register_loader("int4multirange", Int4MultirangeLoader)
adapters.register_loader("int8multirange", Int8MultirangeLoader)
adapters.register_loader("nummultirange", NumericMultirangeLoader)
adapters.register_loader("datemultirange", DateMultirangeLoader)
adapters.register_loader("tsmultirange", TimestampMultirangeLoader)
adapters.register_loader("tstzmultirange", TimestampTZMultirangeLoader)
adapters.register_loader("int4multirange", Int4MultirangeBinaryLoader)
adapters.register_loader("int8multirange", Int8MultirangeBinaryLoader)
adapters.register_loader("nummultirange", NumericMultirangeBinaryLoader)
adapters.register_loader("datemultirange", DateMultirangeBinaryLoader)
adapters.register_loader("tsmultirange", TimestampMultirangeBinaryLoader)
adapters.register_loader("tstzmultirange", TimestampTZMultirangeBinaryLoader)

View File

@ -0,0 +1,201 @@
"""
Adapters for network types.
"""
# Copyright (C) 2020 The Psycopg Team
from typing import Callable, Optional, Type, Union, TYPE_CHECKING
from typing_extensions import TypeAlias
from .. import postgres
from ..pq import Format
from ..abc import AdaptContext
from ..adapt import Buffer, Dumper, Loader
if TYPE_CHECKING:
import ipaddress
Address: TypeAlias = Union["ipaddress.IPv4Address", "ipaddress.IPv6Address"]
Interface: TypeAlias = Union["ipaddress.IPv4Interface", "ipaddress.IPv6Interface"]
Network: TypeAlias = Union["ipaddress.IPv4Network", "ipaddress.IPv6Network"]
# These objects will be imported lazily
ip_address: Callable[[str], Address] = None # type: ignore[assignment]
ip_interface: Callable[[str], Interface] = None # type: ignore[assignment]
ip_network: Callable[[str], Network] = None # type: ignore[assignment]
IPv4Address: "Type[ipaddress.IPv4Address]" = None # type: ignore[assignment]
IPv6Address: "Type[ipaddress.IPv6Address]" = None # type: ignore[assignment]
IPv4Interface: "Type[ipaddress.IPv4Interface]" = None # type: ignore[assignment]
IPv6Interface: "Type[ipaddress.IPv6Interface]" = None # type: ignore[assignment]
IPv4Network: "Type[ipaddress.IPv4Network]" = None # type: ignore[assignment]
IPv6Network: "Type[ipaddress.IPv6Network]" = None # type: ignore[assignment]
PGSQL_AF_INET = 2
PGSQL_AF_INET6 = 3
IPV4_PREFIXLEN = 32
IPV6_PREFIXLEN = 128
class _LazyIpaddress:
def _ensure_module(self) -> None:
global ip_address, ip_interface, ip_network
global IPv4Address, IPv6Address, IPv4Interface, IPv6Interface
global IPv4Network, IPv6Network
if ip_address is None:
from ipaddress import ip_address, ip_interface, ip_network
from ipaddress import IPv4Address, IPv6Address
from ipaddress import IPv4Interface, IPv6Interface
from ipaddress import IPv4Network, IPv6Network
class InterfaceDumper(Dumper):
oid = postgres.types["inet"].oid
def dump(self, obj: Interface) -> bytes:
return str(obj).encode()
class NetworkDumper(Dumper):
oid = postgres.types["cidr"].oid
def dump(self, obj: Network) -> bytes:
return str(obj).encode()
class _AIBinaryDumper(Dumper):
format = Format.BINARY
oid = postgres.types["inet"].oid
class AddressBinaryDumper(_AIBinaryDumper):
def dump(self, obj: Address) -> bytes:
packed = obj.packed
family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6
head = bytes((family, obj.max_prefixlen, 0, len(packed)))
return head + packed
class InterfaceBinaryDumper(_AIBinaryDumper):
def dump(self, obj: Interface) -> bytes:
packed = obj.packed
family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6
head = bytes((family, obj.network.prefixlen, 0, len(packed)))
return head + packed
class InetBinaryDumper(_AIBinaryDumper, _LazyIpaddress):
"""Either an address or an interface to inet
Used when looking up by oid.
"""
def __init__(self, cls: type, context: Optional[AdaptContext] = None):
super().__init__(cls, context)
self._ensure_module()
def dump(self, obj: Union[Address, Interface]) -> bytes:
packed = obj.packed
family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6
if isinstance(obj, (IPv4Interface, IPv6Interface)):
prefixlen = obj.network.prefixlen
else:
prefixlen = obj.max_prefixlen
head = bytes((family, prefixlen, 0, len(packed)))
return head + packed
class NetworkBinaryDumper(Dumper):
format = Format.BINARY
oid = postgres.types["cidr"].oid
def dump(self, obj: Network) -> bytes:
packed = obj.network_address.packed
family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6
head = bytes((family, obj.prefixlen, 1, len(packed)))
return head + packed
class _LazyIpaddressLoader(Loader, _LazyIpaddress):
def __init__(self, oid: int, context: Optional[AdaptContext] = None):
super().__init__(oid, context)
self._ensure_module()
class InetLoader(_LazyIpaddressLoader):
def load(self, data: Buffer) -> Union[Address, Interface]:
if isinstance(data, memoryview):
data = bytes(data)
if b"/" in data:
return ip_interface(data.decode())
else:
return ip_address(data.decode())
class InetBinaryLoader(_LazyIpaddressLoader):
format = Format.BINARY
def load(self, data: Buffer) -> Union[Address, Interface]:
if isinstance(data, memoryview):
data = bytes(data)
prefix = data[1]
packed = data[4:]
if data[0] == PGSQL_AF_INET:
if prefix == IPV4_PREFIXLEN:
return IPv4Address(packed)
else:
return IPv4Interface((packed, prefix))
else:
if prefix == IPV6_PREFIXLEN:
return IPv6Address(packed)
else:
return IPv6Interface((packed, prefix))
class CidrLoader(_LazyIpaddressLoader):
def load(self, data: Buffer) -> Network:
if isinstance(data, memoryview):
data = bytes(data)
return ip_network(data.decode())
class CidrBinaryLoader(_LazyIpaddressLoader):
format = Format.BINARY
def load(self, data: Buffer) -> Network:
if isinstance(data, memoryview):
data = bytes(data)
prefix = data[1]
packed = data[4:]
if data[0] == PGSQL_AF_INET:
return IPv4Network((packed, prefix))
else:
return IPv6Network((packed, prefix))
return ip_network(data.decode())
def register_default_adapters(context: AdaptContext) -> None:
adapters = context.adapters
adapters.register_dumper("ipaddress.IPv4Address", InterfaceDumper)
adapters.register_dumper("ipaddress.IPv6Address", InterfaceDumper)
adapters.register_dumper("ipaddress.IPv4Interface", InterfaceDumper)
adapters.register_dumper("ipaddress.IPv6Interface", InterfaceDumper)
adapters.register_dumper("ipaddress.IPv4Network", NetworkDumper)
adapters.register_dumper("ipaddress.IPv6Network", NetworkDumper)
adapters.register_dumper("ipaddress.IPv4Address", AddressBinaryDumper)
adapters.register_dumper("ipaddress.IPv6Address", AddressBinaryDumper)
adapters.register_dumper("ipaddress.IPv4Interface", InterfaceBinaryDumper)
adapters.register_dumper("ipaddress.IPv6Interface", InterfaceBinaryDumper)
adapters.register_dumper("ipaddress.IPv4Network", NetworkBinaryDumper)
adapters.register_dumper("ipaddress.IPv6Network", NetworkBinaryDumper)
adapters.register_dumper(None, InetBinaryDumper)
adapters.register_loader("inet", InetLoader)
adapters.register_loader("inet", InetBinaryLoader)
adapters.register_loader("cidr", CidrLoader)
adapters.register_loader("cidr", CidrBinaryLoader)

View File

@ -0,0 +1,25 @@
"""
Adapters for None.
"""
# Copyright (C) 2020 The Psycopg Team
from ..abc import AdaptContext, NoneType
from ..adapt import Dumper
class NoneDumper(Dumper):
"""
Not a complete dumper as it doesn't implement dump(), but it implements
quote(), so it can be used in sql composition.
"""
def dump(self, obj: None) -> bytes:
raise NotImplementedError("NULL is passed to Postgres in other ways")
def quote(self, obj: None) -> bytes:
return b"NULL"
def register_default_adapters(context: AdaptContext) -> None:
context.adapters.register_dumper(NoneType, NoneDumper)

View File

@ -0,0 +1,495 @@
"""
Adapters for numeric types.
"""
# Copyright (C) 2020 The Psycopg Team
import struct
from math import log
from typing import Any, Callable, DefaultDict, Dict, Tuple, Union, cast
from decimal import Decimal, DefaultContext, Context
from .. import postgres
from .. import errors as e
from ..pq import Format
from ..abc import AdaptContext
from ..adapt import Buffer, Dumper, Loader, PyFormat
from .._struct import pack_int2, pack_uint2, unpack_int2
from .._struct import pack_int4, pack_uint4, unpack_int4, unpack_uint4
from .._struct import pack_int8, unpack_int8
from .._struct import pack_float4, pack_float8, unpack_float4, unpack_float8
# Exposed here
from .._wrappers import (
Int2 as Int2,
Int4 as Int4,
Int8 as Int8,
IntNumeric as IntNumeric,
Oid as Oid,
Float4 as Float4,
Float8 as Float8,
)
class _IntDumper(Dumper):
def dump(self, obj: Any) -> Buffer:
t = type(obj)
if t is not int:
# Convert to int in order to dump IntEnum correctly
if issubclass(t, int):
obj = int(obj)
else:
raise e.DataError(f"integer expected, got {type(obj).__name__!r}")
return str(obj).encode()
def quote(self, obj: Any) -> Buffer:
value = self.dump(obj)
return value if obj >= 0 else b" " + value
class _SpecialValuesDumper(Dumper):
_special: Dict[bytes, bytes] = {}
def dump(self, obj: Any) -> bytes:
return str(obj).encode()
def quote(self, obj: Any) -> bytes:
value = self.dump(obj)
if value in self._special:
return self._special[value]
return value if obj >= 0 else b" " + value
class FloatDumper(_SpecialValuesDumper):
oid = postgres.types["float8"].oid
_special = {
b"inf": b"'Infinity'::float8",
b"-inf": b"'-Infinity'::float8",
b"nan": b"'NaN'::float8",
}
class Float4Dumper(FloatDumper):
oid = postgres.types["float4"].oid
class FloatBinaryDumper(Dumper):
format = Format.BINARY
oid = postgres.types["float8"].oid
def dump(self, obj: float) -> bytes:
return pack_float8(obj)
class Float4BinaryDumper(FloatBinaryDumper):
oid = postgres.types["float4"].oid
def dump(self, obj: float) -> bytes:
return pack_float4(obj)
class DecimalDumper(_SpecialValuesDumper):
oid = postgres.types["numeric"].oid
def dump(self, obj: Decimal) -> bytes:
if obj.is_nan():
# cover NaN and sNaN
return b"NaN"
else:
return str(obj).encode()
_special = {
b"Infinity": b"'Infinity'::numeric",
b"-Infinity": b"'-Infinity'::numeric",
b"NaN": b"'NaN'::numeric",
}
class Int2Dumper(_IntDumper):
oid = postgres.types["int2"].oid
class Int4Dumper(_IntDumper):
oid = postgres.types["int4"].oid
class Int8Dumper(_IntDumper):
oid = postgres.types["int8"].oid
class IntNumericDumper(_IntDumper):
oid = postgres.types["numeric"].oid
class OidDumper(_IntDumper):
oid = postgres.types["oid"].oid
class IntDumper(Dumper):
def dump(self, obj: Any) -> bytes:
raise TypeError(
f"{type(self).__name__} is a dispatcher to other dumpers:"
" dump() is not supposed to be called"
)
def get_key(self, obj: int, format: PyFormat) -> type:
return self.upgrade(obj, format).cls
_int2_dumper = Int2Dumper(Int2)
_int4_dumper = Int4Dumper(Int4)
_int8_dumper = Int8Dumper(Int8)
_int_numeric_dumper = IntNumericDumper(IntNumeric)
def upgrade(self, obj: int, format: PyFormat) -> Dumper:
if -(2**31) <= obj < 2**31:
if -(2**15) <= obj < 2**15:
return self._int2_dumper
else:
return self._int4_dumper
else:
if -(2**63) <= obj < 2**63:
return self._int8_dumper
else:
return self._int_numeric_dumper
class Int2BinaryDumper(Int2Dumper):
format = Format.BINARY
def dump(self, obj: int) -> bytes:
return pack_int2(obj)
class Int4BinaryDumper(Int4Dumper):
format = Format.BINARY
def dump(self, obj: int) -> bytes:
return pack_int4(obj)
class Int8BinaryDumper(Int8Dumper):
format = Format.BINARY
def dump(self, obj: int) -> bytes:
return pack_int8(obj)
# Ratio between number of bits required to store a number and number of pg
# decimal digits required.
BIT_PER_PGDIGIT = log(2) / log(10_000)
class IntNumericBinaryDumper(IntNumericDumper):
format = Format.BINARY
def dump(self, obj: int) -> Buffer:
return dump_int_to_numeric_binary(obj)
class OidBinaryDumper(OidDumper):
format = Format.BINARY
def dump(self, obj: int) -> bytes:
return pack_uint4(obj)
class IntBinaryDumper(IntDumper):
format = Format.BINARY
_int2_dumper = Int2BinaryDumper(Int2)
_int4_dumper = Int4BinaryDumper(Int4)
_int8_dumper = Int8BinaryDumper(Int8)
_int_numeric_dumper = IntNumericBinaryDumper(IntNumeric)
class IntLoader(Loader):
def load(self, data: Buffer) -> int:
# it supports bytes directly
return int(data)
class Int2BinaryLoader(Loader):
format = Format.BINARY
def load(self, data: Buffer) -> int:
return unpack_int2(data)[0]
class Int4BinaryLoader(Loader):
format = Format.BINARY
def load(self, data: Buffer) -> int:
return unpack_int4(data)[0]
class Int8BinaryLoader(Loader):
format = Format.BINARY
def load(self, data: Buffer) -> int:
return unpack_int8(data)[0]
class OidBinaryLoader(Loader):
format = Format.BINARY
def load(self, data: Buffer) -> int:
return unpack_uint4(data)[0]
class FloatLoader(Loader):
def load(self, data: Buffer) -> float:
# it supports bytes directly
return float(data)
class Float4BinaryLoader(Loader):
format = Format.BINARY
def load(self, data: Buffer) -> float:
return unpack_float4(data)[0]
class Float8BinaryLoader(Loader):
format = Format.BINARY
def load(self, data: Buffer) -> float:
return unpack_float8(data)[0]
class NumericLoader(Loader):
def load(self, data: Buffer) -> Decimal:
if isinstance(data, memoryview):
data = bytes(data)
return Decimal(data.decode())
DEC_DIGITS = 4 # decimal digits per Postgres "digit"
NUMERIC_POS = 0x0000
NUMERIC_NEG = 0x4000
NUMERIC_NAN = 0xC000
NUMERIC_PINF = 0xD000
NUMERIC_NINF = 0xF000
_decimal_special = {
NUMERIC_NAN: Decimal("NaN"),
NUMERIC_PINF: Decimal("Infinity"),
NUMERIC_NINF: Decimal("-Infinity"),
}
class _ContextMap(DefaultDict[int, Context]):
"""
Cache for decimal contexts to use when the precision requires it.
Note: if the default context is used (prec=28) you can get an invalid
operation or a rounding to 0:
- Decimal(1000).shift(24) = Decimal('1000000000000000000000000000')
- Decimal(1000).shift(25) = Decimal('0')
- Decimal(1000).shift(30) raises InvalidOperation
"""
def __missing__(self, key: int) -> Context:
val = Context(prec=key)
self[key] = val
return val
_contexts = _ContextMap()
for i in range(DefaultContext.prec):
_contexts[i] = DefaultContext
_unpack_numeric_head = cast(
Callable[[Buffer], Tuple[int, int, int, int]],
struct.Struct("!HhHH").unpack_from,
)
_pack_numeric_head = cast(
Callable[[int, int, int, int], bytes],
struct.Struct("!HhHH").pack,
)
class NumericBinaryLoader(Loader):
format = Format.BINARY
def load(self, data: Buffer) -> Decimal:
ndigits, weight, sign, dscale = _unpack_numeric_head(data)
if sign == NUMERIC_POS or sign == NUMERIC_NEG:
val = 0
for i in range(8, len(data), 2):
val = val * 10_000 + data[i] * 0x100 + data[i + 1]
shift = dscale - (ndigits - weight - 1) * DEC_DIGITS
ctx = _contexts[(weight + 2) * DEC_DIGITS + dscale]
return (
Decimal(val if sign == NUMERIC_POS else -val)
.scaleb(-dscale, ctx)
.shift(shift, ctx)
)
else:
try:
return _decimal_special[sign]
except KeyError:
raise e.DataError(f"bad value for numeric sign: 0x{sign:X}") from None
NUMERIC_NAN_BIN = _pack_numeric_head(0, 0, NUMERIC_NAN, 0)
NUMERIC_PINF_BIN = _pack_numeric_head(0, 0, NUMERIC_PINF, 0)
NUMERIC_NINF_BIN = _pack_numeric_head(0, 0, NUMERIC_NINF, 0)
class DecimalBinaryDumper(Dumper):
format = Format.BINARY
oid = postgres.types["numeric"].oid
def dump(self, obj: Decimal) -> Buffer:
return dump_decimal_to_numeric_binary(obj)
class NumericDumper(DecimalDumper):
def dump(self, obj: Union[Decimal, int]) -> bytes:
if isinstance(obj, int):
return str(obj).encode()
else:
return super().dump(obj)
class NumericBinaryDumper(Dumper):
format = Format.BINARY
oid = postgres.types["numeric"].oid
def dump(self, obj: Union[Decimal, int]) -> Buffer:
if isinstance(obj, int):
return dump_int_to_numeric_binary(obj)
else:
return dump_decimal_to_numeric_binary(obj)
def dump_decimal_to_numeric_binary(obj: Decimal) -> Union[bytearray, bytes]:
sign, digits, exp = obj.as_tuple()
if exp == "n" or exp == "N":
return NUMERIC_NAN_BIN
elif exp == "F":
return NUMERIC_NINF_BIN if sign else NUMERIC_PINF_BIN
# Weights of py digits into a pg digit according to their positions.
# Starting with an index wi != 0 is equivalent to prepending 0's to
# the digits tuple, but without really changing it.
weights = (1000, 100, 10, 1)
wi = 0
ndigits = nzdigits = len(digits)
# Find the last nonzero digit
while nzdigits > 0 and digits[nzdigits - 1] == 0:
nzdigits -= 1
if exp <= 0:
dscale = -exp
else:
dscale = 0
# align the py digits to the pg digits if there's some py exponent
ndigits += exp % DEC_DIGITS
if not nzdigits:
return _pack_numeric_head(0, 0, NUMERIC_POS, dscale)
# Equivalent of 0-padding left to align the py digits to the pg digits
# but without changing the digits tuple.
mod = (ndigits - dscale) % DEC_DIGITS
if mod:
wi = DEC_DIGITS - mod
ndigits += wi
tmp = nzdigits + wi
out = bytearray(
_pack_numeric_head(
tmp // DEC_DIGITS + (tmp % DEC_DIGITS and 1), # ndigits
(ndigits + exp) // DEC_DIGITS - 1, # weight
NUMERIC_NEG if sign else NUMERIC_POS, # sign
dscale,
)
)
pgdigit = 0
for i in range(nzdigits):
pgdigit += weights[wi] * digits[i]
wi += 1
if wi >= DEC_DIGITS:
out += pack_uint2(pgdigit)
pgdigit = wi = 0
if pgdigit:
out += pack_uint2(pgdigit)
return out
def dump_int_to_numeric_binary(obj: int) -> bytearray:
ndigits = int(obj.bit_length() * BIT_PER_PGDIGIT) + 1
out = bytearray(b"\x00\x00" * (ndigits + 4))
if obj < 0:
sign = NUMERIC_NEG
obj = -obj
else:
sign = NUMERIC_POS
out[:8] = _pack_numeric_head(ndigits, ndigits - 1, sign, 0)
i = 8 + (ndigits - 1) * 2
while obj:
rem = obj % 10_000
obj //= 10_000
out[i : i + 2] = pack_uint2(rem)
i -= 2
return out
def register_default_adapters(context: AdaptContext) -> None:
adapters = context.adapters
adapters.register_dumper(int, IntDumper)
adapters.register_dumper(int, IntBinaryDumper)
adapters.register_dumper(float, FloatDumper)
adapters.register_dumper(float, FloatBinaryDumper)
adapters.register_dumper(Int2, Int2Dumper)
adapters.register_dumper(Int4, Int4Dumper)
adapters.register_dumper(Int8, Int8Dumper)
adapters.register_dumper(IntNumeric, IntNumericDumper)
adapters.register_dumper(Oid, OidDumper)
# The binary dumper is currently some 30% slower, so default to text
# (see tests/scripts/testdec.py for a rough benchmark)
# Also, must be after IntNumericDumper
adapters.register_dumper("decimal.Decimal", DecimalBinaryDumper)
adapters.register_dumper("decimal.Decimal", DecimalDumper)
# Used only by oid, can take both int and Decimal as input
adapters.register_dumper(None, NumericBinaryDumper)
adapters.register_dumper(None, NumericDumper)
adapters.register_dumper(Float4, Float4Dumper)
adapters.register_dumper(Float8, FloatDumper)
adapters.register_dumper(Int2, Int2BinaryDumper)
adapters.register_dumper(Int4, Int4BinaryDumper)
adapters.register_dumper(Int8, Int8BinaryDumper)
adapters.register_dumper(Oid, OidBinaryDumper)
adapters.register_dumper(Float4, Float4BinaryDumper)
adapters.register_dumper(Float8, FloatBinaryDumper)
adapters.register_loader("int2", IntLoader)
adapters.register_loader("int4", IntLoader)
adapters.register_loader("int8", IntLoader)
adapters.register_loader("oid", IntLoader)
adapters.register_loader("int2", Int2BinaryLoader)
adapters.register_loader("int4", Int4BinaryLoader)
adapters.register_loader("int8", Int8BinaryLoader)
adapters.register_loader("oid", OidBinaryLoader)
adapters.register_loader("float4", FloatLoader)
adapters.register_loader("float8", FloatLoader)
adapters.register_loader("float4", Float4BinaryLoader)
adapters.register_loader("float8", Float8BinaryLoader)
adapters.register_loader("numeric", NumericLoader)
adapters.register_loader("numeric", NumericBinaryLoader)

View File

@ -0,0 +1,708 @@
"""
Support for range types adaptation.
"""
# Copyright (C) 2020 The Psycopg Team
import re
from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, Type, Tuple
from typing import cast
from decimal import Decimal
from datetime import date, datetime
from .. import errors as e
from .. import postgres
from ..pq import Format
from ..abc import AdaptContext, Buffer, Dumper, DumperKey
from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat
from .._compat import cache
from .._struct import pack_len, unpack_len
from ..postgres import INVALID_OID, TEXT_OID
from .._typeinfo import RangeInfo as RangeInfo # exported here
RANGE_EMPTY = 0x01 # range is empty
RANGE_LB_INC = 0x02 # lower bound is inclusive
RANGE_UB_INC = 0x04 # upper bound is inclusive
RANGE_LB_INF = 0x08 # lower bound is -infinity
RANGE_UB_INF = 0x10 # upper bound is +infinity
_EMPTY_HEAD = bytes([RANGE_EMPTY])
T = TypeVar("T")
class Range(Generic[T]):
"""Python representation for a PostgreSQL range type.
:param lower: lower bound for the range. `!None` means unbound
:param upper: upper bound for the range. `!None` means unbound
:param bounds: one of the literal strings ``()``, ``[)``, ``(]``, ``[]``,
representing whether the lower or upper bounds are included
:param empty: if `!True`, the range is empty
"""
__slots__ = ("_lower", "_upper", "_bounds")
def __init__(
self,
lower: Optional[T] = None,
upper: Optional[T] = None,
bounds: str = "[)",
empty: bool = False,
):
if not empty:
if bounds not in ("[)", "(]", "()", "[]"):
raise ValueError("bound flags not valid: %r" % bounds)
self._lower = lower
self._upper = upper
# Make bounds consistent with infs
if lower is None and bounds[0] == "[":
bounds = "(" + bounds[1]
if upper is None and bounds[1] == "]":
bounds = bounds[0] + ")"
self._bounds = bounds
else:
self._lower = self._upper = None
self._bounds = ""
def __repr__(self) -> str:
if self._bounds:
args = f"{self._lower!r}, {self._upper!r}, {self._bounds!r}"
else:
args = "empty=True"
return f"{self.__class__.__name__}({args})"
def __str__(self) -> str:
if not self._bounds:
return "empty"
items = [
self._bounds[0],
str(self._lower),
", ",
str(self._upper),
self._bounds[1],
]
return "".join(items)
@property
def lower(self) -> Optional[T]:
"""The lower bound of the range. `!None` if empty or unbound."""
return self._lower
@property
def upper(self) -> Optional[T]:
"""The upper bound of the range. `!None` if empty or unbound."""
return self._upper
@property
def bounds(self) -> str:
"""The bounds string (two characters from '[', '(', ']', ')')."""
return self._bounds
@property
def isempty(self) -> bool:
"""`!True` if the range is empty."""
return not self._bounds
@property
def lower_inf(self) -> bool:
"""`!True` if the range doesn't have a lower bound."""
if not self._bounds:
return False
return self._lower is None
@property
def upper_inf(self) -> bool:
"""`!True` if the range doesn't have an upper bound."""
if not self._bounds:
return False
return self._upper is None
@property
def lower_inc(self) -> bool:
"""`!True` if the lower bound is included in the range."""
if not self._bounds or self._lower is None:
return False
return self._bounds[0] == "["
@property
def upper_inc(self) -> bool:
"""`!True` if the upper bound is included in the range."""
if not self._bounds or self._upper is None:
return False
return self._bounds[1] == "]"
def __contains__(self, x: T) -> bool:
if not self._bounds:
return False
if self._lower is not None:
if self._bounds[0] == "[":
# It doesn't seem that Python has an ABC for ordered types.
if x < self._lower: # type: ignore[operator]
return False
else:
if x <= self._lower: # type: ignore[operator]
return False
if self._upper is not None:
if self._bounds[1] == "]":
if x > self._upper: # type: ignore[operator]
return False
else:
if x >= self._upper: # type: ignore[operator]
return False
return True
def __bool__(self) -> bool:
return bool(self._bounds)
def __eq__(self, other: Any) -> bool:
if not isinstance(other, Range):
return False
return (
self._lower == other._lower
and self._upper == other._upper
and self._bounds == other._bounds
)
def __hash__(self) -> int:
return hash((self._lower, self._upper, self._bounds))
# as the postgres docs describe for the server-side stuff,
# ordering is rather arbitrary, but will remain stable
# and consistent.
def __lt__(self, other: Any) -> bool:
if not isinstance(other, Range):
return NotImplemented
for attr in ("_lower", "_upper", "_bounds"):
self_value = getattr(self, attr)
other_value = getattr(other, attr)
if self_value == other_value:
pass
elif self_value is None:
return True
elif other_value is None:
return False
else:
return cast(bool, self_value < other_value)
return False
def __le__(self, other: Any) -> bool:
return self == other or self < other # type: ignore
def __gt__(self, other: Any) -> bool:
if isinstance(other, Range):
return other < self
else:
return NotImplemented
def __ge__(self, other: Any) -> bool:
return self == other or self > other # type: ignore
def __getstate__(self) -> Dict[str, Any]:
return {
slot: getattr(self, slot) for slot in self.__slots__ if hasattr(self, slot)
}
def __setstate__(self, state: Dict[str, Any]) -> None:
for slot, value in state.items():
setattr(self, slot, value)
# Subclasses to specify a specific subtype. Usually not needed: only needed
# in binary copy, where switching to text is not an option.
class Int4Range(Range[int]):
pass
class Int8Range(Range[int]):
pass
class NumericRange(Range[Decimal]):
pass
class DateRange(Range[date]):
pass
class TimestampRange(Range[datetime]):
pass
class TimestamptzRange(Range[datetime]):
pass
class BaseRangeDumper(RecursiveDumper):
def __init__(self, cls: type, context: Optional[AdaptContext] = None):
super().__init__(cls, context)
self.sub_dumper: Optional[Dumper] = None
self._adapt_format = PyFormat.from_pq(self.format)
def get_key(self, obj: Range[Any], format: PyFormat) -> DumperKey:
# If we are a subclass whose oid is specified we don't need upgrade
if self.cls is not Range:
return self.cls
item = self._get_item(obj)
if item is not None:
sd = self._tx.get_dumper(item, self._adapt_format)
return (self.cls, sd.get_key(item, format))
else:
return (self.cls,)
def upgrade(self, obj: Range[Any], format: PyFormat) -> "BaseRangeDumper":
# If we are a subclass whose oid is specified we don't need upgrade
if self.cls is not Range:
return self
item = self._get_item(obj)
if item is None:
return RangeDumper(self.cls)
dumper: BaseRangeDumper
if type(item) is int:
# postgres won't cast int4range -> int8range so we must use
# text format and unknown oid here
sd = self._tx.get_dumper(item, PyFormat.TEXT)
dumper = RangeDumper(self.cls, self._tx)
dumper.sub_dumper = sd
dumper.oid = INVALID_OID
return dumper
sd = self._tx.get_dumper(item, format)
dumper = type(self)(self.cls, self._tx)
dumper.sub_dumper = sd
if sd.oid == INVALID_OID and isinstance(item, str):
# Work around the normal mapping where text is dumped as unknown
dumper.oid = self._get_range_oid(TEXT_OID)
else:
dumper.oid = self._get_range_oid(sd.oid)
return dumper
def _get_item(self, obj: Range[Any]) -> Any:
"""
Return a member representative of the range
"""
rv = obj.lower
return rv if rv is not None else obj.upper
def _get_range_oid(self, sub_oid: int) -> int:
"""
Return the oid of the range from the oid of its elements.
"""
info = self._tx.adapters.types.get_by_subtype(RangeInfo, sub_oid)
return info.oid if info else INVALID_OID
class RangeDumper(BaseRangeDumper):
"""
Dumper for range types.
The dumper can upgrade to one specific for a different range type.
"""
def dump(self, obj: Range[Any]) -> Buffer:
item = self._get_item(obj)
if item is not None:
dump = self._tx.get_dumper(item, self._adapt_format).dump
else:
dump = fail_dump
return dump_range_text(obj, dump)
def dump_range_text(obj: Range[Any], dump: Callable[[Any], Buffer]) -> Buffer:
if obj.isempty:
return b"empty"
parts: List[Buffer] = [b"[" if obj.lower_inc else b"("]
def dump_item(item: Any) -> Buffer:
ad = dump(item)
if not ad:
return b'""'
elif _re_needs_quotes.search(ad):
return b'"' + _re_esc.sub(rb"\1\1", ad) + b'"'
else:
return ad
if obj.lower is not None:
parts.append(dump_item(obj.lower))
parts.append(b",")
if obj.upper is not None:
parts.append(dump_item(obj.upper))
parts.append(b"]" if obj.upper_inc else b")")
return b"".join(parts)
_re_needs_quotes = re.compile(rb'[",\\\s()\[\]]')
_re_esc = re.compile(rb"([\\\"])")
class RangeBinaryDumper(BaseRangeDumper):
format = Format.BINARY
def dump(self, obj: Range[Any]) -> Buffer:
item = self._get_item(obj)
if item is not None:
dump = self._tx.get_dumper(item, self._adapt_format).dump
else:
dump = fail_dump
return dump_range_binary(obj, dump)
def dump_range_binary(obj: Range[Any], dump: Callable[[Any], Buffer]) -> Buffer:
if not obj:
return _EMPTY_HEAD
out = bytearray([0]) # will replace the head later
head = 0
if obj.lower_inc:
head |= RANGE_LB_INC
if obj.upper_inc:
head |= RANGE_UB_INC
if obj.lower is not None:
data = dump(obj.lower)
out += pack_len(len(data))
out += data
else:
head |= RANGE_LB_INF
if obj.upper is not None:
data = dump(obj.upper)
out += pack_len(len(data))
out += data
else:
head |= RANGE_UB_INF
out[0] = head
return out
def fail_dump(obj: Any) -> Buffer:
raise e.InternalError("trying to dump a range element without information")
class BaseRangeLoader(RecursiveLoader, Generic[T]):
"""Generic loader for a range.
Subclasses must specify the oid of the subtype and the class to load.
"""
subtype_oid: int
def __init__(self, oid: int, context: Optional[AdaptContext] = None):
super().__init__(oid, context)
self._load = self._tx.get_loader(self.subtype_oid, format=self.format).load
class RangeLoader(BaseRangeLoader[T]):
def load(self, data: Buffer) -> Range[T]:
return load_range_text(data, self._load)[0]
def load_range_text(
data: Buffer, load: Callable[[Buffer], Any]
) -> Tuple[Range[Any], int]:
if data == b"empty":
return Range(empty=True), 5
m = _re_range.match(data)
if m is None:
raise e.DataError(
f"failed to parse range: '{bytes(data).decode('utf8', 'replace')}'"
)
lower = None
item = m.group(3)
if item is None:
item = m.group(2)
if item is not None:
lower = load(_re_undouble.sub(rb"\1", item))
else:
lower = load(item)
upper = None
item = m.group(5)
if item is None:
item = m.group(4)
if item is not None:
upper = load(_re_undouble.sub(rb"\1", item))
else:
upper = load(item)
bounds = (m.group(1) + m.group(6)).decode()
return Range(lower, upper, bounds), m.end()
_re_range = re.compile(
rb"""
( \(|\[ ) # lower bound flag
(?: # lower bound:
" ( (?: [^"] | "")* ) " # - a quoted string
| ( [^",]+ ) # - or an unquoted string
)? # - or empty (not caught)
,
(?: # upper bound:
" ( (?: [^"] | "")* ) " # - a quoted string
| ( [^"\)\]]+ ) # - or an unquoted string
)? # - or empty (not caught)
( \)|\] ) # upper bound flag
""",
re.VERBOSE,
)
_re_undouble = re.compile(rb'(["\\])\1')
class RangeBinaryLoader(BaseRangeLoader[T]):
format = Format.BINARY
def load(self, data: Buffer) -> Range[T]:
return load_range_binary(data, self._load)
def load_range_binary(data: Buffer, load: Callable[[Buffer], Any]) -> Range[Any]:
head = data[0]
if head & RANGE_EMPTY:
return Range(empty=True)
lb = "[" if head & RANGE_LB_INC else "("
ub = "]" if head & RANGE_UB_INC else ")"
pos = 1 # after the head
if head & RANGE_LB_INF:
min = None
else:
length = unpack_len(data, pos)[0]
pos += 4
min = load(data[pos : pos + length])
pos += length
if head & RANGE_UB_INF:
max = None
else:
length = unpack_len(data, pos)[0]
pos += 4
max = load(data[pos : pos + length])
pos += length
return Range(min, max, lb + ub)
def register_range(info: RangeInfo, context: Optional[AdaptContext] = None) -> None:
"""Register the adapters to load and dump a range type.
:param info: The object with the information about the range to register.
:param context: The context where to register the adapters. If `!None`,
register it globally.
Register loaders so that loading data of this type will result in a `Range`
with bounds parsed as the right subtype.
.. note::
Registering the adapters doesn't affect objects already created, even
if they are children of the registered context. For instance,
registering the adapter globally doesn't affect already existing
connections.
"""
# A friendly error warning instead of an AttributeError in case fetch()
# failed and it wasn't noticed.
if not info:
raise TypeError("no info passed. Is the requested range available?")
# Register arrays and type info
info.register(context)
adapters = context.adapters if context else postgres.adapters
# generate and register a customized text loader
loader: Type[BaseRangeLoader[Any]]
loader = _make_loader(info.name, info.subtype_oid)
adapters.register_loader(info.oid, loader)
# generate and register a customized binary loader
loader = _make_binary_loader(info.name, info.subtype_oid)
adapters.register_loader(info.oid, loader)
# Cache all dynamically-generated types to avoid leaks in case the types
# cannot be GC'd.
@cache
def _make_loader(name: str, oid: int) -> Type[RangeLoader[Any]]:
return type(f"{name.title()}Loader", (RangeLoader,), {"subtype_oid": oid})
@cache
def _make_binary_loader(name: str, oid: int) -> Type[RangeBinaryLoader[Any]]:
return type(
f"{name.title()}BinaryLoader", (RangeBinaryLoader,), {"subtype_oid": oid}
)
# Text dumpers for builtin range types wrappers
# These are registered on specific subtypes so that the upgrade mechanism
# doesn't kick in.
class Int4RangeDumper(RangeDumper):
oid = postgres.types["int4range"].oid
class Int8RangeDumper(RangeDumper):
oid = postgres.types["int8range"].oid
class NumericRangeDumper(RangeDumper):
oid = postgres.types["numrange"].oid
class DateRangeDumper(RangeDumper):
oid = postgres.types["daterange"].oid
class TimestampRangeDumper(RangeDumper):
oid = postgres.types["tsrange"].oid
class TimestamptzRangeDumper(RangeDumper):
oid = postgres.types["tstzrange"].oid
# Binary dumpers for builtin range types wrappers
# These are registered on specific subtypes so that the upgrade mechanism
# doesn't kick in.
class Int4RangeBinaryDumper(RangeBinaryDumper):
oid = postgres.types["int4range"].oid
class Int8RangeBinaryDumper(RangeBinaryDumper):
oid = postgres.types["int8range"].oid
class NumericRangeBinaryDumper(RangeBinaryDumper):
oid = postgres.types["numrange"].oid
class DateRangeBinaryDumper(RangeBinaryDumper):
oid = postgres.types["daterange"].oid
class TimestampRangeBinaryDumper(RangeBinaryDumper):
oid = postgres.types["tsrange"].oid
class TimestamptzRangeBinaryDumper(RangeBinaryDumper):
oid = postgres.types["tstzrange"].oid
# Text loaders for builtin range types
class Int4RangeLoader(RangeLoader[int]):
subtype_oid = postgres.types["int4"].oid
class Int8RangeLoader(RangeLoader[int]):
subtype_oid = postgres.types["int8"].oid
class NumericRangeLoader(RangeLoader[Decimal]):
subtype_oid = postgres.types["numeric"].oid
class DateRangeLoader(RangeLoader[date]):
subtype_oid = postgres.types["date"].oid
class TimestampRangeLoader(RangeLoader[datetime]):
subtype_oid = postgres.types["timestamp"].oid
class TimestampTZRangeLoader(RangeLoader[datetime]):
subtype_oid = postgres.types["timestamptz"].oid
# Binary loaders for builtin range types
class Int4RangeBinaryLoader(RangeBinaryLoader[int]):
subtype_oid = postgres.types["int4"].oid
class Int8RangeBinaryLoader(RangeBinaryLoader[int]):
subtype_oid = postgres.types["int8"].oid
class NumericRangeBinaryLoader(RangeBinaryLoader[Decimal]):
subtype_oid = postgres.types["numeric"].oid
class DateRangeBinaryLoader(RangeBinaryLoader[date]):
subtype_oid = postgres.types["date"].oid
class TimestampRangeBinaryLoader(RangeBinaryLoader[datetime]):
subtype_oid = postgres.types["timestamp"].oid
class TimestampTZRangeBinaryLoader(RangeBinaryLoader[datetime]):
subtype_oid = postgres.types["timestamptz"].oid
def register_default_adapters(context: AdaptContext) -> None:
adapters = context.adapters
adapters.register_dumper(Range, RangeBinaryDumper)
adapters.register_dumper(Range, RangeDumper)
adapters.register_dumper(Int4Range, Int4RangeDumper)
adapters.register_dumper(Int8Range, Int8RangeDumper)
adapters.register_dumper(NumericRange, NumericRangeDumper)
adapters.register_dumper(DateRange, DateRangeDumper)
adapters.register_dumper(TimestampRange, TimestampRangeDumper)
adapters.register_dumper(TimestamptzRange, TimestamptzRangeDumper)
adapters.register_dumper(Int4Range, Int4RangeBinaryDumper)
adapters.register_dumper(Int8Range, Int8RangeBinaryDumper)
adapters.register_dumper(NumericRange, NumericRangeBinaryDumper)
adapters.register_dumper(DateRange, DateRangeBinaryDumper)
adapters.register_dumper(TimestampRange, TimestampRangeBinaryDumper)
adapters.register_dumper(TimestamptzRange, TimestamptzRangeBinaryDumper)
adapters.register_loader("int4range", Int4RangeLoader)
adapters.register_loader("int8range", Int8RangeLoader)
adapters.register_loader("numrange", NumericRangeLoader)
adapters.register_loader("daterange", DateRangeLoader)
adapters.register_loader("tsrange", TimestampRangeLoader)
adapters.register_loader("tstzrange", TimestampTZRangeLoader)
adapters.register_loader("int4range", Int4RangeBinaryLoader)
adapters.register_loader("int8range", Int8RangeBinaryLoader)
adapters.register_loader("numrange", NumericRangeBinaryLoader)
adapters.register_loader("daterange", DateRangeBinaryLoader)
adapters.register_loader("tsrange", TimestampRangeBinaryLoader)
adapters.register_loader("tstzrange", TimestampTZRangeBinaryLoader)

View File

@ -0,0 +1,90 @@
"""
Adapters for PostGIS geometries
"""
from typing import Optional, Type
from .. import postgres
from ..abc import AdaptContext, Buffer
from ..adapt import Dumper, Loader
from ..pq import Format
from .._compat import cache
from .._typeinfo import TypeInfo
try:
from shapely.wkb import loads, dumps
from shapely.geometry.base import BaseGeometry
except ImportError:
raise ImportError(
"The module psycopg.types.shapely requires the package 'Shapely'"
" to be installed"
)
class GeometryBinaryLoader(Loader):
format = Format.BINARY
def load(self, data: Buffer) -> "BaseGeometry":
if not isinstance(data, bytes):
data = bytes(data)
return loads(data)
class GeometryLoader(Loader):
def load(self, data: Buffer) -> "BaseGeometry":
# it's a hex string in binary
if isinstance(data, memoryview):
data = bytes(data)
return loads(data.decode(), hex=True)
class BaseGeometryBinaryDumper(Dumper):
format = Format.BINARY
def dump(self, obj: "BaseGeometry") -> bytes:
return dumps(obj) # type: ignore
class BaseGeometryDumper(Dumper):
def dump(self, obj: "BaseGeometry") -> bytes:
return dumps(obj, hex=True).encode() # type: ignore
def register_shapely(info: TypeInfo, context: Optional[AdaptContext] = None) -> None:
"""Register Shapely dumper and loaders."""
# A friendly error warning instead of an AttributeError in case fetch()
# failed and it wasn't noticed.
if not info:
raise TypeError("no info passed. Is the 'postgis' extension loaded?")
info.register(context)
adapters = context.adapters if context else postgres.adapters
adapters.register_loader(info.oid, GeometryBinaryLoader)
adapters.register_loader(info.oid, GeometryLoader)
# Default binary dump
adapters.register_dumper(BaseGeometry, _make_dumper(info.oid))
adapters.register_dumper(BaseGeometry, _make_binary_dumper(info.oid))
# Cache all dynamically-generated types to avoid leaks in case the types
# cannot be GC'd.
@cache
def _make_dumper(oid_in: int) -> Type[BaseGeometryDumper]:
class GeometryDumper(BaseGeometryDumper):
oid = oid_in
return GeometryDumper
@cache
def _make_binary_dumper(oid_in: int) -> Type[BaseGeometryBinaryDumper]:
class GeometryBinaryDumper(BaseGeometryBinaryDumper):
oid = oid_in
return GeometryBinaryDumper

View File

@ -0,0 +1,229 @@
"""
Adapters for textual types.
"""
# Copyright (C) 2020 The Psycopg Team
from typing import Optional, Union, TYPE_CHECKING
from .. import postgres
from ..pq import Format, Escaping
from ..abc import AdaptContext
from ..adapt import Buffer, Dumper, Loader
from ..errors import DataError
from .._encodings import conn_encoding
if TYPE_CHECKING:
from ..pq.abc import Escaping as EscapingProto
class _BaseStrDumper(Dumper):
def __init__(self, cls: type, context: Optional[AdaptContext] = None):
super().__init__(cls, context)
enc = conn_encoding(self.connection)
self._encoding = enc if enc != "ascii" else "utf-8"
class _StrBinaryDumper(_BaseStrDumper):
"""
Base class to dump a Python strings to a Postgres text type, in binary format.
Subclasses shall specify the oids of real types (text, varchar, name...).
"""
format = Format.BINARY
def dump(self, obj: str) -> bytes:
# the server will raise DataError subclass if the string contains 0x00
return obj.encode(self._encoding)
class _StrDumper(_BaseStrDumper):
"""
Base class to dump a Python strings to a Postgres text type, in text format.
Subclasses shall specify the oids of real types (text, varchar, name...).
"""
def dump(self, obj: str) -> bytes:
if "\x00" in obj:
raise DataError("PostgreSQL text fields cannot contain NUL (0x00) bytes")
else:
return obj.encode(self._encoding)
# The next are concrete dumpers, each one specifying the oid they dump to.
class StrBinaryDumper(_StrBinaryDumper):
oid = postgres.types["text"].oid
class StrBinaryDumperVarchar(_StrBinaryDumper):
oid = postgres.types["varchar"].oid
class StrBinaryDumperName(_StrBinaryDumper):
oid = postgres.types["name"].oid
class StrDumper(_StrDumper):
"""
Dumper for strings in text format to the text oid.
Note that this dumper is not used by default because the type is too strict
and PostgreSQL would require an explicit casts to everything that is not a
text field. However it is useful where the unknown oid is ambiguous and the
text oid is required, for instance with variadic functions.
"""
oid = postgres.types["text"].oid
class StrDumperVarchar(_StrDumper):
oid = postgres.types["varchar"].oid
class StrDumperName(_StrDumper):
oid = postgres.types["name"].oid
class StrDumperUnknown(_StrDumper):
"""
Dumper for strings in text format to the unknown oid.
This dumper is the default dumper for strings and allows to use Python
strings to represent almost every data type. In a few places, however, the
unknown oid is not accepted (for instance in variadic functions such as
'concat()'). In that case either a cast on the placeholder ('%s::text') or
the StrTextDumper should be used.
"""
pass
class TextLoader(Loader):
def __init__(self, oid: int, context: Optional[AdaptContext] = None):
super().__init__(oid, context)
enc = conn_encoding(self.connection)
self._encoding = enc if enc != "ascii" else ""
def load(self, data: Buffer) -> Union[bytes, str]:
if self._encoding:
if isinstance(data, memoryview):
data = bytes(data)
return data.decode(self._encoding)
else:
# return bytes for SQL_ASCII db
if not isinstance(data, bytes):
data = bytes(data)
return data
class TextBinaryLoader(TextLoader):
format = Format.BINARY
class BytesDumper(Dumper):
oid = postgres.types["bytea"].oid
_qprefix = b""
def __init__(self, cls: type, context: Optional[AdaptContext] = None):
super().__init__(cls, context)
self._esc = Escaping(self.connection.pgconn if self.connection else None)
def dump(self, obj: Buffer) -> Buffer:
return self._esc.escape_bytea(obj)
def quote(self, obj: Buffer) -> bytes:
escaped = self.dump(obj)
# We cannot use the base quoting because escape_bytea already returns
# the quotes content. if scs is off it will escape the backslashes in
# the format, otherwise it won't, but it doesn't tell us what quotes to
# use.
if self.connection:
if not self._qprefix:
scs = self.connection.pgconn.parameter_status(
b"standard_conforming_strings"
)
self._qprefix = b"'" if scs == b"on" else b" E'"
return self._qprefix + escaped + b"'"
# We don't have a connection, so someone is using us to generate a file
# to use off-line or something like that. PQescapeBytea, like its
# string counterpart, is not predictable whether it will escape
# backslashes.
rv: bytes = b" E'" + escaped + b"'"
if self._esc.escape_bytea(b"\x00") == b"\\000":
rv = rv.replace(b"\\", b"\\\\")
return rv
class BytesBinaryDumper(Dumper):
format = Format.BINARY
oid = postgres.types["bytea"].oid
def dump(self, obj: Buffer) -> Buffer:
return obj
class ByteaLoader(Loader):
_escaping: "EscapingProto"
def __init__(self, oid: int, context: Optional[AdaptContext] = None):
super().__init__(oid, context)
if not hasattr(self.__class__, "_escaping"):
self.__class__._escaping = Escaping()
def load(self, data: Buffer) -> bytes:
return self._escaping.unescape_bytea(data)
class ByteaBinaryLoader(Loader):
format = Format.BINARY
def load(self, data: Buffer) -> Buffer:
return data
def register_default_adapters(context: AdaptContext) -> None:
adapters = context.adapters
# NOTE: the order the dumpers are registered is relevant. The last one
# registered becomes the default for each type. Usually, binary is the
# default dumper. For text we use the text dumper as default because it
# plays the role of unknown, and it can be cast automatically to other
# types. However, before that, we register dumper with 'text', 'varchar',
# 'name' oids, which will be used when a text dumper is looked up by oid.
adapters.register_dumper(str, StrBinaryDumperName)
adapters.register_dumper(str, StrBinaryDumperVarchar)
adapters.register_dumper(str, StrBinaryDumper)
adapters.register_dumper(str, StrDumperName)
adapters.register_dumper(str, StrDumperVarchar)
adapters.register_dumper(str, StrDumper)
adapters.register_dumper(str, StrDumperUnknown)
adapters.register_loader(postgres.INVALID_OID, TextLoader)
adapters.register_loader("bpchar", TextLoader)
adapters.register_loader("name", TextLoader)
adapters.register_loader("text", TextLoader)
adapters.register_loader("varchar", TextLoader)
adapters.register_loader('"char"', TextLoader)
adapters.register_loader("bpchar", TextBinaryLoader)
adapters.register_loader("name", TextBinaryLoader)
adapters.register_loader("text", TextBinaryLoader)
adapters.register_loader("varchar", TextBinaryLoader)
adapters.register_loader('"char"', TextBinaryLoader)
adapters.register_dumper(bytes, BytesDumper)
adapters.register_dumper(bytearray, BytesDumper)
adapters.register_dumper(memoryview, BytesDumper)
adapters.register_dumper(bytes, BytesBinaryDumper)
adapters.register_dumper(bytearray, BytesBinaryDumper)
adapters.register_dumper(memoryview, BytesBinaryDumper)
adapters.register_loader("bytea", ByteaLoader)
adapters.register_loader(postgres.INVALID_OID, ByteaBinaryLoader)
adapters.register_loader("bytea", ByteaBinaryLoader)

View File

@ -0,0 +1,62 @@
"""
Adapters for the UUID type.
"""
# Copyright (C) 2020 The Psycopg Team
from typing import Callable, Optional, TYPE_CHECKING
from .. import postgres
from ..pq import Format
from ..abc import AdaptContext
from ..adapt import Buffer, Dumper, Loader
if TYPE_CHECKING:
import uuid
# Importing the uuid module is slow, so import it only on request.
UUID: Callable[..., "uuid.UUID"] = None # type: ignore[assignment]
class UUIDDumper(Dumper):
oid = postgres.types["uuid"].oid
def dump(self, obj: "uuid.UUID") -> bytes:
return obj.hex.encode()
class UUIDBinaryDumper(UUIDDumper):
format = Format.BINARY
def dump(self, obj: "uuid.UUID") -> bytes:
return obj.bytes
class UUIDLoader(Loader):
def __init__(self, oid: int, context: Optional[AdaptContext] = None):
super().__init__(oid, context)
global UUID
if UUID is None:
from uuid import UUID
def load(self, data: Buffer) -> "uuid.UUID":
if isinstance(data, memoryview):
data = bytes(data)
return UUID(data.decode())
class UUIDBinaryLoader(UUIDLoader):
format = Format.BINARY
def load(self, data: Buffer) -> "uuid.UUID":
if isinstance(data, memoryview):
data = bytes(data)
return UUID(bytes=data)
def register_default_adapters(context: AdaptContext) -> None:
adapters = context.adapters
adapters.register_dumper("uuid.UUID", UUIDDumper)
adapters.register_dumper("uuid.UUID", UUIDBinaryDumper)
adapters.register_loader("uuid", UUIDLoader)
adapters.register_loader("uuid", UUIDBinaryLoader)

View File

@ -0,0 +1,14 @@
"""
psycopg distribution version file.
"""
# Copyright (C) 2020 The Psycopg Team
# Use a versioning scheme as defined in
# https://www.python.org/dev/peps/pep-0440/
# STOP AND READ! if you change:
__version__ = "3.1.13"
# also change:
# - `docs/news.rst` to declare this as the current version or an unreleased one
# - `psycopg_c/psycopg_c/version.py` to the same version.

View File

@ -0,0 +1,393 @@
"""
Code concerned with waiting in different contexts (blocking, async, etc).
These functions are designed to consume the generators returned by the
`generators` module function and to return their final value.
"""
# Copyright (C) 2020 The Psycopg Team
import os
import sys
import select
import selectors
from typing import Optional
from asyncio import get_event_loop, wait_for, Event, TimeoutError
from selectors import DefaultSelector
from . import errors as e
from .abc import RV, PQGen, PQGenConn, WaitFunc
from ._enums import Wait as Wait, Ready as Ready # re-exported
from ._cmodule import _psycopg
WAIT_R = Wait.R
WAIT_W = Wait.W
WAIT_RW = Wait.RW
READY_R = Ready.R
READY_W = Ready.W
READY_RW = Ready.RW
def wait_selector(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV:
"""
Wait for a generator using the best strategy available.
:param gen: a generator performing database operations and yielding
`Ready` values when it would block.
:param fileno: the file descriptor to wait on.
:param timeout: timeout (in seconds) to check for other interrupt, e.g.
to allow Ctrl-C.
:type timeout: float
:return: whatever `!gen` returns on completion.
Consume `!gen`, scheduling `fileno` for completion when it is reported to
block. Once ready again send the ready state back to `!gen`.
"""
try:
s = next(gen)
with DefaultSelector() as sel:
while True:
sel.register(fileno, s)
rlist = None
while not rlist:
rlist = sel.select(timeout=timeout)
sel.unregister(fileno)
# note: this line should require a cast, but mypy doesn't complain
ready: Ready = rlist[0][1]
assert s & ready
s = gen.send(ready)
except StopIteration as ex:
rv: RV = ex.args[0] if ex.args else None
return rv
def wait_conn(gen: PQGenConn[RV], timeout: Optional[float] = None) -> RV:
"""
Wait for a connection generator using the best strategy available.
:param gen: a generator performing database operations and yielding
(fd, `Ready`) pairs when it would block.
:param timeout: timeout (in seconds) to check for other interrupt, e.g.
to allow Ctrl-C. If zero or None, wait indefinitely.
:type timeout: float
:return: whatever `!gen` returns on completion.
Behave like in `wait()`, but take the fileno to wait from the generator
itself, which might change during processing.
"""
try:
fileno, s = next(gen)
if not timeout:
timeout = None
with DefaultSelector() as sel:
while True:
sel.register(fileno, s)
rlist = sel.select(timeout=timeout)
sel.unregister(fileno)
if not rlist:
raise e.ConnectionTimeout("connection timeout expired")
ready: Ready = rlist[0][1] # type: ignore[assignment]
fileno, s = gen.send(ready)
except StopIteration as ex:
rv: RV = ex.args[0] if ex.args else None
return rv
async def wait_async(
gen: PQGen[RV], fileno: int, timeout: Optional[float] = None
) -> RV:
"""
Coroutine waiting for a generator to complete.
:param gen: a generator performing database operations and yielding
`Ready` values when it would block.
:param fileno: the file descriptor to wait on.
:return: whatever `!gen` returns on completion.
Behave like in `wait()`, but exposing an `asyncio` interface.
"""
# Use an event to block and restart after the fd state changes.
# Not sure this is the best implementation but it's a start.
ev = Event()
loop = get_event_loop()
ready: Ready
s: Wait
def wakeup(state: Ready) -> None:
nonlocal ready
ready |= state # type: ignore[assignment]
ev.set()
try:
s = next(gen)
while True:
reader = s & WAIT_R
writer = s & WAIT_W
if not reader and not writer:
raise e.InternalError(f"bad poll status: {s}")
ev.clear()
ready = 0 # type: ignore[assignment]
if reader:
loop.add_reader(fileno, wakeup, READY_R)
if writer:
loop.add_writer(fileno, wakeup, READY_W)
try:
if timeout is None:
await ev.wait()
else:
try:
await wait_for(ev.wait(), timeout)
except TimeoutError:
pass
finally:
if reader:
loop.remove_reader(fileno)
if writer:
loop.remove_writer(fileno)
s = gen.send(ready)
except StopIteration as ex:
rv: RV = ex.args[0] if ex.args else None
return rv
async def wait_conn_async(gen: PQGenConn[RV], timeout: Optional[float] = None) -> RV:
"""
Coroutine waiting for a connection generator to complete.
:param gen: a generator performing database operations and yielding
(fd, `Ready`) pairs when it would block.
:param timeout: timeout (in seconds) to check for other interrupt, e.g.
to allow Ctrl-C. If zero or None, wait indefinitely.
:return: whatever `!gen` returns on completion.
Behave like in `wait()`, but take the fileno to wait from the generator
itself, which might change during processing.
"""
# Use an event to block and restart after the fd state changes.
# Not sure this is the best implementation but it's a start.
ev = Event()
loop = get_event_loop()
ready: Ready
s: Wait
def wakeup(state: Ready) -> None:
nonlocal ready
ready = state
ev.set()
try:
fileno, s = next(gen)
if not timeout:
timeout = None
while True:
reader = s & WAIT_R
writer = s & WAIT_W
if not reader and not writer:
raise e.InternalError(f"bad poll status: {s}")
ev.clear()
ready = 0 # type: ignore[assignment]
if reader:
loop.add_reader(fileno, wakeup, READY_R)
if writer:
loop.add_writer(fileno, wakeup, READY_W)
try:
await wait_for(ev.wait(), timeout)
finally:
if reader:
loop.remove_reader(fileno)
if writer:
loop.remove_writer(fileno)
fileno, s = gen.send(ready)
except TimeoutError:
raise e.ConnectionTimeout("connection timeout expired")
except StopIteration as ex:
rv: RV = ex.args[0] if ex.args else None
return rv
# Specialised implementation of wait functions.
def wait_select(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV:
"""
Wait for a generator using select where supported.
BUG: on Linux, can't select on FD >= 1024. On Windows it's fine.
"""
try:
s = next(gen)
empty = ()
fnlist = (fileno,)
while True:
rl, wl, xl = select.select(
fnlist if s & WAIT_R else empty,
fnlist if s & WAIT_W else empty,
fnlist,
timeout,
)
ready = 0
if rl:
ready = READY_R
if wl:
ready |= READY_W
if not ready:
continue
# assert s & ready
s = gen.send(ready) # type: ignore
except StopIteration as ex:
rv: RV = ex.args[0] if ex.args else None
return rv
if hasattr(selectors, "EpollSelector"):
_epoll_evmasks = {
WAIT_R: select.EPOLLONESHOT | select.EPOLLIN | select.EPOLLERR,
WAIT_W: select.EPOLLONESHOT | select.EPOLLOUT | select.EPOLLERR,
WAIT_RW: select.EPOLLONESHOT
| (select.EPOLLIN | select.EPOLLOUT | select.EPOLLERR),
}
else:
_epoll_evmasks = {}
def wait_epoll(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV:
"""
Wait for a generator using epoll where supported.
Parameters are like for `wait()`. If it is detected that the best selector
strategy is `epoll` then this function will be used instead of `wait`.
See also: https://linux.die.net/man/2/epoll_ctl
BUG: if the connection FD is closed, `epoll.poll()` hangs. Same for
EpollSelector. For this reason, wait_poll() is currently preferable.
To reproduce the bug:
export PSYCOPG_WAIT_FUNC=wait_epoll
pytest tests/test_concurrency.py::test_concurrent_close
"""
try:
s = next(gen)
if timeout is None or timeout < 0:
timeout = 0
else:
timeout = int(timeout * 1000.0)
with select.epoll() as epoll:
evmask = _epoll_evmasks[s]
epoll.register(fileno, evmask)
while True:
fileevs = None
while not fileevs:
fileevs = epoll.poll(timeout)
ev = fileevs[0][1]
ready = 0
if ev & ~select.EPOLLOUT:
ready = READY_R
if ev & ~select.EPOLLIN:
ready |= READY_W
# assert s & ready
s = gen.send(ready)
evmask = _epoll_evmasks[s]
epoll.modify(fileno, evmask)
except StopIteration as ex:
rv: RV = ex.args[0] if ex.args else None
return rv
if hasattr(selectors, "PollSelector"):
_poll_evmasks = {
WAIT_R: select.POLLIN,
WAIT_W: select.POLLOUT,
WAIT_RW: select.POLLIN | select.POLLOUT,
}
else:
_poll_evmasks = {}
def wait_poll(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV:
"""
Wait for a generator using poll where supported.
Parameters are like for `wait()`.
"""
try:
s = next(gen)
if timeout is None or timeout < 0:
timeout = 0
else:
timeout = int(timeout * 1000.0)
poll = select.poll()
evmask = _poll_evmasks[s]
poll.register(fileno, evmask)
while True:
fileevs = None
while not fileevs:
fileevs = poll.poll(timeout)
ev = fileevs[0][1]
ready = 0
if ev & ~select.POLLOUT:
ready = READY_R
if ev & ~select.POLLIN:
ready |= READY_W
# assert s & ready
s = gen.send(ready)
evmask = _poll_evmasks[s]
poll.modify(fileno, evmask)
except StopIteration as ex:
rv: RV = ex.args[0] if ex.args else None
return rv
if _psycopg:
wait_c = _psycopg.wait_c
# Choose the best wait strategy for the platform.
#
# the selectors objects have a generic interface but come with some overhead,
# so we also offer more finely tuned implementations.
wait: WaitFunc
# Allow the user to choose a specific function for testing
if "PSYCOPG_WAIT_FUNC" in os.environ:
fname = os.environ["PSYCOPG_WAIT_FUNC"]
if not fname.startswith("wait_") or fname not in globals():
raise ImportError(
"PSYCOPG_WAIT_FUNC should be the name of an available wait function;"
f" got {fname!r}"
)
wait = globals()[fname]
# On Windows, for the moment, avoid using wait_c, because it was reported to
# use excessive CPU (see #645).
# TODO: investigate why.
elif _psycopg and sys.platform != "win32":
wait = wait_c
elif selectors.DefaultSelector is getattr(selectors, "SelectSelector", None):
# On Windows, SelectSelector should be the default.
wait = wait_select
elif hasattr(selectors, "PollSelector"):
# On linux, EpollSelector is the default. However, it hangs if the fd is
# closed while polling.
wait = wait_poll
else:
wait = wait_selector