docker setup
This commit is contained in:
110
srcs/.venv/lib/python3.11/site-packages/psycopg/__init__.py
Normal file
110
srcs/.venv/lib/python3.11/site-packages/psycopg/__init__.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""
|
||||
psycopg -- PostgreSQL database adapter for Python
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
import logging
|
||||
|
||||
from . import pq # noqa: F401 import early to stabilize side effects
|
||||
from . import types
|
||||
from . import postgres
|
||||
from ._tpc import Xid
|
||||
from .copy import Copy, AsyncCopy
|
||||
from ._enums import IsolationLevel
|
||||
from .cursor import Cursor
|
||||
from .errors import Warning, Error, InterfaceError, DatabaseError
|
||||
from .errors import DataError, OperationalError, IntegrityError
|
||||
from .errors import InternalError, ProgrammingError, NotSupportedError
|
||||
from ._column import Column
|
||||
from .conninfo import ConnectionInfo
|
||||
from ._pipeline import Pipeline, AsyncPipeline
|
||||
from .connection import BaseConnection, Connection, Notify
|
||||
from .transaction import Rollback, Transaction, AsyncTransaction
|
||||
from .cursor_async import AsyncCursor
|
||||
from .server_cursor import AsyncServerCursor, ServerCursor
|
||||
from .client_cursor import AsyncClientCursor, ClientCursor
|
||||
from .connection_async import AsyncConnection
|
||||
|
||||
from . import dbapi20
|
||||
from .dbapi20 import BINARY, DATETIME, NUMBER, ROWID, STRING
|
||||
from .dbapi20 import Binary, Date, DateFromTicks, Time, TimeFromTicks
|
||||
from .dbapi20 import Timestamp, TimestampFromTicks
|
||||
|
||||
from .version import __version__ as __version__ # noqa: F401
|
||||
|
||||
# Set the logger to a quiet default, can be enabled if needed
|
||||
logger = logging.getLogger("psycopg")
|
||||
if logger.level == logging.NOTSET:
|
||||
logger.setLevel(logging.WARNING)
|
||||
|
||||
# DBAPI compliance
|
||||
connect = Connection.connect
|
||||
apilevel = "2.0"
|
||||
threadsafety = 2
|
||||
paramstyle = "pyformat"
|
||||
|
||||
# register default adapters for PostgreSQL
|
||||
adapters = postgres.adapters # exposed by the package
|
||||
postgres.register_default_adapters(adapters)
|
||||
|
||||
# After the default ones, because these can deal with the bytea oid better
|
||||
dbapi20.register_dbapi20_adapters(adapters)
|
||||
|
||||
# Must come after all the types have been registered
|
||||
types.array.register_all_arrays(adapters)
|
||||
|
||||
# Note: defining the exported methods helps both Sphynx in documenting that
|
||||
# this is the canonical place to obtain them and should be used by MyPy too,
|
||||
# so that function signatures are consistent with the documentation.
|
||||
__all__ = [
|
||||
"AsyncClientCursor",
|
||||
"AsyncConnection",
|
||||
"AsyncCopy",
|
||||
"AsyncCursor",
|
||||
"AsyncPipeline",
|
||||
"AsyncServerCursor",
|
||||
"AsyncTransaction",
|
||||
"BaseConnection",
|
||||
"ClientCursor",
|
||||
"Column",
|
||||
"Connection",
|
||||
"ConnectionInfo",
|
||||
"Copy",
|
||||
"Cursor",
|
||||
"IsolationLevel",
|
||||
"Notify",
|
||||
"Pipeline",
|
||||
"Rollback",
|
||||
"ServerCursor",
|
||||
"Transaction",
|
||||
"Xid",
|
||||
# DBAPI exports
|
||||
"connect",
|
||||
"apilevel",
|
||||
"threadsafety",
|
||||
"paramstyle",
|
||||
"Warning",
|
||||
"Error",
|
||||
"InterfaceError",
|
||||
"DatabaseError",
|
||||
"DataError",
|
||||
"OperationalError",
|
||||
"IntegrityError",
|
||||
"InternalError",
|
||||
"ProgrammingError",
|
||||
"NotSupportedError",
|
||||
# DBAPI type constructors and singletons
|
||||
"Binary",
|
||||
"Date",
|
||||
"DateFromTicks",
|
||||
"Time",
|
||||
"TimeFromTicks",
|
||||
"Timestamp",
|
||||
"TimestampFromTicks",
|
||||
"BINARY",
|
||||
"DATETIME",
|
||||
"NUMBER",
|
||||
"ROWID",
|
||||
"STRING",
|
||||
]
|
||||
296
srcs/.venv/lib/python3.11/site-packages/psycopg/_adapters_map.py
Normal file
296
srcs/.venv/lib/python3.11/site-packages/psycopg/_adapters_map.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""
|
||||
Mapping from types/oids to Dumpers/Loaders
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
from typing import Any, Dict, List, Optional, Type, TypeVar, Union
|
||||
from typing import cast, TYPE_CHECKING
|
||||
|
||||
from . import pq
|
||||
from . import errors as e
|
||||
from .abc import Dumper, Loader
|
||||
from ._enums import PyFormat as PyFormat
|
||||
from ._cmodule import _psycopg
|
||||
from ._typeinfo import TypesRegistry
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .connection import BaseConnection
|
||||
|
||||
RV = TypeVar("RV")
|
||||
|
||||
|
||||
class AdaptersMap:
|
||||
r"""
|
||||
Establish how types should be converted between Python and PostgreSQL in
|
||||
an `~psycopg.abc.AdaptContext`.
|
||||
|
||||
`!AdaptersMap` maps Python types to `~psycopg.adapt.Dumper` classes to
|
||||
define how Python types are converted to PostgreSQL, and maps OIDs to
|
||||
`~psycopg.adapt.Loader` classes to establish how query results are
|
||||
converted to Python.
|
||||
|
||||
Every `!AdaptContext` object has an underlying `!AdaptersMap` defining how
|
||||
types are converted in that context, exposed as the
|
||||
`~psycopg.abc.AdaptContext.adapters` attribute: changing such map allows
|
||||
to customise adaptation in a context without changing separated contexts.
|
||||
|
||||
When a context is created from another context (for instance when a
|
||||
`~psycopg.Cursor` is created from a `~psycopg.Connection`), the parent's
|
||||
`!adapters` are used as template for the child's `!adapters`, so that every
|
||||
cursor created from the same connection use the connection's types
|
||||
configuration, but separate connections have independent mappings.
|
||||
|
||||
Once created, `!AdaptersMap` are independent. This means that objects
|
||||
already created are not affected if a wider scope (e.g. the global one) is
|
||||
changed.
|
||||
|
||||
The connections adapters are initialised using a global `!AdptersMap`
|
||||
template, exposed as `psycopg.adapters`: changing such mapping allows to
|
||||
customise the type mapping for every connections created afterwards.
|
||||
|
||||
The object can start empty or copy from another object of the same class.
|
||||
Copies are copy-on-write: if the maps are updated make a copy. This way
|
||||
extending e.g. global map by a connection or a connection map from a cursor
|
||||
is cheap: a copy is only made on customisation.
|
||||
"""
|
||||
|
||||
__module__ = "psycopg.adapt"
|
||||
|
||||
types: TypesRegistry
|
||||
|
||||
_dumpers: Dict[PyFormat, Dict[Union[type, str], Type[Dumper]]]
|
||||
_dumpers_by_oid: List[Dict[int, Type[Dumper]]]
|
||||
_loaders: List[Dict[int, Type[Loader]]]
|
||||
|
||||
# Record if a dumper or loader has an optimised version.
|
||||
_optimised: Dict[type, type] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
template: Optional["AdaptersMap"] = None,
|
||||
types: Optional[TypesRegistry] = None,
|
||||
):
|
||||
if template:
|
||||
self._dumpers = template._dumpers.copy()
|
||||
self._own_dumpers = _dumpers_shared.copy()
|
||||
template._own_dumpers = _dumpers_shared.copy()
|
||||
|
||||
self._dumpers_by_oid = template._dumpers_by_oid[:]
|
||||
self._own_dumpers_by_oid = [False, False]
|
||||
template._own_dumpers_by_oid = [False, False]
|
||||
|
||||
self._loaders = template._loaders[:]
|
||||
self._own_loaders = [False, False]
|
||||
template._own_loaders = [False, False]
|
||||
|
||||
self.types = TypesRegistry(template.types)
|
||||
|
||||
else:
|
||||
self._dumpers = {fmt: {} for fmt in PyFormat}
|
||||
self._own_dumpers = _dumpers_owned.copy()
|
||||
|
||||
self._dumpers_by_oid = [{}, {}]
|
||||
self._own_dumpers_by_oid = [True, True]
|
||||
|
||||
self._loaders = [{}, {}]
|
||||
self._own_loaders = [True, True]
|
||||
|
||||
self.types = types or TypesRegistry()
|
||||
|
||||
# implement the AdaptContext protocol too
|
||||
@property
|
||||
def adapters(self) -> "AdaptersMap":
|
||||
return self
|
||||
|
||||
@property
|
||||
def connection(self) -> Optional["BaseConnection[Any]"]:
|
||||
return None
|
||||
|
||||
def register_dumper(
|
||||
self, cls: Union[type, str, None], dumper: Type[Dumper]
|
||||
) -> None:
|
||||
"""
|
||||
Configure the context to use `!dumper` to convert objects of type `!cls`.
|
||||
|
||||
If two dumpers with different `~Dumper.format` are registered for the
|
||||
same type, the last one registered will be chosen when the query
|
||||
doesn't specify a format (i.e. when the value is used with a ``%s``
|
||||
"`~PyFormat.AUTO`" placeholder).
|
||||
|
||||
:param cls: The type to manage.
|
||||
:param dumper: The dumper to register for `!cls`.
|
||||
|
||||
If `!cls` is specified as string it will be lazy-loaded, so that it
|
||||
will be possible to register it without importing it before. In this
|
||||
case it should be the fully qualified name of the object (e.g.
|
||||
``"uuid.UUID"``).
|
||||
|
||||
If `!cls` is None, only use the dumper when looking up using
|
||||
`get_dumper_by_oid()`, which happens when we know the Postgres type to
|
||||
adapt to, but not the Python type that will be adapted (e.g. in COPY
|
||||
after using `~psycopg.Copy.set_types()`).
|
||||
|
||||
"""
|
||||
if not (cls is None or isinstance(cls, (str, type))):
|
||||
raise TypeError(
|
||||
f"dumpers should be registered on classes, got {cls} instead"
|
||||
)
|
||||
|
||||
if _psycopg:
|
||||
dumper = self._get_optimised(dumper)
|
||||
|
||||
# Register the dumper both as its format and as auto
|
||||
# so that the last dumper registered is used in auto (%s) format
|
||||
if cls:
|
||||
for fmt in (PyFormat.from_pq(dumper.format), PyFormat.AUTO):
|
||||
if not self._own_dumpers[fmt]:
|
||||
self._dumpers[fmt] = self._dumpers[fmt].copy()
|
||||
self._own_dumpers[fmt] = True
|
||||
|
||||
self._dumpers[fmt][cls] = dumper
|
||||
|
||||
# Register the dumper by oid, if the oid of the dumper is fixed
|
||||
if dumper.oid:
|
||||
if not self._own_dumpers_by_oid[dumper.format]:
|
||||
self._dumpers_by_oid[dumper.format] = self._dumpers_by_oid[
|
||||
dumper.format
|
||||
].copy()
|
||||
self._own_dumpers_by_oid[dumper.format] = True
|
||||
|
||||
self._dumpers_by_oid[dumper.format][dumper.oid] = dumper
|
||||
|
||||
def register_loader(self, oid: Union[int, str], loader: Type["Loader"]) -> None:
|
||||
"""
|
||||
Configure the context to use `!loader` to convert data of oid `!oid`.
|
||||
|
||||
:param oid: The PostgreSQL OID or type name to manage.
|
||||
:param loader: The loar to register for `!oid`.
|
||||
|
||||
If `oid` is specified as string, it refers to a type name, which is
|
||||
looked up in the `types` registry. `
|
||||
|
||||
"""
|
||||
if isinstance(oid, str):
|
||||
oid = self.types[oid].oid
|
||||
if not isinstance(oid, int):
|
||||
raise TypeError(f"loaders should be registered on oid, got {oid} instead")
|
||||
|
||||
if _psycopg:
|
||||
loader = self._get_optimised(loader)
|
||||
|
||||
fmt = loader.format
|
||||
if not self._own_loaders[fmt]:
|
||||
self._loaders[fmt] = self._loaders[fmt].copy()
|
||||
self._own_loaders[fmt] = True
|
||||
|
||||
self._loaders[fmt][oid] = loader
|
||||
|
||||
def get_dumper(self, cls: type, format: PyFormat) -> Type["Dumper"]:
|
||||
"""
|
||||
Return the dumper class for the given type and format.
|
||||
|
||||
Raise `~psycopg.ProgrammingError` if a class is not available.
|
||||
|
||||
:param cls: The class to adapt.
|
||||
:param format: The format to dump to. If `~psycopg.adapt.PyFormat.AUTO`,
|
||||
use the last one of the dumpers registered on `!cls`.
|
||||
"""
|
||||
try:
|
||||
# Fast path: the class has a known dumper.
|
||||
return self._dumpers[format][cls]
|
||||
except KeyError:
|
||||
if format not in self._dumpers:
|
||||
raise ValueError(f"bad dumper format: {format}")
|
||||
|
||||
# If the KeyError was caused by cls missing from dmap, let's
|
||||
# look for different cases.
|
||||
dmap = self._dumpers[format]
|
||||
|
||||
# Look for the right class, including looking at superclasses
|
||||
for scls in cls.__mro__:
|
||||
if scls in dmap:
|
||||
return dmap[scls]
|
||||
|
||||
# If the adapter is not found, look for its name as a string
|
||||
fqn = scls.__module__ + "." + scls.__qualname__
|
||||
if fqn in dmap:
|
||||
# Replace the class name with the class itself
|
||||
d = dmap[scls] = dmap.pop(fqn)
|
||||
return d
|
||||
|
||||
format = PyFormat(format)
|
||||
raise e.ProgrammingError(
|
||||
f"cannot adapt type {cls.__name__!r} using placeholder '%{format.value}'"
|
||||
f" (format: {format.name})"
|
||||
)
|
||||
|
||||
def get_dumper_by_oid(self, oid: int, format: pq.Format) -> Type["Dumper"]:
|
||||
"""
|
||||
Return the dumper class for the given oid and format.
|
||||
|
||||
Raise `~psycopg.ProgrammingError` if a class is not available.
|
||||
|
||||
:param oid: The oid of the type to dump to.
|
||||
:param format: The format to dump to.
|
||||
"""
|
||||
try:
|
||||
dmap = self._dumpers_by_oid[format]
|
||||
except KeyError:
|
||||
raise ValueError(f"bad dumper format: {format}")
|
||||
|
||||
try:
|
||||
return dmap[oid]
|
||||
except KeyError:
|
||||
info = self.types.get(oid)
|
||||
if info:
|
||||
msg = (
|
||||
f"cannot find a dumper for type {info.name} (oid {oid})"
|
||||
f" format {pq.Format(format).name}"
|
||||
)
|
||||
else:
|
||||
msg = (
|
||||
f"cannot find a dumper for unknown type with oid {oid}"
|
||||
f" format {pq.Format(format).name}"
|
||||
)
|
||||
raise e.ProgrammingError(msg)
|
||||
|
||||
def get_loader(self, oid: int, format: pq.Format) -> Optional[Type["Loader"]]:
|
||||
"""
|
||||
Return the loader class for the given oid and format.
|
||||
|
||||
Return `!None` if not found.
|
||||
|
||||
:param oid: The oid of the type to load.
|
||||
:param format: The format to load from.
|
||||
"""
|
||||
return self._loaders[format].get(oid)
|
||||
|
||||
@classmethod
|
||||
def _get_optimised(self, cls: Type[RV]) -> Type[RV]:
|
||||
"""Return the optimised version of a Dumper or Loader class.
|
||||
|
||||
Return the input class itself if there is no optimised version.
|
||||
"""
|
||||
try:
|
||||
return self._optimised[cls]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# Check if the class comes from psycopg.types and there is a class
|
||||
# with the same name in psycopg_c._psycopg.
|
||||
from psycopg import types
|
||||
|
||||
if cls.__module__.startswith(types.__name__):
|
||||
new = cast(Type[RV], getattr(_psycopg, cls.__name__, None))
|
||||
if new:
|
||||
self._optimised[cls] = new
|
||||
return new
|
||||
|
||||
self._optimised[cls] = cls
|
||||
return cls
|
||||
|
||||
|
||||
# Micro-optimization: copying these objects is faster than creating new dicts
|
||||
_dumpers_owned = dict.fromkeys(PyFormat, True)
|
||||
_dumpers_shared = dict.fromkeys(PyFormat, False)
|
||||
24
srcs/.venv/lib/python3.11/site-packages/psycopg/_cmodule.py
Normal file
24
srcs/.venv/lib/python3.11/site-packages/psycopg/_cmodule.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
Simplify access to the _psycopg module
|
||||
"""
|
||||
|
||||
# Copyright (C) 2021 The Psycopg Team
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from . import pq
|
||||
|
||||
__version__: Optional[str] = None
|
||||
|
||||
# Note: "c" must the first attempt so that mypy associates the variable the
|
||||
# right module interface. It will not result Optional, but hey.
|
||||
if pq.__impl__ == "c":
|
||||
from psycopg_c import _psycopg as _psycopg
|
||||
from psycopg_c import __version__ as __version__ # noqa: F401
|
||||
elif pq.__impl__ == "binary":
|
||||
from psycopg_binary import _psycopg as _psycopg # type: ignore
|
||||
from psycopg_binary import __version__ as __version__ # type: ignore # noqa: F401
|
||||
elif pq.__impl__ == "python":
|
||||
_psycopg = None # type: ignore
|
||||
else:
|
||||
raise ImportError(f"can't find _psycopg optimised module in {pq.__impl__!r}")
|
||||
142
srcs/.venv/lib/python3.11/site-packages/psycopg/_column.py
Normal file
142
srcs/.venv/lib/python3.11/site-packages/psycopg/_column.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""
|
||||
The Column object in Cursor.description
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
from typing import Any, NamedTuple, Optional, Sequence, TYPE_CHECKING
|
||||
from operator import attrgetter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .cursor import BaseCursor
|
||||
|
||||
|
||||
class ColumnData(NamedTuple):
|
||||
ftype: int
|
||||
fmod: int
|
||||
fsize: int
|
||||
|
||||
|
||||
class Column(Sequence[Any]):
|
||||
__module__ = "psycopg"
|
||||
|
||||
def __init__(self, cursor: "BaseCursor[Any, Any]", index: int):
|
||||
res = cursor.pgresult
|
||||
assert res
|
||||
|
||||
fname = res.fname(index)
|
||||
if fname:
|
||||
self._name = fname.decode(cursor._encoding)
|
||||
else:
|
||||
# COPY_OUT results have columns but no name
|
||||
self._name = f"column_{index + 1}"
|
||||
|
||||
self._data = ColumnData(
|
||||
ftype=res.ftype(index),
|
||||
fmod=res.fmod(index),
|
||||
fsize=res.fsize(index),
|
||||
)
|
||||
self._type = cursor.adapters.types.get(self._data.ftype)
|
||||
|
||||
_attrs = tuple(
|
||||
attrgetter(attr)
|
||||
for attr in """
|
||||
name type_code display_size internal_size precision scale null_ok
|
||||
""".split()
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<Column {self.name!r},"
|
||||
f" type: {self._type_display()} (oid: {self.type_code})>"
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return 7
|
||||
|
||||
def _type_display(self) -> str:
|
||||
parts = []
|
||||
parts.append(self._type.name if self._type else str(self.type_code))
|
||||
|
||||
mod1 = self.precision
|
||||
if mod1 is None:
|
||||
mod1 = self.display_size
|
||||
if mod1:
|
||||
parts.append(f"({mod1}")
|
||||
if self.scale:
|
||||
parts.append(f", {self.scale}")
|
||||
parts.append(")")
|
||||
|
||||
if self._type and self.type_code == self._type.array_oid:
|
||||
parts.append("[]")
|
||||
|
||||
return "".join(parts)
|
||||
|
||||
def __getitem__(self, index: Any) -> Any:
|
||||
if isinstance(index, slice):
|
||||
return tuple(getter(self) for getter in self._attrs[index])
|
||||
else:
|
||||
return self._attrs[index](self)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""The name of the column."""
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def type_code(self) -> int:
|
||||
"""The numeric OID of the column."""
|
||||
return self._data.ftype
|
||||
|
||||
@property
|
||||
def display_size(self) -> Optional[int]:
|
||||
"""The field size, for :sql:`varchar(n)`, None otherwise."""
|
||||
if not self._type:
|
||||
return None
|
||||
|
||||
if self._type.name in ("varchar", "char"):
|
||||
fmod = self._data.fmod
|
||||
if fmod >= 0:
|
||||
return fmod - 4
|
||||
|
||||
return None
|
||||
|
||||
@property
|
||||
def internal_size(self) -> Optional[int]:
|
||||
"""The internal field size for fixed-size types, None otherwise."""
|
||||
fsize = self._data.fsize
|
||||
return fsize if fsize >= 0 else None
|
||||
|
||||
@property
|
||||
def precision(self) -> Optional[int]:
|
||||
"""The number of digits for fixed precision types."""
|
||||
if not self._type:
|
||||
return None
|
||||
|
||||
dttypes = ("time", "timetz", "timestamp", "timestamptz", "interval")
|
||||
if self._type.name == "numeric":
|
||||
fmod = self._data.fmod
|
||||
if fmod >= 0:
|
||||
return fmod >> 16
|
||||
|
||||
elif self._type.name in dttypes:
|
||||
fmod = self._data.fmod
|
||||
if fmod >= 0:
|
||||
return fmod & 0xFFFF
|
||||
|
||||
return None
|
||||
|
||||
@property
|
||||
def scale(self) -> Optional[int]:
|
||||
"""The number of digits after the decimal point if available."""
|
||||
if self._type and self._type.name == "numeric":
|
||||
fmod = self._data.fmod - 4
|
||||
if fmod >= 0:
|
||||
return fmod & 0xFFFF
|
||||
|
||||
return None
|
||||
|
||||
@property
|
||||
def null_ok(self) -> Optional[bool]:
|
||||
"""Always `!None`"""
|
||||
return None
|
||||
72
srcs/.venv/lib/python3.11/site-packages/psycopg/_compat.py
Normal file
72
srcs/.venv/lib/python3.11/site-packages/psycopg/_compat.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
compatibility functions for different Python versions
|
||||
"""
|
||||
|
||||
# Copyright (C) 2021 The Psycopg Team
|
||||
|
||||
import sys
|
||||
import asyncio
|
||||
from typing import Any, Awaitable, Generator, Optional, Sequence, Union, TypeVar
|
||||
|
||||
# NOTE: TypeAlias cannot be exported by this module, as pyright special-cases it.
|
||||
# For this raisin it must be imported directly from typing_extension where used.
|
||||
# See https://github.com/microsoft/pyright/issues/4197
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
from typing import Protocol
|
||||
else:
|
||||
from typing_extensions import Protocol
|
||||
|
||||
T = TypeVar("T")
|
||||
FutureT: TypeAlias = Union["asyncio.Future[T]", Generator[Any, None, T], Awaitable[T]]
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
create_task = asyncio.create_task
|
||||
from math import prod
|
||||
|
||||
else:
|
||||
|
||||
def create_task(
|
||||
coro: FutureT[T], name: Optional[str] = None
|
||||
) -> "asyncio.Future[T]":
|
||||
return asyncio.create_task(coro)
|
||||
|
||||
from functools import reduce
|
||||
|
||||
def prod(seq: Sequence[int]) -> int:
|
||||
return reduce(int.__mul__, seq, 1)
|
||||
|
||||
|
||||
if sys.version_info >= (3, 9):
|
||||
from zoneinfo import ZoneInfo
|
||||
from functools import cache
|
||||
from collections import Counter, deque as Deque
|
||||
else:
|
||||
from typing import Counter, Deque
|
||||
from functools import lru_cache
|
||||
from backports.zoneinfo import ZoneInfo
|
||||
|
||||
cache = lru_cache(maxsize=None)
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import TypeGuard
|
||||
else:
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import LiteralString
|
||||
else:
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
__all__ = [
|
||||
"Counter",
|
||||
"Deque",
|
||||
"LiteralString",
|
||||
"Protocol",
|
||||
"TypeGuard",
|
||||
"ZoneInfo",
|
||||
"cache",
|
||||
"create_task",
|
||||
"prod",
|
||||
]
|
||||
252
srcs/.venv/lib/python3.11/site-packages/psycopg/_dns.py
Normal file
252
srcs/.venv/lib/python3.11/site-packages/psycopg/_dns.py
Normal file
@@ -0,0 +1,252 @@
|
||||
# type: ignore # dnspython is currently optional and mypy fails if missing
|
||||
"""
|
||||
DNS query support
|
||||
"""
|
||||
|
||||
# Copyright (C) 2021 The Psycopg Team
|
||||
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from random import randint
|
||||
from typing import Any, DefaultDict, Dict, List, NamedTuple, Optional, Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
from collections import defaultdict
|
||||
|
||||
try:
|
||||
from dns.resolver import Resolver, Cache
|
||||
from dns.asyncresolver import Resolver as AsyncResolver
|
||||
from dns.exception import DNSException
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"the module psycopg._dns requires the package 'dnspython' installed"
|
||||
)
|
||||
|
||||
from . import errors as e
|
||||
from . import conninfo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dns.rdtypes.IN.SRV import SRV
|
||||
|
||||
resolver = Resolver()
|
||||
resolver.cache = Cache()
|
||||
|
||||
async_resolver = AsyncResolver()
|
||||
async_resolver.cache = Cache()
|
||||
|
||||
|
||||
async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Perform async DNS lookup of the hosts and return a new params dict.
|
||||
|
||||
.. deprecated:: 3.1
|
||||
The use of this function is not necessary anymore, because
|
||||
`psycopg.AsyncConnection.connect()` performs non-blocking name
|
||||
resolution automatically.
|
||||
"""
|
||||
warnings.warn(
|
||||
"from psycopg 3.1, resolve_hostaddr_async() is not needed anymore",
|
||||
DeprecationWarning,
|
||||
)
|
||||
hosts: list[str] = []
|
||||
hostaddrs: list[str] = []
|
||||
ports: list[str] = []
|
||||
|
||||
for attempt in conninfo._split_attempts(conninfo._inject_defaults(params)):
|
||||
try:
|
||||
async for a2 in conninfo._split_attempts_and_resolve(attempt):
|
||||
hosts.append(a2["host"])
|
||||
hostaddrs.append(a2["hostaddr"])
|
||||
if "port" in params:
|
||||
ports.append(a2["port"])
|
||||
except OSError as ex:
|
||||
last_exc = ex
|
||||
|
||||
if params.get("host") and not hosts:
|
||||
# We couldn't resolve anything
|
||||
raise e.OperationalError(str(last_exc))
|
||||
|
||||
out = params.copy()
|
||||
shosts = ",".join(hosts)
|
||||
if shosts:
|
||||
out["host"] = shosts
|
||||
shostaddrs = ",".join(hostaddrs)
|
||||
if shostaddrs:
|
||||
out["hostaddr"] = shostaddrs
|
||||
sports = ",".join(ports)
|
||||
if ports:
|
||||
out["port"] = sports
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def resolve_srv(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Apply SRV DNS lookup as defined in :RFC:`2782`."""
|
||||
return Rfc2782Resolver().resolve(params)
|
||||
|
||||
|
||||
async def resolve_srv_async(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Async equivalent of `resolve_srv()`."""
|
||||
return await Rfc2782Resolver().resolve_async(params)
|
||||
|
||||
|
||||
class HostPort(NamedTuple):
|
||||
host: str
|
||||
port: str
|
||||
totry: bool = False
|
||||
target: Optional[str] = None
|
||||
|
||||
|
||||
class Rfc2782Resolver:
|
||||
"""Implement SRV RR Resolution as per RFC 2782
|
||||
|
||||
The class is organised to minimise code duplication between the sync and
|
||||
the async paths.
|
||||
"""
|
||||
|
||||
re_srv_rr = re.compile(r"^(?P<service>_[^\.]+)\.(?P<proto>_[^\.]+)\.(?P<target>.+)")
|
||||
|
||||
def resolve(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Update the parameters host and port after SRV lookup."""
|
||||
attempts = self._get_attempts(params)
|
||||
if not attempts:
|
||||
return params
|
||||
|
||||
hps = []
|
||||
for hp in attempts:
|
||||
if hp.totry:
|
||||
hps.extend(self._resolve_srv(hp))
|
||||
else:
|
||||
hps.append(hp)
|
||||
|
||||
return self._return_params(params, hps)
|
||||
|
||||
async def resolve_async(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Update the parameters host and port after SRV lookup."""
|
||||
attempts = self._get_attempts(params)
|
||||
if not attempts:
|
||||
return params
|
||||
|
||||
hps = []
|
||||
for hp in attempts:
|
||||
if hp.totry:
|
||||
hps.extend(await self._resolve_srv_async(hp))
|
||||
else:
|
||||
hps.append(hp)
|
||||
|
||||
return self._return_params(params, hps)
|
||||
|
||||
def _get_attempts(self, params: Dict[str, Any]) -> List[HostPort]:
|
||||
"""
|
||||
Return the list of host, and for each host if SRV lookup must be tried.
|
||||
|
||||
Return an empty list if no lookup is requested.
|
||||
"""
|
||||
# If hostaddr is defined don't do any resolution.
|
||||
if params.get("hostaddr", os.environ.get("PGHOSTADDR", "")):
|
||||
return []
|
||||
|
||||
host_arg: str = params.get("host", os.environ.get("PGHOST", ""))
|
||||
hosts_in = host_arg.split(",")
|
||||
port_arg: str = str(params.get("port", os.environ.get("PGPORT", "")))
|
||||
ports_in = port_arg.split(",")
|
||||
|
||||
if len(ports_in) == 1:
|
||||
# If only one port is specified, it applies to all the hosts.
|
||||
ports_in *= len(hosts_in)
|
||||
if len(ports_in) != len(hosts_in):
|
||||
# ProgrammingError would have been more appropriate, but this is
|
||||
# what the raise if the libpq fails connect in the same case.
|
||||
raise e.OperationalError(
|
||||
f"cannot match {len(hosts_in)} hosts with {len(ports_in)} port numbers"
|
||||
)
|
||||
|
||||
out = []
|
||||
srv_found = False
|
||||
for host, port in zip(hosts_in, ports_in):
|
||||
m = self.re_srv_rr.match(host)
|
||||
if m or port.lower() == "srv":
|
||||
srv_found = True
|
||||
target = m.group("target") if m else None
|
||||
hp = HostPort(host=host, port=port, totry=True, target=target)
|
||||
else:
|
||||
hp = HostPort(host=host, port=port)
|
||||
out.append(hp)
|
||||
|
||||
return out if srv_found else []
|
||||
|
||||
def _resolve_srv(self, hp: HostPort) -> List[HostPort]:
|
||||
try:
|
||||
ans = resolver.resolve(hp.host, "SRV")
|
||||
except DNSException:
|
||||
ans = ()
|
||||
return self._get_solved_entries(hp, ans)
|
||||
|
||||
async def _resolve_srv_async(self, hp: HostPort) -> List[HostPort]:
|
||||
try:
|
||||
ans = await async_resolver.resolve(hp.host, "SRV")
|
||||
except DNSException:
|
||||
ans = ()
|
||||
return self._get_solved_entries(hp, ans)
|
||||
|
||||
def _get_solved_entries(
|
||||
self, hp: HostPort, entries: "Sequence[SRV]"
|
||||
) -> List[HostPort]:
|
||||
if not entries:
|
||||
# No SRV entry found. Delegate the libpq a QNAME=target lookup
|
||||
if hp.target and hp.port.lower() != "srv":
|
||||
return [HostPort(host=hp.target, port=hp.port)]
|
||||
else:
|
||||
return []
|
||||
|
||||
# If there is precisely one SRV RR, and its Target is "." (the root
|
||||
# domain), abort.
|
||||
if len(entries) == 1 and str(entries[0].target) == ".":
|
||||
return []
|
||||
|
||||
return [
|
||||
HostPort(host=str(entry.target).rstrip("."), port=str(entry.port))
|
||||
for entry in self.sort_rfc2782(entries)
|
||||
]
|
||||
|
||||
def _return_params(
|
||||
self, params: Dict[str, Any], hps: List[HostPort]
|
||||
) -> Dict[str, Any]:
|
||||
if not hps:
|
||||
# Nothing found, we ended up with an empty list
|
||||
raise e.OperationalError("no host found after SRV RR lookup")
|
||||
|
||||
out = params.copy()
|
||||
out["host"] = ",".join(hp.host for hp in hps)
|
||||
out["port"] = ",".join(str(hp.port) for hp in hps)
|
||||
return out
|
||||
|
||||
def sort_rfc2782(self, ans: "Sequence[SRV]") -> "List[SRV]":
|
||||
"""
|
||||
Implement the priority/weight ordering defined in RFC 2782.
|
||||
"""
|
||||
# Divide the entries by priority:
|
||||
priorities: DefaultDict[int, "List[SRV]"] = defaultdict(list)
|
||||
out: "List[SRV]" = []
|
||||
for entry in ans:
|
||||
priorities[entry.priority].append(entry)
|
||||
|
||||
for pri, entries in sorted(priorities.items()):
|
||||
if len(entries) == 1:
|
||||
out.append(entries[0])
|
||||
continue
|
||||
|
||||
entries.sort(key=lambda ent: ent.weight)
|
||||
total_weight = sum(ent.weight for ent in entries)
|
||||
while entries:
|
||||
r = randint(0, total_weight)
|
||||
csum = 0
|
||||
for i, ent in enumerate(entries):
|
||||
csum += ent.weight
|
||||
if csum >= r:
|
||||
break
|
||||
out.append(ent)
|
||||
total_weight -= ent.weight
|
||||
del entries[i]
|
||||
|
||||
return out
|
||||
170
srcs/.venv/lib/python3.11/site-packages/psycopg/_encodings.py
Normal file
170
srcs/.venv/lib/python3.11/site-packages/psycopg/_encodings.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
Mappings between PostgreSQL and Python encodings.
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
import re
|
||||
import string
|
||||
import codecs
|
||||
from typing import Any, Dict, Optional, TYPE_CHECKING
|
||||
|
||||
from .pq._enums import ConnStatus
|
||||
from .errors import NotSupportedError
|
||||
from ._compat import cache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .pq.abc import PGconn
|
||||
from .connection import BaseConnection
|
||||
|
||||
OK = ConnStatus.OK
|
||||
|
||||
|
||||
_py_codecs = {
|
||||
"BIG5": "big5",
|
||||
"EUC_CN": "gb2312",
|
||||
"EUC_JIS_2004": "euc_jis_2004",
|
||||
"EUC_JP": "euc_jp",
|
||||
"EUC_KR": "euc_kr",
|
||||
# "EUC_TW": not available in Python
|
||||
"GB18030": "gb18030",
|
||||
"GBK": "gbk",
|
||||
"ISO_8859_5": "iso8859-5",
|
||||
"ISO_8859_6": "iso8859-6",
|
||||
"ISO_8859_7": "iso8859-7",
|
||||
"ISO_8859_8": "iso8859-8",
|
||||
"JOHAB": "johab",
|
||||
"KOI8R": "koi8-r",
|
||||
"KOI8U": "koi8-u",
|
||||
"LATIN1": "iso8859-1",
|
||||
"LATIN10": "iso8859-16",
|
||||
"LATIN2": "iso8859-2",
|
||||
"LATIN3": "iso8859-3",
|
||||
"LATIN4": "iso8859-4",
|
||||
"LATIN5": "iso8859-9",
|
||||
"LATIN6": "iso8859-10",
|
||||
"LATIN7": "iso8859-13",
|
||||
"LATIN8": "iso8859-14",
|
||||
"LATIN9": "iso8859-15",
|
||||
# "MULE_INTERNAL": not available in Python
|
||||
"SHIFT_JIS_2004": "shift_jis_2004",
|
||||
"SJIS": "shift_jis",
|
||||
# this actually means no encoding, see PostgreSQL docs
|
||||
# it is special-cased by the text loader.
|
||||
"SQL_ASCII": "ascii",
|
||||
"UHC": "cp949",
|
||||
"UTF8": "utf-8",
|
||||
"WIN1250": "cp1250",
|
||||
"WIN1251": "cp1251",
|
||||
"WIN1252": "cp1252",
|
||||
"WIN1253": "cp1253",
|
||||
"WIN1254": "cp1254",
|
||||
"WIN1255": "cp1255",
|
||||
"WIN1256": "cp1256",
|
||||
"WIN1257": "cp1257",
|
||||
"WIN1258": "cp1258",
|
||||
"WIN866": "cp866",
|
||||
"WIN874": "cp874",
|
||||
}
|
||||
|
||||
py_codecs: Dict[bytes, str] = {}
|
||||
py_codecs.update((k.encode(), v) for k, v in _py_codecs.items())
|
||||
|
||||
# Add an alias without underscore, for lenient lookups
|
||||
py_codecs.update(
|
||||
(k.replace("_", "").encode(), v) for k, v in _py_codecs.items() if "_" in k
|
||||
)
|
||||
|
||||
pg_codecs = {v: k.encode() for k, v in _py_codecs.items()}
|
||||
|
||||
|
||||
def conn_encoding(conn: "Optional[BaseConnection[Any]]") -> str:
|
||||
"""
|
||||
Return the Python encoding name of a psycopg connection.
|
||||
|
||||
Default to utf8 if the connection has no encoding info.
|
||||
"""
|
||||
if not conn or conn.closed:
|
||||
return "utf-8"
|
||||
|
||||
pgenc = conn.pgconn.parameter_status(b"client_encoding") or b"UTF8"
|
||||
return pg2pyenc(pgenc)
|
||||
|
||||
|
||||
def pgconn_encoding(pgconn: "PGconn") -> str:
|
||||
"""
|
||||
Return the Python encoding name of a libpq connection.
|
||||
|
||||
Default to utf8 if the connection has no encoding info.
|
||||
"""
|
||||
if pgconn.status != OK:
|
||||
return "utf-8"
|
||||
|
||||
pgenc = pgconn.parameter_status(b"client_encoding") or b"UTF8"
|
||||
return pg2pyenc(pgenc)
|
||||
|
||||
|
||||
def conninfo_encoding(conninfo: str) -> str:
|
||||
"""
|
||||
Return the Python encoding name passed in a conninfo string. Default to utf8.
|
||||
|
||||
Because the input is likely to come from the user and not normalised by the
|
||||
server, be somewhat lenient (non-case-sensitive lookup, ignore noise chars).
|
||||
"""
|
||||
from .conninfo import conninfo_to_dict
|
||||
|
||||
params = conninfo_to_dict(conninfo)
|
||||
pgenc = params.get("client_encoding")
|
||||
if pgenc:
|
||||
try:
|
||||
return pg2pyenc(pgenc.encode())
|
||||
except NotSupportedError:
|
||||
pass
|
||||
|
||||
return "utf-8"
|
||||
|
||||
|
||||
@cache
|
||||
def py2pgenc(name: str) -> bytes:
|
||||
"""Convert a Python encoding name to PostgreSQL encoding name.
|
||||
|
||||
Raise LookupError if the Python encoding is unknown.
|
||||
"""
|
||||
return pg_codecs[codecs.lookup(name).name]
|
||||
|
||||
|
||||
@cache
|
||||
def pg2pyenc(name: bytes) -> str:
|
||||
"""Convert a PostgreSQL encoding name to Python encoding name.
|
||||
|
||||
Raise NotSupportedError if the PostgreSQL encoding is not supported by
|
||||
Python.
|
||||
"""
|
||||
try:
|
||||
return py_codecs[name.replace(b"-", b"").replace(b"_", b"").upper()]
|
||||
except KeyError:
|
||||
sname = name.decode("utf8", "replace")
|
||||
raise NotSupportedError(f"codec not available in Python: {sname!r}")
|
||||
|
||||
|
||||
def _as_python_identifier(s: str, prefix: str = "f") -> str:
|
||||
"""
|
||||
Reduce a string to a valid Python identifier.
|
||||
|
||||
Replace all non-valid chars with '_' and prefix the value with `!prefix` if
|
||||
the first letter is an '_'.
|
||||
"""
|
||||
if not s.isidentifier():
|
||||
if s[0] in "1234567890":
|
||||
s = prefix + s
|
||||
if not s.isidentifier():
|
||||
s = _re_clean.sub("_", s)
|
||||
# namedtuple fields cannot start with underscore. So...
|
||||
if s[0] == "_":
|
||||
s = prefix + s
|
||||
return s
|
||||
|
||||
|
||||
_re_clean = re.compile(
|
||||
f"[^{string.ascii_lowercase}{string.ascii_uppercase}{string.digits}_]"
|
||||
)
|
||||
79
srcs/.venv/lib/python3.11/site-packages/psycopg/_enums.py
Normal file
79
srcs/.venv/lib/python3.11/site-packages/psycopg/_enums.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""
|
||||
Enum values for psycopg
|
||||
|
||||
These values are defined by us and are not necessarily dependent on
|
||||
libpq-defined enums.
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
from enum import Enum, IntEnum
|
||||
from selectors import EVENT_READ, EVENT_WRITE
|
||||
|
||||
from . import pq
|
||||
|
||||
|
||||
class Wait(IntEnum):
|
||||
R = EVENT_READ
|
||||
W = EVENT_WRITE
|
||||
RW = EVENT_READ | EVENT_WRITE
|
||||
|
||||
|
||||
class Ready(IntEnum):
|
||||
R = EVENT_READ
|
||||
W = EVENT_WRITE
|
||||
RW = EVENT_READ | EVENT_WRITE
|
||||
|
||||
|
||||
class PyFormat(str, Enum):
|
||||
"""
|
||||
Enum representing the format wanted for a query argument.
|
||||
|
||||
The value `AUTO` allows psycopg to choose the best format for a certain
|
||||
parameter.
|
||||
"""
|
||||
|
||||
__module__ = "psycopg.adapt"
|
||||
|
||||
AUTO = "s"
|
||||
"""Automatically chosen (``%s`` placeholder)."""
|
||||
TEXT = "t"
|
||||
"""Text parameter (``%t`` placeholder)."""
|
||||
BINARY = "b"
|
||||
"""Binary parameter (``%b`` placeholder)."""
|
||||
|
||||
@classmethod
|
||||
def from_pq(cls, fmt: pq.Format) -> "PyFormat":
|
||||
return _pg2py[fmt]
|
||||
|
||||
@classmethod
|
||||
def as_pq(cls, fmt: "PyFormat") -> pq.Format:
|
||||
return _py2pg[fmt]
|
||||
|
||||
|
||||
class IsolationLevel(IntEnum):
|
||||
"""
|
||||
Enum representing the isolation level for a transaction.
|
||||
"""
|
||||
|
||||
__module__ = "psycopg"
|
||||
|
||||
READ_UNCOMMITTED = 1
|
||||
""":sql:`READ UNCOMMITTED` isolation level."""
|
||||
READ_COMMITTED = 2
|
||||
""":sql:`READ COMMITTED` isolation level."""
|
||||
REPEATABLE_READ = 3
|
||||
""":sql:`REPEATABLE READ` isolation level."""
|
||||
SERIALIZABLE = 4
|
||||
""":sql:`SERIALIZABLE` isolation level."""
|
||||
|
||||
|
||||
_py2pg = {
|
||||
PyFormat.TEXT: pq.Format.TEXT,
|
||||
PyFormat.BINARY: pq.Format.BINARY,
|
||||
}
|
||||
|
||||
_pg2py = {
|
||||
pq.Format.TEXT: PyFormat.TEXT,
|
||||
pq.Format.BINARY: PyFormat.BINARY,
|
||||
}
|
||||
298
srcs/.venv/lib/python3.11/site-packages/psycopg/_pipeline.py
Normal file
298
srcs/.venv/lib/python3.11/site-packages/psycopg/_pipeline.py
Normal file
@@ -0,0 +1,298 @@
|
||||
"""
|
||||
commands pipeline management
|
||||
"""
|
||||
|
||||
# Copyright (C) 2021 The Psycopg Team
|
||||
|
||||
import logging
|
||||
from types import TracebackType
|
||||
from typing import Any, List, Optional, Union, Tuple, Type, TypeVar, TYPE_CHECKING
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from . import pq
|
||||
from . import errors as e
|
||||
from .abc import PipelineCommand, PQGen
|
||||
from ._compat import Deque
|
||||
from .pq.misc import connection_summary
|
||||
from ._encodings import pgconn_encoding
|
||||
from ._preparing import Key, Prepare
|
||||
from .generators import pipeline_communicate, fetch_many, send
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .pq.abc import PGresult
|
||||
from .cursor import BaseCursor
|
||||
from .connection import BaseConnection, Connection
|
||||
from .connection_async import AsyncConnection
|
||||
|
||||
|
||||
PendingResult: TypeAlias = Union[
|
||||
None, Tuple["BaseCursor[Any, Any]", Optional[Tuple[Key, Prepare, bytes]]]
|
||||
]
|
||||
|
||||
FATAL_ERROR = pq.ExecStatus.FATAL_ERROR
|
||||
PIPELINE_ABORTED = pq.ExecStatus.PIPELINE_ABORTED
|
||||
BAD = pq.ConnStatus.BAD
|
||||
|
||||
ACTIVE = pq.TransactionStatus.ACTIVE
|
||||
|
||||
logger = logging.getLogger("psycopg")
|
||||
|
||||
|
||||
class BasePipeline:
|
||||
command_queue: Deque[PipelineCommand]
|
||||
result_queue: Deque[PendingResult]
|
||||
_is_supported: Optional[bool] = None
|
||||
|
||||
def __init__(self, conn: "BaseConnection[Any]") -> None:
|
||||
self._conn = conn
|
||||
self.pgconn = conn.pgconn
|
||||
self.command_queue = Deque[PipelineCommand]()
|
||||
self.result_queue = Deque[PendingResult]()
|
||||
self.level = 0
|
||||
|
||||
def __repr__(self) -> str:
|
||||
cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
|
||||
info = connection_summary(self._conn.pgconn)
|
||||
return f"<{cls} {info} at 0x{id(self):x}>"
|
||||
|
||||
@property
|
||||
def status(self) -> pq.PipelineStatus:
|
||||
return pq.PipelineStatus(self.pgconn.pipeline_status)
|
||||
|
||||
@classmethod
|
||||
def is_supported(cls) -> bool:
|
||||
"""Return `!True` if the psycopg libpq wrapper supports pipeline mode."""
|
||||
if BasePipeline._is_supported is None:
|
||||
BasePipeline._is_supported = not cls._not_supported_reason()
|
||||
return BasePipeline._is_supported
|
||||
|
||||
@classmethod
|
||||
def _not_supported_reason(cls) -> str:
|
||||
"""Return the reason why the pipeline mode is not supported.
|
||||
|
||||
Return an empty string if pipeline mode is supported.
|
||||
"""
|
||||
# Support only depends on the libpq functions available in the pq
|
||||
# wrapper, not on the database version.
|
||||
if pq.version() < 140000:
|
||||
return (
|
||||
f"libpq too old {pq.version()};"
|
||||
" v14 or greater required for pipeline mode"
|
||||
)
|
||||
|
||||
if pq.__build_version__ < 140000:
|
||||
return (
|
||||
f"libpq too old: module built for {pq.__build_version__};"
|
||||
" v14 or greater required for pipeline mode"
|
||||
)
|
||||
|
||||
return ""
|
||||
|
||||
def _enter_gen(self) -> PQGen[None]:
|
||||
if not self.is_supported():
|
||||
raise e.NotSupportedError(
|
||||
f"pipeline mode not supported: {self._not_supported_reason()}"
|
||||
)
|
||||
if self.level == 0:
|
||||
self.pgconn.enter_pipeline_mode()
|
||||
elif self.command_queue or self.pgconn.transaction_status == ACTIVE:
|
||||
# Nested pipeline case.
|
||||
# Transaction might be ACTIVE when the pipeline uses an "implicit
|
||||
# transaction", typically in autocommit mode. But when entering a
|
||||
# Psycopg transaction(), we expect the IDLE state. By sync()-ing,
|
||||
# we make sure all previous commands are completed and the
|
||||
# transaction gets back to IDLE.
|
||||
yield from self._sync_gen()
|
||||
self.level += 1
|
||||
|
||||
def _exit(self, exc: Optional[BaseException]) -> None:
|
||||
self.level -= 1
|
||||
if self.level == 0 and self.pgconn.status != BAD:
|
||||
try:
|
||||
self.pgconn.exit_pipeline_mode()
|
||||
except e.OperationalError as exc2:
|
||||
# Notice that this error might be pretty irrecoverable. It
|
||||
# happens on COPY, for instance: even if sync succeeds, exiting
|
||||
# fails with "cannot exit pipeline mode with uncollected results"
|
||||
if exc:
|
||||
logger.warning("error ignored exiting %r: %s", self, exc2)
|
||||
else:
|
||||
raise exc2.with_traceback(None)
|
||||
|
||||
def _sync_gen(self) -> PQGen[None]:
|
||||
self._enqueue_sync()
|
||||
yield from self._communicate_gen()
|
||||
yield from self._fetch_gen(flush=False)
|
||||
|
||||
def _exit_gen(self) -> PQGen[None]:
|
||||
"""
|
||||
Exit current pipeline by sending a Sync and fetch back all remaining results.
|
||||
"""
|
||||
try:
|
||||
self._enqueue_sync()
|
||||
yield from self._communicate_gen()
|
||||
finally:
|
||||
# No need to force flush since we emitted a sync just before.
|
||||
yield from self._fetch_gen(flush=False)
|
||||
|
||||
def _communicate_gen(self) -> PQGen[None]:
|
||||
"""Communicate with pipeline to send commands and possibly fetch
|
||||
results, which are then processed.
|
||||
"""
|
||||
fetched = yield from pipeline_communicate(self.pgconn, self.command_queue)
|
||||
exception = None
|
||||
for results in fetched:
|
||||
queued = self.result_queue.popleft()
|
||||
try:
|
||||
self._process_results(queued, results)
|
||||
except e.Error as exc:
|
||||
if exception is None:
|
||||
exception = exc
|
||||
if exception is not None:
|
||||
raise exception
|
||||
|
||||
def _fetch_gen(self, *, flush: bool) -> PQGen[None]:
|
||||
"""Fetch available results from the connection and process them with
|
||||
pipeline queued items.
|
||||
|
||||
If 'flush' is True, a PQsendFlushRequest() is issued in order to make
|
||||
sure results can be fetched. Otherwise, the caller may emit a
|
||||
PQpipelineSync() call to ensure the output buffer gets flushed before
|
||||
fetching.
|
||||
"""
|
||||
if not self.result_queue:
|
||||
return
|
||||
|
||||
if flush:
|
||||
self.pgconn.send_flush_request()
|
||||
yield from send(self.pgconn)
|
||||
|
||||
exception = None
|
||||
while self.result_queue:
|
||||
results = yield from fetch_many(self.pgconn)
|
||||
if not results:
|
||||
# No more results to fetch, but there may still be pending
|
||||
# commands.
|
||||
break
|
||||
queued = self.result_queue.popleft()
|
||||
try:
|
||||
self._process_results(queued, results)
|
||||
except e.Error as exc:
|
||||
if exception is None:
|
||||
exception = exc
|
||||
if exception is not None:
|
||||
raise exception
|
||||
|
||||
def _process_results(
|
||||
self, queued: PendingResult, results: List["PGresult"]
|
||||
) -> None:
|
||||
"""Process a results set fetched from the current pipeline.
|
||||
|
||||
This matches 'results' with its respective element in the pipeline
|
||||
queue. For commands (None value in the pipeline queue), results are
|
||||
checked directly. For prepare statement creation requests, update the
|
||||
cache. Otherwise, results are attached to their respective cursor.
|
||||
"""
|
||||
if queued is None:
|
||||
(result,) = results
|
||||
if result.status == FATAL_ERROR:
|
||||
raise e.error_from_result(result, encoding=pgconn_encoding(self.pgconn))
|
||||
elif result.status == PIPELINE_ABORTED:
|
||||
raise e.PipelineAborted("pipeline aborted")
|
||||
else:
|
||||
cursor, prepinfo = queued
|
||||
if prepinfo:
|
||||
key, prep, name = prepinfo
|
||||
# Update the prepare state of the query.
|
||||
cursor._conn._prepared.validate(key, prep, name, results)
|
||||
cursor._set_results_from_pipeline(results)
|
||||
|
||||
def _enqueue_sync(self) -> None:
|
||||
"""Enqueue a PQpipelineSync() command."""
|
||||
self.command_queue.append(self.pgconn.pipeline_sync)
|
||||
self.result_queue.append(None)
|
||||
|
||||
|
||||
class Pipeline(BasePipeline):
|
||||
"""Handler for connection in pipeline mode."""
|
||||
|
||||
__module__ = "psycopg"
|
||||
_conn: "Connection[Any]"
|
||||
_Self = TypeVar("_Self", bound="Pipeline")
|
||||
|
||||
def __init__(self, conn: "Connection[Any]") -> None:
|
||||
super().__init__(conn)
|
||||
|
||||
def sync(self) -> None:
|
||||
"""Sync the pipeline, send any pending command and receive and process
|
||||
all available results.
|
||||
"""
|
||||
try:
|
||||
with self._conn.lock:
|
||||
self._conn.wait(self._sync_gen())
|
||||
except e._NO_TRACEBACK as ex:
|
||||
raise ex.with_traceback(None)
|
||||
|
||||
def __enter__(self: _Self) -> _Self:
|
||||
with self._conn.lock:
|
||||
self._conn.wait(self._enter_gen())
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
try:
|
||||
with self._conn.lock:
|
||||
self._conn.wait(self._exit_gen())
|
||||
except Exception as exc2:
|
||||
# Don't clobber an exception raised in the block with this one
|
||||
if exc_val:
|
||||
logger.warning("error ignored terminating %r: %s", self, exc2)
|
||||
else:
|
||||
raise exc2.with_traceback(None)
|
||||
finally:
|
||||
self._exit(exc_val)
|
||||
|
||||
|
||||
class AsyncPipeline(BasePipeline):
|
||||
"""Handler for async connection in pipeline mode."""
|
||||
|
||||
__module__ = "psycopg"
|
||||
_conn: "AsyncConnection[Any]"
|
||||
_Self = TypeVar("_Self", bound="AsyncPipeline")
|
||||
|
||||
def __init__(self, conn: "AsyncConnection[Any]") -> None:
|
||||
super().__init__(conn)
|
||||
|
||||
async def sync(self) -> None:
|
||||
try:
|
||||
async with self._conn.lock:
|
||||
await self._conn.wait(self._sync_gen())
|
||||
except e._NO_TRACEBACK as ex:
|
||||
raise ex.with_traceback(None)
|
||||
|
||||
async def __aenter__(self: _Self) -> _Self:
|
||||
async with self._conn.lock:
|
||||
await self._conn.wait(self._enter_gen())
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
try:
|
||||
async with self._conn.lock:
|
||||
await self._conn.wait(self._exit_gen())
|
||||
except Exception as exc2:
|
||||
# Don't clobber an exception raised in the block with this one
|
||||
if exc_val:
|
||||
logger.warning("error ignored terminating %r: %s", self, exc2)
|
||||
else:
|
||||
raise exc2.with_traceback(None)
|
||||
finally:
|
||||
self._exit(exc_val)
|
||||
194
srcs/.venv/lib/python3.11/site-packages/psycopg/_preparing.py
Normal file
194
srcs/.venv/lib/python3.11/site-packages/psycopg/_preparing.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
Support for prepared statements
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
from enum import IntEnum, auto
|
||||
from typing import Iterator, Optional, Sequence, Tuple, TYPE_CHECKING
|
||||
from collections import OrderedDict
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from . import pq
|
||||
from ._compat import Deque
|
||||
from ._queries import PostgresQuery
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .pq.abc import PGresult
|
||||
|
||||
Key: TypeAlias = Tuple[bytes, Tuple[int, ...]]
|
||||
|
||||
COMMAND_OK = pq.ExecStatus.COMMAND_OK
|
||||
TUPLES_OK = pq.ExecStatus.TUPLES_OK
|
||||
|
||||
|
||||
class Prepare(IntEnum):
|
||||
NO = auto()
|
||||
YES = auto()
|
||||
SHOULD = auto()
|
||||
|
||||
|
||||
class PrepareManager:
|
||||
# Number of times a query is executed before it is prepared.
|
||||
prepare_threshold: Optional[int] = 5
|
||||
|
||||
# Maximum number of prepared statements on the connection.
|
||||
prepared_max: int = 100
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Map (query, types) to the number of times the query was seen.
|
||||
self._counts: OrderedDict[Key, int] = OrderedDict()
|
||||
|
||||
# Map (query, types) to the name of the statement if prepared.
|
||||
self._names: OrderedDict[Key, bytes] = OrderedDict()
|
||||
|
||||
# Counter to generate prepared statements names
|
||||
self._prepared_idx = 0
|
||||
|
||||
self._maint_commands = Deque[bytes]()
|
||||
|
||||
@staticmethod
|
||||
def key(query: PostgresQuery) -> Key:
|
||||
return (query.query, query.types)
|
||||
|
||||
def get(
|
||||
self, query: PostgresQuery, prepare: Optional[bool] = None
|
||||
) -> Tuple[Prepare, bytes]:
|
||||
"""
|
||||
Check if a query is prepared, tell back whether to prepare it.
|
||||
"""
|
||||
if prepare is False or self.prepare_threshold is None:
|
||||
# The user doesn't want this query to be prepared
|
||||
return Prepare.NO, b""
|
||||
|
||||
key = self.key(query)
|
||||
name = self._names.get(key)
|
||||
if name:
|
||||
# The query was already prepared in this session
|
||||
return Prepare.YES, name
|
||||
|
||||
count = self._counts.get(key, 0)
|
||||
if count >= self.prepare_threshold or prepare:
|
||||
# The query has been executed enough times and needs to be prepared
|
||||
name = f"_pg3_{self._prepared_idx}".encode()
|
||||
self._prepared_idx += 1
|
||||
return Prepare.SHOULD, name
|
||||
else:
|
||||
# The query is not to be prepared yet
|
||||
return Prepare.NO, b""
|
||||
|
||||
def _should_discard(self, prep: Prepare, results: Sequence["PGresult"]) -> bool:
|
||||
"""Check if we need to discard our entire state: it should happen on
|
||||
rollback or on dropping objects, because the same object may get
|
||||
recreated and postgres would fail internal lookups.
|
||||
"""
|
||||
if self._names or prep == Prepare.SHOULD:
|
||||
for result in results:
|
||||
if result.status != COMMAND_OK:
|
||||
continue
|
||||
cmdstat = result.command_status
|
||||
if cmdstat and (cmdstat.startswith(b"DROP ") or cmdstat == b"ROLLBACK"):
|
||||
return self.clear()
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _check_results(results: Sequence["PGresult"]) -> bool:
|
||||
"""Return False if 'results' are invalid for prepared statement cache."""
|
||||
if len(results) != 1:
|
||||
# We cannot prepare a multiple statement
|
||||
return False
|
||||
|
||||
status = results[0].status
|
||||
if COMMAND_OK != status != TUPLES_OK:
|
||||
# We don't prepare failed queries or other weird results
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _rotate(self) -> None:
|
||||
"""Evict an old value from the cache.
|
||||
|
||||
If it was prepared, deallocate it. Do it only once: if the cache was
|
||||
resized, deallocate gradually.
|
||||
"""
|
||||
if len(self._counts) > self.prepared_max:
|
||||
self._counts.popitem(last=False)
|
||||
|
||||
if len(self._names) > self.prepared_max:
|
||||
name = self._names.popitem(last=False)[1]
|
||||
self._maint_commands.append(b"DEALLOCATE " + name)
|
||||
|
||||
def maybe_add_to_cache(
|
||||
self, query: PostgresQuery, prep: Prepare, name: bytes
|
||||
) -> Optional[Key]:
|
||||
"""Handle 'query' for possible addition to the cache.
|
||||
|
||||
If a new entry has been added, return its key. Return None otherwise
|
||||
(meaning the query is already in cache or cache is not enabled).
|
||||
"""
|
||||
# don't do anything if prepared statements are disabled
|
||||
if self.prepare_threshold is None:
|
||||
return None
|
||||
|
||||
key = self.key(query)
|
||||
if key in self._counts:
|
||||
if prep is Prepare.SHOULD:
|
||||
del self._counts[key]
|
||||
self._names[key] = name
|
||||
else:
|
||||
self._counts[key] += 1
|
||||
self._counts.move_to_end(key)
|
||||
return None
|
||||
|
||||
elif key in self._names:
|
||||
self._names.move_to_end(key)
|
||||
return None
|
||||
|
||||
else:
|
||||
if prep is Prepare.SHOULD:
|
||||
self._names[key] = name
|
||||
else:
|
||||
self._counts[key] = 1
|
||||
return key
|
||||
|
||||
def validate(
|
||||
self,
|
||||
key: Key,
|
||||
prep: Prepare,
|
||||
name: bytes,
|
||||
results: Sequence["PGresult"],
|
||||
) -> None:
|
||||
"""Validate cached entry with 'key' by checking query 'results'.
|
||||
|
||||
Possibly record a command to perform maintenance on database side.
|
||||
"""
|
||||
if self._should_discard(prep, results):
|
||||
return
|
||||
|
||||
if not self._check_results(results):
|
||||
self._names.pop(key, None)
|
||||
self._counts.pop(key, None)
|
||||
else:
|
||||
self._rotate()
|
||||
|
||||
def clear(self) -> bool:
|
||||
"""Clear the cache of the maintenance commands.
|
||||
|
||||
Clear the internal state and prepare a command to clear the state of
|
||||
the server.
|
||||
"""
|
||||
self._counts.clear()
|
||||
if self._names:
|
||||
self._names.clear()
|
||||
self._maint_commands.clear()
|
||||
self._maint_commands.append(b"DEALLOCATE ALL")
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def get_maintenance_commands(self) -> Iterator[bytes]:
|
||||
"""
|
||||
Iterate over the commands needed to align the server state to our state
|
||||
"""
|
||||
while self._maint_commands:
|
||||
yield self._maint_commands.popleft()
|
||||
415
srcs/.venv/lib/python3.11/site-packages/psycopg/_queries.py
Normal file
415
srcs/.venv/lib/python3.11/site-packages/psycopg/_queries.py
Normal file
@@ -0,0 +1,415 @@
|
||||
"""
|
||||
Utility module to manipulate queries
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
import re
|
||||
from typing import Any, Callable, Dict, List, Mapping, Match, NamedTuple, Optional
|
||||
from typing import Sequence, Tuple, Union, TYPE_CHECKING
|
||||
from functools import lru_cache
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from . import pq
|
||||
from . import errors as e
|
||||
from .sql import Composable
|
||||
from .abc import Buffer, Query, Params
|
||||
from ._enums import PyFormat
|
||||
from ._encodings import conn_encoding
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .abc import Transformer
|
||||
|
||||
MAX_CACHED_STATEMENT_LENGTH = 4096
|
||||
MAX_CACHED_STATEMENT_PARAMS = 50
|
||||
|
||||
|
||||
class QueryPart(NamedTuple):
|
||||
pre: bytes
|
||||
item: Union[int, str]
|
||||
format: PyFormat
|
||||
|
||||
|
||||
class PostgresQuery:
|
||||
"""
|
||||
Helper to convert a Python query and parameters into Postgres format.
|
||||
"""
|
||||
|
||||
__slots__ = """
|
||||
query params types formats
|
||||
_tx _want_formats _parts _encoding _order
|
||||
""".split()
|
||||
|
||||
def __init__(self, transformer: "Transformer"):
|
||||
self._tx = transformer
|
||||
|
||||
self.params: Optional[Sequence[Optional[Buffer]]] = None
|
||||
# these are tuples so they can be used as keys e.g. in prepared stmts
|
||||
self.types: Tuple[int, ...] = ()
|
||||
|
||||
# The format requested by the user and the ones to really pass Postgres
|
||||
self._want_formats: Optional[List[PyFormat]] = None
|
||||
self.formats: Optional[Sequence[pq.Format]] = None
|
||||
|
||||
self._encoding = conn_encoding(transformer.connection)
|
||||
self._parts: List[QueryPart]
|
||||
self.query = b""
|
||||
self._order: Optional[List[str]] = None
|
||||
|
||||
def convert(self, query: Query, vars: Optional[Params]) -> None:
|
||||
"""
|
||||
Set up the query and parameters to convert.
|
||||
|
||||
The results of this function can be obtained accessing the object
|
||||
attributes (`query`, `params`, `types`, `formats`).
|
||||
"""
|
||||
if isinstance(query, str):
|
||||
bquery = query.encode(self._encoding)
|
||||
elif isinstance(query, Composable):
|
||||
bquery = query.as_bytes(self._tx)
|
||||
else:
|
||||
bquery = query
|
||||
|
||||
if vars is not None:
|
||||
# Avoid caching queries extremely long or with a huge number of
|
||||
# parameters. They are usually generated by ORMs and have poor
|
||||
# cacheablility (e.g. INSERT ... VALUES (...), (...) with varying
|
||||
# numbers of tuples.
|
||||
# see https://github.com/psycopg/psycopg/discussions/628
|
||||
if (
|
||||
len(bquery) <= MAX_CACHED_STATEMENT_LENGTH
|
||||
and len(vars) <= MAX_CACHED_STATEMENT_PARAMS
|
||||
):
|
||||
f: _Query2Pg = _query2pg
|
||||
else:
|
||||
f = _query2pg_nocache
|
||||
|
||||
(self.query, self._want_formats, self._order, self._parts) = f(
|
||||
bquery, self._encoding
|
||||
)
|
||||
else:
|
||||
self.query = bquery
|
||||
self._want_formats = self._order = None
|
||||
|
||||
self.dump(vars)
|
||||
|
||||
def dump(self, vars: Optional[Params]) -> None:
|
||||
"""
|
||||
Process a new set of variables on the query processed by `convert()`.
|
||||
|
||||
This method updates `params` and `types`.
|
||||
"""
|
||||
if vars is not None:
|
||||
params = _validate_and_reorder_params(self._parts, vars, self._order)
|
||||
assert self._want_formats is not None
|
||||
self.params = self._tx.dump_sequence(params, self._want_formats)
|
||||
self.types = self._tx.types or ()
|
||||
self.formats = self._tx.formats
|
||||
else:
|
||||
self.params = None
|
||||
self.types = ()
|
||||
self.formats = None
|
||||
|
||||
|
||||
# The type of the _query2pg() and _query2pg_nocache() methods
|
||||
_Query2Pg: TypeAlias = Callable[
|
||||
[bytes, str], Tuple[bytes, List[PyFormat], Optional[List[str]], List[QueryPart]]
|
||||
]
|
||||
|
||||
|
||||
def _query2pg_nocache(
|
||||
query: bytes, encoding: str
|
||||
) -> Tuple[bytes, List[PyFormat], Optional[List[str]], List[QueryPart]]:
|
||||
"""
|
||||
Convert Python query and params into something Postgres understands.
|
||||
|
||||
- Convert Python placeholders (``%s``, ``%(name)s``) into Postgres
|
||||
format (``$1``, ``$2``)
|
||||
- placeholders can be %s, %t, or %b (auto, text or binary)
|
||||
- return ``query`` (bytes), ``formats`` (list of formats) ``order``
|
||||
(sequence of names used in the query, in the position they appear)
|
||||
``parts`` (splits of queries and placeholders).
|
||||
"""
|
||||
parts = _split_query(query, encoding)
|
||||
order: Optional[List[str]] = None
|
||||
chunks: List[bytes] = []
|
||||
formats = []
|
||||
|
||||
if isinstance(parts[0].item, int):
|
||||
for part in parts[:-1]:
|
||||
assert isinstance(part.item, int)
|
||||
chunks.append(part.pre)
|
||||
chunks.append(b"$%d" % (part.item + 1))
|
||||
formats.append(part.format)
|
||||
|
||||
elif isinstance(parts[0].item, str):
|
||||
seen: Dict[str, Tuple[bytes, PyFormat]] = {}
|
||||
order = []
|
||||
for part in parts[:-1]:
|
||||
assert isinstance(part.item, str)
|
||||
chunks.append(part.pre)
|
||||
if part.item not in seen:
|
||||
ph = b"$%d" % (len(seen) + 1)
|
||||
seen[part.item] = (ph, part.format)
|
||||
order.append(part.item)
|
||||
chunks.append(ph)
|
||||
formats.append(part.format)
|
||||
else:
|
||||
if seen[part.item][1] != part.format:
|
||||
raise e.ProgrammingError(
|
||||
f"placeholder '{part.item}' cannot have different formats"
|
||||
)
|
||||
chunks.append(seen[part.item][0])
|
||||
|
||||
# last part
|
||||
chunks.append(parts[-1].pre)
|
||||
|
||||
return b"".join(chunks), formats, order, parts
|
||||
|
||||
|
||||
# Note: the cache size is 128 items, but someone has reported throwing ~12k
|
||||
# queries (of type `INSERT ... VALUES (...), (...)` with a varying amount of
|
||||
# records), and the resulting cache size is >100Mb. So, we will avoid to cache
|
||||
# large queries or queries with a large number of params. See
|
||||
# https://github.com/sqlalchemy/sqlalchemy/discussions/10270
|
||||
_query2pg = lru_cache()(_query2pg_nocache)
|
||||
|
||||
|
||||
class PostgresClientQuery(PostgresQuery):
|
||||
"""
|
||||
PostgresQuery subclass merging query and arguments client-side.
|
||||
"""
|
||||
|
||||
__slots__ = ("template",)
|
||||
|
||||
def convert(self, query: Query, vars: Optional[Params]) -> None:
|
||||
"""
|
||||
Set up the query and parameters to convert.
|
||||
|
||||
The results of this function can be obtained accessing the object
|
||||
attributes (`query`, `params`, `types`, `formats`).
|
||||
"""
|
||||
if isinstance(query, str):
|
||||
bquery = query.encode(self._encoding)
|
||||
elif isinstance(query, Composable):
|
||||
bquery = query.as_bytes(self._tx)
|
||||
else:
|
||||
bquery = query
|
||||
|
||||
if vars is not None:
|
||||
if (
|
||||
len(bquery) <= MAX_CACHED_STATEMENT_LENGTH
|
||||
and len(vars) <= MAX_CACHED_STATEMENT_PARAMS
|
||||
):
|
||||
f: _Query2PgClient = _query2pg_client
|
||||
else:
|
||||
f = _query2pg_client_nocache
|
||||
|
||||
(self.template, self._order, self._parts) = f(bquery, self._encoding)
|
||||
else:
|
||||
self.query = bquery
|
||||
self._order = None
|
||||
|
||||
self.dump(vars)
|
||||
|
||||
def dump(self, vars: Optional[Params]) -> None:
|
||||
"""
|
||||
Process a new set of variables on the query processed by `convert()`.
|
||||
|
||||
This method updates `params` and `types`.
|
||||
"""
|
||||
if vars is not None:
|
||||
params = _validate_and_reorder_params(self._parts, vars, self._order)
|
||||
self.params = tuple(
|
||||
self._tx.as_literal(p) if p is not None else b"NULL" for p in params
|
||||
)
|
||||
self.query = self.template % self.params
|
||||
else:
|
||||
self.params = None
|
||||
|
||||
|
||||
_Query2PgClient: TypeAlias = Callable[
|
||||
[bytes, str], Tuple[bytes, Optional[List[str]], List[QueryPart]]
|
||||
]
|
||||
|
||||
|
||||
def _query2pg_client_nocache(
|
||||
query: bytes, encoding: str
|
||||
) -> Tuple[bytes, Optional[List[str]], List[QueryPart]]:
|
||||
"""
|
||||
Convert Python query and params into a template to perform client-side binding
|
||||
"""
|
||||
parts = _split_query(query, encoding, collapse_double_percent=False)
|
||||
order: Optional[List[str]] = None
|
||||
chunks: List[bytes] = []
|
||||
|
||||
if isinstance(parts[0].item, int):
|
||||
for part in parts[:-1]:
|
||||
assert isinstance(part.item, int)
|
||||
chunks.append(part.pre)
|
||||
chunks.append(b"%s")
|
||||
|
||||
elif isinstance(parts[0].item, str):
|
||||
seen: Dict[str, Tuple[bytes, PyFormat]] = {}
|
||||
order = []
|
||||
for part in parts[:-1]:
|
||||
assert isinstance(part.item, str)
|
||||
chunks.append(part.pre)
|
||||
if part.item not in seen:
|
||||
ph = b"%s"
|
||||
seen[part.item] = (ph, part.format)
|
||||
order.append(part.item)
|
||||
chunks.append(ph)
|
||||
else:
|
||||
chunks.append(seen[part.item][0])
|
||||
order.append(part.item)
|
||||
|
||||
# last part
|
||||
chunks.append(parts[-1].pre)
|
||||
|
||||
return b"".join(chunks), order, parts
|
||||
|
||||
|
||||
_query2pg_client = lru_cache()(_query2pg_client_nocache)
|
||||
|
||||
|
||||
def _validate_and_reorder_params(
|
||||
parts: List[QueryPart], vars: Params, order: Optional[List[str]]
|
||||
) -> Sequence[Any]:
|
||||
"""
|
||||
Verify the compatibility between a query and a set of params.
|
||||
"""
|
||||
# Try concrete types, then abstract types
|
||||
t = type(vars)
|
||||
if t is list or t is tuple:
|
||||
sequence = True
|
||||
elif t is dict:
|
||||
sequence = False
|
||||
elif isinstance(vars, Sequence) and not isinstance(vars, (bytes, str)):
|
||||
sequence = True
|
||||
elif isinstance(vars, Mapping):
|
||||
sequence = False
|
||||
else:
|
||||
raise TypeError(
|
||||
"query parameters should be a sequence or a mapping,"
|
||||
f" got {type(vars).__name__}"
|
||||
)
|
||||
|
||||
if sequence:
|
||||
if len(vars) != len(parts) - 1:
|
||||
raise e.ProgrammingError(
|
||||
f"the query has {len(parts) - 1} placeholders but"
|
||||
f" {len(vars)} parameters were passed"
|
||||
)
|
||||
if vars and not isinstance(parts[0].item, int):
|
||||
raise TypeError("named placeholders require a mapping of parameters")
|
||||
return vars # type: ignore[return-value]
|
||||
|
||||
else:
|
||||
if vars and len(parts) > 1 and not isinstance(parts[0][1], str):
|
||||
raise TypeError(
|
||||
"positional placeholders (%s) require a sequence of parameters"
|
||||
)
|
||||
try:
|
||||
return [vars[item] for item in order or ()] # type: ignore[call-overload]
|
||||
except KeyError:
|
||||
raise e.ProgrammingError(
|
||||
"query parameter missing:"
|
||||
f" {', '.join(sorted(i for i in order or () if i not in vars))}"
|
||||
)
|
||||
|
||||
|
||||
_re_placeholder = re.compile(
|
||||
rb"""(?x)
|
||||
% # a literal %
|
||||
(?:
|
||||
(?:
|
||||
\( ([^)]+) \) # or a name in (braces)
|
||||
. # followed by a format
|
||||
)
|
||||
|
|
||||
(?:.) # or any char, really
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _split_query(
|
||||
query: bytes, encoding: str = "ascii", collapse_double_percent: bool = True
|
||||
) -> List[QueryPart]:
|
||||
parts: List[Tuple[bytes, Optional[Match[bytes]]]] = []
|
||||
cur = 0
|
||||
|
||||
# pairs [(fragment, match], with the last match None
|
||||
m = None
|
||||
for m in _re_placeholder.finditer(query):
|
||||
pre = query[cur : m.span(0)[0]]
|
||||
parts.append((pre, m))
|
||||
cur = m.span(0)[1]
|
||||
if m:
|
||||
parts.append((query[cur:], None))
|
||||
else:
|
||||
parts.append((query, None))
|
||||
|
||||
rv = []
|
||||
|
||||
# drop the "%%", validate
|
||||
i = 0
|
||||
phtype = None
|
||||
while i < len(parts):
|
||||
pre, m = parts[i]
|
||||
if m is None:
|
||||
# last part
|
||||
rv.append(QueryPart(pre, 0, PyFormat.AUTO))
|
||||
break
|
||||
|
||||
ph = m.group(0)
|
||||
if ph == b"%%":
|
||||
# unescape '%%' to '%' if necessary, then merge the parts
|
||||
if collapse_double_percent:
|
||||
ph = b"%"
|
||||
pre1, m1 = parts[i + 1]
|
||||
parts[i + 1] = (pre + ph + pre1, m1)
|
||||
del parts[i]
|
||||
continue
|
||||
|
||||
if ph == b"%(":
|
||||
raise e.ProgrammingError(
|
||||
"incomplete placeholder:"
|
||||
f" '{query[m.span(0)[0]:].split()[0].decode(encoding)}'"
|
||||
)
|
||||
elif ph == b"% ":
|
||||
# explicit messasge for a typical error
|
||||
raise e.ProgrammingError(
|
||||
"incomplete placeholder: '%'; if you want to use '%' as an"
|
||||
" operator you can double it up, i.e. use '%%'"
|
||||
)
|
||||
elif ph[-1:] not in b"sbt":
|
||||
raise e.ProgrammingError(
|
||||
"only '%s', '%b', '%t' are allowed as placeholders, got"
|
||||
f" '{m.group(0).decode(encoding)}'"
|
||||
)
|
||||
|
||||
# Index or name
|
||||
item: Union[int, str]
|
||||
item = m.group(1).decode(encoding) if m.group(1) else i
|
||||
|
||||
if not phtype:
|
||||
phtype = type(item)
|
||||
elif phtype is not type(item):
|
||||
raise e.ProgrammingError(
|
||||
"positional and named placeholders cannot be mixed"
|
||||
)
|
||||
|
||||
format = _ph_to_fmt[ph[-1:]]
|
||||
rv.append(QueryPart(pre, item, format))
|
||||
i += 1
|
||||
|
||||
return rv
|
||||
|
||||
|
||||
_ph_to_fmt = {
|
||||
b"s": PyFormat.AUTO,
|
||||
b"t": PyFormat.TEXT,
|
||||
b"b": PyFormat.BINARY,
|
||||
}
|
||||
57
srcs/.venv/lib/python3.11/site-packages/psycopg/_struct.py
Normal file
57
srcs/.venv/lib/python3.11/site-packages/psycopg/_struct.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
Utility functions to deal with binary structs.
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
import struct
|
||||
from typing import Callable, cast, Optional, Tuple
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from .abc import Buffer
|
||||
from . import errors as e
|
||||
from ._compat import Protocol
|
||||
|
||||
PackInt: TypeAlias = Callable[[int], bytes]
|
||||
UnpackInt: TypeAlias = Callable[[Buffer], Tuple[int]]
|
||||
PackFloat: TypeAlias = Callable[[float], bytes]
|
||||
UnpackFloat: TypeAlias = Callable[[Buffer], Tuple[float]]
|
||||
|
||||
|
||||
class UnpackLen(Protocol):
|
||||
def __call__(self, data: Buffer, start: Optional[int]) -> Tuple[int]:
|
||||
...
|
||||
|
||||
|
||||
pack_int2 = cast(PackInt, struct.Struct("!h").pack)
|
||||
pack_uint2 = cast(PackInt, struct.Struct("!H").pack)
|
||||
pack_int4 = cast(PackInt, struct.Struct("!i").pack)
|
||||
pack_uint4 = cast(PackInt, struct.Struct("!I").pack)
|
||||
pack_int8 = cast(PackInt, struct.Struct("!q").pack)
|
||||
pack_float4 = cast(PackFloat, struct.Struct("!f").pack)
|
||||
pack_float8 = cast(PackFloat, struct.Struct("!d").pack)
|
||||
|
||||
unpack_int2 = cast(UnpackInt, struct.Struct("!h").unpack)
|
||||
unpack_uint2 = cast(UnpackInt, struct.Struct("!H").unpack)
|
||||
unpack_int4 = cast(UnpackInt, struct.Struct("!i").unpack)
|
||||
unpack_uint4 = cast(UnpackInt, struct.Struct("!I").unpack)
|
||||
unpack_int8 = cast(UnpackInt, struct.Struct("!q").unpack)
|
||||
unpack_float4 = cast(UnpackFloat, struct.Struct("!f").unpack)
|
||||
unpack_float8 = cast(UnpackFloat, struct.Struct("!d").unpack)
|
||||
|
||||
_struct_len = struct.Struct("!i")
|
||||
pack_len = cast(Callable[[int], bytes], _struct_len.pack)
|
||||
unpack_len = cast(UnpackLen, _struct_len.unpack_from)
|
||||
|
||||
|
||||
def pack_float4_bug_304(x: float) -> bytes:
|
||||
raise e.InterfaceError(
|
||||
"cannot dump Float4: Python affected by bug #304. Note that the psycopg-c"
|
||||
" and psycopg-binary packages are not affected by this issue."
|
||||
" See https://github.com/psycopg/psycopg/issues/304"
|
||||
)
|
||||
|
||||
|
||||
# If issue #304 is detected, raise an error instead of dumping wrong data.
|
||||
if struct.Struct("!f").pack(1.0) != bytes.fromhex("3f800000"):
|
||||
pack_float4 = pack_float4_bug_304
|
||||
116
srcs/.venv/lib/python3.11/site-packages/psycopg/_tpc.py
Normal file
116
srcs/.venv/lib/python3.11/site-packages/psycopg/_tpc.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
psycopg two-phase commit support
|
||||
"""
|
||||
|
||||
# Copyright (C) 2021 The Psycopg Team
|
||||
|
||||
import re
|
||||
import datetime as dt
|
||||
from base64 import b64encode, b64decode
|
||||
from typing import Optional, Union
|
||||
from dataclasses import dataclass, replace
|
||||
|
||||
_re_xid = re.compile(r"^(\d+)_([^_]*)_([^_]*)$")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Xid:
|
||||
"""A two-phase commit transaction identifier.
|
||||
|
||||
The object can also be unpacked as a 3-item tuple (`format_id`, `gtrid`,
|
||||
`bqual`).
|
||||
|
||||
"""
|
||||
|
||||
format_id: Optional[int]
|
||||
gtrid: str
|
||||
bqual: Optional[str]
|
||||
prepared: Optional[dt.datetime] = None
|
||||
owner: Optional[str] = None
|
||||
database: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, s: str) -> "Xid":
|
||||
"""Try to parse an XA triple from the string.
|
||||
|
||||
This may fail for several reasons. In such case return an unparsed Xid.
|
||||
"""
|
||||
try:
|
||||
return cls._parse_string(s)
|
||||
except Exception:
|
||||
return Xid(None, s, None)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self._as_tid()
|
||||
|
||||
def __len__(self) -> int:
|
||||
return 3
|
||||
|
||||
def __getitem__(self, index: int) -> Union[int, str, None]:
|
||||
return (self.format_id, self.gtrid, self.bqual)[index]
|
||||
|
||||
@classmethod
|
||||
def _parse_string(cls, s: str) -> "Xid":
|
||||
m = _re_xid.match(s)
|
||||
if not m:
|
||||
raise ValueError("bad Xid format")
|
||||
|
||||
format_id = int(m.group(1))
|
||||
gtrid = b64decode(m.group(2)).decode()
|
||||
bqual = b64decode(m.group(3)).decode()
|
||||
return cls.from_parts(format_id, gtrid, bqual)
|
||||
|
||||
@classmethod
|
||||
def from_parts(
|
||||
cls, format_id: Optional[int], gtrid: str, bqual: Optional[str]
|
||||
) -> "Xid":
|
||||
if format_id is not None:
|
||||
if bqual is None:
|
||||
raise TypeError("if format_id is specified, bqual must be too")
|
||||
if not 0 <= format_id < 0x80000000:
|
||||
raise ValueError("format_id must be a non-negative 32-bit integer")
|
||||
if len(bqual) > 64:
|
||||
raise ValueError("bqual must be not longer than 64 chars")
|
||||
if len(gtrid) > 64:
|
||||
raise ValueError("gtrid must be not longer than 64 chars")
|
||||
|
||||
elif bqual is None:
|
||||
raise TypeError("if format_id is None, bqual must be None too")
|
||||
|
||||
return Xid(format_id, gtrid, bqual)
|
||||
|
||||
def _as_tid(self) -> str:
|
||||
"""
|
||||
Return the PostgreSQL transaction_id for this XA xid.
|
||||
|
||||
PostgreSQL wants just a string, while the DBAPI supports the XA
|
||||
standard and thus a triple. We use the same conversion algorithm
|
||||
implemented by JDBC in order to allow some form of interoperation.
|
||||
|
||||
see also: the pgjdbc implementation
|
||||
http://cvs.pgfoundry.org/cgi-bin/cvsweb.cgi/jdbc/pgjdbc/org/
|
||||
postgresql/xa/RecoveredXid.java?rev=1.2
|
||||
"""
|
||||
if self.format_id is None or self.bqual is None:
|
||||
# Unparsed xid: return the gtrid.
|
||||
return self.gtrid
|
||||
|
||||
# XA xid: mash together the components.
|
||||
egtrid = b64encode(self.gtrid.encode()).decode()
|
||||
ebqual = b64encode(self.bqual.encode()).decode()
|
||||
|
||||
return f"{self.format_id}_{egtrid}_{ebqual}"
|
||||
|
||||
@classmethod
|
||||
def _get_recover_query(cls) -> str:
|
||||
return "SELECT gid, prepared, owner, database FROM pg_prepared_xacts"
|
||||
|
||||
@classmethod
|
||||
def _from_record(
|
||||
cls, gid: str, prepared: dt.datetime, owner: str, database: str
|
||||
) -> "Xid":
|
||||
xid = Xid.from_string(gid)
|
||||
return replace(xid, prepared=prepared, owner=owner, database=database)
|
||||
|
||||
|
||||
Xid.__module__ = "psycopg"
|
||||
354
srcs/.venv/lib/python3.11/site-packages/psycopg/_transform.py
Normal file
354
srcs/.venv/lib/python3.11/site-packages/psycopg/_transform.py
Normal file
@@ -0,0 +1,354 @@
|
||||
"""
|
||||
Helper object to transform values between Python and PostgreSQL
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
from typing import DefaultDict, TYPE_CHECKING
|
||||
from collections import defaultdict
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from . import pq
|
||||
from . import postgres
|
||||
from . import errors as e
|
||||
from .abc import Buffer, LoadFunc, AdaptContext, PyFormat, DumperKey, NoneType
|
||||
from .rows import Row, RowMaker
|
||||
from .postgres import INVALID_OID, TEXT_OID
|
||||
from ._encodings import pgconn_encoding
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .abc import Dumper, Loader
|
||||
from .adapt import AdaptersMap
|
||||
from .pq.abc import PGresult
|
||||
from .connection import BaseConnection
|
||||
|
||||
DumperCache: TypeAlias = Dict[DumperKey, "Dumper"]
|
||||
OidDumperCache: TypeAlias = Dict[int, "Dumper"]
|
||||
LoaderCache: TypeAlias = Dict[int, "Loader"]
|
||||
|
||||
TEXT = pq.Format.TEXT
|
||||
PY_TEXT = PyFormat.TEXT
|
||||
|
||||
|
||||
class Transformer(AdaptContext):
|
||||
"""
|
||||
An object that can adapt efficiently between Python and PostgreSQL.
|
||||
|
||||
The life cycle of the object is the query, so it is assumed that attributes
|
||||
such as the server version or the connection encoding will not change. The
|
||||
object have its state so adapting several values of the same type can be
|
||||
optimised.
|
||||
|
||||
"""
|
||||
|
||||
__module__ = "psycopg.adapt"
|
||||
|
||||
__slots__ = """
|
||||
types formats
|
||||
_conn _adapters _pgresult _dumpers _loaders _encoding _none_oid
|
||||
_oid_dumpers _oid_types _row_dumpers _row_loaders
|
||||
""".split()
|
||||
|
||||
types: Optional[Tuple[int, ...]]
|
||||
formats: Optional[List[pq.Format]]
|
||||
|
||||
_adapters: "AdaptersMap"
|
||||
_pgresult: Optional["PGresult"]
|
||||
_none_oid: int
|
||||
|
||||
def __init__(self, context: Optional[AdaptContext] = None):
|
||||
self._pgresult = self.types = self.formats = None
|
||||
|
||||
# WARNING: don't store context, or you'll create a loop with the Cursor
|
||||
if context:
|
||||
self._adapters = context.adapters
|
||||
self._conn = context.connection
|
||||
else:
|
||||
self._adapters = postgres.adapters
|
||||
self._conn = None
|
||||
|
||||
# mapping fmt, class -> Dumper instance
|
||||
self._dumpers: DefaultDict[PyFormat, DumperCache]
|
||||
self._dumpers = defaultdict(dict)
|
||||
|
||||
# mapping fmt, oid -> Dumper instance
|
||||
# Not often used, so create it only if needed.
|
||||
self._oid_dumpers: Optional[Tuple[OidDumperCache, OidDumperCache]]
|
||||
self._oid_dumpers = None
|
||||
|
||||
# mapping fmt, oid -> Loader instance
|
||||
self._loaders: Tuple[LoaderCache, LoaderCache] = ({}, {})
|
||||
|
||||
self._row_dumpers: Optional[List["Dumper"]] = None
|
||||
|
||||
# sequence of load functions from value to python
|
||||
# the length of the result columns
|
||||
self._row_loaders: List[LoadFunc] = []
|
||||
|
||||
# mapping oid -> type sql representation
|
||||
self._oid_types: Dict[int, bytes] = {}
|
||||
|
||||
self._encoding = ""
|
||||
|
||||
@classmethod
|
||||
def from_context(cls, context: Optional[AdaptContext]) -> "Transformer":
|
||||
"""
|
||||
Return a Transformer from an AdaptContext.
|
||||
|
||||
If the context is a Transformer instance, just return it.
|
||||
"""
|
||||
if isinstance(context, Transformer):
|
||||
return context
|
||||
else:
|
||||
return cls(context)
|
||||
|
||||
@property
|
||||
def connection(self) -> Optional["BaseConnection[Any]"]:
|
||||
return self._conn
|
||||
|
||||
@property
|
||||
def encoding(self) -> str:
|
||||
if not self._encoding:
|
||||
conn = self.connection
|
||||
self._encoding = pgconn_encoding(conn.pgconn) if conn else "utf-8"
|
||||
return self._encoding
|
||||
|
||||
@property
|
||||
def adapters(self) -> "AdaptersMap":
|
||||
return self._adapters
|
||||
|
||||
@property
|
||||
def pgresult(self) -> Optional["PGresult"]:
|
||||
return self._pgresult
|
||||
|
||||
def set_pgresult(
|
||||
self,
|
||||
result: Optional["PGresult"],
|
||||
*,
|
||||
set_loaders: bool = True,
|
||||
format: Optional[pq.Format] = None,
|
||||
) -> None:
|
||||
self._pgresult = result
|
||||
|
||||
if not result:
|
||||
self._nfields = self._ntuples = 0
|
||||
if set_loaders:
|
||||
self._row_loaders = []
|
||||
return
|
||||
|
||||
self._ntuples = result.ntuples
|
||||
nf = self._nfields = result.nfields
|
||||
|
||||
if not set_loaders:
|
||||
return
|
||||
|
||||
if not nf:
|
||||
self._row_loaders = []
|
||||
return
|
||||
|
||||
fmt: pq.Format
|
||||
fmt = result.fformat(0) if format is None else format # type: ignore
|
||||
self._row_loaders = [
|
||||
self.get_loader(result.ftype(i), fmt).load for i in range(nf)
|
||||
]
|
||||
|
||||
def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None:
|
||||
self._row_dumpers = [self.get_dumper_by_oid(oid, format) for oid in types]
|
||||
self.types = tuple(types)
|
||||
self.formats = [format] * len(types)
|
||||
|
||||
def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None:
|
||||
self._row_loaders = [self.get_loader(oid, format).load for oid in types]
|
||||
|
||||
def dump_sequence(
|
||||
self, params: Sequence[Any], formats: Sequence[PyFormat]
|
||||
) -> Sequence[Optional[Buffer]]:
|
||||
nparams = len(params)
|
||||
out: List[Optional[Buffer]] = [None] * nparams
|
||||
|
||||
# If we have dumpers, it means set_dumper_types had been called, in
|
||||
# which case self.types and self.formats are set to sequences of the
|
||||
# right size.
|
||||
if self._row_dumpers:
|
||||
for i in range(nparams):
|
||||
param = params[i]
|
||||
if param is not None:
|
||||
out[i] = self._row_dumpers[i].dump(param)
|
||||
return out
|
||||
|
||||
types = [self._get_none_oid()] * nparams
|
||||
pqformats = [TEXT] * nparams
|
||||
|
||||
for i in range(nparams):
|
||||
param = params[i]
|
||||
if param is None:
|
||||
continue
|
||||
dumper = self.get_dumper(param, formats[i])
|
||||
out[i] = dumper.dump(param)
|
||||
types[i] = dumper.oid
|
||||
pqformats[i] = dumper.format
|
||||
|
||||
self.types = tuple(types)
|
||||
self.formats = pqformats
|
||||
|
||||
return out
|
||||
|
||||
def as_literal(self, obj: Any) -> bytes:
|
||||
dumper = self.get_dumper(obj, PY_TEXT)
|
||||
rv = dumper.quote(obj)
|
||||
# If the result is quoted, and the oid not unknown or text,
|
||||
# add an explicit type cast.
|
||||
# Check the last char because the first one might be 'E'.
|
||||
oid = dumper.oid
|
||||
if oid and rv and rv[-1] == b"'"[0] and oid != TEXT_OID:
|
||||
try:
|
||||
type_sql = self._oid_types[oid]
|
||||
except KeyError:
|
||||
ti = self.adapters.types.get(oid)
|
||||
if ti:
|
||||
if oid < 8192:
|
||||
# builtin: prefer "timestamptz" to "timestamp with time zone"
|
||||
type_sql = ti.name.encode(self.encoding)
|
||||
else:
|
||||
type_sql = ti.regtype.encode(self.encoding)
|
||||
if oid == ti.array_oid:
|
||||
type_sql += b"[]"
|
||||
else:
|
||||
type_sql = b""
|
||||
self._oid_types[oid] = type_sql
|
||||
|
||||
if type_sql:
|
||||
rv = b"%s::%s" % (rv, type_sql)
|
||||
|
||||
if not isinstance(rv, bytes):
|
||||
rv = bytes(rv)
|
||||
return rv
|
||||
|
||||
def get_dumper(self, obj: Any, format: PyFormat) -> "Dumper":
|
||||
"""
|
||||
Return a Dumper instance to dump `!obj`.
|
||||
"""
|
||||
# Normally, the type of the object dictates how to dump it
|
||||
key = type(obj)
|
||||
|
||||
# Reuse an existing Dumper class for objects of the same type
|
||||
cache = self._dumpers[format]
|
||||
try:
|
||||
dumper = cache[key]
|
||||
except KeyError:
|
||||
# If it's the first time we see this type, look for a dumper
|
||||
# configured for it.
|
||||
try:
|
||||
dcls = self.adapters.get_dumper(key, format)
|
||||
except e.ProgrammingError as ex:
|
||||
raise ex from None
|
||||
else:
|
||||
cache[key] = dumper = dcls(key, self)
|
||||
|
||||
# Check if the dumper requires an upgrade to handle this specific value
|
||||
key1 = dumper.get_key(obj, format)
|
||||
if key1 is key:
|
||||
return dumper
|
||||
|
||||
# If it does, ask the dumper to create its own upgraded version
|
||||
try:
|
||||
return cache[key1]
|
||||
except KeyError:
|
||||
dumper = cache[key1] = dumper.upgrade(obj, format)
|
||||
return dumper
|
||||
|
||||
def _get_none_oid(self) -> int:
|
||||
try:
|
||||
return self._none_oid
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
rv = self._none_oid = self._adapters.get_dumper(NoneType, PY_TEXT).oid
|
||||
except KeyError:
|
||||
raise e.InterfaceError("None dumper not found")
|
||||
|
||||
return rv
|
||||
|
||||
def get_dumper_by_oid(self, oid: int, format: pq.Format) -> "Dumper":
|
||||
"""
|
||||
Return a Dumper to dump an object to the type with given oid.
|
||||
"""
|
||||
if not self._oid_dumpers:
|
||||
self._oid_dumpers = ({}, {})
|
||||
|
||||
# Reuse an existing Dumper class for objects of the same type
|
||||
cache = self._oid_dumpers[format]
|
||||
try:
|
||||
return cache[oid]
|
||||
except KeyError:
|
||||
# If it's the first time we see this type, look for a dumper
|
||||
# configured for it.
|
||||
dcls = self.adapters.get_dumper_by_oid(oid, format)
|
||||
cache[oid] = dumper = dcls(NoneType, self)
|
||||
|
||||
return dumper
|
||||
|
||||
def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> List[Row]:
|
||||
res = self._pgresult
|
||||
if not res:
|
||||
raise e.InterfaceError("result not set")
|
||||
|
||||
if not (0 <= row0 <= self._ntuples and 0 <= row1 <= self._ntuples):
|
||||
raise e.InterfaceError(
|
||||
f"rows must be included between 0 and {self._ntuples}"
|
||||
)
|
||||
|
||||
records = []
|
||||
for row in range(row0, row1):
|
||||
record: List[Any] = [None] * self._nfields
|
||||
for col in range(self._nfields):
|
||||
val = res.get_value(row, col)
|
||||
if val is not None:
|
||||
record[col] = self._row_loaders[col](val)
|
||||
records.append(make_row(record))
|
||||
|
||||
return records
|
||||
|
||||
def load_row(self, row: int, make_row: RowMaker[Row]) -> Optional[Row]:
|
||||
res = self._pgresult
|
||||
if not res:
|
||||
return None
|
||||
|
||||
if not 0 <= row < self._ntuples:
|
||||
return None
|
||||
|
||||
record: List[Any] = [None] * self._nfields
|
||||
for col in range(self._nfields):
|
||||
val = res.get_value(row, col)
|
||||
if val is not None:
|
||||
record[col] = self._row_loaders[col](val)
|
||||
|
||||
return make_row(record)
|
||||
|
||||
def load_sequence(self, record: Sequence[Optional[Buffer]]) -> Tuple[Any, ...]:
|
||||
if len(self._row_loaders) != len(record):
|
||||
raise e.ProgrammingError(
|
||||
f"cannot load sequence of {len(record)} items:"
|
||||
f" {len(self._row_loaders)} loaders registered"
|
||||
)
|
||||
|
||||
return tuple(
|
||||
(self._row_loaders[i](val) if val is not None else None)
|
||||
for i, val in enumerate(record)
|
||||
)
|
||||
|
||||
def get_loader(self, oid: int, format: pq.Format) -> "Loader":
|
||||
try:
|
||||
return self._loaders[format][oid]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
loader_cls = self._adapters.get_loader(oid, format)
|
||||
if not loader_cls:
|
||||
loader_cls = self._adapters.get_loader(INVALID_OID, format)
|
||||
if not loader_cls:
|
||||
raise e.InterfaceError("unknown oid loader not found")
|
||||
loader = self._loaders[format][oid] = loader_cls(oid, self)
|
||||
return loader
|
||||
500
srcs/.venv/lib/python3.11/site-packages/psycopg/_typeinfo.py
Normal file
500
srcs/.venv/lib/python3.11/site-packages/psycopg/_typeinfo.py
Normal file
@@ -0,0 +1,500 @@
|
||||
"""
|
||||
Information about PostgreSQL types
|
||||
|
||||
These types allow to read information from the system catalog and provide
|
||||
information to the adapters if needed.
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Iterator, Optional, overload
|
||||
from typing import Sequence, Tuple, Type, TypeVar, Union, TYPE_CHECKING
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from . import errors as e
|
||||
from .abc import AdaptContext, Query
|
||||
from .rows import dict_row
|
||||
from ._encodings import conn_encoding
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .connection import BaseConnection, Connection
|
||||
from .connection_async import AsyncConnection
|
||||
from .sql import Identifier, SQL
|
||||
|
||||
T = TypeVar("T", bound="TypeInfo")
|
||||
RegistryKey: TypeAlias = Union[str, int, Tuple[type, int]]
|
||||
|
||||
|
||||
class TypeInfo:
|
||||
"""
|
||||
Hold information about a PostgreSQL base type.
|
||||
"""
|
||||
|
||||
__module__ = "psycopg.types"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
oid: int,
|
||||
array_oid: int,
|
||||
*,
|
||||
regtype: str = "",
|
||||
delimiter: str = ",",
|
||||
):
|
||||
self.name = name
|
||||
self.oid = oid
|
||||
self.array_oid = array_oid
|
||||
self.regtype = regtype or name
|
||||
self.delimiter = delimiter
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<{self.__class__.__qualname__}:"
|
||||
f" {self.name} (oid: {self.oid}, array oid: {self.array_oid})>"
|
||||
)
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def fetch(
|
||||
cls: Type[T], conn: "Connection[Any]", name: Union[str, "Identifier"]
|
||||
) -> Optional[T]:
|
||||
...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def fetch(
|
||||
cls: Type[T], conn: "AsyncConnection[Any]", name: Union[str, "Identifier"]
|
||||
) -> Optional[T]:
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def fetch(
|
||||
cls: Type[T], conn: "BaseConnection[Any]", name: Union[str, "Identifier"]
|
||||
) -> Any:
|
||||
"""Query a system catalog to read information about a type."""
|
||||
from .sql import Composable
|
||||
from .connection import Connection
|
||||
from .connection_async import AsyncConnection
|
||||
|
||||
if isinstance(name, Composable):
|
||||
name = name.as_string(conn)
|
||||
|
||||
if isinstance(conn, Connection):
|
||||
return cls._fetch(conn, name)
|
||||
elif isinstance(conn, AsyncConnection):
|
||||
return cls._fetch_async(conn, name)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"expected Connection or AsyncConnection, got {type(conn).__name__}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _fetch(cls: Type[T], conn: "Connection[Any]", name: str) -> Optional[T]:
|
||||
# This might result in a nested transaction. What we want is to leave
|
||||
# the function with the connection in the state we found (either idle
|
||||
# or intrans)
|
||||
try:
|
||||
with conn.transaction():
|
||||
if conn_encoding(conn) == "ascii":
|
||||
conn.execute("set local client_encoding to utf8")
|
||||
with conn.cursor(row_factory=dict_row) as cur:
|
||||
cur.execute(cls._get_info_query(conn), {"name": name})
|
||||
recs = cur.fetchall()
|
||||
except e.UndefinedObject:
|
||||
return None
|
||||
|
||||
return cls._from_records(name, recs)
|
||||
|
||||
@classmethod
|
||||
async def _fetch_async(
|
||||
cls: Type[T], conn: "AsyncConnection[Any]", name: str
|
||||
) -> Optional[T]:
|
||||
try:
|
||||
async with conn.transaction():
|
||||
if conn_encoding(conn) == "ascii":
|
||||
await conn.execute("set local client_encoding to utf8")
|
||||
async with conn.cursor(row_factory=dict_row) as cur:
|
||||
await cur.execute(cls._get_info_query(conn), {"name": name})
|
||||
recs = await cur.fetchall()
|
||||
except e.UndefinedObject:
|
||||
return None
|
||||
|
||||
return cls._from_records(name, recs)
|
||||
|
||||
@classmethod
|
||||
def _from_records(
|
||||
cls: Type[T], name: str, recs: Sequence[Dict[str, Any]]
|
||||
) -> Optional[T]:
|
||||
if len(recs) == 1:
|
||||
return cls(**recs[0])
|
||||
elif not recs:
|
||||
return None
|
||||
else:
|
||||
raise e.ProgrammingError(f"found {len(recs)} different types named {name}")
|
||||
|
||||
def register(self, context: Optional[AdaptContext] = None) -> None:
|
||||
"""
|
||||
Register the type information, globally or in the specified `!context`.
|
||||
"""
|
||||
if context:
|
||||
types = context.adapters.types
|
||||
else:
|
||||
from . import postgres
|
||||
|
||||
types = postgres.types
|
||||
|
||||
types.add(self)
|
||||
|
||||
if self.array_oid:
|
||||
from .types.array import register_array
|
||||
|
||||
register_array(self, context)
|
||||
|
||||
@classmethod
|
||||
def _get_info_query(cls, conn: "BaseConnection[Any]") -> Query:
|
||||
from .sql import SQL
|
||||
|
||||
return SQL(
|
||||
"""\
|
||||
SELECT
|
||||
typname AS name, oid, typarray AS array_oid,
|
||||
oid::regtype::text AS regtype, typdelim AS delimiter
|
||||
FROM pg_type t
|
||||
WHERE t.oid = {regtype}
|
||||
ORDER BY t.oid
|
||||
"""
|
||||
).format(regtype=cls._to_regtype(conn))
|
||||
|
||||
@classmethod
|
||||
def _has_to_regtype_function(cls, conn: "BaseConnection[Any]") -> bool:
|
||||
# to_regtype() introduced in PostgreSQL 9.4 and CockroachDB 22.2
|
||||
info = conn.info
|
||||
if info.vendor == "PostgreSQL":
|
||||
return info.server_version >= 90400
|
||||
elif info.vendor == "CockroachDB":
|
||||
return info.server_version >= 220200
|
||||
else:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _to_regtype(cls, conn: "BaseConnection[Any]") -> "SQL":
|
||||
# `to_regtype()` returns the type oid or NULL, unlike the :: operator,
|
||||
# which returns the type or raises an exception, which requires
|
||||
# a transaction rollback and leaves traces in the server logs.
|
||||
|
||||
from .sql import SQL
|
||||
|
||||
if cls._has_to_regtype_function(conn):
|
||||
return SQL("to_regtype(%(name)s)")
|
||||
else:
|
||||
return SQL("%(name)s::regtype")
|
||||
|
||||
def _added(self, registry: "TypesRegistry") -> None:
|
||||
"""Method called by the `!registry` when the object is added there."""
|
||||
pass
|
||||
|
||||
|
||||
class RangeInfo(TypeInfo):
|
||||
"""Manage information about a range type."""
|
||||
|
||||
__module__ = "psycopg.types.range"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
oid: int,
|
||||
array_oid: int,
|
||||
*,
|
||||
regtype: str = "",
|
||||
subtype_oid: int,
|
||||
):
|
||||
super().__init__(name, oid, array_oid, regtype=regtype)
|
||||
self.subtype_oid = subtype_oid
|
||||
|
||||
@classmethod
|
||||
def _get_info_query(cls, conn: "BaseConnection[Any]") -> Query:
|
||||
from .sql import SQL
|
||||
|
||||
return SQL(
|
||||
"""\
|
||||
SELECT t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
|
||||
t.oid::regtype::text AS regtype,
|
||||
r.rngsubtype AS subtype_oid
|
||||
FROM pg_type t
|
||||
JOIN pg_range r ON t.oid = r.rngtypid
|
||||
WHERE t.oid = {regtype}
|
||||
"""
|
||||
).format(regtype=cls._to_regtype(conn))
|
||||
|
||||
def _added(self, registry: "TypesRegistry") -> None:
|
||||
# Map ranges subtypes to info
|
||||
registry._registry[RangeInfo, self.subtype_oid] = self
|
||||
|
||||
|
||||
class MultirangeInfo(TypeInfo):
|
||||
"""Manage information about a multirange type."""
|
||||
|
||||
__module__ = "psycopg.types.multirange"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
oid: int,
|
||||
array_oid: int,
|
||||
*,
|
||||
regtype: str = "",
|
||||
range_oid: int,
|
||||
subtype_oid: int,
|
||||
):
|
||||
super().__init__(name, oid, array_oid, regtype=regtype)
|
||||
self.range_oid = range_oid
|
||||
self.subtype_oid = subtype_oid
|
||||
|
||||
@classmethod
|
||||
def _get_info_query(cls, conn: "BaseConnection[Any]") -> Query:
|
||||
from .sql import SQL
|
||||
|
||||
if conn.info.server_version < 140000:
|
||||
raise e.NotSupportedError(
|
||||
"multirange types are only available from PostgreSQL 14"
|
||||
)
|
||||
|
||||
return SQL(
|
||||
"""\
|
||||
SELECT t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
|
||||
t.oid::regtype::text AS regtype,
|
||||
r.rngtypid AS range_oid, r.rngsubtype AS subtype_oid
|
||||
FROM pg_type t
|
||||
JOIN pg_range r ON t.oid = r.rngmultitypid
|
||||
WHERE t.oid = {regtype}
|
||||
"""
|
||||
).format(regtype=cls._to_regtype(conn))
|
||||
|
||||
def _added(self, registry: "TypesRegistry") -> None:
|
||||
# Map multiranges ranges and subtypes to info
|
||||
registry._registry[MultirangeInfo, self.range_oid] = self
|
||||
registry._registry[MultirangeInfo, self.subtype_oid] = self
|
||||
|
||||
|
||||
class CompositeInfo(TypeInfo):
|
||||
"""Manage information about a composite type."""
|
||||
|
||||
__module__ = "psycopg.types.composite"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
oid: int,
|
||||
array_oid: int,
|
||||
*,
|
||||
regtype: str = "",
|
||||
field_names: Sequence[str],
|
||||
field_types: Sequence[int],
|
||||
):
|
||||
super().__init__(name, oid, array_oid, regtype=regtype)
|
||||
self.field_names = field_names
|
||||
self.field_types = field_types
|
||||
# Will be set by register() if the `factory` is a type
|
||||
self.python_type: Optional[type] = None
|
||||
|
||||
@classmethod
|
||||
def _get_info_query(cls, conn: "BaseConnection[Any]") -> Query:
|
||||
from .sql import SQL
|
||||
|
||||
return SQL(
|
||||
"""\
|
||||
SELECT
|
||||
t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
|
||||
t.oid::regtype::text AS regtype,
|
||||
coalesce(a.fnames, '{{}}') AS field_names,
|
||||
coalesce(a.ftypes, '{{}}') AS field_types
|
||||
FROM pg_type t
|
||||
LEFT JOIN (
|
||||
SELECT
|
||||
attrelid,
|
||||
array_agg(attname) AS fnames,
|
||||
array_agg(atttypid) AS ftypes
|
||||
FROM (
|
||||
SELECT a.attrelid, a.attname, a.atttypid
|
||||
FROM pg_attribute a
|
||||
JOIN pg_type t ON t.typrelid = a.attrelid
|
||||
WHERE t.oid = {regtype}
|
||||
AND a.attnum > 0
|
||||
AND NOT a.attisdropped
|
||||
ORDER BY a.attnum
|
||||
) x
|
||||
GROUP BY attrelid
|
||||
) a ON a.attrelid = t.typrelid
|
||||
WHERE t.oid = {regtype}
|
||||
"""
|
||||
).format(regtype=cls._to_regtype(conn))
|
||||
|
||||
|
||||
class EnumInfo(TypeInfo):
|
||||
"""Manage information about an enum type."""
|
||||
|
||||
__module__ = "psycopg.types.enum"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
oid: int,
|
||||
array_oid: int,
|
||||
labels: Sequence[str],
|
||||
):
|
||||
super().__init__(name, oid, array_oid)
|
||||
self.labels = labels
|
||||
# Will be set by register_enum()
|
||||
self.enum: Optional[Type[Enum]] = None
|
||||
|
||||
@classmethod
|
||||
def _get_info_query(cls, conn: "BaseConnection[Any]") -> Query:
|
||||
from .sql import SQL
|
||||
|
||||
return SQL(
|
||||
"""\
|
||||
SELECT name, oid, array_oid, array_agg(label) AS labels
|
||||
FROM (
|
||||
SELECT
|
||||
t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
|
||||
e.enumlabel AS label
|
||||
FROM pg_type t
|
||||
LEFT JOIN pg_enum e
|
||||
ON e.enumtypid = t.oid
|
||||
WHERE t.oid = {regtype}
|
||||
ORDER BY e.enumsortorder
|
||||
) x
|
||||
GROUP BY name, oid, array_oid
|
||||
"""
|
||||
).format(regtype=cls._to_regtype(conn))
|
||||
|
||||
|
||||
class TypesRegistry:
|
||||
"""
|
||||
Container for the information about types in a database.
|
||||
"""
|
||||
|
||||
__module__ = "psycopg.types"
|
||||
|
||||
def __init__(self, template: Optional["TypesRegistry"] = None):
|
||||
self._registry: Dict[RegistryKey, TypeInfo]
|
||||
|
||||
# Make a shallow copy: it will become a proper copy if the registry
|
||||
# is edited.
|
||||
if template:
|
||||
self._registry = template._registry
|
||||
self._own_state = False
|
||||
template._own_state = False
|
||||
else:
|
||||
self.clear()
|
||||
|
||||
def clear(self) -> None:
|
||||
self._registry = {}
|
||||
self._own_state = True
|
||||
|
||||
def add(self, info: TypeInfo) -> None:
|
||||
self._ensure_own_state()
|
||||
if info.oid:
|
||||
self._registry[info.oid] = info
|
||||
if info.array_oid:
|
||||
self._registry[info.array_oid] = info
|
||||
self._registry[info.name] = info
|
||||
|
||||
if info.regtype and info.regtype not in self._registry:
|
||||
self._registry[info.regtype] = info
|
||||
|
||||
# Allow info to customise further their relation with the registry
|
||||
info._added(self)
|
||||
|
||||
def __iter__(self) -> Iterator[TypeInfo]:
|
||||
seen = set()
|
||||
for t in self._registry.values():
|
||||
if id(t) not in seen:
|
||||
seen.add(id(t))
|
||||
yield t
|
||||
|
||||
@overload
|
||||
def __getitem__(self, key: Union[str, int]) -> TypeInfo:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, key: Tuple[Type[T], int]) -> T:
|
||||
...
|
||||
|
||||
def __getitem__(self, key: RegistryKey) -> TypeInfo:
|
||||
"""
|
||||
Return info about a type, specified by name or oid
|
||||
|
||||
:param key: the name or oid of the type to look for.
|
||||
|
||||
Raise KeyError if not found.
|
||||
"""
|
||||
if isinstance(key, str):
|
||||
if key.endswith("[]"):
|
||||
key = key[:-2]
|
||||
elif not isinstance(key, (int, tuple)):
|
||||
raise TypeError(f"the key must be an oid or a name, got {type(key)}")
|
||||
try:
|
||||
return self._registry[key]
|
||||
except KeyError:
|
||||
raise KeyError(f"couldn't find the type {key!r} in the types registry")
|
||||
|
||||
@overload
|
||||
def get(self, key: Union[str, int]) -> Optional[TypeInfo]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get(self, key: Tuple[Type[T], int]) -> Optional[T]:
|
||||
...
|
||||
|
||||
def get(self, key: RegistryKey) -> Optional[TypeInfo]:
|
||||
"""
|
||||
Return info about a type, specified by name or oid
|
||||
|
||||
:param key: the name or oid of the type to look for.
|
||||
|
||||
Unlike `__getitem__`, return None if not found.
|
||||
"""
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
def get_oid(self, name: str) -> int:
|
||||
"""
|
||||
Return the oid of a PostgreSQL type by name.
|
||||
|
||||
:param key: the name of the type to look for.
|
||||
|
||||
Return the array oid if the type ends with "``[]``"
|
||||
|
||||
Raise KeyError if the name is unknown.
|
||||
"""
|
||||
t = self[name]
|
||||
if name.endswith("[]"):
|
||||
return t.array_oid
|
||||
else:
|
||||
return t.oid
|
||||
|
||||
def get_by_subtype(self, cls: Type[T], subtype: Union[int, str]) -> Optional[T]:
|
||||
"""
|
||||
Return info about a `TypeInfo` subclass by its element name or oid.
|
||||
|
||||
:param cls: the subtype of `!TypeInfo` to look for. Currently
|
||||
supported are `~psycopg.types.range.RangeInfo` and
|
||||
`~psycopg.types.multirange.MultirangeInfo`.
|
||||
:param subtype: The name or OID of the subtype of the element to look for.
|
||||
:return: The `!TypeInfo` object of class `!cls` whose subtype is
|
||||
`!subtype`. `!None` if the element or its range are not found.
|
||||
"""
|
||||
try:
|
||||
info = self[subtype]
|
||||
except KeyError:
|
||||
return None
|
||||
return self.get((cls, info.oid))
|
||||
|
||||
def _ensure_own_state(self) -> None:
|
||||
# Time to write! so, copy.
|
||||
if not self._own_state:
|
||||
self._registry = self._registry.copy()
|
||||
self._own_state = True
|
||||
44
srcs/.venv/lib/python3.11/site-packages/psycopg/_tz.py
Normal file
44
srcs/.venv/lib/python3.11/site-packages/psycopg/_tz.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""
|
||||
Timezone utility functions.
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
import logging
|
||||
from typing import Dict, Optional, Union
|
||||
from datetime import timezone, tzinfo
|
||||
|
||||
from .pq.abc import PGconn
|
||||
from ._compat import ZoneInfo
|
||||
|
||||
logger = logging.getLogger("psycopg")
|
||||
|
||||
_timezones: Dict[Union[None, bytes], tzinfo] = {
|
||||
None: timezone.utc,
|
||||
b"UTC": timezone.utc,
|
||||
}
|
||||
|
||||
|
||||
def get_tzinfo(pgconn: Optional[PGconn]) -> tzinfo:
|
||||
"""Return the Python timezone info of the connection's timezone."""
|
||||
tzname = pgconn.parameter_status(b"TimeZone") if pgconn else None
|
||||
try:
|
||||
return _timezones[tzname]
|
||||
except KeyError:
|
||||
sname = tzname.decode() if tzname else "UTC"
|
||||
try:
|
||||
zi: tzinfo = ZoneInfo(sname)
|
||||
except (KeyError, OSError):
|
||||
logger.warning("unknown PostgreSQL timezone: %r; will use UTC", sname)
|
||||
zi = timezone.utc
|
||||
except Exception as ex:
|
||||
logger.warning(
|
||||
"error handling PostgreSQL timezone: %r; will use UTC (%s - %s)",
|
||||
sname,
|
||||
type(ex).__name__,
|
||||
ex,
|
||||
)
|
||||
zi = timezone.utc
|
||||
|
||||
_timezones[tzname] = zi
|
||||
return zi
|
||||
137
srcs/.venv/lib/python3.11/site-packages/psycopg/_wrappers.py
Normal file
137
srcs/.venv/lib/python3.11/site-packages/psycopg/_wrappers.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
Wrappers for numeric types.
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
# Wrappers to force numbers to be cast as specific PostgreSQL types
|
||||
|
||||
# These types are implemented here but exposed by `psycopg.types.numeric`.
|
||||
# They are defined here to avoid a circular import.
|
||||
_MODULE = "psycopg.types.numeric"
|
||||
|
||||
|
||||
class Int2(int):
|
||||
"""
|
||||
Force dumping a Python `!int` as a PostgreSQL :sql:`smallint/int2`.
|
||||
"""
|
||||
|
||||
__module__ = _MODULE
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls, arg: int) -> "Int2":
|
||||
return super().__new__(cls, arg)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return super().__repr__()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}({super().__repr__()})"
|
||||
|
||||
|
||||
class Int4(int):
|
||||
"""
|
||||
Force dumping a Python `!int` as a PostgreSQL :sql:`integer/int4`.
|
||||
"""
|
||||
|
||||
__module__ = _MODULE
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls, arg: int) -> "Int4":
|
||||
return super().__new__(cls, arg)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return super().__repr__()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}({super().__repr__()})"
|
||||
|
||||
|
||||
class Int8(int):
|
||||
"""
|
||||
Force dumping a Python `!int` as a PostgreSQL :sql:`bigint/int8`.
|
||||
"""
|
||||
|
||||
__module__ = _MODULE
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls, arg: int) -> "Int8":
|
||||
return super().__new__(cls, arg)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return super().__repr__()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}({super().__repr__()})"
|
||||
|
||||
|
||||
class IntNumeric(int):
|
||||
"""
|
||||
Force dumping a Python `!int` as a PostgreSQL :sql:`numeric/decimal`.
|
||||
"""
|
||||
|
||||
__module__ = _MODULE
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls, arg: int) -> "IntNumeric":
|
||||
return super().__new__(cls, arg)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return super().__repr__()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}({super().__repr__()})"
|
||||
|
||||
|
||||
class Float4(float):
|
||||
"""
|
||||
Force dumping a Python `!float` as a PostgreSQL :sql:`float4/real`.
|
||||
"""
|
||||
|
||||
__module__ = _MODULE
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls, arg: float) -> "Float4":
|
||||
return super().__new__(cls, arg)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return super().__repr__()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}({super().__repr__()})"
|
||||
|
||||
|
||||
class Float8(float):
|
||||
"""
|
||||
Force dumping a Python `!float` as a PostgreSQL :sql:`float8/double precision`.
|
||||
"""
|
||||
|
||||
__module__ = _MODULE
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls, arg: float) -> "Float8":
|
||||
return super().__new__(cls, arg)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return super().__repr__()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}({super().__repr__()})"
|
||||
|
||||
|
||||
class Oid(int):
|
||||
"""
|
||||
Force dumping a Python `!int` as a PostgreSQL :sql:`oid`.
|
||||
"""
|
||||
|
||||
__module__ = _MODULE
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls, arg: int) -> "Oid":
|
||||
return super().__new__(cls, arg)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return super().__repr__()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}({super().__repr__()})"
|
||||
265
srcs/.venv/lib/python3.11/site-packages/psycopg/abc.py
Normal file
265
srcs/.venv/lib/python3.11/site-packages/psycopg/abc.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""
|
||||
Protocol objects representing different implementations of the same classes.
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
from typing import Any, Callable, Generator, Mapping
|
||||
from typing import List, Optional, Sequence, Tuple, TypeVar, Union
|
||||
from typing import TYPE_CHECKING
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from . import pq
|
||||
from ._enums import PyFormat as PyFormat
|
||||
from ._compat import Protocol, LiteralString
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import sql
|
||||
from .rows import Row, RowMaker
|
||||
from .pq.abc import PGresult
|
||||
from .waiting import Wait, Ready
|
||||
from .connection import BaseConnection
|
||||
from ._adapters_map import AdaptersMap
|
||||
|
||||
NoneType: type = type(None)
|
||||
|
||||
# An object implementing the buffer protocol
|
||||
Buffer: TypeAlias = Union[bytes, bytearray, memoryview]
|
||||
|
||||
Query: TypeAlias = Union[LiteralString, bytes, "sql.SQL", "sql.Composed"]
|
||||
Params: TypeAlias = Union[Sequence[Any], Mapping[str, Any]]
|
||||
ConnectionType = TypeVar("ConnectionType", bound="BaseConnection[Any]")
|
||||
PipelineCommand: TypeAlias = Callable[[], None]
|
||||
DumperKey: TypeAlias = Union[type, Tuple["DumperKey", ...]]
|
||||
|
||||
# Waiting protocol types
|
||||
|
||||
RV = TypeVar("RV")
|
||||
|
||||
PQGenConn: TypeAlias = Generator[Tuple[int, "Wait"], "Ready", RV]
|
||||
"""Generator for processes where the connection file number can change.
|
||||
|
||||
This can happen in connection and reset, but not in normal querying.
|
||||
"""
|
||||
|
||||
PQGen: TypeAlias = Generator["Wait", "Ready", RV]
|
||||
"""Generator for processes where the connection file number won't change.
|
||||
"""
|
||||
|
||||
|
||||
class WaitFunc(Protocol):
|
||||
"""
|
||||
Wait on the connection which generated `PQgen` and return its final result.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self, gen: PQGen[RV], fileno: int, timeout: Optional[float] = None
|
||||
) -> RV:
|
||||
...
|
||||
|
||||
|
||||
# Adaptation types
|
||||
|
||||
DumpFunc: TypeAlias = Callable[[Any], Buffer]
|
||||
LoadFunc: TypeAlias = Callable[[Buffer], Any]
|
||||
|
||||
|
||||
class AdaptContext(Protocol):
|
||||
"""
|
||||
A context describing how types are adapted.
|
||||
|
||||
Example of `~AdaptContext` are `~psycopg.Connection`, `~psycopg.Cursor`,
|
||||
`~psycopg.adapt.Transformer`, `~psycopg.adapt.AdaptersMap`.
|
||||
|
||||
Note that this is a `~typing.Protocol`, so objects implementing
|
||||
`!AdaptContext` don't need to explicitly inherit from this class.
|
||||
|
||||
"""
|
||||
|
||||
@property
|
||||
def adapters(self) -> "AdaptersMap":
|
||||
"""The adapters configuration that this object uses."""
|
||||
...
|
||||
|
||||
@property
|
||||
def connection(self) -> Optional["BaseConnection[Any]"]:
|
||||
"""The connection used by this object, if available.
|
||||
|
||||
:rtype: `~psycopg.Connection` or `~psycopg.AsyncConnection` or `!None`
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class Dumper(Protocol):
|
||||
"""
|
||||
Convert Python objects of type `!cls` to PostgreSQL representation.
|
||||
"""
|
||||
|
||||
format: pq.Format
|
||||
"""
|
||||
The format that this class `dump()` method produces,
|
||||
`~psycopg.pq.Format.TEXT` or `~psycopg.pq.Format.BINARY`.
|
||||
|
||||
This is a class attribute.
|
||||
"""
|
||||
|
||||
oid: int
|
||||
"""The oid to pass to the server, if known; 0 otherwise (class attribute)."""
|
||||
|
||||
def __init__(self, cls: type, context: Optional[AdaptContext] = None):
|
||||
...
|
||||
|
||||
def dump(self, obj: Any) -> Buffer:
|
||||
"""Convert the object `!obj` to PostgreSQL representation.
|
||||
|
||||
:param obj: the object to convert.
|
||||
"""
|
||||
...
|
||||
|
||||
def quote(self, obj: Any) -> Buffer:
|
||||
"""Convert the object `!obj` to escaped representation.
|
||||
|
||||
:param obj: the object to convert.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_key(self, obj: Any, format: PyFormat) -> DumperKey:
|
||||
"""Return an alternative key to upgrade the dumper to represent `!obj`.
|
||||
|
||||
:param obj: The object to convert
|
||||
:param format: The format to convert to
|
||||
|
||||
Normally the type of the object is all it takes to define how to dump
|
||||
the object to the database. For instance, a Python `~datetime.date` can
|
||||
be simply converted into a PostgreSQL :sql:`date`.
|
||||
|
||||
In a few cases, just the type is not enough. For example:
|
||||
|
||||
- A Python `~datetime.datetime` could be represented as a
|
||||
:sql:`timestamptz` or a :sql:`timestamp`, according to whether it
|
||||
specifies a `!tzinfo` or not.
|
||||
|
||||
- A Python int could be stored as several Postgres types: int2, int4,
|
||||
int8, numeric. If a type too small is used, it may result in an
|
||||
overflow. If a type too large is used, PostgreSQL may not want to
|
||||
cast it to a smaller type.
|
||||
|
||||
- Python lists should be dumped according to the type they contain to
|
||||
convert them to e.g. array of strings, array of ints (and which
|
||||
size of int?...)
|
||||
|
||||
In these cases, a dumper can implement `!get_key()` and return a new
|
||||
class, or sequence of classes, that can be used to identify the same
|
||||
dumper again. If the mechanism is not needed, the method should return
|
||||
the same `!cls` object passed in the constructor.
|
||||
|
||||
If a dumper implements `get_key()` it should also implement
|
||||
`upgrade()`.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
def upgrade(self, obj: Any, format: PyFormat) -> "Dumper":
|
||||
"""Return a new dumper to manage `!obj`.
|
||||
|
||||
:param obj: The object to convert
|
||||
:param format: The format to convert to
|
||||
|
||||
Once `Transformer.get_dumper()` has been notified by `get_key()` that
|
||||
this Dumper class cannot handle `!obj` itself, it will invoke
|
||||
`!upgrade()`, which should return a new `Dumper` instance, which will
|
||||
be reused for every objects for which `!get_key()` returns the same
|
||||
result.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class Loader(Protocol):
|
||||
"""
|
||||
Convert PostgreSQL values with type OID `!oid` to Python objects.
|
||||
"""
|
||||
|
||||
format: pq.Format
|
||||
"""
|
||||
The format that this class `load()` method can convert,
|
||||
`~psycopg.pq.Format.TEXT` or `~psycopg.pq.Format.BINARY`.
|
||||
|
||||
This is a class attribute.
|
||||
"""
|
||||
|
||||
def __init__(self, oid: int, context: Optional[AdaptContext] = None):
|
||||
...
|
||||
|
||||
def load(self, data: Buffer) -> Any:
|
||||
"""
|
||||
Convert the data returned by the database into a Python object.
|
||||
|
||||
:param data: the data to convert.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class Transformer(Protocol):
|
||||
types: Optional[Tuple[int, ...]]
|
||||
formats: Optional[List[pq.Format]]
|
||||
|
||||
def __init__(self, context: Optional[AdaptContext] = None):
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def from_context(cls, context: Optional[AdaptContext]) -> "Transformer":
|
||||
...
|
||||
|
||||
@property
|
||||
def connection(self) -> Optional["BaseConnection[Any]"]:
|
||||
...
|
||||
|
||||
@property
|
||||
def encoding(self) -> str:
|
||||
...
|
||||
|
||||
@property
|
||||
def adapters(self) -> "AdaptersMap":
|
||||
...
|
||||
|
||||
@property
|
||||
def pgresult(self) -> Optional["PGresult"]:
|
||||
...
|
||||
|
||||
def set_pgresult(
|
||||
self,
|
||||
result: Optional["PGresult"],
|
||||
*,
|
||||
set_loaders: bool = True,
|
||||
format: Optional[pq.Format] = None
|
||||
) -> None:
|
||||
...
|
||||
|
||||
def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None:
|
||||
...
|
||||
|
||||
def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None:
|
||||
...
|
||||
|
||||
def dump_sequence(
|
||||
self, params: Sequence[Any], formats: Sequence[PyFormat]
|
||||
) -> Sequence[Optional[Buffer]]:
|
||||
...
|
||||
|
||||
def as_literal(self, obj: Any) -> bytes:
|
||||
...
|
||||
|
||||
def get_dumper(self, obj: Any, format: PyFormat) -> Dumper:
|
||||
...
|
||||
|
||||
def load_rows(self, row0: int, row1: int, make_row: "RowMaker[Row]") -> List["Row"]:
|
||||
...
|
||||
|
||||
def load_row(self, row: int, make_row: "RowMaker[Row]") -> Optional["Row"]:
|
||||
...
|
||||
|
||||
def load_sequence(self, record: Sequence[Optional[Buffer]]) -> Tuple[Any, ...]:
|
||||
...
|
||||
|
||||
def get_loader(self, oid: int, format: pq.Format) -> Loader:
|
||||
...
|
||||
162
srcs/.venv/lib/python3.11/site-packages/psycopg/adapt.py
Normal file
162
srcs/.venv/lib/python3.11/site-packages/psycopg/adapt.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
Entry point into the adaptation system.
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional, Type, TYPE_CHECKING
|
||||
|
||||
from . import pq, abc
|
||||
from . import _adapters_map
|
||||
from ._enums import PyFormat as PyFormat
|
||||
from ._cmodule import _psycopg
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .connection import BaseConnection
|
||||
|
||||
AdaptersMap = _adapters_map.AdaptersMap
|
||||
Buffer = abc.Buffer
|
||||
|
||||
ORD_BS = ord("\\")
|
||||
|
||||
|
||||
class Dumper(abc.Dumper, ABC):
|
||||
"""
|
||||
Convert Python object of the type `!cls` to PostgreSQL representation.
|
||||
"""
|
||||
|
||||
oid: int = 0
|
||||
"""The oid to pass to the server, if known."""
|
||||
|
||||
format: pq.Format = pq.Format.TEXT
|
||||
"""The format of the data dumped."""
|
||||
|
||||
def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None):
|
||||
self.cls = cls
|
||||
self.connection: Optional["BaseConnection[Any]"] = (
|
||||
context.connection if context else None
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<{type(self).__module__}.{type(self).__qualname__}"
|
||||
f" (oid={self.oid}) at 0x{id(self):x}>"
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def dump(self, obj: Any) -> Buffer:
|
||||
...
|
||||
|
||||
def quote(self, obj: Any) -> Buffer:
|
||||
"""
|
||||
By default return the `dump()` value quoted and sanitised, so
|
||||
that the result can be used to build a SQL string. This works well
|
||||
for most types and you won't likely have to implement this method in a
|
||||
subclass.
|
||||
"""
|
||||
value = self.dump(obj)
|
||||
|
||||
if self.connection:
|
||||
esc = pq.Escaping(self.connection.pgconn)
|
||||
# escaping and quoting
|
||||
return esc.escape_literal(value)
|
||||
|
||||
# This path is taken when quote is asked without a connection,
|
||||
# usually it means by psycopg.sql.quote() or by
|
||||
# 'Composible.as_string(None)'. Most often than not this is done by
|
||||
# someone generating a SQL file to consume elsewhere.
|
||||
|
||||
# No quoting, only quote escaping, random bs escaping. See further.
|
||||
esc = pq.Escaping()
|
||||
out = esc.escape_string(value)
|
||||
|
||||
# b"\\" in memoryview doesn't work so search for the ascii value
|
||||
if ORD_BS not in out:
|
||||
# If the string has no backslash, the result is correct and we
|
||||
# don't need to bother with standard_conforming_strings.
|
||||
return b"'" + out + b"'"
|
||||
|
||||
# The libpq has a crazy behaviour: PQescapeString uses the last
|
||||
# standard_conforming_strings setting seen on a connection. This
|
||||
# means that backslashes might be escaped or might not.
|
||||
#
|
||||
# A syntax E'\\' works everywhere, whereas E'\' is an error. OTOH,
|
||||
# if scs is off, '\\' raises a warning and '\' is an error.
|
||||
#
|
||||
# Check what the libpq does, and if it doesn't escape the backslash
|
||||
# let's do it on our own. Never mind the race condition.
|
||||
rv: bytes = b" E'" + out + b"'"
|
||||
if esc.escape_string(b"\\") == b"\\":
|
||||
rv = rv.replace(b"\\", b"\\\\")
|
||||
return rv
|
||||
|
||||
def get_key(self, obj: Any, format: PyFormat) -> abc.DumperKey:
|
||||
"""
|
||||
Implementation of the `~psycopg.abc.Dumper.get_key()` member of the
|
||||
`~psycopg.abc.Dumper` protocol. Look at its definition for details.
|
||||
|
||||
This implementation returns the `!cls` passed in the constructor.
|
||||
Subclasses needing to specialise the PostgreSQL type according to the
|
||||
*value* of the object dumped (not only according to to its type)
|
||||
should override this class.
|
||||
|
||||
"""
|
||||
return self.cls
|
||||
|
||||
def upgrade(self, obj: Any, format: PyFormat) -> "Dumper":
|
||||
"""
|
||||
Implementation of the `~psycopg.abc.Dumper.upgrade()` member of the
|
||||
`~psycopg.abc.Dumper` protocol. Look at its definition for details.
|
||||
|
||||
This implementation just returns `!self`. If a subclass implements
|
||||
`get_key()` it should probably override `!upgrade()` too.
|
||||
"""
|
||||
return self
|
||||
|
||||
|
||||
class Loader(abc.Loader, ABC):
|
||||
"""
|
||||
Convert PostgreSQL values with type OID `!oid` to Python objects.
|
||||
"""
|
||||
|
||||
format: pq.Format = pq.Format.TEXT
|
||||
"""The format of the data loaded."""
|
||||
|
||||
def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None):
|
||||
self.oid = oid
|
||||
self.connection: Optional["BaseConnection[Any]"] = (
|
||||
context.connection if context else None
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def load(self, data: Buffer) -> Any:
|
||||
"""Convert a PostgreSQL value to a Python object."""
|
||||
...
|
||||
|
||||
|
||||
Transformer: Type["abc.Transformer"]
|
||||
|
||||
# Override it with fast object if available
|
||||
if _psycopg:
|
||||
Transformer = _psycopg.Transformer
|
||||
else:
|
||||
from . import _transform
|
||||
|
||||
Transformer = _transform.Transformer
|
||||
|
||||
|
||||
class RecursiveDumper(Dumper):
|
||||
"""Dumper with a transformer to help dumping recursive types."""
|
||||
|
||||
def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None):
|
||||
super().__init__(cls, context)
|
||||
self._tx = Transformer.from_context(context)
|
||||
|
||||
|
||||
class RecursiveLoader(Loader):
|
||||
"""Loader with a transformer to help loading recursive types."""
|
||||
|
||||
def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None):
|
||||
super().__init__(oid, context)
|
||||
self._tx = Transformer.from_context(context)
|
||||
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
psycopg client-side binding cursors
|
||||
"""
|
||||
|
||||
# Copyright (C) 2022 The Psycopg Team
|
||||
|
||||
from typing import Optional, Tuple, TYPE_CHECKING
|
||||
from functools import partial
|
||||
|
||||
from ._queries import PostgresQuery, PostgresClientQuery
|
||||
|
||||
from . import pq
|
||||
from . import adapt
|
||||
from . import errors as e
|
||||
from .abc import ConnectionType, Query, Params
|
||||
from .rows import Row
|
||||
from .cursor import BaseCursor, Cursor
|
||||
from ._preparing import Prepare
|
||||
from .cursor_async import AsyncCursor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any # noqa: F401
|
||||
from .connection import Connection # noqa: F401
|
||||
from .connection_async import AsyncConnection # noqa: F401
|
||||
|
||||
TEXT = pq.Format.TEXT
|
||||
BINARY = pq.Format.BINARY
|
||||
|
||||
|
||||
class ClientCursorMixin(BaseCursor[ConnectionType, Row]):
|
||||
def mogrify(self, query: Query, params: Optional[Params] = None) -> str:
|
||||
"""
|
||||
Return the query and parameters merged.
|
||||
|
||||
Parameters are adapted and merged to the query the same way that
|
||||
`!execute()` would do.
|
||||
|
||||
"""
|
||||
self._tx = adapt.Transformer(self)
|
||||
pgq = self._convert_query(query, params)
|
||||
return pgq.query.decode(self._tx.encoding)
|
||||
|
||||
def _execute_send(
|
||||
self,
|
||||
query: PostgresQuery,
|
||||
*,
|
||||
force_extended: bool = False,
|
||||
binary: Optional[bool] = None,
|
||||
) -> None:
|
||||
if binary is None:
|
||||
fmt = self.format
|
||||
else:
|
||||
fmt = BINARY if binary else TEXT
|
||||
|
||||
if fmt == BINARY:
|
||||
raise e.NotSupportedError(
|
||||
"client-side cursors don't support binary results"
|
||||
)
|
||||
|
||||
self._query = query
|
||||
|
||||
if self._conn._pipeline:
|
||||
# In pipeline mode always use PQsendQueryParams - see #314
|
||||
# Multiple statements in the same query are not allowed anyway.
|
||||
self._conn._pipeline.command_queue.append(
|
||||
partial(self._pgconn.send_query_params, query.query, None)
|
||||
)
|
||||
elif force_extended:
|
||||
self._pgconn.send_query_params(query.query, None)
|
||||
else:
|
||||
# If we can, let's use simple query protocol,
|
||||
# as it can execute more than one statement in a single query.
|
||||
self._pgconn.send_query(query.query)
|
||||
|
||||
def _convert_query(
|
||||
self, query: Query, params: Optional[Params] = None
|
||||
) -> PostgresQuery:
|
||||
pgq = PostgresClientQuery(self._tx)
|
||||
pgq.convert(query, params)
|
||||
return pgq
|
||||
|
||||
def _get_prepared(
|
||||
self, pgq: PostgresQuery, prepare: Optional[bool] = None
|
||||
) -> Tuple[Prepare, bytes]:
|
||||
return (Prepare.NO, b"")
|
||||
|
||||
|
||||
class ClientCursor(ClientCursorMixin["Connection[Any]", Row], Cursor[Row]):
|
||||
__module__ = "psycopg"
|
||||
|
||||
|
||||
class AsyncClientCursor(
|
||||
ClientCursorMixin["AsyncConnection[Any]", Row], AsyncCursor[Row]
|
||||
):
|
||||
__module__ = "psycopg"
|
||||
1054
srcs/.venv/lib/python3.11/site-packages/psycopg/connection.py
Normal file
1054
srcs/.venv/lib/python3.11/site-packages/psycopg/connection.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,428 @@
|
||||
"""
|
||||
psycopg async connection objects
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
import sys
|
||||
import asyncio
|
||||
import logging
|
||||
from types import TracebackType
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, List, Optional
|
||||
from typing import Type, TypeVar, Union, cast, overload, TYPE_CHECKING
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from . import pq
|
||||
from . import errors as e
|
||||
from . import waiting
|
||||
from .abc import AdaptContext, Params, PQGen, PQGenConn, Query, RV
|
||||
from ._tpc import Xid
|
||||
from .rows import Row, AsyncRowFactory, tuple_row, TupleRow, args_row
|
||||
from .adapt import AdaptersMap
|
||||
from ._enums import IsolationLevel
|
||||
from .conninfo import ConnDict, make_conninfo, conninfo_to_dict, conninfo_attempts_async
|
||||
from ._pipeline import AsyncPipeline
|
||||
from ._encodings import pgconn_encoding
|
||||
from .connection import BaseConnection, CursorRow, Notify
|
||||
from .generators import notifies
|
||||
from .transaction import AsyncTransaction
|
||||
from .cursor_async import AsyncCursor
|
||||
from .server_cursor import AsyncServerCursor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .pq.abc import PGconn
|
||||
|
||||
TEXT = pq.Format.TEXT
|
||||
BINARY = pq.Format.BINARY
|
||||
|
||||
IDLE = pq.TransactionStatus.IDLE
|
||||
INTRANS = pq.TransactionStatus.INTRANS
|
||||
|
||||
logger = logging.getLogger("psycopg")
|
||||
|
||||
|
||||
class AsyncConnection(BaseConnection[Row]):
|
||||
"""
|
||||
Asynchronous wrapper for a connection to the database.
|
||||
"""
|
||||
|
||||
__module__ = "psycopg"
|
||||
|
||||
cursor_factory: Type[AsyncCursor[Row]]
|
||||
server_cursor_factory: Type[AsyncServerCursor[Row]]
|
||||
row_factory: AsyncRowFactory[Row]
|
||||
_pipeline: Optional[AsyncPipeline]
|
||||
_Self = TypeVar("_Self", bound="AsyncConnection[Any]")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pgconn: "PGconn",
|
||||
row_factory: AsyncRowFactory[Row] = cast(AsyncRowFactory[Row], tuple_row),
|
||||
):
|
||||
super().__init__(pgconn)
|
||||
self.row_factory = row_factory
|
||||
self.lock = asyncio.Lock()
|
||||
self.cursor_factory = AsyncCursor
|
||||
self.server_cursor_factory = AsyncServerCursor
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def connect(
|
||||
cls,
|
||||
conninfo: str = "",
|
||||
*,
|
||||
autocommit: bool = False,
|
||||
prepare_threshold: Optional[int] = 5,
|
||||
row_factory: AsyncRowFactory[Row],
|
||||
cursor_factory: Optional[Type[AsyncCursor[Row]]] = None,
|
||||
context: Optional[AdaptContext] = None,
|
||||
**kwargs: Union[None, int, str],
|
||||
) -> "AsyncConnection[Row]":
|
||||
# TODO: returned type should be _Self. See #308.
|
||||
...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def connect(
|
||||
cls,
|
||||
conninfo: str = "",
|
||||
*,
|
||||
autocommit: bool = False,
|
||||
prepare_threshold: Optional[int] = 5,
|
||||
cursor_factory: Optional[Type[AsyncCursor[Any]]] = None,
|
||||
context: Optional[AdaptContext] = None,
|
||||
**kwargs: Union[None, int, str],
|
||||
) -> "AsyncConnection[TupleRow]":
|
||||
...
|
||||
|
||||
@classmethod # type: ignore[misc] # https://github.com/python/mypy/issues/11004
|
||||
async def connect(
|
||||
cls,
|
||||
conninfo: str = "",
|
||||
*,
|
||||
autocommit: bool = False,
|
||||
prepare_threshold: Optional[int] = 5,
|
||||
context: Optional[AdaptContext] = None,
|
||||
row_factory: Optional[AsyncRowFactory[Row]] = None,
|
||||
cursor_factory: Optional[Type[AsyncCursor[Row]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> "AsyncConnection[Any]":
|
||||
if sys.platform == "win32":
|
||||
loop = asyncio.get_running_loop()
|
||||
if isinstance(loop, asyncio.ProactorEventLoop):
|
||||
raise e.InterfaceError(
|
||||
"Psycopg cannot use the 'ProactorEventLoop' to run in async"
|
||||
" mode. Please use a compatible event loop, for instance by"
|
||||
" setting 'asyncio.set_event_loop_policy"
|
||||
"(WindowsSelectorEventLoopPolicy())'"
|
||||
)
|
||||
|
||||
params = await cls._get_connection_params(conninfo, **kwargs)
|
||||
timeout = int(params["connect_timeout"])
|
||||
rv = None
|
||||
async for attempt in conninfo_attempts_async(params):
|
||||
try:
|
||||
conninfo = make_conninfo(**attempt)
|
||||
rv = await cls._wait_conn(cls._connect_gen(conninfo), timeout=timeout)
|
||||
break
|
||||
except e._NO_TRACEBACK as ex:
|
||||
last_ex = ex
|
||||
|
||||
if not rv:
|
||||
assert last_ex
|
||||
raise last_ex.with_traceback(None)
|
||||
|
||||
rv._autocommit = bool(autocommit)
|
||||
if row_factory:
|
||||
rv.row_factory = row_factory
|
||||
if cursor_factory:
|
||||
rv.cursor_factory = cursor_factory
|
||||
if context:
|
||||
rv._adapters = AdaptersMap(context.adapters)
|
||||
rv.prepare_threshold = prepare_threshold
|
||||
return rv
|
||||
|
||||
async def __aenter__(self: _Self) -> _Self:
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
if self.closed:
|
||||
return
|
||||
|
||||
if exc_type:
|
||||
# try to rollback, but if there are problems (connection in a bad
|
||||
# state) just warn without clobbering the exception bubbling up.
|
||||
try:
|
||||
await self.rollback()
|
||||
except Exception as exc2:
|
||||
logger.warning(
|
||||
"error ignored in rollback on %s: %s",
|
||||
self,
|
||||
exc2,
|
||||
)
|
||||
else:
|
||||
await self.commit()
|
||||
|
||||
# Close the connection only if it doesn't belong to a pool.
|
||||
if not getattr(self, "_pool", None):
|
||||
await self.close()
|
||||
|
||||
@classmethod
|
||||
async def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> ConnDict:
|
||||
"""Manipulate connection parameters before connecting."""
|
||||
params = conninfo_to_dict(conninfo, **kwargs)
|
||||
|
||||
# Make sure there is an usable connect_timeout
|
||||
if "connect_timeout" in params:
|
||||
params["connect_timeout"] = int(params["connect_timeout"])
|
||||
else:
|
||||
# The sync connect function will stop on the default socket timeout
|
||||
# Because in async connection mode we need to enforce the timeout
|
||||
# ourselves, we need a finite value.
|
||||
params["connect_timeout"] = cls._DEFAULT_CONNECT_TIMEOUT
|
||||
|
||||
return params
|
||||
|
||||
async def close(self) -> None:
|
||||
if self.closed:
|
||||
return
|
||||
self._closed = True
|
||||
|
||||
# TODO: maybe send a cancel on close, if the connection is ACTIVE?
|
||||
|
||||
self.pgconn.finish()
|
||||
|
||||
@overload
|
||||
def cursor(self, *, binary: bool = False) -> AsyncCursor[Row]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def cursor(
|
||||
self, *, binary: bool = False, row_factory: AsyncRowFactory[CursorRow]
|
||||
) -> AsyncCursor[CursorRow]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def cursor(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
binary: bool = False,
|
||||
scrollable: Optional[bool] = None,
|
||||
withhold: bool = False,
|
||||
) -> AsyncServerCursor[Row]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def cursor(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
binary: bool = False,
|
||||
row_factory: AsyncRowFactory[CursorRow],
|
||||
scrollable: Optional[bool] = None,
|
||||
withhold: bool = False,
|
||||
) -> AsyncServerCursor[CursorRow]:
|
||||
...
|
||||
|
||||
def cursor(
|
||||
self,
|
||||
name: str = "",
|
||||
*,
|
||||
binary: bool = False,
|
||||
row_factory: Optional[AsyncRowFactory[Any]] = None,
|
||||
scrollable: Optional[bool] = None,
|
||||
withhold: bool = False,
|
||||
) -> Union[AsyncCursor[Any], AsyncServerCursor[Any]]:
|
||||
"""
|
||||
Return a new `AsyncCursor` to send commands and queries to the connection.
|
||||
"""
|
||||
self._check_connection_ok()
|
||||
|
||||
if not row_factory:
|
||||
row_factory = self.row_factory
|
||||
|
||||
cur: Union[AsyncCursor[Any], AsyncServerCursor[Any]]
|
||||
if name:
|
||||
cur = self.server_cursor_factory(
|
||||
self,
|
||||
name=name,
|
||||
row_factory=row_factory,
|
||||
scrollable=scrollable,
|
||||
withhold=withhold,
|
||||
)
|
||||
else:
|
||||
cur = self.cursor_factory(self, row_factory=row_factory)
|
||||
|
||||
if binary:
|
||||
cur.format = BINARY
|
||||
|
||||
return cur
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
query: Query,
|
||||
params: Optional[Params] = None,
|
||||
*,
|
||||
prepare: Optional[bool] = None,
|
||||
binary: bool = False,
|
||||
) -> AsyncCursor[Row]:
|
||||
try:
|
||||
cur = self.cursor()
|
||||
if binary:
|
||||
cur.format = BINARY
|
||||
|
||||
return await cur.execute(query, params, prepare=prepare)
|
||||
|
||||
except e._NO_TRACEBACK as ex:
|
||||
raise ex.with_traceback(None)
|
||||
|
||||
async def commit(self) -> None:
|
||||
async with self.lock:
|
||||
await self.wait(self._commit_gen())
|
||||
|
||||
async def rollback(self) -> None:
|
||||
async with self.lock:
|
||||
await self.wait(self._rollback_gen())
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction(
|
||||
self,
|
||||
savepoint_name: Optional[str] = None,
|
||||
force_rollback: bool = False,
|
||||
) -> AsyncIterator[AsyncTransaction]:
|
||||
"""
|
||||
Start a context block with a new transaction or nested transaction.
|
||||
|
||||
:rtype: AsyncTransaction
|
||||
"""
|
||||
tx = AsyncTransaction(self, savepoint_name, force_rollback)
|
||||
if self._pipeline:
|
||||
async with self.pipeline(), tx, self.pipeline():
|
||||
yield tx
|
||||
else:
|
||||
async with tx:
|
||||
yield tx
|
||||
|
||||
async def notifies(self) -> AsyncGenerator[Notify, None]:
|
||||
while True:
|
||||
async with self.lock:
|
||||
try:
|
||||
ns = await self.wait(notifies(self.pgconn))
|
||||
except e._NO_TRACEBACK as ex:
|
||||
raise ex.with_traceback(None)
|
||||
enc = pgconn_encoding(self.pgconn)
|
||||
for pgn in ns:
|
||||
n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid)
|
||||
yield n
|
||||
|
||||
@asynccontextmanager
|
||||
async def pipeline(self) -> AsyncIterator[AsyncPipeline]:
|
||||
"""Context manager to switch the connection into pipeline mode."""
|
||||
async with self.lock:
|
||||
self._check_connection_ok()
|
||||
|
||||
pipeline = self._pipeline
|
||||
if pipeline is None:
|
||||
# WARNING: reference loop, broken ahead.
|
||||
pipeline = self._pipeline = AsyncPipeline(self)
|
||||
|
||||
try:
|
||||
async with pipeline:
|
||||
yield pipeline
|
||||
finally:
|
||||
if pipeline.level == 0:
|
||||
async with self.lock:
|
||||
assert pipeline is self._pipeline
|
||||
self._pipeline = None
|
||||
|
||||
async def wait(self, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV:
|
||||
try:
|
||||
return await waiting.wait_async(gen, self.pgconn.socket, timeout=timeout)
|
||||
except (asyncio.CancelledError, KeyboardInterrupt):
|
||||
# On Ctrl-C, try to cancel the query in the server, otherwise
|
||||
# the connection will remain stuck in ACTIVE state.
|
||||
self._try_cancel(self.pgconn)
|
||||
try:
|
||||
await waiting.wait_async(gen, self.pgconn.socket, timeout=timeout)
|
||||
except e.QueryCanceled:
|
||||
pass # as expected
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
async def _wait_conn(cls, gen: PQGenConn[RV], timeout: Optional[int]) -> RV:
|
||||
return await waiting.wait_conn_async(gen, timeout)
|
||||
|
||||
def _set_autocommit(self, value: bool) -> None:
|
||||
self._no_set_async("autocommit")
|
||||
|
||||
async def set_autocommit(self, value: bool) -> None:
|
||||
"""Async version of the `~Connection.autocommit` setter."""
|
||||
async with self.lock:
|
||||
await self.wait(self._set_autocommit_gen(value))
|
||||
|
||||
def _set_isolation_level(self, value: Optional[IsolationLevel]) -> None:
|
||||
self._no_set_async("isolation_level")
|
||||
|
||||
async def set_isolation_level(self, value: Optional[IsolationLevel]) -> None:
|
||||
"""Async version of the `~Connection.isolation_level` setter."""
|
||||
async with self.lock:
|
||||
await self.wait(self._set_isolation_level_gen(value))
|
||||
|
||||
def _set_read_only(self, value: Optional[bool]) -> None:
|
||||
self._no_set_async("read_only")
|
||||
|
||||
async def set_read_only(self, value: Optional[bool]) -> None:
|
||||
"""Async version of the `~Connection.read_only` setter."""
|
||||
async with self.lock:
|
||||
await self.wait(self._set_read_only_gen(value))
|
||||
|
||||
def _set_deferrable(self, value: Optional[bool]) -> None:
|
||||
self._no_set_async("deferrable")
|
||||
|
||||
async def set_deferrable(self, value: Optional[bool]) -> None:
|
||||
"""Async version of the `~Connection.deferrable` setter."""
|
||||
async with self.lock:
|
||||
await self.wait(self._set_deferrable_gen(value))
|
||||
|
||||
def _no_set_async(self, attribute: str) -> None:
|
||||
raise AttributeError(
|
||||
f"'the {attribute!r} property is read-only on async connections:"
|
||||
f" please use 'await .set_{attribute}()' instead."
|
||||
)
|
||||
|
||||
async def tpc_begin(self, xid: Union[Xid, str]) -> None:
|
||||
async with self.lock:
|
||||
await self.wait(self._tpc_begin_gen(xid))
|
||||
|
||||
async def tpc_prepare(self) -> None:
|
||||
try:
|
||||
async with self.lock:
|
||||
await self.wait(self._tpc_prepare_gen())
|
||||
except e.ObjectNotInPrerequisiteState as ex:
|
||||
raise e.NotSupportedError(str(ex)) from None
|
||||
|
||||
async def tpc_commit(self, xid: Union[Xid, str, None] = None) -> None:
|
||||
async with self.lock:
|
||||
await self.wait(self._tpc_finish_gen("commit", xid))
|
||||
|
||||
async def tpc_rollback(self, xid: Union[Xid, str, None] = None) -> None:
|
||||
async with self.lock:
|
||||
await self.wait(self._tpc_finish_gen("rollback", xid))
|
||||
|
||||
async def tpc_recover(self) -> List[Xid]:
|
||||
self._check_tpc()
|
||||
status = self.info.transaction_status
|
||||
async with self.cursor(row_factory=args_row(Xid._from_record)) as cur:
|
||||
await cur.execute(Xid._get_recover_query())
|
||||
res = await cur.fetchall()
|
||||
|
||||
if status == IDLE and self.info.transaction_status == INTRANS:
|
||||
await self.rollback()
|
||||
|
||||
return res
|
||||
478
srcs/.venv/lib/python3.11/site-packages/psycopg/conninfo.py
Normal file
478
srcs/.venv/lib/python3.11/site-packages/psycopg/conninfo.py
Normal file
@@ -0,0 +1,478 @@
|
||||
"""
|
||||
Functions to manipulate conninfo strings
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
import asyncio
|
||||
from typing import Any, Iterator, AsyncIterator
|
||||
from random import shuffle
|
||||
from pathlib import Path
|
||||
from datetime import tzinfo
|
||||
from functools import lru_cache
|
||||
from ipaddress import ip_address
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from . import pq
|
||||
from . import errors as e
|
||||
from ._tz import get_tzinfo
|
||||
from ._compat import cache
|
||||
from ._encodings import pgconn_encoding
|
||||
|
||||
ConnDict: TypeAlias = "dict[str, Any]"
|
||||
|
||||
|
||||
def make_conninfo(conninfo: str = "", **kwargs: Any) -> str:
|
||||
"""
|
||||
Merge a string and keyword params into a single conninfo string.
|
||||
|
||||
:param conninfo: A `connection string`__ as accepted by PostgreSQL.
|
||||
:param kwargs: Parameters overriding the ones specified in `!conninfo`.
|
||||
:return: A connection string valid for PostgreSQL, with the `!kwargs`
|
||||
parameters merged.
|
||||
|
||||
Raise `~psycopg.ProgrammingError` if the input doesn't make a valid
|
||||
conninfo string.
|
||||
|
||||
.. __: https://www.postgresql.org/docs/current/libpq-connect.html
|
||||
#LIBPQ-CONNSTRING
|
||||
"""
|
||||
if not conninfo and not kwargs:
|
||||
return ""
|
||||
|
||||
# If no kwarg specified don't mung the conninfo but check if it's correct.
|
||||
# Make sure to return a string, not a subtype, to avoid making Liskov sad.
|
||||
if not kwargs:
|
||||
_parse_conninfo(conninfo)
|
||||
return str(conninfo)
|
||||
|
||||
# Override the conninfo with the parameters
|
||||
# Drop the None arguments
|
||||
kwargs = {k: v for (k, v) in kwargs.items() if v is not None}
|
||||
|
||||
if conninfo:
|
||||
tmp = conninfo_to_dict(conninfo)
|
||||
tmp.update(kwargs)
|
||||
kwargs = tmp
|
||||
|
||||
conninfo = " ".join(f"{k}={_param_escape(str(v))}" for (k, v) in kwargs.items())
|
||||
|
||||
# Verify the result is valid
|
||||
_parse_conninfo(conninfo)
|
||||
|
||||
return conninfo
|
||||
|
||||
|
||||
def conninfo_to_dict(conninfo: str = "", **kwargs: Any) -> ConnDict:
|
||||
"""
|
||||
Convert the `!conninfo` string into a dictionary of parameters.
|
||||
|
||||
:param conninfo: A `connection string`__ as accepted by PostgreSQL.
|
||||
:param kwargs: Parameters overriding the ones specified in `!conninfo`.
|
||||
:return: Dictionary with the parameters parsed from `!conninfo` and
|
||||
`!kwargs`.
|
||||
|
||||
Raise `~psycopg.ProgrammingError` if `!conninfo` is not a a valid connection
|
||||
string.
|
||||
|
||||
.. __: https://www.postgresql.org/docs/current/libpq-connect.html
|
||||
#LIBPQ-CONNSTRING
|
||||
"""
|
||||
opts = _parse_conninfo(conninfo)
|
||||
rv = {opt.keyword.decode(): opt.val.decode() for opt in opts if opt.val is not None}
|
||||
for k, v in kwargs.items():
|
||||
if v is not None:
|
||||
rv[k] = v
|
||||
return rv
|
||||
|
||||
|
||||
def _parse_conninfo(conninfo: str) -> list[pq.ConninfoOption]:
|
||||
"""
|
||||
Verify that `!conninfo` is a valid connection string.
|
||||
|
||||
Raise ProgrammingError if the string is not valid.
|
||||
|
||||
Return the result of pq.Conninfo.parse() on success.
|
||||
"""
|
||||
try:
|
||||
return pq.Conninfo.parse(conninfo.encode())
|
||||
except e.OperationalError as ex:
|
||||
raise e.ProgrammingError(str(ex)) from None
|
||||
|
||||
|
||||
re_escape = re.compile(r"([\\'])")
|
||||
re_space = re.compile(r"\s")
|
||||
|
||||
|
||||
def _param_escape(s: str) -> str:
|
||||
"""
|
||||
Apply the escaping rule required by PQconnectdb
|
||||
"""
|
||||
if not s:
|
||||
return "''"
|
||||
|
||||
s = re_escape.sub(r"\\\1", s)
|
||||
if re_space.search(s):
|
||||
s = "'" + s + "'"
|
||||
|
||||
return s
|
||||
|
||||
|
||||
class ConnectionInfo:
|
||||
"""Allow access to information about the connection."""
|
||||
|
||||
__module__ = "psycopg"
|
||||
|
||||
def __init__(self, pgconn: pq.abc.PGconn):
|
||||
self.pgconn = pgconn
|
||||
|
||||
@property
|
||||
def vendor(self) -> str:
|
||||
"""A string representing the database vendor connected to."""
|
||||
return "PostgreSQL"
|
||||
|
||||
@property
|
||||
def host(self) -> str:
|
||||
"""The server host name of the active connection. See :pq:`PQhost()`."""
|
||||
return self._get_pgconn_attr("host")
|
||||
|
||||
@property
|
||||
def hostaddr(self) -> str:
|
||||
"""The server IP address of the connection. See :pq:`PQhostaddr()`."""
|
||||
return self._get_pgconn_attr("hostaddr")
|
||||
|
||||
@property
|
||||
def port(self) -> int:
|
||||
"""The port of the active connection. See :pq:`PQport()`."""
|
||||
return int(self._get_pgconn_attr("port"))
|
||||
|
||||
@property
|
||||
def dbname(self) -> str:
|
||||
"""The database name of the connection. See :pq:`PQdb()`."""
|
||||
return self._get_pgconn_attr("db")
|
||||
|
||||
@property
|
||||
def user(self) -> str:
|
||||
"""The user name of the connection. See :pq:`PQuser()`."""
|
||||
return self._get_pgconn_attr("user")
|
||||
|
||||
@property
|
||||
def password(self) -> str:
|
||||
"""The password of the connection. See :pq:`PQpass()`."""
|
||||
return self._get_pgconn_attr("password")
|
||||
|
||||
@property
|
||||
def options(self) -> str:
|
||||
"""
|
||||
The command-line options passed in the connection request.
|
||||
See :pq:`PQoptions`.
|
||||
"""
|
||||
return self._get_pgconn_attr("options")
|
||||
|
||||
def get_parameters(self) -> dict[str, str]:
|
||||
"""Return the connection parameters values.
|
||||
|
||||
Return all the parameters set to a non-default value, which might come
|
||||
either from the connection string and parameters passed to
|
||||
`~Connection.connect()` or from environment variables. The password
|
||||
is never returned (you can read it using the `password` attribute).
|
||||
"""
|
||||
pyenc = self.encoding
|
||||
|
||||
# Get the known defaults to avoid reporting them
|
||||
defaults = {
|
||||
i.keyword: i.compiled
|
||||
for i in pq.Conninfo.get_defaults()
|
||||
if i.compiled is not None
|
||||
}
|
||||
# Not returned by the libq. Bug? Bet we're using SSH.
|
||||
defaults.setdefault(b"channel_binding", b"prefer")
|
||||
defaults[b"passfile"] = str(Path.home() / ".pgpass").encode()
|
||||
|
||||
return {
|
||||
i.keyword.decode(pyenc): i.val.decode(pyenc)
|
||||
for i in self.pgconn.info
|
||||
if i.val is not None
|
||||
and i.keyword != b"password"
|
||||
and i.val != defaults.get(i.keyword)
|
||||
}
|
||||
|
||||
@property
|
||||
def dsn(self) -> str:
|
||||
"""Return the connection string to connect to the database.
|
||||
|
||||
The string contains all the parameters set to a non-default value,
|
||||
which might come either from the connection string and parameters
|
||||
passed to `~Connection.connect()` or from environment variables. The
|
||||
password is never returned (you can read it using the `password`
|
||||
attribute).
|
||||
"""
|
||||
return make_conninfo(**self.get_parameters())
|
||||
|
||||
@property
|
||||
def status(self) -> pq.ConnStatus:
|
||||
"""The status of the connection. See :pq:`PQstatus()`."""
|
||||
return pq.ConnStatus(self.pgconn.status)
|
||||
|
||||
@property
|
||||
def transaction_status(self) -> pq.TransactionStatus:
|
||||
"""
|
||||
The current in-transaction status of the session.
|
||||
See :pq:`PQtransactionStatus()`.
|
||||
"""
|
||||
return pq.TransactionStatus(self.pgconn.transaction_status)
|
||||
|
||||
@property
|
||||
def pipeline_status(self) -> pq.PipelineStatus:
|
||||
"""
|
||||
The current pipeline status of the client.
|
||||
See :pq:`PQpipelineStatus()`.
|
||||
"""
|
||||
return pq.PipelineStatus(self.pgconn.pipeline_status)
|
||||
|
||||
def parameter_status(self, param_name: str) -> str | None:
|
||||
"""
|
||||
Return a parameter setting of the connection.
|
||||
|
||||
Return `None` is the parameter is unknown.
|
||||
"""
|
||||
res = self.pgconn.parameter_status(param_name.encode(self.encoding))
|
||||
return res.decode(self.encoding) if res is not None else None
|
||||
|
||||
@property
|
||||
def server_version(self) -> int:
|
||||
"""
|
||||
An integer representing the server version. See :pq:`PQserverVersion()`.
|
||||
"""
|
||||
return self.pgconn.server_version
|
||||
|
||||
@property
|
||||
def backend_pid(self) -> int:
|
||||
"""
|
||||
The process ID (PID) of the backend process handling this connection.
|
||||
See :pq:`PQbackendPID()`.
|
||||
"""
|
||||
return self.pgconn.backend_pid
|
||||
|
||||
@property
|
||||
def error_message(self) -> str:
|
||||
"""
|
||||
The error message most recently generated by an operation on the connection.
|
||||
See :pq:`PQerrorMessage()`.
|
||||
"""
|
||||
return self._get_pgconn_attr("error_message")
|
||||
|
||||
@property
|
||||
def timezone(self) -> tzinfo:
|
||||
"""The Python timezone info of the connection's timezone."""
|
||||
return get_tzinfo(self.pgconn)
|
||||
|
||||
@property
|
||||
def encoding(self) -> str:
|
||||
"""The Python codec name of the connection's client encoding."""
|
||||
return pgconn_encoding(self.pgconn)
|
||||
|
||||
def _get_pgconn_attr(self, name: str) -> str:
|
||||
value: bytes = getattr(self.pgconn, name)
|
||||
return value.decode(self.encoding)
|
||||
|
||||
|
||||
def conninfo_attempts(params: ConnDict) -> Iterator[ConnDict]:
|
||||
"""Split a set of connection params on the single attempts to perforn.
|
||||
|
||||
A connection param can perform more than one attempt more than one ``host``
|
||||
is provided.
|
||||
|
||||
Because the libpq async function doesn't honour the timeout, we need to
|
||||
reimplement the repeated attempts.
|
||||
"""
|
||||
if params.get("load_balance_hosts", "disable") == "random":
|
||||
attempts = list(_split_attempts(_inject_defaults(params)))
|
||||
shuffle(attempts)
|
||||
for attempt in attempts:
|
||||
yield attempt
|
||||
else:
|
||||
for attempt in _split_attempts(_inject_defaults(params)):
|
||||
yield attempt
|
||||
|
||||
|
||||
async def conninfo_attempts_async(params: ConnDict) -> AsyncIterator[ConnDict]:
|
||||
"""Split a set of connection params on the single attempts to perforn.
|
||||
|
||||
A connection param can perform more than one attempt more than one ``host``
|
||||
is provided.
|
||||
|
||||
Also perform async resolution of the hostname into hostaddr in order to
|
||||
avoid blocking. Because a host can resolve to more than one address, this
|
||||
can lead to yield more attempts too. Raise `OperationalError` if no host
|
||||
could be resolved.
|
||||
|
||||
Because the libpq async function doesn't honour the timeout, we need to
|
||||
reimplement the repeated attempts.
|
||||
"""
|
||||
yielded = False
|
||||
last_exc = None
|
||||
for attempt in _split_attempts(_inject_defaults(params)):
|
||||
try:
|
||||
async for a2 in _split_attempts_and_resolve(attempt):
|
||||
yielded = True
|
||||
yield a2
|
||||
except OSError as ex:
|
||||
last_exc = ex
|
||||
|
||||
if not yielded:
|
||||
assert last_exc
|
||||
# We couldn't resolve anything
|
||||
raise e.OperationalError(str(last_exc))
|
||||
|
||||
|
||||
def _inject_defaults(params: ConnDict) -> ConnDict:
|
||||
"""
|
||||
Add defaults to a dictionary of parameters.
|
||||
|
||||
This avoids the need to look up for env vars at various stages during
|
||||
processing.
|
||||
|
||||
Note that a port is always specified. 5432 likely comes from here.
|
||||
|
||||
The `host`, `hostaddr`, `port` will be always set to a string.
|
||||
"""
|
||||
defaults = _conn_defaults()
|
||||
out = params.copy()
|
||||
|
||||
def inject(name: str, envvar: str) -> None:
|
||||
value = out.get(name)
|
||||
if not value:
|
||||
out[name] = os.environ.get(envvar, defaults[name])
|
||||
else:
|
||||
out[name] = str(value)
|
||||
|
||||
inject("host", "PGHOST")
|
||||
inject("hostaddr", "PGHOSTADDR")
|
||||
inject("port", "PGPORT")
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _split_attempts(params: ConnDict) -> Iterator[ConnDict]:
|
||||
"""
|
||||
Split connection parameters with a sequence of hosts into separate attempts.
|
||||
|
||||
Assume that `host`, `hostaddr`, `port` are always present and a string (as
|
||||
emitted from `_inject_defaults()`).
|
||||
"""
|
||||
|
||||
def split_val(key: str) -> list[str]:
|
||||
# Assume all keys are present and strings.
|
||||
val: str = params[key]
|
||||
return val.split(",") if val else []
|
||||
|
||||
hosts = split_val("host")
|
||||
hostaddrs = split_val("hostaddr")
|
||||
ports = split_val("port")
|
||||
|
||||
if hosts and hostaddrs and len(hosts) != len(hostaddrs):
|
||||
raise e.OperationalError(
|
||||
f"could not match {len(hosts)} host names"
|
||||
f" with {len(hostaddrs)} hostaddr values"
|
||||
)
|
||||
|
||||
nhosts = max(len(hosts), len(hostaddrs))
|
||||
|
||||
if 1 < len(ports) != nhosts:
|
||||
raise e.OperationalError(
|
||||
f"could not match {len(ports)} port numbers to {len(hosts)} hosts"
|
||||
)
|
||||
elif len(ports) == 1:
|
||||
ports *= nhosts
|
||||
|
||||
# A single attempt to make
|
||||
if nhosts <= 1:
|
||||
yield params
|
||||
return
|
||||
|
||||
# Now all lists are either empty or have the same length
|
||||
for i in range(nhosts):
|
||||
attempt = params.copy()
|
||||
if hosts:
|
||||
attempt["host"] = hosts[i]
|
||||
if hostaddrs:
|
||||
attempt["hostaddr"] = hostaddrs[i]
|
||||
if ports:
|
||||
attempt["port"] = ports[i]
|
||||
yield attempt
|
||||
|
||||
|
||||
async def _split_attempts_and_resolve(params: ConnDict) -> AsyncIterator[ConnDict]:
|
||||
"""
|
||||
Perform async DNS lookup of the hosts and return a new params dict.
|
||||
|
||||
:param params: The input parameters, for instance as returned by
|
||||
`~psycopg.conninfo.conninfo_to_dict()`. The function expects at most
|
||||
a single entry for host, hostaddr, port and doesn't check for env vars
|
||||
because it is designed to further process the input of _split_attempts()
|
||||
|
||||
If a ``host`` param is present but not ``hostname``, resolve the host
|
||||
addresses dynamically.
|
||||
|
||||
The function may change the input ``host``, ``hostname``, ``port`` to allow
|
||||
connecting without further DNS lookups.
|
||||
|
||||
Raise `~psycopg.OperationalError` if resolution fails.
|
||||
"""
|
||||
host = params["host"]
|
||||
if not host or host.startswith("/") or host[1:2] == ":":
|
||||
# Local path, or no host to resolve
|
||||
yield params
|
||||
return
|
||||
|
||||
hostaddr = params["hostaddr"]
|
||||
if hostaddr:
|
||||
# Already resolved
|
||||
yield params
|
||||
return
|
||||
|
||||
if is_ip_address(host):
|
||||
# If the host is already an ip address don't try to resolve it
|
||||
params["hostaddr"] = host
|
||||
yield params
|
||||
return
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
port = params["port"]
|
||||
ans = await loop.getaddrinfo(
|
||||
host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
|
||||
)
|
||||
|
||||
attempt = params.copy()
|
||||
for item in ans:
|
||||
attempt["hostaddr"] = item[4][0]
|
||||
yield attempt
|
||||
|
||||
|
||||
@cache
|
||||
def _conn_defaults() -> dict[str, str]:
|
||||
"""
|
||||
Return a dictionary of defaults for connection strings parameters.
|
||||
"""
|
||||
defs = pq.Conninfo.get_defaults()
|
||||
return {
|
||||
d.keyword.decode(): d.compiled.decode() if d.compiled is not None else ""
|
||||
for d in defs
|
||||
}
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def is_ip_address(s: str) -> bool:
|
||||
"""Return True if the string represent a valid ip address."""
|
||||
try:
|
||||
ip_address(s)
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
919
srcs/.venv/lib/python3.11/site-packages/psycopg/copy.py
Normal file
919
srcs/.venv/lib/python3.11/site-packages/psycopg/copy.py
Normal file
@@ -0,0 +1,919 @@
|
||||
"""
|
||||
psycopg copy support
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
import re
|
||||
import queue
|
||||
import struct
|
||||
import asyncio
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from types import TracebackType
|
||||
from typing import Any, AsyncIterator, Dict, Generic, Iterator, List, Match, IO
|
||||
from typing import Optional, Sequence, Tuple, Type, TypeVar, Union, TYPE_CHECKING
|
||||
|
||||
from . import pq
|
||||
from . import adapt
|
||||
from . import errors as e
|
||||
from .abc import Buffer, ConnectionType, PQGen, Transformer
|
||||
from ._compat import create_task
|
||||
from .pq.misc import connection_summary
|
||||
from ._cmodule import _psycopg
|
||||
from ._encodings import pgconn_encoding
|
||||
from .generators import copy_from, copy_to, copy_end
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .cursor import BaseCursor, Cursor
|
||||
from .cursor_async import AsyncCursor
|
||||
from .connection import Connection # noqa: F401
|
||||
from .connection_async import AsyncConnection # noqa: F401
|
||||
|
||||
PY_TEXT = adapt.PyFormat.TEXT
|
||||
PY_BINARY = adapt.PyFormat.BINARY
|
||||
|
||||
TEXT = pq.Format.TEXT
|
||||
BINARY = pq.Format.BINARY
|
||||
|
||||
COPY_IN = pq.ExecStatus.COPY_IN
|
||||
COPY_OUT = pq.ExecStatus.COPY_OUT
|
||||
|
||||
ACTIVE = pq.TransactionStatus.ACTIVE
|
||||
|
||||
# Size of data to accumulate before sending it down the network. We fill a
|
||||
# buffer this size field by field, and when it passes the threshold size
|
||||
# we ship it, so it may end up being bigger than this.
|
||||
BUFFER_SIZE = 32 * 1024
|
||||
|
||||
# Maximum data size we want to queue to send to the libpq copy. Sending a
|
||||
# buffer too big to be handled can cause an infinite loop in the libpq
|
||||
# (#255) so we want to split it in more digestable chunks.
|
||||
MAX_BUFFER_SIZE = 4 * BUFFER_SIZE
|
||||
# Note: making this buffer too large, e.g.
|
||||
# MAX_BUFFER_SIZE = 1024 * 1024
|
||||
# makes operations *way* slower! Probably triggering some quadraticity
|
||||
# in the libpq memory management and data sending.
|
||||
|
||||
# Max size of the write queue of buffers. More than that copy will block
|
||||
# Each buffer should be around BUFFER_SIZE size.
|
||||
QUEUE_SIZE = 1024
|
||||
|
||||
|
||||
class BaseCopy(Generic[ConnectionType]):
|
||||
"""
|
||||
Base implementation for the copy user interface.
|
||||
|
||||
Two subclasses expose real methods with the sync/async differences.
|
||||
|
||||
The difference between the text and binary format is managed by two
|
||||
different `Formatter` subclasses.
|
||||
|
||||
Writing (the I/O part) is implemented in the subclasses by a `Writer` or
|
||||
`AsyncWriter` instance. Normally writing implies sending copy data to a
|
||||
database, but a different writer might be chosen, e.g. to stream data into
|
||||
a file for later use.
|
||||
"""
|
||||
|
||||
_Self = TypeVar("_Self", bound="BaseCopy[Any]")
|
||||
|
||||
formatter: "Formatter"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cursor: "BaseCursor[ConnectionType, Any]",
|
||||
*,
|
||||
binary: Optional[bool] = None,
|
||||
):
|
||||
self.cursor = cursor
|
||||
self.connection = cursor.connection
|
||||
self._pgconn = self.connection.pgconn
|
||||
|
||||
result = cursor.pgresult
|
||||
if result:
|
||||
self._direction = result.status
|
||||
if self._direction != COPY_IN and self._direction != COPY_OUT:
|
||||
raise e.ProgrammingError(
|
||||
"the cursor should have performed a COPY operation;"
|
||||
f" its status is {pq.ExecStatus(self._direction).name} instead"
|
||||
)
|
||||
else:
|
||||
self._direction = COPY_IN
|
||||
|
||||
if binary is None:
|
||||
binary = bool(result and result.binary_tuples)
|
||||
|
||||
tx: Transformer = getattr(cursor, "_tx", None) or adapt.Transformer(cursor)
|
||||
if binary:
|
||||
self.formatter = BinaryFormatter(tx)
|
||||
else:
|
||||
self.formatter = TextFormatter(tx, encoding=pgconn_encoding(self._pgconn))
|
||||
|
||||
self._finished = False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
|
||||
info = connection_summary(self._pgconn)
|
||||
return f"<{cls} {info} at 0x{id(self):x}>"
|
||||
|
||||
def _enter(self) -> None:
|
||||
if self._finished:
|
||||
raise TypeError("copy blocks can be used only once")
|
||||
|
||||
def set_types(self, types: Sequence[Union[int, str]]) -> None:
|
||||
"""
|
||||
Set the types expected in a COPY operation.
|
||||
|
||||
The types must be specified as a sequence of oid or PostgreSQL type
|
||||
names (e.g. ``int4``, ``timestamptz[]``).
|
||||
|
||||
This operation overcomes the lack of metadata returned by PostgreSQL
|
||||
when a COPY operation begins:
|
||||
|
||||
- On :sql:`COPY TO`, `!set_types()` allows to specify what types the
|
||||
operation returns. If `!set_types()` is not used, the data will be
|
||||
returned as unparsed strings or bytes instead of Python objects.
|
||||
|
||||
- On :sql:`COPY FROM`, `!set_types()` allows to choose what type the
|
||||
database expects. This is especially useful in binary copy, because
|
||||
PostgreSQL will apply no cast rule.
|
||||
|
||||
"""
|
||||
registry = self.cursor.adapters.types
|
||||
oids = [t if isinstance(t, int) else registry.get_oid(t) for t in types]
|
||||
|
||||
if self._direction == COPY_IN:
|
||||
self.formatter.transformer.set_dumper_types(oids, self.formatter.format)
|
||||
else:
|
||||
self.formatter.transformer.set_loader_types(oids, self.formatter.format)
|
||||
|
||||
# High level copy protocol generators (state change of the Copy object)
|
||||
|
||||
def _read_gen(self) -> PQGen[Buffer]:
|
||||
if self._finished:
|
||||
return memoryview(b"")
|
||||
|
||||
res = yield from copy_from(self._pgconn)
|
||||
if isinstance(res, memoryview):
|
||||
return res
|
||||
|
||||
# res is the final PGresult
|
||||
self._finished = True
|
||||
|
||||
# This result is a COMMAND_OK which has info about the number of rows
|
||||
# returned, but not about the columns, which is instead an information
|
||||
# that was received on the COPY_OUT result at the beginning of COPY.
|
||||
# So, don't replace the results in the cursor, just update the rowcount.
|
||||
nrows = res.command_tuples
|
||||
self.cursor._rowcount = nrows if nrows is not None else -1
|
||||
return memoryview(b"")
|
||||
|
||||
def _read_row_gen(self) -> PQGen[Optional[Tuple[Any, ...]]]:
|
||||
data = yield from self._read_gen()
|
||||
if not data:
|
||||
return None
|
||||
|
||||
row = self.formatter.parse_row(data)
|
||||
if row is None:
|
||||
# Get the final result to finish the copy operation
|
||||
yield from self._read_gen()
|
||||
self._finished = True
|
||||
return None
|
||||
|
||||
return row
|
||||
|
||||
def _end_copy_out_gen(self, exc: Optional[BaseException]) -> PQGen[None]:
|
||||
if not exc:
|
||||
return
|
||||
|
||||
if self._pgconn.transaction_status != ACTIVE:
|
||||
# The server has already finished to send copy data. The connection
|
||||
# is already in a good state.
|
||||
return
|
||||
|
||||
# Throw a cancel to the server, then consume the rest of the copy data
|
||||
# (which might or might not have been already transferred entirely to
|
||||
# the client, so we won't necessary see the exception associated with
|
||||
# canceling).
|
||||
self.connection.cancel()
|
||||
try:
|
||||
while (yield from self._read_gen()):
|
||||
pass
|
||||
except e.QueryCanceled:
|
||||
pass
|
||||
|
||||
|
||||
class Copy(BaseCopy["Connection[Any]"]):
|
||||
"""Manage a :sql:`COPY` operation.
|
||||
|
||||
:param cursor: the cursor where the operation is performed.
|
||||
:param binary: if `!True`, write binary format.
|
||||
:param writer: the object to write to destination. If not specified, write
|
||||
to the `!cursor` connection.
|
||||
|
||||
Choosing `!binary` is not necessary if the cursor has executed a
|
||||
:sql:`COPY` operation, because the operation result describes the format
|
||||
too. The parameter is useful when a `!Copy` object is created manually and
|
||||
no operation is performed on the cursor, such as when using ``writer=``\\
|
||||
`~psycopg.copy.FileWriter`.
|
||||
|
||||
"""
|
||||
|
||||
__module__ = "psycopg"
|
||||
|
||||
writer: "Writer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cursor: "Cursor[Any]",
|
||||
*,
|
||||
binary: Optional[bool] = None,
|
||||
writer: Optional["Writer"] = None,
|
||||
):
|
||||
super().__init__(cursor, binary=binary)
|
||||
if not writer:
|
||||
writer = LibpqWriter(cursor)
|
||||
|
||||
self.writer = writer
|
||||
self._write = writer.write
|
||||
|
||||
def __enter__(self: BaseCopy._Self) -> BaseCopy._Self:
|
||||
self._enter()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
self.finish(exc_val)
|
||||
|
||||
# End user sync interface
|
||||
|
||||
def __iter__(self) -> Iterator[Buffer]:
|
||||
"""Implement block-by-block iteration on :sql:`COPY TO`."""
|
||||
while True:
|
||||
data = self.read()
|
||||
if not data:
|
||||
break
|
||||
yield data
|
||||
|
||||
def read(self) -> Buffer:
|
||||
"""
|
||||
Read an unparsed row after a :sql:`COPY TO` operation.
|
||||
|
||||
Return an empty string when the data is finished.
|
||||
"""
|
||||
return self.connection.wait(self._read_gen())
|
||||
|
||||
def rows(self) -> Iterator[Tuple[Any, ...]]:
|
||||
"""
|
||||
Iterate on the result of a :sql:`COPY TO` operation record by record.
|
||||
|
||||
Note that the records returned will be tuples of unparsed strings or
|
||||
bytes, unless data types are specified using `set_types()`.
|
||||
"""
|
||||
while True:
|
||||
record = self.read_row()
|
||||
if record is None:
|
||||
break
|
||||
yield record
|
||||
|
||||
def read_row(self) -> Optional[Tuple[Any, ...]]:
|
||||
"""
|
||||
Read a parsed row of data from a table after a :sql:`COPY TO` operation.
|
||||
|
||||
Return `!None` when the data is finished.
|
||||
|
||||
Note that the records returned will be tuples of unparsed strings or
|
||||
bytes, unless data types are specified using `set_types()`.
|
||||
"""
|
||||
return self.connection.wait(self._read_row_gen())
|
||||
|
||||
def write(self, buffer: Union[Buffer, str]) -> None:
|
||||
"""
|
||||
Write a block of data to a table after a :sql:`COPY FROM` operation.
|
||||
|
||||
If the :sql:`COPY` is in binary format `!buffer` must be `!bytes`. In
|
||||
text mode it can be either `!bytes` or `!str`.
|
||||
"""
|
||||
data = self.formatter.write(buffer)
|
||||
if data:
|
||||
self._write(data)
|
||||
|
||||
def write_row(self, row: Sequence[Any]) -> None:
|
||||
"""Write a record to a table after a :sql:`COPY FROM` operation."""
|
||||
data = self.formatter.write_row(row)
|
||||
if data:
|
||||
self._write(data)
|
||||
|
||||
def finish(self, exc: Optional[BaseException]) -> None:
|
||||
"""Terminate the copy operation and free the resources allocated.
|
||||
|
||||
You shouldn't need to call this function yourself: it is usually called
|
||||
by exit. It is available if, despite what is documented, you end up
|
||||
using the `Copy` object outside a block.
|
||||
"""
|
||||
if self._direction == COPY_IN:
|
||||
data = self.formatter.end()
|
||||
if data:
|
||||
self._write(data)
|
||||
self.writer.finish(exc)
|
||||
self._finished = True
|
||||
else:
|
||||
self.connection.wait(self._end_copy_out_gen(exc))
|
||||
|
||||
|
||||
class Writer(ABC):
|
||||
"""
|
||||
A class to write copy data somewhere.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def write(self, data: Buffer) -> None:
|
||||
"""
|
||||
Write some data to destination.
|
||||
"""
|
||||
...
|
||||
|
||||
def finish(self, exc: Optional[BaseException] = None) -> None:
|
||||
"""
|
||||
Called when write operations are finished.
|
||||
|
||||
If operations finished with an error, it will be passed to ``exc``.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class LibpqWriter(Writer):
|
||||
"""
|
||||
A `Writer` to write copy data to a Postgres database.
|
||||
"""
|
||||
|
||||
def __init__(self, cursor: "Cursor[Any]"):
|
||||
self.cursor = cursor
|
||||
self.connection = cursor.connection
|
||||
self._pgconn = self.connection.pgconn
|
||||
|
||||
def write(self, data: Buffer) -> None:
|
||||
if len(data) <= MAX_BUFFER_SIZE:
|
||||
# Most used path: we don't need to split the buffer in smaller
|
||||
# bits, so don't make a copy.
|
||||
self.connection.wait(copy_to(self._pgconn, data))
|
||||
else:
|
||||
# Copy a buffer too large in chunks to avoid causing a memory
|
||||
# error in the libpq, which may cause an infinite loop (#255).
|
||||
for i in range(0, len(data), MAX_BUFFER_SIZE):
|
||||
self.connection.wait(
|
||||
copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE])
|
||||
)
|
||||
|
||||
def finish(self, exc: Optional[BaseException] = None) -> None:
|
||||
bmsg: Optional[bytes]
|
||||
if exc:
|
||||
msg = f"error from Python: {type(exc).__qualname__} - {exc}"
|
||||
bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace")
|
||||
else:
|
||||
bmsg = None
|
||||
|
||||
try:
|
||||
res = self.connection.wait(copy_end(self._pgconn, bmsg))
|
||||
# The QueryCanceled is expected if we sent an exception message to
|
||||
# pgconn.put_copy_end(). The Python exception that generated that
|
||||
# cancelling is more important, so don't clobber it.
|
||||
except e.QueryCanceled:
|
||||
if not bmsg:
|
||||
raise
|
||||
else:
|
||||
self.cursor._results = [res]
|
||||
|
||||
|
||||
class QueuedLibpqWriter(LibpqWriter):
|
||||
"""
|
||||
A writer using a buffer to queue data to write to a Postgres database.
|
||||
|
||||
`write()` returns immediately, so that the main thread can be CPU-bound
|
||||
formatting messages, while a worker thread can be IO-bound waiting to write
|
||||
on the connection.
|
||||
"""
|
||||
|
||||
def __init__(self, cursor: "Cursor[Any]"):
|
||||
super().__init__(cursor)
|
||||
|
||||
self._queue: queue.Queue[Buffer] = queue.Queue(maxsize=QUEUE_SIZE)
|
||||
self._worker: Optional[threading.Thread] = None
|
||||
self._worker_error: Optional[BaseException] = None
|
||||
|
||||
def worker(self) -> None:
|
||||
"""Push data to the server when available from the copy queue.
|
||||
|
||||
Terminate reading when the queue receives a false-y value, or in case
|
||||
of error.
|
||||
|
||||
The function is designed to be run in a separate thread.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
data = self._queue.get(block=True, timeout=24 * 60 * 60)
|
||||
if not data:
|
||||
break
|
||||
self.connection.wait(copy_to(self._pgconn, data))
|
||||
except BaseException as ex:
|
||||
# Propagate the error to the main thread.
|
||||
self._worker_error = ex
|
||||
|
||||
def write(self, data: Buffer) -> None:
|
||||
if not self._worker:
|
||||
# warning: reference loop, broken by _write_end
|
||||
self._worker = threading.Thread(target=self.worker)
|
||||
self._worker.daemon = True
|
||||
self._worker.start()
|
||||
|
||||
# If the worker thread raies an exception, re-raise it to the caller.
|
||||
if self._worker_error:
|
||||
raise self._worker_error
|
||||
|
||||
if len(data) <= MAX_BUFFER_SIZE:
|
||||
# Most used path: we don't need to split the buffer in smaller
|
||||
# bits, so don't make a copy.
|
||||
self._queue.put(data)
|
||||
else:
|
||||
# Copy a buffer too large in chunks to avoid causing a memory
|
||||
# error in the libpq, which may cause an infinite loop (#255).
|
||||
for i in range(0, len(data), MAX_BUFFER_SIZE):
|
||||
self._queue.put(data[i : i + MAX_BUFFER_SIZE])
|
||||
|
||||
def finish(self, exc: Optional[BaseException] = None) -> None:
|
||||
self._queue.put(b"")
|
||||
|
||||
if self._worker:
|
||||
self._worker.join()
|
||||
self._worker = None # break the loop
|
||||
|
||||
# Check if the worker thread raised any exception before terminating.
|
||||
if self._worker_error:
|
||||
raise self._worker_error
|
||||
|
||||
super().finish(exc)
|
||||
|
||||
|
||||
class FileWriter(Writer):
|
||||
"""
|
||||
A `Writer` to write copy data to a file-like object.
|
||||
|
||||
:param file: the file where to write copy data. It must be open for writing
|
||||
in binary mode.
|
||||
"""
|
||||
|
||||
def __init__(self, file: IO[bytes]):
|
||||
self.file = file
|
||||
|
||||
def write(self, data: Buffer) -> None:
|
||||
self.file.write(data)
|
||||
|
||||
|
||||
class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
|
||||
"""Manage an asynchronous :sql:`COPY` operation."""
|
||||
|
||||
__module__ = "psycopg"
|
||||
|
||||
writer: "AsyncWriter"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cursor: "AsyncCursor[Any]",
|
||||
*,
|
||||
binary: Optional[bool] = None,
|
||||
writer: Optional["AsyncWriter"] = None,
|
||||
):
|
||||
super().__init__(cursor, binary=binary)
|
||||
|
||||
if not writer:
|
||||
writer = AsyncLibpqWriter(cursor)
|
||||
|
||||
self.writer = writer
|
||||
self._write = writer.write
|
||||
|
||||
async def __aenter__(self: BaseCopy._Self) -> BaseCopy._Self:
|
||||
self._enter()
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
await self.finish(exc_val)
|
||||
|
||||
async def __aiter__(self) -> AsyncIterator[Buffer]:
|
||||
while True:
|
||||
data = await self.read()
|
||||
if not data:
|
||||
break
|
||||
yield data
|
||||
|
||||
async def read(self) -> Buffer:
|
||||
return await self.connection.wait(self._read_gen())
|
||||
|
||||
async def rows(self) -> AsyncIterator[Tuple[Any, ...]]:
|
||||
while True:
|
||||
record = await self.read_row()
|
||||
if record is None:
|
||||
break
|
||||
yield record
|
||||
|
||||
async def read_row(self) -> Optional[Tuple[Any, ...]]:
|
||||
return await self.connection.wait(self._read_row_gen())
|
||||
|
||||
async def write(self, buffer: Union[Buffer, str]) -> None:
|
||||
data = self.formatter.write(buffer)
|
||||
if data:
|
||||
await self._write(data)
|
||||
|
||||
async def write_row(self, row: Sequence[Any]) -> None:
|
||||
data = self.formatter.write_row(row)
|
||||
if data:
|
||||
await self._write(data)
|
||||
|
||||
async def finish(self, exc: Optional[BaseException]) -> None:
|
||||
if self._direction == COPY_IN:
|
||||
data = self.formatter.end()
|
||||
if data:
|
||||
await self._write(data)
|
||||
await self.writer.finish(exc)
|
||||
self._finished = True
|
||||
else:
|
||||
await self.connection.wait(self._end_copy_out_gen(exc))
|
||||
|
||||
|
||||
class AsyncWriter(ABC):
|
||||
"""
|
||||
A class to write copy data somewhere (for async connections).
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def write(self, data: Buffer) -> None:
|
||||
...
|
||||
|
||||
async def finish(self, exc: Optional[BaseException] = None) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class AsyncLibpqWriter(AsyncWriter):
|
||||
"""
|
||||
An `AsyncWriter` to write copy data to a Postgres database.
|
||||
"""
|
||||
|
||||
def __init__(self, cursor: "AsyncCursor[Any]"):
|
||||
self.cursor = cursor
|
||||
self.connection = cursor.connection
|
||||
self._pgconn = self.connection.pgconn
|
||||
|
||||
async def write(self, data: Buffer) -> None:
|
||||
if len(data) <= MAX_BUFFER_SIZE:
|
||||
# Most used path: we don't need to split the buffer in smaller
|
||||
# bits, so don't make a copy.
|
||||
await self.connection.wait(copy_to(self._pgconn, data))
|
||||
else:
|
||||
# Copy a buffer too large in chunks to avoid causing a memory
|
||||
# error in the libpq, which may cause an infinite loop (#255).
|
||||
for i in range(0, len(data), MAX_BUFFER_SIZE):
|
||||
await self.connection.wait(
|
||||
copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE])
|
||||
)
|
||||
|
||||
async def finish(self, exc: Optional[BaseException] = None) -> None:
|
||||
bmsg: Optional[bytes]
|
||||
if exc:
|
||||
msg = f"error from Python: {type(exc).__qualname__} - {exc}"
|
||||
bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace")
|
||||
else:
|
||||
bmsg = None
|
||||
|
||||
try:
|
||||
res = await self.connection.wait(copy_end(self._pgconn, bmsg))
|
||||
# The QueryCanceled is expected if we sent an exception message to
|
||||
# pgconn.put_copy_end(). The Python exception that generated that
|
||||
# cancelling is more important, so don't clobber it.
|
||||
except e.QueryCanceled:
|
||||
if not bmsg:
|
||||
raise
|
||||
else:
|
||||
self.cursor._results = [res]
|
||||
|
||||
|
||||
class AsyncQueuedLibpqWriter(AsyncLibpqWriter):
|
||||
"""
|
||||
An `AsyncWriter` using a buffer to queue data to write.
|
||||
|
||||
`write()` returns immediately, so that the main thread can be CPU-bound
|
||||
formatting messages, while a worker thread can be IO-bound waiting to write
|
||||
on the connection.
|
||||
"""
|
||||
|
||||
def __init__(self, cursor: "AsyncCursor[Any]"):
|
||||
super().__init__(cursor)
|
||||
|
||||
self._queue: asyncio.Queue[Buffer] = asyncio.Queue(maxsize=QUEUE_SIZE)
|
||||
self._worker: Optional[asyncio.Future[None]] = None
|
||||
|
||||
async def worker(self) -> None:
|
||||
"""Push data to the server when available from the copy queue.
|
||||
|
||||
Terminate reading when the queue receives a false-y value.
|
||||
|
||||
The function is designed to be run in a separate task.
|
||||
"""
|
||||
while True:
|
||||
data = await self._queue.get()
|
||||
if not data:
|
||||
break
|
||||
await self.connection.wait(copy_to(self._pgconn, data))
|
||||
|
||||
async def write(self, data: Buffer) -> None:
|
||||
if not self._worker:
|
||||
self._worker = create_task(self.worker())
|
||||
|
||||
if len(data) <= MAX_BUFFER_SIZE:
|
||||
# Most used path: we don't need to split the buffer in smaller
|
||||
# bits, so don't make a copy.
|
||||
await self._queue.put(data)
|
||||
else:
|
||||
# Copy a buffer too large in chunks to avoid causing a memory
|
||||
# error in the libpq, which may cause an infinite loop (#255).
|
||||
for i in range(0, len(data), MAX_BUFFER_SIZE):
|
||||
await self._queue.put(data[i : i + MAX_BUFFER_SIZE])
|
||||
|
||||
async def finish(self, exc: Optional[BaseException] = None) -> None:
|
||||
await self._queue.put(b"")
|
||||
|
||||
if self._worker:
|
||||
await asyncio.gather(self._worker)
|
||||
self._worker = None # break reference loops if any
|
||||
|
||||
await super().finish(exc)
|
||||
|
||||
|
||||
class Formatter(ABC):
|
||||
"""
|
||||
A class which understand a copy format (text, binary).
|
||||
"""
|
||||
|
||||
format: pq.Format
|
||||
|
||||
def __init__(self, transformer: Transformer):
|
||||
self.transformer = transformer
|
||||
self._write_buffer = bytearray()
|
||||
self._row_mode = False # true if the user is using write_row()
|
||||
|
||||
@abstractmethod
|
||||
def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def write(self, buffer: Union[Buffer, str]) -> Buffer:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def write_row(self, row: Sequence[Any]) -> Buffer:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def end(self) -> Buffer:
|
||||
...
|
||||
|
||||
|
||||
class TextFormatter(Formatter):
|
||||
format = TEXT
|
||||
|
||||
def __init__(self, transformer: Transformer, encoding: str = "utf-8"):
|
||||
super().__init__(transformer)
|
||||
self._encoding = encoding
|
||||
|
||||
def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]:
|
||||
if data:
|
||||
return parse_row_text(data, self.transformer)
|
||||
else:
|
||||
return None
|
||||
|
||||
def write(self, buffer: Union[Buffer, str]) -> Buffer:
|
||||
data = self._ensure_bytes(buffer)
|
||||
self._signature_sent = True
|
||||
return data
|
||||
|
||||
def write_row(self, row: Sequence[Any]) -> Buffer:
|
||||
# Note down that we are writing in row mode: it means we will have
|
||||
# to take care of the end-of-copy marker too
|
||||
self._row_mode = True
|
||||
|
||||
format_row_text(row, self.transformer, self._write_buffer)
|
||||
if len(self._write_buffer) > BUFFER_SIZE:
|
||||
buffer, self._write_buffer = self._write_buffer, bytearray()
|
||||
return buffer
|
||||
else:
|
||||
return b""
|
||||
|
||||
def end(self) -> Buffer:
|
||||
buffer, self._write_buffer = self._write_buffer, bytearray()
|
||||
return buffer
|
||||
|
||||
def _ensure_bytes(self, data: Union[Buffer, str]) -> Buffer:
|
||||
if isinstance(data, str):
|
||||
return data.encode(self._encoding)
|
||||
else:
|
||||
# Assume, for simplicity, that the user is not passing stupid
|
||||
# things to the write function. If that's the case, things
|
||||
# will fail downstream.
|
||||
return data
|
||||
|
||||
|
||||
class BinaryFormatter(Formatter):
|
||||
format = BINARY
|
||||
|
||||
def __init__(self, transformer: Transformer):
|
||||
super().__init__(transformer)
|
||||
self._signature_sent = False
|
||||
|
||||
def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]:
|
||||
if not self._signature_sent:
|
||||
if data[: len(_binary_signature)] != _binary_signature:
|
||||
raise e.DataError(
|
||||
"binary copy doesn't start with the expected signature"
|
||||
)
|
||||
self._signature_sent = True
|
||||
data = data[len(_binary_signature) :]
|
||||
|
||||
elif data == _binary_trailer:
|
||||
return None
|
||||
|
||||
return parse_row_binary(data, self.transformer)
|
||||
|
||||
def write(self, buffer: Union[Buffer, str]) -> Buffer:
|
||||
data = self._ensure_bytes(buffer)
|
||||
self._signature_sent = True
|
||||
return data
|
||||
|
||||
def write_row(self, row: Sequence[Any]) -> Buffer:
|
||||
# Note down that we are writing in row mode: it means we will have
|
||||
# to take care of the end-of-copy marker too
|
||||
self._row_mode = True
|
||||
|
||||
if not self._signature_sent:
|
||||
self._write_buffer += _binary_signature
|
||||
self._signature_sent = True
|
||||
|
||||
format_row_binary(row, self.transformer, self._write_buffer)
|
||||
if len(self._write_buffer) > BUFFER_SIZE:
|
||||
buffer, self._write_buffer = self._write_buffer, bytearray()
|
||||
return buffer
|
||||
else:
|
||||
return b""
|
||||
|
||||
def end(self) -> Buffer:
|
||||
# If we have sent no data we need to send the signature
|
||||
# and the trailer
|
||||
if not self._signature_sent:
|
||||
self._write_buffer += _binary_signature
|
||||
self._write_buffer += _binary_trailer
|
||||
|
||||
elif self._row_mode:
|
||||
# if we have sent data already, we have sent the signature
|
||||
# too (either with the first row, or we assume that in
|
||||
# block mode the signature is included).
|
||||
# Write the trailer only if we are sending rows (with the
|
||||
# assumption that who is copying binary data is sending the
|
||||
# whole format).
|
||||
self._write_buffer += _binary_trailer
|
||||
|
||||
buffer, self._write_buffer = self._write_buffer, bytearray()
|
||||
return buffer
|
||||
|
||||
def _ensure_bytes(self, data: Union[Buffer, str]) -> Buffer:
|
||||
if isinstance(data, str):
|
||||
raise TypeError("cannot copy str data in binary mode: use bytes instead")
|
||||
else:
|
||||
# Assume, for simplicity, that the user is not passing stupid
|
||||
# things to the write function. If that's the case, things
|
||||
# will fail downstream.
|
||||
return data
|
||||
|
||||
|
||||
def _format_row_text(
|
||||
row: Sequence[Any], tx: Transformer, out: Optional[bytearray] = None
|
||||
) -> bytearray:
|
||||
"""Convert a row of objects to the data to send for copy."""
|
||||
if out is None:
|
||||
out = bytearray()
|
||||
|
||||
if not row:
|
||||
out += b"\n"
|
||||
return out
|
||||
|
||||
for item in row:
|
||||
if item is not None:
|
||||
dumper = tx.get_dumper(item, PY_TEXT)
|
||||
b = dumper.dump(item)
|
||||
out += _dump_re.sub(_dump_sub, b)
|
||||
else:
|
||||
out += rb"\N"
|
||||
out += b"\t"
|
||||
|
||||
out[-1:] = b"\n"
|
||||
return out
|
||||
|
||||
|
||||
def _format_row_binary(
|
||||
row: Sequence[Any], tx: Transformer, out: Optional[bytearray] = None
|
||||
) -> bytearray:
|
||||
"""Convert a row of objects to the data to send for binary copy."""
|
||||
if out is None:
|
||||
out = bytearray()
|
||||
|
||||
out += _pack_int2(len(row))
|
||||
adapted = tx.dump_sequence(row, [PY_BINARY] * len(row))
|
||||
for b in adapted:
|
||||
if b is not None:
|
||||
out += _pack_int4(len(b))
|
||||
out += b
|
||||
else:
|
||||
out += _binary_null
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _parse_row_text(data: Buffer, tx: Transformer) -> Tuple[Any, ...]:
|
||||
if not isinstance(data, bytes):
|
||||
data = bytes(data)
|
||||
fields = data.split(b"\t")
|
||||
fields[-1] = fields[-1][:-1] # drop \n
|
||||
row = [None if f == b"\\N" else _load_re.sub(_load_sub, f) for f in fields]
|
||||
return tx.load_sequence(row)
|
||||
|
||||
|
||||
def _parse_row_binary(data: Buffer, tx: Transformer) -> Tuple[Any, ...]:
|
||||
row: List[Optional[Buffer]] = []
|
||||
nfields = _unpack_int2(data, 0)[0]
|
||||
pos = 2
|
||||
for i in range(nfields):
|
||||
length = _unpack_int4(data, pos)[0]
|
||||
pos += 4
|
||||
if length >= 0:
|
||||
row.append(data[pos : pos + length])
|
||||
pos += length
|
||||
else:
|
||||
row.append(None)
|
||||
|
||||
return tx.load_sequence(row)
|
||||
|
||||
|
||||
_pack_int2 = struct.Struct("!h").pack
|
||||
_pack_int4 = struct.Struct("!i").pack
|
||||
_unpack_int2 = struct.Struct("!h").unpack_from
|
||||
_unpack_int4 = struct.Struct("!i").unpack_from
|
||||
|
||||
_binary_signature = (
|
||||
b"PGCOPY\n\xff\r\n\0" # Signature
|
||||
b"\x00\x00\x00\x00" # flags
|
||||
b"\x00\x00\x00\x00" # extra length
|
||||
)
|
||||
_binary_trailer = b"\xff\xff"
|
||||
_binary_null = b"\xff\xff\xff\xff"
|
||||
|
||||
_dump_re = re.compile(b"[\b\t\n\v\f\r\\\\]")
|
||||
_dump_repl = {
|
||||
b"\b": b"\\b",
|
||||
b"\t": b"\\t",
|
||||
b"\n": b"\\n",
|
||||
b"\v": b"\\v",
|
||||
b"\f": b"\\f",
|
||||
b"\r": b"\\r",
|
||||
b"\\": b"\\\\",
|
||||
}
|
||||
|
||||
|
||||
def _dump_sub(m: Match[bytes], __map: Dict[bytes, bytes] = _dump_repl) -> bytes:
|
||||
return __map[m.group(0)]
|
||||
|
||||
|
||||
_load_re = re.compile(b"\\\\[btnvfr\\\\]")
|
||||
_load_repl = {v: k for k, v in _dump_repl.items()}
|
||||
|
||||
|
||||
def _load_sub(m: Match[bytes], __map: Dict[bytes, bytes] = _load_repl) -> bytes:
|
||||
return __map[m.group(0)]
|
||||
|
||||
|
||||
# Override functions with fast versions if available
|
||||
if _psycopg:
|
||||
format_row_text = _psycopg.format_row_text
|
||||
format_row_binary = _psycopg.format_row_binary
|
||||
parse_row_text = _psycopg.parse_row_text
|
||||
parse_row_binary = _psycopg.parse_row_binary
|
||||
|
||||
else:
|
||||
format_row_text = _format_row_text
|
||||
format_row_binary = _format_row_binary
|
||||
parse_row_text = _parse_row_text
|
||||
parse_row_binary = _parse_row_binary
|
||||
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
CockroachDB support package.
|
||||
"""
|
||||
|
||||
# Copyright (C) 2022 The Psycopg Team
|
||||
|
||||
from . import _types
|
||||
from .connection import CrdbConnection, AsyncCrdbConnection, CrdbConnectionInfo
|
||||
|
||||
adapters = _types.adapters # exposed by the package
|
||||
connect = CrdbConnection.connect
|
||||
|
||||
_types.register_crdb_adapters(adapters)
|
||||
|
||||
__all__ = [
|
||||
"AsyncCrdbConnection",
|
||||
"CrdbConnection",
|
||||
"CrdbConnectionInfo",
|
||||
]
|
||||
163
srcs/.venv/lib/python3.11/site-packages/psycopg/crdb/_types.py
Normal file
163
srcs/.venv/lib/python3.11/site-packages/psycopg/crdb/_types.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
Types configuration specific for CockroachDB.
|
||||
"""
|
||||
|
||||
# Copyright (C) 2022 The Psycopg Team
|
||||
|
||||
from enum import Enum
|
||||
from .._typeinfo import TypeInfo, TypesRegistry
|
||||
|
||||
from ..abc import AdaptContext, NoneType
|
||||
from ..postgres import TEXT_OID
|
||||
from .._adapters_map import AdaptersMap
|
||||
from ..types.enum import EnumDumper, EnumBinaryDumper
|
||||
from ..types.none import NoneDumper
|
||||
|
||||
types = TypesRegistry()
|
||||
|
||||
# Global adapter maps with PostgreSQL types configuration
|
||||
adapters = AdaptersMap(types=types)
|
||||
|
||||
|
||||
class CrdbEnumDumper(EnumDumper):
|
||||
oid = TEXT_OID
|
||||
|
||||
|
||||
class CrdbEnumBinaryDumper(EnumBinaryDumper):
|
||||
oid = TEXT_OID
|
||||
|
||||
|
||||
class CrdbNoneDumper(NoneDumper):
|
||||
oid = TEXT_OID
|
||||
|
||||
|
||||
def register_postgres_adapters(context: AdaptContext) -> None:
|
||||
# Same adapters used by PostgreSQL, or a good starting point for customization
|
||||
|
||||
from ..types import array, bool, composite, datetime
|
||||
from ..types import numeric, string, uuid
|
||||
|
||||
array.register_default_adapters(context)
|
||||
bool.register_default_adapters(context)
|
||||
composite.register_default_adapters(context)
|
||||
datetime.register_default_adapters(context)
|
||||
numeric.register_default_adapters(context)
|
||||
string.register_default_adapters(context)
|
||||
uuid.register_default_adapters(context)
|
||||
|
||||
|
||||
def register_crdb_adapters(context: AdaptContext) -> None:
|
||||
from .. import dbapi20
|
||||
from ..types import array
|
||||
|
||||
register_postgres_adapters(context)
|
||||
|
||||
# String must come after enum to map text oid -> string dumper
|
||||
register_crdb_enum_adapters(context)
|
||||
register_crdb_string_adapters(context)
|
||||
register_crdb_json_adapters(context)
|
||||
register_crdb_net_adapters(context)
|
||||
register_crdb_none_adapters(context)
|
||||
|
||||
dbapi20.register_dbapi20_adapters(adapters)
|
||||
|
||||
array.register_all_arrays(adapters)
|
||||
|
||||
|
||||
def register_crdb_string_adapters(context: AdaptContext) -> None:
|
||||
from ..types import string
|
||||
|
||||
# Dump strings with text oid instead of unknown.
|
||||
# Unlike PostgreSQL, CRDB seems able to cast text to most types.
|
||||
context.adapters.register_dumper(str, string.StrDumper)
|
||||
context.adapters.register_dumper(str, string.StrBinaryDumper)
|
||||
|
||||
|
||||
def register_crdb_enum_adapters(context: AdaptContext) -> None:
|
||||
context.adapters.register_dumper(Enum, CrdbEnumBinaryDumper)
|
||||
context.adapters.register_dumper(Enum, CrdbEnumDumper)
|
||||
|
||||
|
||||
def register_crdb_json_adapters(context: AdaptContext) -> None:
|
||||
from ..types import json
|
||||
|
||||
adapters = context.adapters
|
||||
|
||||
# CRDB doesn't have json/jsonb: both names map to the jsonb oid
|
||||
adapters.register_dumper(json.Json, json.JsonbBinaryDumper)
|
||||
adapters.register_dumper(json.Json, json.JsonbDumper)
|
||||
|
||||
adapters.register_dumper(json.Jsonb, json.JsonbBinaryDumper)
|
||||
adapters.register_dumper(json.Jsonb, json.JsonbDumper)
|
||||
|
||||
adapters.register_loader("json", json.JsonLoader)
|
||||
adapters.register_loader("jsonb", json.JsonbLoader)
|
||||
adapters.register_loader("json", json.JsonBinaryLoader)
|
||||
adapters.register_loader("jsonb", json.JsonbBinaryLoader)
|
||||
|
||||
|
||||
def register_crdb_net_adapters(context: AdaptContext) -> None:
|
||||
from ..types import net
|
||||
|
||||
adapters = context.adapters
|
||||
|
||||
adapters.register_dumper("ipaddress.IPv4Address", net.InterfaceDumper)
|
||||
adapters.register_dumper("ipaddress.IPv6Address", net.InterfaceDumper)
|
||||
adapters.register_dumper("ipaddress.IPv4Interface", net.InterfaceDumper)
|
||||
adapters.register_dumper("ipaddress.IPv6Interface", net.InterfaceDumper)
|
||||
adapters.register_dumper("ipaddress.IPv4Address", net.AddressBinaryDumper)
|
||||
adapters.register_dumper("ipaddress.IPv6Address", net.AddressBinaryDumper)
|
||||
adapters.register_dumper("ipaddress.IPv4Interface", net.InterfaceBinaryDumper)
|
||||
adapters.register_dumper("ipaddress.IPv6Interface", net.InterfaceBinaryDumper)
|
||||
adapters.register_dumper(None, net.InetBinaryDumper)
|
||||
adapters.register_loader("inet", net.InetLoader)
|
||||
adapters.register_loader("inet", net.InetBinaryLoader)
|
||||
|
||||
|
||||
def register_crdb_none_adapters(context: AdaptContext) -> None:
|
||||
context.adapters.register_dumper(NoneType, CrdbNoneDumper)
|
||||
|
||||
|
||||
for t in [
|
||||
TypeInfo("json", 3802, 3807, regtype="jsonb"), # Alias json -> jsonb.
|
||||
TypeInfo("int8", 20, 1016, regtype="integer"), # Alias integer -> int8
|
||||
TypeInfo('"char"', 18, 1002), # special case, not generated
|
||||
# autogenerated: start
|
||||
# Generated from CockroachDB 22.1.0
|
||||
TypeInfo("bit", 1560, 1561),
|
||||
TypeInfo("bool", 16, 1000, regtype="boolean"),
|
||||
TypeInfo("bpchar", 1042, 1014, regtype="character"),
|
||||
TypeInfo("bytea", 17, 1001),
|
||||
TypeInfo("date", 1082, 1182),
|
||||
TypeInfo("float4", 700, 1021, regtype="real"),
|
||||
TypeInfo("float8", 701, 1022, regtype="double precision"),
|
||||
TypeInfo("inet", 869, 1041),
|
||||
TypeInfo("int2", 21, 1005, regtype="smallint"),
|
||||
TypeInfo("int2vector", 22, 1006),
|
||||
TypeInfo("int4", 23, 1007),
|
||||
TypeInfo("int8", 20, 1016, regtype="bigint"),
|
||||
TypeInfo("interval", 1186, 1187),
|
||||
TypeInfo("jsonb", 3802, 3807),
|
||||
TypeInfo("name", 19, 1003),
|
||||
TypeInfo("numeric", 1700, 1231),
|
||||
TypeInfo("oid", 26, 1028),
|
||||
TypeInfo("oidvector", 30, 1013),
|
||||
TypeInfo("record", 2249, 2287),
|
||||
TypeInfo("regclass", 2205, 2210),
|
||||
TypeInfo("regnamespace", 4089, 4090),
|
||||
TypeInfo("regproc", 24, 1008),
|
||||
TypeInfo("regprocedure", 2202, 2207),
|
||||
TypeInfo("regrole", 4096, 4097),
|
||||
TypeInfo("regtype", 2206, 2211),
|
||||
TypeInfo("text", 25, 1009),
|
||||
TypeInfo("time", 1083, 1183, regtype="time without time zone"),
|
||||
TypeInfo("timestamp", 1114, 1115, regtype="timestamp without time zone"),
|
||||
TypeInfo("timestamptz", 1184, 1185, regtype="timestamp with time zone"),
|
||||
TypeInfo("timetz", 1266, 1270, regtype="time with time zone"),
|
||||
TypeInfo("unknown", 705, 0),
|
||||
TypeInfo("uuid", 2950, 2951),
|
||||
TypeInfo("varbit", 1562, 1563, regtype="bit varying"),
|
||||
TypeInfo("varchar", 1043, 1015, regtype="character varying"),
|
||||
# autogenerated: end
|
||||
]:
|
||||
types.add(t)
|
||||
@@ -0,0 +1,185 @@
|
||||
"""
|
||||
CockroachDB-specific connections.
|
||||
"""
|
||||
|
||||
# Copyright (C) 2022 The Psycopg Team
|
||||
|
||||
import re
|
||||
from typing import Any, Optional, Type, Union, overload, TYPE_CHECKING
|
||||
|
||||
from .. import errors as e
|
||||
from ..abc import AdaptContext
|
||||
from ..rows import Row, RowFactory, AsyncRowFactory, TupleRow
|
||||
from ..conninfo import ConnectionInfo
|
||||
from ..connection import Connection
|
||||
from .._adapters_map import AdaptersMap
|
||||
from ..connection_async import AsyncConnection
|
||||
from ._types import adapters
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..pq.abc import PGconn
|
||||
from ..cursor import Cursor
|
||||
from ..cursor_async import AsyncCursor
|
||||
|
||||
|
||||
class _CrdbConnectionMixin:
|
||||
_adapters: Optional[AdaptersMap]
|
||||
pgconn: "PGconn"
|
||||
|
||||
@classmethod
|
||||
def is_crdb(
|
||||
cls, conn: Union[Connection[Any], AsyncConnection[Any], "PGconn"]
|
||||
) -> bool:
|
||||
"""
|
||||
Return `!True` if the server connected to `!conn` is CockroachDB.
|
||||
"""
|
||||
if isinstance(conn, (Connection, AsyncConnection)):
|
||||
conn = conn.pgconn
|
||||
|
||||
return bool(conn.parameter_status(b"crdb_version"))
|
||||
|
||||
@property
|
||||
def adapters(self) -> AdaptersMap:
|
||||
if not self._adapters:
|
||||
# By default, use CockroachDB adapters map
|
||||
self._adapters = AdaptersMap(adapters)
|
||||
|
||||
return self._adapters
|
||||
|
||||
@property
|
||||
def info(self) -> "CrdbConnectionInfo":
|
||||
return CrdbConnectionInfo(self.pgconn)
|
||||
|
||||
def _check_tpc(self) -> None:
|
||||
if self.is_crdb(self.pgconn):
|
||||
raise e.NotSupportedError("CockroachDB doesn't support prepared statements")
|
||||
|
||||
|
||||
class CrdbConnection(_CrdbConnectionMixin, Connection[Row]):
|
||||
"""
|
||||
Wrapper for a connection to a CockroachDB database.
|
||||
"""
|
||||
|
||||
__module__ = "psycopg.crdb"
|
||||
|
||||
# TODO: this method shouldn't require re-definition if the base class
|
||||
# implements a generic self.
|
||||
# https://github.com/psycopg/psycopg/issues/308
|
||||
@overload
|
||||
@classmethod
|
||||
def connect(
|
||||
cls,
|
||||
conninfo: str = "",
|
||||
*,
|
||||
autocommit: bool = False,
|
||||
row_factory: RowFactory[Row],
|
||||
prepare_threshold: Optional[int] = 5,
|
||||
cursor_factory: "Optional[Type[Cursor[Row]]]" = None,
|
||||
context: Optional[AdaptContext] = None,
|
||||
**kwargs: Union[None, int, str],
|
||||
) -> "CrdbConnection[Row]":
|
||||
...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def connect(
|
||||
cls,
|
||||
conninfo: str = "",
|
||||
*,
|
||||
autocommit: bool = False,
|
||||
prepare_threshold: Optional[int] = 5,
|
||||
cursor_factory: "Optional[Type[Cursor[Any]]]" = None,
|
||||
context: Optional[AdaptContext] = None,
|
||||
**kwargs: Union[None, int, str],
|
||||
) -> "CrdbConnection[TupleRow]":
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def connect(cls, conninfo: str = "", **kwargs: Any) -> "CrdbConnection[Any]":
|
||||
"""
|
||||
Connect to a database server and return a new `CrdbConnection` instance.
|
||||
"""
|
||||
return super().connect(conninfo, **kwargs) # type: ignore[return-value]
|
||||
|
||||
|
||||
class AsyncCrdbConnection(_CrdbConnectionMixin, AsyncConnection[Row]):
|
||||
"""
|
||||
Wrapper for an async connection to a CockroachDB database.
|
||||
"""
|
||||
|
||||
__module__ = "psycopg.crdb"
|
||||
|
||||
# TODO: this method shouldn't require re-definition if the base class
|
||||
# implements a generic self.
|
||||
# https://github.com/psycopg/psycopg/issues/308
|
||||
@overload
|
||||
@classmethod
|
||||
async def connect(
|
||||
cls,
|
||||
conninfo: str = "",
|
||||
*,
|
||||
autocommit: bool = False,
|
||||
prepare_threshold: Optional[int] = 5,
|
||||
row_factory: AsyncRowFactory[Row],
|
||||
cursor_factory: "Optional[Type[AsyncCursor[Row]]]" = None,
|
||||
context: Optional[AdaptContext] = None,
|
||||
**kwargs: Union[None, int, str],
|
||||
) -> "AsyncCrdbConnection[Row]":
|
||||
...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def connect(
|
||||
cls,
|
||||
conninfo: str = "",
|
||||
*,
|
||||
autocommit: bool = False,
|
||||
prepare_threshold: Optional[int] = 5,
|
||||
cursor_factory: "Optional[Type[AsyncCursor[Any]]]" = None,
|
||||
context: Optional[AdaptContext] = None,
|
||||
**kwargs: Union[None, int, str],
|
||||
) -> "AsyncCrdbConnection[TupleRow]":
|
||||
...
|
||||
|
||||
@classmethod
|
||||
async def connect(
|
||||
cls, conninfo: str = "", **kwargs: Any
|
||||
) -> "AsyncCrdbConnection[Any]":
|
||||
return await super().connect(conninfo, **kwargs) # type: ignore [no-any-return]
|
||||
|
||||
|
||||
class CrdbConnectionInfo(ConnectionInfo):
|
||||
"""
|
||||
`~psycopg.ConnectionInfo` subclass to get info about a CockroachDB database.
|
||||
"""
|
||||
|
||||
__module__ = "psycopg.crdb"
|
||||
|
||||
@property
|
||||
def vendor(self) -> str:
|
||||
return "CockroachDB"
|
||||
|
||||
@property
|
||||
def server_version(self) -> int:
|
||||
"""
|
||||
Return the CockroachDB server version connected.
|
||||
|
||||
Return a number in the PostgreSQL format (e.g. 21.2.10 -> 210210).
|
||||
"""
|
||||
sver = self.parameter_status("crdb_version")
|
||||
if not sver:
|
||||
raise e.InternalError("'crdb_version' parameter status not set")
|
||||
|
||||
ver = self.parse_crdb_version(sver)
|
||||
if ver is None:
|
||||
raise e.InterfaceError(f"couldn't parse CockroachDB version from: {sver!r}")
|
||||
|
||||
return ver
|
||||
|
||||
@classmethod
|
||||
def parse_crdb_version(self, sver: str) -> Optional[int]:
|
||||
m = re.search(r"\bv(\d+)\.(\d+)\.(\d+)", sver)
|
||||
if not m:
|
||||
return None
|
||||
|
||||
return int(m.group(1)) * 10000 + int(m.group(2)) * 100 + int(m.group(3))
|
||||
929
srcs/.venv/lib/python3.11/site-packages/psycopg/cursor.py
Normal file
929
srcs/.venv/lib/python3.11/site-packages/psycopg/cursor.py
Normal file
@@ -0,0 +1,929 @@
|
||||
"""
|
||||
psycopg cursor objects
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
from functools import partial
|
||||
from types import TracebackType
|
||||
from typing import Any, Generic, Iterable, Iterator, List
|
||||
from typing import Optional, NoReturn, Sequence, Tuple, Type, TypeVar
|
||||
from typing import overload, TYPE_CHECKING
|
||||
from warnings import warn
|
||||
from contextlib import contextmanager
|
||||
|
||||
from . import pq
|
||||
from . import adapt
|
||||
from . import errors as e
|
||||
from .abc import ConnectionType, Query, Params, PQGen
|
||||
from .copy import Copy, Writer as CopyWriter
|
||||
from .rows import Row, RowMaker, RowFactory
|
||||
from ._column import Column
|
||||
from .pq.misc import connection_summary
|
||||
from ._queries import PostgresQuery, PostgresClientQuery
|
||||
from ._pipeline import Pipeline
|
||||
from ._encodings import pgconn_encoding
|
||||
from ._preparing import Prepare
|
||||
from .generators import execute, fetch, send
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .abc import Transformer
|
||||
from .pq.abc import PGconn, PGresult
|
||||
from .connection import Connection
|
||||
|
||||
TEXT = pq.Format.TEXT
|
||||
BINARY = pq.Format.BINARY
|
||||
|
||||
EMPTY_QUERY = pq.ExecStatus.EMPTY_QUERY
|
||||
COMMAND_OK = pq.ExecStatus.COMMAND_OK
|
||||
TUPLES_OK = pq.ExecStatus.TUPLES_OK
|
||||
COPY_OUT = pq.ExecStatus.COPY_OUT
|
||||
COPY_IN = pq.ExecStatus.COPY_IN
|
||||
COPY_BOTH = pq.ExecStatus.COPY_BOTH
|
||||
FATAL_ERROR = pq.ExecStatus.FATAL_ERROR
|
||||
SINGLE_TUPLE = pq.ExecStatus.SINGLE_TUPLE
|
||||
PIPELINE_ABORTED = pq.ExecStatus.PIPELINE_ABORTED
|
||||
|
||||
ACTIVE = pq.TransactionStatus.ACTIVE
|
||||
|
||||
|
||||
class BaseCursor(Generic[ConnectionType, Row]):
|
||||
__slots__ = """
|
||||
_conn format _adapters arraysize _closed _results pgresult _pos
|
||||
_iresult _rowcount _query _tx _last_query _row_factory _make_row
|
||||
_pgconn _execmany_returning
|
||||
__weakref__
|
||||
""".split()
|
||||
|
||||
ExecStatus = pq.ExecStatus
|
||||
|
||||
_tx: "Transformer"
|
||||
_make_row: RowMaker[Row]
|
||||
_pgconn: "PGconn"
|
||||
|
||||
def __init__(self, connection: ConnectionType):
|
||||
self._conn = connection
|
||||
self.format = TEXT
|
||||
self._pgconn = connection.pgconn
|
||||
self._adapters = adapt.AdaptersMap(connection.adapters)
|
||||
self.arraysize = 1
|
||||
self._closed = False
|
||||
self._last_query: Optional[Query] = None
|
||||
self._reset()
|
||||
|
||||
def _reset(self, reset_query: bool = True) -> None:
|
||||
self._results: List["PGresult"] = []
|
||||
self.pgresult: Optional["PGresult"] = None
|
||||
self._pos = 0
|
||||
self._iresult = 0
|
||||
self._rowcount = -1
|
||||
self._query: Optional[PostgresQuery]
|
||||
# None if executemany() not executing, True/False according to returning state
|
||||
self._execmany_returning: Optional[bool] = None
|
||||
if reset_query:
|
||||
self._query = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
|
||||
info = connection_summary(self._pgconn)
|
||||
if self._closed:
|
||||
status = "closed"
|
||||
elif self.pgresult:
|
||||
status = pq.ExecStatus(self.pgresult.status).name
|
||||
else:
|
||||
status = "no result"
|
||||
return f"<{cls} [{status}] {info} at 0x{id(self):x}>"
|
||||
|
||||
@property
|
||||
def connection(self) -> ConnectionType:
|
||||
"""The connection this cursor is using."""
|
||||
return self._conn
|
||||
|
||||
@property
|
||||
def adapters(self) -> adapt.AdaptersMap:
|
||||
return self._adapters
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
"""`True` if the cursor is closed."""
|
||||
return self._closed
|
||||
|
||||
@property
|
||||
def description(self) -> Optional[List[Column]]:
|
||||
"""
|
||||
A list of `Column` objects describing the current resultset.
|
||||
|
||||
`!None` if the current resultset didn't return tuples.
|
||||
"""
|
||||
res = self.pgresult
|
||||
|
||||
# We return columns if we have nfields, but also if we don't but
|
||||
# the query said we got tuples (mostly to handle the super useful
|
||||
# query "SELECT ;"
|
||||
if res and (
|
||||
res.nfields or res.status == TUPLES_OK or res.status == SINGLE_TUPLE
|
||||
):
|
||||
return [Column(self, i) for i in range(res.nfields)]
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def rowcount(self) -> int:
|
||||
"""Number of records affected by the precedent operation."""
|
||||
return self._rowcount
|
||||
|
||||
@property
|
||||
def rownumber(self) -> Optional[int]:
|
||||
"""Index of the next row to fetch in the current result.
|
||||
|
||||
`!None` if there is no result to fetch.
|
||||
"""
|
||||
tuples = self.pgresult and self.pgresult.status == TUPLES_OK
|
||||
return self._pos if tuples else None
|
||||
|
||||
def setinputsizes(self, sizes: Sequence[Any]) -> None:
|
||||
# no-op
|
||||
pass
|
||||
|
||||
def setoutputsize(self, size: Any, column: Optional[int] = None) -> None:
|
||||
# no-op
|
||||
pass
|
||||
|
||||
def nextset(self) -> Optional[bool]:
|
||||
"""
|
||||
Move to the result set of the next query executed through `executemany()`
|
||||
or to the next result set if `execute()` returned more than one.
|
||||
|
||||
Return `!True` if a new result is available, which will be the one
|
||||
methods `!fetch*()` will operate on.
|
||||
"""
|
||||
# Raise a warning if people is calling nextset() in pipeline mode
|
||||
# after a sequence of execute() in pipeline mode. Pipeline accumulating
|
||||
# execute() results in the cursor is an unintended difference w.r.t.
|
||||
# non-pipeline mode.
|
||||
if self._execmany_returning is None and self._conn._pipeline:
|
||||
warn(
|
||||
"using nextset() in pipeline mode for several execute() is"
|
||||
" deprecated and will be dropped in 3.2; please use different"
|
||||
" cursors to receive more than one result",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
if self._iresult < len(self._results) - 1:
|
||||
self._select_current_result(self._iresult + 1)
|
||||
return True
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def statusmessage(self) -> Optional[str]:
|
||||
"""
|
||||
The command status tag from the last SQL command executed.
|
||||
|
||||
`!None` if the cursor doesn't have a result available.
|
||||
"""
|
||||
msg = self.pgresult.command_status if self.pgresult else None
|
||||
return msg.decode() if msg else None
|
||||
|
||||
def _make_row_maker(self) -> RowMaker[Row]:
|
||||
raise NotImplementedError
|
||||
|
||||
#
|
||||
# Generators for the high level operations on the cursor
|
||||
#
|
||||
# Like for sync/async connections, these are implemented as generators
|
||||
# so that different concurrency strategies (threads,asyncio) can use their
|
||||
# own way of waiting (or better, `connection.wait()`).
|
||||
#
|
||||
|
||||
def _execute_gen(
|
||||
self,
|
||||
query: Query,
|
||||
params: Optional[Params] = None,
|
||||
*,
|
||||
prepare: Optional[bool] = None,
|
||||
binary: Optional[bool] = None,
|
||||
) -> PQGen[None]:
|
||||
"""Generator implementing `Cursor.execute()`."""
|
||||
yield from self._start_query(query)
|
||||
pgq = self._convert_query(query, params)
|
||||
results = yield from self._maybe_prepare_gen(
|
||||
pgq, prepare=prepare, binary=binary
|
||||
)
|
||||
if self._conn._pipeline:
|
||||
yield from self._conn._pipeline._communicate_gen()
|
||||
else:
|
||||
assert results is not None
|
||||
self._check_results(results)
|
||||
self._results = results
|
||||
self._select_current_result(0)
|
||||
|
||||
self._last_query = query
|
||||
|
||||
for cmd in self._conn._prepared.get_maintenance_commands():
|
||||
yield from self._conn._exec_command(cmd)
|
||||
|
||||
def _executemany_gen_pipeline(
|
||||
self, query: Query, params_seq: Iterable[Params], returning: bool
|
||||
) -> PQGen[None]:
|
||||
"""
|
||||
Generator implementing `Cursor.executemany()` with pipelines available.
|
||||
"""
|
||||
pipeline = self._conn._pipeline
|
||||
assert pipeline
|
||||
|
||||
yield from self._start_query(query)
|
||||
if not returning:
|
||||
self._rowcount = 0
|
||||
|
||||
assert self._execmany_returning is None
|
||||
self._execmany_returning = returning
|
||||
|
||||
first = True
|
||||
for params in params_seq:
|
||||
if first:
|
||||
pgq = self._convert_query(query, params)
|
||||
self._query = pgq
|
||||
first = False
|
||||
else:
|
||||
pgq.dump(params)
|
||||
|
||||
yield from self._maybe_prepare_gen(pgq, prepare=True)
|
||||
yield from pipeline._communicate_gen()
|
||||
|
||||
self._last_query = query
|
||||
|
||||
if returning:
|
||||
yield from pipeline._fetch_gen(flush=True)
|
||||
|
||||
for cmd in self._conn._prepared.get_maintenance_commands():
|
||||
yield from self._conn._exec_command(cmd)
|
||||
|
||||
def _executemany_gen_no_pipeline(
|
||||
self, query: Query, params_seq: Iterable[Params], returning: bool
|
||||
) -> PQGen[None]:
|
||||
"""
|
||||
Generator implementing `Cursor.executemany()` with pipelines not available.
|
||||
"""
|
||||
yield from self._start_query(query)
|
||||
if not returning:
|
||||
self._rowcount = 0
|
||||
first = True
|
||||
for params in params_seq:
|
||||
if first:
|
||||
pgq = self._convert_query(query, params)
|
||||
self._query = pgq
|
||||
first = False
|
||||
else:
|
||||
pgq.dump(params)
|
||||
|
||||
results = yield from self._maybe_prepare_gen(pgq, prepare=True)
|
||||
assert results is not None
|
||||
self._check_results(results)
|
||||
if returning:
|
||||
self._results.extend(results)
|
||||
else:
|
||||
# In non-returning case, set rowcount to the cumulated number
|
||||
# of rows of executed queries.
|
||||
for res in results:
|
||||
self._rowcount += res.command_tuples or 0
|
||||
|
||||
if self._results:
|
||||
self._select_current_result(0)
|
||||
|
||||
self._last_query = query
|
||||
|
||||
for cmd in self._conn._prepared.get_maintenance_commands():
|
||||
yield from self._conn._exec_command(cmd)
|
||||
|
||||
def _maybe_prepare_gen(
|
||||
self,
|
||||
pgq: PostgresQuery,
|
||||
*,
|
||||
prepare: Optional[bool] = None,
|
||||
binary: Optional[bool] = None,
|
||||
) -> PQGen[Optional[List["PGresult"]]]:
|
||||
# Check if the query is prepared or needs preparing
|
||||
prep, name = self._get_prepared(pgq, prepare)
|
||||
if prep is Prepare.NO:
|
||||
# The query must be executed without preparing
|
||||
self._execute_send(pgq, binary=binary)
|
||||
else:
|
||||
# If the query is not already prepared, prepare it.
|
||||
if prep is Prepare.SHOULD:
|
||||
self._send_prepare(name, pgq)
|
||||
if not self._conn._pipeline:
|
||||
(result,) = yield from execute(self._pgconn)
|
||||
if result.status == FATAL_ERROR:
|
||||
raise e.error_from_result(result, encoding=self._encoding)
|
||||
# Then execute it.
|
||||
self._send_query_prepared(name, pgq, binary=binary)
|
||||
|
||||
# Update the prepare state of the query.
|
||||
# If an operation requires to flush our prepared statements cache,
|
||||
# it will be added to the maintenance commands to execute later.
|
||||
key = self._conn._prepared.maybe_add_to_cache(pgq, prep, name)
|
||||
|
||||
if self._conn._pipeline:
|
||||
queued = None
|
||||
if key is not None:
|
||||
queued = (key, prep, name)
|
||||
self._conn._pipeline.result_queue.append((self, queued))
|
||||
return None
|
||||
|
||||
# run the query
|
||||
results = yield from execute(self._pgconn)
|
||||
|
||||
if key is not None:
|
||||
self._conn._prepared.validate(key, prep, name, results)
|
||||
|
||||
return results
|
||||
|
||||
def _get_prepared(
|
||||
self, pgq: PostgresQuery, prepare: Optional[bool] = None
|
||||
) -> Tuple[Prepare, bytes]:
|
||||
return self._conn._prepared.get(pgq, prepare)
|
||||
|
||||
def _stream_send_gen(
|
||||
self,
|
||||
query: Query,
|
||||
params: Optional[Params] = None,
|
||||
*,
|
||||
binary: Optional[bool] = None,
|
||||
) -> PQGen[None]:
|
||||
"""Generator to send the query for `Cursor.stream()`."""
|
||||
yield from self._start_query(query)
|
||||
pgq = self._convert_query(query, params)
|
||||
self._execute_send(pgq, binary=binary, force_extended=True)
|
||||
self._pgconn.set_single_row_mode()
|
||||
self._last_query = query
|
||||
yield from send(self._pgconn)
|
||||
|
||||
def _stream_fetchone_gen(self, first: bool) -> PQGen[Optional["PGresult"]]:
|
||||
res = yield from fetch(self._pgconn)
|
||||
if res is None:
|
||||
return None
|
||||
|
||||
status = res.status
|
||||
if status == SINGLE_TUPLE:
|
||||
self.pgresult = res
|
||||
self._tx.set_pgresult(res, set_loaders=first)
|
||||
if first:
|
||||
self._make_row = self._make_row_maker()
|
||||
return res
|
||||
|
||||
elif status == TUPLES_OK or status == COMMAND_OK:
|
||||
# End of single row results
|
||||
while res:
|
||||
res = yield from fetch(self._pgconn)
|
||||
if status != TUPLES_OK:
|
||||
raise e.ProgrammingError(
|
||||
"the operation in stream() didn't produce a result"
|
||||
)
|
||||
return None
|
||||
|
||||
else:
|
||||
# Errors, unexpected values
|
||||
return self._raise_for_result(res)
|
||||
|
||||
def _start_query(self, query: Optional[Query] = None) -> PQGen[None]:
|
||||
"""Generator to start the processing of a query.
|
||||
|
||||
It is implemented as generator because it may send additional queries,
|
||||
such as `begin`.
|
||||
"""
|
||||
if self.closed:
|
||||
raise e.InterfaceError("the cursor is closed")
|
||||
|
||||
self._reset()
|
||||
if not self._last_query or (self._last_query is not query):
|
||||
self._last_query = None
|
||||
self._tx = adapt.Transformer(self)
|
||||
yield from self._conn._start_query()
|
||||
|
||||
def _start_copy_gen(
|
||||
self, statement: Query, params: Optional[Params] = None
|
||||
) -> PQGen[None]:
|
||||
"""Generator implementing sending a command for `Cursor.copy()."""
|
||||
|
||||
# The connection gets in an unrecoverable state if we attempt COPY in
|
||||
# pipeline mode. Forbid it explicitly.
|
||||
if self._conn._pipeline:
|
||||
raise e.NotSupportedError("COPY cannot be used in pipeline mode")
|
||||
|
||||
yield from self._start_query()
|
||||
|
||||
# Merge the params client-side
|
||||
if params:
|
||||
pgq = PostgresClientQuery(self._tx)
|
||||
pgq.convert(statement, params)
|
||||
statement = pgq.query
|
||||
|
||||
query = self._convert_query(statement)
|
||||
|
||||
self._execute_send(query, binary=False)
|
||||
results = yield from execute(self._pgconn)
|
||||
if len(results) != 1:
|
||||
raise e.ProgrammingError("COPY cannot be mixed with other operations")
|
||||
|
||||
self._check_copy_result(results[0])
|
||||
self._results = results
|
||||
self._select_current_result(0)
|
||||
|
||||
def _execute_send(
|
||||
self,
|
||||
query: PostgresQuery,
|
||||
*,
|
||||
force_extended: bool = False,
|
||||
binary: Optional[bool] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Implement part of execute() before waiting common to sync and async.
|
||||
|
||||
This is not a generator, but a normal non-blocking function.
|
||||
"""
|
||||
if binary is None:
|
||||
fmt = self.format
|
||||
else:
|
||||
fmt = BINARY if binary else TEXT
|
||||
|
||||
self._query = query
|
||||
|
||||
if self._conn._pipeline:
|
||||
# In pipeline mode always use PQsendQueryParams - see #314
|
||||
# Multiple statements in the same query are not allowed anyway.
|
||||
self._conn._pipeline.command_queue.append(
|
||||
partial(
|
||||
self._pgconn.send_query_params,
|
||||
query.query,
|
||||
query.params,
|
||||
param_formats=query.formats,
|
||||
param_types=query.types,
|
||||
result_format=fmt,
|
||||
)
|
||||
)
|
||||
elif force_extended or query.params or fmt == BINARY:
|
||||
self._pgconn.send_query_params(
|
||||
query.query,
|
||||
query.params,
|
||||
param_formats=query.formats,
|
||||
param_types=query.types,
|
||||
result_format=fmt,
|
||||
)
|
||||
else:
|
||||
# If we can, let's use simple query protocol,
|
||||
# as it can execute more than one statement in a single query.
|
||||
self._pgconn.send_query(query.query)
|
||||
|
||||
def _convert_query(
|
||||
self, query: Query, params: Optional[Params] = None
|
||||
) -> PostgresQuery:
|
||||
pgq = PostgresQuery(self._tx)
|
||||
pgq.convert(query, params)
|
||||
return pgq
|
||||
|
||||
def _check_results(self, results: List["PGresult"]) -> None:
|
||||
"""
|
||||
Verify that the results of a query are valid.
|
||||
|
||||
Verify that the query returned at least one result and that they all
|
||||
represent a valid result from the database.
|
||||
"""
|
||||
if not results:
|
||||
raise e.InternalError("got no result from the query")
|
||||
|
||||
for res in results:
|
||||
status = res.status
|
||||
if status != TUPLES_OK and status != COMMAND_OK and status != EMPTY_QUERY:
|
||||
self._raise_for_result(res)
|
||||
|
||||
def _raise_for_result(self, result: "PGresult") -> NoReturn:
|
||||
"""
|
||||
Raise an appropriate error message for an unexpected database result
|
||||
"""
|
||||
status = result.status
|
||||
if status == FATAL_ERROR:
|
||||
raise e.error_from_result(result, encoding=self._encoding)
|
||||
elif status == PIPELINE_ABORTED:
|
||||
raise e.PipelineAborted("pipeline aborted")
|
||||
elif status == COPY_IN or status == COPY_OUT or status == COPY_BOTH:
|
||||
raise e.ProgrammingError(
|
||||
"COPY cannot be used with this method; use copy() instead"
|
||||
)
|
||||
else:
|
||||
raise e.InternalError(
|
||||
"unexpected result status from query:" f" {pq.ExecStatus(status).name}"
|
||||
)
|
||||
|
||||
def _select_current_result(
|
||||
self, i: int, format: Optional[pq.Format] = None
|
||||
) -> None:
|
||||
"""
|
||||
Select one of the results in the cursor as the active one.
|
||||
"""
|
||||
self._iresult = i
|
||||
res = self.pgresult = self._results[i]
|
||||
|
||||
# Note: the only reason to override format is to correctly set
|
||||
# binary loaders on server-side cursors, because send_describe_portal
|
||||
# only returns a text result.
|
||||
self._tx.set_pgresult(res, format=format)
|
||||
|
||||
self._pos = 0
|
||||
|
||||
if res.status == TUPLES_OK:
|
||||
self._rowcount = self.pgresult.ntuples
|
||||
|
||||
# COPY_OUT has never info about nrows. We need such result for the
|
||||
# columns in order to return a `description`, but not overwrite the
|
||||
# cursor rowcount (which was set by the Copy object).
|
||||
elif res.status != COPY_OUT:
|
||||
nrows = self.pgresult.command_tuples
|
||||
self._rowcount = nrows if nrows is not None else -1
|
||||
|
||||
self._make_row = self._make_row_maker()
|
||||
|
||||
def _set_results_from_pipeline(self, results: List["PGresult"]) -> None:
|
||||
self._check_results(results)
|
||||
first_batch = not self._results
|
||||
|
||||
if self._execmany_returning is None:
|
||||
# Received from execute()
|
||||
self._results.extend(results)
|
||||
if first_batch:
|
||||
self._select_current_result(0)
|
||||
|
||||
else:
|
||||
# Received from executemany()
|
||||
if self._execmany_returning:
|
||||
self._results.extend(results)
|
||||
if first_batch:
|
||||
self._select_current_result(0)
|
||||
else:
|
||||
# In non-returning case, set rowcount to the cumulated number of
|
||||
# rows of executed queries.
|
||||
for res in results:
|
||||
self._rowcount += res.command_tuples or 0
|
||||
|
||||
def _send_prepare(self, name: bytes, query: PostgresQuery) -> None:
|
||||
if self._conn._pipeline:
|
||||
self._conn._pipeline.command_queue.append(
|
||||
partial(
|
||||
self._pgconn.send_prepare,
|
||||
name,
|
||||
query.query,
|
||||
param_types=query.types,
|
||||
)
|
||||
)
|
||||
self._conn._pipeline.result_queue.append(None)
|
||||
else:
|
||||
self._pgconn.send_prepare(name, query.query, param_types=query.types)
|
||||
|
||||
def _send_query_prepared(
|
||||
self, name: bytes, pgq: PostgresQuery, *, binary: Optional[bool] = None
|
||||
) -> None:
|
||||
if binary is None:
|
||||
fmt = self.format
|
||||
else:
|
||||
fmt = BINARY if binary else TEXT
|
||||
|
||||
if self._conn._pipeline:
|
||||
self._conn._pipeline.command_queue.append(
|
||||
partial(
|
||||
self._pgconn.send_query_prepared,
|
||||
name,
|
||||
pgq.params,
|
||||
param_formats=pgq.formats,
|
||||
result_format=fmt,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self._pgconn.send_query_prepared(
|
||||
name, pgq.params, param_formats=pgq.formats, result_format=fmt
|
||||
)
|
||||
|
||||
def _check_result_for_fetch(self) -> None:
|
||||
if self.closed:
|
||||
raise e.InterfaceError("the cursor is closed")
|
||||
res = self.pgresult
|
||||
if not res:
|
||||
raise e.ProgrammingError("no result available")
|
||||
|
||||
status = res.status
|
||||
if status == TUPLES_OK:
|
||||
return
|
||||
elif status == FATAL_ERROR:
|
||||
raise e.error_from_result(res, encoding=self._encoding)
|
||||
elif status == PIPELINE_ABORTED:
|
||||
raise e.PipelineAborted("pipeline aborted")
|
||||
else:
|
||||
raise e.ProgrammingError("the last operation didn't produce a result")
|
||||
|
||||
def _check_copy_result(self, result: "PGresult") -> None:
|
||||
"""
|
||||
Check that the value returned in a copy() operation is a legit COPY.
|
||||
"""
|
||||
status = result.status
|
||||
if status == COPY_IN or status == COPY_OUT:
|
||||
return
|
||||
elif status == FATAL_ERROR:
|
||||
raise e.error_from_result(result, encoding=self._encoding)
|
||||
else:
|
||||
raise e.ProgrammingError(
|
||||
"copy() should be used only with COPY ... TO STDOUT or COPY ..."
|
||||
f" FROM STDIN statements, got {pq.ExecStatus(status).name}"
|
||||
)
|
||||
|
||||
def _scroll(self, value: int, mode: str) -> None:
|
||||
self._check_result_for_fetch()
|
||||
assert self.pgresult
|
||||
if mode == "relative":
|
||||
newpos = self._pos + value
|
||||
elif mode == "absolute":
|
||||
newpos = value
|
||||
else:
|
||||
raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'")
|
||||
if not 0 <= newpos < self.pgresult.ntuples:
|
||||
raise IndexError("position out of bound")
|
||||
self._pos = newpos
|
||||
|
||||
def _close(self) -> None:
|
||||
"""Non-blocking part of closing. Common to sync/async."""
|
||||
# Don't reset the query because it may be useful to investigate after
|
||||
# an error.
|
||||
self._reset(reset_query=False)
|
||||
self._closed = True
|
||||
|
||||
@property
|
||||
def _encoding(self) -> str:
|
||||
return pgconn_encoding(self._pgconn)
|
||||
|
||||
|
||||
class Cursor(BaseCursor["Connection[Any]", Row]):
|
||||
__module__ = "psycopg"
|
||||
__slots__ = ()
|
||||
_Self = TypeVar("_Self", bound="Cursor[Any]")
|
||||
|
||||
@overload
|
||||
def __init__(self: "Cursor[Row]", connection: "Connection[Row]"):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self: "Cursor[Row]",
|
||||
connection: "Connection[Any]",
|
||||
*,
|
||||
row_factory: RowFactory[Row],
|
||||
):
|
||||
...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection: "Connection[Any]",
|
||||
*,
|
||||
row_factory: Optional[RowFactory[Row]] = None,
|
||||
):
|
||||
super().__init__(connection)
|
||||
self._row_factory = row_factory or connection.row_factory
|
||||
|
||||
def __enter__(self: _Self) -> _Self:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Close the current cursor and free associated resources.
|
||||
"""
|
||||
self._close()
|
||||
|
||||
@property
|
||||
def row_factory(self) -> RowFactory[Row]:
|
||||
"""Writable attribute to control how result rows are formed."""
|
||||
return self._row_factory
|
||||
|
||||
@row_factory.setter
|
||||
def row_factory(self, row_factory: RowFactory[Row]) -> None:
|
||||
self._row_factory = row_factory
|
||||
if self.pgresult:
|
||||
self._make_row = row_factory(self)
|
||||
|
||||
def _make_row_maker(self) -> RowMaker[Row]:
|
||||
return self._row_factory(self)
|
||||
|
||||
def execute(
|
||||
self: _Self,
|
||||
query: Query,
|
||||
params: Optional[Params] = None,
|
||||
*,
|
||||
prepare: Optional[bool] = None,
|
||||
binary: Optional[bool] = None,
|
||||
) -> _Self:
|
||||
"""
|
||||
Execute a query or command to the database.
|
||||
"""
|
||||
try:
|
||||
with self._conn.lock:
|
||||
self._conn.wait(
|
||||
self._execute_gen(query, params, prepare=prepare, binary=binary)
|
||||
)
|
||||
except e._NO_TRACEBACK as ex:
|
||||
raise ex.with_traceback(None)
|
||||
return self
|
||||
|
||||
def executemany(
|
||||
self,
|
||||
query: Query,
|
||||
params_seq: Iterable[Params],
|
||||
*,
|
||||
returning: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Execute the same command with a sequence of input data.
|
||||
"""
|
||||
try:
|
||||
if Pipeline.is_supported():
|
||||
# If there is already a pipeline, ride it, in order to avoid
|
||||
# sending unnecessary Sync.
|
||||
with self._conn.lock:
|
||||
p = self._conn._pipeline
|
||||
if p:
|
||||
self._conn.wait(
|
||||
self._executemany_gen_pipeline(query, params_seq, returning)
|
||||
)
|
||||
# Otherwise, make a new one
|
||||
if not p:
|
||||
with self._conn.pipeline(), self._conn.lock:
|
||||
self._conn.wait(
|
||||
self._executemany_gen_pipeline(query, params_seq, returning)
|
||||
)
|
||||
else:
|
||||
with self._conn.lock:
|
||||
self._conn.wait(
|
||||
self._executemany_gen_no_pipeline(query, params_seq, returning)
|
||||
)
|
||||
except e._NO_TRACEBACK as ex:
|
||||
raise ex.with_traceback(None)
|
||||
|
||||
def stream(
|
||||
self,
|
||||
query: Query,
|
||||
params: Optional[Params] = None,
|
||||
*,
|
||||
binary: Optional[bool] = None,
|
||||
) -> Iterator[Row]:
|
||||
"""
|
||||
Iterate row-by-row on a result from the database.
|
||||
"""
|
||||
if self._pgconn.pipeline_status:
|
||||
raise e.ProgrammingError("stream() cannot be used in pipeline mode")
|
||||
|
||||
with self._conn.lock:
|
||||
try:
|
||||
self._conn.wait(self._stream_send_gen(query, params, binary=binary))
|
||||
first = True
|
||||
while self._conn.wait(self._stream_fetchone_gen(first)):
|
||||
# We know that, if we got a result, it has a single row.
|
||||
rec: Row = self._tx.load_row(0, self._make_row) # type: ignore
|
||||
yield rec
|
||||
first = False
|
||||
|
||||
except e._NO_TRACEBACK as ex:
|
||||
raise ex.with_traceback(None)
|
||||
|
||||
finally:
|
||||
if self._pgconn.transaction_status == ACTIVE:
|
||||
# Try to cancel the query, then consume the results
|
||||
# already received.
|
||||
self._conn.cancel()
|
||||
try:
|
||||
while self._conn.wait(self._stream_fetchone_gen(first=False)):
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Try to get out of ACTIVE state. Just do a single attempt, which
|
||||
# should work to recover from an error or query cancelled.
|
||||
try:
|
||||
self._conn.wait(self._stream_fetchone_gen(first=False))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def fetchone(self) -> Optional[Row]:
|
||||
"""
|
||||
Return the next record from the current recordset.
|
||||
|
||||
Return `!None` the recordset is finished.
|
||||
|
||||
:rtype: Optional[Row], with Row defined by `row_factory`
|
||||
"""
|
||||
self._fetch_pipeline()
|
||||
self._check_result_for_fetch()
|
||||
record = self._tx.load_row(self._pos, self._make_row)
|
||||
if record is not None:
|
||||
self._pos += 1
|
||||
return record
|
||||
|
||||
def fetchmany(self, size: int = 0) -> List[Row]:
|
||||
"""
|
||||
Return the next `!size` records from the current recordset.
|
||||
|
||||
`!size` default to `!self.arraysize` if not specified.
|
||||
|
||||
:rtype: Sequence[Row], with Row defined by `row_factory`
|
||||
"""
|
||||
self._fetch_pipeline()
|
||||
self._check_result_for_fetch()
|
||||
assert self.pgresult
|
||||
|
||||
if not size:
|
||||
size = self.arraysize
|
||||
records = self._tx.load_rows(
|
||||
self._pos,
|
||||
min(self._pos + size, self.pgresult.ntuples),
|
||||
self._make_row,
|
||||
)
|
||||
self._pos += len(records)
|
||||
return records
|
||||
|
||||
def fetchall(self) -> List[Row]:
|
||||
"""
|
||||
Return all the remaining records from the current recordset.
|
||||
|
||||
:rtype: Sequence[Row], with Row defined by `row_factory`
|
||||
"""
|
||||
self._fetch_pipeline()
|
||||
self._check_result_for_fetch()
|
||||
assert self.pgresult
|
||||
records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row)
|
||||
self._pos = self.pgresult.ntuples
|
||||
return records
|
||||
|
||||
def __iter__(self) -> Iterator[Row]:
|
||||
self._fetch_pipeline()
|
||||
self._check_result_for_fetch()
|
||||
|
||||
def load(pos: int) -> Optional[Row]:
|
||||
return self._tx.load_row(pos, self._make_row)
|
||||
|
||||
while True:
|
||||
row = load(self._pos)
|
||||
if row is None:
|
||||
break
|
||||
self._pos += 1
|
||||
yield row
|
||||
|
||||
def scroll(self, value: int, mode: str = "relative") -> None:
|
||||
"""
|
||||
Move the cursor in the result set to a new position according to mode.
|
||||
|
||||
If `!mode` is ``'relative'`` (default), `!value` is taken as offset to
|
||||
the current position in the result set; if set to ``'absolute'``,
|
||||
`!value` states an absolute target position.
|
||||
|
||||
Raise `!IndexError` in case a scroll operation would leave the result
|
||||
set. In this case the position will not change.
|
||||
"""
|
||||
self._fetch_pipeline()
|
||||
self._scroll(value, mode)
|
||||
|
||||
@contextmanager
|
||||
def copy(
|
||||
self,
|
||||
statement: Query,
|
||||
params: Optional[Params] = None,
|
||||
*,
|
||||
writer: Optional[CopyWriter] = None,
|
||||
) -> Iterator[Copy]:
|
||||
"""
|
||||
Initiate a :sql:`COPY` operation and return an object to manage it.
|
||||
|
||||
:rtype: Copy
|
||||
"""
|
||||
try:
|
||||
with self._conn.lock:
|
||||
self._conn.wait(self._start_copy_gen(statement, params))
|
||||
|
||||
with Copy(self, writer=writer) as copy:
|
||||
yield copy
|
||||
except e._NO_TRACEBACK as ex:
|
||||
raise ex.with_traceback(None)
|
||||
|
||||
# If a fresher result has been set on the cursor by the Copy object,
|
||||
# read its properties (especially rowcount).
|
||||
self._select_current_result(0)
|
||||
|
||||
def _fetch_pipeline(self) -> None:
|
||||
if (
|
||||
self._execmany_returning is not False
|
||||
and not self.pgresult
|
||||
and self._conn._pipeline
|
||||
):
|
||||
with self._conn.lock:
|
||||
self._conn.wait(self._conn._pipeline._fetch_gen(flush=True))
|
||||
251
srcs/.venv/lib/python3.11/site-packages/psycopg/cursor_async.py
Normal file
251
srcs/.venv/lib/python3.11/site-packages/psycopg/cursor_async.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""
|
||||
psycopg async cursor objects
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
from types import TracebackType
|
||||
from typing import Any, AsyncIterator, Iterable, List
|
||||
from typing import Optional, Type, TypeVar, TYPE_CHECKING, overload
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from . import pq
|
||||
from . import errors as e
|
||||
from .abc import Query, Params
|
||||
from .copy import AsyncCopy, AsyncWriter as AsyncCopyWriter
|
||||
from .rows import Row, RowMaker, AsyncRowFactory
|
||||
from .cursor import BaseCursor
|
||||
from ._pipeline import Pipeline
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .connection_async import AsyncConnection
|
||||
|
||||
ACTIVE = pq.TransactionStatus.ACTIVE
|
||||
|
||||
|
||||
class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
|
||||
__module__ = "psycopg"
|
||||
__slots__ = ()
|
||||
_Self = TypeVar("_Self", bound="AsyncCursor[Any]")
|
||||
|
||||
@overload
|
||||
def __init__(self: "AsyncCursor[Row]", connection: "AsyncConnection[Row]"):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self: "AsyncCursor[Row]",
|
||||
connection: "AsyncConnection[Any]",
|
||||
*,
|
||||
row_factory: AsyncRowFactory[Row],
|
||||
):
|
||||
...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection: "AsyncConnection[Any]",
|
||||
*,
|
||||
row_factory: Optional[AsyncRowFactory[Row]] = None,
|
||||
):
|
||||
super().__init__(connection)
|
||||
self._row_factory = row_factory or connection.row_factory
|
||||
|
||||
async def __aenter__(self: _Self) -> _Self:
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
await self.close()
|
||||
|
||||
async def close(self) -> None:
|
||||
self._close()
|
||||
|
||||
@property
|
||||
def row_factory(self) -> AsyncRowFactory[Row]:
|
||||
return self._row_factory
|
||||
|
||||
@row_factory.setter
|
||||
def row_factory(self, row_factory: AsyncRowFactory[Row]) -> None:
|
||||
self._row_factory = row_factory
|
||||
if self.pgresult:
|
||||
self._make_row = row_factory(self)
|
||||
|
||||
def _make_row_maker(self) -> RowMaker[Row]:
|
||||
return self._row_factory(self)
|
||||
|
||||
async def execute(
|
||||
self: _Self,
|
||||
query: Query,
|
||||
params: Optional[Params] = None,
|
||||
*,
|
||||
prepare: Optional[bool] = None,
|
||||
binary: Optional[bool] = None,
|
||||
) -> _Self:
|
||||
try:
|
||||
async with self._conn.lock:
|
||||
await self._conn.wait(
|
||||
self._execute_gen(query, params, prepare=prepare, binary=binary)
|
||||
)
|
||||
except e._NO_TRACEBACK as ex:
|
||||
raise ex.with_traceback(None)
|
||||
return self
|
||||
|
||||
async def executemany(
|
||||
self,
|
||||
query: Query,
|
||||
params_seq: Iterable[Params],
|
||||
*,
|
||||
returning: bool = False,
|
||||
) -> None:
|
||||
try:
|
||||
if Pipeline.is_supported():
|
||||
# If there is already a pipeline, ride it, in order to avoid
|
||||
# sending unnecessary Sync.
|
||||
async with self._conn.lock:
|
||||
p = self._conn._pipeline
|
||||
if p:
|
||||
await self._conn.wait(
|
||||
self._executemany_gen_pipeline(query, params_seq, returning)
|
||||
)
|
||||
# Otherwise, make a new one
|
||||
if not p:
|
||||
async with self._conn.pipeline(), self._conn.lock:
|
||||
await self._conn.wait(
|
||||
self._executemany_gen_pipeline(query, params_seq, returning)
|
||||
)
|
||||
else:
|
||||
async with self._conn.lock:
|
||||
await self._conn.wait(
|
||||
self._executemany_gen_no_pipeline(query, params_seq, returning)
|
||||
)
|
||||
except e._NO_TRACEBACK as ex:
|
||||
raise ex.with_traceback(None)
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
query: Query,
|
||||
params: Optional[Params] = None,
|
||||
*,
|
||||
binary: Optional[bool] = None,
|
||||
) -> AsyncIterator[Row]:
|
||||
if self._pgconn.pipeline_status:
|
||||
raise e.ProgrammingError("stream() cannot be used in pipeline mode")
|
||||
|
||||
async with self._conn.lock:
|
||||
try:
|
||||
await self._conn.wait(
|
||||
self._stream_send_gen(query, params, binary=binary)
|
||||
)
|
||||
first = True
|
||||
while await self._conn.wait(self._stream_fetchone_gen(first)):
|
||||
# We know that, if we got a result, it has a single row.
|
||||
rec: Row = self._tx.load_row(0, self._make_row) # type: ignore
|
||||
yield rec
|
||||
first = False
|
||||
|
||||
except e._NO_TRACEBACK as ex:
|
||||
raise ex.with_traceback(None)
|
||||
|
||||
finally:
|
||||
if self._pgconn.transaction_status == ACTIVE:
|
||||
# Try to cancel the query, then consume the results
|
||||
# already received.
|
||||
self._conn.cancel()
|
||||
try:
|
||||
while await self._conn.wait(
|
||||
self._stream_fetchone_gen(first=False)
|
||||
):
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Try to get out of ACTIVE state. Just do a single attempt, which
|
||||
# should work to recover from an error or query cancelled.
|
||||
try:
|
||||
await self._conn.wait(self._stream_fetchone_gen(first=False))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def fetchone(self) -> Optional[Row]:
|
||||
await self._fetch_pipeline()
|
||||
self._check_result_for_fetch()
|
||||
record = self._tx.load_row(self._pos, self._make_row)
|
||||
if record is not None:
|
||||
self._pos += 1
|
||||
return record
|
||||
|
||||
async def fetchmany(self, size: int = 0) -> List[Row]:
|
||||
await self._fetch_pipeline()
|
||||
self._check_result_for_fetch()
|
||||
assert self.pgresult
|
||||
|
||||
if not size:
|
||||
size = self.arraysize
|
||||
records = self._tx.load_rows(
|
||||
self._pos,
|
||||
min(self._pos + size, self.pgresult.ntuples),
|
||||
self._make_row,
|
||||
)
|
||||
self._pos += len(records)
|
||||
return records
|
||||
|
||||
async def fetchall(self) -> List[Row]:
|
||||
await self._fetch_pipeline()
|
||||
self._check_result_for_fetch()
|
||||
assert self.pgresult
|
||||
records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row)
|
||||
self._pos = self.pgresult.ntuples
|
||||
return records
|
||||
|
||||
async def __aiter__(self) -> AsyncIterator[Row]:
|
||||
await self._fetch_pipeline()
|
||||
self._check_result_for_fetch()
|
||||
|
||||
def load(pos: int) -> Optional[Row]:
|
||||
return self._tx.load_row(pos, self._make_row)
|
||||
|
||||
while True:
|
||||
row = load(self._pos)
|
||||
if row is None:
|
||||
break
|
||||
self._pos += 1
|
||||
yield row
|
||||
|
||||
async def scroll(self, value: int, mode: str = "relative") -> None:
|
||||
await self._fetch_pipeline()
|
||||
self._scroll(value, mode)
|
||||
|
||||
@asynccontextmanager
|
||||
async def copy(
|
||||
self,
|
||||
statement: Query,
|
||||
params: Optional[Params] = None,
|
||||
*,
|
||||
writer: Optional[AsyncCopyWriter] = None,
|
||||
) -> AsyncIterator[AsyncCopy]:
|
||||
"""
|
||||
:rtype: AsyncCopy
|
||||
"""
|
||||
try:
|
||||
async with self._conn.lock:
|
||||
await self._conn.wait(self._start_copy_gen(statement, params))
|
||||
|
||||
async with AsyncCopy(self, writer=writer) as copy:
|
||||
yield copy
|
||||
except e._NO_TRACEBACK as ex:
|
||||
raise ex.with_traceback(None)
|
||||
|
||||
self._select_current_result(0)
|
||||
|
||||
async def _fetch_pipeline(self) -> None:
|
||||
if (
|
||||
self._execmany_returning is not False
|
||||
and not self.pgresult
|
||||
and self._conn._pipeline
|
||||
):
|
||||
async with self._conn.lock:
|
||||
await self._conn.wait(self._conn._pipeline._fetch_gen(flush=True))
|
||||
112
srcs/.venv/lib/python3.11/site-packages/psycopg/dbapi20.py
Normal file
112
srcs/.venv/lib/python3.11/site-packages/psycopg/dbapi20.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""
|
||||
Compatibility objects with DBAPI 2.0
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
import time
|
||||
import datetime as dt
|
||||
from math import floor
|
||||
from typing import Any, Sequence, Union
|
||||
|
||||
from . import postgres
|
||||
from .abc import AdaptContext, Buffer
|
||||
from .types.string import BytesDumper, BytesBinaryDumper
|
||||
|
||||
|
||||
class DBAPITypeObject:
|
||||
def __init__(self, name: str, type_names: Sequence[str]):
|
||||
self.name = name
|
||||
self.values = tuple(postgres.types[n].oid for n in type_names)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"psycopg.{self.name}"
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, int):
|
||||
return other in self.values
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
if isinstance(other, int):
|
||||
return other not in self.values
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
|
||||
BINARY = DBAPITypeObject("BINARY", ("bytea",))
|
||||
DATETIME = DBAPITypeObject(
|
||||
"DATETIME", "timestamp timestamptz date time timetz interval".split()
|
||||
)
|
||||
NUMBER = DBAPITypeObject("NUMBER", "int2 int4 int8 float4 float8 numeric".split())
|
||||
ROWID = DBAPITypeObject("ROWID", ("oid",))
|
||||
STRING = DBAPITypeObject("STRING", "text varchar bpchar".split())
|
||||
|
||||
|
||||
class Binary:
|
||||
def __init__(self, obj: Any):
|
||||
self.obj = obj
|
||||
|
||||
def __repr__(self) -> str:
|
||||
sobj = repr(self.obj)
|
||||
if len(sobj) > 40:
|
||||
sobj = f"{sobj[:35]} ... ({len(sobj)} byteschars)"
|
||||
return f"{self.__class__.__name__}({sobj})"
|
||||
|
||||
|
||||
class BinaryBinaryDumper(BytesBinaryDumper):
|
||||
def dump(self, obj: Union[Buffer, Binary]) -> Buffer:
|
||||
if isinstance(obj, Binary):
|
||||
return super().dump(obj.obj)
|
||||
else:
|
||||
return super().dump(obj)
|
||||
|
||||
|
||||
class BinaryTextDumper(BytesDumper):
|
||||
def dump(self, obj: Union[Buffer, Binary]) -> Buffer:
|
||||
if isinstance(obj, Binary):
|
||||
return super().dump(obj.obj)
|
||||
else:
|
||||
return super().dump(obj)
|
||||
|
||||
|
||||
def Date(year: int, month: int, day: int) -> dt.date:
|
||||
return dt.date(year, month, day)
|
||||
|
||||
|
||||
def DateFromTicks(ticks: float) -> dt.date:
|
||||
return TimestampFromTicks(ticks).date()
|
||||
|
||||
|
||||
def Time(hour: int, minute: int, second: int) -> dt.time:
|
||||
return dt.time(hour, minute, second)
|
||||
|
||||
|
||||
def TimeFromTicks(ticks: float) -> dt.time:
|
||||
return TimestampFromTicks(ticks).time()
|
||||
|
||||
|
||||
def Timestamp(
|
||||
year: int, month: int, day: int, hour: int, minute: int, second: int
|
||||
) -> dt.datetime:
|
||||
return dt.datetime(year, month, day, hour, minute, second)
|
||||
|
||||
|
||||
def TimestampFromTicks(ticks: float) -> dt.datetime:
|
||||
secs = floor(ticks)
|
||||
frac = ticks - secs
|
||||
t = time.localtime(ticks)
|
||||
tzinfo = dt.timezone(dt.timedelta(seconds=t.tm_gmtoff))
|
||||
rv = dt.datetime(*t[:6], round(frac * 1_000_000), tzinfo=tzinfo)
|
||||
return rv
|
||||
|
||||
|
||||
def register_dbapi20_adapters(context: AdaptContext) -> None:
|
||||
adapters = context.adapters
|
||||
adapters.register_dumper(Binary, BinaryTextDumper)
|
||||
adapters.register_dumper(Binary, BinaryBinaryDumper)
|
||||
|
||||
# Make them also the default dumpers when dumping by bytea oid
|
||||
adapters.register_dumper(None, BinaryTextDumper)
|
||||
adapters.register_dumper(None, BinaryBinaryDumper)
|
||||
1727
srcs/.venv/lib/python3.11/site-packages/psycopg/errors.py
Normal file
1727
srcs/.venv/lib/python3.11/site-packages/psycopg/errors.py
Normal file
File diff suppressed because it is too large
Load Diff
333
srcs/.venv/lib/python3.11/site-packages/psycopg/generators.py
Normal file
333
srcs/.venv/lib/python3.11/site-packages/psycopg/generators.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""
|
||||
Generators implementing communication protocols with the libpq
|
||||
|
||||
Certain operations (connection, querying) are an interleave of libpq calls and
|
||||
waiting for the socket to be ready. This module contains the code to execute
|
||||
the operations, yielding a polling state whenever there is to wait. The
|
||||
functions in the `waiting` module are the ones who wait more or less
|
||||
cooperatively for the socket to be ready and make these generators continue.
|
||||
|
||||
All these generators yield pairs (fileno, `Wait`) whenever an operation would
|
||||
block. The generator can be restarted sending the appropriate `Ready` state
|
||||
when the file descriptor is ready.
|
||||
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from . import pq
|
||||
from . import errors as e
|
||||
from .abc import Buffer, PipelineCommand, PQGen, PQGenConn
|
||||
from .pq.abc import PGconn, PGresult
|
||||
from .waiting import Wait, Ready
|
||||
from ._compat import Deque
|
||||
from ._cmodule import _psycopg
|
||||
from ._encodings import pgconn_encoding, conninfo_encoding
|
||||
|
||||
OK = pq.ConnStatus.OK
|
||||
BAD = pq.ConnStatus.BAD
|
||||
|
||||
POLL_OK = pq.PollingStatus.OK
|
||||
POLL_READING = pq.PollingStatus.READING
|
||||
POLL_WRITING = pq.PollingStatus.WRITING
|
||||
POLL_FAILED = pq.PollingStatus.FAILED
|
||||
|
||||
COMMAND_OK = pq.ExecStatus.COMMAND_OK
|
||||
COPY_OUT = pq.ExecStatus.COPY_OUT
|
||||
COPY_IN = pq.ExecStatus.COPY_IN
|
||||
COPY_BOTH = pq.ExecStatus.COPY_BOTH
|
||||
PIPELINE_SYNC = pq.ExecStatus.PIPELINE_SYNC
|
||||
|
||||
WAIT_R = Wait.R
|
||||
WAIT_W = Wait.W
|
||||
WAIT_RW = Wait.RW
|
||||
READY_R = Ready.R
|
||||
READY_W = Ready.W
|
||||
READY_RW = Ready.RW
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _connect(conninfo: str) -> PQGenConn[PGconn]:
|
||||
"""
|
||||
Generator to create a database connection without blocking.
|
||||
|
||||
"""
|
||||
conn = pq.PGconn.connect_start(conninfo.encode())
|
||||
while True:
|
||||
if conn.status == BAD:
|
||||
encoding = conninfo_encoding(conninfo)
|
||||
raise e.OperationalError(
|
||||
f"connection is bad: {pq.error_message(conn, encoding=encoding)}",
|
||||
pgconn=conn,
|
||||
)
|
||||
|
||||
status = conn.connect_poll()
|
||||
if status == POLL_OK:
|
||||
break
|
||||
elif status == POLL_READING:
|
||||
yield conn.socket, WAIT_R
|
||||
elif status == POLL_WRITING:
|
||||
yield conn.socket, WAIT_W
|
||||
elif status == POLL_FAILED:
|
||||
encoding = conninfo_encoding(conninfo)
|
||||
raise e.OperationalError(
|
||||
f"connection failed: {pq.error_message(conn, encoding=encoding)}",
|
||||
pgconn=e.finish_pgconn(conn),
|
||||
)
|
||||
else:
|
||||
raise e.InternalError(
|
||||
f"unexpected poll status: {status}", pgconn=e.finish_pgconn(conn)
|
||||
)
|
||||
|
||||
conn.nonblocking = 1
|
||||
return conn
|
||||
|
||||
|
||||
def _execute(pgconn: PGconn) -> PQGen[List[PGresult]]:
|
||||
"""
|
||||
Generator sending a query and returning results without blocking.
|
||||
|
||||
The query must have already been sent using `pgconn.send_query()` or
|
||||
similar. Flush the query and then return the result using nonblocking
|
||||
functions.
|
||||
|
||||
Return the list of results returned by the database (whether success
|
||||
or error).
|
||||
"""
|
||||
yield from _send(pgconn)
|
||||
rv = yield from _fetch_many(pgconn)
|
||||
return rv
|
||||
|
||||
|
||||
def _send(pgconn: PGconn) -> PQGen[None]:
|
||||
"""
|
||||
Generator to send a query to the server without blocking.
|
||||
|
||||
The query must have already been sent using `pgconn.send_query()` or
|
||||
similar. Flush the query and then return the result using nonblocking
|
||||
functions.
|
||||
|
||||
After this generator has finished you may want to cycle using `fetch()`
|
||||
to retrieve the results available.
|
||||
"""
|
||||
while True:
|
||||
f = pgconn.flush()
|
||||
if f == 0:
|
||||
break
|
||||
|
||||
ready = yield WAIT_RW
|
||||
if ready & READY_R:
|
||||
# This call may read notifies: they will be saved in the
|
||||
# PGconn buffer and passed to Python later, in `fetch()`.
|
||||
pgconn.consume_input()
|
||||
|
||||
|
||||
def _fetch_many(pgconn: PGconn) -> PQGen[List[PGresult]]:
|
||||
"""
|
||||
Generator retrieving results from the database without blocking.
|
||||
|
||||
The query must have already been sent to the server, so pgconn.flush() has
|
||||
already returned 0.
|
||||
|
||||
Return the list of results returned by the database (whether success
|
||||
or error).
|
||||
"""
|
||||
results: List[PGresult] = []
|
||||
while True:
|
||||
res = yield from _fetch(pgconn)
|
||||
if not res:
|
||||
break
|
||||
|
||||
results.append(res)
|
||||
status = res.status
|
||||
if status == COPY_IN or status == COPY_OUT or status == COPY_BOTH:
|
||||
# After entering copy mode the libpq will create a phony result
|
||||
# for every request so let's break the endless loop.
|
||||
break
|
||||
|
||||
if status == PIPELINE_SYNC:
|
||||
# PIPELINE_SYNC is not followed by a NULL, but we return it alone
|
||||
# similarly to other result sets.
|
||||
assert len(results) == 1, results
|
||||
break
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _fetch(pgconn: PGconn) -> PQGen[Optional[PGresult]]:
|
||||
"""
|
||||
Generator retrieving a single result from the database without blocking.
|
||||
|
||||
The query must have already been sent to the server, so pgconn.flush() has
|
||||
already returned 0.
|
||||
|
||||
Return a result from the database (whether success or error).
|
||||
"""
|
||||
if pgconn.is_busy():
|
||||
yield WAIT_R
|
||||
while True:
|
||||
pgconn.consume_input()
|
||||
if not pgconn.is_busy():
|
||||
break
|
||||
yield WAIT_R
|
||||
|
||||
_consume_notifies(pgconn)
|
||||
|
||||
return pgconn.get_result()
|
||||
|
||||
|
||||
def _pipeline_communicate(
|
||||
pgconn: PGconn, commands: Deque[PipelineCommand]
|
||||
) -> PQGen[List[List[PGresult]]]:
|
||||
"""Generator to send queries from a connection in pipeline mode while also
|
||||
receiving results.
|
||||
|
||||
Return a list results, including single PIPELINE_SYNC elements.
|
||||
"""
|
||||
results = []
|
||||
|
||||
while True:
|
||||
ready = yield WAIT_RW
|
||||
|
||||
if ready & READY_R:
|
||||
pgconn.consume_input()
|
||||
_consume_notifies(pgconn)
|
||||
|
||||
res: List[PGresult] = []
|
||||
while not pgconn.is_busy():
|
||||
r = pgconn.get_result()
|
||||
if r is None:
|
||||
if not res:
|
||||
break
|
||||
results.append(res)
|
||||
res = []
|
||||
else:
|
||||
status = r.status
|
||||
if status == PIPELINE_SYNC:
|
||||
assert not res
|
||||
results.append([r])
|
||||
elif status == COPY_IN or status == COPY_OUT or status == COPY_BOTH:
|
||||
# This shouldn't happen, but insisting hard enough, it will.
|
||||
# For instance, in test_executemany_badquery(), with the COPY
|
||||
# statement and the AsyncClientCursor, which disables
|
||||
# prepared statements).
|
||||
# Bail out from the resulting infinite loop.
|
||||
raise e.NotSupportedError(
|
||||
"COPY cannot be used in pipeline mode"
|
||||
)
|
||||
else:
|
||||
res.append(r)
|
||||
|
||||
if ready & READY_W:
|
||||
pgconn.flush()
|
||||
if not commands:
|
||||
break
|
||||
commands.popleft()()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _consume_notifies(pgconn: PGconn) -> None:
|
||||
# Consume notifies
|
||||
while True:
|
||||
n = pgconn.notifies()
|
||||
if not n:
|
||||
break
|
||||
if pgconn.notify_handler:
|
||||
pgconn.notify_handler(n)
|
||||
|
||||
|
||||
def notifies(pgconn: PGconn) -> PQGen[List[pq.PGnotify]]:
|
||||
yield WAIT_R
|
||||
pgconn.consume_input()
|
||||
|
||||
ns = []
|
||||
while True:
|
||||
n = pgconn.notifies()
|
||||
if n:
|
||||
ns.append(n)
|
||||
else:
|
||||
break
|
||||
|
||||
return ns
|
||||
|
||||
|
||||
def copy_from(pgconn: PGconn) -> PQGen[Union[memoryview, PGresult]]:
|
||||
while True:
|
||||
nbytes, data = pgconn.get_copy_data(1)
|
||||
if nbytes != 0:
|
||||
break
|
||||
|
||||
# would block
|
||||
yield WAIT_R
|
||||
pgconn.consume_input()
|
||||
|
||||
if nbytes > 0:
|
||||
# some data
|
||||
return data
|
||||
|
||||
# Retrieve the final result of copy
|
||||
results = yield from _fetch_many(pgconn)
|
||||
if len(results) > 1:
|
||||
# TODO: too brutal? Copy worked.
|
||||
raise e.ProgrammingError("you cannot mix COPY with other operations")
|
||||
result = results[0]
|
||||
if result.status != COMMAND_OK:
|
||||
encoding = pgconn_encoding(pgconn)
|
||||
raise e.error_from_result(result, encoding=encoding)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def copy_to(pgconn: PGconn, buffer: Buffer) -> PQGen[None]:
|
||||
# Retry enqueuing data until successful.
|
||||
#
|
||||
# WARNING! This can cause an infinite loop if the buffer is too large. (see
|
||||
# ticket #255). We avoid it in the Copy object by splitting a large buffer
|
||||
# into smaller ones. We prefer to do it there instead of here in order to
|
||||
# do it upstream the queue decoupling the writer task from the producer one.
|
||||
while pgconn.put_copy_data(buffer) == 0:
|
||||
yield WAIT_W
|
||||
|
||||
|
||||
def copy_end(pgconn: PGconn, error: Optional[bytes]) -> PQGen[PGresult]:
|
||||
# Retry enqueuing end copy message until successful
|
||||
while pgconn.put_copy_end(error) == 0:
|
||||
yield WAIT_W
|
||||
|
||||
# Repeat until it the message is flushed to the server
|
||||
while True:
|
||||
yield WAIT_W
|
||||
f = pgconn.flush()
|
||||
if f == 0:
|
||||
break
|
||||
|
||||
# Retrieve the final result of copy
|
||||
(result,) = yield from _fetch_many(pgconn)
|
||||
if result.status != COMMAND_OK:
|
||||
encoding = pgconn_encoding(pgconn)
|
||||
raise e.error_from_result(result, encoding=encoding)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# Override functions with fast versions if available
|
||||
if _psycopg:
|
||||
connect = _psycopg.connect
|
||||
execute = _psycopg.execute
|
||||
send = _psycopg.send
|
||||
fetch_many = _psycopg.fetch_many
|
||||
fetch = _psycopg.fetch
|
||||
pipeline_communicate = _psycopg.pipeline_communicate
|
||||
|
||||
else:
|
||||
connect = _connect
|
||||
execute = _execute
|
||||
send = _send
|
||||
fetch_many = _fetch_many
|
||||
fetch = _fetch
|
||||
pipeline_communicate = _pipeline_communicate
|
||||
124
srcs/.venv/lib/python3.11/site-packages/psycopg/postgres.py
Normal file
124
srcs/.venv/lib/python3.11/site-packages/psycopg/postgres.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
Types configuration specific to PostgreSQL.
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
from ._typeinfo import TypeInfo, RangeInfo, MultirangeInfo, TypesRegistry
|
||||
from .abc import AdaptContext
|
||||
from ._adapters_map import AdaptersMap
|
||||
|
||||
# Global objects with PostgreSQL builtins and globally registered user types.
|
||||
types = TypesRegistry()
|
||||
|
||||
# Global adapter maps with PostgreSQL types configuration
|
||||
adapters = AdaptersMap(types=types)
|
||||
|
||||
# Use tools/update_oids.py to update this data.
|
||||
for t in [
|
||||
TypeInfo('"char"', 18, 1002),
|
||||
# autogenerated: start
|
||||
# Generated from PostgreSQL 16.0
|
||||
TypeInfo("aclitem", 1033, 1034),
|
||||
TypeInfo("bit", 1560, 1561),
|
||||
TypeInfo("bool", 16, 1000, regtype="boolean"),
|
||||
TypeInfo("box", 603, 1020, delimiter=";"),
|
||||
TypeInfo("bpchar", 1042, 1014, regtype="character"),
|
||||
TypeInfo("bytea", 17, 1001),
|
||||
TypeInfo("cid", 29, 1012),
|
||||
TypeInfo("cidr", 650, 651),
|
||||
TypeInfo("circle", 718, 719),
|
||||
TypeInfo("date", 1082, 1182),
|
||||
TypeInfo("float4", 700, 1021, regtype="real"),
|
||||
TypeInfo("float8", 701, 1022, regtype="double precision"),
|
||||
TypeInfo("gtsvector", 3642, 3644),
|
||||
TypeInfo("inet", 869, 1041),
|
||||
TypeInfo("int2", 21, 1005, regtype="smallint"),
|
||||
TypeInfo("int2vector", 22, 1006),
|
||||
TypeInfo("int4", 23, 1007, regtype="integer"),
|
||||
TypeInfo("int8", 20, 1016, regtype="bigint"),
|
||||
TypeInfo("interval", 1186, 1187),
|
||||
TypeInfo("json", 114, 199),
|
||||
TypeInfo("jsonb", 3802, 3807),
|
||||
TypeInfo("jsonpath", 4072, 4073),
|
||||
TypeInfo("line", 628, 629),
|
||||
TypeInfo("lseg", 601, 1018),
|
||||
TypeInfo("macaddr", 829, 1040),
|
||||
TypeInfo("macaddr8", 774, 775),
|
||||
TypeInfo("money", 790, 791),
|
||||
TypeInfo("name", 19, 1003),
|
||||
TypeInfo("numeric", 1700, 1231),
|
||||
TypeInfo("oid", 26, 1028),
|
||||
TypeInfo("oidvector", 30, 1013),
|
||||
TypeInfo("path", 602, 1019),
|
||||
TypeInfo("pg_lsn", 3220, 3221),
|
||||
TypeInfo("point", 600, 1017),
|
||||
TypeInfo("polygon", 604, 1027),
|
||||
TypeInfo("record", 2249, 2287),
|
||||
TypeInfo("refcursor", 1790, 2201),
|
||||
TypeInfo("regclass", 2205, 2210),
|
||||
TypeInfo("regcollation", 4191, 4192),
|
||||
TypeInfo("regconfig", 3734, 3735),
|
||||
TypeInfo("regdictionary", 3769, 3770),
|
||||
TypeInfo("regnamespace", 4089, 4090),
|
||||
TypeInfo("regoper", 2203, 2208),
|
||||
TypeInfo("regoperator", 2204, 2209),
|
||||
TypeInfo("regproc", 24, 1008),
|
||||
TypeInfo("regprocedure", 2202, 2207),
|
||||
TypeInfo("regrole", 4096, 4097),
|
||||
TypeInfo("regtype", 2206, 2211),
|
||||
TypeInfo("text", 25, 1009),
|
||||
TypeInfo("tid", 27, 1010),
|
||||
TypeInfo("time", 1083, 1183, regtype="time without time zone"),
|
||||
TypeInfo("timestamp", 1114, 1115, regtype="timestamp without time zone"),
|
||||
TypeInfo("timestamptz", 1184, 1185, regtype="timestamp with time zone"),
|
||||
TypeInfo("timetz", 1266, 1270, regtype="time with time zone"),
|
||||
TypeInfo("tsquery", 3615, 3645),
|
||||
TypeInfo("tsvector", 3614, 3643),
|
||||
TypeInfo("txid_snapshot", 2970, 2949),
|
||||
TypeInfo("uuid", 2950, 2951),
|
||||
TypeInfo("varbit", 1562, 1563, regtype="bit varying"),
|
||||
TypeInfo("varchar", 1043, 1015, regtype="character varying"),
|
||||
TypeInfo("xid", 28, 1011),
|
||||
TypeInfo("xid8", 5069, 271),
|
||||
TypeInfo("xml", 142, 143),
|
||||
RangeInfo("daterange", 3912, 3913, subtype_oid=1082),
|
||||
RangeInfo("int4range", 3904, 3905, subtype_oid=23),
|
||||
RangeInfo("int8range", 3926, 3927, subtype_oid=20),
|
||||
RangeInfo("numrange", 3906, 3907, subtype_oid=1700),
|
||||
RangeInfo("tsrange", 3908, 3909, subtype_oid=1114),
|
||||
RangeInfo("tstzrange", 3910, 3911, subtype_oid=1184),
|
||||
MultirangeInfo("datemultirange", 4535, 6155, range_oid=3912, subtype_oid=1082),
|
||||
MultirangeInfo("int4multirange", 4451, 6150, range_oid=3904, subtype_oid=23),
|
||||
MultirangeInfo("int8multirange", 4536, 6157, range_oid=3926, subtype_oid=20),
|
||||
MultirangeInfo("nummultirange", 4532, 6151, range_oid=3906, subtype_oid=1700),
|
||||
MultirangeInfo("tsmultirange", 4533, 6152, range_oid=3908, subtype_oid=1114),
|
||||
MultirangeInfo("tstzmultirange", 4534, 6153, range_oid=3910, subtype_oid=1184),
|
||||
# autogenerated: end
|
||||
]:
|
||||
types.add(t)
|
||||
|
||||
|
||||
# A few oids used a bit everywhere
|
||||
INVALID_OID = 0
|
||||
TEXT_OID = types["text"].oid
|
||||
TEXT_ARRAY_OID = types["text"].array_oid
|
||||
|
||||
|
||||
def register_default_adapters(context: AdaptContext) -> None:
|
||||
from .types import array, bool, composite, datetime, enum, json, multirange
|
||||
from .types import net, none, numeric, range, string, uuid
|
||||
|
||||
array.register_default_adapters(context)
|
||||
bool.register_default_adapters(context)
|
||||
composite.register_default_adapters(context)
|
||||
datetime.register_default_adapters(context)
|
||||
enum.register_default_adapters(context)
|
||||
json.register_default_adapters(context)
|
||||
multirange.register_default_adapters(context)
|
||||
net.register_default_adapters(context)
|
||||
none.register_default_adapters(context)
|
||||
numeric.register_default_adapters(context)
|
||||
range.register_default_adapters(context)
|
||||
string.register_default_adapters(context)
|
||||
uuid.register_default_adapters(context)
|
||||
133
srcs/.venv/lib/python3.11/site-packages/psycopg/pq/__init__.py
Normal file
133
srcs/.venv/lib/python3.11/site-packages/psycopg/pq/__init__.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
psycopg libpq wrapper
|
||||
|
||||
This package exposes the libpq functionalities as Python objects and functions.
|
||||
|
||||
The real implementation (the binding to the C library) is
|
||||
implementation-dependant but all the implementations share the same interface.
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Callable, List, Type
|
||||
|
||||
from . import abc
|
||||
from .misc import ConninfoOption, PGnotify, PGresAttDesc
|
||||
from .misc import error_message
|
||||
from ._enums import ConnStatus, DiagnosticField, ExecStatus, Format, Trace
|
||||
from ._enums import Ping, PipelineStatus, PollingStatus, TransactionStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__impl__: str
|
||||
"""The currently loaded implementation of the `!psycopg.pq` package.
|
||||
|
||||
Possible values include ``python``, ``c``, ``binary``.
|
||||
"""
|
||||
|
||||
__build_version__: int
|
||||
"""The libpq version the C package was built with.
|
||||
|
||||
A number in the same format of `~psycopg.ConnectionInfo.server_version`
|
||||
representing the libpq used to build the speedup module (``c``, ``binary``) if
|
||||
available.
|
||||
|
||||
Certain features might not be available if the built version is too old.
|
||||
"""
|
||||
|
||||
version: Callable[[], int]
|
||||
PGconn: Type[abc.PGconn]
|
||||
PGresult: Type[abc.PGresult]
|
||||
Conninfo: Type[abc.Conninfo]
|
||||
Escaping: Type[abc.Escaping]
|
||||
PGcancel: Type[abc.PGcancel]
|
||||
|
||||
|
||||
def import_from_libpq() -> None:
|
||||
"""
|
||||
Import pq objects implementation from the best libpq wrapper available.
|
||||
|
||||
If an implementation is requested try to import only it, otherwise
|
||||
try to import the best implementation available.
|
||||
"""
|
||||
# import these names into the module on success as side effect
|
||||
global __impl__, version, __build_version__
|
||||
global PGconn, PGresult, Conninfo, Escaping, PGcancel
|
||||
|
||||
impl = os.environ.get("PSYCOPG_IMPL", "").lower()
|
||||
module = None
|
||||
attempts: List[str] = []
|
||||
|
||||
def handle_error(name: str, e: Exception) -> None:
|
||||
if not impl:
|
||||
msg = f"couldn't import psycopg '{name}' implementation: {e}"
|
||||
logger.debug(msg)
|
||||
attempts.append(msg)
|
||||
else:
|
||||
msg = f"couldn't import requested psycopg '{name}' implementation: {e}"
|
||||
raise ImportError(msg) from e
|
||||
|
||||
# The best implementation: fast but requires the system libpq installed
|
||||
if not impl or impl == "c":
|
||||
try:
|
||||
from psycopg_c import pq as module # type: ignore
|
||||
except Exception as e:
|
||||
handle_error("c", e)
|
||||
|
||||
# Second best implementation: fast and stand-alone
|
||||
if not module and (not impl or impl == "binary"):
|
||||
try:
|
||||
from psycopg_binary import pq as module # type: ignore
|
||||
except Exception as e:
|
||||
handle_error("binary", e)
|
||||
|
||||
# Pure Python implementation, slow and requires the system libpq installed.
|
||||
if not module and (not impl or impl == "python"):
|
||||
try:
|
||||
from . import pq_ctypes as module # type: ignore[assignment]
|
||||
except Exception as e:
|
||||
handle_error("python", e)
|
||||
|
||||
if module:
|
||||
__impl__ = module.__impl__
|
||||
version = module.version
|
||||
PGconn = module.PGconn
|
||||
PGresult = module.PGresult
|
||||
Conninfo = module.Conninfo
|
||||
Escaping = module.Escaping
|
||||
PGcancel = module.PGcancel
|
||||
__build_version__ = module.__build_version__
|
||||
elif impl:
|
||||
raise ImportError(f"requested psycopg implementation '{impl}' unknown")
|
||||
else:
|
||||
sattempts = "\n".join(f"- {attempt}" for attempt in attempts)
|
||||
raise ImportError(
|
||||
f"""\
|
||||
no pq wrapper available.
|
||||
Attempts made:
|
||||
{sattempts}"""
|
||||
)
|
||||
|
||||
|
||||
import_from_libpq()
|
||||
|
||||
__all__ = (
|
||||
"ConnStatus",
|
||||
"PipelineStatus",
|
||||
"PollingStatus",
|
||||
"TransactionStatus",
|
||||
"ExecStatus",
|
||||
"Ping",
|
||||
"DiagnosticField",
|
||||
"Format",
|
||||
"Trace",
|
||||
"PGconn",
|
||||
"PGnotify",
|
||||
"Conninfo",
|
||||
"PGresAttDesc",
|
||||
"error_message",
|
||||
"ConninfoOption",
|
||||
"version",
|
||||
)
|
||||
106
srcs/.venv/lib/python3.11/site-packages/psycopg/pq/_debug.py
Normal file
106
srcs/.venv/lib/python3.11/site-packages/psycopg/pq/_debug.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""
|
||||
libpq debugging tools
|
||||
|
||||
These functionalities are exposed here for convenience, but are not part of
|
||||
the public interface and are subject to change at any moment.
|
||||
|
||||
Suggested usage::
|
||||
|
||||
import logging
|
||||
import psycopg
|
||||
from psycopg import pq
|
||||
from psycopg.pq._debug import PGconnDebug
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
||||
logger = logging.getLogger("psycopg.debug")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
assert pq.__impl__ == "python"
|
||||
pq.PGconn = PGconnDebug
|
||||
|
||||
with psycopg.connect("") as conn:
|
||||
conn.pgconn.trace(2)
|
||||
conn.pgconn.set_trace_flags(
|
||||
pq.Trace.SUPPRESS_TIMESTAMPS | pq.Trace.REGRESS_MODE)
|
||||
...
|
||||
|
||||
"""
|
||||
|
||||
# Copyright (C) 2022 The Psycopg Team
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Any, Callable, Type, TypeVar, TYPE_CHECKING
|
||||
from functools import wraps
|
||||
|
||||
from . import PGconn
|
||||
from .misc import connection_summary
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import abc
|
||||
|
||||
Func = TypeVar("Func", bound=Callable[..., Any])
|
||||
|
||||
logger = logging.getLogger("psycopg.debug")
|
||||
|
||||
|
||||
class PGconnDebug:
|
||||
"""Wrapper for a PQconn logging all its access."""
|
||||
|
||||
_Self = TypeVar("_Self", bound="PGconnDebug")
|
||||
_pgconn: "abc.PGconn"
|
||||
|
||||
def __init__(self, pgconn: "abc.PGconn"):
|
||||
super().__setattr__("_pgconn", pgconn)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
|
||||
info = connection_summary(self._pgconn)
|
||||
return f"<{cls} {info} at 0x{id(self):x}>"
|
||||
|
||||
def __getattr__(self, attr: str) -> Any:
|
||||
value = getattr(self._pgconn, attr)
|
||||
if callable(value):
|
||||
return debugging(value)
|
||||
else:
|
||||
logger.info("PGconn.%s -> %s", attr, value)
|
||||
return value
|
||||
|
||||
def __setattr__(self, attr: str, value: Any) -> None:
|
||||
setattr(self._pgconn, attr, value)
|
||||
logger.info("PGconn.%s <- %s", attr, value)
|
||||
|
||||
@classmethod
|
||||
def connect(cls: Type[_Self], conninfo: bytes) -> _Self:
|
||||
return cls(debugging(PGconn.connect)(conninfo))
|
||||
|
||||
@classmethod
|
||||
def connect_start(cls: Type[_Self], conninfo: bytes) -> _Self:
|
||||
return cls(debugging(PGconn.connect_start)(conninfo))
|
||||
|
||||
@classmethod
|
||||
def ping(self, conninfo: bytes) -> int:
|
||||
return debugging(PGconn.ping)(conninfo)
|
||||
|
||||
|
||||
def debugging(f: Func) -> Func:
|
||||
"""Wrap a function in order to log its arguments and return value on call."""
|
||||
|
||||
@wraps(f)
|
||||
def debugging_(*args: Any, **kwargs: Any) -> Any:
|
||||
reprs = []
|
||||
for arg in args:
|
||||
reprs.append(f"{arg!r}")
|
||||
for k, v in kwargs.items():
|
||||
reprs.append(f"{k}={v!r}")
|
||||
|
||||
logger.info("PGconn.%s(%s)", f.__name__, ", ".join(reprs))
|
||||
rv = f(*args, **kwargs)
|
||||
# Display the return value only if the function is declared to return
|
||||
# something else than None.
|
||||
ra = inspect.signature(f).return_annotation
|
||||
if ra is not None or rv is not None:
|
||||
logger.info(" <- %r", rv)
|
||||
return rv
|
||||
|
||||
return debugging_ # type: ignore
|
||||
249
srcs/.venv/lib/python3.11/site-packages/psycopg/pq/_enums.py
Normal file
249
srcs/.venv/lib/python3.11/site-packages/psycopg/pq/_enums.py
Normal file
@@ -0,0 +1,249 @@
|
||||
"""
|
||||
libpq enum definitions for psycopg
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
from enum import IntEnum, IntFlag, auto
|
||||
|
||||
|
||||
class ConnStatus(IntEnum):
|
||||
"""
|
||||
Current status of the connection.
|
||||
"""
|
||||
|
||||
__module__ = "psycopg.pq"
|
||||
|
||||
OK = 0
|
||||
"""The connection is in a working state."""
|
||||
BAD = auto()
|
||||
"""The connection is closed."""
|
||||
|
||||
STARTED = auto()
|
||||
MADE = auto()
|
||||
AWAITING_RESPONSE = auto()
|
||||
AUTH_OK = auto()
|
||||
SETENV = auto()
|
||||
SSL_STARTUP = auto()
|
||||
NEEDED = auto()
|
||||
CHECK_WRITABLE = auto()
|
||||
CONSUME = auto()
|
||||
GSS_STARTUP = auto()
|
||||
CHECK_TARGET = auto()
|
||||
CHECK_STANDBY = auto()
|
||||
|
||||
|
||||
class PollingStatus(IntEnum):
|
||||
"""
|
||||
The status of the socket during a connection.
|
||||
|
||||
If ``READING`` or ``WRITING`` you may select before polling again.
|
||||
"""
|
||||
|
||||
__module__ = "psycopg.pq"
|
||||
|
||||
FAILED = 0
|
||||
"""Connection attempt failed."""
|
||||
READING = auto()
|
||||
"""Will have to wait before reading new data."""
|
||||
WRITING = auto()
|
||||
"""Will have to wait before writing new data."""
|
||||
OK = auto()
|
||||
"""Connection completed."""
|
||||
|
||||
ACTIVE = auto()
|
||||
|
||||
|
||||
class ExecStatus(IntEnum):
|
||||
"""
|
||||
The status of a command.
|
||||
"""
|
||||
|
||||
__module__ = "psycopg.pq"
|
||||
|
||||
EMPTY_QUERY = 0
|
||||
"""The string sent to the server was empty."""
|
||||
|
||||
COMMAND_OK = auto()
|
||||
"""Successful completion of a command returning no data."""
|
||||
|
||||
TUPLES_OK = auto()
|
||||
"""
|
||||
Successful completion of a command returning data (such as a SELECT or SHOW).
|
||||
"""
|
||||
|
||||
COPY_OUT = auto()
|
||||
"""Copy Out (from server) data transfer started."""
|
||||
|
||||
COPY_IN = auto()
|
||||
"""Copy In (to server) data transfer started."""
|
||||
|
||||
BAD_RESPONSE = auto()
|
||||
"""The server's response was not understood."""
|
||||
|
||||
NONFATAL_ERROR = auto()
|
||||
"""A nonfatal error (a notice or warning) occurred."""
|
||||
|
||||
FATAL_ERROR = auto()
|
||||
"""A fatal error occurred."""
|
||||
|
||||
COPY_BOTH = auto()
|
||||
"""
|
||||
Copy In/Out (to and from server) data transfer started.
|
||||
|
||||
This feature is currently used only for streaming replication, so this
|
||||
status should not occur in ordinary applications.
|
||||
"""
|
||||
|
||||
SINGLE_TUPLE = auto()
|
||||
"""
|
||||
The PGresult contains a single result tuple from the current command.
|
||||
|
||||
This status occurs only when single-row mode has been selected for the
|
||||
query.
|
||||
"""
|
||||
|
||||
PIPELINE_SYNC = auto()
|
||||
"""
|
||||
The PGresult represents a synchronization point in pipeline mode,
|
||||
requested by PQpipelineSync.
|
||||
|
||||
This status occurs only when pipeline mode has been selected.
|
||||
"""
|
||||
|
||||
PIPELINE_ABORTED = auto()
|
||||
"""
|
||||
The PGresult represents a pipeline that has received an error from the server.
|
||||
|
||||
PQgetResult must be called repeatedly, and each time it will return this
|
||||
status code until the end of the current pipeline, at which point it will
|
||||
return PGRES_PIPELINE_SYNC and normal processing can resume.
|
||||
"""
|
||||
|
||||
|
||||
class TransactionStatus(IntEnum):
|
||||
"""
|
||||
The transaction status of a connection.
|
||||
"""
|
||||
|
||||
__module__ = "psycopg.pq"
|
||||
|
||||
IDLE = 0
|
||||
"""Connection ready, no transaction active."""
|
||||
|
||||
ACTIVE = auto()
|
||||
"""A command is in progress."""
|
||||
|
||||
INTRANS = auto()
|
||||
"""Connection idle in an open transaction."""
|
||||
|
||||
INERROR = auto()
|
||||
"""An error happened in the current transaction."""
|
||||
|
||||
UNKNOWN = auto()
|
||||
"""Unknown connection state, broken connection."""
|
||||
|
||||
|
||||
class Ping(IntEnum):
|
||||
"""Response from a ping attempt."""
|
||||
|
||||
__module__ = "psycopg.pq"
|
||||
|
||||
OK = 0
|
||||
"""
|
||||
The server is running and appears to be accepting connections.
|
||||
"""
|
||||
|
||||
REJECT = auto()
|
||||
"""
|
||||
The server is running but is in a state that disallows connections.
|
||||
"""
|
||||
|
||||
NO_RESPONSE = auto()
|
||||
"""
|
||||
The server could not be contacted.
|
||||
"""
|
||||
|
||||
NO_ATTEMPT = auto()
|
||||
"""
|
||||
No attempt was made to contact the server.
|
||||
"""
|
||||
|
||||
|
||||
class PipelineStatus(IntEnum):
|
||||
"""Pipeline mode status of the libpq connection."""
|
||||
|
||||
__module__ = "psycopg.pq"
|
||||
|
||||
OFF = 0
|
||||
"""
|
||||
The libpq connection is *not* in pipeline mode.
|
||||
"""
|
||||
ON = auto()
|
||||
"""
|
||||
The libpq connection is in pipeline mode.
|
||||
"""
|
||||
ABORTED = auto()
|
||||
"""
|
||||
The libpq connection is in pipeline mode and an error occurred while
|
||||
processing the current pipeline. The aborted flag is cleared when
|
||||
PQgetResult returns a result of type PGRES_PIPELINE_SYNC.
|
||||
"""
|
||||
|
||||
|
||||
class DiagnosticField(IntEnum):
|
||||
"""
|
||||
Fields in an error report.
|
||||
"""
|
||||
|
||||
__module__ = "psycopg.pq"
|
||||
|
||||
# from postgres_ext.h
|
||||
SEVERITY = ord("S")
|
||||
SEVERITY_NONLOCALIZED = ord("V")
|
||||
SQLSTATE = ord("C")
|
||||
MESSAGE_PRIMARY = ord("M")
|
||||
MESSAGE_DETAIL = ord("D")
|
||||
MESSAGE_HINT = ord("H")
|
||||
STATEMENT_POSITION = ord("P")
|
||||
INTERNAL_POSITION = ord("p")
|
||||
INTERNAL_QUERY = ord("q")
|
||||
CONTEXT = ord("W")
|
||||
SCHEMA_NAME = ord("s")
|
||||
TABLE_NAME = ord("t")
|
||||
COLUMN_NAME = ord("c")
|
||||
DATATYPE_NAME = ord("d")
|
||||
CONSTRAINT_NAME = ord("n")
|
||||
SOURCE_FILE = ord("F")
|
||||
SOURCE_LINE = ord("L")
|
||||
SOURCE_FUNCTION = ord("R")
|
||||
|
||||
|
||||
class Format(IntEnum):
|
||||
"""
|
||||
Enum representing the format of a query argument or return value.
|
||||
|
||||
These values are only the ones managed by the libpq. `~psycopg` may also
|
||||
support automatically-chosen values: see `psycopg.adapt.PyFormat`.
|
||||
"""
|
||||
|
||||
__module__ = "psycopg.pq"
|
||||
|
||||
TEXT = 0
|
||||
"""Text parameter."""
|
||||
BINARY = 1
|
||||
"""Binary parameter."""
|
||||
|
||||
|
||||
class Trace(IntFlag):
|
||||
"""
|
||||
Enum to control tracing of the client/server communication.
|
||||
"""
|
||||
|
||||
__module__ = "psycopg.pq"
|
||||
|
||||
SUPPRESS_TIMESTAMPS = 1
|
||||
"""Do not include timestamps in messages."""
|
||||
|
||||
REGRESS_MODE = 2
|
||||
"""Redact some fields, e.g. OIDs, from messages."""
|
||||
804
srcs/.venv/lib/python3.11/site-packages/psycopg/pq/_pq_ctypes.py
Normal file
804
srcs/.venv/lib/python3.11/site-packages/psycopg/pq/_pq_ctypes.py
Normal file
@@ -0,0 +1,804 @@
|
||||
"""
|
||||
libpq access using ctypes
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
import sys
|
||||
import ctypes
|
||||
import ctypes.util
|
||||
from ctypes import Structure, CFUNCTYPE, POINTER
|
||||
from ctypes import c_char, c_char_p, c_int, c_size_t, c_ubyte, c_uint, c_void_p
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from .misc import find_libpq_full_path
|
||||
from ..errors import NotSupportedError
|
||||
|
||||
libname = find_libpq_full_path()
|
||||
if not libname:
|
||||
raise ImportError("libpq library not found")
|
||||
|
||||
pq = ctypes.cdll.LoadLibrary(libname)
|
||||
|
||||
|
||||
class FILE(Structure):
|
||||
pass
|
||||
|
||||
|
||||
FILE_ptr = POINTER(FILE)
|
||||
|
||||
if sys.platform == "linux":
|
||||
libcname = ctypes.util.find_library("c")
|
||||
assert libcname
|
||||
libc = ctypes.cdll.LoadLibrary(libcname)
|
||||
|
||||
fdopen = libc.fdopen
|
||||
fdopen.argtypes = (c_int, c_char_p)
|
||||
fdopen.restype = FILE_ptr
|
||||
|
||||
|
||||
# Get the libpq version to define what functions are available.
|
||||
|
||||
PQlibVersion = pq.PQlibVersion
|
||||
PQlibVersion.argtypes = []
|
||||
PQlibVersion.restype = c_int
|
||||
|
||||
libpq_version = PQlibVersion()
|
||||
|
||||
|
||||
# libpq data types
|
||||
|
||||
|
||||
Oid = c_uint
|
||||
|
||||
|
||||
class PGconn_struct(Structure):
|
||||
_fields_: List[Tuple[str, type]] = []
|
||||
|
||||
|
||||
class PGresult_struct(Structure):
|
||||
_fields_: List[Tuple[str, type]] = []
|
||||
|
||||
|
||||
class PQconninfoOption_struct(Structure):
|
||||
_fields_ = [
|
||||
("keyword", c_char_p),
|
||||
("envvar", c_char_p),
|
||||
("compiled", c_char_p),
|
||||
("val", c_char_p),
|
||||
("label", c_char_p),
|
||||
("dispchar", c_char_p),
|
||||
("dispsize", c_int),
|
||||
]
|
||||
|
||||
|
||||
class PGnotify_struct(Structure):
|
||||
_fields_ = [
|
||||
("relname", c_char_p),
|
||||
("be_pid", c_int),
|
||||
("extra", c_char_p),
|
||||
]
|
||||
|
||||
|
||||
class PGcancel_struct(Structure):
|
||||
_fields_: List[Tuple[str, type]] = []
|
||||
|
||||
|
||||
class PGresAttDesc_struct(Structure):
|
||||
_fields_ = [
|
||||
("name", c_char_p),
|
||||
("tableid", Oid),
|
||||
("columnid", c_int),
|
||||
("format", c_int),
|
||||
("typid", Oid),
|
||||
("typlen", c_int),
|
||||
("atttypmod", c_int),
|
||||
]
|
||||
|
||||
|
||||
PGconn_ptr = POINTER(PGconn_struct)
|
||||
PGresult_ptr = POINTER(PGresult_struct)
|
||||
PQconninfoOption_ptr = POINTER(PQconninfoOption_struct)
|
||||
PGnotify_ptr = POINTER(PGnotify_struct)
|
||||
PGcancel_ptr = POINTER(PGcancel_struct)
|
||||
PGresAttDesc_ptr = POINTER(PGresAttDesc_struct)
|
||||
|
||||
|
||||
# Function definitions as explained in PostgreSQL 12 documentation
|
||||
|
||||
# 33.1. Database Connection Control Functions
|
||||
|
||||
# PQconnectdbParams: doesn't seem useful, won't wrap for now
|
||||
|
||||
PQconnectdb = pq.PQconnectdb
|
||||
PQconnectdb.argtypes = [c_char_p]
|
||||
PQconnectdb.restype = PGconn_ptr
|
||||
|
||||
# PQsetdbLogin: not useful
|
||||
# PQsetdb: not useful
|
||||
|
||||
# PQconnectStartParams: not useful
|
||||
|
||||
PQconnectStart = pq.PQconnectStart
|
||||
PQconnectStart.argtypes = [c_char_p]
|
||||
PQconnectStart.restype = PGconn_ptr
|
||||
|
||||
PQconnectPoll = pq.PQconnectPoll
|
||||
PQconnectPoll.argtypes = [PGconn_ptr]
|
||||
PQconnectPoll.restype = c_int
|
||||
|
||||
PQconndefaults = pq.PQconndefaults
|
||||
PQconndefaults.argtypes = []
|
||||
PQconndefaults.restype = PQconninfoOption_ptr
|
||||
|
||||
PQconninfoFree = pq.PQconninfoFree
|
||||
PQconninfoFree.argtypes = [PQconninfoOption_ptr]
|
||||
PQconninfoFree.restype = None
|
||||
|
||||
PQconninfo = pq.PQconninfo
|
||||
PQconninfo.argtypes = [PGconn_ptr]
|
||||
PQconninfo.restype = PQconninfoOption_ptr
|
||||
|
||||
PQconninfoParse = pq.PQconninfoParse
|
||||
PQconninfoParse.argtypes = [c_char_p, POINTER(c_char_p)]
|
||||
PQconninfoParse.restype = PQconninfoOption_ptr
|
||||
|
||||
PQfinish = pq.PQfinish
|
||||
PQfinish.argtypes = [PGconn_ptr]
|
||||
PQfinish.restype = None
|
||||
|
||||
PQreset = pq.PQreset
|
||||
PQreset.argtypes = [PGconn_ptr]
|
||||
PQreset.restype = None
|
||||
|
||||
PQresetStart = pq.PQresetStart
|
||||
PQresetStart.argtypes = [PGconn_ptr]
|
||||
PQresetStart.restype = c_int
|
||||
|
||||
PQresetPoll = pq.PQresetPoll
|
||||
PQresetPoll.argtypes = [PGconn_ptr]
|
||||
PQresetPoll.restype = c_int
|
||||
|
||||
PQping = pq.PQping
|
||||
PQping.argtypes = [c_char_p]
|
||||
PQping.restype = c_int
|
||||
|
||||
|
||||
# 33.2. Connection Status Functions
|
||||
|
||||
PQdb = pq.PQdb
|
||||
PQdb.argtypes = [PGconn_ptr]
|
||||
PQdb.restype = c_char_p
|
||||
|
||||
PQuser = pq.PQuser
|
||||
PQuser.argtypes = [PGconn_ptr]
|
||||
PQuser.restype = c_char_p
|
||||
|
||||
PQpass = pq.PQpass
|
||||
PQpass.argtypes = [PGconn_ptr]
|
||||
PQpass.restype = c_char_p
|
||||
|
||||
PQhost = pq.PQhost
|
||||
PQhost.argtypes = [PGconn_ptr]
|
||||
PQhost.restype = c_char_p
|
||||
|
||||
_PQhostaddr = None
|
||||
|
||||
if libpq_version >= 120000:
|
||||
_PQhostaddr = pq.PQhostaddr
|
||||
_PQhostaddr.argtypes = [PGconn_ptr]
|
||||
_PQhostaddr.restype = c_char_p
|
||||
|
||||
|
||||
def PQhostaddr(pgconn: PGconn_struct) -> bytes:
|
||||
if not _PQhostaddr:
|
||||
raise NotSupportedError(
|
||||
"PQhostaddr requires libpq from PostgreSQL 12,"
|
||||
f" {libpq_version} available instead"
|
||||
)
|
||||
|
||||
return _PQhostaddr(pgconn)
|
||||
|
||||
|
||||
PQport = pq.PQport
|
||||
PQport.argtypes = [PGconn_ptr]
|
||||
PQport.restype = c_char_p
|
||||
|
||||
PQtty = pq.PQtty
|
||||
PQtty.argtypes = [PGconn_ptr]
|
||||
PQtty.restype = c_char_p
|
||||
|
||||
PQoptions = pq.PQoptions
|
||||
PQoptions.argtypes = [PGconn_ptr]
|
||||
PQoptions.restype = c_char_p
|
||||
|
||||
PQstatus = pq.PQstatus
|
||||
PQstatus.argtypes = [PGconn_ptr]
|
||||
PQstatus.restype = c_int
|
||||
|
||||
PQtransactionStatus = pq.PQtransactionStatus
|
||||
PQtransactionStatus.argtypes = [PGconn_ptr]
|
||||
PQtransactionStatus.restype = c_int
|
||||
|
||||
PQparameterStatus = pq.PQparameterStatus
|
||||
PQparameterStatus.argtypes = [PGconn_ptr, c_char_p]
|
||||
PQparameterStatus.restype = c_char_p
|
||||
|
||||
PQprotocolVersion = pq.PQprotocolVersion
|
||||
PQprotocolVersion.argtypes = [PGconn_ptr]
|
||||
PQprotocolVersion.restype = c_int
|
||||
|
||||
PQserverVersion = pq.PQserverVersion
|
||||
PQserverVersion.argtypes = [PGconn_ptr]
|
||||
PQserverVersion.restype = c_int
|
||||
|
||||
PQerrorMessage = pq.PQerrorMessage
|
||||
PQerrorMessage.argtypes = [PGconn_ptr]
|
||||
PQerrorMessage.restype = c_char_p
|
||||
|
||||
PQsocket = pq.PQsocket
|
||||
PQsocket.argtypes = [PGconn_ptr]
|
||||
PQsocket.restype = c_int
|
||||
|
||||
PQbackendPID = pq.PQbackendPID
|
||||
PQbackendPID.argtypes = [PGconn_ptr]
|
||||
PQbackendPID.restype = c_int
|
||||
|
||||
PQconnectionNeedsPassword = pq.PQconnectionNeedsPassword
|
||||
PQconnectionNeedsPassword.argtypes = [PGconn_ptr]
|
||||
PQconnectionNeedsPassword.restype = c_int
|
||||
|
||||
PQconnectionUsedPassword = pq.PQconnectionUsedPassword
|
||||
PQconnectionUsedPassword.argtypes = [PGconn_ptr]
|
||||
PQconnectionUsedPassword.restype = c_int
|
||||
|
||||
PQsslInUse = pq.PQsslInUse
|
||||
PQsslInUse.argtypes = [PGconn_ptr]
|
||||
PQsslInUse.restype = c_int
|
||||
|
||||
# TODO: PQsslAttribute, PQsslAttributeNames, PQsslStruct, PQgetssl
|
||||
|
||||
|
||||
# 33.3. Command Execution Functions
|
||||
|
||||
PQexec = pq.PQexec
|
||||
PQexec.argtypes = [PGconn_ptr, c_char_p]
|
||||
PQexec.restype = PGresult_ptr
|
||||
|
||||
PQexecParams = pq.PQexecParams
|
||||
PQexecParams.argtypes = [
|
||||
PGconn_ptr,
|
||||
c_char_p,
|
||||
c_int,
|
||||
POINTER(Oid),
|
||||
POINTER(c_char_p),
|
||||
POINTER(c_int),
|
||||
POINTER(c_int),
|
||||
c_int,
|
||||
]
|
||||
PQexecParams.restype = PGresult_ptr
|
||||
|
||||
PQprepare = pq.PQprepare
|
||||
PQprepare.argtypes = [PGconn_ptr, c_char_p, c_char_p, c_int, POINTER(Oid)]
|
||||
PQprepare.restype = PGresult_ptr
|
||||
|
||||
PQexecPrepared = pq.PQexecPrepared
|
||||
PQexecPrepared.argtypes = [
|
||||
PGconn_ptr,
|
||||
c_char_p,
|
||||
c_int,
|
||||
POINTER(c_char_p),
|
||||
POINTER(c_int),
|
||||
POINTER(c_int),
|
||||
c_int,
|
||||
]
|
||||
PQexecPrepared.restype = PGresult_ptr
|
||||
|
||||
PQdescribePrepared = pq.PQdescribePrepared
|
||||
PQdescribePrepared.argtypes = [PGconn_ptr, c_char_p]
|
||||
PQdescribePrepared.restype = PGresult_ptr
|
||||
|
||||
PQdescribePortal = pq.PQdescribePortal
|
||||
PQdescribePortal.argtypes = [PGconn_ptr, c_char_p]
|
||||
PQdescribePortal.restype = PGresult_ptr
|
||||
|
||||
PQresultStatus = pq.PQresultStatus
|
||||
PQresultStatus.argtypes = [PGresult_ptr]
|
||||
PQresultStatus.restype = c_int
|
||||
|
||||
# PQresStatus: not needed, we have pretty enums
|
||||
|
||||
PQresultErrorMessage = pq.PQresultErrorMessage
|
||||
PQresultErrorMessage.argtypes = [PGresult_ptr]
|
||||
PQresultErrorMessage.restype = c_char_p
|
||||
|
||||
# TODO: PQresultVerboseErrorMessage
|
||||
|
||||
PQresultErrorField = pq.PQresultErrorField
|
||||
PQresultErrorField.argtypes = [PGresult_ptr, c_int]
|
||||
PQresultErrorField.restype = c_char_p
|
||||
|
||||
PQclear = pq.PQclear
|
||||
PQclear.argtypes = [PGresult_ptr]
|
||||
PQclear.restype = None
|
||||
|
||||
|
||||
# 33.3.2. Retrieving Query Result Information
|
||||
|
||||
PQntuples = pq.PQntuples
|
||||
PQntuples.argtypes = [PGresult_ptr]
|
||||
PQntuples.restype = c_int
|
||||
|
||||
PQnfields = pq.PQnfields
|
||||
PQnfields.argtypes = [PGresult_ptr]
|
||||
PQnfields.restype = c_int
|
||||
|
||||
PQfname = pq.PQfname
|
||||
PQfname.argtypes = [PGresult_ptr, c_int]
|
||||
PQfname.restype = c_char_p
|
||||
|
||||
# PQfnumber: useless and hard to use
|
||||
|
||||
PQftable = pq.PQftable
|
||||
PQftable.argtypes = [PGresult_ptr, c_int]
|
||||
PQftable.restype = Oid
|
||||
|
||||
PQftablecol = pq.PQftablecol
|
||||
PQftablecol.argtypes = [PGresult_ptr, c_int]
|
||||
PQftablecol.restype = c_int
|
||||
|
||||
PQfformat = pq.PQfformat
|
||||
PQfformat.argtypes = [PGresult_ptr, c_int]
|
||||
PQfformat.restype = c_int
|
||||
|
||||
PQftype = pq.PQftype
|
||||
PQftype.argtypes = [PGresult_ptr, c_int]
|
||||
PQftype.restype = Oid
|
||||
|
||||
PQfmod = pq.PQfmod
|
||||
PQfmod.argtypes = [PGresult_ptr, c_int]
|
||||
PQfmod.restype = c_int
|
||||
|
||||
PQfsize = pq.PQfsize
|
||||
PQfsize.argtypes = [PGresult_ptr, c_int]
|
||||
PQfsize.restype = c_int
|
||||
|
||||
PQbinaryTuples = pq.PQbinaryTuples
|
||||
PQbinaryTuples.argtypes = [PGresult_ptr]
|
||||
PQbinaryTuples.restype = c_int
|
||||
|
||||
PQgetvalue = pq.PQgetvalue
|
||||
PQgetvalue.argtypes = [PGresult_ptr, c_int, c_int]
|
||||
PQgetvalue.restype = POINTER(c_char) # not a null-terminated string
|
||||
|
||||
PQgetisnull = pq.PQgetisnull
|
||||
PQgetisnull.argtypes = [PGresult_ptr, c_int, c_int]
|
||||
PQgetisnull.restype = c_int
|
||||
|
||||
PQgetlength = pq.PQgetlength
|
||||
PQgetlength.argtypes = [PGresult_ptr, c_int, c_int]
|
||||
PQgetlength.restype = c_int
|
||||
|
||||
PQnparams = pq.PQnparams
|
||||
PQnparams.argtypes = [PGresult_ptr]
|
||||
PQnparams.restype = c_int
|
||||
|
||||
PQparamtype = pq.PQparamtype
|
||||
PQparamtype.argtypes = [PGresult_ptr, c_int]
|
||||
PQparamtype.restype = Oid
|
||||
|
||||
# PQprint: pretty useless
|
||||
|
||||
# 33.3.3. Retrieving Other Result Information
|
||||
|
||||
PQcmdStatus = pq.PQcmdStatus
|
||||
PQcmdStatus.argtypes = [PGresult_ptr]
|
||||
PQcmdStatus.restype = c_char_p
|
||||
|
||||
PQcmdTuples = pq.PQcmdTuples
|
||||
PQcmdTuples.argtypes = [PGresult_ptr]
|
||||
PQcmdTuples.restype = c_char_p
|
||||
|
||||
PQoidValue = pq.PQoidValue
|
||||
PQoidValue.argtypes = [PGresult_ptr]
|
||||
PQoidValue.restype = Oid
|
||||
|
||||
|
||||
# 33.3.4. Escaping Strings for Inclusion in SQL Commands
|
||||
|
||||
PQescapeLiteral = pq.PQescapeLiteral
|
||||
PQescapeLiteral.argtypes = [PGconn_ptr, c_char_p, c_size_t]
|
||||
PQescapeLiteral.restype = POINTER(c_char)
|
||||
|
||||
PQescapeIdentifier = pq.PQescapeIdentifier
|
||||
PQescapeIdentifier.argtypes = [PGconn_ptr, c_char_p, c_size_t]
|
||||
PQescapeIdentifier.restype = POINTER(c_char)
|
||||
|
||||
PQescapeStringConn = pq.PQescapeStringConn
|
||||
# TODO: raises "wrong type" error
|
||||
# PQescapeStringConn.argtypes = [
|
||||
# PGconn_ptr, c_char_p, c_char_p, c_size_t, POINTER(c_int)
|
||||
# ]
|
||||
PQescapeStringConn.restype = c_size_t
|
||||
|
||||
PQescapeString = pq.PQescapeString
|
||||
# TODO: raises "wrong type" error
|
||||
# PQescapeString.argtypes = [c_char_p, c_char_p, c_size_t]
|
||||
PQescapeString.restype = c_size_t
|
||||
|
||||
PQescapeByteaConn = pq.PQescapeByteaConn
|
||||
PQescapeByteaConn.argtypes = [
|
||||
PGconn_ptr,
|
||||
POINTER(c_char), # actually POINTER(c_ubyte) but this is easier
|
||||
c_size_t,
|
||||
POINTER(c_size_t),
|
||||
]
|
||||
PQescapeByteaConn.restype = POINTER(c_ubyte)
|
||||
|
||||
PQescapeBytea = pq.PQescapeBytea
|
||||
PQescapeBytea.argtypes = [
|
||||
POINTER(c_char), # actually POINTER(c_ubyte) but this is easier
|
||||
c_size_t,
|
||||
POINTER(c_size_t),
|
||||
]
|
||||
PQescapeBytea.restype = POINTER(c_ubyte)
|
||||
|
||||
|
||||
PQunescapeBytea = pq.PQunescapeBytea
|
||||
PQunescapeBytea.argtypes = [
|
||||
POINTER(c_char), # actually POINTER(c_ubyte) but this is easier
|
||||
POINTER(c_size_t),
|
||||
]
|
||||
PQunescapeBytea.restype = POINTER(c_ubyte)
|
||||
|
||||
|
||||
# 33.4. Asynchronous Command Processing
|
||||
|
||||
PQsendQuery = pq.PQsendQuery
|
||||
PQsendQuery.argtypes = [PGconn_ptr, c_char_p]
|
||||
PQsendQuery.restype = c_int
|
||||
|
||||
PQsendQueryParams = pq.PQsendQueryParams
|
||||
PQsendQueryParams.argtypes = [
|
||||
PGconn_ptr,
|
||||
c_char_p,
|
||||
c_int,
|
||||
POINTER(Oid),
|
||||
POINTER(c_char_p),
|
||||
POINTER(c_int),
|
||||
POINTER(c_int),
|
||||
c_int,
|
||||
]
|
||||
PQsendQueryParams.restype = c_int
|
||||
|
||||
PQsendPrepare = pq.PQsendPrepare
|
||||
PQsendPrepare.argtypes = [PGconn_ptr, c_char_p, c_char_p, c_int, POINTER(Oid)]
|
||||
PQsendPrepare.restype = c_int
|
||||
|
||||
PQsendQueryPrepared = pq.PQsendQueryPrepared
|
||||
PQsendQueryPrepared.argtypes = [
|
||||
PGconn_ptr,
|
||||
c_char_p,
|
||||
c_int,
|
||||
POINTER(c_char_p),
|
||||
POINTER(c_int),
|
||||
POINTER(c_int),
|
||||
c_int,
|
||||
]
|
||||
PQsendQueryPrepared.restype = c_int
|
||||
|
||||
PQsendDescribePrepared = pq.PQsendDescribePrepared
|
||||
PQsendDescribePrepared.argtypes = [PGconn_ptr, c_char_p]
|
||||
PQsendDescribePrepared.restype = c_int
|
||||
|
||||
PQsendDescribePortal = pq.PQsendDescribePortal
|
||||
PQsendDescribePortal.argtypes = [PGconn_ptr, c_char_p]
|
||||
PQsendDescribePortal.restype = c_int
|
||||
|
||||
PQgetResult = pq.PQgetResult
|
||||
PQgetResult.argtypes = [PGconn_ptr]
|
||||
PQgetResult.restype = PGresult_ptr
|
||||
|
||||
PQconsumeInput = pq.PQconsumeInput
|
||||
PQconsumeInput.argtypes = [PGconn_ptr]
|
||||
PQconsumeInput.restype = c_int
|
||||
|
||||
PQisBusy = pq.PQisBusy
|
||||
PQisBusy.argtypes = [PGconn_ptr]
|
||||
PQisBusy.restype = c_int
|
||||
|
||||
PQsetnonblocking = pq.PQsetnonblocking
|
||||
PQsetnonblocking.argtypes = [PGconn_ptr, c_int]
|
||||
PQsetnonblocking.restype = c_int
|
||||
|
||||
PQisnonblocking = pq.PQisnonblocking
|
||||
PQisnonblocking.argtypes = [PGconn_ptr]
|
||||
PQisnonblocking.restype = c_int
|
||||
|
||||
PQflush = pq.PQflush
|
||||
PQflush.argtypes = [PGconn_ptr]
|
||||
PQflush.restype = c_int
|
||||
|
||||
|
||||
# 33.5. Retrieving Query Results Row-by-Row
|
||||
PQsetSingleRowMode = pq.PQsetSingleRowMode
|
||||
PQsetSingleRowMode.argtypes = [PGconn_ptr]
|
||||
PQsetSingleRowMode.restype = c_int
|
||||
|
||||
|
||||
# 33.6. Canceling Queries in Progress
|
||||
|
||||
PQgetCancel = pq.PQgetCancel
|
||||
PQgetCancel.argtypes = [PGconn_ptr]
|
||||
PQgetCancel.restype = PGcancel_ptr
|
||||
|
||||
PQfreeCancel = pq.PQfreeCancel
|
||||
PQfreeCancel.argtypes = [PGcancel_ptr]
|
||||
PQfreeCancel.restype = None
|
||||
|
||||
PQcancel = pq.PQcancel
|
||||
# TODO: raises "wrong type" error
|
||||
# PQcancel.argtypes = [PGcancel_ptr, POINTER(c_char), c_int]
|
||||
PQcancel.restype = c_int
|
||||
|
||||
|
||||
# 33.8. Asynchronous Notification
|
||||
|
||||
PQnotifies = pq.PQnotifies
|
||||
PQnotifies.argtypes = [PGconn_ptr]
|
||||
PQnotifies.restype = PGnotify_ptr
|
||||
|
||||
|
||||
# 33.9. Functions Associated with the COPY Command
|
||||
|
||||
PQputCopyData = pq.PQputCopyData
|
||||
PQputCopyData.argtypes = [PGconn_ptr, c_char_p, c_int]
|
||||
PQputCopyData.restype = c_int
|
||||
|
||||
PQputCopyEnd = pq.PQputCopyEnd
|
||||
PQputCopyEnd.argtypes = [PGconn_ptr, c_char_p]
|
||||
PQputCopyEnd.restype = c_int
|
||||
|
||||
PQgetCopyData = pq.PQgetCopyData
|
||||
PQgetCopyData.argtypes = [PGconn_ptr, POINTER(c_char_p), c_int]
|
||||
PQgetCopyData.restype = c_int
|
||||
|
||||
|
||||
# 33.10. Control Functions
|
||||
|
||||
PQtrace = pq.PQtrace
|
||||
PQtrace.argtypes = [PGconn_ptr, FILE_ptr]
|
||||
PQtrace.restype = None
|
||||
|
||||
_PQsetTraceFlags = None
|
||||
|
||||
if libpq_version >= 140000:
|
||||
_PQsetTraceFlags = pq.PQsetTraceFlags
|
||||
_PQsetTraceFlags.argtypes = [PGconn_ptr, c_int]
|
||||
_PQsetTraceFlags.restype = None
|
||||
|
||||
|
||||
def PQsetTraceFlags(pgconn: PGconn_struct, flags: int) -> None:
|
||||
if not _PQsetTraceFlags:
|
||||
raise NotSupportedError(
|
||||
"PQsetTraceFlags requires libpq from PostgreSQL 14,"
|
||||
f" {libpq_version} available instead"
|
||||
)
|
||||
|
||||
_PQsetTraceFlags(pgconn, flags)
|
||||
|
||||
|
||||
PQuntrace = pq.PQuntrace
|
||||
PQuntrace.argtypes = [PGconn_ptr]
|
||||
PQuntrace.restype = None
|
||||
|
||||
# 33.11. Miscellaneous Functions
|
||||
|
||||
PQfreemem = pq.PQfreemem
|
||||
PQfreemem.argtypes = [c_void_p]
|
||||
PQfreemem.restype = None
|
||||
|
||||
if libpq_version >= 100000:
|
||||
_PQencryptPasswordConn = pq.PQencryptPasswordConn
|
||||
_PQencryptPasswordConn.argtypes = [
|
||||
PGconn_ptr,
|
||||
c_char_p,
|
||||
c_char_p,
|
||||
c_char_p,
|
||||
]
|
||||
_PQencryptPasswordConn.restype = POINTER(c_char)
|
||||
|
||||
|
||||
def PQencryptPasswordConn(
|
||||
pgconn: PGconn_struct, passwd: bytes, user: bytes, algorithm: bytes
|
||||
) -> Optional[bytes]:
|
||||
if not _PQencryptPasswordConn:
|
||||
raise NotSupportedError(
|
||||
"PQencryptPasswordConn requires libpq from PostgreSQL 10,"
|
||||
f" {libpq_version} available instead"
|
||||
)
|
||||
|
||||
return _PQencryptPasswordConn(pgconn, passwd, user, algorithm)
|
||||
|
||||
|
||||
PQmakeEmptyPGresult = pq.PQmakeEmptyPGresult
|
||||
PQmakeEmptyPGresult.argtypes = [PGconn_ptr, c_int]
|
||||
PQmakeEmptyPGresult.restype = PGresult_ptr
|
||||
|
||||
PQsetResultAttrs = pq.PQsetResultAttrs
|
||||
PQsetResultAttrs.argtypes = [PGresult_ptr, c_int, PGresAttDesc_ptr]
|
||||
PQsetResultAttrs.restype = c_int
|
||||
|
||||
|
||||
# 33.12. Notice Processing
|
||||
|
||||
PQnoticeReceiver = CFUNCTYPE(None, c_void_p, PGresult_ptr)
|
||||
|
||||
PQsetNoticeReceiver = pq.PQsetNoticeReceiver
|
||||
PQsetNoticeReceiver.argtypes = [PGconn_ptr, PQnoticeReceiver, c_void_p]
|
||||
PQsetNoticeReceiver.restype = PQnoticeReceiver
|
||||
|
||||
# 34.5 Pipeline Mode
|
||||
|
||||
_PQpipelineStatus = None
|
||||
_PQenterPipelineMode = None
|
||||
_PQexitPipelineMode = None
|
||||
_PQpipelineSync = None
|
||||
_PQsendFlushRequest = None
|
||||
|
||||
if libpq_version >= 140000:
|
||||
_PQpipelineStatus = pq.PQpipelineStatus
|
||||
_PQpipelineStatus.argtypes = [PGconn_ptr]
|
||||
_PQpipelineStatus.restype = c_int
|
||||
|
||||
_PQenterPipelineMode = pq.PQenterPipelineMode
|
||||
_PQenterPipelineMode.argtypes = [PGconn_ptr]
|
||||
_PQenterPipelineMode.restype = c_int
|
||||
|
||||
_PQexitPipelineMode = pq.PQexitPipelineMode
|
||||
_PQexitPipelineMode.argtypes = [PGconn_ptr]
|
||||
_PQexitPipelineMode.restype = c_int
|
||||
|
||||
_PQpipelineSync = pq.PQpipelineSync
|
||||
_PQpipelineSync.argtypes = [PGconn_ptr]
|
||||
_PQpipelineSync.restype = c_int
|
||||
|
||||
_PQsendFlushRequest = pq.PQsendFlushRequest
|
||||
_PQsendFlushRequest.argtypes = [PGconn_ptr]
|
||||
_PQsendFlushRequest.restype = c_int
|
||||
|
||||
|
||||
def PQpipelineStatus(pgconn: PGconn_struct) -> int:
|
||||
if not _PQpipelineStatus:
|
||||
raise NotSupportedError(
|
||||
"PQpipelineStatus requires libpq from PostgreSQL 14,"
|
||||
f" {libpq_version} available instead"
|
||||
)
|
||||
return _PQpipelineStatus(pgconn)
|
||||
|
||||
|
||||
def PQenterPipelineMode(pgconn: PGconn_struct) -> int:
|
||||
if not _PQenterPipelineMode:
|
||||
raise NotSupportedError(
|
||||
"PQenterPipelineMode requires libpq from PostgreSQL 14,"
|
||||
f" {libpq_version} available instead"
|
||||
)
|
||||
return _PQenterPipelineMode(pgconn)
|
||||
|
||||
|
||||
def PQexitPipelineMode(pgconn: PGconn_struct) -> int:
|
||||
if not _PQexitPipelineMode:
|
||||
raise NotSupportedError(
|
||||
"PQexitPipelineMode requires libpq from PostgreSQL 14,"
|
||||
f" {libpq_version} available instead"
|
||||
)
|
||||
return _PQexitPipelineMode(pgconn)
|
||||
|
||||
|
||||
def PQpipelineSync(pgconn: PGconn_struct) -> int:
|
||||
if not _PQpipelineSync:
|
||||
raise NotSupportedError(
|
||||
"PQpipelineSync requires libpq from PostgreSQL 14,"
|
||||
f" {libpq_version} available instead"
|
||||
)
|
||||
return _PQpipelineSync(pgconn)
|
||||
|
||||
|
||||
def PQsendFlushRequest(pgconn: PGconn_struct) -> int:
|
||||
if not _PQsendFlushRequest:
|
||||
raise NotSupportedError(
|
||||
"PQsendFlushRequest requires libpq from PostgreSQL 14,"
|
||||
f" {libpq_version} available instead"
|
||||
)
|
||||
return _PQsendFlushRequest(pgconn)
|
||||
|
||||
|
||||
# 33.18. SSL Support
|
||||
|
||||
PQinitOpenSSL = pq.PQinitOpenSSL
|
||||
PQinitOpenSSL.argtypes = [c_int, c_int]
|
||||
PQinitOpenSSL.restype = None
|
||||
|
||||
|
||||
def generate_stub() -> None:
|
||||
import re
|
||||
from ctypes import _CFuncPtr # type: ignore
|
||||
|
||||
def type2str(fname, narg, t):
|
||||
if t is None:
|
||||
return "None"
|
||||
elif t is c_void_p:
|
||||
return "Any"
|
||||
elif t is c_int or t is c_uint or t is c_size_t:
|
||||
return "int"
|
||||
elif t is c_char_p or t.__name__ == "LP_c_char":
|
||||
if narg is not None:
|
||||
return "bytes"
|
||||
else:
|
||||
return "Optional[bytes]"
|
||||
|
||||
elif t.__name__ in (
|
||||
"LP_PGconn_struct",
|
||||
"LP_PGresult_struct",
|
||||
"LP_PGcancel_struct",
|
||||
):
|
||||
if narg is not None:
|
||||
return f"Optional[{t.__name__[3:]}]"
|
||||
else:
|
||||
return t.__name__[3:]
|
||||
|
||||
elif t.__name__ in ("LP_PQconninfoOption_struct",):
|
||||
return f"Sequence[{t.__name__[3:]}]"
|
||||
|
||||
elif t.__name__ in (
|
||||
"LP_c_ubyte",
|
||||
"LP_c_char_p",
|
||||
"LP_c_int",
|
||||
"LP_c_uint",
|
||||
"LP_c_ulong",
|
||||
"LP_FILE",
|
||||
):
|
||||
return f"_Pointer[{t.__name__[3:]}]"
|
||||
|
||||
else:
|
||||
assert False, f"can't deal with {t} in {fname}"
|
||||
|
||||
fn = __file__ + "i"
|
||||
with open(fn) as f:
|
||||
lines = f.read().splitlines()
|
||||
|
||||
istart, iend = (
|
||||
i
|
||||
for i, line in enumerate(lines)
|
||||
if re.match(r"\s*#\s*autogenerated:\s+(start|end)", line)
|
||||
)
|
||||
|
||||
known = {
|
||||
line[4:].split("(", 1)[0] for line in lines[:istart] if line.startswith("def ")
|
||||
}
|
||||
|
||||
signatures = []
|
||||
|
||||
for name, obj in globals().items():
|
||||
if name in known:
|
||||
continue
|
||||
if not isinstance(obj, _CFuncPtr):
|
||||
continue
|
||||
|
||||
params = []
|
||||
for i, t in enumerate(obj.argtypes):
|
||||
params.append(f"arg{i + 1}: {type2str(name, i, t)}")
|
||||
|
||||
resname = type2str(name, None, obj.restype)
|
||||
|
||||
signatures.append(f"def {name}({', '.join(params)}) -> {resname}: ...")
|
||||
|
||||
lines[istart + 1 : iend] = signatures
|
||||
|
||||
with open(fn, "w") as f:
|
||||
f.write("\n".join(lines))
|
||||
f.write("\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_stub()
|
||||
384
srcs/.venv/lib/python3.11/site-packages/psycopg/pq/abc.py
Normal file
384
srcs/.venv/lib/python3.11/site-packages/psycopg/pq/abc.py
Normal file
@@ -0,0 +1,384 @@
|
||||
"""
|
||||
Protocol objects to represent objects exposed by different pq implementations.
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
from typing import Any, Callable, List, Optional, Sequence, Tuple
|
||||
from typing import Union, TYPE_CHECKING
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from ._enums import Format, Trace
|
||||
from .._compat import Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .misc import PGnotify, ConninfoOption, PGresAttDesc
|
||||
|
||||
# An object implementing the buffer protocol (ish)
|
||||
Buffer: TypeAlias = Union[bytes, bytearray, memoryview]
|
||||
|
||||
|
||||
class PGconn(Protocol):
|
||||
notice_handler: Optional[Callable[["PGresult"], None]]
|
||||
notify_handler: Optional[Callable[["PGnotify"], None]]
|
||||
|
||||
@classmethod
|
||||
def connect(cls, conninfo: bytes) -> "PGconn":
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def connect_start(cls, conninfo: bytes) -> "PGconn":
|
||||
...
|
||||
|
||||
def connect_poll(self) -> int:
|
||||
...
|
||||
|
||||
def finish(self) -> None:
|
||||
...
|
||||
|
||||
@property
|
||||
def info(self) -> List["ConninfoOption"]:
|
||||
...
|
||||
|
||||
def reset(self) -> None:
|
||||
...
|
||||
|
||||
def reset_start(self) -> None:
|
||||
...
|
||||
|
||||
def reset_poll(self) -> int:
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def ping(self, conninfo: bytes) -> int:
|
||||
...
|
||||
|
||||
@property
|
||||
def db(self) -> bytes:
|
||||
...
|
||||
|
||||
@property
|
||||
def user(self) -> bytes:
|
||||
...
|
||||
|
||||
@property
|
||||
def password(self) -> bytes:
|
||||
...
|
||||
|
||||
@property
|
||||
def host(self) -> bytes:
|
||||
...
|
||||
|
||||
@property
|
||||
def hostaddr(self) -> bytes:
|
||||
...
|
||||
|
||||
@property
|
||||
def port(self) -> bytes:
|
||||
...
|
||||
|
||||
@property
|
||||
def tty(self) -> bytes:
|
||||
...
|
||||
|
||||
@property
|
||||
def options(self) -> bytes:
|
||||
...
|
||||
|
||||
@property
|
||||
def status(self) -> int:
|
||||
...
|
||||
|
||||
@property
|
||||
def transaction_status(self) -> int:
|
||||
...
|
||||
|
||||
def parameter_status(self, name: bytes) -> Optional[bytes]:
|
||||
...
|
||||
|
||||
@property
|
||||
def error_message(self) -> bytes:
|
||||
...
|
||||
|
||||
@property
|
||||
def server_version(self) -> int:
|
||||
...
|
||||
|
||||
@property
|
||||
def socket(self) -> int:
|
||||
...
|
||||
|
||||
@property
|
||||
def backend_pid(self) -> int:
|
||||
...
|
||||
|
||||
@property
|
||||
def needs_password(self) -> bool:
|
||||
...
|
||||
|
||||
@property
|
||||
def used_password(self) -> bool:
|
||||
...
|
||||
|
||||
@property
|
||||
def ssl_in_use(self) -> bool:
|
||||
...
|
||||
|
||||
def exec_(self, command: bytes) -> "PGresult":
|
||||
...
|
||||
|
||||
def send_query(self, command: bytes) -> None:
|
||||
...
|
||||
|
||||
def exec_params(
|
||||
self,
|
||||
command: bytes,
|
||||
param_values: Optional[Sequence[Optional[Buffer]]],
|
||||
param_types: Optional[Sequence[int]] = None,
|
||||
param_formats: Optional[Sequence[int]] = None,
|
||||
result_format: int = Format.TEXT,
|
||||
) -> "PGresult":
|
||||
...
|
||||
|
||||
def send_query_params(
|
||||
self,
|
||||
command: bytes,
|
||||
param_values: Optional[Sequence[Optional[Buffer]]],
|
||||
param_types: Optional[Sequence[int]] = None,
|
||||
param_formats: Optional[Sequence[int]] = None,
|
||||
result_format: int = Format.TEXT,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
def send_prepare(
|
||||
self,
|
||||
name: bytes,
|
||||
command: bytes,
|
||||
param_types: Optional[Sequence[int]] = None,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
def send_query_prepared(
|
||||
self,
|
||||
name: bytes,
|
||||
param_values: Optional[Sequence[Optional[Buffer]]],
|
||||
param_formats: Optional[Sequence[int]] = None,
|
||||
result_format: int = Format.TEXT,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
name: bytes,
|
||||
command: bytes,
|
||||
param_types: Optional[Sequence[int]] = None,
|
||||
) -> "PGresult":
|
||||
...
|
||||
|
||||
def exec_prepared(
|
||||
self,
|
||||
name: bytes,
|
||||
param_values: Optional[Sequence[Buffer]],
|
||||
param_formats: Optional[Sequence[int]] = None,
|
||||
result_format: int = 0,
|
||||
) -> "PGresult":
|
||||
...
|
||||
|
||||
def describe_prepared(self, name: bytes) -> "PGresult":
|
||||
...
|
||||
|
||||
def send_describe_prepared(self, name: bytes) -> None:
|
||||
...
|
||||
|
||||
def describe_portal(self, name: bytes) -> "PGresult":
|
||||
...
|
||||
|
||||
def send_describe_portal(self, name: bytes) -> None:
|
||||
...
|
||||
|
||||
def get_result(self) -> Optional["PGresult"]:
|
||||
...
|
||||
|
||||
def consume_input(self) -> None:
|
||||
...
|
||||
|
||||
def is_busy(self) -> int:
|
||||
...
|
||||
|
||||
@property
|
||||
def nonblocking(self) -> int:
|
||||
...
|
||||
|
||||
@nonblocking.setter
|
||||
def nonblocking(self, arg: int) -> None:
|
||||
...
|
||||
|
||||
def flush(self) -> int:
|
||||
...
|
||||
|
||||
def set_single_row_mode(self) -> None:
|
||||
...
|
||||
|
||||
def get_cancel(self) -> "PGcancel":
|
||||
...
|
||||
|
||||
def notifies(self) -> Optional["PGnotify"]:
|
||||
...
|
||||
|
||||
def put_copy_data(self, buffer: Buffer) -> int:
|
||||
...
|
||||
|
||||
def put_copy_end(self, error: Optional[bytes] = None) -> int:
|
||||
...
|
||||
|
||||
def get_copy_data(self, async_: int) -> Tuple[int, memoryview]:
|
||||
...
|
||||
|
||||
def trace(self, fileno: int) -> None:
|
||||
...
|
||||
|
||||
def set_trace_flags(self, flags: Trace) -> None:
|
||||
...
|
||||
|
||||
def untrace(self) -> None:
|
||||
...
|
||||
|
||||
def encrypt_password(
|
||||
self, passwd: bytes, user: bytes, algorithm: Optional[bytes] = None
|
||||
) -> bytes:
|
||||
...
|
||||
|
||||
def make_empty_result(self, exec_status: int) -> "PGresult":
|
||||
...
|
||||
|
||||
@property
|
||||
def pipeline_status(self) -> int:
|
||||
...
|
||||
|
||||
def enter_pipeline_mode(self) -> None:
|
||||
...
|
||||
|
||||
def exit_pipeline_mode(self) -> None:
|
||||
...
|
||||
|
||||
def pipeline_sync(self) -> None:
|
||||
...
|
||||
|
||||
def send_flush_request(self) -> None:
|
||||
...
|
||||
|
||||
|
||||
class PGresult(Protocol):
|
||||
def clear(self) -> None:
|
||||
...
|
||||
|
||||
@property
|
||||
def status(self) -> int:
|
||||
...
|
||||
|
||||
@property
|
||||
def error_message(self) -> bytes:
|
||||
...
|
||||
|
||||
def error_field(self, fieldcode: int) -> Optional[bytes]:
|
||||
...
|
||||
|
||||
@property
|
||||
def ntuples(self) -> int:
|
||||
...
|
||||
|
||||
@property
|
||||
def nfields(self) -> int:
|
||||
...
|
||||
|
||||
def fname(self, column_number: int) -> Optional[bytes]:
|
||||
...
|
||||
|
||||
def ftable(self, column_number: int) -> int:
|
||||
...
|
||||
|
||||
def ftablecol(self, column_number: int) -> int:
|
||||
...
|
||||
|
||||
def fformat(self, column_number: int) -> int:
|
||||
...
|
||||
|
||||
def ftype(self, column_number: int) -> int:
|
||||
...
|
||||
|
||||
def fmod(self, column_number: int) -> int:
|
||||
...
|
||||
|
||||
def fsize(self, column_number: int) -> int:
|
||||
...
|
||||
|
||||
@property
|
||||
def binary_tuples(self) -> int:
|
||||
...
|
||||
|
||||
def get_value(self, row_number: int, column_number: int) -> Optional[bytes]:
|
||||
...
|
||||
|
||||
@property
|
||||
def nparams(self) -> int:
|
||||
...
|
||||
|
||||
def param_type(self, param_number: int) -> int:
|
||||
...
|
||||
|
||||
@property
|
||||
def command_status(self) -> Optional[bytes]:
|
||||
...
|
||||
|
||||
@property
|
||||
def command_tuples(self) -> Optional[int]:
|
||||
...
|
||||
|
||||
@property
|
||||
def oid_value(self) -> int:
|
||||
...
|
||||
|
||||
def set_attributes(self, descriptions: List["PGresAttDesc"]) -> None:
|
||||
...
|
||||
|
||||
|
||||
class PGcancel(Protocol):
|
||||
def free(self) -> None:
|
||||
...
|
||||
|
||||
def cancel(self) -> None:
|
||||
...
|
||||
|
||||
|
||||
class Conninfo(Protocol):
|
||||
@classmethod
|
||||
def get_defaults(cls) -> List["ConninfoOption"]:
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def parse(cls, conninfo: bytes) -> List["ConninfoOption"]:
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def _options_from_array(cls, opts: Sequence[Any]) -> List["ConninfoOption"]:
|
||||
...
|
||||
|
||||
|
||||
class Escaping(Protocol):
|
||||
def __init__(self, conn: Optional[PGconn] = None):
|
||||
...
|
||||
|
||||
def escape_literal(self, data: Buffer) -> bytes:
|
||||
...
|
||||
|
||||
def escape_identifier(self, data: Buffer) -> bytes:
|
||||
...
|
||||
|
||||
def escape_string(self, data: Buffer) -> bytes:
|
||||
...
|
||||
|
||||
def escape_bytea(self, data: Buffer) -> bytes:
|
||||
...
|
||||
|
||||
def unescape_bytea(self, data: Buffer) -> bytes:
|
||||
...
|
||||
146
srcs/.venv/lib/python3.11/site-packages/psycopg/pq/misc.py
Normal file
146
srcs/.venv/lib/python3.11/site-packages/psycopg/pq/misc.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""
|
||||
Various functionalities to make easier to work with the libpq.
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import ctypes.util
|
||||
from typing import cast, NamedTuple, Optional, Union
|
||||
|
||||
from .abc import PGconn, PGresult
|
||||
from ._enums import ConnStatus, TransactionStatus, PipelineStatus
|
||||
from .._compat import cache
|
||||
from .._encodings import pgconn_encoding
|
||||
|
||||
logger = logging.getLogger("psycopg.pq")
|
||||
|
||||
OK = ConnStatus.OK
|
||||
|
||||
|
||||
class PGnotify(NamedTuple):
|
||||
relname: bytes
|
||||
be_pid: int
|
||||
extra: bytes
|
||||
|
||||
|
||||
class ConninfoOption(NamedTuple):
|
||||
keyword: bytes
|
||||
envvar: Optional[bytes]
|
||||
compiled: Optional[bytes]
|
||||
val: Optional[bytes]
|
||||
label: bytes
|
||||
dispchar: bytes
|
||||
dispsize: int
|
||||
|
||||
|
||||
class PGresAttDesc(NamedTuple):
|
||||
name: bytes
|
||||
tableid: int
|
||||
columnid: int
|
||||
format: int
|
||||
typid: int
|
||||
typlen: int
|
||||
atttypmod: int
|
||||
|
||||
|
||||
@cache
|
||||
def find_libpq_full_path() -> Optional[str]:
|
||||
if sys.platform == "win32":
|
||||
libname = ctypes.util.find_library("libpq.dll")
|
||||
|
||||
elif sys.platform == "darwin":
|
||||
libname = ctypes.util.find_library("libpq.dylib")
|
||||
# (hopefully) temporary hack: libpq not in a standard place
|
||||
# https://github.com/orgs/Homebrew/discussions/3595
|
||||
# If pg_config is available and agrees, let's use its indications.
|
||||
if not libname:
|
||||
try:
|
||||
import subprocess as sp
|
||||
|
||||
libdir = sp.check_output(["pg_config", "--libdir"]).strip().decode()
|
||||
libname = os.path.join(libdir, "libpq.dylib")
|
||||
if not os.path.exists(libname):
|
||||
libname = None
|
||||
except Exception as ex:
|
||||
logger.debug("couldn't use pg_config to find libpq: %s", ex)
|
||||
|
||||
else:
|
||||
libname = ctypes.util.find_library("pq")
|
||||
|
||||
return libname
|
||||
|
||||
|
||||
def error_message(obj: Union[PGconn, PGresult], encoding: str = "utf8") -> str:
|
||||
"""
|
||||
Return an error message from a `PGconn` or `PGresult`.
|
||||
|
||||
The return value is a `!str` (unlike pq data which is usually `!bytes`):
|
||||
use the connection encoding if available, otherwise the `!encoding`
|
||||
parameter as a fallback for decoding. Don't raise exceptions on decoding
|
||||
errors.
|
||||
|
||||
"""
|
||||
bmsg: bytes
|
||||
|
||||
if hasattr(obj, "error_field"):
|
||||
# obj is a PGresult
|
||||
obj = cast(PGresult, obj)
|
||||
bmsg = obj.error_message
|
||||
|
||||
# strip severity and whitespaces
|
||||
if bmsg:
|
||||
bmsg = bmsg.split(b":", 1)[-1].strip()
|
||||
|
||||
elif hasattr(obj, "error_message"):
|
||||
# obj is a PGconn
|
||||
if obj.status == OK:
|
||||
encoding = pgconn_encoding(obj)
|
||||
bmsg = obj.error_message
|
||||
|
||||
# strip severity and whitespaces
|
||||
if bmsg:
|
||||
bmsg = bmsg.split(b":", 1)[-1].strip()
|
||||
|
||||
else:
|
||||
raise TypeError(f"PGconn or PGresult expected, got {type(obj).__name__}")
|
||||
|
||||
if bmsg:
|
||||
msg = bmsg.decode(encoding, "replace")
|
||||
else:
|
||||
msg = "no details available"
|
||||
|
||||
return msg
|
||||
|
||||
|
||||
def connection_summary(pgconn: PGconn) -> str:
|
||||
"""
|
||||
Return summary information on a connection.
|
||||
|
||||
Useful for __repr__
|
||||
"""
|
||||
parts = []
|
||||
if pgconn.status == OK:
|
||||
# Put together the [STATUS]
|
||||
status = TransactionStatus(pgconn.transaction_status).name
|
||||
if pgconn.pipeline_status:
|
||||
status += f", pipeline={PipelineStatus(pgconn.pipeline_status).name}"
|
||||
|
||||
# Put together the (CONNECTION)
|
||||
if not pgconn.host.startswith(b"/"):
|
||||
parts.append(("host", pgconn.host.decode()))
|
||||
if pgconn.port != b"5432":
|
||||
parts.append(("port", pgconn.port.decode()))
|
||||
if pgconn.user != pgconn.db:
|
||||
parts.append(("user", pgconn.user.decode()))
|
||||
parts.append(("database", pgconn.db.decode()))
|
||||
|
||||
else:
|
||||
status = ConnStatus(pgconn.status).name
|
||||
|
||||
sparts = " ".join("%s=%s" % part for part in parts)
|
||||
if sparts:
|
||||
sparts = f" ({sparts})"
|
||||
return f"[{status}]{sparts}"
|
||||
1089
srcs/.venv/lib/python3.11/site-packages/psycopg/pq/pq_ctypes.py
Normal file
1089
srcs/.venv/lib/python3.11/site-packages/psycopg/pq/pq_ctypes.py
Normal file
File diff suppressed because it is too large
Load Diff
255
srcs/.venv/lib/python3.11/site-packages/psycopg/rows.py
Normal file
255
srcs/.venv/lib/python3.11/site-packages/psycopg/rows.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""
|
||||
psycopg row factories
|
||||
"""
|
||||
|
||||
# Copyright (C) 2021 The Psycopg Team
|
||||
|
||||
import functools
|
||||
from typing import Any, Callable, Dict, List, Optional, NamedTuple, NoReturn
|
||||
from typing import TYPE_CHECKING, Sequence, Tuple, Type, TypeVar
|
||||
from collections import namedtuple
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from . import pq
|
||||
from . import errors as e
|
||||
from ._compat import Protocol
|
||||
from ._encodings import _as_python_identifier
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .cursor import BaseCursor, Cursor
|
||||
from .cursor_async import AsyncCursor
|
||||
from psycopg.pq.abc import PGresult
|
||||
|
||||
COMMAND_OK = pq.ExecStatus.COMMAND_OK
|
||||
TUPLES_OK = pq.ExecStatus.TUPLES_OK
|
||||
SINGLE_TUPLE = pq.ExecStatus.SINGLE_TUPLE
|
||||
|
||||
T = TypeVar("T", covariant=True)
|
||||
|
||||
# Row factories
|
||||
|
||||
Row = TypeVar("Row", covariant=True)
|
||||
|
||||
|
||||
class RowMaker(Protocol[Row]):
|
||||
"""
|
||||
Callable protocol taking a sequence of value and returning an object.
|
||||
|
||||
The sequence of value is what is returned from a database query, already
|
||||
adapted to the right Python types. The return value is the object that your
|
||||
program would like to receive: by default (`tuple_row()`) it is a simple
|
||||
tuple, but it may be any type of object.
|
||||
|
||||
Typically, `!RowMaker` functions are returned by `RowFactory`.
|
||||
"""
|
||||
|
||||
def __call__(self, __values: Sequence[Any]) -> Row:
|
||||
...
|
||||
|
||||
|
||||
class RowFactory(Protocol[Row]):
|
||||
"""
|
||||
Callable protocol taking a `~psycopg.Cursor` and returning a `RowMaker`.
|
||||
|
||||
A `!RowFactory` is typically called when a `!Cursor` receives a result.
|
||||
This way it can inspect the cursor state (for instance the
|
||||
`~psycopg.Cursor.description` attribute) and help a `!RowMaker` to create
|
||||
a complete object.
|
||||
|
||||
For instance the `dict_row()` `!RowFactory` uses the names of the column to
|
||||
define the dictionary key and returns a `!RowMaker` function which would
|
||||
use the values to create a dictionary for each record.
|
||||
"""
|
||||
|
||||
def __call__(self, __cursor: "Cursor[Any]") -> RowMaker[Row]:
|
||||
...
|
||||
|
||||
|
||||
class AsyncRowFactory(Protocol[Row]):
|
||||
"""
|
||||
Like `RowFactory`, taking an async cursor as argument.
|
||||
"""
|
||||
|
||||
def __call__(self, __cursor: "AsyncCursor[Any]") -> RowMaker[Row]:
|
||||
...
|
||||
|
||||
|
||||
class BaseRowFactory(Protocol[Row]):
|
||||
"""
|
||||
Like `RowFactory`, taking either type of cursor as argument.
|
||||
"""
|
||||
|
||||
def __call__(self, __cursor: "BaseCursor[Any, Any]") -> RowMaker[Row]:
|
||||
...
|
||||
|
||||
|
||||
TupleRow: TypeAlias = Tuple[Any, ...]
|
||||
"""
|
||||
An alias for the type returned by `tuple_row()` (i.e. a tuple of any content).
|
||||
"""
|
||||
|
||||
|
||||
DictRow: TypeAlias = Dict[str, Any]
|
||||
"""
|
||||
An alias for the type returned by `dict_row()`
|
||||
|
||||
A `!DictRow` is a dictionary with keys as string and any value returned by the
|
||||
database.
|
||||
"""
|
||||
|
||||
|
||||
def tuple_row(cursor: "BaseCursor[Any, Any]") -> "RowMaker[TupleRow]":
|
||||
r"""Row factory to represent rows as simple tuples.
|
||||
|
||||
This is the default factory, used when `~psycopg.Connection.connect()` or
|
||||
`~psycopg.Connection.cursor()` are called without a `!row_factory`
|
||||
parameter.
|
||||
|
||||
"""
|
||||
# Implementation detail: make sure this is the tuple type itself, not an
|
||||
# equivalent function, because the C code fast-paths on it.
|
||||
return tuple
|
||||
|
||||
|
||||
def dict_row(cursor: "BaseCursor[Any, Any]") -> "RowMaker[DictRow]":
|
||||
"""Row factory to represent rows as dictionaries.
|
||||
|
||||
The dictionary keys are taken from the column names of the returned columns.
|
||||
"""
|
||||
names = _get_names(cursor)
|
||||
if names is None:
|
||||
return no_result
|
||||
|
||||
def dict_row_(values: Sequence[Any]) -> Dict[str, Any]:
|
||||
return dict(zip(names, values))
|
||||
|
||||
return dict_row_
|
||||
|
||||
|
||||
def namedtuple_row(
|
||||
cursor: "BaseCursor[Any, Any]",
|
||||
) -> "RowMaker[NamedTuple]":
|
||||
"""Row factory to represent rows as `~collections.namedtuple`.
|
||||
|
||||
The field names are taken from the column names of the returned columns,
|
||||
with some mangling to deal with invalid names.
|
||||
"""
|
||||
res = cursor.pgresult
|
||||
if not res:
|
||||
return no_result
|
||||
|
||||
nfields = _get_nfields(res)
|
||||
if nfields is None:
|
||||
return no_result
|
||||
|
||||
nt = _make_nt(cursor._encoding, *(res.fname(i) for i in range(nfields)))
|
||||
return nt._make
|
||||
|
||||
|
||||
@functools.lru_cache(512)
|
||||
def _make_nt(enc: str, *names: bytes) -> Type[NamedTuple]:
|
||||
snames = tuple(_as_python_identifier(n.decode(enc)) for n in names)
|
||||
return namedtuple("Row", snames) # type: ignore[return-value]
|
||||
|
||||
|
||||
def class_row(cls: Type[T]) -> BaseRowFactory[T]:
|
||||
r"""Generate a row factory to represent rows as instances of the class `!cls`.
|
||||
|
||||
The class must support every output column name as a keyword parameter.
|
||||
|
||||
:param cls: The class to return for each row. It must support the fields
|
||||
returned by the query as keyword arguments.
|
||||
:rtype: `!Callable[[Cursor],` `RowMaker`\[~T]]
|
||||
"""
|
||||
|
||||
def class_row_(cursor: "BaseCursor[Any, Any]") -> "RowMaker[T]":
|
||||
names = _get_names(cursor)
|
||||
if names is None:
|
||||
return no_result
|
||||
|
||||
def class_row__(values: Sequence[Any]) -> T:
|
||||
return cls(**dict(zip(names, values)))
|
||||
|
||||
return class_row__
|
||||
|
||||
return class_row_
|
||||
|
||||
|
||||
def args_row(func: Callable[..., T]) -> BaseRowFactory[T]:
|
||||
"""Generate a row factory calling `!func` with positional parameters for every row.
|
||||
|
||||
:param func: The function to call for each row. It must support the fields
|
||||
returned by the query as positional arguments.
|
||||
"""
|
||||
|
||||
def args_row_(cur: "BaseCursor[Any, T]") -> "RowMaker[T]":
|
||||
def args_row__(values: Sequence[Any]) -> T:
|
||||
return func(*values)
|
||||
|
||||
return args_row__
|
||||
|
||||
return args_row_
|
||||
|
||||
|
||||
def kwargs_row(func: Callable[..., T]) -> BaseRowFactory[T]:
|
||||
"""Generate a row factory calling `!func` with keyword parameters for every row.
|
||||
|
||||
:param func: The function to call for each row. It must support the fields
|
||||
returned by the query as keyword arguments.
|
||||
"""
|
||||
|
||||
def kwargs_row_(cursor: "BaseCursor[Any, T]") -> "RowMaker[T]":
|
||||
names = _get_names(cursor)
|
||||
if names is None:
|
||||
return no_result
|
||||
|
||||
def kwargs_row__(values: Sequence[Any]) -> T:
|
||||
return func(**dict(zip(names, values)))
|
||||
|
||||
return kwargs_row__
|
||||
|
||||
return kwargs_row_
|
||||
|
||||
|
||||
def no_result(values: Sequence[Any]) -> NoReturn:
|
||||
"""A `RowMaker` that always fail.
|
||||
|
||||
It can be used as return value for a `RowFactory` called with no result.
|
||||
Note that the `!RowFactory` *will* be called with no result, but the
|
||||
resulting `!RowMaker` never should.
|
||||
"""
|
||||
raise e.InterfaceError("the cursor doesn't have a result")
|
||||
|
||||
|
||||
def _get_names(cursor: "BaseCursor[Any, Any]") -> Optional[List[str]]:
|
||||
res = cursor.pgresult
|
||||
if not res:
|
||||
return None
|
||||
|
||||
nfields = _get_nfields(res)
|
||||
if nfields is None:
|
||||
return None
|
||||
|
||||
enc = cursor._encoding
|
||||
return [
|
||||
res.fname(i).decode(enc) for i in range(nfields) # type: ignore[union-attr]
|
||||
]
|
||||
|
||||
|
||||
def _get_nfields(res: "PGresult") -> Optional[int]:
|
||||
"""
|
||||
Return the number of columns in a result, if it returns tuples else None
|
||||
|
||||
Take into account the special case of results with zero columns.
|
||||
"""
|
||||
nfields = res.nfields
|
||||
|
||||
if (
|
||||
res.status == TUPLES_OK
|
||||
or res.status == SINGLE_TUPLE
|
||||
# "describe" in named cursors
|
||||
or (res.status == COMMAND_OK and nfields)
|
||||
):
|
||||
return nfields
|
||||
else:
|
||||
return None
|
||||
478
srcs/.venv/lib/python3.11/site-packages/psycopg/server_cursor.py
Normal file
478
srcs/.venv/lib/python3.11/site-packages/psycopg/server_cursor.py
Normal file
@@ -0,0 +1,478 @@
|
||||
"""
|
||||
psycopg server-side cursor objects.
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
from typing import Any, AsyncIterator, List, Iterable, Iterator
|
||||
from typing import Optional, TypeVar, TYPE_CHECKING, overload
|
||||
from warnings import warn
|
||||
|
||||
from . import pq
|
||||
from . import sql
|
||||
from . import errors as e
|
||||
from .abc import ConnectionType, Query, Params, PQGen
|
||||
from .rows import Row, RowFactory, AsyncRowFactory
|
||||
from .cursor import BaseCursor, Cursor
|
||||
from .generators import execute
|
||||
from .cursor_async import AsyncCursor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .connection import Connection
|
||||
from .connection_async import AsyncConnection
|
||||
|
||||
DEFAULT_ITERSIZE = 100
|
||||
|
||||
TEXT = pq.Format.TEXT
|
||||
BINARY = pq.Format.BINARY
|
||||
|
||||
COMMAND_OK = pq.ExecStatus.COMMAND_OK
|
||||
TUPLES_OK = pq.ExecStatus.TUPLES_OK
|
||||
|
||||
IDLE = pq.TransactionStatus.IDLE
|
||||
INTRANS = pq.TransactionStatus.INTRANS
|
||||
|
||||
|
||||
class ServerCursorMixin(BaseCursor[ConnectionType, Row]):
|
||||
"""Mixin to add ServerCursor behaviour and implementation a BaseCursor."""
|
||||
|
||||
__slots__ = "_name _scrollable _withhold _described itersize _format".split()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
scrollable: Optional[bool],
|
||||
withhold: bool,
|
||||
):
|
||||
self._name = name
|
||||
self._scrollable = scrollable
|
||||
self._withhold = withhold
|
||||
self._described = False
|
||||
self.itersize: int = DEFAULT_ITERSIZE
|
||||
self._format = TEXT
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# Insert the name as the second word
|
||||
parts = super().__repr__().split(None, 1)
|
||||
parts.insert(1, f"{self._name!r}")
|
||||
return " ".join(parts)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""The name of the cursor."""
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def scrollable(self) -> Optional[bool]:
|
||||
"""
|
||||
Whether the cursor is scrollable or not.
|
||||
|
||||
If `!None` leave the choice to the server. Use `!True` if you want to
|
||||
use `scroll()` on the cursor.
|
||||
"""
|
||||
return self._scrollable
|
||||
|
||||
@property
|
||||
def withhold(self) -> bool:
|
||||
"""
|
||||
If the cursor can be used after the creating transaction has committed.
|
||||
"""
|
||||
return self._withhold
|
||||
|
||||
@property
|
||||
def rownumber(self) -> Optional[int]:
|
||||
"""Index of the next row to fetch in the current result.
|
||||
|
||||
`!None` if there is no result to fetch.
|
||||
"""
|
||||
res = self.pgresult
|
||||
# command_status is empty if the result comes from
|
||||
# describe_portal, which means that we have just executed the DECLARE,
|
||||
# so we can assume we are at the first row.
|
||||
tuples = res and (res.status == TUPLES_OK or res.command_status == b"")
|
||||
return self._pos if tuples else None
|
||||
|
||||
def _declare_gen(
|
||||
self,
|
||||
query: Query,
|
||||
params: Optional[Params] = None,
|
||||
binary: Optional[bool] = None,
|
||||
) -> PQGen[None]:
|
||||
"""Generator implementing `ServerCursor.execute()`."""
|
||||
|
||||
query = self._make_declare_statement(query)
|
||||
|
||||
# If the cursor is being reused, the previous one must be closed.
|
||||
if self._described:
|
||||
yield from self._close_gen()
|
||||
self._described = False
|
||||
|
||||
yield from self._start_query(query)
|
||||
pgq = self._convert_query(query, params)
|
||||
self._execute_send(pgq, force_extended=True)
|
||||
results = yield from execute(self._conn.pgconn)
|
||||
if results[-1].status != COMMAND_OK:
|
||||
self._raise_for_result(results[-1])
|
||||
|
||||
# Set the format, which will be used by describe and fetch operations
|
||||
if binary is None:
|
||||
self._format = self.format
|
||||
else:
|
||||
self._format = BINARY if binary else TEXT
|
||||
|
||||
# The above result only returned COMMAND_OK. Get the cursor shape
|
||||
yield from self._describe_gen()
|
||||
|
||||
def _describe_gen(self) -> PQGen[None]:
|
||||
self._pgconn.send_describe_portal(self._name.encode(self._encoding))
|
||||
results = yield from execute(self._pgconn)
|
||||
self._check_results(results)
|
||||
self._results = results
|
||||
self._select_current_result(0, format=self._format)
|
||||
self._described = True
|
||||
|
||||
def _close_gen(self) -> PQGen[None]:
|
||||
ts = self._conn.pgconn.transaction_status
|
||||
|
||||
# if the connection is not in a sane state, don't even try
|
||||
if ts != IDLE and ts != INTRANS:
|
||||
return
|
||||
|
||||
# If we are IDLE, a WITHOUT HOLD cursor will surely have gone already.
|
||||
if not self._withhold and ts == IDLE:
|
||||
return
|
||||
|
||||
# if we didn't declare the cursor ourselves we still have to close it
|
||||
# but we must make sure it exists.
|
||||
if not self._described:
|
||||
query = sql.SQL(
|
||||
"SELECT 1 FROM pg_catalog.pg_cursors WHERE name = {}"
|
||||
).format(sql.Literal(self._name))
|
||||
res = yield from self._conn._exec_command(query)
|
||||
# pipeline mode otherwise, unsupported here.
|
||||
assert res is not None
|
||||
if res.ntuples == 0:
|
||||
return
|
||||
|
||||
query = sql.SQL("CLOSE {}").format(sql.Identifier(self._name))
|
||||
yield from self._conn._exec_command(query)
|
||||
|
||||
def _fetch_gen(self, num: Optional[int]) -> PQGen[List[Row]]:
|
||||
if self.closed:
|
||||
raise e.InterfaceError("the cursor is closed")
|
||||
# If we are stealing the cursor, make sure we know its shape
|
||||
if not self._described:
|
||||
yield from self._start_query()
|
||||
yield from self._describe_gen()
|
||||
|
||||
query = sql.SQL("FETCH FORWARD {} FROM {}").format(
|
||||
sql.SQL("ALL") if num is None else sql.Literal(num),
|
||||
sql.Identifier(self._name),
|
||||
)
|
||||
res = yield from self._conn._exec_command(query, result_format=self._format)
|
||||
# pipeline mode otherwise, unsupported here.
|
||||
assert res is not None
|
||||
|
||||
self.pgresult = res
|
||||
self._tx.set_pgresult(res, set_loaders=False)
|
||||
return self._tx.load_rows(0, res.ntuples, self._make_row)
|
||||
|
||||
def _scroll_gen(self, value: int, mode: str) -> PQGen[None]:
|
||||
if mode not in ("relative", "absolute"):
|
||||
raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'")
|
||||
query = sql.SQL("MOVE{} {} FROM {}").format(
|
||||
sql.SQL(" ABSOLUTE" if mode == "absolute" else ""),
|
||||
sql.Literal(value),
|
||||
sql.Identifier(self._name),
|
||||
)
|
||||
yield from self._conn._exec_command(query)
|
||||
|
||||
def _make_declare_statement(self, query: Query) -> sql.Composed:
|
||||
if isinstance(query, bytes):
|
||||
query = query.decode(self._encoding)
|
||||
if not isinstance(query, sql.Composable):
|
||||
query = sql.SQL(query)
|
||||
|
||||
parts = [
|
||||
sql.SQL("DECLARE"),
|
||||
sql.Identifier(self._name),
|
||||
]
|
||||
if self._scrollable is not None:
|
||||
parts.append(sql.SQL("SCROLL" if self._scrollable else "NO SCROLL"))
|
||||
parts.append(sql.SQL("CURSOR"))
|
||||
if self._withhold:
|
||||
parts.append(sql.SQL("WITH HOLD"))
|
||||
parts.append(sql.SQL("FOR"))
|
||||
parts.append(query)
|
||||
|
||||
return sql.SQL(" ").join(parts)
|
||||
|
||||
|
||||
class ServerCursor(ServerCursorMixin["Connection[Any]", Row], Cursor[Row]):
|
||||
__module__ = "psycopg"
|
||||
__slots__ = ()
|
||||
_Self = TypeVar("_Self", bound="ServerCursor[Any]")
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self: "ServerCursor[Row]",
|
||||
connection: "Connection[Row]",
|
||||
name: str,
|
||||
*,
|
||||
scrollable: Optional[bool] = None,
|
||||
withhold: bool = False,
|
||||
):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self: "ServerCursor[Row]",
|
||||
connection: "Connection[Any]",
|
||||
name: str,
|
||||
*,
|
||||
row_factory: RowFactory[Row],
|
||||
scrollable: Optional[bool] = None,
|
||||
withhold: bool = False,
|
||||
):
|
||||
...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection: "Connection[Any]",
|
||||
name: str,
|
||||
*,
|
||||
row_factory: Optional[RowFactory[Row]] = None,
|
||||
scrollable: Optional[bool] = None,
|
||||
withhold: bool = False,
|
||||
):
|
||||
Cursor.__init__(
|
||||
self, connection, row_factory=row_factory or connection.row_factory
|
||||
)
|
||||
ServerCursorMixin.__init__(self, name, scrollable, withhold)
|
||||
|
||||
def __del__(self) -> None:
|
||||
if not self.closed:
|
||||
warn(
|
||||
f"the server-side cursor {self} was deleted while still open."
|
||||
" Please use 'with' or '.close()' to close the cursor properly",
|
||||
ResourceWarning,
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Close the current cursor and free associated resources.
|
||||
"""
|
||||
with self._conn.lock:
|
||||
if self.closed:
|
||||
return
|
||||
if not self._conn.closed:
|
||||
self._conn.wait(self._close_gen())
|
||||
super().close()
|
||||
|
||||
def execute(
|
||||
self: _Self,
|
||||
query: Query,
|
||||
params: Optional[Params] = None,
|
||||
*,
|
||||
binary: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> _Self:
|
||||
"""
|
||||
Open a cursor to execute a query to the database.
|
||||
"""
|
||||
if kwargs:
|
||||
raise TypeError(f"keyword not supported: {list(kwargs)[0]}")
|
||||
if self._pgconn.pipeline_status:
|
||||
raise e.NotSupportedError(
|
||||
"server-side cursors not supported in pipeline mode"
|
||||
)
|
||||
|
||||
try:
|
||||
with self._conn.lock:
|
||||
self._conn.wait(self._declare_gen(query, params, binary))
|
||||
except e._NO_TRACEBACK as ex:
|
||||
raise ex.with_traceback(None)
|
||||
|
||||
return self
|
||||
|
||||
def executemany(
|
||||
self,
|
||||
query: Query,
|
||||
params_seq: Iterable[Params],
|
||||
*,
|
||||
returning: bool = True,
|
||||
) -> None:
|
||||
"""Method not implemented for server-side cursors."""
|
||||
raise e.NotSupportedError("executemany not supported on server-side cursors")
|
||||
|
||||
def fetchone(self) -> Optional[Row]:
|
||||
with self._conn.lock:
|
||||
recs = self._conn.wait(self._fetch_gen(1))
|
||||
if recs:
|
||||
self._pos += 1
|
||||
return recs[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
def fetchmany(self, size: int = 0) -> List[Row]:
|
||||
if not size:
|
||||
size = self.arraysize
|
||||
with self._conn.lock:
|
||||
recs = self._conn.wait(self._fetch_gen(size))
|
||||
self._pos += len(recs)
|
||||
return recs
|
||||
|
||||
def fetchall(self) -> List[Row]:
|
||||
with self._conn.lock:
|
||||
recs = self._conn.wait(self._fetch_gen(None))
|
||||
self._pos += len(recs)
|
||||
return recs
|
||||
|
||||
def __iter__(self) -> Iterator[Row]:
|
||||
while True:
|
||||
with self._conn.lock:
|
||||
recs = self._conn.wait(self._fetch_gen(self.itersize))
|
||||
for rec in recs:
|
||||
self._pos += 1
|
||||
yield rec
|
||||
if len(recs) < self.itersize:
|
||||
break
|
||||
|
||||
def scroll(self, value: int, mode: str = "relative") -> None:
|
||||
with self._conn.lock:
|
||||
self._conn.wait(self._scroll_gen(value, mode))
|
||||
# Postgres doesn't have a reliable way to report a cursor out of bound
|
||||
if mode == "relative":
|
||||
self._pos += value
|
||||
else:
|
||||
self._pos = value
|
||||
|
||||
|
||||
class AsyncServerCursor(
|
||||
ServerCursorMixin["AsyncConnection[Any]", Row], AsyncCursor[Row]
|
||||
):
|
||||
__module__ = "psycopg"
|
||||
__slots__ = ()
|
||||
_Self = TypeVar("_Self", bound="AsyncServerCursor[Any]")
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self: "AsyncServerCursor[Row]",
|
||||
connection: "AsyncConnection[Row]",
|
||||
name: str,
|
||||
*,
|
||||
scrollable: Optional[bool] = None,
|
||||
withhold: bool = False,
|
||||
):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self: "AsyncServerCursor[Row]",
|
||||
connection: "AsyncConnection[Any]",
|
||||
name: str,
|
||||
*,
|
||||
row_factory: AsyncRowFactory[Row],
|
||||
scrollable: Optional[bool] = None,
|
||||
withhold: bool = False,
|
||||
):
|
||||
...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection: "AsyncConnection[Any]",
|
||||
name: str,
|
||||
*,
|
||||
row_factory: Optional[AsyncRowFactory[Row]] = None,
|
||||
scrollable: Optional[bool] = None,
|
||||
withhold: bool = False,
|
||||
):
|
||||
AsyncCursor.__init__(
|
||||
self, connection, row_factory=row_factory or connection.row_factory
|
||||
)
|
||||
ServerCursorMixin.__init__(self, name, scrollable, withhold)
|
||||
|
||||
def __del__(self) -> None:
|
||||
if not self.closed:
|
||||
warn(
|
||||
f"the server-side cursor {self} was deleted while still open."
|
||||
" Please use 'with' or '.close()' to close the cursor properly",
|
||||
ResourceWarning,
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
async with self._conn.lock:
|
||||
if self.closed:
|
||||
return
|
||||
if not self._conn.closed:
|
||||
await self._conn.wait(self._close_gen())
|
||||
await super().close()
|
||||
|
||||
async def execute(
|
||||
self: _Self,
|
||||
query: Query,
|
||||
params: Optional[Params] = None,
|
||||
*,
|
||||
binary: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> _Self:
|
||||
if kwargs:
|
||||
raise TypeError(f"keyword not supported: {list(kwargs)[0]}")
|
||||
if self._pgconn.pipeline_status:
|
||||
raise e.NotSupportedError(
|
||||
"server-side cursors not supported in pipeline mode"
|
||||
)
|
||||
|
||||
try:
|
||||
async with self._conn.lock:
|
||||
await self._conn.wait(self._declare_gen(query, params, binary))
|
||||
except e._NO_TRACEBACK as ex:
|
||||
raise ex.with_traceback(None)
|
||||
|
||||
return self
|
||||
|
||||
async def executemany(
|
||||
self,
|
||||
query: Query,
|
||||
params_seq: Iterable[Params],
|
||||
*,
|
||||
returning: bool = True,
|
||||
) -> None:
|
||||
raise e.NotSupportedError("executemany not supported on server-side cursors")
|
||||
|
||||
async def fetchone(self) -> Optional[Row]:
|
||||
async with self._conn.lock:
|
||||
recs = await self._conn.wait(self._fetch_gen(1))
|
||||
if recs:
|
||||
self._pos += 1
|
||||
return recs[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
async def fetchmany(self, size: int = 0) -> List[Row]:
|
||||
if not size:
|
||||
size = self.arraysize
|
||||
async with self._conn.lock:
|
||||
recs = await self._conn.wait(self._fetch_gen(size))
|
||||
self._pos += len(recs)
|
||||
return recs
|
||||
|
||||
async def fetchall(self) -> List[Row]:
|
||||
async with self._conn.lock:
|
||||
recs = await self._conn.wait(self._fetch_gen(None))
|
||||
self._pos += len(recs)
|
||||
return recs
|
||||
|
||||
async def __aiter__(self) -> AsyncIterator[Row]:
|
||||
while True:
|
||||
async with self._conn.lock:
|
||||
recs = await self._conn.wait(self._fetch_gen(self.itersize))
|
||||
for rec in recs:
|
||||
self._pos += 1
|
||||
yield rec
|
||||
if len(recs) < self.itersize:
|
||||
break
|
||||
|
||||
async def scroll(self, value: int, mode: str = "relative") -> None:
|
||||
async with self._conn.lock:
|
||||
await self._conn.wait(self._scroll_gen(value, mode))
|
||||
467
srcs/.venv/lib/python3.11/site-packages/psycopg/sql.py
Normal file
467
srcs/.venv/lib/python3.11/site-packages/psycopg/sql.py
Normal file
@@ -0,0 +1,467 @@
|
||||
"""
|
||||
SQL composition utility module
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
import codecs
|
||||
import string
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Iterator, Iterable, List, Optional, Sequence, Union
|
||||
|
||||
from .pq import Escaping
|
||||
from .abc import AdaptContext
|
||||
from .adapt import Transformer, PyFormat
|
||||
from ._compat import LiteralString
|
||||
from ._encodings import conn_encoding
|
||||
|
||||
|
||||
def quote(obj: Any, context: Optional[AdaptContext] = None) -> str:
|
||||
"""
|
||||
Adapt a Python object to a quoted SQL string.
|
||||
|
||||
Use this function only if you absolutely want to convert a Python string to
|
||||
an SQL quoted literal to use e.g. to generate batch SQL and you won't have
|
||||
a connection available when you will need to use it.
|
||||
|
||||
This function is relatively inefficient, because it doesn't cache the
|
||||
adaptation rules. If you pass a `!context` you can adapt the adaptation
|
||||
rules used, otherwise only global rules are used.
|
||||
|
||||
"""
|
||||
return Literal(obj).as_string(context)
|
||||
|
||||
|
||||
class Composable(ABC):
|
||||
"""
|
||||
Abstract base class for objects that can be used to compose an SQL string.
|
||||
|
||||
`!Composable` objects can be passed directly to
|
||||
`~psycopg.Cursor.execute()`, `~psycopg.Cursor.executemany()`,
|
||||
`~psycopg.Cursor.copy()` in place of the query string.
|
||||
|
||||
`!Composable` objects can be joined using the ``+`` operator: the result
|
||||
will be a `Composed` instance containing the objects joined. The operator
|
||||
``*`` is also supported with an integer argument: the result is a
|
||||
`!Composed` instance containing the left argument repeated as many times as
|
||||
requested.
|
||||
"""
|
||||
|
||||
def __init__(self, obj: Any):
|
||||
self._obj = obj
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self._obj!r})"
|
||||
|
||||
@abstractmethod
|
||||
def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
|
||||
"""
|
||||
Return the value of the object as bytes.
|
||||
|
||||
:param context: the context to evaluate the object into.
|
||||
:type context: `connection` or `cursor`
|
||||
|
||||
The method is automatically invoked by `~psycopg.Cursor.execute()`,
|
||||
`~psycopg.Cursor.executemany()`, `~psycopg.Cursor.copy()` if a
|
||||
`!Composable` is passed instead of the query string.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def as_string(self, context: Optional[AdaptContext]) -> str:
|
||||
"""
|
||||
Return the value of the object as string.
|
||||
|
||||
:param context: the context to evaluate the string into.
|
||||
:type context: `connection` or `cursor`
|
||||
|
||||
"""
|
||||
conn = context.connection if context else None
|
||||
enc = conn_encoding(conn)
|
||||
b = self.as_bytes(context)
|
||||
if isinstance(b, bytes):
|
||||
return b.decode(enc)
|
||||
else:
|
||||
# buffer object
|
||||
return codecs.lookup(enc).decode(b)[0]
|
||||
|
||||
def __add__(self, other: "Composable") -> "Composed":
|
||||
if isinstance(other, Composed):
|
||||
return Composed([self]) + other
|
||||
if isinstance(other, Composable):
|
||||
return Composed([self]) + Composed([other])
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def __mul__(self, n: int) -> "Composed":
|
||||
return Composed([self] * n)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return type(self) is type(other) and self._obj == other._obj
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self.__eq__(other)
|
||||
|
||||
|
||||
class Composed(Composable):
|
||||
"""
|
||||
A `Composable` object made of a sequence of `!Composable`.
|
||||
|
||||
The object is usually created using `!Composable` operators and methods.
|
||||
However it is possible to create a `!Composed` directly specifying a
|
||||
sequence of objects as arguments: if they are not `!Composable` they will
|
||||
be wrapped in a `Literal`.
|
||||
|
||||
Example::
|
||||
|
||||
>>> comp = sql.Composed(
|
||||
... [sql.SQL("INSERT INTO "), sql.Identifier("table")])
|
||||
>>> print(comp.as_string(conn))
|
||||
INSERT INTO "table"
|
||||
|
||||
`!Composed` objects are iterable (so they can be used in `SQL.join` for
|
||||
instance).
|
||||
"""
|
||||
|
||||
_obj: List[Composable]
|
||||
|
||||
def __init__(self, seq: Sequence[Any]):
|
||||
seq = [obj if isinstance(obj, Composable) else Literal(obj) for obj in seq]
|
||||
super().__init__(seq)
|
||||
|
||||
def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
|
||||
return b"".join(obj.as_bytes(context) for obj in self._obj)
|
||||
|
||||
def __iter__(self) -> Iterator[Composable]:
|
||||
return iter(self._obj)
|
||||
|
||||
def __add__(self, other: Composable) -> "Composed":
|
||||
if isinstance(other, Composed):
|
||||
return Composed(self._obj + other._obj)
|
||||
if isinstance(other, Composable):
|
||||
return Composed(self._obj + [other])
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def join(self, joiner: Union["SQL", LiteralString]) -> "Composed":
|
||||
"""
|
||||
Return a new `!Composed` interposing the `!joiner` with the `!Composed` items.
|
||||
|
||||
The `!joiner` must be a `SQL` or a string which will be interpreted as
|
||||
an `SQL`.
|
||||
|
||||
Example::
|
||||
|
||||
>>> fields = sql.Identifier('foo') + sql.Identifier('bar') # a Composed
|
||||
>>> print(fields.join(', ').as_string(conn))
|
||||
"foo", "bar"
|
||||
|
||||
"""
|
||||
if isinstance(joiner, str):
|
||||
joiner = SQL(joiner)
|
||||
elif not isinstance(joiner, SQL):
|
||||
raise TypeError(
|
||||
"Composed.join() argument must be strings or SQL,"
|
||||
f" got {joiner!r} instead"
|
||||
)
|
||||
|
||||
return joiner.join(self._obj)
|
||||
|
||||
|
||||
class SQL(Composable):
|
||||
"""
|
||||
A `Composable` representing a snippet of SQL statement.
|
||||
|
||||
`!SQL` exposes `join()` and `format()` methods useful to create a template
|
||||
where to merge variable parts of a query (for instance field or table
|
||||
names).
|
||||
|
||||
The `!obj` string doesn't undergo any form of escaping, so it is not
|
||||
suitable to represent variable identifiers or values: you should only use
|
||||
it to pass constant strings representing templates or snippets of SQL
|
||||
statements; use other objects such as `Identifier` or `Literal` to
|
||||
represent variable parts.
|
||||
|
||||
Example::
|
||||
|
||||
>>> query = sql.SQL("SELECT {0} FROM {1}").format(
|
||||
... sql.SQL(', ').join([sql.Identifier('foo'), sql.Identifier('bar')]),
|
||||
... sql.Identifier('table'))
|
||||
>>> print(query.as_string(conn))
|
||||
SELECT "foo", "bar" FROM "table"
|
||||
"""
|
||||
|
||||
_obj: LiteralString
|
||||
_formatter = string.Formatter()
|
||||
|
||||
def __init__(self, obj: LiteralString):
|
||||
super().__init__(obj)
|
||||
if not isinstance(obj, str):
|
||||
raise TypeError(f"SQL values must be strings, got {obj!r} instead")
|
||||
|
||||
def as_string(self, context: Optional[AdaptContext]) -> str:
|
||||
return self._obj
|
||||
|
||||
def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
|
||||
enc = "utf-8"
|
||||
if context:
|
||||
enc = conn_encoding(context.connection)
|
||||
return self._obj.encode(enc)
|
||||
|
||||
def format(self, *args: Any, **kwargs: Any) -> Composed:
|
||||
"""
|
||||
Merge `Composable` objects into a template.
|
||||
|
||||
:param args: parameters to replace to numbered (``{0}``, ``{1}``) or
|
||||
auto-numbered (``{}``) placeholders
|
||||
:param kwargs: parameters to replace to named (``{name}``) placeholders
|
||||
:return: the union of the `!SQL` string with placeholders replaced
|
||||
:rtype: `Composed`
|
||||
|
||||
The method is similar to the Python `str.format()` method: the string
|
||||
template supports auto-numbered (``{}``), numbered (``{0}``,
|
||||
``{1}``...), and named placeholders (``{name}``), with positional
|
||||
arguments replacing the numbered placeholders and keywords replacing
|
||||
the named ones. However placeholder modifiers (``{0!r}``, ``{0:<10}``)
|
||||
are not supported.
|
||||
|
||||
If a `!Composable` objects is passed to the template it will be merged
|
||||
according to its `as_string()` method. If any other Python object is
|
||||
passed, it will be wrapped in a `Literal` object and so escaped
|
||||
according to SQL rules.
|
||||
|
||||
Example::
|
||||
|
||||
>>> print(sql.SQL("SELECT * FROM {} WHERE {} = %s")
|
||||
... .format(sql.Identifier('people'), sql.Identifier('id'))
|
||||
... .as_string(conn))
|
||||
SELECT * FROM "people" WHERE "id" = %s
|
||||
|
||||
>>> print(sql.SQL("SELECT * FROM {tbl} WHERE name = {name}")
|
||||
... .format(tbl=sql.Identifier('people'), name="O'Rourke"))
|
||||
... .as_string(conn))
|
||||
SELECT * FROM "people" WHERE name = 'O''Rourke'
|
||||
|
||||
"""
|
||||
rv: List[Composable] = []
|
||||
autonum: Optional[int] = 0
|
||||
# TODO: this is probably not the right way to whitelist pre
|
||||
# pyre complains. Will wait for mypy to complain too to fix.
|
||||
pre: LiteralString
|
||||
for pre, name, spec, conv in self._formatter.parse(self._obj):
|
||||
if spec:
|
||||
raise ValueError("no format specification supported by SQL")
|
||||
if conv:
|
||||
raise ValueError("no format conversion supported by SQL")
|
||||
if pre:
|
||||
rv.append(SQL(pre))
|
||||
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
if name.isdigit():
|
||||
if autonum:
|
||||
raise ValueError(
|
||||
"cannot switch from automatic field numbering to manual"
|
||||
)
|
||||
rv.append(args[int(name)])
|
||||
autonum = None
|
||||
|
||||
elif not name:
|
||||
if autonum is None:
|
||||
raise ValueError(
|
||||
"cannot switch from manual field numbering to automatic"
|
||||
)
|
||||
rv.append(args[autonum])
|
||||
autonum += 1
|
||||
|
||||
else:
|
||||
rv.append(kwargs[name])
|
||||
|
||||
return Composed(rv)
|
||||
|
||||
def join(self, seq: Iterable[Composable]) -> Composed:
|
||||
"""
|
||||
Join a sequence of `Composable`.
|
||||
|
||||
:param seq: the elements to join.
|
||||
:type seq: iterable of `!Composable`
|
||||
|
||||
Use the `!SQL` object's string to separate the elements in `!seq`.
|
||||
Note that `Composed` objects are iterable too, so they can be used as
|
||||
argument for this method.
|
||||
|
||||
Example::
|
||||
|
||||
>>> snip = sql.SQL(', ').join(
|
||||
... sql.Identifier(n) for n in ['foo', 'bar', 'baz'])
|
||||
>>> print(snip.as_string(conn))
|
||||
"foo", "bar", "baz"
|
||||
"""
|
||||
rv = []
|
||||
it = iter(seq)
|
||||
try:
|
||||
rv.append(next(it))
|
||||
except StopIteration:
|
||||
pass
|
||||
else:
|
||||
for i in it:
|
||||
rv.append(self)
|
||||
rv.append(i)
|
||||
|
||||
return Composed(rv)
|
||||
|
||||
|
||||
class Identifier(Composable):
|
||||
"""
|
||||
A `Composable` representing an SQL identifier or a dot-separated sequence.
|
||||
|
||||
Identifiers usually represent names of database objects, such as tables or
|
||||
fields. PostgreSQL identifiers follow `different rules`__ than SQL string
|
||||
literals for escaping (e.g. they use double quotes instead of single).
|
||||
|
||||
.. __: https://www.postgresql.org/docs/current/sql-syntax-lexical.html# \
|
||||
SQL-SYNTAX-IDENTIFIERS
|
||||
|
||||
Example::
|
||||
|
||||
>>> t1 = sql.Identifier("foo")
|
||||
>>> t2 = sql.Identifier("ba'r")
|
||||
>>> t3 = sql.Identifier('ba"z')
|
||||
>>> print(sql.SQL(', ').join([t1, t2, t3]).as_string(conn))
|
||||
"foo", "ba'r", "ba""z"
|
||||
|
||||
Multiple strings can be passed to the object to represent a qualified name,
|
||||
i.e. a dot-separated sequence of identifiers.
|
||||
|
||||
Example::
|
||||
|
||||
>>> query = sql.SQL("SELECT {} FROM {}").format(
|
||||
... sql.Identifier("table", "field"),
|
||||
... sql.Identifier("schema", "table"))
|
||||
>>> print(query.as_string(conn))
|
||||
SELECT "table"."field" FROM "schema"."table"
|
||||
|
||||
"""
|
||||
|
||||
_obj: Sequence[str]
|
||||
|
||||
def __init__(self, *strings: str):
|
||||
# init super() now to make the __repr__ not explode in case of error
|
||||
super().__init__(strings)
|
||||
|
||||
if not strings:
|
||||
raise TypeError("Identifier cannot be empty")
|
||||
|
||||
for s in strings:
|
||||
if not isinstance(s, str):
|
||||
raise TypeError(
|
||||
f"SQL identifier parts must be strings, got {s!r} instead"
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}({', '.join(map(repr, self._obj))})"
|
||||
|
||||
def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
|
||||
conn = context.connection if context else None
|
||||
if not conn:
|
||||
raise ValueError("a connection is necessary for Identifier")
|
||||
esc = Escaping(conn.pgconn)
|
||||
enc = conn_encoding(conn)
|
||||
escs = [esc.escape_identifier(s.encode(enc)) for s in self._obj]
|
||||
return b".".join(escs)
|
||||
|
||||
|
||||
class Literal(Composable):
|
||||
"""
|
||||
A `Composable` representing an SQL value to include in a query.
|
||||
|
||||
Usually you will want to include placeholders in the query and pass values
|
||||
as `~cursor.execute()` arguments. If however you really really need to
|
||||
include a literal value in the query you can use this object.
|
||||
|
||||
The string returned by `!as_string()` follows the normal :ref:`adaptation
|
||||
rules <types-adaptation>` for Python objects.
|
||||
|
||||
Example::
|
||||
|
||||
>>> s1 = sql.Literal("fo'o")
|
||||
>>> s2 = sql.Literal(42)
|
||||
>>> s3 = sql.Literal(date(2000, 1, 1))
|
||||
>>> print(sql.SQL(', ').join([s1, s2, s3]).as_string(conn))
|
||||
'fo''o', 42, '2000-01-01'::date
|
||||
|
||||
"""
|
||||
|
||||
def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
|
||||
tx = Transformer.from_context(context)
|
||||
return tx.as_literal(self._obj)
|
||||
|
||||
|
||||
class Placeholder(Composable):
|
||||
"""A `Composable` representing a placeholder for query parameters.
|
||||
|
||||
If the name is specified, generate a named placeholder (e.g. ``%(name)s``,
|
||||
``%(name)b``), otherwise generate a positional placeholder (e.g. ``%s``,
|
||||
``%b``).
|
||||
|
||||
The object is useful to generate SQL queries with a variable number of
|
||||
arguments.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> names = ['foo', 'bar', 'baz']
|
||||
|
||||
>>> q1 = sql.SQL("INSERT INTO my_table ({}) VALUES ({})").format(
|
||||
... sql.SQL(', ').join(map(sql.Identifier, names)),
|
||||
... sql.SQL(', ').join(sql.Placeholder() * len(names)))
|
||||
>>> print(q1.as_string(conn))
|
||||
INSERT INTO my_table ("foo", "bar", "baz") VALUES (%s, %s, %s)
|
||||
|
||||
>>> q2 = sql.SQL("INSERT INTO my_table ({}) VALUES ({})").format(
|
||||
... sql.SQL(', ').join(map(sql.Identifier, names)),
|
||||
... sql.SQL(', ').join(map(sql.Placeholder, names)))
|
||||
>>> print(q2.as_string(conn))
|
||||
INSERT INTO my_table ("foo", "bar", "baz") VALUES (%(foo)s, %(bar)s, %(baz)s)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, name: str = "", format: Union[str, PyFormat] = PyFormat.AUTO):
|
||||
super().__init__(name)
|
||||
if not isinstance(name, str):
|
||||
raise TypeError(f"expected string as name, got {name!r}")
|
||||
|
||||
if ")" in name:
|
||||
raise ValueError(f"invalid name: {name!r}")
|
||||
|
||||
if type(format) is str:
|
||||
format = PyFormat(format)
|
||||
if not isinstance(format, PyFormat):
|
||||
raise TypeError(
|
||||
f"expected PyFormat as format, got {type(format).__name__!r}"
|
||||
)
|
||||
|
||||
self._format: PyFormat = format
|
||||
|
||||
def __repr__(self) -> str:
|
||||
parts = []
|
||||
if self._obj:
|
||||
parts.append(repr(self._obj))
|
||||
if self._format is not PyFormat.AUTO:
|
||||
parts.append(f"format={self._format.name}")
|
||||
|
||||
return f"{self.__class__.__name__}({', '.join(parts)})"
|
||||
|
||||
def as_string(self, context: Optional[AdaptContext]) -> str:
|
||||
code = self._format.value
|
||||
return f"%({self._obj}){code}" if self._obj else f"%{code}"
|
||||
|
||||
def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
|
||||
conn = context.connection if context else None
|
||||
enc = conn_encoding(conn)
|
||||
return self.as_string(context).encode(enc)
|
||||
|
||||
|
||||
# Literals
|
||||
NULL = SQL("NULL")
|
||||
DEFAULT = SQL("DEFAULT")
|
||||
291
srcs/.venv/lib/python3.11/site-packages/psycopg/transaction.py
Normal file
291
srcs/.venv/lib/python3.11/site-packages/psycopg/transaction.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""
|
||||
Transaction context managers returned by Connection.transaction()
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
import logging
|
||||
|
||||
from types import TracebackType
|
||||
from typing import Generic, Iterator, Optional, Type, Union, TypeVar, TYPE_CHECKING
|
||||
|
||||
from . import pq
|
||||
from . import sql
|
||||
from . import errors as e
|
||||
from .abc import ConnectionType, PQGen
|
||||
from .pq.misc import connection_summary
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
from .connection import Connection
|
||||
from .connection_async import AsyncConnection
|
||||
|
||||
IDLE = pq.TransactionStatus.IDLE
|
||||
|
||||
OK = pq.ConnStatus.OK
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Rollback(Exception):
|
||||
"""
|
||||
Exit the current `Transaction` context immediately and rollback any changes
|
||||
made within this context.
|
||||
|
||||
If a transaction context is specified in the constructor, rollback
|
||||
enclosing transactions contexts up to and including the one specified.
|
||||
"""
|
||||
|
||||
__module__ = "psycopg"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transaction: Union["Transaction", "AsyncTransaction", None] = None,
|
||||
):
|
||||
self.transaction = transaction
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__qualname__}({self.transaction!r})"
|
||||
|
||||
|
||||
class OutOfOrderTransactionNesting(e.ProgrammingError):
|
||||
"""Out-of-order transaction nesting detected"""
|
||||
|
||||
|
||||
class BaseTransaction(Generic[ConnectionType]):
|
||||
def __init__(
|
||||
self,
|
||||
connection: ConnectionType,
|
||||
savepoint_name: Optional[str] = None,
|
||||
force_rollback: bool = False,
|
||||
):
|
||||
self._conn = connection
|
||||
self.pgconn = self._conn.pgconn
|
||||
self._savepoint_name = savepoint_name or ""
|
||||
self.force_rollback = force_rollback
|
||||
self._entered = self._exited = False
|
||||
self._outer_transaction = False
|
||||
self._stack_index = -1
|
||||
|
||||
@property
|
||||
def savepoint_name(self) -> Optional[str]:
|
||||
"""
|
||||
The name of the savepoint; `!None` if handling the main transaction.
|
||||
"""
|
||||
# Yes, it may change on __enter__. No, I don't care, because the
|
||||
# un-entered state is outside the public interface.
|
||||
return self._savepoint_name
|
||||
|
||||
def __repr__(self) -> str:
|
||||
cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
|
||||
info = connection_summary(self.pgconn)
|
||||
if not self._entered:
|
||||
status = "inactive"
|
||||
elif not self._exited:
|
||||
status = "active"
|
||||
else:
|
||||
status = "terminated"
|
||||
|
||||
sp = f"{self.savepoint_name!r} " if self.savepoint_name else ""
|
||||
return f"<{cls} {sp}({status}) {info} at 0x{id(self):x}>"
|
||||
|
||||
def _enter_gen(self) -> PQGen[None]:
|
||||
if self._entered:
|
||||
raise TypeError("transaction blocks can be used only once")
|
||||
self._entered = True
|
||||
|
||||
self._push_savepoint()
|
||||
for command in self._get_enter_commands():
|
||||
yield from self._conn._exec_command(command)
|
||||
|
||||
def _exit_gen(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> PQGen[bool]:
|
||||
if not exc_val and not self.force_rollback:
|
||||
yield from self._commit_gen()
|
||||
return False
|
||||
else:
|
||||
# try to rollback, but if there are problems (connection in a bad
|
||||
# state) just warn without clobbering the exception bubbling up.
|
||||
try:
|
||||
return (yield from self._rollback_gen(exc_val))
|
||||
except OutOfOrderTransactionNesting:
|
||||
# Clobber an exception happened in the block with the exception
|
||||
# caused by out-of-order transaction detected, so make the
|
||||
# behaviour consistent with _commit_gen and to make sure the
|
||||
# user fixes this condition, which is unrelated from
|
||||
# operational error that might arise in the block.
|
||||
raise
|
||||
except Exception as exc2:
|
||||
logger.warning("error ignored in rollback of %s: %s", self, exc2)
|
||||
return False
|
||||
|
||||
def _commit_gen(self) -> PQGen[None]:
|
||||
ex = self._pop_savepoint("commit")
|
||||
self._exited = True
|
||||
if ex:
|
||||
raise ex
|
||||
|
||||
for command in self._get_commit_commands():
|
||||
yield from self._conn._exec_command(command)
|
||||
|
||||
def _rollback_gen(self, exc_val: Optional[BaseException]) -> PQGen[bool]:
|
||||
if isinstance(exc_val, Rollback):
|
||||
logger.debug(f"{self._conn}: Explicit rollback from: ", exc_info=True)
|
||||
|
||||
ex = self._pop_savepoint("rollback")
|
||||
self._exited = True
|
||||
if ex:
|
||||
raise ex
|
||||
|
||||
for command in self._get_rollback_commands():
|
||||
yield from self._conn._exec_command(command)
|
||||
|
||||
if isinstance(exc_val, Rollback):
|
||||
if not exc_val.transaction or exc_val.transaction is self:
|
||||
return True # Swallow the exception
|
||||
|
||||
return False
|
||||
|
||||
def _get_enter_commands(self) -> Iterator[bytes]:
|
||||
if self._outer_transaction:
|
||||
yield self._conn._get_tx_start_command()
|
||||
|
||||
if self._savepoint_name:
|
||||
yield (
|
||||
sql.SQL("SAVEPOINT {}")
|
||||
.format(sql.Identifier(self._savepoint_name))
|
||||
.as_bytes(self._conn)
|
||||
)
|
||||
|
||||
def _get_commit_commands(self) -> Iterator[bytes]:
|
||||
if self._savepoint_name and not self._outer_transaction:
|
||||
yield (
|
||||
sql.SQL("RELEASE {}")
|
||||
.format(sql.Identifier(self._savepoint_name))
|
||||
.as_bytes(self._conn)
|
||||
)
|
||||
|
||||
if self._outer_transaction:
|
||||
assert not self._conn._num_transactions
|
||||
yield b"COMMIT"
|
||||
|
||||
def _get_rollback_commands(self) -> Iterator[bytes]:
|
||||
if self._savepoint_name and not self._outer_transaction:
|
||||
yield (
|
||||
sql.SQL("ROLLBACK TO {n}")
|
||||
.format(n=sql.Identifier(self._savepoint_name))
|
||||
.as_bytes(self._conn)
|
||||
)
|
||||
yield (
|
||||
sql.SQL("RELEASE {n}")
|
||||
.format(n=sql.Identifier(self._savepoint_name))
|
||||
.as_bytes(self._conn)
|
||||
)
|
||||
|
||||
if self._outer_transaction:
|
||||
assert not self._conn._num_transactions
|
||||
yield b"ROLLBACK"
|
||||
|
||||
# Also clear the prepared statements cache.
|
||||
if self._conn._prepared.clear():
|
||||
yield from self._conn._prepared.get_maintenance_commands()
|
||||
|
||||
def _push_savepoint(self) -> None:
|
||||
"""
|
||||
Push the transaction on the connection transactions stack.
|
||||
|
||||
Also set the internal state of the object and verify consistency.
|
||||
"""
|
||||
self._outer_transaction = self.pgconn.transaction_status == IDLE
|
||||
if self._outer_transaction:
|
||||
# outer transaction: if no name it's only a begin, else
|
||||
# there will be an additional savepoint
|
||||
assert not self._conn._num_transactions
|
||||
else:
|
||||
# inner transaction: it always has a name
|
||||
if not self._savepoint_name:
|
||||
self._savepoint_name = f"_pg3_{self._conn._num_transactions + 1}"
|
||||
|
||||
self._stack_index = self._conn._num_transactions
|
||||
self._conn._num_transactions += 1
|
||||
|
||||
def _pop_savepoint(self, action: str) -> Optional[Exception]:
|
||||
"""
|
||||
Pop the transaction from the connection transactions stack.
|
||||
|
||||
Also verify the state consistency.
|
||||
"""
|
||||
self._conn._num_transactions -= 1
|
||||
if self._conn._num_transactions == self._stack_index:
|
||||
return None
|
||||
|
||||
return OutOfOrderTransactionNesting(
|
||||
f"transaction {action} at the wrong nesting level: {self}"
|
||||
)
|
||||
|
||||
|
||||
class Transaction(BaseTransaction["Connection[Any]"]):
|
||||
"""
|
||||
Returned by `Connection.transaction()` to handle a transaction block.
|
||||
"""
|
||||
|
||||
__module__ = "psycopg"
|
||||
|
||||
_Self = TypeVar("_Self", bound="Transaction")
|
||||
|
||||
@property
|
||||
def connection(self) -> "Connection[Any]":
|
||||
"""The connection the object is managing."""
|
||||
return self._conn
|
||||
|
||||
def __enter__(self: _Self) -> _Self:
|
||||
with self._conn.lock:
|
||||
self._conn.wait(self._enter_gen())
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> bool:
|
||||
if self.pgconn.status == OK:
|
||||
with self._conn.lock:
|
||||
return self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb))
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class AsyncTransaction(BaseTransaction["AsyncConnection[Any]"]):
|
||||
"""
|
||||
Returned by `AsyncConnection.transaction()` to handle a transaction block.
|
||||
"""
|
||||
|
||||
__module__ = "psycopg"
|
||||
|
||||
_Self = TypeVar("_Self", bound="AsyncTransaction")
|
||||
|
||||
@property
|
||||
def connection(self) -> "AsyncConnection[Any]":
|
||||
return self._conn
|
||||
|
||||
async def __aenter__(self: _Self) -> _Self:
|
||||
async with self._conn.lock:
|
||||
await self._conn.wait(self._enter_gen())
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> bool:
|
||||
if self.pgconn.status == OK:
|
||||
async with self._conn.lock:
|
||||
return await self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb))
|
||||
else:
|
||||
return False
|
||||
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
psycopg types package
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
from .. import _typeinfo
|
||||
|
||||
# Exposed here
|
||||
TypeInfo = _typeinfo.TypeInfo
|
||||
TypesRegistry = _typeinfo.TypesRegistry
|
||||
469
srcs/.venv/lib/python3.11/site-packages/psycopg/types/array.py
Normal file
469
srcs/.venv/lib/python3.11/site-packages/psycopg/types/array.py
Normal file
@@ -0,0 +1,469 @@
|
||||
"""
|
||||
Adapters for arrays
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
import re
|
||||
import struct
|
||||
from typing import Any, cast, Callable, List, Optional, Pattern, Set, Tuple, Type
|
||||
|
||||
from .. import pq
|
||||
from .. import errors as e
|
||||
from .. import postgres
|
||||
from ..abc import AdaptContext, Buffer, Dumper, DumperKey, NoneType, Loader, Transformer
|
||||
from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat
|
||||
from .._compat import cache, prod
|
||||
from .._struct import pack_len, unpack_len
|
||||
from .._cmodule import _psycopg
|
||||
from ..postgres import TEXT_OID, INVALID_OID
|
||||
from .._typeinfo import TypeInfo
|
||||
|
||||
_struct_head = struct.Struct("!III") # ndims, hasnull, elem oid
|
||||
_pack_head = cast(Callable[[int, int, int], bytes], _struct_head.pack)
|
||||
_unpack_head = cast(Callable[[Buffer], Tuple[int, int, int]], _struct_head.unpack_from)
|
||||
_struct_dim = struct.Struct("!II") # dim, lower bound
|
||||
_pack_dim = cast(Callable[[int, int], bytes], _struct_dim.pack)
|
||||
_unpack_dim = cast(Callable[[Buffer, int], Tuple[int, int]], _struct_dim.unpack_from)
|
||||
|
||||
TEXT_ARRAY_OID = postgres.types["text"].array_oid
|
||||
|
||||
PY_TEXT = PyFormat.TEXT
|
||||
PQ_BINARY = pq.Format.BINARY
|
||||
|
||||
|
||||
class BaseListDumper(RecursiveDumper):
|
||||
element_oid = 0
|
||||
|
||||
def __init__(self, cls: type, context: Optional[AdaptContext] = None):
|
||||
if cls is NoneType:
|
||||
cls = list
|
||||
|
||||
super().__init__(cls, context)
|
||||
self.sub_dumper: Optional[Dumper] = None
|
||||
if self.element_oid and context:
|
||||
sdclass = context.adapters.get_dumper_by_oid(self.element_oid, self.format)
|
||||
self.sub_dumper = sdclass(NoneType, context)
|
||||
|
||||
def _find_list_element(self, L: List[Any], format: PyFormat) -> Any:
|
||||
"""
|
||||
Find the first non-null element of an eventually nested list
|
||||
"""
|
||||
items = list(self._flatiter(L, set()))
|
||||
types = {type(item): item for item in items}
|
||||
if not types:
|
||||
return None
|
||||
|
||||
if len(types) == 1:
|
||||
t, v = types.popitem()
|
||||
else:
|
||||
# More than one type in the list. It might be still good, as long
|
||||
# as they dump with the same oid (e.g. IPv4Network, IPv6Network).
|
||||
dumpers = [self._tx.get_dumper(item, format) for item in types.values()]
|
||||
oids = set(d.oid for d in dumpers)
|
||||
if len(oids) == 1:
|
||||
t, v = types.popitem()
|
||||
else:
|
||||
raise e.DataError(
|
||||
"cannot dump lists of mixed types;"
|
||||
f" got: {', '.join(sorted(t.__name__ for t in types))}"
|
||||
)
|
||||
|
||||
# Checking for precise type. If the type is a subclass (e.g. Int4)
|
||||
# we assume the user knows what type they are passing.
|
||||
if t is not int:
|
||||
return v
|
||||
|
||||
# If we got an int, let's see what is the biggest one in order to
|
||||
# choose the smallest OID and allow Postgres to do the right cast.
|
||||
imax: int = max(items)
|
||||
imin: int = min(items)
|
||||
if imin >= 0:
|
||||
return imax
|
||||
else:
|
||||
return max(imax, -imin - 1)
|
||||
|
||||
def _flatiter(self, L: List[Any], seen: Set[int]) -> Any:
|
||||
if id(L) in seen:
|
||||
raise e.DataError("cannot dump a recursive list")
|
||||
|
||||
seen.add(id(L))
|
||||
|
||||
for item in L:
|
||||
if type(item) is list:
|
||||
yield from self._flatiter(item, seen)
|
||||
elif item is not None:
|
||||
yield item
|
||||
|
||||
return None
|
||||
|
||||
def _get_base_type_info(self, base_oid: int) -> TypeInfo:
|
||||
"""
|
||||
Return info about the base type.
|
||||
|
||||
Return text info as fallback.
|
||||
"""
|
||||
if base_oid:
|
||||
info = self._tx.adapters.types.get(base_oid)
|
||||
if info:
|
||||
return info
|
||||
|
||||
return self._tx.adapters.types["text"]
|
||||
|
||||
|
||||
class ListDumper(BaseListDumper):
|
||||
delimiter = b","
|
||||
|
||||
def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey:
|
||||
if self.oid:
|
||||
return self.cls
|
||||
|
||||
item = self._find_list_element(obj, format)
|
||||
if item is None:
|
||||
return self.cls
|
||||
|
||||
sd = self._tx.get_dumper(item, format)
|
||||
return (self.cls, sd.get_key(item, format))
|
||||
|
||||
def upgrade(self, obj: List[Any], format: PyFormat) -> "BaseListDumper":
|
||||
# If we have an oid we don't need to upgrade
|
||||
if self.oid:
|
||||
return self
|
||||
|
||||
item = self._find_list_element(obj, format)
|
||||
if item is None:
|
||||
# Empty lists can only be dumped as text if the type is unknown.
|
||||
return self
|
||||
|
||||
sd = self._tx.get_dumper(item, PyFormat.from_pq(self.format))
|
||||
dumper = type(self)(self.cls, self._tx)
|
||||
dumper.sub_dumper = sd
|
||||
|
||||
# We consider an array of unknowns as unknown, so we can dump empty
|
||||
# lists or lists containing only None elements.
|
||||
if sd.oid != INVALID_OID:
|
||||
info = self._get_base_type_info(sd.oid)
|
||||
dumper.oid = info.array_oid or TEXT_ARRAY_OID
|
||||
dumper.delimiter = info.delimiter.encode()
|
||||
else:
|
||||
dumper.oid = INVALID_OID
|
||||
|
||||
return dumper
|
||||
|
||||
# Double quotes and backslashes embedded in element values will be
|
||||
# backslash-escaped.
|
||||
_re_esc = re.compile(rb'(["\\])')
|
||||
|
||||
def dump(self, obj: List[Any]) -> bytes:
|
||||
tokens: List[Buffer] = []
|
||||
needs_quotes = _get_needs_quotes_regexp(self.delimiter).search
|
||||
|
||||
def dump_list(obj: List[Any]) -> None:
|
||||
if not obj:
|
||||
tokens.append(b"{}")
|
||||
return
|
||||
|
||||
tokens.append(b"{")
|
||||
for item in obj:
|
||||
if isinstance(item, list):
|
||||
dump_list(item)
|
||||
elif item is not None:
|
||||
ad = self._dump_item(item)
|
||||
if needs_quotes(ad):
|
||||
if not isinstance(ad, bytes):
|
||||
ad = bytes(ad)
|
||||
ad = b'"' + self._re_esc.sub(rb"\\\1", ad) + b'"'
|
||||
tokens.append(ad)
|
||||
else:
|
||||
tokens.append(b"NULL")
|
||||
|
||||
tokens.append(self.delimiter)
|
||||
|
||||
tokens[-1] = b"}"
|
||||
|
||||
dump_list(obj)
|
||||
|
||||
return b"".join(tokens)
|
||||
|
||||
def _dump_item(self, item: Any) -> Buffer:
|
||||
if self.sub_dumper:
|
||||
return self.sub_dumper.dump(item)
|
||||
else:
|
||||
return self._tx.get_dumper(item, PY_TEXT).dump(item)
|
||||
|
||||
|
||||
@cache
|
||||
def _get_needs_quotes_regexp(delimiter: bytes) -> Pattern[bytes]:
|
||||
"""Return a regexp to recognise when a value needs quotes
|
||||
|
||||
from https://www.postgresql.org/docs/current/arrays.html#ARRAYS-IO
|
||||
|
||||
The array output routine will put double quotes around element values if
|
||||
they are empty strings, contain curly braces, delimiter characters,
|
||||
double quotes, backslashes, or white space, or match the word NULL.
|
||||
"""
|
||||
return re.compile(
|
||||
rb"""(?xi)
|
||||
^$ # the empty string
|
||||
| ["{}%s\\\s] # or a char to escape
|
||||
| ^null$ # or the word NULL
|
||||
"""
|
||||
% delimiter
|
||||
)
|
||||
|
||||
|
||||
class ListBinaryDumper(BaseListDumper):
|
||||
format = pq.Format.BINARY
|
||||
|
||||
def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey:
|
||||
if self.oid:
|
||||
return self.cls
|
||||
|
||||
item = self._find_list_element(obj, format)
|
||||
if item is None:
|
||||
return (self.cls,)
|
||||
|
||||
sd = self._tx.get_dumper(item, format)
|
||||
return (self.cls, sd.get_key(item, format))
|
||||
|
||||
def upgrade(self, obj: List[Any], format: PyFormat) -> "BaseListDumper":
|
||||
# If we have an oid we don't need to upgrade
|
||||
if self.oid:
|
||||
return self
|
||||
|
||||
item = self._find_list_element(obj, format)
|
||||
if item is None:
|
||||
return ListDumper(self.cls, self._tx)
|
||||
|
||||
sd = self._tx.get_dumper(item, format.from_pq(self.format))
|
||||
dumper = type(self)(self.cls, self._tx)
|
||||
dumper.sub_dumper = sd
|
||||
info = self._get_base_type_info(sd.oid)
|
||||
dumper.oid = info.array_oid or TEXT_ARRAY_OID
|
||||
|
||||
return dumper
|
||||
|
||||
def dump(self, obj: List[Any]) -> bytes:
|
||||
# Postgres won't take unknown for element oid: fall back on text
|
||||
sub_oid = self.sub_dumper and self.sub_dumper.oid or TEXT_OID
|
||||
|
||||
if not obj:
|
||||
return _pack_head(0, 0, sub_oid)
|
||||
|
||||
data: List[Buffer] = [b"", b""] # placeholders to avoid a resize
|
||||
dims: List[int] = []
|
||||
hasnull = 0
|
||||
|
||||
def calc_dims(L: List[Any]) -> None:
|
||||
if isinstance(L, self.cls):
|
||||
if not L:
|
||||
raise e.DataError("lists cannot contain empty lists")
|
||||
dims.append(len(L))
|
||||
calc_dims(L[0])
|
||||
|
||||
calc_dims(obj)
|
||||
|
||||
def dump_list(L: List[Any], dim: int) -> None:
|
||||
nonlocal hasnull
|
||||
if len(L) != dims[dim]:
|
||||
raise e.DataError("nested lists have inconsistent lengths")
|
||||
|
||||
if dim == len(dims) - 1:
|
||||
for item in L:
|
||||
if item is not None:
|
||||
# If we get here, the sub_dumper must have been set
|
||||
ad = self.sub_dumper.dump(item) # type: ignore[union-attr]
|
||||
data.append(pack_len(len(ad)))
|
||||
data.append(ad)
|
||||
else:
|
||||
hasnull = 1
|
||||
data.append(b"\xff\xff\xff\xff")
|
||||
else:
|
||||
for item in L:
|
||||
if not isinstance(item, self.cls):
|
||||
raise e.DataError("nested lists have inconsistent depths")
|
||||
dump_list(item, dim + 1) # type: ignore
|
||||
|
||||
dump_list(obj, 0)
|
||||
|
||||
data[0] = _pack_head(len(dims), hasnull, sub_oid)
|
||||
data[1] = b"".join(_pack_dim(dim, 1) for dim in dims)
|
||||
return b"".join(data)
|
||||
|
||||
|
||||
class ArrayLoader(RecursiveLoader):
|
||||
delimiter = b","
|
||||
base_oid: int
|
||||
|
||||
def load(self, data: Buffer) -> List[Any]:
|
||||
loader = self._tx.get_loader(self.base_oid, self.format)
|
||||
return _load_text(data, loader, self.delimiter)
|
||||
|
||||
|
||||
class ArrayBinaryLoader(RecursiveLoader):
|
||||
format = pq.Format.BINARY
|
||||
|
||||
def load(self, data: Buffer) -> List[Any]:
|
||||
return _load_binary(data, self._tx)
|
||||
|
||||
|
||||
def register_array(info: TypeInfo, context: Optional[AdaptContext] = None) -> None:
|
||||
if not info.array_oid:
|
||||
raise ValueError(f"the type info {info} doesn't describe an array")
|
||||
|
||||
adapters = context.adapters if context else postgres.adapters
|
||||
|
||||
loader = _make_loader(info.name, info.oid, info.delimiter)
|
||||
adapters.register_loader(info.array_oid, loader)
|
||||
|
||||
# No need to make a new loader because the binary datum has all the info.
|
||||
loader = getattr(_psycopg, "ArrayBinaryLoader", ArrayBinaryLoader)
|
||||
adapters.register_loader(info.array_oid, loader)
|
||||
|
||||
dumper = _make_dumper(info.name, info.oid, info.array_oid, info.delimiter)
|
||||
adapters.register_dumper(None, dumper)
|
||||
|
||||
dumper = _make_binary_dumper(info.name, info.oid, info.array_oid)
|
||||
adapters.register_dumper(None, dumper)
|
||||
|
||||
|
||||
# Cache all dynamically-generated types to avoid leaks in case the types
|
||||
# cannot be GC'd.
|
||||
|
||||
|
||||
@cache
|
||||
def _make_loader(name: str, oid: int, delimiter: str) -> Type[Loader]:
|
||||
# Note: caching this function is really needed because, if the C extension
|
||||
# is available, the resulting type cannot be GC'd, so calling
|
||||
# register_array() in a loop results in a leak. See #647.
|
||||
base = getattr(_psycopg, "ArrayLoader", ArrayLoader)
|
||||
attribs = {"base_oid": oid, "delimiter": delimiter.encode()}
|
||||
return type(f"{name.title()}{base.__name__}", (base,), attribs)
|
||||
|
||||
|
||||
@cache
|
||||
def _make_dumper(
|
||||
name: str, oid: int, array_oid: int, delimiter: str
|
||||
) -> Type[BaseListDumper]:
|
||||
attribs = {"oid": array_oid, "element_oid": oid, "delimiter": delimiter.encode()}
|
||||
return type(f"{name.title()}ListDumper", (ListDumper,), attribs)
|
||||
|
||||
|
||||
@cache
|
||||
def _make_binary_dumper(name: str, oid: int, array_oid: int) -> Type[BaseListDumper]:
|
||||
attribs = {"oid": array_oid, "element_oid": oid}
|
||||
return type(f"{name.title()}ListBinaryDumper", (ListBinaryDumper,), attribs)
|
||||
|
||||
|
||||
def register_default_adapters(context: AdaptContext) -> None:
|
||||
# The text dumper is more flexible as it can handle lists of mixed type,
|
||||
# so register it later.
|
||||
context.adapters.register_dumper(list, ListBinaryDumper)
|
||||
context.adapters.register_dumper(list, ListDumper)
|
||||
|
||||
|
||||
def register_all_arrays(context: AdaptContext) -> None:
|
||||
"""
|
||||
Associate the array oid of all the types in Loader.globals.
|
||||
|
||||
This function is designed to be called once at import time, after having
|
||||
registered all the base loaders.
|
||||
"""
|
||||
for t in context.adapters.types:
|
||||
if t.array_oid:
|
||||
t.register(context)
|
||||
|
||||
|
||||
def _load_text(
|
||||
data: Buffer,
|
||||
loader: Loader,
|
||||
delimiter: bytes = b",",
|
||||
__re_unescape: Pattern[bytes] = re.compile(rb"\\(.)"),
|
||||
) -> List[Any]:
|
||||
rv = None
|
||||
stack: List[Any] = []
|
||||
a: List[Any] = []
|
||||
rv = a
|
||||
load = loader.load
|
||||
|
||||
# Remove the dimensions information prefix (``[...]=``)
|
||||
if data and data[0] == b"["[0]:
|
||||
if isinstance(data, memoryview):
|
||||
data = bytes(data)
|
||||
idx = data.find(b"=")
|
||||
if idx == -1:
|
||||
raise e.DataError("malformed array: no '=' after dimension information")
|
||||
data = data[idx + 1 :]
|
||||
|
||||
re_parse = _get_array_parse_regexp(delimiter)
|
||||
for m in re_parse.finditer(data):
|
||||
t = m.group(1)
|
||||
if t == b"{":
|
||||
if stack:
|
||||
stack[-1].append(a)
|
||||
stack.append(a)
|
||||
a = []
|
||||
|
||||
elif t == b"}":
|
||||
if not stack:
|
||||
raise e.DataError("malformed array: unexpected '}'")
|
||||
rv = stack.pop()
|
||||
|
||||
else:
|
||||
if not stack:
|
||||
wat = t[:10].decode("utf8", "replace") + "..." if len(t) > 10 else ""
|
||||
raise e.DataError(f"malformed array: unexpected '{wat}'")
|
||||
if t == b"NULL":
|
||||
v = None
|
||||
else:
|
||||
if t.startswith(b'"'):
|
||||
t = __re_unescape.sub(rb"\1", t[1:-1])
|
||||
v = load(t)
|
||||
|
||||
stack[-1].append(v)
|
||||
|
||||
assert rv is not None
|
||||
return rv
|
||||
|
||||
|
||||
@cache
|
||||
def _get_array_parse_regexp(delimiter: bytes) -> Pattern[bytes]:
|
||||
"""
|
||||
Return a regexp to tokenize an array representation into item and brackets
|
||||
"""
|
||||
return re.compile(
|
||||
rb"""(?xi)
|
||||
( [{}] # open or closed bracket
|
||||
| " (?: [^"\\] | \\. )* " # or a quoted string
|
||||
| [^"{}%s\\]+ # or an unquoted non-empty string
|
||||
) ,?
|
||||
"""
|
||||
% delimiter
|
||||
)
|
||||
|
||||
|
||||
def _load_binary(data: Buffer, tx: Transformer) -> List[Any]:
|
||||
ndims, hasnull, oid = _unpack_head(data)
|
||||
load = tx.get_loader(oid, PQ_BINARY).load
|
||||
|
||||
if not ndims:
|
||||
return []
|
||||
|
||||
p = 12 + 8 * ndims
|
||||
dims = [_unpack_dim(data, i)[0] for i in range(12, p, 8)]
|
||||
nelems = prod(dims)
|
||||
|
||||
out: List[Any] = [None] * nelems
|
||||
for i in range(nelems):
|
||||
size = unpack_len(data, p)[0]
|
||||
p += 4
|
||||
if size == -1:
|
||||
continue
|
||||
out[i] = load(data[p : p + size])
|
||||
p += size
|
||||
|
||||
# fon ndims > 1 we have to aggregate the array into sub-arrays
|
||||
for dim in dims[-1:0:-1]:
|
||||
out = [out[i : i + dim] for i in range(0, len(out), dim)]
|
||||
|
||||
return out
|
||||
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
Adapters for booleans.
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
from .. import postgres
|
||||
from ..pq import Format
|
||||
from ..abc import AdaptContext
|
||||
from ..adapt import Buffer, Dumper, Loader
|
||||
|
||||
|
||||
class BoolDumper(Dumper):
|
||||
oid = postgres.types["bool"].oid
|
||||
|
||||
def dump(self, obj: bool) -> bytes:
|
||||
return b"t" if obj else b"f"
|
||||
|
||||
def quote(self, obj: bool) -> bytes:
|
||||
return b"true" if obj else b"false"
|
||||
|
||||
|
||||
class BoolBinaryDumper(Dumper):
|
||||
format = Format.BINARY
|
||||
oid = postgres.types["bool"].oid
|
||||
|
||||
def dump(self, obj: bool) -> bytes:
|
||||
return b"\x01" if obj else b"\x00"
|
||||
|
||||
|
||||
class BoolLoader(Loader):
|
||||
def load(self, data: Buffer) -> bool:
|
||||
return data == b"t"
|
||||
|
||||
|
||||
class BoolBinaryLoader(Loader):
|
||||
format = Format.BINARY
|
||||
|
||||
def load(self, data: Buffer) -> bool:
|
||||
return data != b"\x00"
|
||||
|
||||
|
||||
def register_default_adapters(context: AdaptContext) -> None:
|
||||
adapters = context.adapters
|
||||
adapters.register_dumper(bool, BoolDumper)
|
||||
adapters.register_dumper(bool, BoolBinaryDumper)
|
||||
adapters.register_loader("bool", BoolLoader)
|
||||
adapters.register_loader("bool", BoolBinaryLoader)
|
||||
@@ -0,0 +1,328 @@
|
||||
"""
|
||||
Support for composite types adaptation.
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
import re
|
||||
import struct
|
||||
from collections import namedtuple
|
||||
from typing import Any, Callable, cast, Dict, Iterator, List, Optional
|
||||
from typing import NamedTuple, Sequence, Tuple, Type
|
||||
|
||||
from .. import pq
|
||||
from .. import abc
|
||||
from .. import postgres
|
||||
from ..adapt import Transformer, PyFormat, RecursiveDumper, Loader, Dumper
|
||||
from .._compat import cache
|
||||
from .._struct import pack_len, unpack_len
|
||||
from ..postgres import TEXT_OID
|
||||
from .._typeinfo import CompositeInfo as CompositeInfo # exported here
|
||||
from .._encodings import _as_python_identifier
|
||||
|
||||
_struct_oidlen = struct.Struct("!Ii")
|
||||
_pack_oidlen = cast(Callable[[int, int], bytes], _struct_oidlen.pack)
|
||||
_unpack_oidlen = cast(
|
||||
Callable[[abc.Buffer, int], Tuple[int, int]], _struct_oidlen.unpack_from
|
||||
)
|
||||
|
||||
|
||||
class SequenceDumper(RecursiveDumper):
|
||||
def _dump_sequence(
|
||||
self, obj: Sequence[Any], start: bytes, end: bytes, sep: bytes
|
||||
) -> bytes:
|
||||
if not obj:
|
||||
return start + end
|
||||
|
||||
parts: List[abc.Buffer] = [start]
|
||||
|
||||
for item in obj:
|
||||
if item is None:
|
||||
parts.append(sep)
|
||||
continue
|
||||
|
||||
dumper = self._tx.get_dumper(item, PyFormat.from_pq(self.format))
|
||||
ad = dumper.dump(item)
|
||||
if not ad:
|
||||
ad = b'""'
|
||||
elif self._re_needs_quotes.search(ad):
|
||||
ad = b'"' + self._re_esc.sub(rb"\1\1", ad) + b'"'
|
||||
|
||||
parts.append(ad)
|
||||
parts.append(sep)
|
||||
|
||||
parts[-1] = end
|
||||
|
||||
return b"".join(parts)
|
||||
|
||||
_re_needs_quotes = re.compile(rb'[",\\\s()]')
|
||||
_re_esc = re.compile(rb"([\\\"])")
|
||||
|
||||
|
||||
class TupleDumper(SequenceDumper):
|
||||
# Should be this, but it doesn't work
|
||||
# oid = postgres_types["record"].oid
|
||||
|
||||
def dump(self, obj: Tuple[Any, ...]) -> bytes:
|
||||
return self._dump_sequence(obj, b"(", b")", b",")
|
||||
|
||||
|
||||
class TupleBinaryDumper(Dumper):
|
||||
format = pq.Format.BINARY
|
||||
|
||||
# Subclasses must set this info
|
||||
_field_types: Tuple[int, ...]
|
||||
|
||||
def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None):
|
||||
super().__init__(cls, context)
|
||||
|
||||
# Note: this class is not a RecursiveDumper because it would use the
|
||||
# same Transformer of the context, which would confuse dump_sequence()
|
||||
# in case the composite contains another composite. Make sure to use
|
||||
# a separate Transformer instance instead.
|
||||
self._tx = Transformer(context)
|
||||
self._tx.set_dumper_types(self._field_types, self.format)
|
||||
|
||||
nfields = len(self._field_types)
|
||||
self._formats = (PyFormat.from_pq(self.format),) * nfields
|
||||
|
||||
def dump(self, obj: Tuple[Any, ...]) -> bytearray:
|
||||
out = bytearray(pack_len(len(obj)))
|
||||
adapted = self._tx.dump_sequence(obj, self._formats)
|
||||
for i in range(len(obj)):
|
||||
b = adapted[i]
|
||||
oid = self._field_types[i]
|
||||
if b is not None:
|
||||
out += _pack_oidlen(oid, len(b))
|
||||
out += b
|
||||
else:
|
||||
out += _pack_oidlen(oid, -1)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class BaseCompositeLoader(Loader):
|
||||
def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None):
|
||||
super().__init__(oid, context)
|
||||
self._tx = Transformer(context)
|
||||
|
||||
def _parse_record(self, data: abc.Buffer) -> Iterator[Optional[bytes]]:
|
||||
"""
|
||||
Split a non-empty representation of a composite type into components.
|
||||
|
||||
Terminators shouldn't be used in `!data` (so that both record and range
|
||||
representations can be parsed).
|
||||
"""
|
||||
for m in self._re_tokenize.finditer(data):
|
||||
if m.group(1):
|
||||
yield None
|
||||
elif m.group(2) is not None:
|
||||
yield self._re_undouble.sub(rb"\1", m.group(2))
|
||||
else:
|
||||
yield m.group(3)
|
||||
|
||||
# If the final group ended in `,` there is a final NULL in the record
|
||||
# that the regexp couldn't parse.
|
||||
if m and m.group().endswith(b","):
|
||||
yield None
|
||||
|
||||
_re_tokenize = re.compile(
|
||||
rb"""(?x)
|
||||
(,) # an empty token, representing NULL
|
||||
| " ((?: [^"] | "")*) " ,? # or a quoted string
|
||||
| ([^",)]+) ,? # or an unquoted string
|
||||
"""
|
||||
)
|
||||
|
||||
_re_undouble = re.compile(rb'(["\\])\1')
|
||||
|
||||
|
||||
class RecordLoader(BaseCompositeLoader):
|
||||
def load(self, data: abc.Buffer) -> Tuple[Any, ...]:
|
||||
if data == b"()":
|
||||
return ()
|
||||
|
||||
cast = self._tx.get_loader(TEXT_OID, self.format).load
|
||||
return tuple(
|
||||
cast(token) if token is not None else None
|
||||
for token in self._parse_record(data[1:-1])
|
||||
)
|
||||
|
||||
|
||||
class RecordBinaryLoader(Loader):
|
||||
format = pq.Format.BINARY
|
||||
|
||||
def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None):
|
||||
super().__init__(oid, context)
|
||||
self._ctx = context
|
||||
# Cache a transformer for each sequence of oid found.
|
||||
# Usually there will be only one, but if there is more than one
|
||||
# row in the same query (in different columns, or even in different
|
||||
# records), oids might differ and we'd need separate transformers.
|
||||
self._txs: Dict[Tuple[int, ...], abc.Transformer] = {}
|
||||
|
||||
def load(self, data: abc.Buffer) -> Tuple[Any, ...]:
|
||||
nfields = unpack_len(data, 0)[0]
|
||||
offset = 4
|
||||
oids = []
|
||||
record = []
|
||||
for _ in range(nfields):
|
||||
oid, length = _unpack_oidlen(data, offset)
|
||||
offset += 8
|
||||
record.append(data[offset : offset + length] if length != -1 else None)
|
||||
oids.append(oid)
|
||||
if length >= 0:
|
||||
offset += length
|
||||
|
||||
key = tuple(oids)
|
||||
try:
|
||||
tx = self._txs[key]
|
||||
except KeyError:
|
||||
tx = self._txs[key] = Transformer(self._ctx)
|
||||
tx.set_loader_types(oids, self.format)
|
||||
|
||||
return tx.load_sequence(tuple(record))
|
||||
|
||||
|
||||
class CompositeLoader(RecordLoader):
|
||||
factory: Callable[..., Any]
|
||||
fields_types: List[int]
|
||||
_types_set = False
|
||||
|
||||
def load(self, data: abc.Buffer) -> Any:
|
||||
if not self._types_set:
|
||||
self._config_types(data)
|
||||
self._types_set = True
|
||||
|
||||
if data == b"()":
|
||||
return type(self).factory()
|
||||
|
||||
return type(self).factory(
|
||||
*self._tx.load_sequence(tuple(self._parse_record(data[1:-1])))
|
||||
)
|
||||
|
||||
def _config_types(self, data: abc.Buffer) -> None:
|
||||
self._tx.set_loader_types(self.fields_types, self.format)
|
||||
|
||||
|
||||
class CompositeBinaryLoader(RecordBinaryLoader):
|
||||
format = pq.Format.BINARY
|
||||
factory: Callable[..., Any]
|
||||
|
||||
def load(self, data: abc.Buffer) -> Any:
|
||||
r = super().load(data)
|
||||
return type(self).factory(*r)
|
||||
|
||||
|
||||
def register_composite(
|
||||
info: CompositeInfo,
|
||||
context: Optional[abc.AdaptContext] = None,
|
||||
factory: Optional[Callable[..., Any]] = None,
|
||||
) -> None:
|
||||
"""Register the adapters to load and dump a composite type.
|
||||
|
||||
:param info: The object with the information about the composite to register.
|
||||
:param context: The context where to register the adapters. If `!None`,
|
||||
register it globally.
|
||||
:param factory: Callable to convert the sequence of attributes read from
|
||||
the composite into a Python object.
|
||||
|
||||
.. note::
|
||||
|
||||
Registering the adapters doesn't affect objects already created, even
|
||||
if they are children of the registered context. For instance,
|
||||
registering the adapter globally doesn't affect already existing
|
||||
connections.
|
||||
"""
|
||||
|
||||
# A friendly error warning instead of an AttributeError in case fetch()
|
||||
# failed and it wasn't noticed.
|
||||
if not info:
|
||||
raise TypeError("no info passed. Is the requested composite available?")
|
||||
|
||||
# Register arrays and type info
|
||||
info.register(context)
|
||||
|
||||
if not factory:
|
||||
factory = _nt_from_info(info)
|
||||
|
||||
adapters = context.adapters if context else postgres.adapters
|
||||
|
||||
# generate and register a customized text loader
|
||||
loader: Type[BaseCompositeLoader]
|
||||
loader = _make_loader(info.name, tuple(info.field_types), factory)
|
||||
adapters.register_loader(info.oid, loader)
|
||||
|
||||
# generate and register a customized binary loader
|
||||
loader = _make_binary_loader(info.name, factory)
|
||||
adapters.register_loader(info.oid, loader)
|
||||
|
||||
# If the factory is a type, create and register dumpers for it
|
||||
if isinstance(factory, type):
|
||||
dumper: Type[Dumper]
|
||||
dumper = _make_binary_dumper(info.name, info.oid, tuple(info.field_types))
|
||||
adapters.register_dumper(factory, dumper)
|
||||
|
||||
# Default to the text dumper because it is more flexible
|
||||
dumper = _make_dumper(info.name, info.oid)
|
||||
adapters.register_dumper(factory, dumper)
|
||||
|
||||
info.python_type = factory
|
||||
|
||||
|
||||
def register_default_adapters(context: abc.AdaptContext) -> None:
|
||||
adapters = context.adapters
|
||||
adapters.register_dumper(tuple, TupleDumper)
|
||||
adapters.register_loader("record", RecordLoader)
|
||||
adapters.register_loader("record", RecordBinaryLoader)
|
||||
|
||||
|
||||
def _nt_from_info(info: CompositeInfo) -> Type[NamedTuple]:
|
||||
name = _as_python_identifier(info.name)
|
||||
fields = tuple(_as_python_identifier(n) for n in info.field_names)
|
||||
return _make_nt(name, fields)
|
||||
|
||||
|
||||
# Cache all dynamically-generated types to avoid leaks in case the types
|
||||
# cannot be GC'd.
|
||||
|
||||
|
||||
@cache
|
||||
def _make_nt(name: str, fields: Tuple[str, ...]) -> Type[NamedTuple]:
|
||||
return namedtuple(name, fields) # type: ignore[return-value]
|
||||
|
||||
|
||||
@cache
|
||||
def _make_loader(
|
||||
name: str, types: Tuple[int, ...], factory: Callable[..., Any]
|
||||
) -> Type[BaseCompositeLoader]:
|
||||
return type(
|
||||
f"{name.title()}Loader",
|
||||
(CompositeLoader,),
|
||||
{"factory": factory, "fields_types": list(types)},
|
||||
)
|
||||
|
||||
|
||||
@cache
|
||||
def _make_binary_loader(
|
||||
name: str, factory: Callable[..., Any]
|
||||
) -> Type[BaseCompositeLoader]:
|
||||
return type(
|
||||
f"{name.title()}BinaryLoader", (CompositeBinaryLoader,), {"factory": factory}
|
||||
)
|
||||
|
||||
|
||||
@cache
|
||||
def _make_dumper(name: str, oid: int) -> Type[TupleDumper]:
|
||||
return type(f"{name.title()}Dumper", (TupleDumper,), {"oid": oid})
|
||||
|
||||
|
||||
@cache
|
||||
def _make_binary_dumper(
|
||||
name: str, oid: int, field_types: Tuple[int, ...]
|
||||
) -> Type[TupleBinaryDumper]:
|
||||
return type(
|
||||
f"{name.title()}BinaryDumper",
|
||||
(TupleBinaryDumper,),
|
||||
{"oid": oid, "_field_types": field_types},
|
||||
)
|
||||
@@ -0,0 +1,741 @@
|
||||
"""
|
||||
Adapters for date/time types.
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
import re
|
||||
import struct
|
||||
from datetime import date, datetime, time, timedelta, timezone
|
||||
from typing import Any, Callable, cast, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from .. import postgres
|
||||
from ..pq import Format
|
||||
from .._tz import get_tzinfo
|
||||
from ..abc import AdaptContext, DumperKey
|
||||
from ..adapt import Buffer, Dumper, Loader, PyFormat
|
||||
from ..errors import InterfaceError, DataError
|
||||
from .._struct import pack_int4, pack_int8, unpack_int4, unpack_int8
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..connection import BaseConnection
|
||||
|
||||
_struct_timetz = struct.Struct("!qi") # microseconds, sec tz offset
|
||||
_pack_timetz = cast(Callable[[int, int], bytes], _struct_timetz.pack)
|
||||
_unpack_timetz = cast(Callable[[Buffer], Tuple[int, int]], _struct_timetz.unpack)
|
||||
|
||||
_struct_interval = struct.Struct("!qii") # microseconds, days, months
|
||||
_pack_interval = cast(Callable[[int, int, int], bytes], _struct_interval.pack)
|
||||
_unpack_interval = cast(
|
||||
Callable[[Buffer], Tuple[int, int, int]], _struct_interval.unpack
|
||||
)
|
||||
|
||||
utc = timezone.utc
|
||||
_pg_date_epoch_days = date(2000, 1, 1).toordinal()
|
||||
_pg_datetime_epoch = datetime(2000, 1, 1)
|
||||
_pg_datetimetz_epoch = datetime(2000, 1, 1, tzinfo=utc)
|
||||
_py_date_min_days = date.min.toordinal()
|
||||
|
||||
|
||||
class DateDumper(Dumper):
|
||||
oid = postgres.types["date"].oid
|
||||
|
||||
def dump(self, obj: date) -> bytes:
|
||||
# NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD)
|
||||
# the YYYY-MM-DD is always understood correctly.
|
||||
return str(obj).encode()
|
||||
|
||||
|
||||
class DateBinaryDumper(Dumper):
|
||||
format = Format.BINARY
|
||||
oid = postgres.types["date"].oid
|
||||
|
||||
def dump(self, obj: date) -> bytes:
|
||||
days = obj.toordinal() - _pg_date_epoch_days
|
||||
return pack_int4(days)
|
||||
|
||||
|
||||
class _BaseTimeDumper(Dumper):
|
||||
def get_key(self, obj: time, format: PyFormat) -> DumperKey:
|
||||
# Use (cls,) to report the need to upgrade to a dumper for timetz (the
|
||||
# Frankenstein of the data types).
|
||||
if not obj.tzinfo:
|
||||
return self.cls
|
||||
else:
|
||||
return (self.cls,)
|
||||
|
||||
def upgrade(self, obj: time, format: PyFormat) -> Dumper:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_offset(self, obj: time) -> timedelta:
|
||||
offset = obj.utcoffset()
|
||||
if offset is None:
|
||||
raise DataError(
|
||||
f"cannot calculate the offset of tzinfo '{obj.tzinfo}' without a date"
|
||||
)
|
||||
return offset
|
||||
|
||||
|
||||
class _BaseTimeTextDumper(_BaseTimeDumper):
|
||||
def dump(self, obj: time) -> bytes:
|
||||
return str(obj).encode()
|
||||
|
||||
|
||||
class TimeDumper(_BaseTimeTextDumper):
|
||||
oid = postgres.types["time"].oid
|
||||
|
||||
def upgrade(self, obj: time, format: PyFormat) -> Dumper:
|
||||
if not obj.tzinfo:
|
||||
return self
|
||||
else:
|
||||
return TimeTzDumper(self.cls)
|
||||
|
||||
|
||||
class TimeTzDumper(_BaseTimeTextDumper):
|
||||
oid = postgres.types["timetz"].oid
|
||||
|
||||
def dump(self, obj: time) -> bytes:
|
||||
self._get_offset(obj)
|
||||
return super().dump(obj)
|
||||
|
||||
|
||||
class TimeBinaryDumper(_BaseTimeDumper):
|
||||
format = Format.BINARY
|
||||
oid = postgres.types["time"].oid
|
||||
|
||||
def dump(self, obj: time) -> bytes:
|
||||
us = obj.microsecond + 1_000_000 * (
|
||||
obj.second + 60 * (obj.minute + 60 * obj.hour)
|
||||
)
|
||||
return pack_int8(us)
|
||||
|
||||
def upgrade(self, obj: time, format: PyFormat) -> Dumper:
|
||||
if not obj.tzinfo:
|
||||
return self
|
||||
else:
|
||||
return TimeTzBinaryDumper(self.cls)
|
||||
|
||||
|
||||
class TimeTzBinaryDumper(_BaseTimeDumper):
|
||||
format = Format.BINARY
|
||||
oid = postgres.types["timetz"].oid
|
||||
|
||||
def dump(self, obj: time) -> bytes:
|
||||
us = obj.microsecond + 1_000_000 * (
|
||||
obj.second + 60 * (obj.minute + 60 * obj.hour)
|
||||
)
|
||||
off = self._get_offset(obj)
|
||||
return _pack_timetz(us, -int(off.total_seconds()))
|
||||
|
||||
|
||||
class _BaseDatetimeDumper(Dumper):
|
||||
def get_key(self, obj: datetime, format: PyFormat) -> DumperKey:
|
||||
# Use (cls,) to report the need to upgrade (downgrade, actually) to a
|
||||
# dumper for naive timestamp.
|
||||
if obj.tzinfo:
|
||||
return self.cls
|
||||
else:
|
||||
return (self.cls,)
|
||||
|
||||
def upgrade(self, obj: datetime, format: PyFormat) -> Dumper:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _BaseDatetimeTextDumper(_BaseDatetimeDumper):
|
||||
def dump(self, obj: datetime) -> bytes:
|
||||
# NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD)
|
||||
# the YYYY-MM-DD is always understood correctly.
|
||||
return str(obj).encode()
|
||||
|
||||
|
||||
class DatetimeDumper(_BaseDatetimeTextDumper):
|
||||
oid = postgres.types["timestamptz"].oid
|
||||
|
||||
def upgrade(self, obj: datetime, format: PyFormat) -> Dumper:
|
||||
if obj.tzinfo:
|
||||
return self
|
||||
else:
|
||||
return DatetimeNoTzDumper(self.cls)
|
||||
|
||||
|
||||
class DatetimeNoTzDumper(_BaseDatetimeTextDumper):
|
||||
oid = postgres.types["timestamp"].oid
|
||||
|
||||
|
||||
class DatetimeBinaryDumper(_BaseDatetimeDumper):
|
||||
format = Format.BINARY
|
||||
oid = postgres.types["timestamptz"].oid
|
||||
|
||||
def dump(self, obj: datetime) -> bytes:
|
||||
delta = obj - _pg_datetimetz_epoch
|
||||
micros = delta.microseconds + 1_000_000 * (86_400 * delta.days + delta.seconds)
|
||||
return pack_int8(micros)
|
||||
|
||||
def upgrade(self, obj: datetime, format: PyFormat) -> Dumper:
|
||||
if obj.tzinfo:
|
||||
return self
|
||||
else:
|
||||
return DatetimeNoTzBinaryDumper(self.cls)
|
||||
|
||||
|
||||
class DatetimeNoTzBinaryDumper(_BaseDatetimeDumper):
|
||||
format = Format.BINARY
|
||||
oid = postgres.types["timestamp"].oid
|
||||
|
||||
def dump(self, obj: datetime) -> bytes:
|
||||
delta = obj - _pg_datetime_epoch
|
||||
micros = delta.microseconds + 1_000_000 * (86_400 * delta.days + delta.seconds)
|
||||
return pack_int8(micros)
|
||||
|
||||
|
||||
class TimedeltaDumper(Dumper):
|
||||
oid = postgres.types["interval"].oid
|
||||
|
||||
def __init__(self, cls: type, context: Optional[AdaptContext] = None):
|
||||
super().__init__(cls, context)
|
||||
if self.connection:
|
||||
if (
|
||||
self.connection.pgconn.parameter_status(b"IntervalStyle")
|
||||
== b"sql_standard"
|
||||
):
|
||||
setattr(self, "dump", self._dump_sql)
|
||||
|
||||
def dump(self, obj: timedelta) -> bytes:
|
||||
# The comma is parsed ok by PostgreSQL but it's not documented
|
||||
# and it seems brittle to rely on it. CRDB doesn't consume it well.
|
||||
return str(obj).encode().replace(b",", b"")
|
||||
|
||||
def _dump_sql(self, obj: timedelta) -> bytes:
|
||||
# sql_standard format needs explicit signs
|
||||
# otherwise -1 day 1 sec will mean -1 sec
|
||||
return b"%+d day %+d second %+d microsecond" % (
|
||||
obj.days,
|
||||
obj.seconds,
|
||||
obj.microseconds,
|
||||
)
|
||||
|
||||
|
||||
class TimedeltaBinaryDumper(Dumper):
|
||||
format = Format.BINARY
|
||||
oid = postgres.types["interval"].oid
|
||||
|
||||
def dump(self, obj: timedelta) -> bytes:
|
||||
micros = 1_000_000 * obj.seconds + obj.microseconds
|
||||
return _pack_interval(micros, obj.days, 0)
|
||||
|
||||
|
||||
class DateLoader(Loader):
|
||||
_ORDER_YMD = 0
|
||||
_ORDER_DMY = 1
|
||||
_ORDER_MDY = 2
|
||||
|
||||
def __init__(self, oid: int, context: Optional[AdaptContext] = None):
|
||||
super().__init__(oid, context)
|
||||
ds = _get_datestyle(self.connection)
|
||||
if ds.startswith(b"I"): # ISO
|
||||
self._order = self._ORDER_YMD
|
||||
elif ds.startswith(b"G"): # German
|
||||
self._order = self._ORDER_DMY
|
||||
elif ds.startswith(b"S") or ds.startswith(b"P"): # SQL or Postgres
|
||||
self._order = self._ORDER_DMY if ds.endswith(b"DMY") else self._ORDER_MDY
|
||||
else:
|
||||
raise InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}")
|
||||
|
||||
def load(self, data: Buffer) -> date:
|
||||
if self._order == self._ORDER_YMD:
|
||||
ye = data[:4]
|
||||
mo = data[5:7]
|
||||
da = data[8:]
|
||||
elif self._order == self._ORDER_DMY:
|
||||
da = data[:2]
|
||||
mo = data[3:5]
|
||||
ye = data[6:]
|
||||
else:
|
||||
mo = data[:2]
|
||||
da = data[3:5]
|
||||
ye = data[6:]
|
||||
|
||||
try:
|
||||
return date(int(ye), int(mo), int(da))
|
||||
except ValueError as ex:
|
||||
s = bytes(data).decode("utf8", "replace")
|
||||
if s == "infinity" or (s and len(s.split()[0]) > 10):
|
||||
raise DataError(f"date too large (after year 10K): {s!r}") from None
|
||||
elif s == "-infinity" or "BC" in s:
|
||||
raise DataError(f"date too small (before year 1): {s!r}") from None
|
||||
else:
|
||||
raise DataError(f"can't parse date {s!r}: {ex}") from None
|
||||
|
||||
|
||||
class DateBinaryLoader(Loader):
|
||||
format = Format.BINARY
|
||||
|
||||
def load(self, data: Buffer) -> date:
|
||||
days = unpack_int4(data)[0] + _pg_date_epoch_days
|
||||
try:
|
||||
return date.fromordinal(days)
|
||||
except (ValueError, OverflowError):
|
||||
if days < _py_date_min_days:
|
||||
raise DataError("date too small (before year 1)") from None
|
||||
else:
|
||||
raise DataError("date too large (after year 10K)") from None
|
||||
|
||||
|
||||
class TimeLoader(Loader):
|
||||
_re_format = re.compile(rb"^(\d+):(\d+):(\d+)(?:\.(\d+))?")
|
||||
|
||||
def load(self, data: Buffer) -> time:
|
||||
m = self._re_format.match(data)
|
||||
if not m:
|
||||
s = bytes(data).decode("utf8", "replace")
|
||||
raise DataError(f"can't parse time {s!r}")
|
||||
|
||||
ho, mi, se, fr = m.groups()
|
||||
|
||||
# Pad the fraction of second to get micros
|
||||
if fr:
|
||||
us = int(fr)
|
||||
if len(fr) < 6:
|
||||
us *= _uspad[len(fr)]
|
||||
else:
|
||||
us = 0
|
||||
|
||||
try:
|
||||
return time(int(ho), int(mi), int(se), us)
|
||||
except ValueError as e:
|
||||
s = bytes(data).decode("utf8", "replace")
|
||||
raise DataError(f"can't parse time {s!r}: {e}") from None
|
||||
|
||||
|
||||
class TimeBinaryLoader(Loader):
|
||||
format = Format.BINARY
|
||||
|
||||
def load(self, data: Buffer) -> time:
|
||||
val = unpack_int8(data)[0]
|
||||
val, us = divmod(val, 1_000_000)
|
||||
val, s = divmod(val, 60)
|
||||
h, m = divmod(val, 60)
|
||||
try:
|
||||
return time(h, m, s, us)
|
||||
except ValueError:
|
||||
raise DataError(f"time not supported by Python: hour={h}") from None
|
||||
|
||||
|
||||
class TimetzLoader(Loader):
|
||||
_re_format = re.compile(
|
||||
rb"""(?ix)
|
||||
^
|
||||
(\d+) : (\d+) : (\d+) (?: \. (\d+) )? # Time and micros
|
||||
([-+]) (\d+) (?: : (\d+) )? (?: : (\d+) )? # Timezone
|
||||
$
|
||||
"""
|
||||
)
|
||||
|
||||
def load(self, data: Buffer) -> time:
|
||||
m = self._re_format.match(data)
|
||||
if not m:
|
||||
s = bytes(data).decode("utf8", "replace")
|
||||
raise DataError(f"can't parse timetz {s!r}")
|
||||
|
||||
ho, mi, se, fr, sgn, oh, om, os = m.groups()
|
||||
|
||||
# Pad the fraction of second to get the micros
|
||||
if fr:
|
||||
us = int(fr)
|
||||
if len(fr) < 6:
|
||||
us *= _uspad[len(fr)]
|
||||
else:
|
||||
us = 0
|
||||
|
||||
# Calculate timezone
|
||||
off = 60 * 60 * int(oh)
|
||||
if om:
|
||||
off += 60 * int(om)
|
||||
if os:
|
||||
off += int(os)
|
||||
tz = timezone(timedelta(0, off if sgn == b"+" else -off))
|
||||
|
||||
try:
|
||||
return time(int(ho), int(mi), int(se), us, tz)
|
||||
except ValueError as e:
|
||||
s = bytes(data).decode("utf8", "replace")
|
||||
raise DataError(f"can't parse timetz {s!r}: {e}") from None
|
||||
|
||||
|
||||
class TimetzBinaryLoader(Loader):
|
||||
format = Format.BINARY
|
||||
|
||||
def load(self, data: Buffer) -> time:
|
||||
val, off = _unpack_timetz(data)
|
||||
|
||||
val, us = divmod(val, 1_000_000)
|
||||
val, s = divmod(val, 60)
|
||||
h, m = divmod(val, 60)
|
||||
|
||||
try:
|
||||
return time(h, m, s, us, timezone(timedelta(seconds=-off)))
|
||||
except ValueError:
|
||||
raise DataError(f"time not supported by Python: hour={h}") from None
|
||||
|
||||
|
||||
class TimestampLoader(Loader):
|
||||
_re_format = re.compile(
|
||||
rb"""(?ix)
|
||||
^
|
||||
(\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Date
|
||||
(?: T | [^a-z0-9] ) # Separator, including T
|
||||
(\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Time
|
||||
(?: \.(\d+) )? # Micros
|
||||
$
|
||||
"""
|
||||
)
|
||||
_re_format_pg = re.compile(
|
||||
rb"""(?ix)
|
||||
^
|
||||
[a-z]+ [^a-z0-9] # DoW, separator
|
||||
(\d+|[a-z]+) [^a-z0-9] # Month or day
|
||||
(\d+|[a-z]+) [^a-z0-9] # Month or day
|
||||
(\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Time
|
||||
(?: \.(\d+) )? # Micros
|
||||
[^a-z0-9] (\d+) # Year
|
||||
$
|
||||
"""
|
||||
)
|
||||
|
||||
_ORDER_YMD = 0
|
||||
_ORDER_DMY = 1
|
||||
_ORDER_MDY = 2
|
||||
_ORDER_PGDM = 3
|
||||
_ORDER_PGMD = 4
|
||||
|
||||
def __init__(self, oid: int, context: Optional[AdaptContext] = None):
|
||||
super().__init__(oid, context)
|
||||
|
||||
ds = _get_datestyle(self.connection)
|
||||
if ds.startswith(b"I"): # ISO
|
||||
self._order = self._ORDER_YMD
|
||||
elif ds.startswith(b"G"): # German
|
||||
self._order = self._ORDER_DMY
|
||||
elif ds.startswith(b"S"): # SQL
|
||||
self._order = self._ORDER_DMY if ds.endswith(b"DMY") else self._ORDER_MDY
|
||||
elif ds.startswith(b"P"): # Postgres
|
||||
self._order = self._ORDER_PGDM if ds.endswith(b"DMY") else self._ORDER_PGMD
|
||||
self._re_format = self._re_format_pg
|
||||
else:
|
||||
raise InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}")
|
||||
|
||||
def load(self, data: Buffer) -> datetime:
|
||||
m = self._re_format.match(data)
|
||||
if not m:
|
||||
raise _get_timestamp_load_error(self.connection, data) from None
|
||||
|
||||
if self._order == self._ORDER_YMD:
|
||||
ye, mo, da, ho, mi, se, fr = m.groups()
|
||||
imo = int(mo)
|
||||
elif self._order == self._ORDER_DMY:
|
||||
da, mo, ye, ho, mi, se, fr = m.groups()
|
||||
imo = int(mo)
|
||||
elif self._order == self._ORDER_MDY:
|
||||
mo, da, ye, ho, mi, se, fr = m.groups()
|
||||
imo = int(mo)
|
||||
else:
|
||||
if self._order == self._ORDER_PGDM:
|
||||
da, mo, ho, mi, se, fr, ye = m.groups()
|
||||
else:
|
||||
mo, da, ho, mi, se, fr, ye = m.groups()
|
||||
try:
|
||||
imo = _month_abbr[mo]
|
||||
except KeyError:
|
||||
s = mo.decode("utf8", "replace")
|
||||
raise DataError(f"can't parse month: {s!r}") from None
|
||||
|
||||
# Pad the fraction of second to get the micros
|
||||
if fr:
|
||||
us = int(fr)
|
||||
if len(fr) < 6:
|
||||
us *= _uspad[len(fr)]
|
||||
else:
|
||||
us = 0
|
||||
|
||||
try:
|
||||
return datetime(int(ye), imo, int(da), int(ho), int(mi), int(se), us)
|
||||
except ValueError as ex:
|
||||
raise _get_timestamp_load_error(self.connection, data, ex) from None
|
||||
|
||||
|
||||
class TimestampBinaryLoader(Loader):
|
||||
format = Format.BINARY
|
||||
|
||||
def load(self, data: Buffer) -> datetime:
|
||||
micros = unpack_int8(data)[0]
|
||||
try:
|
||||
return _pg_datetime_epoch + timedelta(microseconds=micros)
|
||||
except OverflowError:
|
||||
if micros <= 0:
|
||||
raise DataError("timestamp too small (before year 1)") from None
|
||||
else:
|
||||
raise DataError("timestamp too large (after year 10K)") from None
|
||||
|
||||
|
||||
class TimestamptzLoader(Loader):
|
||||
_re_format = re.compile(
|
||||
rb"""(?ix)
|
||||
^
|
||||
(\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Date
|
||||
(?: T | [^a-z0-9] ) # Separator, including T
|
||||
(\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Time
|
||||
(?: \.(\d+) )? # Micros
|
||||
([-+]) (\d+) (?: : (\d+) )? (?: : (\d+) )? # Timezone
|
||||
$
|
||||
"""
|
||||
)
|
||||
|
||||
def __init__(self, oid: int, context: Optional[AdaptContext] = None):
|
||||
super().__init__(oid, context)
|
||||
self._timezone = get_tzinfo(self.connection.pgconn if self.connection else None)
|
||||
|
||||
ds = _get_datestyle(self.connection)
|
||||
if not ds.startswith(b"I"): # not ISO
|
||||
setattr(self, "load", self._load_notimpl)
|
||||
|
||||
def load(self, data: Buffer) -> datetime:
|
||||
m = self._re_format.match(data)
|
||||
if not m:
|
||||
raise _get_timestamp_load_error(self.connection, data) from None
|
||||
|
||||
ye, mo, da, ho, mi, se, fr, sgn, oh, om, os = m.groups()
|
||||
|
||||
# Pad the fraction of second to get the micros
|
||||
if fr:
|
||||
us = int(fr)
|
||||
if len(fr) < 6:
|
||||
us *= _uspad[len(fr)]
|
||||
else:
|
||||
us = 0
|
||||
|
||||
# Calculate timezone offset
|
||||
soff = 60 * 60 * int(oh)
|
||||
if om:
|
||||
soff += 60 * int(om)
|
||||
if os:
|
||||
soff += int(os)
|
||||
tzoff = timedelta(0, soff if sgn == b"+" else -soff)
|
||||
|
||||
# The return value is a datetime with the timezone of the connection
|
||||
# (in order to be consistent with the binary loader, which is the only
|
||||
# thing it can return). So create a temporary datetime object, in utc,
|
||||
# shift it by the offset parsed from the timestamp, and then move it to
|
||||
# the connection timezone.
|
||||
dt = None
|
||||
ex: Exception
|
||||
try:
|
||||
dt = datetime(int(ye), int(mo), int(da), int(ho), int(mi), int(se), us, utc)
|
||||
return (dt - tzoff).astimezone(self._timezone)
|
||||
except OverflowError as e:
|
||||
# If we have created the temporary 'dt' it means that we have a
|
||||
# datetime close to max, the shift pushed it past max, overflowing.
|
||||
# In this case return the datetime in a fixed offset timezone.
|
||||
if dt is not None:
|
||||
return dt.replace(tzinfo=timezone(tzoff))
|
||||
else:
|
||||
ex = e
|
||||
except ValueError as e:
|
||||
ex = e
|
||||
|
||||
raise _get_timestamp_load_error(self.connection, data, ex) from None
|
||||
|
||||
def _load_notimpl(self, data: Buffer) -> datetime:
|
||||
s = bytes(data).decode("utf8", "replace")
|
||||
ds = _get_datestyle(self.connection).decode("ascii")
|
||||
raise NotImplementedError(
|
||||
f"can't parse timestamptz with DateStyle {ds!r}: {s!r}"
|
||||
)
|
||||
|
||||
|
||||
class TimestamptzBinaryLoader(Loader):
|
||||
format = Format.BINARY
|
||||
|
||||
def __init__(self, oid: int, context: Optional[AdaptContext] = None):
|
||||
super().__init__(oid, context)
|
||||
self._timezone = get_tzinfo(self.connection.pgconn if self.connection else None)
|
||||
|
||||
def load(self, data: Buffer) -> datetime:
|
||||
micros = unpack_int8(data)[0]
|
||||
try:
|
||||
ts = _pg_datetimetz_epoch + timedelta(microseconds=micros)
|
||||
return ts.astimezone(self._timezone)
|
||||
except OverflowError:
|
||||
# If we were asked about a timestamp which would overflow in UTC,
|
||||
# but not in the desired timezone (e.g. datetime.max at Chicago
|
||||
# timezone) we can still save the day by shifting the value by the
|
||||
# timezone offset and then replacing the timezone.
|
||||
if self._timezone:
|
||||
utcoff = self._timezone.utcoffset(
|
||||
datetime.min if micros < 0 else datetime.max
|
||||
)
|
||||
if utcoff:
|
||||
usoff = 1_000_000 * int(utcoff.total_seconds())
|
||||
try:
|
||||
ts = _pg_datetime_epoch + timedelta(microseconds=micros + usoff)
|
||||
except OverflowError:
|
||||
pass # will raise downstream
|
||||
else:
|
||||
return ts.replace(tzinfo=self._timezone)
|
||||
|
||||
if micros <= 0:
|
||||
raise DataError("timestamp too small (before year 1)") from None
|
||||
else:
|
||||
raise DataError("timestamp too large (after year 10K)") from None
|
||||
|
||||
|
||||
class IntervalLoader(Loader):
|
||||
_re_interval = re.compile(
|
||||
rb"""
|
||||
(?: ([-+]?\d+) \s+ years? \s* )? # Years
|
||||
(?: ([-+]?\d+) \s+ mons? \s* )? # Months
|
||||
(?: ([-+]?\d+) \s+ days? \s* )? # Days
|
||||
(?: ([-+])? (\d+) : (\d+) : (\d+ (?:\.\d+)?) # Time
|
||||
)?
|
||||
""",
|
||||
re.VERBOSE,
|
||||
)
|
||||
|
||||
def __init__(self, oid: int, context: Optional[AdaptContext] = None):
|
||||
super().__init__(oid, context)
|
||||
if self.connection:
|
||||
ints = self.connection.pgconn.parameter_status(b"IntervalStyle")
|
||||
if ints != b"postgres":
|
||||
setattr(self, "load", self._load_notimpl)
|
||||
|
||||
def load(self, data: Buffer) -> timedelta:
|
||||
m = self._re_interval.match(data)
|
||||
if not m:
|
||||
s = bytes(data).decode("utf8", "replace")
|
||||
raise DataError(f"can't parse interval {s!r}")
|
||||
|
||||
ye, mo, da, sgn, ho, mi, se = m.groups()
|
||||
days = 0
|
||||
seconds = 0.0
|
||||
|
||||
if ye:
|
||||
days += 365 * int(ye)
|
||||
if mo:
|
||||
days += 30 * int(mo)
|
||||
if da:
|
||||
days += int(da)
|
||||
|
||||
if ho:
|
||||
seconds = 3600 * int(ho) + 60 * int(mi) + float(se)
|
||||
if sgn == b"-":
|
||||
seconds = -seconds
|
||||
|
||||
try:
|
||||
return timedelta(days=days, seconds=seconds)
|
||||
except OverflowError as e:
|
||||
s = bytes(data).decode("utf8", "replace")
|
||||
raise DataError(f"can't parse interval {s!r}: {e}") from None
|
||||
|
||||
def _load_notimpl(self, data: Buffer) -> timedelta:
|
||||
s = bytes(data).decode("utf8", "replace")
|
||||
ints = (
|
||||
self.connection
|
||||
and self.connection.pgconn.parameter_status(b"IntervalStyle")
|
||||
or b"unknown"
|
||||
).decode("utf8", "replace")
|
||||
raise NotImplementedError(
|
||||
f"can't parse interval with IntervalStyle {ints}: {s!r}"
|
||||
)
|
||||
|
||||
|
||||
class IntervalBinaryLoader(Loader):
|
||||
format = Format.BINARY
|
||||
|
||||
def load(self, data: Buffer) -> timedelta:
|
||||
micros, days, months = _unpack_interval(data)
|
||||
if months > 0:
|
||||
years, months = divmod(months, 12)
|
||||
days = days + 30 * months + 365 * years
|
||||
elif months < 0:
|
||||
years, months = divmod(-months, 12)
|
||||
days = days - 30 * months - 365 * years
|
||||
|
||||
try:
|
||||
return timedelta(days=days, microseconds=micros)
|
||||
except OverflowError as e:
|
||||
raise DataError(f"can't parse interval: {e}") from None
|
||||
|
||||
|
||||
def _get_datestyle(conn: Optional["BaseConnection[Any]"]) -> bytes:
|
||||
if conn:
|
||||
ds = conn.pgconn.parameter_status(b"DateStyle")
|
||||
if ds:
|
||||
return ds
|
||||
|
||||
return b"ISO, DMY"
|
||||
|
||||
|
||||
def _get_timestamp_load_error(
|
||||
conn: Optional["BaseConnection[Any]"], data: Buffer, ex: Optional[Exception] = None
|
||||
) -> Exception:
|
||||
s = bytes(data).decode("utf8", "replace")
|
||||
|
||||
def is_overflow(s: str) -> bool:
|
||||
if not s:
|
||||
return False
|
||||
|
||||
ds = _get_datestyle(conn)
|
||||
if not ds.startswith(b"P"): # Postgres
|
||||
return len(s.split()[0]) > 10 # date is first token
|
||||
else:
|
||||
return len(s.split()[-1]) > 4 # year is last token
|
||||
|
||||
if s == "-infinity" or s.endswith("BC"):
|
||||
return DataError("timestamp too small (before year 1): {s!r}")
|
||||
elif s == "infinity" or is_overflow(s):
|
||||
return DataError(f"timestamp too large (after year 10K): {s!r}")
|
||||
else:
|
||||
return DataError(f"can't parse timestamp {s!r}: {ex or '(unknown)'}")
|
||||
|
||||
|
||||
_month_abbr = {
|
||||
n: i
|
||||
for i, n in enumerate(b"Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec".split(), 1)
|
||||
}
|
||||
|
||||
# Pad to get microseconds from a fraction of seconds
|
||||
_uspad = [0, 100_000, 10_000, 1_000, 100, 10, 1]
|
||||
|
||||
|
||||
def register_default_adapters(context: AdaptContext) -> None:
|
||||
adapters = context.adapters
|
||||
adapters.register_dumper("datetime.date", DateDumper)
|
||||
adapters.register_dumper("datetime.date", DateBinaryDumper)
|
||||
|
||||
# first register dumpers for 'timetz' oid, then the proper ones on time type.
|
||||
adapters.register_dumper("datetime.time", TimeTzDumper)
|
||||
adapters.register_dumper("datetime.time", TimeTzBinaryDumper)
|
||||
adapters.register_dumper("datetime.time", TimeDumper)
|
||||
adapters.register_dumper("datetime.time", TimeBinaryDumper)
|
||||
|
||||
# first register dumpers for 'timestamp' oid, then the proper ones
|
||||
# on the datetime type.
|
||||
adapters.register_dumper("datetime.datetime", DatetimeNoTzDumper)
|
||||
adapters.register_dumper("datetime.datetime", DatetimeNoTzBinaryDumper)
|
||||
adapters.register_dumper("datetime.datetime", DatetimeDumper)
|
||||
adapters.register_dumper("datetime.datetime", DatetimeBinaryDumper)
|
||||
|
||||
adapters.register_dumper("datetime.timedelta", TimedeltaDumper)
|
||||
adapters.register_dumper("datetime.timedelta", TimedeltaBinaryDumper)
|
||||
|
||||
adapters.register_loader("date", DateLoader)
|
||||
adapters.register_loader("date", DateBinaryLoader)
|
||||
adapters.register_loader("time", TimeLoader)
|
||||
adapters.register_loader("time", TimeBinaryLoader)
|
||||
adapters.register_loader("timetz", TimetzLoader)
|
||||
adapters.register_loader("timetz", TimetzBinaryLoader)
|
||||
adapters.register_loader("timestamp", TimestampLoader)
|
||||
adapters.register_loader("timestamp", TimestampBinaryLoader)
|
||||
adapters.register_loader("timestamptz", TimestamptzLoader)
|
||||
adapters.register_loader("timestamptz", TimestamptzBinaryLoader)
|
||||
adapters.register_loader("interval", IntervalLoader)
|
||||
adapters.register_loader("interval", IntervalBinaryLoader)
|
||||
220
srcs/.venv/lib/python3.11/site-packages/psycopg/types/enum.py
Normal file
220
srcs/.venv/lib/python3.11/site-packages/psycopg/types/enum.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""
|
||||
Adapters for the enum type.
|
||||
"""
|
||||
from enum import Enum
|
||||
from typing import Dict, Generic, Optional, Mapping, Sequence
|
||||
from typing import Tuple, Type, TypeVar, Union, cast
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from .. import postgres
|
||||
from .. import errors as e
|
||||
from ..pq import Format
|
||||
from ..abc import AdaptContext
|
||||
from ..adapt import Buffer, Dumper, Loader
|
||||
from .._compat import cache
|
||||
from .._encodings import conn_encoding
|
||||
from .._typeinfo import EnumInfo as EnumInfo # exported here
|
||||
|
||||
E = TypeVar("E", bound=Enum)
|
||||
|
||||
EnumDumpMap: TypeAlias = Dict[E, bytes]
|
||||
EnumLoadMap: TypeAlias = Dict[bytes, E]
|
||||
EnumMapping: TypeAlias = Union[Mapping[E, str], Sequence[Tuple[E, str]], None]
|
||||
|
||||
# Hashable versions
|
||||
_HEnumDumpMap: TypeAlias = Tuple[Tuple[E, bytes], ...]
|
||||
_HEnumLoadMap: TypeAlias = Tuple[Tuple[bytes, E], ...]
|
||||
|
||||
TEXT = Format.TEXT
|
||||
BINARY = Format.BINARY
|
||||
|
||||
|
||||
class _BaseEnumLoader(Loader, Generic[E]):
|
||||
"""
|
||||
Loader for a specific Enum class
|
||||
"""
|
||||
|
||||
enum: Type[E]
|
||||
_load_map: EnumLoadMap[E]
|
||||
|
||||
def load(self, data: Buffer) -> E:
|
||||
if not isinstance(data, bytes):
|
||||
data = bytes(data)
|
||||
|
||||
try:
|
||||
return self._load_map[data]
|
||||
except KeyError:
|
||||
enc = conn_encoding(self.connection)
|
||||
label = data.decode(enc, "replace")
|
||||
raise e.DataError(
|
||||
f"bad member for enum {self.enum.__qualname__}: {label!r}"
|
||||
)
|
||||
|
||||
|
||||
class _BaseEnumDumper(Dumper, Generic[E]):
|
||||
"""
|
||||
Dumper for a specific Enum class
|
||||
"""
|
||||
|
||||
enum: Type[E]
|
||||
_dump_map: EnumDumpMap[E]
|
||||
|
||||
def dump(self, value: E) -> Buffer:
|
||||
return self._dump_map[value]
|
||||
|
||||
|
||||
class EnumDumper(Dumper):
|
||||
"""
|
||||
Dumper for a generic Enum class
|
||||
"""
|
||||
|
||||
def __init__(self, cls: type, context: Optional[AdaptContext] = None):
|
||||
super().__init__(cls, context)
|
||||
self._encoding = conn_encoding(self.connection)
|
||||
|
||||
def dump(self, value: E) -> Buffer:
|
||||
return value.name.encode(self._encoding)
|
||||
|
||||
|
||||
class EnumBinaryDumper(EnumDumper):
|
||||
format = BINARY
|
||||
|
||||
|
||||
def register_enum(
|
||||
info: EnumInfo,
|
||||
context: Optional[AdaptContext] = None,
|
||||
enum: Optional[Type[E]] = None,
|
||||
*,
|
||||
mapping: EnumMapping[E] = None,
|
||||
) -> None:
|
||||
"""Register the adapters to load and dump a enum type.
|
||||
|
||||
:param info: The object with the information about the enum to register.
|
||||
:param context: The context where to register the adapters. If `!None`,
|
||||
register it globally.
|
||||
:param enum: Python enum type matching to the PostgreSQL one. If `!None`,
|
||||
a new enum will be generated and exposed as `EnumInfo.enum`.
|
||||
:param mapping: Override the mapping between `!enum` members and `!info`
|
||||
labels.
|
||||
"""
|
||||
|
||||
if not info:
|
||||
raise TypeError("no info passed. Is the requested enum available?")
|
||||
|
||||
if enum is None:
|
||||
enum = cast(Type[E], _make_enum(info.name, tuple(info.labels)))
|
||||
|
||||
info.enum = enum
|
||||
adapters = context.adapters if context else postgres.adapters
|
||||
info.register(context)
|
||||
|
||||
load_map = _make_load_map(info, enum, mapping, context)
|
||||
|
||||
loader = _make_loader(info.name, info.enum, load_map)
|
||||
adapters.register_loader(info.oid, loader)
|
||||
|
||||
loader = _make_binary_loader(info.name, info.enum, load_map)
|
||||
adapters.register_loader(info.oid, loader)
|
||||
|
||||
dump_map = _make_dump_map(info, enum, mapping, context)
|
||||
|
||||
dumper = _make_dumper(info.enum, info.oid, dump_map)
|
||||
adapters.register_dumper(info.enum, dumper)
|
||||
|
||||
dumper = _make_binary_dumper(info.enum, info.oid, dump_map)
|
||||
adapters.register_dumper(info.enum, dumper)
|
||||
|
||||
|
||||
# Cache all dynamically-generated types to avoid leaks in case the types
|
||||
# cannot be GC'd.
|
||||
|
||||
|
||||
@cache
|
||||
def _make_enum(name: str, labels: Tuple[str, ...]) -> Enum:
|
||||
return Enum(name.title(), labels, module=__name__)
|
||||
|
||||
|
||||
@cache
|
||||
def _make_loader(
|
||||
name: str, enum: Type[Enum], load_map: _HEnumLoadMap[E]
|
||||
) -> Type[_BaseEnumLoader[E]]:
|
||||
attribs = {"enum": enum, "_load_map": dict(load_map)}
|
||||
return type(f"{name.title()}Loader", (_BaseEnumLoader,), attribs)
|
||||
|
||||
|
||||
@cache
|
||||
def _make_binary_loader(
|
||||
name: str, enum: Type[Enum], load_map: _HEnumLoadMap[E]
|
||||
) -> Type[_BaseEnumLoader[E]]:
|
||||
attribs = {"enum": enum, "_load_map": dict(load_map), "format": BINARY}
|
||||
return type(f"{name.title()}BinaryLoader", (_BaseEnumLoader,), attribs)
|
||||
|
||||
|
||||
@cache
|
||||
def _make_dumper(
|
||||
enum: Type[Enum], oid: int, dump_map: _HEnumDumpMap[E]
|
||||
) -> Type[_BaseEnumDumper[E]]:
|
||||
attribs = {"enum": enum, "oid": oid, "_dump_map": dict(dump_map)}
|
||||
return type(f"{enum.__name__}Dumper", (_BaseEnumDumper,), attribs)
|
||||
|
||||
|
||||
@cache
|
||||
def _make_binary_dumper(
|
||||
enum: Type[Enum], oid: int, dump_map: _HEnumDumpMap[E]
|
||||
) -> Type[_BaseEnumDumper[E]]:
|
||||
attribs = {"enum": enum, "oid": oid, "_dump_map": dict(dump_map), "format": BINARY}
|
||||
return type(f"{enum.__name__}BinaryDumper", (_BaseEnumDumper,), attribs)
|
||||
|
||||
|
||||
def _make_load_map(
|
||||
info: EnumInfo,
|
||||
enum: Type[E],
|
||||
mapping: EnumMapping[E],
|
||||
context: Optional[AdaptContext],
|
||||
) -> _HEnumLoadMap[E]:
|
||||
enc = conn_encoding(context.connection if context else None)
|
||||
rv = []
|
||||
for label in info.labels:
|
||||
try:
|
||||
member = enum[label]
|
||||
except KeyError:
|
||||
# tolerate a missing enum, assuming it won't be used. If it is we
|
||||
# will get a DataError on fetch.
|
||||
pass
|
||||
else:
|
||||
rv.append((label.encode(enc), member))
|
||||
|
||||
if mapping:
|
||||
if isinstance(mapping, Mapping):
|
||||
mapping = list(mapping.items())
|
||||
|
||||
for member, label in mapping:
|
||||
rv.append((label.encode(enc), member))
|
||||
|
||||
return tuple(rv)
|
||||
|
||||
|
||||
def _make_dump_map(
|
||||
info: EnumInfo,
|
||||
enum: Type[E],
|
||||
mapping: EnumMapping[E],
|
||||
context: Optional[AdaptContext],
|
||||
) -> _HEnumDumpMap[E]:
|
||||
enc = conn_encoding(context.connection if context else None)
|
||||
rv = []
|
||||
for member in enum:
|
||||
rv.append((member, member.name.encode(enc)))
|
||||
|
||||
if mapping:
|
||||
if isinstance(mapping, Mapping):
|
||||
mapping = list(mapping.items())
|
||||
|
||||
for member, label in mapping:
|
||||
rv.append((member, label.encode(enc)))
|
||||
|
||||
return tuple(rv)
|
||||
|
||||
|
||||
def register_default_adapters(context: AdaptContext) -> None:
|
||||
context.adapters.register_dumper(Enum, EnumBinaryDumper)
|
||||
context.adapters.register_dumper(Enum, EnumDumper)
|
||||
146
srcs/.venv/lib/python3.11/site-packages/psycopg/types/hstore.py
Normal file
146
srcs/.venv/lib/python3.11/site-packages/psycopg/types/hstore.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""
|
||||
Dict to hstore adaptation
|
||||
"""
|
||||
|
||||
# Copyright (C) 2021 The Psycopg Team
|
||||
|
||||
import re
|
||||
from typing import Dict, List, Optional, Type
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from .. import errors as e
|
||||
from .. import postgres
|
||||
from ..abc import Buffer, AdaptContext
|
||||
from ..adapt import PyFormat, RecursiveDumper, RecursiveLoader
|
||||
from .._compat import cache
|
||||
from ..postgres import TEXT_OID
|
||||
from .._typeinfo import TypeInfo
|
||||
|
||||
_re_escape = re.compile(r'(["\\])')
|
||||
_re_unescape = re.compile(r"\\(.)")
|
||||
|
||||
_re_hstore = re.compile(
|
||||
r"""
|
||||
# hstore key:
|
||||
# a string of normal or escaped chars
|
||||
"((?: [^"\\] | \\. )*)"
|
||||
\s*=>\s* # hstore value
|
||||
(?:
|
||||
NULL # the value can be null - not caught
|
||||
# or a quoted string like the key
|
||||
| "((?: [^"\\] | \\. )*)"
|
||||
)
|
||||
(?:\s*,\s*|$) # pairs separated by comma or end of string.
|
||||
""",
|
||||
re.VERBOSE,
|
||||
)
|
||||
|
||||
|
||||
Hstore: TypeAlias = Dict[str, Optional[str]]
|
||||
|
||||
|
||||
class BaseHstoreDumper(RecursiveDumper):
|
||||
def dump(self, obj: Hstore) -> Buffer:
|
||||
if not obj:
|
||||
return b""
|
||||
|
||||
tokens: List[str] = []
|
||||
|
||||
def add_token(s: str) -> None:
|
||||
tokens.append('"')
|
||||
tokens.append(_re_escape.sub(r"\\\1", s))
|
||||
tokens.append('"')
|
||||
|
||||
for k, v in obj.items():
|
||||
if not isinstance(k, str):
|
||||
raise e.DataError("hstore keys can only be strings")
|
||||
add_token(k)
|
||||
|
||||
tokens.append("=>")
|
||||
|
||||
if v is None:
|
||||
tokens.append("NULL")
|
||||
elif not isinstance(v, str):
|
||||
raise e.DataError("hstore keys can only be strings")
|
||||
else:
|
||||
add_token(v)
|
||||
|
||||
tokens.append(",")
|
||||
|
||||
del tokens[-1]
|
||||
data = "".join(tokens)
|
||||
dumper = self._tx.get_dumper(data, PyFormat.TEXT)
|
||||
return dumper.dump(data)
|
||||
|
||||
|
||||
class HstoreLoader(RecursiveLoader):
|
||||
def load(self, data: Buffer) -> Hstore:
|
||||
loader = self._tx.get_loader(TEXT_OID, self.format)
|
||||
s: str = loader.load(data)
|
||||
|
||||
rv: Hstore = {}
|
||||
start = 0
|
||||
for m in _re_hstore.finditer(s):
|
||||
if m is None or m.start() != start:
|
||||
raise e.DataError(f"error parsing hstore pair at char {start}")
|
||||
k = _re_unescape.sub(r"\1", m.group(1))
|
||||
v = m.group(2)
|
||||
if v is not None:
|
||||
v = _re_unescape.sub(r"\1", v)
|
||||
|
||||
rv[k] = v
|
||||
start = m.end()
|
||||
|
||||
if start < len(s):
|
||||
raise e.DataError(f"error parsing hstore: unparsed data after char {start}")
|
||||
|
||||
return rv
|
||||
|
||||
|
||||
def register_hstore(info: TypeInfo, context: Optional[AdaptContext] = None) -> None:
|
||||
"""Register the adapters to load and dump hstore.
|
||||
|
||||
:param info: The object with the information about the hstore type.
|
||||
:param context: The context where to register the adapters. If `!None`,
|
||||
register it globally.
|
||||
|
||||
.. note::
|
||||
|
||||
Registering the adapters doesn't affect objects already created, even
|
||||
if they are children of the registered context. For instance,
|
||||
registering the adapter globally doesn't affect already existing
|
||||
connections.
|
||||
"""
|
||||
# A friendly error warning instead of an AttributeError in case fetch()
|
||||
# failed and it wasn't noticed.
|
||||
if not info:
|
||||
raise TypeError("no info passed. Is the 'hstore' extension loaded?")
|
||||
|
||||
# Register arrays and type info
|
||||
info.register(context)
|
||||
|
||||
adapters = context.adapters if context else postgres.adapters
|
||||
|
||||
# Generate and register a customized text dumper
|
||||
adapters.register_dumper(dict, _make_hstore_dumper(info.oid))
|
||||
|
||||
# register the text loader on the oid
|
||||
adapters.register_loader(info.oid, HstoreLoader)
|
||||
|
||||
|
||||
# Cache all dynamically-generated types to avoid leaks in case the types
|
||||
# cannot be GC'd.
|
||||
|
||||
|
||||
@cache
|
||||
def _make_hstore_dumper(oid_in: int) -> Type[BaseHstoreDumper]:
|
||||
"""
|
||||
Return an hstore dumper class configured using `oid_in`.
|
||||
|
||||
Avoid to create new classes if the oid configured is the same.
|
||||
"""
|
||||
|
||||
class HstoreDumper(BaseHstoreDumper):
|
||||
oid = oid_in
|
||||
|
||||
return HstoreDumper
|
||||
247
srcs/.venv/lib/python3.11/site-packages/psycopg/types/json.py
Normal file
247
srcs/.venv/lib/python3.11/site-packages/psycopg/types/json.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""
|
||||
Adapters for JSON types.
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
import json
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||
|
||||
from .. import abc
|
||||
from .. import errors as e
|
||||
from .. import postgres
|
||||
from ..pq import Format
|
||||
from ..adapt import Buffer, Dumper, Loader, PyFormat, AdaptersMap
|
||||
from ..errors import DataError
|
||||
from .._compat import cache
|
||||
|
||||
JsonDumpsFunction = Callable[[Any], Union[str, bytes]]
|
||||
JsonLoadsFunction = Callable[[Union[str, bytes]], Any]
|
||||
|
||||
|
||||
def set_json_dumps(
|
||||
dumps: JsonDumpsFunction, context: Optional[abc.AdaptContext] = None
|
||||
) -> None:
|
||||
"""
|
||||
Set the JSON serialisation function to store JSON objects in the database.
|
||||
|
||||
:param dumps: The dump function to use.
|
||||
:type dumps: `!Callable[[Any], str]`
|
||||
:param context: Where to use the `!dumps` function. If not specified, use it
|
||||
globally.
|
||||
:type context: `~psycopg.Connection` or `~psycopg.Cursor`
|
||||
|
||||
By default dumping JSON uses the builtin `json.dumps`. You can override
|
||||
it to use a different JSON library or to use customised arguments.
|
||||
|
||||
If the `Json` wrapper specified a `!dumps` function, use it in precedence
|
||||
of the one set by this function.
|
||||
"""
|
||||
if context is None:
|
||||
# If changing load function globally, just change the default on the
|
||||
# global class
|
||||
_JsonDumper._dumps = dumps
|
||||
else:
|
||||
adapters = context.adapters
|
||||
|
||||
# If the scope is smaller than global, create subclassess and register
|
||||
# them in the appropriate scope.
|
||||
grid = [
|
||||
(Json, PyFormat.BINARY),
|
||||
(Json, PyFormat.TEXT),
|
||||
(Jsonb, PyFormat.BINARY),
|
||||
(Jsonb, PyFormat.TEXT),
|
||||
]
|
||||
for wrapper, format in grid:
|
||||
base = _get_current_dumper(adapters, wrapper, format)
|
||||
dumper = _make_dumper(base, dumps)
|
||||
adapters.register_dumper(wrapper, dumper)
|
||||
|
||||
|
||||
def set_json_loads(
|
||||
loads: JsonLoadsFunction, context: Optional[abc.AdaptContext] = None
|
||||
) -> None:
|
||||
"""
|
||||
Set the JSON parsing function to fetch JSON objects from the database.
|
||||
|
||||
:param loads: The load function to use.
|
||||
:type loads: `!Callable[[bytes], Any]`
|
||||
:param context: Where to use the `!loads` function. If not specified, use
|
||||
it globally.
|
||||
:type context: `~psycopg.Connection` or `~psycopg.Cursor`
|
||||
|
||||
By default loading JSON uses the builtin `json.loads`. You can override
|
||||
it to use a different JSON library or to use customised arguments.
|
||||
"""
|
||||
if context is None:
|
||||
# If changing load function globally, just change the default on the
|
||||
# global class
|
||||
_JsonLoader._loads = loads
|
||||
else:
|
||||
# If the scope is smaller than global, create subclassess and register
|
||||
# them in the appropriate scope.
|
||||
grid = [
|
||||
("json", JsonLoader),
|
||||
("json", JsonBinaryLoader),
|
||||
("jsonb", JsonbLoader),
|
||||
("jsonb", JsonbBinaryLoader),
|
||||
]
|
||||
for tname, base in grid:
|
||||
loader = _make_loader(base, loads)
|
||||
context.adapters.register_loader(tname, loader)
|
||||
|
||||
|
||||
# Cache all dynamically-generated types to avoid leaks in case the types
|
||||
# cannot be GC'd.
|
||||
|
||||
|
||||
@cache
|
||||
def _make_dumper(base: Type[abc.Dumper], dumps: JsonDumpsFunction) -> Type[abc.Dumper]:
|
||||
name = base.__name__
|
||||
if not name.startswith("Custom"):
|
||||
name = f"Custom{name}"
|
||||
return type(name, (base,), {"_dumps": dumps})
|
||||
|
||||
|
||||
@cache
|
||||
def _make_loader(base: Type[Loader], loads: JsonLoadsFunction) -> Type[Loader]:
|
||||
name = base.__name__
|
||||
if not name.startswith("Custom"):
|
||||
name = f"Custom{name}"
|
||||
return type(name, (base,), {"_loads": loads})
|
||||
|
||||
|
||||
class _JsonWrapper:
|
||||
__slots__ = ("obj", "dumps")
|
||||
|
||||
def __init__(self, obj: Any, dumps: Optional[JsonDumpsFunction] = None):
|
||||
self.obj = obj
|
||||
self.dumps = dumps
|
||||
|
||||
def __repr__(self) -> str:
|
||||
sobj = repr(self.obj)
|
||||
if len(sobj) > 40:
|
||||
sobj = f"{sobj[:35]} ... ({len(sobj)} chars)"
|
||||
return f"{self.__class__.__name__}({sobj})"
|
||||
|
||||
|
||||
class Json(_JsonWrapper):
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
class Jsonb(_JsonWrapper):
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
class _JsonDumper(Dumper):
|
||||
# The globally used JSON dumps() function. It can be changed globally (by
|
||||
# set_json_dumps) or by a subclass.
|
||||
_dumps: JsonDumpsFunction = json.dumps
|
||||
|
||||
def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None):
|
||||
super().__init__(cls, context)
|
||||
self.dumps = self.__class__._dumps
|
||||
|
||||
def dump(self, obj: Any) -> bytes:
|
||||
if isinstance(obj, _JsonWrapper):
|
||||
dumps = obj.dumps or self.dumps
|
||||
obj = obj.obj
|
||||
else:
|
||||
dumps = self.dumps
|
||||
data = dumps(obj)
|
||||
if isinstance(data, str):
|
||||
return data.encode()
|
||||
return data
|
||||
|
||||
|
||||
class JsonDumper(_JsonDumper):
|
||||
oid = postgres.types["json"].oid
|
||||
|
||||
|
||||
class JsonBinaryDumper(_JsonDumper):
|
||||
format = Format.BINARY
|
||||
oid = postgres.types["json"].oid
|
||||
|
||||
|
||||
class JsonbDumper(_JsonDumper):
|
||||
oid = postgres.types["jsonb"].oid
|
||||
|
||||
|
||||
class JsonbBinaryDumper(_JsonDumper):
|
||||
format = Format.BINARY
|
||||
oid = postgres.types["jsonb"].oid
|
||||
|
||||
def dump(self, obj: Any) -> bytes:
|
||||
return b"\x01" + super().dump(obj)
|
||||
|
||||
|
||||
class _JsonLoader(Loader):
|
||||
# The globally used JSON loads() function. It can be changed globally (by
|
||||
# set_json_loads) or by a subclass.
|
||||
_loads: JsonLoadsFunction = json.loads
|
||||
|
||||
def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None):
|
||||
super().__init__(oid, context)
|
||||
self.loads = self.__class__._loads
|
||||
|
||||
def load(self, data: Buffer) -> Any:
|
||||
# json.loads() cannot work on memoryview.
|
||||
if not isinstance(data, bytes):
|
||||
data = bytes(data)
|
||||
return self.loads(data)
|
||||
|
||||
|
||||
class JsonLoader(_JsonLoader):
|
||||
pass
|
||||
|
||||
|
||||
class JsonbLoader(_JsonLoader):
|
||||
pass
|
||||
|
||||
|
||||
class JsonBinaryLoader(_JsonLoader):
|
||||
format = Format.BINARY
|
||||
|
||||
|
||||
class JsonbBinaryLoader(_JsonLoader):
|
||||
format = Format.BINARY
|
||||
|
||||
def load(self, data: Buffer) -> Any:
|
||||
if data and data[0] != 1:
|
||||
raise DataError("unknown jsonb binary format: {data[0]}")
|
||||
data = data[1:]
|
||||
if not isinstance(data, bytes):
|
||||
data = bytes(data)
|
||||
return self.loads(data)
|
||||
|
||||
|
||||
def _get_current_dumper(
|
||||
adapters: AdaptersMap, cls: type, format: PyFormat
|
||||
) -> Type[abc.Dumper]:
|
||||
try:
|
||||
return adapters.get_dumper(cls, format)
|
||||
except e.ProgrammingError:
|
||||
return _default_dumpers[cls, format]
|
||||
|
||||
|
||||
_default_dumpers: Dict[Tuple[Type[_JsonWrapper], PyFormat], Type[Dumper]] = {
|
||||
(Json, PyFormat.BINARY): JsonBinaryDumper,
|
||||
(Json, PyFormat.TEXT): JsonDumper,
|
||||
(Jsonb, PyFormat.BINARY): JsonbBinaryDumper,
|
||||
(Jsonb, PyFormat.TEXT): JsonDumper,
|
||||
}
|
||||
|
||||
|
||||
def register_default_adapters(context: abc.AdaptContext) -> None:
|
||||
adapters = context.adapters
|
||||
|
||||
# Currently json binary format is nothing different than text, maybe with
|
||||
# an extra memcopy we can avoid.
|
||||
adapters.register_dumper(Json, JsonBinaryDumper)
|
||||
adapters.register_dumper(Json, JsonDumper)
|
||||
adapters.register_dumper(Jsonb, JsonbBinaryDumper)
|
||||
adapters.register_dumper(Jsonb, JsonbDumper)
|
||||
adapters.register_loader("json", JsonLoader)
|
||||
adapters.register_loader("jsonb", JsonbLoader)
|
||||
adapters.register_loader("json", JsonBinaryLoader)
|
||||
adapters.register_loader("jsonb", JsonbBinaryLoader)
|
||||
@@ -0,0 +1,521 @@
|
||||
"""
|
||||
Support for multirange types adaptation.
|
||||
"""
|
||||
|
||||
# Copyright (C) 2021 The Psycopg Team
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Any, Generic, List, Iterable
|
||||
from typing import MutableSequence, Optional, Type, Union, overload
|
||||
from datetime import date, datetime
|
||||
|
||||
from .. import errors as e
|
||||
from .. import postgres
|
||||
from ..pq import Format
|
||||
from ..abc import AdaptContext, Buffer, Dumper, DumperKey
|
||||
from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat
|
||||
from .._compat import cache
|
||||
from .._struct import pack_len, unpack_len
|
||||
from ..postgres import INVALID_OID, TEXT_OID
|
||||
from .._typeinfo import MultirangeInfo as MultirangeInfo # exported here
|
||||
|
||||
from .range import Range, T, load_range_text, load_range_binary
|
||||
from .range import dump_range_text, dump_range_binary, fail_dump
|
||||
|
||||
|
||||
class Multirange(MutableSequence[Range[T]]):
|
||||
"""Python representation for a PostgreSQL multirange type.
|
||||
|
||||
:param items: Sequence of ranges to initialise the object.
|
||||
"""
|
||||
|
||||
def __init__(self, items: Iterable[Range[T]] = ()):
|
||||
self._ranges: List[Range[T]] = list(map(self._check_type, items))
|
||||
|
||||
def _check_type(self, item: Any) -> Range[Any]:
|
||||
if not isinstance(item, Range):
|
||||
raise TypeError(
|
||||
f"Multirange is a sequence of Range, got {type(item).__name__}"
|
||||
)
|
||||
return item
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self._ranges!r})"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{{{', '.join(map(str, self._ranges))}}}"
|
||||
|
||||
@overload
|
||||
def __getitem__(self, index: int) -> Range[T]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, index: slice) -> "Multirange[T]":
|
||||
...
|
||||
|
||||
def __getitem__(self, index: Union[int, slice]) -> "Union[Range[T],Multirange[T]]":
|
||||
if isinstance(index, int):
|
||||
return self._ranges[index]
|
||||
else:
|
||||
return Multirange(self._ranges[index])
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._ranges)
|
||||
|
||||
@overload
|
||||
def __setitem__(self, index: int, value: Range[T]) -> None:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __setitem__(self, index: slice, value: Iterable[Range[T]]) -> None:
|
||||
...
|
||||
|
||||
def __setitem__(
|
||||
self,
|
||||
index: Union[int, slice],
|
||||
value: Union[Range[T], Iterable[Range[T]]],
|
||||
) -> None:
|
||||
if isinstance(index, int):
|
||||
self._check_type(value)
|
||||
self._ranges[index] = self._check_type(value)
|
||||
elif not isinstance(value, Iterable):
|
||||
raise TypeError("can only assign an iterable")
|
||||
else:
|
||||
value = map(self._check_type, value)
|
||||
self._ranges[index] = value
|
||||
|
||||
def __delitem__(self, index: Union[int, slice]) -> None:
|
||||
del self._ranges[index]
|
||||
|
||||
def insert(self, index: int, value: Range[T]) -> None:
|
||||
self._ranges.insert(index, self._check_type(value))
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if not isinstance(other, Multirange):
|
||||
return False
|
||||
return self._ranges == other._ranges
|
||||
|
||||
# Order is arbitrary but consistent
|
||||
|
||||
def __lt__(self, other: Any) -> bool:
|
||||
if not isinstance(other, Multirange):
|
||||
return NotImplemented
|
||||
return self._ranges < other._ranges
|
||||
|
||||
def __le__(self, other: Any) -> bool:
|
||||
return self == other or self < other # type: ignore
|
||||
|
||||
def __gt__(self, other: Any) -> bool:
|
||||
if not isinstance(other, Multirange):
|
||||
return NotImplemented
|
||||
return self._ranges > other._ranges
|
||||
|
||||
def __ge__(self, other: Any) -> bool:
|
||||
return self == other or self > other # type: ignore
|
||||
|
||||
|
||||
# Subclasses to specify a specific subtype. Usually not needed
|
||||
|
||||
|
||||
class Int4Multirange(Multirange[int]):
|
||||
pass
|
||||
|
||||
|
||||
class Int8Multirange(Multirange[int]):
|
||||
pass
|
||||
|
||||
|
||||
class NumericMultirange(Multirange[Decimal]):
|
||||
pass
|
||||
|
||||
|
||||
class DateMultirange(Multirange[date]):
|
||||
pass
|
||||
|
||||
|
||||
class TimestampMultirange(Multirange[datetime]):
|
||||
pass
|
||||
|
||||
|
||||
class TimestamptzMultirange(Multirange[datetime]):
|
||||
pass
|
||||
|
||||
|
||||
class BaseMultirangeDumper(RecursiveDumper):
|
||||
def __init__(self, cls: type, context: Optional[AdaptContext] = None):
|
||||
super().__init__(cls, context)
|
||||
self.sub_dumper: Optional[Dumper] = None
|
||||
self._adapt_format = PyFormat.from_pq(self.format)
|
||||
|
||||
def get_key(self, obj: Multirange[Any], format: PyFormat) -> DumperKey:
|
||||
# If we are a subclass whose oid is specified we don't need upgrade
|
||||
if self.cls is not Multirange:
|
||||
return self.cls
|
||||
|
||||
item = self._get_item(obj)
|
||||
if item is not None:
|
||||
sd = self._tx.get_dumper(item, self._adapt_format)
|
||||
return (self.cls, sd.get_key(item, format))
|
||||
else:
|
||||
return (self.cls,)
|
||||
|
||||
def upgrade(self, obj: Multirange[Any], format: PyFormat) -> "BaseMultirangeDumper":
|
||||
# If we are a subclass whose oid is specified we don't need upgrade
|
||||
if self.cls is not Multirange:
|
||||
return self
|
||||
|
||||
item = self._get_item(obj)
|
||||
if item is None:
|
||||
return MultirangeDumper(self.cls)
|
||||
|
||||
dumper: BaseMultirangeDumper
|
||||
if type(item) is int:
|
||||
# postgres won't cast int4range -> int8range so we must use
|
||||
# text format and unknown oid here
|
||||
sd = self._tx.get_dumper(item, PyFormat.TEXT)
|
||||
dumper = MultirangeDumper(self.cls, self._tx)
|
||||
dumper.sub_dumper = sd
|
||||
dumper.oid = INVALID_OID
|
||||
return dumper
|
||||
|
||||
sd = self._tx.get_dumper(item, format)
|
||||
dumper = type(self)(self.cls, self._tx)
|
||||
dumper.sub_dumper = sd
|
||||
if sd.oid == INVALID_OID and isinstance(item, str):
|
||||
# Work around the normal mapping where text is dumped as unknown
|
||||
dumper.oid = self._get_multirange_oid(TEXT_OID)
|
||||
else:
|
||||
dumper.oid = self._get_multirange_oid(sd.oid)
|
||||
|
||||
return dumper
|
||||
|
||||
def _get_item(self, obj: Multirange[Any]) -> Any:
|
||||
"""
|
||||
Return a member representative of the multirange
|
||||
"""
|
||||
for r in obj:
|
||||
if r.lower is not None:
|
||||
return r.lower
|
||||
if r.upper is not None:
|
||||
return r.upper
|
||||
return None
|
||||
|
||||
def _get_multirange_oid(self, sub_oid: int) -> int:
|
||||
"""
|
||||
Return the oid of the range from the oid of its elements.
|
||||
"""
|
||||
info = self._tx.adapters.types.get_by_subtype(MultirangeInfo, sub_oid)
|
||||
return info.oid if info else INVALID_OID
|
||||
|
||||
|
||||
class MultirangeDumper(BaseMultirangeDumper):
|
||||
"""
|
||||
Dumper for multirange types.
|
||||
|
||||
The dumper can upgrade to one specific for a different range type.
|
||||
"""
|
||||
|
||||
def dump(self, obj: Multirange[Any]) -> Buffer:
|
||||
if not obj:
|
||||
return b"{}"
|
||||
|
||||
item = self._get_item(obj)
|
||||
if item is not None:
|
||||
dump = self._tx.get_dumper(item, self._adapt_format).dump
|
||||
else:
|
||||
dump = fail_dump
|
||||
|
||||
out: List[Buffer] = [b"{"]
|
||||
for r in obj:
|
||||
out.append(dump_range_text(r, dump))
|
||||
out.append(b",")
|
||||
out[-1] = b"}"
|
||||
return b"".join(out)
|
||||
|
||||
|
||||
class MultirangeBinaryDumper(BaseMultirangeDumper):
|
||||
format = Format.BINARY
|
||||
|
||||
def dump(self, obj: Multirange[Any]) -> Buffer:
|
||||
item = self._get_item(obj)
|
||||
if item is not None:
|
||||
dump = self._tx.get_dumper(item, self._adapt_format).dump
|
||||
else:
|
||||
dump = fail_dump
|
||||
|
||||
out: List[Buffer] = [pack_len(len(obj))]
|
||||
for r in obj:
|
||||
data = dump_range_binary(r, dump)
|
||||
out.append(pack_len(len(data)))
|
||||
out.append(data)
|
||||
return b"".join(out)
|
||||
|
||||
|
||||
class BaseMultirangeLoader(RecursiveLoader, Generic[T]):
|
||||
subtype_oid: int
|
||||
|
||||
def __init__(self, oid: int, context: Optional[AdaptContext] = None):
|
||||
super().__init__(oid, context)
|
||||
self._load = self._tx.get_loader(self.subtype_oid, format=self.format).load
|
||||
|
||||
|
||||
class MultirangeLoader(BaseMultirangeLoader[T]):
|
||||
def load(self, data: Buffer) -> Multirange[T]:
|
||||
if not data or data[0] != _START_INT:
|
||||
raise e.DataError(
|
||||
"malformed multirange starting with"
|
||||
f" {bytes(data[:1]).decode('utf8', 'replace')}"
|
||||
)
|
||||
|
||||
out = Multirange[T]()
|
||||
if data == b"{}":
|
||||
return out
|
||||
|
||||
pos = 1
|
||||
data = data[pos:]
|
||||
try:
|
||||
while True:
|
||||
r, pos = load_range_text(data, self._load)
|
||||
out.append(r)
|
||||
|
||||
sep = data[pos] # can raise IndexError
|
||||
if sep == _SEP_INT:
|
||||
data = data[pos + 1 :]
|
||||
continue
|
||||
elif sep == _END_INT:
|
||||
if len(data) == pos + 1:
|
||||
return out
|
||||
else:
|
||||
raise e.DataError(
|
||||
"malformed multirange: data after closing brace"
|
||||
)
|
||||
else:
|
||||
raise e.DataError(
|
||||
f"malformed multirange: found unexpected {chr(sep)}"
|
||||
)
|
||||
|
||||
except IndexError:
|
||||
raise e.DataError("malformed multirange: separator missing")
|
||||
|
||||
return out
|
||||
|
||||
|
||||
_SEP_INT = ord(",")
|
||||
_START_INT = ord("{")
|
||||
_END_INT = ord("}")
|
||||
|
||||
|
||||
class MultirangeBinaryLoader(BaseMultirangeLoader[T]):
|
||||
format = Format.BINARY
|
||||
|
||||
def load(self, data: Buffer) -> Multirange[T]:
|
||||
nelems = unpack_len(data, 0)[0]
|
||||
pos = 4
|
||||
out = Multirange[T]()
|
||||
for i in range(nelems):
|
||||
length = unpack_len(data, pos)[0]
|
||||
pos += 4
|
||||
out.append(load_range_binary(data[pos : pos + length], self._load))
|
||||
pos += length
|
||||
|
||||
if pos != len(data):
|
||||
raise e.DataError("unexpected trailing data in multirange")
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def register_multirange(
|
||||
info: MultirangeInfo, context: Optional[AdaptContext] = None
|
||||
) -> None:
|
||||
"""Register the adapters to load and dump a multirange type.
|
||||
|
||||
:param info: The object with the information about the range to register.
|
||||
:param context: The context where to register the adapters. If `!None`,
|
||||
register it globally.
|
||||
|
||||
Register loaders so that loading data of this type will result in a `Range`
|
||||
with bounds parsed as the right subtype.
|
||||
|
||||
.. note::
|
||||
|
||||
Registering the adapters doesn't affect objects already created, even
|
||||
if they are children of the registered context. For instance,
|
||||
registering the adapter globally doesn't affect already existing
|
||||
connections.
|
||||
"""
|
||||
# A friendly error warning instead of an AttributeError in case fetch()
|
||||
# failed and it wasn't noticed.
|
||||
if not info:
|
||||
raise TypeError("no info passed. Is the requested multirange available?")
|
||||
|
||||
# Register arrays and type info
|
||||
info.register(context)
|
||||
|
||||
adapters = context.adapters if context else postgres.adapters
|
||||
|
||||
# generate and register a customized text loader
|
||||
loader: Type[BaseMultirangeLoader[Any]]
|
||||
loader = _make_loader(info.name, info.subtype_oid)
|
||||
adapters.register_loader(info.oid, loader)
|
||||
|
||||
# generate and register a customized binary loader
|
||||
loader = _make_binary_loader(info.name, info.subtype_oid)
|
||||
adapters.register_loader(info.oid, loader)
|
||||
|
||||
|
||||
# Cache all dynamically-generated types to avoid leaks in case the types
|
||||
# cannot be GC'd.
|
||||
|
||||
|
||||
@cache
|
||||
def _make_loader(name: str, oid: int) -> Type[MultirangeLoader[Any]]:
|
||||
return type(f"{name.title()}Loader", (MultirangeLoader,), {"subtype_oid": oid})
|
||||
|
||||
|
||||
@cache
|
||||
def _make_binary_loader(name: str, oid: int) -> Type[MultirangeBinaryLoader[Any]]:
|
||||
return type(
|
||||
f"{name.title()}BinaryLoader", (MultirangeBinaryLoader,), {"subtype_oid": oid}
|
||||
)
|
||||
|
||||
|
||||
# Text dumpers for builtin multirange types wrappers
|
||||
# These are registered on specific subtypes so that the upgrade mechanism
|
||||
# doesn't kick in.
|
||||
|
||||
|
||||
class Int4MultirangeDumper(MultirangeDumper):
|
||||
oid = postgres.types["int4multirange"].oid
|
||||
|
||||
|
||||
class Int8MultirangeDumper(MultirangeDumper):
|
||||
oid = postgres.types["int8multirange"].oid
|
||||
|
||||
|
||||
class NumericMultirangeDumper(MultirangeDumper):
|
||||
oid = postgres.types["nummultirange"].oid
|
||||
|
||||
|
||||
class DateMultirangeDumper(MultirangeDumper):
|
||||
oid = postgres.types["datemultirange"].oid
|
||||
|
||||
|
||||
class TimestampMultirangeDumper(MultirangeDumper):
|
||||
oid = postgres.types["tsmultirange"].oid
|
||||
|
||||
|
||||
class TimestamptzMultirangeDumper(MultirangeDumper):
|
||||
oid = postgres.types["tstzmultirange"].oid
|
||||
|
||||
|
||||
# Binary dumpers for builtin multirange types wrappers
|
||||
# These are registered on specific subtypes so that the upgrade mechanism
|
||||
# doesn't kick in.
|
||||
|
||||
|
||||
class Int4MultirangeBinaryDumper(MultirangeBinaryDumper):
|
||||
oid = postgres.types["int4multirange"].oid
|
||||
|
||||
|
||||
class Int8MultirangeBinaryDumper(MultirangeBinaryDumper):
|
||||
oid = postgres.types["int8multirange"].oid
|
||||
|
||||
|
||||
class NumericMultirangeBinaryDumper(MultirangeBinaryDumper):
|
||||
oid = postgres.types["nummultirange"].oid
|
||||
|
||||
|
||||
class DateMultirangeBinaryDumper(MultirangeBinaryDumper):
|
||||
oid = postgres.types["datemultirange"].oid
|
||||
|
||||
|
||||
class TimestampMultirangeBinaryDumper(MultirangeBinaryDumper):
|
||||
oid = postgres.types["tsmultirange"].oid
|
||||
|
||||
|
||||
class TimestamptzMultirangeBinaryDumper(MultirangeBinaryDumper):
|
||||
oid = postgres.types["tstzmultirange"].oid
|
||||
|
||||
|
||||
# Text loaders for builtin multirange types
|
||||
|
||||
|
||||
class Int4MultirangeLoader(MultirangeLoader[int]):
|
||||
subtype_oid = postgres.types["int4"].oid
|
||||
|
||||
|
||||
class Int8MultirangeLoader(MultirangeLoader[int]):
|
||||
subtype_oid = postgres.types["int8"].oid
|
||||
|
||||
|
||||
class NumericMultirangeLoader(MultirangeLoader[Decimal]):
|
||||
subtype_oid = postgres.types["numeric"].oid
|
||||
|
||||
|
||||
class DateMultirangeLoader(MultirangeLoader[date]):
|
||||
subtype_oid = postgres.types["date"].oid
|
||||
|
||||
|
||||
class TimestampMultirangeLoader(MultirangeLoader[datetime]):
|
||||
subtype_oid = postgres.types["timestamp"].oid
|
||||
|
||||
|
||||
class TimestampTZMultirangeLoader(MultirangeLoader[datetime]):
|
||||
subtype_oid = postgres.types["timestamptz"].oid
|
||||
|
||||
|
||||
# Binary loaders for builtin multirange types
|
||||
|
||||
|
||||
class Int4MultirangeBinaryLoader(MultirangeBinaryLoader[int]):
|
||||
subtype_oid = postgres.types["int4"].oid
|
||||
|
||||
|
||||
class Int8MultirangeBinaryLoader(MultirangeBinaryLoader[int]):
|
||||
subtype_oid = postgres.types["int8"].oid
|
||||
|
||||
|
||||
class NumericMultirangeBinaryLoader(MultirangeBinaryLoader[Decimal]):
|
||||
subtype_oid = postgres.types["numeric"].oid
|
||||
|
||||
|
||||
class DateMultirangeBinaryLoader(MultirangeBinaryLoader[date]):
|
||||
subtype_oid = postgres.types["date"].oid
|
||||
|
||||
|
||||
class TimestampMultirangeBinaryLoader(MultirangeBinaryLoader[datetime]):
|
||||
subtype_oid = postgres.types["timestamp"].oid
|
||||
|
||||
|
||||
class TimestampTZMultirangeBinaryLoader(MultirangeBinaryLoader[datetime]):
|
||||
subtype_oid = postgres.types["timestamptz"].oid
|
||||
|
||||
|
||||
def register_default_adapters(context: AdaptContext) -> None:
|
||||
adapters = context.adapters
|
||||
adapters.register_dumper(Multirange, MultirangeBinaryDumper)
|
||||
adapters.register_dumper(Multirange, MultirangeDumper)
|
||||
adapters.register_dumper(Int4Multirange, Int4MultirangeDumper)
|
||||
adapters.register_dumper(Int8Multirange, Int8MultirangeDumper)
|
||||
adapters.register_dumper(NumericMultirange, NumericMultirangeDumper)
|
||||
adapters.register_dumper(DateMultirange, DateMultirangeDumper)
|
||||
adapters.register_dumper(TimestampMultirange, TimestampMultirangeDumper)
|
||||
adapters.register_dumper(TimestamptzMultirange, TimestamptzMultirangeDumper)
|
||||
adapters.register_dumper(Int4Multirange, Int4MultirangeBinaryDumper)
|
||||
adapters.register_dumper(Int8Multirange, Int8MultirangeBinaryDumper)
|
||||
adapters.register_dumper(NumericMultirange, NumericMultirangeBinaryDumper)
|
||||
adapters.register_dumper(DateMultirange, DateMultirangeBinaryDumper)
|
||||
adapters.register_dumper(TimestampMultirange, TimestampMultirangeBinaryDumper)
|
||||
adapters.register_dumper(TimestamptzMultirange, TimestamptzMultirangeBinaryDumper)
|
||||
adapters.register_loader("int4multirange", Int4MultirangeLoader)
|
||||
adapters.register_loader("int8multirange", Int8MultirangeLoader)
|
||||
adapters.register_loader("nummultirange", NumericMultirangeLoader)
|
||||
adapters.register_loader("datemultirange", DateMultirangeLoader)
|
||||
adapters.register_loader("tsmultirange", TimestampMultirangeLoader)
|
||||
adapters.register_loader("tstzmultirange", TimestampTZMultirangeLoader)
|
||||
adapters.register_loader("int4multirange", Int4MultirangeBinaryLoader)
|
||||
adapters.register_loader("int8multirange", Int8MultirangeBinaryLoader)
|
||||
adapters.register_loader("nummultirange", NumericMultirangeBinaryLoader)
|
||||
adapters.register_loader("datemultirange", DateMultirangeBinaryLoader)
|
||||
adapters.register_loader("tsmultirange", TimestampMultirangeBinaryLoader)
|
||||
adapters.register_loader("tstzmultirange", TimestampTZMultirangeBinaryLoader)
|
||||
201
srcs/.venv/lib/python3.11/site-packages/psycopg/types/net.py
Normal file
201
srcs/.venv/lib/python3.11/site-packages/psycopg/types/net.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""
|
||||
Adapters for network types.
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
from typing import Callable, Optional, Type, Union, TYPE_CHECKING
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from .. import postgres
|
||||
from ..pq import Format
|
||||
from ..abc import AdaptContext
|
||||
from ..adapt import Buffer, Dumper, Loader
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import ipaddress
|
||||
|
||||
Address: TypeAlias = Union["ipaddress.IPv4Address", "ipaddress.IPv6Address"]
|
||||
Interface: TypeAlias = Union["ipaddress.IPv4Interface", "ipaddress.IPv6Interface"]
|
||||
Network: TypeAlias = Union["ipaddress.IPv4Network", "ipaddress.IPv6Network"]
|
||||
|
||||
# These objects will be imported lazily
|
||||
ip_address: Callable[[str], Address] = None # type: ignore[assignment]
|
||||
ip_interface: Callable[[str], Interface] = None # type: ignore[assignment]
|
||||
ip_network: Callable[[str], Network] = None # type: ignore[assignment]
|
||||
IPv4Address: "Type[ipaddress.IPv4Address]" = None # type: ignore[assignment]
|
||||
IPv6Address: "Type[ipaddress.IPv6Address]" = None # type: ignore[assignment]
|
||||
IPv4Interface: "Type[ipaddress.IPv4Interface]" = None # type: ignore[assignment]
|
||||
IPv6Interface: "Type[ipaddress.IPv6Interface]" = None # type: ignore[assignment]
|
||||
IPv4Network: "Type[ipaddress.IPv4Network]" = None # type: ignore[assignment]
|
||||
IPv6Network: "Type[ipaddress.IPv6Network]" = None # type: ignore[assignment]
|
||||
|
||||
PGSQL_AF_INET = 2
|
||||
PGSQL_AF_INET6 = 3
|
||||
IPV4_PREFIXLEN = 32
|
||||
IPV6_PREFIXLEN = 128
|
||||
|
||||
|
||||
class _LazyIpaddress:
|
||||
def _ensure_module(self) -> None:
|
||||
global ip_address, ip_interface, ip_network
|
||||
global IPv4Address, IPv6Address, IPv4Interface, IPv6Interface
|
||||
global IPv4Network, IPv6Network
|
||||
|
||||
if ip_address is None:
|
||||
from ipaddress import ip_address, ip_interface, ip_network
|
||||
from ipaddress import IPv4Address, IPv6Address
|
||||
from ipaddress import IPv4Interface, IPv6Interface
|
||||
from ipaddress import IPv4Network, IPv6Network
|
||||
|
||||
|
||||
class InterfaceDumper(Dumper):
|
||||
oid = postgres.types["inet"].oid
|
||||
|
||||
def dump(self, obj: Interface) -> bytes:
|
||||
return str(obj).encode()
|
||||
|
||||
|
||||
class NetworkDumper(Dumper):
|
||||
oid = postgres.types["cidr"].oid
|
||||
|
||||
def dump(self, obj: Network) -> bytes:
|
||||
return str(obj).encode()
|
||||
|
||||
|
||||
class _AIBinaryDumper(Dumper):
|
||||
format = Format.BINARY
|
||||
oid = postgres.types["inet"].oid
|
||||
|
||||
|
||||
class AddressBinaryDumper(_AIBinaryDumper):
|
||||
def dump(self, obj: Address) -> bytes:
|
||||
packed = obj.packed
|
||||
family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6
|
||||
head = bytes((family, obj.max_prefixlen, 0, len(packed)))
|
||||
return head + packed
|
||||
|
||||
|
||||
class InterfaceBinaryDumper(_AIBinaryDumper):
|
||||
def dump(self, obj: Interface) -> bytes:
|
||||
packed = obj.packed
|
||||
family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6
|
||||
head = bytes((family, obj.network.prefixlen, 0, len(packed)))
|
||||
return head + packed
|
||||
|
||||
|
||||
class InetBinaryDumper(_AIBinaryDumper, _LazyIpaddress):
|
||||
"""Either an address or an interface to inet
|
||||
|
||||
Used when looking up by oid.
|
||||
"""
|
||||
|
||||
def __init__(self, cls: type, context: Optional[AdaptContext] = None):
|
||||
super().__init__(cls, context)
|
||||
self._ensure_module()
|
||||
|
||||
def dump(self, obj: Union[Address, Interface]) -> bytes:
|
||||
packed = obj.packed
|
||||
family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6
|
||||
if isinstance(obj, (IPv4Interface, IPv6Interface)):
|
||||
prefixlen = obj.network.prefixlen
|
||||
else:
|
||||
prefixlen = obj.max_prefixlen
|
||||
|
||||
head = bytes((family, prefixlen, 0, len(packed)))
|
||||
return head + packed
|
||||
|
||||
|
||||
class NetworkBinaryDumper(Dumper):
|
||||
format = Format.BINARY
|
||||
oid = postgres.types["cidr"].oid
|
||||
|
||||
def dump(self, obj: Network) -> bytes:
|
||||
packed = obj.network_address.packed
|
||||
family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6
|
||||
head = bytes((family, obj.prefixlen, 1, len(packed)))
|
||||
return head + packed
|
||||
|
||||
|
||||
class _LazyIpaddressLoader(Loader, _LazyIpaddress):
|
||||
def __init__(self, oid: int, context: Optional[AdaptContext] = None):
|
||||
super().__init__(oid, context)
|
||||
self._ensure_module()
|
||||
|
||||
|
||||
class InetLoader(_LazyIpaddressLoader):
|
||||
def load(self, data: Buffer) -> Union[Address, Interface]:
|
||||
if isinstance(data, memoryview):
|
||||
data = bytes(data)
|
||||
|
||||
if b"/" in data:
|
||||
return ip_interface(data.decode())
|
||||
else:
|
||||
return ip_address(data.decode())
|
||||
|
||||
|
||||
class InetBinaryLoader(_LazyIpaddressLoader):
|
||||
format = Format.BINARY
|
||||
|
||||
def load(self, data: Buffer) -> Union[Address, Interface]:
|
||||
if isinstance(data, memoryview):
|
||||
data = bytes(data)
|
||||
|
||||
prefix = data[1]
|
||||
packed = data[4:]
|
||||
if data[0] == PGSQL_AF_INET:
|
||||
if prefix == IPV4_PREFIXLEN:
|
||||
return IPv4Address(packed)
|
||||
else:
|
||||
return IPv4Interface((packed, prefix))
|
||||
else:
|
||||
if prefix == IPV6_PREFIXLEN:
|
||||
return IPv6Address(packed)
|
||||
else:
|
||||
return IPv6Interface((packed, prefix))
|
||||
|
||||
|
||||
class CidrLoader(_LazyIpaddressLoader):
|
||||
def load(self, data: Buffer) -> Network:
|
||||
if isinstance(data, memoryview):
|
||||
data = bytes(data)
|
||||
|
||||
return ip_network(data.decode())
|
||||
|
||||
|
||||
class CidrBinaryLoader(_LazyIpaddressLoader):
|
||||
format = Format.BINARY
|
||||
|
||||
def load(self, data: Buffer) -> Network:
|
||||
if isinstance(data, memoryview):
|
||||
data = bytes(data)
|
||||
|
||||
prefix = data[1]
|
||||
packed = data[4:]
|
||||
if data[0] == PGSQL_AF_INET:
|
||||
return IPv4Network((packed, prefix))
|
||||
else:
|
||||
return IPv6Network((packed, prefix))
|
||||
|
||||
return ip_network(data.decode())
|
||||
|
||||
|
||||
def register_default_adapters(context: AdaptContext) -> None:
|
||||
adapters = context.adapters
|
||||
adapters.register_dumper("ipaddress.IPv4Address", InterfaceDumper)
|
||||
adapters.register_dumper("ipaddress.IPv6Address", InterfaceDumper)
|
||||
adapters.register_dumper("ipaddress.IPv4Interface", InterfaceDumper)
|
||||
adapters.register_dumper("ipaddress.IPv6Interface", InterfaceDumper)
|
||||
adapters.register_dumper("ipaddress.IPv4Network", NetworkDumper)
|
||||
adapters.register_dumper("ipaddress.IPv6Network", NetworkDumper)
|
||||
adapters.register_dumper("ipaddress.IPv4Address", AddressBinaryDumper)
|
||||
adapters.register_dumper("ipaddress.IPv6Address", AddressBinaryDumper)
|
||||
adapters.register_dumper("ipaddress.IPv4Interface", InterfaceBinaryDumper)
|
||||
adapters.register_dumper("ipaddress.IPv6Interface", InterfaceBinaryDumper)
|
||||
adapters.register_dumper("ipaddress.IPv4Network", NetworkBinaryDumper)
|
||||
adapters.register_dumper("ipaddress.IPv6Network", NetworkBinaryDumper)
|
||||
adapters.register_dumper(None, InetBinaryDumper)
|
||||
adapters.register_loader("inet", InetLoader)
|
||||
adapters.register_loader("inet", InetBinaryLoader)
|
||||
adapters.register_loader("cidr", CidrLoader)
|
||||
adapters.register_loader("cidr", CidrBinaryLoader)
|
||||
@@ -0,0 +1,25 @@
|
||||
"""
|
||||
Adapters for None.
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
from ..abc import AdaptContext, NoneType
|
||||
from ..adapt import Dumper
|
||||
|
||||
|
||||
class NoneDumper(Dumper):
|
||||
"""
|
||||
Not a complete dumper as it doesn't implement dump(), but it implements
|
||||
quote(), so it can be used in sql composition.
|
||||
"""
|
||||
|
||||
def dump(self, obj: None) -> bytes:
|
||||
raise NotImplementedError("NULL is passed to Postgres in other ways")
|
||||
|
||||
def quote(self, obj: None) -> bytes:
|
||||
return b"NULL"
|
||||
|
||||
|
||||
def register_default_adapters(context: AdaptContext) -> None:
|
||||
context.adapters.register_dumper(NoneType, NoneDumper)
|
||||
495
srcs/.venv/lib/python3.11/site-packages/psycopg/types/numeric.py
Normal file
495
srcs/.venv/lib/python3.11/site-packages/psycopg/types/numeric.py
Normal file
@@ -0,0 +1,495 @@
|
||||
"""
|
||||
Adapters for numeric types.
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
import struct
|
||||
from math import log
|
||||
from typing import Any, Callable, DefaultDict, Dict, Tuple, Union, cast
|
||||
from decimal import Decimal, DefaultContext, Context
|
||||
|
||||
from .. import postgres
|
||||
from .. import errors as e
|
||||
from ..pq import Format
|
||||
from ..abc import AdaptContext
|
||||
from ..adapt import Buffer, Dumper, Loader, PyFormat
|
||||
from .._struct import pack_int2, pack_uint2, unpack_int2
|
||||
from .._struct import pack_int4, pack_uint4, unpack_int4, unpack_uint4
|
||||
from .._struct import pack_int8, unpack_int8
|
||||
from .._struct import pack_float4, pack_float8, unpack_float4, unpack_float8
|
||||
|
||||
# Exposed here
|
||||
from .._wrappers import (
|
||||
Int2 as Int2,
|
||||
Int4 as Int4,
|
||||
Int8 as Int8,
|
||||
IntNumeric as IntNumeric,
|
||||
Oid as Oid,
|
||||
Float4 as Float4,
|
||||
Float8 as Float8,
|
||||
)
|
||||
|
||||
|
||||
class _IntDumper(Dumper):
|
||||
def dump(self, obj: Any) -> Buffer:
|
||||
t = type(obj)
|
||||
if t is not int:
|
||||
# Convert to int in order to dump IntEnum correctly
|
||||
if issubclass(t, int):
|
||||
obj = int(obj)
|
||||
else:
|
||||
raise e.DataError(f"integer expected, got {type(obj).__name__!r}")
|
||||
|
||||
return str(obj).encode()
|
||||
|
||||
def quote(self, obj: Any) -> Buffer:
|
||||
value = self.dump(obj)
|
||||
return value if obj >= 0 else b" " + value
|
||||
|
||||
|
||||
class _SpecialValuesDumper(Dumper):
|
||||
_special: Dict[bytes, bytes] = {}
|
||||
|
||||
def dump(self, obj: Any) -> bytes:
|
||||
return str(obj).encode()
|
||||
|
||||
def quote(self, obj: Any) -> bytes:
|
||||
value = self.dump(obj)
|
||||
|
||||
if value in self._special:
|
||||
return self._special[value]
|
||||
|
||||
return value if obj >= 0 else b" " + value
|
||||
|
||||
|
||||
class FloatDumper(_SpecialValuesDumper):
|
||||
oid = postgres.types["float8"].oid
|
||||
|
||||
_special = {
|
||||
b"inf": b"'Infinity'::float8",
|
||||
b"-inf": b"'-Infinity'::float8",
|
||||
b"nan": b"'NaN'::float8",
|
||||
}
|
||||
|
||||
|
||||
class Float4Dumper(FloatDumper):
|
||||
oid = postgres.types["float4"].oid
|
||||
|
||||
|
||||
class FloatBinaryDumper(Dumper):
|
||||
format = Format.BINARY
|
||||
oid = postgres.types["float8"].oid
|
||||
|
||||
def dump(self, obj: float) -> bytes:
|
||||
return pack_float8(obj)
|
||||
|
||||
|
||||
class Float4BinaryDumper(FloatBinaryDumper):
|
||||
oid = postgres.types["float4"].oid
|
||||
|
||||
def dump(self, obj: float) -> bytes:
|
||||
return pack_float4(obj)
|
||||
|
||||
|
||||
class DecimalDumper(_SpecialValuesDumper):
|
||||
oid = postgres.types["numeric"].oid
|
||||
|
||||
def dump(self, obj: Decimal) -> bytes:
|
||||
if obj.is_nan():
|
||||
# cover NaN and sNaN
|
||||
return b"NaN"
|
||||
else:
|
||||
return str(obj).encode()
|
||||
|
||||
_special = {
|
||||
b"Infinity": b"'Infinity'::numeric",
|
||||
b"-Infinity": b"'-Infinity'::numeric",
|
||||
b"NaN": b"'NaN'::numeric",
|
||||
}
|
||||
|
||||
|
||||
class Int2Dumper(_IntDumper):
|
||||
oid = postgres.types["int2"].oid
|
||||
|
||||
|
||||
class Int4Dumper(_IntDumper):
|
||||
oid = postgres.types["int4"].oid
|
||||
|
||||
|
||||
class Int8Dumper(_IntDumper):
|
||||
oid = postgres.types["int8"].oid
|
||||
|
||||
|
||||
class IntNumericDumper(_IntDumper):
|
||||
oid = postgres.types["numeric"].oid
|
||||
|
||||
|
||||
class OidDumper(_IntDumper):
|
||||
oid = postgres.types["oid"].oid
|
||||
|
||||
|
||||
class IntDumper(Dumper):
|
||||
def dump(self, obj: Any) -> bytes:
|
||||
raise TypeError(
|
||||
f"{type(self).__name__} is a dispatcher to other dumpers:"
|
||||
" dump() is not supposed to be called"
|
||||
)
|
||||
|
||||
def get_key(self, obj: int, format: PyFormat) -> type:
|
||||
return self.upgrade(obj, format).cls
|
||||
|
||||
_int2_dumper = Int2Dumper(Int2)
|
||||
_int4_dumper = Int4Dumper(Int4)
|
||||
_int8_dumper = Int8Dumper(Int8)
|
||||
_int_numeric_dumper = IntNumericDumper(IntNumeric)
|
||||
|
||||
def upgrade(self, obj: int, format: PyFormat) -> Dumper:
|
||||
if -(2**31) <= obj < 2**31:
|
||||
if -(2**15) <= obj < 2**15:
|
||||
return self._int2_dumper
|
||||
else:
|
||||
return self._int4_dumper
|
||||
else:
|
||||
if -(2**63) <= obj < 2**63:
|
||||
return self._int8_dumper
|
||||
else:
|
||||
return self._int_numeric_dumper
|
||||
|
||||
|
||||
class Int2BinaryDumper(Int2Dumper):
|
||||
format = Format.BINARY
|
||||
|
||||
def dump(self, obj: int) -> bytes:
|
||||
return pack_int2(obj)
|
||||
|
||||
|
||||
class Int4BinaryDumper(Int4Dumper):
|
||||
format = Format.BINARY
|
||||
|
||||
def dump(self, obj: int) -> bytes:
|
||||
return pack_int4(obj)
|
||||
|
||||
|
||||
class Int8BinaryDumper(Int8Dumper):
|
||||
format = Format.BINARY
|
||||
|
||||
def dump(self, obj: int) -> bytes:
|
||||
return pack_int8(obj)
|
||||
|
||||
|
||||
# Ratio between number of bits required to store a number and number of pg
|
||||
# decimal digits required.
|
||||
BIT_PER_PGDIGIT = log(2) / log(10_000)
|
||||
|
||||
|
||||
class IntNumericBinaryDumper(IntNumericDumper):
|
||||
format = Format.BINARY
|
||||
|
||||
def dump(self, obj: int) -> Buffer:
|
||||
return dump_int_to_numeric_binary(obj)
|
||||
|
||||
|
||||
class OidBinaryDumper(OidDumper):
|
||||
format = Format.BINARY
|
||||
|
||||
def dump(self, obj: int) -> bytes:
|
||||
return pack_uint4(obj)
|
||||
|
||||
|
||||
class IntBinaryDumper(IntDumper):
|
||||
format = Format.BINARY
|
||||
|
||||
_int2_dumper = Int2BinaryDumper(Int2)
|
||||
_int4_dumper = Int4BinaryDumper(Int4)
|
||||
_int8_dumper = Int8BinaryDumper(Int8)
|
||||
_int_numeric_dumper = IntNumericBinaryDumper(IntNumeric)
|
||||
|
||||
|
||||
class IntLoader(Loader):
|
||||
def load(self, data: Buffer) -> int:
|
||||
# it supports bytes directly
|
||||
return int(data)
|
||||
|
||||
|
||||
class Int2BinaryLoader(Loader):
|
||||
format = Format.BINARY
|
||||
|
||||
def load(self, data: Buffer) -> int:
|
||||
return unpack_int2(data)[0]
|
||||
|
||||
|
||||
class Int4BinaryLoader(Loader):
|
||||
format = Format.BINARY
|
||||
|
||||
def load(self, data: Buffer) -> int:
|
||||
return unpack_int4(data)[0]
|
||||
|
||||
|
||||
class Int8BinaryLoader(Loader):
|
||||
format = Format.BINARY
|
||||
|
||||
def load(self, data: Buffer) -> int:
|
||||
return unpack_int8(data)[0]
|
||||
|
||||
|
||||
class OidBinaryLoader(Loader):
|
||||
format = Format.BINARY
|
||||
|
||||
def load(self, data: Buffer) -> int:
|
||||
return unpack_uint4(data)[0]
|
||||
|
||||
|
||||
class FloatLoader(Loader):
|
||||
def load(self, data: Buffer) -> float:
|
||||
# it supports bytes directly
|
||||
return float(data)
|
||||
|
||||
|
||||
class Float4BinaryLoader(Loader):
|
||||
format = Format.BINARY
|
||||
|
||||
def load(self, data: Buffer) -> float:
|
||||
return unpack_float4(data)[0]
|
||||
|
||||
|
||||
class Float8BinaryLoader(Loader):
|
||||
format = Format.BINARY
|
||||
|
||||
def load(self, data: Buffer) -> float:
|
||||
return unpack_float8(data)[0]
|
||||
|
||||
|
||||
class NumericLoader(Loader):
|
||||
def load(self, data: Buffer) -> Decimal:
|
||||
if isinstance(data, memoryview):
|
||||
data = bytes(data)
|
||||
return Decimal(data.decode())
|
||||
|
||||
|
||||
DEC_DIGITS = 4 # decimal digits per Postgres "digit"
|
||||
NUMERIC_POS = 0x0000
|
||||
NUMERIC_NEG = 0x4000
|
||||
NUMERIC_NAN = 0xC000
|
||||
NUMERIC_PINF = 0xD000
|
||||
NUMERIC_NINF = 0xF000
|
||||
|
||||
_decimal_special = {
|
||||
NUMERIC_NAN: Decimal("NaN"),
|
||||
NUMERIC_PINF: Decimal("Infinity"),
|
||||
NUMERIC_NINF: Decimal("-Infinity"),
|
||||
}
|
||||
|
||||
|
||||
class _ContextMap(DefaultDict[int, Context]):
|
||||
"""
|
||||
Cache for decimal contexts to use when the precision requires it.
|
||||
|
||||
Note: if the default context is used (prec=28) you can get an invalid
|
||||
operation or a rounding to 0:
|
||||
|
||||
- Decimal(1000).shift(24) = Decimal('1000000000000000000000000000')
|
||||
- Decimal(1000).shift(25) = Decimal('0')
|
||||
- Decimal(1000).shift(30) raises InvalidOperation
|
||||
"""
|
||||
|
||||
def __missing__(self, key: int) -> Context:
|
||||
val = Context(prec=key)
|
||||
self[key] = val
|
||||
return val
|
||||
|
||||
|
||||
_contexts = _ContextMap()
|
||||
for i in range(DefaultContext.prec):
|
||||
_contexts[i] = DefaultContext
|
||||
|
||||
_unpack_numeric_head = cast(
|
||||
Callable[[Buffer], Tuple[int, int, int, int]],
|
||||
struct.Struct("!HhHH").unpack_from,
|
||||
)
|
||||
_pack_numeric_head = cast(
|
||||
Callable[[int, int, int, int], bytes],
|
||||
struct.Struct("!HhHH").pack,
|
||||
)
|
||||
|
||||
|
||||
class NumericBinaryLoader(Loader):
|
||||
format = Format.BINARY
|
||||
|
||||
def load(self, data: Buffer) -> Decimal:
|
||||
ndigits, weight, sign, dscale = _unpack_numeric_head(data)
|
||||
if sign == NUMERIC_POS or sign == NUMERIC_NEG:
|
||||
val = 0
|
||||
for i in range(8, len(data), 2):
|
||||
val = val * 10_000 + data[i] * 0x100 + data[i + 1]
|
||||
|
||||
shift = dscale - (ndigits - weight - 1) * DEC_DIGITS
|
||||
ctx = _contexts[(weight + 2) * DEC_DIGITS + dscale]
|
||||
return (
|
||||
Decimal(val if sign == NUMERIC_POS else -val)
|
||||
.scaleb(-dscale, ctx)
|
||||
.shift(shift, ctx)
|
||||
)
|
||||
else:
|
||||
try:
|
||||
return _decimal_special[sign]
|
||||
except KeyError:
|
||||
raise e.DataError(f"bad value for numeric sign: 0x{sign:X}") from None
|
||||
|
||||
|
||||
NUMERIC_NAN_BIN = _pack_numeric_head(0, 0, NUMERIC_NAN, 0)
|
||||
NUMERIC_PINF_BIN = _pack_numeric_head(0, 0, NUMERIC_PINF, 0)
|
||||
NUMERIC_NINF_BIN = _pack_numeric_head(0, 0, NUMERIC_NINF, 0)
|
||||
|
||||
|
||||
class DecimalBinaryDumper(Dumper):
|
||||
format = Format.BINARY
|
||||
oid = postgres.types["numeric"].oid
|
||||
|
||||
def dump(self, obj: Decimal) -> Buffer:
|
||||
return dump_decimal_to_numeric_binary(obj)
|
||||
|
||||
|
||||
class NumericDumper(DecimalDumper):
|
||||
def dump(self, obj: Union[Decimal, int]) -> bytes:
|
||||
if isinstance(obj, int):
|
||||
return str(obj).encode()
|
||||
else:
|
||||
return super().dump(obj)
|
||||
|
||||
|
||||
class NumericBinaryDumper(Dumper):
|
||||
format = Format.BINARY
|
||||
oid = postgres.types["numeric"].oid
|
||||
|
||||
def dump(self, obj: Union[Decimal, int]) -> Buffer:
|
||||
if isinstance(obj, int):
|
||||
return dump_int_to_numeric_binary(obj)
|
||||
else:
|
||||
return dump_decimal_to_numeric_binary(obj)
|
||||
|
||||
|
||||
def dump_decimal_to_numeric_binary(obj: Decimal) -> Union[bytearray, bytes]:
|
||||
sign, digits, exp = obj.as_tuple()
|
||||
if exp == "n" or exp == "N":
|
||||
return NUMERIC_NAN_BIN
|
||||
elif exp == "F":
|
||||
return NUMERIC_NINF_BIN if sign else NUMERIC_PINF_BIN
|
||||
|
||||
# Weights of py digits into a pg digit according to their positions.
|
||||
# Starting with an index wi != 0 is equivalent to prepending 0's to
|
||||
# the digits tuple, but without really changing it.
|
||||
weights = (1000, 100, 10, 1)
|
||||
wi = 0
|
||||
|
||||
ndigits = nzdigits = len(digits)
|
||||
|
||||
# Find the last nonzero digit
|
||||
while nzdigits > 0 and digits[nzdigits - 1] == 0:
|
||||
nzdigits -= 1
|
||||
|
||||
if exp <= 0:
|
||||
dscale = -exp
|
||||
else:
|
||||
dscale = 0
|
||||
# align the py digits to the pg digits if there's some py exponent
|
||||
ndigits += exp % DEC_DIGITS
|
||||
|
||||
if not nzdigits:
|
||||
return _pack_numeric_head(0, 0, NUMERIC_POS, dscale)
|
||||
|
||||
# Equivalent of 0-padding left to align the py digits to the pg digits
|
||||
# but without changing the digits tuple.
|
||||
mod = (ndigits - dscale) % DEC_DIGITS
|
||||
if mod:
|
||||
wi = DEC_DIGITS - mod
|
||||
ndigits += wi
|
||||
|
||||
tmp = nzdigits + wi
|
||||
out = bytearray(
|
||||
_pack_numeric_head(
|
||||
tmp // DEC_DIGITS + (tmp % DEC_DIGITS and 1), # ndigits
|
||||
(ndigits + exp) // DEC_DIGITS - 1, # weight
|
||||
NUMERIC_NEG if sign else NUMERIC_POS, # sign
|
||||
dscale,
|
||||
)
|
||||
)
|
||||
|
||||
pgdigit = 0
|
||||
for i in range(nzdigits):
|
||||
pgdigit += weights[wi] * digits[i]
|
||||
wi += 1
|
||||
if wi >= DEC_DIGITS:
|
||||
out += pack_uint2(pgdigit)
|
||||
pgdigit = wi = 0
|
||||
|
||||
if pgdigit:
|
||||
out += pack_uint2(pgdigit)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def dump_int_to_numeric_binary(obj: int) -> bytearray:
|
||||
ndigits = int(obj.bit_length() * BIT_PER_PGDIGIT) + 1
|
||||
out = bytearray(b"\x00\x00" * (ndigits + 4))
|
||||
if obj < 0:
|
||||
sign = NUMERIC_NEG
|
||||
obj = -obj
|
||||
else:
|
||||
sign = NUMERIC_POS
|
||||
|
||||
out[:8] = _pack_numeric_head(ndigits, ndigits - 1, sign, 0)
|
||||
i = 8 + (ndigits - 1) * 2
|
||||
while obj:
|
||||
rem = obj % 10_000
|
||||
obj //= 10_000
|
||||
out[i : i + 2] = pack_uint2(rem)
|
||||
i -= 2
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def register_default_adapters(context: AdaptContext) -> None:
|
||||
adapters = context.adapters
|
||||
adapters.register_dumper(int, IntDumper)
|
||||
adapters.register_dumper(int, IntBinaryDumper)
|
||||
adapters.register_dumper(float, FloatDumper)
|
||||
adapters.register_dumper(float, FloatBinaryDumper)
|
||||
adapters.register_dumper(Int2, Int2Dumper)
|
||||
adapters.register_dumper(Int4, Int4Dumper)
|
||||
adapters.register_dumper(Int8, Int8Dumper)
|
||||
adapters.register_dumper(IntNumeric, IntNumericDumper)
|
||||
adapters.register_dumper(Oid, OidDumper)
|
||||
|
||||
# The binary dumper is currently some 30% slower, so default to text
|
||||
# (see tests/scripts/testdec.py for a rough benchmark)
|
||||
# Also, must be after IntNumericDumper
|
||||
adapters.register_dumper("decimal.Decimal", DecimalBinaryDumper)
|
||||
adapters.register_dumper("decimal.Decimal", DecimalDumper)
|
||||
|
||||
# Used only by oid, can take both int and Decimal as input
|
||||
adapters.register_dumper(None, NumericBinaryDumper)
|
||||
adapters.register_dumper(None, NumericDumper)
|
||||
|
||||
adapters.register_dumper(Float4, Float4Dumper)
|
||||
adapters.register_dumper(Float8, FloatDumper)
|
||||
adapters.register_dumper(Int2, Int2BinaryDumper)
|
||||
adapters.register_dumper(Int4, Int4BinaryDumper)
|
||||
adapters.register_dumper(Int8, Int8BinaryDumper)
|
||||
adapters.register_dumper(Oid, OidBinaryDumper)
|
||||
adapters.register_dumper(Float4, Float4BinaryDumper)
|
||||
adapters.register_dumper(Float8, FloatBinaryDumper)
|
||||
adapters.register_loader("int2", IntLoader)
|
||||
adapters.register_loader("int4", IntLoader)
|
||||
adapters.register_loader("int8", IntLoader)
|
||||
adapters.register_loader("oid", IntLoader)
|
||||
adapters.register_loader("int2", Int2BinaryLoader)
|
||||
adapters.register_loader("int4", Int4BinaryLoader)
|
||||
adapters.register_loader("int8", Int8BinaryLoader)
|
||||
adapters.register_loader("oid", OidBinaryLoader)
|
||||
adapters.register_loader("float4", FloatLoader)
|
||||
adapters.register_loader("float8", FloatLoader)
|
||||
adapters.register_loader("float4", Float4BinaryLoader)
|
||||
adapters.register_loader("float8", Float8BinaryLoader)
|
||||
adapters.register_loader("numeric", NumericLoader)
|
||||
adapters.register_loader("numeric", NumericBinaryLoader)
|
||||
708
srcs/.venv/lib/python3.11/site-packages/psycopg/types/range.py
Normal file
708
srcs/.venv/lib/python3.11/site-packages/psycopg/types/range.py
Normal file
@@ -0,0 +1,708 @@
|
||||
"""
|
||||
Support for range types adaptation.
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
import re
|
||||
from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, Type, Tuple
|
||||
from typing import cast
|
||||
from decimal import Decimal
|
||||
from datetime import date, datetime
|
||||
|
||||
from .. import errors as e
|
||||
from .. import postgres
|
||||
from ..pq import Format
|
||||
from ..abc import AdaptContext, Buffer, Dumper, DumperKey
|
||||
from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat
|
||||
from .._compat import cache
|
||||
from .._struct import pack_len, unpack_len
|
||||
from ..postgres import INVALID_OID, TEXT_OID
|
||||
from .._typeinfo import RangeInfo as RangeInfo # exported here
|
||||
|
||||
RANGE_EMPTY = 0x01 # range is empty
|
||||
RANGE_LB_INC = 0x02 # lower bound is inclusive
|
||||
RANGE_UB_INC = 0x04 # upper bound is inclusive
|
||||
RANGE_LB_INF = 0x08 # lower bound is -infinity
|
||||
RANGE_UB_INF = 0x10 # upper bound is +infinity
|
||||
|
||||
_EMPTY_HEAD = bytes([RANGE_EMPTY])
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Range(Generic[T]):
|
||||
"""Python representation for a PostgreSQL range type.
|
||||
|
||||
:param lower: lower bound for the range. `!None` means unbound
|
||||
:param upper: upper bound for the range. `!None` means unbound
|
||||
:param bounds: one of the literal strings ``()``, ``[)``, ``(]``, ``[]``,
|
||||
representing whether the lower or upper bounds are included
|
||||
:param empty: if `!True`, the range is empty
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ("_lower", "_upper", "_bounds")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lower: Optional[T] = None,
|
||||
upper: Optional[T] = None,
|
||||
bounds: str = "[)",
|
||||
empty: bool = False,
|
||||
):
|
||||
if not empty:
|
||||
if bounds not in ("[)", "(]", "()", "[]"):
|
||||
raise ValueError("bound flags not valid: %r" % bounds)
|
||||
|
||||
self._lower = lower
|
||||
self._upper = upper
|
||||
|
||||
# Make bounds consistent with infs
|
||||
if lower is None and bounds[0] == "[":
|
||||
bounds = "(" + bounds[1]
|
||||
if upper is None and bounds[1] == "]":
|
||||
bounds = bounds[0] + ")"
|
||||
|
||||
self._bounds = bounds
|
||||
else:
|
||||
self._lower = self._upper = None
|
||||
self._bounds = ""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if self._bounds:
|
||||
args = f"{self._lower!r}, {self._upper!r}, {self._bounds!r}"
|
||||
else:
|
||||
args = "empty=True"
|
||||
|
||||
return f"{self.__class__.__name__}({args})"
|
||||
|
||||
def __str__(self) -> str:
|
||||
if not self._bounds:
|
||||
return "empty"
|
||||
|
||||
items = [
|
||||
self._bounds[0],
|
||||
str(self._lower),
|
||||
", ",
|
||||
str(self._upper),
|
||||
self._bounds[1],
|
||||
]
|
||||
return "".join(items)
|
||||
|
||||
@property
|
||||
def lower(self) -> Optional[T]:
|
||||
"""The lower bound of the range. `!None` if empty or unbound."""
|
||||
return self._lower
|
||||
|
||||
@property
|
||||
def upper(self) -> Optional[T]:
|
||||
"""The upper bound of the range. `!None` if empty or unbound."""
|
||||
return self._upper
|
||||
|
||||
@property
|
||||
def bounds(self) -> str:
|
||||
"""The bounds string (two characters from '[', '(', ']', ')')."""
|
||||
return self._bounds
|
||||
|
||||
@property
|
||||
def isempty(self) -> bool:
|
||||
"""`!True` if the range is empty."""
|
||||
return not self._bounds
|
||||
|
||||
@property
|
||||
def lower_inf(self) -> bool:
|
||||
"""`!True` if the range doesn't have a lower bound."""
|
||||
if not self._bounds:
|
||||
return False
|
||||
return self._lower is None
|
||||
|
||||
@property
|
||||
def upper_inf(self) -> bool:
|
||||
"""`!True` if the range doesn't have an upper bound."""
|
||||
if not self._bounds:
|
||||
return False
|
||||
return self._upper is None
|
||||
|
||||
@property
|
||||
def lower_inc(self) -> bool:
|
||||
"""`!True` if the lower bound is included in the range."""
|
||||
if not self._bounds or self._lower is None:
|
||||
return False
|
||||
return self._bounds[0] == "["
|
||||
|
||||
@property
|
||||
def upper_inc(self) -> bool:
|
||||
"""`!True` if the upper bound is included in the range."""
|
||||
if not self._bounds or self._upper is None:
|
||||
return False
|
||||
return self._bounds[1] == "]"
|
||||
|
||||
def __contains__(self, x: T) -> bool:
|
||||
if not self._bounds:
|
||||
return False
|
||||
|
||||
if self._lower is not None:
|
||||
if self._bounds[0] == "[":
|
||||
# It doesn't seem that Python has an ABC for ordered types.
|
||||
if x < self._lower: # type: ignore[operator]
|
||||
return False
|
||||
else:
|
||||
if x <= self._lower: # type: ignore[operator]
|
||||
return False
|
||||
|
||||
if self._upper is not None:
|
||||
if self._bounds[1] == "]":
|
||||
if x > self._upper: # type: ignore[operator]
|
||||
return False
|
||||
else:
|
||||
if x >= self._upper: # type: ignore[operator]
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self._bounds)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if not isinstance(other, Range):
|
||||
return False
|
||||
return (
|
||||
self._lower == other._lower
|
||||
and self._upper == other._upper
|
||||
and self._bounds == other._bounds
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self._lower, self._upper, self._bounds))
|
||||
|
||||
# as the postgres docs describe for the server-side stuff,
|
||||
# ordering is rather arbitrary, but will remain stable
|
||||
# and consistent.
|
||||
|
||||
def __lt__(self, other: Any) -> bool:
|
||||
if not isinstance(other, Range):
|
||||
return NotImplemented
|
||||
for attr in ("_lower", "_upper", "_bounds"):
|
||||
self_value = getattr(self, attr)
|
||||
other_value = getattr(other, attr)
|
||||
if self_value == other_value:
|
||||
pass
|
||||
elif self_value is None:
|
||||
return True
|
||||
elif other_value is None:
|
||||
return False
|
||||
else:
|
||||
return cast(bool, self_value < other_value)
|
||||
return False
|
||||
|
||||
def __le__(self, other: Any) -> bool:
|
||||
return self == other or self < other # type: ignore
|
||||
|
||||
def __gt__(self, other: Any) -> bool:
|
||||
if isinstance(other, Range):
|
||||
return other < self
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def __ge__(self, other: Any) -> bool:
|
||||
return self == other or self > other # type: ignore
|
||||
|
||||
def __getstate__(self) -> Dict[str, Any]:
|
||||
return {
|
||||
slot: getattr(self, slot) for slot in self.__slots__ if hasattr(self, slot)
|
||||
}
|
||||
|
||||
def __setstate__(self, state: Dict[str, Any]) -> None:
|
||||
for slot, value in state.items():
|
||||
setattr(self, slot, value)
|
||||
|
||||
|
||||
# Subclasses to specify a specific subtype. Usually not needed: only needed
|
||||
# in binary copy, where switching to text is not an option.
|
||||
|
||||
|
||||
class Int4Range(Range[int]):
|
||||
pass
|
||||
|
||||
|
||||
class Int8Range(Range[int]):
|
||||
pass
|
||||
|
||||
|
||||
class NumericRange(Range[Decimal]):
|
||||
pass
|
||||
|
||||
|
||||
class DateRange(Range[date]):
|
||||
pass
|
||||
|
||||
|
||||
class TimestampRange(Range[datetime]):
|
||||
pass
|
||||
|
||||
|
||||
class TimestamptzRange(Range[datetime]):
|
||||
pass
|
||||
|
||||
|
||||
class BaseRangeDumper(RecursiveDumper):
|
||||
def __init__(self, cls: type, context: Optional[AdaptContext] = None):
|
||||
super().__init__(cls, context)
|
||||
self.sub_dumper: Optional[Dumper] = None
|
||||
self._adapt_format = PyFormat.from_pq(self.format)
|
||||
|
||||
def get_key(self, obj: Range[Any], format: PyFormat) -> DumperKey:
|
||||
# If we are a subclass whose oid is specified we don't need upgrade
|
||||
if self.cls is not Range:
|
||||
return self.cls
|
||||
|
||||
item = self._get_item(obj)
|
||||
if item is not None:
|
||||
sd = self._tx.get_dumper(item, self._adapt_format)
|
||||
return (self.cls, sd.get_key(item, format))
|
||||
else:
|
||||
return (self.cls,)
|
||||
|
||||
def upgrade(self, obj: Range[Any], format: PyFormat) -> "BaseRangeDumper":
|
||||
# If we are a subclass whose oid is specified we don't need upgrade
|
||||
if self.cls is not Range:
|
||||
return self
|
||||
|
||||
item = self._get_item(obj)
|
||||
if item is None:
|
||||
return RangeDumper(self.cls)
|
||||
|
||||
dumper: BaseRangeDumper
|
||||
if type(item) is int:
|
||||
# postgres won't cast int4range -> int8range so we must use
|
||||
# text format and unknown oid here
|
||||
sd = self._tx.get_dumper(item, PyFormat.TEXT)
|
||||
dumper = RangeDumper(self.cls, self._tx)
|
||||
dumper.sub_dumper = sd
|
||||
dumper.oid = INVALID_OID
|
||||
return dumper
|
||||
|
||||
sd = self._tx.get_dumper(item, format)
|
||||
dumper = type(self)(self.cls, self._tx)
|
||||
dumper.sub_dumper = sd
|
||||
if sd.oid == INVALID_OID and isinstance(item, str):
|
||||
# Work around the normal mapping where text is dumped as unknown
|
||||
dumper.oid = self._get_range_oid(TEXT_OID)
|
||||
else:
|
||||
dumper.oid = self._get_range_oid(sd.oid)
|
||||
|
||||
return dumper
|
||||
|
||||
def _get_item(self, obj: Range[Any]) -> Any:
|
||||
"""
|
||||
Return a member representative of the range
|
||||
"""
|
||||
rv = obj.lower
|
||||
return rv if rv is not None else obj.upper
|
||||
|
||||
def _get_range_oid(self, sub_oid: int) -> int:
|
||||
"""
|
||||
Return the oid of the range from the oid of its elements.
|
||||
"""
|
||||
info = self._tx.adapters.types.get_by_subtype(RangeInfo, sub_oid)
|
||||
return info.oid if info else INVALID_OID
|
||||
|
||||
|
||||
class RangeDumper(BaseRangeDumper):
|
||||
"""
|
||||
Dumper for range types.
|
||||
|
||||
The dumper can upgrade to one specific for a different range type.
|
||||
"""
|
||||
|
||||
def dump(self, obj: Range[Any]) -> Buffer:
|
||||
item = self._get_item(obj)
|
||||
if item is not None:
|
||||
dump = self._tx.get_dumper(item, self._adapt_format).dump
|
||||
else:
|
||||
dump = fail_dump
|
||||
|
||||
return dump_range_text(obj, dump)
|
||||
|
||||
|
||||
def dump_range_text(obj: Range[Any], dump: Callable[[Any], Buffer]) -> Buffer:
|
||||
if obj.isempty:
|
||||
return b"empty"
|
||||
|
||||
parts: List[Buffer] = [b"[" if obj.lower_inc else b"("]
|
||||
|
||||
def dump_item(item: Any) -> Buffer:
|
||||
ad = dump(item)
|
||||
if not ad:
|
||||
return b'""'
|
||||
elif _re_needs_quotes.search(ad):
|
||||
return b'"' + _re_esc.sub(rb"\1\1", ad) + b'"'
|
||||
else:
|
||||
return ad
|
||||
|
||||
if obj.lower is not None:
|
||||
parts.append(dump_item(obj.lower))
|
||||
|
||||
parts.append(b",")
|
||||
|
||||
if obj.upper is not None:
|
||||
parts.append(dump_item(obj.upper))
|
||||
|
||||
parts.append(b"]" if obj.upper_inc else b")")
|
||||
|
||||
return b"".join(parts)
|
||||
|
||||
|
||||
_re_needs_quotes = re.compile(rb'[",\\\s()\[\]]')
|
||||
_re_esc = re.compile(rb"([\\\"])")
|
||||
|
||||
|
||||
class RangeBinaryDumper(BaseRangeDumper):
|
||||
format = Format.BINARY
|
||||
|
||||
def dump(self, obj: Range[Any]) -> Buffer:
|
||||
item = self._get_item(obj)
|
||||
if item is not None:
|
||||
dump = self._tx.get_dumper(item, self._adapt_format).dump
|
||||
else:
|
||||
dump = fail_dump
|
||||
|
||||
return dump_range_binary(obj, dump)
|
||||
|
||||
|
||||
def dump_range_binary(obj: Range[Any], dump: Callable[[Any], Buffer]) -> Buffer:
|
||||
if not obj:
|
||||
return _EMPTY_HEAD
|
||||
|
||||
out = bytearray([0]) # will replace the head later
|
||||
|
||||
head = 0
|
||||
if obj.lower_inc:
|
||||
head |= RANGE_LB_INC
|
||||
if obj.upper_inc:
|
||||
head |= RANGE_UB_INC
|
||||
|
||||
if obj.lower is not None:
|
||||
data = dump(obj.lower)
|
||||
out += pack_len(len(data))
|
||||
out += data
|
||||
else:
|
||||
head |= RANGE_LB_INF
|
||||
|
||||
if obj.upper is not None:
|
||||
data = dump(obj.upper)
|
||||
out += pack_len(len(data))
|
||||
out += data
|
||||
else:
|
||||
head |= RANGE_UB_INF
|
||||
|
||||
out[0] = head
|
||||
return out
|
||||
|
||||
|
||||
def fail_dump(obj: Any) -> Buffer:
|
||||
raise e.InternalError("trying to dump a range element without information")
|
||||
|
||||
|
||||
class BaseRangeLoader(RecursiveLoader, Generic[T]):
|
||||
"""Generic loader for a range.
|
||||
|
||||
Subclasses must specify the oid of the subtype and the class to load.
|
||||
"""
|
||||
|
||||
subtype_oid: int
|
||||
|
||||
def __init__(self, oid: int, context: Optional[AdaptContext] = None):
|
||||
super().__init__(oid, context)
|
||||
self._load = self._tx.get_loader(self.subtype_oid, format=self.format).load
|
||||
|
||||
|
||||
class RangeLoader(BaseRangeLoader[T]):
|
||||
def load(self, data: Buffer) -> Range[T]:
|
||||
return load_range_text(data, self._load)[0]
|
||||
|
||||
|
||||
def load_range_text(
|
||||
data: Buffer, load: Callable[[Buffer], Any]
|
||||
) -> Tuple[Range[Any], int]:
|
||||
if data == b"empty":
|
||||
return Range(empty=True), 5
|
||||
|
||||
m = _re_range.match(data)
|
||||
if m is None:
|
||||
raise e.DataError(
|
||||
f"failed to parse range: '{bytes(data).decode('utf8', 'replace')}'"
|
||||
)
|
||||
|
||||
lower = None
|
||||
item = m.group(3)
|
||||
if item is None:
|
||||
item = m.group(2)
|
||||
if item is not None:
|
||||
lower = load(_re_undouble.sub(rb"\1", item))
|
||||
else:
|
||||
lower = load(item)
|
||||
|
||||
upper = None
|
||||
item = m.group(5)
|
||||
if item is None:
|
||||
item = m.group(4)
|
||||
if item is not None:
|
||||
upper = load(_re_undouble.sub(rb"\1", item))
|
||||
else:
|
||||
upper = load(item)
|
||||
|
||||
bounds = (m.group(1) + m.group(6)).decode()
|
||||
|
||||
return Range(lower, upper, bounds), m.end()
|
||||
|
||||
|
||||
_re_range = re.compile(
|
||||
rb"""
|
||||
( \(|\[ ) # lower bound flag
|
||||
(?: # lower bound:
|
||||
" ( (?: [^"] | "")* ) " # - a quoted string
|
||||
| ( [^",]+ ) # - or an unquoted string
|
||||
)? # - or empty (not caught)
|
||||
,
|
||||
(?: # upper bound:
|
||||
" ( (?: [^"] | "")* ) " # - a quoted string
|
||||
| ( [^"\)\]]+ ) # - or an unquoted string
|
||||
)? # - or empty (not caught)
|
||||
( \)|\] ) # upper bound flag
|
||||
""",
|
||||
re.VERBOSE,
|
||||
)
|
||||
|
||||
_re_undouble = re.compile(rb'(["\\])\1')
|
||||
|
||||
|
||||
class RangeBinaryLoader(BaseRangeLoader[T]):
|
||||
format = Format.BINARY
|
||||
|
||||
def load(self, data: Buffer) -> Range[T]:
|
||||
return load_range_binary(data, self._load)
|
||||
|
||||
|
||||
def load_range_binary(data: Buffer, load: Callable[[Buffer], Any]) -> Range[Any]:
|
||||
head = data[0]
|
||||
if head & RANGE_EMPTY:
|
||||
return Range(empty=True)
|
||||
|
||||
lb = "[" if head & RANGE_LB_INC else "("
|
||||
ub = "]" if head & RANGE_UB_INC else ")"
|
||||
|
||||
pos = 1 # after the head
|
||||
if head & RANGE_LB_INF:
|
||||
min = None
|
||||
else:
|
||||
length = unpack_len(data, pos)[0]
|
||||
pos += 4
|
||||
min = load(data[pos : pos + length])
|
||||
pos += length
|
||||
|
||||
if head & RANGE_UB_INF:
|
||||
max = None
|
||||
else:
|
||||
length = unpack_len(data, pos)[0]
|
||||
pos += 4
|
||||
max = load(data[pos : pos + length])
|
||||
pos += length
|
||||
|
||||
return Range(min, max, lb + ub)
|
||||
|
||||
|
||||
def register_range(info: RangeInfo, context: Optional[AdaptContext] = None) -> None:
|
||||
"""Register the adapters to load and dump a range type.
|
||||
|
||||
:param info: The object with the information about the range to register.
|
||||
:param context: The context where to register the adapters. If `!None`,
|
||||
register it globally.
|
||||
|
||||
Register loaders so that loading data of this type will result in a `Range`
|
||||
with bounds parsed as the right subtype.
|
||||
|
||||
.. note::
|
||||
|
||||
Registering the adapters doesn't affect objects already created, even
|
||||
if they are children of the registered context. For instance,
|
||||
registering the adapter globally doesn't affect already existing
|
||||
connections.
|
||||
"""
|
||||
# A friendly error warning instead of an AttributeError in case fetch()
|
||||
# failed and it wasn't noticed.
|
||||
if not info:
|
||||
raise TypeError("no info passed. Is the requested range available?")
|
||||
|
||||
# Register arrays and type info
|
||||
info.register(context)
|
||||
|
||||
adapters = context.adapters if context else postgres.adapters
|
||||
|
||||
# generate and register a customized text loader
|
||||
loader: Type[BaseRangeLoader[Any]]
|
||||
loader = _make_loader(info.name, info.subtype_oid)
|
||||
adapters.register_loader(info.oid, loader)
|
||||
|
||||
# generate and register a customized binary loader
|
||||
loader = _make_binary_loader(info.name, info.subtype_oid)
|
||||
adapters.register_loader(info.oid, loader)
|
||||
|
||||
|
||||
# Cache all dynamically-generated types to avoid leaks in case the types
|
||||
# cannot be GC'd.
|
||||
|
||||
|
||||
@cache
|
||||
def _make_loader(name: str, oid: int) -> Type[RangeLoader[Any]]:
|
||||
return type(f"{name.title()}Loader", (RangeLoader,), {"subtype_oid": oid})
|
||||
|
||||
|
||||
@cache
|
||||
def _make_binary_loader(name: str, oid: int) -> Type[RangeBinaryLoader[Any]]:
|
||||
return type(
|
||||
f"{name.title()}BinaryLoader", (RangeBinaryLoader,), {"subtype_oid": oid}
|
||||
)
|
||||
|
||||
|
||||
# Text dumpers for builtin range types wrappers
|
||||
# These are registered on specific subtypes so that the upgrade mechanism
|
||||
# doesn't kick in.
|
||||
|
||||
|
||||
class Int4RangeDumper(RangeDumper):
|
||||
oid = postgres.types["int4range"].oid
|
||||
|
||||
|
||||
class Int8RangeDumper(RangeDumper):
|
||||
oid = postgres.types["int8range"].oid
|
||||
|
||||
|
||||
class NumericRangeDumper(RangeDumper):
|
||||
oid = postgres.types["numrange"].oid
|
||||
|
||||
|
||||
class DateRangeDumper(RangeDumper):
|
||||
oid = postgres.types["daterange"].oid
|
||||
|
||||
|
||||
class TimestampRangeDumper(RangeDumper):
|
||||
oid = postgres.types["tsrange"].oid
|
||||
|
||||
|
||||
class TimestamptzRangeDumper(RangeDumper):
|
||||
oid = postgres.types["tstzrange"].oid
|
||||
|
||||
|
||||
# Binary dumpers for builtin range types wrappers
|
||||
# These are registered on specific subtypes so that the upgrade mechanism
|
||||
# doesn't kick in.
|
||||
|
||||
|
||||
class Int4RangeBinaryDumper(RangeBinaryDumper):
|
||||
oid = postgres.types["int4range"].oid
|
||||
|
||||
|
||||
class Int8RangeBinaryDumper(RangeBinaryDumper):
|
||||
oid = postgres.types["int8range"].oid
|
||||
|
||||
|
||||
class NumericRangeBinaryDumper(RangeBinaryDumper):
|
||||
oid = postgres.types["numrange"].oid
|
||||
|
||||
|
||||
class DateRangeBinaryDumper(RangeBinaryDumper):
|
||||
oid = postgres.types["daterange"].oid
|
||||
|
||||
|
||||
class TimestampRangeBinaryDumper(RangeBinaryDumper):
|
||||
oid = postgres.types["tsrange"].oid
|
||||
|
||||
|
||||
class TimestamptzRangeBinaryDumper(RangeBinaryDumper):
|
||||
oid = postgres.types["tstzrange"].oid
|
||||
|
||||
|
||||
# Text loaders for builtin range types
|
||||
|
||||
|
||||
class Int4RangeLoader(RangeLoader[int]):
|
||||
subtype_oid = postgres.types["int4"].oid
|
||||
|
||||
|
||||
class Int8RangeLoader(RangeLoader[int]):
|
||||
subtype_oid = postgres.types["int8"].oid
|
||||
|
||||
|
||||
class NumericRangeLoader(RangeLoader[Decimal]):
|
||||
subtype_oid = postgres.types["numeric"].oid
|
||||
|
||||
|
||||
class DateRangeLoader(RangeLoader[date]):
|
||||
subtype_oid = postgres.types["date"].oid
|
||||
|
||||
|
||||
class TimestampRangeLoader(RangeLoader[datetime]):
|
||||
subtype_oid = postgres.types["timestamp"].oid
|
||||
|
||||
|
||||
class TimestampTZRangeLoader(RangeLoader[datetime]):
|
||||
subtype_oid = postgres.types["timestamptz"].oid
|
||||
|
||||
|
||||
# Binary loaders for builtin range types
|
||||
|
||||
|
||||
class Int4RangeBinaryLoader(RangeBinaryLoader[int]):
|
||||
subtype_oid = postgres.types["int4"].oid
|
||||
|
||||
|
||||
class Int8RangeBinaryLoader(RangeBinaryLoader[int]):
|
||||
subtype_oid = postgres.types["int8"].oid
|
||||
|
||||
|
||||
class NumericRangeBinaryLoader(RangeBinaryLoader[Decimal]):
|
||||
subtype_oid = postgres.types["numeric"].oid
|
||||
|
||||
|
||||
class DateRangeBinaryLoader(RangeBinaryLoader[date]):
|
||||
subtype_oid = postgres.types["date"].oid
|
||||
|
||||
|
||||
class TimestampRangeBinaryLoader(RangeBinaryLoader[datetime]):
|
||||
subtype_oid = postgres.types["timestamp"].oid
|
||||
|
||||
|
||||
class TimestampTZRangeBinaryLoader(RangeBinaryLoader[datetime]):
|
||||
subtype_oid = postgres.types["timestamptz"].oid
|
||||
|
||||
|
||||
def register_default_adapters(context: AdaptContext) -> None:
|
||||
adapters = context.adapters
|
||||
adapters.register_dumper(Range, RangeBinaryDumper)
|
||||
adapters.register_dumper(Range, RangeDumper)
|
||||
adapters.register_dumper(Int4Range, Int4RangeDumper)
|
||||
adapters.register_dumper(Int8Range, Int8RangeDumper)
|
||||
adapters.register_dumper(NumericRange, NumericRangeDumper)
|
||||
adapters.register_dumper(DateRange, DateRangeDumper)
|
||||
adapters.register_dumper(TimestampRange, TimestampRangeDumper)
|
||||
adapters.register_dumper(TimestamptzRange, TimestamptzRangeDumper)
|
||||
adapters.register_dumper(Int4Range, Int4RangeBinaryDumper)
|
||||
adapters.register_dumper(Int8Range, Int8RangeBinaryDumper)
|
||||
adapters.register_dumper(NumericRange, NumericRangeBinaryDumper)
|
||||
adapters.register_dumper(DateRange, DateRangeBinaryDumper)
|
||||
adapters.register_dumper(TimestampRange, TimestampRangeBinaryDumper)
|
||||
adapters.register_dumper(TimestamptzRange, TimestamptzRangeBinaryDumper)
|
||||
adapters.register_loader("int4range", Int4RangeLoader)
|
||||
adapters.register_loader("int8range", Int8RangeLoader)
|
||||
adapters.register_loader("numrange", NumericRangeLoader)
|
||||
adapters.register_loader("daterange", DateRangeLoader)
|
||||
adapters.register_loader("tsrange", TimestampRangeLoader)
|
||||
adapters.register_loader("tstzrange", TimestampTZRangeLoader)
|
||||
adapters.register_loader("int4range", Int4RangeBinaryLoader)
|
||||
adapters.register_loader("int8range", Int8RangeBinaryLoader)
|
||||
adapters.register_loader("numrange", NumericRangeBinaryLoader)
|
||||
adapters.register_loader("daterange", DateRangeBinaryLoader)
|
||||
adapters.register_loader("tsrange", TimestampRangeBinaryLoader)
|
||||
adapters.register_loader("tstzrange", TimestampTZRangeBinaryLoader)
|
||||
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
Adapters for PostGIS geometries
|
||||
"""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from .. import postgres
|
||||
from ..abc import AdaptContext, Buffer
|
||||
from ..adapt import Dumper, Loader
|
||||
from ..pq import Format
|
||||
from .._compat import cache
|
||||
from .._typeinfo import TypeInfo
|
||||
|
||||
|
||||
try:
|
||||
from shapely.wkb import loads, dumps
|
||||
from shapely.geometry.base import BaseGeometry
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The module psycopg.types.shapely requires the package 'Shapely'"
|
||||
" to be installed"
|
||||
)
|
||||
|
||||
|
||||
class GeometryBinaryLoader(Loader):
|
||||
format = Format.BINARY
|
||||
|
||||
def load(self, data: Buffer) -> "BaseGeometry":
|
||||
if not isinstance(data, bytes):
|
||||
data = bytes(data)
|
||||
return loads(data)
|
||||
|
||||
|
||||
class GeometryLoader(Loader):
|
||||
def load(self, data: Buffer) -> "BaseGeometry":
|
||||
# it's a hex string in binary
|
||||
if isinstance(data, memoryview):
|
||||
data = bytes(data)
|
||||
return loads(data.decode(), hex=True)
|
||||
|
||||
|
||||
class BaseGeometryBinaryDumper(Dumper):
|
||||
format = Format.BINARY
|
||||
|
||||
def dump(self, obj: "BaseGeometry") -> bytes:
|
||||
return dumps(obj) # type: ignore
|
||||
|
||||
|
||||
class BaseGeometryDumper(Dumper):
|
||||
def dump(self, obj: "BaseGeometry") -> bytes:
|
||||
return dumps(obj, hex=True).encode() # type: ignore
|
||||
|
||||
|
||||
def register_shapely(info: TypeInfo, context: Optional[AdaptContext] = None) -> None:
|
||||
"""Register Shapely dumper and loaders."""
|
||||
|
||||
# A friendly error warning instead of an AttributeError in case fetch()
|
||||
# failed and it wasn't noticed.
|
||||
if not info:
|
||||
raise TypeError("no info passed. Is the 'postgis' extension loaded?")
|
||||
|
||||
info.register(context)
|
||||
adapters = context.adapters if context else postgres.adapters
|
||||
|
||||
adapters.register_loader(info.oid, GeometryBinaryLoader)
|
||||
adapters.register_loader(info.oid, GeometryLoader)
|
||||
# Default binary dump
|
||||
adapters.register_dumper(BaseGeometry, _make_dumper(info.oid))
|
||||
adapters.register_dumper(BaseGeometry, _make_binary_dumper(info.oid))
|
||||
|
||||
|
||||
# Cache all dynamically-generated types to avoid leaks in case the types
|
||||
# cannot be GC'd.
|
||||
|
||||
|
||||
@cache
|
||||
def _make_dumper(oid_in: int) -> Type[BaseGeometryDumper]:
|
||||
class GeometryDumper(BaseGeometryDumper):
|
||||
oid = oid_in
|
||||
|
||||
return GeometryDumper
|
||||
|
||||
|
||||
@cache
|
||||
def _make_binary_dumper(oid_in: int) -> Type[BaseGeometryBinaryDumper]:
|
||||
class GeometryBinaryDumper(BaseGeometryBinaryDumper):
|
||||
oid = oid_in
|
||||
|
||||
return GeometryBinaryDumper
|
||||
229
srcs/.venv/lib/python3.11/site-packages/psycopg/types/string.py
Normal file
229
srcs/.venv/lib/python3.11/site-packages/psycopg/types/string.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""
|
||||
Adapters for textual types.
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
from typing import Optional, Union, TYPE_CHECKING
|
||||
|
||||
from .. import postgres
|
||||
from ..pq import Format, Escaping
|
||||
from ..abc import AdaptContext
|
||||
from ..adapt import Buffer, Dumper, Loader
|
||||
from ..errors import DataError
|
||||
from .._encodings import conn_encoding
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..pq.abc import Escaping as EscapingProto
|
||||
|
||||
|
||||
class _BaseStrDumper(Dumper):
|
||||
def __init__(self, cls: type, context: Optional[AdaptContext] = None):
|
||||
super().__init__(cls, context)
|
||||
enc = conn_encoding(self.connection)
|
||||
self._encoding = enc if enc != "ascii" else "utf-8"
|
||||
|
||||
|
||||
class _StrBinaryDumper(_BaseStrDumper):
|
||||
"""
|
||||
Base class to dump a Python strings to a Postgres text type, in binary format.
|
||||
|
||||
Subclasses shall specify the oids of real types (text, varchar, name...).
|
||||
"""
|
||||
|
||||
format = Format.BINARY
|
||||
|
||||
def dump(self, obj: str) -> bytes:
|
||||
# the server will raise DataError subclass if the string contains 0x00
|
||||
return obj.encode(self._encoding)
|
||||
|
||||
|
||||
class _StrDumper(_BaseStrDumper):
|
||||
"""
|
||||
Base class to dump a Python strings to a Postgres text type, in text format.
|
||||
|
||||
Subclasses shall specify the oids of real types (text, varchar, name...).
|
||||
"""
|
||||
|
||||
def dump(self, obj: str) -> bytes:
|
||||
if "\x00" in obj:
|
||||
raise DataError("PostgreSQL text fields cannot contain NUL (0x00) bytes")
|
||||
else:
|
||||
return obj.encode(self._encoding)
|
||||
|
||||
|
||||
# The next are concrete dumpers, each one specifying the oid they dump to.
|
||||
|
||||
|
||||
class StrBinaryDumper(_StrBinaryDumper):
|
||||
oid = postgres.types["text"].oid
|
||||
|
||||
|
||||
class StrBinaryDumperVarchar(_StrBinaryDumper):
|
||||
oid = postgres.types["varchar"].oid
|
||||
|
||||
|
||||
class StrBinaryDumperName(_StrBinaryDumper):
|
||||
oid = postgres.types["name"].oid
|
||||
|
||||
|
||||
class StrDumper(_StrDumper):
|
||||
"""
|
||||
Dumper for strings in text format to the text oid.
|
||||
|
||||
Note that this dumper is not used by default because the type is too strict
|
||||
and PostgreSQL would require an explicit casts to everything that is not a
|
||||
text field. However it is useful where the unknown oid is ambiguous and the
|
||||
text oid is required, for instance with variadic functions.
|
||||
"""
|
||||
|
||||
oid = postgres.types["text"].oid
|
||||
|
||||
|
||||
class StrDumperVarchar(_StrDumper):
|
||||
oid = postgres.types["varchar"].oid
|
||||
|
||||
|
||||
class StrDumperName(_StrDumper):
|
||||
oid = postgres.types["name"].oid
|
||||
|
||||
|
||||
class StrDumperUnknown(_StrDumper):
|
||||
"""
|
||||
Dumper for strings in text format to the unknown oid.
|
||||
|
||||
This dumper is the default dumper for strings and allows to use Python
|
||||
strings to represent almost every data type. In a few places, however, the
|
||||
unknown oid is not accepted (for instance in variadic functions such as
|
||||
'concat()'). In that case either a cast on the placeholder ('%s::text') or
|
||||
the StrTextDumper should be used.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TextLoader(Loader):
|
||||
def __init__(self, oid: int, context: Optional[AdaptContext] = None):
|
||||
super().__init__(oid, context)
|
||||
enc = conn_encoding(self.connection)
|
||||
self._encoding = enc if enc != "ascii" else ""
|
||||
|
||||
def load(self, data: Buffer) -> Union[bytes, str]:
|
||||
if self._encoding:
|
||||
if isinstance(data, memoryview):
|
||||
data = bytes(data)
|
||||
return data.decode(self._encoding)
|
||||
else:
|
||||
# return bytes for SQL_ASCII db
|
||||
if not isinstance(data, bytes):
|
||||
data = bytes(data)
|
||||
return data
|
||||
|
||||
|
||||
class TextBinaryLoader(TextLoader):
|
||||
format = Format.BINARY
|
||||
|
||||
|
||||
class BytesDumper(Dumper):
|
||||
oid = postgres.types["bytea"].oid
|
||||
_qprefix = b""
|
||||
|
||||
def __init__(self, cls: type, context: Optional[AdaptContext] = None):
|
||||
super().__init__(cls, context)
|
||||
self._esc = Escaping(self.connection.pgconn if self.connection else None)
|
||||
|
||||
def dump(self, obj: Buffer) -> Buffer:
|
||||
return self._esc.escape_bytea(obj)
|
||||
|
||||
def quote(self, obj: Buffer) -> bytes:
|
||||
escaped = self.dump(obj)
|
||||
|
||||
# We cannot use the base quoting because escape_bytea already returns
|
||||
# the quotes content. if scs is off it will escape the backslashes in
|
||||
# the format, otherwise it won't, but it doesn't tell us what quotes to
|
||||
# use.
|
||||
if self.connection:
|
||||
if not self._qprefix:
|
||||
scs = self.connection.pgconn.parameter_status(
|
||||
b"standard_conforming_strings"
|
||||
)
|
||||
self._qprefix = b"'" if scs == b"on" else b" E'"
|
||||
|
||||
return self._qprefix + escaped + b"'"
|
||||
|
||||
# We don't have a connection, so someone is using us to generate a file
|
||||
# to use off-line or something like that. PQescapeBytea, like its
|
||||
# string counterpart, is not predictable whether it will escape
|
||||
# backslashes.
|
||||
rv: bytes = b" E'" + escaped + b"'"
|
||||
if self._esc.escape_bytea(b"\x00") == b"\\000":
|
||||
rv = rv.replace(b"\\", b"\\\\")
|
||||
return rv
|
||||
|
||||
|
||||
class BytesBinaryDumper(Dumper):
|
||||
format = Format.BINARY
|
||||
oid = postgres.types["bytea"].oid
|
||||
|
||||
def dump(self, obj: Buffer) -> Buffer:
|
||||
return obj
|
||||
|
||||
|
||||
class ByteaLoader(Loader):
|
||||
_escaping: "EscapingProto"
|
||||
|
||||
def __init__(self, oid: int, context: Optional[AdaptContext] = None):
|
||||
super().__init__(oid, context)
|
||||
if not hasattr(self.__class__, "_escaping"):
|
||||
self.__class__._escaping = Escaping()
|
||||
|
||||
def load(self, data: Buffer) -> bytes:
|
||||
return self._escaping.unescape_bytea(data)
|
||||
|
||||
|
||||
class ByteaBinaryLoader(Loader):
|
||||
format = Format.BINARY
|
||||
|
||||
def load(self, data: Buffer) -> Buffer:
|
||||
return data
|
||||
|
||||
|
||||
def register_default_adapters(context: AdaptContext) -> None:
|
||||
adapters = context.adapters
|
||||
|
||||
# NOTE: the order the dumpers are registered is relevant. The last one
|
||||
# registered becomes the default for each type. Usually, binary is the
|
||||
# default dumper. For text we use the text dumper as default because it
|
||||
# plays the role of unknown, and it can be cast automatically to other
|
||||
# types. However, before that, we register dumper with 'text', 'varchar',
|
||||
# 'name' oids, which will be used when a text dumper is looked up by oid.
|
||||
adapters.register_dumper(str, StrBinaryDumperName)
|
||||
adapters.register_dumper(str, StrBinaryDumperVarchar)
|
||||
adapters.register_dumper(str, StrBinaryDumper)
|
||||
adapters.register_dumper(str, StrDumperName)
|
||||
adapters.register_dumper(str, StrDumperVarchar)
|
||||
adapters.register_dumper(str, StrDumper)
|
||||
adapters.register_dumper(str, StrDumperUnknown)
|
||||
|
||||
adapters.register_loader(postgres.INVALID_OID, TextLoader)
|
||||
adapters.register_loader("bpchar", TextLoader)
|
||||
adapters.register_loader("name", TextLoader)
|
||||
adapters.register_loader("text", TextLoader)
|
||||
adapters.register_loader("varchar", TextLoader)
|
||||
adapters.register_loader('"char"', TextLoader)
|
||||
adapters.register_loader("bpchar", TextBinaryLoader)
|
||||
adapters.register_loader("name", TextBinaryLoader)
|
||||
adapters.register_loader("text", TextBinaryLoader)
|
||||
adapters.register_loader("varchar", TextBinaryLoader)
|
||||
adapters.register_loader('"char"', TextBinaryLoader)
|
||||
|
||||
adapters.register_dumper(bytes, BytesDumper)
|
||||
adapters.register_dumper(bytearray, BytesDumper)
|
||||
adapters.register_dumper(memoryview, BytesDumper)
|
||||
adapters.register_dumper(bytes, BytesBinaryDumper)
|
||||
adapters.register_dumper(bytearray, BytesBinaryDumper)
|
||||
adapters.register_dumper(memoryview, BytesBinaryDumper)
|
||||
|
||||
adapters.register_loader("bytea", ByteaLoader)
|
||||
adapters.register_loader(postgres.INVALID_OID, ByteaBinaryLoader)
|
||||
adapters.register_loader("bytea", ByteaBinaryLoader)
|
||||
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
Adapters for the UUID type.
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
from typing import Callable, Optional, TYPE_CHECKING
|
||||
|
||||
from .. import postgres
|
||||
from ..pq import Format
|
||||
from ..abc import AdaptContext
|
||||
from ..adapt import Buffer, Dumper, Loader
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import uuid
|
||||
|
||||
# Importing the uuid module is slow, so import it only on request.
|
||||
UUID: Callable[..., "uuid.UUID"] = None # type: ignore[assignment]
|
||||
|
||||
|
||||
class UUIDDumper(Dumper):
|
||||
oid = postgres.types["uuid"].oid
|
||||
|
||||
def dump(self, obj: "uuid.UUID") -> bytes:
|
||||
return obj.hex.encode()
|
||||
|
||||
|
||||
class UUIDBinaryDumper(UUIDDumper):
|
||||
format = Format.BINARY
|
||||
|
||||
def dump(self, obj: "uuid.UUID") -> bytes:
|
||||
return obj.bytes
|
||||
|
||||
|
||||
class UUIDLoader(Loader):
|
||||
def __init__(self, oid: int, context: Optional[AdaptContext] = None):
|
||||
super().__init__(oid, context)
|
||||
global UUID
|
||||
if UUID is None:
|
||||
from uuid import UUID
|
||||
|
||||
def load(self, data: Buffer) -> "uuid.UUID":
|
||||
if isinstance(data, memoryview):
|
||||
data = bytes(data)
|
||||
return UUID(data.decode())
|
||||
|
||||
|
||||
class UUIDBinaryLoader(UUIDLoader):
|
||||
format = Format.BINARY
|
||||
|
||||
def load(self, data: Buffer) -> "uuid.UUID":
|
||||
if isinstance(data, memoryview):
|
||||
data = bytes(data)
|
||||
return UUID(bytes=data)
|
||||
|
||||
|
||||
def register_default_adapters(context: AdaptContext) -> None:
|
||||
adapters = context.adapters
|
||||
adapters.register_dumper("uuid.UUID", UUIDDumper)
|
||||
adapters.register_dumper("uuid.UUID", UUIDBinaryDumper)
|
||||
adapters.register_loader("uuid", UUIDLoader)
|
||||
adapters.register_loader("uuid", UUIDBinaryLoader)
|
||||
14
srcs/.venv/lib/python3.11/site-packages/psycopg/version.py
Normal file
14
srcs/.venv/lib/python3.11/site-packages/psycopg/version.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
psycopg distribution version file.
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
# Use a versioning scheme as defined in
|
||||
# https://www.python.org/dev/peps/pep-0440/
|
||||
|
||||
# STOP AND READ! if you change:
|
||||
__version__ = "3.1.13"
|
||||
# also change:
|
||||
# - `docs/news.rst` to declare this as the current version or an unreleased one
|
||||
# - `psycopg_c/psycopg_c/version.py` to the same version.
|
||||
393
srcs/.venv/lib/python3.11/site-packages/psycopg/waiting.py
Normal file
393
srcs/.venv/lib/python3.11/site-packages/psycopg/waiting.py
Normal file
@@ -0,0 +1,393 @@
|
||||
"""
|
||||
Code concerned with waiting in different contexts (blocking, async, etc).
|
||||
|
||||
These functions are designed to consume the generators returned by the
|
||||
`generators` module function and to return their final value.
|
||||
|
||||
"""
|
||||
|
||||
# Copyright (C) 2020 The Psycopg Team
|
||||
|
||||
|
||||
import os
|
||||
import sys
|
||||
import select
|
||||
import selectors
|
||||
from typing import Optional
|
||||
from asyncio import get_event_loop, wait_for, Event, TimeoutError
|
||||
from selectors import DefaultSelector
|
||||
|
||||
from . import errors as e
|
||||
from .abc import RV, PQGen, PQGenConn, WaitFunc
|
||||
from ._enums import Wait as Wait, Ready as Ready # re-exported
|
||||
from ._cmodule import _psycopg
|
||||
|
||||
WAIT_R = Wait.R
|
||||
WAIT_W = Wait.W
|
||||
WAIT_RW = Wait.RW
|
||||
READY_R = Ready.R
|
||||
READY_W = Ready.W
|
||||
READY_RW = Ready.RW
|
||||
|
||||
|
||||
def wait_selector(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV:
|
||||
"""
|
||||
Wait for a generator using the best strategy available.
|
||||
|
||||
:param gen: a generator performing database operations and yielding
|
||||
`Ready` values when it would block.
|
||||
:param fileno: the file descriptor to wait on.
|
||||
:param timeout: timeout (in seconds) to check for other interrupt, e.g.
|
||||
to allow Ctrl-C.
|
||||
:type timeout: float
|
||||
:return: whatever `!gen` returns on completion.
|
||||
|
||||
Consume `!gen`, scheduling `fileno` for completion when it is reported to
|
||||
block. Once ready again send the ready state back to `!gen`.
|
||||
"""
|
||||
try:
|
||||
s = next(gen)
|
||||
with DefaultSelector() as sel:
|
||||
while True:
|
||||
sel.register(fileno, s)
|
||||
rlist = None
|
||||
while not rlist:
|
||||
rlist = sel.select(timeout=timeout)
|
||||
sel.unregister(fileno)
|
||||
# note: this line should require a cast, but mypy doesn't complain
|
||||
ready: Ready = rlist[0][1]
|
||||
assert s & ready
|
||||
s = gen.send(ready)
|
||||
|
||||
except StopIteration as ex:
|
||||
rv: RV = ex.args[0] if ex.args else None
|
||||
return rv
|
||||
|
||||
|
||||
def wait_conn(gen: PQGenConn[RV], timeout: Optional[float] = None) -> RV:
|
||||
"""
|
||||
Wait for a connection generator using the best strategy available.
|
||||
|
||||
:param gen: a generator performing database operations and yielding
|
||||
(fd, `Ready`) pairs when it would block.
|
||||
:param timeout: timeout (in seconds) to check for other interrupt, e.g.
|
||||
to allow Ctrl-C. If zero or None, wait indefinitely.
|
||||
:type timeout: float
|
||||
:return: whatever `!gen` returns on completion.
|
||||
|
||||
Behave like in `wait()`, but take the fileno to wait from the generator
|
||||
itself, which might change during processing.
|
||||
"""
|
||||
try:
|
||||
fileno, s = next(gen)
|
||||
if not timeout:
|
||||
timeout = None
|
||||
with DefaultSelector() as sel:
|
||||
while True:
|
||||
sel.register(fileno, s)
|
||||
rlist = sel.select(timeout=timeout)
|
||||
sel.unregister(fileno)
|
||||
if not rlist:
|
||||
raise e.ConnectionTimeout("connection timeout expired")
|
||||
ready: Ready = rlist[0][1] # type: ignore[assignment]
|
||||
fileno, s = gen.send(ready)
|
||||
|
||||
except StopIteration as ex:
|
||||
rv: RV = ex.args[0] if ex.args else None
|
||||
return rv
|
||||
|
||||
|
||||
async def wait_async(
|
||||
gen: PQGen[RV], fileno: int, timeout: Optional[float] = None
|
||||
) -> RV:
|
||||
"""
|
||||
Coroutine waiting for a generator to complete.
|
||||
|
||||
:param gen: a generator performing database operations and yielding
|
||||
`Ready` values when it would block.
|
||||
:param fileno: the file descriptor to wait on.
|
||||
:return: whatever `!gen` returns on completion.
|
||||
|
||||
Behave like in `wait()`, but exposing an `asyncio` interface.
|
||||
"""
|
||||
# Use an event to block and restart after the fd state changes.
|
||||
# Not sure this is the best implementation but it's a start.
|
||||
ev = Event()
|
||||
loop = get_event_loop()
|
||||
ready: Ready
|
||||
s: Wait
|
||||
|
||||
def wakeup(state: Ready) -> None:
|
||||
nonlocal ready
|
||||
ready |= state # type: ignore[assignment]
|
||||
ev.set()
|
||||
|
||||
try:
|
||||
s = next(gen)
|
||||
while True:
|
||||
reader = s & WAIT_R
|
||||
writer = s & WAIT_W
|
||||
if not reader and not writer:
|
||||
raise e.InternalError(f"bad poll status: {s}")
|
||||
ev.clear()
|
||||
ready = 0 # type: ignore[assignment]
|
||||
if reader:
|
||||
loop.add_reader(fileno, wakeup, READY_R)
|
||||
if writer:
|
||||
loop.add_writer(fileno, wakeup, READY_W)
|
||||
try:
|
||||
if timeout is None:
|
||||
await ev.wait()
|
||||
else:
|
||||
try:
|
||||
await wait_for(ev.wait(), timeout)
|
||||
except TimeoutError:
|
||||
pass
|
||||
finally:
|
||||
if reader:
|
||||
loop.remove_reader(fileno)
|
||||
if writer:
|
||||
loop.remove_writer(fileno)
|
||||
s = gen.send(ready)
|
||||
|
||||
except StopIteration as ex:
|
||||
rv: RV = ex.args[0] if ex.args else None
|
||||
return rv
|
||||
|
||||
|
||||
async def wait_conn_async(gen: PQGenConn[RV], timeout: Optional[float] = None) -> RV:
|
||||
"""
|
||||
Coroutine waiting for a connection generator to complete.
|
||||
|
||||
:param gen: a generator performing database operations and yielding
|
||||
(fd, `Ready`) pairs when it would block.
|
||||
:param timeout: timeout (in seconds) to check for other interrupt, e.g.
|
||||
to allow Ctrl-C. If zero or None, wait indefinitely.
|
||||
:return: whatever `!gen` returns on completion.
|
||||
|
||||
Behave like in `wait()`, but take the fileno to wait from the generator
|
||||
itself, which might change during processing.
|
||||
"""
|
||||
# Use an event to block and restart after the fd state changes.
|
||||
# Not sure this is the best implementation but it's a start.
|
||||
ev = Event()
|
||||
loop = get_event_loop()
|
||||
ready: Ready
|
||||
s: Wait
|
||||
|
||||
def wakeup(state: Ready) -> None:
|
||||
nonlocal ready
|
||||
ready = state
|
||||
ev.set()
|
||||
|
||||
try:
|
||||
fileno, s = next(gen)
|
||||
if not timeout:
|
||||
timeout = None
|
||||
while True:
|
||||
reader = s & WAIT_R
|
||||
writer = s & WAIT_W
|
||||
if not reader and not writer:
|
||||
raise e.InternalError(f"bad poll status: {s}")
|
||||
ev.clear()
|
||||
ready = 0 # type: ignore[assignment]
|
||||
if reader:
|
||||
loop.add_reader(fileno, wakeup, READY_R)
|
||||
if writer:
|
||||
loop.add_writer(fileno, wakeup, READY_W)
|
||||
try:
|
||||
await wait_for(ev.wait(), timeout)
|
||||
finally:
|
||||
if reader:
|
||||
loop.remove_reader(fileno)
|
||||
if writer:
|
||||
loop.remove_writer(fileno)
|
||||
fileno, s = gen.send(ready)
|
||||
|
||||
except TimeoutError:
|
||||
raise e.ConnectionTimeout("connection timeout expired")
|
||||
|
||||
except StopIteration as ex:
|
||||
rv: RV = ex.args[0] if ex.args else None
|
||||
return rv
|
||||
|
||||
|
||||
# Specialised implementation of wait functions.
|
||||
|
||||
|
||||
def wait_select(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV:
|
||||
"""
|
||||
Wait for a generator using select where supported.
|
||||
|
||||
BUG: on Linux, can't select on FD >= 1024. On Windows it's fine.
|
||||
"""
|
||||
try:
|
||||
s = next(gen)
|
||||
|
||||
empty = ()
|
||||
fnlist = (fileno,)
|
||||
while True:
|
||||
rl, wl, xl = select.select(
|
||||
fnlist if s & WAIT_R else empty,
|
||||
fnlist if s & WAIT_W else empty,
|
||||
fnlist,
|
||||
timeout,
|
||||
)
|
||||
ready = 0
|
||||
if rl:
|
||||
ready = READY_R
|
||||
if wl:
|
||||
ready |= READY_W
|
||||
if not ready:
|
||||
continue
|
||||
# assert s & ready
|
||||
s = gen.send(ready) # type: ignore
|
||||
|
||||
except StopIteration as ex:
|
||||
rv: RV = ex.args[0] if ex.args else None
|
||||
return rv
|
||||
|
||||
|
||||
if hasattr(selectors, "EpollSelector"):
|
||||
_epoll_evmasks = {
|
||||
WAIT_R: select.EPOLLONESHOT | select.EPOLLIN | select.EPOLLERR,
|
||||
WAIT_W: select.EPOLLONESHOT | select.EPOLLOUT | select.EPOLLERR,
|
||||
WAIT_RW: select.EPOLLONESHOT
|
||||
| (select.EPOLLIN | select.EPOLLOUT | select.EPOLLERR),
|
||||
}
|
||||
else:
|
||||
_epoll_evmasks = {}
|
||||
|
||||
|
||||
def wait_epoll(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV:
|
||||
"""
|
||||
Wait for a generator using epoll where supported.
|
||||
|
||||
Parameters are like for `wait()`. If it is detected that the best selector
|
||||
strategy is `epoll` then this function will be used instead of `wait`.
|
||||
|
||||
See also: https://linux.die.net/man/2/epoll_ctl
|
||||
|
||||
BUG: if the connection FD is closed, `epoll.poll()` hangs. Same for
|
||||
EpollSelector. For this reason, wait_poll() is currently preferable.
|
||||
To reproduce the bug:
|
||||
|
||||
export PSYCOPG_WAIT_FUNC=wait_epoll
|
||||
pytest tests/test_concurrency.py::test_concurrent_close
|
||||
"""
|
||||
try:
|
||||
s = next(gen)
|
||||
|
||||
if timeout is None or timeout < 0:
|
||||
timeout = 0
|
||||
else:
|
||||
timeout = int(timeout * 1000.0)
|
||||
|
||||
with select.epoll() as epoll:
|
||||
evmask = _epoll_evmasks[s]
|
||||
epoll.register(fileno, evmask)
|
||||
while True:
|
||||
fileevs = None
|
||||
while not fileevs:
|
||||
fileevs = epoll.poll(timeout)
|
||||
ev = fileevs[0][1]
|
||||
ready = 0
|
||||
if ev & ~select.EPOLLOUT:
|
||||
ready = READY_R
|
||||
if ev & ~select.EPOLLIN:
|
||||
ready |= READY_W
|
||||
# assert s & ready
|
||||
s = gen.send(ready)
|
||||
evmask = _epoll_evmasks[s]
|
||||
epoll.modify(fileno, evmask)
|
||||
|
||||
except StopIteration as ex:
|
||||
rv: RV = ex.args[0] if ex.args else None
|
||||
return rv
|
||||
|
||||
|
||||
if hasattr(selectors, "PollSelector"):
|
||||
_poll_evmasks = {
|
||||
WAIT_R: select.POLLIN,
|
||||
WAIT_W: select.POLLOUT,
|
||||
WAIT_RW: select.POLLIN | select.POLLOUT,
|
||||
}
|
||||
else:
|
||||
_poll_evmasks = {}
|
||||
|
||||
|
||||
def wait_poll(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV:
|
||||
"""
|
||||
Wait for a generator using poll where supported.
|
||||
|
||||
Parameters are like for `wait()`.
|
||||
"""
|
||||
try:
|
||||
s = next(gen)
|
||||
|
||||
if timeout is None or timeout < 0:
|
||||
timeout = 0
|
||||
else:
|
||||
timeout = int(timeout * 1000.0)
|
||||
|
||||
poll = select.poll()
|
||||
evmask = _poll_evmasks[s]
|
||||
poll.register(fileno, evmask)
|
||||
while True:
|
||||
fileevs = None
|
||||
while not fileevs:
|
||||
fileevs = poll.poll(timeout)
|
||||
ev = fileevs[0][1]
|
||||
ready = 0
|
||||
if ev & ~select.POLLOUT:
|
||||
ready = READY_R
|
||||
if ev & ~select.POLLIN:
|
||||
ready |= READY_W
|
||||
# assert s & ready
|
||||
s = gen.send(ready)
|
||||
evmask = _poll_evmasks[s]
|
||||
poll.modify(fileno, evmask)
|
||||
|
||||
except StopIteration as ex:
|
||||
rv: RV = ex.args[0] if ex.args else None
|
||||
return rv
|
||||
|
||||
|
||||
if _psycopg:
|
||||
wait_c = _psycopg.wait_c
|
||||
|
||||
|
||||
# Choose the best wait strategy for the platform.
|
||||
#
|
||||
# the selectors objects have a generic interface but come with some overhead,
|
||||
# so we also offer more finely tuned implementations.
|
||||
|
||||
wait: WaitFunc
|
||||
|
||||
# Allow the user to choose a specific function for testing
|
||||
if "PSYCOPG_WAIT_FUNC" in os.environ:
|
||||
fname = os.environ["PSYCOPG_WAIT_FUNC"]
|
||||
if not fname.startswith("wait_") or fname not in globals():
|
||||
raise ImportError(
|
||||
"PSYCOPG_WAIT_FUNC should be the name of an available wait function;"
|
||||
f" got {fname!r}"
|
||||
)
|
||||
wait = globals()[fname]
|
||||
|
||||
# On Windows, for the moment, avoid using wait_c, because it was reported to
|
||||
# use excessive CPU (see #645).
|
||||
# TODO: investigate why.
|
||||
elif _psycopg and sys.platform != "win32":
|
||||
wait = wait_c
|
||||
|
||||
elif selectors.DefaultSelector is getattr(selectors, "SelectSelector", None):
|
||||
# On Windows, SelectSelector should be the default.
|
||||
wait = wait_select
|
||||
|
||||
elif hasattr(selectors, "PollSelector"):
|
||||
# On linux, EpollSelector is the default. However, it hangs if the fd is
|
||||
# closed while polling.
|
||||
wait = wait_poll
|
||||
|
||||
else:
|
||||
wait = wait_selector
|
||||
Reference in New Issue
Block a user