docker setup
This commit is contained in:
@@ -0,0 +1,61 @@
|
||||
from django.core import signals
|
||||
from django.db.utils import (
|
||||
DEFAULT_DB_ALIAS,
|
||||
DJANGO_VERSION_PICKLE_KEY,
|
||||
ConnectionHandler,
|
||||
ConnectionRouter,
|
||||
DatabaseError,
|
||||
DataError,
|
||||
Error,
|
||||
IntegrityError,
|
||||
InterfaceError,
|
||||
InternalError,
|
||||
NotSupportedError,
|
||||
OperationalError,
|
||||
ProgrammingError,
|
||||
)
|
||||
from django.utils.connection import ConnectionProxy
|
||||
|
||||
__all__ = [
|
||||
"connection",
|
||||
"connections",
|
||||
"router",
|
||||
"DatabaseError",
|
||||
"IntegrityError",
|
||||
"InternalError",
|
||||
"ProgrammingError",
|
||||
"DataError",
|
||||
"NotSupportedError",
|
||||
"Error",
|
||||
"InterfaceError",
|
||||
"OperationalError",
|
||||
"DEFAULT_DB_ALIAS",
|
||||
"DJANGO_VERSION_PICKLE_KEY",
|
||||
]
|
||||
|
||||
connections = ConnectionHandler()
|
||||
|
||||
router = ConnectionRouter()
|
||||
|
||||
# For backwards compatibility. Prefer connections['default'] instead.
|
||||
connection = ConnectionProxy(connections, DEFAULT_DB_ALIAS)
|
||||
|
||||
|
||||
# Register an event to reset saved queries when a Django request is started.
|
||||
def reset_queries(**kwargs):
|
||||
for conn in connections.all(initialized_only=True):
|
||||
conn.queries_log.clear()
|
||||
|
||||
|
||||
signals.request_started.connect(reset_queries)
|
||||
|
||||
|
||||
# Register an event to reset transaction state and close connections past
|
||||
# their lifetime.
|
||||
def close_old_connections(**kwargs):
|
||||
for conn in connections.all(initialized_only=True):
|
||||
conn.close_if_unusable_or_obsolete()
|
||||
|
||||
|
||||
signals.request_started.connect(close_old_connections)
|
||||
signals.request_finished.connect(close_old_connections)
|
||||
@@ -0,0 +1,802 @@
|
||||
import _thread
|
||||
import copy
|
||||
import datetime
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import warnings
|
||||
from collections import deque
|
||||
from contextlib import contextmanager
|
||||
|
||||
from django.db.backends.utils import debug_transaction
|
||||
|
||||
try:
|
||||
import zoneinfo
|
||||
except ImportError:
|
||||
from backports import zoneinfo
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db import DEFAULT_DB_ALIAS, DatabaseError, NotSupportedError
|
||||
from django.db.backends import utils
|
||||
from django.db.backends.base.validation import BaseDatabaseValidation
|
||||
from django.db.backends.signals import connection_created
|
||||
from django.db.transaction import TransactionManagementError
|
||||
from django.db.utils import DatabaseErrorWrapper
|
||||
from django.utils.asyncio import async_unsafe
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
NO_DB_ALIAS = "__no_db__"
|
||||
RAN_DB_VERSION_CHECK = set()
|
||||
|
||||
logger = logging.getLogger("django.db.backends.base")
|
||||
|
||||
|
||||
# RemovedInDjango50Warning
|
||||
def timezone_constructor(tzname):
|
||||
if settings.USE_DEPRECATED_PYTZ:
|
||||
import pytz
|
||||
|
||||
return pytz.timezone(tzname)
|
||||
return zoneinfo.ZoneInfo(tzname)
|
||||
|
||||
|
||||
class BaseDatabaseWrapper:
|
||||
"""Represent a database connection."""
|
||||
|
||||
# Mapping of Field objects to their column types.
|
||||
data_types = {}
|
||||
# Mapping of Field objects to their SQL suffix such as AUTOINCREMENT.
|
||||
data_types_suffix = {}
|
||||
# Mapping of Field objects to their SQL for CHECK constraints.
|
||||
data_type_check_constraints = {}
|
||||
ops = None
|
||||
vendor = "unknown"
|
||||
display_name = "unknown"
|
||||
SchemaEditorClass = None
|
||||
# Classes instantiated in __init__().
|
||||
client_class = None
|
||||
creation_class = None
|
||||
features_class = None
|
||||
introspection_class = None
|
||||
ops_class = None
|
||||
validation_class = BaseDatabaseValidation
|
||||
|
||||
queries_limit = 9000
|
||||
|
||||
def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS):
|
||||
# Connection related attributes.
|
||||
# The underlying database connection.
|
||||
self.connection = None
|
||||
# `settings_dict` should be a dictionary containing keys such as
|
||||
# NAME, USER, etc. It's called `settings_dict` instead of `settings`
|
||||
# to disambiguate it from Django settings modules.
|
||||
self.settings_dict = settings_dict
|
||||
self.alias = alias
|
||||
# Query logging in debug mode or when explicitly enabled.
|
||||
self.queries_log = deque(maxlen=self.queries_limit)
|
||||
self.force_debug_cursor = False
|
||||
|
||||
# Transaction related attributes.
|
||||
# Tracks if the connection is in autocommit mode. Per PEP 249, by
|
||||
# default, it isn't.
|
||||
self.autocommit = False
|
||||
# Tracks if the connection is in a transaction managed by 'atomic'.
|
||||
self.in_atomic_block = False
|
||||
# Increment to generate unique savepoint ids.
|
||||
self.savepoint_state = 0
|
||||
# List of savepoints created by 'atomic'.
|
||||
self.savepoint_ids = []
|
||||
# Stack of active 'atomic' blocks.
|
||||
self.atomic_blocks = []
|
||||
# Tracks if the outermost 'atomic' block should commit on exit,
|
||||
# ie. if autocommit was active on entry.
|
||||
self.commit_on_exit = True
|
||||
# Tracks if the transaction should be rolled back to the next
|
||||
# available savepoint because of an exception in an inner block.
|
||||
self.needs_rollback = False
|
||||
self.rollback_exc = None
|
||||
|
||||
# Connection termination related attributes.
|
||||
self.close_at = None
|
||||
self.closed_in_transaction = False
|
||||
self.errors_occurred = False
|
||||
self.health_check_enabled = False
|
||||
self.health_check_done = False
|
||||
|
||||
# Thread-safety related attributes.
|
||||
self._thread_sharing_lock = threading.Lock()
|
||||
self._thread_sharing_count = 0
|
||||
self._thread_ident = _thread.get_ident()
|
||||
|
||||
# A list of no-argument functions to run when the transaction commits.
|
||||
# Each entry is an (sids, func, robust) tuple, where sids is a set of
|
||||
# the active savepoint IDs when this function was registered and robust
|
||||
# specifies whether it's allowed for the function to fail.
|
||||
self.run_on_commit = []
|
||||
|
||||
# Should we run the on-commit hooks the next time set_autocommit(True)
|
||||
# is called?
|
||||
self.run_commit_hooks_on_set_autocommit_on = False
|
||||
|
||||
# A stack of wrappers to be invoked around execute()/executemany()
|
||||
# calls. Each entry is a function taking five arguments: execute, sql,
|
||||
# params, many, and context. It's the function's responsibility to
|
||||
# call execute(sql, params, many, context).
|
||||
self.execute_wrappers = []
|
||||
|
||||
self.client = self.client_class(self)
|
||||
self.creation = self.creation_class(self)
|
||||
self.features = self.features_class(self)
|
||||
self.introspection = self.introspection_class(self)
|
||||
self.ops = self.ops_class(self)
|
||||
self.validation = self.validation_class(self)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<{self.__class__.__qualname__} "
|
||||
f"vendor={self.vendor!r} alias={self.alias!r}>"
|
||||
)
|
||||
|
||||
def ensure_timezone(self):
|
||||
"""
|
||||
Ensure the connection's timezone is set to `self.timezone_name` and
|
||||
return whether it changed or not.
|
||||
"""
|
||||
return False
|
||||
|
||||
@cached_property
|
||||
def timezone(self):
|
||||
"""
|
||||
Return a tzinfo of the database connection time zone.
|
||||
|
||||
This is only used when time zone support is enabled. When a datetime is
|
||||
read from the database, it is always returned in this time zone.
|
||||
|
||||
When the database backend supports time zones, it doesn't matter which
|
||||
time zone Django uses, as long as aware datetimes are used everywhere.
|
||||
Other users connecting to the database can choose their own time zone.
|
||||
|
||||
When the database backend doesn't support time zones, the time zone
|
||||
Django uses may be constrained by the requirements of other users of
|
||||
the database.
|
||||
"""
|
||||
if not settings.USE_TZ:
|
||||
return None
|
||||
elif self.settings_dict["TIME_ZONE"] is None:
|
||||
return datetime.timezone.utc
|
||||
else:
|
||||
return timezone_constructor(self.settings_dict["TIME_ZONE"])
|
||||
|
||||
@cached_property
|
||||
def timezone_name(self):
|
||||
"""
|
||||
Name of the time zone of the database connection.
|
||||
"""
|
||||
if not settings.USE_TZ:
|
||||
return settings.TIME_ZONE
|
||||
elif self.settings_dict["TIME_ZONE"] is None:
|
||||
return "UTC"
|
||||
else:
|
||||
return self.settings_dict["TIME_ZONE"]
|
||||
|
||||
@property
|
||||
def queries_logged(self):
|
||||
return self.force_debug_cursor or settings.DEBUG
|
||||
|
||||
@property
|
||||
def queries(self):
|
||||
if len(self.queries_log) == self.queries_log.maxlen:
|
||||
warnings.warn(
|
||||
"Limit for query logging exceeded, only the last {} queries "
|
||||
"will be returned.".format(self.queries_log.maxlen)
|
||||
)
|
||||
return list(self.queries_log)
|
||||
|
||||
def get_database_version(self):
|
||||
"""Return a tuple of the database's version."""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseWrapper may require a get_database_version() "
|
||||
"method."
|
||||
)
|
||||
|
||||
def check_database_version_supported(self):
|
||||
"""
|
||||
Raise an error if the database version isn't supported by this
|
||||
version of Django.
|
||||
"""
|
||||
if (
|
||||
self.features.minimum_database_version is not None
|
||||
and self.get_database_version() < self.features.minimum_database_version
|
||||
):
|
||||
db_version = ".".join(map(str, self.get_database_version()))
|
||||
min_db_version = ".".join(map(str, self.features.minimum_database_version))
|
||||
raise NotSupportedError(
|
||||
f"{self.display_name} {min_db_version} or later is required "
|
||||
f"(found {db_version})."
|
||||
)
|
||||
|
||||
# ##### Backend-specific methods for creating connections and cursors #####
|
||||
|
||||
def get_connection_params(self):
|
||||
"""Return a dict of parameters suitable for get_new_connection."""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseWrapper may require a get_connection_params() "
|
||||
"method"
|
||||
)
|
||||
|
||||
def get_new_connection(self, conn_params):
|
||||
"""Open a connection to the database."""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseWrapper may require a get_new_connection() "
|
||||
"method"
|
||||
)
|
||||
|
||||
def init_connection_state(self):
|
||||
"""Initialize the database connection settings."""
|
||||
global RAN_DB_VERSION_CHECK
|
||||
if self.alias not in RAN_DB_VERSION_CHECK:
|
||||
self.check_database_version_supported()
|
||||
RAN_DB_VERSION_CHECK.add(self.alias)
|
||||
|
||||
def create_cursor(self, name=None):
|
||||
"""Create a cursor. Assume that a connection is established."""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseWrapper may require a create_cursor() method"
|
||||
)
|
||||
|
||||
# ##### Backend-specific methods for creating connections #####
|
||||
|
||||
@async_unsafe
|
||||
def connect(self):
|
||||
"""Connect to the database. Assume that the connection is closed."""
|
||||
# Check for invalid configurations.
|
||||
self.check_settings()
|
||||
# In case the previous connection was closed while in an atomic block
|
||||
self.in_atomic_block = False
|
||||
self.savepoint_ids = []
|
||||
self.atomic_blocks = []
|
||||
self.needs_rollback = False
|
||||
# Reset parameters defining when to close/health-check the connection.
|
||||
self.health_check_enabled = self.settings_dict["CONN_HEALTH_CHECKS"]
|
||||
max_age = self.settings_dict["CONN_MAX_AGE"]
|
||||
self.close_at = None if max_age is None else time.monotonic() + max_age
|
||||
self.closed_in_transaction = False
|
||||
self.errors_occurred = False
|
||||
# New connections are healthy.
|
||||
self.health_check_done = True
|
||||
# Establish the connection
|
||||
conn_params = self.get_connection_params()
|
||||
self.connection = self.get_new_connection(conn_params)
|
||||
self.set_autocommit(self.settings_dict["AUTOCOMMIT"])
|
||||
self.init_connection_state()
|
||||
connection_created.send(sender=self.__class__, connection=self)
|
||||
|
||||
self.run_on_commit = []
|
||||
|
||||
def check_settings(self):
|
||||
if self.settings_dict["TIME_ZONE"] is not None and not settings.USE_TZ:
|
||||
raise ImproperlyConfigured(
|
||||
"Connection '%s' cannot set TIME_ZONE because USE_TZ is False."
|
||||
% self.alias
|
||||
)
|
||||
|
||||
@async_unsafe
|
||||
def ensure_connection(self):
|
||||
"""Guarantee that a connection to the database is established."""
|
||||
if self.connection is None:
|
||||
with self.wrap_database_errors:
|
||||
self.connect()
|
||||
|
||||
# ##### Backend-specific wrappers for PEP-249 connection methods #####
|
||||
|
||||
def _prepare_cursor(self, cursor):
|
||||
"""
|
||||
Validate the connection is usable and perform database cursor wrapping.
|
||||
"""
|
||||
self.validate_thread_sharing()
|
||||
if self.queries_logged:
|
||||
wrapped_cursor = self.make_debug_cursor(cursor)
|
||||
else:
|
||||
wrapped_cursor = self.make_cursor(cursor)
|
||||
return wrapped_cursor
|
||||
|
||||
def _cursor(self, name=None):
|
||||
self.close_if_health_check_failed()
|
||||
self.ensure_connection()
|
||||
with self.wrap_database_errors:
|
||||
return self._prepare_cursor(self.create_cursor(name))
|
||||
|
||||
def _commit(self):
|
||||
if self.connection is not None:
|
||||
with debug_transaction(self, "COMMIT"), self.wrap_database_errors:
|
||||
return self.connection.commit()
|
||||
|
||||
def _rollback(self):
|
||||
if self.connection is not None:
|
||||
with debug_transaction(self, "ROLLBACK"), self.wrap_database_errors:
|
||||
return self.connection.rollback()
|
||||
|
||||
def _close(self):
|
||||
if self.connection is not None:
|
||||
with self.wrap_database_errors:
|
||||
return self.connection.close()
|
||||
|
||||
# ##### Generic wrappers for PEP-249 connection methods #####
|
||||
|
||||
@async_unsafe
|
||||
def cursor(self):
|
||||
"""Create a cursor, opening a connection if necessary."""
|
||||
return self._cursor()
|
||||
|
||||
@async_unsafe
|
||||
def commit(self):
|
||||
"""Commit a transaction and reset the dirty flag."""
|
||||
self.validate_thread_sharing()
|
||||
self.validate_no_atomic_block()
|
||||
self._commit()
|
||||
# A successful commit means that the database connection works.
|
||||
self.errors_occurred = False
|
||||
self.run_commit_hooks_on_set_autocommit_on = True
|
||||
|
||||
@async_unsafe
|
||||
def rollback(self):
|
||||
"""Roll back a transaction and reset the dirty flag."""
|
||||
self.validate_thread_sharing()
|
||||
self.validate_no_atomic_block()
|
||||
self._rollback()
|
||||
# A successful rollback means that the database connection works.
|
||||
self.errors_occurred = False
|
||||
self.needs_rollback = False
|
||||
self.run_on_commit = []
|
||||
|
||||
@async_unsafe
|
||||
def close(self):
|
||||
"""Close the connection to the database."""
|
||||
self.validate_thread_sharing()
|
||||
self.run_on_commit = []
|
||||
|
||||
# Don't call validate_no_atomic_block() to avoid making it difficult
|
||||
# to get rid of a connection in an invalid state. The next connect()
|
||||
# will reset the transaction state anyway.
|
||||
if self.closed_in_transaction or self.connection is None:
|
||||
return
|
||||
try:
|
||||
self._close()
|
||||
finally:
|
||||
if self.in_atomic_block:
|
||||
self.closed_in_transaction = True
|
||||
self.needs_rollback = True
|
||||
else:
|
||||
self.connection = None
|
||||
|
||||
# ##### Backend-specific savepoint management methods #####
|
||||
|
||||
def _savepoint(self, sid):
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute(self.ops.savepoint_create_sql(sid))
|
||||
|
||||
def _savepoint_rollback(self, sid):
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute(self.ops.savepoint_rollback_sql(sid))
|
||||
|
||||
def _savepoint_commit(self, sid):
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute(self.ops.savepoint_commit_sql(sid))
|
||||
|
||||
def _savepoint_allowed(self):
|
||||
# Savepoints cannot be created outside a transaction
|
||||
return self.features.uses_savepoints and not self.get_autocommit()
|
||||
|
||||
# ##### Generic savepoint management methods #####
|
||||
|
||||
@async_unsafe
|
||||
def savepoint(self):
|
||||
"""
|
||||
Create a savepoint inside the current transaction. Return an
|
||||
identifier for the savepoint that will be used for the subsequent
|
||||
rollback or commit. Do nothing if savepoints are not supported.
|
||||
"""
|
||||
if not self._savepoint_allowed():
|
||||
return
|
||||
|
||||
thread_ident = _thread.get_ident()
|
||||
tid = str(thread_ident).replace("-", "")
|
||||
|
||||
self.savepoint_state += 1
|
||||
sid = "s%s_x%d" % (tid, self.savepoint_state)
|
||||
|
||||
self.validate_thread_sharing()
|
||||
self._savepoint(sid)
|
||||
|
||||
return sid
|
||||
|
||||
@async_unsafe
|
||||
def savepoint_rollback(self, sid):
|
||||
"""
|
||||
Roll back to a savepoint. Do nothing if savepoints are not supported.
|
||||
"""
|
||||
if not self._savepoint_allowed():
|
||||
return
|
||||
|
||||
self.validate_thread_sharing()
|
||||
self._savepoint_rollback(sid)
|
||||
|
||||
# Remove any callbacks registered while this savepoint was active.
|
||||
self.run_on_commit = [
|
||||
(sids, func, robust)
|
||||
for (sids, func, robust) in self.run_on_commit
|
||||
if sid not in sids
|
||||
]
|
||||
|
||||
@async_unsafe
|
||||
def savepoint_commit(self, sid):
|
||||
"""
|
||||
Release a savepoint. Do nothing if savepoints are not supported.
|
||||
"""
|
||||
if not self._savepoint_allowed():
|
||||
return
|
||||
|
||||
self.validate_thread_sharing()
|
||||
self._savepoint_commit(sid)
|
||||
|
||||
@async_unsafe
|
||||
def clean_savepoints(self):
|
||||
"""
|
||||
Reset the counter used to generate unique savepoint ids in this thread.
|
||||
"""
|
||||
self.savepoint_state = 0
|
||||
|
||||
# ##### Backend-specific transaction management methods #####
|
||||
|
||||
def _set_autocommit(self, autocommit):
|
||||
"""
|
||||
Backend-specific implementation to enable or disable autocommit.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseWrapper may require a _set_autocommit() method"
|
||||
)
|
||||
|
||||
# ##### Generic transaction management methods #####
|
||||
|
||||
def get_autocommit(self):
|
||||
"""Get the autocommit state."""
|
||||
self.ensure_connection()
|
||||
return self.autocommit
|
||||
|
||||
def set_autocommit(
|
||||
self, autocommit, force_begin_transaction_with_broken_autocommit=False
|
||||
):
|
||||
"""
|
||||
Enable or disable autocommit.
|
||||
|
||||
The usual way to start a transaction is to turn autocommit off.
|
||||
SQLite does not properly start a transaction when disabling
|
||||
autocommit. To avoid this buggy behavior and to actually enter a new
|
||||
transaction, an explicit BEGIN is required. Using
|
||||
force_begin_transaction_with_broken_autocommit=True will issue an
|
||||
explicit BEGIN with SQLite. This option will be ignored for other
|
||||
backends.
|
||||
"""
|
||||
self.validate_no_atomic_block()
|
||||
self.close_if_health_check_failed()
|
||||
self.ensure_connection()
|
||||
|
||||
start_transaction_under_autocommit = (
|
||||
force_begin_transaction_with_broken_autocommit
|
||||
and not autocommit
|
||||
and hasattr(self, "_start_transaction_under_autocommit")
|
||||
)
|
||||
|
||||
if start_transaction_under_autocommit:
|
||||
self._start_transaction_under_autocommit()
|
||||
elif autocommit:
|
||||
self._set_autocommit(autocommit)
|
||||
else:
|
||||
with debug_transaction(self, "BEGIN"):
|
||||
self._set_autocommit(autocommit)
|
||||
self.autocommit = autocommit
|
||||
|
||||
if autocommit and self.run_commit_hooks_on_set_autocommit_on:
|
||||
self.run_and_clear_commit_hooks()
|
||||
self.run_commit_hooks_on_set_autocommit_on = False
|
||||
|
||||
def get_rollback(self):
|
||||
"""Get the "needs rollback" flag -- for *advanced use* only."""
|
||||
if not self.in_atomic_block:
|
||||
raise TransactionManagementError(
|
||||
"The rollback flag doesn't work outside of an 'atomic' block."
|
||||
)
|
||||
return self.needs_rollback
|
||||
|
||||
def set_rollback(self, rollback):
|
||||
"""
|
||||
Set or unset the "needs rollback" flag -- for *advanced use* only.
|
||||
"""
|
||||
if not self.in_atomic_block:
|
||||
raise TransactionManagementError(
|
||||
"The rollback flag doesn't work outside of an 'atomic' block."
|
||||
)
|
||||
self.needs_rollback = rollback
|
||||
|
||||
def validate_no_atomic_block(self):
|
||||
"""Raise an error if an atomic block is active."""
|
||||
if self.in_atomic_block:
|
||||
raise TransactionManagementError(
|
||||
"This is forbidden when an 'atomic' block is active."
|
||||
)
|
||||
|
||||
def validate_no_broken_transaction(self):
|
||||
if self.needs_rollback:
|
||||
raise TransactionManagementError(
|
||||
"An error occurred in the current transaction. You can't "
|
||||
"execute queries until the end of the 'atomic' block."
|
||||
) from self.rollback_exc
|
||||
|
||||
# ##### Foreign key constraints checks handling #####
|
||||
|
||||
@contextmanager
|
||||
def constraint_checks_disabled(self):
|
||||
"""
|
||||
Disable foreign key constraint checking.
|
||||
"""
|
||||
disabled = self.disable_constraint_checking()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if disabled:
|
||||
self.enable_constraint_checking()
|
||||
|
||||
def disable_constraint_checking(self):
|
||||
"""
|
||||
Backends can implement as needed to temporarily disable foreign key
|
||||
constraint checking. Should return True if the constraints were
|
||||
disabled and will need to be reenabled.
|
||||
"""
|
||||
return False
|
||||
|
||||
def enable_constraint_checking(self):
|
||||
"""
|
||||
Backends can implement as needed to re-enable foreign key constraint
|
||||
checking.
|
||||
"""
|
||||
pass
|
||||
|
||||
def check_constraints(self, table_names=None):
|
||||
"""
|
||||
Backends can override this method if they can apply constraint
|
||||
checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE"). Should raise an
|
||||
IntegrityError if any invalid foreign key references are encountered.
|
||||
"""
|
||||
pass
|
||||
|
||||
# ##### Connection termination handling #####
|
||||
|
||||
def is_usable(self):
|
||||
"""
|
||||
Test if the database connection is usable.
|
||||
|
||||
This method may assume that self.connection is not None.
|
||||
|
||||
Actual implementations should take care not to raise exceptions
|
||||
as that may prevent Django from recycling unusable connections.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseWrapper may require an is_usable() method"
|
||||
)
|
||||
|
||||
def close_if_health_check_failed(self):
|
||||
"""Close existing connection if it fails a health check."""
|
||||
if (
|
||||
self.connection is None
|
||||
or not self.health_check_enabled
|
||||
or self.health_check_done
|
||||
):
|
||||
return
|
||||
|
||||
if not self.is_usable():
|
||||
self.close()
|
||||
self.health_check_done = True
|
||||
|
||||
def close_if_unusable_or_obsolete(self):
|
||||
"""
|
||||
Close the current connection if unrecoverable errors have occurred
|
||||
or if it outlived its maximum age.
|
||||
"""
|
||||
if self.connection is not None:
|
||||
self.health_check_done = False
|
||||
# If the application didn't restore the original autocommit setting,
|
||||
# don't take chances, drop the connection.
|
||||
if self.get_autocommit() != self.settings_dict["AUTOCOMMIT"]:
|
||||
self.close()
|
||||
return
|
||||
|
||||
# If an exception other than DataError or IntegrityError occurred
|
||||
# since the last commit / rollback, check if the connection works.
|
||||
if self.errors_occurred:
|
||||
if self.is_usable():
|
||||
self.errors_occurred = False
|
||||
self.health_check_done = True
|
||||
else:
|
||||
self.close()
|
||||
return
|
||||
|
||||
if self.close_at is not None and time.monotonic() >= self.close_at:
|
||||
self.close()
|
||||
return
|
||||
|
||||
# ##### Thread safety handling #####
|
||||
|
||||
@property
|
||||
def allow_thread_sharing(self):
|
||||
with self._thread_sharing_lock:
|
||||
return self._thread_sharing_count > 0
|
||||
|
||||
def inc_thread_sharing(self):
|
||||
with self._thread_sharing_lock:
|
||||
self._thread_sharing_count += 1
|
||||
|
||||
def dec_thread_sharing(self):
|
||||
with self._thread_sharing_lock:
|
||||
if self._thread_sharing_count <= 0:
|
||||
raise RuntimeError(
|
||||
"Cannot decrement the thread sharing count below zero."
|
||||
)
|
||||
self._thread_sharing_count -= 1
|
||||
|
||||
def validate_thread_sharing(self):
|
||||
"""
|
||||
Validate that the connection isn't accessed by another thread than the
|
||||
one which originally created it, unless the connection was explicitly
|
||||
authorized to be shared between threads (via the `inc_thread_sharing()`
|
||||
method). Raise an exception if the validation fails.
|
||||
"""
|
||||
if not (self.allow_thread_sharing or self._thread_ident == _thread.get_ident()):
|
||||
raise DatabaseError(
|
||||
"DatabaseWrapper objects created in a "
|
||||
"thread can only be used in that same thread. The object "
|
||||
"with alias '%s' was created in thread id %s and this is "
|
||||
"thread id %s." % (self.alias, self._thread_ident, _thread.get_ident())
|
||||
)
|
||||
|
||||
# ##### Miscellaneous #####
|
||||
|
||||
def prepare_database(self):
|
||||
"""
|
||||
Hook to do any database check or preparation, generally called before
|
||||
migrating a project or an app.
|
||||
"""
|
||||
pass
|
||||
|
||||
@cached_property
|
||||
def wrap_database_errors(self):
|
||||
"""
|
||||
Context manager and decorator that re-throws backend-specific database
|
||||
exceptions using Django's common wrappers.
|
||||
"""
|
||||
return DatabaseErrorWrapper(self)
|
||||
|
||||
def chunked_cursor(self):
|
||||
"""
|
||||
Return a cursor that tries to avoid caching in the database (if
|
||||
supported by the database), otherwise return a regular cursor.
|
||||
"""
|
||||
return self.cursor()
|
||||
|
||||
def make_debug_cursor(self, cursor):
|
||||
"""Create a cursor that logs all queries in self.queries_log."""
|
||||
return utils.CursorDebugWrapper(cursor, self)
|
||||
|
||||
def make_cursor(self, cursor):
|
||||
"""Create a cursor without debug logging."""
|
||||
return utils.CursorWrapper(cursor, self)
|
||||
|
||||
@contextmanager
|
||||
def temporary_connection(self):
|
||||
"""
|
||||
Context manager that ensures that a connection is established, and
|
||||
if it opened one, closes it to avoid leaving a dangling connection.
|
||||
This is useful for operations outside of the request-response cycle.
|
||||
|
||||
Provide a cursor: with self.temporary_connection() as cursor: ...
|
||||
"""
|
||||
must_close = self.connection is None
|
||||
try:
|
||||
with self.cursor() as cursor:
|
||||
yield cursor
|
||||
finally:
|
||||
if must_close:
|
||||
self.close()
|
||||
|
||||
@contextmanager
|
||||
def _nodb_cursor(self):
|
||||
"""
|
||||
Return a cursor from an alternative connection to be used when there is
|
||||
no need to access the main database, specifically for test db
|
||||
creation/deletion. This also prevents the production database from
|
||||
being exposed to potential child threads while (or after) the test
|
||||
database is destroyed. Refs #10868, #17786, #16969.
|
||||
"""
|
||||
conn = self.__class__({**self.settings_dict, "NAME": None}, alias=NO_DB_ALIAS)
|
||||
try:
|
||||
with conn.cursor() as cursor:
|
||||
yield cursor
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def schema_editor(self, *args, **kwargs):
|
||||
"""
|
||||
Return a new instance of this backend's SchemaEditor.
|
||||
"""
|
||||
if self.SchemaEditorClass is None:
|
||||
raise NotImplementedError(
|
||||
"The SchemaEditorClass attribute of this database wrapper is still None"
|
||||
)
|
||||
return self.SchemaEditorClass(self, *args, **kwargs)
|
||||
|
||||
def on_commit(self, func, robust=False):
|
||||
if not callable(func):
|
||||
raise TypeError("on_commit()'s callback must be a callable.")
|
||||
if self.in_atomic_block:
|
||||
# Transaction in progress; save for execution on commit.
|
||||
self.run_on_commit.append((set(self.savepoint_ids), func, robust))
|
||||
elif not self.get_autocommit():
|
||||
raise TransactionManagementError(
|
||||
"on_commit() cannot be used in manual transaction management"
|
||||
)
|
||||
else:
|
||||
# No transaction in progress and in autocommit mode; execute
|
||||
# immediately.
|
||||
if robust:
|
||||
try:
|
||||
func()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error calling {func.__qualname__} in on_commit() (%s).",
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
func()
|
||||
|
||||
def run_and_clear_commit_hooks(self):
|
||||
self.validate_no_atomic_block()
|
||||
current_run_on_commit = self.run_on_commit
|
||||
self.run_on_commit = []
|
||||
while current_run_on_commit:
|
||||
_, func, robust = current_run_on_commit.pop(0)
|
||||
if robust:
|
||||
try:
|
||||
func()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error calling {func.__qualname__} in on_commit() during "
|
||||
f"transaction (%s).",
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
func()
|
||||
|
||||
@contextmanager
|
||||
def execute_wrapper(self, wrapper):
|
||||
"""
|
||||
Return a context manager under which the wrapper is applied to suitable
|
||||
database query executions.
|
||||
"""
|
||||
self.execute_wrappers.append(wrapper)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.execute_wrappers.pop()
|
||||
|
||||
def copy(self, alias=None):
|
||||
"""
|
||||
Return a copy of this connection.
|
||||
|
||||
For tests that require two connections to the same database.
|
||||
"""
|
||||
settings_dict = copy.deepcopy(self.settings_dict)
|
||||
if alias is None:
|
||||
alias = self.alias
|
||||
return type(self)(settings_dict, alias)
|
||||
@@ -0,0 +1,28 @@
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
|
||||
class BaseDatabaseClient:
|
||||
"""Encapsulate backend-specific methods for opening a client shell."""
|
||||
|
||||
# This should be a string representing the name of the executable
|
||||
# (e.g., "psql"). Subclasses must override this.
|
||||
executable_name = None
|
||||
|
||||
def __init__(self, connection):
|
||||
# connection is an instance of BaseDatabaseWrapper.
|
||||
self.connection = connection
|
||||
|
||||
@classmethod
|
||||
def settings_to_cmd_args_env(cls, settings_dict, parameters):
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseClient must provide a "
|
||||
"settings_to_cmd_args_env() method or override a runshell()."
|
||||
)
|
||||
|
||||
def runshell(self, parameters):
|
||||
args, env = self.settings_to_cmd_args_env(
|
||||
self.connection.settings_dict, parameters
|
||||
)
|
||||
env = {**os.environ, **env} if env else None
|
||||
subprocess.run(args, env=env, check=True)
|
||||
@@ -0,0 +1,381 @@
|
||||
import os
|
||||
import sys
|
||||
from io import StringIO
|
||||
|
||||
from django.apps import apps
|
||||
from django.conf import settings
|
||||
from django.core import serializers
|
||||
from django.db import router
|
||||
from django.db.transaction import atomic
|
||||
from django.utils.module_loading import import_string
|
||||
|
||||
# The prefix to put on the default database name when creating
|
||||
# the test database.
|
||||
TEST_DATABASE_PREFIX = "test_"
|
||||
|
||||
|
||||
class BaseDatabaseCreation:
|
||||
"""
|
||||
Encapsulate backend-specific differences pertaining to creation and
|
||||
destruction of the test database.
|
||||
"""
|
||||
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
def _nodb_cursor(self):
|
||||
return self.connection._nodb_cursor()
|
||||
|
||||
def log(self, msg):
|
||||
sys.stderr.write(msg + os.linesep)
|
||||
|
||||
def create_test_db(
|
||||
self, verbosity=1, autoclobber=False, serialize=True, keepdb=False
|
||||
):
|
||||
"""
|
||||
Create a test database, prompting the user for confirmation if the
|
||||
database already exists. Return the name of the test database created.
|
||||
"""
|
||||
# Don't import django.core.management if it isn't needed.
|
||||
from django.core.management import call_command
|
||||
|
||||
test_database_name = self._get_test_db_name()
|
||||
|
||||
if verbosity >= 1:
|
||||
action = "Creating"
|
||||
if keepdb:
|
||||
action = "Using existing"
|
||||
|
||||
self.log(
|
||||
"%s test database for alias %s..."
|
||||
% (
|
||||
action,
|
||||
self._get_database_display_str(verbosity, test_database_name),
|
||||
)
|
||||
)
|
||||
|
||||
# We could skip this call if keepdb is True, but we instead
|
||||
# give it the keepdb param. This is to handle the case
|
||||
# where the test DB doesn't exist, in which case we need to
|
||||
# create it, then just not destroy it. If we instead skip
|
||||
# this, we will get an exception.
|
||||
self._create_test_db(verbosity, autoclobber, keepdb)
|
||||
|
||||
self.connection.close()
|
||||
settings.DATABASES[self.connection.alias]["NAME"] = test_database_name
|
||||
self.connection.settings_dict["NAME"] = test_database_name
|
||||
|
||||
try:
|
||||
if self.connection.settings_dict["TEST"]["MIGRATE"] is False:
|
||||
# Disable migrations for all apps.
|
||||
old_migration_modules = settings.MIGRATION_MODULES
|
||||
settings.MIGRATION_MODULES = {
|
||||
app.label: None for app in apps.get_app_configs()
|
||||
}
|
||||
# We report migrate messages at one level lower than that
|
||||
# requested. This ensures we don't get flooded with messages during
|
||||
# testing (unless you really ask to be flooded).
|
||||
call_command(
|
||||
"migrate",
|
||||
verbosity=max(verbosity - 1, 0),
|
||||
interactive=False,
|
||||
database=self.connection.alias,
|
||||
run_syncdb=True,
|
||||
)
|
||||
finally:
|
||||
if self.connection.settings_dict["TEST"]["MIGRATE"] is False:
|
||||
settings.MIGRATION_MODULES = old_migration_modules
|
||||
|
||||
# We then serialize the current state of the database into a string
|
||||
# and store it on the connection. This slightly horrific process is so people
|
||||
# who are testing on databases without transactions or who are using
|
||||
# a TransactionTestCase still get a clean database on every test run.
|
||||
if serialize:
|
||||
self.connection._test_serialized_contents = self.serialize_db_to_string()
|
||||
|
||||
call_command("createcachetable", database=self.connection.alias)
|
||||
|
||||
# Ensure a connection for the side effect of initializing the test database.
|
||||
self.connection.ensure_connection()
|
||||
|
||||
if os.environ.get("RUNNING_DJANGOS_TEST_SUITE") == "true":
|
||||
self.mark_expected_failures_and_skips()
|
||||
|
||||
return test_database_name
|
||||
|
||||
def set_as_test_mirror(self, primary_settings_dict):
|
||||
"""
|
||||
Set this database up to be used in testing as a mirror of a primary
|
||||
database whose settings are given.
|
||||
"""
|
||||
self.connection.settings_dict["NAME"] = primary_settings_dict["NAME"]
|
||||
|
||||
def serialize_db_to_string(self):
|
||||
"""
|
||||
Serialize all data in the database into a JSON string.
|
||||
Designed only for test runner usage; will not handle large
|
||||
amounts of data.
|
||||
"""
|
||||
|
||||
# Iteratively return every object for all models to serialize.
|
||||
def get_objects():
|
||||
from django.db.migrations.loader import MigrationLoader
|
||||
|
||||
loader = MigrationLoader(self.connection)
|
||||
for app_config in apps.get_app_configs():
|
||||
if (
|
||||
app_config.models_module is not None
|
||||
and app_config.label in loader.migrated_apps
|
||||
and app_config.name not in settings.TEST_NON_SERIALIZED_APPS
|
||||
):
|
||||
for model in app_config.get_models():
|
||||
if model._meta.can_migrate(
|
||||
self.connection
|
||||
) and router.allow_migrate_model(self.connection.alias, model):
|
||||
queryset = model._base_manager.using(
|
||||
self.connection.alias,
|
||||
).order_by(model._meta.pk.name)
|
||||
yield from queryset.iterator()
|
||||
|
||||
# Serialize to a string
|
||||
out = StringIO()
|
||||
serializers.serialize("json", get_objects(), indent=None, stream=out)
|
||||
return out.getvalue()
|
||||
|
||||
def deserialize_db_from_string(self, data):
|
||||
"""
|
||||
Reload the database with data from a string generated by
|
||||
the serialize_db_to_string() method.
|
||||
"""
|
||||
data = StringIO(data)
|
||||
table_names = set()
|
||||
# Load data in a transaction to handle forward references and cycles.
|
||||
with atomic(using=self.connection.alias):
|
||||
# Disable constraint checks, because some databases (MySQL) doesn't
|
||||
# support deferred checks.
|
||||
with self.connection.constraint_checks_disabled():
|
||||
for obj in serializers.deserialize(
|
||||
"json", data, using=self.connection.alias
|
||||
):
|
||||
obj.save()
|
||||
table_names.add(obj.object.__class__._meta.db_table)
|
||||
# Manually check for any invalid keys that might have been added,
|
||||
# because constraint checks were disabled.
|
||||
self.connection.check_constraints(table_names=table_names)
|
||||
|
||||
def _get_database_display_str(self, verbosity, database_name):
|
||||
"""
|
||||
Return display string for a database for use in various actions.
|
||||
"""
|
||||
return "'%s'%s" % (
|
||||
self.connection.alias,
|
||||
(" ('%s')" % database_name) if verbosity >= 2 else "",
|
||||
)
|
||||
|
||||
def _get_test_db_name(self):
|
||||
"""
|
||||
Internal implementation - return the name of the test DB that will be
|
||||
created. Only useful when called from create_test_db() and
|
||||
_create_test_db() and when no external munging is done with the 'NAME'
|
||||
settings.
|
||||
"""
|
||||
if self.connection.settings_dict["TEST"]["NAME"]:
|
||||
return self.connection.settings_dict["TEST"]["NAME"]
|
||||
return TEST_DATABASE_PREFIX + self.connection.settings_dict["NAME"]
|
||||
|
||||
def _execute_create_test_db(self, cursor, parameters, keepdb=False):
|
||||
cursor.execute("CREATE DATABASE %(dbname)s %(suffix)s" % parameters)
|
||||
|
||||
def _create_test_db(self, verbosity, autoclobber, keepdb=False):
|
||||
"""
|
||||
Internal implementation - create the test db tables.
|
||||
"""
|
||||
test_database_name = self._get_test_db_name()
|
||||
test_db_params = {
|
||||
"dbname": self.connection.ops.quote_name(test_database_name),
|
||||
"suffix": self.sql_table_creation_suffix(),
|
||||
}
|
||||
# Create the test database and connect to it.
|
||||
with self._nodb_cursor() as cursor:
|
||||
try:
|
||||
self._execute_create_test_db(cursor, test_db_params, keepdb)
|
||||
except Exception as e:
|
||||
# if we want to keep the db, then no need to do any of the below,
|
||||
# just return and skip it all.
|
||||
if keepdb:
|
||||
return test_database_name
|
||||
|
||||
self.log("Got an error creating the test database: %s" % e)
|
||||
if not autoclobber:
|
||||
confirm = input(
|
||||
"Type 'yes' if you would like to try deleting the test "
|
||||
"database '%s', or 'no' to cancel: " % test_database_name
|
||||
)
|
||||
if autoclobber or confirm == "yes":
|
||||
try:
|
||||
if verbosity >= 1:
|
||||
self.log(
|
||||
"Destroying old test database for alias %s..."
|
||||
% (
|
||||
self._get_database_display_str(
|
||||
verbosity, test_database_name
|
||||
),
|
||||
)
|
||||
)
|
||||
cursor.execute("DROP DATABASE %(dbname)s" % test_db_params)
|
||||
self._execute_create_test_db(cursor, test_db_params, keepdb)
|
||||
except Exception as e:
|
||||
self.log("Got an error recreating the test database: %s" % e)
|
||||
sys.exit(2)
|
||||
else:
|
||||
self.log("Tests cancelled.")
|
||||
sys.exit(1)
|
||||
|
||||
return test_database_name
|
||||
|
||||
def clone_test_db(self, suffix, verbosity=1, autoclobber=False, keepdb=False):
|
||||
"""
|
||||
Clone a test database.
|
||||
"""
|
||||
source_database_name = self.connection.settings_dict["NAME"]
|
||||
|
||||
if verbosity >= 1:
|
||||
action = "Cloning test database"
|
||||
if keepdb:
|
||||
action = "Using existing clone"
|
||||
self.log(
|
||||
"%s for alias %s..."
|
||||
% (
|
||||
action,
|
||||
self._get_database_display_str(verbosity, source_database_name),
|
||||
)
|
||||
)
|
||||
|
||||
# We could skip this call if keepdb is True, but we instead
|
||||
# give it the keepdb param. See create_test_db for details.
|
||||
self._clone_test_db(suffix, verbosity, keepdb)
|
||||
|
||||
def get_test_db_clone_settings(self, suffix):
|
||||
"""
|
||||
Return a modified connection settings dict for the n-th clone of a DB.
|
||||
"""
|
||||
# When this function is called, the test database has been created
|
||||
# already and its name has been copied to settings_dict['NAME'] so
|
||||
# we don't need to call _get_test_db_name.
|
||||
orig_settings_dict = self.connection.settings_dict
|
||||
return {
|
||||
**orig_settings_dict,
|
||||
"NAME": "{}_{}".format(orig_settings_dict["NAME"], suffix),
|
||||
}
|
||||
|
||||
def _clone_test_db(self, suffix, verbosity, keepdb=False):
|
||||
"""
|
||||
Internal implementation - duplicate the test db tables.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"The database backend doesn't support cloning databases. "
|
||||
"Disable the option to run tests in parallel processes."
|
||||
)
|
||||
|
||||
def destroy_test_db(
|
||||
self, old_database_name=None, verbosity=1, keepdb=False, suffix=None
|
||||
):
|
||||
"""
|
||||
Destroy a test database, prompting the user for confirmation if the
|
||||
database already exists.
|
||||
"""
|
||||
self.connection.close()
|
||||
if suffix is None:
|
||||
test_database_name = self.connection.settings_dict["NAME"]
|
||||
else:
|
||||
test_database_name = self.get_test_db_clone_settings(suffix)["NAME"]
|
||||
|
||||
if verbosity >= 1:
|
||||
action = "Destroying"
|
||||
if keepdb:
|
||||
action = "Preserving"
|
||||
self.log(
|
||||
"%s test database for alias %s..."
|
||||
% (
|
||||
action,
|
||||
self._get_database_display_str(verbosity, test_database_name),
|
||||
)
|
||||
)
|
||||
|
||||
# if we want to preserve the database
|
||||
# skip the actual destroying piece.
|
||||
if not keepdb:
|
||||
self._destroy_test_db(test_database_name, verbosity)
|
||||
|
||||
# Restore the original database name
|
||||
if old_database_name is not None:
|
||||
settings.DATABASES[self.connection.alias]["NAME"] = old_database_name
|
||||
self.connection.settings_dict["NAME"] = old_database_name
|
||||
|
||||
def _destroy_test_db(self, test_database_name, verbosity):
|
||||
"""
|
||||
Internal implementation - remove the test db tables.
|
||||
"""
|
||||
# Remove the test database to clean up after
|
||||
# ourselves. Connect to the previous database (not the test database)
|
||||
# to do so, because it's not allowed to delete a database while being
|
||||
# connected to it.
|
||||
with self._nodb_cursor() as cursor:
|
||||
cursor.execute(
|
||||
"DROP DATABASE %s" % self.connection.ops.quote_name(test_database_name)
|
||||
)
|
||||
|
||||
def mark_expected_failures_and_skips(self):
|
||||
"""
|
||||
Mark tests in Django's test suite which are expected failures on this
|
||||
database and test which should be skipped on this database.
|
||||
"""
|
||||
# Only load unittest if we're actually testing.
|
||||
from unittest import expectedFailure, skip
|
||||
|
||||
for test_name in self.connection.features.django_test_expected_failures:
|
||||
test_case_name, _, test_method_name = test_name.rpartition(".")
|
||||
test_app = test_name.split(".")[0]
|
||||
# Importing a test app that isn't installed raises RuntimeError.
|
||||
if test_app in settings.INSTALLED_APPS:
|
||||
test_case = import_string(test_case_name)
|
||||
test_method = getattr(test_case, test_method_name)
|
||||
setattr(test_case, test_method_name, expectedFailure(test_method))
|
||||
for reason, tests in self.connection.features.django_test_skips.items():
|
||||
for test_name in tests:
|
||||
test_case_name, _, test_method_name = test_name.rpartition(".")
|
||||
test_app = test_name.split(".")[0]
|
||||
# Importing a test app that isn't installed raises RuntimeError.
|
||||
if test_app in settings.INSTALLED_APPS:
|
||||
test_case = import_string(test_case_name)
|
||||
test_method = getattr(test_case, test_method_name)
|
||||
setattr(test_case, test_method_name, skip(reason)(test_method))
|
||||
|
||||
def sql_table_creation_suffix(self):
|
||||
"""
|
||||
SQL to append to the end of the test table creation statements.
|
||||
"""
|
||||
return ""
|
||||
|
||||
def test_db_signature(self):
|
||||
"""
|
||||
Return a tuple with elements of self.connection.settings_dict (a
|
||||
DATABASES setting value) that uniquely identify a database
|
||||
accordingly to the RDBMS particularities.
|
||||
"""
|
||||
settings_dict = self.connection.settings_dict
|
||||
return (
|
||||
settings_dict["HOST"],
|
||||
settings_dict["PORT"],
|
||||
settings_dict["ENGINE"],
|
||||
self._get_test_db_name(),
|
||||
)
|
||||
|
||||
def setup_worker_connection(self, _worker_id):
|
||||
settings_dict = self.get_test_db_clone_settings(str(_worker_id))
|
||||
# connection.settings_dict must be updated in place for changes to be
|
||||
# reflected in django.db.connections. If the following line assigned
|
||||
# connection.settings_dict = settings_dict, new threads would connect
|
||||
# to the default database instead of the appropriate clone.
|
||||
self.connection.settings_dict.update(settings_dict)
|
||||
self.connection.close()
|
||||
@@ -0,0 +1,393 @@
|
||||
from django.db import ProgrammingError
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
||||
class BaseDatabaseFeatures:
|
||||
# An optional tuple indicating the minimum supported database version.
|
||||
minimum_database_version = None
|
||||
gis_enabled = False
|
||||
# Oracle can't group by LOB (large object) data types.
|
||||
allows_group_by_lob = True
|
||||
allows_group_by_selected_pks = False
|
||||
allows_group_by_select_index = True
|
||||
empty_fetchmany_value = []
|
||||
update_can_self_select = True
|
||||
|
||||
# Does the backend distinguish between '' and None?
|
||||
interprets_empty_strings_as_nulls = False
|
||||
|
||||
# Does the backend allow inserting duplicate NULL rows in a nullable
|
||||
# unique field? All core backends implement this correctly, but other
|
||||
# databases such as SQL Server do not.
|
||||
supports_nullable_unique_constraints = True
|
||||
|
||||
# Does the backend allow inserting duplicate rows when a unique_together
|
||||
# constraint exists and some fields are nullable but not all of them?
|
||||
supports_partially_nullable_unique_constraints = True
|
||||
# Does the backend support initially deferrable unique constraints?
|
||||
supports_deferrable_unique_constraints = False
|
||||
|
||||
can_use_chunked_reads = True
|
||||
can_return_columns_from_insert = False
|
||||
can_return_rows_from_bulk_insert = False
|
||||
has_bulk_insert = True
|
||||
uses_savepoints = True
|
||||
can_release_savepoints = False
|
||||
|
||||
# If True, don't use integer foreign keys referring to, e.g., positive
|
||||
# integer primary keys.
|
||||
related_fields_match_type = False
|
||||
allow_sliced_subqueries_with_in = True
|
||||
has_select_for_update = False
|
||||
has_select_for_update_nowait = False
|
||||
has_select_for_update_skip_locked = False
|
||||
has_select_for_update_of = False
|
||||
has_select_for_no_key_update = False
|
||||
# Does the database's SELECT FOR UPDATE OF syntax require a column rather
|
||||
# than a table?
|
||||
select_for_update_of_column = False
|
||||
|
||||
# Does the default test database allow multiple connections?
|
||||
# Usually an indication that the test database is in-memory
|
||||
test_db_allows_multiple_connections = True
|
||||
|
||||
# Can an object be saved without an explicit primary key?
|
||||
supports_unspecified_pk = False
|
||||
|
||||
# Can a fixture contain forward references? i.e., are
|
||||
# FK constraints checked at the end of transaction, or
|
||||
# at the end of each save operation?
|
||||
supports_forward_references = True
|
||||
|
||||
# Does the backend truncate names properly when they are too long?
|
||||
truncates_names = False
|
||||
|
||||
# Is there a REAL datatype in addition to floats/doubles?
|
||||
has_real_datatype = False
|
||||
supports_subqueries_in_group_by = True
|
||||
|
||||
# Does the backend ignore unnecessary ORDER BY clauses in subqueries?
|
||||
ignores_unnecessary_order_by_in_subqueries = True
|
||||
|
||||
# Is there a true datatype for uuid?
|
||||
has_native_uuid_field = False
|
||||
|
||||
# Is there a true datatype for timedeltas?
|
||||
has_native_duration_field = False
|
||||
|
||||
# Does the database driver supports same type temporal data subtraction
|
||||
# by returning the type used to store duration field?
|
||||
supports_temporal_subtraction = False
|
||||
|
||||
# Does the __regex lookup support backreferencing and grouping?
|
||||
supports_regex_backreferencing = True
|
||||
|
||||
# Can date/datetime lookups be performed using a string?
|
||||
supports_date_lookup_using_string = True
|
||||
|
||||
# Can datetimes with timezones be used?
|
||||
supports_timezones = True
|
||||
|
||||
# Does the database have a copy of the zoneinfo database?
|
||||
has_zoneinfo_database = True
|
||||
|
||||
# When performing a GROUP BY, is an ORDER BY NULL required
|
||||
# to remove any ordering?
|
||||
requires_explicit_null_ordering_when_grouping = False
|
||||
|
||||
# Does the backend order NULL values as largest or smallest?
|
||||
nulls_order_largest = False
|
||||
|
||||
# Does the backend support NULLS FIRST and NULLS LAST in ORDER BY?
|
||||
supports_order_by_nulls_modifier = True
|
||||
|
||||
# Does the backend orders NULLS FIRST by default?
|
||||
order_by_nulls_first = False
|
||||
|
||||
# The database's limit on the number of query parameters.
|
||||
max_query_params = None
|
||||
|
||||
# Can an object have an autoincrement primary key of 0?
|
||||
allows_auto_pk_0 = True
|
||||
|
||||
# Do we need to NULL a ForeignKey out, or can the constraint check be
|
||||
# deferred
|
||||
can_defer_constraint_checks = False
|
||||
|
||||
# Does the backend support tablespaces? Default to False because it isn't
|
||||
# in the SQL standard.
|
||||
supports_tablespaces = False
|
||||
|
||||
# Does the backend reset sequences between tests?
|
||||
supports_sequence_reset = True
|
||||
|
||||
# Can the backend introspect the default value of a column?
|
||||
can_introspect_default = True
|
||||
|
||||
# Confirm support for introspected foreign keys
|
||||
# Every database can do this reliably, except MySQL,
|
||||
# which can't do it for MyISAM tables
|
||||
can_introspect_foreign_keys = True
|
||||
|
||||
# Map fields which some backends may not be able to differentiate to the
|
||||
# field it's introspected as.
|
||||
introspected_field_types = {
|
||||
"AutoField": "AutoField",
|
||||
"BigAutoField": "BigAutoField",
|
||||
"BigIntegerField": "BigIntegerField",
|
||||
"BinaryField": "BinaryField",
|
||||
"BooleanField": "BooleanField",
|
||||
"CharField": "CharField",
|
||||
"DurationField": "DurationField",
|
||||
"GenericIPAddressField": "GenericIPAddressField",
|
||||
"IntegerField": "IntegerField",
|
||||
"PositiveBigIntegerField": "PositiveBigIntegerField",
|
||||
"PositiveIntegerField": "PositiveIntegerField",
|
||||
"PositiveSmallIntegerField": "PositiveSmallIntegerField",
|
||||
"SmallAutoField": "SmallAutoField",
|
||||
"SmallIntegerField": "SmallIntegerField",
|
||||
"TimeField": "TimeField",
|
||||
}
|
||||
|
||||
# Can the backend introspect the column order (ASC/DESC) for indexes?
|
||||
supports_index_column_ordering = True
|
||||
|
||||
# Does the backend support introspection of materialized views?
|
||||
can_introspect_materialized_views = False
|
||||
|
||||
# Support for the DISTINCT ON clause
|
||||
can_distinct_on_fields = False
|
||||
|
||||
# Does the backend prevent running SQL queries in broken transactions?
|
||||
atomic_transactions = True
|
||||
|
||||
# Can we roll back DDL in a transaction?
|
||||
can_rollback_ddl = False
|
||||
|
||||
schema_editor_uses_clientside_param_binding = False
|
||||
|
||||
# Does it support operations requiring references rename in a transaction?
|
||||
supports_atomic_references_rename = True
|
||||
|
||||
# Can we issue more than one ALTER COLUMN clause in an ALTER TABLE?
|
||||
supports_combined_alters = False
|
||||
|
||||
# Does it support foreign keys?
|
||||
supports_foreign_keys = True
|
||||
|
||||
# Can it create foreign key constraints inline when adding columns?
|
||||
can_create_inline_fk = True
|
||||
|
||||
# Can an index be renamed?
|
||||
can_rename_index = False
|
||||
|
||||
# Does it automatically index foreign keys?
|
||||
indexes_foreign_keys = True
|
||||
|
||||
# Does it support CHECK constraints?
|
||||
supports_column_check_constraints = True
|
||||
supports_table_check_constraints = True
|
||||
# Does the backend support introspection of CHECK constraints?
|
||||
can_introspect_check_constraints = True
|
||||
|
||||
# Does the backend support 'pyformat' style ("... %(name)s ...", {'name': value})
|
||||
# parameter passing? Note this can be provided by the backend even if not
|
||||
# supported by the Python driver
|
||||
supports_paramstyle_pyformat = True
|
||||
|
||||
# Does the backend require literal defaults, rather than parameterized ones?
|
||||
requires_literal_defaults = False
|
||||
|
||||
# Does the backend require a connection reset after each material schema change?
|
||||
connection_persists_old_columns = False
|
||||
|
||||
# What kind of error does the backend throw when accessing closed cursor?
|
||||
closed_cursor_error_class = ProgrammingError
|
||||
|
||||
# Does 'a' LIKE 'A' match?
|
||||
has_case_insensitive_like = False
|
||||
|
||||
# Suffix for backends that don't support "SELECT xxx;" queries.
|
||||
bare_select_suffix = ""
|
||||
|
||||
# If NULL is implied on columns without needing to be explicitly specified
|
||||
implied_column_null = False
|
||||
|
||||
# Does the backend support "select for update" queries with limit (and offset)?
|
||||
supports_select_for_update_with_limit = True
|
||||
|
||||
# Does the backend ignore null expressions in GREATEST and LEAST queries unless
|
||||
# every expression is null?
|
||||
greatest_least_ignores_nulls = False
|
||||
|
||||
# Can the backend clone databases for parallel test execution?
|
||||
# Defaults to False to allow third-party backends to opt-in.
|
||||
can_clone_databases = False
|
||||
|
||||
# Does the backend consider table names with different casing to
|
||||
# be equal?
|
||||
ignores_table_name_case = False
|
||||
|
||||
# Place FOR UPDATE right after FROM clause. Used on MSSQL.
|
||||
for_update_after_from = False
|
||||
|
||||
# Combinatorial flags
|
||||
supports_select_union = True
|
||||
supports_select_intersection = True
|
||||
supports_select_difference = True
|
||||
supports_slicing_ordering_in_compound = False
|
||||
supports_parentheses_in_compound = True
|
||||
requires_compound_order_by_subquery = False
|
||||
|
||||
# Does the database support SQL 2003 FILTER (WHERE ...) in aggregate
|
||||
# expressions?
|
||||
supports_aggregate_filter_clause = False
|
||||
|
||||
# Does the backend support indexing a TextField?
|
||||
supports_index_on_text_field = True
|
||||
|
||||
# Does the backend support window expressions (expression OVER (...))?
|
||||
supports_over_clause = False
|
||||
supports_frame_range_fixed_distance = False
|
||||
only_supports_unbounded_with_preceding_and_following = False
|
||||
|
||||
# Does the backend support CAST with precision?
|
||||
supports_cast_with_precision = True
|
||||
|
||||
# How many second decimals does the database return when casting a value to
|
||||
# a type with time?
|
||||
time_cast_precision = 6
|
||||
|
||||
# SQL to create a procedure for use by the Django test suite. The
|
||||
# functionality of the procedure isn't important.
|
||||
create_test_procedure_without_params_sql = None
|
||||
create_test_procedure_with_int_param_sql = None
|
||||
|
||||
# SQL to create a table with a composite primary key for use by the Django
|
||||
# test suite.
|
||||
create_test_table_with_composite_primary_key = None
|
||||
|
||||
# Does the backend support keyword parameters for cursor.callproc()?
|
||||
supports_callproc_kwargs = False
|
||||
|
||||
# What formats does the backend EXPLAIN syntax support?
|
||||
supported_explain_formats = set()
|
||||
|
||||
# Does the backend support the default parameter in lead() and lag()?
|
||||
supports_default_in_lead_lag = True
|
||||
|
||||
# Does the backend support ignoring constraint or uniqueness errors during
|
||||
# INSERT?
|
||||
supports_ignore_conflicts = True
|
||||
# Does the backend support updating rows on constraint or uniqueness errors
|
||||
# during INSERT?
|
||||
supports_update_conflicts = False
|
||||
supports_update_conflicts_with_target = False
|
||||
|
||||
# Does this backend require casting the results of CASE expressions used
|
||||
# in UPDATE statements to ensure the expression has the correct type?
|
||||
requires_casted_case_in_updates = False
|
||||
|
||||
# Does the backend support partial indexes (CREATE INDEX ... WHERE ...)?
|
||||
supports_partial_indexes = True
|
||||
supports_functions_in_partial_indexes = True
|
||||
# Does the backend support covering indexes (CREATE INDEX ... INCLUDE ...)?
|
||||
supports_covering_indexes = False
|
||||
# Does the backend support indexes on expressions?
|
||||
supports_expression_indexes = True
|
||||
# Does the backend treat COLLATE as an indexed expression?
|
||||
collate_as_index_expression = False
|
||||
|
||||
# Does the database allow more than one constraint or index on the same
|
||||
# field(s)?
|
||||
allows_multiple_constraints_on_same_fields = True
|
||||
|
||||
# Does the backend support boolean expressions in SELECT and GROUP BY
|
||||
# clauses?
|
||||
supports_boolean_expr_in_select_clause = True
|
||||
# Does the backend support comparing boolean expressions in WHERE clauses?
|
||||
# Eg: WHERE (price > 0) IS NOT NULL
|
||||
supports_comparing_boolean_expr = True
|
||||
|
||||
# Does the backend support JSONField?
|
||||
supports_json_field = True
|
||||
# Can the backend introspect a JSONField?
|
||||
can_introspect_json_field = True
|
||||
# Does the backend support primitives in JSONField?
|
||||
supports_primitives_in_json_field = True
|
||||
# Is there a true datatype for JSON?
|
||||
has_native_json_field = False
|
||||
# Does the backend use PostgreSQL-style JSON operators like '->'?
|
||||
has_json_operators = False
|
||||
# Does the backend support __contains and __contained_by lookups for
|
||||
# a JSONField?
|
||||
supports_json_field_contains = True
|
||||
# Does value__d__contains={'f': 'g'} (without a list around the dict) match
|
||||
# {'d': [{'f': 'g'}]}?
|
||||
json_key_contains_list_matching_requires_list = False
|
||||
# Does the backend support JSONObject() database function?
|
||||
has_json_object_function = True
|
||||
|
||||
# Does the backend support column collations?
|
||||
supports_collation_on_charfield = True
|
||||
supports_collation_on_textfield = True
|
||||
# Does the backend support non-deterministic collations?
|
||||
supports_non_deterministic_collations = True
|
||||
|
||||
# Does the backend support column and table comments?
|
||||
supports_comments = False
|
||||
# Does the backend support column comments in ADD COLUMN statements?
|
||||
supports_comments_inline = False
|
||||
|
||||
# Does the backend support the logical XOR operator?
|
||||
supports_logical_xor = False
|
||||
|
||||
# Set to (exception, message) if null characters in text are disallowed.
|
||||
prohibits_null_characters_in_text_exception = None
|
||||
|
||||
# Does the backend support unlimited character columns?
|
||||
supports_unlimited_charfield = False
|
||||
|
||||
# Collation names for use by the Django test suite.
|
||||
test_collations = {
|
||||
"ci": None, # Case-insensitive.
|
||||
"cs": None, # Case-sensitive.
|
||||
"non_default": None, # Non-default.
|
||||
"swedish_ci": None, # Swedish case-insensitive.
|
||||
}
|
||||
# SQL template override for tests.aggregation.tests.NowUTC
|
||||
test_now_utc_template = None
|
||||
|
||||
# A set of dotted paths to tests in Django's test suite that are expected
|
||||
# to fail on this database.
|
||||
django_test_expected_failures = set()
|
||||
# A map of reasons to sets of dotted paths to tests in Django's test suite
|
||||
# that should be skipped for this database.
|
||||
django_test_skips = {}
|
||||
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
@cached_property
|
||||
def supports_explaining_query_execution(self):
|
||||
"""Does this backend support explaining query execution?"""
|
||||
return self.connection.ops.explain_prefix is not None
|
||||
|
||||
@cached_property
|
||||
def supports_transactions(self):
|
||||
"""Confirm support for transactions."""
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute("CREATE TABLE ROLLBACK_TEST (X INT)")
|
||||
self.connection.set_autocommit(False)
|
||||
cursor.execute("INSERT INTO ROLLBACK_TEST (X) VALUES (8)")
|
||||
self.connection.rollback()
|
||||
self.connection.set_autocommit(True)
|
||||
cursor.execute("SELECT COUNT(X) FROM ROLLBACK_TEST")
|
||||
(count,) = cursor.fetchone()
|
||||
cursor.execute("DROP TABLE ROLLBACK_TEST")
|
||||
return count == 0
|
||||
|
||||
def allows_group_by_selected_pks_on_model(self, model):
|
||||
if not self.allows_group_by_selected_pks:
|
||||
return False
|
||||
return model._meta.managed
|
||||
@@ -0,0 +1,212 @@
|
||||
from collections import namedtuple
|
||||
|
||||
# Structure returned by DatabaseIntrospection.get_table_list()
|
||||
TableInfo = namedtuple("TableInfo", ["name", "type"])
|
||||
|
||||
# Structure returned by the DB-API cursor.description interface (PEP 249)
|
||||
FieldInfo = namedtuple(
|
||||
"FieldInfo",
|
||||
"name type_code display_size internal_size precision scale null_ok "
|
||||
"default collation",
|
||||
)
|
||||
|
||||
|
||||
class BaseDatabaseIntrospection:
|
||||
"""Encapsulate backend-specific introspection utilities."""
|
||||
|
||||
data_types_reverse = {}
|
||||
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
def get_field_type(self, data_type, description):
|
||||
"""
|
||||
Hook for a database backend to use the cursor description to
|
||||
match a Django field type to a database column.
|
||||
|
||||
For Oracle, the column data_type on its own is insufficient to
|
||||
distinguish between a FloatField and IntegerField, for example.
|
||||
"""
|
||||
return self.data_types_reverse[data_type]
|
||||
|
||||
def identifier_converter(self, name):
|
||||
"""
|
||||
Apply a conversion to the identifier for the purposes of comparison.
|
||||
|
||||
The default identifier converter is for case sensitive comparison.
|
||||
"""
|
||||
return name
|
||||
|
||||
def table_names(self, cursor=None, include_views=False):
|
||||
"""
|
||||
Return a list of names of all tables that exist in the database.
|
||||
Sort the returned table list by Python's default sorting. Do NOT use
|
||||
the database's ORDER BY here to avoid subtle differences in sorting
|
||||
order between databases.
|
||||
"""
|
||||
|
||||
def get_names(cursor):
|
||||
return sorted(
|
||||
ti.name
|
||||
for ti in self.get_table_list(cursor)
|
||||
if include_views or ti.type == "t"
|
||||
)
|
||||
|
||||
if cursor is None:
|
||||
with self.connection.cursor() as cursor:
|
||||
return get_names(cursor)
|
||||
return get_names(cursor)
|
||||
|
||||
def get_table_list(self, cursor):
|
||||
"""
|
||||
Return an unsorted list of TableInfo named tuples of all tables and
|
||||
views that exist in the database.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseIntrospection may require a get_table_list() "
|
||||
"method"
|
||||
)
|
||||
|
||||
def get_table_description(self, cursor, table_name):
|
||||
"""
|
||||
Return a description of the table with the DB-API cursor.description
|
||||
interface.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseIntrospection may require a "
|
||||
"get_table_description() method."
|
||||
)
|
||||
|
||||
def get_migratable_models(self):
|
||||
from django.apps import apps
|
||||
from django.db import router
|
||||
|
||||
return (
|
||||
model
|
||||
for app_config in apps.get_app_configs()
|
||||
for model in router.get_migratable_models(app_config, self.connection.alias)
|
||||
if model._meta.can_migrate(self.connection)
|
||||
)
|
||||
|
||||
def django_table_names(self, only_existing=False, include_views=True):
|
||||
"""
|
||||
Return a list of all table names that have associated Django models and
|
||||
are in INSTALLED_APPS.
|
||||
|
||||
If only_existing is True, include only the tables in the database.
|
||||
"""
|
||||
tables = set()
|
||||
for model in self.get_migratable_models():
|
||||
if not model._meta.managed:
|
||||
continue
|
||||
tables.add(model._meta.db_table)
|
||||
tables.update(
|
||||
f.m2m_db_table()
|
||||
for f in model._meta.local_many_to_many
|
||||
if f.remote_field.through._meta.managed
|
||||
)
|
||||
tables = list(tables)
|
||||
if only_existing:
|
||||
existing_tables = set(self.table_names(include_views=include_views))
|
||||
tables = [
|
||||
t for t in tables if self.identifier_converter(t) in existing_tables
|
||||
]
|
||||
return tables
|
||||
|
||||
def installed_models(self, tables):
|
||||
"""
|
||||
Return a set of all models represented by the provided list of table
|
||||
names.
|
||||
"""
|
||||
tables = set(map(self.identifier_converter, tables))
|
||||
return {
|
||||
m
|
||||
for m in self.get_migratable_models()
|
||||
if self.identifier_converter(m._meta.db_table) in tables
|
||||
}
|
||||
|
||||
def sequence_list(self):
|
||||
"""
|
||||
Return a list of information about all DB sequences for all models in
|
||||
all apps.
|
||||
"""
|
||||
sequence_list = []
|
||||
with self.connection.cursor() as cursor:
|
||||
for model in self.get_migratable_models():
|
||||
if not model._meta.managed:
|
||||
continue
|
||||
if model._meta.swapped:
|
||||
continue
|
||||
sequence_list.extend(
|
||||
self.get_sequences(
|
||||
cursor, model._meta.db_table, model._meta.local_fields
|
||||
)
|
||||
)
|
||||
for f in model._meta.local_many_to_many:
|
||||
# If this is an m2m using an intermediate table,
|
||||
# we don't need to reset the sequence.
|
||||
if f.remote_field.through._meta.auto_created:
|
||||
sequence = self.get_sequences(cursor, f.m2m_db_table())
|
||||
sequence_list.extend(
|
||||
sequence or [{"table": f.m2m_db_table(), "column": None}]
|
||||
)
|
||||
return sequence_list
|
||||
|
||||
def get_sequences(self, cursor, table_name, table_fields=()):
|
||||
"""
|
||||
Return a list of introspected sequences for table_name. Each sequence
|
||||
is a dict: {'table': <table_name>, 'column': <column_name>}. An optional
|
||||
'name' key can be added if the backend supports named sequences.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseIntrospection may require a get_sequences() "
|
||||
"method"
|
||||
)
|
||||
|
||||
def get_relations(self, cursor, table_name):
|
||||
"""
|
||||
Return a dictionary of {field_name: (field_name_other_table, other_table)}
|
||||
representing all foreign keys in the given table.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseIntrospection may require a "
|
||||
"get_relations() method."
|
||||
)
|
||||
|
||||
def get_primary_key_column(self, cursor, table_name):
|
||||
"""
|
||||
Return the name of the primary key column for the given table.
|
||||
"""
|
||||
columns = self.get_primary_key_columns(cursor, table_name)
|
||||
return columns[0] if columns else None
|
||||
|
||||
def get_primary_key_columns(self, cursor, table_name):
|
||||
"""Return a list of primary key columns for the given table."""
|
||||
for constraint in self.get_constraints(cursor, table_name).values():
|
||||
if constraint["primary_key"]:
|
||||
return constraint["columns"]
|
||||
return None
|
||||
|
||||
def get_constraints(self, cursor, table_name):
|
||||
"""
|
||||
Retrieve any constraints or keys (unique, pk, fk, check, index)
|
||||
across one or more columns.
|
||||
|
||||
Return a dict mapping constraint names to their attributes,
|
||||
where attributes is a dict with keys:
|
||||
* columns: List of columns this covers
|
||||
* primary_key: True if primary key, False otherwise
|
||||
* unique: True if this is a unique constraint, False otherwise
|
||||
* foreign_key: (table, column) of target, or None
|
||||
* check: True if check constraint, False otherwise
|
||||
* index: True if index, False otherwise.
|
||||
* orders: The order (ASC/DESC) defined for the columns of indexes
|
||||
* type: The type of the index (btree, hash, etc.)
|
||||
|
||||
Some backends may return special constraint names that don't exist
|
||||
if they don't name constraints of a certain type (e.g. SQLite)
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseIntrospection may require a get_constraints() "
|
||||
"method"
|
||||
)
|
||||
@@ -0,0 +1,778 @@
|
||||
import datetime
|
||||
import decimal
|
||||
import json
|
||||
from importlib import import_module
|
||||
|
||||
import sqlparse
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import NotSupportedError, transaction
|
||||
from django.db.backends import utils
|
||||
from django.utils import timezone
|
||||
from django.utils.encoding import force_str
|
||||
|
||||
|
||||
class BaseDatabaseOperations:
|
||||
"""
|
||||
Encapsulate backend-specific differences, such as the way a backend
|
||||
performs ordering or calculates the ID of a recently-inserted row.
|
||||
"""
|
||||
|
||||
compiler_module = "django.db.models.sql.compiler"
|
||||
|
||||
# Integer field safe ranges by `internal_type` as documented
|
||||
# in docs/ref/models/fields.txt.
|
||||
integer_field_ranges = {
|
||||
"SmallIntegerField": (-32768, 32767),
|
||||
"IntegerField": (-2147483648, 2147483647),
|
||||
"BigIntegerField": (-9223372036854775808, 9223372036854775807),
|
||||
"PositiveBigIntegerField": (0, 9223372036854775807),
|
||||
"PositiveSmallIntegerField": (0, 32767),
|
||||
"PositiveIntegerField": (0, 2147483647),
|
||||
"SmallAutoField": (-32768, 32767),
|
||||
"AutoField": (-2147483648, 2147483647),
|
||||
"BigAutoField": (-9223372036854775808, 9223372036854775807),
|
||||
}
|
||||
set_operators = {
|
||||
"union": "UNION",
|
||||
"intersection": "INTERSECT",
|
||||
"difference": "EXCEPT",
|
||||
}
|
||||
# Mapping of Field.get_internal_type() (typically the model field's class
|
||||
# name) to the data type to use for the Cast() function, if different from
|
||||
# DatabaseWrapper.data_types.
|
||||
cast_data_types = {}
|
||||
# CharField data type if the max_length argument isn't provided.
|
||||
cast_char_field_without_max_length = None
|
||||
|
||||
# Start and end points for window expressions.
|
||||
PRECEDING = "PRECEDING"
|
||||
FOLLOWING = "FOLLOWING"
|
||||
UNBOUNDED_PRECEDING = "UNBOUNDED " + PRECEDING
|
||||
UNBOUNDED_FOLLOWING = "UNBOUNDED " + FOLLOWING
|
||||
CURRENT_ROW = "CURRENT ROW"
|
||||
|
||||
# Prefix for EXPLAIN queries, or None EXPLAIN isn't supported.
|
||||
explain_prefix = None
|
||||
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
self._cache = None
|
||||
|
||||
def autoinc_sql(self, table, column):
|
||||
"""
|
||||
Return any SQL needed to support auto-incrementing primary keys, or
|
||||
None if no SQL is necessary.
|
||||
|
||||
This SQL is executed when a table is created.
|
||||
"""
|
||||
return None
|
||||
|
||||
def bulk_batch_size(self, fields, objs):
|
||||
"""
|
||||
Return the maximum allowed batch size for the backend. The fields
|
||||
are the fields going to be inserted in the batch, the objs contains
|
||||
all the objects to be inserted.
|
||||
"""
|
||||
return len(objs)
|
||||
|
||||
def format_for_duration_arithmetic(self, sql):
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseOperations may require a "
|
||||
"format_for_duration_arithmetic() method."
|
||||
)
|
||||
|
||||
def cache_key_culling_sql(self):
|
||||
"""
|
||||
Return an SQL query that retrieves the first cache key greater than the
|
||||
n smallest.
|
||||
|
||||
This is used by the 'db' cache backend to determine where to start
|
||||
culling.
|
||||
"""
|
||||
cache_key = self.quote_name("cache_key")
|
||||
return f"SELECT {cache_key} FROM %s ORDER BY {cache_key} LIMIT 1 OFFSET %%s"
|
||||
|
||||
def unification_cast_sql(self, output_field):
|
||||
"""
|
||||
Given a field instance, return the SQL that casts the result of a union
|
||||
to that type. The resulting string should contain a '%s' placeholder
|
||||
for the expression being cast.
|
||||
"""
|
||||
return "%s"
|
||||
|
||||
def date_extract_sql(self, lookup_type, sql, params):
|
||||
"""
|
||||
Given a lookup_type of 'year', 'month', or 'day', return the SQL that
|
||||
extracts a value from the given date field field_name.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseOperations may require a date_extract_sql() "
|
||||
"method"
|
||||
)
|
||||
|
||||
def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
|
||||
"""
|
||||
Given a lookup_type of 'year', 'month', or 'day', return the SQL that
|
||||
truncates the given date or datetime field field_name to a date object
|
||||
with only the given specificity.
|
||||
|
||||
If `tzname` is provided, the given value is truncated in a specific
|
||||
timezone.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseOperations may require a date_trunc_sql() "
|
||||
"method."
|
||||
)
|
||||
|
||||
def datetime_cast_date_sql(self, sql, params, tzname):
|
||||
"""
|
||||
Return the SQL to cast a datetime value to date value.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseOperations may require a "
|
||||
"datetime_cast_date_sql() method."
|
||||
)
|
||||
|
||||
def datetime_cast_time_sql(self, sql, params, tzname):
|
||||
"""
|
||||
Return the SQL to cast a datetime value to time value.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseOperations may require a "
|
||||
"datetime_cast_time_sql() method"
|
||||
)
|
||||
|
||||
def datetime_extract_sql(self, lookup_type, sql, params, tzname):
|
||||
"""
|
||||
Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
|
||||
'second', return the SQL that extracts a value from the given
|
||||
datetime field field_name.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseOperations may require a datetime_extract_sql() "
|
||||
"method"
|
||||
)
|
||||
|
||||
def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
|
||||
"""
|
||||
Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
|
||||
'second', return the SQL that truncates the given datetime field
|
||||
field_name to a datetime object with only the given specificity.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseOperations may require a datetime_trunc_sql() "
|
||||
"method"
|
||||
)
|
||||
|
||||
def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
|
||||
"""
|
||||
Given a lookup_type of 'hour', 'minute' or 'second', return the SQL
|
||||
that truncates the given time or datetime field field_name to a time
|
||||
object with only the given specificity.
|
||||
|
||||
If `tzname` is provided, the given value is truncated in a specific
|
||||
timezone.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseOperations may require a time_trunc_sql() method"
|
||||
)
|
||||
|
||||
def time_extract_sql(self, lookup_type, sql, params):
|
||||
"""
|
||||
Given a lookup_type of 'hour', 'minute', or 'second', return the SQL
|
||||
that extracts a value from the given time field field_name.
|
||||
"""
|
||||
return self.date_extract_sql(lookup_type, sql, params)
|
||||
|
||||
def deferrable_sql(self):
|
||||
"""
|
||||
Return the SQL to make a constraint "initially deferred" during a
|
||||
CREATE TABLE statement.
|
||||
"""
|
||||
return ""
|
||||
|
||||
def distinct_sql(self, fields, params):
|
||||
"""
|
||||
Return an SQL DISTINCT clause which removes duplicate rows from the
|
||||
result set. If any fields are given, only check the given fields for
|
||||
duplicates.
|
||||
"""
|
||||
if fields:
|
||||
raise NotSupportedError(
|
||||
"DISTINCT ON fields is not supported by this database backend"
|
||||
)
|
||||
else:
|
||||
return ["DISTINCT"], []
|
||||
|
||||
def fetch_returned_insert_columns(self, cursor, returning_params):
|
||||
"""
|
||||
Given a cursor object that has just performed an INSERT...RETURNING
|
||||
statement into a table, return the newly created data.
|
||||
"""
|
||||
return cursor.fetchone()
|
||||
|
||||
def field_cast_sql(self, db_type, internal_type):
|
||||
"""
|
||||
Given a column type (e.g. 'BLOB', 'VARCHAR') and an internal type
|
||||
(e.g. 'GenericIPAddressField'), return the SQL to cast it before using
|
||||
it in a WHERE statement. The resulting string should contain a '%s'
|
||||
placeholder for the column being searched against.
|
||||
"""
|
||||
return "%s"
|
||||
|
||||
def force_no_ordering(self):
|
||||
"""
|
||||
Return a list used in the "ORDER BY" clause to force no ordering at
|
||||
all. Return an empty list to include nothing in the ordering.
|
||||
"""
|
||||
return []
|
||||
|
||||
def for_update_sql(self, nowait=False, skip_locked=False, of=(), no_key=False):
|
||||
"""
|
||||
Return the FOR UPDATE SQL clause to lock rows for an update operation.
|
||||
"""
|
||||
return "FOR%s UPDATE%s%s%s" % (
|
||||
" NO KEY" if no_key else "",
|
||||
" OF %s" % ", ".join(of) if of else "",
|
||||
" NOWAIT" if nowait else "",
|
||||
" SKIP LOCKED" if skip_locked else "",
|
||||
)
|
||||
|
||||
def _get_limit_offset_params(self, low_mark, high_mark):
|
||||
offset = low_mark or 0
|
||||
if high_mark is not None:
|
||||
return (high_mark - offset), offset
|
||||
elif offset:
|
||||
return self.connection.ops.no_limit_value(), offset
|
||||
return None, offset
|
||||
|
||||
def limit_offset_sql(self, low_mark, high_mark):
|
||||
"""Return LIMIT/OFFSET SQL clause."""
|
||||
limit, offset = self._get_limit_offset_params(low_mark, high_mark)
|
||||
return " ".join(
|
||||
sql
|
||||
for sql in (
|
||||
("LIMIT %d" % limit) if limit else None,
|
||||
("OFFSET %d" % offset) if offset else None,
|
||||
)
|
||||
if sql
|
||||
)
|
||||
|
||||
def last_executed_query(self, cursor, sql, params):
|
||||
"""
|
||||
Return a string of the query last executed by the given cursor, with
|
||||
placeholders replaced with actual values.
|
||||
|
||||
`sql` is the raw query containing placeholders and `params` is the
|
||||
sequence of parameters. These are used by default, but this method
|
||||
exists for database backends to provide a better implementation
|
||||
according to their own quoting schemes.
|
||||
"""
|
||||
|
||||
# Convert params to contain string values.
|
||||
def to_string(s):
|
||||
return force_str(s, strings_only=True, errors="replace")
|
||||
|
||||
if isinstance(params, (list, tuple)):
|
||||
u_params = tuple(to_string(val) for val in params)
|
||||
elif params is None:
|
||||
u_params = ()
|
||||
else:
|
||||
u_params = {to_string(k): to_string(v) for k, v in params.items()}
|
||||
|
||||
return "QUERY = %r - PARAMS = %r" % (sql, u_params)
|
||||
|
||||
def last_insert_id(self, cursor, table_name, pk_name):
|
||||
"""
|
||||
Given a cursor object that has just performed an INSERT statement into
|
||||
a table that has an auto-incrementing ID, return the newly created ID.
|
||||
|
||||
`pk_name` is the name of the primary-key column.
|
||||
"""
|
||||
return cursor.lastrowid
|
||||
|
||||
def lookup_cast(self, lookup_type, internal_type=None):
|
||||
"""
|
||||
Return the string to use in a query when performing lookups
|
||||
("contains", "like", etc.). It should contain a '%s' placeholder for
|
||||
the column being searched against.
|
||||
"""
|
||||
return "%s"
|
||||
|
||||
def max_in_list_size(self):
|
||||
"""
|
||||
Return the maximum number of items that can be passed in a single 'IN'
|
||||
list condition, or None if the backend does not impose a limit.
|
||||
"""
|
||||
return None
|
||||
|
||||
def max_name_length(self):
|
||||
"""
|
||||
Return the maximum length of table and column names, or None if there
|
||||
is no limit.
|
||||
"""
|
||||
return None
|
||||
|
||||
def no_limit_value(self):
|
||||
"""
|
||||
Return the value to use for the LIMIT when we are wanting "LIMIT
|
||||
infinity". Return None if the limit clause can be omitted in this case.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseOperations may require a no_limit_value() method"
|
||||
)
|
||||
|
||||
def pk_default_value(self):
|
||||
"""
|
||||
Return the value to use during an INSERT statement to specify that
|
||||
the field should use its default value.
|
||||
"""
|
||||
return "DEFAULT"
|
||||
|
||||
def prepare_sql_script(self, sql):
|
||||
"""
|
||||
Take an SQL script that may contain multiple lines and return a list
|
||||
of statements to feed to successive cursor.execute() calls.
|
||||
|
||||
Since few databases are able to process raw SQL scripts in a single
|
||||
cursor.execute() call and PEP 249 doesn't talk about this use case,
|
||||
the default implementation is conservative.
|
||||
"""
|
||||
return [
|
||||
sqlparse.format(statement, strip_comments=True)
|
||||
for statement in sqlparse.split(sql)
|
||||
if statement
|
||||
]
|
||||
|
||||
def process_clob(self, value):
|
||||
"""
|
||||
Return the value of a CLOB column, for backends that return a locator
|
||||
object that requires additional processing.
|
||||
"""
|
||||
return value
|
||||
|
||||
def return_insert_columns(self, fields):
|
||||
"""
|
||||
For backends that support returning columns as part of an insert query,
|
||||
return the SQL and params to append to the INSERT query. The returned
|
||||
fragment should contain a format string to hold the appropriate column.
|
||||
"""
|
||||
pass
|
||||
|
||||
def compiler(self, compiler_name):
|
||||
"""
|
||||
Return the SQLCompiler class corresponding to the given name,
|
||||
in the namespace corresponding to the `compiler_module` attribute
|
||||
on this backend.
|
||||
"""
|
||||
if self._cache is None:
|
||||
self._cache = import_module(self.compiler_module)
|
||||
return getattr(self._cache, compiler_name)
|
||||
|
||||
def quote_name(self, name):
|
||||
"""
|
||||
Return a quoted version of the given table, index, or column name. Do
|
||||
not quote the given name if it's already been quoted.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseOperations may require a quote_name() method"
|
||||
)
|
||||
|
||||
def regex_lookup(self, lookup_type):
|
||||
"""
|
||||
Return the string to use in a query when performing regular expression
|
||||
lookups (using "regex" or "iregex"). It should contain a '%s'
|
||||
placeholder for the column being searched against.
|
||||
|
||||
If the feature is not supported (or part of it is not supported), raise
|
||||
NotImplementedError.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseOperations may require a regex_lookup() method"
|
||||
)
|
||||
|
||||
def savepoint_create_sql(self, sid):
|
||||
"""
|
||||
Return the SQL for starting a new savepoint. Only required if the
|
||||
"uses_savepoints" feature is True. The "sid" parameter is a string
|
||||
for the savepoint id.
|
||||
"""
|
||||
return "SAVEPOINT %s" % self.quote_name(sid)
|
||||
|
||||
def savepoint_commit_sql(self, sid):
|
||||
"""
|
||||
Return the SQL for committing the given savepoint.
|
||||
"""
|
||||
return "RELEASE SAVEPOINT %s" % self.quote_name(sid)
|
||||
|
||||
def savepoint_rollback_sql(self, sid):
|
||||
"""
|
||||
Return the SQL for rolling back the given savepoint.
|
||||
"""
|
||||
return "ROLLBACK TO SAVEPOINT %s" % self.quote_name(sid)
|
||||
|
||||
def set_time_zone_sql(self):
|
||||
"""
|
||||
Return the SQL that will set the connection's time zone.
|
||||
|
||||
Return '' if the backend doesn't support time zones.
|
||||
"""
|
||||
return ""
|
||||
|
||||
def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
|
||||
"""
|
||||
Return a list of SQL statements required to remove all data from
|
||||
the given database tables (without actually removing the tables
|
||||
themselves).
|
||||
|
||||
The `style` argument is a Style object as returned by either
|
||||
color_style() or no_style() in django.core.management.color.
|
||||
|
||||
If `reset_sequences` is True, the list includes SQL statements required
|
||||
to reset the sequences.
|
||||
|
||||
The `allow_cascade` argument determines whether truncation may cascade
|
||||
to tables with foreign keys pointing the tables being truncated.
|
||||
PostgreSQL requires a cascade even if these tables are empty.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseOperations must provide an sql_flush() method"
|
||||
)
|
||||
|
||||
def execute_sql_flush(self, sql_list):
|
||||
"""Execute a list of SQL statements to flush the database."""
|
||||
with transaction.atomic(
|
||||
using=self.connection.alias,
|
||||
savepoint=self.connection.features.can_rollback_ddl,
|
||||
):
|
||||
with self.connection.cursor() as cursor:
|
||||
for sql in sql_list:
|
||||
cursor.execute(sql)
|
||||
|
||||
def sequence_reset_by_name_sql(self, style, sequences):
|
||||
"""
|
||||
Return a list of the SQL statements required to reset sequences
|
||||
passed in `sequences`.
|
||||
|
||||
The `style` argument is a Style object as returned by either
|
||||
color_style() or no_style() in django.core.management.color.
|
||||
"""
|
||||
return []
|
||||
|
||||
def sequence_reset_sql(self, style, model_list):
|
||||
"""
|
||||
Return a list of the SQL statements required to reset sequences for
|
||||
the given models.
|
||||
|
||||
The `style` argument is a Style object as returned by either
|
||||
color_style() or no_style() in django.core.management.color.
|
||||
"""
|
||||
return [] # No sequence reset required by default.
|
||||
|
||||
def start_transaction_sql(self):
|
||||
"""Return the SQL statement required to start a transaction."""
|
||||
return "BEGIN;"
|
||||
|
||||
def end_transaction_sql(self, success=True):
|
||||
"""Return the SQL statement required to end a transaction."""
|
||||
if not success:
|
||||
return "ROLLBACK;"
|
||||
return "COMMIT;"
|
||||
|
||||
def tablespace_sql(self, tablespace, inline=False):
|
||||
"""
|
||||
Return the SQL that will be used in a query to define the tablespace.
|
||||
|
||||
Return '' if the backend doesn't support tablespaces.
|
||||
|
||||
If `inline` is True, append the SQL to a row; otherwise append it to
|
||||
the entire CREATE TABLE or CREATE INDEX statement.
|
||||
"""
|
||||
return ""
|
||||
|
||||
def prep_for_like_query(self, x):
|
||||
"""Prepare a value for use in a LIKE query."""
|
||||
return str(x).replace("\\", "\\\\").replace("%", r"\%").replace("_", r"\_")
|
||||
|
||||
# Same as prep_for_like_query(), but called for "iexact" matches, which
|
||||
# need not necessarily be implemented using "LIKE" in the backend.
|
||||
prep_for_iexact_query = prep_for_like_query
|
||||
|
||||
def validate_autopk_value(self, value):
|
||||
"""
|
||||
Certain backends do not accept some values for "serial" fields
|
||||
(for example zero in MySQL). Raise a ValueError if the value is
|
||||
invalid, otherwise return the validated value.
|
||||
"""
|
||||
return value
|
||||
|
||||
def adapt_unknown_value(self, value):
|
||||
"""
|
||||
Transform a value to something compatible with the backend driver.
|
||||
|
||||
This method only depends on the type of the value. It's designed for
|
||||
cases where the target type isn't known, such as .raw() SQL queries.
|
||||
As a consequence it may not work perfectly in all circumstances.
|
||||
"""
|
||||
if isinstance(value, datetime.datetime): # must be before date
|
||||
return self.adapt_datetimefield_value(value)
|
||||
elif isinstance(value, datetime.date):
|
||||
return self.adapt_datefield_value(value)
|
||||
elif isinstance(value, datetime.time):
|
||||
return self.adapt_timefield_value(value)
|
||||
elif isinstance(value, decimal.Decimal):
|
||||
return self.adapt_decimalfield_value(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
def adapt_integerfield_value(self, value, internal_type):
|
||||
return value
|
||||
|
||||
def adapt_datefield_value(self, value):
|
||||
"""
|
||||
Transform a date value to an object compatible with what is expected
|
||||
by the backend driver for date columns.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
return str(value)
|
||||
|
||||
def adapt_datetimefield_value(self, value):
|
||||
"""
|
||||
Transform a datetime value to an object compatible with what is expected
|
||||
by the backend driver for datetime columns.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
# Expression values are adapted by the database.
|
||||
if hasattr(value, "resolve_expression"):
|
||||
return value
|
||||
|
||||
return str(value)
|
||||
|
||||
def adapt_timefield_value(self, value):
|
||||
"""
|
||||
Transform a time value to an object compatible with what is expected
|
||||
by the backend driver for time columns.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
# Expression values are adapted by the database.
|
||||
if hasattr(value, "resolve_expression"):
|
||||
return value
|
||||
|
||||
if timezone.is_aware(value):
|
||||
raise ValueError("Django does not support timezone-aware times.")
|
||||
return str(value)
|
||||
|
||||
def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None):
|
||||
"""
|
||||
Transform a decimal.Decimal value to an object compatible with what is
|
||||
expected by the backend driver for decimal (numeric) columns.
|
||||
"""
|
||||
return utils.format_number(value, max_digits, decimal_places)
|
||||
|
||||
def adapt_ipaddressfield_value(self, value):
|
||||
"""
|
||||
Transform a string representation of an IP address into the expected
|
||||
type for the backend driver.
|
||||
"""
|
||||
return value or None
|
||||
|
||||
def adapt_json_value(self, value, encoder):
|
||||
return json.dumps(value, cls=encoder)
|
||||
|
||||
def year_lookup_bounds_for_date_field(self, value, iso_year=False):
|
||||
"""
|
||||
Return a two-elements list with the lower and upper bound to be used
|
||||
with a BETWEEN operator to query a DateField value using a year
|
||||
lookup.
|
||||
|
||||
`value` is an int, containing the looked-up year.
|
||||
If `iso_year` is True, return bounds for ISO-8601 week-numbering years.
|
||||
"""
|
||||
if iso_year:
|
||||
first = datetime.date.fromisocalendar(value, 1, 1)
|
||||
second = datetime.date.fromisocalendar(
|
||||
value + 1, 1, 1
|
||||
) - datetime.timedelta(days=1)
|
||||
else:
|
||||
first = datetime.date(value, 1, 1)
|
||||
second = datetime.date(value, 12, 31)
|
||||
first = self.adapt_datefield_value(first)
|
||||
second = self.adapt_datefield_value(second)
|
||||
return [first, second]
|
||||
|
||||
def year_lookup_bounds_for_datetime_field(self, value, iso_year=False):
|
||||
"""
|
||||
Return a two-elements list with the lower and upper bound to be used
|
||||
with a BETWEEN operator to query a DateTimeField value using a year
|
||||
lookup.
|
||||
|
||||
`value` is an int, containing the looked-up year.
|
||||
If `iso_year` is True, return bounds for ISO-8601 week-numbering years.
|
||||
"""
|
||||
if iso_year:
|
||||
first = datetime.datetime.fromisocalendar(value, 1, 1)
|
||||
second = datetime.datetime.fromisocalendar(
|
||||
value + 1, 1, 1
|
||||
) - datetime.timedelta(microseconds=1)
|
||||
else:
|
||||
first = datetime.datetime(value, 1, 1)
|
||||
second = datetime.datetime(value, 12, 31, 23, 59, 59, 999999)
|
||||
if settings.USE_TZ:
|
||||
tz = timezone.get_current_timezone()
|
||||
first = timezone.make_aware(first, tz)
|
||||
second = timezone.make_aware(second, tz)
|
||||
first = self.adapt_datetimefield_value(first)
|
||||
second = self.adapt_datetimefield_value(second)
|
||||
return [first, second]
|
||||
|
||||
def get_db_converters(self, expression):
|
||||
"""
|
||||
Return a list of functions needed to convert field data.
|
||||
|
||||
Some field types on some backends do not provide data in the correct
|
||||
format, this is the hook for converter functions.
|
||||
"""
|
||||
return []
|
||||
|
||||
def convert_durationfield_value(self, value, expression, connection):
|
||||
if value is not None:
|
||||
return datetime.timedelta(0, 0, value)
|
||||
|
||||
def check_expression_support(self, expression):
|
||||
"""
|
||||
Check that the backend supports the provided expression.
|
||||
|
||||
This is used on specific backends to rule out known expressions
|
||||
that have problematic or nonexistent implementations. If the
|
||||
expression has a known problem, the backend should raise
|
||||
NotSupportedError.
|
||||
"""
|
||||
pass
|
||||
|
||||
def conditional_expression_supported_in_where_clause(self, expression):
|
||||
"""
|
||||
Return True, if the conditional expression is supported in the WHERE
|
||||
clause.
|
||||
"""
|
||||
return True
|
||||
|
||||
def combine_expression(self, connector, sub_expressions):
|
||||
"""
|
||||
Combine a list of subexpressions into a single expression, using
|
||||
the provided connecting operator. This is required because operators
|
||||
can vary between backends (e.g., Oracle with %% and &) and between
|
||||
subexpression types (e.g., date expressions).
|
||||
"""
|
||||
conn = " %s " % connector
|
||||
return conn.join(sub_expressions)
|
||||
|
||||
def combine_duration_expression(self, connector, sub_expressions):
|
||||
return self.combine_expression(connector, sub_expressions)
|
||||
|
||||
def binary_placeholder_sql(self, value):
|
||||
"""
|
||||
Some backends require special syntax to insert binary content (MySQL
|
||||
for example uses '_binary %s').
|
||||
"""
|
||||
return "%s"
|
||||
|
||||
def modify_insert_params(self, placeholder, params):
|
||||
"""
|
||||
Allow modification of insert parameters. Needed for Oracle Spatial
|
||||
backend due to #10888.
|
||||
"""
|
||||
return params
|
||||
|
||||
def integer_field_range(self, internal_type):
|
||||
"""
|
||||
Given an integer field internal type (e.g. 'PositiveIntegerField'),
|
||||
return a tuple of the (min_value, max_value) form representing the
|
||||
range of the column type bound to the field.
|
||||
"""
|
||||
return self.integer_field_ranges[internal_type]
|
||||
|
||||
def subtract_temporals(self, internal_type, lhs, rhs):
|
||||
if self.connection.features.supports_temporal_subtraction:
|
||||
lhs_sql, lhs_params = lhs
|
||||
rhs_sql, rhs_params = rhs
|
||||
return "(%s - %s)" % (lhs_sql, rhs_sql), (*lhs_params, *rhs_params)
|
||||
raise NotSupportedError(
|
||||
"This backend does not support %s subtraction." % internal_type
|
||||
)
|
||||
|
||||
def window_frame_start(self, start):
|
||||
if isinstance(start, int):
|
||||
if start < 0:
|
||||
return "%d %s" % (abs(start), self.PRECEDING)
|
||||
elif start == 0:
|
||||
return self.CURRENT_ROW
|
||||
elif start is None:
|
||||
return self.UNBOUNDED_PRECEDING
|
||||
raise ValueError(
|
||||
"start argument must be a negative integer, zero, or None, but got '%s'."
|
||||
% start
|
||||
)
|
||||
|
||||
def window_frame_end(self, end):
|
||||
if isinstance(end, int):
|
||||
if end == 0:
|
||||
return self.CURRENT_ROW
|
||||
elif end > 0:
|
||||
return "%d %s" % (end, self.FOLLOWING)
|
||||
elif end is None:
|
||||
return self.UNBOUNDED_FOLLOWING
|
||||
raise ValueError(
|
||||
"end argument must be a positive integer, zero, or None, but got '%s'."
|
||||
% end
|
||||
)
|
||||
|
||||
def window_frame_rows_start_end(self, start=None, end=None):
|
||||
"""
|
||||
Return SQL for start and end points in an OVER clause window frame.
|
||||
"""
|
||||
if not self.connection.features.supports_over_clause:
|
||||
raise NotSupportedError("This backend does not support window expressions.")
|
||||
return self.window_frame_start(start), self.window_frame_end(end)
|
||||
|
||||
def window_frame_range_start_end(self, start=None, end=None):
|
||||
start_, end_ = self.window_frame_rows_start_end(start, end)
|
||||
features = self.connection.features
|
||||
if features.only_supports_unbounded_with_preceding_and_following and (
|
||||
(start and start < 0) or (end and end > 0)
|
||||
):
|
||||
raise NotSupportedError(
|
||||
"%s only supports UNBOUNDED together with PRECEDING and "
|
||||
"FOLLOWING." % self.connection.display_name
|
||||
)
|
||||
return start_, end_
|
||||
|
||||
def explain_query_prefix(self, format=None, **options):
|
||||
if not self.connection.features.supports_explaining_query_execution:
|
||||
raise NotSupportedError(
|
||||
"This backend does not support explaining query execution."
|
||||
)
|
||||
if format:
|
||||
supported_formats = self.connection.features.supported_explain_formats
|
||||
normalized_format = format.upper()
|
||||
if normalized_format not in supported_formats:
|
||||
msg = "%s is not a recognized format." % normalized_format
|
||||
if supported_formats:
|
||||
msg += " Allowed formats: %s" % ", ".join(sorted(supported_formats))
|
||||
else:
|
||||
msg += (
|
||||
f" {self.connection.display_name} does not support any formats."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
if options:
|
||||
raise ValueError("Unknown options: %s" % ", ".join(sorted(options.keys())))
|
||||
return self.explain_prefix
|
||||
|
||||
def insert_statement(self, on_conflict=None):
|
||||
return "INSERT INTO"
|
||||
|
||||
def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
|
||||
return ""
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,29 @@
|
||||
class BaseDatabaseValidation:
|
||||
"""Encapsulate backend-specific validation."""
|
||||
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
def check(self, **kwargs):
|
||||
return []
|
||||
|
||||
def check_field(self, field, **kwargs):
|
||||
errors = []
|
||||
# Backends may implement a check_field_type() method.
|
||||
if (
|
||||
hasattr(self, "check_field_type")
|
||||
and
|
||||
# Ignore any related fields.
|
||||
not getattr(field, "remote_field", None)
|
||||
):
|
||||
# Ignore fields with unsupported features.
|
||||
db_supports_all_required_features = all(
|
||||
getattr(self.connection.features, feature, False)
|
||||
for feature in field.model._meta.required_db_features
|
||||
)
|
||||
if db_supports_all_required_features:
|
||||
field_type = field.db_type(self.connection)
|
||||
# Ignore non-concrete fields.
|
||||
if field_type is not None:
|
||||
errors.extend(self.check_field_type(field, field_type))
|
||||
return errors
|
||||
@@ -0,0 +1,254 @@
|
||||
"""
|
||||
Helpers to manipulate deferred DDL statements that might need to be adjusted or
|
||||
discarded within when executing a migration.
|
||||
"""
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
class Reference:
|
||||
"""Base class that defines the reference interface."""
|
||||
|
||||
def references_table(self, table):
|
||||
"""
|
||||
Return whether or not this instance references the specified table.
|
||||
"""
|
||||
return False
|
||||
|
||||
def references_column(self, table, column):
|
||||
"""
|
||||
Return whether or not this instance references the specified column.
|
||||
"""
|
||||
return False
|
||||
|
||||
def rename_table_references(self, old_table, new_table):
|
||||
"""
|
||||
Rename all references to the old_name to the new_table.
|
||||
"""
|
||||
pass
|
||||
|
||||
def rename_column_references(self, table, old_column, new_column):
|
||||
"""
|
||||
Rename all references to the old_column to the new_column.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s %r>" % (self.__class__.__name__, str(self))
|
||||
|
||||
def __str__(self):
|
||||
raise NotImplementedError(
|
||||
"Subclasses must define how they should be converted to string."
|
||||
)
|
||||
|
||||
|
||||
class Table(Reference):
|
||||
"""Hold a reference to a table."""
|
||||
|
||||
def __init__(self, table, quote_name):
|
||||
self.table = table
|
||||
self.quote_name = quote_name
|
||||
|
||||
def references_table(self, table):
|
||||
return self.table == table
|
||||
|
||||
def rename_table_references(self, old_table, new_table):
|
||||
if self.table == old_table:
|
||||
self.table = new_table
|
||||
|
||||
def __str__(self):
|
||||
return self.quote_name(self.table)
|
||||
|
||||
|
||||
class TableColumns(Table):
|
||||
"""Base class for references to multiple columns of a table."""
|
||||
|
||||
def __init__(self, table, columns):
|
||||
self.table = table
|
||||
self.columns = columns
|
||||
|
||||
def references_column(self, table, column):
|
||||
return self.table == table and column in self.columns
|
||||
|
||||
def rename_column_references(self, table, old_column, new_column):
|
||||
if self.table == table:
|
||||
for index, column in enumerate(self.columns):
|
||||
if column == old_column:
|
||||
self.columns[index] = new_column
|
||||
|
||||
|
||||
class Columns(TableColumns):
|
||||
"""Hold a reference to one or many columns."""
|
||||
|
||||
def __init__(self, table, columns, quote_name, col_suffixes=()):
|
||||
self.quote_name = quote_name
|
||||
self.col_suffixes = col_suffixes
|
||||
super().__init__(table, columns)
|
||||
|
||||
def __str__(self):
|
||||
def col_str(column, idx):
|
||||
col = self.quote_name(column)
|
||||
try:
|
||||
suffix = self.col_suffixes[idx]
|
||||
if suffix:
|
||||
col = "{} {}".format(col, suffix)
|
||||
except IndexError:
|
||||
pass
|
||||
return col
|
||||
|
||||
return ", ".join(
|
||||
col_str(column, idx) for idx, column in enumerate(self.columns)
|
||||
)
|
||||
|
||||
|
||||
class IndexName(TableColumns):
|
||||
"""Hold a reference to an index name."""
|
||||
|
||||
def __init__(self, table, columns, suffix, create_index_name):
|
||||
self.suffix = suffix
|
||||
self.create_index_name = create_index_name
|
||||
super().__init__(table, columns)
|
||||
|
||||
def __str__(self):
|
||||
return self.create_index_name(self.table, self.columns, self.suffix)
|
||||
|
||||
|
||||
class IndexColumns(Columns):
|
||||
def __init__(self, table, columns, quote_name, col_suffixes=(), opclasses=()):
|
||||
self.opclasses = opclasses
|
||||
super().__init__(table, columns, quote_name, col_suffixes)
|
||||
|
||||
def __str__(self):
|
||||
def col_str(column, idx):
|
||||
# Index.__init__() guarantees that self.opclasses is the same
|
||||
# length as self.columns.
|
||||
col = "{} {}".format(self.quote_name(column), self.opclasses[idx])
|
||||
try:
|
||||
suffix = self.col_suffixes[idx]
|
||||
if suffix:
|
||||
col = "{} {}".format(col, suffix)
|
||||
except IndexError:
|
||||
pass
|
||||
return col
|
||||
|
||||
return ", ".join(
|
||||
col_str(column, idx) for idx, column in enumerate(self.columns)
|
||||
)
|
||||
|
||||
|
||||
class ForeignKeyName(TableColumns):
|
||||
"""Hold a reference to a foreign key name."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
from_table,
|
||||
from_columns,
|
||||
to_table,
|
||||
to_columns,
|
||||
suffix_template,
|
||||
create_fk_name,
|
||||
):
|
||||
self.to_reference = TableColumns(to_table, to_columns)
|
||||
self.suffix_template = suffix_template
|
||||
self.create_fk_name = create_fk_name
|
||||
super().__init__(
|
||||
from_table,
|
||||
from_columns,
|
||||
)
|
||||
|
||||
def references_table(self, table):
|
||||
return super().references_table(table) or self.to_reference.references_table(
|
||||
table
|
||||
)
|
||||
|
||||
def references_column(self, table, column):
|
||||
return super().references_column(
|
||||
table, column
|
||||
) or self.to_reference.references_column(table, column)
|
||||
|
||||
def rename_table_references(self, old_table, new_table):
|
||||
super().rename_table_references(old_table, new_table)
|
||||
self.to_reference.rename_table_references(old_table, new_table)
|
||||
|
||||
def rename_column_references(self, table, old_column, new_column):
|
||||
super().rename_column_references(table, old_column, new_column)
|
||||
self.to_reference.rename_column_references(table, old_column, new_column)
|
||||
|
||||
def __str__(self):
|
||||
suffix = self.suffix_template % {
|
||||
"to_table": self.to_reference.table,
|
||||
"to_column": self.to_reference.columns[0],
|
||||
}
|
||||
return self.create_fk_name(self.table, self.columns, suffix)
|
||||
|
||||
|
||||
class Statement(Reference):
|
||||
"""
|
||||
Statement template and formatting parameters container.
|
||||
|
||||
Allows keeping a reference to a statement without interpolating identifiers
|
||||
that might have to be adjusted if they're referencing a table or column
|
||||
that is removed
|
||||
"""
|
||||
|
||||
def __init__(self, template, **parts):
|
||||
self.template = template
|
||||
self.parts = parts
|
||||
|
||||
def references_table(self, table):
|
||||
return any(
|
||||
hasattr(part, "references_table") and part.references_table(table)
|
||||
for part in self.parts.values()
|
||||
)
|
||||
|
||||
def references_column(self, table, column):
|
||||
return any(
|
||||
hasattr(part, "references_column") and part.references_column(table, column)
|
||||
for part in self.parts.values()
|
||||
)
|
||||
|
||||
def rename_table_references(self, old_table, new_table):
|
||||
for part in self.parts.values():
|
||||
if hasattr(part, "rename_table_references"):
|
||||
part.rename_table_references(old_table, new_table)
|
||||
|
||||
def rename_column_references(self, table, old_column, new_column):
|
||||
for part in self.parts.values():
|
||||
if hasattr(part, "rename_column_references"):
|
||||
part.rename_column_references(table, old_column, new_column)
|
||||
|
||||
def __str__(self):
|
||||
return self.template % self.parts
|
||||
|
||||
|
||||
class Expressions(TableColumns):
|
||||
def __init__(self, table, expressions, compiler, quote_value):
|
||||
self.compiler = compiler
|
||||
self.expressions = expressions
|
||||
self.quote_value = quote_value
|
||||
columns = [
|
||||
col.target.column
|
||||
for col in self.compiler.query._gen_cols([self.expressions])
|
||||
]
|
||||
super().__init__(table, columns)
|
||||
|
||||
def rename_table_references(self, old_table, new_table):
|
||||
if self.table != old_table:
|
||||
return
|
||||
self.expressions = self.expressions.relabeled_clone({old_table: new_table})
|
||||
super().rename_table_references(old_table, new_table)
|
||||
|
||||
def rename_column_references(self, table, old_column, new_column):
|
||||
if self.table != table:
|
||||
return
|
||||
expressions = deepcopy(self.expressions)
|
||||
self.columns = []
|
||||
for col in self.compiler.query._gen_cols([expressions]):
|
||||
if col.target.column == old_column:
|
||||
col.target.column = new_column
|
||||
self.columns.append(col.target.column)
|
||||
self.expressions = expressions
|
||||
|
||||
def __str__(self):
|
||||
sql, params = self.compiler.compile(self.expressions)
|
||||
params = map(self.quote_value, params)
|
||||
return sql % tuple(params)
|
||||
@@ -0,0 +1,74 @@
|
||||
"""
|
||||
Dummy database backend for Django.
|
||||
|
||||
Django uses this if the database ENGINE setting is empty (None or empty string).
|
||||
|
||||
Each of these API functions, except connection.close(), raise
|
||||
ImproperlyConfigured.
|
||||
"""
|
||||
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.db.backends.base.client import BaseDatabaseClient
|
||||
from django.db.backends.base.creation import BaseDatabaseCreation
|
||||
from django.db.backends.base.introspection import BaseDatabaseIntrospection
|
||||
from django.db.backends.base.operations import BaseDatabaseOperations
|
||||
from django.db.backends.dummy.features import DummyDatabaseFeatures
|
||||
|
||||
|
||||
def complain(*args, **kwargs):
|
||||
raise ImproperlyConfigured(
|
||||
"settings.DATABASES is improperly configured. "
|
||||
"Please supply the ENGINE value. Check "
|
||||
"settings documentation for more details."
|
||||
)
|
||||
|
||||
|
||||
def ignore(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseOperations(BaseDatabaseOperations):
|
||||
quote_name = complain
|
||||
|
||||
|
||||
class DatabaseClient(BaseDatabaseClient):
|
||||
runshell = complain
|
||||
|
||||
|
||||
class DatabaseCreation(BaseDatabaseCreation):
|
||||
create_test_db = ignore
|
||||
destroy_test_db = ignore
|
||||
|
||||
|
||||
class DatabaseIntrospection(BaseDatabaseIntrospection):
|
||||
get_table_list = complain
|
||||
get_table_description = complain
|
||||
get_relations = complain
|
||||
get_indexes = complain
|
||||
|
||||
|
||||
class DatabaseWrapper(BaseDatabaseWrapper):
|
||||
operators = {}
|
||||
# Override the base class implementations with null
|
||||
# implementations. Anything that tries to actually
|
||||
# do something raises complain; anything that tries
|
||||
# to rollback or undo something raises ignore.
|
||||
_cursor = complain
|
||||
ensure_connection = complain
|
||||
_commit = complain
|
||||
_rollback = ignore
|
||||
_close = ignore
|
||||
_savepoint = ignore
|
||||
_savepoint_commit = complain
|
||||
_savepoint_rollback = ignore
|
||||
_set_autocommit = complain
|
||||
# Classes instantiated in __init__().
|
||||
client_class = DatabaseClient
|
||||
creation_class = DatabaseCreation
|
||||
features_class = DummyDatabaseFeatures
|
||||
introspection_class = DatabaseIntrospection
|
||||
ops_class = DatabaseOperations
|
||||
|
||||
def is_usable(self):
|
||||
return True
|
||||
@@ -0,0 +1,6 @@
|
||||
from django.db.backends.base.features import BaseDatabaseFeatures
|
||||
|
||||
|
||||
class DummyDatabaseFeatures(BaseDatabaseFeatures):
|
||||
supports_transactions = False
|
||||
uses_savepoints = False
|
||||
@@ -0,0 +1,444 @@
|
||||
"""
|
||||
MySQL database backend for Django.
|
||||
|
||||
Requires mysqlclient: https://pypi.org/project/mysqlclient/
|
||||
"""
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db import IntegrityError
|
||||
from django.db.backends import utils as backend_utils
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.utils.asyncio import async_unsafe
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
|
||||
try:
|
||||
import MySQLdb as Database
|
||||
except ImportError as err:
|
||||
raise ImproperlyConfigured(
|
||||
"Error loading MySQLdb module.\nDid you install mysqlclient?"
|
||||
) from err
|
||||
|
||||
from MySQLdb.constants import CLIENT, FIELD_TYPE
|
||||
from MySQLdb.converters import conversions
|
||||
|
||||
# Some of these import MySQLdb, so import them after checking if it's installed.
|
||||
from .client import DatabaseClient
|
||||
from .creation import DatabaseCreation
|
||||
from .features import DatabaseFeatures
|
||||
from .introspection import DatabaseIntrospection
|
||||
from .operations import DatabaseOperations
|
||||
from .schema import DatabaseSchemaEditor
|
||||
from .validation import DatabaseValidation
|
||||
|
||||
version = Database.version_info
|
||||
if version < (1, 4, 3):
|
||||
raise ImproperlyConfigured(
|
||||
"mysqlclient 1.4.3 or newer is required; you have %s." % Database.__version__
|
||||
)
|
||||
|
||||
|
||||
# MySQLdb returns TIME columns as timedelta -- they are more like timedelta in
|
||||
# terms of actual behavior as they are signed and include days -- and Django
|
||||
# expects time.
|
||||
django_conversions = {
|
||||
**conversions,
|
||||
**{FIELD_TYPE.TIME: backend_utils.typecast_time},
|
||||
}
|
||||
|
||||
# This should match the numerical portion of the version numbers (we can treat
|
||||
# versions like 5.0.24 and 5.0.24a as the same).
|
||||
server_version_re = _lazy_re_compile(r"(\d{1,2})\.(\d{1,2})\.(\d{1,2})")
|
||||
|
||||
|
||||
class CursorWrapper:
|
||||
"""
|
||||
A thin wrapper around MySQLdb's normal cursor class that catches particular
|
||||
exception instances and reraises them with the correct types.
|
||||
|
||||
Implemented as a wrapper, rather than a subclass, so that it isn't stuck
|
||||
to the particular underlying representation returned by Connection.cursor().
|
||||
"""
|
||||
|
||||
codes_for_integrityerror = (
|
||||
1048, # Column cannot be null
|
||||
1690, # BIGINT UNSIGNED value is out of range
|
||||
3819, # CHECK constraint is violated
|
||||
4025, # CHECK constraint failed
|
||||
)
|
||||
|
||||
def __init__(self, cursor):
|
||||
self.cursor = cursor
|
||||
|
||||
def execute(self, query, args=None):
|
||||
try:
|
||||
# args is None means no string interpolation
|
||||
return self.cursor.execute(query, args)
|
||||
except Database.OperationalError as e:
|
||||
# Map some error codes to IntegrityError, since they seem to be
|
||||
# misclassified and Django would prefer the more logical place.
|
||||
if e.args[0] in self.codes_for_integrityerror:
|
||||
raise IntegrityError(*tuple(e.args))
|
||||
raise
|
||||
|
||||
def executemany(self, query, args):
|
||||
try:
|
||||
return self.cursor.executemany(query, args)
|
||||
except Database.OperationalError as e:
|
||||
# Map some error codes to IntegrityError, since they seem to be
|
||||
# misclassified and Django would prefer the more logical place.
|
||||
if e.args[0] in self.codes_for_integrityerror:
|
||||
raise IntegrityError(*tuple(e.args))
|
||||
raise
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.cursor, attr)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.cursor)
|
||||
|
||||
|
||||
class DatabaseWrapper(BaseDatabaseWrapper):
|
||||
vendor = "mysql"
|
||||
# This dictionary maps Field objects to their associated MySQL column
|
||||
# types, as strings. Column-type strings can contain format strings; they'll
|
||||
# be interpolated against the values of Field.__dict__ before being output.
|
||||
# If a column type is set to None, it won't be included in the output.
|
||||
data_types = {
|
||||
"AutoField": "integer AUTO_INCREMENT",
|
||||
"BigAutoField": "bigint AUTO_INCREMENT",
|
||||
"BinaryField": "longblob",
|
||||
"BooleanField": "bool",
|
||||
"CharField": "varchar(%(max_length)s)",
|
||||
"DateField": "date",
|
||||
"DateTimeField": "datetime(6)",
|
||||
"DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)",
|
||||
"DurationField": "bigint",
|
||||
"FileField": "varchar(%(max_length)s)",
|
||||
"FilePathField": "varchar(%(max_length)s)",
|
||||
"FloatField": "double precision",
|
||||
"IntegerField": "integer",
|
||||
"BigIntegerField": "bigint",
|
||||
"IPAddressField": "char(15)",
|
||||
"GenericIPAddressField": "char(39)",
|
||||
"JSONField": "json",
|
||||
"OneToOneField": "integer",
|
||||
"PositiveBigIntegerField": "bigint UNSIGNED",
|
||||
"PositiveIntegerField": "integer UNSIGNED",
|
||||
"PositiveSmallIntegerField": "smallint UNSIGNED",
|
||||
"SlugField": "varchar(%(max_length)s)",
|
||||
"SmallAutoField": "smallint AUTO_INCREMENT",
|
||||
"SmallIntegerField": "smallint",
|
||||
"TextField": "longtext",
|
||||
"TimeField": "time(6)",
|
||||
"UUIDField": "char(32)",
|
||||
}
|
||||
|
||||
# For these data types:
|
||||
# - MySQL < 8.0.13 doesn't accept default values and implicitly treats them
|
||||
# as nullable
|
||||
# - all versions of MySQL and MariaDB don't support full width database
|
||||
# indexes
|
||||
_limited_data_types = (
|
||||
"tinyblob",
|
||||
"blob",
|
||||
"mediumblob",
|
||||
"longblob",
|
||||
"tinytext",
|
||||
"text",
|
||||
"mediumtext",
|
||||
"longtext",
|
||||
"json",
|
||||
)
|
||||
|
||||
operators = {
|
||||
"exact": "= %s",
|
||||
"iexact": "LIKE %s",
|
||||
"contains": "LIKE BINARY %s",
|
||||
"icontains": "LIKE %s",
|
||||
"gt": "> %s",
|
||||
"gte": ">= %s",
|
||||
"lt": "< %s",
|
||||
"lte": "<= %s",
|
||||
"startswith": "LIKE BINARY %s",
|
||||
"endswith": "LIKE BINARY %s",
|
||||
"istartswith": "LIKE %s",
|
||||
"iendswith": "LIKE %s",
|
||||
}
|
||||
|
||||
# The patterns below are used to generate SQL pattern lookup clauses when
|
||||
# the right-hand side of the lookup isn't a raw string (it might be an expression
|
||||
# or the result of a bilateral transformation).
|
||||
# In those cases, special characters for LIKE operators (e.g. \, *, _) should be
|
||||
# escaped on database side.
|
||||
#
|
||||
# Note: we use str.format() here for readability as '%' is used as a wildcard for
|
||||
# the LIKE operator.
|
||||
pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\\', '\\\\'), '%%', '\%%'), '_', '\_')"
|
||||
pattern_ops = {
|
||||
"contains": "LIKE BINARY CONCAT('%%', {}, '%%')",
|
||||
"icontains": "LIKE CONCAT('%%', {}, '%%')",
|
||||
"startswith": "LIKE BINARY CONCAT({}, '%%')",
|
||||
"istartswith": "LIKE CONCAT({}, '%%')",
|
||||
"endswith": "LIKE BINARY CONCAT('%%', {})",
|
||||
"iendswith": "LIKE CONCAT('%%', {})",
|
||||
}
|
||||
|
||||
isolation_levels = {
|
||||
"read uncommitted",
|
||||
"read committed",
|
||||
"repeatable read",
|
||||
"serializable",
|
||||
}
|
||||
|
||||
Database = Database
|
||||
SchemaEditorClass = DatabaseSchemaEditor
|
||||
# Classes instantiated in __init__().
|
||||
client_class = DatabaseClient
|
||||
creation_class = DatabaseCreation
|
||||
features_class = DatabaseFeatures
|
||||
introspection_class = DatabaseIntrospection
|
||||
ops_class = DatabaseOperations
|
||||
validation_class = DatabaseValidation
|
||||
|
||||
def get_database_version(self):
|
||||
return self.mysql_version
|
||||
|
||||
def get_connection_params(self):
|
||||
kwargs = {
|
||||
"conv": django_conversions,
|
||||
"charset": "utf8",
|
||||
}
|
||||
settings_dict = self.settings_dict
|
||||
if settings_dict["USER"]:
|
||||
kwargs["user"] = settings_dict["USER"]
|
||||
if settings_dict["NAME"]:
|
||||
kwargs["database"] = settings_dict["NAME"]
|
||||
if settings_dict["PASSWORD"]:
|
||||
kwargs["password"] = settings_dict["PASSWORD"]
|
||||
if settings_dict["HOST"].startswith("/"):
|
||||
kwargs["unix_socket"] = settings_dict["HOST"]
|
||||
elif settings_dict["HOST"]:
|
||||
kwargs["host"] = settings_dict["HOST"]
|
||||
if settings_dict["PORT"]:
|
||||
kwargs["port"] = int(settings_dict["PORT"])
|
||||
# We need the number of potentially affected rows after an
|
||||
# "UPDATE", not the number of changed rows.
|
||||
kwargs["client_flag"] = CLIENT.FOUND_ROWS
|
||||
# Validate the transaction isolation level, if specified.
|
||||
options = settings_dict["OPTIONS"].copy()
|
||||
isolation_level = options.pop("isolation_level", "read committed")
|
||||
if isolation_level:
|
||||
isolation_level = isolation_level.lower()
|
||||
if isolation_level not in self.isolation_levels:
|
||||
raise ImproperlyConfigured(
|
||||
"Invalid transaction isolation level '%s' specified.\n"
|
||||
"Use one of %s, or None."
|
||||
% (
|
||||
isolation_level,
|
||||
", ".join("'%s'" % s for s in sorted(self.isolation_levels)),
|
||||
)
|
||||
)
|
||||
self.isolation_level = isolation_level
|
||||
kwargs.update(options)
|
||||
return kwargs
|
||||
|
||||
@async_unsafe
|
||||
def get_new_connection(self, conn_params):
|
||||
connection = Database.connect(**conn_params)
|
||||
# bytes encoder in mysqlclient doesn't work and was added only to
|
||||
# prevent KeyErrors in Django < 2.0. We can remove this workaround when
|
||||
# mysqlclient 2.1 becomes the minimal mysqlclient supported by Django.
|
||||
# See https://github.com/PyMySQL/mysqlclient/issues/489
|
||||
if connection.encoders.get(bytes) is bytes:
|
||||
connection.encoders.pop(bytes)
|
||||
return connection
|
||||
|
||||
def init_connection_state(self):
|
||||
super().init_connection_state()
|
||||
assignments = []
|
||||
if self.features.is_sql_auto_is_null_enabled:
|
||||
# SQL_AUTO_IS_NULL controls whether an AUTO_INCREMENT column on
|
||||
# a recently inserted row will return when the field is tested
|
||||
# for NULL. Disabling this brings this aspect of MySQL in line
|
||||
# with SQL standards.
|
||||
assignments.append("SET SQL_AUTO_IS_NULL = 0")
|
||||
|
||||
if self.isolation_level:
|
||||
assignments.append(
|
||||
"SET SESSION TRANSACTION ISOLATION LEVEL %s"
|
||||
% self.isolation_level.upper()
|
||||
)
|
||||
|
||||
if assignments:
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute("; ".join(assignments))
|
||||
|
||||
@async_unsafe
|
||||
def create_cursor(self, name=None):
|
||||
cursor = self.connection.cursor()
|
||||
return CursorWrapper(cursor)
|
||||
|
||||
def _rollback(self):
|
||||
try:
|
||||
BaseDatabaseWrapper._rollback(self)
|
||||
except Database.NotSupportedError:
|
||||
pass
|
||||
|
||||
def _set_autocommit(self, autocommit):
|
||||
with self.wrap_database_errors:
|
||||
self.connection.autocommit(autocommit)
|
||||
|
||||
def disable_constraint_checking(self):
|
||||
"""
|
||||
Disable foreign key checks, primarily for use in adding rows with
|
||||
forward references. Always return True to indicate constraint checks
|
||||
need to be re-enabled.
|
||||
"""
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute("SET foreign_key_checks=0")
|
||||
return True
|
||||
|
||||
def enable_constraint_checking(self):
|
||||
"""
|
||||
Re-enable foreign key checks after they have been disabled.
|
||||
"""
|
||||
# Override needs_rollback in case constraint_checks_disabled is
|
||||
# nested inside transaction.atomic.
|
||||
self.needs_rollback, needs_rollback = False, self.needs_rollback
|
||||
try:
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute("SET foreign_key_checks=1")
|
||||
finally:
|
||||
self.needs_rollback = needs_rollback
|
||||
|
||||
def check_constraints(self, table_names=None):
|
||||
"""
|
||||
Check each table name in `table_names` for rows with invalid foreign
|
||||
key references. This method is intended to be used in conjunction with
|
||||
`disable_constraint_checking()` and `enable_constraint_checking()`, to
|
||||
determine if rows with invalid references were entered while constraint
|
||||
checks were off.
|
||||
"""
|
||||
with self.cursor() as cursor:
|
||||
if table_names is None:
|
||||
table_names = self.introspection.table_names(cursor)
|
||||
for table_name in table_names:
|
||||
primary_key_column_name = self.introspection.get_primary_key_column(
|
||||
cursor, table_name
|
||||
)
|
||||
if not primary_key_column_name:
|
||||
continue
|
||||
relations = self.introspection.get_relations(cursor, table_name)
|
||||
for column_name, (
|
||||
referenced_column_name,
|
||||
referenced_table_name,
|
||||
) in relations.items():
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING
|
||||
LEFT JOIN `%s` as REFERRED
|
||||
ON (REFERRING.`%s` = REFERRED.`%s`)
|
||||
WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL
|
||||
"""
|
||||
% (
|
||||
primary_key_column_name,
|
||||
column_name,
|
||||
table_name,
|
||||
referenced_table_name,
|
||||
column_name,
|
||||
referenced_column_name,
|
||||
column_name,
|
||||
referenced_column_name,
|
||||
)
|
||||
)
|
||||
for bad_row in cursor.fetchall():
|
||||
raise IntegrityError(
|
||||
"The row in table '%s' with primary key '%s' has an "
|
||||
"invalid foreign key: %s.%s contains a value '%s' that "
|
||||
"does not have a corresponding value in %s.%s."
|
||||
% (
|
||||
table_name,
|
||||
bad_row[0],
|
||||
table_name,
|
||||
column_name,
|
||||
bad_row[1],
|
||||
referenced_table_name,
|
||||
referenced_column_name,
|
||||
)
|
||||
)
|
||||
|
||||
def is_usable(self):
|
||||
try:
|
||||
self.connection.ping()
|
||||
except Database.Error:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
@cached_property
|
||||
def display_name(self):
|
||||
return "MariaDB" if self.mysql_is_mariadb else "MySQL"
|
||||
|
||||
@cached_property
|
||||
def data_type_check_constraints(self):
|
||||
if self.features.supports_column_check_constraints:
|
||||
check_constraints = {
|
||||
"PositiveBigIntegerField": "`%(column)s` >= 0",
|
||||
"PositiveIntegerField": "`%(column)s` >= 0",
|
||||
"PositiveSmallIntegerField": "`%(column)s` >= 0",
|
||||
}
|
||||
if self.mysql_is_mariadb and self.mysql_version < (10, 4, 3):
|
||||
# MariaDB < 10.4.3 doesn't automatically use the JSON_VALID as
|
||||
# a check constraint.
|
||||
check_constraints["JSONField"] = "JSON_VALID(`%(column)s`)"
|
||||
return check_constraints
|
||||
return {}
|
||||
|
||||
@cached_property
|
||||
def mysql_server_data(self):
|
||||
with self.temporary_connection() as cursor:
|
||||
# Select some server variables and test if the time zone
|
||||
# definitions are installed. CONVERT_TZ returns NULL if 'UTC'
|
||||
# timezone isn't loaded into the mysql.time_zone table.
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT VERSION(),
|
||||
@@sql_mode,
|
||||
@@default_storage_engine,
|
||||
@@sql_auto_is_null,
|
||||
@@lower_case_table_names,
|
||||
CONVERT_TZ('2001-01-01 01:00:00', 'UTC', 'UTC') IS NOT NULL
|
||||
"""
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return {
|
||||
"version": row[0],
|
||||
"sql_mode": row[1],
|
||||
"default_storage_engine": row[2],
|
||||
"sql_auto_is_null": bool(row[3]),
|
||||
"lower_case_table_names": bool(row[4]),
|
||||
"has_zoneinfo_database": bool(row[5]),
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def mysql_server_info(self):
|
||||
return self.mysql_server_data["version"]
|
||||
|
||||
@cached_property
|
||||
def mysql_version(self):
|
||||
match = server_version_re.match(self.mysql_server_info)
|
||||
if not match:
|
||||
raise Exception(
|
||||
"Unable to determine MySQL version from version string %r"
|
||||
% self.mysql_server_info
|
||||
)
|
||||
return tuple(int(x) for x in match.groups())
|
||||
|
||||
@cached_property
|
||||
def mysql_is_mariadb(self):
|
||||
return "mariadb" in self.mysql_server_info.lower()
|
||||
|
||||
@cached_property
|
||||
def sql_mode(self):
|
||||
sql_mode = self.mysql_server_data["sql_mode"]
|
||||
return set(sql_mode.split(",") if sql_mode else ())
|
||||
@@ -0,0 +1,72 @@
|
||||
import signal
|
||||
|
||||
from django.db.backends.base.client import BaseDatabaseClient
|
||||
|
||||
|
||||
class DatabaseClient(BaseDatabaseClient):
|
||||
executable_name = "mysql"
|
||||
|
||||
@classmethod
|
||||
def settings_to_cmd_args_env(cls, settings_dict, parameters):
|
||||
args = [cls.executable_name]
|
||||
env = None
|
||||
database = settings_dict["OPTIONS"].get(
|
||||
"database",
|
||||
settings_dict["OPTIONS"].get("db", settings_dict["NAME"]),
|
||||
)
|
||||
user = settings_dict["OPTIONS"].get("user", settings_dict["USER"])
|
||||
password = settings_dict["OPTIONS"].get(
|
||||
"password",
|
||||
settings_dict["OPTIONS"].get("passwd", settings_dict["PASSWORD"]),
|
||||
)
|
||||
host = settings_dict["OPTIONS"].get("host", settings_dict["HOST"])
|
||||
port = settings_dict["OPTIONS"].get("port", settings_dict["PORT"])
|
||||
server_ca = settings_dict["OPTIONS"].get("ssl", {}).get("ca")
|
||||
client_cert = settings_dict["OPTIONS"].get("ssl", {}).get("cert")
|
||||
client_key = settings_dict["OPTIONS"].get("ssl", {}).get("key")
|
||||
defaults_file = settings_dict["OPTIONS"].get("read_default_file")
|
||||
charset = settings_dict["OPTIONS"].get("charset")
|
||||
# Seems to be no good way to set sql_mode with CLI.
|
||||
|
||||
if defaults_file:
|
||||
args += ["--defaults-file=%s" % defaults_file]
|
||||
if user:
|
||||
args += ["--user=%s" % user]
|
||||
if password:
|
||||
# The MYSQL_PWD environment variable usage is discouraged per
|
||||
# MySQL's documentation due to the possibility of exposure through
|
||||
# `ps` on old Unix flavors but --password suffers from the same
|
||||
# flaw on even more systems. Usage of an environment variable also
|
||||
# prevents password exposure if the subprocess.run(check=True) call
|
||||
# raises a CalledProcessError since the string representation of
|
||||
# the latter includes all of the provided `args`.
|
||||
env = {"MYSQL_PWD": password}
|
||||
if host:
|
||||
if "/" in host:
|
||||
args += ["--socket=%s" % host]
|
||||
else:
|
||||
args += ["--host=%s" % host]
|
||||
if port:
|
||||
args += ["--port=%s" % port]
|
||||
if server_ca:
|
||||
args += ["--ssl-ca=%s" % server_ca]
|
||||
if client_cert:
|
||||
args += ["--ssl-cert=%s" % client_cert]
|
||||
if client_key:
|
||||
args += ["--ssl-key=%s" % client_key]
|
||||
if charset:
|
||||
args += ["--default-character-set=%s" % charset]
|
||||
if database:
|
||||
args += [database]
|
||||
args.extend(parameters)
|
||||
return args, env
|
||||
|
||||
def runshell(self, parameters):
|
||||
sigint_handler = signal.getsignal(signal.SIGINT)
|
||||
try:
|
||||
# Allow SIGINT to pass to mysql to abort queries.
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
super().runshell(parameters)
|
||||
finally:
|
||||
# Restore the original SIGINT handler.
|
||||
signal.signal(signal.SIGINT, sigint_handler)
|
||||
@@ -0,0 +1,84 @@
|
||||
from django.core.exceptions import FieldError, FullResultSet
|
||||
from django.db.models.expressions import Col
|
||||
from django.db.models.sql import compiler
|
||||
|
||||
|
||||
class SQLCompiler(compiler.SQLCompiler):
|
||||
def as_subquery_condition(self, alias, columns, compiler):
|
||||
qn = compiler.quote_name_unless_alias
|
||||
qn2 = self.connection.ops.quote_name
|
||||
sql, params = self.as_sql()
|
||||
return (
|
||||
"(%s) IN (%s)"
|
||||
% (
|
||||
", ".join("%s.%s" % (qn(alias), qn2(column)) for column in columns),
|
||||
sql,
|
||||
),
|
||||
params,
|
||||
)
|
||||
|
||||
|
||||
class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler):
|
||||
pass
|
||||
|
||||
|
||||
class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
|
||||
def as_sql(self):
|
||||
# Prefer the non-standard DELETE FROM syntax over the SQL generated by
|
||||
# the SQLDeleteCompiler's default implementation when multiple tables
|
||||
# are involved since MySQL/MariaDB will generate a more efficient query
|
||||
# plan than when using a subquery.
|
||||
where, having, qualify = self.query.where.split_having_qualify(
|
||||
must_group_by=self.query.group_by is not None
|
||||
)
|
||||
if self.single_alias or having or qualify:
|
||||
# DELETE FROM cannot be used when filtering against aggregates or
|
||||
# window functions as it doesn't allow for GROUP BY/HAVING clauses
|
||||
# and the subquery wrapping (necessary to emulate QUALIFY).
|
||||
return super().as_sql()
|
||||
result = [
|
||||
"DELETE %s FROM"
|
||||
% self.quote_name_unless_alias(self.query.get_initial_alias())
|
||||
]
|
||||
from_sql, params = self.get_from_clause()
|
||||
result.extend(from_sql)
|
||||
try:
|
||||
where_sql, where_params = self.compile(where)
|
||||
except FullResultSet:
|
||||
pass
|
||||
else:
|
||||
result.append("WHERE %s" % where_sql)
|
||||
params.extend(where_params)
|
||||
return " ".join(result), tuple(params)
|
||||
|
||||
|
||||
class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
|
||||
def as_sql(self):
|
||||
update_query, update_params = super().as_sql()
|
||||
# MySQL and MariaDB support UPDATE ... ORDER BY syntax.
|
||||
if self.query.order_by:
|
||||
order_by_sql = []
|
||||
order_by_params = []
|
||||
db_table = self.query.get_meta().db_table
|
||||
try:
|
||||
for resolved, (sql, params, _) in self.get_order_by():
|
||||
if (
|
||||
isinstance(resolved.expression, Col)
|
||||
and resolved.expression.alias != db_table
|
||||
):
|
||||
# Ignore ordering if it contains joined fields, because
|
||||
# they cannot be used in the ORDER BY clause.
|
||||
raise FieldError
|
||||
order_by_sql.append(sql)
|
||||
order_by_params.extend(params)
|
||||
update_query += " ORDER BY " + ", ".join(order_by_sql)
|
||||
update_params += tuple(order_by_params)
|
||||
except FieldError:
|
||||
# Ignore ordering if it contains annotations, because they're
|
||||
# removed in .update() and cannot be resolved.
|
||||
pass
|
||||
return update_query, update_params
|
||||
|
||||
|
||||
class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler):
|
||||
pass
|
||||
@@ -0,0 +1,87 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from django.db.backends.base.creation import BaseDatabaseCreation
|
||||
|
||||
from .client import DatabaseClient
|
||||
|
||||
|
||||
class DatabaseCreation(BaseDatabaseCreation):
|
||||
def sql_table_creation_suffix(self):
|
||||
suffix = []
|
||||
test_settings = self.connection.settings_dict["TEST"]
|
||||
if test_settings["CHARSET"]:
|
||||
suffix.append("CHARACTER SET %s" % test_settings["CHARSET"])
|
||||
if test_settings["COLLATION"]:
|
||||
suffix.append("COLLATE %s" % test_settings["COLLATION"])
|
||||
return " ".join(suffix)
|
||||
|
||||
def _execute_create_test_db(self, cursor, parameters, keepdb=False):
|
||||
try:
|
||||
super()._execute_create_test_db(cursor, parameters, keepdb)
|
||||
except Exception as e:
|
||||
if len(e.args) < 1 or e.args[0] != 1007:
|
||||
# All errors except "database exists" (1007) cancel tests.
|
||||
self.log("Got an error creating the test database: %s" % e)
|
||||
sys.exit(2)
|
||||
else:
|
||||
raise
|
||||
|
||||
def _clone_test_db(self, suffix, verbosity, keepdb=False):
|
||||
source_database_name = self.connection.settings_dict["NAME"]
|
||||
target_database_name = self.get_test_db_clone_settings(suffix)["NAME"]
|
||||
test_db_params = {
|
||||
"dbname": self.connection.ops.quote_name(target_database_name),
|
||||
"suffix": self.sql_table_creation_suffix(),
|
||||
}
|
||||
with self._nodb_cursor() as cursor:
|
||||
try:
|
||||
self._execute_create_test_db(cursor, test_db_params, keepdb)
|
||||
except Exception:
|
||||
if keepdb:
|
||||
# If the database should be kept, skip everything else.
|
||||
return
|
||||
try:
|
||||
if verbosity >= 1:
|
||||
self.log(
|
||||
"Destroying old test database for alias %s..."
|
||||
% (
|
||||
self._get_database_display_str(
|
||||
verbosity, target_database_name
|
||||
),
|
||||
)
|
||||
)
|
||||
cursor.execute("DROP DATABASE %(dbname)s" % test_db_params)
|
||||
self._execute_create_test_db(cursor, test_db_params, keepdb)
|
||||
except Exception as e:
|
||||
self.log("Got an error recreating the test database: %s" % e)
|
||||
sys.exit(2)
|
||||
self._clone_db(source_database_name, target_database_name)
|
||||
|
||||
def _clone_db(self, source_database_name, target_database_name):
|
||||
cmd_args, cmd_env = DatabaseClient.settings_to_cmd_args_env(
|
||||
self.connection.settings_dict, []
|
||||
)
|
||||
dump_cmd = [
|
||||
"mysqldump",
|
||||
*cmd_args[1:-1],
|
||||
"--routines",
|
||||
"--events",
|
||||
source_database_name,
|
||||
]
|
||||
dump_env = load_env = {**os.environ, **cmd_env} if cmd_env else None
|
||||
load_cmd = cmd_args
|
||||
load_cmd[-1] = target_database_name
|
||||
|
||||
with subprocess.Popen(
|
||||
dump_cmd, stdout=subprocess.PIPE, env=dump_env
|
||||
) as dump_proc:
|
||||
with subprocess.Popen(
|
||||
load_cmd,
|
||||
stdin=dump_proc.stdout,
|
||||
stdout=subprocess.DEVNULL,
|
||||
env=load_env,
|
||||
):
|
||||
# Allow dump_proc to receive a SIGPIPE if the load process exits.
|
||||
dump_proc.stdout.close()
|
||||
@@ -0,0 +1,350 @@
|
||||
import operator
|
||||
|
||||
from django.db.backends.base.features import BaseDatabaseFeatures
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
||||
class DatabaseFeatures(BaseDatabaseFeatures):
|
||||
empty_fetchmany_value = ()
|
||||
allows_group_by_selected_pks = True
|
||||
related_fields_match_type = True
|
||||
# MySQL doesn't support sliced subqueries with IN/ALL/ANY/SOME.
|
||||
allow_sliced_subqueries_with_in = False
|
||||
has_select_for_update = True
|
||||
supports_forward_references = False
|
||||
supports_regex_backreferencing = False
|
||||
supports_date_lookup_using_string = False
|
||||
supports_timezones = False
|
||||
requires_explicit_null_ordering_when_grouping = True
|
||||
atomic_transactions = False
|
||||
can_clone_databases = True
|
||||
supports_comments = True
|
||||
supports_comments_inline = True
|
||||
supports_temporal_subtraction = True
|
||||
supports_slicing_ordering_in_compound = True
|
||||
supports_index_on_text_field = False
|
||||
supports_update_conflicts = True
|
||||
create_test_procedure_without_params_sql = """
|
||||
CREATE PROCEDURE test_procedure ()
|
||||
BEGIN
|
||||
DECLARE V_I INTEGER;
|
||||
SET V_I = 1;
|
||||
END;
|
||||
"""
|
||||
create_test_procedure_with_int_param_sql = """
|
||||
CREATE PROCEDURE test_procedure (P_I INTEGER)
|
||||
BEGIN
|
||||
DECLARE V_I INTEGER;
|
||||
SET V_I = P_I;
|
||||
END;
|
||||
"""
|
||||
create_test_table_with_composite_primary_key = """
|
||||
CREATE TABLE test_table_composite_pk (
|
||||
column_1 INTEGER NOT NULL,
|
||||
column_2 INTEGER NOT NULL,
|
||||
PRIMARY KEY(column_1, column_2)
|
||||
)
|
||||
"""
|
||||
# Neither MySQL nor MariaDB support partial indexes.
|
||||
supports_partial_indexes = False
|
||||
# COLLATE must be wrapped in parentheses because MySQL treats COLLATE as an
|
||||
# indexed expression.
|
||||
collate_as_index_expression = True
|
||||
|
||||
supports_order_by_nulls_modifier = False
|
||||
order_by_nulls_first = True
|
||||
supports_logical_xor = True
|
||||
|
||||
@cached_property
|
||||
def minimum_database_version(self):
|
||||
if self.connection.mysql_is_mariadb:
|
||||
return (10, 4)
|
||||
else:
|
||||
return (8,)
|
||||
|
||||
@cached_property
|
||||
def test_collations(self):
|
||||
charset = "utf8"
|
||||
if (
|
||||
self.connection.mysql_is_mariadb
|
||||
and self.connection.mysql_version >= (10, 6)
|
||||
) or (
|
||||
not self.connection.mysql_is_mariadb
|
||||
and self.connection.mysql_version >= (8, 0, 30)
|
||||
):
|
||||
# utf8 is an alias for utf8mb3 in MariaDB 10.6+ and MySQL 8.0.30+.
|
||||
charset = "utf8mb3"
|
||||
return {
|
||||
"ci": f"{charset}_general_ci",
|
||||
"non_default": f"{charset}_esperanto_ci",
|
||||
"swedish_ci": f"{charset}_swedish_ci",
|
||||
}
|
||||
|
||||
test_now_utc_template = "UTC_TIMESTAMP(6)"
|
||||
|
||||
@cached_property
|
||||
def django_test_skips(self):
|
||||
skips = {
|
||||
"This doesn't work on MySQL.": {
|
||||
"db_functions.comparison.test_greatest.GreatestTests."
|
||||
"test_coalesce_workaround",
|
||||
"db_functions.comparison.test_least.LeastTests."
|
||||
"test_coalesce_workaround",
|
||||
},
|
||||
"Running on MySQL requires utf8mb4 encoding (#18392).": {
|
||||
"model_fields.test_textfield.TextFieldTests.test_emoji",
|
||||
"model_fields.test_charfield.TestCharField.test_emoji",
|
||||
},
|
||||
"MySQL doesn't support functional indexes on a function that "
|
||||
"returns JSON": {
|
||||
"schema.tests.SchemaTests.test_func_index_json_key_transform",
|
||||
},
|
||||
"MySQL supports multiplying and dividing DurationFields by a "
|
||||
"scalar value but it's not implemented (#25287).": {
|
||||
"expressions.tests.FTimeDeltaTests.test_durationfield_multiply_divide",
|
||||
},
|
||||
"UPDATE ... ORDER BY syntax on MySQL/MariaDB does not support ordering by"
|
||||
"related fields.": {
|
||||
"update.tests.AdvancedTests."
|
||||
"test_update_ordered_by_inline_m2m_annotation",
|
||||
"update.tests.AdvancedTests.test_update_ordered_by_m2m_annotation",
|
||||
},
|
||||
}
|
||||
if self.connection.mysql_is_mariadb and (
|
||||
10,
|
||||
4,
|
||||
3,
|
||||
) < self.connection.mysql_version < (10, 5, 2):
|
||||
skips.update(
|
||||
{
|
||||
"https://jira.mariadb.org/browse/MDEV-19598": {
|
||||
"schema.tests.SchemaTests."
|
||||
"test_alter_not_unique_field_to_primary_key",
|
||||
},
|
||||
}
|
||||
)
|
||||
if self.connection.mysql_is_mariadb and (
|
||||
10,
|
||||
4,
|
||||
12,
|
||||
) < self.connection.mysql_version < (10, 5):
|
||||
skips.update(
|
||||
{
|
||||
"https://jira.mariadb.org/browse/MDEV-22775": {
|
||||
"schema.tests.SchemaTests."
|
||||
"test_alter_pk_with_self_referential_field",
|
||||
},
|
||||
}
|
||||
)
|
||||
if not self.supports_explain_analyze:
|
||||
skips.update(
|
||||
{
|
||||
"MariaDB and MySQL >= 8.0.18 specific.": {
|
||||
"queries.test_explain.ExplainTests.test_mysql_analyze",
|
||||
},
|
||||
}
|
||||
)
|
||||
if "ONLY_FULL_GROUP_BY" in self.connection.sql_mode:
|
||||
skips.update(
|
||||
{
|
||||
"GROUP BY cannot contain nonaggregated column when "
|
||||
"ONLY_FULL_GROUP_BY mode is enabled on MySQL, see #34262.": {
|
||||
"aggregation.tests.AggregateTestCase."
|
||||
"test_group_by_nested_expression_with_params",
|
||||
},
|
||||
}
|
||||
)
|
||||
if self.connection.mysql_is_mariadb and self.connection.mysql_version >= (
|
||||
10,
|
||||
5,
|
||||
2,
|
||||
):
|
||||
skips.update(
|
||||
{
|
||||
"ALTER TABLE ... RENAME COLUMN statement doesn't rename inline "
|
||||
"constraints on MariaDB 10.5.2+, this is fixed in Django 5.0+ "
|
||||
"(#34320).": {
|
||||
"schema.tests.SchemaTests."
|
||||
"test_rename_field_with_check_to_truncated_name",
|
||||
},
|
||||
}
|
||||
)
|
||||
return skips
|
||||
|
||||
@cached_property
|
||||
def _mysql_storage_engine(self):
|
||||
"Internal method used in Django tests. Don't rely on this from your code"
|
||||
return self.connection.mysql_server_data["default_storage_engine"]
|
||||
|
||||
@cached_property
|
||||
def allows_auto_pk_0(self):
|
||||
"""
|
||||
Autoincrement primary key can be set to 0 if it doesn't generate new
|
||||
autoincrement values.
|
||||
"""
|
||||
return "NO_AUTO_VALUE_ON_ZERO" in self.connection.sql_mode
|
||||
|
||||
@cached_property
|
||||
def update_can_self_select(self):
|
||||
return self.connection.mysql_is_mariadb and self.connection.mysql_version >= (
|
||||
10,
|
||||
3,
|
||||
2,
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def can_introspect_foreign_keys(self):
|
||||
"Confirm support for introspected foreign keys"
|
||||
return self._mysql_storage_engine != "MyISAM"
|
||||
|
||||
@cached_property
|
||||
def introspected_field_types(self):
|
||||
return {
|
||||
**super().introspected_field_types,
|
||||
"BinaryField": "TextField",
|
||||
"BooleanField": "IntegerField",
|
||||
"DurationField": "BigIntegerField",
|
||||
"GenericIPAddressField": "CharField",
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def can_return_columns_from_insert(self):
|
||||
return self.connection.mysql_is_mariadb and self.connection.mysql_version >= (
|
||||
10,
|
||||
5,
|
||||
0,
|
||||
)
|
||||
|
||||
can_return_rows_from_bulk_insert = property(
|
||||
operator.attrgetter("can_return_columns_from_insert")
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def has_zoneinfo_database(self):
|
||||
return self.connection.mysql_server_data["has_zoneinfo_database"]
|
||||
|
||||
@cached_property
|
||||
def is_sql_auto_is_null_enabled(self):
|
||||
return self.connection.mysql_server_data["sql_auto_is_null"]
|
||||
|
||||
@cached_property
|
||||
def supports_over_clause(self):
|
||||
if self.connection.mysql_is_mariadb:
|
||||
return True
|
||||
return self.connection.mysql_version >= (8, 0, 2)
|
||||
|
||||
supports_frame_range_fixed_distance = property(
|
||||
operator.attrgetter("supports_over_clause")
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def supports_column_check_constraints(self):
|
||||
if self.connection.mysql_is_mariadb:
|
||||
return True
|
||||
return self.connection.mysql_version >= (8, 0, 16)
|
||||
|
||||
supports_table_check_constraints = property(
|
||||
operator.attrgetter("supports_column_check_constraints")
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def can_introspect_check_constraints(self):
|
||||
if self.connection.mysql_is_mariadb:
|
||||
return True
|
||||
return self.connection.mysql_version >= (8, 0, 16)
|
||||
|
||||
@cached_property
|
||||
def has_select_for_update_skip_locked(self):
|
||||
if self.connection.mysql_is_mariadb:
|
||||
return self.connection.mysql_version >= (10, 6)
|
||||
return self.connection.mysql_version >= (8, 0, 1)
|
||||
|
||||
@cached_property
|
||||
def has_select_for_update_nowait(self):
|
||||
if self.connection.mysql_is_mariadb:
|
||||
return True
|
||||
return self.connection.mysql_version >= (8, 0, 1)
|
||||
|
||||
@cached_property
|
||||
def has_select_for_update_of(self):
|
||||
return (
|
||||
not self.connection.mysql_is_mariadb
|
||||
and self.connection.mysql_version >= (8, 0, 1)
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def supports_explain_analyze(self):
|
||||
return self.connection.mysql_is_mariadb or self.connection.mysql_version >= (
|
||||
8,
|
||||
0,
|
||||
18,
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def supported_explain_formats(self):
|
||||
# Alias MySQL's TRADITIONAL to TEXT for consistency with other
|
||||
# backends.
|
||||
formats = {"JSON", "TEXT", "TRADITIONAL"}
|
||||
if not self.connection.mysql_is_mariadb and self.connection.mysql_version >= (
|
||||
8,
|
||||
0,
|
||||
16,
|
||||
):
|
||||
formats.add("TREE")
|
||||
return formats
|
||||
|
||||
@cached_property
|
||||
def supports_transactions(self):
|
||||
"""
|
||||
All storage engines except MyISAM support transactions.
|
||||
"""
|
||||
return self._mysql_storage_engine != "MyISAM"
|
||||
|
||||
uses_savepoints = property(operator.attrgetter("supports_transactions"))
|
||||
can_release_savepoints = property(operator.attrgetter("supports_transactions"))
|
||||
|
||||
@cached_property
|
||||
def ignores_table_name_case(self):
|
||||
return self.connection.mysql_server_data["lower_case_table_names"]
|
||||
|
||||
@cached_property
|
||||
def supports_default_in_lead_lag(self):
|
||||
# To be added in https://jira.mariadb.org/browse/MDEV-12981.
|
||||
return not self.connection.mysql_is_mariadb
|
||||
|
||||
@cached_property
|
||||
def can_introspect_json_field(self):
|
||||
if self.connection.mysql_is_mariadb:
|
||||
return self.can_introspect_check_constraints
|
||||
return True
|
||||
|
||||
@cached_property
|
||||
def supports_index_column_ordering(self):
|
||||
if self._mysql_storage_engine != "InnoDB":
|
||||
return False
|
||||
if self.connection.mysql_is_mariadb:
|
||||
return self.connection.mysql_version >= (10, 8)
|
||||
return self.connection.mysql_version >= (8, 0, 1)
|
||||
|
||||
@cached_property
|
||||
def supports_expression_indexes(self):
|
||||
return (
|
||||
not self.connection.mysql_is_mariadb
|
||||
and self._mysql_storage_engine != "MyISAM"
|
||||
and self.connection.mysql_version >= (8, 0, 13)
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def supports_select_intersection(self):
|
||||
is_mariadb = self.connection.mysql_is_mariadb
|
||||
return is_mariadb or self.connection.mysql_version >= (8, 0, 31)
|
||||
|
||||
supports_select_difference = property(
|
||||
operator.attrgetter("supports_select_intersection")
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def can_rename_index(self):
|
||||
if self.connection.mysql_is_mariadb:
|
||||
return self.connection.mysql_version >= (10, 5, 2)
|
||||
return True
|
||||
@@ -0,0 +1,349 @@
|
||||
from collections import namedtuple
|
||||
|
||||
import sqlparse
|
||||
from MySQLdb.constants import FIELD_TYPE
|
||||
|
||||
from django.db.backends.base.introspection import BaseDatabaseIntrospection
|
||||
from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo
|
||||
from django.db.backends.base.introspection import TableInfo as BaseTableInfo
|
||||
from django.db.models import Index
|
||||
from django.utils.datastructures import OrderedSet
|
||||
|
||||
FieldInfo = namedtuple(
|
||||
"FieldInfo",
|
||||
BaseFieldInfo._fields + ("extra", "is_unsigned", "has_json_constraint", "comment"),
|
||||
)
|
||||
InfoLine = namedtuple(
|
||||
"InfoLine",
|
||||
"col_name data_type max_len num_prec num_scale extra column_default "
|
||||
"collation is_unsigned comment",
|
||||
)
|
||||
TableInfo = namedtuple("TableInfo", BaseTableInfo._fields + ("comment",))
|
||||
|
||||
|
||||
class DatabaseIntrospection(BaseDatabaseIntrospection):
|
||||
data_types_reverse = {
|
||||
FIELD_TYPE.BLOB: "TextField",
|
||||
FIELD_TYPE.CHAR: "CharField",
|
||||
FIELD_TYPE.DECIMAL: "DecimalField",
|
||||
FIELD_TYPE.NEWDECIMAL: "DecimalField",
|
||||
FIELD_TYPE.DATE: "DateField",
|
||||
FIELD_TYPE.DATETIME: "DateTimeField",
|
||||
FIELD_TYPE.DOUBLE: "FloatField",
|
||||
FIELD_TYPE.FLOAT: "FloatField",
|
||||
FIELD_TYPE.INT24: "IntegerField",
|
||||
FIELD_TYPE.JSON: "JSONField",
|
||||
FIELD_TYPE.LONG: "IntegerField",
|
||||
FIELD_TYPE.LONGLONG: "BigIntegerField",
|
||||
FIELD_TYPE.SHORT: "SmallIntegerField",
|
||||
FIELD_TYPE.STRING: "CharField",
|
||||
FIELD_TYPE.TIME: "TimeField",
|
||||
FIELD_TYPE.TIMESTAMP: "DateTimeField",
|
||||
FIELD_TYPE.TINY: "IntegerField",
|
||||
FIELD_TYPE.TINY_BLOB: "TextField",
|
||||
FIELD_TYPE.MEDIUM_BLOB: "TextField",
|
||||
FIELD_TYPE.LONG_BLOB: "TextField",
|
||||
FIELD_TYPE.VAR_STRING: "CharField",
|
||||
}
|
||||
|
||||
def get_field_type(self, data_type, description):
|
||||
field_type = super().get_field_type(data_type, description)
|
||||
if "auto_increment" in description.extra:
|
||||
if field_type == "IntegerField":
|
||||
return "AutoField"
|
||||
elif field_type == "BigIntegerField":
|
||||
return "BigAutoField"
|
||||
elif field_type == "SmallIntegerField":
|
||||
return "SmallAutoField"
|
||||
if description.is_unsigned:
|
||||
if field_type == "BigIntegerField":
|
||||
return "PositiveBigIntegerField"
|
||||
elif field_type == "IntegerField":
|
||||
return "PositiveIntegerField"
|
||||
elif field_type == "SmallIntegerField":
|
||||
return "PositiveSmallIntegerField"
|
||||
# JSON data type is an alias for LONGTEXT in MariaDB, use check
|
||||
# constraints clauses to introspect JSONField.
|
||||
if description.has_json_constraint:
|
||||
return "JSONField"
|
||||
return field_type
|
||||
|
||||
def get_table_list(self, cursor):
|
||||
"""Return a list of table and view names in the current database."""
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
table_name,
|
||||
table_type,
|
||||
table_comment
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
"""
|
||||
)
|
||||
return [
|
||||
TableInfo(row[0], {"BASE TABLE": "t", "VIEW": "v"}.get(row[1]), row[2])
|
||||
for row in cursor.fetchall()
|
||||
]
|
||||
|
||||
def get_table_description(self, cursor, table_name):
|
||||
"""
|
||||
Return a description of the table with the DB-API cursor.description
|
||||
interface."
|
||||
"""
|
||||
json_constraints = {}
|
||||
if (
|
||||
self.connection.mysql_is_mariadb
|
||||
and self.connection.features.can_introspect_json_field
|
||||
):
|
||||
# JSON data type is an alias for LONGTEXT in MariaDB, select
|
||||
# JSON_VALID() constraints to introspect JSONField.
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT c.constraint_name AS column_name
|
||||
FROM information_schema.check_constraints AS c
|
||||
WHERE
|
||||
c.table_name = %s AND
|
||||
LOWER(c.check_clause) =
|
||||
'json_valid(`' + LOWER(c.constraint_name) + '`)' AND
|
||||
c.constraint_schema = DATABASE()
|
||||
""",
|
||||
[table_name],
|
||||
)
|
||||
json_constraints = {row[0] for row in cursor.fetchall()}
|
||||
# A default collation for the given table.
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT table_collation
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = %s
|
||||
""",
|
||||
[table_name],
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
default_column_collation = row[0] if row else ""
|
||||
# information_schema database gives more accurate results for some figures:
|
||||
# - varchar length returned by cursor.description is an internal length,
|
||||
# not visible length (#5725)
|
||||
# - precision and scale (for decimal fields) (#5014)
|
||||
# - auto_increment is not available in cursor.description
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
column_name, data_type, character_maximum_length,
|
||||
numeric_precision, numeric_scale, extra, column_default,
|
||||
CASE
|
||||
WHEN collation_name = %s THEN NULL
|
||||
ELSE collation_name
|
||||
END AS collation_name,
|
||||
CASE
|
||||
WHEN column_type LIKE '%% unsigned' THEN 1
|
||||
ELSE 0
|
||||
END AS is_unsigned,
|
||||
column_comment
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = %s AND table_schema = DATABASE()
|
||||
""",
|
||||
[default_column_collation, table_name],
|
||||
)
|
||||
field_info = {line[0]: InfoLine(*line) for line in cursor.fetchall()}
|
||||
|
||||
cursor.execute(
|
||||
"SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name)
|
||||
)
|
||||
|
||||
def to_int(i):
|
||||
return int(i) if i is not None else i
|
||||
|
||||
fields = []
|
||||
for line in cursor.description:
|
||||
info = field_info[line[0]]
|
||||
fields.append(
|
||||
FieldInfo(
|
||||
*line[:2],
|
||||
to_int(info.max_len) or line[2],
|
||||
to_int(info.max_len) or line[3],
|
||||
to_int(info.num_prec) or line[4],
|
||||
to_int(info.num_scale) or line[5],
|
||||
line[6],
|
||||
info.column_default,
|
||||
info.collation,
|
||||
info.extra,
|
||||
info.is_unsigned,
|
||||
line[0] in json_constraints,
|
||||
info.comment,
|
||||
)
|
||||
)
|
||||
return fields
|
||||
|
||||
def get_sequences(self, cursor, table_name, table_fields=()):
|
||||
for field_info in self.get_table_description(cursor, table_name):
|
||||
if "auto_increment" in field_info.extra:
|
||||
# MySQL allows only one auto-increment column per table.
|
||||
return [{"table": table_name, "column": field_info.name}]
|
||||
return []
|
||||
|
||||
def get_relations(self, cursor, table_name):
|
||||
"""
|
||||
Return a dictionary of {field_name: (field_name_other_table, other_table)}
|
||||
representing all foreign keys in the given table.
|
||||
"""
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT column_name, referenced_column_name, referenced_table_name
|
||||
FROM information_schema.key_column_usage
|
||||
WHERE table_name = %s
|
||||
AND table_schema = DATABASE()
|
||||
AND referenced_table_name IS NOT NULL
|
||||
AND referenced_column_name IS NOT NULL
|
||||
""",
|
||||
[table_name],
|
||||
)
|
||||
return {
|
||||
field_name: (other_field, other_table)
|
||||
for field_name, other_field, other_table in cursor.fetchall()
|
||||
}
|
||||
|
||||
def get_storage_engine(self, cursor, table_name):
|
||||
"""
|
||||
Retrieve the storage engine for a given table. Return the default
|
||||
storage engine if the table doesn't exist.
|
||||
"""
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT engine
|
||||
FROM information_schema.tables
|
||||
WHERE
|
||||
table_name = %s AND
|
||||
table_schema = DATABASE()
|
||||
""",
|
||||
[table_name],
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if not result:
|
||||
return self.connection.features._mysql_storage_engine
|
||||
return result[0]
|
||||
|
||||
def _parse_constraint_columns(self, check_clause, columns):
|
||||
check_columns = OrderedSet()
|
||||
statement = sqlparse.parse(check_clause)[0]
|
||||
tokens = (token for token in statement.flatten() if not token.is_whitespace)
|
||||
for token in tokens:
|
||||
if (
|
||||
token.ttype == sqlparse.tokens.Name
|
||||
and self.connection.ops.quote_name(token.value) == token.value
|
||||
and token.value[1:-1] in columns
|
||||
):
|
||||
check_columns.add(token.value[1:-1])
|
||||
return check_columns
|
||||
|
||||
def get_constraints(self, cursor, table_name):
|
||||
"""
|
||||
Retrieve any constraints or keys (unique, pk, fk, check, index) across
|
||||
one or more columns.
|
||||
"""
|
||||
constraints = {}
|
||||
# Get the actual constraint names and columns
|
||||
name_query = """
|
||||
SELECT kc.`constraint_name`, kc.`column_name`,
|
||||
kc.`referenced_table_name`, kc.`referenced_column_name`,
|
||||
c.`constraint_type`
|
||||
FROM
|
||||
information_schema.key_column_usage AS kc,
|
||||
information_schema.table_constraints AS c
|
||||
WHERE
|
||||
kc.table_schema = DATABASE() AND
|
||||
c.table_schema = kc.table_schema AND
|
||||
c.constraint_name = kc.constraint_name AND
|
||||
c.constraint_type != 'CHECK' AND
|
||||
kc.table_name = %s
|
||||
ORDER BY kc.`ordinal_position`
|
||||
"""
|
||||
cursor.execute(name_query, [table_name])
|
||||
for constraint, column, ref_table, ref_column, kind in cursor.fetchall():
|
||||
if constraint not in constraints:
|
||||
constraints[constraint] = {
|
||||
"columns": OrderedSet(),
|
||||
"primary_key": kind == "PRIMARY KEY",
|
||||
"unique": kind in {"PRIMARY KEY", "UNIQUE"},
|
||||
"index": False,
|
||||
"check": False,
|
||||
"foreign_key": (ref_table, ref_column) if ref_column else None,
|
||||
}
|
||||
if self.connection.features.supports_index_column_ordering:
|
||||
constraints[constraint]["orders"] = []
|
||||
constraints[constraint]["columns"].add(column)
|
||||
# Add check constraints.
|
||||
if self.connection.features.can_introspect_check_constraints:
|
||||
unnamed_constraints_index = 0
|
||||
columns = {
|
||||
info.name for info in self.get_table_description(cursor, table_name)
|
||||
}
|
||||
if self.connection.mysql_is_mariadb:
|
||||
type_query = """
|
||||
SELECT c.constraint_name, c.check_clause
|
||||
FROM information_schema.check_constraints AS c
|
||||
WHERE
|
||||
c.constraint_schema = DATABASE() AND
|
||||
c.table_name = %s
|
||||
"""
|
||||
else:
|
||||
type_query = """
|
||||
SELECT cc.constraint_name, cc.check_clause
|
||||
FROM
|
||||
information_schema.check_constraints AS cc,
|
||||
information_schema.table_constraints AS tc
|
||||
WHERE
|
||||
cc.constraint_schema = DATABASE() AND
|
||||
tc.table_schema = cc.constraint_schema AND
|
||||
cc.constraint_name = tc.constraint_name AND
|
||||
tc.constraint_type = 'CHECK' AND
|
||||
tc.table_name = %s
|
||||
"""
|
||||
cursor.execute(type_query, [table_name])
|
||||
for constraint, check_clause in cursor.fetchall():
|
||||
constraint_columns = self._parse_constraint_columns(
|
||||
check_clause, columns
|
||||
)
|
||||
# Ensure uniqueness of unnamed constraints. Unnamed unique
|
||||
# and check columns constraints have the same name as
|
||||
# a column.
|
||||
if set(constraint_columns) == {constraint}:
|
||||
unnamed_constraints_index += 1
|
||||
constraint = "__unnamed_constraint_%s__" % unnamed_constraints_index
|
||||
constraints[constraint] = {
|
||||
"columns": constraint_columns,
|
||||
"primary_key": False,
|
||||
"unique": False,
|
||||
"index": False,
|
||||
"check": True,
|
||||
"foreign_key": None,
|
||||
}
|
||||
# Now add in the indexes
|
||||
cursor.execute(
|
||||
"SHOW INDEX FROM %s" % self.connection.ops.quote_name(table_name)
|
||||
)
|
||||
for table, non_unique, index, colseq, column, order, type_ in [
|
||||
x[:6] + (x[10],) for x in cursor.fetchall()
|
||||
]:
|
||||
if index not in constraints:
|
||||
constraints[index] = {
|
||||
"columns": OrderedSet(),
|
||||
"primary_key": False,
|
||||
"unique": not non_unique,
|
||||
"check": False,
|
||||
"foreign_key": None,
|
||||
}
|
||||
if self.connection.features.supports_index_column_ordering:
|
||||
constraints[index]["orders"] = []
|
||||
constraints[index]["index"] = True
|
||||
constraints[index]["type"] = (
|
||||
Index.suffix if type_ == "BTREE" else type_.lower()
|
||||
)
|
||||
constraints[index]["columns"].add(column)
|
||||
if self.connection.features.supports_index_column_ordering:
|
||||
constraints[index]["orders"].append("DESC" if order == "D" else "ASC")
|
||||
# Convert the sorted sets to lists
|
||||
for constraint in constraints.values():
|
||||
constraint["columns"] = list(constraint["columns"])
|
||||
return constraints
|
||||
@@ -0,0 +1,464 @@
|
||||
import uuid
|
||||
|
||||
from django.conf import settings
|
||||
from django.db.backends.base.operations import BaseDatabaseOperations
|
||||
from django.db.backends.utils import split_tzname_delta
|
||||
from django.db.models import Exists, ExpressionWrapper, Lookup
|
||||
from django.db.models.constants import OnConflict
|
||||
from django.utils import timezone
|
||||
from django.utils.encoding import force_str
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
|
||||
|
||||
class DatabaseOperations(BaseDatabaseOperations):
|
||||
compiler_module = "django.db.backends.mysql.compiler"
|
||||
|
||||
# MySQL stores positive fields as UNSIGNED ints.
|
||||
integer_field_ranges = {
|
||||
**BaseDatabaseOperations.integer_field_ranges,
|
||||
"PositiveSmallIntegerField": (0, 65535),
|
||||
"PositiveIntegerField": (0, 4294967295),
|
||||
"PositiveBigIntegerField": (0, 18446744073709551615),
|
||||
}
|
||||
cast_data_types = {
|
||||
"AutoField": "signed integer",
|
||||
"BigAutoField": "signed integer",
|
||||
"SmallAutoField": "signed integer",
|
||||
"CharField": "char(%(max_length)s)",
|
||||
"DecimalField": "decimal(%(max_digits)s, %(decimal_places)s)",
|
||||
"TextField": "char",
|
||||
"IntegerField": "signed integer",
|
||||
"BigIntegerField": "signed integer",
|
||||
"SmallIntegerField": "signed integer",
|
||||
"PositiveBigIntegerField": "unsigned integer",
|
||||
"PositiveIntegerField": "unsigned integer",
|
||||
"PositiveSmallIntegerField": "unsigned integer",
|
||||
"DurationField": "signed integer",
|
||||
}
|
||||
cast_char_field_without_max_length = "char"
|
||||
explain_prefix = "EXPLAIN"
|
||||
|
||||
# EXTRACT format cannot be passed in parameters.
|
||||
_extract_format_re = _lazy_re_compile(r"[A-Z_]+")
|
||||
|
||||
def date_extract_sql(self, lookup_type, sql, params):
|
||||
# https://dev.mysql.com/doc/mysql/en/date-and-time-functions.html
|
||||
if lookup_type == "week_day":
|
||||
# DAYOFWEEK() returns an integer, 1-7, Sunday=1.
|
||||
return f"DAYOFWEEK({sql})", params
|
||||
elif lookup_type == "iso_week_day":
|
||||
# WEEKDAY() returns an integer, 0-6, Monday=0.
|
||||
return f"WEEKDAY({sql}) + 1", params
|
||||
elif lookup_type == "week":
|
||||
# Override the value of default_week_format for consistency with
|
||||
# other database backends.
|
||||
# Mode 3: Monday, 1-53, with 4 or more days this year.
|
||||
return f"WEEK({sql}, 3)", params
|
||||
elif lookup_type == "iso_year":
|
||||
# Get the year part from the YEARWEEK function, which returns a
|
||||
# number as year * 100 + week.
|
||||
return f"TRUNCATE(YEARWEEK({sql}, 3), -2) / 100", params
|
||||
else:
|
||||
# EXTRACT returns 1-53 based on ISO-8601 for the week number.
|
||||
lookup_type = lookup_type.upper()
|
||||
if not self._extract_format_re.fullmatch(lookup_type):
|
||||
raise ValueError(f"Invalid loookup type: {lookup_type!r}")
|
||||
return f"EXTRACT({lookup_type} FROM {sql})", params
|
||||
|
||||
def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
fields = {
|
||||
"year": "%Y-01-01",
|
||||
"month": "%Y-%m-01",
|
||||
}
|
||||
if lookup_type in fields:
|
||||
format_str = fields[lookup_type]
|
||||
return f"CAST(DATE_FORMAT({sql}, %s) AS DATE)", (*params, format_str)
|
||||
elif lookup_type == "quarter":
|
||||
return (
|
||||
f"MAKEDATE(YEAR({sql}), 1) + "
|
||||
f"INTERVAL QUARTER({sql}) QUARTER - INTERVAL 1 QUARTER",
|
||||
(*params, *params),
|
||||
)
|
||||
elif lookup_type == "week":
|
||||
return f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY)", (*params, *params)
|
||||
else:
|
||||
return f"DATE({sql})", params
|
||||
|
||||
def _prepare_tzname_delta(self, tzname):
|
||||
tzname, sign, offset = split_tzname_delta(tzname)
|
||||
return f"{sign}{offset}" if offset else tzname
|
||||
|
||||
def _convert_sql_to_tz(self, sql, params, tzname):
|
||||
if tzname and settings.USE_TZ and self.connection.timezone_name != tzname:
|
||||
return f"CONVERT_TZ({sql}, %s, %s)", (
|
||||
*params,
|
||||
self.connection.timezone_name,
|
||||
self._prepare_tzname_delta(tzname),
|
||||
)
|
||||
return sql, params
|
||||
|
||||
def datetime_cast_date_sql(self, sql, params, tzname):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
return f"DATE({sql})", params
|
||||
|
||||
def datetime_cast_time_sql(self, sql, params, tzname):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
return f"TIME({sql})", params
|
||||
|
||||
def datetime_extract_sql(self, lookup_type, sql, params, tzname):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
return self.date_extract_sql(lookup_type, sql, params)
|
||||
|
||||
def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
fields = ["year", "month", "day", "hour", "minute", "second"]
|
||||
format = ("%Y-", "%m", "-%d", " %H:", "%i", ":%s")
|
||||
format_def = ("0000-", "01", "-01", " 00:", "00", ":00")
|
||||
if lookup_type == "quarter":
|
||||
return (
|
||||
f"CAST(DATE_FORMAT(MAKEDATE(YEAR({sql}), 1) + "
|
||||
f"INTERVAL QUARTER({sql}) QUARTER - "
|
||||
f"INTERVAL 1 QUARTER, %s) AS DATETIME)"
|
||||
), (*params, *params, "%Y-%m-01 00:00:00")
|
||||
if lookup_type == "week":
|
||||
return (
|
||||
f"CAST(DATE_FORMAT("
|
||||
f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY), %s) AS DATETIME)"
|
||||
), (*params, *params, "%Y-%m-%d 00:00:00")
|
||||
try:
|
||||
i = fields.index(lookup_type) + 1
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
format_str = "".join(format[:i] + format_def[i:])
|
||||
return f"CAST(DATE_FORMAT({sql}, %s) AS DATETIME)", (*params, format_str)
|
||||
return sql, params
|
||||
|
||||
def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
fields = {
|
||||
"hour": "%H:00:00",
|
||||
"minute": "%H:%i:00",
|
||||
"second": "%H:%i:%s",
|
||||
}
|
||||
if lookup_type in fields:
|
||||
format_str = fields[lookup_type]
|
||||
return f"CAST(DATE_FORMAT({sql}, %s) AS TIME)", (*params, format_str)
|
||||
else:
|
||||
return f"TIME({sql})", params
|
||||
|
||||
def fetch_returned_insert_rows(self, cursor):
|
||||
"""
|
||||
Given a cursor object that has just performed an INSERT...RETURNING
|
||||
statement into a table, return the tuple of returned data.
|
||||
"""
|
||||
return cursor.fetchall()
|
||||
|
||||
def format_for_duration_arithmetic(self, sql):
|
||||
return "INTERVAL %s MICROSECOND" % sql
|
||||
|
||||
def force_no_ordering(self):
|
||||
"""
|
||||
"ORDER BY NULL" prevents MySQL from implicitly ordering by grouped
|
||||
columns. If no ordering would otherwise be applied, we don't want any
|
||||
implicit sorting going on.
|
||||
"""
|
||||
return [(None, ("NULL", [], False))]
|
||||
|
||||
def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None):
|
||||
return value
|
||||
|
||||
def last_executed_query(self, cursor, sql, params):
|
||||
# With MySQLdb, cursor objects have an (undocumented) "_executed"
|
||||
# attribute where the exact query sent to the database is saved.
|
||||
# See MySQLdb/cursors.py in the source distribution.
|
||||
# MySQLdb returns string, PyMySQL bytes.
|
||||
return force_str(getattr(cursor, "_executed", None), errors="replace")
|
||||
|
||||
def no_limit_value(self):
|
||||
# 2**64 - 1, as recommended by the MySQL documentation
|
||||
return 18446744073709551615
|
||||
|
||||
def quote_name(self, name):
|
||||
if name.startswith("`") and name.endswith("`"):
|
||||
return name # Quoting once is enough.
|
||||
return "`%s`" % name
|
||||
|
||||
def return_insert_columns(self, fields):
|
||||
# MySQL and MariaDB < 10.5.0 don't support an INSERT...RETURNING
|
||||
# statement.
|
||||
if not fields:
|
||||
return "", ()
|
||||
columns = [
|
||||
"%s.%s"
|
||||
% (
|
||||
self.quote_name(field.model._meta.db_table),
|
||||
self.quote_name(field.column),
|
||||
)
|
||||
for field in fields
|
||||
]
|
||||
return "RETURNING %s" % ", ".join(columns), ()
|
||||
|
||||
def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
|
||||
if not tables:
|
||||
return []
|
||||
|
||||
sql = ["SET FOREIGN_KEY_CHECKS = 0;"]
|
||||
if reset_sequences:
|
||||
# It's faster to TRUNCATE tables that require a sequence reset
|
||||
# since ALTER TABLE AUTO_INCREMENT is slower than TRUNCATE.
|
||||
sql.extend(
|
||||
"%s %s;"
|
||||
% (
|
||||
style.SQL_KEYWORD("TRUNCATE"),
|
||||
style.SQL_FIELD(self.quote_name(table_name)),
|
||||
)
|
||||
for table_name in tables
|
||||
)
|
||||
else:
|
||||
# Otherwise issue a simple DELETE since it's faster than TRUNCATE
|
||||
# and preserves sequences.
|
||||
sql.extend(
|
||||
"%s %s %s;"
|
||||
% (
|
||||
style.SQL_KEYWORD("DELETE"),
|
||||
style.SQL_KEYWORD("FROM"),
|
||||
style.SQL_FIELD(self.quote_name(table_name)),
|
||||
)
|
||||
for table_name in tables
|
||||
)
|
||||
sql.append("SET FOREIGN_KEY_CHECKS = 1;")
|
||||
return sql
|
||||
|
||||
def sequence_reset_by_name_sql(self, style, sequences):
|
||||
return [
|
||||
"%s %s %s %s = 1;"
|
||||
% (
|
||||
style.SQL_KEYWORD("ALTER"),
|
||||
style.SQL_KEYWORD("TABLE"),
|
||||
style.SQL_FIELD(self.quote_name(sequence_info["table"])),
|
||||
style.SQL_FIELD("AUTO_INCREMENT"),
|
||||
)
|
||||
for sequence_info in sequences
|
||||
]
|
||||
|
||||
def validate_autopk_value(self, value):
|
||||
# Zero in AUTO_INCREMENT field does not work without the
|
||||
# NO_AUTO_VALUE_ON_ZERO SQL mode.
|
||||
if value == 0 and not self.connection.features.allows_auto_pk_0:
|
||||
raise ValueError(
|
||||
"The database backend does not accept 0 as a value for AutoField."
|
||||
)
|
||||
return value
|
||||
|
||||
def adapt_datetimefield_value(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# Expression values are adapted by the database.
|
||||
if hasattr(value, "resolve_expression"):
|
||||
return value
|
||||
|
||||
# MySQL doesn't support tz-aware datetimes
|
||||
if timezone.is_aware(value):
|
||||
if settings.USE_TZ:
|
||||
value = timezone.make_naive(value, self.connection.timezone)
|
||||
else:
|
||||
raise ValueError(
|
||||
"MySQL backend does not support timezone-aware datetimes when "
|
||||
"USE_TZ is False."
|
||||
)
|
||||
return str(value)
|
||||
|
||||
def adapt_timefield_value(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# Expression values are adapted by the database.
|
||||
if hasattr(value, "resolve_expression"):
|
||||
return value
|
||||
|
||||
# MySQL doesn't support tz-aware times
|
||||
if timezone.is_aware(value):
|
||||
raise ValueError("MySQL backend does not support timezone-aware times.")
|
||||
|
||||
return value.isoformat(timespec="microseconds")
|
||||
|
||||
def max_name_length(self):
|
||||
return 64
|
||||
|
||||
def pk_default_value(self):
|
||||
return "NULL"
|
||||
|
||||
def bulk_insert_sql(self, fields, placeholder_rows):
|
||||
placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
|
||||
values_sql = ", ".join("(%s)" % sql for sql in placeholder_rows_sql)
|
||||
return "VALUES " + values_sql
|
||||
|
||||
def combine_expression(self, connector, sub_expressions):
|
||||
if connector == "^":
|
||||
return "POW(%s)" % ",".join(sub_expressions)
|
||||
# Convert the result to a signed integer since MySQL's binary operators
|
||||
# return an unsigned integer.
|
||||
elif connector in ("&", "|", "<<", "#"):
|
||||
connector = "^" if connector == "#" else connector
|
||||
return "CONVERT(%s, SIGNED)" % connector.join(sub_expressions)
|
||||
elif connector == ">>":
|
||||
lhs, rhs = sub_expressions
|
||||
return "FLOOR(%(lhs)s / POW(2, %(rhs)s))" % {"lhs": lhs, "rhs": rhs}
|
||||
return super().combine_expression(connector, sub_expressions)
|
||||
|
||||
def get_db_converters(self, expression):
|
||||
converters = super().get_db_converters(expression)
|
||||
internal_type = expression.output_field.get_internal_type()
|
||||
if internal_type == "BooleanField":
|
||||
converters.append(self.convert_booleanfield_value)
|
||||
elif internal_type == "DateTimeField":
|
||||
if settings.USE_TZ:
|
||||
converters.append(self.convert_datetimefield_value)
|
||||
elif internal_type == "UUIDField":
|
||||
converters.append(self.convert_uuidfield_value)
|
||||
return converters
|
||||
|
||||
def convert_booleanfield_value(self, value, expression, connection):
|
||||
if value in (0, 1):
|
||||
value = bool(value)
|
||||
return value
|
||||
|
||||
def convert_datetimefield_value(self, value, expression, connection):
|
||||
if value is not None:
|
||||
value = timezone.make_aware(value, self.connection.timezone)
|
||||
return value
|
||||
|
||||
def convert_uuidfield_value(self, value, expression, connection):
|
||||
if value is not None:
|
||||
value = uuid.UUID(value)
|
||||
return value
|
||||
|
||||
def binary_placeholder_sql(self, value):
|
||||
return (
|
||||
"_binary %s" if value is not None and not hasattr(value, "as_sql") else "%s"
|
||||
)
|
||||
|
||||
def subtract_temporals(self, internal_type, lhs, rhs):
|
||||
lhs_sql, lhs_params = lhs
|
||||
rhs_sql, rhs_params = rhs
|
||||
if internal_type == "TimeField":
|
||||
if self.connection.mysql_is_mariadb:
|
||||
# MariaDB includes the microsecond component in TIME_TO_SEC as
|
||||
# a decimal. MySQL returns an integer without microseconds.
|
||||
return (
|
||||
"CAST((TIME_TO_SEC(%(lhs)s) - TIME_TO_SEC(%(rhs)s)) "
|
||||
"* 1000000 AS SIGNED)"
|
||||
) % {
|
||||
"lhs": lhs_sql,
|
||||
"rhs": rhs_sql,
|
||||
}, (
|
||||
*lhs_params,
|
||||
*rhs_params,
|
||||
)
|
||||
return (
|
||||
"((TIME_TO_SEC(%(lhs)s) * 1000000 + MICROSECOND(%(lhs)s)) -"
|
||||
" (TIME_TO_SEC(%(rhs)s) * 1000000 + MICROSECOND(%(rhs)s)))"
|
||||
) % {"lhs": lhs_sql, "rhs": rhs_sql}, tuple(lhs_params) * 2 + tuple(
|
||||
rhs_params
|
||||
) * 2
|
||||
params = (*rhs_params, *lhs_params)
|
||||
return "TIMESTAMPDIFF(MICROSECOND, %s, %s)" % (rhs_sql, lhs_sql), params
|
||||
|
||||
def explain_query_prefix(self, format=None, **options):
|
||||
# Alias MySQL's TRADITIONAL to TEXT for consistency with other backends.
|
||||
if format and format.upper() == "TEXT":
|
||||
format = "TRADITIONAL"
|
||||
elif (
|
||||
not format and "TREE" in self.connection.features.supported_explain_formats
|
||||
):
|
||||
# Use TREE by default (if supported) as it's more informative.
|
||||
format = "TREE"
|
||||
analyze = options.pop("analyze", False)
|
||||
prefix = super().explain_query_prefix(format, **options)
|
||||
if analyze and self.connection.features.supports_explain_analyze:
|
||||
# MariaDB uses ANALYZE instead of EXPLAIN ANALYZE.
|
||||
prefix = (
|
||||
"ANALYZE" if self.connection.mysql_is_mariadb else prefix + " ANALYZE"
|
||||
)
|
||||
if format and not (analyze and not self.connection.mysql_is_mariadb):
|
||||
# Only MariaDB supports the analyze option with formats.
|
||||
prefix += " FORMAT=%s" % format
|
||||
return prefix
|
||||
|
||||
def regex_lookup(self, lookup_type):
|
||||
# REGEXP_LIKE doesn't exist in MariaDB.
|
||||
if self.connection.mysql_is_mariadb:
|
||||
if lookup_type == "regex":
|
||||
return "%s REGEXP BINARY %s"
|
||||
return "%s REGEXP %s"
|
||||
|
||||
match_option = "c" if lookup_type == "regex" else "i"
|
||||
return "REGEXP_LIKE(%%s, %%s, '%s')" % match_option
|
||||
|
||||
def insert_statement(self, on_conflict=None):
|
||||
if on_conflict == OnConflict.IGNORE:
|
||||
return "INSERT IGNORE INTO"
|
||||
return super().insert_statement(on_conflict=on_conflict)
|
||||
|
||||
def lookup_cast(self, lookup_type, internal_type=None):
|
||||
lookup = "%s"
|
||||
if internal_type == "JSONField":
|
||||
if self.connection.mysql_is_mariadb or lookup_type in (
|
||||
"iexact",
|
||||
"contains",
|
||||
"icontains",
|
||||
"startswith",
|
||||
"istartswith",
|
||||
"endswith",
|
||||
"iendswith",
|
||||
"regex",
|
||||
"iregex",
|
||||
):
|
||||
lookup = "JSON_UNQUOTE(%s)"
|
||||
return lookup
|
||||
|
||||
def conditional_expression_supported_in_where_clause(self, expression):
|
||||
# MySQL ignores indexes with boolean fields unless they're compared
|
||||
# directly to a boolean value.
|
||||
if isinstance(expression, (Exists, Lookup)):
|
||||
return True
|
||||
if isinstance(expression, ExpressionWrapper) and expression.conditional:
|
||||
return self.conditional_expression_supported_in_where_clause(
|
||||
expression.expression
|
||||
)
|
||||
if getattr(expression, "conditional", False):
|
||||
return False
|
||||
return super().conditional_expression_supported_in_where_clause(expression)
|
||||
|
||||
def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
|
||||
if on_conflict == OnConflict.UPDATE:
|
||||
conflict_suffix_sql = "ON DUPLICATE KEY UPDATE %(fields)s"
|
||||
# The use of VALUES() is deprecated in MySQL 8.0.20+. Instead, use
|
||||
# aliases for the new row and its columns available in MySQL
|
||||
# 8.0.19+.
|
||||
if not self.connection.mysql_is_mariadb:
|
||||
if self.connection.mysql_version >= (8, 0, 19):
|
||||
conflict_suffix_sql = f"AS new {conflict_suffix_sql}"
|
||||
field_sql = "%(field)s = new.%(field)s"
|
||||
else:
|
||||
field_sql = "%(field)s = VALUES(%(field)s)"
|
||||
# Use VALUE() on MariaDB.
|
||||
else:
|
||||
field_sql = "%(field)s = VALUE(%(field)s)"
|
||||
|
||||
fields = ", ".join(
|
||||
[
|
||||
field_sql % {"field": field}
|
||||
for field in map(self.quote_name, update_fields)
|
||||
]
|
||||
)
|
||||
return conflict_suffix_sql % {"fields": fields}
|
||||
return super().on_conflict_suffix_sql(
|
||||
fields,
|
||||
on_conflict,
|
||||
update_fields,
|
||||
unique_fields,
|
||||
)
|
||||
@@ -0,0 +1,243 @@
|
||||
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
||||
from django.db.models import NOT_PROVIDED, F, UniqueConstraint
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
|
||||
|
||||
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||
sql_rename_table = "RENAME TABLE %(old_table)s TO %(new_table)s"
|
||||
|
||||
sql_alter_column_null = "MODIFY %(column)s %(type)s NULL"
|
||||
sql_alter_column_not_null = "MODIFY %(column)s %(type)s NOT NULL"
|
||||
sql_alter_column_type = "MODIFY %(column)s %(type)s%(collation)s%(comment)s"
|
||||
sql_alter_column_no_default_null = "ALTER COLUMN %(column)s SET DEFAULT NULL"
|
||||
|
||||
# No 'CASCADE' which works as a no-op in MySQL but is undocumented
|
||||
sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s"
|
||||
|
||||
sql_delete_unique = "ALTER TABLE %(table)s DROP INDEX %(name)s"
|
||||
sql_create_column_inline_fk = (
|
||||
", ADD CONSTRAINT %(name)s FOREIGN KEY (%(column)s) "
|
||||
"REFERENCES %(to_table)s(%(to_column)s)"
|
||||
)
|
||||
sql_delete_fk = "ALTER TABLE %(table)s DROP FOREIGN KEY %(name)s"
|
||||
|
||||
sql_delete_index = "DROP INDEX %(name)s ON %(table)s"
|
||||
sql_rename_index = "ALTER TABLE %(table)s RENAME INDEX %(old_name)s TO %(new_name)s"
|
||||
|
||||
sql_create_pk = (
|
||||
"ALTER TABLE %(table)s ADD CONSTRAINT %(name)s PRIMARY KEY (%(columns)s)"
|
||||
)
|
||||
sql_delete_pk = "ALTER TABLE %(table)s DROP PRIMARY KEY"
|
||||
|
||||
sql_create_index = "CREATE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s"
|
||||
|
||||
sql_alter_table_comment = "ALTER TABLE %(table)s COMMENT = %(comment)s"
|
||||
sql_alter_column_comment = None
|
||||
|
||||
@property
|
||||
def sql_delete_check(self):
|
||||
if self.connection.mysql_is_mariadb:
|
||||
# The name of the column check constraint is the same as the field
|
||||
# name on MariaDB. Adding IF EXISTS clause prevents migrations
|
||||
# crash. Constraint is removed during a "MODIFY" column statement.
|
||||
return "ALTER TABLE %(table)s DROP CONSTRAINT IF EXISTS %(name)s"
|
||||
return "ALTER TABLE %(table)s DROP CHECK %(name)s"
|
||||
|
||||
@property
|
||||
def sql_rename_column(self):
|
||||
# MariaDB >= 10.5.2 and MySQL >= 8.0.4 support an
|
||||
# "ALTER TABLE ... RENAME COLUMN" statement.
|
||||
if self.connection.mysql_is_mariadb:
|
||||
if self.connection.mysql_version >= (10, 5, 2):
|
||||
return super().sql_rename_column
|
||||
elif self.connection.mysql_version >= (8, 0, 4):
|
||||
return super().sql_rename_column
|
||||
return "ALTER TABLE %(table)s CHANGE %(old_column)s %(new_column)s %(type)s"
|
||||
|
||||
def quote_value(self, value):
|
||||
self.connection.ensure_connection()
|
||||
if isinstance(value, str):
|
||||
value = value.replace("%", "%%")
|
||||
# MySQLdb escapes to string, PyMySQL to bytes.
|
||||
quoted = self.connection.connection.escape(
|
||||
value, self.connection.connection.encoders
|
||||
)
|
||||
if isinstance(value, str) and isinstance(quoted, bytes):
|
||||
quoted = quoted.decode()
|
||||
return quoted
|
||||
|
||||
def _is_limited_data_type(self, field):
|
||||
db_type = field.db_type(self.connection)
|
||||
return (
|
||||
db_type is not None
|
||||
and db_type.lower() in self.connection._limited_data_types
|
||||
)
|
||||
|
||||
def skip_default(self, field):
|
||||
if not self._supports_limited_data_type_defaults:
|
||||
return self._is_limited_data_type(field)
|
||||
return False
|
||||
|
||||
def skip_default_on_alter(self, field):
|
||||
if self._is_limited_data_type(field) and not self.connection.mysql_is_mariadb:
|
||||
# MySQL doesn't support defaults for BLOB and TEXT in the
|
||||
# ALTER COLUMN statement.
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def _supports_limited_data_type_defaults(self):
|
||||
# MariaDB and MySQL >= 8.0.13 support defaults for BLOB and TEXT.
|
||||
if self.connection.mysql_is_mariadb:
|
||||
return True
|
||||
return self.connection.mysql_version >= (8, 0, 13)
|
||||
|
||||
def _column_default_sql(self, field):
|
||||
if (
|
||||
not self.connection.mysql_is_mariadb
|
||||
and self._supports_limited_data_type_defaults
|
||||
and self._is_limited_data_type(field)
|
||||
):
|
||||
# MySQL supports defaults for BLOB and TEXT columns only if the
|
||||
# default value is written as an expression i.e. in parentheses.
|
||||
return "(%s)"
|
||||
return super()._column_default_sql(field)
|
||||
|
||||
def add_field(self, model, field):
|
||||
super().add_field(model, field)
|
||||
|
||||
# Simulate the effect of a one-off default.
|
||||
# field.default may be unhashable, so a set isn't used for "in" check.
|
||||
if self.skip_default(field) and field.default not in (None, NOT_PROVIDED):
|
||||
effective_default = self.effective_default(field)
|
||||
self.execute(
|
||||
"UPDATE %(table)s SET %(column)s = %%s"
|
||||
% {
|
||||
"table": self.quote_name(model._meta.db_table),
|
||||
"column": self.quote_name(field.column),
|
||||
},
|
||||
[effective_default],
|
||||
)
|
||||
|
||||
def remove_constraint(self, model, constraint):
|
||||
if (
|
||||
isinstance(constraint, UniqueConstraint)
|
||||
and constraint.create_sql(model, self) is not None
|
||||
):
|
||||
self._create_missing_fk_index(
|
||||
model,
|
||||
fields=constraint.fields,
|
||||
expressions=constraint.expressions,
|
||||
)
|
||||
super().remove_constraint(model, constraint)
|
||||
|
||||
def remove_index(self, model, index):
|
||||
self._create_missing_fk_index(
|
||||
model,
|
||||
fields=[field_name for field_name, _ in index.fields_orders],
|
||||
expressions=index.expressions,
|
||||
)
|
||||
super().remove_index(model, index)
|
||||
|
||||
def _field_should_be_indexed(self, model, field):
|
||||
if not super()._field_should_be_indexed(model, field):
|
||||
return False
|
||||
|
||||
storage = self.connection.introspection.get_storage_engine(
|
||||
self.connection.cursor(), model._meta.db_table
|
||||
)
|
||||
# No need to create an index for ForeignKey fields except if
|
||||
# db_constraint=False because the index from that constraint won't be
|
||||
# created.
|
||||
if (
|
||||
storage == "InnoDB"
|
||||
and field.get_internal_type() == "ForeignKey"
|
||||
and field.db_constraint
|
||||
):
|
||||
return False
|
||||
return not self._is_limited_data_type(field)
|
||||
|
||||
def _create_missing_fk_index(
|
||||
self,
|
||||
model,
|
||||
*,
|
||||
fields,
|
||||
expressions=None,
|
||||
):
|
||||
"""
|
||||
MySQL can remove an implicit FK index on a field when that field is
|
||||
covered by another index like a unique_together. "covered" here means
|
||||
that the more complex index has the FK field as its first field (see
|
||||
https://bugs.mysql.com/bug.php?id=37910).
|
||||
|
||||
Manually create an implicit FK index to make it possible to remove the
|
||||
composed index.
|
||||
"""
|
||||
first_field_name = None
|
||||
if fields:
|
||||
first_field_name = fields[0]
|
||||
elif (
|
||||
expressions
|
||||
and self.connection.features.supports_expression_indexes
|
||||
and isinstance(expressions[0], F)
|
||||
and LOOKUP_SEP not in expressions[0].name
|
||||
):
|
||||
first_field_name = expressions[0].name
|
||||
|
||||
if not first_field_name:
|
||||
return
|
||||
|
||||
first_field = model._meta.get_field(first_field_name)
|
||||
if first_field.get_internal_type() == "ForeignKey":
|
||||
column = self.connection.introspection.identifier_converter(
|
||||
first_field.column
|
||||
)
|
||||
with self.connection.cursor() as cursor:
|
||||
constraint_names = [
|
||||
name
|
||||
for name, infodict in self.connection.introspection.get_constraints(
|
||||
cursor, model._meta.db_table
|
||||
).items()
|
||||
if infodict["index"] and infodict["columns"][0] == column
|
||||
]
|
||||
# There are no other indexes that starts with the FK field, only
|
||||
# the index that is expected to be deleted.
|
||||
if len(constraint_names) == 1:
|
||||
self.execute(
|
||||
self._create_index_sql(model, fields=[first_field], suffix="")
|
||||
)
|
||||
|
||||
def _delete_composed_index(self, model, fields, *args):
|
||||
self._create_missing_fk_index(model, fields=fields)
|
||||
return super()._delete_composed_index(model, fields, *args)
|
||||
|
||||
def _set_field_new_type_null_status(self, field, new_type):
|
||||
"""
|
||||
Keep the null property of the old field. If it has changed, it will be
|
||||
handled separately.
|
||||
"""
|
||||
if field.null:
|
||||
new_type += " NULL"
|
||||
else:
|
||||
new_type += " NOT NULL"
|
||||
return new_type
|
||||
|
||||
def _alter_column_type_sql(
|
||||
self, model, old_field, new_field, new_type, old_collation, new_collation
|
||||
):
|
||||
new_type = self._set_field_new_type_null_status(old_field, new_type)
|
||||
return super()._alter_column_type_sql(
|
||||
model, old_field, new_field, new_type, old_collation, new_collation
|
||||
)
|
||||
|
||||
def _rename_field_sql(self, table, old_field, new_field, new_type):
|
||||
new_type = self._set_field_new_type_null_status(old_field, new_type)
|
||||
return super()._rename_field_sql(table, old_field, new_field, new_type)
|
||||
|
||||
def _alter_column_comment_sql(self, model, new_field, new_type, new_db_comment):
|
||||
# Comment is alter when altering the column type.
|
||||
return "", []
|
||||
|
||||
def _comment_sql(self, comment):
|
||||
comment_sql = super()._comment_sql(comment)
|
||||
return f" COMMENT {comment_sql}"
|
||||
@@ -0,0 +1,77 @@
|
||||
from django.core import checks
|
||||
from django.db.backends.base.validation import BaseDatabaseValidation
|
||||
from django.utils.version import get_docs_version
|
||||
|
||||
|
||||
class DatabaseValidation(BaseDatabaseValidation):
|
||||
def check(self, **kwargs):
|
||||
issues = super().check(**kwargs)
|
||||
issues.extend(self._check_sql_mode(**kwargs))
|
||||
return issues
|
||||
|
||||
def _check_sql_mode(self, **kwargs):
|
||||
if not (
|
||||
self.connection.sql_mode & {"STRICT_TRANS_TABLES", "STRICT_ALL_TABLES"}
|
||||
):
|
||||
return [
|
||||
checks.Warning(
|
||||
"%s Strict Mode is not set for database connection '%s'"
|
||||
% (self.connection.display_name, self.connection.alias),
|
||||
hint=(
|
||||
"%s's Strict Mode fixes many data integrity problems in "
|
||||
"%s, such as data truncation upon insertion, by "
|
||||
"escalating warnings into errors. It is strongly "
|
||||
"recommended you activate it. See: "
|
||||
"https://docs.djangoproject.com/en/%s/ref/databases/"
|
||||
"#mysql-sql-mode"
|
||||
% (
|
||||
self.connection.display_name,
|
||||
self.connection.display_name,
|
||||
get_docs_version(),
|
||||
),
|
||||
),
|
||||
id="mysql.W002",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
def check_field_type(self, field, field_type):
|
||||
"""
|
||||
MySQL has the following field length restriction:
|
||||
No character (varchar) fields can have a length exceeding 255
|
||||
characters if they have a unique index on them.
|
||||
MySQL doesn't support a database index on some data types.
|
||||
"""
|
||||
errors = []
|
||||
if (
|
||||
field_type.startswith("varchar")
|
||||
and field.unique
|
||||
and (field.max_length is None or int(field.max_length) > 255)
|
||||
):
|
||||
errors.append(
|
||||
checks.Warning(
|
||||
"%s may not allow unique CharFields to have a max_length "
|
||||
"> 255." % self.connection.display_name,
|
||||
obj=field,
|
||||
hint=(
|
||||
"See: https://docs.djangoproject.com/en/%s/ref/"
|
||||
"databases/#mysql-character-fields" % get_docs_version()
|
||||
),
|
||||
id="mysql.W003",
|
||||
)
|
||||
)
|
||||
|
||||
if field.db_index and field_type.lower() in self.connection._limited_data_types:
|
||||
errors.append(
|
||||
checks.Warning(
|
||||
"%s does not support a database index on %s columns."
|
||||
% (self.connection.display_name, field_type),
|
||||
hint=(
|
||||
"An index won't be created. Silence this warning if "
|
||||
"you don't care about it."
|
||||
),
|
||||
obj=field,
|
||||
id="fields.W162",
|
||||
)
|
||||
)
|
||||
return errors
|
||||
@@ -0,0 +1,592 @@
|
||||
"""
|
||||
Oracle database backend for Django.
|
||||
|
||||
Requires cx_Oracle: https://oracle.github.io/python-cx_Oracle/
|
||||
"""
|
||||
import datetime
|
||||
import decimal
|
||||
import os
|
||||
import platform
|
||||
from contextlib import contextmanager
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db import IntegrityError
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.db.backends.utils import debug_transaction
|
||||
from django.utils.asyncio import async_unsafe
|
||||
from django.utils.encoding import force_bytes, force_str
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
||||
def _setup_environment(environ):
|
||||
# Cygwin requires some special voodoo to set the environment variables
|
||||
# properly so that Oracle will see them.
|
||||
if platform.system().upper().startswith("CYGWIN"):
|
||||
try:
|
||||
import ctypes
|
||||
except ImportError as e:
|
||||
raise ImproperlyConfigured(
|
||||
"Error loading ctypes: %s; "
|
||||
"the Oracle backend requires ctypes to "
|
||||
"operate correctly under Cygwin." % e
|
||||
)
|
||||
kernel32 = ctypes.CDLL("kernel32")
|
||||
for name, value in environ:
|
||||
kernel32.SetEnvironmentVariableA(name, value)
|
||||
else:
|
||||
os.environ.update(environ)
|
||||
|
||||
|
||||
_setup_environment(
|
||||
[
|
||||
# Oracle takes client-side character set encoding from the environment.
|
||||
("NLS_LANG", ".AL32UTF8"),
|
||||
# This prevents Unicode from getting mangled by getting encoded into the
|
||||
# potentially non-Unicode database character set.
|
||||
("ORA_NCHAR_LITERAL_REPLACE", "TRUE"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
import cx_Oracle as Database
|
||||
except ImportError as e:
|
||||
raise ImproperlyConfigured("Error loading cx_Oracle module: %s" % e)
|
||||
|
||||
# Some of these import cx_Oracle, so import them after checking if it's installed.
|
||||
from .client import DatabaseClient # NOQA
|
||||
from .creation import DatabaseCreation # NOQA
|
||||
from .features import DatabaseFeatures # NOQA
|
||||
from .introspection import DatabaseIntrospection # NOQA
|
||||
from .operations import DatabaseOperations # NOQA
|
||||
from .schema import DatabaseSchemaEditor # NOQA
|
||||
from .utils import Oracle_datetime, dsn # NOQA
|
||||
from .validation import DatabaseValidation # NOQA
|
||||
|
||||
|
||||
@contextmanager
|
||||
def wrap_oracle_errors():
|
||||
try:
|
||||
yield
|
||||
except Database.DatabaseError as e:
|
||||
# cx_Oracle raises a cx_Oracle.DatabaseError exception with the
|
||||
# following attributes and values:
|
||||
# code = 2091
|
||||
# message = 'ORA-02091: transaction rolled back
|
||||
# 'ORA-02291: integrity constraint (TEST_DJANGOTEST.SYS
|
||||
# _C00102056) violated - parent key not found'
|
||||
# or:
|
||||
# 'ORA-00001: unique constraint (DJANGOTEST.DEFERRABLE_
|
||||
# PINK_CONSTRAINT) violated
|
||||
# Convert that case to Django's IntegrityError exception.
|
||||
x = e.args[0]
|
||||
if (
|
||||
hasattr(x, "code")
|
||||
and hasattr(x, "message")
|
||||
and x.code == 2091
|
||||
and ("ORA-02291" in x.message or "ORA-00001" in x.message)
|
||||
):
|
||||
raise IntegrityError(*tuple(e.args))
|
||||
raise
|
||||
|
||||
|
||||
class _UninitializedOperatorsDescriptor:
|
||||
def __get__(self, instance, cls=None):
|
||||
# If connection.operators is looked up before a connection has been
|
||||
# created, transparently initialize connection.operators to avert an
|
||||
# AttributeError.
|
||||
if instance is None:
|
||||
raise AttributeError("operators not available as class attribute")
|
||||
# Creating a cursor will initialize the operators.
|
||||
instance.cursor().close()
|
||||
return instance.__dict__["operators"]
|
||||
|
||||
|
||||
class DatabaseWrapper(BaseDatabaseWrapper):
|
||||
vendor = "oracle"
|
||||
display_name = "Oracle"
|
||||
# This dictionary maps Field objects to their associated Oracle column
|
||||
# types, as strings. Column-type strings can contain format strings; they'll
|
||||
# be interpolated against the values of Field.__dict__ before being output.
|
||||
# If a column type is set to None, it won't be included in the output.
|
||||
#
|
||||
# Any format strings starting with "qn_" are quoted before being used in the
|
||||
# output (the "qn_" prefix is stripped before the lookup is performed.
|
||||
data_types = {
|
||||
"AutoField": "NUMBER(11) GENERATED BY DEFAULT ON NULL AS IDENTITY",
|
||||
"BigAutoField": "NUMBER(19) GENERATED BY DEFAULT ON NULL AS IDENTITY",
|
||||
"BinaryField": "BLOB",
|
||||
"BooleanField": "NUMBER(1)",
|
||||
"CharField": "NVARCHAR2(%(max_length)s)",
|
||||
"DateField": "DATE",
|
||||
"DateTimeField": "TIMESTAMP",
|
||||
"DecimalField": "NUMBER(%(max_digits)s, %(decimal_places)s)",
|
||||
"DurationField": "INTERVAL DAY(9) TO SECOND(6)",
|
||||
"FileField": "NVARCHAR2(%(max_length)s)",
|
||||
"FilePathField": "NVARCHAR2(%(max_length)s)",
|
||||
"FloatField": "DOUBLE PRECISION",
|
||||
"IntegerField": "NUMBER(11)",
|
||||
"JSONField": "NCLOB",
|
||||
"BigIntegerField": "NUMBER(19)",
|
||||
"IPAddressField": "VARCHAR2(15)",
|
||||
"GenericIPAddressField": "VARCHAR2(39)",
|
||||
"OneToOneField": "NUMBER(11)",
|
||||
"PositiveBigIntegerField": "NUMBER(19)",
|
||||
"PositiveIntegerField": "NUMBER(11)",
|
||||
"PositiveSmallIntegerField": "NUMBER(11)",
|
||||
"SlugField": "NVARCHAR2(%(max_length)s)",
|
||||
"SmallAutoField": "NUMBER(5) GENERATED BY DEFAULT ON NULL AS IDENTITY",
|
||||
"SmallIntegerField": "NUMBER(11)",
|
||||
"TextField": "NCLOB",
|
||||
"TimeField": "TIMESTAMP",
|
||||
"URLField": "VARCHAR2(%(max_length)s)",
|
||||
"UUIDField": "VARCHAR2(32)",
|
||||
}
|
||||
data_type_check_constraints = {
|
||||
"BooleanField": "%(qn_column)s IN (0,1)",
|
||||
"JSONField": "%(qn_column)s IS JSON",
|
||||
"PositiveBigIntegerField": "%(qn_column)s >= 0",
|
||||
"PositiveIntegerField": "%(qn_column)s >= 0",
|
||||
"PositiveSmallIntegerField": "%(qn_column)s >= 0",
|
||||
}
|
||||
|
||||
# Oracle doesn't support a database index on these columns.
|
||||
_limited_data_types = ("clob", "nclob", "blob")
|
||||
|
||||
operators = _UninitializedOperatorsDescriptor()
|
||||
|
||||
_standard_operators = {
|
||||
"exact": "= %s",
|
||||
"iexact": "= UPPER(%s)",
|
||||
"contains": (
|
||||
"LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)"
|
||||
),
|
||||
"icontains": (
|
||||
"LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) "
|
||||
"ESCAPE TRANSLATE('\\' USING NCHAR_CS)"
|
||||
),
|
||||
"gt": "> %s",
|
||||
"gte": ">= %s",
|
||||
"lt": "< %s",
|
||||
"lte": "<= %s",
|
||||
"startswith": (
|
||||
"LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)"
|
||||
),
|
||||
"endswith": (
|
||||
"LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)"
|
||||
),
|
||||
"istartswith": (
|
||||
"LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) "
|
||||
"ESCAPE TRANSLATE('\\' USING NCHAR_CS)"
|
||||
),
|
||||
"iendswith": (
|
||||
"LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) "
|
||||
"ESCAPE TRANSLATE('\\' USING NCHAR_CS)"
|
||||
),
|
||||
}
|
||||
|
||||
_likec_operators = {
|
||||
**_standard_operators,
|
||||
"contains": "LIKEC %s ESCAPE '\\'",
|
||||
"icontains": "LIKEC UPPER(%s) ESCAPE '\\'",
|
||||
"startswith": "LIKEC %s ESCAPE '\\'",
|
||||
"endswith": "LIKEC %s ESCAPE '\\'",
|
||||
"istartswith": "LIKEC UPPER(%s) ESCAPE '\\'",
|
||||
"iendswith": "LIKEC UPPER(%s) ESCAPE '\\'",
|
||||
}
|
||||
|
||||
# The patterns below are used to generate SQL pattern lookup clauses when
|
||||
# the right-hand side of the lookup isn't a raw string (it might be an expression
|
||||
# or the result of a bilateral transformation).
|
||||
# In those cases, special characters for LIKE operators (e.g. \, %, _)
|
||||
# should be escaped on the database side.
|
||||
#
|
||||
# Note: we use str.format() here for readability as '%' is used as a wildcard for
|
||||
# the LIKE operator.
|
||||
pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\', '\\'), '%%', '\%%'), '_', '\_')"
|
||||
_pattern_ops = {
|
||||
"contains": "'%%' || {} || '%%'",
|
||||
"icontains": "'%%' || UPPER({}) || '%%'",
|
||||
"startswith": "{} || '%%'",
|
||||
"istartswith": "UPPER({}) || '%%'",
|
||||
"endswith": "'%%' || {}",
|
||||
"iendswith": "'%%' || UPPER({})",
|
||||
}
|
||||
|
||||
_standard_pattern_ops = {
|
||||
k: "LIKE TRANSLATE( " + v + " USING NCHAR_CS)"
|
||||
" ESCAPE TRANSLATE('\\' USING NCHAR_CS)"
|
||||
for k, v in _pattern_ops.items()
|
||||
}
|
||||
_likec_pattern_ops = {
|
||||
k: "LIKEC " + v + " ESCAPE '\\'" for k, v in _pattern_ops.items()
|
||||
}
|
||||
|
||||
Database = Database
|
||||
SchemaEditorClass = DatabaseSchemaEditor
|
||||
# Classes instantiated in __init__().
|
||||
client_class = DatabaseClient
|
||||
creation_class = DatabaseCreation
|
||||
features_class = DatabaseFeatures
|
||||
introspection_class = DatabaseIntrospection
|
||||
ops_class = DatabaseOperations
|
||||
validation_class = DatabaseValidation
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
use_returning_into = self.settings_dict["OPTIONS"].get(
|
||||
"use_returning_into", True
|
||||
)
|
||||
self.features.can_return_columns_from_insert = use_returning_into
|
||||
|
||||
def get_database_version(self):
|
||||
return self.oracle_version
|
||||
|
||||
def get_connection_params(self):
|
||||
conn_params = self.settings_dict["OPTIONS"].copy()
|
||||
if "use_returning_into" in conn_params:
|
||||
del conn_params["use_returning_into"]
|
||||
return conn_params
|
||||
|
||||
@async_unsafe
|
||||
def get_new_connection(self, conn_params):
|
||||
return Database.connect(
|
||||
user=self.settings_dict["USER"],
|
||||
password=self.settings_dict["PASSWORD"],
|
||||
dsn=dsn(self.settings_dict),
|
||||
**conn_params,
|
||||
)
|
||||
|
||||
def init_connection_state(self):
|
||||
super().init_connection_state()
|
||||
cursor = self.create_cursor()
|
||||
# Set the territory first. The territory overrides NLS_DATE_FORMAT
|
||||
# and NLS_TIMESTAMP_FORMAT to the territory default. When all of
|
||||
# these are set in single statement it isn't clear what is supposed
|
||||
# to happen.
|
||||
cursor.execute("ALTER SESSION SET NLS_TERRITORY = 'AMERICA'")
|
||||
# Set Oracle date to ANSI date format. This only needs to execute
|
||||
# once when we create a new connection. We also set the Territory
|
||||
# to 'AMERICA' which forces Sunday to evaluate to a '1' in
|
||||
# TO_CHAR().
|
||||
cursor.execute(
|
||||
"ALTER SESSION SET NLS_DATE_FORMAT = 'YYYY-MM-DD HH24:MI:SS'"
|
||||
" NLS_TIMESTAMP_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF'"
|
||||
+ (" TIME_ZONE = 'UTC'" if settings.USE_TZ else "")
|
||||
)
|
||||
cursor.close()
|
||||
if "operators" not in self.__dict__:
|
||||
# Ticket #14149: Check whether our LIKE implementation will
|
||||
# work for this connection or we need to fall back on LIKEC.
|
||||
# This check is performed only once per DatabaseWrapper
|
||||
# instance per thread, since subsequent connections will use
|
||||
# the same settings.
|
||||
cursor = self.create_cursor()
|
||||
try:
|
||||
cursor.execute(
|
||||
"SELECT 1 FROM DUAL WHERE DUMMY %s"
|
||||
% self._standard_operators["contains"],
|
||||
["X"],
|
||||
)
|
||||
except Database.DatabaseError:
|
||||
self.operators = self._likec_operators
|
||||
self.pattern_ops = self._likec_pattern_ops
|
||||
else:
|
||||
self.operators = self._standard_operators
|
||||
self.pattern_ops = self._standard_pattern_ops
|
||||
cursor.close()
|
||||
self.connection.stmtcachesize = 20
|
||||
# Ensure all changes are preserved even when AUTOCOMMIT is False.
|
||||
if not self.get_autocommit():
|
||||
self.commit()
|
||||
|
||||
@async_unsafe
|
||||
def create_cursor(self, name=None):
|
||||
return FormatStylePlaceholderCursor(self.connection)
|
||||
|
||||
def _commit(self):
|
||||
if self.connection is not None:
|
||||
with debug_transaction(self, "COMMIT"), wrap_oracle_errors():
|
||||
return self.connection.commit()
|
||||
|
||||
# Oracle doesn't support releasing savepoints. But we fake them when query
|
||||
# logging is enabled to keep query counts consistent with other backends.
|
||||
def _savepoint_commit(self, sid):
|
||||
if self.queries_logged:
|
||||
self.queries_log.append(
|
||||
{
|
||||
"sql": "-- RELEASE SAVEPOINT %s (faked)" % self.ops.quote_name(sid),
|
||||
"time": "0.000",
|
||||
}
|
||||
)
|
||||
|
||||
def _set_autocommit(self, autocommit):
|
||||
with self.wrap_database_errors:
|
||||
self.connection.autocommit = autocommit
|
||||
|
||||
def check_constraints(self, table_names=None):
|
||||
"""
|
||||
Check constraints by setting them to immediate. Return them to deferred
|
||||
afterward.
|
||||
"""
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute("SET CONSTRAINTS ALL IMMEDIATE")
|
||||
cursor.execute("SET CONSTRAINTS ALL DEFERRED")
|
||||
|
||||
def is_usable(self):
|
||||
try:
|
||||
self.connection.ping()
|
||||
except Database.Error:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
@cached_property
|
||||
def cx_oracle_version(self):
|
||||
return tuple(int(x) for x in Database.version.split("."))
|
||||
|
||||
@cached_property
|
||||
def oracle_version(self):
|
||||
with self.temporary_connection():
|
||||
return tuple(int(x) for x in self.connection.version.split("."))
|
||||
|
||||
|
||||
class OracleParam:
|
||||
"""
|
||||
Wrapper object for formatting parameters for Oracle. If the string
|
||||
representation of the value is large enough (greater than 4000 characters)
|
||||
the input size needs to be set as CLOB. Alternatively, if the parameter
|
||||
has an `input_size` attribute, then the value of the `input_size` attribute
|
||||
will be used instead. Otherwise, no input size will be set for the
|
||||
parameter when executing the query.
|
||||
"""
|
||||
|
||||
def __init__(self, param, cursor, strings_only=False):
|
||||
# With raw SQL queries, datetimes can reach this function
|
||||
# without being converted by DateTimeField.get_db_prep_value.
|
||||
if settings.USE_TZ and (
|
||||
isinstance(param, datetime.datetime)
|
||||
and not isinstance(param, Oracle_datetime)
|
||||
):
|
||||
param = Oracle_datetime.from_datetime(param)
|
||||
|
||||
string_size = 0
|
||||
# Oracle doesn't recognize True and False correctly.
|
||||
if param is True:
|
||||
param = 1
|
||||
elif param is False:
|
||||
param = 0
|
||||
if hasattr(param, "bind_parameter"):
|
||||
self.force_bytes = param.bind_parameter(cursor)
|
||||
elif isinstance(param, (Database.Binary, datetime.timedelta)):
|
||||
self.force_bytes = param
|
||||
else:
|
||||
# To transmit to the database, we need Unicode if supported
|
||||
# To get size right, we must consider bytes.
|
||||
self.force_bytes = force_str(param, cursor.charset, strings_only)
|
||||
if isinstance(self.force_bytes, str):
|
||||
# We could optimize by only converting up to 4000 bytes here
|
||||
string_size = len(force_bytes(param, cursor.charset, strings_only))
|
||||
if hasattr(param, "input_size"):
|
||||
# If parameter has `input_size` attribute, use that.
|
||||
self.input_size = param.input_size
|
||||
elif string_size > 4000:
|
||||
# Mark any string param greater than 4000 characters as a CLOB.
|
||||
self.input_size = Database.CLOB
|
||||
elif isinstance(param, datetime.datetime):
|
||||
self.input_size = Database.TIMESTAMP
|
||||
else:
|
||||
self.input_size = None
|
||||
|
||||
|
||||
class VariableWrapper:
|
||||
"""
|
||||
An adapter class for cursor variables that prevents the wrapped object
|
||||
from being converted into a string when used to instantiate an OracleParam.
|
||||
This can be used generally for any other object that should be passed into
|
||||
Cursor.execute as-is.
|
||||
"""
|
||||
|
||||
def __init__(self, var):
|
||||
self.var = var
|
||||
|
||||
def bind_parameter(self, cursor):
|
||||
return self.var
|
||||
|
||||
def __getattr__(self, key):
|
||||
return getattr(self.var, key)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key == "var":
|
||||
self.__dict__[key] = value
|
||||
else:
|
||||
setattr(self.var, key, value)
|
||||
|
||||
|
||||
class FormatStylePlaceholderCursor:
|
||||
"""
|
||||
Django uses "format" (e.g. '%s') style placeholders, but Oracle uses ":var"
|
||||
style. This fixes it -- but note that if you want to use a literal "%s" in
|
||||
a query, you'll need to use "%%s".
|
||||
"""
|
||||
|
||||
charset = "utf-8"
|
||||
|
||||
def __init__(self, connection):
|
||||
self.cursor = connection.cursor()
|
||||
self.cursor.outputtypehandler = self._output_type_handler
|
||||
|
||||
@staticmethod
|
||||
def _output_number_converter(value):
|
||||
return decimal.Decimal(value) if "." in value else int(value)
|
||||
|
||||
@staticmethod
|
||||
def _get_decimal_converter(precision, scale):
|
||||
if scale == 0:
|
||||
return int
|
||||
context = decimal.Context(prec=precision)
|
||||
quantize_value = decimal.Decimal(1).scaleb(-scale)
|
||||
return lambda v: decimal.Decimal(v).quantize(quantize_value, context=context)
|
||||
|
||||
@staticmethod
|
||||
def _output_type_handler(cursor, name, defaultType, length, precision, scale):
|
||||
"""
|
||||
Called for each db column fetched from cursors. Return numbers as the
|
||||
appropriate Python type.
|
||||
"""
|
||||
if defaultType == Database.NUMBER:
|
||||
if scale == -127:
|
||||
if precision == 0:
|
||||
# NUMBER column: decimal-precision floating point.
|
||||
# This will normally be an integer from a sequence,
|
||||
# but it could be a decimal value.
|
||||
outconverter = FormatStylePlaceholderCursor._output_number_converter
|
||||
else:
|
||||
# FLOAT column: binary-precision floating point.
|
||||
# This comes from FloatField columns.
|
||||
outconverter = float
|
||||
elif precision > 0:
|
||||
# NUMBER(p,s) column: decimal-precision fixed point.
|
||||
# This comes from IntegerField and DecimalField columns.
|
||||
outconverter = FormatStylePlaceholderCursor._get_decimal_converter(
|
||||
precision, scale
|
||||
)
|
||||
else:
|
||||
# No type information. This normally comes from a
|
||||
# mathematical expression in the SELECT list. Guess int
|
||||
# or Decimal based on whether it has a decimal point.
|
||||
outconverter = FormatStylePlaceholderCursor._output_number_converter
|
||||
return cursor.var(
|
||||
Database.STRING,
|
||||
size=255,
|
||||
arraysize=cursor.arraysize,
|
||||
outconverter=outconverter,
|
||||
)
|
||||
|
||||
def _format_params(self, params):
|
||||
try:
|
||||
return {k: OracleParam(v, self, True) for k, v in params.items()}
|
||||
except AttributeError:
|
||||
return tuple(OracleParam(p, self, True) for p in params)
|
||||
|
||||
def _guess_input_sizes(self, params_list):
|
||||
# Try dict handling; if that fails, treat as sequence
|
||||
if hasattr(params_list[0], "keys"):
|
||||
sizes = {}
|
||||
for params in params_list:
|
||||
for k, value in params.items():
|
||||
if value.input_size:
|
||||
sizes[k] = value.input_size
|
||||
if sizes:
|
||||
self.setinputsizes(**sizes)
|
||||
else:
|
||||
# It's not a list of dicts; it's a list of sequences
|
||||
sizes = [None] * len(params_list[0])
|
||||
for params in params_list:
|
||||
for i, value in enumerate(params):
|
||||
if value.input_size:
|
||||
sizes[i] = value.input_size
|
||||
if sizes:
|
||||
self.setinputsizes(*sizes)
|
||||
|
||||
def _param_generator(self, params):
|
||||
# Try dict handling; if that fails, treat as sequence
|
||||
if hasattr(params, "items"):
|
||||
return {k: v.force_bytes for k, v in params.items()}
|
||||
else:
|
||||
return [p.force_bytes for p in params]
|
||||
|
||||
def _fix_for_params(self, query, params, unify_by_values=False):
|
||||
# cx_Oracle wants no trailing ';' for SQL statements. For PL/SQL, it
|
||||
# it does want a trailing ';' but not a trailing '/'. However, these
|
||||
# characters must be included in the original query in case the query
|
||||
# is being passed to SQL*Plus.
|
||||
if query.endswith(";") or query.endswith("/"):
|
||||
query = query[:-1]
|
||||
if params is None:
|
||||
params = []
|
||||
elif hasattr(params, "keys"):
|
||||
# Handle params as dict
|
||||
args = {k: ":%s" % k for k in params}
|
||||
query %= args
|
||||
elif unify_by_values and params:
|
||||
# Handle params as a dict with unified query parameters by their
|
||||
# values. It can be used only in single query execute() because
|
||||
# executemany() shares the formatted query with each of the params
|
||||
# list. e.g. for input params = [0.75, 2, 0.75, 'sth', 0.75]
|
||||
# params_dict = {0.75: ':arg0', 2: ':arg1', 'sth': ':arg2'}
|
||||
# args = [':arg0', ':arg1', ':arg0', ':arg2', ':arg0']
|
||||
# params = {':arg0': 0.75, ':arg1': 2, ':arg2': 'sth'}
|
||||
params_dict = {
|
||||
param: ":arg%d" % i for i, param in enumerate(dict.fromkeys(params))
|
||||
}
|
||||
args = [params_dict[param] for param in params]
|
||||
params = {value: key for key, value in params_dict.items()}
|
||||
query %= tuple(args)
|
||||
else:
|
||||
# Handle params as sequence
|
||||
args = [(":arg%d" % i) for i in range(len(params))]
|
||||
query %= tuple(args)
|
||||
return query, self._format_params(params)
|
||||
|
||||
def execute(self, query, params=None):
|
||||
query, params = self._fix_for_params(query, params, unify_by_values=True)
|
||||
self._guess_input_sizes([params])
|
||||
with wrap_oracle_errors():
|
||||
return self.cursor.execute(query, self._param_generator(params))
|
||||
|
||||
def executemany(self, query, params=None):
|
||||
if not params:
|
||||
# No params given, nothing to do
|
||||
return None
|
||||
# uniform treatment for sequences and iterables
|
||||
params_iter = iter(params)
|
||||
query, firstparams = self._fix_for_params(query, next(params_iter))
|
||||
# we build a list of formatted params; as we're going to traverse it
|
||||
# more than once, we can't make it lazy by using a generator
|
||||
formatted = [firstparams] + [self._format_params(p) for p in params_iter]
|
||||
self._guess_input_sizes(formatted)
|
||||
with wrap_oracle_errors():
|
||||
return self.cursor.executemany(
|
||||
query, [self._param_generator(p) for p in formatted]
|
||||
)
|
||||
|
||||
def close(self):
|
||||
try:
|
||||
self.cursor.close()
|
||||
except Database.InterfaceError:
|
||||
# already closed
|
||||
pass
|
||||
|
||||
def var(self, *args):
|
||||
return VariableWrapper(self.cursor.var(*args))
|
||||
|
||||
def arrayvar(self, *args):
|
||||
return VariableWrapper(self.cursor.arrayvar(*args))
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.cursor, attr)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.cursor)
|
||||
@@ -0,0 +1,27 @@
|
||||
import shutil
|
||||
|
||||
from django.db.backends.base.client import BaseDatabaseClient
|
||||
|
||||
|
||||
class DatabaseClient(BaseDatabaseClient):
|
||||
executable_name = "sqlplus"
|
||||
wrapper_name = "rlwrap"
|
||||
|
||||
@staticmethod
|
||||
def connect_string(settings_dict):
|
||||
from django.db.backends.oracle.utils import dsn
|
||||
|
||||
return '%s/"%s"@%s' % (
|
||||
settings_dict["USER"],
|
||||
settings_dict["PASSWORD"],
|
||||
dsn(settings_dict),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def settings_to_cmd_args_env(cls, settings_dict, parameters):
|
||||
args = [cls.executable_name, "-L", cls.connect_string(settings_dict)]
|
||||
wrapper_path = shutil.which(cls.wrapper_name)
|
||||
if wrapper_path:
|
||||
args = [wrapper_path, *args]
|
||||
args.extend(parameters)
|
||||
return args, None
|
||||
@@ -0,0 +1,464 @@
|
||||
import sys
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import DatabaseError
|
||||
from django.db.backends.base.creation import BaseDatabaseCreation
|
||||
from django.utils.crypto import get_random_string
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
TEST_DATABASE_PREFIX = "test_"
|
||||
|
||||
|
||||
class DatabaseCreation(BaseDatabaseCreation):
|
||||
@cached_property
|
||||
def _maindb_connection(self):
|
||||
"""
|
||||
This is analogous to other backends' `_nodb_connection` property,
|
||||
which allows access to an "administrative" connection which can
|
||||
be used to manage the test databases.
|
||||
For Oracle, the only connection that can be used for that purpose
|
||||
is the main (non-test) connection.
|
||||
"""
|
||||
settings_dict = settings.DATABASES[self.connection.alias]
|
||||
user = settings_dict.get("SAVED_USER") or settings_dict["USER"]
|
||||
password = settings_dict.get("SAVED_PASSWORD") or settings_dict["PASSWORD"]
|
||||
settings_dict = {**settings_dict, "USER": user, "PASSWORD": password}
|
||||
DatabaseWrapper = type(self.connection)
|
||||
return DatabaseWrapper(settings_dict, alias=self.connection.alias)
|
||||
|
||||
def _create_test_db(self, verbosity=1, autoclobber=False, keepdb=False):
|
||||
parameters = self._get_test_db_params()
|
||||
with self._maindb_connection.cursor() as cursor:
|
||||
if self._test_database_create():
|
||||
try:
|
||||
self._execute_test_db_creation(
|
||||
cursor, parameters, verbosity, keepdb
|
||||
)
|
||||
except Exception as e:
|
||||
if "ORA-01543" not in str(e):
|
||||
# All errors except "tablespace already exists" cancel tests
|
||||
self.log("Got an error creating the test database: %s" % e)
|
||||
sys.exit(2)
|
||||
if not autoclobber:
|
||||
confirm = input(
|
||||
"It appears the test database, %s, already exists. "
|
||||
"Type 'yes' to delete it, or 'no' to cancel: "
|
||||
% parameters["user"]
|
||||
)
|
||||
if autoclobber or confirm == "yes":
|
||||
if verbosity >= 1:
|
||||
self.log(
|
||||
"Destroying old test database for alias '%s'..."
|
||||
% self.connection.alias
|
||||
)
|
||||
try:
|
||||
self._execute_test_db_destruction(
|
||||
cursor, parameters, verbosity
|
||||
)
|
||||
except DatabaseError as e:
|
||||
if "ORA-29857" in str(e):
|
||||
self._handle_objects_preventing_db_destruction(
|
||||
cursor, parameters, verbosity, autoclobber
|
||||
)
|
||||
else:
|
||||
# Ran into a database error that isn't about
|
||||
# leftover objects in the tablespace.
|
||||
self.log(
|
||||
"Got an error destroying the old test database: %s"
|
||||
% e
|
||||
)
|
||||
sys.exit(2)
|
||||
except Exception as e:
|
||||
self.log(
|
||||
"Got an error destroying the old test database: %s" % e
|
||||
)
|
||||
sys.exit(2)
|
||||
try:
|
||||
self._execute_test_db_creation(
|
||||
cursor, parameters, verbosity, keepdb
|
||||
)
|
||||
except Exception as e:
|
||||
self.log(
|
||||
"Got an error recreating the test database: %s" % e
|
||||
)
|
||||
sys.exit(2)
|
||||
else:
|
||||
self.log("Tests cancelled.")
|
||||
sys.exit(1)
|
||||
|
||||
if self._test_user_create():
|
||||
if verbosity >= 1:
|
||||
self.log("Creating test user...")
|
||||
try:
|
||||
self._create_test_user(cursor, parameters, verbosity, keepdb)
|
||||
except Exception as e:
|
||||
if "ORA-01920" not in str(e):
|
||||
# All errors except "user already exists" cancel tests
|
||||
self.log("Got an error creating the test user: %s" % e)
|
||||
sys.exit(2)
|
||||
if not autoclobber:
|
||||
confirm = input(
|
||||
"It appears the test user, %s, already exists. Type "
|
||||
"'yes' to delete it, or 'no' to cancel: "
|
||||
% parameters["user"]
|
||||
)
|
||||
if autoclobber or confirm == "yes":
|
||||
try:
|
||||
if verbosity >= 1:
|
||||
self.log("Destroying old test user...")
|
||||
self._destroy_test_user(cursor, parameters, verbosity)
|
||||
if verbosity >= 1:
|
||||
self.log("Creating test user...")
|
||||
self._create_test_user(
|
||||
cursor, parameters, verbosity, keepdb
|
||||
)
|
||||
except Exception as e:
|
||||
self.log("Got an error recreating the test user: %s" % e)
|
||||
sys.exit(2)
|
||||
else:
|
||||
self.log("Tests cancelled.")
|
||||
sys.exit(1)
|
||||
# Done with main user -- test user and tablespaces created.
|
||||
self._maindb_connection.close()
|
||||
self._switch_to_test_user(parameters)
|
||||
return self.connection.settings_dict["NAME"]
|
||||
|
||||
def _switch_to_test_user(self, parameters):
|
||||
"""
|
||||
Switch to the user that's used for creating the test database.
|
||||
|
||||
Oracle doesn't have the concept of separate databases under the same
|
||||
user, so a separate user is used; see _create_test_db(). The main user
|
||||
is also needed for cleanup when testing is completed, so save its
|
||||
credentials in the SAVED_USER/SAVED_PASSWORD key in the settings dict.
|
||||
"""
|
||||
real_settings = settings.DATABASES[self.connection.alias]
|
||||
real_settings["SAVED_USER"] = self.connection.settings_dict[
|
||||
"SAVED_USER"
|
||||
] = self.connection.settings_dict["USER"]
|
||||
real_settings["SAVED_PASSWORD"] = self.connection.settings_dict[
|
||||
"SAVED_PASSWORD"
|
||||
] = self.connection.settings_dict["PASSWORD"]
|
||||
real_test_settings = real_settings["TEST"]
|
||||
test_settings = self.connection.settings_dict["TEST"]
|
||||
real_test_settings["USER"] = real_settings["USER"] = test_settings[
|
||||
"USER"
|
||||
] = self.connection.settings_dict["USER"] = parameters["user"]
|
||||
real_settings["PASSWORD"] = self.connection.settings_dict[
|
||||
"PASSWORD"
|
||||
] = parameters["password"]
|
||||
|
||||
def set_as_test_mirror(self, primary_settings_dict):
|
||||
"""
|
||||
Set this database up to be used in testing as a mirror of a primary
|
||||
database whose settings are given.
|
||||
"""
|
||||
self.connection.settings_dict["USER"] = primary_settings_dict["USER"]
|
||||
self.connection.settings_dict["PASSWORD"] = primary_settings_dict["PASSWORD"]
|
||||
|
||||
def _handle_objects_preventing_db_destruction(
|
||||
self, cursor, parameters, verbosity, autoclobber
|
||||
):
|
||||
# There are objects in the test tablespace which prevent dropping it
|
||||
# The easy fix is to drop the test user -- but are we allowed to do so?
|
||||
self.log(
|
||||
"There are objects in the old test database which prevent its destruction."
|
||||
"\nIf they belong to the test user, deleting the user will allow the test "
|
||||
"database to be recreated.\n"
|
||||
"Otherwise, you will need to find and remove each of these objects, "
|
||||
"or use a different tablespace.\n"
|
||||
)
|
||||
if self._test_user_create():
|
||||
if not autoclobber:
|
||||
confirm = input("Type 'yes' to delete user %s: " % parameters["user"])
|
||||
if autoclobber or confirm == "yes":
|
||||
try:
|
||||
if verbosity >= 1:
|
||||
self.log("Destroying old test user...")
|
||||
self._destroy_test_user(cursor, parameters, verbosity)
|
||||
except Exception as e:
|
||||
self.log("Got an error destroying the test user: %s" % e)
|
||||
sys.exit(2)
|
||||
try:
|
||||
if verbosity >= 1:
|
||||
self.log(
|
||||
"Destroying old test database for alias '%s'..."
|
||||
% self.connection.alias
|
||||
)
|
||||
self._execute_test_db_destruction(cursor, parameters, verbosity)
|
||||
except Exception as e:
|
||||
self.log("Got an error destroying the test database: %s" % e)
|
||||
sys.exit(2)
|
||||
else:
|
||||
self.log("Tests cancelled -- test database cannot be recreated.")
|
||||
sys.exit(1)
|
||||
else:
|
||||
self.log(
|
||||
"Django is configured to use pre-existing test user '%s',"
|
||||
" and will not attempt to delete it." % parameters["user"]
|
||||
)
|
||||
self.log("Tests cancelled -- test database cannot be recreated.")
|
||||
sys.exit(1)
|
||||
|
||||
def _destroy_test_db(self, test_database_name, verbosity=1):
|
||||
"""
|
||||
Destroy a test database, prompting the user for confirmation if the
|
||||
database already exists. Return the name of the test database created.
|
||||
"""
|
||||
self.connection.settings_dict["USER"] = self.connection.settings_dict[
|
||||
"SAVED_USER"
|
||||
]
|
||||
self.connection.settings_dict["PASSWORD"] = self.connection.settings_dict[
|
||||
"SAVED_PASSWORD"
|
||||
]
|
||||
self.connection.close()
|
||||
parameters = self._get_test_db_params()
|
||||
with self._maindb_connection.cursor() as cursor:
|
||||
if self._test_user_create():
|
||||
if verbosity >= 1:
|
||||
self.log("Destroying test user...")
|
||||
self._destroy_test_user(cursor, parameters, verbosity)
|
||||
if self._test_database_create():
|
||||
if verbosity >= 1:
|
||||
self.log("Destroying test database tables...")
|
||||
self._execute_test_db_destruction(cursor, parameters, verbosity)
|
||||
self._maindb_connection.close()
|
||||
|
||||
def _execute_test_db_creation(self, cursor, parameters, verbosity, keepdb=False):
|
||||
if verbosity >= 2:
|
||||
self.log("_create_test_db(): dbname = %s" % parameters["user"])
|
||||
if self._test_database_oracle_managed_files():
|
||||
statements = [
|
||||
"""
|
||||
CREATE TABLESPACE %(tblspace)s
|
||||
DATAFILE SIZE %(size)s
|
||||
AUTOEXTEND ON NEXT %(extsize)s MAXSIZE %(maxsize)s
|
||||
""",
|
||||
"""
|
||||
CREATE TEMPORARY TABLESPACE %(tblspace_temp)s
|
||||
TEMPFILE SIZE %(size_tmp)s
|
||||
AUTOEXTEND ON NEXT %(extsize_tmp)s MAXSIZE %(maxsize_tmp)s
|
||||
""",
|
||||
]
|
||||
else:
|
||||
statements = [
|
||||
"""
|
||||
CREATE TABLESPACE %(tblspace)s
|
||||
DATAFILE '%(datafile)s' SIZE %(size)s REUSE
|
||||
AUTOEXTEND ON NEXT %(extsize)s MAXSIZE %(maxsize)s
|
||||
""",
|
||||
"""
|
||||
CREATE TEMPORARY TABLESPACE %(tblspace_temp)s
|
||||
TEMPFILE '%(datafile_tmp)s' SIZE %(size_tmp)s REUSE
|
||||
AUTOEXTEND ON NEXT %(extsize_tmp)s MAXSIZE %(maxsize_tmp)s
|
||||
""",
|
||||
]
|
||||
# Ignore "tablespace already exists" error when keepdb is on.
|
||||
acceptable_ora_err = "ORA-01543" if keepdb else None
|
||||
self._execute_allow_fail_statements(
|
||||
cursor, statements, parameters, verbosity, acceptable_ora_err
|
||||
)
|
||||
|
||||
def _create_test_user(self, cursor, parameters, verbosity, keepdb=False):
|
||||
if verbosity >= 2:
|
||||
self.log("_create_test_user(): username = %s" % parameters["user"])
|
||||
statements = [
|
||||
"""CREATE USER %(user)s
|
||||
IDENTIFIED BY "%(password)s"
|
||||
DEFAULT TABLESPACE %(tblspace)s
|
||||
TEMPORARY TABLESPACE %(tblspace_temp)s
|
||||
QUOTA UNLIMITED ON %(tblspace)s
|
||||
""",
|
||||
"""GRANT CREATE SESSION,
|
||||
CREATE TABLE,
|
||||
CREATE SEQUENCE,
|
||||
CREATE PROCEDURE,
|
||||
CREATE TRIGGER
|
||||
TO %(user)s""",
|
||||
]
|
||||
# Ignore "user already exists" error when keepdb is on
|
||||
acceptable_ora_err = "ORA-01920" if keepdb else None
|
||||
success = self._execute_allow_fail_statements(
|
||||
cursor, statements, parameters, verbosity, acceptable_ora_err
|
||||
)
|
||||
# If the password was randomly generated, change the user accordingly.
|
||||
if not success and self._test_settings_get("PASSWORD") is None:
|
||||
set_password = 'ALTER USER %(user)s IDENTIFIED BY "%(password)s"'
|
||||
self._execute_statements(cursor, [set_password], parameters, verbosity)
|
||||
# Most test suites can be run without "create view" and
|
||||
# "create materialized view" privileges. But some need it.
|
||||
for object_type in ("VIEW", "MATERIALIZED VIEW"):
|
||||
extra = "GRANT CREATE %(object_type)s TO %(user)s"
|
||||
parameters["object_type"] = object_type
|
||||
success = self._execute_allow_fail_statements(
|
||||
cursor, [extra], parameters, verbosity, "ORA-01031"
|
||||
)
|
||||
if not success and verbosity >= 2:
|
||||
self.log(
|
||||
"Failed to grant CREATE %s permission to test user. This may be ok."
|
||||
% object_type
|
||||
)
|
||||
|
||||
def _execute_test_db_destruction(self, cursor, parameters, verbosity):
|
||||
if verbosity >= 2:
|
||||
self.log("_execute_test_db_destruction(): dbname=%s" % parameters["user"])
|
||||
statements = [
|
||||
"DROP TABLESPACE %(tblspace)s "
|
||||
"INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS",
|
||||
"DROP TABLESPACE %(tblspace_temp)s "
|
||||
"INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS",
|
||||
]
|
||||
self._execute_statements(cursor, statements, parameters, verbosity)
|
||||
|
||||
def _destroy_test_user(self, cursor, parameters, verbosity):
|
||||
if verbosity >= 2:
|
||||
self.log("_destroy_test_user(): user=%s" % parameters["user"])
|
||||
self.log("Be patient. This can take some time...")
|
||||
statements = [
|
||||
"DROP USER %(user)s CASCADE",
|
||||
]
|
||||
self._execute_statements(cursor, statements, parameters, verbosity)
|
||||
|
||||
def _execute_statements(
|
||||
self, cursor, statements, parameters, verbosity, allow_quiet_fail=False
|
||||
):
|
||||
for template in statements:
|
||||
stmt = template % parameters
|
||||
if verbosity >= 2:
|
||||
print(stmt)
|
||||
try:
|
||||
cursor.execute(stmt)
|
||||
except Exception as err:
|
||||
if (not allow_quiet_fail) or verbosity >= 2:
|
||||
self.log("Failed (%s)" % (err))
|
||||
raise
|
||||
|
||||
def _execute_allow_fail_statements(
|
||||
self, cursor, statements, parameters, verbosity, acceptable_ora_err
|
||||
):
|
||||
"""
|
||||
Execute statements which are allowed to fail silently if the Oracle
|
||||
error code given by `acceptable_ora_err` is raised. Return True if the
|
||||
statements execute without an exception, or False otherwise.
|
||||
"""
|
||||
try:
|
||||
# Statement can fail when acceptable_ora_err is not None
|
||||
allow_quiet_fail = (
|
||||
acceptable_ora_err is not None and len(acceptable_ora_err) > 0
|
||||
)
|
||||
self._execute_statements(
|
||||
cursor,
|
||||
statements,
|
||||
parameters,
|
||||
verbosity,
|
||||
allow_quiet_fail=allow_quiet_fail,
|
||||
)
|
||||
return True
|
||||
except DatabaseError as err:
|
||||
description = str(err)
|
||||
if acceptable_ora_err is None or acceptable_ora_err not in description:
|
||||
raise
|
||||
return False
|
||||
|
||||
def _get_test_db_params(self):
|
||||
return {
|
||||
"dbname": self._test_database_name(),
|
||||
"user": self._test_database_user(),
|
||||
"password": self._test_database_passwd(),
|
||||
"tblspace": self._test_database_tblspace(),
|
||||
"tblspace_temp": self._test_database_tblspace_tmp(),
|
||||
"datafile": self._test_database_tblspace_datafile(),
|
||||
"datafile_tmp": self._test_database_tblspace_tmp_datafile(),
|
||||
"maxsize": self._test_database_tblspace_maxsize(),
|
||||
"maxsize_tmp": self._test_database_tblspace_tmp_maxsize(),
|
||||
"size": self._test_database_tblspace_size(),
|
||||
"size_tmp": self._test_database_tblspace_tmp_size(),
|
||||
"extsize": self._test_database_tblspace_extsize(),
|
||||
"extsize_tmp": self._test_database_tblspace_tmp_extsize(),
|
||||
}
|
||||
|
||||
def _test_settings_get(self, key, default=None, prefixed=None):
|
||||
"""
|
||||
Return a value from the test settings dict, or a given default, or a
|
||||
prefixed entry from the main settings dict.
|
||||
"""
|
||||
settings_dict = self.connection.settings_dict
|
||||
val = settings_dict["TEST"].get(key, default)
|
||||
if val is None and prefixed:
|
||||
val = TEST_DATABASE_PREFIX + settings_dict[prefixed]
|
||||
return val
|
||||
|
||||
def _test_database_name(self):
|
||||
return self._test_settings_get("NAME", prefixed="NAME")
|
||||
|
||||
def _test_database_create(self):
|
||||
return self._test_settings_get("CREATE_DB", default=True)
|
||||
|
||||
def _test_user_create(self):
|
||||
return self._test_settings_get("CREATE_USER", default=True)
|
||||
|
||||
def _test_database_user(self):
|
||||
return self._test_settings_get("USER", prefixed="USER")
|
||||
|
||||
def _test_database_passwd(self):
|
||||
password = self._test_settings_get("PASSWORD")
|
||||
if password is None and self._test_user_create():
|
||||
# Oracle passwords are limited to 30 chars and can't contain symbols.
|
||||
password = get_random_string(30)
|
||||
return password
|
||||
|
||||
def _test_database_tblspace(self):
|
||||
return self._test_settings_get("TBLSPACE", prefixed="USER")
|
||||
|
||||
def _test_database_tblspace_tmp(self):
|
||||
settings_dict = self.connection.settings_dict
|
||||
return settings_dict["TEST"].get(
|
||||
"TBLSPACE_TMP", TEST_DATABASE_PREFIX + settings_dict["USER"] + "_temp"
|
||||
)
|
||||
|
||||
def _test_database_tblspace_datafile(self):
|
||||
tblspace = "%s.dbf" % self._test_database_tblspace()
|
||||
return self._test_settings_get("DATAFILE", default=tblspace)
|
||||
|
||||
def _test_database_tblspace_tmp_datafile(self):
|
||||
tblspace = "%s.dbf" % self._test_database_tblspace_tmp()
|
||||
return self._test_settings_get("DATAFILE_TMP", default=tblspace)
|
||||
|
||||
def _test_database_tblspace_maxsize(self):
|
||||
return self._test_settings_get("DATAFILE_MAXSIZE", default="500M")
|
||||
|
||||
def _test_database_tblspace_tmp_maxsize(self):
|
||||
return self._test_settings_get("DATAFILE_TMP_MAXSIZE", default="500M")
|
||||
|
||||
def _test_database_tblspace_size(self):
|
||||
return self._test_settings_get("DATAFILE_SIZE", default="50M")
|
||||
|
||||
def _test_database_tblspace_tmp_size(self):
|
||||
return self._test_settings_get("DATAFILE_TMP_SIZE", default="50M")
|
||||
|
||||
def _test_database_tblspace_extsize(self):
|
||||
return self._test_settings_get("DATAFILE_EXTSIZE", default="25M")
|
||||
|
||||
def _test_database_tblspace_tmp_extsize(self):
|
||||
return self._test_settings_get("DATAFILE_TMP_EXTSIZE", default="25M")
|
||||
|
||||
def _test_database_oracle_managed_files(self):
|
||||
return self._test_settings_get("ORACLE_MANAGED_FILES", default=False)
|
||||
|
||||
def _get_test_db_name(self):
|
||||
"""
|
||||
Return the 'production' DB name to get the test DB creation machinery
|
||||
to work. This isn't a great deal in this case because DB names as
|
||||
handled by Django don't have real counterparts in Oracle.
|
||||
"""
|
||||
return self.connection.settings_dict["NAME"]
|
||||
|
||||
def test_db_signature(self):
|
||||
settings_dict = self.connection.settings_dict
|
||||
return (
|
||||
settings_dict["HOST"],
|
||||
settings_dict["PORT"],
|
||||
settings_dict["ENGINE"],
|
||||
settings_dict["NAME"],
|
||||
self._test_database_user(),
|
||||
)
|
||||
@@ -0,0 +1,152 @@
|
||||
from django.db import DatabaseError, InterfaceError
|
||||
from django.db.backends.base.features import BaseDatabaseFeatures
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
||||
class DatabaseFeatures(BaseDatabaseFeatures):
|
||||
minimum_database_version = (19,)
|
||||
# Oracle crashes with "ORA-00932: inconsistent datatypes: expected - got
|
||||
# BLOB" when grouping by LOBs (#24096).
|
||||
allows_group_by_lob = False
|
||||
allows_group_by_select_index = False
|
||||
interprets_empty_strings_as_nulls = True
|
||||
has_select_for_update = True
|
||||
has_select_for_update_nowait = True
|
||||
has_select_for_update_skip_locked = True
|
||||
has_select_for_update_of = True
|
||||
select_for_update_of_column = True
|
||||
can_return_columns_from_insert = True
|
||||
supports_subqueries_in_group_by = False
|
||||
ignores_unnecessary_order_by_in_subqueries = False
|
||||
supports_transactions = True
|
||||
supports_timezones = False
|
||||
has_native_duration_field = True
|
||||
can_defer_constraint_checks = True
|
||||
supports_partially_nullable_unique_constraints = False
|
||||
supports_deferrable_unique_constraints = True
|
||||
truncates_names = True
|
||||
supports_comments = True
|
||||
supports_tablespaces = True
|
||||
supports_sequence_reset = False
|
||||
can_introspect_materialized_views = True
|
||||
atomic_transactions = False
|
||||
nulls_order_largest = True
|
||||
requires_literal_defaults = True
|
||||
closed_cursor_error_class = InterfaceError
|
||||
bare_select_suffix = " FROM DUAL"
|
||||
# Select for update with limit can be achieved on Oracle, but not with the
|
||||
# current backend.
|
||||
supports_select_for_update_with_limit = False
|
||||
supports_temporal_subtraction = True
|
||||
# Oracle doesn't ignore quoted identifiers case but the current backend
|
||||
# does by uppercasing all identifiers.
|
||||
ignores_table_name_case = True
|
||||
supports_index_on_text_field = False
|
||||
create_test_procedure_without_params_sql = """
|
||||
CREATE PROCEDURE "TEST_PROCEDURE" AS
|
||||
V_I INTEGER;
|
||||
BEGIN
|
||||
V_I := 1;
|
||||
END;
|
||||
"""
|
||||
create_test_procedure_with_int_param_sql = """
|
||||
CREATE PROCEDURE "TEST_PROCEDURE" (P_I INTEGER) AS
|
||||
V_I INTEGER;
|
||||
BEGIN
|
||||
V_I := P_I;
|
||||
END;
|
||||
"""
|
||||
create_test_table_with_composite_primary_key = """
|
||||
CREATE TABLE test_table_composite_pk (
|
||||
column_1 NUMBER(11) NOT NULL,
|
||||
column_2 NUMBER(11) NOT NULL,
|
||||
PRIMARY KEY (column_1, column_2)
|
||||
)
|
||||
"""
|
||||
supports_callproc_kwargs = True
|
||||
supports_over_clause = True
|
||||
supports_frame_range_fixed_distance = True
|
||||
supports_ignore_conflicts = False
|
||||
max_query_params = 2**16 - 1
|
||||
supports_partial_indexes = False
|
||||
can_rename_index = True
|
||||
supports_slicing_ordering_in_compound = True
|
||||
requires_compound_order_by_subquery = True
|
||||
allows_multiple_constraints_on_same_fields = False
|
||||
supports_boolean_expr_in_select_clause = False
|
||||
supports_comparing_boolean_expr = False
|
||||
supports_primitives_in_json_field = False
|
||||
supports_json_field_contains = False
|
||||
supports_collation_on_textfield = False
|
||||
test_collations = {
|
||||
"ci": "BINARY_CI",
|
||||
"cs": "BINARY",
|
||||
"non_default": "SWEDISH_CI",
|
||||
"swedish_ci": "SWEDISH_CI",
|
||||
}
|
||||
test_now_utc_template = "CURRENT_TIMESTAMP AT TIME ZONE 'UTC'"
|
||||
|
||||
django_test_skips = {
|
||||
"Oracle doesn't support SHA224.": {
|
||||
"db_functions.text.test_sha224.SHA224Tests.test_basic",
|
||||
"db_functions.text.test_sha224.SHA224Tests.test_transform",
|
||||
},
|
||||
"Oracle doesn't correctly calculate ISO 8601 week numbering before "
|
||||
"1583 (the Gregorian calendar was introduced in 1582).": {
|
||||
"db_functions.datetime.test_extract_trunc.DateFunctionTests."
|
||||
"test_trunc_week_before_1000",
|
||||
"db_functions.datetime.test_extract_trunc.DateFunctionWithTimeZoneTests."
|
||||
"test_trunc_week_before_1000",
|
||||
},
|
||||
"Oracle extracts seconds including fractional seconds (#33517).": {
|
||||
"db_functions.datetime.test_extract_trunc.DateFunctionTests."
|
||||
"test_extract_second_func_no_fractional",
|
||||
"db_functions.datetime.test_extract_trunc.DateFunctionWithTimeZoneTests."
|
||||
"test_extract_second_func_no_fractional",
|
||||
},
|
||||
"Oracle doesn't support bitwise XOR.": {
|
||||
"expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor",
|
||||
"expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor_null",
|
||||
"expressions.tests.ExpressionOperatorTests."
|
||||
"test_lefthand_bitwise_xor_right_null",
|
||||
},
|
||||
"Oracle requires ORDER BY in row_number, ANSI:SQL doesn't.": {
|
||||
"expressions_window.tests.WindowFunctionTests.test_row_number_no_ordering",
|
||||
},
|
||||
"Raises ORA-00600: internal error code.": {
|
||||
"model_fields.test_jsonfield.TestQuerying.test_usage_in_subquery",
|
||||
},
|
||||
"Oracle doesn't support changing collations on indexed columns (#33671).": {
|
||||
"migrations.test_operations.OperationTests."
|
||||
"test_alter_field_pk_fk_db_collation",
|
||||
},
|
||||
}
|
||||
django_test_expected_failures = {
|
||||
# A bug in Django/cx_Oracle with respect to string handling (#23843).
|
||||
"annotations.tests.NonAggregateAnnotationTestCase.test_custom_functions",
|
||||
"annotations.tests.NonAggregateAnnotationTestCase."
|
||||
"test_custom_functions_can_ref_other_functions",
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def introspected_field_types(self):
|
||||
return {
|
||||
**super().introspected_field_types,
|
||||
"GenericIPAddressField": "CharField",
|
||||
"PositiveBigIntegerField": "BigIntegerField",
|
||||
"PositiveIntegerField": "IntegerField",
|
||||
"PositiveSmallIntegerField": "IntegerField",
|
||||
"SmallIntegerField": "IntegerField",
|
||||
"TimeField": "DateTimeField",
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def supports_collation_on_charfield(self):
|
||||
with self.connection.cursor() as cursor:
|
||||
try:
|
||||
cursor.execute("SELECT CAST('a' AS VARCHAR2(4001)) FROM dual")
|
||||
except DatabaseError as e:
|
||||
if e.args[0].code == 910:
|
||||
return False
|
||||
raise
|
||||
return True
|
||||
@@ -0,0 +1,26 @@
|
||||
from django.db.models import DecimalField, DurationField, Func
|
||||
|
||||
|
||||
class IntervalToSeconds(Func):
|
||||
function = ""
|
||||
template = """
|
||||
EXTRACT(day from %(expressions)s) * 86400 +
|
||||
EXTRACT(hour from %(expressions)s) * 3600 +
|
||||
EXTRACT(minute from %(expressions)s) * 60 +
|
||||
EXTRACT(second from %(expressions)s)
|
||||
"""
|
||||
|
||||
def __init__(self, expression, *, output_field=None, **extra):
|
||||
super().__init__(
|
||||
expression, output_field=output_field or DecimalField(), **extra
|
||||
)
|
||||
|
||||
|
||||
class SecondsToInterval(Func):
|
||||
function = "NUMTODSINTERVAL"
|
||||
template = "%(function)s(%(expressions)s, 'SECOND')"
|
||||
|
||||
def __init__(self, expression, *, output_field=None, **extra):
|
||||
super().__init__(
|
||||
expression, output_field=output_field or DurationField(), **extra
|
||||
)
|
||||
@@ -0,0 +1,411 @@
|
||||
from collections import namedtuple
|
||||
|
||||
import cx_Oracle
|
||||
|
||||
from django.db import models
|
||||
from django.db.backends.base.introspection import BaseDatabaseIntrospection
|
||||
from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo
|
||||
from django.db.backends.base.introspection import TableInfo as BaseTableInfo
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
FieldInfo = namedtuple(
|
||||
"FieldInfo", BaseFieldInfo._fields + ("is_autofield", "is_json", "comment")
|
||||
)
|
||||
TableInfo = namedtuple("TableInfo", BaseTableInfo._fields + ("comment",))
|
||||
|
||||
|
||||
class DatabaseIntrospection(BaseDatabaseIntrospection):
|
||||
cache_bust_counter = 1
|
||||
|
||||
# Maps type objects to Django Field types.
|
||||
@cached_property
|
||||
def data_types_reverse(self):
|
||||
if self.connection.cx_oracle_version < (8,):
|
||||
return {
|
||||
cx_Oracle.BLOB: "BinaryField",
|
||||
cx_Oracle.CLOB: "TextField",
|
||||
cx_Oracle.DATETIME: "DateField",
|
||||
cx_Oracle.FIXED_CHAR: "CharField",
|
||||
cx_Oracle.FIXED_NCHAR: "CharField",
|
||||
cx_Oracle.INTERVAL: "DurationField",
|
||||
cx_Oracle.NATIVE_FLOAT: "FloatField",
|
||||
cx_Oracle.NCHAR: "CharField",
|
||||
cx_Oracle.NCLOB: "TextField",
|
||||
cx_Oracle.NUMBER: "DecimalField",
|
||||
cx_Oracle.STRING: "CharField",
|
||||
cx_Oracle.TIMESTAMP: "DateTimeField",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
cx_Oracle.DB_TYPE_DATE: "DateField",
|
||||
cx_Oracle.DB_TYPE_BINARY_DOUBLE: "FloatField",
|
||||
cx_Oracle.DB_TYPE_BLOB: "BinaryField",
|
||||
cx_Oracle.DB_TYPE_CHAR: "CharField",
|
||||
cx_Oracle.DB_TYPE_CLOB: "TextField",
|
||||
cx_Oracle.DB_TYPE_INTERVAL_DS: "DurationField",
|
||||
cx_Oracle.DB_TYPE_NCHAR: "CharField",
|
||||
cx_Oracle.DB_TYPE_NCLOB: "TextField",
|
||||
cx_Oracle.DB_TYPE_NVARCHAR: "CharField",
|
||||
cx_Oracle.DB_TYPE_NUMBER: "DecimalField",
|
||||
cx_Oracle.DB_TYPE_TIMESTAMP: "DateTimeField",
|
||||
cx_Oracle.DB_TYPE_VARCHAR: "CharField",
|
||||
}
|
||||
|
||||
def get_field_type(self, data_type, description):
|
||||
if data_type == cx_Oracle.NUMBER:
|
||||
precision, scale = description[4:6]
|
||||
if scale == 0:
|
||||
if precision > 11:
|
||||
return (
|
||||
"BigAutoField"
|
||||
if description.is_autofield
|
||||
else "BigIntegerField"
|
||||
)
|
||||
elif 1 < precision < 6 and description.is_autofield:
|
||||
return "SmallAutoField"
|
||||
elif precision == 1:
|
||||
return "BooleanField"
|
||||
elif description.is_autofield:
|
||||
return "AutoField"
|
||||
else:
|
||||
return "IntegerField"
|
||||
elif scale == -127:
|
||||
return "FloatField"
|
||||
elif data_type == cx_Oracle.NCLOB and description.is_json:
|
||||
return "JSONField"
|
||||
|
||||
return super().get_field_type(data_type, description)
|
||||
|
||||
def get_table_list(self, cursor):
|
||||
"""Return a list of table and view names in the current database."""
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
user_tables.table_name,
|
||||
't',
|
||||
user_tab_comments.comments
|
||||
FROM user_tables
|
||||
LEFT OUTER JOIN
|
||||
user_tab_comments
|
||||
ON user_tab_comments.table_name = user_tables.table_name
|
||||
WHERE
|
||||
NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM user_mviews
|
||||
WHERE user_mviews.mview_name = user_tables.table_name
|
||||
)
|
||||
UNION ALL
|
||||
SELECT view_name, 'v', NULL FROM user_views
|
||||
UNION ALL
|
||||
SELECT mview_name, 'v', NULL FROM user_mviews
|
||||
"""
|
||||
)
|
||||
return [
|
||||
TableInfo(self.identifier_converter(row[0]), row[1], row[2])
|
||||
for row in cursor.fetchall()
|
||||
]
|
||||
|
||||
def get_table_description(self, cursor, table_name):
|
||||
"""
|
||||
Return a description of the table with the DB-API cursor.description
|
||||
interface.
|
||||
"""
|
||||
# user_tab_columns gives data default for columns
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
user_tab_cols.column_name,
|
||||
user_tab_cols.data_default,
|
||||
CASE
|
||||
WHEN user_tab_cols.collation = user_tables.default_collation
|
||||
THEN NULL
|
||||
ELSE user_tab_cols.collation
|
||||
END collation,
|
||||
CASE
|
||||
WHEN user_tab_cols.char_used IS NULL
|
||||
THEN user_tab_cols.data_length
|
||||
ELSE user_tab_cols.char_length
|
||||
END as display_size,
|
||||
CASE
|
||||
WHEN user_tab_cols.identity_column = 'YES' THEN 1
|
||||
ELSE 0
|
||||
END as is_autofield,
|
||||
CASE
|
||||
WHEN EXISTS (
|
||||
SELECT 1
|
||||
FROM user_json_columns
|
||||
WHERE
|
||||
user_json_columns.table_name = user_tab_cols.table_name AND
|
||||
user_json_columns.column_name = user_tab_cols.column_name
|
||||
)
|
||||
THEN 1
|
||||
ELSE 0
|
||||
END as is_json,
|
||||
user_col_comments.comments as col_comment
|
||||
FROM user_tab_cols
|
||||
LEFT OUTER JOIN
|
||||
user_tables ON user_tables.table_name = user_tab_cols.table_name
|
||||
LEFT OUTER JOIN
|
||||
user_col_comments ON
|
||||
user_col_comments.column_name = user_tab_cols.column_name AND
|
||||
user_col_comments.table_name = user_tab_cols.table_name
|
||||
WHERE user_tab_cols.table_name = UPPER(%s)
|
||||
""",
|
||||
[table_name],
|
||||
)
|
||||
field_map = {
|
||||
column: (
|
||||
display_size,
|
||||
default if default != "NULL" else None,
|
||||
collation,
|
||||
is_autofield,
|
||||
is_json,
|
||||
comment,
|
||||
)
|
||||
for (
|
||||
column,
|
||||
default,
|
||||
collation,
|
||||
display_size,
|
||||
is_autofield,
|
||||
is_json,
|
||||
comment,
|
||||
) in cursor.fetchall()
|
||||
}
|
||||
self.cache_bust_counter += 1
|
||||
cursor.execute(
|
||||
"SELECT * FROM {} WHERE ROWNUM < 2 AND {} > 0".format(
|
||||
self.connection.ops.quote_name(table_name), self.cache_bust_counter
|
||||
)
|
||||
)
|
||||
description = []
|
||||
for desc in cursor.description:
|
||||
name = desc[0]
|
||||
(
|
||||
display_size,
|
||||
default,
|
||||
collation,
|
||||
is_autofield,
|
||||
is_json,
|
||||
comment,
|
||||
) = field_map[name]
|
||||
name %= {} # cx_Oracle, for some reason, doubles percent signs.
|
||||
description.append(
|
||||
FieldInfo(
|
||||
self.identifier_converter(name),
|
||||
desc[1],
|
||||
display_size,
|
||||
desc[3],
|
||||
desc[4] or 0,
|
||||
desc[5] or 0,
|
||||
*desc[6:],
|
||||
default,
|
||||
collation,
|
||||
is_autofield,
|
||||
is_json,
|
||||
comment,
|
||||
)
|
||||
)
|
||||
return description
|
||||
|
||||
def identifier_converter(self, name):
|
||||
"""Identifier comparison is case insensitive under Oracle."""
|
||||
return name.lower()
|
||||
|
||||
def get_sequences(self, cursor, table_name, table_fields=()):
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
user_tab_identity_cols.sequence_name,
|
||||
user_tab_identity_cols.column_name
|
||||
FROM
|
||||
user_tab_identity_cols,
|
||||
user_constraints,
|
||||
user_cons_columns cols
|
||||
WHERE
|
||||
user_constraints.constraint_name = cols.constraint_name
|
||||
AND user_constraints.table_name = user_tab_identity_cols.table_name
|
||||
AND cols.column_name = user_tab_identity_cols.column_name
|
||||
AND user_constraints.constraint_type = 'P'
|
||||
AND user_tab_identity_cols.table_name = UPPER(%s)
|
||||
""",
|
||||
[table_name],
|
||||
)
|
||||
# Oracle allows only one identity column per table.
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
return [
|
||||
{
|
||||
"name": self.identifier_converter(row[0]),
|
||||
"table": self.identifier_converter(table_name),
|
||||
"column": self.identifier_converter(row[1]),
|
||||
}
|
||||
]
|
||||
# To keep backward compatibility for AutoFields that aren't Oracle
|
||||
# identity columns.
|
||||
for f in table_fields:
|
||||
if isinstance(f, models.AutoField):
|
||||
return [{"table": table_name, "column": f.column}]
|
||||
return []
|
||||
|
||||
def get_relations(self, cursor, table_name):
|
||||
"""
|
||||
Return a dictionary of {field_name: (field_name_other_table, other_table)}
|
||||
representing all foreign keys in the given table.
|
||||
"""
|
||||
table_name = table_name.upper()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT ca.column_name, cb.table_name, cb.column_name
|
||||
FROM user_constraints, USER_CONS_COLUMNS ca, USER_CONS_COLUMNS cb
|
||||
WHERE user_constraints.table_name = %s AND
|
||||
user_constraints.constraint_name = ca.constraint_name AND
|
||||
user_constraints.r_constraint_name = cb.constraint_name AND
|
||||
ca.position = cb.position""",
|
||||
[table_name],
|
||||
)
|
||||
|
||||
return {
|
||||
self.identifier_converter(field_name): (
|
||||
self.identifier_converter(rel_field_name),
|
||||
self.identifier_converter(rel_table_name),
|
||||
)
|
||||
for field_name, rel_table_name, rel_field_name in cursor.fetchall()
|
||||
}
|
||||
|
||||
def get_primary_key_columns(self, cursor, table_name):
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
cols.column_name
|
||||
FROM
|
||||
user_constraints,
|
||||
user_cons_columns cols
|
||||
WHERE
|
||||
user_constraints.constraint_name = cols.constraint_name AND
|
||||
user_constraints.constraint_type = 'P' AND
|
||||
user_constraints.table_name = UPPER(%s)
|
||||
ORDER BY
|
||||
cols.position
|
||||
""",
|
||||
[table_name],
|
||||
)
|
||||
return [self.identifier_converter(row[0]) for row in cursor.fetchall()]
|
||||
|
||||
def get_constraints(self, cursor, table_name):
|
||||
"""
|
||||
Retrieve any constraints or keys (unique, pk, fk, check, index) across
|
||||
one or more columns.
|
||||
"""
|
||||
constraints = {}
|
||||
# Loop over the constraints, getting PKs, uniques, and checks
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
user_constraints.constraint_name,
|
||||
LISTAGG(LOWER(cols.column_name), ',')
|
||||
WITHIN GROUP (ORDER BY cols.position),
|
||||
CASE user_constraints.constraint_type
|
||||
WHEN 'P' THEN 1
|
||||
ELSE 0
|
||||
END AS is_primary_key,
|
||||
CASE
|
||||
WHEN user_constraints.constraint_type IN ('P', 'U') THEN 1
|
||||
ELSE 0
|
||||
END AS is_unique,
|
||||
CASE user_constraints.constraint_type
|
||||
WHEN 'C' THEN 1
|
||||
ELSE 0
|
||||
END AS is_check_constraint
|
||||
FROM
|
||||
user_constraints
|
||||
LEFT OUTER JOIN
|
||||
user_cons_columns cols
|
||||
ON user_constraints.constraint_name = cols.constraint_name
|
||||
WHERE
|
||||
user_constraints.constraint_type = ANY('P', 'U', 'C')
|
||||
AND user_constraints.table_name = UPPER(%s)
|
||||
GROUP BY user_constraints.constraint_name, user_constraints.constraint_type
|
||||
""",
|
||||
[table_name],
|
||||
)
|
||||
for constraint, columns, pk, unique, check in cursor.fetchall():
|
||||
constraint = self.identifier_converter(constraint)
|
||||
constraints[constraint] = {
|
||||
"columns": columns.split(","),
|
||||
"primary_key": pk,
|
||||
"unique": unique,
|
||||
"foreign_key": None,
|
||||
"check": check,
|
||||
"index": unique, # All uniques come with an index
|
||||
}
|
||||
# Foreign key constraints
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
cons.constraint_name,
|
||||
LISTAGG(LOWER(cols.column_name), ',')
|
||||
WITHIN GROUP (ORDER BY cols.position),
|
||||
LOWER(rcols.table_name),
|
||||
LOWER(rcols.column_name)
|
||||
FROM
|
||||
user_constraints cons
|
||||
INNER JOIN
|
||||
user_cons_columns rcols
|
||||
ON rcols.constraint_name = cons.r_constraint_name AND rcols.position = 1
|
||||
LEFT OUTER JOIN
|
||||
user_cons_columns cols
|
||||
ON cons.constraint_name = cols.constraint_name
|
||||
WHERE
|
||||
cons.constraint_type = 'R' AND
|
||||
cons.table_name = UPPER(%s)
|
||||
GROUP BY cons.constraint_name, rcols.table_name, rcols.column_name
|
||||
""",
|
||||
[table_name],
|
||||
)
|
||||
for constraint, columns, other_table, other_column in cursor.fetchall():
|
||||
constraint = self.identifier_converter(constraint)
|
||||
constraints[constraint] = {
|
||||
"primary_key": False,
|
||||
"unique": False,
|
||||
"foreign_key": (other_table, other_column),
|
||||
"check": False,
|
||||
"index": False,
|
||||
"columns": columns.split(","),
|
||||
}
|
||||
# Now get indexes
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
ind.index_name,
|
||||
LOWER(ind.index_type),
|
||||
LOWER(ind.uniqueness),
|
||||
LISTAGG(LOWER(cols.column_name), ',')
|
||||
WITHIN GROUP (ORDER BY cols.column_position),
|
||||
LISTAGG(cols.descend, ',') WITHIN GROUP (ORDER BY cols.column_position)
|
||||
FROM
|
||||
user_ind_columns cols, user_indexes ind
|
||||
WHERE
|
||||
cols.table_name = UPPER(%s) AND
|
||||
NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM user_constraints cons
|
||||
WHERE ind.index_name = cons.index_name
|
||||
) AND cols.index_name = ind.index_name
|
||||
GROUP BY ind.index_name, ind.index_type, ind.uniqueness
|
||||
""",
|
||||
[table_name],
|
||||
)
|
||||
for constraint, type_, unique, columns, orders in cursor.fetchall():
|
||||
constraint = self.identifier_converter(constraint)
|
||||
constraints[constraint] = {
|
||||
"primary_key": False,
|
||||
"unique": unique == "unique",
|
||||
"foreign_key": None,
|
||||
"check": False,
|
||||
"index": True,
|
||||
"type": "idx" if type_ == "normal" else type_,
|
||||
"columns": columns.split(","),
|
||||
"orders": orders.split(","),
|
||||
}
|
||||
return constraints
|
||||
@@ -0,0 +1,722 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from functools import lru_cache
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import DatabaseError, NotSupportedError
|
||||
from django.db.backends.base.operations import BaseDatabaseOperations
|
||||
from django.db.backends.utils import split_tzname_delta, strip_quotes, truncate_name
|
||||
from django.db.models import AutoField, Exists, ExpressionWrapper, Lookup
|
||||
from django.db.models.expressions import RawSQL
|
||||
from django.db.models.sql.where import WhereNode
|
||||
from django.utils import timezone
|
||||
from django.utils.encoding import force_bytes, force_str
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
|
||||
from .base import Database
|
||||
from .utils import BulkInsertMapper, InsertVar, Oracle_datetime
|
||||
|
||||
|
||||
class DatabaseOperations(BaseDatabaseOperations):
|
||||
# Oracle uses NUMBER(5), NUMBER(11), and NUMBER(19) for integer fields.
|
||||
# SmallIntegerField uses NUMBER(11) instead of NUMBER(5), which is used by
|
||||
# SmallAutoField, to preserve backward compatibility.
|
||||
integer_field_ranges = {
|
||||
"SmallIntegerField": (-99999999999, 99999999999),
|
||||
"IntegerField": (-99999999999, 99999999999),
|
||||
"BigIntegerField": (-9999999999999999999, 9999999999999999999),
|
||||
"PositiveBigIntegerField": (0, 9999999999999999999),
|
||||
"PositiveSmallIntegerField": (0, 99999999999),
|
||||
"PositiveIntegerField": (0, 99999999999),
|
||||
"SmallAutoField": (-99999, 99999),
|
||||
"AutoField": (-99999999999, 99999999999),
|
||||
"BigAutoField": (-9999999999999999999, 9999999999999999999),
|
||||
}
|
||||
set_operators = {**BaseDatabaseOperations.set_operators, "difference": "MINUS"}
|
||||
|
||||
# TODO: colorize this SQL code with style.SQL_KEYWORD(), etc.
|
||||
_sequence_reset_sql = """
|
||||
DECLARE
|
||||
table_value integer;
|
||||
seq_value integer;
|
||||
seq_name user_tab_identity_cols.sequence_name%%TYPE;
|
||||
BEGIN
|
||||
BEGIN
|
||||
SELECT sequence_name INTO seq_name FROM user_tab_identity_cols
|
||||
WHERE table_name = '%(table_name)s' AND
|
||||
column_name = '%(column_name)s';
|
||||
EXCEPTION WHEN NO_DATA_FOUND THEN
|
||||
seq_name := '%(no_autofield_sequence_name)s';
|
||||
END;
|
||||
|
||||
SELECT NVL(MAX(%(column)s), 0) INTO table_value FROM %(table)s;
|
||||
SELECT NVL(last_number - cache_size, 0) INTO seq_value FROM user_sequences
|
||||
WHERE sequence_name = seq_name;
|
||||
WHILE table_value > seq_value LOOP
|
||||
EXECUTE IMMEDIATE 'SELECT "'||seq_name||'".nextval FROM DUAL'
|
||||
INTO seq_value;
|
||||
END LOOP;
|
||||
END;
|
||||
/"""
|
||||
|
||||
# Oracle doesn't support string without precision; use the max string size.
|
||||
cast_char_field_without_max_length = "NVARCHAR2(2000)"
|
||||
cast_data_types = {
|
||||
"AutoField": "NUMBER(11)",
|
||||
"BigAutoField": "NUMBER(19)",
|
||||
"SmallAutoField": "NUMBER(5)",
|
||||
"TextField": cast_char_field_without_max_length,
|
||||
}
|
||||
|
||||
def cache_key_culling_sql(self):
|
||||
cache_key = self.quote_name("cache_key")
|
||||
return (
|
||||
f"SELECT {cache_key} "
|
||||
f"FROM %s "
|
||||
f"ORDER BY {cache_key} OFFSET %%s ROWS FETCH FIRST 1 ROWS ONLY"
|
||||
)
|
||||
|
||||
# EXTRACT format cannot be passed in parameters.
|
||||
_extract_format_re = _lazy_re_compile(r"[A-Z_]+")
|
||||
|
||||
def date_extract_sql(self, lookup_type, sql, params):
|
||||
extract_sql = f"TO_CHAR({sql}, %s)"
|
||||
extract_param = None
|
||||
if lookup_type == "week_day":
|
||||
# TO_CHAR(field, 'D') returns an integer from 1-7, where 1=Sunday.
|
||||
extract_param = "D"
|
||||
elif lookup_type == "iso_week_day":
|
||||
extract_sql = f"TO_CHAR({sql} - 1, %s)"
|
||||
extract_param = "D"
|
||||
elif lookup_type == "week":
|
||||
# IW = ISO week number
|
||||
extract_param = "IW"
|
||||
elif lookup_type == "quarter":
|
||||
extract_param = "Q"
|
||||
elif lookup_type == "iso_year":
|
||||
extract_param = "IYYY"
|
||||
else:
|
||||
lookup_type = lookup_type.upper()
|
||||
if not self._extract_format_re.fullmatch(lookup_type):
|
||||
raise ValueError(f"Invalid loookup type: {lookup_type!r}")
|
||||
# https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/EXTRACT-datetime.html
|
||||
return f"EXTRACT({lookup_type} FROM {sql})", params
|
||||
return extract_sql, (*params, extract_param)
|
||||
|
||||
def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
# https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/ROUND-and-TRUNC-Date-Functions.html
|
||||
trunc_param = None
|
||||
if lookup_type in ("year", "month"):
|
||||
trunc_param = lookup_type.upper()
|
||||
elif lookup_type == "quarter":
|
||||
trunc_param = "Q"
|
||||
elif lookup_type == "week":
|
||||
trunc_param = "IW"
|
||||
else:
|
||||
return f"TRUNC({sql})", params
|
||||
return f"TRUNC({sql}, %s)", (*params, trunc_param)
|
||||
|
||||
# Oracle crashes with "ORA-03113: end-of-file on communication channel"
|
||||
# if the time zone name is passed in parameter. Use interpolation instead.
|
||||
# https://groups.google.com/forum/#!msg/django-developers/zwQju7hbG78/9l934yelwfsJ
|
||||
# This regexp matches all time zone names from the zoneinfo database.
|
||||
_tzname_re = _lazy_re_compile(r"^[\w/:+-]+$")
|
||||
|
||||
def _prepare_tzname_delta(self, tzname):
|
||||
tzname, sign, offset = split_tzname_delta(tzname)
|
||||
return f"{sign}{offset}" if offset else tzname
|
||||
|
||||
def _convert_sql_to_tz(self, sql, params, tzname):
|
||||
if not (settings.USE_TZ and tzname):
|
||||
return sql, params
|
||||
if not self._tzname_re.match(tzname):
|
||||
raise ValueError("Invalid time zone name: %s" % tzname)
|
||||
# Convert from connection timezone to the local time, returning
|
||||
# TIMESTAMP WITH TIME ZONE and cast it back to TIMESTAMP to strip the
|
||||
# TIME ZONE details.
|
||||
if self.connection.timezone_name != tzname:
|
||||
from_timezone_name = self.connection.timezone_name
|
||||
to_timezone_name = self._prepare_tzname_delta(tzname)
|
||||
return (
|
||||
f"CAST((FROM_TZ({sql}, '{from_timezone_name}') AT TIME ZONE "
|
||||
f"'{to_timezone_name}') AS TIMESTAMP)",
|
||||
params,
|
||||
)
|
||||
return sql, params
|
||||
|
||||
def datetime_cast_date_sql(self, sql, params, tzname):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
return f"TRUNC({sql})", params
|
||||
|
||||
def datetime_cast_time_sql(self, sql, params, tzname):
|
||||
# Since `TimeField` values are stored as TIMESTAMP change to the
|
||||
# default date and convert the field to the specified timezone.
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
convert_datetime_sql = (
|
||||
f"TO_TIMESTAMP(CONCAT('1900-01-01 ', TO_CHAR({sql}, 'HH24:MI:SS.FF')), "
|
||||
f"'YYYY-MM-DD HH24:MI:SS.FF')"
|
||||
)
|
||||
return (
|
||||
f"CASE WHEN {sql} IS NOT NULL THEN {convert_datetime_sql} ELSE NULL END",
|
||||
(*params, *params),
|
||||
)
|
||||
|
||||
def datetime_extract_sql(self, lookup_type, sql, params, tzname):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
return self.date_extract_sql(lookup_type, sql, params)
|
||||
|
||||
def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
# https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/ROUND-and-TRUNC-Date-Functions.html
|
||||
trunc_param = None
|
||||
if lookup_type in ("year", "month"):
|
||||
trunc_param = lookup_type.upper()
|
||||
elif lookup_type == "quarter":
|
||||
trunc_param = "Q"
|
||||
elif lookup_type == "week":
|
||||
trunc_param = "IW"
|
||||
elif lookup_type == "hour":
|
||||
trunc_param = "HH24"
|
||||
elif lookup_type == "minute":
|
||||
trunc_param = "MI"
|
||||
elif lookup_type == "day":
|
||||
return f"TRUNC({sql})", params
|
||||
else:
|
||||
# Cast to DATE removes sub-second precision.
|
||||
return f"CAST({sql} AS DATE)", params
|
||||
return f"TRUNC({sql}, %s)", (*params, trunc_param)
|
||||
|
||||
def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
|
||||
# The implementation is similar to `datetime_trunc_sql` as both
|
||||
# `DateTimeField` and `TimeField` are stored as TIMESTAMP where
|
||||
# the date part of the later is ignored.
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
trunc_param = None
|
||||
if lookup_type == "hour":
|
||||
trunc_param = "HH24"
|
||||
elif lookup_type == "minute":
|
||||
trunc_param = "MI"
|
||||
elif lookup_type == "second":
|
||||
# Cast to DATE removes sub-second precision.
|
||||
return f"CAST({sql} AS DATE)", params
|
||||
return f"TRUNC({sql}, %s)", (*params, trunc_param)
|
||||
|
||||
def get_db_converters(self, expression):
|
||||
converters = super().get_db_converters(expression)
|
||||
internal_type = expression.output_field.get_internal_type()
|
||||
if internal_type in ["JSONField", "TextField"]:
|
||||
converters.append(self.convert_textfield_value)
|
||||
elif internal_type == "BinaryField":
|
||||
converters.append(self.convert_binaryfield_value)
|
||||
elif internal_type == "BooleanField":
|
||||
converters.append(self.convert_booleanfield_value)
|
||||
elif internal_type == "DateTimeField":
|
||||
if settings.USE_TZ:
|
||||
converters.append(self.convert_datetimefield_value)
|
||||
elif internal_type == "DateField":
|
||||
converters.append(self.convert_datefield_value)
|
||||
elif internal_type == "TimeField":
|
||||
converters.append(self.convert_timefield_value)
|
||||
elif internal_type == "UUIDField":
|
||||
converters.append(self.convert_uuidfield_value)
|
||||
# Oracle stores empty strings as null. If the field accepts the empty
|
||||
# string, undo this to adhere to the Django convention of using
|
||||
# the empty string instead of null.
|
||||
if expression.output_field.empty_strings_allowed:
|
||||
converters.append(
|
||||
self.convert_empty_bytes
|
||||
if internal_type == "BinaryField"
|
||||
else self.convert_empty_string
|
||||
)
|
||||
return converters
|
||||
|
||||
def convert_textfield_value(self, value, expression, connection):
|
||||
if isinstance(value, Database.LOB):
|
||||
value = value.read()
|
||||
return value
|
||||
|
||||
def convert_binaryfield_value(self, value, expression, connection):
|
||||
if isinstance(value, Database.LOB):
|
||||
value = force_bytes(value.read())
|
||||
return value
|
||||
|
||||
def convert_booleanfield_value(self, value, expression, connection):
|
||||
if value in (0, 1):
|
||||
value = bool(value)
|
||||
return value
|
||||
|
||||
# cx_Oracle always returns datetime.datetime objects for
|
||||
# DATE and TIMESTAMP columns, but Django wants to see a
|
||||
# python datetime.date, .time, or .datetime.
|
||||
|
||||
def convert_datetimefield_value(self, value, expression, connection):
|
||||
if value is not None:
|
||||
value = timezone.make_aware(value, self.connection.timezone)
|
||||
return value
|
||||
|
||||
def convert_datefield_value(self, value, expression, connection):
|
||||
if isinstance(value, Database.Timestamp):
|
||||
value = value.date()
|
||||
return value
|
||||
|
||||
def convert_timefield_value(self, value, expression, connection):
|
||||
if isinstance(value, Database.Timestamp):
|
||||
value = value.time()
|
||||
return value
|
||||
|
||||
def convert_uuidfield_value(self, value, expression, connection):
|
||||
if value is not None:
|
||||
value = uuid.UUID(value)
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def convert_empty_string(value, expression, connection):
|
||||
return "" if value is None else value
|
||||
|
||||
@staticmethod
|
||||
def convert_empty_bytes(value, expression, connection):
|
||||
return b"" if value is None else value
|
||||
|
||||
def deferrable_sql(self):
|
||||
return " DEFERRABLE INITIALLY DEFERRED"
|
||||
|
||||
def fetch_returned_insert_columns(self, cursor, returning_params):
|
||||
columns = []
|
||||
for param in returning_params:
|
||||
value = param.get_value()
|
||||
if value == []:
|
||||
raise DatabaseError(
|
||||
"The database did not return a new row id. Probably "
|
||||
'"ORA-1403: no data found" was raised internally but was '
|
||||
"hidden by the Oracle OCI library (see "
|
||||
"https://code.djangoproject.com/ticket/28859)."
|
||||
)
|
||||
columns.append(value[0])
|
||||
return tuple(columns)
|
||||
|
||||
def no_limit_value(self):
|
||||
return None
|
||||
|
||||
def limit_offset_sql(self, low_mark, high_mark):
|
||||
fetch, offset = self._get_limit_offset_params(low_mark, high_mark)
|
||||
return " ".join(
|
||||
sql
|
||||
for sql in (
|
||||
("OFFSET %d ROWS" % offset) if offset else None,
|
||||
("FETCH FIRST %d ROWS ONLY" % fetch) if fetch else None,
|
||||
)
|
||||
if sql
|
||||
)
|
||||
|
||||
def last_executed_query(self, cursor, sql, params):
|
||||
# https://cx-oracle.readthedocs.io/en/latest/api_manual/cursor.html#Cursor.statement
|
||||
# The DB API definition does not define this attribute.
|
||||
statement = cursor.statement
|
||||
# Unlike Psycopg's `query` and MySQLdb`'s `_executed`, cx_Oracle's
|
||||
# `statement` doesn't contain the query parameters. Substitute
|
||||
# parameters manually.
|
||||
if params:
|
||||
if isinstance(params, (tuple, list)):
|
||||
params = {
|
||||
f":arg{i}": param for i, param in enumerate(dict.fromkeys(params))
|
||||
}
|
||||
elif isinstance(params, dict):
|
||||
params = {f":{key}": val for (key, val) in params.items()}
|
||||
for key in sorted(params, key=len, reverse=True):
|
||||
statement = statement.replace(
|
||||
key, force_str(params[key], errors="replace")
|
||||
)
|
||||
return statement
|
||||
|
||||
def last_insert_id(self, cursor, table_name, pk_name):
|
||||
sq_name = self._get_sequence_name(cursor, strip_quotes(table_name), pk_name)
|
||||
cursor.execute('"%s".currval' % sq_name)
|
||||
return cursor.fetchone()[0]
|
||||
|
||||
def lookup_cast(self, lookup_type, internal_type=None):
|
||||
if lookup_type in ("iexact", "icontains", "istartswith", "iendswith"):
|
||||
return "UPPER(%s)"
|
||||
if (
|
||||
lookup_type != "isnull" and internal_type in ("BinaryField", "TextField")
|
||||
) or (lookup_type == "exact" and internal_type == "JSONField"):
|
||||
return "DBMS_LOB.SUBSTR(%s)"
|
||||
return "%s"
|
||||
|
||||
def max_in_list_size(self):
|
||||
return 1000
|
||||
|
||||
def max_name_length(self):
|
||||
return 30
|
||||
|
||||
def pk_default_value(self):
|
||||
return "NULL"
|
||||
|
||||
def prep_for_iexact_query(self, x):
|
||||
return x
|
||||
|
||||
def process_clob(self, value):
|
||||
if value is None:
|
||||
return ""
|
||||
return value.read()
|
||||
|
||||
def quote_name(self, name):
|
||||
# SQL92 requires delimited (quoted) names to be case-sensitive. When
|
||||
# not quoted, Oracle has case-insensitive behavior for identifiers, but
|
||||
# always defaults to uppercase.
|
||||
# We simplify things by making Oracle identifiers always uppercase.
|
||||
if not name.startswith('"') and not name.endswith('"'):
|
||||
name = '"%s"' % truncate_name(name, self.max_name_length())
|
||||
# Oracle puts the query text into a (query % args) construct, so % signs
|
||||
# in names need to be escaped. The '%%' will be collapsed back to '%' at
|
||||
# that stage so we aren't really making the name longer here.
|
||||
name = name.replace("%", "%%")
|
||||
return name.upper()
|
||||
|
||||
def regex_lookup(self, lookup_type):
|
||||
if lookup_type == "regex":
|
||||
match_option = "'c'"
|
||||
else:
|
||||
match_option = "'i'"
|
||||
return "REGEXP_LIKE(%%s, %%s, %s)" % match_option
|
||||
|
||||
def return_insert_columns(self, fields):
|
||||
if not fields:
|
||||
return "", ()
|
||||
field_names = []
|
||||
params = []
|
||||
for field in fields:
|
||||
field_names.append(
|
||||
"%s.%s"
|
||||
% (
|
||||
self.quote_name(field.model._meta.db_table),
|
||||
self.quote_name(field.column),
|
||||
)
|
||||
)
|
||||
params.append(InsertVar(field))
|
||||
return "RETURNING %s INTO %s" % (
|
||||
", ".join(field_names),
|
||||
", ".join(["%s"] * len(params)),
|
||||
), tuple(params)
|
||||
|
||||
def __foreign_key_constraints(self, table_name, recursive):
|
||||
with self.connection.cursor() as cursor:
|
||||
if recursive:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
user_tables.table_name, rcons.constraint_name
|
||||
FROM
|
||||
user_tables
|
||||
JOIN
|
||||
user_constraints cons
|
||||
ON (user_tables.table_name = cons.table_name
|
||||
AND cons.constraint_type = ANY('P', 'U'))
|
||||
LEFT JOIN
|
||||
user_constraints rcons
|
||||
ON (user_tables.table_name = rcons.table_name
|
||||
AND rcons.constraint_type = 'R')
|
||||
START WITH user_tables.table_name = UPPER(%s)
|
||||
CONNECT BY
|
||||
NOCYCLE PRIOR cons.constraint_name = rcons.r_constraint_name
|
||||
GROUP BY
|
||||
user_tables.table_name, rcons.constraint_name
|
||||
HAVING user_tables.table_name != UPPER(%s)
|
||||
ORDER BY MAX(level) DESC
|
||||
""",
|
||||
(table_name, table_name),
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
cons.table_name, cons.constraint_name
|
||||
FROM
|
||||
user_constraints cons
|
||||
WHERE
|
||||
cons.constraint_type = 'R'
|
||||
AND cons.table_name = UPPER(%s)
|
||||
""",
|
||||
(table_name,),
|
||||
)
|
||||
return cursor.fetchall()
|
||||
|
||||
@cached_property
|
||||
def _foreign_key_constraints(self):
|
||||
# 512 is large enough to fit the ~330 tables (as of this writing) in
|
||||
# Django's test suite.
|
||||
return lru_cache(maxsize=512)(self.__foreign_key_constraints)
|
||||
|
||||
def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
|
||||
if not tables:
|
||||
return []
|
||||
|
||||
truncated_tables = {table.upper() for table in tables}
|
||||
constraints = set()
|
||||
# Oracle's TRUNCATE CASCADE only works with ON DELETE CASCADE foreign
|
||||
# keys which Django doesn't define. Emulate the PostgreSQL behavior
|
||||
# which truncates all dependent tables by manually retrieving all
|
||||
# foreign key constraints and resolving dependencies.
|
||||
for table in tables:
|
||||
for foreign_table, constraint in self._foreign_key_constraints(
|
||||
table, recursive=allow_cascade
|
||||
):
|
||||
if allow_cascade:
|
||||
truncated_tables.add(foreign_table)
|
||||
constraints.add((foreign_table, constraint))
|
||||
sql = (
|
||||
[
|
||||
"%s %s %s %s %s %s %s %s;"
|
||||
% (
|
||||
style.SQL_KEYWORD("ALTER"),
|
||||
style.SQL_KEYWORD("TABLE"),
|
||||
style.SQL_FIELD(self.quote_name(table)),
|
||||
style.SQL_KEYWORD("DISABLE"),
|
||||
style.SQL_KEYWORD("CONSTRAINT"),
|
||||
style.SQL_FIELD(self.quote_name(constraint)),
|
||||
style.SQL_KEYWORD("KEEP"),
|
||||
style.SQL_KEYWORD("INDEX"),
|
||||
)
|
||||
for table, constraint in constraints
|
||||
]
|
||||
+ [
|
||||
"%s %s %s;"
|
||||
% (
|
||||
style.SQL_KEYWORD("TRUNCATE"),
|
||||
style.SQL_KEYWORD("TABLE"),
|
||||
style.SQL_FIELD(self.quote_name(table)),
|
||||
)
|
||||
for table in truncated_tables
|
||||
]
|
||||
+ [
|
||||
"%s %s %s %s %s %s;"
|
||||
% (
|
||||
style.SQL_KEYWORD("ALTER"),
|
||||
style.SQL_KEYWORD("TABLE"),
|
||||
style.SQL_FIELD(self.quote_name(table)),
|
||||
style.SQL_KEYWORD("ENABLE"),
|
||||
style.SQL_KEYWORD("CONSTRAINT"),
|
||||
style.SQL_FIELD(self.quote_name(constraint)),
|
||||
)
|
||||
for table, constraint in constraints
|
||||
]
|
||||
)
|
||||
if reset_sequences:
|
||||
sequences = [
|
||||
sequence
|
||||
for sequence in self.connection.introspection.sequence_list()
|
||||
if sequence["table"].upper() in truncated_tables
|
||||
]
|
||||
# Since we've just deleted all the rows, running our sequence ALTER
|
||||
# code will reset the sequence to 0.
|
||||
sql.extend(self.sequence_reset_by_name_sql(style, sequences))
|
||||
return sql
|
||||
|
||||
def sequence_reset_by_name_sql(self, style, sequences):
|
||||
sql = []
|
||||
for sequence_info in sequences:
|
||||
no_autofield_sequence_name = self._get_no_autofield_sequence_name(
|
||||
sequence_info["table"]
|
||||
)
|
||||
table = self.quote_name(sequence_info["table"])
|
||||
column = self.quote_name(sequence_info["column"] or "id")
|
||||
query = self._sequence_reset_sql % {
|
||||
"no_autofield_sequence_name": no_autofield_sequence_name,
|
||||
"table": table,
|
||||
"column": column,
|
||||
"table_name": strip_quotes(table),
|
||||
"column_name": strip_quotes(column),
|
||||
}
|
||||
sql.append(query)
|
||||
return sql
|
||||
|
||||
def sequence_reset_sql(self, style, model_list):
|
||||
output = []
|
||||
query = self._sequence_reset_sql
|
||||
for model in model_list:
|
||||
for f in model._meta.local_fields:
|
||||
if isinstance(f, AutoField):
|
||||
no_autofield_sequence_name = self._get_no_autofield_sequence_name(
|
||||
model._meta.db_table
|
||||
)
|
||||
table = self.quote_name(model._meta.db_table)
|
||||
column = self.quote_name(f.column)
|
||||
output.append(
|
||||
query
|
||||
% {
|
||||
"no_autofield_sequence_name": no_autofield_sequence_name,
|
||||
"table": table,
|
||||
"column": column,
|
||||
"table_name": strip_quotes(table),
|
||||
"column_name": strip_quotes(column),
|
||||
}
|
||||
)
|
||||
# Only one AutoField is allowed per model, so don't
|
||||
# continue to loop
|
||||
break
|
||||
return output
|
||||
|
||||
def start_transaction_sql(self):
|
||||
return ""
|
||||
|
||||
def tablespace_sql(self, tablespace, inline=False):
|
||||
if inline:
|
||||
return "USING INDEX TABLESPACE %s" % self.quote_name(tablespace)
|
||||
else:
|
||||
return "TABLESPACE %s" % self.quote_name(tablespace)
|
||||
|
||||
def adapt_datefield_value(self, value):
|
||||
"""
|
||||
Transform a date value to an object compatible with what is expected
|
||||
by the backend driver for date columns.
|
||||
The default implementation transforms the date to text, but that is not
|
||||
necessary for Oracle.
|
||||
"""
|
||||
return value
|
||||
|
||||
def adapt_datetimefield_value(self, value):
|
||||
"""
|
||||
Transform a datetime value to an object compatible with what is expected
|
||||
by the backend driver for datetime columns.
|
||||
|
||||
If naive datetime is passed assumes that is in UTC. Normally Django
|
||||
models.DateTimeField makes sure that if USE_TZ is True passed datetime
|
||||
is timezone aware.
|
||||
"""
|
||||
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# Expression values are adapted by the database.
|
||||
if hasattr(value, "resolve_expression"):
|
||||
return value
|
||||
|
||||
# cx_Oracle doesn't support tz-aware datetimes
|
||||
if timezone.is_aware(value):
|
||||
if settings.USE_TZ:
|
||||
value = timezone.make_naive(value, self.connection.timezone)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Oracle backend does not support timezone-aware datetimes when "
|
||||
"USE_TZ is False."
|
||||
)
|
||||
|
||||
return Oracle_datetime.from_datetime(value)
|
||||
|
||||
def adapt_timefield_value(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# Expression values are adapted by the database.
|
||||
if hasattr(value, "resolve_expression"):
|
||||
return value
|
||||
|
||||
if isinstance(value, str):
|
||||
return datetime.datetime.strptime(value, "%H:%M:%S")
|
||||
|
||||
# Oracle doesn't support tz-aware times
|
||||
if timezone.is_aware(value):
|
||||
raise ValueError("Oracle backend does not support timezone-aware times.")
|
||||
|
||||
return Oracle_datetime(
|
||||
1900, 1, 1, value.hour, value.minute, value.second, value.microsecond
|
||||
)
|
||||
|
||||
def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None):
|
||||
return value
|
||||
|
||||
def combine_expression(self, connector, sub_expressions):
|
||||
lhs, rhs = sub_expressions
|
||||
if connector == "%%":
|
||||
return "MOD(%s)" % ",".join(sub_expressions)
|
||||
elif connector == "&":
|
||||
return "BITAND(%s)" % ",".join(sub_expressions)
|
||||
elif connector == "|":
|
||||
return "BITAND(-%(lhs)s-1,%(rhs)s)+%(lhs)s" % {"lhs": lhs, "rhs": rhs}
|
||||
elif connector == "<<":
|
||||
return "(%(lhs)s * POWER(2, %(rhs)s))" % {"lhs": lhs, "rhs": rhs}
|
||||
elif connector == ">>":
|
||||
return "FLOOR(%(lhs)s / POWER(2, %(rhs)s))" % {"lhs": lhs, "rhs": rhs}
|
||||
elif connector == "^":
|
||||
return "POWER(%s)" % ",".join(sub_expressions)
|
||||
elif connector == "#":
|
||||
raise NotSupportedError("Bitwise XOR is not supported in Oracle.")
|
||||
return super().combine_expression(connector, sub_expressions)
|
||||
|
||||
def _get_no_autofield_sequence_name(self, table):
|
||||
"""
|
||||
Manually created sequence name to keep backward compatibility for
|
||||
AutoFields that aren't Oracle identity columns.
|
||||
"""
|
||||
name_length = self.max_name_length() - 3
|
||||
return "%s_SQ" % truncate_name(strip_quotes(table), name_length).upper()
|
||||
|
||||
def _get_sequence_name(self, cursor, table, pk_name):
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT sequence_name
|
||||
FROM user_tab_identity_cols
|
||||
WHERE table_name = UPPER(%s)
|
||||
AND column_name = UPPER(%s)""",
|
||||
[table, pk_name],
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return self._get_no_autofield_sequence_name(table) if row is None else row[0]
|
||||
|
||||
def bulk_insert_sql(self, fields, placeholder_rows):
|
||||
query = []
|
||||
for row in placeholder_rows:
|
||||
select = []
|
||||
for i, placeholder in enumerate(row):
|
||||
# A model without any fields has fields=[None].
|
||||
if fields[i]:
|
||||
internal_type = getattr(
|
||||
fields[i], "target_field", fields[i]
|
||||
).get_internal_type()
|
||||
placeholder = (
|
||||
BulkInsertMapper.types.get(internal_type, "%s") % placeholder
|
||||
)
|
||||
# Add columns aliases to the first select to avoid "ORA-00918:
|
||||
# column ambiguously defined" when two or more columns in the
|
||||
# first select have the same value.
|
||||
if not query:
|
||||
placeholder = "%s col_%s" % (placeholder, i)
|
||||
select.append(placeholder)
|
||||
query.append("SELECT %s FROM DUAL" % ", ".join(select))
|
||||
# Bulk insert to tables with Oracle identity columns causes Oracle to
|
||||
# add sequence.nextval to it. Sequence.nextval cannot be used with the
|
||||
# UNION operator. To prevent incorrect SQL, move UNION to a subquery.
|
||||
return "SELECT * FROM (%s)" % " UNION ALL ".join(query)
|
||||
|
||||
def subtract_temporals(self, internal_type, lhs, rhs):
|
||||
if internal_type == "DateField":
|
||||
lhs_sql, lhs_params = lhs
|
||||
rhs_sql, rhs_params = rhs
|
||||
params = (*lhs_params, *rhs_params)
|
||||
return (
|
||||
"NUMTODSINTERVAL(TO_NUMBER(%s - %s), 'DAY')" % (lhs_sql, rhs_sql),
|
||||
params,
|
||||
)
|
||||
return super().subtract_temporals(internal_type, lhs, rhs)
|
||||
|
||||
def bulk_batch_size(self, fields, objs):
|
||||
"""Oracle restricts the number of parameters in a query."""
|
||||
if fields:
|
||||
return self.connection.features.max_query_params // len(fields)
|
||||
return len(objs)
|
||||
|
||||
def conditional_expression_supported_in_where_clause(self, expression):
|
||||
"""
|
||||
Oracle supports only EXISTS(...) or filters in the WHERE clause, others
|
||||
must be compared with True.
|
||||
"""
|
||||
if isinstance(expression, (Exists, Lookup, WhereNode)):
|
||||
return True
|
||||
if isinstance(expression, ExpressionWrapper) and expression.conditional:
|
||||
return self.conditional_expression_supported_in_where_clause(
|
||||
expression.expression
|
||||
)
|
||||
if isinstance(expression, RawSQL) and expression.conditional:
|
||||
return True
|
||||
return False
|
||||
@@ -0,0 +1,250 @@
|
||||
import copy
|
||||
import datetime
|
||||
import re
|
||||
|
||||
from django.db import DatabaseError
|
||||
from django.db.backends.base.schema import (
|
||||
BaseDatabaseSchemaEditor,
|
||||
_related_non_m2m_objects,
|
||||
)
|
||||
from django.utils.duration import duration_iso_string
|
||||
|
||||
|
||||
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||
sql_create_column = "ALTER TABLE %(table)s ADD %(column)s %(definition)s"
|
||||
sql_alter_column_type = "MODIFY %(column)s %(type)s%(collation)s"
|
||||
sql_alter_column_null = "MODIFY %(column)s NULL"
|
||||
sql_alter_column_not_null = "MODIFY %(column)s NOT NULL"
|
||||
sql_alter_column_default = "MODIFY %(column)s DEFAULT %(default)s"
|
||||
sql_alter_column_no_default = "MODIFY %(column)s DEFAULT NULL"
|
||||
sql_alter_column_no_default_null = sql_alter_column_no_default
|
||||
|
||||
sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s"
|
||||
sql_create_column_inline_fk = (
|
||||
"CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s"
|
||||
)
|
||||
sql_delete_table = "DROP TABLE %(table)s CASCADE CONSTRAINTS"
|
||||
sql_create_index = "CREATE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s"
|
||||
|
||||
def quote_value(self, value):
|
||||
if isinstance(value, (datetime.date, datetime.time, datetime.datetime)):
|
||||
return "'%s'" % value
|
||||
elif isinstance(value, datetime.timedelta):
|
||||
return "'%s'" % duration_iso_string(value)
|
||||
elif isinstance(value, str):
|
||||
return "'%s'" % value.replace("'", "''").replace("%", "%%")
|
||||
elif isinstance(value, (bytes, bytearray, memoryview)):
|
||||
return "'%s'" % value.hex()
|
||||
elif isinstance(value, bool):
|
||||
return "1" if value else "0"
|
||||
else:
|
||||
return str(value)
|
||||
|
||||
def remove_field(self, model, field):
|
||||
# If the column is an identity column, drop the identity before
|
||||
# removing the field.
|
||||
if self._is_identity_column(model._meta.db_table, field.column):
|
||||
self._drop_identity(model._meta.db_table, field.column)
|
||||
super().remove_field(model, field)
|
||||
|
||||
def delete_model(self, model):
|
||||
# Run superclass action
|
||||
super().delete_model(model)
|
||||
# Clean up manually created sequence.
|
||||
self.execute(
|
||||
"""
|
||||
DECLARE
|
||||
i INTEGER;
|
||||
BEGIN
|
||||
SELECT COUNT(1) INTO i FROM USER_SEQUENCES
|
||||
WHERE SEQUENCE_NAME = '%(sq_name)s';
|
||||
IF i = 1 THEN
|
||||
EXECUTE IMMEDIATE 'DROP SEQUENCE "%(sq_name)s"';
|
||||
END IF;
|
||||
END;
|
||||
/"""
|
||||
% {
|
||||
"sq_name": self.connection.ops._get_no_autofield_sequence_name(
|
||||
model._meta.db_table
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
def alter_field(self, model, old_field, new_field, strict=False):
|
||||
try:
|
||||
super().alter_field(model, old_field, new_field, strict)
|
||||
except DatabaseError as e:
|
||||
description = str(e)
|
||||
# If we're changing type to an unsupported type we need a
|
||||
# SQLite-ish workaround
|
||||
if "ORA-22858" in description or "ORA-22859" in description:
|
||||
self._alter_field_type_workaround(model, old_field, new_field)
|
||||
# If an identity column is changing to a non-numeric type, drop the
|
||||
# identity first.
|
||||
elif "ORA-30675" in description:
|
||||
self._drop_identity(model._meta.db_table, old_field.column)
|
||||
self.alter_field(model, old_field, new_field, strict)
|
||||
# If a primary key column is changing to an identity column, drop
|
||||
# the primary key first.
|
||||
elif "ORA-30673" in description and old_field.primary_key:
|
||||
self._delete_primary_key(model, strict=True)
|
||||
self._alter_field_type_workaround(model, old_field, new_field)
|
||||
# If a collation is changing on a primary key, drop the primary key
|
||||
# first.
|
||||
elif "ORA-43923" in description and old_field.primary_key:
|
||||
self._delete_primary_key(model, strict=True)
|
||||
self.alter_field(model, old_field, new_field, strict)
|
||||
# Restore a primary key, if needed.
|
||||
if new_field.primary_key:
|
||||
self.execute(self._create_primary_key_sql(model, new_field))
|
||||
else:
|
||||
raise
|
||||
|
||||
def _alter_field_type_workaround(self, model, old_field, new_field):
|
||||
"""
|
||||
Oracle refuses to change from some type to other type.
|
||||
What we need to do instead is:
|
||||
- Add a nullable version of the desired field with a temporary name. If
|
||||
the new column is an auto field, then the temporary column can't be
|
||||
nullable.
|
||||
- Update the table to transfer values from old to new
|
||||
- Drop old column
|
||||
- Rename the new column and possibly drop the nullable property
|
||||
"""
|
||||
# Make a new field that's like the new one but with a temporary
|
||||
# column name.
|
||||
new_temp_field = copy.deepcopy(new_field)
|
||||
new_temp_field.null = new_field.get_internal_type() not in (
|
||||
"AutoField",
|
||||
"BigAutoField",
|
||||
"SmallAutoField",
|
||||
)
|
||||
new_temp_field.column = self._generate_temp_name(new_field.column)
|
||||
# Add it
|
||||
self.add_field(model, new_temp_field)
|
||||
# Explicit data type conversion
|
||||
# https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf
|
||||
# /Data-Type-Comparison-Rules.html#GUID-D0C5A47E-6F93-4C2D-9E49-4F2B86B359DD
|
||||
new_value = self.quote_name(old_field.column)
|
||||
old_type = old_field.db_type(self.connection)
|
||||
if re.match("^N?CLOB", old_type):
|
||||
new_value = "TO_CHAR(%s)" % new_value
|
||||
old_type = "VARCHAR2"
|
||||
if re.match("^N?VARCHAR2", old_type):
|
||||
new_internal_type = new_field.get_internal_type()
|
||||
if new_internal_type == "DateField":
|
||||
new_value = "TO_DATE(%s, 'YYYY-MM-DD')" % new_value
|
||||
elif new_internal_type == "DateTimeField":
|
||||
new_value = "TO_TIMESTAMP(%s, 'YYYY-MM-DD HH24:MI:SS.FF')" % new_value
|
||||
elif new_internal_type == "TimeField":
|
||||
# TimeField are stored as TIMESTAMP with a 1900-01-01 date part.
|
||||
new_value = "CONCAT('1900-01-01 ', %s)" % new_value
|
||||
new_value = "TO_TIMESTAMP(%s, 'YYYY-MM-DD HH24:MI:SS.FF')" % new_value
|
||||
# Transfer values across
|
||||
self.execute(
|
||||
"UPDATE %s set %s=%s"
|
||||
% (
|
||||
self.quote_name(model._meta.db_table),
|
||||
self.quote_name(new_temp_field.column),
|
||||
new_value,
|
||||
)
|
||||
)
|
||||
# Drop the old field
|
||||
self.remove_field(model, old_field)
|
||||
# Rename and possibly make the new field NOT NULL
|
||||
super().alter_field(model, new_temp_field, new_field)
|
||||
# Recreate foreign key (if necessary) because the old field is not
|
||||
# passed to the alter_field() and data types of new_temp_field and
|
||||
# new_field always match.
|
||||
new_type = new_field.db_type(self.connection)
|
||||
if (
|
||||
(old_field.primary_key and new_field.primary_key)
|
||||
or (old_field.unique and new_field.unique)
|
||||
) and old_type != new_type:
|
||||
for _, rel in _related_non_m2m_objects(new_temp_field, new_field):
|
||||
if rel.field.db_constraint:
|
||||
self.execute(
|
||||
self._create_fk_sql(rel.related_model, rel.field, "_fk")
|
||||
)
|
||||
|
||||
def _alter_column_type_sql(
|
||||
self, model, old_field, new_field, new_type, old_collation, new_collation
|
||||
):
|
||||
auto_field_types = {"AutoField", "BigAutoField", "SmallAutoField"}
|
||||
# Drop the identity if migrating away from AutoField.
|
||||
if (
|
||||
old_field.get_internal_type() in auto_field_types
|
||||
and new_field.get_internal_type() not in auto_field_types
|
||||
and self._is_identity_column(model._meta.db_table, new_field.column)
|
||||
):
|
||||
self._drop_identity(model._meta.db_table, new_field.column)
|
||||
return super()._alter_column_type_sql(
|
||||
model, old_field, new_field, new_type, old_collation, new_collation
|
||||
)
|
||||
|
||||
def normalize_name(self, name):
|
||||
"""
|
||||
Get the properly shortened and uppercased identifier as returned by
|
||||
quote_name() but without the quotes.
|
||||
"""
|
||||
nn = self.quote_name(name)
|
||||
if nn[0] == '"' and nn[-1] == '"':
|
||||
nn = nn[1:-1]
|
||||
return nn
|
||||
|
||||
def _generate_temp_name(self, for_name):
|
||||
"""Generate temporary names for workarounds that need temp columns."""
|
||||
suffix = hex(hash(for_name)).upper()[1:]
|
||||
return self.normalize_name(for_name + "_" + suffix)
|
||||
|
||||
def prepare_default(self, value):
|
||||
return self.quote_value(value)
|
||||
|
||||
def _field_should_be_indexed(self, model, field):
|
||||
create_index = super()._field_should_be_indexed(model, field)
|
||||
db_type = field.db_type(self.connection)
|
||||
if (
|
||||
db_type is not None
|
||||
and db_type.lower() in self.connection._limited_data_types
|
||||
):
|
||||
return False
|
||||
return create_index
|
||||
|
||||
def _is_identity_column(self, table_name, column_name):
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
CASE WHEN identity_column = 'YES' THEN 1 ELSE 0 END
|
||||
FROM user_tab_cols
|
||||
WHERE table_name = %s AND
|
||||
column_name = %s
|
||||
""",
|
||||
[self.normalize_name(table_name), self.normalize_name(column_name)],
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return row[0] if row else False
|
||||
|
||||
def _drop_identity(self, table_name, column_name):
|
||||
self.execute(
|
||||
"ALTER TABLE %(table)s MODIFY %(column)s DROP IDENTITY"
|
||||
% {
|
||||
"table": self.quote_name(table_name),
|
||||
"column": self.quote_name(column_name),
|
||||
}
|
||||
)
|
||||
|
||||
def _get_default_collation(self, table_name):
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT default_collation FROM user_tables WHERE table_name = %s
|
||||
""",
|
||||
[self.normalize_name(table_name)],
|
||||
)
|
||||
return cursor.fetchone()[0]
|
||||
|
||||
def _collate_sql(self, collation, old_collation=None, table_name=None):
|
||||
if collation is None and old_collation is not None:
|
||||
collation = self._get_default_collation(table_name)
|
||||
return super()._collate_sql(collation, old_collation, table_name)
|
||||
@@ -0,0 +1,97 @@
|
||||
import datetime
|
||||
|
||||
from .base import Database
|
||||
|
||||
|
||||
class InsertVar:
|
||||
"""
|
||||
A late-binding cursor variable that can be passed to Cursor.execute
|
||||
as a parameter, in order to receive the id of the row created by an
|
||||
insert statement.
|
||||
"""
|
||||
|
||||
types = {
|
||||
"AutoField": int,
|
||||
"BigAutoField": int,
|
||||
"SmallAutoField": int,
|
||||
"IntegerField": int,
|
||||
"BigIntegerField": int,
|
||||
"SmallIntegerField": int,
|
||||
"PositiveBigIntegerField": int,
|
||||
"PositiveSmallIntegerField": int,
|
||||
"PositiveIntegerField": int,
|
||||
"FloatField": Database.NATIVE_FLOAT,
|
||||
"DateTimeField": Database.TIMESTAMP,
|
||||
"DateField": Database.Date,
|
||||
"DecimalField": Database.NUMBER,
|
||||
}
|
||||
|
||||
def __init__(self, field):
|
||||
internal_type = getattr(field, "target_field", field).get_internal_type()
|
||||
self.db_type = self.types.get(internal_type, str)
|
||||
self.bound_param = None
|
||||
|
||||
def bind_parameter(self, cursor):
|
||||
self.bound_param = cursor.cursor.var(self.db_type)
|
||||
return self.bound_param
|
||||
|
||||
def get_value(self):
|
||||
return self.bound_param.getvalue()
|
||||
|
||||
|
||||
class Oracle_datetime(datetime.datetime):
|
||||
"""
|
||||
A datetime object, with an additional class attribute
|
||||
to tell cx_Oracle to save the microseconds too.
|
||||
"""
|
||||
|
||||
input_size = Database.TIMESTAMP
|
||||
|
||||
@classmethod
|
||||
def from_datetime(cls, dt):
|
||||
return Oracle_datetime(
|
||||
dt.year,
|
||||
dt.month,
|
||||
dt.day,
|
||||
dt.hour,
|
||||
dt.minute,
|
||||
dt.second,
|
||||
dt.microsecond,
|
||||
)
|
||||
|
||||
|
||||
class BulkInsertMapper:
|
||||
BLOB = "TO_BLOB(%s)"
|
||||
DATE = "TO_DATE(%s)"
|
||||
INTERVAL = "CAST(%s as INTERVAL DAY(9) TO SECOND(6))"
|
||||
NCLOB = "TO_NCLOB(%s)"
|
||||
NUMBER = "TO_NUMBER(%s)"
|
||||
TIMESTAMP = "TO_TIMESTAMP(%s)"
|
||||
|
||||
types = {
|
||||
"AutoField": NUMBER,
|
||||
"BigAutoField": NUMBER,
|
||||
"BigIntegerField": NUMBER,
|
||||
"BinaryField": BLOB,
|
||||
"BooleanField": NUMBER,
|
||||
"DateField": DATE,
|
||||
"DateTimeField": TIMESTAMP,
|
||||
"DecimalField": NUMBER,
|
||||
"DurationField": INTERVAL,
|
||||
"FloatField": NUMBER,
|
||||
"IntegerField": NUMBER,
|
||||
"PositiveBigIntegerField": NUMBER,
|
||||
"PositiveIntegerField": NUMBER,
|
||||
"PositiveSmallIntegerField": NUMBER,
|
||||
"SmallAutoField": NUMBER,
|
||||
"SmallIntegerField": NUMBER,
|
||||
"TextField": NCLOB,
|
||||
"TimeField": TIMESTAMP,
|
||||
}
|
||||
|
||||
|
||||
def dsn(settings_dict):
|
||||
if settings_dict["PORT"]:
|
||||
host = settings_dict["HOST"].strip() or "localhost"
|
||||
return Database.makedsn(host, int(settings_dict["PORT"]), settings_dict["NAME"])
|
||||
return settings_dict["NAME"]
|
||||
@@ -0,0 +1,22 @@
|
||||
from django.core import checks
|
||||
from django.db.backends.base.validation import BaseDatabaseValidation
|
||||
|
||||
|
||||
class DatabaseValidation(BaseDatabaseValidation):
|
||||
def check_field_type(self, field, field_type):
|
||||
"""Oracle doesn't support a database index on some data types."""
|
||||
errors = []
|
||||
if field.db_index and field_type.lower() in self.connection._limited_data_types:
|
||||
errors.append(
|
||||
checks.Warning(
|
||||
"Oracle does not support a database index on %s columns."
|
||||
% field_type,
|
||||
hint=(
|
||||
"An index won't be created. Silence this warning if "
|
||||
"you don't care about it."
|
||||
),
|
||||
obj=field,
|
||||
id="fields.W162",
|
||||
)
|
||||
)
|
||||
return errors
|
||||
@@ -0,0 +1,487 @@
|
||||
"""
|
||||
PostgreSQL database backend for Django.
|
||||
|
||||
Requires psycopg2 >= 2.8.4 or psycopg >= 3.1.8
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db import DatabaseError as WrappedDatabaseError
|
||||
from django.db import connections
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.db.backends.utils import CursorDebugWrapper as BaseCursorDebugWrapper
|
||||
from django.utils.asyncio import async_unsafe
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.safestring import SafeString
|
||||
from django.utils.version import get_version_tuple
|
||||
|
||||
try:
|
||||
try:
|
||||
import psycopg as Database
|
||||
except ImportError:
|
||||
import psycopg2 as Database
|
||||
except ImportError:
|
||||
raise ImproperlyConfigured("Error loading psycopg2 or psycopg module")
|
||||
|
||||
|
||||
def psycopg_version():
|
||||
version = Database.__version__.split(" ", 1)[0]
|
||||
return get_version_tuple(version)
|
||||
|
||||
|
||||
if psycopg_version() < (2, 8, 4):
|
||||
raise ImproperlyConfigured(
|
||||
f"psycopg2 version 2.8.4 or newer is required; you have {Database.__version__}"
|
||||
)
|
||||
if (3,) <= psycopg_version() < (3, 1, 8):
|
||||
raise ImproperlyConfigured(
|
||||
f"psycopg version 3.1.8 or newer is required; you have {Database.__version__}"
|
||||
)
|
||||
|
||||
|
||||
from .psycopg_any import IsolationLevel, is_psycopg3 # NOQA isort:skip
|
||||
|
||||
if is_psycopg3:
|
||||
from psycopg import adapters, sql
|
||||
from psycopg.pq import Format
|
||||
|
||||
from .psycopg_any import get_adapters_template, register_tzloader
|
||||
|
||||
TIMESTAMPTZ_OID = adapters.types["timestamptz"].oid
|
||||
|
||||
else:
|
||||
import psycopg2.extensions
|
||||
import psycopg2.extras
|
||||
|
||||
psycopg2.extensions.register_adapter(SafeString, psycopg2.extensions.QuotedString)
|
||||
psycopg2.extras.register_uuid()
|
||||
|
||||
# Register support for inet[] manually so we don't have to handle the Inet()
|
||||
# object on load all the time.
|
||||
INETARRAY_OID = 1041
|
||||
INETARRAY = psycopg2.extensions.new_array_type(
|
||||
(INETARRAY_OID,),
|
||||
"INETARRAY",
|
||||
psycopg2.extensions.UNICODE,
|
||||
)
|
||||
psycopg2.extensions.register_type(INETARRAY)
|
||||
|
||||
# Some of these import psycopg, so import them after checking if it's installed.
|
||||
from .client import DatabaseClient # NOQA isort:skip
|
||||
from .creation import DatabaseCreation # NOQA isort:skip
|
||||
from .features import DatabaseFeatures # NOQA isort:skip
|
||||
from .introspection import DatabaseIntrospection # NOQA isort:skip
|
||||
from .operations import DatabaseOperations # NOQA isort:skip
|
||||
from .schema import DatabaseSchemaEditor # NOQA isort:skip
|
||||
|
||||
|
||||
def _get_varchar_column(data):
|
||||
if data["max_length"] is None:
|
||||
return "varchar"
|
||||
return "varchar(%(max_length)s)" % data
|
||||
|
||||
|
||||
class DatabaseWrapper(BaseDatabaseWrapper):
|
||||
vendor = "postgresql"
|
||||
display_name = "PostgreSQL"
|
||||
# This dictionary maps Field objects to their associated PostgreSQL column
|
||||
# types, as strings. Column-type strings can contain format strings; they'll
|
||||
# be interpolated against the values of Field.__dict__ before being output.
|
||||
# If a column type is set to None, it won't be included in the output.
|
||||
data_types = {
|
||||
"AutoField": "integer",
|
||||
"BigAutoField": "bigint",
|
||||
"BinaryField": "bytea",
|
||||
"BooleanField": "boolean",
|
||||
"CharField": _get_varchar_column,
|
||||
"DateField": "date",
|
||||
"DateTimeField": "timestamp with time zone",
|
||||
"DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)",
|
||||
"DurationField": "interval",
|
||||
"FileField": "varchar(%(max_length)s)",
|
||||
"FilePathField": "varchar(%(max_length)s)",
|
||||
"FloatField": "double precision",
|
||||
"IntegerField": "integer",
|
||||
"BigIntegerField": "bigint",
|
||||
"IPAddressField": "inet",
|
||||
"GenericIPAddressField": "inet",
|
||||
"JSONField": "jsonb",
|
||||
"OneToOneField": "integer",
|
||||
"PositiveBigIntegerField": "bigint",
|
||||
"PositiveIntegerField": "integer",
|
||||
"PositiveSmallIntegerField": "smallint",
|
||||
"SlugField": "varchar(%(max_length)s)",
|
||||
"SmallAutoField": "smallint",
|
||||
"SmallIntegerField": "smallint",
|
||||
"TextField": "text",
|
||||
"TimeField": "time",
|
||||
"UUIDField": "uuid",
|
||||
}
|
||||
data_type_check_constraints = {
|
||||
"PositiveBigIntegerField": '"%(column)s" >= 0',
|
||||
"PositiveIntegerField": '"%(column)s" >= 0',
|
||||
"PositiveSmallIntegerField": '"%(column)s" >= 0',
|
||||
}
|
||||
data_types_suffix = {
|
||||
"AutoField": "GENERATED BY DEFAULT AS IDENTITY",
|
||||
"BigAutoField": "GENERATED BY DEFAULT AS IDENTITY",
|
||||
"SmallAutoField": "GENERATED BY DEFAULT AS IDENTITY",
|
||||
}
|
||||
operators = {
|
||||
"exact": "= %s",
|
||||
"iexact": "= UPPER(%s)",
|
||||
"contains": "LIKE %s",
|
||||
"icontains": "LIKE UPPER(%s)",
|
||||
"regex": "~ %s",
|
||||
"iregex": "~* %s",
|
||||
"gt": "> %s",
|
||||
"gte": ">= %s",
|
||||
"lt": "< %s",
|
||||
"lte": "<= %s",
|
||||
"startswith": "LIKE %s",
|
||||
"endswith": "LIKE %s",
|
||||
"istartswith": "LIKE UPPER(%s)",
|
||||
"iendswith": "LIKE UPPER(%s)",
|
||||
}
|
||||
|
||||
# The patterns below are used to generate SQL pattern lookup clauses when
|
||||
# the right-hand side of the lookup isn't a raw string (it might be an expression
|
||||
# or the result of a bilateral transformation).
|
||||
# In those cases, special characters for LIKE operators (e.g. \, *, _) should be
|
||||
# escaped on database side.
|
||||
#
|
||||
# Note: we use str.format() here for readability as '%' is used as a wildcard for
|
||||
# the LIKE operator.
|
||||
pattern_esc = (
|
||||
r"REPLACE(REPLACE(REPLACE({}, E'\\', E'\\\\'), E'%%', E'\\%%'), E'_', E'\\_')"
|
||||
)
|
||||
pattern_ops = {
|
||||
"contains": "LIKE '%%' || {} || '%%'",
|
||||
"icontains": "LIKE '%%' || UPPER({}) || '%%'",
|
||||
"startswith": "LIKE {} || '%%'",
|
||||
"istartswith": "LIKE UPPER({}) || '%%'",
|
||||
"endswith": "LIKE '%%' || {}",
|
||||
"iendswith": "LIKE '%%' || UPPER({})",
|
||||
}
|
||||
|
||||
Database = Database
|
||||
SchemaEditorClass = DatabaseSchemaEditor
|
||||
# Classes instantiated in __init__().
|
||||
client_class = DatabaseClient
|
||||
creation_class = DatabaseCreation
|
||||
features_class = DatabaseFeatures
|
||||
introspection_class = DatabaseIntrospection
|
||||
ops_class = DatabaseOperations
|
||||
# PostgreSQL backend-specific attributes.
|
||||
_named_cursor_idx = 0
|
||||
|
||||
def get_database_version(self):
|
||||
"""
|
||||
Return a tuple of the database's version.
|
||||
E.g. for pg_version 120004, return (12, 4).
|
||||
"""
|
||||
return divmod(self.pg_version, 10000)
|
||||
|
||||
def get_connection_params(self):
|
||||
settings_dict = self.settings_dict
|
||||
# None may be used to connect to the default 'postgres' db
|
||||
if settings_dict["NAME"] == "" and not settings_dict.get("OPTIONS", {}).get(
|
||||
"service"
|
||||
):
|
||||
raise ImproperlyConfigured(
|
||||
"settings.DATABASES is improperly configured. "
|
||||
"Please supply the NAME or OPTIONS['service'] value."
|
||||
)
|
||||
if len(settings_dict["NAME"] or "") > self.ops.max_name_length():
|
||||
raise ImproperlyConfigured(
|
||||
"The database name '%s' (%d characters) is longer than "
|
||||
"PostgreSQL's limit of %d characters. Supply a shorter NAME "
|
||||
"in settings.DATABASES."
|
||||
% (
|
||||
settings_dict["NAME"],
|
||||
len(settings_dict["NAME"]),
|
||||
self.ops.max_name_length(),
|
||||
)
|
||||
)
|
||||
if settings_dict["NAME"]:
|
||||
conn_params = {
|
||||
"dbname": settings_dict["NAME"],
|
||||
**settings_dict["OPTIONS"],
|
||||
}
|
||||
elif settings_dict["NAME"] is None:
|
||||
# Connect to the default 'postgres' db.
|
||||
settings_dict.get("OPTIONS", {}).pop("service", None)
|
||||
conn_params = {"dbname": "postgres", **settings_dict["OPTIONS"]}
|
||||
else:
|
||||
conn_params = {**settings_dict["OPTIONS"]}
|
||||
conn_params["client_encoding"] = "UTF8"
|
||||
|
||||
conn_params.pop("assume_role", None)
|
||||
conn_params.pop("isolation_level", None)
|
||||
server_side_binding = conn_params.pop("server_side_binding", None)
|
||||
conn_params.setdefault(
|
||||
"cursor_factory",
|
||||
ServerBindingCursor
|
||||
if is_psycopg3 and server_side_binding is True
|
||||
else Cursor,
|
||||
)
|
||||
if settings_dict["USER"]:
|
||||
conn_params["user"] = settings_dict["USER"]
|
||||
if settings_dict["PASSWORD"]:
|
||||
conn_params["password"] = settings_dict["PASSWORD"]
|
||||
if settings_dict["HOST"]:
|
||||
conn_params["host"] = settings_dict["HOST"]
|
||||
if settings_dict["PORT"]:
|
||||
conn_params["port"] = settings_dict["PORT"]
|
||||
if is_psycopg3:
|
||||
conn_params["context"] = get_adapters_template(
|
||||
settings.USE_TZ, self.timezone
|
||||
)
|
||||
# Disable prepared statements by default to keep connection poolers
|
||||
# working. Can be reenabled via OPTIONS in the settings dict.
|
||||
conn_params["prepare_threshold"] = conn_params.pop(
|
||||
"prepare_threshold", None
|
||||
)
|
||||
return conn_params
|
||||
|
||||
@async_unsafe
|
||||
def get_new_connection(self, conn_params):
|
||||
# self.isolation_level must be set:
|
||||
# - after connecting to the database in order to obtain the database's
|
||||
# default when no value is explicitly specified in options.
|
||||
# - before calling _set_autocommit() because if autocommit is on, that
|
||||
# will set connection.isolation_level to ISOLATION_LEVEL_AUTOCOMMIT.
|
||||
options = self.settings_dict["OPTIONS"]
|
||||
set_isolation_level = False
|
||||
try:
|
||||
isolation_level_value = options["isolation_level"]
|
||||
except KeyError:
|
||||
self.isolation_level = IsolationLevel.READ_COMMITTED
|
||||
else:
|
||||
# Set the isolation level to the value from OPTIONS.
|
||||
try:
|
||||
self.isolation_level = IsolationLevel(isolation_level_value)
|
||||
set_isolation_level = True
|
||||
except ValueError:
|
||||
raise ImproperlyConfigured(
|
||||
f"Invalid transaction isolation level {isolation_level_value} "
|
||||
f"specified. Use one of the psycopg.IsolationLevel values."
|
||||
)
|
||||
connection = self.Database.connect(**conn_params)
|
||||
if set_isolation_level:
|
||||
connection.isolation_level = self.isolation_level
|
||||
if not is_psycopg3:
|
||||
# Register dummy loads() to avoid a round trip from psycopg2's
|
||||
# decode to json.dumps() to json.loads(), when using a custom
|
||||
# decoder in JSONField.
|
||||
psycopg2.extras.register_default_jsonb(
|
||||
conn_or_curs=connection, loads=lambda x: x
|
||||
)
|
||||
return connection
|
||||
|
||||
def ensure_timezone(self):
|
||||
if self.connection is None:
|
||||
return False
|
||||
conn_timezone_name = self.connection.info.parameter_status("TimeZone")
|
||||
timezone_name = self.timezone_name
|
||||
if timezone_name and conn_timezone_name != timezone_name:
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(self.ops.set_time_zone_sql(), [timezone_name])
|
||||
return True
|
||||
return False
|
||||
|
||||
def ensure_role(self):
|
||||
if self.connection is None:
|
||||
return False
|
||||
if new_role := self.settings_dict.get("OPTIONS", {}).get("assume_role"):
|
||||
with self.connection.cursor() as cursor:
|
||||
sql = self.ops.compose_sql("SET ROLE %s", [new_role])
|
||||
cursor.execute(sql)
|
||||
return True
|
||||
return False
|
||||
|
||||
def init_connection_state(self):
|
||||
super().init_connection_state()
|
||||
|
||||
# Commit after setting the time zone.
|
||||
commit_tz = self.ensure_timezone()
|
||||
# Set the role on the connection. This is useful if the credential used
|
||||
# to login is not the same as the role that owns database resources. As
|
||||
# can be the case when using temporary or ephemeral credentials.
|
||||
commit_role = self.ensure_role()
|
||||
|
||||
if (commit_role or commit_tz) and not self.get_autocommit():
|
||||
self.connection.commit()
|
||||
|
||||
@async_unsafe
|
||||
def create_cursor(self, name=None):
|
||||
if name:
|
||||
# In autocommit mode, the cursor will be used outside of a
|
||||
# transaction, hence use a holdable cursor.
|
||||
cursor = self.connection.cursor(
|
||||
name, scrollable=False, withhold=self.connection.autocommit
|
||||
)
|
||||
else:
|
||||
cursor = self.connection.cursor()
|
||||
|
||||
if is_psycopg3:
|
||||
# Register the cursor timezone only if the connection disagrees, to
|
||||
# avoid copying the adapter map.
|
||||
tzloader = self.connection.adapters.get_loader(TIMESTAMPTZ_OID, Format.TEXT)
|
||||
if self.timezone != tzloader.timezone:
|
||||
register_tzloader(self.timezone, cursor)
|
||||
else:
|
||||
cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None
|
||||
return cursor
|
||||
|
||||
def tzinfo_factory(self, offset):
|
||||
return self.timezone
|
||||
|
||||
@async_unsafe
|
||||
def chunked_cursor(self):
|
||||
self._named_cursor_idx += 1
|
||||
# Get the current async task
|
||||
# Note that right now this is behind @async_unsafe, so this is
|
||||
# unreachable, but in future we'll start loosening this restriction.
|
||||
# For now, it's here so that every use of "threading" is
|
||||
# also async-compatible.
|
||||
try:
|
||||
current_task = asyncio.current_task()
|
||||
except RuntimeError:
|
||||
current_task = None
|
||||
# Current task can be none even if the current_task call didn't error
|
||||
if current_task:
|
||||
task_ident = str(id(current_task))
|
||||
else:
|
||||
task_ident = "sync"
|
||||
# Use that and the thread ident to get a unique name
|
||||
return self._cursor(
|
||||
name="_django_curs_%d_%s_%d"
|
||||
% (
|
||||
# Avoid reusing name in other threads / tasks
|
||||
threading.current_thread().ident,
|
||||
task_ident,
|
||||
self._named_cursor_idx,
|
||||
)
|
||||
)
|
||||
|
||||
def _set_autocommit(self, autocommit):
|
||||
with self.wrap_database_errors:
|
||||
self.connection.autocommit = autocommit
|
||||
|
||||
def check_constraints(self, table_names=None):
|
||||
"""
|
||||
Check constraints by setting them to immediate. Return them to deferred
|
||||
afterward.
|
||||
"""
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute("SET CONSTRAINTS ALL IMMEDIATE")
|
||||
cursor.execute("SET CONSTRAINTS ALL DEFERRED")
|
||||
|
||||
def is_usable(self):
|
||||
try:
|
||||
# Use a psycopg cursor directly, bypassing Django's utilities.
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute("SELECT 1")
|
||||
except Database.Error:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
@contextmanager
|
||||
def _nodb_cursor(self):
|
||||
cursor = None
|
||||
try:
|
||||
with super()._nodb_cursor() as cursor:
|
||||
yield cursor
|
||||
except (Database.DatabaseError, WrappedDatabaseError):
|
||||
if cursor is not None:
|
||||
raise
|
||||
warnings.warn(
|
||||
"Normally Django will use a connection to the 'postgres' database "
|
||||
"to avoid running initialization queries against the production "
|
||||
"database when it's not needed (for example, when running tests). "
|
||||
"Django was unable to create a connection to the 'postgres' database "
|
||||
"and will use the first PostgreSQL database instead.",
|
||||
RuntimeWarning,
|
||||
)
|
||||
for connection in connections.all():
|
||||
if (
|
||||
connection.vendor == "postgresql"
|
||||
and connection.settings_dict["NAME"] != "postgres"
|
||||
):
|
||||
conn = self.__class__(
|
||||
{
|
||||
**self.settings_dict,
|
||||
"NAME": connection.settings_dict["NAME"],
|
||||
},
|
||||
alias=self.alias,
|
||||
)
|
||||
try:
|
||||
with conn.cursor() as cursor:
|
||||
yield cursor
|
||||
finally:
|
||||
conn.close()
|
||||
break
|
||||
else:
|
||||
raise
|
||||
|
||||
@cached_property
|
||||
def pg_version(self):
|
||||
with self.temporary_connection():
|
||||
return self.connection.info.server_version
|
||||
|
||||
def make_debug_cursor(self, cursor):
|
||||
return CursorDebugWrapper(cursor, self)
|
||||
|
||||
|
||||
if is_psycopg3:
|
||||
|
||||
class CursorMixin:
|
||||
"""
|
||||
A subclass of psycopg cursor implementing callproc.
|
||||
"""
|
||||
|
||||
def callproc(self, name, args=None):
|
||||
if not isinstance(name, sql.Identifier):
|
||||
name = sql.Identifier(name)
|
||||
|
||||
qparts = [sql.SQL("SELECT * FROM "), name, sql.SQL("(")]
|
||||
if args:
|
||||
for item in args:
|
||||
qparts.append(sql.Literal(item))
|
||||
qparts.append(sql.SQL(","))
|
||||
del qparts[-1]
|
||||
|
||||
qparts.append(sql.SQL(")"))
|
||||
stmt = sql.Composed(qparts)
|
||||
self.execute(stmt)
|
||||
return args
|
||||
|
||||
class ServerBindingCursor(CursorMixin, Database.Cursor):
|
||||
pass
|
||||
|
||||
class Cursor(CursorMixin, Database.ClientCursor):
|
||||
pass
|
||||
|
||||
class CursorDebugWrapper(BaseCursorDebugWrapper):
|
||||
def copy(self, statement):
|
||||
with self.debug_sql(statement):
|
||||
return self.cursor.copy(statement)
|
||||
|
||||
else:
|
||||
Cursor = psycopg2.extensions.cursor
|
||||
|
||||
class CursorDebugWrapper(BaseCursorDebugWrapper):
|
||||
def copy_expert(self, sql, file, *args):
|
||||
with self.debug_sql(sql):
|
||||
return self.cursor.copy_expert(sql, file, *args)
|
||||
|
||||
def copy_to(self, file, table, *args, **kwargs):
|
||||
with self.debug_sql(sql="COPY %s TO STDOUT" % table):
|
||||
return self.cursor.copy_to(file, table, *args, **kwargs)
|
||||
@@ -0,0 +1,64 @@
|
||||
import signal
|
||||
|
||||
from django.db.backends.base.client import BaseDatabaseClient
|
||||
|
||||
|
||||
class DatabaseClient(BaseDatabaseClient):
|
||||
executable_name = "psql"
|
||||
|
||||
@classmethod
|
||||
def settings_to_cmd_args_env(cls, settings_dict, parameters):
|
||||
args = [cls.executable_name]
|
||||
options = settings_dict.get("OPTIONS", {})
|
||||
|
||||
host = settings_dict.get("HOST")
|
||||
port = settings_dict.get("PORT")
|
||||
dbname = settings_dict.get("NAME")
|
||||
user = settings_dict.get("USER")
|
||||
passwd = settings_dict.get("PASSWORD")
|
||||
passfile = options.get("passfile")
|
||||
service = options.get("service")
|
||||
sslmode = options.get("sslmode")
|
||||
sslrootcert = options.get("sslrootcert")
|
||||
sslcert = options.get("sslcert")
|
||||
sslkey = options.get("sslkey")
|
||||
|
||||
if not dbname and not service:
|
||||
# Connect to the default 'postgres' db.
|
||||
dbname = "postgres"
|
||||
if user:
|
||||
args += ["-U", user]
|
||||
if host:
|
||||
args += ["-h", host]
|
||||
if port:
|
||||
args += ["-p", str(port)]
|
||||
args.extend(parameters)
|
||||
if dbname:
|
||||
args += [dbname]
|
||||
|
||||
env = {}
|
||||
if passwd:
|
||||
env["PGPASSWORD"] = str(passwd)
|
||||
if service:
|
||||
env["PGSERVICE"] = str(service)
|
||||
if sslmode:
|
||||
env["PGSSLMODE"] = str(sslmode)
|
||||
if sslrootcert:
|
||||
env["PGSSLROOTCERT"] = str(sslrootcert)
|
||||
if sslcert:
|
||||
env["PGSSLCERT"] = str(sslcert)
|
||||
if sslkey:
|
||||
env["PGSSLKEY"] = str(sslkey)
|
||||
if passfile:
|
||||
env["PGPASSFILE"] = str(passfile)
|
||||
return args, (env or None)
|
||||
|
||||
def runshell(self, parameters):
|
||||
sigint_handler = signal.getsignal(signal.SIGINT)
|
||||
try:
|
||||
# Allow SIGINT to pass to psql to abort queries.
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
super().runshell(parameters)
|
||||
finally:
|
||||
# Restore the original SIGINT handler.
|
||||
signal.signal(signal.SIGINT, sigint_handler)
|
||||
@@ -0,0 +1,86 @@
|
||||
import sys
|
||||
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db.backends.base.creation import BaseDatabaseCreation
|
||||
from django.db.backends.postgresql.psycopg_any import errors
|
||||
from django.db.backends.utils import strip_quotes
|
||||
|
||||
|
||||
class DatabaseCreation(BaseDatabaseCreation):
|
||||
def _quote_name(self, name):
|
||||
return self.connection.ops.quote_name(name)
|
||||
|
||||
def _get_database_create_suffix(self, encoding=None, template=None):
|
||||
suffix = ""
|
||||
if encoding:
|
||||
suffix += " ENCODING '{}'".format(encoding)
|
||||
if template:
|
||||
suffix += " TEMPLATE {}".format(self._quote_name(template))
|
||||
return suffix and "WITH" + suffix
|
||||
|
||||
def sql_table_creation_suffix(self):
|
||||
test_settings = self.connection.settings_dict["TEST"]
|
||||
if test_settings.get("COLLATION") is not None:
|
||||
raise ImproperlyConfigured(
|
||||
"PostgreSQL does not support collation setting at database "
|
||||
"creation time."
|
||||
)
|
||||
return self._get_database_create_suffix(
|
||||
encoding=test_settings["CHARSET"],
|
||||
template=test_settings.get("TEMPLATE"),
|
||||
)
|
||||
|
||||
def _database_exists(self, cursor, database_name):
|
||||
cursor.execute(
|
||||
"SELECT 1 FROM pg_catalog.pg_database WHERE datname = %s",
|
||||
[strip_quotes(database_name)],
|
||||
)
|
||||
return cursor.fetchone() is not None
|
||||
|
||||
def _execute_create_test_db(self, cursor, parameters, keepdb=False):
|
||||
try:
|
||||
if keepdb and self._database_exists(cursor, parameters["dbname"]):
|
||||
# If the database should be kept and it already exists, don't
|
||||
# try to create a new one.
|
||||
return
|
||||
super()._execute_create_test_db(cursor, parameters, keepdb)
|
||||
except Exception as e:
|
||||
if not isinstance(e.__cause__, errors.DuplicateDatabase):
|
||||
# All errors except "database already exists" cancel tests.
|
||||
self.log("Got an error creating the test database: %s" % e)
|
||||
sys.exit(2)
|
||||
elif not keepdb:
|
||||
# If the database should be kept, ignore "database already
|
||||
# exists".
|
||||
raise
|
||||
|
||||
def _clone_test_db(self, suffix, verbosity, keepdb=False):
|
||||
# CREATE DATABASE ... WITH TEMPLATE ... requires closing connections
|
||||
# to the template database.
|
||||
self.connection.close()
|
||||
|
||||
source_database_name = self.connection.settings_dict["NAME"]
|
||||
target_database_name = self.get_test_db_clone_settings(suffix)["NAME"]
|
||||
test_db_params = {
|
||||
"dbname": self._quote_name(target_database_name),
|
||||
"suffix": self._get_database_create_suffix(template=source_database_name),
|
||||
}
|
||||
with self._nodb_cursor() as cursor:
|
||||
try:
|
||||
self._execute_create_test_db(cursor, test_db_params, keepdb)
|
||||
except Exception:
|
||||
try:
|
||||
if verbosity >= 1:
|
||||
self.log(
|
||||
"Destroying old test database for alias %s..."
|
||||
% (
|
||||
self._get_database_display_str(
|
||||
verbosity, target_database_name
|
||||
),
|
||||
)
|
||||
)
|
||||
cursor.execute("DROP DATABASE %(dbname)s" % test_db_params)
|
||||
self._execute_create_test_db(cursor, test_db_params, keepdb)
|
||||
except Exception as e:
|
||||
self.log("Got an error cloning the test database: %s" % e)
|
||||
sys.exit(2)
|
||||
@@ -0,0 +1,136 @@
|
||||
import operator
|
||||
|
||||
from django.db import DataError, InterfaceError
|
||||
from django.db.backends.base.features import BaseDatabaseFeatures
|
||||
from django.db.backends.postgresql.psycopg_any import is_psycopg3
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
||||
class DatabaseFeatures(BaseDatabaseFeatures):
|
||||
minimum_database_version = (12,)
|
||||
allows_group_by_selected_pks = True
|
||||
can_return_columns_from_insert = True
|
||||
can_return_rows_from_bulk_insert = True
|
||||
has_real_datatype = True
|
||||
has_native_uuid_field = True
|
||||
has_native_duration_field = True
|
||||
has_native_json_field = True
|
||||
can_defer_constraint_checks = True
|
||||
has_select_for_update = True
|
||||
has_select_for_update_nowait = True
|
||||
has_select_for_update_of = True
|
||||
has_select_for_update_skip_locked = True
|
||||
has_select_for_no_key_update = True
|
||||
can_release_savepoints = True
|
||||
supports_comments = True
|
||||
supports_tablespaces = True
|
||||
supports_transactions = True
|
||||
can_introspect_materialized_views = True
|
||||
can_distinct_on_fields = True
|
||||
can_rollback_ddl = True
|
||||
schema_editor_uses_clientside_param_binding = True
|
||||
supports_combined_alters = True
|
||||
nulls_order_largest = True
|
||||
closed_cursor_error_class = InterfaceError
|
||||
greatest_least_ignores_nulls = True
|
||||
can_clone_databases = True
|
||||
supports_temporal_subtraction = True
|
||||
supports_slicing_ordering_in_compound = True
|
||||
create_test_procedure_without_params_sql = """
|
||||
CREATE FUNCTION test_procedure () RETURNS void AS $$
|
||||
DECLARE
|
||||
V_I INTEGER;
|
||||
BEGIN
|
||||
V_I := 1;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;"""
|
||||
create_test_procedure_with_int_param_sql = """
|
||||
CREATE FUNCTION test_procedure (P_I INTEGER) RETURNS void AS $$
|
||||
DECLARE
|
||||
V_I INTEGER;
|
||||
BEGIN
|
||||
V_I := P_I;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;"""
|
||||
create_test_table_with_composite_primary_key = """
|
||||
CREATE TABLE test_table_composite_pk (
|
||||
column_1 INTEGER NOT NULL,
|
||||
column_2 INTEGER NOT NULL,
|
||||
PRIMARY KEY(column_1, column_2)
|
||||
)
|
||||
"""
|
||||
requires_casted_case_in_updates = True
|
||||
supports_over_clause = True
|
||||
only_supports_unbounded_with_preceding_and_following = True
|
||||
supports_aggregate_filter_clause = True
|
||||
supported_explain_formats = {"JSON", "TEXT", "XML", "YAML"}
|
||||
supports_deferrable_unique_constraints = True
|
||||
has_json_operators = True
|
||||
json_key_contains_list_matching_requires_list = True
|
||||
supports_update_conflicts = True
|
||||
supports_update_conflicts_with_target = True
|
||||
supports_covering_indexes = True
|
||||
can_rename_index = True
|
||||
test_collations = {
|
||||
"non_default": "sv-x-icu",
|
||||
"swedish_ci": "sv-x-icu",
|
||||
}
|
||||
test_now_utc_template = "STATEMENT_TIMESTAMP() AT TIME ZONE 'UTC'"
|
||||
|
||||
django_test_skips = {
|
||||
"opclasses are PostgreSQL only.": {
|
||||
"indexes.tests.SchemaIndexesNotPostgreSQLTests."
|
||||
"test_create_index_ignores_opclasses",
|
||||
},
|
||||
"PostgreSQL requires casting to text.": {
|
||||
"lookup.tests.LookupTests.test_textfield_exact_null",
|
||||
},
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def django_test_expected_failures(self):
|
||||
expected_failures = set()
|
||||
if self.uses_server_side_binding:
|
||||
expected_failures.update(
|
||||
{
|
||||
# Parameters passed to expressions in SELECT and GROUP BY
|
||||
# clauses are not recognized as the same values when using
|
||||
# server-side binding cursors (#34255).
|
||||
"aggregation.tests.AggregateTestCase."
|
||||
"test_group_by_nested_expression_with_params",
|
||||
}
|
||||
)
|
||||
return expected_failures
|
||||
|
||||
@cached_property
|
||||
def uses_server_side_binding(self):
|
||||
options = self.connection.settings_dict["OPTIONS"]
|
||||
return is_psycopg3 and options.get("server_side_binding") is True
|
||||
|
||||
@cached_property
|
||||
def prohibits_null_characters_in_text_exception(self):
|
||||
if is_psycopg3:
|
||||
return DataError, "PostgreSQL text fields cannot contain NUL (0x00) bytes"
|
||||
else:
|
||||
return ValueError, "A string literal cannot contain NUL (0x00) characters."
|
||||
|
||||
@cached_property
|
||||
def introspected_field_types(self):
|
||||
return {
|
||||
**super().introspected_field_types,
|
||||
"PositiveBigIntegerField": "BigIntegerField",
|
||||
"PositiveIntegerField": "IntegerField",
|
||||
"PositiveSmallIntegerField": "SmallIntegerField",
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def is_postgresql_13(self):
|
||||
return self.connection.pg_version >= 130000
|
||||
|
||||
@cached_property
|
||||
def is_postgresql_14(self):
|
||||
return self.connection.pg_version >= 140000
|
||||
|
||||
has_bit_xor = property(operator.attrgetter("is_postgresql_14"))
|
||||
supports_covering_spgist_indexes = property(operator.attrgetter("is_postgresql_14"))
|
||||
supports_unlimited_charfield = True
|
||||
@@ -0,0 +1,299 @@
|
||||
from collections import namedtuple
|
||||
|
||||
from django.db.backends.base.introspection import BaseDatabaseIntrospection
|
||||
from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo
|
||||
from django.db.backends.base.introspection import TableInfo as BaseTableInfo
|
||||
from django.db.models import Index
|
||||
|
||||
FieldInfo = namedtuple("FieldInfo", BaseFieldInfo._fields + ("is_autofield", "comment"))
|
||||
TableInfo = namedtuple("TableInfo", BaseTableInfo._fields + ("comment",))
|
||||
|
||||
|
||||
class DatabaseIntrospection(BaseDatabaseIntrospection):
|
||||
# Maps type codes to Django Field types.
|
||||
data_types_reverse = {
|
||||
16: "BooleanField",
|
||||
17: "BinaryField",
|
||||
20: "BigIntegerField",
|
||||
21: "SmallIntegerField",
|
||||
23: "IntegerField",
|
||||
25: "TextField",
|
||||
700: "FloatField",
|
||||
701: "FloatField",
|
||||
869: "GenericIPAddressField",
|
||||
1042: "CharField", # blank-padded
|
||||
1043: "CharField",
|
||||
1082: "DateField",
|
||||
1083: "TimeField",
|
||||
1114: "DateTimeField",
|
||||
1184: "DateTimeField",
|
||||
1186: "DurationField",
|
||||
1266: "TimeField",
|
||||
1700: "DecimalField",
|
||||
2950: "UUIDField",
|
||||
3802: "JSONField",
|
||||
}
|
||||
# A hook for subclasses.
|
||||
index_default_access_method = "btree"
|
||||
|
||||
ignored_tables = []
|
||||
|
||||
def get_field_type(self, data_type, description):
|
||||
field_type = super().get_field_type(data_type, description)
|
||||
if description.is_autofield or (
|
||||
# Required for pre-Django 4.1 serial columns.
|
||||
description.default
|
||||
and "nextval" in description.default
|
||||
):
|
||||
if field_type == "IntegerField":
|
||||
return "AutoField"
|
||||
elif field_type == "BigIntegerField":
|
||||
return "BigAutoField"
|
||||
elif field_type == "SmallIntegerField":
|
||||
return "SmallAutoField"
|
||||
return field_type
|
||||
|
||||
def get_table_list(self, cursor):
|
||||
"""Return a list of table and view names in the current database."""
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
c.relname,
|
||||
CASE
|
||||
WHEN c.relispartition THEN 'p'
|
||||
WHEN c.relkind IN ('m', 'v') THEN 'v'
|
||||
ELSE 't'
|
||||
END,
|
||||
obj_description(c.oid, 'pg_class')
|
||||
FROM pg_catalog.pg_class c
|
||||
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
|
||||
WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
|
||||
AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
|
||||
AND pg_catalog.pg_table_is_visible(c.oid)
|
||||
"""
|
||||
)
|
||||
return [
|
||||
TableInfo(*row)
|
||||
for row in cursor.fetchall()
|
||||
if row[0] not in self.ignored_tables
|
||||
]
|
||||
|
||||
def get_table_description(self, cursor, table_name):
|
||||
"""
|
||||
Return a description of the table with the DB-API cursor.description
|
||||
interface.
|
||||
"""
|
||||
# Query the pg_catalog tables as cursor.description does not reliably
|
||||
# return the nullable property and information_schema.columns does not
|
||||
# contain details of materialized views.
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
a.attname AS column_name,
|
||||
NOT (a.attnotnull OR (t.typtype = 'd' AND t.typnotnull)) AS is_nullable,
|
||||
pg_get_expr(ad.adbin, ad.adrelid) AS column_default,
|
||||
CASE WHEN collname = 'default' THEN NULL ELSE collname END AS collation,
|
||||
a.attidentity != '' AS is_autofield,
|
||||
col_description(a.attrelid, a.attnum) AS column_comment
|
||||
FROM pg_attribute a
|
||||
LEFT JOIN pg_attrdef ad ON a.attrelid = ad.adrelid AND a.attnum = ad.adnum
|
||||
LEFT JOIN pg_collation co ON a.attcollation = co.oid
|
||||
JOIN pg_type t ON a.atttypid = t.oid
|
||||
JOIN pg_class c ON a.attrelid = c.oid
|
||||
JOIN pg_namespace n ON c.relnamespace = n.oid
|
||||
WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
|
||||
AND c.relname = %s
|
||||
AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
|
||||
AND pg_catalog.pg_table_is_visible(c.oid)
|
||||
""",
|
||||
[table_name],
|
||||
)
|
||||
field_map = {line[0]: line[1:] for line in cursor.fetchall()}
|
||||
cursor.execute(
|
||||
"SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name)
|
||||
)
|
||||
return [
|
||||
FieldInfo(
|
||||
line.name,
|
||||
line.type_code,
|
||||
# display_size is always None on psycopg2.
|
||||
line.internal_size if line.display_size is None else line.display_size,
|
||||
line.internal_size,
|
||||
line.precision,
|
||||
line.scale,
|
||||
*field_map[line.name],
|
||||
)
|
||||
for line in cursor.description
|
||||
]
|
||||
|
||||
def get_sequences(self, cursor, table_name, table_fields=()):
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
s.relname AS sequence_name,
|
||||
a.attname AS colname
|
||||
FROM
|
||||
pg_class s
|
||||
JOIN pg_depend d ON d.objid = s.oid
|
||||
AND d.classid = 'pg_class'::regclass
|
||||
AND d.refclassid = 'pg_class'::regclass
|
||||
JOIN pg_attribute a ON d.refobjid = a.attrelid
|
||||
AND d.refobjsubid = a.attnum
|
||||
JOIN pg_class tbl ON tbl.oid = d.refobjid
|
||||
AND tbl.relname = %s
|
||||
AND pg_catalog.pg_table_is_visible(tbl.oid)
|
||||
WHERE
|
||||
s.relkind = 'S';
|
||||
""",
|
||||
[table_name],
|
||||
)
|
||||
return [
|
||||
{"name": row[0], "table": table_name, "column": row[1]}
|
||||
for row in cursor.fetchall()
|
||||
]
|
||||
|
||||
def get_relations(self, cursor, table_name):
|
||||
"""
|
||||
Return a dictionary of {field_name: (field_name_other_table, other_table)}
|
||||
representing all foreign keys in the given table.
|
||||
"""
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT a1.attname, c2.relname, a2.attname
|
||||
FROM pg_constraint con
|
||||
LEFT JOIN pg_class c1 ON con.conrelid = c1.oid
|
||||
LEFT JOIN pg_class c2 ON con.confrelid = c2.oid
|
||||
LEFT JOIN
|
||||
pg_attribute a1 ON c1.oid = a1.attrelid AND a1.attnum = con.conkey[1]
|
||||
LEFT JOIN
|
||||
pg_attribute a2 ON c2.oid = a2.attrelid AND a2.attnum = con.confkey[1]
|
||||
WHERE
|
||||
c1.relname = %s AND
|
||||
con.contype = 'f' AND
|
||||
c1.relnamespace = c2.relnamespace AND
|
||||
pg_catalog.pg_table_is_visible(c1.oid)
|
||||
""",
|
||||
[table_name],
|
||||
)
|
||||
return {row[0]: (row[2], row[1]) for row in cursor.fetchall()}
|
||||
|
||||
def get_constraints(self, cursor, table_name):
|
||||
"""
|
||||
Retrieve any constraints or keys (unique, pk, fk, check, index) across
|
||||
one or more columns. Also retrieve the definition of expression-based
|
||||
indexes.
|
||||
"""
|
||||
constraints = {}
|
||||
# Loop over the key table, collecting things as constraints. The column
|
||||
# array must return column names in the same order in which they were
|
||||
# created.
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
c.conname,
|
||||
array(
|
||||
SELECT attname
|
||||
FROM unnest(c.conkey) WITH ORDINALITY cols(colid, arridx)
|
||||
JOIN pg_attribute AS ca ON cols.colid = ca.attnum
|
||||
WHERE ca.attrelid = c.conrelid
|
||||
ORDER BY cols.arridx
|
||||
),
|
||||
c.contype,
|
||||
(SELECT fkc.relname || '.' || fka.attname
|
||||
FROM pg_attribute AS fka
|
||||
JOIN pg_class AS fkc ON fka.attrelid = fkc.oid
|
||||
WHERE fka.attrelid = c.confrelid AND fka.attnum = c.confkey[1]),
|
||||
cl.reloptions
|
||||
FROM pg_constraint AS c
|
||||
JOIN pg_class AS cl ON c.conrelid = cl.oid
|
||||
WHERE cl.relname = %s AND pg_catalog.pg_table_is_visible(cl.oid)
|
||||
""",
|
||||
[table_name],
|
||||
)
|
||||
for constraint, columns, kind, used_cols, options in cursor.fetchall():
|
||||
constraints[constraint] = {
|
||||
"columns": columns,
|
||||
"primary_key": kind == "p",
|
||||
"unique": kind in ["p", "u"],
|
||||
"foreign_key": tuple(used_cols.split(".", 1)) if kind == "f" else None,
|
||||
"check": kind == "c",
|
||||
"index": False,
|
||||
"definition": None,
|
||||
"options": options,
|
||||
}
|
||||
# Now get indexes
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
indexname,
|
||||
array_agg(attname ORDER BY arridx),
|
||||
indisunique,
|
||||
indisprimary,
|
||||
array_agg(ordering ORDER BY arridx),
|
||||
amname,
|
||||
exprdef,
|
||||
s2.attoptions
|
||||
FROM (
|
||||
SELECT
|
||||
c2.relname as indexname, idx.*, attr.attname, am.amname,
|
||||
CASE
|
||||
WHEN idx.indexprs IS NOT NULL THEN
|
||||
pg_get_indexdef(idx.indexrelid)
|
||||
END AS exprdef,
|
||||
CASE am.amname
|
||||
WHEN %s THEN
|
||||
CASE (option & 1)
|
||||
WHEN 1 THEN 'DESC' ELSE 'ASC'
|
||||
END
|
||||
END as ordering,
|
||||
c2.reloptions as attoptions
|
||||
FROM (
|
||||
SELECT *
|
||||
FROM
|
||||
pg_index i,
|
||||
unnest(i.indkey, i.indoption)
|
||||
WITH ORDINALITY koi(key, option, arridx)
|
||||
) idx
|
||||
LEFT JOIN pg_class c ON idx.indrelid = c.oid
|
||||
LEFT JOIN pg_class c2 ON idx.indexrelid = c2.oid
|
||||
LEFT JOIN pg_am am ON c2.relam = am.oid
|
||||
LEFT JOIN
|
||||
pg_attribute attr ON attr.attrelid = c.oid AND attr.attnum = idx.key
|
||||
WHERE c.relname = %s AND pg_catalog.pg_table_is_visible(c.oid)
|
||||
) s2
|
||||
GROUP BY indexname, indisunique, indisprimary, amname, exprdef, attoptions;
|
||||
""",
|
||||
[self.index_default_access_method, table_name],
|
||||
)
|
||||
for (
|
||||
index,
|
||||
columns,
|
||||
unique,
|
||||
primary,
|
||||
orders,
|
||||
type_,
|
||||
definition,
|
||||
options,
|
||||
) in cursor.fetchall():
|
||||
if index not in constraints:
|
||||
basic_index = (
|
||||
type_ == self.index_default_access_method
|
||||
and
|
||||
# '_btree' references
|
||||
# django.contrib.postgres.indexes.BTreeIndex.suffix.
|
||||
not index.endswith("_btree")
|
||||
and options is None
|
||||
)
|
||||
constraints[index] = {
|
||||
"columns": columns if columns != [None] else [],
|
||||
"orders": orders if orders != [None] else [],
|
||||
"primary_key": primary,
|
||||
"unique": unique,
|
||||
"foreign_key": None,
|
||||
"check": False,
|
||||
"index": True,
|
||||
"type": Index.suffix if basic_index else type_,
|
||||
"definition": definition,
|
||||
"options": options,
|
||||
}
|
||||
return constraints
|
||||
@@ -0,0 +1,404 @@
|
||||
import json
|
||||
from functools import lru_cache, partial
|
||||
|
||||
from django.conf import settings
|
||||
from django.db.backends.base.operations import BaseDatabaseOperations
|
||||
from django.db.backends.postgresql.psycopg_any import (
|
||||
Inet,
|
||||
Jsonb,
|
||||
errors,
|
||||
is_psycopg3,
|
||||
mogrify,
|
||||
)
|
||||
from django.db.backends.utils import split_tzname_delta
|
||||
from django.db.models.constants import OnConflict
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_json_dumps(encoder):
|
||||
if encoder is None:
|
||||
return json.dumps
|
||||
return partial(json.dumps, cls=encoder)
|
||||
|
||||
|
||||
class DatabaseOperations(BaseDatabaseOperations):
|
||||
cast_char_field_without_max_length = "varchar"
|
||||
explain_prefix = "EXPLAIN"
|
||||
explain_options = frozenset(
|
||||
[
|
||||
"ANALYZE",
|
||||
"BUFFERS",
|
||||
"COSTS",
|
||||
"SETTINGS",
|
||||
"SUMMARY",
|
||||
"TIMING",
|
||||
"VERBOSE",
|
||||
"WAL",
|
||||
]
|
||||
)
|
||||
cast_data_types = {
|
||||
"AutoField": "integer",
|
||||
"BigAutoField": "bigint",
|
||||
"SmallAutoField": "smallint",
|
||||
}
|
||||
|
||||
if is_psycopg3:
|
||||
from psycopg.types import numeric
|
||||
|
||||
integerfield_type_map = {
|
||||
"SmallIntegerField": numeric.Int2,
|
||||
"IntegerField": numeric.Int4,
|
||||
"BigIntegerField": numeric.Int8,
|
||||
"PositiveSmallIntegerField": numeric.Int2,
|
||||
"PositiveIntegerField": numeric.Int4,
|
||||
"PositiveBigIntegerField": numeric.Int8,
|
||||
}
|
||||
|
||||
def unification_cast_sql(self, output_field):
|
||||
internal_type = output_field.get_internal_type()
|
||||
if internal_type in (
|
||||
"GenericIPAddressField",
|
||||
"IPAddressField",
|
||||
"TimeField",
|
||||
"UUIDField",
|
||||
):
|
||||
# PostgreSQL will resolve a union as type 'text' if input types are
|
||||
# 'unknown'.
|
||||
# https://www.postgresql.org/docs/current/typeconv-union-case.html
|
||||
# These fields cannot be implicitly cast back in the default
|
||||
# PostgreSQL configuration so we need to explicitly cast them.
|
||||
# We must also remove components of the type within brackets:
|
||||
# varchar(255) -> varchar.
|
||||
return (
|
||||
"CAST(%%s AS %s)" % output_field.db_type(self.connection).split("(")[0]
|
||||
)
|
||||
return "%s"
|
||||
|
||||
# EXTRACT format cannot be passed in parameters.
|
||||
_extract_format_re = _lazy_re_compile(r"[A-Z_]+")
|
||||
|
||||
def date_extract_sql(self, lookup_type, sql, params):
|
||||
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT
|
||||
if lookup_type == "week_day":
|
||||
# For consistency across backends, we return Sunday=1, Saturday=7.
|
||||
return f"EXTRACT(DOW FROM {sql}) + 1", params
|
||||
elif lookup_type == "iso_week_day":
|
||||
return f"EXTRACT(ISODOW FROM {sql})", params
|
||||
elif lookup_type == "iso_year":
|
||||
return f"EXTRACT(ISOYEAR FROM {sql})", params
|
||||
|
||||
lookup_type = lookup_type.upper()
|
||||
if not self._extract_format_re.fullmatch(lookup_type):
|
||||
raise ValueError(f"Invalid lookup type: {lookup_type!r}")
|
||||
return f"EXTRACT({lookup_type} FROM {sql})", params
|
||||
|
||||
def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
|
||||
return f"DATE_TRUNC(%s, {sql})", (lookup_type, *params)
|
||||
|
||||
def _prepare_tzname_delta(self, tzname):
|
||||
tzname, sign, offset = split_tzname_delta(tzname)
|
||||
if offset:
|
||||
sign = "-" if sign == "+" else "+"
|
||||
return f"{tzname}{sign}{offset}"
|
||||
return tzname
|
||||
|
||||
def _convert_sql_to_tz(self, sql, params, tzname):
|
||||
if tzname and settings.USE_TZ:
|
||||
tzname_param = self._prepare_tzname_delta(tzname)
|
||||
return f"{sql} AT TIME ZONE %s", (*params, tzname_param)
|
||||
return sql, params
|
||||
|
||||
def datetime_cast_date_sql(self, sql, params, tzname):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
return f"({sql})::date", params
|
||||
|
||||
def datetime_cast_time_sql(self, sql, params, tzname):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
return f"({sql})::time", params
|
||||
|
||||
def datetime_extract_sql(self, lookup_type, sql, params, tzname):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
if lookup_type == "second":
|
||||
# Truncate fractional seconds.
|
||||
return f"EXTRACT(SECOND FROM DATE_TRUNC(%s, {sql}))", ("second", *params)
|
||||
return self.date_extract_sql(lookup_type, sql, params)
|
||||
|
||||
def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
|
||||
return f"DATE_TRUNC(%s, {sql})", (lookup_type, *params)
|
||||
|
||||
def time_extract_sql(self, lookup_type, sql, params):
|
||||
if lookup_type == "second":
|
||||
# Truncate fractional seconds.
|
||||
return f"EXTRACT(SECOND FROM DATE_TRUNC(%s, {sql}))", ("second", *params)
|
||||
return self.date_extract_sql(lookup_type, sql, params)
|
||||
|
||||
def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
return f"DATE_TRUNC(%s, {sql})::time", (lookup_type, *params)
|
||||
|
||||
def deferrable_sql(self):
|
||||
return " DEFERRABLE INITIALLY DEFERRED"
|
||||
|
||||
def fetch_returned_insert_rows(self, cursor):
|
||||
"""
|
||||
Given a cursor object that has just performed an INSERT...RETURNING
|
||||
statement into a table, return the tuple of returned data.
|
||||
"""
|
||||
return cursor.fetchall()
|
||||
|
||||
def lookup_cast(self, lookup_type, internal_type=None):
|
||||
lookup = "%s"
|
||||
# Cast text lookups to text to allow things like filter(x__contains=4)
|
||||
if lookup_type in (
|
||||
"iexact",
|
||||
"contains",
|
||||
"icontains",
|
||||
"startswith",
|
||||
"istartswith",
|
||||
"endswith",
|
||||
"iendswith",
|
||||
"regex",
|
||||
"iregex",
|
||||
):
|
||||
if internal_type in ("IPAddressField", "GenericIPAddressField"):
|
||||
lookup = "HOST(%s)"
|
||||
# RemovedInDjango51Warning.
|
||||
elif internal_type in ("CICharField", "CIEmailField", "CITextField"):
|
||||
lookup = "%s::citext"
|
||||
else:
|
||||
lookup = "%s::text"
|
||||
|
||||
# Use UPPER(x) for case-insensitive lookups; it's faster.
|
||||
if lookup_type in ("iexact", "icontains", "istartswith", "iendswith"):
|
||||
lookup = "UPPER(%s)" % lookup
|
||||
|
||||
return lookup
|
||||
|
||||
def no_limit_value(self):
|
||||
return None
|
||||
|
||||
def prepare_sql_script(self, sql):
|
||||
return [sql]
|
||||
|
||||
def quote_name(self, name):
|
||||
if name.startswith('"') and name.endswith('"'):
|
||||
return name # Quoting once is enough.
|
||||
return '"%s"' % name
|
||||
|
||||
def compose_sql(self, sql, params):
|
||||
return mogrify(sql, params, self.connection)
|
||||
|
||||
def set_time_zone_sql(self):
|
||||
return "SELECT set_config('TimeZone', %s, false)"
|
||||
|
||||
def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
|
||||
if not tables:
|
||||
return []
|
||||
|
||||
# Perform a single SQL 'TRUNCATE x, y, z...;' statement. It allows us
|
||||
# to truncate tables referenced by a foreign key in any other table.
|
||||
sql_parts = [
|
||||
style.SQL_KEYWORD("TRUNCATE"),
|
||||
", ".join(style.SQL_FIELD(self.quote_name(table)) for table in tables),
|
||||
]
|
||||
if reset_sequences:
|
||||
sql_parts.append(style.SQL_KEYWORD("RESTART IDENTITY"))
|
||||
if allow_cascade:
|
||||
sql_parts.append(style.SQL_KEYWORD("CASCADE"))
|
||||
return ["%s;" % " ".join(sql_parts)]
|
||||
|
||||
def sequence_reset_by_name_sql(self, style, sequences):
|
||||
# 'ALTER SEQUENCE sequence_name RESTART WITH 1;'... style SQL statements
|
||||
# to reset sequence indices
|
||||
sql = []
|
||||
for sequence_info in sequences:
|
||||
table_name = sequence_info["table"]
|
||||
# 'id' will be the case if it's an m2m using an autogenerated
|
||||
# intermediate table (see BaseDatabaseIntrospection.sequence_list).
|
||||
column_name = sequence_info["column"] or "id"
|
||||
sql.append(
|
||||
"%s setval(pg_get_serial_sequence('%s','%s'), 1, false);"
|
||||
% (
|
||||
style.SQL_KEYWORD("SELECT"),
|
||||
style.SQL_TABLE(self.quote_name(table_name)),
|
||||
style.SQL_FIELD(column_name),
|
||||
)
|
||||
)
|
||||
return sql
|
||||
|
||||
def tablespace_sql(self, tablespace, inline=False):
|
||||
if inline:
|
||||
return "USING INDEX TABLESPACE %s" % self.quote_name(tablespace)
|
||||
else:
|
||||
return "TABLESPACE %s" % self.quote_name(tablespace)
|
||||
|
||||
def sequence_reset_sql(self, style, model_list):
|
||||
from django.db import models
|
||||
|
||||
output = []
|
||||
qn = self.quote_name
|
||||
for model in model_list:
|
||||
# Use `coalesce` to set the sequence for each model to the max pk
|
||||
# value if there are records, or 1 if there are none. Set the
|
||||
# `is_called` property (the third argument to `setval`) to true if
|
||||
# there are records (as the max pk value is already in use),
|
||||
# otherwise set it to false. Use pg_get_serial_sequence to get the
|
||||
# underlying sequence name from the table name and column name.
|
||||
|
||||
for f in model._meta.local_fields:
|
||||
if isinstance(f, models.AutoField):
|
||||
output.append(
|
||||
"%s setval(pg_get_serial_sequence('%s','%s'), "
|
||||
"coalesce(max(%s), 1), max(%s) %s null) %s %s;"
|
||||
% (
|
||||
style.SQL_KEYWORD("SELECT"),
|
||||
style.SQL_TABLE(qn(model._meta.db_table)),
|
||||
style.SQL_FIELD(f.column),
|
||||
style.SQL_FIELD(qn(f.column)),
|
||||
style.SQL_FIELD(qn(f.column)),
|
||||
style.SQL_KEYWORD("IS NOT"),
|
||||
style.SQL_KEYWORD("FROM"),
|
||||
style.SQL_TABLE(qn(model._meta.db_table)),
|
||||
)
|
||||
)
|
||||
# Only one AutoField is allowed per model, so don't bother
|
||||
# continuing.
|
||||
break
|
||||
return output
|
||||
|
||||
def prep_for_iexact_query(self, x):
|
||||
return x
|
||||
|
||||
def max_name_length(self):
|
||||
"""
|
||||
Return the maximum length of an identifier.
|
||||
|
||||
The maximum length of an identifier is 63 by default, but can be
|
||||
changed by recompiling PostgreSQL after editing the NAMEDATALEN
|
||||
macro in src/include/pg_config_manual.h.
|
||||
|
||||
This implementation returns 63, but can be overridden by a custom
|
||||
database backend that inherits most of its behavior from this one.
|
||||
"""
|
||||
return 63
|
||||
|
||||
def distinct_sql(self, fields, params):
|
||||
if fields:
|
||||
params = [param for param_list in params for param in param_list]
|
||||
return (["DISTINCT ON (%s)" % ", ".join(fields)], params)
|
||||
else:
|
||||
return ["DISTINCT"], []
|
||||
|
||||
if is_psycopg3:
|
||||
|
||||
def last_executed_query(self, cursor, sql, params):
|
||||
try:
|
||||
return self.compose_sql(sql, params)
|
||||
except errors.DataError:
|
||||
return None
|
||||
|
||||
else:
|
||||
|
||||
def last_executed_query(self, cursor, sql, params):
|
||||
# https://www.psycopg.org/docs/cursor.html#cursor.query
|
||||
# The query attribute is a Psycopg extension to the DB API 2.0.
|
||||
if cursor.query is not None:
|
||||
return cursor.query.decode()
|
||||
return None
|
||||
|
||||
def return_insert_columns(self, fields):
|
||||
if not fields:
|
||||
return "", ()
|
||||
columns = [
|
||||
"%s.%s"
|
||||
% (
|
||||
self.quote_name(field.model._meta.db_table),
|
||||
self.quote_name(field.column),
|
||||
)
|
||||
for field in fields
|
||||
]
|
||||
return "RETURNING %s" % ", ".join(columns), ()
|
||||
|
||||
def bulk_insert_sql(self, fields, placeholder_rows):
|
||||
placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
|
||||
values_sql = ", ".join("(%s)" % sql for sql in placeholder_rows_sql)
|
||||
return "VALUES " + values_sql
|
||||
|
||||
if is_psycopg3:
|
||||
|
||||
def adapt_integerfield_value(self, value, internal_type):
|
||||
if value is None or hasattr(value, "resolve_expression"):
|
||||
return value
|
||||
return self.integerfield_type_map[internal_type](value)
|
||||
|
||||
def adapt_datefield_value(self, value):
|
||||
return value
|
||||
|
||||
def adapt_datetimefield_value(self, value):
|
||||
return value
|
||||
|
||||
def adapt_timefield_value(self, value):
|
||||
return value
|
||||
|
||||
def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None):
|
||||
return value
|
||||
|
||||
def adapt_ipaddressfield_value(self, value):
|
||||
if value:
|
||||
return Inet(value)
|
||||
return None
|
||||
|
||||
def adapt_json_value(self, value, encoder):
|
||||
return Jsonb(value, dumps=get_json_dumps(encoder))
|
||||
|
||||
def subtract_temporals(self, internal_type, lhs, rhs):
|
||||
if internal_type == "DateField":
|
||||
lhs_sql, lhs_params = lhs
|
||||
rhs_sql, rhs_params = rhs
|
||||
params = (*lhs_params, *rhs_params)
|
||||
return "(interval '1 day' * (%s - %s))" % (lhs_sql, rhs_sql), params
|
||||
return super().subtract_temporals(internal_type, lhs, rhs)
|
||||
|
||||
def explain_query_prefix(self, format=None, **options):
|
||||
extra = {}
|
||||
# Normalize options.
|
||||
if options:
|
||||
options = {
|
||||
name.upper(): "true" if value else "false"
|
||||
for name, value in options.items()
|
||||
}
|
||||
for valid_option in self.explain_options:
|
||||
value = options.pop(valid_option, None)
|
||||
if value is not None:
|
||||
extra[valid_option] = value
|
||||
prefix = super().explain_query_prefix(format, **options)
|
||||
if format:
|
||||
extra["FORMAT"] = format
|
||||
if extra:
|
||||
prefix += " (%s)" % ", ".join("%s %s" % i for i in extra.items())
|
||||
return prefix
|
||||
|
||||
def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
|
||||
if on_conflict == OnConflict.IGNORE:
|
||||
return "ON CONFLICT DO NOTHING"
|
||||
if on_conflict == OnConflict.UPDATE:
|
||||
return "ON CONFLICT(%s) DO UPDATE SET %s" % (
|
||||
", ".join(map(self.quote_name, unique_fields)),
|
||||
", ".join(
|
||||
[
|
||||
f"{field} = EXCLUDED.{field}"
|
||||
for field in map(self.quote_name, update_fields)
|
||||
]
|
||||
),
|
||||
)
|
||||
return super().on_conflict_suffix_sql(
|
||||
fields,
|
||||
on_conflict,
|
||||
update_fields,
|
||||
unique_fields,
|
||||
)
|
||||
@@ -0,0 +1,103 @@
|
||||
import ipaddress
|
||||
from functools import lru_cache
|
||||
|
||||
try:
|
||||
from psycopg import ClientCursor, IsolationLevel, adapt, adapters, errors, sql
|
||||
from psycopg.postgres import types
|
||||
from psycopg.types.datetime import TimestamptzLoader
|
||||
from psycopg.types.json import Jsonb
|
||||
from psycopg.types.range import Range, RangeDumper
|
||||
from psycopg.types.string import TextLoader
|
||||
|
||||
Inet = ipaddress.ip_address
|
||||
|
||||
DateRange = DateTimeRange = DateTimeTZRange = NumericRange = Range
|
||||
RANGE_TYPES = (Range,)
|
||||
|
||||
TSRANGE_OID = types["tsrange"].oid
|
||||
TSTZRANGE_OID = types["tstzrange"].oid
|
||||
|
||||
def mogrify(sql, params, connection):
|
||||
with connection.cursor() as cursor:
|
||||
return ClientCursor(cursor.connection).mogrify(sql, params)
|
||||
|
||||
# Adapters.
|
||||
class BaseTzLoader(TimestamptzLoader):
|
||||
"""
|
||||
Load a PostgreSQL timestamptz using the a specific timezone.
|
||||
The timezone can be None too, in which case it will be chopped.
|
||||
"""
|
||||
|
||||
timezone = None
|
||||
|
||||
def load(self, data):
|
||||
res = super().load(data)
|
||||
return res.replace(tzinfo=self.timezone)
|
||||
|
||||
def register_tzloader(tz, context):
|
||||
class SpecificTzLoader(BaseTzLoader):
|
||||
timezone = tz
|
||||
|
||||
context.adapters.register_loader("timestamptz", SpecificTzLoader)
|
||||
|
||||
class DjangoRangeDumper(RangeDumper):
|
||||
"""A Range dumper customized for Django."""
|
||||
|
||||
def upgrade(self, obj, format):
|
||||
# Dump ranges containing naive datetimes as tstzrange, because
|
||||
# Django doesn't use tz-aware ones.
|
||||
dumper = super().upgrade(obj, format)
|
||||
if dumper is not self and dumper.oid == TSRANGE_OID:
|
||||
dumper.oid = TSTZRANGE_OID
|
||||
return dumper
|
||||
|
||||
@lru_cache
|
||||
def get_adapters_template(use_tz, timezone):
|
||||
# Create at adapters map extending the base one.
|
||||
ctx = adapt.AdaptersMap(adapters)
|
||||
# Register a no-op dumper to avoid a round trip from psycopg version 3
|
||||
# decode to json.dumps() to json.loads(), when using a custom decoder
|
||||
# in JSONField.
|
||||
ctx.register_loader("jsonb", TextLoader)
|
||||
# Don't convert automatically from PostgreSQL network types to Python
|
||||
# ipaddress.
|
||||
ctx.register_loader("inet", TextLoader)
|
||||
ctx.register_loader("cidr", TextLoader)
|
||||
ctx.register_dumper(Range, DjangoRangeDumper)
|
||||
# Register a timestamptz loader configured on self.timezone.
|
||||
# This, however, can be overridden by create_cursor.
|
||||
register_tzloader(timezone, ctx)
|
||||
return ctx
|
||||
|
||||
is_psycopg3 = True
|
||||
|
||||
except ImportError:
|
||||
from enum import IntEnum
|
||||
|
||||
from psycopg2 import errors, extensions, sql # NOQA
|
||||
from psycopg2.extras import DateRange, DateTimeRange, DateTimeTZRange, Inet # NOQA
|
||||
from psycopg2.extras import Json as Jsonb # NOQA
|
||||
from psycopg2.extras import NumericRange, Range # NOQA
|
||||
|
||||
RANGE_TYPES = (DateRange, DateTimeRange, DateTimeTZRange, NumericRange)
|
||||
|
||||
class IsolationLevel(IntEnum):
|
||||
READ_UNCOMMITTED = extensions.ISOLATION_LEVEL_READ_UNCOMMITTED
|
||||
READ_COMMITTED = extensions.ISOLATION_LEVEL_READ_COMMITTED
|
||||
REPEATABLE_READ = extensions.ISOLATION_LEVEL_REPEATABLE_READ
|
||||
SERIALIZABLE = extensions.ISOLATION_LEVEL_SERIALIZABLE
|
||||
|
||||
def _quote(value, connection=None):
|
||||
adapted = extensions.adapt(value)
|
||||
if hasattr(adapted, "encoding"):
|
||||
adapted.encoding = "utf8"
|
||||
# getquoted() returns a quoted bytestring of the adapted value.
|
||||
return adapted.getquoted().decode()
|
||||
|
||||
sql.quote = _quote
|
||||
|
||||
def mogrify(sql, params, connection):
|
||||
with connection.cursor() as cursor:
|
||||
return cursor.mogrify(sql, params).decode()
|
||||
|
||||
is_psycopg3 = False
|
||||
@@ -0,0 +1,374 @@
|
||||
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
||||
from django.db.backends.ddl_references import IndexColumns
|
||||
from django.db.backends.postgresql.psycopg_any import sql
|
||||
from django.db.backends.utils import strip_quotes
|
||||
|
||||
|
||||
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||
# Setting all constraints to IMMEDIATE to allow changing data in the same
|
||||
# transaction.
|
||||
sql_update_with_default = (
|
||||
"UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL"
|
||||
"; SET CONSTRAINTS ALL IMMEDIATE"
|
||||
)
|
||||
sql_alter_sequence_type = "ALTER SEQUENCE IF EXISTS %(sequence)s AS %(type)s"
|
||||
sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE"
|
||||
|
||||
sql_create_index = (
|
||||
"CREATE INDEX %(name)s ON %(table)s%(using)s "
|
||||
"(%(columns)s)%(include)s%(extra)s%(condition)s"
|
||||
)
|
||||
sql_create_index_concurrently = (
|
||||
"CREATE INDEX CONCURRENTLY %(name)s ON %(table)s%(using)s "
|
||||
"(%(columns)s)%(include)s%(extra)s%(condition)s"
|
||||
)
|
||||
sql_delete_index = "DROP INDEX IF EXISTS %(name)s"
|
||||
sql_delete_index_concurrently = "DROP INDEX CONCURRENTLY IF EXISTS %(name)s"
|
||||
|
||||
# Setting the constraint to IMMEDIATE to allow changing data in the same
|
||||
# transaction.
|
||||
sql_create_column_inline_fk = (
|
||||
"CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s"
|
||||
"; SET CONSTRAINTS %(namespace)s%(name)s IMMEDIATE"
|
||||
)
|
||||
# Setting the constraint to IMMEDIATE runs any deferred checks to allow
|
||||
# dropping it in the same transaction.
|
||||
sql_delete_fk = (
|
||||
"SET CONSTRAINTS %(name)s IMMEDIATE; "
|
||||
"ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
|
||||
)
|
||||
sql_delete_procedure = "DROP FUNCTION %(procedure)s(%(param_types)s)"
|
||||
|
||||
def execute(self, sql, params=()):
|
||||
# Merge the query client-side, as PostgreSQL won't do it server-side.
|
||||
if params is None:
|
||||
return super().execute(sql, params)
|
||||
sql = self.connection.ops.compose_sql(str(sql), params)
|
||||
# Don't let the superclass touch anything.
|
||||
return super().execute(sql, None)
|
||||
|
||||
sql_add_identity = (
|
||||
"ALTER TABLE %(table)s ALTER COLUMN %(column)s ADD "
|
||||
"GENERATED BY DEFAULT AS IDENTITY"
|
||||
)
|
||||
sql_drop_indentity = (
|
||||
"ALTER TABLE %(table)s ALTER COLUMN %(column)s DROP IDENTITY IF EXISTS"
|
||||
)
|
||||
|
||||
def quote_value(self, value):
|
||||
if isinstance(value, str):
|
||||
value = value.replace("%", "%%")
|
||||
return sql.quote(value, self.connection.connection)
|
||||
|
||||
def _field_indexes_sql(self, model, field):
|
||||
output = super()._field_indexes_sql(model, field)
|
||||
like_index_statement = self._create_like_index_sql(model, field)
|
||||
if like_index_statement is not None:
|
||||
output.append(like_index_statement)
|
||||
return output
|
||||
|
||||
def _field_data_type(self, field):
|
||||
if field.is_relation:
|
||||
return field.rel_db_type(self.connection)
|
||||
return self.connection.data_types.get(
|
||||
field.get_internal_type(),
|
||||
field.db_type(self.connection),
|
||||
)
|
||||
|
||||
def _field_base_data_types(self, field):
|
||||
# Yield base data types for array fields.
|
||||
if field.base_field.get_internal_type() == "ArrayField":
|
||||
yield from self._field_base_data_types(field.base_field)
|
||||
else:
|
||||
yield self._field_data_type(field.base_field)
|
||||
|
||||
def _create_like_index_sql(self, model, field):
|
||||
"""
|
||||
Return the statement to create an index with varchar operator pattern
|
||||
when the column type is 'varchar' or 'text', otherwise return None.
|
||||
"""
|
||||
db_type = field.db_type(connection=self.connection)
|
||||
if db_type is not None and (field.db_index or field.unique):
|
||||
# Fields with database column types of `varchar` and `text` need
|
||||
# a second index that specifies their operator class, which is
|
||||
# needed when performing correct LIKE queries outside the
|
||||
# C locale. See #12234.
|
||||
#
|
||||
# The same doesn't apply to array fields such as varchar[size]
|
||||
# and text[size], so skip them.
|
||||
if "[" in db_type:
|
||||
return None
|
||||
# Non-deterministic collations on Postgresql don't support indexes
|
||||
# for operator classes varchar_pattern_ops/text_pattern_ops.
|
||||
if getattr(field, "db_collation", None):
|
||||
return None
|
||||
if db_type.startswith("varchar"):
|
||||
return self._create_index_sql(
|
||||
model,
|
||||
fields=[field],
|
||||
suffix="_like",
|
||||
opclasses=["varchar_pattern_ops"],
|
||||
)
|
||||
elif db_type.startswith("text"):
|
||||
return self._create_index_sql(
|
||||
model,
|
||||
fields=[field],
|
||||
suffix="_like",
|
||||
opclasses=["text_pattern_ops"],
|
||||
)
|
||||
return None
|
||||
|
||||
def _using_sql(self, new_field, old_field):
|
||||
using_sql = " USING %(column)s::%(type)s"
|
||||
new_internal_type = new_field.get_internal_type()
|
||||
old_internal_type = old_field.get_internal_type()
|
||||
if new_internal_type == "ArrayField" and new_internal_type == old_internal_type:
|
||||
# Compare base data types for array fields.
|
||||
if list(self._field_base_data_types(old_field)) != list(
|
||||
self._field_base_data_types(new_field)
|
||||
):
|
||||
return using_sql
|
||||
elif self._field_data_type(old_field) != self._field_data_type(new_field):
|
||||
return using_sql
|
||||
return ""
|
||||
|
||||
def _get_sequence_name(self, table, column):
|
||||
with self.connection.cursor() as cursor:
|
||||
for sequence in self.connection.introspection.get_sequences(cursor, table):
|
||||
if sequence["column"] == column:
|
||||
return sequence["name"]
|
||||
return None
|
||||
|
||||
def _alter_column_type_sql(
|
||||
self, model, old_field, new_field, new_type, old_collation, new_collation
|
||||
):
|
||||
# Drop indexes on varchar/text/citext columns that are changing to a
|
||||
# different type.
|
||||
old_db_params = old_field.db_parameters(connection=self.connection)
|
||||
old_type = old_db_params["type"]
|
||||
if (old_field.db_index or old_field.unique) and (
|
||||
(old_type.startswith("varchar") and not new_type.startswith("varchar"))
|
||||
or (old_type.startswith("text") and not new_type.startswith("text"))
|
||||
or (old_type.startswith("citext") and not new_type.startswith("citext"))
|
||||
):
|
||||
index_name = self._create_index_name(
|
||||
model._meta.db_table, [old_field.column], suffix="_like"
|
||||
)
|
||||
self.execute(self._delete_index_sql(model, index_name))
|
||||
|
||||
self.sql_alter_column_type = (
|
||||
"ALTER COLUMN %(column)s TYPE %(type)s%(collation)s"
|
||||
)
|
||||
# Cast when data type changed.
|
||||
if using_sql := self._using_sql(new_field, old_field):
|
||||
self.sql_alter_column_type += using_sql
|
||||
new_internal_type = new_field.get_internal_type()
|
||||
old_internal_type = old_field.get_internal_type()
|
||||
# Make ALTER TYPE with IDENTITY make sense.
|
||||
table = strip_quotes(model._meta.db_table)
|
||||
auto_field_types = {
|
||||
"AutoField",
|
||||
"BigAutoField",
|
||||
"SmallAutoField",
|
||||
}
|
||||
old_is_auto = old_internal_type in auto_field_types
|
||||
new_is_auto = new_internal_type in auto_field_types
|
||||
if new_is_auto and not old_is_auto:
|
||||
column = strip_quotes(new_field.column)
|
||||
return (
|
||||
(
|
||||
self.sql_alter_column_type
|
||||
% {
|
||||
"column": self.quote_name(column),
|
||||
"type": new_type,
|
||||
"collation": "",
|
||||
},
|
||||
[],
|
||||
),
|
||||
[
|
||||
(
|
||||
self.sql_add_identity
|
||||
% {
|
||||
"table": self.quote_name(table),
|
||||
"column": self.quote_name(column),
|
||||
},
|
||||
[],
|
||||
),
|
||||
],
|
||||
)
|
||||
elif old_is_auto and not new_is_auto:
|
||||
# Drop IDENTITY if exists (pre-Django 4.1 serial columns don't have
|
||||
# it).
|
||||
self.execute(
|
||||
self.sql_drop_indentity
|
||||
% {
|
||||
"table": self.quote_name(table),
|
||||
"column": self.quote_name(strip_quotes(new_field.column)),
|
||||
}
|
||||
)
|
||||
column = strip_quotes(new_field.column)
|
||||
fragment, _ = super()._alter_column_type_sql(
|
||||
model, old_field, new_field, new_type, old_collation, new_collation
|
||||
)
|
||||
# Drop the sequence if exists (Django 4.1+ identity columns don't
|
||||
# have it).
|
||||
other_actions = []
|
||||
if sequence_name := self._get_sequence_name(table, column):
|
||||
other_actions = [
|
||||
(
|
||||
self.sql_delete_sequence
|
||||
% {
|
||||
"sequence": self.quote_name(sequence_name),
|
||||
},
|
||||
[],
|
||||
)
|
||||
]
|
||||
return fragment, other_actions
|
||||
elif new_is_auto and old_is_auto and old_internal_type != new_internal_type:
|
||||
fragment, _ = super()._alter_column_type_sql(
|
||||
model, old_field, new_field, new_type, old_collation, new_collation
|
||||
)
|
||||
column = strip_quotes(new_field.column)
|
||||
db_types = {
|
||||
"AutoField": "integer",
|
||||
"BigAutoField": "bigint",
|
||||
"SmallAutoField": "smallint",
|
||||
}
|
||||
# Alter the sequence type if exists (Django 4.1+ identity columns
|
||||
# don't have it).
|
||||
other_actions = []
|
||||
if sequence_name := self._get_sequence_name(table, column):
|
||||
other_actions = [
|
||||
(
|
||||
self.sql_alter_sequence_type
|
||||
% {
|
||||
"sequence": self.quote_name(sequence_name),
|
||||
"type": db_types[new_internal_type],
|
||||
},
|
||||
[],
|
||||
),
|
||||
]
|
||||
return fragment, other_actions
|
||||
else:
|
||||
return super()._alter_column_type_sql(
|
||||
model, old_field, new_field, new_type, old_collation, new_collation
|
||||
)
|
||||
|
||||
def _alter_column_collation_sql(
|
||||
self, model, new_field, new_type, new_collation, old_field
|
||||
):
|
||||
sql = self.sql_alter_column_collate
|
||||
# Cast when data type changed.
|
||||
if using_sql := self._using_sql(new_field, old_field):
|
||||
sql += using_sql
|
||||
return (
|
||||
sql
|
||||
% {
|
||||
"column": self.quote_name(new_field.column),
|
||||
"type": new_type,
|
||||
"collation": " " + self._collate_sql(new_collation)
|
||||
if new_collation
|
||||
else "",
|
||||
},
|
||||
[],
|
||||
)
|
||||
|
||||
def _alter_field(
|
||||
self,
|
||||
model,
|
||||
old_field,
|
||||
new_field,
|
||||
old_type,
|
||||
new_type,
|
||||
old_db_params,
|
||||
new_db_params,
|
||||
strict=False,
|
||||
):
|
||||
super()._alter_field(
|
||||
model,
|
||||
old_field,
|
||||
new_field,
|
||||
old_type,
|
||||
new_type,
|
||||
old_db_params,
|
||||
new_db_params,
|
||||
strict,
|
||||
)
|
||||
# Added an index? Create any PostgreSQL-specific indexes.
|
||||
if (not (old_field.db_index or old_field.unique) and new_field.db_index) or (
|
||||
not old_field.unique and new_field.unique
|
||||
):
|
||||
like_index_statement = self._create_like_index_sql(model, new_field)
|
||||
if like_index_statement is not None:
|
||||
self.execute(like_index_statement)
|
||||
|
||||
# Removed an index? Drop any PostgreSQL-specific indexes.
|
||||
if old_field.unique and not (new_field.db_index or new_field.unique):
|
||||
index_to_remove = self._create_index_name(
|
||||
model._meta.db_table, [old_field.column], suffix="_like"
|
||||
)
|
||||
self.execute(self._delete_index_sql(model, index_to_remove))
|
||||
|
||||
def _index_columns(self, table, columns, col_suffixes, opclasses):
|
||||
if opclasses:
|
||||
return IndexColumns(
|
||||
table,
|
||||
columns,
|
||||
self.quote_name,
|
||||
col_suffixes=col_suffixes,
|
||||
opclasses=opclasses,
|
||||
)
|
||||
return super()._index_columns(table, columns, col_suffixes, opclasses)
|
||||
|
||||
def add_index(self, model, index, concurrently=False):
|
||||
self.execute(
|
||||
index.create_sql(model, self, concurrently=concurrently), params=None
|
||||
)
|
||||
|
||||
def remove_index(self, model, index, concurrently=False):
|
||||
self.execute(index.remove_sql(model, self, concurrently=concurrently))
|
||||
|
||||
def _delete_index_sql(self, model, name, sql=None, concurrently=False):
|
||||
sql = (
|
||||
self.sql_delete_index_concurrently
|
||||
if concurrently
|
||||
else self.sql_delete_index
|
||||
)
|
||||
return super()._delete_index_sql(model, name, sql)
|
||||
|
||||
def _create_index_sql(
|
||||
self,
|
||||
model,
|
||||
*,
|
||||
fields=None,
|
||||
name=None,
|
||||
suffix="",
|
||||
using="",
|
||||
db_tablespace=None,
|
||||
col_suffixes=(),
|
||||
sql=None,
|
||||
opclasses=(),
|
||||
condition=None,
|
||||
concurrently=False,
|
||||
include=None,
|
||||
expressions=None,
|
||||
):
|
||||
sql = sql or (
|
||||
self.sql_create_index
|
||||
if not concurrently
|
||||
else self.sql_create_index_concurrently
|
||||
)
|
||||
return super()._create_index_sql(
|
||||
model,
|
||||
fields=fields,
|
||||
name=name,
|
||||
suffix=suffix,
|
||||
using=using,
|
||||
db_tablespace=db_tablespace,
|
||||
col_suffixes=col_suffixes,
|
||||
sql=sql,
|
||||
opclasses=opclasses,
|
||||
condition=condition,
|
||||
include=include,
|
||||
expressions=expressions,
|
||||
)
|
||||
@@ -0,0 +1,3 @@
|
||||
from django.dispatch import Signal
|
||||
|
||||
connection_created = Signal()
|
||||
@@ -0,0 +1,512 @@
|
||||
"""
|
||||
Implementations of SQL functions for SQLite.
|
||||
"""
|
||||
import functools
|
||||
import random
|
||||
import statistics
|
||||
from datetime import timedelta
|
||||
from hashlib import sha1, sha224, sha256, sha384, sha512
|
||||
from math import (
|
||||
acos,
|
||||
asin,
|
||||
atan,
|
||||
atan2,
|
||||
ceil,
|
||||
cos,
|
||||
degrees,
|
||||
exp,
|
||||
floor,
|
||||
fmod,
|
||||
log,
|
||||
pi,
|
||||
radians,
|
||||
sin,
|
||||
sqrt,
|
||||
tan,
|
||||
)
|
||||
from re import search as re_search
|
||||
|
||||
from django.db.backends.base.base import timezone_constructor
|
||||
from django.db.backends.utils import (
|
||||
split_tzname_delta,
|
||||
typecast_time,
|
||||
typecast_timestamp,
|
||||
)
|
||||
from django.utils import timezone
|
||||
from django.utils.crypto import md5
|
||||
from django.utils.duration import duration_microseconds
|
||||
|
||||
|
||||
def register(connection):
|
||||
create_deterministic_function = functools.partial(
|
||||
connection.create_function,
|
||||
deterministic=True,
|
||||
)
|
||||
create_deterministic_function("django_date_extract", 2, _sqlite_datetime_extract)
|
||||
create_deterministic_function("django_date_trunc", 4, _sqlite_date_trunc)
|
||||
create_deterministic_function(
|
||||
"django_datetime_cast_date", 3, _sqlite_datetime_cast_date
|
||||
)
|
||||
create_deterministic_function(
|
||||
"django_datetime_cast_time", 3, _sqlite_datetime_cast_time
|
||||
)
|
||||
create_deterministic_function(
|
||||
"django_datetime_extract", 4, _sqlite_datetime_extract
|
||||
)
|
||||
create_deterministic_function("django_datetime_trunc", 4, _sqlite_datetime_trunc)
|
||||
create_deterministic_function("django_time_extract", 2, _sqlite_time_extract)
|
||||
create_deterministic_function("django_time_trunc", 4, _sqlite_time_trunc)
|
||||
create_deterministic_function("django_time_diff", 2, _sqlite_time_diff)
|
||||
create_deterministic_function("django_timestamp_diff", 2, _sqlite_timestamp_diff)
|
||||
create_deterministic_function("django_format_dtdelta", 3, _sqlite_format_dtdelta)
|
||||
create_deterministic_function("regexp", 2, _sqlite_regexp)
|
||||
create_deterministic_function("BITXOR", 2, _sqlite_bitxor)
|
||||
create_deterministic_function("COT", 1, _sqlite_cot)
|
||||
create_deterministic_function("LPAD", 3, _sqlite_lpad)
|
||||
create_deterministic_function("MD5", 1, _sqlite_md5)
|
||||
create_deterministic_function("REPEAT", 2, _sqlite_repeat)
|
||||
create_deterministic_function("REVERSE", 1, _sqlite_reverse)
|
||||
create_deterministic_function("RPAD", 3, _sqlite_rpad)
|
||||
create_deterministic_function("SHA1", 1, _sqlite_sha1)
|
||||
create_deterministic_function("SHA224", 1, _sqlite_sha224)
|
||||
create_deterministic_function("SHA256", 1, _sqlite_sha256)
|
||||
create_deterministic_function("SHA384", 1, _sqlite_sha384)
|
||||
create_deterministic_function("SHA512", 1, _sqlite_sha512)
|
||||
create_deterministic_function("SIGN", 1, _sqlite_sign)
|
||||
# Don't use the built-in RANDOM() function because it returns a value
|
||||
# in the range [-1 * 2^63, 2^63 - 1] instead of [0, 1).
|
||||
connection.create_function("RAND", 0, random.random)
|
||||
connection.create_aggregate("STDDEV_POP", 1, StdDevPop)
|
||||
connection.create_aggregate("STDDEV_SAMP", 1, StdDevSamp)
|
||||
connection.create_aggregate("VAR_POP", 1, VarPop)
|
||||
connection.create_aggregate("VAR_SAMP", 1, VarSamp)
|
||||
# Some math functions are enabled by default in SQLite 3.35+.
|
||||
sql = "select sqlite_compileoption_used('ENABLE_MATH_FUNCTIONS')"
|
||||
if not connection.execute(sql).fetchone()[0]:
|
||||
create_deterministic_function("ACOS", 1, _sqlite_acos)
|
||||
create_deterministic_function("ASIN", 1, _sqlite_asin)
|
||||
create_deterministic_function("ATAN", 1, _sqlite_atan)
|
||||
create_deterministic_function("ATAN2", 2, _sqlite_atan2)
|
||||
create_deterministic_function("CEILING", 1, _sqlite_ceiling)
|
||||
create_deterministic_function("COS", 1, _sqlite_cos)
|
||||
create_deterministic_function("DEGREES", 1, _sqlite_degrees)
|
||||
create_deterministic_function("EXP", 1, _sqlite_exp)
|
||||
create_deterministic_function("FLOOR", 1, _sqlite_floor)
|
||||
create_deterministic_function("LN", 1, _sqlite_ln)
|
||||
create_deterministic_function("LOG", 2, _sqlite_log)
|
||||
create_deterministic_function("MOD", 2, _sqlite_mod)
|
||||
create_deterministic_function("PI", 0, _sqlite_pi)
|
||||
create_deterministic_function("POWER", 2, _sqlite_power)
|
||||
create_deterministic_function("RADIANS", 1, _sqlite_radians)
|
||||
create_deterministic_function("SIN", 1, _sqlite_sin)
|
||||
create_deterministic_function("SQRT", 1, _sqlite_sqrt)
|
||||
create_deterministic_function("TAN", 1, _sqlite_tan)
|
||||
|
||||
|
||||
def _sqlite_datetime_parse(dt, tzname=None, conn_tzname=None):
|
||||
if dt is None:
|
||||
return None
|
||||
try:
|
||||
dt = typecast_timestamp(dt)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
if conn_tzname:
|
||||
dt = dt.replace(tzinfo=timezone_constructor(conn_tzname))
|
||||
if tzname is not None and tzname != conn_tzname:
|
||||
tzname, sign, offset = split_tzname_delta(tzname)
|
||||
if offset:
|
||||
hours, minutes = offset.split(":")
|
||||
offset_delta = timedelta(hours=int(hours), minutes=int(minutes))
|
||||
dt += offset_delta if sign == "+" else -offset_delta
|
||||
dt = timezone.localtime(dt, timezone_constructor(tzname))
|
||||
return dt
|
||||
|
||||
|
||||
def _sqlite_date_trunc(lookup_type, dt, tzname, conn_tzname):
|
||||
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
|
||||
if dt is None:
|
||||
return None
|
||||
if lookup_type == "year":
|
||||
return f"{dt.year:04d}-01-01"
|
||||
elif lookup_type == "quarter":
|
||||
month_in_quarter = dt.month - (dt.month - 1) % 3
|
||||
return f"{dt.year:04d}-{month_in_quarter:02d}-01"
|
||||
elif lookup_type == "month":
|
||||
return f"{dt.year:04d}-{dt.month:02d}-01"
|
||||
elif lookup_type == "week":
|
||||
dt -= timedelta(days=dt.weekday())
|
||||
return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d}"
|
||||
elif lookup_type == "day":
|
||||
return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d}"
|
||||
raise ValueError(f"Unsupported lookup type: {lookup_type!r}")
|
||||
|
||||
|
||||
def _sqlite_time_trunc(lookup_type, dt, tzname, conn_tzname):
|
||||
if dt is None:
|
||||
return None
|
||||
dt_parsed = _sqlite_datetime_parse(dt, tzname, conn_tzname)
|
||||
if dt_parsed is None:
|
||||
try:
|
||||
dt = typecast_time(dt)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
else:
|
||||
dt = dt_parsed
|
||||
if lookup_type == "hour":
|
||||
return f"{dt.hour:02d}:00:00"
|
||||
elif lookup_type == "minute":
|
||||
return f"{dt.hour:02d}:{dt.minute:02d}:00"
|
||||
elif lookup_type == "second":
|
||||
return f"{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}"
|
||||
raise ValueError(f"Unsupported lookup type: {lookup_type!r}")
|
||||
|
||||
|
||||
def _sqlite_datetime_cast_date(dt, tzname, conn_tzname):
|
||||
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
|
||||
if dt is None:
|
||||
return None
|
||||
return dt.date().isoformat()
|
||||
|
||||
|
||||
def _sqlite_datetime_cast_time(dt, tzname, conn_tzname):
|
||||
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
|
||||
if dt is None:
|
||||
return None
|
||||
return dt.time().isoformat()
|
||||
|
||||
|
||||
def _sqlite_datetime_extract(lookup_type, dt, tzname=None, conn_tzname=None):
|
||||
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
|
||||
if dt is None:
|
||||
return None
|
||||
if lookup_type == "week_day":
|
||||
return (dt.isoweekday() % 7) + 1
|
||||
elif lookup_type == "iso_week_day":
|
||||
return dt.isoweekday()
|
||||
elif lookup_type == "week":
|
||||
return dt.isocalendar()[1]
|
||||
elif lookup_type == "quarter":
|
||||
return ceil(dt.month / 3)
|
||||
elif lookup_type == "iso_year":
|
||||
return dt.isocalendar()[0]
|
||||
else:
|
||||
return getattr(dt, lookup_type)
|
||||
|
||||
|
||||
def _sqlite_datetime_trunc(lookup_type, dt, tzname, conn_tzname):
|
||||
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
|
||||
if dt is None:
|
||||
return None
|
||||
if lookup_type == "year":
|
||||
return f"{dt.year:04d}-01-01 00:00:00"
|
||||
elif lookup_type == "quarter":
|
||||
month_in_quarter = dt.month - (dt.month - 1) % 3
|
||||
return f"{dt.year:04d}-{month_in_quarter:02d}-01 00:00:00"
|
||||
elif lookup_type == "month":
|
||||
return f"{dt.year:04d}-{dt.month:02d}-01 00:00:00"
|
||||
elif lookup_type == "week":
|
||||
dt -= timedelta(days=dt.weekday())
|
||||
return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} 00:00:00"
|
||||
elif lookup_type == "day":
|
||||
return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} 00:00:00"
|
||||
elif lookup_type == "hour":
|
||||
return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} {dt.hour:02d}:00:00"
|
||||
elif lookup_type == "minute":
|
||||
return (
|
||||
f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} "
|
||||
f"{dt.hour:02d}:{dt.minute:02d}:00"
|
||||
)
|
||||
elif lookup_type == "second":
|
||||
return (
|
||||
f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} "
|
||||
f"{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}"
|
||||
)
|
||||
raise ValueError(f"Unsupported lookup type: {lookup_type!r}")
|
||||
|
||||
|
||||
def _sqlite_time_extract(lookup_type, dt):
|
||||
if dt is None:
|
||||
return None
|
||||
try:
|
||||
dt = typecast_time(dt)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
return getattr(dt, lookup_type)
|
||||
|
||||
|
||||
def _sqlite_prepare_dtdelta_param(conn, param):
|
||||
if conn in ["+", "-"]:
|
||||
if isinstance(param, int):
|
||||
return timedelta(0, 0, param)
|
||||
else:
|
||||
return typecast_timestamp(param)
|
||||
return param
|
||||
|
||||
|
||||
def _sqlite_format_dtdelta(connector, lhs, rhs):
|
||||
"""
|
||||
LHS and RHS can be either:
|
||||
- An integer number of microseconds
|
||||
- A string representing a datetime
|
||||
- A scalar value, e.g. float
|
||||
"""
|
||||
if connector is None or lhs is None or rhs is None:
|
||||
return None
|
||||
connector = connector.strip()
|
||||
try:
|
||||
real_lhs = _sqlite_prepare_dtdelta_param(connector, lhs)
|
||||
real_rhs = _sqlite_prepare_dtdelta_param(connector, rhs)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
if connector == "+":
|
||||
# typecast_timestamp() returns a date or a datetime without timezone.
|
||||
# It will be formatted as "%Y-%m-%d" or "%Y-%m-%d %H:%M:%S[.%f]"
|
||||
out = str(real_lhs + real_rhs)
|
||||
elif connector == "-":
|
||||
out = str(real_lhs - real_rhs)
|
||||
elif connector == "*":
|
||||
out = real_lhs * real_rhs
|
||||
else:
|
||||
out = real_lhs / real_rhs
|
||||
return out
|
||||
|
||||
|
||||
def _sqlite_time_diff(lhs, rhs):
|
||||
if lhs is None or rhs is None:
|
||||
return None
|
||||
left = typecast_time(lhs)
|
||||
right = typecast_time(rhs)
|
||||
return (
|
||||
(left.hour * 60 * 60 * 1000000)
|
||||
+ (left.minute * 60 * 1000000)
|
||||
+ (left.second * 1000000)
|
||||
+ (left.microsecond)
|
||||
- (right.hour * 60 * 60 * 1000000)
|
||||
- (right.minute * 60 * 1000000)
|
||||
- (right.second * 1000000)
|
||||
- (right.microsecond)
|
||||
)
|
||||
|
||||
|
||||
def _sqlite_timestamp_diff(lhs, rhs):
|
||||
if lhs is None or rhs is None:
|
||||
return None
|
||||
left = typecast_timestamp(lhs)
|
||||
right = typecast_timestamp(rhs)
|
||||
return duration_microseconds(left - right)
|
||||
|
||||
|
||||
def _sqlite_regexp(pattern, string):
|
||||
if pattern is None or string is None:
|
||||
return None
|
||||
if not isinstance(string, str):
|
||||
string = str(string)
|
||||
return bool(re_search(pattern, string))
|
||||
|
||||
|
||||
def _sqlite_acos(x):
|
||||
if x is None:
|
||||
return None
|
||||
return acos(x)
|
||||
|
||||
|
||||
def _sqlite_asin(x):
|
||||
if x is None:
|
||||
return None
|
||||
return asin(x)
|
||||
|
||||
|
||||
def _sqlite_atan(x):
|
||||
if x is None:
|
||||
return None
|
||||
return atan(x)
|
||||
|
||||
|
||||
def _sqlite_atan2(y, x):
|
||||
if y is None or x is None:
|
||||
return None
|
||||
return atan2(y, x)
|
||||
|
||||
|
||||
def _sqlite_bitxor(x, y):
|
||||
if x is None or y is None:
|
||||
return None
|
||||
return x ^ y
|
||||
|
||||
|
||||
def _sqlite_ceiling(x):
|
||||
if x is None:
|
||||
return None
|
||||
return ceil(x)
|
||||
|
||||
|
||||
def _sqlite_cos(x):
|
||||
if x is None:
|
||||
return None
|
||||
return cos(x)
|
||||
|
||||
|
||||
def _sqlite_cot(x):
|
||||
if x is None:
|
||||
return None
|
||||
return 1 / tan(x)
|
||||
|
||||
|
||||
def _sqlite_degrees(x):
|
||||
if x is None:
|
||||
return None
|
||||
return degrees(x)
|
||||
|
||||
|
||||
def _sqlite_exp(x):
|
||||
if x is None:
|
||||
return None
|
||||
return exp(x)
|
||||
|
||||
|
||||
def _sqlite_floor(x):
|
||||
if x is None:
|
||||
return None
|
||||
return floor(x)
|
||||
|
||||
|
||||
def _sqlite_ln(x):
|
||||
if x is None:
|
||||
return None
|
||||
return log(x)
|
||||
|
||||
|
||||
def _sqlite_log(base, x):
|
||||
if base is None or x is None:
|
||||
return None
|
||||
# Arguments reversed to match SQL standard.
|
||||
return log(x, base)
|
||||
|
||||
|
||||
def _sqlite_lpad(text, length, fill_text):
|
||||
if text is None or length is None or fill_text is None:
|
||||
return None
|
||||
delta = length - len(text)
|
||||
if delta <= 0:
|
||||
return text[:length]
|
||||
return (fill_text * length)[:delta] + text
|
||||
|
||||
|
||||
def _sqlite_md5(text):
|
||||
if text is None:
|
||||
return None
|
||||
return md5(text.encode()).hexdigest()
|
||||
|
||||
|
||||
def _sqlite_mod(x, y):
|
||||
if x is None or y is None:
|
||||
return None
|
||||
return fmod(x, y)
|
||||
|
||||
|
||||
def _sqlite_pi():
|
||||
return pi
|
||||
|
||||
|
||||
def _sqlite_power(x, y):
|
||||
if x is None or y is None:
|
||||
return None
|
||||
return x**y
|
||||
|
||||
|
||||
def _sqlite_radians(x):
|
||||
if x is None:
|
||||
return None
|
||||
return radians(x)
|
||||
|
||||
|
||||
def _sqlite_repeat(text, count):
|
||||
if text is None or count is None:
|
||||
return None
|
||||
return text * count
|
||||
|
||||
|
||||
def _sqlite_reverse(text):
|
||||
if text is None:
|
||||
return None
|
||||
return text[::-1]
|
||||
|
||||
|
||||
def _sqlite_rpad(text, length, fill_text):
|
||||
if text is None or length is None or fill_text is None:
|
||||
return None
|
||||
return (text + fill_text * length)[:length]
|
||||
|
||||
|
||||
def _sqlite_sha1(text):
|
||||
if text is None:
|
||||
return None
|
||||
return sha1(text.encode()).hexdigest()
|
||||
|
||||
|
||||
def _sqlite_sha224(text):
|
||||
if text is None:
|
||||
return None
|
||||
return sha224(text.encode()).hexdigest()
|
||||
|
||||
|
||||
def _sqlite_sha256(text):
|
||||
if text is None:
|
||||
return None
|
||||
return sha256(text.encode()).hexdigest()
|
||||
|
||||
|
||||
def _sqlite_sha384(text):
|
||||
if text is None:
|
||||
return None
|
||||
return sha384(text.encode()).hexdigest()
|
||||
|
||||
|
||||
def _sqlite_sha512(text):
|
||||
if text is None:
|
||||
return None
|
||||
return sha512(text.encode()).hexdigest()
|
||||
|
||||
|
||||
def _sqlite_sign(x):
|
||||
if x is None:
|
||||
return None
|
||||
return (x > 0) - (x < 0)
|
||||
|
||||
|
||||
def _sqlite_sin(x):
|
||||
if x is None:
|
||||
return None
|
||||
return sin(x)
|
||||
|
||||
|
||||
def _sqlite_sqrt(x):
|
||||
if x is None:
|
||||
return None
|
||||
return sqrt(x)
|
||||
|
||||
|
||||
def _sqlite_tan(x):
|
||||
if x is None:
|
||||
return None
|
||||
return tan(x)
|
||||
|
||||
|
||||
class ListAggregate(list):
|
||||
step = list.append
|
||||
|
||||
|
||||
class StdDevPop(ListAggregate):
|
||||
finalize = statistics.pstdev
|
||||
|
||||
|
||||
class StdDevSamp(ListAggregate):
|
||||
finalize = statistics.stdev
|
||||
|
||||
|
||||
class VarPop(ListAggregate):
|
||||
finalize = statistics.pvariance
|
||||
|
||||
|
||||
class VarSamp(ListAggregate):
|
||||
finalize = statistics.variance
|
||||
@@ -0,0 +1,347 @@
|
||||
"""
|
||||
SQLite backend for the sqlite3 module in the standard library.
|
||||
"""
|
||||
import datetime
|
||||
import decimal
|
||||
import warnings
|
||||
from collections.abc import Mapping
|
||||
from itertools import chain, tee
|
||||
from sqlite3 import dbapi2 as Database
|
||||
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db import IntegrityError
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.utils.asyncio import async_unsafe
|
||||
from django.utils.dateparse import parse_date, parse_datetime, parse_time
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
|
||||
from ._functions import register as register_functions
|
||||
from .client import DatabaseClient
|
||||
from .creation import DatabaseCreation
|
||||
from .features import DatabaseFeatures
|
||||
from .introspection import DatabaseIntrospection
|
||||
from .operations import DatabaseOperations
|
||||
from .schema import DatabaseSchemaEditor
|
||||
|
||||
|
||||
def decoder(conv_func):
|
||||
"""
|
||||
Convert bytestrings from Python's sqlite3 interface to a regular string.
|
||||
"""
|
||||
return lambda s: conv_func(s.decode())
|
||||
|
||||
|
||||
def adapt_date(val):
|
||||
return val.isoformat()
|
||||
|
||||
|
||||
def adapt_datetime(val):
|
||||
return val.isoformat(" ")
|
||||
|
||||
|
||||
Database.register_converter("bool", b"1".__eq__)
|
||||
Database.register_converter("date", decoder(parse_date))
|
||||
Database.register_converter("time", decoder(parse_time))
|
||||
Database.register_converter("datetime", decoder(parse_datetime))
|
||||
Database.register_converter("timestamp", decoder(parse_datetime))
|
||||
|
||||
Database.register_adapter(decimal.Decimal, str)
|
||||
Database.register_adapter(datetime.date, adapt_date)
|
||||
Database.register_adapter(datetime.datetime, adapt_datetime)
|
||||
|
||||
|
||||
class DatabaseWrapper(BaseDatabaseWrapper):
|
||||
vendor = "sqlite"
|
||||
display_name = "SQLite"
|
||||
# SQLite doesn't actually support most of these types, but it "does the right
|
||||
# thing" given more verbose field definitions, so leave them as is so that
|
||||
# schema inspection is more useful.
|
||||
data_types = {
|
||||
"AutoField": "integer",
|
||||
"BigAutoField": "integer",
|
||||
"BinaryField": "BLOB",
|
||||
"BooleanField": "bool",
|
||||
"CharField": "varchar(%(max_length)s)",
|
||||
"DateField": "date",
|
||||
"DateTimeField": "datetime",
|
||||
"DecimalField": "decimal",
|
||||
"DurationField": "bigint",
|
||||
"FileField": "varchar(%(max_length)s)",
|
||||
"FilePathField": "varchar(%(max_length)s)",
|
||||
"FloatField": "real",
|
||||
"IntegerField": "integer",
|
||||
"BigIntegerField": "bigint",
|
||||
"IPAddressField": "char(15)",
|
||||
"GenericIPAddressField": "char(39)",
|
||||
"JSONField": "text",
|
||||
"OneToOneField": "integer",
|
||||
"PositiveBigIntegerField": "bigint unsigned",
|
||||
"PositiveIntegerField": "integer unsigned",
|
||||
"PositiveSmallIntegerField": "smallint unsigned",
|
||||
"SlugField": "varchar(%(max_length)s)",
|
||||
"SmallAutoField": "integer",
|
||||
"SmallIntegerField": "smallint",
|
||||
"TextField": "text",
|
||||
"TimeField": "time",
|
||||
"UUIDField": "char(32)",
|
||||
}
|
||||
data_type_check_constraints = {
|
||||
"PositiveBigIntegerField": '"%(column)s" >= 0',
|
||||
"JSONField": '(JSON_VALID("%(column)s") OR "%(column)s" IS NULL)',
|
||||
"PositiveIntegerField": '"%(column)s" >= 0',
|
||||
"PositiveSmallIntegerField": '"%(column)s" >= 0',
|
||||
}
|
||||
data_types_suffix = {
|
||||
"AutoField": "AUTOINCREMENT",
|
||||
"BigAutoField": "AUTOINCREMENT",
|
||||
"SmallAutoField": "AUTOINCREMENT",
|
||||
}
|
||||
# SQLite requires LIKE statements to include an ESCAPE clause if the value
|
||||
# being escaped has a percent or underscore in it.
|
||||
# See https://www.sqlite.org/lang_expr.html for an explanation.
|
||||
operators = {
|
||||
"exact": "= %s",
|
||||
"iexact": "LIKE %s ESCAPE '\\'",
|
||||
"contains": "LIKE %s ESCAPE '\\'",
|
||||
"icontains": "LIKE %s ESCAPE '\\'",
|
||||
"regex": "REGEXP %s",
|
||||
"iregex": "REGEXP '(?i)' || %s",
|
||||
"gt": "> %s",
|
||||
"gte": ">= %s",
|
||||
"lt": "< %s",
|
||||
"lte": "<= %s",
|
||||
"startswith": "LIKE %s ESCAPE '\\'",
|
||||
"endswith": "LIKE %s ESCAPE '\\'",
|
||||
"istartswith": "LIKE %s ESCAPE '\\'",
|
||||
"iendswith": "LIKE %s ESCAPE '\\'",
|
||||
}
|
||||
|
||||
# The patterns below are used to generate SQL pattern lookup clauses when
|
||||
# the right-hand side of the lookup isn't a raw string (it might be an expression
|
||||
# or the result of a bilateral transformation).
|
||||
# In those cases, special characters for LIKE operators (e.g. \, *, _) should be
|
||||
# escaped on database side.
|
||||
#
|
||||
# Note: we use str.format() here for readability as '%' is used as a wildcard for
|
||||
# the LIKE operator.
|
||||
pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\', '\\'), '%%', '\%%'), '_', '\_')"
|
||||
pattern_ops = {
|
||||
"contains": r"LIKE '%%' || {} || '%%' ESCAPE '\'",
|
||||
"icontains": r"LIKE '%%' || UPPER({}) || '%%' ESCAPE '\'",
|
||||
"startswith": r"LIKE {} || '%%' ESCAPE '\'",
|
||||
"istartswith": r"LIKE UPPER({}) || '%%' ESCAPE '\'",
|
||||
"endswith": r"LIKE '%%' || {} ESCAPE '\'",
|
||||
"iendswith": r"LIKE '%%' || UPPER({}) ESCAPE '\'",
|
||||
}
|
||||
|
||||
Database = Database
|
||||
SchemaEditorClass = DatabaseSchemaEditor
|
||||
# Classes instantiated in __init__().
|
||||
client_class = DatabaseClient
|
||||
creation_class = DatabaseCreation
|
||||
features_class = DatabaseFeatures
|
||||
introspection_class = DatabaseIntrospection
|
||||
ops_class = DatabaseOperations
|
||||
|
||||
def get_connection_params(self):
|
||||
settings_dict = self.settings_dict
|
||||
if not settings_dict["NAME"]:
|
||||
raise ImproperlyConfigured(
|
||||
"settings.DATABASES is improperly configured. "
|
||||
"Please supply the NAME value."
|
||||
)
|
||||
kwargs = {
|
||||
"database": settings_dict["NAME"],
|
||||
"detect_types": Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES,
|
||||
**settings_dict["OPTIONS"],
|
||||
}
|
||||
# Always allow the underlying SQLite connection to be shareable
|
||||
# between multiple threads. The safe-guarding will be handled at a
|
||||
# higher level by the `BaseDatabaseWrapper.allow_thread_sharing`
|
||||
# property. This is necessary as the shareability is disabled by
|
||||
# default in sqlite3 and it cannot be changed once a connection is
|
||||
# opened.
|
||||
if "check_same_thread" in kwargs and kwargs["check_same_thread"]:
|
||||
warnings.warn(
|
||||
"The `check_same_thread` option was provided and set to "
|
||||
"True. It will be overridden with False. Use the "
|
||||
"`DatabaseWrapper.allow_thread_sharing` property instead "
|
||||
"for controlling thread shareability.",
|
||||
RuntimeWarning,
|
||||
)
|
||||
kwargs.update({"check_same_thread": False, "uri": True})
|
||||
return kwargs
|
||||
|
||||
def get_database_version(self):
|
||||
return self.Database.sqlite_version_info
|
||||
|
||||
@async_unsafe
|
||||
def get_new_connection(self, conn_params):
|
||||
conn = Database.connect(**conn_params)
|
||||
register_functions(conn)
|
||||
|
||||
conn.execute("PRAGMA foreign_keys = ON")
|
||||
# The macOS bundled SQLite defaults legacy_alter_table ON, which
|
||||
# prevents atomic table renames (feature supports_atomic_references_rename)
|
||||
conn.execute("PRAGMA legacy_alter_table = OFF")
|
||||
return conn
|
||||
|
||||
def create_cursor(self, name=None):
|
||||
return self.connection.cursor(factory=SQLiteCursorWrapper)
|
||||
|
||||
@async_unsafe
|
||||
def close(self):
|
||||
self.validate_thread_sharing()
|
||||
# If database is in memory, closing the connection destroys the
|
||||
# database. To prevent accidental data loss, ignore close requests on
|
||||
# an in-memory db.
|
||||
if not self.is_in_memory_db():
|
||||
BaseDatabaseWrapper.close(self)
|
||||
|
||||
def _savepoint_allowed(self):
|
||||
# When 'isolation_level' is not None, sqlite3 commits before each
|
||||
# savepoint; it's a bug. When it is None, savepoints don't make sense
|
||||
# because autocommit is enabled. The only exception is inside 'atomic'
|
||||
# blocks. To work around that bug, on SQLite, 'atomic' starts a
|
||||
# transaction explicitly rather than simply disable autocommit.
|
||||
return self.in_atomic_block
|
||||
|
||||
def _set_autocommit(self, autocommit):
|
||||
if autocommit:
|
||||
level = None
|
||||
else:
|
||||
# sqlite3's internal default is ''. It's different from None.
|
||||
# See Modules/_sqlite/connection.c.
|
||||
level = ""
|
||||
# 'isolation_level' is a misleading API.
|
||||
# SQLite always runs at the SERIALIZABLE isolation level.
|
||||
with self.wrap_database_errors:
|
||||
self.connection.isolation_level = level
|
||||
|
||||
def disable_constraint_checking(self):
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute("PRAGMA foreign_keys = OFF")
|
||||
# Foreign key constraints cannot be turned off while in a multi-
|
||||
# statement transaction. Fetch the current state of the pragma
|
||||
# to determine if constraints are effectively disabled.
|
||||
enabled = cursor.execute("PRAGMA foreign_keys").fetchone()[0]
|
||||
return not bool(enabled)
|
||||
|
||||
def enable_constraint_checking(self):
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute("PRAGMA foreign_keys = ON")
|
||||
|
||||
def check_constraints(self, table_names=None):
|
||||
"""
|
||||
Check each table name in `table_names` for rows with invalid foreign
|
||||
key references. This method is intended to be used in conjunction with
|
||||
`disable_constraint_checking()` and `enable_constraint_checking()`, to
|
||||
determine if rows with invalid references were entered while constraint
|
||||
checks were off.
|
||||
"""
|
||||
with self.cursor() as cursor:
|
||||
if table_names is None:
|
||||
violations = cursor.execute("PRAGMA foreign_key_check").fetchall()
|
||||
else:
|
||||
violations = chain.from_iterable(
|
||||
cursor.execute(
|
||||
"PRAGMA foreign_key_check(%s)" % self.ops.quote_name(table_name)
|
||||
).fetchall()
|
||||
for table_name in table_names
|
||||
)
|
||||
# See https://www.sqlite.org/pragma.html#pragma_foreign_key_check
|
||||
for (
|
||||
table_name,
|
||||
rowid,
|
||||
referenced_table_name,
|
||||
foreign_key_index,
|
||||
) in violations:
|
||||
foreign_key = cursor.execute(
|
||||
"PRAGMA foreign_key_list(%s)" % self.ops.quote_name(table_name)
|
||||
).fetchall()[foreign_key_index]
|
||||
column_name, referenced_column_name = foreign_key[3:5]
|
||||
primary_key_column_name = self.introspection.get_primary_key_column(
|
||||
cursor, table_name
|
||||
)
|
||||
primary_key_value, bad_value = cursor.execute(
|
||||
"SELECT %s, %s FROM %s WHERE rowid = %%s"
|
||||
% (
|
||||
self.ops.quote_name(primary_key_column_name),
|
||||
self.ops.quote_name(column_name),
|
||||
self.ops.quote_name(table_name),
|
||||
),
|
||||
(rowid,),
|
||||
).fetchone()
|
||||
raise IntegrityError(
|
||||
"The row in table '%s' with primary key '%s' has an "
|
||||
"invalid foreign key: %s.%s contains a value '%s' that "
|
||||
"does not have a corresponding value in %s.%s."
|
||||
% (
|
||||
table_name,
|
||||
primary_key_value,
|
||||
table_name,
|
||||
column_name,
|
||||
bad_value,
|
||||
referenced_table_name,
|
||||
referenced_column_name,
|
||||
)
|
||||
)
|
||||
|
||||
def is_usable(self):
|
||||
return True
|
||||
|
||||
def _start_transaction_under_autocommit(self):
|
||||
"""
|
||||
Start a transaction explicitly in autocommit mode.
|
||||
|
||||
Staying in autocommit mode works around a bug of sqlite3 that breaks
|
||||
savepoints when autocommit is disabled.
|
||||
"""
|
||||
self.cursor().execute("BEGIN")
|
||||
|
||||
def is_in_memory_db(self):
|
||||
return self.creation.is_in_memory_db(self.settings_dict["NAME"])
|
||||
|
||||
|
||||
FORMAT_QMARK_REGEX = _lazy_re_compile(r"(?<!%)%s")
|
||||
|
||||
|
||||
class SQLiteCursorWrapper(Database.Cursor):
|
||||
"""
|
||||
Django uses the "format" and "pyformat" styles, but Python's sqlite3 module
|
||||
supports neither of these styles.
|
||||
|
||||
This wrapper performs the following conversions:
|
||||
|
||||
- "format" style to "qmark" style
|
||||
- "pyformat" style to "named" style
|
||||
|
||||
In both cases, if you want to use a literal "%s", you'll need to use "%%s".
|
||||
"""
|
||||
|
||||
def execute(self, query, params=None):
|
||||
if params is None:
|
||||
return super().execute(query)
|
||||
# Extract names if params is a mapping, i.e. "pyformat" style is used.
|
||||
param_names = list(params) if isinstance(params, Mapping) else None
|
||||
query = self.convert_query(query, param_names=param_names)
|
||||
return super().execute(query, params)
|
||||
|
||||
def executemany(self, query, param_list):
|
||||
# Extract names if params is a mapping, i.e. "pyformat" style is used.
|
||||
# Peek carefully as a generator can be passed instead of a list/tuple.
|
||||
peekable, param_list = tee(iter(param_list))
|
||||
if (params := next(peekable, None)) and isinstance(params, Mapping):
|
||||
param_names = list(params)
|
||||
else:
|
||||
param_names = None
|
||||
query = self.convert_query(query, param_names=param_names)
|
||||
return super().executemany(query, param_list)
|
||||
|
||||
def convert_query(self, query, *, param_names=None):
|
||||
if param_names is None:
|
||||
# Convert from "format" style to "qmark" style.
|
||||
return FORMAT_QMARK_REGEX.sub("?", query).replace("%%", "%")
|
||||
else:
|
||||
# Convert from "pyformat" style to "named" style.
|
||||
return query % {name: f":{name}" for name in param_names}
|
||||
@@ -0,0 +1,10 @@
|
||||
from django.db.backends.base.client import BaseDatabaseClient
|
||||
|
||||
|
||||
class DatabaseClient(BaseDatabaseClient):
|
||||
executable_name = "sqlite3"
|
||||
|
||||
@classmethod
|
||||
def settings_to_cmd_args_env(cls, settings_dict, parameters):
|
||||
args = [cls.executable_name, settings_dict["NAME"], *parameters]
|
||||
return args, None
|
||||
@@ -0,0 +1,159 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
import shutil
|
||||
import sqlite3
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from django.db import NotSupportedError
|
||||
from django.db.backends.base.creation import BaseDatabaseCreation
|
||||
|
||||
|
||||
class DatabaseCreation(BaseDatabaseCreation):
|
||||
@staticmethod
|
||||
def is_in_memory_db(database_name):
|
||||
return not isinstance(database_name, Path) and (
|
||||
database_name == ":memory:" or "mode=memory" in database_name
|
||||
)
|
||||
|
||||
def _get_test_db_name(self):
|
||||
test_database_name = self.connection.settings_dict["TEST"]["NAME"] or ":memory:"
|
||||
if test_database_name == ":memory:":
|
||||
return "file:memorydb_%s?mode=memory&cache=shared" % self.connection.alias
|
||||
return test_database_name
|
||||
|
||||
def _create_test_db(self, verbosity, autoclobber, keepdb=False):
|
||||
test_database_name = self._get_test_db_name()
|
||||
|
||||
if keepdb:
|
||||
return test_database_name
|
||||
if not self.is_in_memory_db(test_database_name):
|
||||
# Erase the old test database
|
||||
if verbosity >= 1:
|
||||
self.log(
|
||||
"Destroying old test database for alias %s..."
|
||||
% (self._get_database_display_str(verbosity, test_database_name),)
|
||||
)
|
||||
if os.access(test_database_name, os.F_OK):
|
||||
if not autoclobber:
|
||||
confirm = input(
|
||||
"Type 'yes' if you would like to try deleting the test "
|
||||
"database '%s', or 'no' to cancel: " % test_database_name
|
||||
)
|
||||
if autoclobber or confirm == "yes":
|
||||
try:
|
||||
os.remove(test_database_name)
|
||||
except Exception as e:
|
||||
self.log("Got an error deleting the old test database: %s" % e)
|
||||
sys.exit(2)
|
||||
else:
|
||||
self.log("Tests cancelled.")
|
||||
sys.exit(1)
|
||||
return test_database_name
|
||||
|
||||
def get_test_db_clone_settings(self, suffix):
|
||||
orig_settings_dict = self.connection.settings_dict
|
||||
source_database_name = orig_settings_dict["NAME"]
|
||||
|
||||
if not self.is_in_memory_db(source_database_name):
|
||||
root, ext = os.path.splitext(source_database_name)
|
||||
return {**orig_settings_dict, "NAME": f"{root}_{suffix}{ext}"}
|
||||
|
||||
start_method = multiprocessing.get_start_method()
|
||||
if start_method == "fork":
|
||||
return orig_settings_dict
|
||||
if start_method == "spawn":
|
||||
return {
|
||||
**orig_settings_dict,
|
||||
"NAME": f"{self.connection.alias}_{suffix}.sqlite3",
|
||||
}
|
||||
raise NotSupportedError(
|
||||
f"Cloning with start method {start_method!r} is not supported."
|
||||
)
|
||||
|
||||
def _clone_test_db(self, suffix, verbosity, keepdb=False):
|
||||
source_database_name = self.connection.settings_dict["NAME"]
|
||||
target_database_name = self.get_test_db_clone_settings(suffix)["NAME"]
|
||||
if not self.is_in_memory_db(source_database_name):
|
||||
# Erase the old test database
|
||||
if os.access(target_database_name, os.F_OK):
|
||||
if keepdb:
|
||||
return
|
||||
if verbosity >= 1:
|
||||
self.log(
|
||||
"Destroying old test database for alias %s..."
|
||||
% (
|
||||
self._get_database_display_str(
|
||||
verbosity, target_database_name
|
||||
),
|
||||
)
|
||||
)
|
||||
try:
|
||||
os.remove(target_database_name)
|
||||
except Exception as e:
|
||||
self.log("Got an error deleting the old test database: %s" % e)
|
||||
sys.exit(2)
|
||||
try:
|
||||
shutil.copy(source_database_name, target_database_name)
|
||||
except Exception as e:
|
||||
self.log("Got an error cloning the test database: %s" % e)
|
||||
sys.exit(2)
|
||||
# Forking automatically makes a copy of an in-memory database.
|
||||
# Spawn requires migrating to disk which will be re-opened in
|
||||
# setup_worker_connection.
|
||||
elif multiprocessing.get_start_method() == "spawn":
|
||||
ondisk_db = sqlite3.connect(target_database_name, uri=True)
|
||||
self.connection.connection.backup(ondisk_db)
|
||||
ondisk_db.close()
|
||||
|
||||
def _destroy_test_db(self, test_database_name, verbosity):
|
||||
if test_database_name and not self.is_in_memory_db(test_database_name):
|
||||
# Remove the SQLite database file
|
||||
os.remove(test_database_name)
|
||||
|
||||
def test_db_signature(self):
|
||||
"""
|
||||
Return a tuple that uniquely identifies a test database.
|
||||
|
||||
This takes into account the special cases of ":memory:" and "" for
|
||||
SQLite since the databases will be distinct despite having the same
|
||||
TEST NAME. See https://www.sqlite.org/inmemorydb.html
|
||||
"""
|
||||
test_database_name = self._get_test_db_name()
|
||||
sig = [self.connection.settings_dict["NAME"]]
|
||||
if self.is_in_memory_db(test_database_name):
|
||||
sig.append(self.connection.alias)
|
||||
else:
|
||||
sig.append(test_database_name)
|
||||
return tuple(sig)
|
||||
|
||||
def setup_worker_connection(self, _worker_id):
|
||||
settings_dict = self.get_test_db_clone_settings(_worker_id)
|
||||
# connection.settings_dict must be updated in place for changes to be
|
||||
# reflected in django.db.connections. Otherwise new threads would
|
||||
# connect to the default database instead of the appropriate clone.
|
||||
start_method = multiprocessing.get_start_method()
|
||||
if start_method == "fork":
|
||||
# Update settings_dict in place.
|
||||
self.connection.settings_dict.update(settings_dict)
|
||||
self.connection.close()
|
||||
elif start_method == "spawn":
|
||||
alias = self.connection.alias
|
||||
connection_str = (
|
||||
f"file:memorydb_{alias}_{_worker_id}?mode=memory&cache=shared"
|
||||
)
|
||||
source_db = self.connection.Database.connect(
|
||||
f"file:{alias}_{_worker_id}.sqlite3", uri=True
|
||||
)
|
||||
target_db = sqlite3.connect(connection_str, uri=True)
|
||||
source_db.backup(target_db)
|
||||
source_db.close()
|
||||
# Update settings_dict in place.
|
||||
self.connection.settings_dict.update(settings_dict)
|
||||
self.connection.settings_dict["NAME"] = connection_str
|
||||
# Re-open connection to in-memory database before closing copy
|
||||
# connection.
|
||||
self.connection.connect()
|
||||
target_db.close()
|
||||
if os.environ.get("RUNNING_DJANGOS_TEST_SUITE") == "true":
|
||||
self.mark_expected_failures_and_skips()
|
||||
@@ -0,0 +1,165 @@
|
||||
import operator
|
||||
|
||||
from django.db import transaction
|
||||
from django.db.backends.base.features import BaseDatabaseFeatures
|
||||
from django.db.utils import OperationalError
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
from .base import Database
|
||||
|
||||
|
||||
class DatabaseFeatures(BaseDatabaseFeatures):
|
||||
minimum_database_version = (3, 21)
|
||||
test_db_allows_multiple_connections = False
|
||||
supports_unspecified_pk = True
|
||||
supports_timezones = False
|
||||
max_query_params = 999
|
||||
supports_transactions = True
|
||||
atomic_transactions = False
|
||||
can_rollback_ddl = True
|
||||
can_create_inline_fk = False
|
||||
requires_literal_defaults = True
|
||||
can_clone_databases = True
|
||||
supports_temporal_subtraction = True
|
||||
ignores_table_name_case = True
|
||||
supports_cast_with_precision = False
|
||||
time_cast_precision = 3
|
||||
can_release_savepoints = True
|
||||
has_case_insensitive_like = True
|
||||
# Is "ALTER TABLE ... RENAME COLUMN" supported?
|
||||
can_alter_table_rename_column = Database.sqlite_version_info >= (3, 25, 0)
|
||||
# Is "ALTER TABLE ... DROP COLUMN" supported?
|
||||
can_alter_table_drop_column = Database.sqlite_version_info >= (3, 35, 5)
|
||||
supports_parentheses_in_compound = False
|
||||
can_defer_constraint_checks = True
|
||||
supports_over_clause = Database.sqlite_version_info >= (3, 25, 0)
|
||||
supports_frame_range_fixed_distance = Database.sqlite_version_info >= (3, 28, 0)
|
||||
supports_aggregate_filter_clause = Database.sqlite_version_info >= (3, 30, 1)
|
||||
supports_order_by_nulls_modifier = Database.sqlite_version_info >= (3, 30, 0)
|
||||
# NULLS LAST/FIRST emulation on < 3.30 requires subquery wrapping.
|
||||
requires_compound_order_by_subquery = Database.sqlite_version_info < (3, 30)
|
||||
order_by_nulls_first = True
|
||||
supports_json_field_contains = False
|
||||
supports_update_conflicts = Database.sqlite_version_info >= (3, 24, 0)
|
||||
supports_update_conflicts_with_target = supports_update_conflicts
|
||||
test_collations = {
|
||||
"ci": "nocase",
|
||||
"cs": "binary",
|
||||
"non_default": "nocase",
|
||||
}
|
||||
django_test_expected_failures = {
|
||||
# The django_format_dtdelta() function doesn't properly handle mixed
|
||||
# Date/DateTime fields and timedeltas.
|
||||
"expressions.tests.FTimeDeltaTests.test_mixed_comparisons1",
|
||||
}
|
||||
create_test_table_with_composite_primary_key = """
|
||||
CREATE TABLE test_table_composite_pk (
|
||||
column_1 INTEGER NOT NULL,
|
||||
column_2 INTEGER NOT NULL,
|
||||
PRIMARY KEY(column_1, column_2)
|
||||
)
|
||||
"""
|
||||
|
||||
@cached_property
|
||||
def django_test_skips(self):
|
||||
skips = {
|
||||
"SQLite stores values rounded to 15 significant digits.": {
|
||||
"model_fields.test_decimalfield.DecimalFieldTests."
|
||||
"test_fetch_from_db_without_float_rounding",
|
||||
},
|
||||
"SQLite naively remakes the table on field alteration.": {
|
||||
"schema.tests.SchemaTests.test_unique_no_unnecessary_fk_drops",
|
||||
"schema.tests.SchemaTests.test_unique_and_reverse_m2m",
|
||||
"schema.tests.SchemaTests."
|
||||
"test_alter_field_default_doesnt_perform_queries",
|
||||
"schema.tests.SchemaTests."
|
||||
"test_rename_column_renames_deferred_sql_references",
|
||||
},
|
||||
"SQLite doesn't support negative precision for ROUND().": {
|
||||
"db_functions.math.test_round.RoundTests."
|
||||
"test_null_with_negative_precision",
|
||||
"db_functions.math.test_round.RoundTests."
|
||||
"test_decimal_with_negative_precision",
|
||||
"db_functions.math.test_round.RoundTests."
|
||||
"test_float_with_negative_precision",
|
||||
"db_functions.math.test_round.RoundTests."
|
||||
"test_integer_with_negative_precision",
|
||||
},
|
||||
}
|
||||
if Database.sqlite_version_info < (3, 27):
|
||||
skips.update(
|
||||
{
|
||||
"Nondeterministic failure on SQLite < 3.27.": {
|
||||
"expressions_window.tests.WindowFunctionTests."
|
||||
"test_subquery_row_range_rank",
|
||||
},
|
||||
}
|
||||
)
|
||||
if self.connection.is_in_memory_db():
|
||||
skips.update(
|
||||
{
|
||||
"the sqlite backend's close() method is a no-op when using an "
|
||||
"in-memory database": {
|
||||
"servers.test_liveserverthread.LiveServerThreadTest."
|
||||
"test_closes_connections",
|
||||
"servers.tests.LiveServerTestCloseConnectionTest."
|
||||
"test_closes_connections",
|
||||
},
|
||||
"For SQLite in-memory tests, closing the connection destroys"
|
||||
"the database.": {
|
||||
"test_utils.tests.AssertNumQueriesUponConnectionTests."
|
||||
"test_ignores_connection_configuration_queries",
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
skips.update(
|
||||
{
|
||||
"Only connections to in-memory SQLite databases are passed to the "
|
||||
"server thread.": {
|
||||
"servers.tests.LiveServerInMemoryDatabaseLockTest."
|
||||
"test_in_memory_database_lock",
|
||||
},
|
||||
"multiprocessing's start method is checked only for in-memory "
|
||||
"SQLite databases": {
|
||||
"backends.sqlite.test_creation.TestDbSignatureTests."
|
||||
"test_get_test_db_clone_settings_not_supported",
|
||||
},
|
||||
}
|
||||
)
|
||||
return skips
|
||||
|
||||
@cached_property
|
||||
def supports_atomic_references_rename(self):
|
||||
return Database.sqlite_version_info >= (3, 26, 0)
|
||||
|
||||
@cached_property
|
||||
def introspected_field_types(self):
|
||||
return {
|
||||
**super().introspected_field_types,
|
||||
"BigAutoField": "AutoField",
|
||||
"DurationField": "BigIntegerField",
|
||||
"GenericIPAddressField": "CharField",
|
||||
"SmallAutoField": "AutoField",
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def supports_json_field(self):
|
||||
with self.connection.cursor() as cursor:
|
||||
try:
|
||||
with transaction.atomic(self.connection.alias):
|
||||
cursor.execute('SELECT JSON(\'{"a": "b"}\')')
|
||||
except OperationalError:
|
||||
return False
|
||||
return True
|
||||
|
||||
can_introspect_json_field = property(operator.attrgetter("supports_json_field"))
|
||||
has_json_object_function = property(operator.attrgetter("supports_json_field"))
|
||||
|
||||
@cached_property
|
||||
def can_return_columns_from_insert(self):
|
||||
return Database.sqlite_version_info >= (3, 35)
|
||||
|
||||
can_return_rows_from_bulk_insert = property(
|
||||
operator.attrgetter("can_return_columns_from_insert")
|
||||
)
|
||||
@@ -0,0 +1,434 @@
|
||||
from collections import namedtuple
|
||||
|
||||
import sqlparse
|
||||
|
||||
from django.db import DatabaseError
|
||||
from django.db.backends.base.introspection import BaseDatabaseIntrospection
|
||||
from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo
|
||||
from django.db.backends.base.introspection import TableInfo
|
||||
from django.db.models import Index
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
|
||||
FieldInfo = namedtuple(
|
||||
"FieldInfo", BaseFieldInfo._fields + ("pk", "has_json_constraint")
|
||||
)
|
||||
|
||||
field_size_re = _lazy_re_compile(r"^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$")
|
||||
|
||||
|
||||
def get_field_size(name):
|
||||
"""Extract the size number from a "varchar(11)" type name"""
|
||||
m = field_size_re.search(name)
|
||||
return int(m[1]) if m else None
|
||||
|
||||
|
||||
# This light wrapper "fakes" a dictionary interface, because some SQLite data
|
||||
# types include variables in them -- e.g. "varchar(30)" -- and can't be matched
|
||||
# as a simple dictionary lookup.
|
||||
class FlexibleFieldLookupDict:
|
||||
# Maps SQL types to Django Field types. Some of the SQL types have multiple
|
||||
# entries here because SQLite allows for anything and doesn't normalize the
|
||||
# field type; it uses whatever was given.
|
||||
base_data_types_reverse = {
|
||||
"bool": "BooleanField",
|
||||
"boolean": "BooleanField",
|
||||
"smallint": "SmallIntegerField",
|
||||
"smallint unsigned": "PositiveSmallIntegerField",
|
||||
"smallinteger": "SmallIntegerField",
|
||||
"int": "IntegerField",
|
||||
"integer": "IntegerField",
|
||||
"bigint": "BigIntegerField",
|
||||
"integer unsigned": "PositiveIntegerField",
|
||||
"bigint unsigned": "PositiveBigIntegerField",
|
||||
"decimal": "DecimalField",
|
||||
"real": "FloatField",
|
||||
"text": "TextField",
|
||||
"char": "CharField",
|
||||
"varchar": "CharField",
|
||||
"blob": "BinaryField",
|
||||
"date": "DateField",
|
||||
"datetime": "DateTimeField",
|
||||
"time": "TimeField",
|
||||
}
|
||||
|
||||
def __getitem__(self, key):
|
||||
key = key.lower().split("(", 1)[0].strip()
|
||||
return self.base_data_types_reverse[key]
|
||||
|
||||
|
||||
class DatabaseIntrospection(BaseDatabaseIntrospection):
|
||||
data_types_reverse = FlexibleFieldLookupDict()
|
||||
|
||||
def get_field_type(self, data_type, description):
|
||||
field_type = super().get_field_type(data_type, description)
|
||||
if description.pk and field_type in {
|
||||
"BigIntegerField",
|
||||
"IntegerField",
|
||||
"SmallIntegerField",
|
||||
}:
|
||||
# No support for BigAutoField or SmallAutoField as SQLite treats
|
||||
# all integer primary keys as signed 64-bit integers.
|
||||
return "AutoField"
|
||||
if description.has_json_constraint:
|
||||
return "JSONField"
|
||||
return field_type
|
||||
|
||||
def get_table_list(self, cursor):
|
||||
"""Return a list of table and view names in the current database."""
|
||||
# Skip the sqlite_sequence system table used for autoincrement key
|
||||
# generation.
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT name, type FROM sqlite_master
|
||||
WHERE type in ('table', 'view') AND NOT name='sqlite_sequence'
|
||||
ORDER BY name"""
|
||||
)
|
||||
return [TableInfo(row[0], row[1][0]) for row in cursor.fetchall()]
|
||||
|
||||
def get_table_description(self, cursor, table_name):
|
||||
"""
|
||||
Return a description of the table with the DB-API cursor.description
|
||||
interface.
|
||||
"""
|
||||
cursor.execute(
|
||||
"PRAGMA table_info(%s)" % self.connection.ops.quote_name(table_name)
|
||||
)
|
||||
table_info = cursor.fetchall()
|
||||
if not table_info:
|
||||
raise DatabaseError(f"Table {table_name} does not exist (empty pragma).")
|
||||
collations = self._get_column_collations(cursor, table_name)
|
||||
json_columns = set()
|
||||
if self.connection.features.can_introspect_json_field:
|
||||
for line in table_info:
|
||||
column = line[1]
|
||||
json_constraint_sql = '%%json_valid("%s")%%' % column
|
||||
has_json_constraint = cursor.execute(
|
||||
"""
|
||||
SELECT sql
|
||||
FROM sqlite_master
|
||||
WHERE
|
||||
type = 'table' AND
|
||||
name = %s AND
|
||||
sql LIKE %s
|
||||
""",
|
||||
[table_name, json_constraint_sql],
|
||||
).fetchone()
|
||||
if has_json_constraint:
|
||||
json_columns.add(column)
|
||||
return [
|
||||
FieldInfo(
|
||||
name,
|
||||
data_type,
|
||||
get_field_size(data_type),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
not notnull,
|
||||
default,
|
||||
collations.get(name),
|
||||
pk == 1,
|
||||
name in json_columns,
|
||||
)
|
||||
for cid, name, data_type, notnull, default, pk in table_info
|
||||
]
|
||||
|
||||
def get_sequences(self, cursor, table_name, table_fields=()):
|
||||
pk_col = self.get_primary_key_column(cursor, table_name)
|
||||
return [{"table": table_name, "column": pk_col}]
|
||||
|
||||
def get_relations(self, cursor, table_name):
|
||||
"""
|
||||
Return a dictionary of {column_name: (ref_column_name, ref_table_name)}
|
||||
representing all foreign keys in the given table.
|
||||
"""
|
||||
cursor.execute(
|
||||
"PRAGMA foreign_key_list(%s)" % self.connection.ops.quote_name(table_name)
|
||||
)
|
||||
return {
|
||||
column_name: (ref_column_name, ref_table_name)
|
||||
for (
|
||||
_,
|
||||
_,
|
||||
ref_table_name,
|
||||
column_name,
|
||||
ref_column_name,
|
||||
*_,
|
||||
) in cursor.fetchall()
|
||||
}
|
||||
|
||||
def get_primary_key_columns(self, cursor, table_name):
|
||||
cursor.execute(
|
||||
"PRAGMA table_info(%s)" % self.connection.ops.quote_name(table_name)
|
||||
)
|
||||
return [name for _, name, *_, pk in cursor.fetchall() if pk]
|
||||
|
||||
def _parse_column_or_constraint_definition(self, tokens, columns):
|
||||
token = None
|
||||
is_constraint_definition = None
|
||||
field_name = None
|
||||
constraint_name = None
|
||||
unique = False
|
||||
unique_columns = []
|
||||
check = False
|
||||
check_columns = []
|
||||
braces_deep = 0
|
||||
for token in tokens:
|
||||
if token.match(sqlparse.tokens.Punctuation, "("):
|
||||
braces_deep += 1
|
||||
elif token.match(sqlparse.tokens.Punctuation, ")"):
|
||||
braces_deep -= 1
|
||||
if braces_deep < 0:
|
||||
# End of columns and constraints for table definition.
|
||||
break
|
||||
elif braces_deep == 0 and token.match(sqlparse.tokens.Punctuation, ","):
|
||||
# End of current column or constraint definition.
|
||||
break
|
||||
# Detect column or constraint definition by first token.
|
||||
if is_constraint_definition is None:
|
||||
is_constraint_definition = token.match(
|
||||
sqlparse.tokens.Keyword, "CONSTRAINT"
|
||||
)
|
||||
if is_constraint_definition:
|
||||
continue
|
||||
if is_constraint_definition:
|
||||
# Detect constraint name by second token.
|
||||
if constraint_name is None:
|
||||
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
|
||||
constraint_name = token.value
|
||||
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
|
||||
constraint_name = token.value[1:-1]
|
||||
# Start constraint columns parsing after UNIQUE keyword.
|
||||
if token.match(sqlparse.tokens.Keyword, "UNIQUE"):
|
||||
unique = True
|
||||
unique_braces_deep = braces_deep
|
||||
elif unique:
|
||||
if unique_braces_deep == braces_deep:
|
||||
if unique_columns:
|
||||
# Stop constraint parsing.
|
||||
unique = False
|
||||
continue
|
||||
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
|
||||
unique_columns.append(token.value)
|
||||
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
|
||||
unique_columns.append(token.value[1:-1])
|
||||
else:
|
||||
# Detect field name by first token.
|
||||
if field_name is None:
|
||||
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
|
||||
field_name = token.value
|
||||
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
|
||||
field_name = token.value[1:-1]
|
||||
if token.match(sqlparse.tokens.Keyword, "UNIQUE"):
|
||||
unique_columns = [field_name]
|
||||
# Start constraint columns parsing after CHECK keyword.
|
||||
if token.match(sqlparse.tokens.Keyword, "CHECK"):
|
||||
check = True
|
||||
check_braces_deep = braces_deep
|
||||
elif check:
|
||||
if check_braces_deep == braces_deep:
|
||||
if check_columns:
|
||||
# Stop constraint parsing.
|
||||
check = False
|
||||
continue
|
||||
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
|
||||
if token.value in columns:
|
||||
check_columns.append(token.value)
|
||||
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
|
||||
if token.value[1:-1] in columns:
|
||||
check_columns.append(token.value[1:-1])
|
||||
unique_constraint = (
|
||||
{
|
||||
"unique": True,
|
||||
"columns": unique_columns,
|
||||
"primary_key": False,
|
||||
"foreign_key": None,
|
||||
"check": False,
|
||||
"index": False,
|
||||
}
|
||||
if unique_columns
|
||||
else None
|
||||
)
|
||||
check_constraint = (
|
||||
{
|
||||
"check": True,
|
||||
"columns": check_columns,
|
||||
"primary_key": False,
|
||||
"unique": False,
|
||||
"foreign_key": None,
|
||||
"index": False,
|
||||
}
|
||||
if check_columns
|
||||
else None
|
||||
)
|
||||
return constraint_name, unique_constraint, check_constraint, token
|
||||
|
||||
def _parse_table_constraints(self, sql, columns):
|
||||
# Check constraint parsing is based of SQLite syntax diagram.
|
||||
# https://www.sqlite.org/syntaxdiagrams.html#table-constraint
|
||||
statement = sqlparse.parse(sql)[0]
|
||||
constraints = {}
|
||||
unnamed_constrains_index = 0
|
||||
tokens = (token for token in statement.flatten() if not token.is_whitespace)
|
||||
# Go to columns and constraint definition
|
||||
for token in tokens:
|
||||
if token.match(sqlparse.tokens.Punctuation, "("):
|
||||
break
|
||||
# Parse columns and constraint definition
|
||||
while True:
|
||||
(
|
||||
constraint_name,
|
||||
unique,
|
||||
check,
|
||||
end_token,
|
||||
) = self._parse_column_or_constraint_definition(tokens, columns)
|
||||
if unique:
|
||||
if constraint_name:
|
||||
constraints[constraint_name] = unique
|
||||
else:
|
||||
unnamed_constrains_index += 1
|
||||
constraints[
|
||||
"__unnamed_constraint_%s__" % unnamed_constrains_index
|
||||
] = unique
|
||||
if check:
|
||||
if constraint_name:
|
||||
constraints[constraint_name] = check
|
||||
else:
|
||||
unnamed_constrains_index += 1
|
||||
constraints[
|
||||
"__unnamed_constraint_%s__" % unnamed_constrains_index
|
||||
] = check
|
||||
if end_token.match(sqlparse.tokens.Punctuation, ")"):
|
||||
break
|
||||
return constraints
|
||||
|
||||
def get_constraints(self, cursor, table_name):
|
||||
"""
|
||||
Retrieve any constraints or keys (unique, pk, fk, check, index) across
|
||||
one or more columns.
|
||||
"""
|
||||
constraints = {}
|
||||
# Find inline check constraints.
|
||||
try:
|
||||
table_schema = cursor.execute(
|
||||
"SELECT sql FROM sqlite_master WHERE type='table' and name=%s"
|
||||
% (self.connection.ops.quote_name(table_name),)
|
||||
).fetchone()[0]
|
||||
except TypeError:
|
||||
# table_name is a view.
|
||||
pass
|
||||
else:
|
||||
columns = {
|
||||
info.name for info in self.get_table_description(cursor, table_name)
|
||||
}
|
||||
constraints.update(self._parse_table_constraints(table_schema, columns))
|
||||
|
||||
# Get the index info
|
||||
cursor.execute(
|
||||
"PRAGMA index_list(%s)" % self.connection.ops.quote_name(table_name)
|
||||
)
|
||||
for row in cursor.fetchall():
|
||||
# SQLite 3.8.9+ has 5 columns, however older versions only give 3
|
||||
# columns. Discard last 2 columns if there.
|
||||
number, index, unique = row[:3]
|
||||
cursor.execute(
|
||||
"SELECT sql FROM sqlite_master "
|
||||
"WHERE type='index' AND name=%s" % self.connection.ops.quote_name(index)
|
||||
)
|
||||
# There's at most one row.
|
||||
(sql,) = cursor.fetchone() or (None,)
|
||||
# Inline constraints are already detected in
|
||||
# _parse_table_constraints(). The reasons to avoid fetching inline
|
||||
# constraints from `PRAGMA index_list` are:
|
||||
# - Inline constraints can have a different name and information
|
||||
# than what `PRAGMA index_list` gives.
|
||||
# - Not all inline constraints may appear in `PRAGMA index_list`.
|
||||
if not sql:
|
||||
# An inline constraint
|
||||
continue
|
||||
# Get the index info for that index
|
||||
cursor.execute(
|
||||
"PRAGMA index_info(%s)" % self.connection.ops.quote_name(index)
|
||||
)
|
||||
for index_rank, column_rank, column in cursor.fetchall():
|
||||
if index not in constraints:
|
||||
constraints[index] = {
|
||||
"columns": [],
|
||||
"primary_key": False,
|
||||
"unique": bool(unique),
|
||||
"foreign_key": None,
|
||||
"check": False,
|
||||
"index": True,
|
||||
}
|
||||
constraints[index]["columns"].append(column)
|
||||
# Add type and column orders for indexes
|
||||
if constraints[index]["index"]:
|
||||
# SQLite doesn't support any index type other than b-tree
|
||||
constraints[index]["type"] = Index.suffix
|
||||
orders = self._get_index_columns_orders(sql)
|
||||
if orders is not None:
|
||||
constraints[index]["orders"] = orders
|
||||
# Get the PK
|
||||
pk_columns = self.get_primary_key_columns(cursor, table_name)
|
||||
if pk_columns:
|
||||
# SQLite doesn't actually give a name to the PK constraint,
|
||||
# so we invent one. This is fine, as the SQLite backend never
|
||||
# deletes PK constraints by name, as you can't delete constraints
|
||||
# in SQLite; we remake the table with a new PK instead.
|
||||
constraints["__primary__"] = {
|
||||
"columns": pk_columns,
|
||||
"primary_key": True,
|
||||
"unique": False, # It's not actually a unique constraint.
|
||||
"foreign_key": None,
|
||||
"check": False,
|
||||
"index": False,
|
||||
}
|
||||
relations = enumerate(self.get_relations(cursor, table_name).items())
|
||||
constraints.update(
|
||||
{
|
||||
f"fk_{index}": {
|
||||
"columns": [column_name],
|
||||
"primary_key": False,
|
||||
"unique": False,
|
||||
"foreign_key": (ref_table_name, ref_column_name),
|
||||
"check": False,
|
||||
"index": False,
|
||||
}
|
||||
for index, (column_name, (ref_column_name, ref_table_name)) in relations
|
||||
}
|
||||
)
|
||||
return constraints
|
||||
|
||||
def _get_index_columns_orders(self, sql):
|
||||
tokens = sqlparse.parse(sql)[0]
|
||||
for token in tokens:
|
||||
if isinstance(token, sqlparse.sql.Parenthesis):
|
||||
columns = str(token).strip("()").split(", ")
|
||||
return ["DESC" if info.endswith("DESC") else "ASC" for info in columns]
|
||||
return None
|
||||
|
||||
def _get_column_collations(self, cursor, table_name):
|
||||
row = cursor.execute(
|
||||
"""
|
||||
SELECT sql
|
||||
FROM sqlite_master
|
||||
WHERE type = 'table' AND name = %s
|
||||
""",
|
||||
[table_name],
|
||||
).fetchone()
|
||||
if not row:
|
||||
return {}
|
||||
|
||||
sql = row[0]
|
||||
columns = str(sqlparse.parse(sql)[0][-1]).strip("()").split(", ")
|
||||
collations = {}
|
||||
for column in columns:
|
||||
tokens = column[1:].split()
|
||||
column_name = tokens[0].strip('"')
|
||||
for index, token in enumerate(tokens):
|
||||
if token == "COLLATE":
|
||||
collation = tokens[index + 1]
|
||||
break
|
||||
else:
|
||||
collation = None
|
||||
collations[column_name] = collation
|
||||
return collations
|
||||
@@ -0,0 +1,434 @@
|
||||
import datetime
|
||||
import decimal
|
||||
import uuid
|
||||
from functools import lru_cache
|
||||
from itertools import chain
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db import DatabaseError, NotSupportedError, models
|
||||
from django.db.backends.base.operations import BaseDatabaseOperations
|
||||
from django.db.models.constants import OnConflict
|
||||
from django.db.models.expressions import Col
|
||||
from django.utils import timezone
|
||||
from django.utils.dateparse import parse_date, parse_datetime, parse_time
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
||||
class DatabaseOperations(BaseDatabaseOperations):
|
||||
cast_char_field_without_max_length = "text"
|
||||
cast_data_types = {
|
||||
"DateField": "TEXT",
|
||||
"DateTimeField": "TEXT",
|
||||
}
|
||||
explain_prefix = "EXPLAIN QUERY PLAN"
|
||||
# List of datatypes to that cannot be extracted with JSON_EXTRACT() on
|
||||
# SQLite. Use JSON_TYPE() instead.
|
||||
jsonfield_datatype_values = frozenset(["null", "false", "true"])
|
||||
|
||||
def bulk_batch_size(self, fields, objs):
|
||||
"""
|
||||
SQLite has a compile-time default (SQLITE_LIMIT_VARIABLE_NUMBER) of
|
||||
999 variables per query.
|
||||
|
||||
If there's only a single field to insert, the limit is 500
|
||||
(SQLITE_MAX_COMPOUND_SELECT).
|
||||
"""
|
||||
if len(fields) == 1:
|
||||
return 500
|
||||
elif len(fields) > 1:
|
||||
return self.connection.features.max_query_params // len(fields)
|
||||
else:
|
||||
return len(objs)
|
||||
|
||||
def check_expression_support(self, expression):
|
||||
bad_fields = (models.DateField, models.DateTimeField, models.TimeField)
|
||||
bad_aggregates = (models.Sum, models.Avg, models.Variance, models.StdDev)
|
||||
if isinstance(expression, bad_aggregates):
|
||||
for expr in expression.get_source_expressions():
|
||||
try:
|
||||
output_field = expr.output_field
|
||||
except (AttributeError, FieldError):
|
||||
# Not every subexpression has an output_field which is fine
|
||||
# to ignore.
|
||||
pass
|
||||
else:
|
||||
if isinstance(output_field, bad_fields):
|
||||
raise NotSupportedError(
|
||||
"You cannot use Sum, Avg, StdDev, and Variance "
|
||||
"aggregations on date/time fields in sqlite3 "
|
||||
"since date/time is saved as text."
|
||||
)
|
||||
if (
|
||||
isinstance(expression, models.Aggregate)
|
||||
and expression.distinct
|
||||
and len(expression.source_expressions) > 1
|
||||
):
|
||||
raise NotSupportedError(
|
||||
"SQLite doesn't support DISTINCT on aggregate functions "
|
||||
"accepting multiple arguments."
|
||||
)
|
||||
|
||||
def date_extract_sql(self, lookup_type, sql, params):
|
||||
"""
|
||||
Support EXTRACT with a user-defined function django_date_extract()
|
||||
that's registered in connect(). Use single quotes because this is a
|
||||
string and could otherwise cause a collision with a field name.
|
||||
"""
|
||||
return f"django_date_extract(%s, {sql})", (lookup_type.lower(), *params)
|
||||
|
||||
def fetch_returned_insert_rows(self, cursor):
|
||||
"""
|
||||
Given a cursor object that has just performed an INSERT...RETURNING
|
||||
statement into a table, return the list of returned data.
|
||||
"""
|
||||
return cursor.fetchall()
|
||||
|
||||
def format_for_duration_arithmetic(self, sql):
|
||||
"""Do nothing since formatting is handled in the custom function."""
|
||||
return sql
|
||||
|
||||
def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
|
||||
return f"django_date_trunc(%s, {sql}, %s, %s)", (
|
||||
lookup_type.lower(),
|
||||
*params,
|
||||
*self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
|
||||
return f"django_time_trunc(%s, {sql}, %s, %s)", (
|
||||
lookup_type.lower(),
|
||||
*params,
|
||||
*self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def _convert_tznames_to_sql(self, tzname):
|
||||
if tzname and settings.USE_TZ:
|
||||
return tzname, self.connection.timezone_name
|
||||
return None, None
|
||||
|
||||
def datetime_cast_date_sql(self, sql, params, tzname):
|
||||
return f"django_datetime_cast_date({sql}, %s, %s)", (
|
||||
*params,
|
||||
*self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def datetime_cast_time_sql(self, sql, params, tzname):
|
||||
return f"django_datetime_cast_time({sql}, %s, %s)", (
|
||||
*params,
|
||||
*self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def datetime_extract_sql(self, lookup_type, sql, params, tzname):
|
||||
return f"django_datetime_extract(%s, {sql}, %s, %s)", (
|
||||
lookup_type.lower(),
|
||||
*params,
|
||||
*self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
|
||||
return f"django_datetime_trunc(%s, {sql}, %s, %s)", (
|
||||
lookup_type.lower(),
|
||||
*params,
|
||||
*self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def time_extract_sql(self, lookup_type, sql, params):
|
||||
return f"django_time_extract(%s, {sql})", (lookup_type.lower(), *params)
|
||||
|
||||
def pk_default_value(self):
|
||||
return "NULL"
|
||||
|
||||
def _quote_params_for_last_executed_query(self, params):
|
||||
"""
|
||||
Only for last_executed_query! Don't use this to execute SQL queries!
|
||||
"""
|
||||
# This function is limited both by SQLITE_LIMIT_VARIABLE_NUMBER (the
|
||||
# number of parameters, default = 999) and SQLITE_MAX_COLUMN (the
|
||||
# number of return values, default = 2000). Since Python's sqlite3
|
||||
# module doesn't expose the get_limit() C API, assume the default
|
||||
# limits are in effect and split the work in batches if needed.
|
||||
BATCH_SIZE = 999
|
||||
if len(params) > BATCH_SIZE:
|
||||
results = ()
|
||||
for index in range(0, len(params), BATCH_SIZE):
|
||||
chunk = params[index : index + BATCH_SIZE]
|
||||
results += self._quote_params_for_last_executed_query(chunk)
|
||||
return results
|
||||
|
||||
sql = "SELECT " + ", ".join(["QUOTE(?)"] * len(params))
|
||||
# Bypass Django's wrappers and use the underlying sqlite3 connection
|
||||
# to avoid logging this query - it would trigger infinite recursion.
|
||||
cursor = self.connection.connection.cursor()
|
||||
# Native sqlite3 cursors cannot be used as context managers.
|
||||
try:
|
||||
return cursor.execute(sql, params).fetchone()
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def last_executed_query(self, cursor, sql, params):
|
||||
# Python substitutes parameters in Modules/_sqlite/cursor.c with:
|
||||
# bind_parameters(state, self->statement, parameters);
|
||||
# Unfortunately there is no way to reach self->statement from Python,
|
||||
# so we quote and substitute parameters manually.
|
||||
if params:
|
||||
if isinstance(params, (list, tuple)):
|
||||
params = self._quote_params_for_last_executed_query(params)
|
||||
else:
|
||||
values = tuple(params.values())
|
||||
values = self._quote_params_for_last_executed_query(values)
|
||||
params = dict(zip(params, values))
|
||||
return sql % params
|
||||
# For consistency with SQLiteCursorWrapper.execute(), just return sql
|
||||
# when there are no parameters. See #13648 and #17158.
|
||||
else:
|
||||
return sql
|
||||
|
||||
def quote_name(self, name):
|
||||
if name.startswith('"') and name.endswith('"'):
|
||||
return name # Quoting once is enough.
|
||||
return '"%s"' % name
|
||||
|
||||
def no_limit_value(self):
|
||||
return -1
|
||||
|
||||
def __references_graph(self, table_name):
|
||||
query = """
|
||||
WITH tables AS (
|
||||
SELECT %s name
|
||||
UNION
|
||||
SELECT sqlite_master.name
|
||||
FROM sqlite_master
|
||||
JOIN tables ON (sql REGEXP %s || tables.name || %s)
|
||||
) SELECT name FROM tables;
|
||||
"""
|
||||
params = (
|
||||
table_name,
|
||||
r'(?i)\s+references\s+("|\')?',
|
||||
r'("|\')?\s*\(',
|
||||
)
|
||||
with self.connection.cursor() as cursor:
|
||||
results = cursor.execute(query, params)
|
||||
return [row[0] for row in results.fetchall()]
|
||||
|
||||
@cached_property
|
||||
def _references_graph(self):
|
||||
# 512 is large enough to fit the ~330 tables (as of this writing) in
|
||||
# Django's test suite.
|
||||
return lru_cache(maxsize=512)(self.__references_graph)
|
||||
|
||||
def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
|
||||
if tables and allow_cascade:
|
||||
# Simulate TRUNCATE CASCADE by recursively collecting the tables
|
||||
# referencing the tables to be flushed.
|
||||
tables = set(
|
||||
chain.from_iterable(self._references_graph(table) for table in tables)
|
||||
)
|
||||
sql = [
|
||||
"%s %s %s;"
|
||||
% (
|
||||
style.SQL_KEYWORD("DELETE"),
|
||||
style.SQL_KEYWORD("FROM"),
|
||||
style.SQL_FIELD(self.quote_name(table)),
|
||||
)
|
||||
for table in tables
|
||||
]
|
||||
if reset_sequences:
|
||||
sequences = [{"table": table} for table in tables]
|
||||
sql.extend(self.sequence_reset_by_name_sql(style, sequences))
|
||||
return sql
|
||||
|
||||
def sequence_reset_by_name_sql(self, style, sequences):
|
||||
if not sequences:
|
||||
return []
|
||||
return [
|
||||
"%s %s %s %s = 0 %s %s %s (%s);"
|
||||
% (
|
||||
style.SQL_KEYWORD("UPDATE"),
|
||||
style.SQL_TABLE(self.quote_name("sqlite_sequence")),
|
||||
style.SQL_KEYWORD("SET"),
|
||||
style.SQL_FIELD(self.quote_name("seq")),
|
||||
style.SQL_KEYWORD("WHERE"),
|
||||
style.SQL_FIELD(self.quote_name("name")),
|
||||
style.SQL_KEYWORD("IN"),
|
||||
", ".join(
|
||||
["'%s'" % sequence_info["table"] for sequence_info in sequences]
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
def adapt_datetimefield_value(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# Expression values are adapted by the database.
|
||||
if hasattr(value, "resolve_expression"):
|
||||
return value
|
||||
|
||||
# SQLite doesn't support tz-aware datetimes
|
||||
if timezone.is_aware(value):
|
||||
if settings.USE_TZ:
|
||||
value = timezone.make_naive(value, self.connection.timezone)
|
||||
else:
|
||||
raise ValueError(
|
||||
"SQLite backend does not support timezone-aware datetimes when "
|
||||
"USE_TZ is False."
|
||||
)
|
||||
|
||||
return str(value)
|
||||
|
||||
def adapt_timefield_value(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# Expression values are adapted by the database.
|
||||
if hasattr(value, "resolve_expression"):
|
||||
return value
|
||||
|
||||
# SQLite doesn't support tz-aware datetimes
|
||||
if timezone.is_aware(value):
|
||||
raise ValueError("SQLite backend does not support timezone-aware times.")
|
||||
|
||||
return str(value)
|
||||
|
||||
def get_db_converters(self, expression):
|
||||
converters = super().get_db_converters(expression)
|
||||
internal_type = expression.output_field.get_internal_type()
|
||||
if internal_type == "DateTimeField":
|
||||
converters.append(self.convert_datetimefield_value)
|
||||
elif internal_type == "DateField":
|
||||
converters.append(self.convert_datefield_value)
|
||||
elif internal_type == "TimeField":
|
||||
converters.append(self.convert_timefield_value)
|
||||
elif internal_type == "DecimalField":
|
||||
converters.append(self.get_decimalfield_converter(expression))
|
||||
elif internal_type == "UUIDField":
|
||||
converters.append(self.convert_uuidfield_value)
|
||||
elif internal_type == "BooleanField":
|
||||
converters.append(self.convert_booleanfield_value)
|
||||
return converters
|
||||
|
||||
def convert_datetimefield_value(self, value, expression, connection):
|
||||
if value is not None:
|
||||
if not isinstance(value, datetime.datetime):
|
||||
value = parse_datetime(value)
|
||||
if settings.USE_TZ and not timezone.is_aware(value):
|
||||
value = timezone.make_aware(value, self.connection.timezone)
|
||||
return value
|
||||
|
||||
def convert_datefield_value(self, value, expression, connection):
|
||||
if value is not None:
|
||||
if not isinstance(value, datetime.date):
|
||||
value = parse_date(value)
|
||||
return value
|
||||
|
||||
def convert_timefield_value(self, value, expression, connection):
|
||||
if value is not None:
|
||||
if not isinstance(value, datetime.time):
|
||||
value = parse_time(value)
|
||||
return value
|
||||
|
||||
def get_decimalfield_converter(self, expression):
|
||||
# SQLite stores only 15 significant digits. Digits coming from
|
||||
# float inaccuracy must be removed.
|
||||
create_decimal = decimal.Context(prec=15).create_decimal_from_float
|
||||
if isinstance(expression, Col):
|
||||
quantize_value = decimal.Decimal(1).scaleb(
|
||||
-expression.output_field.decimal_places
|
||||
)
|
||||
|
||||
def converter(value, expression, connection):
|
||||
if value is not None:
|
||||
return create_decimal(value).quantize(
|
||||
quantize_value, context=expression.output_field.context
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
def converter(value, expression, connection):
|
||||
if value is not None:
|
||||
return create_decimal(value)
|
||||
|
||||
return converter
|
||||
|
||||
def convert_uuidfield_value(self, value, expression, connection):
|
||||
if value is not None:
|
||||
value = uuid.UUID(value)
|
||||
return value
|
||||
|
||||
def convert_booleanfield_value(self, value, expression, connection):
|
||||
return bool(value) if value in (1, 0) else value
|
||||
|
||||
def bulk_insert_sql(self, fields, placeholder_rows):
|
||||
placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
|
||||
values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql)
|
||||
return f"VALUES {values_sql}"
|
||||
|
||||
def combine_expression(self, connector, sub_expressions):
|
||||
# SQLite doesn't have a ^ operator, so use the user-defined POWER
|
||||
# function that's registered in connect().
|
||||
if connector == "^":
|
||||
return "POWER(%s)" % ",".join(sub_expressions)
|
||||
elif connector == "#":
|
||||
return "BITXOR(%s)" % ",".join(sub_expressions)
|
||||
return super().combine_expression(connector, sub_expressions)
|
||||
|
||||
def combine_duration_expression(self, connector, sub_expressions):
|
||||
if connector not in ["+", "-", "*", "/"]:
|
||||
raise DatabaseError("Invalid connector for timedelta: %s." % connector)
|
||||
fn_params = ["'%s'" % connector] + sub_expressions
|
||||
if len(fn_params) > 3:
|
||||
raise ValueError("Too many params for timedelta operations.")
|
||||
return "django_format_dtdelta(%s)" % ", ".join(fn_params)
|
||||
|
||||
def integer_field_range(self, internal_type):
|
||||
# SQLite doesn't enforce any integer constraints
|
||||
return (None, None)
|
||||
|
||||
def subtract_temporals(self, internal_type, lhs, rhs):
|
||||
lhs_sql, lhs_params = lhs
|
||||
rhs_sql, rhs_params = rhs
|
||||
params = (*lhs_params, *rhs_params)
|
||||
if internal_type == "TimeField":
|
||||
return "django_time_diff(%s, %s)" % (lhs_sql, rhs_sql), params
|
||||
return "django_timestamp_diff(%s, %s)" % (lhs_sql, rhs_sql), params
|
||||
|
||||
def insert_statement(self, on_conflict=None):
|
||||
if on_conflict == OnConflict.IGNORE:
|
||||
return "INSERT OR IGNORE INTO"
|
||||
return super().insert_statement(on_conflict=on_conflict)
|
||||
|
||||
def return_insert_columns(self, fields):
|
||||
# SQLite < 3.35 doesn't support an INSERT...RETURNING statement.
|
||||
if not fields:
|
||||
return "", ()
|
||||
columns = [
|
||||
"%s.%s"
|
||||
% (
|
||||
self.quote_name(field.model._meta.db_table),
|
||||
self.quote_name(field.column),
|
||||
)
|
||||
for field in fields
|
||||
]
|
||||
return "RETURNING %s" % ", ".join(columns), ()
|
||||
|
||||
def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
|
||||
if (
|
||||
on_conflict == OnConflict.UPDATE
|
||||
and self.connection.features.supports_update_conflicts_with_target
|
||||
):
|
||||
return "ON CONFLICT(%s) DO UPDATE SET %s" % (
|
||||
", ".join(map(self.quote_name, unique_fields)),
|
||||
", ".join(
|
||||
[
|
||||
f"{field} = EXCLUDED.{field}"
|
||||
for field in map(self.quote_name, update_fields)
|
||||
]
|
||||
),
|
||||
)
|
||||
return super().on_conflict_suffix_sql(
|
||||
fields,
|
||||
on_conflict,
|
||||
update_fields,
|
||||
unique_fields,
|
||||
)
|
||||
@@ -0,0 +1,576 @@
|
||||
import copy
|
||||
from decimal import Decimal
|
||||
|
||||
from django.apps.registry import Apps
|
||||
from django.db import NotSupportedError
|
||||
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
||||
from django.db.backends.ddl_references import Statement
|
||||
from django.db.backends.utils import strip_quotes
|
||||
from django.db.models import UniqueConstraint
|
||||
from django.db.transaction import atomic
|
||||
|
||||
|
||||
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||
sql_delete_table = "DROP TABLE %(table)s"
|
||||
sql_create_fk = None
|
||||
sql_create_inline_fk = (
|
||||
"REFERENCES %(to_table)s (%(to_column)s) DEFERRABLE INITIALLY DEFERRED"
|
||||
)
|
||||
sql_create_column_inline_fk = sql_create_inline_fk
|
||||
sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s"
|
||||
sql_create_unique = "CREATE UNIQUE INDEX %(name)s ON %(table)s (%(columns)s)"
|
||||
sql_delete_unique = "DROP INDEX %(name)s"
|
||||
|
||||
def __enter__(self):
|
||||
# Some SQLite schema alterations need foreign key constraints to be
|
||||
# disabled. Enforce it here for the duration of the schema edition.
|
||||
if not self.connection.disable_constraint_checking():
|
||||
raise NotSupportedError(
|
||||
"SQLite schema editor cannot be used while foreign key "
|
||||
"constraint checks are enabled. Make sure to disable them "
|
||||
"before entering a transaction.atomic() context because "
|
||||
"SQLite does not support disabling them in the middle of "
|
||||
"a multi-statement transaction."
|
||||
)
|
||||
return super().__enter__()
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.connection.check_constraints()
|
||||
super().__exit__(exc_type, exc_value, traceback)
|
||||
self.connection.enable_constraint_checking()
|
||||
|
||||
def quote_value(self, value):
|
||||
# The backend "mostly works" without this function and there are use
|
||||
# cases for compiling Python without the sqlite3 libraries (e.g.
|
||||
# security hardening).
|
||||
try:
|
||||
import sqlite3
|
||||
|
||||
value = sqlite3.adapt(value)
|
||||
except ImportError:
|
||||
pass
|
||||
except sqlite3.ProgrammingError:
|
||||
pass
|
||||
# Manual emulation of SQLite parameter quoting
|
||||
if isinstance(value, bool):
|
||||
return str(int(value))
|
||||
elif isinstance(value, (Decimal, float, int)):
|
||||
return str(value)
|
||||
elif isinstance(value, str):
|
||||
return "'%s'" % value.replace("'", "''")
|
||||
elif value is None:
|
||||
return "NULL"
|
||||
elif isinstance(value, (bytes, bytearray, memoryview)):
|
||||
# Bytes are only allowed for BLOB fields, encoded as string
|
||||
# literals containing hexadecimal data and preceded by a single "X"
|
||||
# character.
|
||||
return "X'%s'" % value.hex()
|
||||
else:
|
||||
raise ValueError(
|
||||
"Cannot quote parameter value %r of type %s" % (value, type(value))
|
||||
)
|
||||
|
||||
def prepare_default(self, value):
|
||||
return self.quote_value(value)
|
||||
|
||||
def _is_referenced_by_fk_constraint(
|
||||
self, table_name, column_name=None, ignore_self=False
|
||||
):
|
||||
"""
|
||||
Return whether or not the provided table name is referenced by another
|
||||
one. If `column_name` is specified, only references pointing to that
|
||||
column are considered. If `ignore_self` is True, self-referential
|
||||
constraints are ignored.
|
||||
"""
|
||||
with self.connection.cursor() as cursor:
|
||||
for other_table in self.connection.introspection.get_table_list(cursor):
|
||||
if ignore_self and other_table.name == table_name:
|
||||
continue
|
||||
relations = self.connection.introspection.get_relations(
|
||||
cursor, other_table.name
|
||||
)
|
||||
for constraint_column, constraint_table in relations.values():
|
||||
if constraint_table == table_name and (
|
||||
column_name is None or constraint_column == column_name
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def alter_db_table(
|
||||
self, model, old_db_table, new_db_table, disable_constraints=True
|
||||
):
|
||||
if (
|
||||
not self.connection.features.supports_atomic_references_rename
|
||||
and disable_constraints
|
||||
and self._is_referenced_by_fk_constraint(old_db_table)
|
||||
):
|
||||
if self.connection.in_atomic_block:
|
||||
raise NotSupportedError(
|
||||
(
|
||||
"Renaming the %r table while in a transaction is not "
|
||||
"supported on SQLite < 3.26 because it would break referential "
|
||||
"integrity. Try adding `atomic = False` to the Migration class."
|
||||
)
|
||||
% old_db_table
|
||||
)
|
||||
self.connection.enable_constraint_checking()
|
||||
super().alter_db_table(model, old_db_table, new_db_table)
|
||||
self.connection.disable_constraint_checking()
|
||||
else:
|
||||
super().alter_db_table(model, old_db_table, new_db_table)
|
||||
|
||||
def alter_field(self, model, old_field, new_field, strict=False):
|
||||
if not self._field_should_be_altered(old_field, new_field):
|
||||
return
|
||||
old_field_name = old_field.name
|
||||
table_name = model._meta.db_table
|
||||
_, old_column_name = old_field.get_attname_column()
|
||||
if (
|
||||
new_field.name != old_field_name
|
||||
and not self.connection.features.supports_atomic_references_rename
|
||||
and self._is_referenced_by_fk_constraint(
|
||||
table_name, old_column_name, ignore_self=True
|
||||
)
|
||||
):
|
||||
if self.connection.in_atomic_block:
|
||||
raise NotSupportedError(
|
||||
(
|
||||
"Renaming the %r.%r column while in a transaction is not "
|
||||
"supported on SQLite < 3.26 because it would break referential "
|
||||
"integrity. Try adding `atomic = False` to the Migration class."
|
||||
)
|
||||
% (model._meta.db_table, old_field_name)
|
||||
)
|
||||
with atomic(self.connection.alias):
|
||||
super().alter_field(model, old_field, new_field, strict=strict)
|
||||
# Follow SQLite's documented procedure for performing changes
|
||||
# that don't affect the on-disk content.
|
||||
# https://sqlite.org/lang_altertable.html#otheralter
|
||||
with self.connection.cursor() as cursor:
|
||||
schema_version = cursor.execute("PRAGMA schema_version").fetchone()[
|
||||
0
|
||||
]
|
||||
cursor.execute("PRAGMA writable_schema = 1")
|
||||
references_template = ' REFERENCES "%s" ("%%s") ' % table_name
|
||||
new_column_name = new_field.get_attname_column()[1]
|
||||
search = references_template % old_column_name
|
||||
replacement = references_template % new_column_name
|
||||
cursor.execute(
|
||||
"UPDATE sqlite_master SET sql = replace(sql, %s, %s)",
|
||||
(search, replacement),
|
||||
)
|
||||
cursor.execute("PRAGMA schema_version = %d" % (schema_version + 1))
|
||||
cursor.execute("PRAGMA writable_schema = 0")
|
||||
# The integrity check will raise an exception and rollback
|
||||
# the transaction if the sqlite_master updates corrupt the
|
||||
# database.
|
||||
cursor.execute("PRAGMA integrity_check")
|
||||
# Perform a VACUUM to refresh the database representation from
|
||||
# the sqlite_master table.
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute("VACUUM")
|
||||
else:
|
||||
super().alter_field(model, old_field, new_field, strict=strict)
|
||||
|
||||
def _remake_table(
|
||||
self, model, create_field=None, delete_field=None, alter_fields=None
|
||||
):
|
||||
"""
|
||||
Shortcut to transform a model from old_model into new_model
|
||||
|
||||
This follows the correct procedure to perform non-rename or column
|
||||
addition operations based on SQLite's documentation
|
||||
|
||||
https://www.sqlite.org/lang_altertable.html#caution
|
||||
|
||||
The essential steps are:
|
||||
1. Create a table with the updated definition called "new__app_model"
|
||||
2. Copy the data from the existing "app_model" table to the new table
|
||||
3. Drop the "app_model" table
|
||||
4. Rename the "new__app_model" table to "app_model"
|
||||
5. Restore any index of the previous "app_model" table.
|
||||
"""
|
||||
|
||||
# Self-referential fields must be recreated rather than copied from
|
||||
# the old model to ensure their remote_field.field_name doesn't refer
|
||||
# to an altered field.
|
||||
def is_self_referential(f):
|
||||
return f.is_relation and f.remote_field.model is model
|
||||
|
||||
# Work out the new fields dict / mapping
|
||||
body = {
|
||||
f.name: f.clone() if is_self_referential(f) else f
|
||||
for f in model._meta.local_concrete_fields
|
||||
}
|
||||
# Since mapping might mix column names and default values,
|
||||
# its values must be already quoted.
|
||||
mapping = {
|
||||
f.column: self.quote_name(f.column)
|
||||
for f in model._meta.local_concrete_fields
|
||||
}
|
||||
# This maps field names (not columns) for things like unique_together
|
||||
rename_mapping = {}
|
||||
# If any of the new or altered fields is introducing a new PK,
|
||||
# remove the old one
|
||||
restore_pk_field = None
|
||||
alter_fields = alter_fields or []
|
||||
if getattr(create_field, "primary_key", False) or any(
|
||||
getattr(new_field, "primary_key", False) for _, new_field in alter_fields
|
||||
):
|
||||
for name, field in list(body.items()):
|
||||
if field.primary_key and not any(
|
||||
# Do not remove the old primary key when an altered field
|
||||
# that introduces a primary key is the same field.
|
||||
name == new_field.name
|
||||
for _, new_field in alter_fields
|
||||
):
|
||||
field.primary_key = False
|
||||
restore_pk_field = field
|
||||
if field.auto_created:
|
||||
del body[name]
|
||||
del mapping[field.column]
|
||||
# Add in any created fields
|
||||
if create_field:
|
||||
body[create_field.name] = create_field
|
||||
# Choose a default and insert it into the copy map
|
||||
if not create_field.many_to_many and create_field.concrete:
|
||||
mapping[create_field.column] = self.prepare_default(
|
||||
self.effective_default(create_field),
|
||||
)
|
||||
# Add in any altered fields
|
||||
for alter_field in alter_fields:
|
||||
old_field, new_field = alter_field
|
||||
body.pop(old_field.name, None)
|
||||
mapping.pop(old_field.column, None)
|
||||
body[new_field.name] = new_field
|
||||
if old_field.null and not new_field.null:
|
||||
case_sql = "coalesce(%(col)s, %(default)s)" % {
|
||||
"col": self.quote_name(old_field.column),
|
||||
"default": self.prepare_default(self.effective_default(new_field)),
|
||||
}
|
||||
mapping[new_field.column] = case_sql
|
||||
else:
|
||||
mapping[new_field.column] = self.quote_name(old_field.column)
|
||||
rename_mapping[old_field.name] = new_field.name
|
||||
# Remove any deleted fields
|
||||
if delete_field:
|
||||
del body[delete_field.name]
|
||||
del mapping[delete_field.column]
|
||||
# Remove any implicit M2M tables
|
||||
if (
|
||||
delete_field.many_to_many
|
||||
and delete_field.remote_field.through._meta.auto_created
|
||||
):
|
||||
return self.delete_model(delete_field.remote_field.through)
|
||||
# Work inside a new app registry
|
||||
apps = Apps()
|
||||
|
||||
# Work out the new value of unique_together, taking renames into
|
||||
# account
|
||||
unique_together = [
|
||||
[rename_mapping.get(n, n) for n in unique]
|
||||
for unique in model._meta.unique_together
|
||||
]
|
||||
|
||||
# Work out the new value for index_together, taking renames into
|
||||
# account
|
||||
index_together = [
|
||||
[rename_mapping.get(n, n) for n in index]
|
||||
for index in model._meta.index_together
|
||||
]
|
||||
|
||||
indexes = model._meta.indexes
|
||||
if delete_field:
|
||||
indexes = [
|
||||
index for index in indexes if delete_field.name not in index.fields
|
||||
]
|
||||
|
||||
constraints = list(model._meta.constraints)
|
||||
|
||||
# Provide isolated instances of the fields to the new model body so
|
||||
# that the existing model's internals aren't interfered with when
|
||||
# the dummy model is constructed.
|
||||
body_copy = copy.deepcopy(body)
|
||||
|
||||
# Construct a new model with the new fields to allow self referential
|
||||
# primary key to resolve to. This model won't ever be materialized as a
|
||||
# table and solely exists for foreign key reference resolution purposes.
|
||||
# This wouldn't be required if the schema editor was operating on model
|
||||
# states instead of rendered models.
|
||||
meta_contents = {
|
||||
"app_label": model._meta.app_label,
|
||||
"db_table": model._meta.db_table,
|
||||
"unique_together": unique_together,
|
||||
"index_together": index_together,
|
||||
"indexes": indexes,
|
||||
"constraints": constraints,
|
||||
"apps": apps,
|
||||
}
|
||||
meta = type("Meta", (), meta_contents)
|
||||
body_copy["Meta"] = meta
|
||||
body_copy["__module__"] = model.__module__
|
||||
type(model._meta.object_name, model.__bases__, body_copy)
|
||||
|
||||
# Construct a model with a renamed table name.
|
||||
body_copy = copy.deepcopy(body)
|
||||
meta_contents = {
|
||||
"app_label": model._meta.app_label,
|
||||
"db_table": "new__%s" % strip_quotes(model._meta.db_table),
|
||||
"unique_together": unique_together,
|
||||
"index_together": index_together,
|
||||
"indexes": indexes,
|
||||
"constraints": constraints,
|
||||
"apps": apps,
|
||||
}
|
||||
meta = type("Meta", (), meta_contents)
|
||||
body_copy["Meta"] = meta
|
||||
body_copy["__module__"] = model.__module__
|
||||
new_model = type("New%s" % model._meta.object_name, model.__bases__, body_copy)
|
||||
|
||||
# Create a new table with the updated schema.
|
||||
self.create_model(new_model)
|
||||
|
||||
# Copy data from the old table into the new table
|
||||
self.execute(
|
||||
"INSERT INTO %s (%s) SELECT %s FROM %s"
|
||||
% (
|
||||
self.quote_name(new_model._meta.db_table),
|
||||
", ".join(self.quote_name(x) for x in mapping),
|
||||
", ".join(mapping.values()),
|
||||
self.quote_name(model._meta.db_table),
|
||||
)
|
||||
)
|
||||
|
||||
# Delete the old table to make way for the new
|
||||
self.delete_model(model, handle_autom2m=False)
|
||||
|
||||
# Rename the new table to take way for the old
|
||||
self.alter_db_table(
|
||||
new_model,
|
||||
new_model._meta.db_table,
|
||||
model._meta.db_table,
|
||||
disable_constraints=False,
|
||||
)
|
||||
|
||||
# Run deferred SQL on correct table
|
||||
for sql in self.deferred_sql:
|
||||
self.execute(sql)
|
||||
self.deferred_sql = []
|
||||
# Fix any PK-removed field
|
||||
if restore_pk_field:
|
||||
restore_pk_field.primary_key = True
|
||||
|
||||
def delete_model(self, model, handle_autom2m=True):
|
||||
if handle_autom2m:
|
||||
super().delete_model(model)
|
||||
else:
|
||||
# Delete the table (and only that)
|
||||
self.execute(
|
||||
self.sql_delete_table
|
||||
% {
|
||||
"table": self.quote_name(model._meta.db_table),
|
||||
}
|
||||
)
|
||||
# Remove all deferred statements referencing the deleted table.
|
||||
for sql in list(self.deferred_sql):
|
||||
if isinstance(sql, Statement) and sql.references_table(
|
||||
model._meta.db_table
|
||||
):
|
||||
self.deferred_sql.remove(sql)
|
||||
|
||||
def add_field(self, model, field):
|
||||
"""Create a field on a model."""
|
||||
# Special-case implicit M2M tables.
|
||||
if field.many_to_many and field.remote_field.through._meta.auto_created:
|
||||
self.create_model(field.remote_field.through)
|
||||
elif (
|
||||
# Primary keys and unique fields are not supported in ALTER TABLE
|
||||
# ADD COLUMN.
|
||||
field.primary_key
|
||||
or field.unique
|
||||
or
|
||||
# Fields with default values cannot by handled by ALTER TABLE ADD
|
||||
# COLUMN statement because DROP DEFAULT is not supported in
|
||||
# ALTER TABLE.
|
||||
not field.null
|
||||
or self.effective_default(field) is not None
|
||||
):
|
||||
self._remake_table(model, create_field=field)
|
||||
else:
|
||||
super().add_field(model, field)
|
||||
|
||||
def remove_field(self, model, field):
|
||||
"""
|
||||
Remove a field from a model. Usually involves deleting a column,
|
||||
but for M2Ms may involve deleting a table.
|
||||
"""
|
||||
# M2M fields are a special case
|
||||
if field.many_to_many:
|
||||
# For implicit M2M tables, delete the auto-created table
|
||||
if field.remote_field.through._meta.auto_created:
|
||||
self.delete_model(field.remote_field.through)
|
||||
# For explicit "through" M2M fields, do nothing
|
||||
elif (
|
||||
self.connection.features.can_alter_table_drop_column
|
||||
# Primary keys, unique fields, indexed fields, and foreign keys are
|
||||
# not supported in ALTER TABLE DROP COLUMN.
|
||||
and not field.primary_key
|
||||
and not field.unique
|
||||
and not field.db_index
|
||||
and not (field.remote_field and field.db_constraint)
|
||||
):
|
||||
super().remove_field(model, field)
|
||||
# For everything else, remake.
|
||||
else:
|
||||
# It might not actually have a column behind it
|
||||
if field.db_parameters(connection=self.connection)["type"] is None:
|
||||
return
|
||||
self._remake_table(model, delete_field=field)
|
||||
|
||||
def _alter_field(
|
||||
self,
|
||||
model,
|
||||
old_field,
|
||||
new_field,
|
||||
old_type,
|
||||
new_type,
|
||||
old_db_params,
|
||||
new_db_params,
|
||||
strict=False,
|
||||
):
|
||||
"""Perform a "physical" (non-ManyToMany) field update."""
|
||||
# Use "ALTER TABLE ... RENAME COLUMN" if only the column name
|
||||
# changed and there aren't any constraints.
|
||||
if (
|
||||
self.connection.features.can_alter_table_rename_column
|
||||
and old_field.column != new_field.column
|
||||
and self.column_sql(model, old_field) == self.column_sql(model, new_field)
|
||||
and not (
|
||||
old_field.remote_field
|
||||
and old_field.db_constraint
|
||||
or new_field.remote_field
|
||||
and new_field.db_constraint
|
||||
)
|
||||
):
|
||||
return self.execute(
|
||||
self._rename_field_sql(
|
||||
model._meta.db_table, old_field, new_field, new_type
|
||||
)
|
||||
)
|
||||
# Alter by remaking table
|
||||
self._remake_table(model, alter_fields=[(old_field, new_field)])
|
||||
# Rebuild tables with FKs pointing to this field.
|
||||
old_collation = old_db_params.get("collation")
|
||||
new_collation = new_db_params.get("collation")
|
||||
if new_field.unique and (
|
||||
old_type != new_type or old_collation != new_collation
|
||||
):
|
||||
related_models = set()
|
||||
opts = new_field.model._meta
|
||||
for remote_field in opts.related_objects:
|
||||
# Ignore self-relationship since the table was already rebuilt.
|
||||
if remote_field.related_model == model:
|
||||
continue
|
||||
if not remote_field.many_to_many:
|
||||
if remote_field.field_name == new_field.name:
|
||||
related_models.add(remote_field.related_model)
|
||||
elif new_field.primary_key and remote_field.through._meta.auto_created:
|
||||
related_models.add(remote_field.through)
|
||||
if new_field.primary_key:
|
||||
for many_to_many in opts.many_to_many:
|
||||
# Ignore self-relationship since the table was already rebuilt.
|
||||
if many_to_many.related_model == model:
|
||||
continue
|
||||
if many_to_many.remote_field.through._meta.auto_created:
|
||||
related_models.add(many_to_many.remote_field.through)
|
||||
for related_model in related_models:
|
||||
self._remake_table(related_model)
|
||||
|
||||
def _alter_many_to_many(self, model, old_field, new_field, strict):
|
||||
"""Alter M2Ms to repoint their to= endpoints."""
|
||||
if (
|
||||
old_field.remote_field.through._meta.db_table
|
||||
== new_field.remote_field.through._meta.db_table
|
||||
):
|
||||
# The field name didn't change, but some options did, so we have to
|
||||
# propagate this altering.
|
||||
self._remake_table(
|
||||
old_field.remote_field.through,
|
||||
alter_fields=[
|
||||
(
|
||||
# The field that points to the target model is needed,
|
||||
# so that table can be remade with the new m2m field -
|
||||
# this is m2m_reverse_field_name().
|
||||
old_field.remote_field.through._meta.get_field(
|
||||
old_field.m2m_reverse_field_name()
|
||||
),
|
||||
new_field.remote_field.through._meta.get_field(
|
||||
new_field.m2m_reverse_field_name()
|
||||
),
|
||||
),
|
||||
(
|
||||
# The field that points to the model itself is needed,
|
||||
# so that table can be remade with the new self field -
|
||||
# this is m2m_field_name().
|
||||
old_field.remote_field.through._meta.get_field(
|
||||
old_field.m2m_field_name()
|
||||
),
|
||||
new_field.remote_field.through._meta.get_field(
|
||||
new_field.m2m_field_name()
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
return
|
||||
|
||||
# Make a new through table
|
||||
self.create_model(new_field.remote_field.through)
|
||||
# Copy the data across
|
||||
self.execute(
|
||||
"INSERT INTO %s (%s) SELECT %s FROM %s"
|
||||
% (
|
||||
self.quote_name(new_field.remote_field.through._meta.db_table),
|
||||
", ".join(
|
||||
[
|
||||
"id",
|
||||
new_field.m2m_column_name(),
|
||||
new_field.m2m_reverse_name(),
|
||||
]
|
||||
),
|
||||
", ".join(
|
||||
[
|
||||
"id",
|
||||
old_field.m2m_column_name(),
|
||||
old_field.m2m_reverse_name(),
|
||||
]
|
||||
),
|
||||
self.quote_name(old_field.remote_field.through._meta.db_table),
|
||||
)
|
||||
)
|
||||
# Delete the old through table
|
||||
self.delete_model(old_field.remote_field.through)
|
||||
|
||||
def add_constraint(self, model, constraint):
|
||||
if isinstance(constraint, UniqueConstraint) and (
|
||||
constraint.condition
|
||||
or constraint.contains_expressions
|
||||
or constraint.include
|
||||
or constraint.deferrable
|
||||
):
|
||||
super().add_constraint(model, constraint)
|
||||
else:
|
||||
self._remake_table(model)
|
||||
|
||||
def remove_constraint(self, model, constraint):
|
||||
if isinstance(constraint, UniqueConstraint) and (
|
||||
constraint.condition
|
||||
or constraint.contains_expressions
|
||||
or constraint.include
|
||||
or constraint.deferrable
|
||||
):
|
||||
super().remove_constraint(model, constraint)
|
||||
else:
|
||||
self._remake_table(model)
|
||||
|
||||
def _collate_sql(self, collation):
|
||||
return "COLLATE " + collation
|
||||
@@ -0,0 +1,320 @@
|
||||
import datetime
|
||||
import decimal
|
||||
import functools
|
||||
import logging
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
|
||||
from django.db import NotSupportedError
|
||||
from django.utils.crypto import md5
|
||||
from django.utils.dateparse import parse_time
|
||||
|
||||
logger = logging.getLogger("django.db.backends")
|
||||
|
||||
|
||||
class CursorWrapper:
|
||||
def __init__(self, cursor, db):
|
||||
self.cursor = cursor
|
||||
self.db = db
|
||||
|
||||
WRAP_ERROR_ATTRS = frozenset(["fetchone", "fetchmany", "fetchall", "nextset"])
|
||||
|
||||
def __getattr__(self, attr):
|
||||
cursor_attr = getattr(self.cursor, attr)
|
||||
if attr in CursorWrapper.WRAP_ERROR_ATTRS:
|
||||
return self.db.wrap_database_errors(cursor_attr)
|
||||
else:
|
||||
return cursor_attr
|
||||
|
||||
def __iter__(self):
|
||||
with self.db.wrap_database_errors:
|
||||
yield from self.cursor
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
# Close instead of passing through to avoid backend-specific behavior
|
||||
# (#17671). Catch errors liberally because errors in cleanup code
|
||||
# aren't useful.
|
||||
try:
|
||||
self.close()
|
||||
except self.db.Database.Error:
|
||||
pass
|
||||
|
||||
# The following methods cannot be implemented in __getattr__, because the
|
||||
# code must run when the method is invoked, not just when it is accessed.
|
||||
|
||||
def callproc(self, procname, params=None, kparams=None):
|
||||
# Keyword parameters for callproc aren't supported in PEP 249, but the
|
||||
# database driver may support them (e.g. cx_Oracle).
|
||||
if kparams is not None and not self.db.features.supports_callproc_kwargs:
|
||||
raise NotSupportedError(
|
||||
"Keyword parameters for callproc are not supported on this "
|
||||
"database backend."
|
||||
)
|
||||
self.db.validate_no_broken_transaction()
|
||||
with self.db.wrap_database_errors:
|
||||
if params is None and kparams is None:
|
||||
return self.cursor.callproc(procname)
|
||||
elif kparams is None:
|
||||
return self.cursor.callproc(procname, params)
|
||||
else:
|
||||
params = params or ()
|
||||
return self.cursor.callproc(procname, params, kparams)
|
||||
|
||||
def execute(self, sql, params=None):
|
||||
return self._execute_with_wrappers(
|
||||
sql, params, many=False, executor=self._execute
|
||||
)
|
||||
|
||||
def executemany(self, sql, param_list):
|
||||
return self._execute_with_wrappers(
|
||||
sql, param_list, many=True, executor=self._executemany
|
||||
)
|
||||
|
||||
def _execute_with_wrappers(self, sql, params, many, executor):
|
||||
context = {"connection": self.db, "cursor": self}
|
||||
for wrapper in reversed(self.db.execute_wrappers):
|
||||
executor = functools.partial(wrapper, executor)
|
||||
return executor(sql, params, many, context)
|
||||
|
||||
def _execute(self, sql, params, *ignored_wrapper_args):
|
||||
self.db.validate_no_broken_transaction()
|
||||
with self.db.wrap_database_errors:
|
||||
if params is None:
|
||||
# params default might be backend specific.
|
||||
return self.cursor.execute(sql)
|
||||
else:
|
||||
return self.cursor.execute(sql, params)
|
||||
|
||||
def _executemany(self, sql, param_list, *ignored_wrapper_args):
|
||||
self.db.validate_no_broken_transaction()
|
||||
with self.db.wrap_database_errors:
|
||||
return self.cursor.executemany(sql, param_list)
|
||||
|
||||
|
||||
class CursorDebugWrapper(CursorWrapper):
|
||||
# XXX callproc isn't instrumented at this time.
|
||||
|
||||
def execute(self, sql, params=None):
|
||||
with self.debug_sql(sql, params, use_last_executed_query=True):
|
||||
return super().execute(sql, params)
|
||||
|
||||
def executemany(self, sql, param_list):
|
||||
with self.debug_sql(sql, param_list, many=True):
|
||||
return super().executemany(sql, param_list)
|
||||
|
||||
@contextmanager
|
||||
def debug_sql(
|
||||
self, sql=None, params=None, use_last_executed_query=False, many=False
|
||||
):
|
||||
start = time.monotonic()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
stop = time.monotonic()
|
||||
duration = stop - start
|
||||
if use_last_executed_query:
|
||||
sql = self.db.ops.last_executed_query(self.cursor, sql, params)
|
||||
try:
|
||||
times = len(params) if many else ""
|
||||
except TypeError:
|
||||
# params could be an iterator.
|
||||
times = "?"
|
||||
self.db.queries_log.append(
|
||||
{
|
||||
"sql": "%s times: %s" % (times, sql) if many else sql,
|
||||
"time": "%.3f" % duration,
|
||||
}
|
||||
)
|
||||
logger.debug(
|
||||
"(%.3f) %s; args=%s; alias=%s",
|
||||
duration,
|
||||
sql,
|
||||
params,
|
||||
self.db.alias,
|
||||
extra={
|
||||
"duration": duration,
|
||||
"sql": sql,
|
||||
"params": params,
|
||||
"alias": self.db.alias,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def debug_transaction(connection, sql):
|
||||
start = time.monotonic()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if connection.queries_logged:
|
||||
stop = time.monotonic()
|
||||
duration = stop - start
|
||||
connection.queries_log.append(
|
||||
{
|
||||
"sql": "%s" % sql,
|
||||
"time": "%.3f" % duration,
|
||||
}
|
||||
)
|
||||
logger.debug(
|
||||
"(%.3f) %s; args=%s; alias=%s",
|
||||
duration,
|
||||
sql,
|
||||
None,
|
||||
connection.alias,
|
||||
extra={
|
||||
"duration": duration,
|
||||
"sql": sql,
|
||||
"alias": connection.alias,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def split_tzname_delta(tzname):
|
||||
"""
|
||||
Split a time zone name into a 3-tuple of (name, sign, offset).
|
||||
"""
|
||||
for sign in ["+", "-"]:
|
||||
if sign in tzname:
|
||||
name, offset = tzname.rsplit(sign, 1)
|
||||
if offset and parse_time(offset):
|
||||
return name, sign, offset
|
||||
return tzname, None, None
|
||||
|
||||
|
||||
###############################################
|
||||
# Converters from database (string) to Python #
|
||||
###############################################
|
||||
|
||||
|
||||
def typecast_date(s):
|
||||
return (
|
||||
datetime.date(*map(int, s.split("-"))) if s else None
|
||||
) # return None if s is null
|
||||
|
||||
|
||||
def typecast_time(s): # does NOT store time zone information
|
||||
if not s:
|
||||
return None
|
||||
hour, minutes, seconds = s.split(":")
|
||||
if "." in seconds: # check whether seconds have a fractional part
|
||||
seconds, microseconds = seconds.split(".")
|
||||
else:
|
||||
microseconds = "0"
|
||||
return datetime.time(
|
||||
int(hour), int(minutes), int(seconds), int((microseconds + "000000")[:6])
|
||||
)
|
||||
|
||||
|
||||
def typecast_timestamp(s): # does NOT store time zone information
|
||||
# "2005-07-29 15:48:00.590358-05"
|
||||
# "2005-07-29 09:56:00-05"
|
||||
if not s:
|
||||
return None
|
||||
if " " not in s:
|
||||
return typecast_date(s)
|
||||
d, t = s.split()
|
||||
# Remove timezone information.
|
||||
if "-" in t:
|
||||
t, _ = t.split("-", 1)
|
||||
elif "+" in t:
|
||||
t, _ = t.split("+", 1)
|
||||
dates = d.split("-")
|
||||
times = t.split(":")
|
||||
seconds = times[2]
|
||||
if "." in seconds: # check whether seconds have a fractional part
|
||||
seconds, microseconds = seconds.split(".")
|
||||
else:
|
||||
microseconds = "0"
|
||||
return datetime.datetime(
|
||||
int(dates[0]),
|
||||
int(dates[1]),
|
||||
int(dates[2]),
|
||||
int(times[0]),
|
||||
int(times[1]),
|
||||
int(seconds),
|
||||
int((microseconds + "000000")[:6]),
|
||||
)
|
||||
|
||||
|
||||
###############################################
|
||||
# Converters from Python to database (string) #
|
||||
###############################################
|
||||
|
||||
|
||||
def split_identifier(identifier):
|
||||
"""
|
||||
Split an SQL identifier into a two element tuple of (namespace, name).
|
||||
|
||||
The identifier could be a table, column, or sequence name might be prefixed
|
||||
by a namespace.
|
||||
"""
|
||||
try:
|
||||
namespace, name = identifier.split('"."')
|
||||
except ValueError:
|
||||
namespace, name = "", identifier
|
||||
return namespace.strip('"'), name.strip('"')
|
||||
|
||||
|
||||
def truncate_name(identifier, length=None, hash_len=4):
|
||||
"""
|
||||
Shorten an SQL identifier to a repeatable mangled version with the given
|
||||
length.
|
||||
|
||||
If a quote stripped name contains a namespace, e.g. USERNAME"."TABLE,
|
||||
truncate the table portion only.
|
||||
"""
|
||||
namespace, name = split_identifier(identifier)
|
||||
|
||||
if length is None or len(name) <= length:
|
||||
return identifier
|
||||
|
||||
digest = names_digest(name, length=hash_len)
|
||||
return "%s%s%s" % (
|
||||
'%s"."' % namespace if namespace else "",
|
||||
name[: length - hash_len],
|
||||
digest,
|
||||
)
|
||||
|
||||
|
||||
def names_digest(*args, length):
|
||||
"""
|
||||
Generate a 32-bit digest of a set of arguments that can be used to shorten
|
||||
identifying names.
|
||||
"""
|
||||
h = md5(usedforsecurity=False)
|
||||
for arg in args:
|
||||
h.update(arg.encode())
|
||||
return h.hexdigest()[:length]
|
||||
|
||||
|
||||
def format_number(value, max_digits, decimal_places):
|
||||
"""
|
||||
Format a number into a string with the requisite number of digits and
|
||||
decimal places.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
context = decimal.getcontext().copy()
|
||||
if max_digits is not None:
|
||||
context.prec = max_digits
|
||||
if decimal_places is not None:
|
||||
value = value.quantize(
|
||||
decimal.Decimal(1).scaleb(-decimal_places), context=context
|
||||
)
|
||||
else:
|
||||
context.traps[decimal.Rounded] = 1
|
||||
value = context.create_decimal(value)
|
||||
return "{:f}".format(value)
|
||||
|
||||
|
||||
def strip_quotes(table_name):
|
||||
"""
|
||||
Strip quotes off of quoted table names to make them safe for use in index
|
||||
names, sequence names, etc. For example '"USER"."TABLE"' (an Oracle naming
|
||||
scheme) becomes 'USER"."TABLE'.
|
||||
"""
|
||||
has_quotes = table_name.startswith('"') and table_name.endswith('"')
|
||||
return table_name[1:-1] if has_quotes else table_name
|
||||
@@ -0,0 +1,115 @@
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from django.db.models import signals
|
||||
from django.db.models.aggregates import * # NOQA
|
||||
from django.db.models.aggregates import __all__ as aggregates_all
|
||||
from django.db.models.constraints import * # NOQA
|
||||
from django.db.models.constraints import __all__ as constraints_all
|
||||
from django.db.models.deletion import (
|
||||
CASCADE,
|
||||
DO_NOTHING,
|
||||
PROTECT,
|
||||
RESTRICT,
|
||||
SET,
|
||||
SET_DEFAULT,
|
||||
SET_NULL,
|
||||
ProtectedError,
|
||||
RestrictedError,
|
||||
)
|
||||
from django.db.models.enums import * # NOQA
|
||||
from django.db.models.enums import __all__ as enums_all
|
||||
from django.db.models.expressions import (
|
||||
Case,
|
||||
Exists,
|
||||
Expression,
|
||||
ExpressionList,
|
||||
ExpressionWrapper,
|
||||
F,
|
||||
Func,
|
||||
OrderBy,
|
||||
OuterRef,
|
||||
RowRange,
|
||||
Subquery,
|
||||
Value,
|
||||
ValueRange,
|
||||
When,
|
||||
Window,
|
||||
WindowFrame,
|
||||
)
|
||||
from django.db.models.fields import * # NOQA
|
||||
from django.db.models.fields import __all__ as fields_all
|
||||
from django.db.models.fields.files import FileField, ImageField
|
||||
from django.db.models.fields.json import JSONField
|
||||
from django.db.models.fields.proxy import OrderWrt
|
||||
from django.db.models.indexes import * # NOQA
|
||||
from django.db.models.indexes import __all__ as indexes_all
|
||||
from django.db.models.lookups import Lookup, Transform
|
||||
from django.db.models.manager import Manager
|
||||
from django.db.models.query import Prefetch, QuerySet, prefetch_related_objects
|
||||
from django.db.models.query_utils import FilteredRelation, Q
|
||||
|
||||
# Imports that would create circular imports if sorted
|
||||
from django.db.models.base import DEFERRED, Model # isort:skip
|
||||
from django.db.models.fields.related import ( # isort:skip
|
||||
ForeignKey,
|
||||
ForeignObject,
|
||||
OneToOneField,
|
||||
ManyToManyField,
|
||||
ForeignObjectRel,
|
||||
ManyToOneRel,
|
||||
ManyToManyRel,
|
||||
OneToOneRel,
|
||||
)
|
||||
|
||||
|
||||
__all__ = aggregates_all + constraints_all + enums_all + fields_all + indexes_all
|
||||
__all__ += [
|
||||
"ObjectDoesNotExist",
|
||||
"signals",
|
||||
"CASCADE",
|
||||
"DO_NOTHING",
|
||||
"PROTECT",
|
||||
"RESTRICT",
|
||||
"SET",
|
||||
"SET_DEFAULT",
|
||||
"SET_NULL",
|
||||
"ProtectedError",
|
||||
"RestrictedError",
|
||||
"Case",
|
||||
"Exists",
|
||||
"Expression",
|
||||
"ExpressionList",
|
||||
"ExpressionWrapper",
|
||||
"F",
|
||||
"Func",
|
||||
"OrderBy",
|
||||
"OuterRef",
|
||||
"RowRange",
|
||||
"Subquery",
|
||||
"Value",
|
||||
"ValueRange",
|
||||
"When",
|
||||
"Window",
|
||||
"WindowFrame",
|
||||
"FileField",
|
||||
"ImageField",
|
||||
"JSONField",
|
||||
"OrderWrt",
|
||||
"Lookup",
|
||||
"Transform",
|
||||
"Manager",
|
||||
"Prefetch",
|
||||
"Q",
|
||||
"QuerySet",
|
||||
"prefetch_related_objects",
|
||||
"DEFERRED",
|
||||
"Model",
|
||||
"FilteredRelation",
|
||||
"ForeignKey",
|
||||
"ForeignObject",
|
||||
"OneToOneField",
|
||||
"ManyToManyField",
|
||||
"ForeignObjectRel",
|
||||
"ManyToOneRel",
|
||||
"ManyToManyRel",
|
||||
"OneToOneRel",
|
||||
]
|
||||
@@ -0,0 +1,210 @@
|
||||
"""
|
||||
Classes to represent the definitions of aggregate functions.
|
||||
"""
|
||||
from django.core.exceptions import FieldError, FullResultSet
|
||||
from django.db.models.expressions import Case, Func, Star, Value, When
|
||||
from django.db.models.fields import IntegerField
|
||||
from django.db.models.functions.comparison import Coalesce
|
||||
from django.db.models.functions.mixins import (
|
||||
FixDurationInputMixin,
|
||||
NumericOutputFieldMixin,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Aggregate",
|
||||
"Avg",
|
||||
"Count",
|
||||
"Max",
|
||||
"Min",
|
||||
"StdDev",
|
||||
"Sum",
|
||||
"Variance",
|
||||
]
|
||||
|
||||
|
||||
class Aggregate(Func):
|
||||
template = "%(function)s(%(distinct)s%(expressions)s)"
|
||||
contains_aggregate = True
|
||||
name = None
|
||||
filter_template = "%s FILTER (WHERE %%(filter)s)"
|
||||
window_compatible = True
|
||||
allow_distinct = False
|
||||
empty_result_set_value = None
|
||||
|
||||
def __init__(
|
||||
self, *expressions, distinct=False, filter=None, default=None, **extra
|
||||
):
|
||||
if distinct and not self.allow_distinct:
|
||||
raise TypeError("%s does not allow distinct." % self.__class__.__name__)
|
||||
if default is not None and self.empty_result_set_value is not None:
|
||||
raise TypeError(f"{self.__class__.__name__} does not allow default.")
|
||||
self.distinct = distinct
|
||||
self.filter = filter
|
||||
self.default = default
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
def get_source_fields(self):
|
||||
# Don't return the filter expression since it's not a source field.
|
||||
return [e._output_field_or_none for e in super().get_source_expressions()]
|
||||
|
||||
def get_source_expressions(self):
|
||||
source_expressions = super().get_source_expressions()
|
||||
if self.filter:
|
||||
return source_expressions + [self.filter]
|
||||
return source_expressions
|
||||
|
||||
def set_source_expressions(self, exprs):
|
||||
self.filter = self.filter and exprs.pop()
|
||||
return super().set_source_expressions(exprs)
|
||||
|
||||
def resolve_expression(
|
||||
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
|
||||
):
|
||||
# Aggregates are not allowed in UPDATE queries, so ignore for_save
|
||||
c = super().resolve_expression(query, allow_joins, reuse, summarize)
|
||||
c.filter = c.filter and c.filter.resolve_expression(
|
||||
query, allow_joins, reuse, summarize
|
||||
)
|
||||
if summarize:
|
||||
# Summarized aggregates cannot refer to summarized aggregates.
|
||||
for ref in c.get_refs():
|
||||
if query.annotations[ref].is_summary:
|
||||
raise FieldError(
|
||||
f"Cannot compute {c.name}('{ref}'): '{ref}' is an aggregate"
|
||||
)
|
||||
elif not self.is_summary:
|
||||
# Call Aggregate.get_source_expressions() to avoid
|
||||
# returning self.filter and including that in this loop.
|
||||
expressions = super(Aggregate, c).get_source_expressions()
|
||||
for index, expr in enumerate(expressions):
|
||||
if expr.contains_aggregate:
|
||||
before_resolved = self.get_source_expressions()[index]
|
||||
name = (
|
||||
before_resolved.name
|
||||
if hasattr(before_resolved, "name")
|
||||
else repr(before_resolved)
|
||||
)
|
||||
raise FieldError(
|
||||
"Cannot compute %s('%s'): '%s' is an aggregate"
|
||||
% (c.name, name, name)
|
||||
)
|
||||
if (default := c.default) is None:
|
||||
return c
|
||||
if hasattr(default, "resolve_expression"):
|
||||
default = default.resolve_expression(query, allow_joins, reuse, summarize)
|
||||
if default._output_field_or_none is None:
|
||||
default.output_field = c._output_field_or_none
|
||||
else:
|
||||
default = Value(default, c._output_field_or_none)
|
||||
c.default = None # Reset the default argument before wrapping.
|
||||
coalesce = Coalesce(c, default, output_field=c._output_field_or_none)
|
||||
coalesce.is_summary = c.is_summary
|
||||
return coalesce
|
||||
|
||||
@property
|
||||
def default_alias(self):
|
||||
expressions = self.get_source_expressions()
|
||||
if len(expressions) == 1 and hasattr(expressions[0], "name"):
|
||||
return "%s__%s" % (expressions[0].name, self.name.lower())
|
||||
raise TypeError("Complex expressions require an alias")
|
||||
|
||||
def get_group_by_cols(self):
|
||||
return []
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
extra_context["distinct"] = "DISTINCT " if self.distinct else ""
|
||||
if self.filter:
|
||||
if connection.features.supports_aggregate_filter_clause:
|
||||
try:
|
||||
filter_sql, filter_params = self.filter.as_sql(compiler, connection)
|
||||
except FullResultSet:
|
||||
pass
|
||||
else:
|
||||
template = self.filter_template % extra_context.get(
|
||||
"template", self.template
|
||||
)
|
||||
sql, params = super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template=template,
|
||||
filter=filter_sql,
|
||||
**extra_context,
|
||||
)
|
||||
return sql, (*params, *filter_params)
|
||||
else:
|
||||
copy = self.copy()
|
||||
copy.filter = None
|
||||
source_expressions = copy.get_source_expressions()
|
||||
condition = When(self.filter, then=source_expressions[0])
|
||||
copy.set_source_expressions([Case(condition)] + source_expressions[1:])
|
||||
return super(Aggregate, copy).as_sql(
|
||||
compiler, connection, **extra_context
|
||||
)
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
def _get_repr_options(self):
|
||||
options = super()._get_repr_options()
|
||||
if self.distinct:
|
||||
options["distinct"] = self.distinct
|
||||
if self.filter:
|
||||
options["filter"] = self.filter
|
||||
return options
|
||||
|
||||
|
||||
class Avg(FixDurationInputMixin, NumericOutputFieldMixin, Aggregate):
|
||||
function = "AVG"
|
||||
name = "Avg"
|
||||
allow_distinct = True
|
||||
|
||||
|
||||
class Count(Aggregate):
|
||||
function = "COUNT"
|
||||
name = "Count"
|
||||
output_field = IntegerField()
|
||||
allow_distinct = True
|
||||
empty_result_set_value = 0
|
||||
|
||||
def __init__(self, expression, filter=None, **extra):
|
||||
if expression == "*":
|
||||
expression = Star()
|
||||
if isinstance(expression, Star) and filter is not None:
|
||||
raise ValueError("Star cannot be used with filter. Please specify a field.")
|
||||
super().__init__(expression, filter=filter, **extra)
|
||||
|
||||
|
||||
class Max(Aggregate):
|
||||
function = "MAX"
|
||||
name = "Max"
|
||||
|
||||
|
||||
class Min(Aggregate):
|
||||
function = "MIN"
|
||||
name = "Min"
|
||||
|
||||
|
||||
class StdDev(NumericOutputFieldMixin, Aggregate):
|
||||
name = "StdDev"
|
||||
|
||||
def __init__(self, expression, sample=False, **extra):
|
||||
self.function = "STDDEV_SAMP" if sample else "STDDEV_POP"
|
||||
super().__init__(expression, **extra)
|
||||
|
||||
def _get_repr_options(self):
|
||||
return {**super()._get_repr_options(), "sample": self.function == "STDDEV_SAMP"}
|
||||
|
||||
|
||||
class Sum(FixDurationInputMixin, Aggregate):
|
||||
function = "SUM"
|
||||
name = "Sum"
|
||||
allow_distinct = True
|
||||
|
||||
|
||||
class Variance(NumericOutputFieldMixin, Aggregate):
|
||||
name = "Variance"
|
||||
|
||||
def __init__(self, expression, sample=False, **extra):
|
||||
self.function = "VAR_SAMP" if sample else "VAR_POP"
|
||||
super().__init__(expression, **extra)
|
||||
|
||||
def _get_repr_options(self):
|
||||
return {**super()._get_repr_options(), "sample": self.function == "VAR_SAMP"}
|
||||
2531
srcs/.venv/lib/python3.11/site-packages/django/db/models/base.py
Normal file
2531
srcs/.venv/lib/python3.11/site-packages/django/db/models/base.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Constants used across the ORM in general.
|
||||
"""
|
||||
from enum import Enum
|
||||
|
||||
# Separator used to split filter strings apart.
|
||||
LOOKUP_SEP = "__"
|
||||
|
||||
|
||||
class OnConflict(Enum):
|
||||
IGNORE = "ignore"
|
||||
UPDATE = "update"
|
||||
@@ -0,0 +1,371 @@
|
||||
from enum import Enum
|
||||
|
||||
from django.core.exceptions import FieldError, ValidationError
|
||||
from django.db import connections
|
||||
from django.db.models.expressions import Exists, ExpressionList, F, OrderBy
|
||||
from django.db.models.indexes import IndexExpression
|
||||
from django.db.models.lookups import Exact
|
||||
from django.db.models.query_utils import Q
|
||||
from django.db.models.sql.query import Query
|
||||
from django.db.utils import DEFAULT_DB_ALIAS
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
__all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"]
|
||||
|
||||
|
||||
class BaseConstraint:
|
||||
default_violation_error_message = _("Constraint “%(name)s” is violated.")
|
||||
violation_error_message = None
|
||||
|
||||
def __init__(self, name, violation_error_message=None):
|
||||
self.name = name
|
||||
if violation_error_message is not None:
|
||||
self.violation_error_message = violation_error_message
|
||||
else:
|
||||
self.violation_error_message = self.default_violation_error_message
|
||||
|
||||
@property
|
||||
def contains_expressions(self):
|
||||
return False
|
||||
|
||||
def constraint_sql(self, model, schema_editor):
|
||||
raise NotImplementedError("This method must be implemented by a subclass.")
|
||||
|
||||
def create_sql(self, model, schema_editor):
|
||||
raise NotImplementedError("This method must be implemented by a subclass.")
|
||||
|
||||
def remove_sql(self, model, schema_editor):
|
||||
raise NotImplementedError("This method must be implemented by a subclass.")
|
||||
|
||||
def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
|
||||
raise NotImplementedError("This method must be implemented by a subclass.")
|
||||
|
||||
def get_violation_error_message(self):
|
||||
return self.violation_error_message % {"name": self.name}
|
||||
|
||||
def deconstruct(self):
|
||||
path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
|
||||
path = path.replace("django.db.models.constraints", "django.db.models")
|
||||
kwargs = {"name": self.name}
|
||||
if (
|
||||
self.violation_error_message is not None
|
||||
and self.violation_error_message != self.default_violation_error_message
|
||||
):
|
||||
kwargs["violation_error_message"] = self.violation_error_message
|
||||
return (path, (), kwargs)
|
||||
|
||||
def clone(self):
|
||||
_, args, kwargs = self.deconstruct()
|
||||
return self.__class__(*args, **kwargs)
|
||||
|
||||
|
||||
class CheckConstraint(BaseConstraint):
|
||||
def __init__(self, *, check, name, violation_error_message=None):
|
||||
self.check = check
|
||||
if not getattr(check, "conditional", False):
|
||||
raise TypeError(
|
||||
"CheckConstraint.check must be a Q instance or boolean expression."
|
||||
)
|
||||
super().__init__(name, violation_error_message=violation_error_message)
|
||||
|
||||
def _get_check_sql(self, model, schema_editor):
|
||||
query = Query(model=model, alias_cols=False)
|
||||
where = query.build_where(self.check)
|
||||
compiler = query.get_compiler(connection=schema_editor.connection)
|
||||
sql, params = where.as_sql(compiler, schema_editor.connection)
|
||||
return sql % tuple(schema_editor.quote_value(p) for p in params)
|
||||
|
||||
def constraint_sql(self, model, schema_editor):
|
||||
check = self._get_check_sql(model, schema_editor)
|
||||
return schema_editor._check_sql(self.name, check)
|
||||
|
||||
def create_sql(self, model, schema_editor):
|
||||
check = self._get_check_sql(model, schema_editor)
|
||||
return schema_editor._create_check_sql(model, self.name, check)
|
||||
|
||||
def remove_sql(self, model, schema_editor):
|
||||
return schema_editor._delete_check_sql(model, self.name)
|
||||
|
||||
def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
|
||||
against = instance._get_field_value_map(meta=model._meta, exclude=exclude)
|
||||
try:
|
||||
if not Q(self.check).check(against, using=using):
|
||||
raise ValidationError(self.get_violation_error_message())
|
||||
except FieldError:
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s: check=%s name=%s>" % (
|
||||
self.__class__.__qualname__,
|
||||
self.check,
|
||||
repr(self.name),
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, CheckConstraint):
|
||||
return (
|
||||
self.name == other.name
|
||||
and self.check == other.check
|
||||
and self.violation_error_message == other.violation_error_message
|
||||
)
|
||||
return super().__eq__(other)
|
||||
|
||||
def deconstruct(self):
|
||||
path, args, kwargs = super().deconstruct()
|
||||
kwargs["check"] = self.check
|
||||
return path, args, kwargs
|
||||
|
||||
|
||||
class Deferrable(Enum):
|
||||
DEFERRED = "deferred"
|
||||
IMMEDIATE = "immediate"
|
||||
|
||||
# A similar format was proposed for Python 3.10.
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__qualname__}.{self._name_}"
|
||||
|
||||
|
||||
class UniqueConstraint(BaseConstraint):
|
||||
def __init__(
|
||||
self,
|
||||
*expressions,
|
||||
fields=(),
|
||||
name=None,
|
||||
condition=None,
|
||||
deferrable=None,
|
||||
include=None,
|
||||
opclasses=(),
|
||||
violation_error_message=None,
|
||||
):
|
||||
if not name:
|
||||
raise ValueError("A unique constraint must be named.")
|
||||
if not expressions and not fields:
|
||||
raise ValueError(
|
||||
"At least one field or expression is required to define a "
|
||||
"unique constraint."
|
||||
)
|
||||
if expressions and fields:
|
||||
raise ValueError(
|
||||
"UniqueConstraint.fields and expressions are mutually exclusive."
|
||||
)
|
||||
if not isinstance(condition, (type(None), Q)):
|
||||
raise ValueError("UniqueConstraint.condition must be a Q instance.")
|
||||
if condition and deferrable:
|
||||
raise ValueError("UniqueConstraint with conditions cannot be deferred.")
|
||||
if include and deferrable:
|
||||
raise ValueError("UniqueConstraint with include fields cannot be deferred.")
|
||||
if opclasses and deferrable:
|
||||
raise ValueError("UniqueConstraint with opclasses cannot be deferred.")
|
||||
if expressions and deferrable:
|
||||
raise ValueError("UniqueConstraint with expressions cannot be deferred.")
|
||||
if expressions and opclasses:
|
||||
raise ValueError(
|
||||
"UniqueConstraint.opclasses cannot be used with expressions. "
|
||||
"Use django.contrib.postgres.indexes.OpClass() instead."
|
||||
)
|
||||
if not isinstance(deferrable, (type(None), Deferrable)):
|
||||
raise ValueError(
|
||||
"UniqueConstraint.deferrable must be a Deferrable instance."
|
||||
)
|
||||
if not isinstance(include, (type(None), list, tuple)):
|
||||
raise ValueError("UniqueConstraint.include must be a list or tuple.")
|
||||
if not isinstance(opclasses, (list, tuple)):
|
||||
raise ValueError("UniqueConstraint.opclasses must be a list or tuple.")
|
||||
if opclasses and len(fields) != len(opclasses):
|
||||
raise ValueError(
|
||||
"UniqueConstraint.fields and UniqueConstraint.opclasses must "
|
||||
"have the same number of elements."
|
||||
)
|
||||
self.fields = tuple(fields)
|
||||
self.condition = condition
|
||||
self.deferrable = deferrable
|
||||
self.include = tuple(include) if include else ()
|
||||
self.opclasses = opclasses
|
||||
self.expressions = tuple(
|
||||
F(expression) if isinstance(expression, str) else expression
|
||||
for expression in expressions
|
||||
)
|
||||
super().__init__(name, violation_error_message=violation_error_message)
|
||||
|
||||
@property
|
||||
def contains_expressions(self):
|
||||
return bool(self.expressions)
|
||||
|
||||
def _get_condition_sql(self, model, schema_editor):
|
||||
if self.condition is None:
|
||||
return None
|
||||
query = Query(model=model, alias_cols=False)
|
||||
where = query.build_where(self.condition)
|
||||
compiler = query.get_compiler(connection=schema_editor.connection)
|
||||
sql, params = where.as_sql(compiler, schema_editor.connection)
|
||||
return sql % tuple(schema_editor.quote_value(p) for p in params)
|
||||
|
||||
def _get_index_expressions(self, model, schema_editor):
|
||||
if not self.expressions:
|
||||
return None
|
||||
index_expressions = []
|
||||
for expression in self.expressions:
|
||||
index_expression = IndexExpression(expression)
|
||||
index_expression.set_wrapper_classes(schema_editor.connection)
|
||||
index_expressions.append(index_expression)
|
||||
return ExpressionList(*index_expressions).resolve_expression(
|
||||
Query(model, alias_cols=False),
|
||||
)
|
||||
|
||||
def constraint_sql(self, model, schema_editor):
|
||||
fields = [model._meta.get_field(field_name) for field_name in self.fields]
|
||||
include = [
|
||||
model._meta.get_field(field_name).column for field_name in self.include
|
||||
]
|
||||
condition = self._get_condition_sql(model, schema_editor)
|
||||
expressions = self._get_index_expressions(model, schema_editor)
|
||||
return schema_editor._unique_sql(
|
||||
model,
|
||||
fields,
|
||||
self.name,
|
||||
condition=condition,
|
||||
deferrable=self.deferrable,
|
||||
include=include,
|
||||
opclasses=self.opclasses,
|
||||
expressions=expressions,
|
||||
)
|
||||
|
||||
def create_sql(self, model, schema_editor):
|
||||
fields = [model._meta.get_field(field_name) for field_name in self.fields]
|
||||
include = [
|
||||
model._meta.get_field(field_name).column for field_name in self.include
|
||||
]
|
||||
condition = self._get_condition_sql(model, schema_editor)
|
||||
expressions = self._get_index_expressions(model, schema_editor)
|
||||
return schema_editor._create_unique_sql(
|
||||
model,
|
||||
fields,
|
||||
self.name,
|
||||
condition=condition,
|
||||
deferrable=self.deferrable,
|
||||
include=include,
|
||||
opclasses=self.opclasses,
|
||||
expressions=expressions,
|
||||
)
|
||||
|
||||
def remove_sql(self, model, schema_editor):
|
||||
condition = self._get_condition_sql(model, schema_editor)
|
||||
include = [
|
||||
model._meta.get_field(field_name).column for field_name in self.include
|
||||
]
|
||||
expressions = self._get_index_expressions(model, schema_editor)
|
||||
return schema_editor._delete_unique_sql(
|
||||
model,
|
||||
self.name,
|
||||
condition=condition,
|
||||
deferrable=self.deferrable,
|
||||
include=include,
|
||||
opclasses=self.opclasses,
|
||||
expressions=expressions,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s:%s%s%s%s%s%s%s>" % (
|
||||
self.__class__.__qualname__,
|
||||
"" if not self.fields else " fields=%s" % repr(self.fields),
|
||||
"" if not self.expressions else " expressions=%s" % repr(self.expressions),
|
||||
" name=%s" % repr(self.name),
|
||||
"" if self.condition is None else " condition=%s" % self.condition,
|
||||
"" if self.deferrable is None else " deferrable=%r" % self.deferrable,
|
||||
"" if not self.include else " include=%s" % repr(self.include),
|
||||
"" if not self.opclasses else " opclasses=%s" % repr(self.opclasses),
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, UniqueConstraint):
|
||||
return (
|
||||
self.name == other.name
|
||||
and self.fields == other.fields
|
||||
and self.condition == other.condition
|
||||
and self.deferrable == other.deferrable
|
||||
and self.include == other.include
|
||||
and self.opclasses == other.opclasses
|
||||
and self.expressions == other.expressions
|
||||
and self.violation_error_message == other.violation_error_message
|
||||
)
|
||||
return super().__eq__(other)
|
||||
|
||||
def deconstruct(self):
|
||||
path, args, kwargs = super().deconstruct()
|
||||
if self.fields:
|
||||
kwargs["fields"] = self.fields
|
||||
if self.condition:
|
||||
kwargs["condition"] = self.condition
|
||||
if self.deferrable:
|
||||
kwargs["deferrable"] = self.deferrable
|
||||
if self.include:
|
||||
kwargs["include"] = self.include
|
||||
if self.opclasses:
|
||||
kwargs["opclasses"] = self.opclasses
|
||||
return path, self.expressions, kwargs
|
||||
|
||||
def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
|
||||
queryset = model._default_manager.using(using)
|
||||
if self.fields:
|
||||
lookup_kwargs = {}
|
||||
for field_name in self.fields:
|
||||
if exclude and field_name in exclude:
|
||||
return
|
||||
field = model._meta.get_field(field_name)
|
||||
lookup_value = getattr(instance, field.attname)
|
||||
if lookup_value is None or (
|
||||
lookup_value == ""
|
||||
and connections[using].features.interprets_empty_strings_as_nulls
|
||||
):
|
||||
# A composite constraint containing NULL value cannot cause
|
||||
# a violation since NULL != NULL in SQL.
|
||||
return
|
||||
lookup_kwargs[field.name] = lookup_value
|
||||
queryset = queryset.filter(**lookup_kwargs)
|
||||
else:
|
||||
# Ignore constraints with excluded fields.
|
||||
if exclude:
|
||||
for expression in self.expressions:
|
||||
if hasattr(expression, "flatten"):
|
||||
for expr in expression.flatten():
|
||||
if isinstance(expr, F) and expr.name in exclude:
|
||||
return
|
||||
elif isinstance(expression, F) and expression.name in exclude:
|
||||
return
|
||||
replacements = {
|
||||
F(field): value
|
||||
for field, value in instance._get_field_value_map(
|
||||
meta=model._meta, exclude=exclude
|
||||
).items()
|
||||
}
|
||||
expressions = []
|
||||
for expr in self.expressions:
|
||||
# Ignore ordering.
|
||||
if isinstance(expr, OrderBy):
|
||||
expr = expr.expression
|
||||
expressions.append(Exact(expr, expr.replace_expressions(replacements)))
|
||||
queryset = queryset.filter(*expressions)
|
||||
model_class_pk = instance._get_pk_val(model._meta)
|
||||
if not instance._state.adding and model_class_pk is not None:
|
||||
queryset = queryset.exclude(pk=model_class_pk)
|
||||
if not self.condition:
|
||||
if queryset.exists():
|
||||
if self.expressions:
|
||||
raise ValidationError(self.get_violation_error_message())
|
||||
# When fields are defined, use the unique_error_message() for
|
||||
# backward compatibility.
|
||||
for model, constraints in instance.get_constraints():
|
||||
for constraint in constraints:
|
||||
if constraint is self:
|
||||
raise ValidationError(
|
||||
instance.unique_error_message(model, self.fields)
|
||||
)
|
||||
else:
|
||||
against = instance._get_field_value_map(meta=model._meta, exclude=exclude)
|
||||
try:
|
||||
if (self.condition & Exists(queryset.filter(self.condition))).check(
|
||||
against, using=using
|
||||
):
|
||||
raise ValidationError(self.get_violation_error_message())
|
||||
except FieldError:
|
||||
pass
|
||||
@@ -0,0 +1,522 @@
|
||||
from collections import Counter, defaultdict
|
||||
from functools import partial, reduce
|
||||
from itertools import chain
|
||||
from operator import attrgetter, or_
|
||||
|
||||
from django.db import IntegrityError, connections, models, transaction
|
||||
from django.db.models import query_utils, signals, sql
|
||||
|
||||
|
||||
class ProtectedError(IntegrityError):
|
||||
def __init__(self, msg, protected_objects):
|
||||
self.protected_objects = protected_objects
|
||||
super().__init__(msg, protected_objects)
|
||||
|
||||
|
||||
class RestrictedError(IntegrityError):
|
||||
def __init__(self, msg, restricted_objects):
|
||||
self.restricted_objects = restricted_objects
|
||||
super().__init__(msg, restricted_objects)
|
||||
|
||||
|
||||
def CASCADE(collector, field, sub_objs, using):
|
||||
collector.collect(
|
||||
sub_objs,
|
||||
source=field.remote_field.model,
|
||||
source_attr=field.name,
|
||||
nullable=field.null,
|
||||
fail_on_restricted=False,
|
||||
)
|
||||
if field.null and not connections[using].features.can_defer_constraint_checks:
|
||||
collector.add_field_update(field, None, sub_objs)
|
||||
|
||||
|
||||
def PROTECT(collector, field, sub_objs, using):
|
||||
raise ProtectedError(
|
||||
"Cannot delete some instances of model '%s' because they are "
|
||||
"referenced through a protected foreign key: '%s.%s'"
|
||||
% (
|
||||
field.remote_field.model.__name__,
|
||||
sub_objs[0].__class__.__name__,
|
||||
field.name,
|
||||
),
|
||||
sub_objs,
|
||||
)
|
||||
|
||||
|
||||
def RESTRICT(collector, field, sub_objs, using):
|
||||
collector.add_restricted_objects(field, sub_objs)
|
||||
collector.add_dependency(field.remote_field.model, field.model)
|
||||
|
||||
|
||||
def SET(value):
|
||||
if callable(value):
|
||||
|
||||
def set_on_delete(collector, field, sub_objs, using):
|
||||
collector.add_field_update(field, value(), sub_objs)
|
||||
|
||||
else:
|
||||
|
||||
def set_on_delete(collector, field, sub_objs, using):
|
||||
collector.add_field_update(field, value, sub_objs)
|
||||
|
||||
set_on_delete.deconstruct = lambda: ("django.db.models.SET", (value,), {})
|
||||
set_on_delete.lazy_sub_objs = True
|
||||
return set_on_delete
|
||||
|
||||
|
||||
def SET_NULL(collector, field, sub_objs, using):
|
||||
collector.add_field_update(field, None, sub_objs)
|
||||
|
||||
|
||||
SET_NULL.lazy_sub_objs = True
|
||||
|
||||
|
||||
def SET_DEFAULT(collector, field, sub_objs, using):
|
||||
collector.add_field_update(field, field.get_default(), sub_objs)
|
||||
|
||||
|
||||
SET_DEFAULT.lazy_sub_objs = True
|
||||
|
||||
|
||||
def DO_NOTHING(collector, field, sub_objs, using):
|
||||
pass
|
||||
|
||||
|
||||
def get_candidate_relations_to_delete(opts):
|
||||
# The candidate relations are the ones that come from N-1 and 1-1 relations.
|
||||
# N-N (i.e., many-to-many) relations aren't candidates for deletion.
|
||||
return (
|
||||
f
|
||||
for f in opts.get_fields(include_hidden=True)
|
||||
if f.auto_created and not f.concrete and (f.one_to_one or f.one_to_many)
|
||||
)
|
||||
|
||||
|
||||
class Collector:
|
||||
def __init__(self, using, origin=None):
|
||||
self.using = using
|
||||
# A Model or QuerySet object.
|
||||
self.origin = origin
|
||||
# Initially, {model: {instances}}, later values become lists.
|
||||
self.data = defaultdict(set)
|
||||
# {(field, value): [instances, …]}
|
||||
self.field_updates = defaultdict(list)
|
||||
# {model: {field: {instances}}}
|
||||
self.restricted_objects = defaultdict(partial(defaultdict, set))
|
||||
# fast_deletes is a list of queryset-likes that can be deleted without
|
||||
# fetching the objects into memory.
|
||||
self.fast_deletes = []
|
||||
|
||||
# Tracks deletion-order dependency for databases without transactions
|
||||
# or ability to defer constraint checks. Only concrete model classes
|
||||
# should be included, as the dependencies exist only between actual
|
||||
# database tables; proxy models are represented here by their concrete
|
||||
# parent.
|
||||
self.dependencies = defaultdict(set) # {model: {models}}
|
||||
|
||||
def add(self, objs, source=None, nullable=False, reverse_dependency=False):
|
||||
"""
|
||||
Add 'objs' to the collection of objects to be deleted. If the call is
|
||||
the result of a cascade, 'source' should be the model that caused it,
|
||||
and 'nullable' should be set to True if the relation can be null.
|
||||
|
||||
Return a list of all objects that were not already collected.
|
||||
"""
|
||||
if not objs:
|
||||
return []
|
||||
new_objs = []
|
||||
model = objs[0].__class__
|
||||
instances = self.data[model]
|
||||
for obj in objs:
|
||||
if obj not in instances:
|
||||
new_objs.append(obj)
|
||||
instances.update(new_objs)
|
||||
# Nullable relationships can be ignored -- they are nulled out before
|
||||
# deleting, and therefore do not affect the order in which objects have
|
||||
# to be deleted.
|
||||
if source is not None and not nullable:
|
||||
self.add_dependency(source, model, reverse_dependency=reverse_dependency)
|
||||
return new_objs
|
||||
|
||||
def add_dependency(self, model, dependency, reverse_dependency=False):
|
||||
if reverse_dependency:
|
||||
model, dependency = dependency, model
|
||||
self.dependencies[model._meta.concrete_model].add(
|
||||
dependency._meta.concrete_model
|
||||
)
|
||||
self.data.setdefault(dependency, self.data.default_factory())
|
||||
|
||||
def add_field_update(self, field, value, objs):
|
||||
"""
|
||||
Schedule a field update. 'objs' must be a homogeneous iterable
|
||||
collection of model instances (e.g. a QuerySet).
|
||||
"""
|
||||
self.field_updates[field, value].append(objs)
|
||||
|
||||
def add_restricted_objects(self, field, objs):
|
||||
if objs:
|
||||
model = objs[0].__class__
|
||||
self.restricted_objects[model][field].update(objs)
|
||||
|
||||
def clear_restricted_objects_from_set(self, model, objs):
|
||||
if model in self.restricted_objects:
|
||||
self.restricted_objects[model] = {
|
||||
field: items - objs
|
||||
for field, items in self.restricted_objects[model].items()
|
||||
}
|
||||
|
||||
def clear_restricted_objects_from_queryset(self, model, qs):
|
||||
if model in self.restricted_objects:
|
||||
objs = set(
|
||||
qs.filter(
|
||||
pk__in=[
|
||||
obj.pk
|
||||
for objs in self.restricted_objects[model].values()
|
||||
for obj in objs
|
||||
]
|
||||
)
|
||||
)
|
||||
self.clear_restricted_objects_from_set(model, objs)
|
||||
|
||||
def _has_signal_listeners(self, model):
|
||||
return signals.pre_delete.has_listeners(
|
||||
model
|
||||
) or signals.post_delete.has_listeners(model)
|
||||
|
||||
def can_fast_delete(self, objs, from_field=None):
|
||||
"""
|
||||
Determine if the objects in the given queryset-like or single object
|
||||
can be fast-deleted. This can be done if there are no cascades, no
|
||||
parents and no signal listeners for the object class.
|
||||
|
||||
The 'from_field' tells where we are coming from - we need this to
|
||||
determine if the objects are in fact to be deleted. Allow also
|
||||
skipping parent -> child -> parent chain preventing fast delete of
|
||||
the child.
|
||||
"""
|
||||
if from_field and from_field.remote_field.on_delete is not CASCADE:
|
||||
return False
|
||||
if hasattr(objs, "_meta"):
|
||||
model = objs._meta.model
|
||||
elif hasattr(objs, "model") and hasattr(objs, "_raw_delete"):
|
||||
model = objs.model
|
||||
else:
|
||||
return False
|
||||
if self._has_signal_listeners(model):
|
||||
return False
|
||||
# The use of from_field comes from the need to avoid cascade back to
|
||||
# parent when parent delete is cascading to child.
|
||||
opts = model._meta
|
||||
return (
|
||||
all(
|
||||
link == from_field
|
||||
for link in opts.concrete_model._meta.parents.values()
|
||||
)
|
||||
and
|
||||
# Foreign keys pointing to this model.
|
||||
all(
|
||||
related.field.remote_field.on_delete is DO_NOTHING
|
||||
for related in get_candidate_relations_to_delete(opts)
|
||||
)
|
||||
and (
|
||||
# Something like generic foreign key.
|
||||
not any(
|
||||
hasattr(field, "bulk_related_objects")
|
||||
for field in opts.private_fields
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def get_del_batches(self, objs, fields):
|
||||
"""
|
||||
Return the objs in suitably sized batches for the used connection.
|
||||
"""
|
||||
field_names = [field.name for field in fields]
|
||||
conn_batch_size = max(
|
||||
connections[self.using].ops.bulk_batch_size(field_names, objs), 1
|
||||
)
|
||||
if len(objs) > conn_batch_size:
|
||||
return [
|
||||
objs[i : i + conn_batch_size]
|
||||
for i in range(0, len(objs), conn_batch_size)
|
||||
]
|
||||
else:
|
||||
return [objs]
|
||||
|
||||
def collect(
|
||||
self,
|
||||
objs,
|
||||
source=None,
|
||||
nullable=False,
|
||||
collect_related=True,
|
||||
source_attr=None,
|
||||
reverse_dependency=False,
|
||||
keep_parents=False,
|
||||
fail_on_restricted=True,
|
||||
):
|
||||
"""
|
||||
Add 'objs' to the collection of objects to be deleted as well as all
|
||||
parent instances. 'objs' must be a homogeneous iterable collection of
|
||||
model instances (e.g. a QuerySet). If 'collect_related' is True,
|
||||
related objects will be handled by their respective on_delete handler.
|
||||
|
||||
If the call is the result of a cascade, 'source' should be the model
|
||||
that caused it and 'nullable' should be set to True, if the relation
|
||||
can be null.
|
||||
|
||||
If 'reverse_dependency' is True, 'source' will be deleted before the
|
||||
current model, rather than after. (Needed for cascading to parent
|
||||
models, the one case in which the cascade follows the forwards
|
||||
direction of an FK rather than the reverse direction.)
|
||||
|
||||
If 'keep_parents' is True, data of parent model's will be not deleted.
|
||||
|
||||
If 'fail_on_restricted' is False, error won't be raised even if it's
|
||||
prohibited to delete such objects due to RESTRICT, that defers
|
||||
restricted object checking in recursive calls where the top-level call
|
||||
may need to collect more objects to determine whether restricted ones
|
||||
can be deleted.
|
||||
"""
|
||||
if self.can_fast_delete(objs):
|
||||
self.fast_deletes.append(objs)
|
||||
return
|
||||
new_objs = self.add(
|
||||
objs, source, nullable, reverse_dependency=reverse_dependency
|
||||
)
|
||||
if not new_objs:
|
||||
return
|
||||
|
||||
model = new_objs[0].__class__
|
||||
|
||||
if not keep_parents:
|
||||
# Recursively collect concrete model's parent models, but not their
|
||||
# related objects. These will be found by meta.get_fields()
|
||||
concrete_model = model._meta.concrete_model
|
||||
for ptr in concrete_model._meta.parents.values():
|
||||
if ptr:
|
||||
parent_objs = [getattr(obj, ptr.name) for obj in new_objs]
|
||||
self.collect(
|
||||
parent_objs,
|
||||
source=model,
|
||||
source_attr=ptr.remote_field.related_name,
|
||||
collect_related=False,
|
||||
reverse_dependency=True,
|
||||
fail_on_restricted=False,
|
||||
)
|
||||
if not collect_related:
|
||||
return
|
||||
|
||||
if keep_parents:
|
||||
parents = set(model._meta.get_parent_list())
|
||||
model_fast_deletes = defaultdict(list)
|
||||
protected_objects = defaultdict(list)
|
||||
for related in get_candidate_relations_to_delete(model._meta):
|
||||
# Preserve parent reverse relationships if keep_parents=True.
|
||||
if keep_parents and related.model in parents:
|
||||
continue
|
||||
field = related.field
|
||||
on_delete = field.remote_field.on_delete
|
||||
if on_delete == DO_NOTHING:
|
||||
continue
|
||||
related_model = related.related_model
|
||||
if self.can_fast_delete(related_model, from_field=field):
|
||||
model_fast_deletes[related_model].append(field)
|
||||
continue
|
||||
batches = self.get_del_batches(new_objs, [field])
|
||||
for batch in batches:
|
||||
sub_objs = self.related_objects(related_model, [field], batch)
|
||||
# Non-referenced fields can be deferred if no signal receivers
|
||||
# are connected for the related model as they'll never be
|
||||
# exposed to the user. Skip field deferring when some
|
||||
# relationships are select_related as interactions between both
|
||||
# features are hard to get right. This should only happen in
|
||||
# the rare cases where .related_objects is overridden anyway.
|
||||
if not (
|
||||
sub_objs.query.select_related
|
||||
or self._has_signal_listeners(related_model)
|
||||
):
|
||||
referenced_fields = set(
|
||||
chain.from_iterable(
|
||||
(rf.attname for rf in rel.field.foreign_related_fields)
|
||||
for rel in get_candidate_relations_to_delete(
|
||||
related_model._meta
|
||||
)
|
||||
)
|
||||
)
|
||||
sub_objs = sub_objs.only(*tuple(referenced_fields))
|
||||
if getattr(on_delete, "lazy_sub_objs", False) or sub_objs:
|
||||
try:
|
||||
on_delete(self, field, sub_objs, self.using)
|
||||
except ProtectedError as error:
|
||||
key = "'%s.%s'" % (field.model.__name__, field.name)
|
||||
protected_objects[key] += error.protected_objects
|
||||
if protected_objects:
|
||||
raise ProtectedError(
|
||||
"Cannot delete some instances of model %r because they are "
|
||||
"referenced through protected foreign keys: %s."
|
||||
% (
|
||||
model.__name__,
|
||||
", ".join(protected_objects),
|
||||
),
|
||||
set(chain.from_iterable(protected_objects.values())),
|
||||
)
|
||||
for related_model, related_fields in model_fast_deletes.items():
|
||||
batches = self.get_del_batches(new_objs, related_fields)
|
||||
for batch in batches:
|
||||
sub_objs = self.related_objects(related_model, related_fields, batch)
|
||||
self.fast_deletes.append(sub_objs)
|
||||
for field in model._meta.private_fields:
|
||||
if hasattr(field, "bulk_related_objects"):
|
||||
# It's something like generic foreign key.
|
||||
sub_objs = field.bulk_related_objects(new_objs, self.using)
|
||||
self.collect(
|
||||
sub_objs, source=model, nullable=True, fail_on_restricted=False
|
||||
)
|
||||
|
||||
if fail_on_restricted:
|
||||
# Raise an error if collected restricted objects (RESTRICT) aren't
|
||||
# candidates for deletion also collected via CASCADE.
|
||||
for related_model, instances in self.data.items():
|
||||
self.clear_restricted_objects_from_set(related_model, instances)
|
||||
for qs in self.fast_deletes:
|
||||
self.clear_restricted_objects_from_queryset(qs.model, qs)
|
||||
if self.restricted_objects.values():
|
||||
restricted_objects = defaultdict(list)
|
||||
for related_model, fields in self.restricted_objects.items():
|
||||
for field, objs in fields.items():
|
||||
if objs:
|
||||
key = "'%s.%s'" % (related_model.__name__, field.name)
|
||||
restricted_objects[key] += objs
|
||||
if restricted_objects:
|
||||
raise RestrictedError(
|
||||
"Cannot delete some instances of model %r because "
|
||||
"they are referenced through restricted foreign keys: "
|
||||
"%s."
|
||||
% (
|
||||
model.__name__,
|
||||
", ".join(restricted_objects),
|
||||
),
|
||||
set(chain.from_iterable(restricted_objects.values())),
|
||||
)
|
||||
|
||||
def related_objects(self, related_model, related_fields, objs):
|
||||
"""
|
||||
Get a QuerySet of the related model to objs via related fields.
|
||||
"""
|
||||
predicate = query_utils.Q.create(
|
||||
[(f"{related_field.name}__in", objs) for related_field in related_fields],
|
||||
connector=query_utils.Q.OR,
|
||||
)
|
||||
return related_model._base_manager.using(self.using).filter(predicate)
|
||||
|
||||
def instances_with_model(self):
|
||||
for model, instances in self.data.items():
|
||||
for obj in instances:
|
||||
yield model, obj
|
||||
|
||||
def sort(self):
|
||||
sorted_models = []
|
||||
concrete_models = set()
|
||||
models = list(self.data)
|
||||
while len(sorted_models) < len(models):
|
||||
found = False
|
||||
for model in models:
|
||||
if model in sorted_models:
|
||||
continue
|
||||
dependencies = self.dependencies.get(model._meta.concrete_model)
|
||||
if not (dependencies and dependencies.difference(concrete_models)):
|
||||
sorted_models.append(model)
|
||||
concrete_models.add(model._meta.concrete_model)
|
||||
found = True
|
||||
if not found:
|
||||
return
|
||||
self.data = {model: self.data[model] for model in sorted_models}
|
||||
|
||||
def delete(self):
|
||||
# sort instance collections
|
||||
for model, instances in self.data.items():
|
||||
self.data[model] = sorted(instances, key=attrgetter("pk"))
|
||||
|
||||
# if possible, bring the models in an order suitable for databases that
|
||||
# don't support transactions or cannot defer constraint checks until the
|
||||
# end of a transaction.
|
||||
self.sort()
|
||||
# number of objects deleted for each model label
|
||||
deleted_counter = Counter()
|
||||
|
||||
# Optimize for the case with a single obj and no dependencies
|
||||
if len(self.data) == 1 and len(instances) == 1:
|
||||
instance = list(instances)[0]
|
||||
if self.can_fast_delete(instance):
|
||||
with transaction.mark_for_rollback_on_error(self.using):
|
||||
count = sql.DeleteQuery(model).delete_batch(
|
||||
[instance.pk], self.using
|
||||
)
|
||||
setattr(instance, model._meta.pk.attname, None)
|
||||
return count, {model._meta.label: count}
|
||||
|
||||
with transaction.atomic(using=self.using, savepoint=False):
|
||||
# send pre_delete signals
|
||||
for model, obj in self.instances_with_model():
|
||||
if not model._meta.auto_created:
|
||||
signals.pre_delete.send(
|
||||
sender=model,
|
||||
instance=obj,
|
||||
using=self.using,
|
||||
origin=self.origin,
|
||||
)
|
||||
|
||||
# fast deletes
|
||||
for qs in self.fast_deletes:
|
||||
count = qs._raw_delete(using=self.using)
|
||||
if count:
|
||||
deleted_counter[qs.model._meta.label] += count
|
||||
|
||||
# update fields
|
||||
for (field, value), instances_list in self.field_updates.items():
|
||||
updates = []
|
||||
objs = []
|
||||
for instances in instances_list:
|
||||
if (
|
||||
isinstance(instances, models.QuerySet)
|
||||
and instances._result_cache is None
|
||||
):
|
||||
updates.append(instances)
|
||||
else:
|
||||
objs.extend(instances)
|
||||
if updates:
|
||||
combined_updates = reduce(or_, updates)
|
||||
combined_updates.update(**{field.name: value})
|
||||
if objs:
|
||||
model = objs[0].__class__
|
||||
query = sql.UpdateQuery(model)
|
||||
query.update_batch(
|
||||
list({obj.pk for obj in objs}), {field.name: value}, self.using
|
||||
)
|
||||
|
||||
# reverse instance collections
|
||||
for instances in self.data.values():
|
||||
instances.reverse()
|
||||
|
||||
# delete instances
|
||||
for model, instances in self.data.items():
|
||||
query = sql.DeleteQuery(model)
|
||||
pk_list = [obj.pk for obj in instances]
|
||||
count = query.delete_batch(pk_list, self.using)
|
||||
if count:
|
||||
deleted_counter[model._meta.label] += count
|
||||
|
||||
if not model._meta.auto_created:
|
||||
for obj in instances:
|
||||
signals.post_delete.send(
|
||||
sender=model,
|
||||
instance=obj,
|
||||
using=self.using,
|
||||
origin=self.origin,
|
||||
)
|
||||
|
||||
for model, instances in self.data.items():
|
||||
for instance in instances:
|
||||
setattr(instance, model._meta.pk.attname, None)
|
||||
return sum(deleted_counter.values()), dict(deleted_counter)
|
||||
@@ -0,0 +1,92 @@
|
||||
import enum
|
||||
from types import DynamicClassAttribute
|
||||
|
||||
from django.utils.functional import Promise
|
||||
|
||||
__all__ = ["Choices", "IntegerChoices", "TextChoices"]
|
||||
|
||||
|
||||
class ChoicesMeta(enum.EnumMeta):
|
||||
"""A metaclass for creating a enum choices."""
|
||||
|
||||
def __new__(metacls, classname, bases, classdict, **kwds):
|
||||
labels = []
|
||||
for key in classdict._member_names:
|
||||
value = classdict[key]
|
||||
if (
|
||||
isinstance(value, (list, tuple))
|
||||
and len(value) > 1
|
||||
and isinstance(value[-1], (Promise, str))
|
||||
):
|
||||
*value, label = value
|
||||
value = tuple(value)
|
||||
else:
|
||||
label = key.replace("_", " ").title()
|
||||
labels.append(label)
|
||||
# Use dict.__setitem__() to suppress defenses against double
|
||||
# assignment in enum's classdict.
|
||||
dict.__setitem__(classdict, key, value)
|
||||
cls = super().__new__(metacls, classname, bases, classdict, **kwds)
|
||||
for member, label in zip(cls.__members__.values(), labels):
|
||||
member._label_ = label
|
||||
return enum.unique(cls)
|
||||
|
||||
def __contains__(cls, member):
|
||||
if not isinstance(member, enum.Enum):
|
||||
# Allow non-enums to match against member values.
|
||||
return any(x.value == member for x in cls)
|
||||
return super().__contains__(member)
|
||||
|
||||
@property
|
||||
def names(cls):
|
||||
empty = ["__empty__"] if hasattr(cls, "__empty__") else []
|
||||
return empty + [member.name for member in cls]
|
||||
|
||||
@property
|
||||
def choices(cls):
|
||||
empty = [(None, cls.__empty__)] if hasattr(cls, "__empty__") else []
|
||||
return empty + [(member.value, member.label) for member in cls]
|
||||
|
||||
@property
|
||||
def labels(cls):
|
||||
return [label for _, label in cls.choices]
|
||||
|
||||
@property
|
||||
def values(cls):
|
||||
return [value for value, _ in cls.choices]
|
||||
|
||||
|
||||
class Choices(enum.Enum, metaclass=ChoicesMeta):
|
||||
"""Class for creating enumerated choices."""
|
||||
|
||||
@DynamicClassAttribute
|
||||
def label(self):
|
||||
return self._label_
|
||||
|
||||
@property
|
||||
def do_not_call_in_templates(self):
|
||||
return True
|
||||
|
||||
def __str__(self):
|
||||
"""
|
||||
Use value when cast to str, so that Choices set as model instance
|
||||
attributes are rendered as expected in templates and similar contexts.
|
||||
"""
|
||||
return str(self.value)
|
||||
|
||||
# A similar format was proposed for Python 3.10.
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__qualname__}.{self._name_}"
|
||||
|
||||
|
||||
class IntegerChoices(int, Choices):
|
||||
"""Class for creating enumerated integer choices."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TextChoices(str, Choices):
|
||||
"""Class for creating enumerated string choices."""
|
||||
|
||||
def _generate_next_value_(name, start, count, last_values):
|
||||
return name
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,510 @@
|
||||
import datetime
|
||||
import posixpath
|
||||
|
||||
from django import forms
|
||||
from django.core import checks
|
||||
from django.core.files.base import File
|
||||
from django.core.files.images import ImageFile
|
||||
from django.core.files.storage import Storage, default_storage
|
||||
from django.core.files.utils import validate_file_name
|
||||
from django.db.models import signals
|
||||
from django.db.models.fields import Field
|
||||
from django.db.models.query_utils import DeferredAttribute
|
||||
from django.db.models.utils import AltersData
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class FieldFile(File, AltersData):
|
||||
def __init__(self, instance, field, name):
|
||||
super().__init__(None, name)
|
||||
self.instance = instance
|
||||
self.field = field
|
||||
self.storage = field.storage
|
||||
self._committed = True
|
||||
|
||||
def __eq__(self, other):
|
||||
# Older code may be expecting FileField values to be simple strings.
|
||||
# By overriding the == operator, it can remain backwards compatibility.
|
||||
if hasattr(other, "name"):
|
||||
return self.name == other.name
|
||||
return self.name == other
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.name)
|
||||
|
||||
# The standard File contains most of the necessary properties, but
|
||||
# FieldFiles can be instantiated without a name, so that needs to
|
||||
# be checked for here.
|
||||
|
||||
def _require_file(self):
|
||||
if not self:
|
||||
raise ValueError(
|
||||
"The '%s' attribute has no file associated with it." % self.field.name
|
||||
)
|
||||
|
||||
def _get_file(self):
|
||||
self._require_file()
|
||||
if getattr(self, "_file", None) is None:
|
||||
self._file = self.storage.open(self.name, "rb")
|
||||
return self._file
|
||||
|
||||
def _set_file(self, file):
|
||||
self._file = file
|
||||
|
||||
def _del_file(self):
|
||||
del self._file
|
||||
|
||||
file = property(_get_file, _set_file, _del_file)
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
self._require_file()
|
||||
return self.storage.path(self.name)
|
||||
|
||||
@property
|
||||
def url(self):
|
||||
self._require_file()
|
||||
return self.storage.url(self.name)
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
self._require_file()
|
||||
if not self._committed:
|
||||
return self.file.size
|
||||
return self.storage.size(self.name)
|
||||
|
||||
def open(self, mode="rb"):
|
||||
self._require_file()
|
||||
if getattr(self, "_file", None) is None:
|
||||
self.file = self.storage.open(self.name, mode)
|
||||
else:
|
||||
self.file.open(mode)
|
||||
return self
|
||||
|
||||
# open() doesn't alter the file's contents, but it does reset the pointer
|
||||
open.alters_data = True
|
||||
|
||||
# In addition to the standard File API, FieldFiles have extra methods
|
||||
# to further manipulate the underlying file, as well as update the
|
||||
# associated model instance.
|
||||
|
||||
def save(self, name, content, save=True):
|
||||
name = self.field.generate_filename(self.instance, name)
|
||||
self.name = self.storage.save(name, content, max_length=self.field.max_length)
|
||||
setattr(self.instance, self.field.attname, self.name)
|
||||
self._committed = True
|
||||
|
||||
# Save the object because it has changed, unless save is False
|
||||
if save:
|
||||
self.instance.save()
|
||||
|
||||
save.alters_data = True
|
||||
|
||||
def delete(self, save=True):
|
||||
if not self:
|
||||
return
|
||||
# Only close the file if it's already open, which we know by the
|
||||
# presence of self._file
|
||||
if hasattr(self, "_file"):
|
||||
self.close()
|
||||
del self.file
|
||||
|
||||
self.storage.delete(self.name)
|
||||
|
||||
self.name = None
|
||||
setattr(self.instance, self.field.attname, self.name)
|
||||
self._committed = False
|
||||
|
||||
if save:
|
||||
self.instance.save()
|
||||
|
||||
delete.alters_data = True
|
||||
|
||||
@property
|
||||
def closed(self):
|
||||
file = getattr(self, "_file", None)
|
||||
return file is None or file.closed
|
||||
|
||||
def close(self):
|
||||
file = getattr(self, "_file", None)
|
||||
if file is not None:
|
||||
file.close()
|
||||
|
||||
def __getstate__(self):
|
||||
# FieldFile needs access to its associated model field, an instance and
|
||||
# the file's name. Everything else will be restored later, by
|
||||
# FileDescriptor below.
|
||||
return {
|
||||
"name": self.name,
|
||||
"closed": False,
|
||||
"_committed": True,
|
||||
"_file": None,
|
||||
"instance": self.instance,
|
||||
"field": self.field,
|
||||
}
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__.update(state)
|
||||
self.storage = self.field.storage
|
||||
|
||||
|
||||
class FileDescriptor(DeferredAttribute):
|
||||
"""
|
||||
The descriptor for the file attribute on the model instance. Return a
|
||||
FieldFile when accessed so you can write code like::
|
||||
|
||||
>>> from myapp.models import MyModel
|
||||
>>> instance = MyModel.objects.get(pk=1)
|
||||
>>> instance.file.size
|
||||
|
||||
Assign a file object on assignment so you can do::
|
||||
|
||||
>>> with open('/path/to/hello.world') as f:
|
||||
... instance.file = File(f)
|
||||
"""
|
||||
|
||||
def __get__(self, instance, cls=None):
|
||||
if instance is None:
|
||||
return self
|
||||
|
||||
# This is slightly complicated, so worth an explanation.
|
||||
# instance.file needs to ultimately return some instance of `File`,
|
||||
# probably a subclass. Additionally, this returned object needs to have
|
||||
# the FieldFile API so that users can easily do things like
|
||||
# instance.file.path and have that delegated to the file storage engine.
|
||||
# Easy enough if we're strict about assignment in __set__, but if you
|
||||
# peek below you can see that we're not. So depending on the current
|
||||
# value of the field we have to dynamically construct some sort of
|
||||
# "thing" to return.
|
||||
|
||||
# The instance dict contains whatever was originally assigned
|
||||
# in __set__.
|
||||
file = super().__get__(instance, cls)
|
||||
|
||||
# If this value is a string (instance.file = "path/to/file") or None
|
||||
# then we simply wrap it with the appropriate attribute class according
|
||||
# to the file field. [This is FieldFile for FileFields and
|
||||
# ImageFieldFile for ImageFields; it's also conceivable that user
|
||||
# subclasses might also want to subclass the attribute class]. This
|
||||
# object understands how to convert a path to a file, and also how to
|
||||
# handle None.
|
||||
if isinstance(file, str) or file is None:
|
||||
attr = self.field.attr_class(instance, self.field, file)
|
||||
instance.__dict__[self.field.attname] = attr
|
||||
|
||||
# Other types of files may be assigned as well, but they need to have
|
||||
# the FieldFile interface added to them. Thus, we wrap any other type of
|
||||
# File inside a FieldFile (well, the field's attr_class, which is
|
||||
# usually FieldFile).
|
||||
elif isinstance(file, File) and not isinstance(file, FieldFile):
|
||||
file_copy = self.field.attr_class(instance, self.field, file.name)
|
||||
file_copy.file = file
|
||||
file_copy._committed = False
|
||||
instance.__dict__[self.field.attname] = file_copy
|
||||
|
||||
# Finally, because of the (some would say boneheaded) way pickle works,
|
||||
# the underlying FieldFile might not actually itself have an associated
|
||||
# file. So we need to reset the details of the FieldFile in those cases.
|
||||
elif isinstance(file, FieldFile) and not hasattr(file, "field"):
|
||||
file.instance = instance
|
||||
file.field = self.field
|
||||
file.storage = self.field.storage
|
||||
|
||||
# Make sure that the instance is correct.
|
||||
elif isinstance(file, FieldFile) and instance is not file.instance:
|
||||
file.instance = instance
|
||||
|
||||
# That was fun, wasn't it?
|
||||
return instance.__dict__[self.field.attname]
|
||||
|
||||
def __set__(self, instance, value):
|
||||
instance.__dict__[self.field.attname] = value
|
||||
|
||||
|
||||
class FileField(Field):
|
||||
# The class to wrap instance attributes in. Accessing the file object off
|
||||
# the instance will always return an instance of attr_class.
|
||||
attr_class = FieldFile
|
||||
|
||||
# The descriptor to use for accessing the attribute off of the class.
|
||||
descriptor_class = FileDescriptor
|
||||
|
||||
description = _("File")
|
||||
|
||||
def __init__(
|
||||
self, verbose_name=None, name=None, upload_to="", storage=None, **kwargs
|
||||
):
|
||||
self._primary_key_set_explicitly = "primary_key" in kwargs
|
||||
|
||||
self.storage = storage or default_storage
|
||||
if callable(self.storage):
|
||||
# Hold a reference to the callable for deconstruct().
|
||||
self._storage_callable = self.storage
|
||||
self.storage = self.storage()
|
||||
if not isinstance(self.storage, Storage):
|
||||
raise TypeError(
|
||||
"%s.storage must be a subclass/instance of %s.%s"
|
||||
% (
|
||||
self.__class__.__qualname__,
|
||||
Storage.__module__,
|
||||
Storage.__qualname__,
|
||||
)
|
||||
)
|
||||
self.upload_to = upload_to
|
||||
|
||||
kwargs.setdefault("max_length", 100)
|
||||
super().__init__(verbose_name, name, **kwargs)
|
||||
|
||||
def check(self, **kwargs):
|
||||
return [
|
||||
*super().check(**kwargs),
|
||||
*self._check_primary_key(),
|
||||
*self._check_upload_to(),
|
||||
]
|
||||
|
||||
def _check_primary_key(self):
|
||||
if self._primary_key_set_explicitly:
|
||||
return [
|
||||
checks.Error(
|
||||
"'primary_key' is not a valid argument for a %s."
|
||||
% self.__class__.__name__,
|
||||
obj=self,
|
||||
id="fields.E201",
|
||||
)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def _check_upload_to(self):
|
||||
if isinstance(self.upload_to, str) and self.upload_to.startswith("/"):
|
||||
return [
|
||||
checks.Error(
|
||||
"%s's 'upload_to' argument must be a relative path, not an "
|
||||
"absolute path." % self.__class__.__name__,
|
||||
obj=self,
|
||||
id="fields.E202",
|
||||
hint="Remove the leading slash.",
|
||||
)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def deconstruct(self):
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
if kwargs.get("max_length") == 100:
|
||||
del kwargs["max_length"]
|
||||
kwargs["upload_to"] = self.upload_to
|
||||
storage = getattr(self, "_storage_callable", self.storage)
|
||||
if storage is not default_storage:
|
||||
kwargs["storage"] = storage
|
||||
return name, path, args, kwargs
|
||||
|
||||
def get_internal_type(self):
|
||||
return "FileField"
|
||||
|
||||
def get_prep_value(self, value):
|
||||
value = super().get_prep_value(value)
|
||||
# Need to convert File objects provided via a form to string for
|
||||
# database insertion.
|
||||
if value is None:
|
||||
return None
|
||||
return str(value)
|
||||
|
||||
def pre_save(self, model_instance, add):
|
||||
file = super().pre_save(model_instance, add)
|
||||
if file and not file._committed:
|
||||
# Commit the file to storage prior to saving the model
|
||||
file.save(file.name, file.file, save=False)
|
||||
return file
|
||||
|
||||
def contribute_to_class(self, cls, name, **kwargs):
|
||||
super().contribute_to_class(cls, name, **kwargs)
|
||||
setattr(cls, self.attname, self.descriptor_class(self))
|
||||
|
||||
def generate_filename(self, instance, filename):
|
||||
"""
|
||||
Apply (if callable) or prepend (if a string) upload_to to the filename,
|
||||
then delegate further processing of the name to the storage backend.
|
||||
Until the storage layer, all file paths are expected to be Unix style
|
||||
(with forward slashes).
|
||||
"""
|
||||
if callable(self.upload_to):
|
||||
filename = self.upload_to(instance, filename)
|
||||
else:
|
||||
dirname = datetime.datetime.now().strftime(str(self.upload_to))
|
||||
filename = posixpath.join(dirname, filename)
|
||||
filename = validate_file_name(filename, allow_relative_path=True)
|
||||
return self.storage.generate_filename(filename)
|
||||
|
||||
def save_form_data(self, instance, data):
|
||||
# Important: None means "no change", other false value means "clear"
|
||||
# This subtle distinction (rather than a more explicit marker) is
|
||||
# needed because we need to consume values that are also sane for a
|
||||
# regular (non Model-) Form to find in its cleaned_data dictionary.
|
||||
if data is not None:
|
||||
# This value will be converted to str and stored in the
|
||||
# database, so leaving False as-is is not acceptable.
|
||||
setattr(instance, self.name, data or "")
|
||||
|
||||
def formfield(self, **kwargs):
|
||||
return super().formfield(
|
||||
**{
|
||||
"form_class": forms.FileField,
|
||||
"max_length": self.max_length,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ImageFileDescriptor(FileDescriptor):
|
||||
"""
|
||||
Just like the FileDescriptor, but for ImageFields. The only difference is
|
||||
assigning the width/height to the width_field/height_field, if appropriate.
|
||||
"""
|
||||
|
||||
def __set__(self, instance, value):
|
||||
previous_file = instance.__dict__.get(self.field.attname)
|
||||
super().__set__(instance, value)
|
||||
|
||||
# To prevent recalculating image dimensions when we are instantiating
|
||||
# an object from the database (bug #11084), only update dimensions if
|
||||
# the field had a value before this assignment. Since the default
|
||||
# value for FileField subclasses is an instance of field.attr_class,
|
||||
# previous_file will only be None when we are called from
|
||||
# Model.__init__(). The ImageField.update_dimension_fields method
|
||||
# hooked up to the post_init signal handles the Model.__init__() cases.
|
||||
# Assignment happening outside of Model.__init__() will trigger the
|
||||
# update right here.
|
||||
if previous_file is not None:
|
||||
self.field.update_dimension_fields(instance, force=True)
|
||||
|
||||
|
||||
class ImageFieldFile(ImageFile, FieldFile):
|
||||
def delete(self, save=True):
|
||||
# Clear the image dimensions cache
|
||||
if hasattr(self, "_dimensions_cache"):
|
||||
del self._dimensions_cache
|
||||
super().delete(save)
|
||||
|
||||
|
||||
class ImageField(FileField):
|
||||
attr_class = ImageFieldFile
|
||||
descriptor_class = ImageFileDescriptor
|
||||
description = _("Image")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
verbose_name=None,
|
||||
name=None,
|
||||
width_field=None,
|
||||
height_field=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.width_field, self.height_field = width_field, height_field
|
||||
super().__init__(verbose_name, name, **kwargs)
|
||||
|
||||
def check(self, **kwargs):
|
||||
return [
|
||||
*super().check(**kwargs),
|
||||
*self._check_image_library_installed(),
|
||||
]
|
||||
|
||||
def _check_image_library_installed(self):
|
||||
try:
|
||||
from PIL import Image # NOQA
|
||||
except ImportError:
|
||||
return [
|
||||
checks.Error(
|
||||
"Cannot use ImageField because Pillow is not installed.",
|
||||
hint=(
|
||||
"Get Pillow at https://pypi.org/project/Pillow/ "
|
||||
'or run command "python -m pip install Pillow".'
|
||||
),
|
||||
obj=self,
|
||||
id="fields.E210",
|
||||
)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def deconstruct(self):
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
if self.width_field:
|
||||
kwargs["width_field"] = self.width_field
|
||||
if self.height_field:
|
||||
kwargs["height_field"] = self.height_field
|
||||
return name, path, args, kwargs
|
||||
|
||||
def contribute_to_class(self, cls, name, **kwargs):
|
||||
super().contribute_to_class(cls, name, **kwargs)
|
||||
# Attach update_dimension_fields so that dimension fields declared
|
||||
# after their corresponding image field don't stay cleared by
|
||||
# Model.__init__, see bug #11196.
|
||||
# Only run post-initialization dimension update on non-abstract models
|
||||
if not cls._meta.abstract:
|
||||
signals.post_init.connect(self.update_dimension_fields, sender=cls)
|
||||
|
||||
def update_dimension_fields(self, instance, force=False, *args, **kwargs):
|
||||
"""
|
||||
Update field's width and height fields, if defined.
|
||||
|
||||
This method is hooked up to model's post_init signal to update
|
||||
dimensions after instantiating a model instance. However, dimensions
|
||||
won't be updated if the dimensions fields are already populated. This
|
||||
avoids unnecessary recalculation when loading an object from the
|
||||
database.
|
||||
|
||||
Dimensions can be forced to update with force=True, which is how
|
||||
ImageFileDescriptor.__set__ calls this method.
|
||||
"""
|
||||
# Nothing to update if the field doesn't have dimension fields or if
|
||||
# the field is deferred.
|
||||
has_dimension_fields = self.width_field or self.height_field
|
||||
if not has_dimension_fields or self.attname not in instance.__dict__:
|
||||
return
|
||||
|
||||
# getattr will call the ImageFileDescriptor's __get__ method, which
|
||||
# coerces the assigned value into an instance of self.attr_class
|
||||
# (ImageFieldFile in this case).
|
||||
file = getattr(instance, self.attname)
|
||||
|
||||
# Nothing to update if we have no file and not being forced to update.
|
||||
if not file and not force:
|
||||
return
|
||||
|
||||
dimension_fields_filled = not (
|
||||
(self.width_field and not getattr(instance, self.width_field))
|
||||
or (self.height_field and not getattr(instance, self.height_field))
|
||||
)
|
||||
# When both dimension fields have values, we are most likely loading
|
||||
# data from the database or updating an image field that already had
|
||||
# an image stored. In the first case, we don't want to update the
|
||||
# dimension fields because we are already getting their values from the
|
||||
# database. In the second case, we do want to update the dimensions
|
||||
# fields and will skip this return because force will be True since we
|
||||
# were called from ImageFileDescriptor.__set__.
|
||||
if dimension_fields_filled and not force:
|
||||
return
|
||||
|
||||
# file should be an instance of ImageFieldFile or should be None.
|
||||
if file:
|
||||
width = file.width
|
||||
height = file.height
|
||||
else:
|
||||
# No file, so clear dimensions fields.
|
||||
width = None
|
||||
height = None
|
||||
|
||||
# Update the width and height fields.
|
||||
if self.width_field:
|
||||
setattr(instance, self.width_field, width)
|
||||
if self.height_field:
|
||||
setattr(instance, self.height_field, height)
|
||||
|
||||
def formfield(self, **kwargs):
|
||||
return super().formfield(
|
||||
**{
|
||||
"form_class": forms.ImageField,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,638 @@
|
||||
import json
|
||||
import warnings
|
||||
|
||||
from django import forms
|
||||
from django.core import checks, exceptions
|
||||
from django.db import NotSupportedError, connections, router
|
||||
from django.db.models import expressions, lookups
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.db.models.fields import TextField
|
||||
from django.db.models.lookups import (
|
||||
FieldGetDbPrepValueMixin,
|
||||
PostgresOperatorLookup,
|
||||
Transform,
|
||||
)
|
||||
from django.utils.deprecation import RemovedInDjango51Warning
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from . import Field
|
||||
from .mixins import CheckFieldDefaultMixin
|
||||
|
||||
__all__ = ["JSONField"]
|
||||
|
||||
|
||||
class JSONField(CheckFieldDefaultMixin, Field):
|
||||
empty_strings_allowed = False
|
||||
description = _("A JSON object")
|
||||
default_error_messages = {
|
||||
"invalid": _("Value must be valid JSON."),
|
||||
}
|
||||
_default_hint = ("dict", "{}")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
verbose_name=None,
|
||||
name=None,
|
||||
encoder=None,
|
||||
decoder=None,
|
||||
**kwargs,
|
||||
):
|
||||
if encoder and not callable(encoder):
|
||||
raise ValueError("The encoder parameter must be a callable object.")
|
||||
if decoder and not callable(decoder):
|
||||
raise ValueError("The decoder parameter must be a callable object.")
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
super().__init__(verbose_name, name, **kwargs)
|
||||
|
||||
def check(self, **kwargs):
|
||||
errors = super().check(**kwargs)
|
||||
databases = kwargs.get("databases") or []
|
||||
errors.extend(self._check_supported(databases))
|
||||
return errors
|
||||
|
||||
def _check_supported(self, databases):
|
||||
errors = []
|
||||
for db in databases:
|
||||
if not router.allow_migrate_model(db, self.model):
|
||||
continue
|
||||
connection = connections[db]
|
||||
if (
|
||||
self.model._meta.required_db_vendor
|
||||
and self.model._meta.required_db_vendor != connection.vendor
|
||||
):
|
||||
continue
|
||||
if not (
|
||||
"supports_json_field" in self.model._meta.required_db_features
|
||||
or connection.features.supports_json_field
|
||||
):
|
||||
errors.append(
|
||||
checks.Error(
|
||||
"%s does not support JSONFields." % connection.display_name,
|
||||
obj=self.model,
|
||||
id="fields.E180",
|
||||
)
|
||||
)
|
||||
return errors
|
||||
|
||||
def deconstruct(self):
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
if self.encoder is not None:
|
||||
kwargs["encoder"] = self.encoder
|
||||
if self.decoder is not None:
|
||||
kwargs["decoder"] = self.decoder
|
||||
return name, path, args, kwargs
|
||||
|
||||
def from_db_value(self, value, expression, connection):
|
||||
if value is None:
|
||||
return value
|
||||
# Some backends (SQLite at least) extract non-string values in their
|
||||
# SQL datatypes.
|
||||
if isinstance(expression, KeyTransform) and not isinstance(value, str):
|
||||
return value
|
||||
try:
|
||||
return json.loads(value, cls=self.decoder)
|
||||
except json.JSONDecodeError:
|
||||
return value
|
||||
|
||||
def get_internal_type(self):
|
||||
return "JSONField"
|
||||
|
||||
def get_db_prep_value(self, value, connection, prepared=False):
|
||||
if not prepared:
|
||||
value = self.get_prep_value(value)
|
||||
# RemovedInDjango51Warning: When the deprecation ends, replace with:
|
||||
# if (
|
||||
# isinstance(value, expressions.Value)
|
||||
# and isinstance(value.output_field, JSONField)
|
||||
# ):
|
||||
# value = value.value
|
||||
# elif hasattr(value, "as_sql"): ...
|
||||
if isinstance(value, expressions.Value):
|
||||
if isinstance(value.value, str) and not isinstance(
|
||||
value.output_field, JSONField
|
||||
):
|
||||
try:
|
||||
value = json.loads(value.value, cls=self.decoder)
|
||||
except json.JSONDecodeError:
|
||||
value = value.value
|
||||
else:
|
||||
warnings.warn(
|
||||
"Providing an encoded JSON string via Value() is deprecated. "
|
||||
f"Use Value({value!r}, output_field=JSONField()) instead.",
|
||||
category=RemovedInDjango51Warning,
|
||||
)
|
||||
elif isinstance(value.output_field, JSONField):
|
||||
value = value.value
|
||||
else:
|
||||
return value
|
||||
elif hasattr(value, "as_sql"):
|
||||
return value
|
||||
return connection.ops.adapt_json_value(value, self.encoder)
|
||||
|
||||
def get_db_prep_save(self, value, connection):
|
||||
if value is None:
|
||||
return value
|
||||
return self.get_db_prep_value(value, connection)
|
||||
|
||||
def get_transform(self, name):
|
||||
transform = super().get_transform(name)
|
||||
if transform:
|
||||
return transform
|
||||
return KeyTransformFactory(name)
|
||||
|
||||
def validate(self, value, model_instance):
|
||||
super().validate(value, model_instance)
|
||||
try:
|
||||
json.dumps(value, cls=self.encoder)
|
||||
except TypeError:
|
||||
raise exceptions.ValidationError(
|
||||
self.error_messages["invalid"],
|
||||
code="invalid",
|
||||
params={"value": value},
|
||||
)
|
||||
|
||||
def value_to_string(self, obj):
|
||||
return self.value_from_object(obj)
|
||||
|
||||
def formfield(self, **kwargs):
|
||||
return super().formfield(
|
||||
**{
|
||||
"form_class": forms.JSONField,
|
||||
"encoder": self.encoder,
|
||||
"decoder": self.decoder,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def compile_json_path(key_transforms, include_root=True):
|
||||
path = ["$"] if include_root else []
|
||||
for key_transform in key_transforms:
|
||||
try:
|
||||
num = int(key_transform)
|
||||
except ValueError: # non-integer
|
||||
path.append(".")
|
||||
path.append(json.dumps(key_transform))
|
||||
else:
|
||||
path.append("[%s]" % num)
|
||||
return "".join(path)
|
||||
|
||||
|
||||
class DataContains(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
|
||||
lookup_name = "contains"
|
||||
postgres_operator = "@>"
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if not connection.features.supports_json_field_contains:
|
||||
raise NotSupportedError(
|
||||
"contains lookup is not supported on this database backend."
|
||||
)
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||
params = tuple(lhs_params) + tuple(rhs_params)
|
||||
return "JSON_CONTAINS(%s, %s)" % (lhs, rhs), params
|
||||
|
||||
|
||||
class ContainedBy(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
|
||||
lookup_name = "contained_by"
|
||||
postgres_operator = "<@"
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if not connection.features.supports_json_field_contains:
|
||||
raise NotSupportedError(
|
||||
"contained_by lookup is not supported on this database backend."
|
||||
)
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||
params = tuple(rhs_params) + tuple(lhs_params)
|
||||
return "JSON_CONTAINS(%s, %s)" % (rhs, lhs), params
|
||||
|
||||
|
||||
class HasKeyLookup(PostgresOperatorLookup):
|
||||
logical_operator = None
|
||||
|
||||
def compile_json_path_final_key(self, key_transform):
|
||||
# Compile the final key without interpreting ints as array elements.
|
||||
return ".%s" % json.dumps(key_transform)
|
||||
|
||||
def as_sql(self, compiler, connection, template=None):
|
||||
# Process JSON path from the left-hand side.
|
||||
if isinstance(self.lhs, KeyTransform):
|
||||
lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(
|
||||
compiler, connection
|
||||
)
|
||||
lhs_json_path = compile_json_path(lhs_key_transforms)
|
||||
else:
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
lhs_json_path = "$"
|
||||
sql = template % lhs
|
||||
# Process JSON path from the right-hand side.
|
||||
rhs = self.rhs
|
||||
rhs_params = []
|
||||
if not isinstance(rhs, (list, tuple)):
|
||||
rhs = [rhs]
|
||||
for key in rhs:
|
||||
if isinstance(key, KeyTransform):
|
||||
*_, rhs_key_transforms = key.preprocess_lhs(compiler, connection)
|
||||
else:
|
||||
rhs_key_transforms = [key]
|
||||
*rhs_key_transforms, final_key = rhs_key_transforms
|
||||
rhs_json_path = compile_json_path(rhs_key_transforms, include_root=False)
|
||||
rhs_json_path += self.compile_json_path_final_key(final_key)
|
||||
rhs_params.append(lhs_json_path + rhs_json_path)
|
||||
# Add condition for each key.
|
||||
if self.logical_operator:
|
||||
sql = "(%s)" % self.logical_operator.join([sql] * len(rhs_params))
|
||||
return sql, tuple(lhs_params) + tuple(rhs_params)
|
||||
|
||||
def as_mysql(self, compiler, connection):
|
||||
return self.as_sql(
|
||||
compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)"
|
||||
)
|
||||
|
||||
def as_oracle(self, compiler, connection):
|
||||
sql, params = self.as_sql(
|
||||
compiler, connection, template="JSON_EXISTS(%s, '%%s')"
|
||||
)
|
||||
# Add paths directly into SQL because path expressions cannot be passed
|
||||
# as bind variables on Oracle.
|
||||
return sql % tuple(params), []
|
||||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
if isinstance(self.rhs, KeyTransform):
|
||||
*_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection)
|
||||
for key in rhs_key_transforms[:-1]:
|
||||
self.lhs = KeyTransform(key, self.lhs)
|
||||
self.rhs = rhs_key_transforms[-1]
|
||||
return super().as_postgresql(compiler, connection)
|
||||
|
||||
def as_sqlite(self, compiler, connection):
|
||||
return self.as_sql(
|
||||
compiler, connection, template="JSON_TYPE(%s, %%s) IS NOT NULL"
|
||||
)
|
||||
|
||||
|
||||
class HasKey(HasKeyLookup):
|
||||
lookup_name = "has_key"
|
||||
postgres_operator = "?"
|
||||
prepare_rhs = False
|
||||
|
||||
|
||||
class HasKeys(HasKeyLookup):
|
||||
lookup_name = "has_keys"
|
||||
postgres_operator = "?&"
|
||||
logical_operator = " AND "
|
||||
|
||||
def get_prep_lookup(self):
|
||||
return [str(item) for item in self.rhs]
|
||||
|
||||
|
||||
class HasAnyKeys(HasKeys):
|
||||
lookup_name = "has_any_keys"
|
||||
postgres_operator = "?|"
|
||||
logical_operator = " OR "
|
||||
|
||||
|
||||
class HasKeyOrArrayIndex(HasKey):
|
||||
def compile_json_path_final_key(self, key_transform):
|
||||
return compile_json_path([key_transform], include_root=False)
|
||||
|
||||
|
||||
class CaseInsensitiveMixin:
|
||||
"""
|
||||
Mixin to allow case-insensitive comparison of JSON values on MySQL.
|
||||
MySQL handles strings used in JSON context using the utf8mb4_bin collation.
|
||||
Because utf8mb4_bin is a binary collation, comparison of JSON values is
|
||||
case-sensitive.
|
||||
"""
|
||||
|
||||
def process_lhs(self, compiler, connection):
|
||||
lhs, lhs_params = super().process_lhs(compiler, connection)
|
||||
if connection.vendor == "mysql":
|
||||
return "LOWER(%s)" % lhs, lhs_params
|
||||
return lhs, lhs_params
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
if connection.vendor == "mysql":
|
||||
return "LOWER(%s)" % rhs, rhs_params
|
||||
return rhs, rhs_params
|
||||
|
||||
|
||||
class JSONExact(lookups.Exact):
|
||||
can_use_none_as_rhs = True
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
# Treat None lookup values as null.
|
||||
if rhs == "%s" and rhs_params == [None]:
|
||||
rhs_params = ["null"]
|
||||
if connection.vendor == "mysql":
|
||||
func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params)
|
||||
rhs %= tuple(func)
|
||||
return rhs, rhs_params
|
||||
|
||||
|
||||
class JSONIContains(CaseInsensitiveMixin, lookups.IContains):
|
||||
pass
|
||||
|
||||
|
||||
JSONField.register_lookup(DataContains)
|
||||
JSONField.register_lookup(ContainedBy)
|
||||
JSONField.register_lookup(HasKey)
|
||||
JSONField.register_lookup(HasKeys)
|
||||
JSONField.register_lookup(HasAnyKeys)
|
||||
JSONField.register_lookup(JSONExact)
|
||||
JSONField.register_lookup(JSONIContains)
|
||||
|
||||
|
||||
class KeyTransform(Transform):
|
||||
postgres_operator = "->"
|
||||
postgres_nested_operator = "#>"
|
||||
|
||||
def __init__(self, key_name, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.key_name = str(key_name)
|
||||
|
||||
def preprocess_lhs(self, compiler, connection):
|
||||
key_transforms = [self.key_name]
|
||||
previous = self.lhs
|
||||
while isinstance(previous, KeyTransform):
|
||||
key_transforms.insert(0, previous.key_name)
|
||||
previous = previous.lhs
|
||||
lhs, params = compiler.compile(previous)
|
||||
if connection.vendor == "oracle":
|
||||
# Escape string-formatting.
|
||||
key_transforms = [key.replace("%", "%%") for key in key_transforms]
|
||||
return lhs, params, key_transforms
|
||||
|
||||
def as_mysql(self, compiler, connection):
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
json_path = compile_json_path(key_transforms)
|
||||
return "JSON_EXTRACT(%s, %%s)" % lhs, tuple(params) + (json_path,)
|
||||
|
||||
def as_oracle(self, compiler, connection):
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
json_path = compile_json_path(key_transforms)
|
||||
return (
|
||||
"COALESCE(JSON_QUERY(%s, '%s'), JSON_VALUE(%s, '%s'))"
|
||||
% ((lhs, json_path) * 2)
|
||||
), tuple(params) * 2
|
||||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
if len(key_transforms) > 1:
|
||||
sql = "(%s %s %%s)" % (lhs, self.postgres_nested_operator)
|
||||
return sql, tuple(params) + (key_transforms,)
|
||||
try:
|
||||
lookup = int(self.key_name)
|
||||
except ValueError:
|
||||
lookup = self.key_name
|
||||
return "(%s %s %%s)" % (lhs, self.postgres_operator), tuple(params) + (lookup,)
|
||||
|
||||
def as_sqlite(self, compiler, connection):
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
json_path = compile_json_path(key_transforms)
|
||||
datatype_values = ",".join(
|
||||
[repr(datatype) for datatype in connection.ops.jsonfield_datatype_values]
|
||||
)
|
||||
return (
|
||||
"(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) "
|
||||
"THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)"
|
||||
) % (lhs, datatype_values, lhs, lhs), (tuple(params) + (json_path,)) * 3
|
||||
|
||||
|
||||
class KeyTextTransform(KeyTransform):
|
||||
postgres_operator = "->>"
|
||||
postgres_nested_operator = "#>>"
|
||||
output_field = TextField()
|
||||
|
||||
def as_mysql(self, compiler, connection):
|
||||
if connection.mysql_is_mariadb:
|
||||
# MariaDB doesn't support -> and ->> operators (see MDEV-13594).
|
||||
sql, params = super().as_mysql(compiler, connection)
|
||||
return "JSON_UNQUOTE(%s)" % sql, params
|
||||
else:
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
json_path = compile_json_path(key_transforms)
|
||||
return "(%s ->> %%s)" % lhs, tuple(params) + (json_path,)
|
||||
|
||||
@classmethod
|
||||
def from_lookup(cls, lookup):
|
||||
transform, *keys = lookup.split(LOOKUP_SEP)
|
||||
if not keys:
|
||||
raise ValueError("Lookup must contain key or index transforms.")
|
||||
for key in keys:
|
||||
transform = cls(key, transform)
|
||||
return transform
|
||||
|
||||
|
||||
KT = KeyTextTransform.from_lookup
|
||||
|
||||
|
||||
class KeyTransformTextLookupMixin:
|
||||
"""
|
||||
Mixin for combining with a lookup expecting a text lhs from a JSONField
|
||||
key lookup. On PostgreSQL, make use of the ->> operator instead of casting
|
||||
key values to text and performing the lookup on the resulting
|
||||
representation.
|
||||
"""
|
||||
|
||||
def __init__(self, key_transform, *args, **kwargs):
|
||||
if not isinstance(key_transform, KeyTransform):
|
||||
raise TypeError(
|
||||
"Transform should be an instance of KeyTransform in order to "
|
||||
"use this lookup."
|
||||
)
|
||||
key_text_transform = KeyTextTransform(
|
||||
key_transform.key_name,
|
||||
*key_transform.source_expressions,
|
||||
**key_transform.extra,
|
||||
)
|
||||
super().__init__(key_text_transform, *args, **kwargs)
|
||||
|
||||
|
||||
class KeyTransformIsNull(lookups.IsNull):
|
||||
# key__isnull=False is the same as has_key='key'
|
||||
def as_oracle(self, compiler, connection):
|
||||
sql, params = HasKeyOrArrayIndex(
|
||||
self.lhs.lhs,
|
||||
self.lhs.key_name,
|
||||
).as_oracle(compiler, connection)
|
||||
if not self.rhs:
|
||||
return sql, params
|
||||
# Column doesn't have a key or IS NULL.
|
||||
lhs, lhs_params, _ = self.lhs.preprocess_lhs(compiler, connection)
|
||||
return "(NOT %s OR %s IS NULL)" % (sql, lhs), tuple(params) + tuple(lhs_params)
|
||||
|
||||
def as_sqlite(self, compiler, connection):
|
||||
template = "JSON_TYPE(%s, %%s) IS NULL"
|
||||
if not self.rhs:
|
||||
template = "JSON_TYPE(%s, %%s) IS NOT NULL"
|
||||
return HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name).as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template=template,
|
||||
)
|
||||
|
||||
|
||||
class KeyTransformIn(lookups.In):
|
||||
def resolve_expression_parameter(self, compiler, connection, sql, param):
|
||||
sql, params = super().resolve_expression_parameter(
|
||||
compiler,
|
||||
connection,
|
||||
sql,
|
||||
param,
|
||||
)
|
||||
if (
|
||||
not hasattr(param, "as_sql")
|
||||
and not connection.features.has_native_json_field
|
||||
):
|
||||
if connection.vendor == "oracle":
|
||||
value = json.loads(param)
|
||||
sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
|
||||
if isinstance(value, (list, dict)):
|
||||
sql %= "JSON_QUERY"
|
||||
else:
|
||||
sql %= "JSON_VALUE"
|
||||
elif connection.vendor == "mysql" or (
|
||||
connection.vendor == "sqlite"
|
||||
and params[0] not in connection.ops.jsonfield_datatype_values
|
||||
):
|
||||
sql = "JSON_EXTRACT(%s, '$')"
|
||||
if connection.vendor == "mysql" and connection.mysql_is_mariadb:
|
||||
sql = "JSON_UNQUOTE(%s)" % sql
|
||||
return sql, params
|
||||
|
||||
|
||||
class KeyTransformExact(JSONExact):
|
||||
def process_rhs(self, compiler, connection):
|
||||
if isinstance(self.rhs, KeyTransform):
|
||||
return super(lookups.Exact, self).process_rhs(compiler, connection)
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
if connection.vendor == "oracle":
|
||||
func = []
|
||||
sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
|
||||
for value in rhs_params:
|
||||
value = json.loads(value)
|
||||
if isinstance(value, (list, dict)):
|
||||
func.append(sql % "JSON_QUERY")
|
||||
else:
|
||||
func.append(sql % "JSON_VALUE")
|
||||
rhs %= tuple(func)
|
||||
elif connection.vendor == "sqlite":
|
||||
func = []
|
||||
for value in rhs_params:
|
||||
if value in connection.ops.jsonfield_datatype_values:
|
||||
func.append("%s")
|
||||
else:
|
||||
func.append("JSON_EXTRACT(%s, '$')")
|
||||
rhs %= tuple(func)
|
||||
return rhs, rhs_params
|
||||
|
||||
def as_oracle(self, compiler, connection):
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
if rhs_params == ["null"]:
|
||||
# Field has key and it's NULL.
|
||||
has_key_expr = HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name)
|
||||
has_key_sql, has_key_params = has_key_expr.as_oracle(compiler, connection)
|
||||
is_null_expr = self.lhs.get_lookup("isnull")(self.lhs, True)
|
||||
is_null_sql, is_null_params = is_null_expr.as_sql(compiler, connection)
|
||||
return (
|
||||
"%s AND %s" % (has_key_sql, is_null_sql),
|
||||
tuple(has_key_params) + tuple(is_null_params),
|
||||
)
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
|
||||
class KeyTransformIExact(
|
||||
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformIContains(
|
||||
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformIStartsWith(
|
||||
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformIEndsWith(
|
||||
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformIRegex(
|
||||
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformNumericLookupMixin:
|
||||
def process_rhs(self, compiler, connection):
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
if not connection.features.has_native_json_field:
|
||||
rhs_params = [json.loads(value) for value in rhs_params]
|
||||
return rhs, rhs_params
|
||||
|
||||
|
||||
class KeyTransformLt(KeyTransformNumericLookupMixin, lookups.LessThan):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformLte(KeyTransformNumericLookupMixin, lookups.LessThanOrEqual):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformGt(KeyTransformNumericLookupMixin, lookups.GreaterThan):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformGte(KeyTransformNumericLookupMixin, lookups.GreaterThanOrEqual):
|
||||
pass
|
||||
|
||||
|
||||
KeyTransform.register_lookup(KeyTransformIn)
|
||||
KeyTransform.register_lookup(KeyTransformExact)
|
||||
KeyTransform.register_lookup(KeyTransformIExact)
|
||||
KeyTransform.register_lookup(KeyTransformIsNull)
|
||||
KeyTransform.register_lookup(KeyTransformIContains)
|
||||
KeyTransform.register_lookup(KeyTransformStartsWith)
|
||||
KeyTransform.register_lookup(KeyTransformIStartsWith)
|
||||
KeyTransform.register_lookup(KeyTransformEndsWith)
|
||||
KeyTransform.register_lookup(KeyTransformIEndsWith)
|
||||
KeyTransform.register_lookup(KeyTransformRegex)
|
||||
KeyTransform.register_lookup(KeyTransformIRegex)
|
||||
|
||||
KeyTransform.register_lookup(KeyTransformLt)
|
||||
KeyTransform.register_lookup(KeyTransformLte)
|
||||
KeyTransform.register_lookup(KeyTransformGt)
|
||||
KeyTransform.register_lookup(KeyTransformGte)
|
||||
|
||||
|
||||
class KeyTransformFactory:
|
||||
def __init__(self, key_name):
|
||||
self.key_name = key_name
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return KeyTransform(self.key_name, *args, **kwargs)
|
||||
@@ -0,0 +1,59 @@
|
||||
from django.core import checks
|
||||
|
||||
NOT_PROVIDED = object()
|
||||
|
||||
|
||||
class FieldCacheMixin:
|
||||
"""Provide an API for working with the model's fields value cache."""
|
||||
|
||||
def get_cache_name(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_cached_value(self, instance, default=NOT_PROVIDED):
|
||||
cache_name = self.get_cache_name()
|
||||
try:
|
||||
return instance._state.fields_cache[cache_name]
|
||||
except KeyError:
|
||||
if default is NOT_PROVIDED:
|
||||
raise
|
||||
return default
|
||||
|
||||
def is_cached(self, instance):
|
||||
return self.get_cache_name() in instance._state.fields_cache
|
||||
|
||||
def set_cached_value(self, instance, value):
|
||||
instance._state.fields_cache[self.get_cache_name()] = value
|
||||
|
||||
def delete_cached_value(self, instance):
|
||||
del instance._state.fields_cache[self.get_cache_name()]
|
||||
|
||||
|
||||
class CheckFieldDefaultMixin:
|
||||
_default_hint = ("<valid default>", "<invalid default>")
|
||||
|
||||
def _check_default(self):
|
||||
if (
|
||||
self.has_default()
|
||||
and self.default is not None
|
||||
and not callable(self.default)
|
||||
):
|
||||
return [
|
||||
checks.Warning(
|
||||
"%s default should be a callable instead of an instance "
|
||||
"so that it's not shared between all field instances."
|
||||
% (self.__class__.__name__,),
|
||||
hint=(
|
||||
"Use a callable instead, e.g., use `%s` instead of "
|
||||
"`%s`." % self._default_hint
|
||||
),
|
||||
obj=self,
|
||||
id="fields.E010",
|
||||
)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def check(self, **kwargs):
|
||||
errors = super().check(**kwargs)
|
||||
errors.extend(self._check_default())
|
||||
return errors
|
||||
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
Field-like classes that aren't really fields. It's easier to use objects that
|
||||
have the same attributes as fields sometimes (avoids a lot of special casing).
|
||||
"""
|
||||
|
||||
from django.db.models import fields
|
||||
|
||||
|
||||
class OrderWrt(fields.IntegerField):
|
||||
"""
|
||||
A proxy for the _order database field that is used when
|
||||
Meta.order_with_respect_to is specified.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs["name"] = "_order"
|
||||
kwargs["editable"] = False
|
||||
super().__init__(*args, **kwargs)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,209 @@
|
||||
import warnings
|
||||
|
||||
from django.db.models.lookups import (
|
||||
Exact,
|
||||
GreaterThan,
|
||||
GreaterThanOrEqual,
|
||||
In,
|
||||
IsNull,
|
||||
LessThan,
|
||||
LessThanOrEqual,
|
||||
)
|
||||
from django.utils.deprecation import RemovedInDjango50Warning
|
||||
|
||||
|
||||
class MultiColSource:
|
||||
contains_aggregate = False
|
||||
contains_over_clause = False
|
||||
|
||||
def __init__(self, alias, targets, sources, field):
|
||||
self.targets, self.sources, self.field, self.alias = (
|
||||
targets,
|
||||
sources,
|
||||
field,
|
||||
alias,
|
||||
)
|
||||
self.output_field = self.field
|
||||
|
||||
def __repr__(self):
|
||||
return "{}({}, {})".format(self.__class__.__name__, self.alias, self.field)
|
||||
|
||||
def relabeled_clone(self, relabels):
|
||||
return self.__class__(
|
||||
relabels.get(self.alias, self.alias), self.targets, self.sources, self.field
|
||||
)
|
||||
|
||||
def get_lookup(self, lookup):
|
||||
return self.output_field.get_lookup(lookup)
|
||||
|
||||
def resolve_expression(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
|
||||
def get_normalized_value(value, lhs):
|
||||
from django.db.models import Model
|
||||
|
||||
if isinstance(value, Model):
|
||||
if value.pk is None:
|
||||
# When the deprecation ends, replace with:
|
||||
# raise ValueError(
|
||||
# "Model instances passed to related filters must be saved."
|
||||
# )
|
||||
warnings.warn(
|
||||
"Passing unsaved model instances to related filters is deprecated.",
|
||||
RemovedInDjango50Warning,
|
||||
)
|
||||
value_list = []
|
||||
sources = lhs.output_field.path_infos[-1].target_fields
|
||||
for source in sources:
|
||||
while not isinstance(value, source.model) and source.remote_field:
|
||||
source = source.remote_field.model._meta.get_field(
|
||||
source.remote_field.field_name
|
||||
)
|
||||
try:
|
||||
value_list.append(getattr(value, source.attname))
|
||||
except AttributeError:
|
||||
# A case like Restaurant.objects.filter(place=restaurant_instance),
|
||||
# where place is a OneToOneField and the primary key of Restaurant.
|
||||
return (value.pk,)
|
||||
return tuple(value_list)
|
||||
if not isinstance(value, tuple):
|
||||
return (value,)
|
||||
return value
|
||||
|
||||
|
||||
class RelatedIn(In):
|
||||
def get_prep_lookup(self):
|
||||
if not isinstance(self.lhs, MultiColSource):
|
||||
if self.rhs_is_direct_value():
|
||||
# If we get here, we are dealing with single-column relations.
|
||||
self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs]
|
||||
# We need to run the related field's get_prep_value(). Consider
|
||||
# case ForeignKey to IntegerField given value 'abc'. The
|
||||
# ForeignKey itself doesn't have validation for non-integers,
|
||||
# so we must run validation using the target field.
|
||||
if hasattr(self.lhs.output_field, "path_infos"):
|
||||
# Run the target field's get_prep_value. We can safely
|
||||
# assume there is only one as we don't get to the direct
|
||||
# value branch otherwise.
|
||||
target_field = self.lhs.output_field.path_infos[-1].target_fields[
|
||||
-1
|
||||
]
|
||||
self.rhs = [target_field.get_prep_value(v) for v in self.rhs]
|
||||
elif not getattr(self.rhs, "has_select_fields", True) and not getattr(
|
||||
self.lhs.field.target_field, "primary_key", False
|
||||
):
|
||||
if (
|
||||
getattr(self.lhs.output_field, "primary_key", False)
|
||||
and self.lhs.output_field.model == self.rhs.model
|
||||
):
|
||||
# A case like
|
||||
# Restaurant.objects.filter(place__in=restaurant_qs), where
|
||||
# place is a OneToOneField and the primary key of
|
||||
# Restaurant.
|
||||
target_field = self.lhs.field.name
|
||||
else:
|
||||
target_field = self.lhs.field.target_field.name
|
||||
self.rhs.set_values([target_field])
|
||||
return super().get_prep_lookup()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if isinstance(self.lhs, MultiColSource):
|
||||
# For multicolumn lookups we need to build a multicolumn where clause.
|
||||
# This clause is either a SubqueryConstraint (for values that need
|
||||
# to be compiled to SQL) or an OR-combined list of
|
||||
# (col1 = val1 AND col2 = val2 AND ...) clauses.
|
||||
from django.db.models.sql.where import (
|
||||
AND,
|
||||
OR,
|
||||
SubqueryConstraint,
|
||||
WhereNode,
|
||||
)
|
||||
|
||||
root_constraint = WhereNode(connector=OR)
|
||||
if self.rhs_is_direct_value():
|
||||
values = [get_normalized_value(value, self.lhs) for value in self.rhs]
|
||||
for value in values:
|
||||
value_constraint = WhereNode()
|
||||
for source, target, val in zip(
|
||||
self.lhs.sources, self.lhs.targets, value
|
||||
):
|
||||
lookup_class = target.get_lookup("exact")
|
||||
lookup = lookup_class(
|
||||
target.get_col(self.lhs.alias, source), val
|
||||
)
|
||||
value_constraint.add(lookup, AND)
|
||||
root_constraint.add(value_constraint, OR)
|
||||
else:
|
||||
root_constraint.add(
|
||||
SubqueryConstraint(
|
||||
self.lhs.alias,
|
||||
[target.column for target in self.lhs.targets],
|
||||
[source.name for source in self.lhs.sources],
|
||||
self.rhs,
|
||||
),
|
||||
AND,
|
||||
)
|
||||
return root_constraint.as_sql(compiler, connection)
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
|
||||
class RelatedLookupMixin:
|
||||
def get_prep_lookup(self):
|
||||
if not isinstance(self.lhs, MultiColSource) and not hasattr(
|
||||
self.rhs, "resolve_expression"
|
||||
):
|
||||
# If we get here, we are dealing with single-column relations.
|
||||
self.rhs = get_normalized_value(self.rhs, self.lhs)[0]
|
||||
# We need to run the related field's get_prep_value(). Consider case
|
||||
# ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
|
||||
# doesn't have validation for non-integers, so we must run validation
|
||||
# using the target field.
|
||||
if self.prepare_rhs and hasattr(self.lhs.output_field, "path_infos"):
|
||||
# Get the target field. We can safely assume there is only one
|
||||
# as we don't get to the direct value branch otherwise.
|
||||
target_field = self.lhs.output_field.path_infos[-1].target_fields[-1]
|
||||
self.rhs = target_field.get_prep_value(self.rhs)
|
||||
|
||||
return super().get_prep_lookup()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if isinstance(self.lhs, MultiColSource):
|
||||
assert self.rhs_is_direct_value()
|
||||
self.rhs = get_normalized_value(self.rhs, self.lhs)
|
||||
from django.db.models.sql.where import AND, WhereNode
|
||||
|
||||
root_constraint = WhereNode()
|
||||
for target, source, val in zip(
|
||||
self.lhs.targets, self.lhs.sources, self.rhs
|
||||
):
|
||||
lookup_class = target.get_lookup(self.lookup_name)
|
||||
root_constraint.add(
|
||||
lookup_class(target.get_col(self.lhs.alias, source), val), AND
|
||||
)
|
||||
return root_constraint.as_sql(compiler, connection)
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
|
||||
class RelatedExact(RelatedLookupMixin, Exact):
|
||||
pass
|
||||
|
||||
|
||||
class RelatedLessThan(RelatedLookupMixin, LessThan):
|
||||
pass
|
||||
|
||||
|
||||
class RelatedGreaterThan(RelatedLookupMixin, GreaterThan):
|
||||
pass
|
||||
|
||||
|
||||
class RelatedGreaterThanOrEqual(RelatedLookupMixin, GreaterThanOrEqual):
|
||||
pass
|
||||
|
||||
|
||||
class RelatedLessThanOrEqual(RelatedLookupMixin, LessThanOrEqual):
|
||||
pass
|
||||
|
||||
|
||||
class RelatedIsNull(RelatedLookupMixin, IsNull):
|
||||
pass
|
||||
@@ -0,0 +1,402 @@
|
||||
"""
|
||||
"Rel objects" for related fields.
|
||||
|
||||
"Rel objects" (for lack of a better name) carry information about the relation
|
||||
modeled by a related field and provide some utility functions. They're stored
|
||||
in the ``remote_field`` attribute of the field.
|
||||
|
||||
They also act as reverse fields for the purposes of the Meta API because
|
||||
they're the closest concept currently available.
|
||||
"""
|
||||
|
||||
from django.core import exceptions
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.hashable import make_hashable
|
||||
|
||||
from . import BLANK_CHOICE_DASH
|
||||
from .mixins import FieldCacheMixin
|
||||
|
||||
|
||||
class ForeignObjectRel(FieldCacheMixin):
|
||||
"""
|
||||
Used by ForeignObject to store information about the relation.
|
||||
|
||||
``_meta.get_fields()`` returns this class to provide access to the field
|
||||
flags for the reverse relation.
|
||||
"""
|
||||
|
||||
# Field flags
|
||||
auto_created = True
|
||||
concrete = False
|
||||
editable = False
|
||||
is_relation = True
|
||||
|
||||
# Reverse relations are always nullable (Django can't enforce that a
|
||||
# foreign key on the related model points to this model).
|
||||
null = True
|
||||
empty_strings_allowed = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
field,
|
||||
to,
|
||||
related_name=None,
|
||||
related_query_name=None,
|
||||
limit_choices_to=None,
|
||||
parent_link=False,
|
||||
on_delete=None,
|
||||
):
|
||||
self.field = field
|
||||
self.model = to
|
||||
self.related_name = related_name
|
||||
self.related_query_name = related_query_name
|
||||
self.limit_choices_to = {} if limit_choices_to is None else limit_choices_to
|
||||
self.parent_link = parent_link
|
||||
self.on_delete = on_delete
|
||||
|
||||
self.symmetrical = False
|
||||
self.multiple = True
|
||||
|
||||
# Some of the following cached_properties can't be initialized in
|
||||
# __init__ as the field doesn't have its model yet. Calling these methods
|
||||
# before field.contribute_to_class() has been called will result in
|
||||
# AttributeError
|
||||
@cached_property
|
||||
def hidden(self):
|
||||
return self.is_hidden()
|
||||
|
||||
@cached_property
|
||||
def name(self):
|
||||
return self.field.related_query_name()
|
||||
|
||||
@property
|
||||
def remote_field(self):
|
||||
return self.field
|
||||
|
||||
@property
|
||||
def target_field(self):
|
||||
"""
|
||||
When filtering against this relation, return the field on the remote
|
||||
model against which the filtering should happen.
|
||||
"""
|
||||
target_fields = self.path_infos[-1].target_fields
|
||||
if len(target_fields) > 1:
|
||||
raise exceptions.FieldError(
|
||||
"Can't use target_field for multicolumn relations."
|
||||
)
|
||||
return target_fields[0]
|
||||
|
||||
@cached_property
|
||||
def related_model(self):
|
||||
if not self.field.model:
|
||||
raise AttributeError(
|
||||
"This property can't be accessed before self.field.contribute_to_class "
|
||||
"has been called."
|
||||
)
|
||||
return self.field.model
|
||||
|
||||
@cached_property
|
||||
def many_to_many(self):
|
||||
return self.field.many_to_many
|
||||
|
||||
@cached_property
|
||||
def many_to_one(self):
|
||||
return self.field.one_to_many
|
||||
|
||||
@cached_property
|
||||
def one_to_many(self):
|
||||
return self.field.many_to_one
|
||||
|
||||
@cached_property
|
||||
def one_to_one(self):
|
||||
return self.field.one_to_one
|
||||
|
||||
def get_lookup(self, lookup_name):
|
||||
return self.field.get_lookup(lookup_name)
|
||||
|
||||
def get_lookups(self):
|
||||
return self.field.get_lookups()
|
||||
|
||||
def get_transform(self, name):
|
||||
return self.field.get_transform(name)
|
||||
|
||||
def get_internal_type(self):
|
||||
return self.field.get_internal_type()
|
||||
|
||||
@property
|
||||
def db_type(self):
|
||||
return self.field.db_type
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s: %s.%s>" % (
|
||||
type(self).__name__,
|
||||
self.related_model._meta.app_label,
|
||||
self.related_model._meta.model_name,
|
||||
)
|
||||
|
||||
@property
|
||||
def identity(self):
|
||||
return (
|
||||
self.field,
|
||||
self.model,
|
||||
self.related_name,
|
||||
self.related_query_name,
|
||||
make_hashable(self.limit_choices_to),
|
||||
self.parent_link,
|
||||
self.on_delete,
|
||||
self.symmetrical,
|
||||
self.multiple,
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, self.__class__):
|
||||
return NotImplemented
|
||||
return self.identity == other.identity
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.identity)
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
# Delete the path_infos cached property because it can be recalculated
|
||||
# at first invocation after deserialization. The attribute must be
|
||||
# removed because subclasses like ManyToOneRel may have a PathInfo
|
||||
# which contains an intermediate M2M table that's been dynamically
|
||||
# created and doesn't exist in the .models module.
|
||||
# This is a reverse relation, so there is no reverse_path_infos to
|
||||
# delete.
|
||||
state.pop("path_infos", None)
|
||||
return state
|
||||
|
||||
def get_choices(
|
||||
self,
|
||||
include_blank=True,
|
||||
blank_choice=BLANK_CHOICE_DASH,
|
||||
limit_choices_to=None,
|
||||
ordering=(),
|
||||
):
|
||||
"""
|
||||
Return choices with a default blank choices included, for use
|
||||
as <select> choices for this field.
|
||||
|
||||
Analog of django.db.models.fields.Field.get_choices(), provided
|
||||
initially for utilization by RelatedFieldListFilter.
|
||||
"""
|
||||
limit_choices_to = limit_choices_to or self.limit_choices_to
|
||||
qs = self.related_model._default_manager.complex_filter(limit_choices_to)
|
||||
if ordering:
|
||||
qs = qs.order_by(*ordering)
|
||||
return (blank_choice if include_blank else []) + [(x.pk, str(x)) for x in qs]
|
||||
|
||||
def is_hidden(self):
|
||||
"""Should the related object be hidden?"""
|
||||
return bool(self.related_name) and self.related_name[-1] == "+"
|
||||
|
||||
def get_joining_columns(self):
|
||||
return self.field.get_reverse_joining_columns()
|
||||
|
||||
def get_extra_restriction(self, alias, related_alias):
|
||||
return self.field.get_extra_restriction(related_alias, alias)
|
||||
|
||||
def set_field_name(self):
|
||||
"""
|
||||
Set the related field's name, this is not available until later stages
|
||||
of app loading, so set_field_name is called from
|
||||
set_attributes_from_rel()
|
||||
"""
|
||||
# By default foreign object doesn't relate to any remote field (for
|
||||
# example custom multicolumn joins currently have no remote field).
|
||||
self.field_name = None
|
||||
|
||||
def get_accessor_name(self, model=None):
|
||||
# This method encapsulates the logic that decides what name to give an
|
||||
# accessor descriptor that retrieves related many-to-one or
|
||||
# many-to-many objects. It uses the lowercased object_name + "_set",
|
||||
# but this can be overridden with the "related_name" option. Due to
|
||||
# backwards compatibility ModelForms need to be able to provide an
|
||||
# alternate model. See BaseInlineFormSet.get_default_prefix().
|
||||
opts = model._meta if model else self.related_model._meta
|
||||
model = model or self.related_model
|
||||
if self.multiple:
|
||||
# If this is a symmetrical m2m relation on self, there is no
|
||||
# reverse accessor.
|
||||
if self.symmetrical and model == self.model:
|
||||
return None
|
||||
if self.related_name:
|
||||
return self.related_name
|
||||
return opts.model_name + ("_set" if self.multiple else "")
|
||||
|
||||
def get_path_info(self, filtered_relation=None):
|
||||
if filtered_relation:
|
||||
return self.field.get_reverse_path_info(filtered_relation)
|
||||
else:
|
||||
return self.field.reverse_path_infos
|
||||
|
||||
@cached_property
|
||||
def path_infos(self):
|
||||
return self.get_path_info()
|
||||
|
||||
def get_cache_name(self):
|
||||
"""
|
||||
Return the name of the cache key to use for storing an instance of the
|
||||
forward model on the reverse model.
|
||||
"""
|
||||
return self.get_accessor_name()
|
||||
|
||||
|
||||
class ManyToOneRel(ForeignObjectRel):
|
||||
"""
|
||||
Used by the ForeignKey field to store information about the relation.
|
||||
|
||||
``_meta.get_fields()`` returns this class to provide access to the field
|
||||
flags for the reverse relation.
|
||||
|
||||
Note: Because we somewhat abuse the Rel objects by using them as reverse
|
||||
fields we get the funny situation where
|
||||
``ManyToOneRel.many_to_one == False`` and
|
||||
``ManyToOneRel.one_to_many == True``. This is unfortunate but the actual
|
||||
ManyToOneRel class is a private API and there is work underway to turn
|
||||
reverse relations into actual fields.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
field,
|
||||
to,
|
||||
field_name,
|
||||
related_name=None,
|
||||
related_query_name=None,
|
||||
limit_choices_to=None,
|
||||
parent_link=False,
|
||||
on_delete=None,
|
||||
):
|
||||
super().__init__(
|
||||
field,
|
||||
to,
|
||||
related_name=related_name,
|
||||
related_query_name=related_query_name,
|
||||
limit_choices_to=limit_choices_to,
|
||||
parent_link=parent_link,
|
||||
on_delete=on_delete,
|
||||
)
|
||||
|
||||
self.field_name = field_name
|
||||
|
||||
def __getstate__(self):
|
||||
state = super().__getstate__()
|
||||
state.pop("related_model", None)
|
||||
return state
|
||||
|
||||
@property
|
||||
def identity(self):
|
||||
return super().identity + (self.field_name,)
|
||||
|
||||
def get_related_field(self):
|
||||
"""
|
||||
Return the Field in the 'to' object to which this relationship is tied.
|
||||
"""
|
||||
field = self.model._meta.get_field(self.field_name)
|
||||
if not field.concrete:
|
||||
raise exceptions.FieldDoesNotExist(
|
||||
"No related field named '%s'" % self.field_name
|
||||
)
|
||||
return field
|
||||
|
||||
def set_field_name(self):
|
||||
self.field_name = self.field_name or self.model._meta.pk.name
|
||||
|
||||
|
||||
class OneToOneRel(ManyToOneRel):
|
||||
"""
|
||||
Used by OneToOneField to store information about the relation.
|
||||
|
||||
``_meta.get_fields()`` returns this class to provide access to the field
|
||||
flags for the reverse relation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
field,
|
||||
to,
|
||||
field_name,
|
||||
related_name=None,
|
||||
related_query_name=None,
|
||||
limit_choices_to=None,
|
||||
parent_link=False,
|
||||
on_delete=None,
|
||||
):
|
||||
super().__init__(
|
||||
field,
|
||||
to,
|
||||
field_name,
|
||||
related_name=related_name,
|
||||
related_query_name=related_query_name,
|
||||
limit_choices_to=limit_choices_to,
|
||||
parent_link=parent_link,
|
||||
on_delete=on_delete,
|
||||
)
|
||||
|
||||
self.multiple = False
|
||||
|
||||
|
||||
class ManyToManyRel(ForeignObjectRel):
|
||||
"""
|
||||
Used by ManyToManyField to store information about the relation.
|
||||
|
||||
``_meta.get_fields()`` returns this class to provide access to the field
|
||||
flags for the reverse relation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
field,
|
||||
to,
|
||||
related_name=None,
|
||||
related_query_name=None,
|
||||
limit_choices_to=None,
|
||||
symmetrical=True,
|
||||
through=None,
|
||||
through_fields=None,
|
||||
db_constraint=True,
|
||||
):
|
||||
super().__init__(
|
||||
field,
|
||||
to,
|
||||
related_name=related_name,
|
||||
related_query_name=related_query_name,
|
||||
limit_choices_to=limit_choices_to,
|
||||
)
|
||||
|
||||
if through and not db_constraint:
|
||||
raise ValueError("Can't supply a through model and db_constraint=False")
|
||||
self.through = through
|
||||
|
||||
if through_fields and not through:
|
||||
raise ValueError("Cannot specify through_fields without a through model")
|
||||
self.through_fields = through_fields
|
||||
|
||||
self.symmetrical = symmetrical
|
||||
self.db_constraint = db_constraint
|
||||
|
||||
@property
|
||||
def identity(self):
|
||||
return super().identity + (
|
||||
self.through,
|
||||
make_hashable(self.through_fields),
|
||||
self.db_constraint,
|
||||
)
|
||||
|
||||
def get_related_field(self):
|
||||
"""
|
||||
Return the field in the 'to' object to which this relationship is tied.
|
||||
Provided for symmetry with ManyToOneRel.
|
||||
"""
|
||||
opts = self.through._meta
|
||||
if self.through_fields:
|
||||
field = opts.get_field(self.through_fields[0])
|
||||
else:
|
||||
for field in opts.fields:
|
||||
rel = getattr(field, "remote_field", None)
|
||||
if rel and rel.model == self.model:
|
||||
break
|
||||
return field.foreign_related_fields[0]
|
||||
@@ -0,0 +1,190 @@
|
||||
from .comparison import Cast, Coalesce, Collate, Greatest, JSONObject, Least, NullIf
|
||||
from .datetime import (
|
||||
Extract,
|
||||
ExtractDay,
|
||||
ExtractHour,
|
||||
ExtractIsoWeekDay,
|
||||
ExtractIsoYear,
|
||||
ExtractMinute,
|
||||
ExtractMonth,
|
||||
ExtractQuarter,
|
||||
ExtractSecond,
|
||||
ExtractWeek,
|
||||
ExtractWeekDay,
|
||||
ExtractYear,
|
||||
Now,
|
||||
Trunc,
|
||||
TruncDate,
|
||||
TruncDay,
|
||||
TruncHour,
|
||||
TruncMinute,
|
||||
TruncMonth,
|
||||
TruncQuarter,
|
||||
TruncSecond,
|
||||
TruncTime,
|
||||
TruncWeek,
|
||||
TruncYear,
|
||||
)
|
||||
from .math import (
|
||||
Abs,
|
||||
ACos,
|
||||
ASin,
|
||||
ATan,
|
||||
ATan2,
|
||||
Ceil,
|
||||
Cos,
|
||||
Cot,
|
||||
Degrees,
|
||||
Exp,
|
||||
Floor,
|
||||
Ln,
|
||||
Log,
|
||||
Mod,
|
||||
Pi,
|
||||
Power,
|
||||
Radians,
|
||||
Random,
|
||||
Round,
|
||||
Sign,
|
||||
Sin,
|
||||
Sqrt,
|
||||
Tan,
|
||||
)
|
||||
from .text import (
|
||||
MD5,
|
||||
SHA1,
|
||||
SHA224,
|
||||
SHA256,
|
||||
SHA384,
|
||||
SHA512,
|
||||
Chr,
|
||||
Concat,
|
||||
ConcatPair,
|
||||
Left,
|
||||
Length,
|
||||
Lower,
|
||||
LPad,
|
||||
LTrim,
|
||||
Ord,
|
||||
Repeat,
|
||||
Replace,
|
||||
Reverse,
|
||||
Right,
|
||||
RPad,
|
||||
RTrim,
|
||||
StrIndex,
|
||||
Substr,
|
||||
Trim,
|
||||
Upper,
|
||||
)
|
||||
from .window import (
|
||||
CumeDist,
|
||||
DenseRank,
|
||||
FirstValue,
|
||||
Lag,
|
||||
LastValue,
|
||||
Lead,
|
||||
NthValue,
|
||||
Ntile,
|
||||
PercentRank,
|
||||
Rank,
|
||||
RowNumber,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# comparison and conversion
|
||||
"Cast",
|
||||
"Coalesce",
|
||||
"Collate",
|
||||
"Greatest",
|
||||
"JSONObject",
|
||||
"Least",
|
||||
"NullIf",
|
||||
# datetime
|
||||
"Extract",
|
||||
"ExtractDay",
|
||||
"ExtractHour",
|
||||
"ExtractMinute",
|
||||
"ExtractMonth",
|
||||
"ExtractQuarter",
|
||||
"ExtractSecond",
|
||||
"ExtractWeek",
|
||||
"ExtractIsoWeekDay",
|
||||
"ExtractWeekDay",
|
||||
"ExtractIsoYear",
|
||||
"ExtractYear",
|
||||
"Now",
|
||||
"Trunc",
|
||||
"TruncDate",
|
||||
"TruncDay",
|
||||
"TruncHour",
|
||||
"TruncMinute",
|
||||
"TruncMonth",
|
||||
"TruncQuarter",
|
||||
"TruncSecond",
|
||||
"TruncTime",
|
||||
"TruncWeek",
|
||||
"TruncYear",
|
||||
# math
|
||||
"Abs",
|
||||
"ACos",
|
||||
"ASin",
|
||||
"ATan",
|
||||
"ATan2",
|
||||
"Ceil",
|
||||
"Cos",
|
||||
"Cot",
|
||||
"Degrees",
|
||||
"Exp",
|
||||
"Floor",
|
||||
"Ln",
|
||||
"Log",
|
||||
"Mod",
|
||||
"Pi",
|
||||
"Power",
|
||||
"Radians",
|
||||
"Random",
|
||||
"Round",
|
||||
"Sign",
|
||||
"Sin",
|
||||
"Sqrt",
|
||||
"Tan",
|
||||
# text
|
||||
"MD5",
|
||||
"SHA1",
|
||||
"SHA224",
|
||||
"SHA256",
|
||||
"SHA384",
|
||||
"SHA512",
|
||||
"Chr",
|
||||
"Concat",
|
||||
"ConcatPair",
|
||||
"Left",
|
||||
"Length",
|
||||
"Lower",
|
||||
"LPad",
|
||||
"LTrim",
|
||||
"Ord",
|
||||
"Repeat",
|
||||
"Replace",
|
||||
"Reverse",
|
||||
"Right",
|
||||
"RPad",
|
||||
"RTrim",
|
||||
"StrIndex",
|
||||
"Substr",
|
||||
"Trim",
|
||||
"Upper",
|
||||
# window
|
||||
"CumeDist",
|
||||
"DenseRank",
|
||||
"FirstValue",
|
||||
"Lag",
|
||||
"LastValue",
|
||||
"Lead",
|
||||
"NthValue",
|
||||
"Ntile",
|
||||
"PercentRank",
|
||||
"Rank",
|
||||
"RowNumber",
|
||||
]
|
||||
@@ -0,0 +1,220 @@
|
||||
"""Database functions that do comparisons or type conversions."""
|
||||
from django.db import NotSupportedError
|
||||
from django.db.models.expressions import Func, Value
|
||||
from django.db.models.fields import TextField
|
||||
from django.db.models.fields.json import JSONField
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
|
||||
|
||||
class Cast(Func):
|
||||
"""Coerce an expression to a new field type."""
|
||||
|
||||
function = "CAST"
|
||||
template = "%(function)s(%(expressions)s AS %(db_type)s)"
|
||||
|
||||
def __init__(self, expression, output_field):
|
||||
super().__init__(expression, output_field=output_field)
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
extra_context["db_type"] = self.output_field.cast_db_type(connection)
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
db_type = self.output_field.db_type(connection)
|
||||
if db_type in {"datetime", "time"}:
|
||||
# Use strftime as datetime/time don't keep fractional seconds.
|
||||
template = "strftime(%%s, %(expressions)s)"
|
||||
sql, params = super().as_sql(
|
||||
compiler, connection, template=template, **extra_context
|
||||
)
|
||||
format_string = "%H:%M:%f" if db_type == "time" else "%Y-%m-%d %H:%M:%f"
|
||||
params.insert(0, format_string)
|
||||
return sql, params
|
||||
elif db_type == "date":
|
||||
template = "date(%(expressions)s)"
|
||||
return super().as_sql(
|
||||
compiler, connection, template=template, **extra_context
|
||||
)
|
||||
return self.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
template = None
|
||||
output_type = self.output_field.get_internal_type()
|
||||
# MySQL doesn't support explicit cast to float.
|
||||
if output_type == "FloatField":
|
||||
template = "(%(expressions)s + 0.0)"
|
||||
# MariaDB doesn't support explicit cast to JSON.
|
||||
elif output_type == "JSONField" and connection.mysql_is_mariadb:
|
||||
template = "JSON_EXTRACT(%(expressions)s, '$')"
|
||||
return self.as_sql(compiler, connection, template=template, **extra_context)
|
||||
|
||||
def as_postgresql(self, compiler, connection, **extra_context):
|
||||
# CAST would be valid too, but the :: shortcut syntax is more readable.
|
||||
# 'expressions' is wrapped in parentheses in case it's a complex
|
||||
# expression.
|
||||
return self.as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template="(%(expressions)s)::%(db_type)s",
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
if self.output_field.get_internal_type() == "JSONField":
|
||||
# Oracle doesn't support explicit cast to JSON.
|
||||
template = "JSON_QUERY(%(expressions)s, '$')"
|
||||
return super().as_sql(
|
||||
compiler, connection, template=template, **extra_context
|
||||
)
|
||||
return self.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Coalesce(Func):
|
||||
"""Return, from left to right, the first non-null expression."""
|
||||
|
||||
function = "COALESCE"
|
||||
|
||||
def __init__(self, *expressions, **extra):
|
||||
if len(expressions) < 2:
|
||||
raise ValueError("Coalesce must take at least two expressions")
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
@property
|
||||
def empty_result_set_value(self):
|
||||
for expression in self.get_source_expressions():
|
||||
result = expression.empty_result_set_value
|
||||
if result is NotImplemented or result is not None:
|
||||
return result
|
||||
return None
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
# Oracle prohibits mixing TextField (NCLOB) and CharField (NVARCHAR2),
|
||||
# so convert all fields to NCLOB when that type is expected.
|
||||
if self.output_field.get_internal_type() == "TextField":
|
||||
clone = self.copy()
|
||||
clone.set_source_expressions(
|
||||
[
|
||||
Func(expression, function="TO_NCLOB")
|
||||
for expression in self.get_source_expressions()
|
||||
]
|
||||
)
|
||||
return super(Coalesce, clone).as_sql(compiler, connection, **extra_context)
|
||||
return self.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Collate(Func):
|
||||
function = "COLLATE"
|
||||
template = "%(expressions)s %(function)s %(collation)s"
|
||||
# Inspired from
|
||||
# https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
|
||||
collation_re = _lazy_re_compile(r"^[\w\-]+$")
|
||||
|
||||
def __init__(self, expression, collation):
|
||||
if not (collation and self.collation_re.match(collation)):
|
||||
raise ValueError("Invalid collation name: %r." % collation)
|
||||
self.collation = collation
|
||||
super().__init__(expression)
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
extra_context.setdefault("collation", connection.ops.quote_name(self.collation))
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Greatest(Func):
|
||||
"""
|
||||
Return the maximum expression.
|
||||
|
||||
If any expression is null the return value is database-specific:
|
||||
On PostgreSQL, the maximum not-null expression is returned.
|
||||
On MySQL, Oracle, and SQLite, if any expression is null, null is returned.
|
||||
"""
|
||||
|
||||
function = "GREATEST"
|
||||
|
||||
def __init__(self, *expressions, **extra):
|
||||
if len(expressions) < 2:
|
||||
raise ValueError("Greatest must take at least two expressions")
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
"""Use the MAX function on SQLite."""
|
||||
return super().as_sqlite(compiler, connection, function="MAX", **extra_context)
|
||||
|
||||
|
||||
class JSONObject(Func):
|
||||
function = "JSON_OBJECT"
|
||||
output_field = JSONField()
|
||||
|
||||
def __init__(self, **fields):
|
||||
expressions = []
|
||||
for key, value in fields.items():
|
||||
expressions.extend((Value(key), value))
|
||||
super().__init__(*expressions)
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
if not connection.features.has_json_object_function:
|
||||
raise NotSupportedError(
|
||||
"JSONObject() is not supported on this database backend."
|
||||
)
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
def as_postgresql(self, compiler, connection, **extra_context):
|
||||
copy = self.copy()
|
||||
copy.set_source_expressions(
|
||||
[
|
||||
Cast(expression, TextField()) if index % 2 == 0 else expression
|
||||
for index, expression in enumerate(copy.get_source_expressions())
|
||||
]
|
||||
)
|
||||
return super(JSONObject, copy).as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
function="JSONB_BUILD_OBJECT",
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
class ArgJoiner:
|
||||
def join(self, args):
|
||||
args = [" VALUE ".join(arg) for arg in zip(args[::2], args[1::2])]
|
||||
return ", ".join(args)
|
||||
|
||||
return self.as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
arg_joiner=ArgJoiner(),
|
||||
template="%(function)s(%(expressions)s RETURNING CLOB)",
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
|
||||
class Least(Func):
|
||||
"""
|
||||
Return the minimum expression.
|
||||
|
||||
If any expression is null the return value is database-specific:
|
||||
On PostgreSQL, return the minimum not-null expression.
|
||||
On MySQL, Oracle, and SQLite, if any expression is null, return null.
|
||||
"""
|
||||
|
||||
function = "LEAST"
|
||||
|
||||
def __init__(self, *expressions, **extra):
|
||||
if len(expressions) < 2:
|
||||
raise ValueError("Least must take at least two expressions")
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
"""Use the MIN function on SQLite."""
|
||||
return super().as_sqlite(compiler, connection, function="MIN", **extra_context)
|
||||
|
||||
|
||||
class NullIf(Func):
|
||||
function = "NULLIF"
|
||||
arity = 2
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
expression1 = self.get_source_expressions()[0]
|
||||
if isinstance(expression1, Value) and expression1.value is None:
|
||||
raise ValueError("Oracle does not allow Value(None) for expression1.")
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
@@ -0,0 +1,439 @@
|
||||
from datetime import datetime
|
||||
|
||||
from django.conf import settings
|
||||
from django.db.models.expressions import Func
|
||||
from django.db.models.fields import (
|
||||
DateField,
|
||||
DateTimeField,
|
||||
DurationField,
|
||||
Field,
|
||||
IntegerField,
|
||||
TimeField,
|
||||
)
|
||||
from django.db.models.lookups import (
|
||||
Transform,
|
||||
YearExact,
|
||||
YearGt,
|
||||
YearGte,
|
||||
YearLt,
|
||||
YearLte,
|
||||
)
|
||||
from django.utils import timezone
|
||||
|
||||
|
||||
class TimezoneMixin:
|
||||
tzinfo = None
|
||||
|
||||
def get_tzname(self):
|
||||
# Timezone conversions must happen to the input datetime *before*
|
||||
# applying a function. 2015-12-31 23:00:00 -02:00 is stored in the
|
||||
# database as 2016-01-01 01:00:00 +00:00. Any results should be
|
||||
# based on the input datetime not the stored datetime.
|
||||
tzname = None
|
||||
if settings.USE_TZ:
|
||||
if self.tzinfo is None:
|
||||
tzname = timezone.get_current_timezone_name()
|
||||
else:
|
||||
tzname = timezone._get_timezone_name(self.tzinfo)
|
||||
return tzname
|
||||
|
||||
|
||||
class Extract(TimezoneMixin, Transform):
|
||||
lookup_name = None
|
||||
output_field = IntegerField()
|
||||
|
||||
def __init__(self, expression, lookup_name=None, tzinfo=None, **extra):
|
||||
if self.lookup_name is None:
|
||||
self.lookup_name = lookup_name
|
||||
if self.lookup_name is None:
|
||||
raise ValueError("lookup_name must be provided")
|
||||
self.tzinfo = tzinfo
|
||||
super().__init__(expression, **extra)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
sql, params = compiler.compile(self.lhs)
|
||||
lhs_output_field = self.lhs.output_field
|
||||
if isinstance(lhs_output_field, DateTimeField):
|
||||
tzname = self.get_tzname()
|
||||
sql, params = connection.ops.datetime_extract_sql(
|
||||
self.lookup_name, sql, tuple(params), tzname
|
||||
)
|
||||
elif self.tzinfo is not None:
|
||||
raise ValueError("tzinfo can only be used with DateTimeField.")
|
||||
elif isinstance(lhs_output_field, DateField):
|
||||
sql, params = connection.ops.date_extract_sql(
|
||||
self.lookup_name, sql, tuple(params)
|
||||
)
|
||||
elif isinstance(lhs_output_field, TimeField):
|
||||
sql, params = connection.ops.time_extract_sql(
|
||||
self.lookup_name, sql, tuple(params)
|
||||
)
|
||||
elif isinstance(lhs_output_field, DurationField):
|
||||
if not connection.features.has_native_duration_field:
|
||||
raise ValueError(
|
||||
"Extract requires native DurationField database support."
|
||||
)
|
||||
sql, params = connection.ops.time_extract_sql(
|
||||
self.lookup_name, sql, tuple(params)
|
||||
)
|
||||
else:
|
||||
# resolve_expression has already validated the output_field so this
|
||||
# assert should never be hit.
|
||||
assert False, "Tried to Extract from an invalid type."
|
||||
return sql, params
|
||||
|
||||
def resolve_expression(
|
||||
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
|
||||
):
|
||||
copy = super().resolve_expression(
|
||||
query, allow_joins, reuse, summarize, for_save
|
||||
)
|
||||
field = getattr(copy.lhs, "output_field", None)
|
||||
if field is None:
|
||||
return copy
|
||||
if not isinstance(field, (DateField, DateTimeField, TimeField, DurationField)):
|
||||
raise ValueError(
|
||||
"Extract input expression must be DateField, DateTimeField, "
|
||||
"TimeField, or DurationField."
|
||||
)
|
||||
# Passing dates to functions expecting datetimes is most likely a mistake.
|
||||
if type(field) is DateField and copy.lookup_name in (
|
||||
"hour",
|
||||
"minute",
|
||||
"second",
|
||||
):
|
||||
raise ValueError(
|
||||
"Cannot extract time component '%s' from DateField '%s'."
|
||||
% (copy.lookup_name, field.name)
|
||||
)
|
||||
if isinstance(field, DurationField) and copy.lookup_name in (
|
||||
"year",
|
||||
"iso_year",
|
||||
"month",
|
||||
"week",
|
||||
"week_day",
|
||||
"iso_week_day",
|
||||
"quarter",
|
||||
):
|
||||
raise ValueError(
|
||||
"Cannot extract component '%s' from DurationField '%s'."
|
||||
% (copy.lookup_name, field.name)
|
||||
)
|
||||
return copy
|
||||
|
||||
|
||||
class ExtractYear(Extract):
|
||||
lookup_name = "year"
|
||||
|
||||
|
||||
class ExtractIsoYear(Extract):
|
||||
"""Return the ISO-8601 week-numbering year."""
|
||||
|
||||
lookup_name = "iso_year"
|
||||
|
||||
|
||||
class ExtractMonth(Extract):
|
||||
lookup_name = "month"
|
||||
|
||||
|
||||
class ExtractDay(Extract):
|
||||
lookup_name = "day"
|
||||
|
||||
|
||||
class ExtractWeek(Extract):
|
||||
"""
|
||||
Return 1-52 or 53, based on ISO-8601, i.e., Monday is the first of the
|
||||
week.
|
||||
"""
|
||||
|
||||
lookup_name = "week"
|
||||
|
||||
|
||||
class ExtractWeekDay(Extract):
|
||||
"""
|
||||
Return Sunday=1 through Saturday=7.
|
||||
|
||||
To replicate this in Python: (mydatetime.isoweekday() % 7) + 1
|
||||
"""
|
||||
|
||||
lookup_name = "week_day"
|
||||
|
||||
|
||||
class ExtractIsoWeekDay(Extract):
|
||||
"""Return Monday=1 through Sunday=7, based on ISO-8601."""
|
||||
|
||||
lookup_name = "iso_week_day"
|
||||
|
||||
|
||||
class ExtractQuarter(Extract):
|
||||
lookup_name = "quarter"
|
||||
|
||||
|
||||
class ExtractHour(Extract):
|
||||
lookup_name = "hour"
|
||||
|
||||
|
||||
class ExtractMinute(Extract):
|
||||
lookup_name = "minute"
|
||||
|
||||
|
||||
class ExtractSecond(Extract):
|
||||
lookup_name = "second"
|
||||
|
||||
|
||||
DateField.register_lookup(ExtractYear)
|
||||
DateField.register_lookup(ExtractMonth)
|
||||
DateField.register_lookup(ExtractDay)
|
||||
DateField.register_lookup(ExtractWeekDay)
|
||||
DateField.register_lookup(ExtractIsoWeekDay)
|
||||
DateField.register_lookup(ExtractWeek)
|
||||
DateField.register_lookup(ExtractIsoYear)
|
||||
DateField.register_lookup(ExtractQuarter)
|
||||
|
||||
TimeField.register_lookup(ExtractHour)
|
||||
TimeField.register_lookup(ExtractMinute)
|
||||
TimeField.register_lookup(ExtractSecond)
|
||||
|
||||
DateTimeField.register_lookup(ExtractHour)
|
||||
DateTimeField.register_lookup(ExtractMinute)
|
||||
DateTimeField.register_lookup(ExtractSecond)
|
||||
|
||||
ExtractYear.register_lookup(YearExact)
|
||||
ExtractYear.register_lookup(YearGt)
|
||||
ExtractYear.register_lookup(YearGte)
|
||||
ExtractYear.register_lookup(YearLt)
|
||||
ExtractYear.register_lookup(YearLte)
|
||||
|
||||
ExtractIsoYear.register_lookup(YearExact)
|
||||
ExtractIsoYear.register_lookup(YearGt)
|
||||
ExtractIsoYear.register_lookup(YearGte)
|
||||
ExtractIsoYear.register_lookup(YearLt)
|
||||
ExtractIsoYear.register_lookup(YearLte)
|
||||
|
||||
|
||||
class Now(Func):
|
||||
template = "CURRENT_TIMESTAMP"
|
||||
output_field = DateTimeField()
|
||||
|
||||
def as_postgresql(self, compiler, connection, **extra_context):
|
||||
# PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the
|
||||
# transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with
|
||||
# other databases.
|
||||
return self.as_sql(
|
||||
compiler, connection, template="STATEMENT_TIMESTAMP()", **extra_context
|
||||
)
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
return self.as_sql(
|
||||
compiler, connection, template="CURRENT_TIMESTAMP(6)", **extra_context
|
||||
)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
return self.as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template="STRFTIME('%%%%Y-%%%%m-%%%%d %%%%H:%%%%M:%%%%f', 'NOW')",
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
|
||||
class TruncBase(TimezoneMixin, Transform):
|
||||
kind = None
|
||||
tzinfo = None
|
||||
|
||||
# RemovedInDjango50Warning: when the deprecation ends, remove is_dst
|
||||
# argument.
|
||||
def __init__(
|
||||
self,
|
||||
expression,
|
||||
output_field=None,
|
||||
tzinfo=None,
|
||||
is_dst=timezone.NOT_PASSED,
|
||||
**extra,
|
||||
):
|
||||
self.tzinfo = tzinfo
|
||||
self.is_dst = is_dst
|
||||
super().__init__(expression, output_field=output_field, **extra)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
sql, params = compiler.compile(self.lhs)
|
||||
tzname = None
|
||||
if isinstance(self.lhs.output_field, DateTimeField):
|
||||
tzname = self.get_tzname()
|
||||
elif self.tzinfo is not None:
|
||||
raise ValueError("tzinfo can only be used with DateTimeField.")
|
||||
if isinstance(self.output_field, DateTimeField):
|
||||
sql, params = connection.ops.datetime_trunc_sql(
|
||||
self.kind, sql, tuple(params), tzname
|
||||
)
|
||||
elif isinstance(self.output_field, DateField):
|
||||
sql, params = connection.ops.date_trunc_sql(
|
||||
self.kind, sql, tuple(params), tzname
|
||||
)
|
||||
elif isinstance(self.output_field, TimeField):
|
||||
sql, params = connection.ops.time_trunc_sql(
|
||||
self.kind, sql, tuple(params), tzname
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Trunc only valid on DateField, TimeField, or DateTimeField."
|
||||
)
|
||||
return sql, params
|
||||
|
||||
def resolve_expression(
|
||||
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
|
||||
):
|
||||
copy = super().resolve_expression(
|
||||
query, allow_joins, reuse, summarize, for_save
|
||||
)
|
||||
field = copy.lhs.output_field
|
||||
# DateTimeField is a subclass of DateField so this works for both.
|
||||
if not isinstance(field, (DateField, TimeField)):
|
||||
raise TypeError(
|
||||
"%r isn't a DateField, TimeField, or DateTimeField." % field.name
|
||||
)
|
||||
# If self.output_field was None, then accessing the field will trigger
|
||||
# the resolver to assign it to self.lhs.output_field.
|
||||
if not isinstance(copy.output_field, (DateField, DateTimeField, TimeField)):
|
||||
raise ValueError(
|
||||
"output_field must be either DateField, TimeField, or DateTimeField"
|
||||
)
|
||||
# Passing dates or times to functions expecting datetimes is most
|
||||
# likely a mistake.
|
||||
class_output_field = (
|
||||
self.__class__.output_field
|
||||
if isinstance(self.__class__.output_field, Field)
|
||||
else None
|
||||
)
|
||||
output_field = class_output_field or copy.output_field
|
||||
has_explicit_output_field = (
|
||||
class_output_field or field.__class__ is not copy.output_field.__class__
|
||||
)
|
||||
if type(field) is DateField and (
|
||||
isinstance(output_field, DateTimeField)
|
||||
or copy.kind in ("hour", "minute", "second", "time")
|
||||
):
|
||||
raise ValueError(
|
||||
"Cannot truncate DateField '%s' to %s."
|
||||
% (
|
||||
field.name,
|
||||
output_field.__class__.__name__
|
||||
if has_explicit_output_field
|
||||
else "DateTimeField",
|
||||
)
|
||||
)
|
||||
elif isinstance(field, TimeField) and (
|
||||
isinstance(output_field, DateTimeField)
|
||||
or copy.kind in ("year", "quarter", "month", "week", "day", "date")
|
||||
):
|
||||
raise ValueError(
|
||||
"Cannot truncate TimeField '%s' to %s."
|
||||
% (
|
||||
field.name,
|
||||
output_field.__class__.__name__
|
||||
if has_explicit_output_field
|
||||
else "DateTimeField",
|
||||
)
|
||||
)
|
||||
return copy
|
||||
|
||||
def convert_value(self, value, expression, connection):
|
||||
if isinstance(self.output_field, DateTimeField):
|
||||
if not settings.USE_TZ:
|
||||
pass
|
||||
elif value is not None:
|
||||
value = value.replace(tzinfo=None)
|
||||
value = timezone.make_aware(value, self.tzinfo, is_dst=self.is_dst)
|
||||
elif not connection.features.has_zoneinfo_database:
|
||||
raise ValueError(
|
||||
"Database returned an invalid datetime value. Are time "
|
||||
"zone definitions for your database installed?"
|
||||
)
|
||||
elif isinstance(value, datetime):
|
||||
if value is None:
|
||||
pass
|
||||
elif isinstance(self.output_field, DateField):
|
||||
value = value.date()
|
||||
elif isinstance(self.output_field, TimeField):
|
||||
value = value.time()
|
||||
return value
|
||||
|
||||
|
||||
class Trunc(TruncBase):
|
||||
# RemovedInDjango50Warning: when the deprecation ends, remove is_dst
|
||||
# argument.
|
||||
def __init__(
|
||||
self,
|
||||
expression,
|
||||
kind,
|
||||
output_field=None,
|
||||
tzinfo=None,
|
||||
is_dst=timezone.NOT_PASSED,
|
||||
**extra,
|
||||
):
|
||||
self.kind = kind
|
||||
super().__init__(
|
||||
expression, output_field=output_field, tzinfo=tzinfo, is_dst=is_dst, **extra
|
||||
)
|
||||
|
||||
|
||||
class TruncYear(TruncBase):
|
||||
kind = "year"
|
||||
|
||||
|
||||
class TruncQuarter(TruncBase):
|
||||
kind = "quarter"
|
||||
|
||||
|
||||
class TruncMonth(TruncBase):
|
||||
kind = "month"
|
||||
|
||||
|
||||
class TruncWeek(TruncBase):
|
||||
"""Truncate to midnight on the Monday of the week."""
|
||||
|
||||
kind = "week"
|
||||
|
||||
|
||||
class TruncDay(TruncBase):
|
||||
kind = "day"
|
||||
|
||||
|
||||
class TruncDate(TruncBase):
|
||||
kind = "date"
|
||||
lookup_name = "date"
|
||||
output_field = DateField()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
# Cast to date rather than truncate to date.
|
||||
sql, params = compiler.compile(self.lhs)
|
||||
tzname = self.get_tzname()
|
||||
return connection.ops.datetime_cast_date_sql(sql, tuple(params), tzname)
|
||||
|
||||
|
||||
class TruncTime(TruncBase):
|
||||
kind = "time"
|
||||
lookup_name = "time"
|
||||
output_field = TimeField()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
# Cast to time rather than truncate to time.
|
||||
sql, params = compiler.compile(self.lhs)
|
||||
tzname = self.get_tzname()
|
||||
return connection.ops.datetime_cast_time_sql(sql, tuple(params), tzname)
|
||||
|
||||
|
||||
class TruncHour(TruncBase):
|
||||
kind = "hour"
|
||||
|
||||
|
||||
class TruncMinute(TruncBase):
|
||||
kind = "minute"
|
||||
|
||||
|
||||
class TruncSecond(TruncBase):
|
||||
kind = "second"
|
||||
|
||||
|
||||
DateTimeField.register_lookup(TruncDate)
|
||||
DateTimeField.register_lookup(TruncTime)
|
||||
@@ -0,0 +1,212 @@
|
||||
import math
|
||||
|
||||
from django.db.models.expressions import Func, Value
|
||||
from django.db.models.fields import FloatField, IntegerField
|
||||
from django.db.models.functions import Cast
|
||||
from django.db.models.functions.mixins import (
|
||||
FixDecimalInputMixin,
|
||||
NumericOutputFieldMixin,
|
||||
)
|
||||
from django.db.models.lookups import Transform
|
||||
|
||||
|
||||
class Abs(Transform):
|
||||
function = "ABS"
|
||||
lookup_name = "abs"
|
||||
|
||||
|
||||
class ACos(NumericOutputFieldMixin, Transform):
|
||||
function = "ACOS"
|
||||
lookup_name = "acos"
|
||||
|
||||
|
||||
class ASin(NumericOutputFieldMixin, Transform):
|
||||
function = "ASIN"
|
||||
lookup_name = "asin"
|
||||
|
||||
|
||||
class ATan(NumericOutputFieldMixin, Transform):
|
||||
function = "ATAN"
|
||||
lookup_name = "atan"
|
||||
|
||||
|
||||
class ATan2(NumericOutputFieldMixin, Func):
|
||||
function = "ATAN2"
|
||||
arity = 2
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
if not getattr(
|
||||
connection.ops, "spatialite", False
|
||||
) or connection.ops.spatial_version >= (5, 0, 0):
|
||||
return self.as_sql(compiler, connection)
|
||||
# This function is usually ATan2(y, x), returning the inverse tangent
|
||||
# of y / x, but it's ATan2(x, y) on SpatiaLite < 5.0.0.
|
||||
# Cast integers to float to avoid inconsistent/buggy behavior if the
|
||||
# arguments are mixed between integer and float or decimal.
|
||||
# https://www.gaia-gis.it/fossil/libspatialite/tktview?name=0f72cca3a2
|
||||
clone = self.copy()
|
||||
clone.set_source_expressions(
|
||||
[
|
||||
Cast(expression, FloatField())
|
||||
if isinstance(expression.output_field, IntegerField)
|
||||
else expression
|
||||
for expression in self.get_source_expressions()[::-1]
|
||||
]
|
||||
)
|
||||
return clone.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Ceil(Transform):
|
||||
function = "CEILING"
|
||||
lookup_name = "ceil"
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="CEIL", **extra_context)
|
||||
|
||||
|
||||
class Cos(NumericOutputFieldMixin, Transform):
|
||||
function = "COS"
|
||||
lookup_name = "cos"
|
||||
|
||||
|
||||
class Cot(NumericOutputFieldMixin, Transform):
|
||||
function = "COT"
|
||||
lookup_name = "cot"
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler, connection, template="(1 / TAN(%(expressions)s))", **extra_context
|
||||
)
|
||||
|
||||
|
||||
class Degrees(NumericOutputFieldMixin, Transform):
|
||||
function = "DEGREES"
|
||||
lookup_name = "degrees"
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template="((%%(expressions)s) * 180 / %s)" % math.pi,
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
|
||||
class Exp(NumericOutputFieldMixin, Transform):
|
||||
function = "EXP"
|
||||
lookup_name = "exp"
|
||||
|
||||
|
||||
class Floor(Transform):
|
||||
function = "FLOOR"
|
||||
lookup_name = "floor"
|
||||
|
||||
|
||||
class Ln(NumericOutputFieldMixin, Transform):
|
||||
function = "LN"
|
||||
lookup_name = "ln"
|
||||
|
||||
|
||||
class Log(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
|
||||
function = "LOG"
|
||||
arity = 2
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
if not getattr(connection.ops, "spatialite", False):
|
||||
return self.as_sql(compiler, connection)
|
||||
# This function is usually Log(b, x) returning the logarithm of x to
|
||||
# the base b, but on SpatiaLite it's Log(x, b).
|
||||
clone = self.copy()
|
||||
clone.set_source_expressions(self.get_source_expressions()[::-1])
|
||||
return clone.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Mod(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
|
||||
function = "MOD"
|
||||
arity = 2
|
||||
|
||||
|
||||
class Pi(NumericOutputFieldMixin, Func):
|
||||
function = "PI"
|
||||
arity = 0
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler, connection, template=str(math.pi), **extra_context
|
||||
)
|
||||
|
||||
|
||||
class Power(NumericOutputFieldMixin, Func):
|
||||
function = "POWER"
|
||||
arity = 2
|
||||
|
||||
|
||||
class Radians(NumericOutputFieldMixin, Transform):
|
||||
function = "RADIANS"
|
||||
lookup_name = "radians"
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template="((%%(expressions)s) * %s / 180)" % math.pi,
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
|
||||
class Random(NumericOutputFieldMixin, Func):
|
||||
function = "RANDOM"
|
||||
arity = 0
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="RAND", **extra_context)
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler, connection, function="DBMS_RANDOM.VALUE", **extra_context
|
||||
)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="RAND", **extra_context)
|
||||
|
||||
def get_group_by_cols(self):
|
||||
return []
|
||||
|
||||
|
||||
class Round(FixDecimalInputMixin, Transform):
|
||||
function = "ROUND"
|
||||
lookup_name = "round"
|
||||
arity = None # Override Transform's arity=1 to enable passing precision.
|
||||
|
||||
def __init__(self, expression, precision=0, **extra):
|
||||
super().__init__(expression, precision, **extra)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
precision = self.get_source_expressions()[1]
|
||||
if isinstance(precision, Value) and precision.value < 0:
|
||||
raise ValueError("SQLite does not support negative precision.")
|
||||
return super().as_sqlite(compiler, connection, **extra_context)
|
||||
|
||||
def _resolve_output_field(self):
|
||||
source = self.get_source_expressions()[0]
|
||||
return source.output_field
|
||||
|
||||
|
||||
class Sign(Transform):
|
||||
function = "SIGN"
|
||||
lookup_name = "sign"
|
||||
|
||||
|
||||
class Sin(NumericOutputFieldMixin, Transform):
|
||||
function = "SIN"
|
||||
lookup_name = "sin"
|
||||
|
||||
|
||||
class Sqrt(NumericOutputFieldMixin, Transform):
|
||||
function = "SQRT"
|
||||
lookup_name = "sqrt"
|
||||
|
||||
|
||||
class Tan(NumericOutputFieldMixin, Transform):
|
||||
function = "TAN"
|
||||
lookup_name = "tan"
|
||||
@@ -0,0 +1,57 @@
|
||||
import sys
|
||||
|
||||
from django.db.models.fields import DecimalField, FloatField, IntegerField
|
||||
from django.db.models.functions import Cast
|
||||
|
||||
|
||||
class FixDecimalInputMixin:
|
||||
def as_postgresql(self, compiler, connection, **extra_context):
|
||||
# Cast FloatField to DecimalField as PostgreSQL doesn't support the
|
||||
# following function signatures:
|
||||
# - LOG(double, double)
|
||||
# - MOD(double, double)
|
||||
output_field = DecimalField(decimal_places=sys.float_info.dig, max_digits=1000)
|
||||
clone = self.copy()
|
||||
clone.set_source_expressions(
|
||||
[
|
||||
Cast(expression, output_field)
|
||||
if isinstance(expression.output_field, FloatField)
|
||||
else expression
|
||||
for expression in self.get_source_expressions()
|
||||
]
|
||||
)
|
||||
return clone.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class FixDurationInputMixin:
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
sql, params = super().as_sql(compiler, connection, **extra_context)
|
||||
if self.output_field.get_internal_type() == "DurationField":
|
||||
sql = "CAST(%s AS SIGNED)" % sql
|
||||
return sql, params
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
if self.output_field.get_internal_type() == "DurationField":
|
||||
expression = self.get_source_expressions()[0]
|
||||
options = self._get_repr_options()
|
||||
from django.db.backends.oracle.functions import (
|
||||
IntervalToSeconds,
|
||||
SecondsToInterval,
|
||||
)
|
||||
|
||||
return compiler.compile(
|
||||
SecondsToInterval(
|
||||
self.__class__(IntervalToSeconds(expression), **options)
|
||||
)
|
||||
)
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class NumericOutputFieldMixin:
|
||||
def _resolve_output_field(self):
|
||||
source_fields = self.get_source_fields()
|
||||
if any(isinstance(s, DecimalField) for s in source_fields):
|
||||
return DecimalField()
|
||||
if any(isinstance(s, IntegerField) for s in source_fields):
|
||||
return FloatField()
|
||||
return super()._resolve_output_field() if source_fields else FloatField()
|
||||
@@ -0,0 +1,365 @@
|
||||
from django.db import NotSupportedError
|
||||
from django.db.models.expressions import Func, Value
|
||||
from django.db.models.fields import CharField, IntegerField, TextField
|
||||
from django.db.models.functions import Cast, Coalesce
|
||||
from django.db.models.lookups import Transform
|
||||
|
||||
|
||||
class MySQLSHA2Mixin:
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template="SHA2(%%(expressions)s, %s)" % self.function[3:],
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
|
||||
class OracleHashMixin:
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template=(
|
||||
"LOWER(RAWTOHEX(STANDARD_HASH(UTL_I18N.STRING_TO_RAW("
|
||||
"%(expressions)s, 'AL32UTF8'), '%(function)s')))"
|
||||
),
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
|
||||
class PostgreSQLSHAMixin:
|
||||
def as_postgresql(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template="ENCODE(DIGEST(%(expressions)s, '%(function)s'), 'hex')",
|
||||
function=self.function.lower(),
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
|
||||
class Chr(Transform):
|
||||
function = "CHR"
|
||||
lookup_name = "chr"
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
function="CHAR",
|
||||
template="%(function)s(%(expressions)s USING utf16)",
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template="%(function)s(%(expressions)s USING NCHAR_CS)",
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="CHAR", **extra_context)
|
||||
|
||||
|
||||
class ConcatPair(Func):
|
||||
"""
|
||||
Concatenate two arguments together. This is used by `Concat` because not
|
||||
all backend databases support more than two arguments.
|
||||
"""
|
||||
|
||||
function = "CONCAT"
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
coalesced = self.coalesce()
|
||||
return super(ConcatPair, coalesced).as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template="%(expressions)s",
|
||||
arg_joiner=" || ",
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
def as_postgresql(self, compiler, connection, **extra_context):
|
||||
copy = self.copy()
|
||||
copy.set_source_expressions(
|
||||
[
|
||||
Cast(expression, TextField())
|
||||
for expression in copy.get_source_expressions()
|
||||
]
|
||||
)
|
||||
return super(ConcatPair, copy).as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
# Use CONCAT_WS with an empty separator so that NULLs are ignored.
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
function="CONCAT_WS",
|
||||
template="%(function)s('', %(expressions)s)",
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
def coalesce(self):
|
||||
# null on either side results in null for expression, wrap with coalesce
|
||||
c = self.copy()
|
||||
c.set_source_expressions(
|
||||
[
|
||||
Coalesce(expression, Value(""))
|
||||
for expression in c.get_source_expressions()
|
||||
]
|
||||
)
|
||||
return c
|
||||
|
||||
|
||||
class Concat(Func):
|
||||
"""
|
||||
Concatenate text fields together. Backends that result in an entire
|
||||
null expression when any arguments are null will wrap each argument in
|
||||
coalesce functions to ensure a non-null result.
|
||||
"""
|
||||
|
||||
function = None
|
||||
template = "%(expressions)s"
|
||||
|
||||
def __init__(self, *expressions, **extra):
|
||||
if len(expressions) < 2:
|
||||
raise ValueError("Concat must take at least two expressions")
|
||||
paired = self._paired(expressions)
|
||||
super().__init__(paired, **extra)
|
||||
|
||||
def _paired(self, expressions):
|
||||
# wrap pairs of expressions in successive concat functions
|
||||
# exp = [a, b, c, d]
|
||||
# -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d))))
|
||||
if len(expressions) == 2:
|
||||
return ConcatPair(*expressions)
|
||||
return ConcatPair(expressions[0], self._paired(expressions[1:]))
|
||||
|
||||
|
||||
class Left(Func):
|
||||
function = "LEFT"
|
||||
arity = 2
|
||||
output_field = CharField()
|
||||
|
||||
def __init__(self, expression, length, **extra):
|
||||
"""
|
||||
expression: the name of a field, or an expression returning a string
|
||||
length: the number of characters to return from the start of the string
|
||||
"""
|
||||
if not hasattr(length, "resolve_expression"):
|
||||
if length < 1:
|
||||
raise ValueError("'length' must be greater than 0.")
|
||||
super().__init__(expression, length, **extra)
|
||||
|
||||
def get_substr(self):
|
||||
return Substr(self.source_expressions[0], Value(1), self.source_expressions[1])
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return self.get_substr().as_oracle(compiler, connection, **extra_context)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
return self.get_substr().as_sqlite(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Length(Transform):
|
||||
"""Return the number of characters in the expression."""
|
||||
|
||||
function = "LENGTH"
|
||||
lookup_name = "length"
|
||||
output_field = IntegerField()
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler, connection, function="CHAR_LENGTH", **extra_context
|
||||
)
|
||||
|
||||
|
||||
class Lower(Transform):
|
||||
function = "LOWER"
|
||||
lookup_name = "lower"
|
||||
|
||||
|
||||
class LPad(Func):
|
||||
function = "LPAD"
|
||||
output_field = CharField()
|
||||
|
||||
def __init__(self, expression, length, fill_text=Value(" "), **extra):
|
||||
if (
|
||||
not hasattr(length, "resolve_expression")
|
||||
and length is not None
|
||||
and length < 0
|
||||
):
|
||||
raise ValueError("'length' must be greater or equal to 0.")
|
||||
super().__init__(expression, length, fill_text, **extra)
|
||||
|
||||
|
||||
class LTrim(Transform):
|
||||
function = "LTRIM"
|
||||
lookup_name = "ltrim"
|
||||
|
||||
|
||||
class MD5(OracleHashMixin, Transform):
|
||||
function = "MD5"
|
||||
lookup_name = "md5"
|
||||
|
||||
|
||||
class Ord(Transform):
|
||||
function = "ASCII"
|
||||
lookup_name = "ord"
|
||||
output_field = IntegerField()
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="ORD", **extra_context)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="UNICODE", **extra_context)
|
||||
|
||||
|
||||
class Repeat(Func):
|
||||
function = "REPEAT"
|
||||
output_field = CharField()
|
||||
|
||||
def __init__(self, expression, number, **extra):
|
||||
if (
|
||||
not hasattr(number, "resolve_expression")
|
||||
and number is not None
|
||||
and number < 0
|
||||
):
|
||||
raise ValueError("'number' must be greater or equal to 0.")
|
||||
super().__init__(expression, number, **extra)
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
expression, number = self.source_expressions
|
||||
length = None if number is None else Length(expression) * number
|
||||
rpad = RPad(expression, length, expression)
|
||||
return rpad.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Replace(Func):
|
||||
function = "REPLACE"
|
||||
|
||||
def __init__(self, expression, text, replacement=Value(""), **extra):
|
||||
super().__init__(expression, text, replacement, **extra)
|
||||
|
||||
|
||||
class Reverse(Transform):
|
||||
function = "REVERSE"
|
||||
lookup_name = "reverse"
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
# REVERSE in Oracle is undocumented and doesn't support multi-byte
|
||||
# strings. Use a special subquery instead.
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template=(
|
||||
"(SELECT LISTAGG(s) WITHIN GROUP (ORDER BY n DESC) FROM "
|
||||
"(SELECT LEVEL n, SUBSTR(%(expressions)s, LEVEL, 1) s "
|
||||
"FROM DUAL CONNECT BY LEVEL <= LENGTH(%(expressions)s)) "
|
||||
"GROUP BY %(expressions)s)"
|
||||
),
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
|
||||
class Right(Left):
|
||||
function = "RIGHT"
|
||||
|
||||
def get_substr(self):
|
||||
return Substr(
|
||||
self.source_expressions[0], self.source_expressions[1] * Value(-1)
|
||||
)
|
||||
|
||||
|
||||
class RPad(LPad):
|
||||
function = "RPAD"
|
||||
|
||||
|
||||
class RTrim(Transform):
|
||||
function = "RTRIM"
|
||||
lookup_name = "rtrim"
|
||||
|
||||
|
||||
class SHA1(OracleHashMixin, PostgreSQLSHAMixin, Transform):
|
||||
function = "SHA1"
|
||||
lookup_name = "sha1"
|
||||
|
||||
|
||||
class SHA224(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
|
||||
function = "SHA224"
|
||||
lookup_name = "sha224"
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
raise NotSupportedError("SHA224 is not supported on Oracle.")
|
||||
|
||||
|
||||
class SHA256(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
|
||||
function = "SHA256"
|
||||
lookup_name = "sha256"
|
||||
|
||||
|
||||
class SHA384(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
|
||||
function = "SHA384"
|
||||
lookup_name = "sha384"
|
||||
|
||||
|
||||
class SHA512(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
|
||||
function = "SHA512"
|
||||
lookup_name = "sha512"
|
||||
|
||||
|
||||
class StrIndex(Func):
|
||||
"""
|
||||
Return a positive integer corresponding to the 1-indexed position of the
|
||||
first occurrence of a substring inside another string, or 0 if the
|
||||
substring is not found.
|
||||
"""
|
||||
|
||||
function = "INSTR"
|
||||
arity = 2
|
||||
output_field = IntegerField()
|
||||
|
||||
def as_postgresql(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="STRPOS", **extra_context)
|
||||
|
||||
|
||||
class Substr(Func):
|
||||
function = "SUBSTRING"
|
||||
output_field = CharField()
|
||||
|
||||
def __init__(self, expression, pos, length=None, **extra):
|
||||
"""
|
||||
expression: the name of a field, or an expression returning a string
|
||||
pos: an integer > 0, or an expression returning an integer
|
||||
length: an optional number of characters to return
|
||||
"""
|
||||
if not hasattr(pos, "resolve_expression"):
|
||||
if pos < 1:
|
||||
raise ValueError("'pos' must be greater than 0")
|
||||
expressions = [expression, pos]
|
||||
if length is not None:
|
||||
expressions.append(length)
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="SUBSTR", **extra_context)
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="SUBSTR", **extra_context)
|
||||
|
||||
|
||||
class Trim(Transform):
|
||||
function = "TRIM"
|
||||
lookup_name = "trim"
|
||||
|
||||
|
||||
class Upper(Transform):
|
||||
function = "UPPER"
|
||||
lookup_name = "upper"
|
||||
@@ -0,0 +1,120 @@
|
||||
from django.db.models.expressions import Func
|
||||
from django.db.models.fields import FloatField, IntegerField
|
||||
|
||||
__all__ = [
|
||||
"CumeDist",
|
||||
"DenseRank",
|
||||
"FirstValue",
|
||||
"Lag",
|
||||
"LastValue",
|
||||
"Lead",
|
||||
"NthValue",
|
||||
"Ntile",
|
||||
"PercentRank",
|
||||
"Rank",
|
||||
"RowNumber",
|
||||
]
|
||||
|
||||
|
||||
class CumeDist(Func):
|
||||
function = "CUME_DIST"
|
||||
output_field = FloatField()
|
||||
window_compatible = True
|
||||
|
||||
|
||||
class DenseRank(Func):
|
||||
function = "DENSE_RANK"
|
||||
output_field = IntegerField()
|
||||
window_compatible = True
|
||||
|
||||
|
||||
class FirstValue(Func):
|
||||
arity = 1
|
||||
function = "FIRST_VALUE"
|
||||
window_compatible = True
|
||||
|
||||
|
||||
class LagLeadFunction(Func):
|
||||
window_compatible = True
|
||||
|
||||
def __init__(self, expression, offset=1, default=None, **extra):
|
||||
if expression is None:
|
||||
raise ValueError(
|
||||
"%s requires a non-null source expression." % self.__class__.__name__
|
||||
)
|
||||
if offset is None or offset <= 0:
|
||||
raise ValueError(
|
||||
"%s requires a positive integer for the offset."
|
||||
% self.__class__.__name__
|
||||
)
|
||||
args = (expression, offset)
|
||||
if default is not None:
|
||||
args += (default,)
|
||||
super().__init__(*args, **extra)
|
||||
|
||||
def _resolve_output_field(self):
|
||||
sources = self.get_source_expressions()
|
||||
return sources[0].output_field
|
||||
|
||||
|
||||
class Lag(LagLeadFunction):
|
||||
function = "LAG"
|
||||
|
||||
|
||||
class LastValue(Func):
|
||||
arity = 1
|
||||
function = "LAST_VALUE"
|
||||
window_compatible = True
|
||||
|
||||
|
||||
class Lead(LagLeadFunction):
|
||||
function = "LEAD"
|
||||
|
||||
|
||||
class NthValue(Func):
|
||||
function = "NTH_VALUE"
|
||||
window_compatible = True
|
||||
|
||||
def __init__(self, expression, nth=1, **extra):
|
||||
if expression is None:
|
||||
raise ValueError(
|
||||
"%s requires a non-null source expression." % self.__class__.__name__
|
||||
)
|
||||
if nth is None or nth <= 0:
|
||||
raise ValueError(
|
||||
"%s requires a positive integer as for nth." % self.__class__.__name__
|
||||
)
|
||||
super().__init__(expression, nth, **extra)
|
||||
|
||||
def _resolve_output_field(self):
|
||||
sources = self.get_source_expressions()
|
||||
return sources[0].output_field
|
||||
|
||||
|
||||
class Ntile(Func):
|
||||
function = "NTILE"
|
||||
output_field = IntegerField()
|
||||
window_compatible = True
|
||||
|
||||
def __init__(self, num_buckets=1, **extra):
|
||||
if num_buckets <= 0:
|
||||
raise ValueError("num_buckets must be greater than 0.")
|
||||
super().__init__(num_buckets, **extra)
|
||||
|
||||
|
||||
class PercentRank(Func):
|
||||
function = "PERCENT_RANK"
|
||||
output_field = FloatField()
|
||||
window_compatible = True
|
||||
|
||||
|
||||
class Rank(Func):
|
||||
function = "RANK"
|
||||
output_field = IntegerField()
|
||||
window_compatible = True
|
||||
|
||||
|
||||
class RowNumber(Func):
|
||||
function = "ROW_NUMBER"
|
||||
output_field = IntegerField()
|
||||
window_compatible = True
|
||||
@@ -0,0 +1,295 @@
|
||||
from django.db.backends.utils import names_digest, split_identifier
|
||||
from django.db.models.expressions import Col, ExpressionList, F, Func, OrderBy
|
||||
from django.db.models.functions import Collate
|
||||
from django.db.models.query_utils import Q
|
||||
from django.db.models.sql import Query
|
||||
from django.utils.functional import partition
|
||||
|
||||
__all__ = ["Index"]
|
||||
|
||||
|
||||
class Index:
|
||||
suffix = "idx"
|
||||
# The max length of the name of the index (restricted to 30 for
|
||||
# cross-database compatibility with Oracle)
|
||||
max_name_length = 30
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*expressions,
|
||||
fields=(),
|
||||
name=None,
|
||||
db_tablespace=None,
|
||||
opclasses=(),
|
||||
condition=None,
|
||||
include=None,
|
||||
):
|
||||
if opclasses and not name:
|
||||
raise ValueError("An index must be named to use opclasses.")
|
||||
if not isinstance(condition, (type(None), Q)):
|
||||
raise ValueError("Index.condition must be a Q instance.")
|
||||
if condition and not name:
|
||||
raise ValueError("An index must be named to use condition.")
|
||||
if not isinstance(fields, (list, tuple)):
|
||||
raise ValueError("Index.fields must be a list or tuple.")
|
||||
if not isinstance(opclasses, (list, tuple)):
|
||||
raise ValueError("Index.opclasses must be a list or tuple.")
|
||||
if not expressions and not fields:
|
||||
raise ValueError(
|
||||
"At least one field or expression is required to define an index."
|
||||
)
|
||||
if expressions and fields:
|
||||
raise ValueError(
|
||||
"Index.fields and expressions are mutually exclusive.",
|
||||
)
|
||||
if expressions and not name:
|
||||
raise ValueError("An index must be named to use expressions.")
|
||||
if expressions and opclasses:
|
||||
raise ValueError(
|
||||
"Index.opclasses cannot be used with expressions. Use "
|
||||
"django.contrib.postgres.indexes.OpClass() instead."
|
||||
)
|
||||
if opclasses and len(fields) != len(opclasses):
|
||||
raise ValueError(
|
||||
"Index.fields and Index.opclasses must have the same number of "
|
||||
"elements."
|
||||
)
|
||||
if fields and not all(isinstance(field, str) for field in fields):
|
||||
raise ValueError("Index.fields must contain only strings with field names.")
|
||||
if include and not name:
|
||||
raise ValueError("A covering index must be named.")
|
||||
if not isinstance(include, (type(None), list, tuple)):
|
||||
raise ValueError("Index.include must be a list or tuple.")
|
||||
self.fields = list(fields)
|
||||
# A list of 2-tuple with the field name and ordering ('' or 'DESC').
|
||||
self.fields_orders = [
|
||||
(field_name[1:], "DESC") if field_name.startswith("-") else (field_name, "")
|
||||
for field_name in self.fields
|
||||
]
|
||||
self.name = name or ""
|
||||
self.db_tablespace = db_tablespace
|
||||
self.opclasses = opclasses
|
||||
self.condition = condition
|
||||
self.include = tuple(include) if include else ()
|
||||
self.expressions = tuple(
|
||||
F(expression) if isinstance(expression, str) else expression
|
||||
for expression in expressions
|
||||
)
|
||||
|
||||
@property
|
||||
def contains_expressions(self):
|
||||
return bool(self.expressions)
|
||||
|
||||
def _get_condition_sql(self, model, schema_editor):
|
||||
if self.condition is None:
|
||||
return None
|
||||
query = Query(model=model, alias_cols=False)
|
||||
where = query.build_where(self.condition)
|
||||
compiler = query.get_compiler(connection=schema_editor.connection)
|
||||
sql, params = where.as_sql(compiler, schema_editor.connection)
|
||||
return sql % tuple(schema_editor.quote_value(p) for p in params)
|
||||
|
||||
def create_sql(self, model, schema_editor, using="", **kwargs):
|
||||
include = [
|
||||
model._meta.get_field(field_name).column for field_name in self.include
|
||||
]
|
||||
condition = self._get_condition_sql(model, schema_editor)
|
||||
if self.expressions:
|
||||
index_expressions = []
|
||||
for expression in self.expressions:
|
||||
index_expression = IndexExpression(expression)
|
||||
index_expression.set_wrapper_classes(schema_editor.connection)
|
||||
index_expressions.append(index_expression)
|
||||
expressions = ExpressionList(*index_expressions).resolve_expression(
|
||||
Query(model, alias_cols=False),
|
||||
)
|
||||
fields = None
|
||||
col_suffixes = None
|
||||
else:
|
||||
fields = [
|
||||
model._meta.get_field(field_name)
|
||||
for field_name, _ in self.fields_orders
|
||||
]
|
||||
if schema_editor.connection.features.supports_index_column_ordering:
|
||||
col_suffixes = [order[1] for order in self.fields_orders]
|
||||
else:
|
||||
col_suffixes = [""] * len(self.fields_orders)
|
||||
expressions = None
|
||||
return schema_editor._create_index_sql(
|
||||
model,
|
||||
fields=fields,
|
||||
name=self.name,
|
||||
using=using,
|
||||
db_tablespace=self.db_tablespace,
|
||||
col_suffixes=col_suffixes,
|
||||
opclasses=self.opclasses,
|
||||
condition=condition,
|
||||
include=include,
|
||||
expressions=expressions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def remove_sql(self, model, schema_editor, **kwargs):
|
||||
return schema_editor._delete_index_sql(model, self.name, **kwargs)
|
||||
|
||||
def deconstruct(self):
|
||||
path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
|
||||
path = path.replace("django.db.models.indexes", "django.db.models")
|
||||
kwargs = {"name": self.name}
|
||||
if self.fields:
|
||||
kwargs["fields"] = self.fields
|
||||
if self.db_tablespace is not None:
|
||||
kwargs["db_tablespace"] = self.db_tablespace
|
||||
if self.opclasses:
|
||||
kwargs["opclasses"] = self.opclasses
|
||||
if self.condition:
|
||||
kwargs["condition"] = self.condition
|
||||
if self.include:
|
||||
kwargs["include"] = self.include
|
||||
return (path, self.expressions, kwargs)
|
||||
|
||||
def clone(self):
|
||||
"""Create a copy of this Index."""
|
||||
_, args, kwargs = self.deconstruct()
|
||||
return self.__class__(*args, **kwargs)
|
||||
|
||||
def set_name_with_model(self, model):
|
||||
"""
|
||||
Generate a unique name for the index.
|
||||
|
||||
The name is divided into 3 parts - table name (12 chars), field name
|
||||
(8 chars) and unique hash + suffix (10 chars). Each part is made to
|
||||
fit its size by truncating the excess length.
|
||||
"""
|
||||
_, table_name = split_identifier(model._meta.db_table)
|
||||
column_names = [
|
||||
model._meta.get_field(field_name).column
|
||||
for field_name, order in self.fields_orders
|
||||
]
|
||||
column_names_with_order = [
|
||||
(("-%s" if order else "%s") % column_name)
|
||||
for column_name, (field_name, order) in zip(
|
||||
column_names, self.fields_orders
|
||||
)
|
||||
]
|
||||
# The length of the parts of the name is based on the default max
|
||||
# length of 30 characters.
|
||||
hash_data = [table_name] + column_names_with_order + [self.suffix]
|
||||
self.name = "%s_%s_%s" % (
|
||||
table_name[:11],
|
||||
column_names[0][:7],
|
||||
"%s_%s" % (names_digest(*hash_data, length=6), self.suffix),
|
||||
)
|
||||
if len(self.name) > self.max_name_length:
|
||||
raise ValueError(
|
||||
"Index too long for multiple database support. Is self.suffix "
|
||||
"longer than 3 characters?"
|
||||
)
|
||||
if self.name[0] == "_" or self.name[0].isdigit():
|
||||
self.name = "D%s" % self.name[1:]
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s:%s%s%s%s%s%s%s>" % (
|
||||
self.__class__.__qualname__,
|
||||
"" if not self.fields else " fields=%s" % repr(self.fields),
|
||||
"" if not self.expressions else " expressions=%s" % repr(self.expressions),
|
||||
"" if not self.name else " name=%s" % repr(self.name),
|
||||
""
|
||||
if self.db_tablespace is None
|
||||
else " db_tablespace=%s" % repr(self.db_tablespace),
|
||||
"" if self.condition is None else " condition=%s" % self.condition,
|
||||
"" if not self.include else " include=%s" % repr(self.include),
|
||||
"" if not self.opclasses else " opclasses=%s" % repr(self.opclasses),
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
if self.__class__ == other.__class__:
|
||||
return self.deconstruct() == other.deconstruct()
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class IndexExpression(Func):
|
||||
"""Order and wrap expressions for CREATE INDEX statements."""
|
||||
|
||||
template = "%(expressions)s"
|
||||
wrapper_classes = (OrderBy, Collate)
|
||||
|
||||
def set_wrapper_classes(self, connection=None):
|
||||
# Some databases (e.g. MySQL) treats COLLATE as an indexed expression.
|
||||
if connection and connection.features.collate_as_index_expression:
|
||||
self.wrapper_classes = tuple(
|
||||
[
|
||||
wrapper_cls
|
||||
for wrapper_cls in self.wrapper_classes
|
||||
if wrapper_cls is not Collate
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def register_wrappers(cls, *wrapper_classes):
|
||||
cls.wrapper_classes = wrapper_classes
|
||||
|
||||
def resolve_expression(
|
||||
self,
|
||||
query=None,
|
||||
allow_joins=True,
|
||||
reuse=None,
|
||||
summarize=False,
|
||||
for_save=False,
|
||||
):
|
||||
expressions = list(self.flatten())
|
||||
# Split expressions and wrappers.
|
||||
index_expressions, wrappers = partition(
|
||||
lambda e: isinstance(e, self.wrapper_classes),
|
||||
expressions,
|
||||
)
|
||||
wrapper_types = [type(wrapper) for wrapper in wrappers]
|
||||
if len(wrapper_types) != len(set(wrapper_types)):
|
||||
raise ValueError(
|
||||
"Multiple references to %s can't be used in an indexed "
|
||||
"expression."
|
||||
% ", ".join(
|
||||
[wrapper_cls.__qualname__ for wrapper_cls in self.wrapper_classes]
|
||||
)
|
||||
)
|
||||
if expressions[1 : len(wrappers) + 1] != wrappers:
|
||||
raise ValueError(
|
||||
"%s must be topmost expressions in an indexed expression."
|
||||
% ", ".join(
|
||||
[wrapper_cls.__qualname__ for wrapper_cls in self.wrapper_classes]
|
||||
)
|
||||
)
|
||||
# Wrap expressions in parentheses if they are not column references.
|
||||
root_expression = index_expressions[1]
|
||||
resolve_root_expression = root_expression.resolve_expression(
|
||||
query,
|
||||
allow_joins,
|
||||
reuse,
|
||||
summarize,
|
||||
for_save,
|
||||
)
|
||||
if not isinstance(resolve_root_expression, Col):
|
||||
root_expression = Func(root_expression, template="(%(expressions)s)")
|
||||
|
||||
if wrappers:
|
||||
# Order wrappers and set their expressions.
|
||||
wrappers = sorted(
|
||||
wrappers,
|
||||
key=lambda w: self.wrapper_classes.index(type(w)),
|
||||
)
|
||||
wrappers = [wrapper.copy() for wrapper in wrappers]
|
||||
for i, wrapper in enumerate(wrappers[:-1]):
|
||||
wrapper.set_source_expressions([wrappers[i + 1]])
|
||||
# Set the root expression on the deepest wrapper.
|
||||
wrappers[-1].set_source_expressions([root_expression])
|
||||
self.set_source_expressions([wrappers[0]])
|
||||
else:
|
||||
# Use the root expression, if there are no wrappers.
|
||||
self.set_source_expressions([root_expression])
|
||||
return super().resolve_expression(
|
||||
query, allow_joins, reuse, summarize, for_save
|
||||
)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
# Casting to numeric is unnecessary.
|
||||
return self.as_sql(compiler, connection, **extra_context)
|
||||
@@ -0,0 +1,727 @@
|
||||
import itertools
|
||||
import math
|
||||
|
||||
from django.core.exceptions import EmptyResultSet, FullResultSet
|
||||
from django.db.models.expressions import Case, Expression, Func, Value, When
|
||||
from django.db.models.fields import (
|
||||
BooleanField,
|
||||
CharField,
|
||||
DateTimeField,
|
||||
Field,
|
||||
IntegerField,
|
||||
UUIDField,
|
||||
)
|
||||
from django.db.models.query_utils import RegisterLookupMixin
|
||||
from django.utils.datastructures import OrderedSet
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.hashable import make_hashable
|
||||
|
||||
|
||||
class Lookup(Expression):
|
||||
lookup_name = None
|
||||
prepare_rhs = True
|
||||
can_use_none_as_rhs = False
|
||||
|
||||
def __init__(self, lhs, rhs):
|
||||
self.lhs, self.rhs = lhs, rhs
|
||||
self.rhs = self.get_prep_lookup()
|
||||
self.lhs = self.get_prep_lhs()
|
||||
if hasattr(self.lhs, "get_bilateral_transforms"):
|
||||
bilateral_transforms = self.lhs.get_bilateral_transforms()
|
||||
else:
|
||||
bilateral_transforms = []
|
||||
if bilateral_transforms:
|
||||
# Warn the user as soon as possible if they are trying to apply
|
||||
# a bilateral transformation on a nested QuerySet: that won't work.
|
||||
from django.db.models.sql.query import Query # avoid circular import
|
||||
|
||||
if isinstance(rhs, Query):
|
||||
raise NotImplementedError(
|
||||
"Bilateral transformations on nested querysets are not implemented."
|
||||
)
|
||||
self.bilateral_transforms = bilateral_transforms
|
||||
|
||||
def apply_bilateral_transforms(self, value):
|
||||
for transform in self.bilateral_transforms:
|
||||
value = transform(value)
|
||||
return value
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}({self.lhs!r}, {self.rhs!r})"
|
||||
|
||||
def batch_process_rhs(self, compiler, connection, rhs=None):
|
||||
if rhs is None:
|
||||
rhs = self.rhs
|
||||
if self.bilateral_transforms:
|
||||
sqls, sqls_params = [], []
|
||||
for p in rhs:
|
||||
value = Value(p, output_field=self.lhs.output_field)
|
||||
value = self.apply_bilateral_transforms(value)
|
||||
value = value.resolve_expression(compiler.query)
|
||||
sql, sql_params = compiler.compile(value)
|
||||
sqls.append(sql)
|
||||
sqls_params.extend(sql_params)
|
||||
else:
|
||||
_, params = self.get_db_prep_lookup(rhs, connection)
|
||||
sqls, sqls_params = ["%s"] * len(params), params
|
||||
return sqls, sqls_params
|
||||
|
||||
def get_source_expressions(self):
|
||||
if self.rhs_is_direct_value():
|
||||
return [self.lhs]
|
||||
return [self.lhs, self.rhs]
|
||||
|
||||
def set_source_expressions(self, new_exprs):
|
||||
if len(new_exprs) == 1:
|
||||
self.lhs = new_exprs[0]
|
||||
else:
|
||||
self.lhs, self.rhs = new_exprs
|
||||
|
||||
def get_prep_lookup(self):
|
||||
if not self.prepare_rhs or hasattr(self.rhs, "resolve_expression"):
|
||||
return self.rhs
|
||||
if hasattr(self.lhs, "output_field"):
|
||||
if hasattr(self.lhs.output_field, "get_prep_value"):
|
||||
return self.lhs.output_field.get_prep_value(self.rhs)
|
||||
elif self.rhs_is_direct_value():
|
||||
return Value(self.rhs)
|
||||
return self.rhs
|
||||
|
||||
def get_prep_lhs(self):
|
||||
if hasattr(self.lhs, "resolve_expression"):
|
||||
return self.lhs
|
||||
return Value(self.lhs)
|
||||
|
||||
def get_db_prep_lookup(self, value, connection):
|
||||
return ("%s", [value])
|
||||
|
||||
def process_lhs(self, compiler, connection, lhs=None):
|
||||
lhs = lhs or self.lhs
|
||||
if hasattr(lhs, "resolve_expression"):
|
||||
lhs = lhs.resolve_expression(compiler.query)
|
||||
sql, params = compiler.compile(lhs)
|
||||
if isinstance(lhs, Lookup):
|
||||
# Wrapped in parentheses to respect operator precedence.
|
||||
sql = f"({sql})"
|
||||
return sql, params
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
value = self.rhs
|
||||
if self.bilateral_transforms:
|
||||
if self.rhs_is_direct_value():
|
||||
# Do not call get_db_prep_lookup here as the value will be
|
||||
# transformed before being used for lookup
|
||||
value = Value(value, output_field=self.lhs.output_field)
|
||||
value = self.apply_bilateral_transforms(value)
|
||||
value = value.resolve_expression(compiler.query)
|
||||
if hasattr(value, "as_sql"):
|
||||
sql, params = compiler.compile(value)
|
||||
# Ensure expression is wrapped in parentheses to respect operator
|
||||
# precedence but avoid double wrapping as it can be misinterpreted
|
||||
# on some backends (e.g. subqueries on SQLite).
|
||||
if sql and sql[0] != "(":
|
||||
sql = "(%s)" % sql
|
||||
return sql, params
|
||||
else:
|
||||
return self.get_db_prep_lookup(value, connection)
|
||||
|
||||
def rhs_is_direct_value(self):
|
||||
return not hasattr(self.rhs, "as_sql")
|
||||
|
||||
def get_group_by_cols(self):
|
||||
cols = []
|
||||
for source in self.get_source_expressions():
|
||||
cols.extend(source.get_group_by_cols())
|
||||
return cols
|
||||
|
||||
def as_oracle(self, compiler, connection):
|
||||
# Oracle doesn't allow EXISTS() and filters to be compared to another
|
||||
# expression unless they're wrapped in a CASE WHEN.
|
||||
wrapped = False
|
||||
exprs = []
|
||||
for expr in (self.lhs, self.rhs):
|
||||
if connection.ops.conditional_expression_supported_in_where_clause(expr):
|
||||
expr = Case(When(expr, then=True), default=False)
|
||||
wrapped = True
|
||||
exprs.append(expr)
|
||||
lookup = type(self)(*exprs) if wrapped else self
|
||||
return lookup.as_sql(compiler, connection)
|
||||
|
||||
@cached_property
|
||||
def output_field(self):
|
||||
return BooleanField()
|
||||
|
||||
@property
|
||||
def identity(self):
|
||||
return self.__class__, self.lhs, self.rhs
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Lookup):
|
||||
return NotImplemented
|
||||
return self.identity == other.identity
|
||||
|
||||
def __hash__(self):
|
||||
return hash(make_hashable(self.identity))
|
||||
|
||||
def resolve_expression(
|
||||
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
|
||||
):
|
||||
c = self.copy()
|
||||
c.is_summary = summarize
|
||||
c.lhs = self.lhs.resolve_expression(
|
||||
query, allow_joins, reuse, summarize, for_save
|
||||
)
|
||||
if hasattr(self.rhs, "resolve_expression"):
|
||||
c.rhs = self.rhs.resolve_expression(
|
||||
query, allow_joins, reuse, summarize, for_save
|
||||
)
|
||||
return c
|
||||
|
||||
def select_format(self, compiler, sql, params):
|
||||
# Wrap filters with a CASE WHEN expression if a database backend
|
||||
# (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
|
||||
# BY list.
|
||||
if not compiler.connection.features.supports_boolean_expr_in_select_clause:
|
||||
sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
|
||||
return sql, params
|
||||
|
||||
|
||||
class Transform(RegisterLookupMixin, Func):
|
||||
"""
|
||||
RegisterLookupMixin() is first so that get_lookup() and get_transform()
|
||||
first examine self and then check output_field.
|
||||
"""
|
||||
|
||||
bilateral = False
|
||||
arity = 1
|
||||
|
||||
@property
|
||||
def lhs(self):
|
||||
return self.get_source_expressions()[0]
|
||||
|
||||
def get_bilateral_transforms(self):
|
||||
if hasattr(self.lhs, "get_bilateral_transforms"):
|
||||
bilateral_transforms = self.lhs.get_bilateral_transforms()
|
||||
else:
|
||||
bilateral_transforms = []
|
||||
if self.bilateral:
|
||||
bilateral_transforms.append(self.__class__)
|
||||
return bilateral_transforms
|
||||
|
||||
|
||||
class BuiltinLookup(Lookup):
|
||||
def process_lhs(self, compiler, connection, lhs=None):
|
||||
lhs_sql, params = super().process_lhs(compiler, connection, lhs)
|
||||
field_internal_type = self.lhs.output_field.get_internal_type()
|
||||
db_type = self.lhs.output_field.db_type(connection=connection)
|
||||
lhs_sql = connection.ops.field_cast_sql(db_type, field_internal_type) % lhs_sql
|
||||
lhs_sql = (
|
||||
connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql
|
||||
)
|
||||
return lhs_sql, list(params)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
lhs_sql, params = self.process_lhs(compiler, connection)
|
||||
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
|
||||
params.extend(rhs_params)
|
||||
rhs_sql = self.get_rhs_op(connection, rhs_sql)
|
||||
return "%s %s" % (lhs_sql, rhs_sql), params
|
||||
|
||||
def get_rhs_op(self, connection, rhs):
|
||||
return connection.operators[self.lookup_name] % rhs
|
||||
|
||||
|
||||
class FieldGetDbPrepValueMixin:
|
||||
"""
|
||||
Some lookups require Field.get_db_prep_value() to be called on their
|
||||
inputs.
|
||||
"""
|
||||
|
||||
get_db_prep_lookup_value_is_iterable = False
|
||||
|
||||
def get_db_prep_lookup(self, value, connection):
|
||||
# For relational fields, use the 'target_field' attribute of the
|
||||
# output_field.
|
||||
field = getattr(self.lhs.output_field, "target_field", None)
|
||||
get_db_prep_value = (
|
||||
getattr(field, "get_db_prep_value", None)
|
||||
or self.lhs.output_field.get_db_prep_value
|
||||
)
|
||||
return (
|
||||
"%s",
|
||||
[get_db_prep_value(v, connection, prepared=True) for v in value]
|
||||
if self.get_db_prep_lookup_value_is_iterable
|
||||
else [get_db_prep_value(value, connection, prepared=True)],
|
||||
)
|
||||
|
||||
|
||||
class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
|
||||
"""
|
||||
Some lookups require Field.get_db_prep_value() to be called on each value
|
||||
in an iterable.
|
||||
"""
|
||||
|
||||
get_db_prep_lookup_value_is_iterable = True
|
||||
|
||||
def get_prep_lookup(self):
|
||||
if hasattr(self.rhs, "resolve_expression"):
|
||||
return self.rhs
|
||||
prepared_values = []
|
||||
for rhs_value in self.rhs:
|
||||
if hasattr(rhs_value, "resolve_expression"):
|
||||
# An expression will be handled by the database but can coexist
|
||||
# alongside real values.
|
||||
pass
|
||||
elif self.prepare_rhs and hasattr(self.lhs.output_field, "get_prep_value"):
|
||||
rhs_value = self.lhs.output_field.get_prep_value(rhs_value)
|
||||
prepared_values.append(rhs_value)
|
||||
return prepared_values
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
if self.rhs_is_direct_value():
|
||||
# rhs should be an iterable of values. Use batch_process_rhs()
|
||||
# to prepare/transform those values.
|
||||
return self.batch_process_rhs(compiler, connection)
|
||||
else:
|
||||
return super().process_rhs(compiler, connection)
|
||||
|
||||
def resolve_expression_parameter(self, compiler, connection, sql, param):
|
||||
params = [param]
|
||||
if hasattr(param, "resolve_expression"):
|
||||
param = param.resolve_expression(compiler.query)
|
||||
if hasattr(param, "as_sql"):
|
||||
sql, params = compiler.compile(param)
|
||||
return sql, params
|
||||
|
||||
def batch_process_rhs(self, compiler, connection, rhs=None):
|
||||
pre_processed = super().batch_process_rhs(compiler, connection, rhs)
|
||||
# The params list may contain expressions which compile to a
|
||||
# sql/param pair. Zip them to get sql and param pairs that refer to the
|
||||
# same argument and attempt to replace them with the result of
|
||||
# compiling the param step.
|
||||
sql, params = zip(
|
||||
*(
|
||||
self.resolve_expression_parameter(compiler, connection, sql, param)
|
||||
for sql, param in zip(*pre_processed)
|
||||
)
|
||||
)
|
||||
params = itertools.chain.from_iterable(params)
|
||||
return sql, tuple(params)
|
||||
|
||||
|
||||
class PostgresOperatorLookup(Lookup):
|
||||
"""Lookup defined by operators on PostgreSQL."""
|
||||
|
||||
postgres_operator = None
|
||||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||
params = tuple(lhs_params) + tuple(rhs_params)
|
||||
return "%s %s %s" % (lhs, self.postgres_operator, rhs), params
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class Exact(FieldGetDbPrepValueMixin, BuiltinLookup):
|
||||
lookup_name = "exact"
|
||||
|
||||
def get_prep_lookup(self):
|
||||
from django.db.models.sql.query import Query # avoid circular import
|
||||
|
||||
if isinstance(self.rhs, Query):
|
||||
if self.rhs.has_limit_one():
|
||||
if not self.rhs.has_select_fields:
|
||||
self.rhs.clear_select_clause()
|
||||
self.rhs.add_fields(["pk"])
|
||||
else:
|
||||
raise ValueError(
|
||||
"The QuerySet value for an exact lookup must be limited to "
|
||||
"one result using slicing."
|
||||
)
|
||||
return super().get_prep_lookup()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
# Avoid comparison against direct rhs if lhs is a boolean value. That
|
||||
# turns "boolfield__exact=True" into "WHERE boolean_field" instead of
|
||||
# "WHERE boolean_field = True" when allowed.
|
||||
if (
|
||||
isinstance(self.rhs, bool)
|
||||
and getattr(self.lhs, "conditional", False)
|
||||
and connection.ops.conditional_expression_supported_in_where_clause(
|
||||
self.lhs
|
||||
)
|
||||
):
|
||||
lhs_sql, params = self.process_lhs(compiler, connection)
|
||||
template = "%s" if self.rhs else "NOT %s"
|
||||
return template % lhs_sql, params
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class IExact(BuiltinLookup):
|
||||
lookup_name = "iexact"
|
||||
prepare_rhs = False
|
||||
|
||||
def process_rhs(self, qn, connection):
|
||||
rhs, params = super().process_rhs(qn, connection)
|
||||
if params:
|
||||
params[0] = connection.ops.prep_for_iexact_query(params[0])
|
||||
return rhs, params
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class GreaterThan(FieldGetDbPrepValueMixin, BuiltinLookup):
|
||||
lookup_name = "gt"
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class GreaterThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
|
||||
lookup_name = "gte"
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class LessThan(FieldGetDbPrepValueMixin, BuiltinLookup):
|
||||
lookup_name = "lt"
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class LessThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
|
||||
lookup_name = "lte"
|
||||
|
||||
|
||||
class IntegerFieldFloatRounding:
|
||||
"""
|
||||
Allow floats to work as query values for IntegerField. Without this, the
|
||||
decimal portion of the float would always be discarded.
|
||||
"""
|
||||
|
||||
def get_prep_lookup(self):
|
||||
if isinstance(self.rhs, float):
|
||||
self.rhs = math.ceil(self.rhs)
|
||||
return super().get_prep_lookup()
|
||||
|
||||
|
||||
@IntegerField.register_lookup
|
||||
class IntegerGreaterThanOrEqual(IntegerFieldFloatRounding, GreaterThanOrEqual):
|
||||
pass
|
||||
|
||||
|
||||
@IntegerField.register_lookup
|
||||
class IntegerLessThan(IntegerFieldFloatRounding, LessThan):
|
||||
pass
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
|
||||
lookup_name = "in"
|
||||
|
||||
def get_prep_lookup(self):
|
||||
from django.db.models.sql.query import Query # avoid circular import
|
||||
|
||||
if isinstance(self.rhs, Query):
|
||||
self.rhs.clear_ordering(clear_default=True)
|
||||
if not self.rhs.has_select_fields:
|
||||
self.rhs.clear_select_clause()
|
||||
self.rhs.add_fields(["pk"])
|
||||
return super().get_prep_lookup()
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
db_rhs = getattr(self.rhs, "_db", None)
|
||||
if db_rhs is not None and db_rhs != connection.alias:
|
||||
raise ValueError(
|
||||
"Subqueries aren't allowed across different databases. Force "
|
||||
"the inner query to be evaluated using `list(inner_query)`."
|
||||
)
|
||||
|
||||
if self.rhs_is_direct_value():
|
||||
# Remove None from the list as NULL is never equal to anything.
|
||||
try:
|
||||
rhs = OrderedSet(self.rhs)
|
||||
rhs.discard(None)
|
||||
except TypeError: # Unhashable items in self.rhs
|
||||
rhs = [r for r in self.rhs if r is not None]
|
||||
|
||||
if not rhs:
|
||||
raise EmptyResultSet
|
||||
|
||||
# rhs should be an iterable; use batch_process_rhs() to
|
||||
# prepare/transform those values.
|
||||
sqls, sqls_params = self.batch_process_rhs(compiler, connection, rhs)
|
||||
placeholder = "(" + ", ".join(sqls) + ")"
|
||||
return (placeholder, sqls_params)
|
||||
return super().process_rhs(compiler, connection)
|
||||
|
||||
def get_rhs_op(self, connection, rhs):
|
||||
return "IN %s" % rhs
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
max_in_list_size = connection.ops.max_in_list_size()
|
||||
if (
|
||||
self.rhs_is_direct_value()
|
||||
and max_in_list_size
|
||||
and len(self.rhs) > max_in_list_size
|
||||
):
|
||||
return self.split_parameter_list_as_sql(compiler, connection)
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
def split_parameter_list_as_sql(self, compiler, connection):
|
||||
# This is a special case for databases which limit the number of
|
||||
# elements which can appear in an 'IN' clause.
|
||||
max_in_list_size = connection.ops.max_in_list_size()
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.batch_process_rhs(compiler, connection)
|
||||
in_clause_elements = ["("]
|
||||
params = []
|
||||
for offset in range(0, len(rhs_params), max_in_list_size):
|
||||
if offset > 0:
|
||||
in_clause_elements.append(" OR ")
|
||||
in_clause_elements.append("%s IN (" % lhs)
|
||||
params.extend(lhs_params)
|
||||
sqls = rhs[offset : offset + max_in_list_size]
|
||||
sqls_params = rhs_params[offset : offset + max_in_list_size]
|
||||
param_group = ", ".join(sqls)
|
||||
in_clause_elements.append(param_group)
|
||||
in_clause_elements.append(")")
|
||||
params.extend(sqls_params)
|
||||
in_clause_elements.append(")")
|
||||
return "".join(in_clause_elements), params
|
||||
|
||||
|
||||
class PatternLookup(BuiltinLookup):
|
||||
param_pattern = "%%%s%%"
|
||||
prepare_rhs = False
|
||||
|
||||
def get_rhs_op(self, connection, rhs):
|
||||
# Assume we are in startswith. We need to produce SQL like:
|
||||
# col LIKE %s, ['thevalue%']
|
||||
# For python values we can (and should) do that directly in Python,
|
||||
# but if the value is for example reference to other column, then
|
||||
# we need to add the % pattern match to the lookup by something like
|
||||
# col LIKE othercol || '%%'
|
||||
# So, for Python values we don't need any special pattern, but for
|
||||
# SQL reference values or SQL transformations we need the correct
|
||||
# pattern added.
|
||||
if hasattr(self.rhs, "as_sql") or self.bilateral_transforms:
|
||||
pattern = connection.pattern_ops[self.lookup_name].format(
|
||||
connection.pattern_esc
|
||||
)
|
||||
return pattern.format(rhs)
|
||||
else:
|
||||
return super().get_rhs_op(connection, rhs)
|
||||
|
||||
def process_rhs(self, qn, connection):
|
||||
rhs, params = super().process_rhs(qn, connection)
|
||||
if self.rhs_is_direct_value() and params and not self.bilateral_transforms:
|
||||
params[0] = self.param_pattern % connection.ops.prep_for_like_query(
|
||||
params[0]
|
||||
)
|
||||
return rhs, params
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class Contains(PatternLookup):
|
||||
lookup_name = "contains"
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class IContains(Contains):
|
||||
lookup_name = "icontains"
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class StartsWith(PatternLookup):
|
||||
lookup_name = "startswith"
|
||||
param_pattern = "%s%%"
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class IStartsWith(StartsWith):
|
||||
lookup_name = "istartswith"
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class EndsWith(PatternLookup):
|
||||
lookup_name = "endswith"
|
||||
param_pattern = "%%%s"
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class IEndsWith(EndsWith):
|
||||
lookup_name = "iendswith"
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class Range(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
|
||||
lookup_name = "range"
|
||||
|
||||
def get_rhs_op(self, connection, rhs):
|
||||
return "BETWEEN %s AND %s" % (rhs[0], rhs[1])
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class IsNull(BuiltinLookup):
|
||||
lookup_name = "isnull"
|
||||
prepare_rhs = False
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if not isinstance(self.rhs, bool):
|
||||
raise ValueError(
|
||||
"The QuerySet value for an isnull lookup must be True or False."
|
||||
)
|
||||
if isinstance(self.lhs, Value):
|
||||
if self.lhs.value is None or (
|
||||
self.lhs.value == ""
|
||||
and connection.features.interprets_empty_strings_as_nulls
|
||||
):
|
||||
result_exception = FullResultSet if self.rhs else EmptyResultSet
|
||||
else:
|
||||
result_exception = EmptyResultSet if self.rhs else FullResultSet
|
||||
raise result_exception
|
||||
sql, params = self.process_lhs(compiler, connection)
|
||||
if self.rhs:
|
||||
return "%s IS NULL" % sql, params
|
||||
else:
|
||||
return "%s IS NOT NULL" % sql, params
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class Regex(BuiltinLookup):
|
||||
lookup_name = "regex"
|
||||
prepare_rhs = False
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if self.lookup_name in connection.operators:
|
||||
return super().as_sql(compiler, connection)
|
||||
else:
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||
sql_template = connection.ops.regex_lookup(self.lookup_name)
|
||||
return sql_template % (lhs, rhs), lhs_params + rhs_params
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class IRegex(Regex):
|
||||
lookup_name = "iregex"
|
||||
|
||||
|
||||
class YearLookup(Lookup):
|
||||
def year_lookup_bounds(self, connection, year):
|
||||
from django.db.models.functions import ExtractIsoYear
|
||||
|
||||
iso_year = isinstance(self.lhs, ExtractIsoYear)
|
||||
output_field = self.lhs.lhs.output_field
|
||||
if isinstance(output_field, DateTimeField):
|
||||
bounds = connection.ops.year_lookup_bounds_for_datetime_field(
|
||||
year,
|
||||
iso_year=iso_year,
|
||||
)
|
||||
else:
|
||||
bounds = connection.ops.year_lookup_bounds_for_date_field(
|
||||
year,
|
||||
iso_year=iso_year,
|
||||
)
|
||||
return bounds
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
# Avoid the extract operation if the rhs is a direct value to allow
|
||||
# indexes to be used.
|
||||
if self.rhs_is_direct_value():
|
||||
# Skip the extract part by directly using the originating field,
|
||||
# that is self.lhs.lhs.
|
||||
lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
|
||||
rhs_sql, _ = self.process_rhs(compiler, connection)
|
||||
rhs_sql = self.get_direct_rhs_sql(connection, rhs_sql)
|
||||
start, finish = self.year_lookup_bounds(connection, self.rhs)
|
||||
params.extend(self.get_bound_params(start, finish))
|
||||
return "%s %s" % (lhs_sql, rhs_sql), params
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
def get_direct_rhs_sql(self, connection, rhs):
|
||||
return connection.operators[self.lookup_name] % rhs
|
||||
|
||||
def get_bound_params(self, start, finish):
|
||||
raise NotImplementedError(
|
||||
"subclasses of YearLookup must provide a get_bound_params() method"
|
||||
)
|
||||
|
||||
|
||||
class YearExact(YearLookup, Exact):
|
||||
def get_direct_rhs_sql(self, connection, rhs):
|
||||
return "BETWEEN %s AND %s"
|
||||
|
||||
def get_bound_params(self, start, finish):
|
||||
return (start, finish)
|
||||
|
||||
|
||||
class YearGt(YearLookup, GreaterThan):
|
||||
def get_bound_params(self, start, finish):
|
||||
return (finish,)
|
||||
|
||||
|
||||
class YearGte(YearLookup, GreaterThanOrEqual):
|
||||
def get_bound_params(self, start, finish):
|
||||
return (start,)
|
||||
|
||||
|
||||
class YearLt(YearLookup, LessThan):
|
||||
def get_bound_params(self, start, finish):
|
||||
return (start,)
|
||||
|
||||
|
||||
class YearLte(YearLookup, LessThanOrEqual):
|
||||
def get_bound_params(self, start, finish):
|
||||
return (finish,)
|
||||
|
||||
|
||||
class UUIDTextMixin:
|
||||
"""
|
||||
Strip hyphens from a value when filtering a UUIDField on backends without
|
||||
a native datatype for UUID.
|
||||
"""
|
||||
|
||||
def process_rhs(self, qn, connection):
|
||||
if not connection.features.has_native_uuid_field:
|
||||
from django.db.models.functions import Replace
|
||||
|
||||
if self.rhs_is_direct_value():
|
||||
self.rhs = Value(self.rhs)
|
||||
self.rhs = Replace(
|
||||
self.rhs, Value("-"), Value(""), output_field=CharField()
|
||||
)
|
||||
rhs, params = super().process_rhs(qn, connection)
|
||||
return rhs, params
|
||||
|
||||
|
||||
@UUIDField.register_lookup
|
||||
class UUIDIExact(UUIDTextMixin, IExact):
|
||||
pass
|
||||
|
||||
|
||||
@UUIDField.register_lookup
|
||||
class UUIDContains(UUIDTextMixin, Contains):
|
||||
pass
|
||||
|
||||
|
||||
@UUIDField.register_lookup
|
||||
class UUIDIContains(UUIDTextMixin, IContains):
|
||||
pass
|
||||
|
||||
|
||||
@UUIDField.register_lookup
|
||||
class UUIDStartsWith(UUIDTextMixin, StartsWith):
|
||||
pass
|
||||
|
||||
|
||||
@UUIDField.register_lookup
|
||||
class UUIDIStartsWith(UUIDTextMixin, IStartsWith):
|
||||
pass
|
||||
|
||||
|
||||
@UUIDField.register_lookup
|
||||
class UUIDEndsWith(UUIDTextMixin, EndsWith):
|
||||
pass
|
||||
|
||||
|
||||
@UUIDField.register_lookup
|
||||
class UUIDIEndsWith(UUIDTextMixin, IEndsWith):
|
||||
pass
|
||||
@@ -0,0 +1,213 @@
|
||||
import copy
|
||||
import inspect
|
||||
from functools import wraps
|
||||
from importlib import import_module
|
||||
|
||||
from django.db import router
|
||||
from django.db.models.query import QuerySet
|
||||
|
||||
|
||||
class BaseManager:
|
||||
# To retain order, track each time a Manager instance is created.
|
||||
creation_counter = 0
|
||||
|
||||
# Set to True for the 'objects' managers that are automatically created.
|
||||
auto_created = False
|
||||
|
||||
#: If set to True the manager will be serialized into migrations and will
|
||||
#: thus be available in e.g. RunPython operations.
|
||||
use_in_migrations = False
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# Capture the arguments to make returning them trivial.
|
||||
obj = super().__new__(cls)
|
||||
obj._constructor_args = (args, kwargs)
|
||||
return obj
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._set_creation_counter()
|
||||
self.model = None
|
||||
self.name = None
|
||||
self._db = None
|
||||
self._hints = {}
|
||||
|
||||
def __str__(self):
|
||||
"""Return "app_label.model_label.manager_name"."""
|
||||
return "%s.%s" % (self.model._meta.label, self.name)
|
||||
|
||||
def __class_getitem__(cls, *args, **kwargs):
|
||||
return cls
|
||||
|
||||
def deconstruct(self):
|
||||
"""
|
||||
Return a 5-tuple of the form (as_manager (True), manager_class,
|
||||
queryset_class, args, kwargs).
|
||||
|
||||
Raise a ValueError if the manager is dynamically generated.
|
||||
"""
|
||||
qs_class = self._queryset_class
|
||||
if getattr(self, "_built_with_as_manager", False):
|
||||
# using MyQuerySet.as_manager()
|
||||
return (
|
||||
True, # as_manager
|
||||
None, # manager_class
|
||||
"%s.%s" % (qs_class.__module__, qs_class.__name__), # qs_class
|
||||
None, # args
|
||||
None, # kwargs
|
||||
)
|
||||
else:
|
||||
module_name = self.__module__
|
||||
name = self.__class__.__name__
|
||||
# Make sure it's actually there and not an inner class
|
||||
module = import_module(module_name)
|
||||
if not hasattr(module, name):
|
||||
raise ValueError(
|
||||
"Could not find manager %s in %s.\n"
|
||||
"Please note that you need to inherit from managers you "
|
||||
"dynamically generated with 'from_queryset()'."
|
||||
% (name, module_name)
|
||||
)
|
||||
return (
|
||||
False, # as_manager
|
||||
"%s.%s" % (module_name, name), # manager_class
|
||||
None, # qs_class
|
||||
self._constructor_args[0], # args
|
||||
self._constructor_args[1], # kwargs
|
||||
)
|
||||
|
||||
def check(self, **kwargs):
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _get_queryset_methods(cls, queryset_class):
|
||||
def create_method(name, method):
|
||||
@wraps(method)
|
||||
def manager_method(self, *args, **kwargs):
|
||||
return getattr(self.get_queryset(), name)(*args, **kwargs)
|
||||
|
||||
return manager_method
|
||||
|
||||
new_methods = {}
|
||||
for name, method in inspect.getmembers(
|
||||
queryset_class, predicate=inspect.isfunction
|
||||
):
|
||||
# Only copy missing methods.
|
||||
if hasattr(cls, name):
|
||||
continue
|
||||
# Only copy public methods or methods with the attribute
|
||||
# queryset_only=False.
|
||||
queryset_only = getattr(method, "queryset_only", None)
|
||||
if queryset_only or (queryset_only is None and name.startswith("_")):
|
||||
continue
|
||||
# Copy the method onto the manager.
|
||||
new_methods[name] = create_method(name, method)
|
||||
return new_methods
|
||||
|
||||
@classmethod
|
||||
def from_queryset(cls, queryset_class, class_name=None):
|
||||
if class_name is None:
|
||||
class_name = "%sFrom%s" % (cls.__name__, queryset_class.__name__)
|
||||
return type(
|
||||
class_name,
|
||||
(cls,),
|
||||
{
|
||||
"_queryset_class": queryset_class,
|
||||
**cls._get_queryset_methods(queryset_class),
|
||||
},
|
||||
)
|
||||
|
||||
def contribute_to_class(self, cls, name):
|
||||
self.name = self.name or name
|
||||
self.model = cls
|
||||
|
||||
setattr(cls, name, ManagerDescriptor(self))
|
||||
|
||||
cls._meta.add_manager(self)
|
||||
|
||||
def _set_creation_counter(self):
|
||||
"""
|
||||
Set the creation counter value for this instance and increment the
|
||||
class-level copy.
|
||||
"""
|
||||
self.creation_counter = BaseManager.creation_counter
|
||||
BaseManager.creation_counter += 1
|
||||
|
||||
def db_manager(self, using=None, hints=None):
|
||||
obj = copy.copy(self)
|
||||
obj._db = using or self._db
|
||||
obj._hints = hints or self._hints
|
||||
return obj
|
||||
|
||||
@property
|
||||
def db(self):
|
||||
return self._db or router.db_for_read(self.model, **self._hints)
|
||||
|
||||
#######################
|
||||
# PROXIES TO QUERYSET #
|
||||
#######################
|
||||
|
||||
def get_queryset(self):
|
||||
"""
|
||||
Return a new QuerySet object. Subclasses can override this method to
|
||||
customize the behavior of the Manager.
|
||||
"""
|
||||
return self._queryset_class(model=self.model, using=self._db, hints=self._hints)
|
||||
|
||||
def all(self):
|
||||
# We can't proxy this method through the `QuerySet` like we do for the
|
||||
# rest of the `QuerySet` methods. This is because `QuerySet.all()`
|
||||
# works by creating a "copy" of the current queryset and in making said
|
||||
# copy, all the cached `prefetch_related` lookups are lost. See the
|
||||
# implementation of `RelatedManager.get_queryset()` for a better
|
||||
# understanding of how this comes into play.
|
||||
return self.get_queryset()
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
isinstance(other, self.__class__)
|
||||
and self._constructor_args == other._constructor_args
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
|
||||
class Manager(BaseManager.from_queryset(QuerySet)):
|
||||
pass
|
||||
|
||||
|
||||
class ManagerDescriptor:
|
||||
def __init__(self, manager):
|
||||
self.manager = manager
|
||||
|
||||
def __get__(self, instance, cls=None):
|
||||
if instance is not None:
|
||||
raise AttributeError(
|
||||
"Manager isn't accessible via %s instances" % cls.__name__
|
||||
)
|
||||
|
||||
if cls._meta.abstract:
|
||||
raise AttributeError(
|
||||
"Manager isn't available; %s is abstract" % (cls._meta.object_name,)
|
||||
)
|
||||
|
||||
if cls._meta.swapped:
|
||||
raise AttributeError(
|
||||
"Manager isn't available; '%s' has been swapped for '%s'"
|
||||
% (
|
||||
cls._meta.label,
|
||||
cls._meta.swapped,
|
||||
)
|
||||
)
|
||||
|
||||
return cls._meta.managers_map[self.manager.name]
|
||||
|
||||
|
||||
class EmptyManager(Manager):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def get_queryset(self):
|
||||
return super().get_queryset().none()
|
||||
1014
srcs/.venv/lib/python3.11/site-packages/django/db/models/options.py
Normal file
1014
srcs/.venv/lib/python3.11/site-packages/django/db/models/options.py
Normal file
File diff suppressed because it is too large
Load Diff
2631
srcs/.venv/lib/python3.11/site-packages/django/db/models/query.py
Normal file
2631
srcs/.venv/lib/python3.11/site-packages/django/db/models/query.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,435 @@
|
||||
"""
|
||||
Various data structures used in query construction.
|
||||
|
||||
Factored out from django.db.models.query to avoid making the main module very
|
||||
large and/or so that they can be used by other modules without getting into
|
||||
circular import difficulties.
|
||||
"""
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db import DEFAULT_DB_ALIAS, DatabaseError, connections
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.utils import tree
|
||||
|
||||
logger = logging.getLogger("django.db.models")
|
||||
|
||||
# PathInfo is used when converting lookups (fk__somecol). The contents
|
||||
# describe the relation in Model terms (model Options and Fields for both
|
||||
# sides of the relation. The join_field is the field backing the relation.
|
||||
PathInfo = namedtuple(
|
||||
"PathInfo",
|
||||
"from_opts to_opts target_fields join_field m2m direct filtered_relation",
|
||||
)
|
||||
|
||||
|
||||
def subclasses(cls):
|
||||
yield cls
|
||||
for subclass in cls.__subclasses__():
|
||||
yield from subclasses(subclass)
|
||||
|
||||
|
||||
class Q(tree.Node):
|
||||
"""
|
||||
Encapsulate filters as objects that can then be combined logically (using
|
||||
`&` and `|`).
|
||||
"""
|
||||
|
||||
# Connection types
|
||||
AND = "AND"
|
||||
OR = "OR"
|
||||
XOR = "XOR"
|
||||
default = AND
|
||||
conditional = True
|
||||
|
||||
def __init__(self, *args, _connector=None, _negated=False, **kwargs):
|
||||
super().__init__(
|
||||
children=[*args, *sorted(kwargs.items())],
|
||||
connector=_connector,
|
||||
negated=_negated,
|
||||
)
|
||||
|
||||
def _combine(self, other, conn):
|
||||
if getattr(other, "conditional", False) is False:
|
||||
raise TypeError(other)
|
||||
if not self:
|
||||
return other.copy()
|
||||
if not other and isinstance(other, Q):
|
||||
return self.copy()
|
||||
|
||||
obj = self.create(connector=conn)
|
||||
obj.add(self, conn)
|
||||
obj.add(other, conn)
|
||||
return obj
|
||||
|
||||
def __or__(self, other):
|
||||
return self._combine(other, self.OR)
|
||||
|
||||
def __and__(self, other):
|
||||
return self._combine(other, self.AND)
|
||||
|
||||
def __xor__(self, other):
|
||||
return self._combine(other, self.XOR)
|
||||
|
||||
def __invert__(self):
|
||||
obj = self.copy()
|
||||
obj.negate()
|
||||
return obj
|
||||
|
||||
def resolve_expression(
|
||||
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
|
||||
):
|
||||
# We must promote any new joins to left outer joins so that when Q is
|
||||
# used as an expression, rows aren't filtered due to joins.
|
||||
clause, joins = query._add_q(
|
||||
self,
|
||||
reuse,
|
||||
allow_joins=allow_joins,
|
||||
split_subq=False,
|
||||
check_filterable=False,
|
||||
summarize=summarize,
|
||||
)
|
||||
query.promote_joins(joins)
|
||||
return clause
|
||||
|
||||
def flatten(self):
|
||||
"""
|
||||
Recursively yield this Q object and all subexpressions, in depth-first
|
||||
order.
|
||||
"""
|
||||
yield self
|
||||
for child in self.children:
|
||||
if isinstance(child, tuple):
|
||||
# Use the lookup.
|
||||
child = child[1]
|
||||
if hasattr(child, "flatten"):
|
||||
yield from child.flatten()
|
||||
else:
|
||||
yield child
|
||||
|
||||
def check(self, against, using=DEFAULT_DB_ALIAS):
|
||||
"""
|
||||
Do a database query to check if the expressions of the Q instance
|
||||
matches against the expressions.
|
||||
"""
|
||||
# Avoid circular imports.
|
||||
from django.db.models import BooleanField, Value
|
||||
from django.db.models.functions import Coalesce
|
||||
from django.db.models.sql import Query
|
||||
from django.db.models.sql.constants import SINGLE
|
||||
|
||||
query = Query(None)
|
||||
for name, value in against.items():
|
||||
if not hasattr(value, "resolve_expression"):
|
||||
value = Value(value)
|
||||
query.add_annotation(value, name, select=False)
|
||||
query.add_annotation(Value(1), "_check")
|
||||
# This will raise a FieldError if a field is missing in "against".
|
||||
if connections[using].features.supports_comparing_boolean_expr:
|
||||
query.add_q(Q(Coalesce(self, True, output_field=BooleanField())))
|
||||
else:
|
||||
query.add_q(self)
|
||||
compiler = query.get_compiler(using=using)
|
||||
try:
|
||||
return compiler.execute_sql(SINGLE) is not None
|
||||
except DatabaseError as e:
|
||||
logger.warning("Got a database error calling check() on %r: %s", self, e)
|
||||
return True
|
||||
|
||||
def deconstruct(self):
|
||||
path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
|
||||
if path.startswith("django.db.models.query_utils"):
|
||||
path = path.replace("django.db.models.query_utils", "django.db.models")
|
||||
args = tuple(self.children)
|
||||
kwargs = {}
|
||||
if self.connector != self.default:
|
||||
kwargs["_connector"] = self.connector
|
||||
if self.negated:
|
||||
kwargs["_negated"] = True
|
||||
return path, args, kwargs
|
||||
|
||||
|
||||
class DeferredAttribute:
|
||||
"""
|
||||
A wrapper for a deferred-loading field. When the value is read from this
|
||||
object the first time, the query is executed.
|
||||
"""
|
||||
|
||||
def __init__(self, field):
|
||||
self.field = field
|
||||
|
||||
def __get__(self, instance, cls=None):
|
||||
"""
|
||||
Retrieve and caches the value from the datastore on the first lookup.
|
||||
Return the cached value.
|
||||
"""
|
||||
if instance is None:
|
||||
return self
|
||||
data = instance.__dict__
|
||||
field_name = self.field.attname
|
||||
if field_name not in data:
|
||||
# Let's see if the field is part of the parent chain. If so we
|
||||
# might be able to reuse the already loaded value. Refs #18343.
|
||||
val = self._check_parent_chain(instance)
|
||||
if val is None:
|
||||
instance.refresh_from_db(fields=[field_name])
|
||||
else:
|
||||
data[field_name] = val
|
||||
return data[field_name]
|
||||
|
||||
def _check_parent_chain(self, instance):
|
||||
"""
|
||||
Check if the field value can be fetched from a parent field already
|
||||
loaded in the instance. This can be done if the to-be fetched
|
||||
field is a primary key field.
|
||||
"""
|
||||
opts = instance._meta
|
||||
link_field = opts.get_ancestor_link(self.field.model)
|
||||
if self.field.primary_key and self.field != link_field:
|
||||
return getattr(instance, link_field.attname)
|
||||
return None
|
||||
|
||||
|
||||
class class_or_instance_method:
|
||||
"""
|
||||
Hook used in RegisterLookupMixin to return partial functions depending on
|
||||
the caller type (instance or class of models.Field).
|
||||
"""
|
||||
|
||||
def __init__(self, class_method, instance_method):
|
||||
self.class_method = class_method
|
||||
self.instance_method = instance_method
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
if instance is None:
|
||||
return functools.partial(self.class_method, owner)
|
||||
return functools.partial(self.instance_method, instance)
|
||||
|
||||
|
||||
class RegisterLookupMixin:
|
||||
def _get_lookup(self, lookup_name):
|
||||
return self.get_lookups().get(lookup_name, None)
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_class_lookups(cls):
|
||||
class_lookups = [
|
||||
parent.__dict__.get("class_lookups", {}) for parent in inspect.getmro(cls)
|
||||
]
|
||||
return cls.merge_dicts(class_lookups)
|
||||
|
||||
def get_instance_lookups(self):
|
||||
class_lookups = self.get_class_lookups()
|
||||
if instance_lookups := getattr(self, "instance_lookups", None):
|
||||
return {**class_lookups, **instance_lookups}
|
||||
return class_lookups
|
||||
|
||||
get_lookups = class_or_instance_method(get_class_lookups, get_instance_lookups)
|
||||
get_class_lookups = classmethod(get_class_lookups)
|
||||
|
||||
def get_lookup(self, lookup_name):
|
||||
from django.db.models.lookups import Lookup
|
||||
|
||||
found = self._get_lookup(lookup_name)
|
||||
if found is None and hasattr(self, "output_field"):
|
||||
return self.output_field.get_lookup(lookup_name)
|
||||
if found is not None and not issubclass(found, Lookup):
|
||||
return None
|
||||
return found
|
||||
|
||||
def get_transform(self, lookup_name):
|
||||
from django.db.models.lookups import Transform
|
||||
|
||||
found = self._get_lookup(lookup_name)
|
||||
if found is None and hasattr(self, "output_field"):
|
||||
return self.output_field.get_transform(lookup_name)
|
||||
if found is not None and not issubclass(found, Transform):
|
||||
return None
|
||||
return found
|
||||
|
||||
@staticmethod
|
||||
def merge_dicts(dicts):
|
||||
"""
|
||||
Merge dicts in reverse to preference the order of the original list. e.g.,
|
||||
merge_dicts([a, b]) will preference the keys in 'a' over those in 'b'.
|
||||
"""
|
||||
merged = {}
|
||||
for d in reversed(dicts):
|
||||
merged.update(d)
|
||||
return merged
|
||||
|
||||
@classmethod
|
||||
def _clear_cached_class_lookups(cls):
|
||||
for subclass in subclasses(cls):
|
||||
subclass.get_class_lookups.cache_clear()
|
||||
|
||||
def register_class_lookup(cls, lookup, lookup_name=None):
|
||||
if lookup_name is None:
|
||||
lookup_name = lookup.lookup_name
|
||||
if "class_lookups" not in cls.__dict__:
|
||||
cls.class_lookups = {}
|
||||
cls.class_lookups[lookup_name] = lookup
|
||||
cls._clear_cached_class_lookups()
|
||||
return lookup
|
||||
|
||||
def register_instance_lookup(self, lookup, lookup_name=None):
|
||||
if lookup_name is None:
|
||||
lookup_name = lookup.lookup_name
|
||||
if "instance_lookups" not in self.__dict__:
|
||||
self.instance_lookups = {}
|
||||
self.instance_lookups[lookup_name] = lookup
|
||||
return lookup
|
||||
|
||||
register_lookup = class_or_instance_method(
|
||||
register_class_lookup, register_instance_lookup
|
||||
)
|
||||
register_class_lookup = classmethod(register_class_lookup)
|
||||
|
||||
def _unregister_class_lookup(cls, lookup, lookup_name=None):
|
||||
"""
|
||||
Remove given lookup from cls lookups. For use in tests only as it's
|
||||
not thread-safe.
|
||||
"""
|
||||
if lookup_name is None:
|
||||
lookup_name = lookup.lookup_name
|
||||
del cls.class_lookups[lookup_name]
|
||||
cls._clear_cached_class_lookups()
|
||||
|
||||
def _unregister_instance_lookup(self, lookup, lookup_name=None):
|
||||
"""
|
||||
Remove given lookup from instance lookups. For use in tests only as
|
||||
it's not thread-safe.
|
||||
"""
|
||||
if lookup_name is None:
|
||||
lookup_name = lookup.lookup_name
|
||||
del self.instance_lookups[lookup_name]
|
||||
|
||||
_unregister_lookup = class_or_instance_method(
|
||||
_unregister_class_lookup, _unregister_instance_lookup
|
||||
)
|
||||
_unregister_class_lookup = classmethod(_unregister_class_lookup)
|
||||
|
||||
|
||||
def select_related_descend(field, restricted, requested, select_mask, reverse=False):
|
||||
"""
|
||||
Return True if this field should be used to descend deeper for
|
||||
select_related() purposes. Used by both the query construction code
|
||||
(compiler.get_related_selections()) and the model instance creation code
|
||||
(compiler.klass_info).
|
||||
|
||||
Arguments:
|
||||
* field - the field to be checked
|
||||
* restricted - a boolean field, indicating if the field list has been
|
||||
manually restricted using a requested clause)
|
||||
* requested - The select_related() dictionary.
|
||||
* select_mask - the dictionary of selected fields.
|
||||
* reverse - boolean, True if we are checking a reverse select related
|
||||
"""
|
||||
if not field.remote_field:
|
||||
return False
|
||||
if field.remote_field.parent_link and not reverse:
|
||||
return False
|
||||
if restricted:
|
||||
if reverse and field.related_query_name() not in requested:
|
||||
return False
|
||||
if not reverse and field.name not in requested:
|
||||
return False
|
||||
if not restricted and field.null:
|
||||
return False
|
||||
if (
|
||||
restricted
|
||||
and select_mask
|
||||
and field.name in requested
|
||||
and field not in select_mask
|
||||
):
|
||||
raise FieldError(
|
||||
f"Field {field.model._meta.object_name}.{field.name} cannot be both "
|
||||
"deferred and traversed using select_related at the same time."
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def refs_expression(lookup_parts, annotations):
|
||||
"""
|
||||
Check if the lookup_parts contains references to the given annotations set.
|
||||
Because the LOOKUP_SEP is contained in the default annotation names, check
|
||||
each prefix of the lookup_parts for a match.
|
||||
"""
|
||||
for n in range(1, len(lookup_parts) + 1):
|
||||
level_n_lookup = LOOKUP_SEP.join(lookup_parts[0:n])
|
||||
if annotations.get(level_n_lookup):
|
||||
return level_n_lookup, lookup_parts[n:]
|
||||
return None, ()
|
||||
|
||||
|
||||
def check_rel_lookup_compatibility(model, target_opts, field):
|
||||
"""
|
||||
Check that self.model is compatible with target_opts. Compatibility
|
||||
is OK if:
|
||||
1) model and opts match (where proxy inheritance is removed)
|
||||
2) model is parent of opts' model or the other way around
|
||||
"""
|
||||
|
||||
def check(opts):
|
||||
return (
|
||||
model._meta.concrete_model == opts.concrete_model
|
||||
or opts.concrete_model in model._meta.get_parent_list()
|
||||
or model in opts.get_parent_list()
|
||||
)
|
||||
|
||||
# If the field is a primary key, then doing a query against the field's
|
||||
# model is ok, too. Consider the case:
|
||||
# class Restaurant(models.Model):
|
||||
# place = OneToOneField(Place, primary_key=True):
|
||||
# Restaurant.objects.filter(pk__in=Restaurant.objects.all()).
|
||||
# If we didn't have the primary key check, then pk__in (== place__in) would
|
||||
# give Place's opts as the target opts, but Restaurant isn't compatible
|
||||
# with that. This logic applies only to primary keys, as when doing __in=qs,
|
||||
# we are going to turn this into __in=qs.values('pk') later on.
|
||||
return check(target_opts) or (
|
||||
getattr(field, "primary_key", False) and check(field.model._meta)
|
||||
)
|
||||
|
||||
|
||||
class FilteredRelation:
|
||||
"""Specify custom filtering in the ON clause of SQL joins."""
|
||||
|
||||
def __init__(self, relation_name, *, condition=Q()):
|
||||
if not relation_name:
|
||||
raise ValueError("relation_name cannot be empty.")
|
||||
self.relation_name = relation_name
|
||||
self.alias = None
|
||||
if not isinstance(condition, Q):
|
||||
raise ValueError("condition argument must be a Q() instance.")
|
||||
self.condition = condition
|
||||
self.path = []
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, self.__class__):
|
||||
return NotImplemented
|
||||
return (
|
||||
self.relation_name == other.relation_name
|
||||
and self.alias == other.alias
|
||||
and self.condition == other.condition
|
||||
)
|
||||
|
||||
def clone(self):
|
||||
clone = FilteredRelation(self.relation_name, condition=self.condition)
|
||||
clone.alias = self.alias
|
||||
clone.path = self.path[:]
|
||||
return clone
|
||||
|
||||
def resolve_expression(self, *args, **kwargs):
|
||||
"""
|
||||
QuerySet.annotate() only accepts expression-like arguments
|
||||
(with a resolve_expression() method).
|
||||
"""
|
||||
raise NotImplementedError("FilteredRelation.resolve_expression() is unused.")
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
# Resolve the condition in Join.filtered_relation.
|
||||
query = compiler.query
|
||||
where = query.build_filtered_relation_q(self.condition, reuse=set(self.path))
|
||||
return compiler.compile(where)
|
||||
@@ -0,0 +1,54 @@
|
||||
from functools import partial
|
||||
|
||||
from django.db.models.utils import make_model_tuple
|
||||
from django.dispatch import Signal
|
||||
|
||||
class_prepared = Signal()
|
||||
|
||||
|
||||
class ModelSignal(Signal):
|
||||
"""
|
||||
Signal subclass that allows the sender to be lazily specified as a string
|
||||
of the `app_label.ModelName` form.
|
||||
"""
|
||||
|
||||
def _lazy_method(self, method, apps, receiver, sender, **kwargs):
|
||||
from django.db.models.options import Options
|
||||
|
||||
# This partial takes a single optional argument named "sender".
|
||||
partial_method = partial(method, receiver, **kwargs)
|
||||
if isinstance(sender, str):
|
||||
apps = apps or Options.default_apps
|
||||
apps.lazy_model_operation(partial_method, make_model_tuple(sender))
|
||||
else:
|
||||
return partial_method(sender)
|
||||
|
||||
def connect(self, receiver, sender=None, weak=True, dispatch_uid=None, apps=None):
|
||||
self._lazy_method(
|
||||
super().connect,
|
||||
apps,
|
||||
receiver,
|
||||
sender,
|
||||
weak=weak,
|
||||
dispatch_uid=dispatch_uid,
|
||||
)
|
||||
|
||||
def disconnect(self, receiver=None, sender=None, dispatch_uid=None, apps=None):
|
||||
return self._lazy_method(
|
||||
super().disconnect, apps, receiver, sender, dispatch_uid=dispatch_uid
|
||||
)
|
||||
|
||||
|
||||
pre_init = ModelSignal(use_caching=True)
|
||||
post_init = ModelSignal(use_caching=True)
|
||||
|
||||
pre_save = ModelSignal(use_caching=True)
|
||||
post_save = ModelSignal(use_caching=True)
|
||||
|
||||
pre_delete = ModelSignal(use_caching=True)
|
||||
post_delete = ModelSignal(use_caching=True)
|
||||
|
||||
m2m_changed = ModelSignal(use_caching=True)
|
||||
|
||||
pre_migrate = Signal()
|
||||
post_migrate = Signal()
|
||||
@@ -0,0 +1,6 @@
|
||||
from django.db.models.sql.query import * # NOQA
|
||||
from django.db.models.sql.query import Query
|
||||
from django.db.models.sql.subqueries import * # NOQA
|
||||
from django.db.models.sql.where import AND, OR, XOR
|
||||
|
||||
__all__ = ["Query", "AND", "OR", "XOR"]
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
Constants specific to the SQL storage portion of the ORM.
|
||||
"""
|
||||
|
||||
# Size of each "chunk" for get_iterator calls.
|
||||
# Larger values are slightly faster at the expense of more storage space.
|
||||
GET_ITERATOR_CHUNK_SIZE = 100
|
||||
|
||||
# Namedtuples for sql.* internal use.
|
||||
|
||||
# How many results to expect from a cursor.execute call
|
||||
MULTI = "multi"
|
||||
SINGLE = "single"
|
||||
CURSOR = "cursor"
|
||||
NO_RESULTS = "no results"
|
||||
|
||||
ORDER_DIR = {
|
||||
"ASC": ("ASC", "DESC"),
|
||||
"DESC": ("DESC", "ASC"),
|
||||
}
|
||||
|
||||
# SQL join types.
|
||||
INNER = "INNER JOIN"
|
||||
LOUTER = "LEFT OUTER JOIN"
|
||||
@@ -0,0 +1,224 @@
|
||||
"""
|
||||
Useful auxiliary data structures for query construction. Not useful outside
|
||||
the SQL domain.
|
||||
"""
|
||||
from django.core.exceptions import FullResultSet
|
||||
from django.db.models.sql.constants import INNER, LOUTER
|
||||
|
||||
|
||||
class MultiJoin(Exception):
|
||||
"""
|
||||
Used by join construction code to indicate the point at which a
|
||||
multi-valued join was attempted (if the caller wants to treat that
|
||||
exceptionally).
|
||||
"""
|
||||
|
||||
def __init__(self, names_pos, path_with_names):
|
||||
self.level = names_pos
|
||||
# The path travelled, this includes the path to the multijoin.
|
||||
self.names_with_path = path_with_names
|
||||
|
||||
|
||||
class Empty:
|
||||
pass
|
||||
|
||||
|
||||
class Join:
|
||||
"""
|
||||
Used by sql.Query and sql.SQLCompiler to generate JOIN clauses into the
|
||||
FROM entry. For example, the SQL generated could be
|
||||
LEFT OUTER JOIN "sometable" T1
|
||||
ON ("othertable"."sometable_id" = "sometable"."id")
|
||||
|
||||
This class is primarily used in Query.alias_map. All entries in alias_map
|
||||
must be Join compatible by providing the following attributes and methods:
|
||||
- table_name (string)
|
||||
- table_alias (possible alias for the table, can be None)
|
||||
- join_type (can be None for those entries that aren't joined from
|
||||
anything)
|
||||
- parent_alias (which table is this join's parent, can be None similarly
|
||||
to join_type)
|
||||
- as_sql()
|
||||
- relabeled_clone()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
table_name,
|
||||
parent_alias,
|
||||
table_alias,
|
||||
join_type,
|
||||
join_field,
|
||||
nullable,
|
||||
filtered_relation=None,
|
||||
):
|
||||
# Join table
|
||||
self.table_name = table_name
|
||||
self.parent_alias = parent_alias
|
||||
# Note: table_alias is not necessarily known at instantiation time.
|
||||
self.table_alias = table_alias
|
||||
# LOUTER or INNER
|
||||
self.join_type = join_type
|
||||
# A list of 2-tuples to use in the ON clause of the JOIN.
|
||||
# Each 2-tuple will create one join condition in the ON clause.
|
||||
self.join_cols = join_field.get_joining_columns()
|
||||
# Along which field (or ForeignObjectRel in the reverse join case)
|
||||
self.join_field = join_field
|
||||
# Is this join nullabled?
|
||||
self.nullable = nullable
|
||||
self.filtered_relation = filtered_relation
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
"""
|
||||
Generate the full
|
||||
LEFT OUTER JOIN sometable ON sometable.somecol = othertable.othercol, params
|
||||
clause for this join.
|
||||
"""
|
||||
join_conditions = []
|
||||
params = []
|
||||
qn = compiler.quote_name_unless_alias
|
||||
qn2 = connection.ops.quote_name
|
||||
|
||||
# Add a join condition for each pair of joining columns.
|
||||
for lhs_col, rhs_col in self.join_cols:
|
||||
join_conditions.append(
|
||||
"%s.%s = %s.%s"
|
||||
% (
|
||||
qn(self.parent_alias),
|
||||
qn2(lhs_col),
|
||||
qn(self.table_alias),
|
||||
qn2(rhs_col),
|
||||
)
|
||||
)
|
||||
|
||||
# Add a single condition inside parentheses for whatever
|
||||
# get_extra_restriction() returns.
|
||||
extra_cond = self.join_field.get_extra_restriction(
|
||||
self.table_alias, self.parent_alias
|
||||
)
|
||||
if extra_cond:
|
||||
extra_sql, extra_params = compiler.compile(extra_cond)
|
||||
join_conditions.append("(%s)" % extra_sql)
|
||||
params.extend(extra_params)
|
||||
if self.filtered_relation:
|
||||
try:
|
||||
extra_sql, extra_params = compiler.compile(self.filtered_relation)
|
||||
except FullResultSet:
|
||||
pass
|
||||
else:
|
||||
join_conditions.append("(%s)" % extra_sql)
|
||||
params.extend(extra_params)
|
||||
if not join_conditions:
|
||||
# This might be a rel on the other end of an actual declared field.
|
||||
declared_field = getattr(self.join_field, "field", self.join_field)
|
||||
raise ValueError(
|
||||
"Join generated an empty ON clause. %s did not yield either "
|
||||
"joining columns or extra restrictions." % declared_field.__class__
|
||||
)
|
||||
on_clause_sql = " AND ".join(join_conditions)
|
||||
alias_str = (
|
||||
"" if self.table_alias == self.table_name else (" %s" % self.table_alias)
|
||||
)
|
||||
sql = "%s %s%s ON (%s)" % (
|
||||
self.join_type,
|
||||
qn(self.table_name),
|
||||
alias_str,
|
||||
on_clause_sql,
|
||||
)
|
||||
return sql, params
|
||||
|
||||
def relabeled_clone(self, change_map):
|
||||
new_parent_alias = change_map.get(self.parent_alias, self.parent_alias)
|
||||
new_table_alias = change_map.get(self.table_alias, self.table_alias)
|
||||
if self.filtered_relation is not None:
|
||||
filtered_relation = self.filtered_relation.clone()
|
||||
filtered_relation.path = [
|
||||
change_map.get(p, p) for p in self.filtered_relation.path
|
||||
]
|
||||
else:
|
||||
filtered_relation = None
|
||||
return self.__class__(
|
||||
self.table_name,
|
||||
new_parent_alias,
|
||||
new_table_alias,
|
||||
self.join_type,
|
||||
self.join_field,
|
||||
self.nullable,
|
||||
filtered_relation=filtered_relation,
|
||||
)
|
||||
|
||||
@property
|
||||
def identity(self):
|
||||
return (
|
||||
self.__class__,
|
||||
self.table_name,
|
||||
self.parent_alias,
|
||||
self.join_field,
|
||||
self.filtered_relation,
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Join):
|
||||
return NotImplemented
|
||||
return self.identity == other.identity
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.identity)
|
||||
|
||||
def equals(self, other):
|
||||
# Ignore filtered_relation in equality check.
|
||||
return self.identity[:-1] == other.identity[:-1]
|
||||
|
||||
def demote(self):
|
||||
new = self.relabeled_clone({})
|
||||
new.join_type = INNER
|
||||
return new
|
||||
|
||||
def promote(self):
|
||||
new = self.relabeled_clone({})
|
||||
new.join_type = LOUTER
|
||||
return new
|
||||
|
||||
|
||||
class BaseTable:
|
||||
"""
|
||||
The BaseTable class is used for base table references in FROM clause. For
|
||||
example, the SQL "foo" in
|
||||
SELECT * FROM "foo" WHERE somecond
|
||||
could be generated by this class.
|
||||
"""
|
||||
|
||||
join_type = None
|
||||
parent_alias = None
|
||||
filtered_relation = None
|
||||
|
||||
def __init__(self, table_name, alias):
|
||||
self.table_name = table_name
|
||||
self.table_alias = alias
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
alias_str = (
|
||||
"" if self.table_alias == self.table_name else (" %s" % self.table_alias)
|
||||
)
|
||||
base_sql = compiler.quote_name_unless_alias(self.table_name)
|
||||
return base_sql + alias_str, []
|
||||
|
||||
def relabeled_clone(self, change_map):
|
||||
return self.__class__(
|
||||
self.table_name, change_map.get(self.table_alias, self.table_alias)
|
||||
)
|
||||
|
||||
@property
|
||||
def identity(self):
|
||||
return self.__class__, self.table_name, self.table_alias
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, BaseTable):
|
||||
return NotImplemented
|
||||
return self.identity == other.identity
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.identity)
|
||||
|
||||
def equals(self, other):
|
||||
return self.identity == other.identity
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,171 @@
|
||||
"""
|
||||
Query subclasses which provide extra functionality beyond simple data retrieval.
|
||||
"""
|
||||
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE, NO_RESULTS
|
||||
from django.db.models.sql.query import Query
|
||||
|
||||
__all__ = ["DeleteQuery", "UpdateQuery", "InsertQuery", "AggregateQuery"]
|
||||
|
||||
|
||||
class DeleteQuery(Query):
|
||||
"""A DELETE SQL query."""
|
||||
|
||||
compiler = "SQLDeleteCompiler"
|
||||
|
||||
def do_query(self, table, where, using):
|
||||
self.alias_map = {table: self.alias_map[table]}
|
||||
self.where = where
|
||||
cursor = self.get_compiler(using).execute_sql(CURSOR)
|
||||
if cursor:
|
||||
with cursor:
|
||||
return cursor.rowcount
|
||||
return 0
|
||||
|
||||
def delete_batch(self, pk_list, using):
|
||||
"""
|
||||
Set up and execute delete queries for all the objects in pk_list.
|
||||
|
||||
More than one physical query may be executed if there are a
|
||||
lot of values in pk_list.
|
||||
"""
|
||||
# number of objects deleted
|
||||
num_deleted = 0
|
||||
field = self.get_meta().pk
|
||||
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
|
||||
self.clear_where()
|
||||
self.add_filter(
|
||||
f"{field.attname}__in",
|
||||
pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE],
|
||||
)
|
||||
num_deleted += self.do_query(
|
||||
self.get_meta().db_table, self.where, using=using
|
||||
)
|
||||
return num_deleted
|
||||
|
||||
|
||||
class UpdateQuery(Query):
|
||||
"""An UPDATE SQL query."""
|
||||
|
||||
compiler = "SQLUpdateCompiler"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._setup_query()
|
||||
|
||||
def _setup_query(self):
|
||||
"""
|
||||
Run on initialization and at the end of chaining. Any attributes that
|
||||
would normally be set in __init__() should go here instead.
|
||||
"""
|
||||
self.values = []
|
||||
self.related_ids = None
|
||||
self.related_updates = {}
|
||||
|
||||
def clone(self):
|
||||
obj = super().clone()
|
||||
obj.related_updates = self.related_updates.copy()
|
||||
return obj
|
||||
|
||||
def update_batch(self, pk_list, values, using):
|
||||
self.add_update_values(values)
|
||||
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
|
||||
self.clear_where()
|
||||
self.add_filter(
|
||||
"pk__in", pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]
|
||||
)
|
||||
self.get_compiler(using).execute_sql(NO_RESULTS)
|
||||
|
||||
def add_update_values(self, values):
|
||||
"""
|
||||
Convert a dictionary of field name to value mappings into an update
|
||||
query. This is the entry point for the public update() method on
|
||||
querysets.
|
||||
"""
|
||||
values_seq = []
|
||||
for name, val in values.items():
|
||||
field = self.get_meta().get_field(name)
|
||||
direct = (
|
||||
not (field.auto_created and not field.concrete) or not field.concrete
|
||||
)
|
||||
model = field.model._meta.concrete_model
|
||||
if not direct or (field.is_relation and field.many_to_many):
|
||||
raise FieldError(
|
||||
"Cannot update model field %r (only non-relations and "
|
||||
"foreign keys permitted)." % field
|
||||
)
|
||||
if model is not self.get_meta().concrete_model:
|
||||
self.add_related_update(model, field, val)
|
||||
continue
|
||||
values_seq.append((field, model, val))
|
||||
return self.add_update_fields(values_seq)
|
||||
|
||||
def add_update_fields(self, values_seq):
|
||||
"""
|
||||
Append a sequence of (field, model, value) triples to the internal list
|
||||
that will be used to generate the UPDATE query. Might be more usefully
|
||||
called add_update_targets() to hint at the extra information here.
|
||||
"""
|
||||
for field, model, val in values_seq:
|
||||
if hasattr(val, "resolve_expression"):
|
||||
# Resolve expressions here so that annotations are no longer needed
|
||||
val = val.resolve_expression(self, allow_joins=False, for_save=True)
|
||||
self.values.append((field, model, val))
|
||||
|
||||
def add_related_update(self, model, field, value):
|
||||
"""
|
||||
Add (name, value) to an update query for an ancestor model.
|
||||
|
||||
Update are coalesced so that only one update query per ancestor is run.
|
||||
"""
|
||||
self.related_updates.setdefault(model, []).append((field, None, value))
|
||||
|
||||
def get_related_updates(self):
|
||||
"""
|
||||
Return a list of query objects: one for each update required to an
|
||||
ancestor model. Each query will have the same filtering conditions as
|
||||
the current query but will only update a single table.
|
||||
"""
|
||||
if not self.related_updates:
|
||||
return []
|
||||
result = []
|
||||
for model, values in self.related_updates.items():
|
||||
query = UpdateQuery(model)
|
||||
query.values = values
|
||||
if self.related_ids is not None:
|
||||
query.add_filter("pk__in", self.related_ids[model])
|
||||
result.append(query)
|
||||
return result
|
||||
|
||||
|
||||
class InsertQuery(Query):
|
||||
compiler = "SQLInsertCompiler"
|
||||
|
||||
def __init__(
|
||||
self, *args, on_conflict=None, update_fields=None, unique_fields=None, **kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.fields = []
|
||||
self.objs = []
|
||||
self.on_conflict = on_conflict
|
||||
self.update_fields = update_fields or []
|
||||
self.unique_fields = unique_fields or []
|
||||
|
||||
def insert_values(self, fields, objs, raw=False):
|
||||
self.fields = fields
|
||||
self.objs = objs
|
||||
self.raw = raw
|
||||
|
||||
|
||||
class AggregateQuery(Query):
|
||||
"""
|
||||
Take another query as a parameter to the FROM clause and only select the
|
||||
elements in the provided list.
|
||||
"""
|
||||
|
||||
compiler = "SQLAggregateCompiler"
|
||||
|
||||
def __init__(self, model, inner_query):
|
||||
self.inner_query = inner_query
|
||||
super().__init__(model)
|
||||
@@ -0,0 +1,355 @@
|
||||
"""
|
||||
Code to manage the creation and SQL rendering of 'where' constraints.
|
||||
"""
|
||||
import operator
|
||||
from functools import reduce
|
||||
|
||||
from django.core.exceptions import EmptyResultSet, FullResultSet
|
||||
from django.db.models.expressions import Case, When
|
||||
from django.db.models.lookups import Exact
|
||||
from django.utils import tree
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
# Connection types
|
||||
AND = "AND"
|
||||
OR = "OR"
|
||||
XOR = "XOR"
|
||||
|
||||
|
||||
class WhereNode(tree.Node):
|
||||
"""
|
||||
An SQL WHERE clause.
|
||||
|
||||
The class is tied to the Query class that created it (in order to create
|
||||
the correct SQL).
|
||||
|
||||
A child is usually an expression producing boolean values. Most likely the
|
||||
expression is a Lookup instance.
|
||||
|
||||
However, a child could also be any class with as_sql() and either
|
||||
relabeled_clone() method or relabel_aliases() and clone() methods and
|
||||
contains_aggregate attribute.
|
||||
"""
|
||||
|
||||
default = AND
|
||||
resolved = False
|
||||
conditional = True
|
||||
|
||||
def split_having_qualify(self, negated=False, must_group_by=False):
|
||||
"""
|
||||
Return three possibly None nodes: one for those parts of self that
|
||||
should be included in the WHERE clause, one for those parts of self
|
||||
that must be included in the HAVING clause, and one for those parts
|
||||
that refer to window functions.
|
||||
"""
|
||||
if not self.contains_aggregate and not self.contains_over_clause:
|
||||
return self, None, None
|
||||
in_negated = negated ^ self.negated
|
||||
# Whether or not children must be connected in the same filtering
|
||||
# clause (WHERE > HAVING > QUALIFY) to maintain logical semantic.
|
||||
must_remain_connected = (
|
||||
(in_negated and self.connector == AND)
|
||||
or (not in_negated and self.connector == OR)
|
||||
or self.connector == XOR
|
||||
)
|
||||
if (
|
||||
must_remain_connected
|
||||
and self.contains_aggregate
|
||||
and not self.contains_over_clause
|
||||
):
|
||||
# It's must cheaper to short-circuit and stash everything in the
|
||||
# HAVING clause than split children if possible.
|
||||
return None, self, None
|
||||
where_parts = []
|
||||
having_parts = []
|
||||
qualify_parts = []
|
||||
for c in self.children:
|
||||
if hasattr(c, "split_having_qualify"):
|
||||
where_part, having_part, qualify_part = c.split_having_qualify(
|
||||
in_negated, must_group_by
|
||||
)
|
||||
if where_part is not None:
|
||||
where_parts.append(where_part)
|
||||
if having_part is not None:
|
||||
having_parts.append(having_part)
|
||||
if qualify_part is not None:
|
||||
qualify_parts.append(qualify_part)
|
||||
elif c.contains_over_clause:
|
||||
qualify_parts.append(c)
|
||||
elif c.contains_aggregate:
|
||||
having_parts.append(c)
|
||||
else:
|
||||
where_parts.append(c)
|
||||
if must_remain_connected and qualify_parts:
|
||||
# Disjunctive heterogeneous predicates can be pushed down to
|
||||
# qualify as long as no conditional aggregation is involved.
|
||||
if not where_parts or (where_parts and not must_group_by):
|
||||
return None, None, self
|
||||
elif where_parts:
|
||||
# In theory this should only be enforced when dealing with
|
||||
# where_parts containing predicates against multi-valued
|
||||
# relationships that could affect aggregation results but this
|
||||
# is complex to infer properly.
|
||||
raise NotImplementedError(
|
||||
"Heterogeneous disjunctive predicates against window functions are "
|
||||
"not implemented when performing conditional aggregation."
|
||||
)
|
||||
where_node = (
|
||||
self.create(where_parts, self.connector, self.negated)
|
||||
if where_parts
|
||||
else None
|
||||
)
|
||||
having_node = (
|
||||
self.create(having_parts, self.connector, self.negated)
|
||||
if having_parts
|
||||
else None
|
||||
)
|
||||
qualify_node = (
|
||||
self.create(qualify_parts, self.connector, self.negated)
|
||||
if qualify_parts
|
||||
else None
|
||||
)
|
||||
return where_node, having_node, qualify_node
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
"""
|
||||
Return the SQL version of the where clause and the value to be
|
||||
substituted in. Return '', [] if this node matches everything,
|
||||
None, [] if this node is empty, and raise EmptyResultSet if this
|
||||
node can't match anything.
|
||||
"""
|
||||
result = []
|
||||
result_params = []
|
||||
if self.connector == AND:
|
||||
full_needed, empty_needed = len(self.children), 1
|
||||
else:
|
||||
full_needed, empty_needed = 1, len(self.children)
|
||||
|
||||
if self.connector == XOR and not connection.features.supports_logical_xor:
|
||||
# Convert if the database doesn't support XOR:
|
||||
# a XOR b XOR c XOR ...
|
||||
# to:
|
||||
# (a OR b OR c OR ...) AND (a + b + c + ...) == 1
|
||||
lhs = self.__class__(self.children, OR)
|
||||
rhs_sum = reduce(
|
||||
operator.add,
|
||||
(Case(When(c, then=1), default=0) for c in self.children),
|
||||
)
|
||||
rhs = Exact(1, rhs_sum)
|
||||
return self.__class__([lhs, rhs], AND, self.negated).as_sql(
|
||||
compiler, connection
|
||||
)
|
||||
|
||||
for child in self.children:
|
||||
try:
|
||||
sql, params = compiler.compile(child)
|
||||
except EmptyResultSet:
|
||||
empty_needed -= 1
|
||||
except FullResultSet:
|
||||
full_needed -= 1
|
||||
else:
|
||||
if sql:
|
||||
result.append(sql)
|
||||
result_params.extend(params)
|
||||
else:
|
||||
full_needed -= 1
|
||||
# Check if this node matches nothing or everything.
|
||||
# First check the amount of full nodes and empty nodes
|
||||
# to make this node empty/full.
|
||||
# Now, check if this node is full/empty using the
|
||||
# counts.
|
||||
if empty_needed == 0:
|
||||
if self.negated:
|
||||
raise FullResultSet
|
||||
else:
|
||||
raise EmptyResultSet
|
||||
if full_needed == 0:
|
||||
if self.negated:
|
||||
raise EmptyResultSet
|
||||
else:
|
||||
raise FullResultSet
|
||||
conn = " %s " % self.connector
|
||||
sql_string = conn.join(result)
|
||||
if not sql_string:
|
||||
raise FullResultSet
|
||||
if self.negated:
|
||||
# Some backends (Oracle at least) need parentheses around the inner
|
||||
# SQL in the negated case, even if the inner SQL contains just a
|
||||
# single expression.
|
||||
sql_string = "NOT (%s)" % sql_string
|
||||
elif len(result) > 1 or self.resolved:
|
||||
sql_string = "(%s)" % sql_string
|
||||
return sql_string, result_params
|
||||
|
||||
def get_group_by_cols(self):
|
||||
cols = []
|
||||
for child in self.children:
|
||||
cols.extend(child.get_group_by_cols())
|
||||
return cols
|
||||
|
||||
def get_source_expressions(self):
|
||||
return self.children[:]
|
||||
|
||||
def set_source_expressions(self, children):
|
||||
assert len(children) == len(self.children)
|
||||
self.children = children
|
||||
|
||||
def relabel_aliases(self, change_map):
|
||||
"""
|
||||
Relabel the alias values of any children. 'change_map' is a dictionary
|
||||
mapping old (current) alias values to the new values.
|
||||
"""
|
||||
for pos, child in enumerate(self.children):
|
||||
if hasattr(child, "relabel_aliases"):
|
||||
# For example another WhereNode
|
||||
child.relabel_aliases(change_map)
|
||||
elif hasattr(child, "relabeled_clone"):
|
||||
self.children[pos] = child.relabeled_clone(change_map)
|
||||
|
||||
def clone(self):
|
||||
clone = self.create(connector=self.connector, negated=self.negated)
|
||||
for child in self.children:
|
||||
if hasattr(child, "clone"):
|
||||
child = child.clone()
|
||||
clone.children.append(child)
|
||||
return clone
|
||||
|
||||
def relabeled_clone(self, change_map):
|
||||
clone = self.clone()
|
||||
clone.relabel_aliases(change_map)
|
||||
return clone
|
||||
|
||||
def replace_expressions(self, replacements):
|
||||
if replacement := replacements.get(self):
|
||||
return replacement
|
||||
clone = self.create(connector=self.connector, negated=self.negated)
|
||||
for child in self.children:
|
||||
clone.children.append(child.replace_expressions(replacements))
|
||||
return clone
|
||||
|
||||
def get_refs(self):
|
||||
refs = set()
|
||||
for child in self.children:
|
||||
refs |= child.get_refs()
|
||||
return refs
|
||||
|
||||
@classmethod
|
||||
def _contains_aggregate(cls, obj):
|
||||
if isinstance(obj, tree.Node):
|
||||
return any(cls._contains_aggregate(c) for c in obj.children)
|
||||
return obj.contains_aggregate
|
||||
|
||||
@cached_property
|
||||
def contains_aggregate(self):
|
||||
return self._contains_aggregate(self)
|
||||
|
||||
@classmethod
|
||||
def _contains_over_clause(cls, obj):
|
||||
if isinstance(obj, tree.Node):
|
||||
return any(cls._contains_over_clause(c) for c in obj.children)
|
||||
return obj.contains_over_clause
|
||||
|
||||
@cached_property
|
||||
def contains_over_clause(self):
|
||||
return self._contains_over_clause(self)
|
||||
|
||||
@property
|
||||
def is_summary(self):
|
||||
return any(child.is_summary for child in self.children)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_leaf(expr, query, *args, **kwargs):
|
||||
if hasattr(expr, "resolve_expression"):
|
||||
expr = expr.resolve_expression(query, *args, **kwargs)
|
||||
return expr
|
||||
|
||||
@classmethod
|
||||
def _resolve_node(cls, node, query, *args, **kwargs):
|
||||
if hasattr(node, "children"):
|
||||
for child in node.children:
|
||||
cls._resolve_node(child, query, *args, **kwargs)
|
||||
if hasattr(node, "lhs"):
|
||||
node.lhs = cls._resolve_leaf(node.lhs, query, *args, **kwargs)
|
||||
if hasattr(node, "rhs"):
|
||||
node.rhs = cls._resolve_leaf(node.rhs, query, *args, **kwargs)
|
||||
|
||||
def resolve_expression(self, *args, **kwargs):
|
||||
clone = self.clone()
|
||||
clone._resolve_node(clone, *args, **kwargs)
|
||||
clone.resolved = True
|
||||
return clone
|
||||
|
||||
@cached_property
|
||||
def output_field(self):
|
||||
from django.db.models import BooleanField
|
||||
|
||||
return BooleanField()
|
||||
|
||||
@property
|
||||
def _output_field_or_none(self):
|
||||
return self.output_field
|
||||
|
||||
def select_format(self, compiler, sql, params):
|
||||
# Wrap filters with a CASE WHEN expression if a database backend
|
||||
# (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
|
||||
# BY list.
|
||||
if not compiler.connection.features.supports_boolean_expr_in_select_clause:
|
||||
sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
|
||||
return sql, params
|
||||
|
||||
def get_db_converters(self, connection):
|
||||
return self.output_field.get_db_converters(connection)
|
||||
|
||||
def get_lookup(self, lookup):
|
||||
return self.output_field.get_lookup(lookup)
|
||||
|
||||
def leaves(self):
|
||||
for child in self.children:
|
||||
if isinstance(child, WhereNode):
|
||||
yield from child.leaves()
|
||||
else:
|
||||
yield child
|
||||
|
||||
|
||||
class NothingNode:
|
||||
"""A node that matches nothing."""
|
||||
|
||||
contains_aggregate = False
|
||||
contains_over_clause = False
|
||||
|
||||
def as_sql(self, compiler=None, connection=None):
|
||||
raise EmptyResultSet
|
||||
|
||||
|
||||
class ExtraWhere:
|
||||
# The contents are a black box - assume no aggregates or windows are used.
|
||||
contains_aggregate = False
|
||||
contains_over_clause = False
|
||||
|
||||
def __init__(self, sqls, params):
|
||||
self.sqls = sqls
|
||||
self.params = params
|
||||
|
||||
def as_sql(self, compiler=None, connection=None):
|
||||
sqls = ["(%s)" % sql for sql in self.sqls]
|
||||
return " AND ".join(sqls), list(self.params or ())
|
||||
|
||||
|
||||
class SubqueryConstraint:
|
||||
# Even if aggregates or windows would be used in a subquery,
|
||||
# the outer query isn't interested about those.
|
||||
contains_aggregate = False
|
||||
contains_over_clause = False
|
||||
|
||||
def __init__(self, alias, columns, targets, query_object):
|
||||
self.alias = alias
|
||||
self.columns = columns
|
||||
self.targets = targets
|
||||
query_object.clear_ordering(clear_default=True)
|
||||
self.query_object = query_object
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
query = self.query_object
|
||||
query.set_values(self.targets)
|
||||
query_compiler = query.get_compiler(connection=connection)
|
||||
return query_compiler.as_subquery_condition(self.alias, self.columns, compiler)
|
||||
@@ -0,0 +1,69 @@
|
||||
import functools
|
||||
from collections import namedtuple
|
||||
|
||||
|
||||
def make_model_tuple(model):
|
||||
"""
|
||||
Take a model or a string of the form "app_label.ModelName" and return a
|
||||
corresponding ("app_label", "modelname") tuple. If a tuple is passed in,
|
||||
assume it's a valid model tuple already and return it unchanged.
|
||||
"""
|
||||
try:
|
||||
if isinstance(model, tuple):
|
||||
model_tuple = model
|
||||
elif isinstance(model, str):
|
||||
app_label, model_name = model.split(".")
|
||||
model_tuple = app_label, model_name.lower()
|
||||
else:
|
||||
model_tuple = model._meta.app_label, model._meta.model_name
|
||||
assert len(model_tuple) == 2
|
||||
return model_tuple
|
||||
except (ValueError, AssertionError):
|
||||
raise ValueError(
|
||||
"Invalid model reference '%s'. String model references "
|
||||
"must be of the form 'app_label.ModelName'." % model
|
||||
)
|
||||
|
||||
|
||||
def resolve_callables(mapping):
|
||||
"""
|
||||
Generate key/value pairs for the given mapping where the values are
|
||||
evaluated if they're callable.
|
||||
"""
|
||||
for k, v in mapping.items():
|
||||
yield k, v() if callable(v) else v
|
||||
|
||||
|
||||
def unpickle_named_row(names, values):
|
||||
return create_namedtuple_class(*names)(*values)
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def create_namedtuple_class(*names):
|
||||
# Cache type() with @lru_cache since it's too slow to be called for every
|
||||
# QuerySet evaluation.
|
||||
def __reduce__(self):
|
||||
return unpickle_named_row, (names, tuple(self))
|
||||
|
||||
return type(
|
||||
"Row",
|
||||
(namedtuple("Row", names),),
|
||||
{"__reduce__": __reduce__, "__slots__": ()},
|
||||
)
|
||||
|
||||
|
||||
class AltersData:
|
||||
"""
|
||||
Make subclasses preserve the alters_data attribute on overridden methods.
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
for fn_name, fn in vars(cls).items():
|
||||
if callable(fn) and not hasattr(fn, "alters_data"):
|
||||
for base in cls.__bases__:
|
||||
if base_fn := getattr(base, fn_name, None):
|
||||
if hasattr(base_fn, "alters_data"):
|
||||
fn.alters_data = base_fn.alters_data
|
||||
break
|
||||
|
||||
super().__init_subclass__(**kwargs)
|
||||
340
srcs/.venv/lib/python3.11/site-packages/django/db/transaction.py
Normal file
340
srcs/.venv/lib/python3.11/site-packages/django/db/transaction.py
Normal file
@@ -0,0 +1,340 @@
|
||||
from contextlib import ContextDecorator, contextmanager
|
||||
|
||||
from django.db import (
|
||||
DEFAULT_DB_ALIAS,
|
||||
DatabaseError,
|
||||
Error,
|
||||
ProgrammingError,
|
||||
connections,
|
||||
)
|
||||
|
||||
|
||||
class TransactionManagementError(ProgrammingError):
|
||||
"""Transaction management is used improperly."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def get_connection(using=None):
|
||||
"""
|
||||
Get a database connection by name, or the default database connection
|
||||
if no name is provided. This is a private API.
|
||||
"""
|
||||
if using is None:
|
||||
using = DEFAULT_DB_ALIAS
|
||||
return connections[using]
|
||||
|
||||
|
||||
def get_autocommit(using=None):
|
||||
"""Get the autocommit status of the connection."""
|
||||
return get_connection(using).get_autocommit()
|
||||
|
||||
|
||||
def set_autocommit(autocommit, using=None):
|
||||
"""Set the autocommit status of the connection."""
|
||||
return get_connection(using).set_autocommit(autocommit)
|
||||
|
||||
|
||||
def commit(using=None):
|
||||
"""Commit a transaction."""
|
||||
get_connection(using).commit()
|
||||
|
||||
|
||||
def rollback(using=None):
|
||||
"""Roll back a transaction."""
|
||||
get_connection(using).rollback()
|
||||
|
||||
|
||||
def savepoint(using=None):
|
||||
"""
|
||||
Create a savepoint (if supported and required by the backend) inside the
|
||||
current transaction. Return an identifier for the savepoint that will be
|
||||
used for the subsequent rollback or commit.
|
||||
"""
|
||||
return get_connection(using).savepoint()
|
||||
|
||||
|
||||
def savepoint_rollback(sid, using=None):
|
||||
"""
|
||||
Roll back the most recent savepoint (if one exists). Do nothing if
|
||||
savepoints are not supported.
|
||||
"""
|
||||
get_connection(using).savepoint_rollback(sid)
|
||||
|
||||
|
||||
def savepoint_commit(sid, using=None):
|
||||
"""
|
||||
Commit the most recent savepoint (if one exists). Do nothing if
|
||||
savepoints are not supported.
|
||||
"""
|
||||
get_connection(using).savepoint_commit(sid)
|
||||
|
||||
|
||||
def clean_savepoints(using=None):
|
||||
"""
|
||||
Reset the counter used to generate unique savepoint ids in this thread.
|
||||
"""
|
||||
get_connection(using).clean_savepoints()
|
||||
|
||||
|
||||
def get_rollback(using=None):
|
||||
"""Get the "needs rollback" flag -- for *advanced use* only."""
|
||||
return get_connection(using).get_rollback()
|
||||
|
||||
|
||||
def set_rollback(rollback, using=None):
|
||||
"""
|
||||
Set or unset the "needs rollback" flag -- for *advanced use* only.
|
||||
|
||||
When `rollback` is `True`, trigger a rollback when exiting the innermost
|
||||
enclosing atomic block that has `savepoint=True` (that's the default). Use
|
||||
this to force a rollback without raising an exception.
|
||||
|
||||
When `rollback` is `False`, prevent such a rollback. Use this only after
|
||||
rolling back to a known-good state! Otherwise, you break the atomic block
|
||||
and data corruption may occur.
|
||||
"""
|
||||
return get_connection(using).set_rollback(rollback)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def mark_for_rollback_on_error(using=None):
|
||||
"""
|
||||
Internal low-level utility to mark a transaction as "needs rollback" when
|
||||
an exception is raised while not enforcing the enclosed block to be in a
|
||||
transaction. This is needed by Model.save() and friends to avoid starting a
|
||||
transaction when in autocommit mode and a single query is executed.
|
||||
|
||||
It's equivalent to:
|
||||
|
||||
connection = get_connection(using)
|
||||
if connection.get_autocommit():
|
||||
yield
|
||||
else:
|
||||
with transaction.atomic(using=using, savepoint=False):
|
||||
yield
|
||||
|
||||
but it uses low-level utilities to avoid performance overhead.
|
||||
"""
|
||||
try:
|
||||
yield
|
||||
except Exception as exc:
|
||||
connection = get_connection(using)
|
||||
if connection.in_atomic_block:
|
||||
connection.needs_rollback = True
|
||||
connection.rollback_exc = exc
|
||||
raise
|
||||
|
||||
|
||||
def on_commit(func, using=None, robust=False):
|
||||
"""
|
||||
Register `func` to be called when the current transaction is committed.
|
||||
If the current transaction is rolled back, `func` will not be called.
|
||||
"""
|
||||
get_connection(using).on_commit(func, robust)
|
||||
|
||||
|
||||
#################################
|
||||
# Decorators / context managers #
|
||||
#################################
|
||||
|
||||
|
||||
class Atomic(ContextDecorator):
|
||||
"""
|
||||
Guarantee the atomic execution of a given block.
|
||||
|
||||
An instance can be used either as a decorator or as a context manager.
|
||||
|
||||
When it's used as a decorator, __call__ wraps the execution of the
|
||||
decorated function in the instance itself, used as a context manager.
|
||||
|
||||
When it's used as a context manager, __enter__ creates a transaction or a
|
||||
savepoint, depending on whether a transaction is already in progress, and
|
||||
__exit__ commits the transaction or releases the savepoint on normal exit,
|
||||
and rolls back the transaction or to the savepoint on exceptions.
|
||||
|
||||
It's possible to disable the creation of savepoints if the goal is to
|
||||
ensure that some code runs within a transaction without creating overhead.
|
||||
|
||||
A stack of savepoints identifiers is maintained as an attribute of the
|
||||
connection. None denotes the absence of a savepoint.
|
||||
|
||||
This allows reentrancy even if the same AtomicWrapper is reused. For
|
||||
example, it's possible to define `oa = atomic('other')` and use `@oa` or
|
||||
`with oa:` multiple times.
|
||||
|
||||
Since database connections are thread-local, this is thread-safe.
|
||||
|
||||
An atomic block can be tagged as durable. In this case, raise a
|
||||
RuntimeError if it's nested within another atomic block. This guarantees
|
||||
that database changes in a durable block are committed to the database when
|
||||
the block exists without error.
|
||||
|
||||
This is a private API.
|
||||
"""
|
||||
|
||||
def __init__(self, using, savepoint, durable):
|
||||
self.using = using
|
||||
self.savepoint = savepoint
|
||||
self.durable = durable
|
||||
self._from_testcase = False
|
||||
|
||||
def __enter__(self):
|
||||
connection = get_connection(self.using)
|
||||
|
||||
if (
|
||||
self.durable
|
||||
and connection.atomic_blocks
|
||||
and not connection.atomic_blocks[-1]._from_testcase
|
||||
):
|
||||
raise RuntimeError(
|
||||
"A durable atomic block cannot be nested within another "
|
||||
"atomic block."
|
||||
)
|
||||
if not connection.in_atomic_block:
|
||||
# Reset state when entering an outermost atomic block.
|
||||
connection.commit_on_exit = True
|
||||
connection.needs_rollback = False
|
||||
if not connection.get_autocommit():
|
||||
# Pretend we're already in an atomic block to bypass the code
|
||||
# that disables autocommit to enter a transaction, and make a
|
||||
# note to deal with this case in __exit__.
|
||||
connection.in_atomic_block = True
|
||||
connection.commit_on_exit = False
|
||||
|
||||
if connection.in_atomic_block:
|
||||
# We're already in a transaction; create a savepoint, unless we
|
||||
# were told not to or we're already waiting for a rollback. The
|
||||
# second condition avoids creating useless savepoints and prevents
|
||||
# overwriting needs_rollback until the rollback is performed.
|
||||
if self.savepoint and not connection.needs_rollback:
|
||||
sid = connection.savepoint()
|
||||
connection.savepoint_ids.append(sid)
|
||||
else:
|
||||
connection.savepoint_ids.append(None)
|
||||
else:
|
||||
connection.set_autocommit(
|
||||
False, force_begin_transaction_with_broken_autocommit=True
|
||||
)
|
||||
connection.in_atomic_block = True
|
||||
|
||||
if connection.in_atomic_block:
|
||||
connection.atomic_blocks.append(self)
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
connection = get_connection(self.using)
|
||||
|
||||
if connection.in_atomic_block:
|
||||
connection.atomic_blocks.pop()
|
||||
|
||||
if connection.savepoint_ids:
|
||||
sid = connection.savepoint_ids.pop()
|
||||
else:
|
||||
# Prematurely unset this flag to allow using commit or rollback.
|
||||
connection.in_atomic_block = False
|
||||
|
||||
try:
|
||||
if connection.closed_in_transaction:
|
||||
# The database will perform a rollback by itself.
|
||||
# Wait until we exit the outermost block.
|
||||
pass
|
||||
|
||||
elif exc_type is None and not connection.needs_rollback:
|
||||
if connection.in_atomic_block:
|
||||
# Release savepoint if there is one
|
||||
if sid is not None:
|
||||
try:
|
||||
connection.savepoint_commit(sid)
|
||||
except DatabaseError:
|
||||
try:
|
||||
connection.savepoint_rollback(sid)
|
||||
# The savepoint won't be reused. Release it to
|
||||
# minimize overhead for the database server.
|
||||
connection.savepoint_commit(sid)
|
||||
except Error:
|
||||
# If rolling back to a savepoint fails, mark for
|
||||
# rollback at a higher level and avoid shadowing
|
||||
# the original exception.
|
||||
connection.needs_rollback = True
|
||||
raise
|
||||
else:
|
||||
# Commit transaction
|
||||
try:
|
||||
connection.commit()
|
||||
except DatabaseError:
|
||||
try:
|
||||
connection.rollback()
|
||||
except Error:
|
||||
# An error during rollback means that something
|
||||
# went wrong with the connection. Drop it.
|
||||
connection.close()
|
||||
raise
|
||||
else:
|
||||
# This flag will be set to True again if there isn't a savepoint
|
||||
# allowing to perform the rollback at this level.
|
||||
connection.needs_rollback = False
|
||||
if connection.in_atomic_block:
|
||||
# Roll back to savepoint if there is one, mark for rollback
|
||||
# otherwise.
|
||||
if sid is None:
|
||||
connection.needs_rollback = True
|
||||
else:
|
||||
try:
|
||||
connection.savepoint_rollback(sid)
|
||||
# The savepoint won't be reused. Release it to
|
||||
# minimize overhead for the database server.
|
||||
connection.savepoint_commit(sid)
|
||||
except Error:
|
||||
# If rolling back to a savepoint fails, mark for
|
||||
# rollback at a higher level and avoid shadowing
|
||||
# the original exception.
|
||||
connection.needs_rollback = True
|
||||
else:
|
||||
# Roll back transaction
|
||||
try:
|
||||
connection.rollback()
|
||||
except Error:
|
||||
# An error during rollback means that something
|
||||
# went wrong with the connection. Drop it.
|
||||
connection.close()
|
||||
|
||||
finally:
|
||||
# Outermost block exit when autocommit was enabled.
|
||||
if not connection.in_atomic_block:
|
||||
if connection.closed_in_transaction:
|
||||
connection.connection = None
|
||||
else:
|
||||
connection.set_autocommit(True)
|
||||
# Outermost block exit when autocommit was disabled.
|
||||
elif not connection.savepoint_ids and not connection.commit_on_exit:
|
||||
if connection.closed_in_transaction:
|
||||
connection.connection = None
|
||||
else:
|
||||
connection.in_atomic_block = False
|
||||
|
||||
|
||||
def atomic(using=None, savepoint=True, durable=False):
|
||||
# Bare decorator: @atomic -- although the first argument is called
|
||||
# `using`, it's actually the function being decorated.
|
||||
if callable(using):
|
||||
return Atomic(DEFAULT_DB_ALIAS, savepoint, durable)(using)
|
||||
# Decorator: @atomic(...) or context manager: with atomic(...): ...
|
||||
else:
|
||||
return Atomic(using, savepoint, durable)
|
||||
|
||||
|
||||
def _non_atomic_requests(view, using):
|
||||
try:
|
||||
view._non_atomic_requests.add(using)
|
||||
except AttributeError:
|
||||
view._non_atomic_requests = {using}
|
||||
return view
|
||||
|
||||
|
||||
def non_atomic_requests(using=None):
|
||||
if callable(using):
|
||||
return _non_atomic_requests(using, DEFAULT_DB_ALIAS)
|
||||
else:
|
||||
if using is None:
|
||||
using = DEFAULT_DB_ALIAS
|
||||
return lambda view: _non_atomic_requests(view, using)
|
||||
278
srcs/.venv/lib/python3.11/site-packages/django/db/utils.py
Normal file
278
srcs/.venv/lib/python3.11/site-packages/django/db/utils.py
Normal file
@@ -0,0 +1,278 @@
|
||||
import pkgutil
|
||||
from importlib import import_module
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
|
||||
# For backwards compatibility with Django < 3.2
|
||||
from django.utils.connection import ConnectionDoesNotExist # NOQA: F401
|
||||
from django.utils.connection import BaseConnectionHandler
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.module_loading import import_string
|
||||
|
||||
DEFAULT_DB_ALIAS = "default"
|
||||
DJANGO_VERSION_PICKLE_KEY = "_django_version"
|
||||
|
||||
|
||||
class Error(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InterfaceError(Error):
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseError(Error):
|
||||
pass
|
||||
|
||||
|
||||
class DataError(DatabaseError):
|
||||
pass
|
||||
|
||||
|
||||
class OperationalError(DatabaseError):
|
||||
pass
|
||||
|
||||
|
||||
class IntegrityError(DatabaseError):
|
||||
pass
|
||||
|
||||
|
||||
class InternalError(DatabaseError):
|
||||
pass
|
||||
|
||||
|
||||
class ProgrammingError(DatabaseError):
|
||||
pass
|
||||
|
||||
|
||||
class NotSupportedError(DatabaseError):
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseErrorWrapper:
|
||||
"""
|
||||
Context manager and decorator that reraises backend-specific database
|
||||
exceptions using Django's common wrappers.
|
||||
"""
|
||||
|
||||
def __init__(self, wrapper):
|
||||
"""
|
||||
wrapper is a database wrapper.
|
||||
|
||||
It must have a Database attribute defining PEP-249 exceptions.
|
||||
"""
|
||||
self.wrapper = wrapper
|
||||
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
if exc_type is None:
|
||||
return
|
||||
for dj_exc_type in (
|
||||
DataError,
|
||||
OperationalError,
|
||||
IntegrityError,
|
||||
InternalError,
|
||||
ProgrammingError,
|
||||
NotSupportedError,
|
||||
DatabaseError,
|
||||
InterfaceError,
|
||||
Error,
|
||||
):
|
||||
db_exc_type = getattr(self.wrapper.Database, dj_exc_type.__name__)
|
||||
if issubclass(exc_type, db_exc_type):
|
||||
dj_exc_value = dj_exc_type(*exc_value.args)
|
||||
# Only set the 'errors_occurred' flag for errors that may make
|
||||
# the connection unusable.
|
||||
if dj_exc_type not in (DataError, IntegrityError):
|
||||
self.wrapper.errors_occurred = True
|
||||
raise dj_exc_value.with_traceback(traceback) from exc_value
|
||||
|
||||
def __call__(self, func):
|
||||
# Note that we are intentionally not using @wraps here for performance
|
||||
# reasons. Refs #21109.
|
||||
def inner(*args, **kwargs):
|
||||
with self:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def load_backend(backend_name):
|
||||
"""
|
||||
Return a database backend's "base" module given a fully qualified database
|
||||
backend name, or raise an error if it doesn't exist.
|
||||
"""
|
||||
# This backend was renamed in Django 1.9.
|
||||
if backend_name == "django.db.backends.postgresql_psycopg2":
|
||||
backend_name = "django.db.backends.postgresql"
|
||||
|
||||
try:
|
||||
return import_module("%s.base" % backend_name)
|
||||
except ImportError as e_user:
|
||||
# The database backend wasn't found. Display a helpful error message
|
||||
# listing all built-in database backends.
|
||||
import django.db.backends
|
||||
|
||||
builtin_backends = [
|
||||
name
|
||||
for _, name, ispkg in pkgutil.iter_modules(django.db.backends.__path__)
|
||||
if ispkg and name not in {"base", "dummy"}
|
||||
]
|
||||
if backend_name not in ["django.db.backends.%s" % b for b in builtin_backends]:
|
||||
backend_reprs = map(repr, sorted(builtin_backends))
|
||||
raise ImproperlyConfigured(
|
||||
"%r isn't an available database backend or couldn't be "
|
||||
"imported. Check the above exception. To use one of the "
|
||||
"built-in backends, use 'django.db.backends.XXX', where XXX "
|
||||
"is one of:\n"
|
||||
" %s" % (backend_name, ", ".join(backend_reprs))
|
||||
) from e_user
|
||||
else:
|
||||
# If there's some other error, this must be an error in Django
|
||||
raise
|
||||
|
||||
|
||||
class ConnectionHandler(BaseConnectionHandler):
|
||||
settings_name = "DATABASES"
|
||||
# Connections needs to still be an actual thread local, as it's truly
|
||||
# thread-critical. Database backends should use @async_unsafe to protect
|
||||
# their code from async contexts, but this will give those contexts
|
||||
# separate connections in case it's needed as well. There's no cleanup
|
||||
# after async contexts, though, so we don't allow that if we can help it.
|
||||
thread_critical = True
|
||||
|
||||
def configure_settings(self, databases):
|
||||
databases = super().configure_settings(databases)
|
||||
if databases == {}:
|
||||
databases[DEFAULT_DB_ALIAS] = {"ENGINE": "django.db.backends.dummy"}
|
||||
elif DEFAULT_DB_ALIAS not in databases:
|
||||
raise ImproperlyConfigured(
|
||||
f"You must define a '{DEFAULT_DB_ALIAS}' database."
|
||||
)
|
||||
elif databases[DEFAULT_DB_ALIAS] == {}:
|
||||
databases[DEFAULT_DB_ALIAS]["ENGINE"] = "django.db.backends.dummy"
|
||||
|
||||
# Configure default settings.
|
||||
for conn in databases.values():
|
||||
conn.setdefault("ATOMIC_REQUESTS", False)
|
||||
conn.setdefault("AUTOCOMMIT", True)
|
||||
conn.setdefault("ENGINE", "django.db.backends.dummy")
|
||||
if conn["ENGINE"] == "django.db.backends." or not conn["ENGINE"]:
|
||||
conn["ENGINE"] = "django.db.backends.dummy"
|
||||
conn.setdefault("CONN_MAX_AGE", 0)
|
||||
conn.setdefault("CONN_HEALTH_CHECKS", False)
|
||||
conn.setdefault("OPTIONS", {})
|
||||
conn.setdefault("TIME_ZONE", None)
|
||||
for setting in ["NAME", "USER", "PASSWORD", "HOST", "PORT"]:
|
||||
conn.setdefault(setting, "")
|
||||
|
||||
test_settings = conn.setdefault("TEST", {})
|
||||
default_test_settings = [
|
||||
("CHARSET", None),
|
||||
("COLLATION", None),
|
||||
("MIGRATE", True),
|
||||
("MIRROR", None),
|
||||
("NAME", None),
|
||||
]
|
||||
for key, value in default_test_settings:
|
||||
test_settings.setdefault(key, value)
|
||||
return databases
|
||||
|
||||
@property
|
||||
def databases(self):
|
||||
# Maintained for backward compatibility as some 3rd party packages have
|
||||
# made use of this private API in the past. It is no longer used within
|
||||
# Django itself.
|
||||
return self.settings
|
||||
|
||||
def create_connection(self, alias):
|
||||
db = self.settings[alias]
|
||||
backend = load_backend(db["ENGINE"])
|
||||
return backend.DatabaseWrapper(db, alias)
|
||||
|
||||
|
||||
class ConnectionRouter:
|
||||
def __init__(self, routers=None):
|
||||
"""
|
||||
If routers is not specified, default to settings.DATABASE_ROUTERS.
|
||||
"""
|
||||
self._routers = routers
|
||||
|
||||
@cached_property
|
||||
def routers(self):
|
||||
if self._routers is None:
|
||||
self._routers = settings.DATABASE_ROUTERS
|
||||
routers = []
|
||||
for r in self._routers:
|
||||
if isinstance(r, str):
|
||||
router = import_string(r)()
|
||||
else:
|
||||
router = r
|
||||
routers.append(router)
|
||||
return routers
|
||||
|
||||
def _router_func(action):
|
||||
def _route_db(self, model, **hints):
|
||||
chosen_db = None
|
||||
for router in self.routers:
|
||||
try:
|
||||
method = getattr(router, action)
|
||||
except AttributeError:
|
||||
# If the router doesn't have a method, skip to the next one.
|
||||
pass
|
||||
else:
|
||||
chosen_db = method(model, **hints)
|
||||
if chosen_db:
|
||||
return chosen_db
|
||||
instance = hints.get("instance")
|
||||
if instance is not None and instance._state.db:
|
||||
return instance._state.db
|
||||
return DEFAULT_DB_ALIAS
|
||||
|
||||
return _route_db
|
||||
|
||||
db_for_read = _router_func("db_for_read")
|
||||
db_for_write = _router_func("db_for_write")
|
||||
|
||||
def allow_relation(self, obj1, obj2, **hints):
|
||||
for router in self.routers:
|
||||
try:
|
||||
method = router.allow_relation
|
||||
except AttributeError:
|
||||
# If the router doesn't have a method, skip to the next one.
|
||||
pass
|
||||
else:
|
||||
allow = method(obj1, obj2, **hints)
|
||||
if allow is not None:
|
||||
return allow
|
||||
return obj1._state.db == obj2._state.db
|
||||
|
||||
def allow_migrate(self, db, app_label, **hints):
|
||||
for router in self.routers:
|
||||
try:
|
||||
method = router.allow_migrate
|
||||
except AttributeError:
|
||||
# If the router doesn't have a method, skip to the next one.
|
||||
continue
|
||||
|
||||
allow = method(db, app_label, **hints)
|
||||
|
||||
if allow is not None:
|
||||
return allow
|
||||
return True
|
||||
|
||||
def allow_migrate_model(self, db, model):
|
||||
return self.allow_migrate(
|
||||
db,
|
||||
model._meta.app_label,
|
||||
model_name=model._meta.model_name,
|
||||
model=model,
|
||||
)
|
||||
|
||||
def get_migratable_models(self, app_config, db, include_auto_created=False):
|
||||
"""Return app models allowed to be migrated on provided db."""
|
||||
models = app_config.get_models(include_auto_created=include_auto_created)
|
||||
return [model for model in models if self.allow_migrate_model(db, model)]
|
||||
Reference in New Issue
Block a user