""" psycopg async connection objects """ # Copyright (C) 2020 The Psycopg Team import sys import asyncio import logging from types import TracebackType from typing import Any, AsyncGenerator, AsyncIterator, List, Optional from typing import Type, TypeVar, Union, cast, overload, TYPE_CHECKING from contextlib import asynccontextmanager from . import pq from . import errors as e from . import waiting from .abc import AdaptContext, Params, PQGen, PQGenConn, Query, RV from ._tpc import Xid from .rows import Row, AsyncRowFactory, tuple_row, TupleRow, args_row from .adapt import AdaptersMap from ._enums import IsolationLevel from .conninfo import ConnDict, make_conninfo, conninfo_to_dict, conninfo_attempts_async from ._pipeline import AsyncPipeline from ._encodings import pgconn_encoding from .connection import BaseConnection, CursorRow, Notify from .generators import notifies from .transaction import AsyncTransaction from .cursor_async import AsyncCursor from .server_cursor import AsyncServerCursor if TYPE_CHECKING: from .pq.abc import PGconn TEXT = pq.Format.TEXT BINARY = pq.Format.BINARY IDLE = pq.TransactionStatus.IDLE INTRANS = pq.TransactionStatus.INTRANS logger = logging.getLogger("psycopg") class AsyncConnection(BaseConnection[Row]): """ Asynchronous wrapper for a connection to the database. """ __module__ = "psycopg" cursor_factory: Type[AsyncCursor[Row]] server_cursor_factory: Type[AsyncServerCursor[Row]] row_factory: AsyncRowFactory[Row] _pipeline: Optional[AsyncPipeline] _Self = TypeVar("_Self", bound="AsyncConnection[Any]") def __init__( self, pgconn: "PGconn", row_factory: AsyncRowFactory[Row] = cast(AsyncRowFactory[Row], tuple_row), ): super().__init__(pgconn) self.row_factory = row_factory self.lock = asyncio.Lock() self.cursor_factory = AsyncCursor self.server_cursor_factory = AsyncServerCursor @overload @classmethod async def connect( cls, conninfo: str = "", *, autocommit: bool = False, prepare_threshold: Optional[int] = 5, row_factory: AsyncRowFactory[Row], cursor_factory: Optional[Type[AsyncCursor[Row]]] = None, context: Optional[AdaptContext] = None, **kwargs: Union[None, int, str], ) -> "AsyncConnection[Row]": # TODO: returned type should be _Self. See #308. ... @overload @classmethod async def connect( cls, conninfo: str = "", *, autocommit: bool = False, prepare_threshold: Optional[int] = 5, cursor_factory: Optional[Type[AsyncCursor[Any]]] = None, context: Optional[AdaptContext] = None, **kwargs: Union[None, int, str], ) -> "AsyncConnection[TupleRow]": ... @classmethod # type: ignore[misc] # https://github.com/python/mypy/issues/11004 async def connect( cls, conninfo: str = "", *, autocommit: bool = False, prepare_threshold: Optional[int] = 5, context: Optional[AdaptContext] = None, row_factory: Optional[AsyncRowFactory[Row]] = None, cursor_factory: Optional[Type[AsyncCursor[Row]]] = None, **kwargs: Any, ) -> "AsyncConnection[Any]": if sys.platform == "win32": loop = asyncio.get_running_loop() if isinstance(loop, asyncio.ProactorEventLoop): raise e.InterfaceError( "Psycopg cannot use the 'ProactorEventLoop' to run in async" " mode. Please use a compatible event loop, for instance by" " setting 'asyncio.set_event_loop_policy" "(WindowsSelectorEventLoopPolicy())'" ) params = await cls._get_connection_params(conninfo, **kwargs) timeout = int(params["connect_timeout"]) rv = None async for attempt in conninfo_attempts_async(params): try: conninfo = make_conninfo(**attempt) rv = await cls._wait_conn(cls._connect_gen(conninfo), timeout=timeout) break except e._NO_TRACEBACK as ex: last_ex = ex if not rv: assert last_ex raise last_ex.with_traceback(None) rv._autocommit = bool(autocommit) if row_factory: rv.row_factory = row_factory if cursor_factory: rv.cursor_factory = cursor_factory if context: rv._adapters = AdaptersMap(context.adapters) rv.prepare_threshold = prepare_threshold return rv async def __aenter__(self: _Self) -> _Self: return self async def __aexit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: if self.closed: return if exc_type: # try to rollback, but if there are problems (connection in a bad # state) just warn without clobbering the exception bubbling up. try: await self.rollback() except Exception as exc2: logger.warning( "error ignored in rollback on %s: %s", self, exc2, ) else: await self.commit() # Close the connection only if it doesn't belong to a pool. if not getattr(self, "_pool", None): await self.close() @classmethod async def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> ConnDict: """Manipulate connection parameters before connecting.""" params = conninfo_to_dict(conninfo, **kwargs) # Make sure there is an usable connect_timeout if "connect_timeout" in params: params["connect_timeout"] = int(params["connect_timeout"]) else: # The sync connect function will stop on the default socket timeout # Because in async connection mode we need to enforce the timeout # ourselves, we need a finite value. params["connect_timeout"] = cls._DEFAULT_CONNECT_TIMEOUT return params async def close(self) -> None: if self.closed: return self._closed = True # TODO: maybe send a cancel on close, if the connection is ACTIVE? self.pgconn.finish() @overload def cursor(self, *, binary: bool = False) -> AsyncCursor[Row]: ... @overload def cursor( self, *, binary: bool = False, row_factory: AsyncRowFactory[CursorRow] ) -> AsyncCursor[CursorRow]: ... @overload def cursor( self, name: str, *, binary: bool = False, scrollable: Optional[bool] = None, withhold: bool = False, ) -> AsyncServerCursor[Row]: ... @overload def cursor( self, name: str, *, binary: bool = False, row_factory: AsyncRowFactory[CursorRow], scrollable: Optional[bool] = None, withhold: bool = False, ) -> AsyncServerCursor[CursorRow]: ... def cursor( self, name: str = "", *, binary: bool = False, row_factory: Optional[AsyncRowFactory[Any]] = None, scrollable: Optional[bool] = None, withhold: bool = False, ) -> Union[AsyncCursor[Any], AsyncServerCursor[Any]]: """ Return a new `AsyncCursor` to send commands and queries to the connection. """ self._check_connection_ok() if not row_factory: row_factory = self.row_factory cur: Union[AsyncCursor[Any], AsyncServerCursor[Any]] if name: cur = self.server_cursor_factory( self, name=name, row_factory=row_factory, scrollable=scrollable, withhold=withhold, ) else: cur = self.cursor_factory(self, row_factory=row_factory) if binary: cur.format = BINARY return cur async def execute( self, query: Query, params: Optional[Params] = None, *, prepare: Optional[bool] = None, binary: bool = False, ) -> AsyncCursor[Row]: try: cur = self.cursor() if binary: cur.format = BINARY return await cur.execute(query, params, prepare=prepare) except e._NO_TRACEBACK as ex: raise ex.with_traceback(None) async def commit(self) -> None: async with self.lock: await self.wait(self._commit_gen()) async def rollback(self) -> None: async with self.lock: await self.wait(self._rollback_gen()) @asynccontextmanager async def transaction( self, savepoint_name: Optional[str] = None, force_rollback: bool = False, ) -> AsyncIterator[AsyncTransaction]: """ Start a context block with a new transaction or nested transaction. :rtype: AsyncTransaction """ tx = AsyncTransaction(self, savepoint_name, force_rollback) if self._pipeline: async with self.pipeline(), tx, self.pipeline(): yield tx else: async with tx: yield tx async def notifies(self) -> AsyncGenerator[Notify, None]: while True: async with self.lock: try: ns = await self.wait(notifies(self.pgconn)) except e._NO_TRACEBACK as ex: raise ex.with_traceback(None) enc = pgconn_encoding(self.pgconn) for pgn in ns: n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid) yield n @asynccontextmanager async def pipeline(self) -> AsyncIterator[AsyncPipeline]: """Context manager to switch the connection into pipeline mode.""" async with self.lock: self._check_connection_ok() pipeline = self._pipeline if pipeline is None: # WARNING: reference loop, broken ahead. pipeline = self._pipeline = AsyncPipeline(self) try: async with pipeline: yield pipeline finally: if pipeline.level == 0: async with self.lock: assert pipeline is self._pipeline self._pipeline = None async def wait(self, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV: try: return await waiting.wait_async(gen, self.pgconn.socket, timeout=timeout) except (asyncio.CancelledError, KeyboardInterrupt): # On Ctrl-C, try to cancel the query in the server, otherwise # the connection will remain stuck in ACTIVE state. self._try_cancel(self.pgconn) try: await waiting.wait_async(gen, self.pgconn.socket, timeout=timeout) except e.QueryCanceled: pass # as expected raise @classmethod async def _wait_conn(cls, gen: PQGenConn[RV], timeout: Optional[int]) -> RV: return await waiting.wait_conn_async(gen, timeout) def _set_autocommit(self, value: bool) -> None: self._no_set_async("autocommit") async def set_autocommit(self, value: bool) -> None: """Async version of the `~Connection.autocommit` setter.""" async with self.lock: await self.wait(self._set_autocommit_gen(value)) def _set_isolation_level(self, value: Optional[IsolationLevel]) -> None: self._no_set_async("isolation_level") async def set_isolation_level(self, value: Optional[IsolationLevel]) -> None: """Async version of the `~Connection.isolation_level` setter.""" async with self.lock: await self.wait(self._set_isolation_level_gen(value)) def _set_read_only(self, value: Optional[bool]) -> None: self._no_set_async("read_only") async def set_read_only(self, value: Optional[bool]) -> None: """Async version of the `~Connection.read_only` setter.""" async with self.lock: await self.wait(self._set_read_only_gen(value)) def _set_deferrable(self, value: Optional[bool]) -> None: self._no_set_async("deferrable") async def set_deferrable(self, value: Optional[bool]) -> None: """Async version of the `~Connection.deferrable` setter.""" async with self.lock: await self.wait(self._set_deferrable_gen(value)) def _no_set_async(self, attribute: str) -> None: raise AttributeError( f"'the {attribute!r} property is read-only on async connections:" f" please use 'await .set_{attribute}()' instead." ) async def tpc_begin(self, xid: Union[Xid, str]) -> None: async with self.lock: await self.wait(self._tpc_begin_gen(xid)) async def tpc_prepare(self) -> None: try: async with self.lock: await self.wait(self._tpc_prepare_gen()) except e.ObjectNotInPrerequisiteState as ex: raise e.NotSupportedError(str(ex)) from None async def tpc_commit(self, xid: Union[Xid, str, None] = None) -> None: async with self.lock: await self.wait(self._tpc_finish_gen("commit", xid)) async def tpc_rollback(self, xid: Union[Xid, str, None] = None) -> None: async with self.lock: await self.wait(self._tpc_finish_gen("rollback", xid)) async def tpc_recover(self) -> List[Xid]: self._check_tpc() status = self.info.transaction_status async with self.cursor(row_factory=args_row(Xid._from_record)) as cur: await cur.execute(Xid._get_recover_query()) res = await cur.fetchall() if status == IDLE and self.info.transaction_status == INTRANS: await self.rollback() return res