docker setup
This commit is contained in:
		| @ -0,0 +1,802 @@ | ||||
| import _thread | ||||
| import copy | ||||
| import datetime | ||||
| import logging | ||||
| import threading | ||||
| import time | ||||
| import warnings | ||||
| from collections import deque | ||||
| from contextlib import contextmanager | ||||
|  | ||||
| from django.db.backends.utils import debug_transaction | ||||
|  | ||||
| try: | ||||
|     import zoneinfo | ||||
| except ImportError: | ||||
|     from backports import zoneinfo | ||||
|  | ||||
| from django.conf import settings | ||||
| from django.core.exceptions import ImproperlyConfigured | ||||
| from django.db import DEFAULT_DB_ALIAS, DatabaseError, NotSupportedError | ||||
| from django.db.backends import utils | ||||
| from django.db.backends.base.validation import BaseDatabaseValidation | ||||
| from django.db.backends.signals import connection_created | ||||
| from django.db.transaction import TransactionManagementError | ||||
| from django.db.utils import DatabaseErrorWrapper | ||||
| from django.utils.asyncio import async_unsafe | ||||
| from django.utils.functional import cached_property | ||||
|  | ||||
| NO_DB_ALIAS = "__no_db__" | ||||
| RAN_DB_VERSION_CHECK = set() | ||||
|  | ||||
| logger = logging.getLogger("django.db.backends.base") | ||||
|  | ||||
|  | ||||
| # RemovedInDjango50Warning | ||||
| def timezone_constructor(tzname): | ||||
|     if settings.USE_DEPRECATED_PYTZ: | ||||
|         import pytz | ||||
|  | ||||
|         return pytz.timezone(tzname) | ||||
|     return zoneinfo.ZoneInfo(tzname) | ||||
|  | ||||
|  | ||||
| class BaseDatabaseWrapper: | ||||
|     """Represent a database connection.""" | ||||
|  | ||||
|     # Mapping of Field objects to their column types. | ||||
|     data_types = {} | ||||
|     # Mapping of Field objects to their SQL suffix such as AUTOINCREMENT. | ||||
|     data_types_suffix = {} | ||||
|     # Mapping of Field objects to their SQL for CHECK constraints. | ||||
|     data_type_check_constraints = {} | ||||
|     ops = None | ||||
|     vendor = "unknown" | ||||
|     display_name = "unknown" | ||||
|     SchemaEditorClass = None | ||||
|     # Classes instantiated in __init__(). | ||||
|     client_class = None | ||||
|     creation_class = None | ||||
|     features_class = None | ||||
|     introspection_class = None | ||||
|     ops_class = None | ||||
|     validation_class = BaseDatabaseValidation | ||||
|  | ||||
|     queries_limit = 9000 | ||||
|  | ||||
|     def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS): | ||||
|         # Connection related attributes. | ||||
|         # The underlying database connection. | ||||
|         self.connection = None | ||||
|         # `settings_dict` should be a dictionary containing keys such as | ||||
|         # NAME, USER, etc. It's called `settings_dict` instead of `settings` | ||||
|         # to disambiguate it from Django settings modules. | ||||
|         self.settings_dict = settings_dict | ||||
|         self.alias = alias | ||||
|         # Query logging in debug mode or when explicitly enabled. | ||||
|         self.queries_log = deque(maxlen=self.queries_limit) | ||||
|         self.force_debug_cursor = False | ||||
|  | ||||
|         # Transaction related attributes. | ||||
|         # Tracks if the connection is in autocommit mode. Per PEP 249, by | ||||
|         # default, it isn't. | ||||
|         self.autocommit = False | ||||
|         # Tracks if the connection is in a transaction managed by 'atomic'. | ||||
|         self.in_atomic_block = False | ||||
|         # Increment to generate unique savepoint ids. | ||||
|         self.savepoint_state = 0 | ||||
|         # List of savepoints created by 'atomic'. | ||||
|         self.savepoint_ids = [] | ||||
|         # Stack of active 'atomic' blocks. | ||||
|         self.atomic_blocks = [] | ||||
|         # Tracks if the outermost 'atomic' block should commit on exit, | ||||
|         # ie. if autocommit was active on entry. | ||||
|         self.commit_on_exit = True | ||||
|         # Tracks if the transaction should be rolled back to the next | ||||
|         # available savepoint because of an exception in an inner block. | ||||
|         self.needs_rollback = False | ||||
|         self.rollback_exc = None | ||||
|  | ||||
|         # Connection termination related attributes. | ||||
|         self.close_at = None | ||||
|         self.closed_in_transaction = False | ||||
|         self.errors_occurred = False | ||||
|         self.health_check_enabled = False | ||||
|         self.health_check_done = False | ||||
|  | ||||
|         # Thread-safety related attributes. | ||||
|         self._thread_sharing_lock = threading.Lock() | ||||
|         self._thread_sharing_count = 0 | ||||
|         self._thread_ident = _thread.get_ident() | ||||
|  | ||||
|         # A list of no-argument functions to run when the transaction commits. | ||||
|         # Each entry is an (sids, func, robust) tuple, where sids is a set of | ||||
|         # the active savepoint IDs when this function was registered and robust | ||||
|         # specifies whether it's allowed for the function to fail. | ||||
|         self.run_on_commit = [] | ||||
|  | ||||
|         # Should we run the on-commit hooks the next time set_autocommit(True) | ||||
|         # is called? | ||||
|         self.run_commit_hooks_on_set_autocommit_on = False | ||||
|  | ||||
|         # A stack of wrappers to be invoked around execute()/executemany() | ||||
|         # calls. Each entry is a function taking five arguments: execute, sql, | ||||
|         # params, many, and context. It's the function's responsibility to | ||||
|         # call execute(sql, params, many, context). | ||||
|         self.execute_wrappers = [] | ||||
|  | ||||
|         self.client = self.client_class(self) | ||||
|         self.creation = self.creation_class(self) | ||||
|         self.features = self.features_class(self) | ||||
|         self.introspection = self.introspection_class(self) | ||||
|         self.ops = self.ops_class(self) | ||||
|         self.validation = self.validation_class(self) | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return ( | ||||
|             f"<{self.__class__.__qualname__} " | ||||
|             f"vendor={self.vendor!r} alias={self.alias!r}>" | ||||
|         ) | ||||
|  | ||||
|     def ensure_timezone(self): | ||||
|         """ | ||||
|         Ensure the connection's timezone is set to `self.timezone_name` and | ||||
|         return whether it changed or not. | ||||
|         """ | ||||
|         return False | ||||
|  | ||||
|     @cached_property | ||||
|     def timezone(self): | ||||
|         """ | ||||
|         Return a tzinfo of the database connection time zone. | ||||
|  | ||||
|         This is only used when time zone support is enabled. When a datetime is | ||||
|         read from the database, it is always returned in this time zone. | ||||
|  | ||||
|         When the database backend supports time zones, it doesn't matter which | ||||
|         time zone Django uses, as long as aware datetimes are used everywhere. | ||||
|         Other users connecting to the database can choose their own time zone. | ||||
|  | ||||
|         When the database backend doesn't support time zones, the time zone | ||||
|         Django uses may be constrained by the requirements of other users of | ||||
|         the database. | ||||
|         """ | ||||
|         if not settings.USE_TZ: | ||||
|             return None | ||||
|         elif self.settings_dict["TIME_ZONE"] is None: | ||||
|             return datetime.timezone.utc | ||||
|         else: | ||||
|             return timezone_constructor(self.settings_dict["TIME_ZONE"]) | ||||
|  | ||||
|     @cached_property | ||||
|     def timezone_name(self): | ||||
|         """ | ||||
|         Name of the time zone of the database connection. | ||||
|         """ | ||||
|         if not settings.USE_TZ: | ||||
|             return settings.TIME_ZONE | ||||
|         elif self.settings_dict["TIME_ZONE"] is None: | ||||
|             return "UTC" | ||||
|         else: | ||||
|             return self.settings_dict["TIME_ZONE"] | ||||
|  | ||||
|     @property | ||||
|     def queries_logged(self): | ||||
|         return self.force_debug_cursor or settings.DEBUG | ||||
|  | ||||
|     @property | ||||
|     def queries(self): | ||||
|         if len(self.queries_log) == self.queries_log.maxlen: | ||||
|             warnings.warn( | ||||
|                 "Limit for query logging exceeded, only the last {} queries " | ||||
|                 "will be returned.".format(self.queries_log.maxlen) | ||||
|             ) | ||||
|         return list(self.queries_log) | ||||
|  | ||||
|     def get_database_version(self): | ||||
|         """Return a tuple of the database's version.""" | ||||
|         raise NotImplementedError( | ||||
|             "subclasses of BaseDatabaseWrapper may require a get_database_version() " | ||||
|             "method." | ||||
|         ) | ||||
|  | ||||
|     def check_database_version_supported(self): | ||||
|         """ | ||||
|         Raise an error if the database version isn't supported by this | ||||
|         version of Django. | ||||
|         """ | ||||
|         if ( | ||||
|             self.features.minimum_database_version is not None | ||||
|             and self.get_database_version() < self.features.minimum_database_version | ||||
|         ): | ||||
|             db_version = ".".join(map(str, self.get_database_version())) | ||||
|             min_db_version = ".".join(map(str, self.features.minimum_database_version)) | ||||
|             raise NotSupportedError( | ||||
|                 f"{self.display_name} {min_db_version} or later is required " | ||||
|                 f"(found {db_version})." | ||||
|             ) | ||||
|  | ||||
|     # ##### Backend-specific methods for creating connections and cursors ##### | ||||
|  | ||||
|     def get_connection_params(self): | ||||
|         """Return a dict of parameters suitable for get_new_connection.""" | ||||
|         raise NotImplementedError( | ||||
|             "subclasses of BaseDatabaseWrapper may require a get_connection_params() " | ||||
|             "method" | ||||
|         ) | ||||
|  | ||||
|     def get_new_connection(self, conn_params): | ||||
|         """Open a connection to the database.""" | ||||
|         raise NotImplementedError( | ||||
|             "subclasses of BaseDatabaseWrapper may require a get_new_connection() " | ||||
|             "method" | ||||
|         ) | ||||
|  | ||||
|     def init_connection_state(self): | ||||
|         """Initialize the database connection settings.""" | ||||
|         global RAN_DB_VERSION_CHECK | ||||
|         if self.alias not in RAN_DB_VERSION_CHECK: | ||||
|             self.check_database_version_supported() | ||||
|             RAN_DB_VERSION_CHECK.add(self.alias) | ||||
|  | ||||
|     def create_cursor(self, name=None): | ||||
|         """Create a cursor. Assume that a connection is established.""" | ||||
|         raise NotImplementedError( | ||||
|             "subclasses of BaseDatabaseWrapper may require a create_cursor() method" | ||||
|         ) | ||||
|  | ||||
|     # ##### Backend-specific methods for creating connections ##### | ||||
|  | ||||
|     @async_unsafe | ||||
|     def connect(self): | ||||
|         """Connect to the database. Assume that the connection is closed.""" | ||||
|         # Check for invalid configurations. | ||||
|         self.check_settings() | ||||
|         # In case the previous connection was closed while in an atomic block | ||||
|         self.in_atomic_block = False | ||||
|         self.savepoint_ids = [] | ||||
|         self.atomic_blocks = [] | ||||
|         self.needs_rollback = False | ||||
|         # Reset parameters defining when to close/health-check the connection. | ||||
|         self.health_check_enabled = self.settings_dict["CONN_HEALTH_CHECKS"] | ||||
|         max_age = self.settings_dict["CONN_MAX_AGE"] | ||||
|         self.close_at = None if max_age is None else time.monotonic() + max_age | ||||
|         self.closed_in_transaction = False | ||||
|         self.errors_occurred = False | ||||
|         # New connections are healthy. | ||||
|         self.health_check_done = True | ||||
|         # Establish the connection | ||||
|         conn_params = self.get_connection_params() | ||||
|         self.connection = self.get_new_connection(conn_params) | ||||
|         self.set_autocommit(self.settings_dict["AUTOCOMMIT"]) | ||||
|         self.init_connection_state() | ||||
|         connection_created.send(sender=self.__class__, connection=self) | ||||
|  | ||||
|         self.run_on_commit = [] | ||||
|  | ||||
|     def check_settings(self): | ||||
|         if self.settings_dict["TIME_ZONE"] is not None and not settings.USE_TZ: | ||||
|             raise ImproperlyConfigured( | ||||
|                 "Connection '%s' cannot set TIME_ZONE because USE_TZ is False." | ||||
|                 % self.alias | ||||
|             ) | ||||
|  | ||||
|     @async_unsafe | ||||
|     def ensure_connection(self): | ||||
|         """Guarantee that a connection to the database is established.""" | ||||
|         if self.connection is None: | ||||
|             with self.wrap_database_errors: | ||||
|                 self.connect() | ||||
|  | ||||
|     # ##### Backend-specific wrappers for PEP-249 connection methods ##### | ||||
|  | ||||
|     def _prepare_cursor(self, cursor): | ||||
|         """ | ||||
|         Validate the connection is usable and perform database cursor wrapping. | ||||
|         """ | ||||
|         self.validate_thread_sharing() | ||||
|         if self.queries_logged: | ||||
|             wrapped_cursor = self.make_debug_cursor(cursor) | ||||
|         else: | ||||
|             wrapped_cursor = self.make_cursor(cursor) | ||||
|         return wrapped_cursor | ||||
|  | ||||
|     def _cursor(self, name=None): | ||||
|         self.close_if_health_check_failed() | ||||
|         self.ensure_connection() | ||||
|         with self.wrap_database_errors: | ||||
|             return self._prepare_cursor(self.create_cursor(name)) | ||||
|  | ||||
|     def _commit(self): | ||||
|         if self.connection is not None: | ||||
|             with debug_transaction(self, "COMMIT"), self.wrap_database_errors: | ||||
|                 return self.connection.commit() | ||||
|  | ||||
|     def _rollback(self): | ||||
|         if self.connection is not None: | ||||
|             with debug_transaction(self, "ROLLBACK"), self.wrap_database_errors: | ||||
|                 return self.connection.rollback() | ||||
|  | ||||
|     def _close(self): | ||||
|         if self.connection is not None: | ||||
|             with self.wrap_database_errors: | ||||
|                 return self.connection.close() | ||||
|  | ||||
|     # ##### Generic wrappers for PEP-249 connection methods ##### | ||||
|  | ||||
|     @async_unsafe | ||||
|     def cursor(self): | ||||
|         """Create a cursor, opening a connection if necessary.""" | ||||
|         return self._cursor() | ||||
|  | ||||
|     @async_unsafe | ||||
|     def commit(self): | ||||
|         """Commit a transaction and reset the dirty flag.""" | ||||
|         self.validate_thread_sharing() | ||||
|         self.validate_no_atomic_block() | ||||
|         self._commit() | ||||
|         # A successful commit means that the database connection works. | ||||
|         self.errors_occurred = False | ||||
|         self.run_commit_hooks_on_set_autocommit_on = True | ||||
|  | ||||
|     @async_unsafe | ||||
|     def rollback(self): | ||||
|         """Roll back a transaction and reset the dirty flag.""" | ||||
|         self.validate_thread_sharing() | ||||
|         self.validate_no_atomic_block() | ||||
|         self._rollback() | ||||
|         # A successful rollback means that the database connection works. | ||||
|         self.errors_occurred = False | ||||
|         self.needs_rollback = False | ||||
|         self.run_on_commit = [] | ||||
|  | ||||
|     @async_unsafe | ||||
|     def close(self): | ||||
|         """Close the connection to the database.""" | ||||
|         self.validate_thread_sharing() | ||||
|         self.run_on_commit = [] | ||||
|  | ||||
|         # Don't call validate_no_atomic_block() to avoid making it difficult | ||||
|         # to get rid of a connection in an invalid state. The next connect() | ||||
|         # will reset the transaction state anyway. | ||||
|         if self.closed_in_transaction or self.connection is None: | ||||
|             return | ||||
|         try: | ||||
|             self._close() | ||||
|         finally: | ||||
|             if self.in_atomic_block: | ||||
|                 self.closed_in_transaction = True | ||||
|                 self.needs_rollback = True | ||||
|             else: | ||||
|                 self.connection = None | ||||
|  | ||||
|     # ##### Backend-specific savepoint management methods ##### | ||||
|  | ||||
|     def _savepoint(self, sid): | ||||
|         with self.cursor() as cursor: | ||||
|             cursor.execute(self.ops.savepoint_create_sql(sid)) | ||||
|  | ||||
|     def _savepoint_rollback(self, sid): | ||||
|         with self.cursor() as cursor: | ||||
|             cursor.execute(self.ops.savepoint_rollback_sql(sid)) | ||||
|  | ||||
|     def _savepoint_commit(self, sid): | ||||
|         with self.cursor() as cursor: | ||||
|             cursor.execute(self.ops.savepoint_commit_sql(sid)) | ||||
|  | ||||
|     def _savepoint_allowed(self): | ||||
|         # Savepoints cannot be created outside a transaction | ||||
|         return self.features.uses_savepoints and not self.get_autocommit() | ||||
|  | ||||
|     # ##### Generic savepoint management methods ##### | ||||
|  | ||||
|     @async_unsafe | ||||
|     def savepoint(self): | ||||
|         """ | ||||
|         Create a savepoint inside the current transaction. Return an | ||||
|         identifier for the savepoint that will be used for the subsequent | ||||
|         rollback or commit. Do nothing if savepoints are not supported. | ||||
|         """ | ||||
|         if not self._savepoint_allowed(): | ||||
|             return | ||||
|  | ||||
|         thread_ident = _thread.get_ident() | ||||
|         tid = str(thread_ident).replace("-", "") | ||||
|  | ||||
|         self.savepoint_state += 1 | ||||
|         sid = "s%s_x%d" % (tid, self.savepoint_state) | ||||
|  | ||||
|         self.validate_thread_sharing() | ||||
|         self._savepoint(sid) | ||||
|  | ||||
|         return sid | ||||
|  | ||||
|     @async_unsafe | ||||
|     def savepoint_rollback(self, sid): | ||||
|         """ | ||||
|         Roll back to a savepoint. Do nothing if savepoints are not supported. | ||||
|         """ | ||||
|         if not self._savepoint_allowed(): | ||||
|             return | ||||
|  | ||||
|         self.validate_thread_sharing() | ||||
|         self._savepoint_rollback(sid) | ||||
|  | ||||
|         # Remove any callbacks registered while this savepoint was active. | ||||
|         self.run_on_commit = [ | ||||
|             (sids, func, robust) | ||||
|             for (sids, func, robust) in self.run_on_commit | ||||
|             if sid not in sids | ||||
|         ] | ||||
|  | ||||
|     @async_unsafe | ||||
|     def savepoint_commit(self, sid): | ||||
|         """ | ||||
|         Release a savepoint. Do nothing if savepoints are not supported. | ||||
|         """ | ||||
|         if not self._savepoint_allowed(): | ||||
|             return | ||||
|  | ||||
|         self.validate_thread_sharing() | ||||
|         self._savepoint_commit(sid) | ||||
|  | ||||
|     @async_unsafe | ||||
|     def clean_savepoints(self): | ||||
|         """ | ||||
|         Reset the counter used to generate unique savepoint ids in this thread. | ||||
|         """ | ||||
|         self.savepoint_state = 0 | ||||
|  | ||||
|     # ##### Backend-specific transaction management methods ##### | ||||
|  | ||||
|     def _set_autocommit(self, autocommit): | ||||
|         """ | ||||
|         Backend-specific implementation to enable or disable autocommit. | ||||
|         """ | ||||
|         raise NotImplementedError( | ||||
|             "subclasses of BaseDatabaseWrapper may require a _set_autocommit() method" | ||||
|         ) | ||||
|  | ||||
|     # ##### Generic transaction management methods ##### | ||||
|  | ||||
|     def get_autocommit(self): | ||||
|         """Get the autocommit state.""" | ||||
|         self.ensure_connection() | ||||
|         return self.autocommit | ||||
|  | ||||
|     def set_autocommit( | ||||
|         self, autocommit, force_begin_transaction_with_broken_autocommit=False | ||||
|     ): | ||||
|         """ | ||||
|         Enable or disable autocommit. | ||||
|  | ||||
|         The usual way to start a transaction is to turn autocommit off. | ||||
|         SQLite does not properly start a transaction when disabling | ||||
|         autocommit. To avoid this buggy behavior and to actually enter a new | ||||
|         transaction, an explicit BEGIN is required. Using | ||||
|         force_begin_transaction_with_broken_autocommit=True will issue an | ||||
|         explicit BEGIN with SQLite. This option will be ignored for other | ||||
|         backends. | ||||
|         """ | ||||
|         self.validate_no_atomic_block() | ||||
|         self.close_if_health_check_failed() | ||||
|         self.ensure_connection() | ||||
|  | ||||
|         start_transaction_under_autocommit = ( | ||||
|             force_begin_transaction_with_broken_autocommit | ||||
|             and not autocommit | ||||
|             and hasattr(self, "_start_transaction_under_autocommit") | ||||
|         ) | ||||
|  | ||||
|         if start_transaction_under_autocommit: | ||||
|             self._start_transaction_under_autocommit() | ||||
|         elif autocommit: | ||||
|             self._set_autocommit(autocommit) | ||||
|         else: | ||||
|             with debug_transaction(self, "BEGIN"): | ||||
|                 self._set_autocommit(autocommit) | ||||
|         self.autocommit = autocommit | ||||
|  | ||||
|         if autocommit and self.run_commit_hooks_on_set_autocommit_on: | ||||
|             self.run_and_clear_commit_hooks() | ||||
|             self.run_commit_hooks_on_set_autocommit_on = False | ||||
|  | ||||
|     def get_rollback(self): | ||||
|         """Get the "needs rollback" flag -- for *advanced use* only.""" | ||||
|         if not self.in_atomic_block: | ||||
|             raise TransactionManagementError( | ||||
|                 "The rollback flag doesn't work outside of an 'atomic' block." | ||||
|             ) | ||||
|         return self.needs_rollback | ||||
|  | ||||
|     def set_rollback(self, rollback): | ||||
|         """ | ||||
|         Set or unset the "needs rollback" flag -- for *advanced use* only. | ||||
|         """ | ||||
|         if not self.in_atomic_block: | ||||
|             raise TransactionManagementError( | ||||
|                 "The rollback flag doesn't work outside of an 'atomic' block." | ||||
|             ) | ||||
|         self.needs_rollback = rollback | ||||
|  | ||||
|     def validate_no_atomic_block(self): | ||||
|         """Raise an error if an atomic block is active.""" | ||||
|         if self.in_atomic_block: | ||||
|             raise TransactionManagementError( | ||||
|                 "This is forbidden when an 'atomic' block is active." | ||||
|             ) | ||||
|  | ||||
|     def validate_no_broken_transaction(self): | ||||
|         if self.needs_rollback: | ||||
|             raise TransactionManagementError( | ||||
|                 "An error occurred in the current transaction. You can't " | ||||
|                 "execute queries until the end of the 'atomic' block." | ||||
|             ) from self.rollback_exc | ||||
|  | ||||
|     # ##### Foreign key constraints checks handling ##### | ||||
|  | ||||
|     @contextmanager | ||||
|     def constraint_checks_disabled(self): | ||||
|         """ | ||||
|         Disable foreign key constraint checking. | ||||
|         """ | ||||
|         disabled = self.disable_constraint_checking() | ||||
|         try: | ||||
|             yield | ||||
|         finally: | ||||
|             if disabled: | ||||
|                 self.enable_constraint_checking() | ||||
|  | ||||
|     def disable_constraint_checking(self): | ||||
|         """ | ||||
|         Backends can implement as needed to temporarily disable foreign key | ||||
|         constraint checking. Should return True if the constraints were | ||||
|         disabled and will need to be reenabled. | ||||
|         """ | ||||
|         return False | ||||
|  | ||||
|     def enable_constraint_checking(self): | ||||
|         """ | ||||
|         Backends can implement as needed to re-enable foreign key constraint | ||||
|         checking. | ||||
|         """ | ||||
|         pass | ||||
|  | ||||
|     def check_constraints(self, table_names=None): | ||||
|         """ | ||||
|         Backends can override this method if they can apply constraint | ||||
|         checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE"). Should raise an | ||||
|         IntegrityError if any invalid foreign key references are encountered. | ||||
|         """ | ||||
|         pass | ||||
|  | ||||
|     # ##### Connection termination handling ##### | ||||
|  | ||||
|     def is_usable(self): | ||||
|         """ | ||||
|         Test if the database connection is usable. | ||||
|  | ||||
|         This method may assume that self.connection is not None. | ||||
|  | ||||
|         Actual implementations should take care not to raise exceptions | ||||
|         as that may prevent Django from recycling unusable connections. | ||||
|         """ | ||||
|         raise NotImplementedError( | ||||
|             "subclasses of BaseDatabaseWrapper may require an is_usable() method" | ||||
|         ) | ||||
|  | ||||
|     def close_if_health_check_failed(self): | ||||
|         """Close existing connection if it fails a health check.""" | ||||
|         if ( | ||||
|             self.connection is None | ||||
|             or not self.health_check_enabled | ||||
|             or self.health_check_done | ||||
|         ): | ||||
|             return | ||||
|  | ||||
|         if not self.is_usable(): | ||||
|             self.close() | ||||
|         self.health_check_done = True | ||||
|  | ||||
|     def close_if_unusable_or_obsolete(self): | ||||
|         """ | ||||
|         Close the current connection if unrecoverable errors have occurred | ||||
|         or if it outlived its maximum age. | ||||
|         """ | ||||
|         if self.connection is not None: | ||||
|             self.health_check_done = False | ||||
|             # If the application didn't restore the original autocommit setting, | ||||
|             # don't take chances, drop the connection. | ||||
|             if self.get_autocommit() != self.settings_dict["AUTOCOMMIT"]: | ||||
|                 self.close() | ||||
|                 return | ||||
|  | ||||
|             # If an exception other than DataError or IntegrityError occurred | ||||
|             # since the last commit / rollback, check if the connection works. | ||||
|             if self.errors_occurred: | ||||
|                 if self.is_usable(): | ||||
|                     self.errors_occurred = False | ||||
|                     self.health_check_done = True | ||||
|                 else: | ||||
|                     self.close() | ||||
|                     return | ||||
|  | ||||
|             if self.close_at is not None and time.monotonic() >= self.close_at: | ||||
|                 self.close() | ||||
|                 return | ||||
|  | ||||
|     # ##### Thread safety handling ##### | ||||
|  | ||||
|     @property | ||||
|     def allow_thread_sharing(self): | ||||
|         with self._thread_sharing_lock: | ||||
|             return self._thread_sharing_count > 0 | ||||
|  | ||||
|     def inc_thread_sharing(self): | ||||
|         with self._thread_sharing_lock: | ||||
|             self._thread_sharing_count += 1 | ||||
|  | ||||
|     def dec_thread_sharing(self): | ||||
|         with self._thread_sharing_lock: | ||||
|             if self._thread_sharing_count <= 0: | ||||
|                 raise RuntimeError( | ||||
|                     "Cannot decrement the thread sharing count below zero." | ||||
|                 ) | ||||
|             self._thread_sharing_count -= 1 | ||||
|  | ||||
|     def validate_thread_sharing(self): | ||||
|         """ | ||||
|         Validate that the connection isn't accessed by another thread than the | ||||
|         one which originally created it, unless the connection was explicitly | ||||
|         authorized to be shared between threads (via the `inc_thread_sharing()` | ||||
|         method). Raise an exception if the validation fails. | ||||
|         """ | ||||
|         if not (self.allow_thread_sharing or self._thread_ident == _thread.get_ident()): | ||||
|             raise DatabaseError( | ||||
|                 "DatabaseWrapper objects created in a " | ||||
|                 "thread can only be used in that same thread. The object " | ||||
|                 "with alias '%s' was created in thread id %s and this is " | ||||
|                 "thread id %s." % (self.alias, self._thread_ident, _thread.get_ident()) | ||||
|             ) | ||||
|  | ||||
|     # ##### Miscellaneous ##### | ||||
|  | ||||
|     def prepare_database(self): | ||||
|         """ | ||||
|         Hook to do any database check or preparation, generally called before | ||||
|         migrating a project or an app. | ||||
|         """ | ||||
|         pass | ||||
|  | ||||
|     @cached_property | ||||
|     def wrap_database_errors(self): | ||||
|         """ | ||||
|         Context manager and decorator that re-throws backend-specific database | ||||
|         exceptions using Django's common wrappers. | ||||
|         """ | ||||
|         return DatabaseErrorWrapper(self) | ||||
|  | ||||
|     def chunked_cursor(self): | ||||
|         """ | ||||
|         Return a cursor that tries to avoid caching in the database (if | ||||
|         supported by the database), otherwise return a regular cursor. | ||||
|         """ | ||||
|         return self.cursor() | ||||
|  | ||||
|     def make_debug_cursor(self, cursor): | ||||
|         """Create a cursor that logs all queries in self.queries_log.""" | ||||
|         return utils.CursorDebugWrapper(cursor, self) | ||||
|  | ||||
|     def make_cursor(self, cursor): | ||||
|         """Create a cursor without debug logging.""" | ||||
|         return utils.CursorWrapper(cursor, self) | ||||
|  | ||||
|     @contextmanager | ||||
|     def temporary_connection(self): | ||||
|         """ | ||||
|         Context manager that ensures that a connection is established, and | ||||
|         if it opened one, closes it to avoid leaving a dangling connection. | ||||
|         This is useful for operations outside of the request-response cycle. | ||||
|  | ||||
|         Provide a cursor: with self.temporary_connection() as cursor: ... | ||||
|         """ | ||||
|         must_close = self.connection is None | ||||
|         try: | ||||
|             with self.cursor() as cursor: | ||||
|                 yield cursor | ||||
|         finally: | ||||
|             if must_close: | ||||
|                 self.close() | ||||
|  | ||||
|     @contextmanager | ||||
|     def _nodb_cursor(self): | ||||
|         """ | ||||
|         Return a cursor from an alternative connection to be used when there is | ||||
|         no need to access the main database, specifically for test db | ||||
|         creation/deletion. This also prevents the production database from | ||||
|         being exposed to potential child threads while (or after) the test | ||||
|         database is destroyed. Refs #10868, #17786, #16969. | ||||
|         """ | ||||
|         conn = self.__class__({**self.settings_dict, "NAME": None}, alias=NO_DB_ALIAS) | ||||
|         try: | ||||
|             with conn.cursor() as cursor: | ||||
|                 yield cursor | ||||
|         finally: | ||||
|             conn.close() | ||||
|  | ||||
|     def schema_editor(self, *args, **kwargs): | ||||
|         """ | ||||
|         Return a new instance of this backend's SchemaEditor. | ||||
|         """ | ||||
|         if self.SchemaEditorClass is None: | ||||
|             raise NotImplementedError( | ||||
|                 "The SchemaEditorClass attribute of this database wrapper is still None" | ||||
|             ) | ||||
|         return self.SchemaEditorClass(self, *args, **kwargs) | ||||
|  | ||||
|     def on_commit(self, func, robust=False): | ||||
|         if not callable(func): | ||||
|             raise TypeError("on_commit()'s callback must be a callable.") | ||||
|         if self.in_atomic_block: | ||||
|             # Transaction in progress; save for execution on commit. | ||||
|             self.run_on_commit.append((set(self.savepoint_ids), func, robust)) | ||||
|         elif not self.get_autocommit(): | ||||
|             raise TransactionManagementError( | ||||
|                 "on_commit() cannot be used in manual transaction management" | ||||
|             ) | ||||
|         else: | ||||
|             # No transaction in progress and in autocommit mode; execute | ||||
|             # immediately. | ||||
|             if robust: | ||||
|                 try: | ||||
|                     func() | ||||
|                 except Exception as e: | ||||
|                     logger.error( | ||||
|                         f"Error calling {func.__qualname__} in on_commit() (%s).", | ||||
|                         e, | ||||
|                         exc_info=True, | ||||
|                     ) | ||||
|             else: | ||||
|                 func() | ||||
|  | ||||
|     def run_and_clear_commit_hooks(self): | ||||
|         self.validate_no_atomic_block() | ||||
|         current_run_on_commit = self.run_on_commit | ||||
|         self.run_on_commit = [] | ||||
|         while current_run_on_commit: | ||||
|             _, func, robust = current_run_on_commit.pop(0) | ||||
|             if robust: | ||||
|                 try: | ||||
|                     func() | ||||
|                 except Exception as e: | ||||
|                     logger.error( | ||||
|                         f"Error calling {func.__qualname__} in on_commit() during " | ||||
|                         f"transaction (%s).", | ||||
|                         e, | ||||
|                         exc_info=True, | ||||
|                     ) | ||||
|             else: | ||||
|                 func() | ||||
|  | ||||
|     @contextmanager | ||||
|     def execute_wrapper(self, wrapper): | ||||
|         """ | ||||
|         Return a context manager under which the wrapper is applied to suitable | ||||
|         database query executions. | ||||
|         """ | ||||
|         self.execute_wrappers.append(wrapper) | ||||
|         try: | ||||
|             yield | ||||
|         finally: | ||||
|             self.execute_wrappers.pop() | ||||
|  | ||||
|     def copy(self, alias=None): | ||||
|         """ | ||||
|         Return a copy of this connection. | ||||
|  | ||||
|         For tests that require two connections to the same database. | ||||
|         """ | ||||
|         settings_dict = copy.deepcopy(self.settings_dict) | ||||
|         if alias is None: | ||||
|             alias = self.alias | ||||
|         return type(self)(settings_dict, alias) | ||||
| @ -0,0 +1,28 @@ | ||||
| import os | ||||
| import subprocess | ||||
|  | ||||
|  | ||||
| class BaseDatabaseClient: | ||||
|     """Encapsulate backend-specific methods for opening a client shell.""" | ||||
|  | ||||
|     # This should be a string representing the name of the executable | ||||
|     # (e.g., "psql"). Subclasses must override this. | ||||
|     executable_name = None | ||||
|  | ||||
|     def __init__(self, connection): | ||||
|         # connection is an instance of BaseDatabaseWrapper. | ||||
|         self.connection = connection | ||||
|  | ||||
|     @classmethod | ||||
|     def settings_to_cmd_args_env(cls, settings_dict, parameters): | ||||
|         raise NotImplementedError( | ||||
|             "subclasses of BaseDatabaseClient must provide a " | ||||
|             "settings_to_cmd_args_env() method or override a runshell()." | ||||
|         ) | ||||
|  | ||||
|     def runshell(self, parameters): | ||||
|         args, env = self.settings_to_cmd_args_env( | ||||
|             self.connection.settings_dict, parameters | ||||
|         ) | ||||
|         env = {**os.environ, **env} if env else None | ||||
|         subprocess.run(args, env=env, check=True) | ||||
| @ -0,0 +1,381 @@ | ||||
| import os | ||||
| import sys | ||||
| from io import StringIO | ||||
|  | ||||
| from django.apps import apps | ||||
| from django.conf import settings | ||||
| from django.core import serializers | ||||
| from django.db import router | ||||
| from django.db.transaction import atomic | ||||
| from django.utils.module_loading import import_string | ||||
|  | ||||
| # The prefix to put on the default database name when creating | ||||
| # the test database. | ||||
| TEST_DATABASE_PREFIX = "test_" | ||||
|  | ||||
|  | ||||
| class BaseDatabaseCreation: | ||||
|     """ | ||||
|     Encapsulate backend-specific differences pertaining to creation and | ||||
|     destruction of the test database. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, connection): | ||||
|         self.connection = connection | ||||
|  | ||||
|     def _nodb_cursor(self): | ||||
|         return self.connection._nodb_cursor() | ||||
|  | ||||
|     def log(self, msg): | ||||
|         sys.stderr.write(msg + os.linesep) | ||||
|  | ||||
|     def create_test_db( | ||||
|         self, verbosity=1, autoclobber=False, serialize=True, keepdb=False | ||||
|     ): | ||||
|         """ | ||||
|         Create a test database, prompting the user for confirmation if the | ||||
|         database already exists. Return the name of the test database created. | ||||
|         """ | ||||
|         # Don't import django.core.management if it isn't needed. | ||||
|         from django.core.management import call_command | ||||
|  | ||||
|         test_database_name = self._get_test_db_name() | ||||
|  | ||||
|         if verbosity >= 1: | ||||
|             action = "Creating" | ||||
|             if keepdb: | ||||
|                 action = "Using existing" | ||||
|  | ||||
|             self.log( | ||||
|                 "%s test database for alias %s..." | ||||
|                 % ( | ||||
|                     action, | ||||
|                     self._get_database_display_str(verbosity, test_database_name), | ||||
|                 ) | ||||
|             ) | ||||
|  | ||||
|         # We could skip this call if keepdb is True, but we instead | ||||
|         # give it the keepdb param. This is to handle the case | ||||
|         # where the test DB doesn't exist, in which case we need to | ||||
|         # create it, then just not destroy it. If we instead skip | ||||
|         # this, we will get an exception. | ||||
|         self._create_test_db(verbosity, autoclobber, keepdb) | ||||
|  | ||||
|         self.connection.close() | ||||
|         settings.DATABASES[self.connection.alias]["NAME"] = test_database_name | ||||
|         self.connection.settings_dict["NAME"] = test_database_name | ||||
|  | ||||
|         try: | ||||
|             if self.connection.settings_dict["TEST"]["MIGRATE"] is False: | ||||
|                 # Disable migrations for all apps. | ||||
|                 old_migration_modules = settings.MIGRATION_MODULES | ||||
|                 settings.MIGRATION_MODULES = { | ||||
|                     app.label: None for app in apps.get_app_configs() | ||||
|                 } | ||||
|             # We report migrate messages at one level lower than that | ||||
|             # requested. This ensures we don't get flooded with messages during | ||||
|             # testing (unless you really ask to be flooded). | ||||
|             call_command( | ||||
|                 "migrate", | ||||
|                 verbosity=max(verbosity - 1, 0), | ||||
|                 interactive=False, | ||||
|                 database=self.connection.alias, | ||||
|                 run_syncdb=True, | ||||
|             ) | ||||
|         finally: | ||||
|             if self.connection.settings_dict["TEST"]["MIGRATE"] is False: | ||||
|                 settings.MIGRATION_MODULES = old_migration_modules | ||||
|  | ||||
|         # We then serialize the current state of the database into a string | ||||
|         # and store it on the connection. This slightly horrific process is so people | ||||
|         # who are testing on databases without transactions or who are using | ||||
|         # a TransactionTestCase still get a clean database on every test run. | ||||
|         if serialize: | ||||
|             self.connection._test_serialized_contents = self.serialize_db_to_string() | ||||
|  | ||||
|         call_command("createcachetable", database=self.connection.alias) | ||||
|  | ||||
|         # Ensure a connection for the side effect of initializing the test database. | ||||
|         self.connection.ensure_connection() | ||||
|  | ||||
|         if os.environ.get("RUNNING_DJANGOS_TEST_SUITE") == "true": | ||||
|             self.mark_expected_failures_and_skips() | ||||
|  | ||||
|         return test_database_name | ||||
|  | ||||
|     def set_as_test_mirror(self, primary_settings_dict): | ||||
|         """ | ||||
|         Set this database up to be used in testing as a mirror of a primary | ||||
|         database whose settings are given. | ||||
|         """ | ||||
|         self.connection.settings_dict["NAME"] = primary_settings_dict["NAME"] | ||||
|  | ||||
|     def serialize_db_to_string(self): | ||||
|         """ | ||||
|         Serialize all data in the database into a JSON string. | ||||
|         Designed only for test runner usage; will not handle large | ||||
|         amounts of data. | ||||
|         """ | ||||
|  | ||||
|         # Iteratively return every object for all models to serialize. | ||||
|         def get_objects(): | ||||
|             from django.db.migrations.loader import MigrationLoader | ||||
|  | ||||
|             loader = MigrationLoader(self.connection) | ||||
|             for app_config in apps.get_app_configs(): | ||||
|                 if ( | ||||
|                     app_config.models_module is not None | ||||
|                     and app_config.label in loader.migrated_apps | ||||
|                     and app_config.name not in settings.TEST_NON_SERIALIZED_APPS | ||||
|                 ): | ||||
|                     for model in app_config.get_models(): | ||||
|                         if model._meta.can_migrate( | ||||
|                             self.connection | ||||
|                         ) and router.allow_migrate_model(self.connection.alias, model): | ||||
|                             queryset = model._base_manager.using( | ||||
|                                 self.connection.alias, | ||||
|                             ).order_by(model._meta.pk.name) | ||||
|                             yield from queryset.iterator() | ||||
|  | ||||
|         # Serialize to a string | ||||
|         out = StringIO() | ||||
|         serializers.serialize("json", get_objects(), indent=None, stream=out) | ||||
|         return out.getvalue() | ||||
|  | ||||
|     def deserialize_db_from_string(self, data): | ||||
|         """ | ||||
|         Reload the database with data from a string generated by | ||||
|         the serialize_db_to_string() method. | ||||
|         """ | ||||
|         data = StringIO(data) | ||||
|         table_names = set() | ||||
|         # Load data in a transaction to handle forward references and cycles. | ||||
|         with atomic(using=self.connection.alias): | ||||
|             # Disable constraint checks, because some databases (MySQL) doesn't | ||||
|             # support deferred checks. | ||||
|             with self.connection.constraint_checks_disabled(): | ||||
|                 for obj in serializers.deserialize( | ||||
|                     "json", data, using=self.connection.alias | ||||
|                 ): | ||||
|                     obj.save() | ||||
|                     table_names.add(obj.object.__class__._meta.db_table) | ||||
|             # Manually check for any invalid keys that might have been added, | ||||
|             # because constraint checks were disabled. | ||||
|             self.connection.check_constraints(table_names=table_names) | ||||
|  | ||||
|     def _get_database_display_str(self, verbosity, database_name): | ||||
|         """ | ||||
|         Return display string for a database for use in various actions. | ||||
|         """ | ||||
|         return "'%s'%s" % ( | ||||
|             self.connection.alias, | ||||
|             (" ('%s')" % database_name) if verbosity >= 2 else "", | ||||
|         ) | ||||
|  | ||||
|     def _get_test_db_name(self): | ||||
|         """ | ||||
|         Internal implementation - return the name of the test DB that will be | ||||
|         created. Only useful when called from create_test_db() and | ||||
|         _create_test_db() and when no external munging is done with the 'NAME' | ||||
|         settings. | ||||
|         """ | ||||
|         if self.connection.settings_dict["TEST"]["NAME"]: | ||||
|             return self.connection.settings_dict["TEST"]["NAME"] | ||||
|         return TEST_DATABASE_PREFIX + self.connection.settings_dict["NAME"] | ||||
|  | ||||
|     def _execute_create_test_db(self, cursor, parameters, keepdb=False): | ||||
|         cursor.execute("CREATE DATABASE %(dbname)s %(suffix)s" % parameters) | ||||
|  | ||||
|     def _create_test_db(self, verbosity, autoclobber, keepdb=False): | ||||
|         """ | ||||
|         Internal implementation - create the test db tables. | ||||
|         """ | ||||
|         test_database_name = self._get_test_db_name() | ||||
|         test_db_params = { | ||||
|             "dbname": self.connection.ops.quote_name(test_database_name), | ||||
|             "suffix": self.sql_table_creation_suffix(), | ||||
|         } | ||||
|         # Create the test database and connect to it. | ||||
|         with self._nodb_cursor() as cursor: | ||||
|             try: | ||||
|                 self._execute_create_test_db(cursor, test_db_params, keepdb) | ||||
|             except Exception as e: | ||||
|                 # if we want to keep the db, then no need to do any of the below, | ||||
|                 # just return and skip it all. | ||||
|                 if keepdb: | ||||
|                     return test_database_name | ||||
|  | ||||
|                 self.log("Got an error creating the test database: %s" % e) | ||||
|                 if not autoclobber: | ||||
|                     confirm = input( | ||||
|                         "Type 'yes' if you would like to try deleting the test " | ||||
|                         "database '%s', or 'no' to cancel: " % test_database_name | ||||
|                     ) | ||||
|                 if autoclobber or confirm == "yes": | ||||
|                     try: | ||||
|                         if verbosity >= 1: | ||||
|                             self.log( | ||||
|                                 "Destroying old test database for alias %s..." | ||||
|                                 % ( | ||||
|                                     self._get_database_display_str( | ||||
|                                         verbosity, test_database_name | ||||
|                                     ), | ||||
|                                 ) | ||||
|                             ) | ||||
|                         cursor.execute("DROP DATABASE %(dbname)s" % test_db_params) | ||||
|                         self._execute_create_test_db(cursor, test_db_params, keepdb) | ||||
|                     except Exception as e: | ||||
|                         self.log("Got an error recreating the test database: %s" % e) | ||||
|                         sys.exit(2) | ||||
|                 else: | ||||
|                     self.log("Tests cancelled.") | ||||
|                     sys.exit(1) | ||||
|  | ||||
|         return test_database_name | ||||
|  | ||||
|     def clone_test_db(self, suffix, verbosity=1, autoclobber=False, keepdb=False): | ||||
|         """ | ||||
|         Clone a test database. | ||||
|         """ | ||||
|         source_database_name = self.connection.settings_dict["NAME"] | ||||
|  | ||||
|         if verbosity >= 1: | ||||
|             action = "Cloning test database" | ||||
|             if keepdb: | ||||
|                 action = "Using existing clone" | ||||
|             self.log( | ||||
|                 "%s for alias %s..." | ||||
|                 % ( | ||||
|                     action, | ||||
|                     self._get_database_display_str(verbosity, source_database_name), | ||||
|                 ) | ||||
|             ) | ||||
|  | ||||
|         # We could skip this call if keepdb is True, but we instead | ||||
|         # give it the keepdb param. See create_test_db for details. | ||||
|         self._clone_test_db(suffix, verbosity, keepdb) | ||||
|  | ||||
|     def get_test_db_clone_settings(self, suffix): | ||||
|         """ | ||||
|         Return a modified connection settings dict for the n-th clone of a DB. | ||||
|         """ | ||||
|         # When this function is called, the test database has been created | ||||
|         # already and its name has been copied to settings_dict['NAME'] so | ||||
|         # we don't need to call _get_test_db_name. | ||||
|         orig_settings_dict = self.connection.settings_dict | ||||
|         return { | ||||
|             **orig_settings_dict, | ||||
|             "NAME": "{}_{}".format(orig_settings_dict["NAME"], suffix), | ||||
|         } | ||||
|  | ||||
|     def _clone_test_db(self, suffix, verbosity, keepdb=False): | ||||
|         """ | ||||
|         Internal implementation - duplicate the test db tables. | ||||
|         """ | ||||
|         raise NotImplementedError( | ||||
|             "The database backend doesn't support cloning databases. " | ||||
|             "Disable the option to run tests in parallel processes." | ||||
|         ) | ||||
|  | ||||
|     def destroy_test_db( | ||||
|         self, old_database_name=None, verbosity=1, keepdb=False, suffix=None | ||||
|     ): | ||||
|         """ | ||||
|         Destroy a test database, prompting the user for confirmation if the | ||||
|         database already exists. | ||||
|         """ | ||||
|         self.connection.close() | ||||
|         if suffix is None: | ||||
|             test_database_name = self.connection.settings_dict["NAME"] | ||||
|         else: | ||||
|             test_database_name = self.get_test_db_clone_settings(suffix)["NAME"] | ||||
|  | ||||
|         if verbosity >= 1: | ||||
|             action = "Destroying" | ||||
|             if keepdb: | ||||
|                 action = "Preserving" | ||||
|             self.log( | ||||
|                 "%s test database for alias %s..." | ||||
|                 % ( | ||||
|                     action, | ||||
|                     self._get_database_display_str(verbosity, test_database_name), | ||||
|                 ) | ||||
|             ) | ||||
|  | ||||
|         # if we want to preserve the database | ||||
|         # skip the actual destroying piece. | ||||
|         if not keepdb: | ||||
|             self._destroy_test_db(test_database_name, verbosity) | ||||
|  | ||||
|         # Restore the original database name | ||||
|         if old_database_name is not None: | ||||
|             settings.DATABASES[self.connection.alias]["NAME"] = old_database_name | ||||
|             self.connection.settings_dict["NAME"] = old_database_name | ||||
|  | ||||
|     def _destroy_test_db(self, test_database_name, verbosity): | ||||
|         """ | ||||
|         Internal implementation - remove the test db tables. | ||||
|         """ | ||||
|         # Remove the test database to clean up after | ||||
|         # ourselves. Connect to the previous database (not the test database) | ||||
|         # to do so, because it's not allowed to delete a database while being | ||||
|         # connected to it. | ||||
|         with self._nodb_cursor() as cursor: | ||||
|             cursor.execute( | ||||
|                 "DROP DATABASE %s" % self.connection.ops.quote_name(test_database_name) | ||||
|             ) | ||||
|  | ||||
|     def mark_expected_failures_and_skips(self): | ||||
|         """ | ||||
|         Mark tests in Django's test suite which are expected failures on this | ||||
|         database and test which should be skipped on this database. | ||||
|         """ | ||||
|         # Only load unittest if we're actually testing. | ||||
|         from unittest import expectedFailure, skip | ||||
|  | ||||
|         for test_name in self.connection.features.django_test_expected_failures: | ||||
|             test_case_name, _, test_method_name = test_name.rpartition(".") | ||||
|             test_app = test_name.split(".")[0] | ||||
|             # Importing a test app that isn't installed raises RuntimeError. | ||||
|             if test_app in settings.INSTALLED_APPS: | ||||
|                 test_case = import_string(test_case_name) | ||||
|                 test_method = getattr(test_case, test_method_name) | ||||
|                 setattr(test_case, test_method_name, expectedFailure(test_method)) | ||||
|         for reason, tests in self.connection.features.django_test_skips.items(): | ||||
|             for test_name in tests: | ||||
|                 test_case_name, _, test_method_name = test_name.rpartition(".") | ||||
|                 test_app = test_name.split(".")[0] | ||||
|                 # Importing a test app that isn't installed raises RuntimeError. | ||||
|                 if test_app in settings.INSTALLED_APPS: | ||||
|                     test_case = import_string(test_case_name) | ||||
|                     test_method = getattr(test_case, test_method_name) | ||||
|                     setattr(test_case, test_method_name, skip(reason)(test_method)) | ||||
|  | ||||
|     def sql_table_creation_suffix(self): | ||||
|         """ | ||||
|         SQL to append to the end of the test table creation statements. | ||||
|         """ | ||||
|         return "" | ||||
|  | ||||
|     def test_db_signature(self): | ||||
|         """ | ||||
|         Return a tuple with elements of self.connection.settings_dict (a | ||||
|         DATABASES setting value) that uniquely identify a database | ||||
|         accordingly to the RDBMS particularities. | ||||
|         """ | ||||
|         settings_dict = self.connection.settings_dict | ||||
|         return ( | ||||
|             settings_dict["HOST"], | ||||
|             settings_dict["PORT"], | ||||
|             settings_dict["ENGINE"], | ||||
|             self._get_test_db_name(), | ||||
|         ) | ||||
|  | ||||
|     def setup_worker_connection(self, _worker_id): | ||||
|         settings_dict = self.get_test_db_clone_settings(str(_worker_id)) | ||||
|         # connection.settings_dict must be updated in place for changes to be | ||||
|         # reflected in django.db.connections. If the following line assigned | ||||
|         # connection.settings_dict = settings_dict, new threads would connect | ||||
|         # to the default database instead of the appropriate clone. | ||||
|         self.connection.settings_dict.update(settings_dict) | ||||
|         self.connection.close() | ||||
| @ -0,0 +1,393 @@ | ||||
| from django.db import ProgrammingError | ||||
| from django.utils.functional import cached_property | ||||
|  | ||||
|  | ||||
| class BaseDatabaseFeatures: | ||||
|     # An optional tuple indicating the minimum supported database version. | ||||
|     minimum_database_version = None | ||||
|     gis_enabled = False | ||||
|     # Oracle can't group by LOB (large object) data types. | ||||
|     allows_group_by_lob = True | ||||
|     allows_group_by_selected_pks = False | ||||
|     allows_group_by_select_index = True | ||||
|     empty_fetchmany_value = [] | ||||
|     update_can_self_select = True | ||||
|  | ||||
|     # Does the backend distinguish between '' and None? | ||||
|     interprets_empty_strings_as_nulls = False | ||||
|  | ||||
|     # Does the backend allow inserting duplicate NULL rows in a nullable | ||||
|     # unique field? All core backends implement this correctly, but other | ||||
|     # databases such as SQL Server do not. | ||||
|     supports_nullable_unique_constraints = True | ||||
|  | ||||
|     # Does the backend allow inserting duplicate rows when a unique_together | ||||
|     # constraint exists and some fields are nullable but not all of them? | ||||
|     supports_partially_nullable_unique_constraints = True | ||||
|     # Does the backend support initially deferrable unique constraints? | ||||
|     supports_deferrable_unique_constraints = False | ||||
|  | ||||
|     can_use_chunked_reads = True | ||||
|     can_return_columns_from_insert = False | ||||
|     can_return_rows_from_bulk_insert = False | ||||
|     has_bulk_insert = True | ||||
|     uses_savepoints = True | ||||
|     can_release_savepoints = False | ||||
|  | ||||
|     # If True, don't use integer foreign keys referring to, e.g., positive | ||||
|     # integer primary keys. | ||||
|     related_fields_match_type = False | ||||
|     allow_sliced_subqueries_with_in = True | ||||
|     has_select_for_update = False | ||||
|     has_select_for_update_nowait = False | ||||
|     has_select_for_update_skip_locked = False | ||||
|     has_select_for_update_of = False | ||||
|     has_select_for_no_key_update = False | ||||
|     # Does the database's SELECT FOR UPDATE OF syntax require a column rather | ||||
|     # than a table? | ||||
|     select_for_update_of_column = False | ||||
|  | ||||
|     # Does the default test database allow multiple connections? | ||||
|     # Usually an indication that the test database is in-memory | ||||
|     test_db_allows_multiple_connections = True | ||||
|  | ||||
|     # Can an object be saved without an explicit primary key? | ||||
|     supports_unspecified_pk = False | ||||
|  | ||||
|     # Can a fixture contain forward references? i.e., are | ||||
|     # FK constraints checked at the end of transaction, or | ||||
|     # at the end of each save operation? | ||||
|     supports_forward_references = True | ||||
|  | ||||
|     # Does the backend truncate names properly when they are too long? | ||||
|     truncates_names = False | ||||
|  | ||||
|     # Is there a REAL datatype in addition to floats/doubles? | ||||
|     has_real_datatype = False | ||||
|     supports_subqueries_in_group_by = True | ||||
|  | ||||
|     # Does the backend ignore unnecessary ORDER BY clauses in subqueries? | ||||
|     ignores_unnecessary_order_by_in_subqueries = True | ||||
|  | ||||
|     # Is there a true datatype for uuid? | ||||
|     has_native_uuid_field = False | ||||
|  | ||||
|     # Is there a true datatype for timedeltas? | ||||
|     has_native_duration_field = False | ||||
|  | ||||
|     # Does the database driver supports same type temporal data subtraction | ||||
|     # by returning the type used to store duration field? | ||||
|     supports_temporal_subtraction = False | ||||
|  | ||||
|     # Does the __regex lookup support backreferencing and grouping? | ||||
|     supports_regex_backreferencing = True | ||||
|  | ||||
|     # Can date/datetime lookups be performed using a string? | ||||
|     supports_date_lookup_using_string = True | ||||
|  | ||||
|     # Can datetimes with timezones be used? | ||||
|     supports_timezones = True | ||||
|  | ||||
|     # Does the database have a copy of the zoneinfo database? | ||||
|     has_zoneinfo_database = True | ||||
|  | ||||
|     # When performing a GROUP BY, is an ORDER BY NULL required | ||||
|     # to remove any ordering? | ||||
|     requires_explicit_null_ordering_when_grouping = False | ||||
|  | ||||
|     # Does the backend order NULL values as largest or smallest? | ||||
|     nulls_order_largest = False | ||||
|  | ||||
|     # Does the backend support NULLS FIRST and NULLS LAST in ORDER BY? | ||||
|     supports_order_by_nulls_modifier = True | ||||
|  | ||||
|     # Does the backend orders NULLS FIRST by default? | ||||
|     order_by_nulls_first = False | ||||
|  | ||||
|     # The database's limit on the number of query parameters. | ||||
|     max_query_params = None | ||||
|  | ||||
|     # Can an object have an autoincrement primary key of 0? | ||||
|     allows_auto_pk_0 = True | ||||
|  | ||||
|     # Do we need to NULL a ForeignKey out, or can the constraint check be | ||||
|     # deferred | ||||
|     can_defer_constraint_checks = False | ||||
|  | ||||
|     # Does the backend support tablespaces? Default to False because it isn't | ||||
|     # in the SQL standard. | ||||
|     supports_tablespaces = False | ||||
|  | ||||
|     # Does the backend reset sequences between tests? | ||||
|     supports_sequence_reset = True | ||||
|  | ||||
|     # Can the backend introspect the default value of a column? | ||||
|     can_introspect_default = True | ||||
|  | ||||
|     # Confirm support for introspected foreign keys | ||||
|     # Every database can do this reliably, except MySQL, | ||||
|     # which can't do it for MyISAM tables | ||||
|     can_introspect_foreign_keys = True | ||||
|  | ||||
|     # Map fields which some backends may not be able to differentiate to the | ||||
|     # field it's introspected as. | ||||
|     introspected_field_types = { | ||||
|         "AutoField": "AutoField", | ||||
|         "BigAutoField": "BigAutoField", | ||||
|         "BigIntegerField": "BigIntegerField", | ||||
|         "BinaryField": "BinaryField", | ||||
|         "BooleanField": "BooleanField", | ||||
|         "CharField": "CharField", | ||||
|         "DurationField": "DurationField", | ||||
|         "GenericIPAddressField": "GenericIPAddressField", | ||||
|         "IntegerField": "IntegerField", | ||||
|         "PositiveBigIntegerField": "PositiveBigIntegerField", | ||||
|         "PositiveIntegerField": "PositiveIntegerField", | ||||
|         "PositiveSmallIntegerField": "PositiveSmallIntegerField", | ||||
|         "SmallAutoField": "SmallAutoField", | ||||
|         "SmallIntegerField": "SmallIntegerField", | ||||
|         "TimeField": "TimeField", | ||||
|     } | ||||
|  | ||||
|     # Can the backend introspect the column order (ASC/DESC) for indexes? | ||||
|     supports_index_column_ordering = True | ||||
|  | ||||
|     # Does the backend support introspection of materialized views? | ||||
|     can_introspect_materialized_views = False | ||||
|  | ||||
|     # Support for the DISTINCT ON clause | ||||
|     can_distinct_on_fields = False | ||||
|  | ||||
|     # Does the backend prevent running SQL queries in broken transactions? | ||||
|     atomic_transactions = True | ||||
|  | ||||
|     # Can we roll back DDL in a transaction? | ||||
|     can_rollback_ddl = False | ||||
|  | ||||
|     schema_editor_uses_clientside_param_binding = False | ||||
|  | ||||
|     # Does it support operations requiring references rename in a transaction? | ||||
|     supports_atomic_references_rename = True | ||||
|  | ||||
|     # Can we issue more than one ALTER COLUMN clause in an ALTER TABLE? | ||||
|     supports_combined_alters = False | ||||
|  | ||||
|     # Does it support foreign keys? | ||||
|     supports_foreign_keys = True | ||||
|  | ||||
|     # Can it create foreign key constraints inline when adding columns? | ||||
|     can_create_inline_fk = True | ||||
|  | ||||
|     # Can an index be renamed? | ||||
|     can_rename_index = False | ||||
|  | ||||
|     # Does it automatically index foreign keys? | ||||
|     indexes_foreign_keys = True | ||||
|  | ||||
|     # Does it support CHECK constraints? | ||||
|     supports_column_check_constraints = True | ||||
|     supports_table_check_constraints = True | ||||
|     # Does the backend support introspection of CHECK constraints? | ||||
|     can_introspect_check_constraints = True | ||||
|  | ||||
|     # Does the backend support 'pyformat' style ("... %(name)s ...", {'name': value}) | ||||
|     # parameter passing? Note this can be provided by the backend even if not | ||||
|     # supported by the Python driver | ||||
|     supports_paramstyle_pyformat = True | ||||
|  | ||||
|     # Does the backend require literal defaults, rather than parameterized ones? | ||||
|     requires_literal_defaults = False | ||||
|  | ||||
|     # Does the backend require a connection reset after each material schema change? | ||||
|     connection_persists_old_columns = False | ||||
|  | ||||
|     # What kind of error does the backend throw when accessing closed cursor? | ||||
|     closed_cursor_error_class = ProgrammingError | ||||
|  | ||||
|     # Does 'a' LIKE 'A' match? | ||||
|     has_case_insensitive_like = False | ||||
|  | ||||
|     # Suffix for backends that don't support "SELECT xxx;" queries. | ||||
|     bare_select_suffix = "" | ||||
|  | ||||
|     # If NULL is implied on columns without needing to be explicitly specified | ||||
|     implied_column_null = False | ||||
|  | ||||
|     # Does the backend support "select for update" queries with limit (and offset)? | ||||
|     supports_select_for_update_with_limit = True | ||||
|  | ||||
|     # Does the backend ignore null expressions in GREATEST and LEAST queries unless | ||||
|     # every expression is null? | ||||
|     greatest_least_ignores_nulls = False | ||||
|  | ||||
|     # Can the backend clone databases for parallel test execution? | ||||
|     # Defaults to False to allow third-party backends to opt-in. | ||||
|     can_clone_databases = False | ||||
|  | ||||
|     # Does the backend consider table names with different casing to | ||||
|     # be equal? | ||||
|     ignores_table_name_case = False | ||||
|  | ||||
|     # Place FOR UPDATE right after FROM clause. Used on MSSQL. | ||||
|     for_update_after_from = False | ||||
|  | ||||
|     # Combinatorial flags | ||||
|     supports_select_union = True | ||||
|     supports_select_intersection = True | ||||
|     supports_select_difference = True | ||||
|     supports_slicing_ordering_in_compound = False | ||||
|     supports_parentheses_in_compound = True | ||||
|     requires_compound_order_by_subquery = False | ||||
|  | ||||
|     # Does the database support SQL 2003 FILTER (WHERE ...) in aggregate | ||||
|     # expressions? | ||||
|     supports_aggregate_filter_clause = False | ||||
|  | ||||
|     # Does the backend support indexing a TextField? | ||||
|     supports_index_on_text_field = True | ||||
|  | ||||
|     # Does the backend support window expressions (expression OVER (...))? | ||||
|     supports_over_clause = False | ||||
|     supports_frame_range_fixed_distance = False | ||||
|     only_supports_unbounded_with_preceding_and_following = False | ||||
|  | ||||
|     # Does the backend support CAST with precision? | ||||
|     supports_cast_with_precision = True | ||||
|  | ||||
|     # How many second decimals does the database return when casting a value to | ||||
|     # a type with time? | ||||
|     time_cast_precision = 6 | ||||
|  | ||||
|     # SQL to create a procedure for use by the Django test suite. The | ||||
|     # functionality of the procedure isn't important. | ||||
|     create_test_procedure_without_params_sql = None | ||||
|     create_test_procedure_with_int_param_sql = None | ||||
|  | ||||
|     # SQL to create a table with a composite primary key for use by the Django | ||||
|     # test suite. | ||||
|     create_test_table_with_composite_primary_key = None | ||||
|  | ||||
|     # Does the backend support keyword parameters for cursor.callproc()? | ||||
|     supports_callproc_kwargs = False | ||||
|  | ||||
|     # What formats does the backend EXPLAIN syntax support? | ||||
|     supported_explain_formats = set() | ||||
|  | ||||
|     # Does the backend support the default parameter in lead() and lag()? | ||||
|     supports_default_in_lead_lag = True | ||||
|  | ||||
|     # Does the backend support ignoring constraint or uniqueness errors during | ||||
|     # INSERT? | ||||
|     supports_ignore_conflicts = True | ||||
|     # Does the backend support updating rows on constraint or uniqueness errors | ||||
|     # during INSERT? | ||||
|     supports_update_conflicts = False | ||||
|     supports_update_conflicts_with_target = False | ||||
|  | ||||
|     # Does this backend require casting the results of CASE expressions used | ||||
|     # in UPDATE statements to ensure the expression has the correct type? | ||||
|     requires_casted_case_in_updates = False | ||||
|  | ||||
|     # Does the backend support partial indexes (CREATE INDEX ... WHERE ...)? | ||||
|     supports_partial_indexes = True | ||||
|     supports_functions_in_partial_indexes = True | ||||
|     # Does the backend support covering indexes (CREATE INDEX ... INCLUDE ...)? | ||||
|     supports_covering_indexes = False | ||||
|     # Does the backend support indexes on expressions? | ||||
|     supports_expression_indexes = True | ||||
|     # Does the backend treat COLLATE as an indexed expression? | ||||
|     collate_as_index_expression = False | ||||
|  | ||||
|     # Does the database allow more than one constraint or index on the same | ||||
|     # field(s)? | ||||
|     allows_multiple_constraints_on_same_fields = True | ||||
|  | ||||
|     # Does the backend support boolean expressions in SELECT and GROUP BY | ||||
|     # clauses? | ||||
|     supports_boolean_expr_in_select_clause = True | ||||
|     # Does the backend support comparing boolean expressions in WHERE clauses? | ||||
|     # Eg: WHERE (price > 0) IS NOT NULL | ||||
|     supports_comparing_boolean_expr = True | ||||
|  | ||||
|     # Does the backend support JSONField? | ||||
|     supports_json_field = True | ||||
|     # Can the backend introspect a JSONField? | ||||
|     can_introspect_json_field = True | ||||
|     # Does the backend support primitives in JSONField? | ||||
|     supports_primitives_in_json_field = True | ||||
|     # Is there a true datatype for JSON? | ||||
|     has_native_json_field = False | ||||
|     # Does the backend use PostgreSQL-style JSON operators like '->'? | ||||
|     has_json_operators = False | ||||
|     # Does the backend support __contains and __contained_by lookups for | ||||
|     # a JSONField? | ||||
|     supports_json_field_contains = True | ||||
|     # Does value__d__contains={'f': 'g'} (without a list around the dict) match | ||||
|     # {'d': [{'f': 'g'}]}? | ||||
|     json_key_contains_list_matching_requires_list = False | ||||
|     # Does the backend support JSONObject() database function? | ||||
|     has_json_object_function = True | ||||
|  | ||||
|     # Does the backend support column collations? | ||||
|     supports_collation_on_charfield = True | ||||
|     supports_collation_on_textfield = True | ||||
|     # Does the backend support non-deterministic collations? | ||||
|     supports_non_deterministic_collations = True | ||||
|  | ||||
|     # Does the backend support column and table comments? | ||||
|     supports_comments = False | ||||
|     # Does the backend support column comments in ADD COLUMN statements? | ||||
|     supports_comments_inline = False | ||||
|  | ||||
|     # Does the backend support the logical XOR operator? | ||||
|     supports_logical_xor = False | ||||
|  | ||||
|     # Set to (exception, message) if null characters in text are disallowed. | ||||
|     prohibits_null_characters_in_text_exception = None | ||||
|  | ||||
|     # Does the backend support unlimited character columns? | ||||
|     supports_unlimited_charfield = False | ||||
|  | ||||
|     # Collation names for use by the Django test suite. | ||||
|     test_collations = { | ||||
|         "ci": None,  # Case-insensitive. | ||||
|         "cs": None,  # Case-sensitive. | ||||
|         "non_default": None,  # Non-default. | ||||
|         "swedish_ci": None,  # Swedish case-insensitive. | ||||
|     } | ||||
|     # SQL template override for tests.aggregation.tests.NowUTC | ||||
|     test_now_utc_template = None | ||||
|  | ||||
|     # A set of dotted paths to tests in Django's test suite that are expected | ||||
|     # to fail on this database. | ||||
|     django_test_expected_failures = set() | ||||
|     # A map of reasons to sets of dotted paths to tests in Django's test suite | ||||
|     # that should be skipped for this database. | ||||
|     django_test_skips = {} | ||||
|  | ||||
|     def __init__(self, connection): | ||||
|         self.connection = connection | ||||
|  | ||||
|     @cached_property | ||||
|     def supports_explaining_query_execution(self): | ||||
|         """Does this backend support explaining query execution?""" | ||||
|         return self.connection.ops.explain_prefix is not None | ||||
|  | ||||
|     @cached_property | ||||
|     def supports_transactions(self): | ||||
|         """Confirm support for transactions.""" | ||||
|         with self.connection.cursor() as cursor: | ||||
|             cursor.execute("CREATE TABLE ROLLBACK_TEST (X INT)") | ||||
|             self.connection.set_autocommit(False) | ||||
|             cursor.execute("INSERT INTO ROLLBACK_TEST (X) VALUES (8)") | ||||
|             self.connection.rollback() | ||||
|             self.connection.set_autocommit(True) | ||||
|             cursor.execute("SELECT COUNT(X) FROM ROLLBACK_TEST") | ||||
|             (count,) = cursor.fetchone() | ||||
|             cursor.execute("DROP TABLE ROLLBACK_TEST") | ||||
|         return count == 0 | ||||
|  | ||||
|     def allows_group_by_selected_pks_on_model(self, model): | ||||
|         if not self.allows_group_by_selected_pks: | ||||
|             return False | ||||
|         return model._meta.managed | ||||
| @ -0,0 +1,212 @@ | ||||
| from collections import namedtuple | ||||
|  | ||||
| # Structure returned by DatabaseIntrospection.get_table_list() | ||||
| TableInfo = namedtuple("TableInfo", ["name", "type"]) | ||||
|  | ||||
| # Structure returned by the DB-API cursor.description interface (PEP 249) | ||||
| FieldInfo = namedtuple( | ||||
|     "FieldInfo", | ||||
|     "name type_code display_size internal_size precision scale null_ok " | ||||
|     "default collation", | ||||
| ) | ||||
|  | ||||
|  | ||||
| class BaseDatabaseIntrospection: | ||||
|     """Encapsulate backend-specific introspection utilities.""" | ||||
|  | ||||
|     data_types_reverse = {} | ||||
|  | ||||
|     def __init__(self, connection): | ||||
|         self.connection = connection | ||||
|  | ||||
|     def get_field_type(self, data_type, description): | ||||
|         """ | ||||
|         Hook for a database backend to use the cursor description to | ||||
|         match a Django field type to a database column. | ||||
|  | ||||
|         For Oracle, the column data_type on its own is insufficient to | ||||
|         distinguish between a FloatField and IntegerField, for example. | ||||
|         """ | ||||
|         return self.data_types_reverse[data_type] | ||||
|  | ||||
|     def identifier_converter(self, name): | ||||
|         """ | ||||
|         Apply a conversion to the identifier for the purposes of comparison. | ||||
|  | ||||
|         The default identifier converter is for case sensitive comparison. | ||||
|         """ | ||||
|         return name | ||||
|  | ||||
|     def table_names(self, cursor=None, include_views=False): | ||||
|         """ | ||||
|         Return a list of names of all tables that exist in the database. | ||||
|         Sort the returned table list by Python's default sorting. Do NOT use | ||||
|         the database's ORDER BY here to avoid subtle differences in sorting | ||||
|         order between databases. | ||||
|         """ | ||||
|  | ||||
|         def get_names(cursor): | ||||
|             return sorted( | ||||
|                 ti.name | ||||
|                 for ti in self.get_table_list(cursor) | ||||
|                 if include_views or ti.type == "t" | ||||
|             ) | ||||
|  | ||||
|         if cursor is None: | ||||
|             with self.connection.cursor() as cursor: | ||||
|                 return get_names(cursor) | ||||
|         return get_names(cursor) | ||||
|  | ||||
|     def get_table_list(self, cursor): | ||||
|         """ | ||||
|         Return an unsorted list of TableInfo named tuples of all tables and | ||||
|         views that exist in the database. | ||||
|         """ | ||||
|         raise NotImplementedError( | ||||
|             "subclasses of BaseDatabaseIntrospection may require a get_table_list() " | ||||
|             "method" | ||||
|         ) | ||||
|  | ||||
|     def get_table_description(self, cursor, table_name): | ||||
|         """ | ||||
|         Return a description of the table with the DB-API cursor.description | ||||
|         interface. | ||||
|         """ | ||||
|         raise NotImplementedError( | ||||
|             "subclasses of BaseDatabaseIntrospection may require a " | ||||
|             "get_table_description() method." | ||||
|         ) | ||||
|  | ||||
|     def get_migratable_models(self): | ||||
|         from django.apps import apps | ||||
|         from django.db import router | ||||
|  | ||||
|         return ( | ||||
|             model | ||||
|             for app_config in apps.get_app_configs() | ||||
|             for model in router.get_migratable_models(app_config, self.connection.alias) | ||||
|             if model._meta.can_migrate(self.connection) | ||||
|         ) | ||||
|  | ||||
|     def django_table_names(self, only_existing=False, include_views=True): | ||||
|         """ | ||||
|         Return a list of all table names that have associated Django models and | ||||
|         are in INSTALLED_APPS. | ||||
|  | ||||
|         If only_existing is True, include only the tables in the database. | ||||
|         """ | ||||
|         tables = set() | ||||
|         for model in self.get_migratable_models(): | ||||
|             if not model._meta.managed: | ||||
|                 continue | ||||
|             tables.add(model._meta.db_table) | ||||
|             tables.update( | ||||
|                 f.m2m_db_table() | ||||
|                 for f in model._meta.local_many_to_many | ||||
|                 if f.remote_field.through._meta.managed | ||||
|             ) | ||||
|         tables = list(tables) | ||||
|         if only_existing: | ||||
|             existing_tables = set(self.table_names(include_views=include_views)) | ||||
|             tables = [ | ||||
|                 t for t in tables if self.identifier_converter(t) in existing_tables | ||||
|             ] | ||||
|         return tables | ||||
|  | ||||
|     def installed_models(self, tables): | ||||
|         """ | ||||
|         Return a set of all models represented by the provided list of table | ||||
|         names. | ||||
|         """ | ||||
|         tables = set(map(self.identifier_converter, tables)) | ||||
|         return { | ||||
|             m | ||||
|             for m in self.get_migratable_models() | ||||
|             if self.identifier_converter(m._meta.db_table) in tables | ||||
|         } | ||||
|  | ||||
|     def sequence_list(self): | ||||
|         """ | ||||
|         Return a list of information about all DB sequences for all models in | ||||
|         all apps. | ||||
|         """ | ||||
|         sequence_list = [] | ||||
|         with self.connection.cursor() as cursor: | ||||
|             for model in self.get_migratable_models(): | ||||
|                 if not model._meta.managed: | ||||
|                     continue | ||||
|                 if model._meta.swapped: | ||||
|                     continue | ||||
|                 sequence_list.extend( | ||||
|                     self.get_sequences( | ||||
|                         cursor, model._meta.db_table, model._meta.local_fields | ||||
|                     ) | ||||
|                 ) | ||||
|                 for f in model._meta.local_many_to_many: | ||||
|                     # If this is an m2m using an intermediate table, | ||||
|                     # we don't need to reset the sequence. | ||||
|                     if f.remote_field.through._meta.auto_created: | ||||
|                         sequence = self.get_sequences(cursor, f.m2m_db_table()) | ||||
|                         sequence_list.extend( | ||||
|                             sequence or [{"table": f.m2m_db_table(), "column": None}] | ||||
|                         ) | ||||
|         return sequence_list | ||||
|  | ||||
|     def get_sequences(self, cursor, table_name, table_fields=()): | ||||
|         """ | ||||
|         Return a list of introspected sequences for table_name. Each sequence | ||||
|         is a dict: {'table': <table_name>, 'column': <column_name>}. An optional | ||||
|         'name' key can be added if the backend supports named sequences. | ||||
|         """ | ||||
|         raise NotImplementedError( | ||||
|             "subclasses of BaseDatabaseIntrospection may require a get_sequences() " | ||||
|             "method" | ||||
|         ) | ||||
|  | ||||
|     def get_relations(self, cursor, table_name): | ||||
|         """ | ||||
|         Return a dictionary of {field_name: (field_name_other_table, other_table)} | ||||
|         representing all foreign keys in the given table. | ||||
|         """ | ||||
|         raise NotImplementedError( | ||||
|             "subclasses of BaseDatabaseIntrospection may require a " | ||||
|             "get_relations() method." | ||||
|         ) | ||||
|  | ||||
|     def get_primary_key_column(self, cursor, table_name): | ||||
|         """ | ||||
|         Return the name of the primary key column for the given table. | ||||
|         """ | ||||
|         columns = self.get_primary_key_columns(cursor, table_name) | ||||
|         return columns[0] if columns else None | ||||
|  | ||||
|     def get_primary_key_columns(self, cursor, table_name): | ||||
|         """Return a list of primary key columns for the given table.""" | ||||
|         for constraint in self.get_constraints(cursor, table_name).values(): | ||||
|             if constraint["primary_key"]: | ||||
|                 return constraint["columns"] | ||||
|         return None | ||||
|  | ||||
|     def get_constraints(self, cursor, table_name): | ||||
|         """ | ||||
|         Retrieve any constraints or keys (unique, pk, fk, check, index) | ||||
|         across one or more columns. | ||||
|  | ||||
|         Return a dict mapping constraint names to their attributes, | ||||
|         where attributes is a dict with keys: | ||||
|          * columns: List of columns this covers | ||||
|          * primary_key: True if primary key, False otherwise | ||||
|          * unique: True if this is a unique constraint, False otherwise | ||||
|          * foreign_key: (table, column) of target, or None | ||||
|          * check: True if check constraint, False otherwise | ||||
|          * index: True if index, False otherwise. | ||||
|          * orders: The order (ASC/DESC) defined for the columns of indexes | ||||
|          * type: The type of the index (btree, hash, etc.) | ||||
|  | ||||
|         Some backends may return special constraint names that don't exist | ||||
|         if they don't name constraints of a certain type (e.g. SQLite) | ||||
|         """ | ||||
|         raise NotImplementedError( | ||||
|             "subclasses of BaseDatabaseIntrospection may require a get_constraints() " | ||||
|             "method" | ||||
|         ) | ||||
| @ -0,0 +1,778 @@ | ||||
| import datetime | ||||
| import decimal | ||||
| import json | ||||
| from importlib import import_module | ||||
|  | ||||
| import sqlparse | ||||
|  | ||||
| from django.conf import settings | ||||
| from django.db import NotSupportedError, transaction | ||||
| from django.db.backends import utils | ||||
| from django.utils import timezone | ||||
| from django.utils.encoding import force_str | ||||
|  | ||||
|  | ||||
| class BaseDatabaseOperations: | ||||
|     """ | ||||
|     Encapsulate backend-specific differences, such as the way a backend | ||||
|     performs ordering or calculates the ID of a recently-inserted row. | ||||
|     """ | ||||
|  | ||||
|     compiler_module = "django.db.models.sql.compiler" | ||||
|  | ||||
|     # Integer field safe ranges by `internal_type` as documented | ||||
|     # in docs/ref/models/fields.txt. | ||||
|     integer_field_ranges = { | ||||
|         "SmallIntegerField": (-32768, 32767), | ||||
|         "IntegerField": (-2147483648, 2147483647), | ||||
|         "BigIntegerField": (-9223372036854775808, 9223372036854775807), | ||||
|         "PositiveBigIntegerField": (0, 9223372036854775807), | ||||
|         "PositiveSmallIntegerField": (0, 32767), | ||||
|         "PositiveIntegerField": (0, 2147483647), | ||||
|         "SmallAutoField": (-32768, 32767), | ||||
|         "AutoField": (-2147483648, 2147483647), | ||||
|         "BigAutoField": (-9223372036854775808, 9223372036854775807), | ||||
|     } | ||||
|     set_operators = { | ||||
|         "union": "UNION", | ||||
|         "intersection": "INTERSECT", | ||||
|         "difference": "EXCEPT", | ||||
|     } | ||||
|     # Mapping of Field.get_internal_type() (typically the model field's class | ||||
|     # name) to the data type to use for the Cast() function, if different from | ||||
|     # DatabaseWrapper.data_types. | ||||
|     cast_data_types = {} | ||||
|     # CharField data type if the max_length argument isn't provided. | ||||
|     cast_char_field_without_max_length = None | ||||
|  | ||||
|     # Start and end points for window expressions. | ||||
|     PRECEDING = "PRECEDING" | ||||
|     FOLLOWING = "FOLLOWING" | ||||
|     UNBOUNDED_PRECEDING = "UNBOUNDED " + PRECEDING | ||||
|     UNBOUNDED_FOLLOWING = "UNBOUNDED " + FOLLOWING | ||||
|     CURRENT_ROW = "CURRENT ROW" | ||||
|  | ||||
|     # Prefix for EXPLAIN queries, or None EXPLAIN isn't supported. | ||||
|     explain_prefix = None | ||||
|  | ||||
|     def __init__(self, connection): | ||||
|         self.connection = connection | ||||
|         self._cache = None | ||||
|  | ||||
|     def autoinc_sql(self, table, column): | ||||
|         """ | ||||
|         Return any SQL needed to support auto-incrementing primary keys, or | ||||
|         None if no SQL is necessary. | ||||
|  | ||||
|         This SQL is executed when a table is created. | ||||
|         """ | ||||
|         return None | ||||
|  | ||||
|     def bulk_batch_size(self, fields, objs): | ||||
|         """ | ||||
|         Return the maximum allowed batch size for the backend. The fields | ||||
|         are the fields going to be inserted in the batch, the objs contains | ||||
|         all the objects to be inserted. | ||||
|         """ | ||||
|         return len(objs) | ||||
|  | ||||
|     def format_for_duration_arithmetic(self, sql): | ||||
|         raise NotImplementedError( | ||||
|             "subclasses of BaseDatabaseOperations may require a " | ||||
|             "format_for_duration_arithmetic() method." | ||||
|         ) | ||||
|  | ||||
|     def cache_key_culling_sql(self): | ||||
|         """ | ||||
|         Return an SQL query that retrieves the first cache key greater than the | ||||
|         n smallest. | ||||
|  | ||||
|         This is used by the 'db' cache backend to determine where to start | ||||
|         culling. | ||||
|         """ | ||||
|         cache_key = self.quote_name("cache_key") | ||||
|         return f"SELECT {cache_key} FROM %s ORDER BY {cache_key} LIMIT 1 OFFSET %%s" | ||||
|  | ||||
|     def unification_cast_sql(self, output_field): | ||||
|         """ | ||||
|         Given a field instance, return the SQL that casts the result of a union | ||||
|         to that type. The resulting string should contain a '%s' placeholder | ||||
|         for the expression being cast. | ||||
|         """ | ||||
|         return "%s" | ||||
|  | ||||
|     def date_extract_sql(self, lookup_type, sql, params): | ||||
|         """ | ||||
|         Given a lookup_type of 'year', 'month', or 'day', return the SQL that | ||||
|         extracts a value from the given date field field_name. | ||||
|         """ | ||||
|         raise NotImplementedError( | ||||
|             "subclasses of BaseDatabaseOperations may require a date_extract_sql() " | ||||
|             "method" | ||||
|         ) | ||||
|  | ||||
|     def date_trunc_sql(self, lookup_type, sql, params, tzname=None): | ||||
|         """ | ||||
|         Given a lookup_type of 'year', 'month', or 'day', return the SQL that | ||||
|         truncates the given date or datetime field field_name to a date object | ||||
|         with only the given specificity. | ||||
|  | ||||
|         If `tzname` is provided, the given value is truncated in a specific | ||||
|         timezone. | ||||
|         """ | ||||
|         raise NotImplementedError( | ||||
|             "subclasses of BaseDatabaseOperations may require a date_trunc_sql() " | ||||
|             "method." | ||||
|         ) | ||||
|  | ||||
|     def datetime_cast_date_sql(self, sql, params, tzname): | ||||
|         """ | ||||
|         Return the SQL to cast a datetime value to date value. | ||||
|         """ | ||||
|         raise NotImplementedError( | ||||
|             "subclasses of BaseDatabaseOperations may require a " | ||||
|             "datetime_cast_date_sql() method." | ||||
|         ) | ||||
|  | ||||
|     def datetime_cast_time_sql(self, sql, params, tzname): | ||||
|         """ | ||||
|         Return the SQL to cast a datetime value to time value. | ||||
|         """ | ||||
|         raise NotImplementedError( | ||||
|             "subclasses of BaseDatabaseOperations may require a " | ||||
|             "datetime_cast_time_sql() method" | ||||
|         ) | ||||
|  | ||||
|     def datetime_extract_sql(self, lookup_type, sql, params, tzname): | ||||
|         """ | ||||
|         Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or | ||||
|         'second', return the SQL that extracts a value from the given | ||||
|         datetime field field_name. | ||||
|         """ | ||||
|         raise NotImplementedError( | ||||
|             "subclasses of BaseDatabaseOperations may require a datetime_extract_sql() " | ||||
|             "method" | ||||
|         ) | ||||
|  | ||||
|     def datetime_trunc_sql(self, lookup_type, sql, params, tzname): | ||||
|         """ | ||||
|         Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or | ||||
|         'second', return the SQL that truncates the given datetime field | ||||
|         field_name to a datetime object with only the given specificity. | ||||
|         """ | ||||
|         raise NotImplementedError( | ||||
|             "subclasses of BaseDatabaseOperations may require a datetime_trunc_sql() " | ||||
|             "method" | ||||
|         ) | ||||
|  | ||||
|     def time_trunc_sql(self, lookup_type, sql, params, tzname=None): | ||||
|         """ | ||||
|         Given a lookup_type of 'hour', 'minute' or 'second', return the SQL | ||||
|         that truncates the given time or datetime field field_name to a time | ||||
|         object with only the given specificity. | ||||
|  | ||||
|         If `tzname` is provided, the given value is truncated in a specific | ||||
|         timezone. | ||||
|         """ | ||||
|         raise NotImplementedError( | ||||
|             "subclasses of BaseDatabaseOperations may require a time_trunc_sql() method" | ||||
|         ) | ||||
|  | ||||
|     def time_extract_sql(self, lookup_type, sql, params): | ||||
|         """ | ||||
|         Given a lookup_type of 'hour', 'minute', or 'second', return the SQL | ||||
|         that extracts a value from the given time field field_name. | ||||
|         """ | ||||
|         return self.date_extract_sql(lookup_type, sql, params) | ||||
|  | ||||
|     def deferrable_sql(self): | ||||
|         """ | ||||
|         Return the SQL to make a constraint "initially deferred" during a | ||||
|         CREATE TABLE statement. | ||||
|         """ | ||||
|         return "" | ||||
|  | ||||
|     def distinct_sql(self, fields, params): | ||||
|         """ | ||||
|         Return an SQL DISTINCT clause which removes duplicate rows from the | ||||
|         result set. If any fields are given, only check the given fields for | ||||
|         duplicates. | ||||
|         """ | ||||
|         if fields: | ||||
|             raise NotSupportedError( | ||||
|                 "DISTINCT ON fields is not supported by this database backend" | ||||
|             ) | ||||
|         else: | ||||
|             return ["DISTINCT"], [] | ||||
|  | ||||
|     def fetch_returned_insert_columns(self, cursor, returning_params): | ||||
|         """ | ||||
|         Given a cursor object that has just performed an INSERT...RETURNING | ||||
|         statement into a table, return the newly created data. | ||||
|         """ | ||||
|         return cursor.fetchone() | ||||
|  | ||||
|     def field_cast_sql(self, db_type, internal_type): | ||||
|         """ | ||||
|         Given a column type (e.g. 'BLOB', 'VARCHAR') and an internal type | ||||
|         (e.g. 'GenericIPAddressField'), return the SQL to cast it before using | ||||
|         it in a WHERE statement. The resulting string should contain a '%s' | ||||
|         placeholder for the column being searched against. | ||||
|         """ | ||||
|         return "%s" | ||||
|  | ||||
|     def force_no_ordering(self): | ||||
|         """ | ||||
|         Return a list used in the "ORDER BY" clause to force no ordering at | ||||
|         all. Return an empty list to include nothing in the ordering. | ||||
|         """ | ||||
|         return [] | ||||
|  | ||||
|     def for_update_sql(self, nowait=False, skip_locked=False, of=(), no_key=False): | ||||
|         """ | ||||
|         Return the FOR UPDATE SQL clause to lock rows for an update operation. | ||||
|         """ | ||||
|         return "FOR%s UPDATE%s%s%s" % ( | ||||
|             " NO KEY" if no_key else "", | ||||
|             " OF %s" % ", ".join(of) if of else "", | ||||
|             " NOWAIT" if nowait else "", | ||||
|             " SKIP LOCKED" if skip_locked else "", | ||||
|         ) | ||||
|  | ||||
|     def _get_limit_offset_params(self, low_mark, high_mark): | ||||
|         offset = low_mark or 0 | ||||
|         if high_mark is not None: | ||||
|             return (high_mark - offset), offset | ||||
|         elif offset: | ||||
|             return self.connection.ops.no_limit_value(), offset | ||||
|         return None, offset | ||||
|  | ||||
|     def limit_offset_sql(self, low_mark, high_mark): | ||||
|         """Return LIMIT/OFFSET SQL clause.""" | ||||
|         limit, offset = self._get_limit_offset_params(low_mark, high_mark) | ||||
|         return " ".join( | ||||
|             sql | ||||
|             for sql in ( | ||||
|                 ("LIMIT %d" % limit) if limit else None, | ||||
|                 ("OFFSET %d" % offset) if offset else None, | ||||
|             ) | ||||
|             if sql | ||||
|         ) | ||||
|  | ||||
|     def last_executed_query(self, cursor, sql, params): | ||||
|         """ | ||||
|         Return a string of the query last executed by the given cursor, with | ||||
|         placeholders replaced with actual values. | ||||
|  | ||||
|         `sql` is the raw query containing placeholders and `params` is the | ||||
|         sequence of parameters. These are used by default, but this method | ||||
|         exists for database backends to provide a better implementation | ||||
|         according to their own quoting schemes. | ||||
|         """ | ||||
|  | ||||
|         # Convert params to contain string values. | ||||
|         def to_string(s): | ||||
|             return force_str(s, strings_only=True, errors="replace") | ||||
|  | ||||
|         if isinstance(params, (list, tuple)): | ||||
|             u_params = tuple(to_string(val) for val in params) | ||||
|         elif params is None: | ||||
|             u_params = () | ||||
|         else: | ||||
|             u_params = {to_string(k): to_string(v) for k, v in params.items()} | ||||
|  | ||||
|         return "QUERY = %r - PARAMS = %r" % (sql, u_params) | ||||
|  | ||||
|     def last_insert_id(self, cursor, table_name, pk_name): | ||||
|         """ | ||||
|         Given a cursor object that has just performed an INSERT statement into | ||||
|         a table that has an auto-incrementing ID, return the newly created ID. | ||||
|  | ||||
|         `pk_name` is the name of the primary-key column. | ||||
|         """ | ||||
|         return cursor.lastrowid | ||||
|  | ||||
|     def lookup_cast(self, lookup_type, internal_type=None): | ||||
|         """ | ||||
|         Return the string to use in a query when performing lookups | ||||
|         ("contains", "like", etc.). It should contain a '%s' placeholder for | ||||
|         the column being searched against. | ||||
|         """ | ||||
|         return "%s" | ||||
|  | ||||
|     def max_in_list_size(self): | ||||
|         """ | ||||
|         Return the maximum number of items that can be passed in a single 'IN' | ||||
|         list condition, or None if the backend does not impose a limit. | ||||
|         """ | ||||
|         return None | ||||
|  | ||||
|     def max_name_length(self): | ||||
|         """ | ||||
|         Return the maximum length of table and column names, or None if there | ||||
|         is no limit. | ||||
|         """ | ||||
|         return None | ||||
|  | ||||
|     def no_limit_value(self): | ||||
|         """ | ||||
|         Return the value to use for the LIMIT when we are wanting "LIMIT | ||||
|         infinity". Return None if the limit clause can be omitted in this case. | ||||
|         """ | ||||
|         raise NotImplementedError( | ||||
|             "subclasses of BaseDatabaseOperations may require a no_limit_value() method" | ||||
|         ) | ||||
|  | ||||
|     def pk_default_value(self): | ||||
|         """ | ||||
|         Return the value to use during an INSERT statement to specify that | ||||
|         the field should use its default value. | ||||
|         """ | ||||
|         return "DEFAULT" | ||||
|  | ||||
|     def prepare_sql_script(self, sql): | ||||
|         """ | ||||
|         Take an SQL script that may contain multiple lines and return a list | ||||
|         of statements to feed to successive cursor.execute() calls. | ||||
|  | ||||
|         Since few databases are able to process raw SQL scripts in a single | ||||
|         cursor.execute() call and PEP 249 doesn't talk about this use case, | ||||
|         the default implementation is conservative. | ||||
|         """ | ||||
|         return [ | ||||
|             sqlparse.format(statement, strip_comments=True) | ||||
|             for statement in sqlparse.split(sql) | ||||
|             if statement | ||||
|         ] | ||||
|  | ||||
|     def process_clob(self, value): | ||||
|         """ | ||||
|         Return the value of a CLOB column, for backends that return a locator | ||||
|         object that requires additional processing. | ||||
|         """ | ||||
|         return value | ||||
|  | ||||
|     def return_insert_columns(self, fields): | ||||
|         """ | ||||
|         For backends that support returning columns as part of an insert query, | ||||
|         return the SQL and params to append to the INSERT query. The returned | ||||
|         fragment should contain a format string to hold the appropriate column. | ||||
|         """ | ||||
|         pass | ||||
|  | ||||
|     def compiler(self, compiler_name): | ||||
|         """ | ||||
|         Return the SQLCompiler class corresponding to the given name, | ||||
|         in the namespace corresponding to the `compiler_module` attribute | ||||
|         on this backend. | ||||
|         """ | ||||
|         if self._cache is None: | ||||
|             self._cache = import_module(self.compiler_module) | ||||
|         return getattr(self._cache, compiler_name) | ||||
|  | ||||
|     def quote_name(self, name): | ||||
|         """ | ||||
|         Return a quoted version of the given table, index, or column name. Do | ||||
|         not quote the given name if it's already been quoted. | ||||
|         """ | ||||
|         raise NotImplementedError( | ||||
|             "subclasses of BaseDatabaseOperations may require a quote_name() method" | ||||
|         ) | ||||
|  | ||||
|     def regex_lookup(self, lookup_type): | ||||
|         """ | ||||
|         Return the string to use in a query when performing regular expression | ||||
|         lookups (using "regex" or "iregex"). It should contain a '%s' | ||||
|         placeholder for the column being searched against. | ||||
|  | ||||
|         If the feature is not supported (or part of it is not supported), raise | ||||
|         NotImplementedError. | ||||
|         """ | ||||
|         raise NotImplementedError( | ||||
|             "subclasses of BaseDatabaseOperations may require a regex_lookup() method" | ||||
|         ) | ||||
|  | ||||
|     def savepoint_create_sql(self, sid): | ||||
|         """ | ||||
|         Return the SQL for starting a new savepoint. Only required if the | ||||
|         "uses_savepoints" feature is True. The "sid" parameter is a string | ||||
|         for the savepoint id. | ||||
|         """ | ||||
|         return "SAVEPOINT %s" % self.quote_name(sid) | ||||
|  | ||||
|     def savepoint_commit_sql(self, sid): | ||||
|         """ | ||||
|         Return the SQL for committing the given savepoint. | ||||
|         """ | ||||
|         return "RELEASE SAVEPOINT %s" % self.quote_name(sid) | ||||
|  | ||||
|     def savepoint_rollback_sql(self, sid): | ||||
|         """ | ||||
|         Return the SQL for rolling back the given savepoint. | ||||
|         """ | ||||
|         return "ROLLBACK TO SAVEPOINT %s" % self.quote_name(sid) | ||||
|  | ||||
|     def set_time_zone_sql(self): | ||||
|         """ | ||||
|         Return the SQL that will set the connection's time zone. | ||||
|  | ||||
|         Return '' if the backend doesn't support time zones. | ||||
|         """ | ||||
|         return "" | ||||
|  | ||||
|     def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False): | ||||
|         """ | ||||
|         Return a list of SQL statements required to remove all data from | ||||
|         the given database tables (without actually removing the tables | ||||
|         themselves). | ||||
|  | ||||
|         The `style` argument is a Style object as returned by either | ||||
|         color_style() or no_style() in django.core.management.color. | ||||
|  | ||||
|         If `reset_sequences` is True, the list includes SQL statements required | ||||
|         to reset the sequences. | ||||
|  | ||||
|         The `allow_cascade` argument determines whether truncation may cascade | ||||
|         to tables with foreign keys pointing the tables being truncated. | ||||
|         PostgreSQL requires a cascade even if these tables are empty. | ||||
|         """ | ||||
|         raise NotImplementedError( | ||||
|             "subclasses of BaseDatabaseOperations must provide an sql_flush() method" | ||||
|         ) | ||||
|  | ||||
|     def execute_sql_flush(self, sql_list): | ||||
|         """Execute a list of SQL statements to flush the database.""" | ||||
|         with transaction.atomic( | ||||
|             using=self.connection.alias, | ||||
|             savepoint=self.connection.features.can_rollback_ddl, | ||||
|         ): | ||||
|             with self.connection.cursor() as cursor: | ||||
|                 for sql in sql_list: | ||||
|                     cursor.execute(sql) | ||||
|  | ||||
|     def sequence_reset_by_name_sql(self, style, sequences): | ||||
|         """ | ||||
|         Return a list of the SQL statements required to reset sequences | ||||
|         passed in `sequences`. | ||||
|  | ||||
|         The `style` argument is a Style object as returned by either | ||||
|         color_style() or no_style() in django.core.management.color. | ||||
|         """ | ||||
|         return [] | ||||
|  | ||||
|     def sequence_reset_sql(self, style, model_list): | ||||
|         """ | ||||
|         Return a list of the SQL statements required to reset sequences for | ||||
|         the given models. | ||||
|  | ||||
|         The `style` argument is a Style object as returned by either | ||||
|         color_style() or no_style() in django.core.management.color. | ||||
|         """ | ||||
|         return []  # No sequence reset required by default. | ||||
|  | ||||
|     def start_transaction_sql(self): | ||||
|         """Return the SQL statement required to start a transaction.""" | ||||
|         return "BEGIN;" | ||||
|  | ||||
|     def end_transaction_sql(self, success=True): | ||||
|         """Return the SQL statement required to end a transaction.""" | ||||
|         if not success: | ||||
|             return "ROLLBACK;" | ||||
|         return "COMMIT;" | ||||
|  | ||||
|     def tablespace_sql(self, tablespace, inline=False): | ||||
|         """ | ||||
|         Return the SQL that will be used in a query to define the tablespace. | ||||
|  | ||||
|         Return '' if the backend doesn't support tablespaces. | ||||
|  | ||||
|         If `inline` is True, append the SQL to a row; otherwise append it to | ||||
|         the entire CREATE TABLE or CREATE INDEX statement. | ||||
|         """ | ||||
|         return "" | ||||
|  | ||||
|     def prep_for_like_query(self, x): | ||||
|         """Prepare a value for use in a LIKE query.""" | ||||
|         return str(x).replace("\\", "\\\\").replace("%", r"\%").replace("_", r"\_") | ||||
|  | ||||
|     # Same as prep_for_like_query(), but called for "iexact" matches, which | ||||
|     # need not necessarily be implemented using "LIKE" in the backend. | ||||
|     prep_for_iexact_query = prep_for_like_query | ||||
|  | ||||
|     def validate_autopk_value(self, value): | ||||
|         """ | ||||
|         Certain backends do not accept some values for "serial" fields | ||||
|         (for example zero in MySQL). Raise a ValueError if the value is | ||||
|         invalid, otherwise return the validated value. | ||||
|         """ | ||||
|         return value | ||||
|  | ||||
|     def adapt_unknown_value(self, value): | ||||
|         """ | ||||
|         Transform a value to something compatible with the backend driver. | ||||
|  | ||||
|         This method only depends on the type of the value. It's designed for | ||||
|         cases where the target type isn't known, such as .raw() SQL queries. | ||||
|         As a consequence it may not work perfectly in all circumstances. | ||||
|         """ | ||||
|         if isinstance(value, datetime.datetime):  # must be before date | ||||
|             return self.adapt_datetimefield_value(value) | ||||
|         elif isinstance(value, datetime.date): | ||||
|             return self.adapt_datefield_value(value) | ||||
|         elif isinstance(value, datetime.time): | ||||
|             return self.adapt_timefield_value(value) | ||||
|         elif isinstance(value, decimal.Decimal): | ||||
|             return self.adapt_decimalfield_value(value) | ||||
|         else: | ||||
|             return value | ||||
|  | ||||
|     def adapt_integerfield_value(self, value, internal_type): | ||||
|         return value | ||||
|  | ||||
|     def adapt_datefield_value(self, value): | ||||
|         """ | ||||
|         Transform a date value to an object compatible with what is expected | ||||
|         by the backend driver for date columns. | ||||
|         """ | ||||
|         if value is None: | ||||
|             return None | ||||
|         return str(value) | ||||
|  | ||||
|     def adapt_datetimefield_value(self, value): | ||||
|         """ | ||||
|         Transform a datetime value to an object compatible with what is expected | ||||
|         by the backend driver for datetime columns. | ||||
|         """ | ||||
|         if value is None: | ||||
|             return None | ||||
|         # Expression values are adapted by the database. | ||||
|         if hasattr(value, "resolve_expression"): | ||||
|             return value | ||||
|  | ||||
|         return str(value) | ||||
|  | ||||
|     def adapt_timefield_value(self, value): | ||||
|         """ | ||||
|         Transform a time value to an object compatible with what is expected | ||||
|         by the backend driver for time columns. | ||||
|         """ | ||||
|         if value is None: | ||||
|             return None | ||||
|         # Expression values are adapted by the database. | ||||
|         if hasattr(value, "resolve_expression"): | ||||
|             return value | ||||
|  | ||||
|         if timezone.is_aware(value): | ||||
|             raise ValueError("Django does not support timezone-aware times.") | ||||
|         return str(value) | ||||
|  | ||||
|     def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None): | ||||
|         """ | ||||
|         Transform a decimal.Decimal value to an object compatible with what is | ||||
|         expected by the backend driver for decimal (numeric) columns. | ||||
|         """ | ||||
|         return utils.format_number(value, max_digits, decimal_places) | ||||
|  | ||||
|     def adapt_ipaddressfield_value(self, value): | ||||
|         """ | ||||
|         Transform a string representation of an IP address into the expected | ||||
|         type for the backend driver. | ||||
|         """ | ||||
|         return value or None | ||||
|  | ||||
|     def adapt_json_value(self, value, encoder): | ||||
|         return json.dumps(value, cls=encoder) | ||||
|  | ||||
|     def year_lookup_bounds_for_date_field(self, value, iso_year=False): | ||||
|         """ | ||||
|         Return a two-elements list with the lower and upper bound to be used | ||||
|         with a BETWEEN operator to query a DateField value using a year | ||||
|         lookup. | ||||
|  | ||||
|         `value` is an int, containing the looked-up year. | ||||
|         If `iso_year` is True, return bounds for ISO-8601 week-numbering years. | ||||
|         """ | ||||
|         if iso_year: | ||||
|             first = datetime.date.fromisocalendar(value, 1, 1) | ||||
|             second = datetime.date.fromisocalendar( | ||||
|                 value + 1, 1, 1 | ||||
|             ) - datetime.timedelta(days=1) | ||||
|         else: | ||||
|             first = datetime.date(value, 1, 1) | ||||
|             second = datetime.date(value, 12, 31) | ||||
|         first = self.adapt_datefield_value(first) | ||||
|         second = self.adapt_datefield_value(second) | ||||
|         return [first, second] | ||||
|  | ||||
|     def year_lookup_bounds_for_datetime_field(self, value, iso_year=False): | ||||
|         """ | ||||
|         Return a two-elements list with the lower and upper bound to be used | ||||
|         with a BETWEEN operator to query a DateTimeField value using a year | ||||
|         lookup. | ||||
|  | ||||
|         `value` is an int, containing the looked-up year. | ||||
|         If `iso_year` is True, return bounds for ISO-8601 week-numbering years. | ||||
|         """ | ||||
|         if iso_year: | ||||
|             first = datetime.datetime.fromisocalendar(value, 1, 1) | ||||
|             second = datetime.datetime.fromisocalendar( | ||||
|                 value + 1, 1, 1 | ||||
|             ) - datetime.timedelta(microseconds=1) | ||||
|         else: | ||||
|             first = datetime.datetime(value, 1, 1) | ||||
|             second = datetime.datetime(value, 12, 31, 23, 59, 59, 999999) | ||||
|         if settings.USE_TZ: | ||||
|             tz = timezone.get_current_timezone() | ||||
|             first = timezone.make_aware(first, tz) | ||||
|             second = timezone.make_aware(second, tz) | ||||
|         first = self.adapt_datetimefield_value(first) | ||||
|         second = self.adapt_datetimefield_value(second) | ||||
|         return [first, second] | ||||
|  | ||||
|     def get_db_converters(self, expression): | ||||
|         """ | ||||
|         Return a list of functions needed to convert field data. | ||||
|  | ||||
|         Some field types on some backends do not provide data in the correct | ||||
|         format, this is the hook for converter functions. | ||||
|         """ | ||||
|         return [] | ||||
|  | ||||
|     def convert_durationfield_value(self, value, expression, connection): | ||||
|         if value is not None: | ||||
|             return datetime.timedelta(0, 0, value) | ||||
|  | ||||
|     def check_expression_support(self, expression): | ||||
|         """ | ||||
|         Check that the backend supports the provided expression. | ||||
|  | ||||
|         This is used on specific backends to rule out known expressions | ||||
|         that have problematic or nonexistent implementations. If the | ||||
|         expression has a known problem, the backend should raise | ||||
|         NotSupportedError. | ||||
|         """ | ||||
|         pass | ||||
|  | ||||
|     def conditional_expression_supported_in_where_clause(self, expression): | ||||
|         """ | ||||
|         Return True, if the conditional expression is supported in the WHERE | ||||
|         clause. | ||||
|         """ | ||||
|         return True | ||||
|  | ||||
|     def combine_expression(self, connector, sub_expressions): | ||||
|         """ | ||||
|         Combine a list of subexpressions into a single expression, using | ||||
|         the provided connecting operator. This is required because operators | ||||
|         can vary between backends (e.g., Oracle with %% and &) and between | ||||
|         subexpression types (e.g., date expressions). | ||||
|         """ | ||||
|         conn = " %s " % connector | ||||
|         return conn.join(sub_expressions) | ||||
|  | ||||
|     def combine_duration_expression(self, connector, sub_expressions): | ||||
|         return self.combine_expression(connector, sub_expressions) | ||||
|  | ||||
|     def binary_placeholder_sql(self, value): | ||||
|         """ | ||||
|         Some backends require special syntax to insert binary content (MySQL | ||||
|         for example uses '_binary %s'). | ||||
|         """ | ||||
|         return "%s" | ||||
|  | ||||
|     def modify_insert_params(self, placeholder, params): | ||||
|         """ | ||||
|         Allow modification of insert parameters. Needed for Oracle Spatial | ||||
|         backend due to #10888. | ||||
|         """ | ||||
|         return params | ||||
|  | ||||
|     def integer_field_range(self, internal_type): | ||||
|         """ | ||||
|         Given an integer field internal type (e.g. 'PositiveIntegerField'), | ||||
|         return a tuple of the (min_value, max_value) form representing the | ||||
|         range of the column type bound to the field. | ||||
|         """ | ||||
|         return self.integer_field_ranges[internal_type] | ||||
|  | ||||
|     def subtract_temporals(self, internal_type, lhs, rhs): | ||||
|         if self.connection.features.supports_temporal_subtraction: | ||||
|             lhs_sql, lhs_params = lhs | ||||
|             rhs_sql, rhs_params = rhs | ||||
|             return "(%s - %s)" % (lhs_sql, rhs_sql), (*lhs_params, *rhs_params) | ||||
|         raise NotSupportedError( | ||||
|             "This backend does not support %s subtraction." % internal_type | ||||
|         ) | ||||
|  | ||||
|     def window_frame_start(self, start): | ||||
|         if isinstance(start, int): | ||||
|             if start < 0: | ||||
|                 return "%d %s" % (abs(start), self.PRECEDING) | ||||
|             elif start == 0: | ||||
|                 return self.CURRENT_ROW | ||||
|         elif start is None: | ||||
|             return self.UNBOUNDED_PRECEDING | ||||
|         raise ValueError( | ||||
|             "start argument must be a negative integer, zero, or None, but got '%s'." | ||||
|             % start | ||||
|         ) | ||||
|  | ||||
|     def window_frame_end(self, end): | ||||
|         if isinstance(end, int): | ||||
|             if end == 0: | ||||
|                 return self.CURRENT_ROW | ||||
|             elif end > 0: | ||||
|                 return "%d %s" % (end, self.FOLLOWING) | ||||
|         elif end is None: | ||||
|             return self.UNBOUNDED_FOLLOWING | ||||
|         raise ValueError( | ||||
|             "end argument must be a positive integer, zero, or None, but got '%s'." | ||||
|             % end | ||||
|         ) | ||||
|  | ||||
|     def window_frame_rows_start_end(self, start=None, end=None): | ||||
|         """ | ||||
|         Return SQL for start and end points in an OVER clause window frame. | ||||
|         """ | ||||
|         if not self.connection.features.supports_over_clause: | ||||
|             raise NotSupportedError("This backend does not support window expressions.") | ||||
|         return self.window_frame_start(start), self.window_frame_end(end) | ||||
|  | ||||
|     def window_frame_range_start_end(self, start=None, end=None): | ||||
|         start_, end_ = self.window_frame_rows_start_end(start, end) | ||||
|         features = self.connection.features | ||||
|         if features.only_supports_unbounded_with_preceding_and_following and ( | ||||
|             (start and start < 0) or (end and end > 0) | ||||
|         ): | ||||
|             raise NotSupportedError( | ||||
|                 "%s only supports UNBOUNDED together with PRECEDING and " | ||||
|                 "FOLLOWING." % self.connection.display_name | ||||
|             ) | ||||
|         return start_, end_ | ||||
|  | ||||
|     def explain_query_prefix(self, format=None, **options): | ||||
|         if not self.connection.features.supports_explaining_query_execution: | ||||
|             raise NotSupportedError( | ||||
|                 "This backend does not support explaining query execution." | ||||
|             ) | ||||
|         if format: | ||||
|             supported_formats = self.connection.features.supported_explain_formats | ||||
|             normalized_format = format.upper() | ||||
|             if normalized_format not in supported_formats: | ||||
|                 msg = "%s is not a recognized format." % normalized_format | ||||
|                 if supported_formats: | ||||
|                     msg += " Allowed formats: %s" % ", ".join(sorted(supported_formats)) | ||||
|                 else: | ||||
|                     msg += ( | ||||
|                         f" {self.connection.display_name} does not support any formats." | ||||
|                     ) | ||||
|                 raise ValueError(msg) | ||||
|         if options: | ||||
|             raise ValueError("Unknown options: %s" % ", ".join(sorted(options.keys()))) | ||||
|         return self.explain_prefix | ||||
|  | ||||
|     def insert_statement(self, on_conflict=None): | ||||
|         return "INSERT INTO" | ||||
|  | ||||
|     def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields): | ||||
|         return "" | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -0,0 +1,29 @@ | ||||
| class BaseDatabaseValidation: | ||||
|     """Encapsulate backend-specific validation.""" | ||||
|  | ||||
|     def __init__(self, connection): | ||||
|         self.connection = connection | ||||
|  | ||||
|     def check(self, **kwargs): | ||||
|         return [] | ||||
|  | ||||
|     def check_field(self, field, **kwargs): | ||||
|         errors = [] | ||||
|         # Backends may implement a check_field_type() method. | ||||
|         if ( | ||||
|             hasattr(self, "check_field_type") | ||||
|             and | ||||
|             # Ignore any related fields. | ||||
|             not getattr(field, "remote_field", None) | ||||
|         ): | ||||
|             # Ignore fields with unsupported features. | ||||
|             db_supports_all_required_features = all( | ||||
|                 getattr(self.connection.features, feature, False) | ||||
|                 for feature in field.model._meta.required_db_features | ||||
|             ) | ||||
|             if db_supports_all_required_features: | ||||
|                 field_type = field.db_type(self.connection) | ||||
|                 # Ignore non-concrete fields. | ||||
|                 if field_type is not None: | ||||
|                     errors.extend(self.check_field_type(field, field_type)) | ||||
|         return errors | ||||
| @ -0,0 +1,254 @@ | ||||
| """ | ||||
| Helpers to manipulate deferred DDL statements that might need to be adjusted or | ||||
| discarded within when executing a migration. | ||||
| """ | ||||
| from copy import deepcopy | ||||
|  | ||||
|  | ||||
| class Reference: | ||||
|     """Base class that defines the reference interface.""" | ||||
|  | ||||
|     def references_table(self, table): | ||||
|         """ | ||||
|         Return whether or not this instance references the specified table. | ||||
|         """ | ||||
|         return False | ||||
|  | ||||
|     def references_column(self, table, column): | ||||
|         """ | ||||
|         Return whether or not this instance references the specified column. | ||||
|         """ | ||||
|         return False | ||||
|  | ||||
|     def rename_table_references(self, old_table, new_table): | ||||
|         """ | ||||
|         Rename all references to the old_name to the new_table. | ||||
|         """ | ||||
|         pass | ||||
|  | ||||
|     def rename_column_references(self, table, old_column, new_column): | ||||
|         """ | ||||
|         Rename all references to the old_column to the new_column. | ||||
|         """ | ||||
|         pass | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "<%s %r>" % (self.__class__.__name__, str(self)) | ||||
|  | ||||
|     def __str__(self): | ||||
|         raise NotImplementedError( | ||||
|             "Subclasses must define how they should be converted to string." | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class Table(Reference): | ||||
|     """Hold a reference to a table.""" | ||||
|  | ||||
|     def __init__(self, table, quote_name): | ||||
|         self.table = table | ||||
|         self.quote_name = quote_name | ||||
|  | ||||
|     def references_table(self, table): | ||||
|         return self.table == table | ||||
|  | ||||
|     def rename_table_references(self, old_table, new_table): | ||||
|         if self.table == old_table: | ||||
|             self.table = new_table | ||||
|  | ||||
|     def __str__(self): | ||||
|         return self.quote_name(self.table) | ||||
|  | ||||
|  | ||||
| class TableColumns(Table): | ||||
|     """Base class for references to multiple columns of a table.""" | ||||
|  | ||||
|     def __init__(self, table, columns): | ||||
|         self.table = table | ||||
|         self.columns = columns | ||||
|  | ||||
|     def references_column(self, table, column): | ||||
|         return self.table == table and column in self.columns | ||||
|  | ||||
|     def rename_column_references(self, table, old_column, new_column): | ||||
|         if self.table == table: | ||||
|             for index, column in enumerate(self.columns): | ||||
|                 if column == old_column: | ||||
|                     self.columns[index] = new_column | ||||
|  | ||||
|  | ||||
| class Columns(TableColumns): | ||||
|     """Hold a reference to one or many columns.""" | ||||
|  | ||||
|     def __init__(self, table, columns, quote_name, col_suffixes=()): | ||||
|         self.quote_name = quote_name | ||||
|         self.col_suffixes = col_suffixes | ||||
|         super().__init__(table, columns) | ||||
|  | ||||
|     def __str__(self): | ||||
|         def col_str(column, idx): | ||||
|             col = self.quote_name(column) | ||||
|             try: | ||||
|                 suffix = self.col_suffixes[idx] | ||||
|                 if suffix: | ||||
|                     col = "{} {}".format(col, suffix) | ||||
|             except IndexError: | ||||
|                 pass | ||||
|             return col | ||||
|  | ||||
|         return ", ".join( | ||||
|             col_str(column, idx) for idx, column in enumerate(self.columns) | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class IndexName(TableColumns): | ||||
|     """Hold a reference to an index name.""" | ||||
|  | ||||
|     def __init__(self, table, columns, suffix, create_index_name): | ||||
|         self.suffix = suffix | ||||
|         self.create_index_name = create_index_name | ||||
|         super().__init__(table, columns) | ||||
|  | ||||
|     def __str__(self): | ||||
|         return self.create_index_name(self.table, self.columns, self.suffix) | ||||
|  | ||||
|  | ||||
| class IndexColumns(Columns): | ||||
|     def __init__(self, table, columns, quote_name, col_suffixes=(), opclasses=()): | ||||
|         self.opclasses = opclasses | ||||
|         super().__init__(table, columns, quote_name, col_suffixes) | ||||
|  | ||||
|     def __str__(self): | ||||
|         def col_str(column, idx): | ||||
|             # Index.__init__() guarantees that self.opclasses is the same | ||||
|             # length as self.columns. | ||||
|             col = "{} {}".format(self.quote_name(column), self.opclasses[idx]) | ||||
|             try: | ||||
|                 suffix = self.col_suffixes[idx] | ||||
|                 if suffix: | ||||
|                     col = "{} {}".format(col, suffix) | ||||
|             except IndexError: | ||||
|                 pass | ||||
|             return col | ||||
|  | ||||
|         return ", ".join( | ||||
|             col_str(column, idx) for idx, column in enumerate(self.columns) | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class ForeignKeyName(TableColumns): | ||||
|     """Hold a reference to a foreign key name.""" | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         from_table, | ||||
|         from_columns, | ||||
|         to_table, | ||||
|         to_columns, | ||||
|         suffix_template, | ||||
|         create_fk_name, | ||||
|     ): | ||||
|         self.to_reference = TableColumns(to_table, to_columns) | ||||
|         self.suffix_template = suffix_template | ||||
|         self.create_fk_name = create_fk_name | ||||
|         super().__init__( | ||||
|             from_table, | ||||
|             from_columns, | ||||
|         ) | ||||
|  | ||||
|     def references_table(self, table): | ||||
|         return super().references_table(table) or self.to_reference.references_table( | ||||
|             table | ||||
|         ) | ||||
|  | ||||
|     def references_column(self, table, column): | ||||
|         return super().references_column( | ||||
|             table, column | ||||
|         ) or self.to_reference.references_column(table, column) | ||||
|  | ||||
|     def rename_table_references(self, old_table, new_table): | ||||
|         super().rename_table_references(old_table, new_table) | ||||
|         self.to_reference.rename_table_references(old_table, new_table) | ||||
|  | ||||
|     def rename_column_references(self, table, old_column, new_column): | ||||
|         super().rename_column_references(table, old_column, new_column) | ||||
|         self.to_reference.rename_column_references(table, old_column, new_column) | ||||
|  | ||||
|     def __str__(self): | ||||
|         suffix = self.suffix_template % { | ||||
|             "to_table": self.to_reference.table, | ||||
|             "to_column": self.to_reference.columns[0], | ||||
|         } | ||||
|         return self.create_fk_name(self.table, self.columns, suffix) | ||||
|  | ||||
|  | ||||
| class Statement(Reference): | ||||
|     """ | ||||
|     Statement template and formatting parameters container. | ||||
|  | ||||
|     Allows keeping a reference to a statement without interpolating identifiers | ||||
|     that might have to be adjusted if they're referencing a table or column | ||||
|     that is removed | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, template, **parts): | ||||
|         self.template = template | ||||
|         self.parts = parts | ||||
|  | ||||
|     def references_table(self, table): | ||||
|         return any( | ||||
|             hasattr(part, "references_table") and part.references_table(table) | ||||
|             for part in self.parts.values() | ||||
|         ) | ||||
|  | ||||
|     def references_column(self, table, column): | ||||
|         return any( | ||||
|             hasattr(part, "references_column") and part.references_column(table, column) | ||||
|             for part in self.parts.values() | ||||
|         ) | ||||
|  | ||||
|     def rename_table_references(self, old_table, new_table): | ||||
|         for part in self.parts.values(): | ||||
|             if hasattr(part, "rename_table_references"): | ||||
|                 part.rename_table_references(old_table, new_table) | ||||
|  | ||||
|     def rename_column_references(self, table, old_column, new_column): | ||||
|         for part in self.parts.values(): | ||||
|             if hasattr(part, "rename_column_references"): | ||||
|                 part.rename_column_references(table, old_column, new_column) | ||||
|  | ||||
|     def __str__(self): | ||||
|         return self.template % self.parts | ||||
|  | ||||
|  | ||||
| class Expressions(TableColumns): | ||||
|     def __init__(self, table, expressions, compiler, quote_value): | ||||
|         self.compiler = compiler | ||||
|         self.expressions = expressions | ||||
|         self.quote_value = quote_value | ||||
|         columns = [ | ||||
|             col.target.column | ||||
|             for col in self.compiler.query._gen_cols([self.expressions]) | ||||
|         ] | ||||
|         super().__init__(table, columns) | ||||
|  | ||||
|     def rename_table_references(self, old_table, new_table): | ||||
|         if self.table != old_table: | ||||
|             return | ||||
|         self.expressions = self.expressions.relabeled_clone({old_table: new_table}) | ||||
|         super().rename_table_references(old_table, new_table) | ||||
|  | ||||
|     def rename_column_references(self, table, old_column, new_column): | ||||
|         if self.table != table: | ||||
|             return | ||||
|         expressions = deepcopy(self.expressions) | ||||
|         self.columns = [] | ||||
|         for col in self.compiler.query._gen_cols([expressions]): | ||||
|             if col.target.column == old_column: | ||||
|                 col.target.column = new_column | ||||
|             self.columns.append(col.target.column) | ||||
|         self.expressions = expressions | ||||
|  | ||||
|     def __str__(self): | ||||
|         sql, params = self.compiler.compile(self.expressions) | ||||
|         params = map(self.quote_value, params) | ||||
|         return sql % tuple(params) | ||||
| @ -0,0 +1,74 @@ | ||||
| """ | ||||
| Dummy database backend for Django. | ||||
|  | ||||
| Django uses this if the database ENGINE setting is empty (None or empty string). | ||||
|  | ||||
| Each of these API functions, except connection.close(), raise | ||||
| ImproperlyConfigured. | ||||
| """ | ||||
|  | ||||
| from django.core.exceptions import ImproperlyConfigured | ||||
| from django.db.backends.base.base import BaseDatabaseWrapper | ||||
| from django.db.backends.base.client import BaseDatabaseClient | ||||
| from django.db.backends.base.creation import BaseDatabaseCreation | ||||
| from django.db.backends.base.introspection import BaseDatabaseIntrospection | ||||
| from django.db.backends.base.operations import BaseDatabaseOperations | ||||
| from django.db.backends.dummy.features import DummyDatabaseFeatures | ||||
|  | ||||
|  | ||||
| def complain(*args, **kwargs): | ||||
|     raise ImproperlyConfigured( | ||||
|         "settings.DATABASES is improperly configured. " | ||||
|         "Please supply the ENGINE value. Check " | ||||
|         "settings documentation for more details." | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def ignore(*args, **kwargs): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class DatabaseOperations(BaseDatabaseOperations): | ||||
|     quote_name = complain | ||||
|  | ||||
|  | ||||
| class DatabaseClient(BaseDatabaseClient): | ||||
|     runshell = complain | ||||
|  | ||||
|  | ||||
| class DatabaseCreation(BaseDatabaseCreation): | ||||
|     create_test_db = ignore | ||||
|     destroy_test_db = ignore | ||||
|  | ||||
|  | ||||
| class DatabaseIntrospection(BaseDatabaseIntrospection): | ||||
|     get_table_list = complain | ||||
|     get_table_description = complain | ||||
|     get_relations = complain | ||||
|     get_indexes = complain | ||||
|  | ||||
|  | ||||
| class DatabaseWrapper(BaseDatabaseWrapper): | ||||
|     operators = {} | ||||
|     # Override the base class implementations with null | ||||
|     # implementations. Anything that tries to actually | ||||
|     # do something raises complain; anything that tries | ||||
|     # to rollback or undo something raises ignore. | ||||
|     _cursor = complain | ||||
|     ensure_connection = complain | ||||
|     _commit = complain | ||||
|     _rollback = ignore | ||||
|     _close = ignore | ||||
|     _savepoint = ignore | ||||
|     _savepoint_commit = complain | ||||
|     _savepoint_rollback = ignore | ||||
|     _set_autocommit = complain | ||||
|     # Classes instantiated in __init__(). | ||||
|     client_class = DatabaseClient | ||||
|     creation_class = DatabaseCreation | ||||
|     features_class = DummyDatabaseFeatures | ||||
|     introspection_class = DatabaseIntrospection | ||||
|     ops_class = DatabaseOperations | ||||
|  | ||||
|     def is_usable(self): | ||||
|         return True | ||||
| @ -0,0 +1,6 @@ | ||||
| from django.db.backends.base.features import BaseDatabaseFeatures | ||||
|  | ||||
|  | ||||
| class DummyDatabaseFeatures(BaseDatabaseFeatures): | ||||
|     supports_transactions = False | ||||
|     uses_savepoints = False | ||||
| @ -0,0 +1,444 @@ | ||||
| """ | ||||
| MySQL database backend for Django. | ||||
|  | ||||
| Requires mysqlclient: https://pypi.org/project/mysqlclient/ | ||||
| """ | ||||
| from django.core.exceptions import ImproperlyConfigured | ||||
| from django.db import IntegrityError | ||||
| from django.db.backends import utils as backend_utils | ||||
| from django.db.backends.base.base import BaseDatabaseWrapper | ||||
| from django.utils.asyncio import async_unsafe | ||||
| from django.utils.functional import cached_property | ||||
| from django.utils.regex_helper import _lazy_re_compile | ||||
|  | ||||
| try: | ||||
|     import MySQLdb as Database | ||||
| except ImportError as err: | ||||
|     raise ImproperlyConfigured( | ||||
|         "Error loading MySQLdb module.\nDid you install mysqlclient?" | ||||
|     ) from err | ||||
|  | ||||
| from MySQLdb.constants import CLIENT, FIELD_TYPE | ||||
| from MySQLdb.converters import conversions | ||||
|  | ||||
| # Some of these import MySQLdb, so import them after checking if it's installed. | ||||
| from .client import DatabaseClient | ||||
| from .creation import DatabaseCreation | ||||
| from .features import DatabaseFeatures | ||||
| from .introspection import DatabaseIntrospection | ||||
| from .operations import DatabaseOperations | ||||
| from .schema import DatabaseSchemaEditor | ||||
| from .validation import DatabaseValidation | ||||
|  | ||||
| version = Database.version_info | ||||
| if version < (1, 4, 3): | ||||
|     raise ImproperlyConfigured( | ||||
|         "mysqlclient 1.4.3 or newer is required; you have %s." % Database.__version__ | ||||
|     ) | ||||
|  | ||||
|  | ||||
| # MySQLdb returns TIME columns as timedelta -- they are more like timedelta in | ||||
| # terms of actual behavior as they are signed and include days -- and Django | ||||
| # expects time. | ||||
| django_conversions = { | ||||
|     **conversions, | ||||
|     **{FIELD_TYPE.TIME: backend_utils.typecast_time}, | ||||
| } | ||||
|  | ||||
| # This should match the numerical portion of the version numbers (we can treat | ||||
| # versions like 5.0.24 and 5.0.24a as the same). | ||||
| server_version_re = _lazy_re_compile(r"(\d{1,2})\.(\d{1,2})\.(\d{1,2})") | ||||
|  | ||||
|  | ||||
| class CursorWrapper: | ||||
|     """ | ||||
|     A thin wrapper around MySQLdb's normal cursor class that catches particular | ||||
|     exception instances and reraises them with the correct types. | ||||
|  | ||||
|     Implemented as a wrapper, rather than a subclass, so that it isn't stuck | ||||
|     to the particular underlying representation returned by Connection.cursor(). | ||||
|     """ | ||||
|  | ||||
|     codes_for_integrityerror = ( | ||||
|         1048,  # Column cannot be null | ||||
|         1690,  # BIGINT UNSIGNED value is out of range | ||||
|         3819,  # CHECK constraint is violated | ||||
|         4025,  # CHECK constraint failed | ||||
|     ) | ||||
|  | ||||
|     def __init__(self, cursor): | ||||
|         self.cursor = cursor | ||||
|  | ||||
|     def execute(self, query, args=None): | ||||
|         try: | ||||
|             # args is None means no string interpolation | ||||
|             return self.cursor.execute(query, args) | ||||
|         except Database.OperationalError as e: | ||||
|             # Map some error codes to IntegrityError, since they seem to be | ||||
|             # misclassified and Django would prefer the more logical place. | ||||
|             if e.args[0] in self.codes_for_integrityerror: | ||||
|                 raise IntegrityError(*tuple(e.args)) | ||||
|             raise | ||||
|  | ||||
|     def executemany(self, query, args): | ||||
|         try: | ||||
|             return self.cursor.executemany(query, args) | ||||
|         except Database.OperationalError as e: | ||||
|             # Map some error codes to IntegrityError, since they seem to be | ||||
|             # misclassified and Django would prefer the more logical place. | ||||
|             if e.args[0] in self.codes_for_integrityerror: | ||||
|                 raise IntegrityError(*tuple(e.args)) | ||||
|             raise | ||||
|  | ||||
|     def __getattr__(self, attr): | ||||
|         return getattr(self.cursor, attr) | ||||
|  | ||||
|     def __iter__(self): | ||||
|         return iter(self.cursor) | ||||
|  | ||||
|  | ||||
| class DatabaseWrapper(BaseDatabaseWrapper): | ||||
|     vendor = "mysql" | ||||
|     # This dictionary maps Field objects to their associated MySQL column | ||||
|     # types, as strings. Column-type strings can contain format strings; they'll | ||||
|     # be interpolated against the values of Field.__dict__ before being output. | ||||
|     # If a column type is set to None, it won't be included in the output. | ||||
|     data_types = { | ||||
|         "AutoField": "integer AUTO_INCREMENT", | ||||
|         "BigAutoField": "bigint AUTO_INCREMENT", | ||||
|         "BinaryField": "longblob", | ||||
|         "BooleanField": "bool", | ||||
|         "CharField": "varchar(%(max_length)s)", | ||||
|         "DateField": "date", | ||||
|         "DateTimeField": "datetime(6)", | ||||
|         "DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)", | ||||
|         "DurationField": "bigint", | ||||
|         "FileField": "varchar(%(max_length)s)", | ||||
|         "FilePathField": "varchar(%(max_length)s)", | ||||
|         "FloatField": "double precision", | ||||
|         "IntegerField": "integer", | ||||
|         "BigIntegerField": "bigint", | ||||
|         "IPAddressField": "char(15)", | ||||
|         "GenericIPAddressField": "char(39)", | ||||
|         "JSONField": "json", | ||||
|         "OneToOneField": "integer", | ||||
|         "PositiveBigIntegerField": "bigint UNSIGNED", | ||||
|         "PositiveIntegerField": "integer UNSIGNED", | ||||
|         "PositiveSmallIntegerField": "smallint UNSIGNED", | ||||
|         "SlugField": "varchar(%(max_length)s)", | ||||
|         "SmallAutoField": "smallint AUTO_INCREMENT", | ||||
|         "SmallIntegerField": "smallint", | ||||
|         "TextField": "longtext", | ||||
|         "TimeField": "time(6)", | ||||
|         "UUIDField": "char(32)", | ||||
|     } | ||||
|  | ||||
|     # For these data types: | ||||
|     # - MySQL < 8.0.13 doesn't accept default values and implicitly treats them | ||||
|     #   as nullable | ||||
|     # - all versions of MySQL and MariaDB don't support full width database | ||||
|     #   indexes | ||||
|     _limited_data_types = ( | ||||
|         "tinyblob", | ||||
|         "blob", | ||||
|         "mediumblob", | ||||
|         "longblob", | ||||
|         "tinytext", | ||||
|         "text", | ||||
|         "mediumtext", | ||||
|         "longtext", | ||||
|         "json", | ||||
|     ) | ||||
|  | ||||
|     operators = { | ||||
|         "exact": "= %s", | ||||
|         "iexact": "LIKE %s", | ||||
|         "contains": "LIKE BINARY %s", | ||||
|         "icontains": "LIKE %s", | ||||
|         "gt": "> %s", | ||||
|         "gte": ">= %s", | ||||
|         "lt": "< %s", | ||||
|         "lte": "<= %s", | ||||
|         "startswith": "LIKE BINARY %s", | ||||
|         "endswith": "LIKE BINARY %s", | ||||
|         "istartswith": "LIKE %s", | ||||
|         "iendswith": "LIKE %s", | ||||
|     } | ||||
|  | ||||
|     # The patterns below are used to generate SQL pattern lookup clauses when | ||||
|     # the right-hand side of the lookup isn't a raw string (it might be an expression | ||||
|     # or the result of a bilateral transformation). | ||||
|     # In those cases, special characters for LIKE operators (e.g. \, *, _) should be | ||||
|     # escaped on database side. | ||||
|     # | ||||
|     # Note: we use str.format() here for readability as '%' is used as a wildcard for | ||||
|     # the LIKE operator. | ||||
|     pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\\', '\\\\'), '%%', '\%%'), '_', '\_')" | ||||
|     pattern_ops = { | ||||
|         "contains": "LIKE BINARY CONCAT('%%', {}, '%%')", | ||||
|         "icontains": "LIKE CONCAT('%%', {}, '%%')", | ||||
|         "startswith": "LIKE BINARY CONCAT({}, '%%')", | ||||
|         "istartswith": "LIKE CONCAT({}, '%%')", | ||||
|         "endswith": "LIKE BINARY CONCAT('%%', {})", | ||||
|         "iendswith": "LIKE CONCAT('%%', {})", | ||||
|     } | ||||
|  | ||||
|     isolation_levels = { | ||||
|         "read uncommitted", | ||||
|         "read committed", | ||||
|         "repeatable read", | ||||
|         "serializable", | ||||
|     } | ||||
|  | ||||
|     Database = Database | ||||
|     SchemaEditorClass = DatabaseSchemaEditor | ||||
|     # Classes instantiated in __init__(). | ||||
|     client_class = DatabaseClient | ||||
|     creation_class = DatabaseCreation | ||||
|     features_class = DatabaseFeatures | ||||
|     introspection_class = DatabaseIntrospection | ||||
|     ops_class = DatabaseOperations | ||||
|     validation_class = DatabaseValidation | ||||
|  | ||||
|     def get_database_version(self): | ||||
|         return self.mysql_version | ||||
|  | ||||
|     def get_connection_params(self): | ||||
|         kwargs = { | ||||
|             "conv": django_conversions, | ||||
|             "charset": "utf8", | ||||
|         } | ||||
|         settings_dict = self.settings_dict | ||||
|         if settings_dict["USER"]: | ||||
|             kwargs["user"] = settings_dict["USER"] | ||||
|         if settings_dict["NAME"]: | ||||
|             kwargs["database"] = settings_dict["NAME"] | ||||
|         if settings_dict["PASSWORD"]: | ||||
|             kwargs["password"] = settings_dict["PASSWORD"] | ||||
|         if settings_dict["HOST"].startswith("/"): | ||||
|             kwargs["unix_socket"] = settings_dict["HOST"] | ||||
|         elif settings_dict["HOST"]: | ||||
|             kwargs["host"] = settings_dict["HOST"] | ||||
|         if settings_dict["PORT"]: | ||||
|             kwargs["port"] = int(settings_dict["PORT"]) | ||||
|         # We need the number of potentially affected rows after an | ||||
|         # "UPDATE", not the number of changed rows. | ||||
|         kwargs["client_flag"] = CLIENT.FOUND_ROWS | ||||
|         # Validate the transaction isolation level, if specified. | ||||
|         options = settings_dict["OPTIONS"].copy() | ||||
|         isolation_level = options.pop("isolation_level", "read committed") | ||||
|         if isolation_level: | ||||
|             isolation_level = isolation_level.lower() | ||||
|             if isolation_level not in self.isolation_levels: | ||||
|                 raise ImproperlyConfigured( | ||||
|                     "Invalid transaction isolation level '%s' specified.\n" | ||||
|                     "Use one of %s, or None." | ||||
|                     % ( | ||||
|                         isolation_level, | ||||
|                         ", ".join("'%s'" % s for s in sorted(self.isolation_levels)), | ||||
|                     ) | ||||
|                 ) | ||||
|         self.isolation_level = isolation_level | ||||
|         kwargs.update(options) | ||||
|         return kwargs | ||||
|  | ||||
|     @async_unsafe | ||||
|     def get_new_connection(self, conn_params): | ||||
|         connection = Database.connect(**conn_params) | ||||
|         # bytes encoder in mysqlclient doesn't work and was added only to | ||||
|         # prevent KeyErrors in Django < 2.0. We can remove this workaround when | ||||
|         # mysqlclient 2.1 becomes the minimal mysqlclient supported by Django. | ||||
|         # See https://github.com/PyMySQL/mysqlclient/issues/489 | ||||
|         if connection.encoders.get(bytes) is bytes: | ||||
|             connection.encoders.pop(bytes) | ||||
|         return connection | ||||
|  | ||||
|     def init_connection_state(self): | ||||
|         super().init_connection_state() | ||||
|         assignments = [] | ||||
|         if self.features.is_sql_auto_is_null_enabled: | ||||
|             # SQL_AUTO_IS_NULL controls whether an AUTO_INCREMENT column on | ||||
|             # a recently inserted row will return when the field is tested | ||||
|             # for NULL. Disabling this brings this aspect of MySQL in line | ||||
|             # with SQL standards. | ||||
|             assignments.append("SET SQL_AUTO_IS_NULL = 0") | ||||
|  | ||||
|         if self.isolation_level: | ||||
|             assignments.append( | ||||
|                 "SET SESSION TRANSACTION ISOLATION LEVEL %s" | ||||
|                 % self.isolation_level.upper() | ||||
|             ) | ||||
|  | ||||
|         if assignments: | ||||
|             with self.cursor() as cursor: | ||||
|                 cursor.execute("; ".join(assignments)) | ||||
|  | ||||
|     @async_unsafe | ||||
|     def create_cursor(self, name=None): | ||||
|         cursor = self.connection.cursor() | ||||
|         return CursorWrapper(cursor) | ||||
|  | ||||
|     def _rollback(self): | ||||
|         try: | ||||
|             BaseDatabaseWrapper._rollback(self) | ||||
|         except Database.NotSupportedError: | ||||
|             pass | ||||
|  | ||||
|     def _set_autocommit(self, autocommit): | ||||
|         with self.wrap_database_errors: | ||||
|             self.connection.autocommit(autocommit) | ||||
|  | ||||
|     def disable_constraint_checking(self): | ||||
|         """ | ||||
|         Disable foreign key checks, primarily for use in adding rows with | ||||
|         forward references. Always return True to indicate constraint checks | ||||
|         need to be re-enabled. | ||||
|         """ | ||||
|         with self.cursor() as cursor: | ||||
|             cursor.execute("SET foreign_key_checks=0") | ||||
|         return True | ||||
|  | ||||
|     def enable_constraint_checking(self): | ||||
|         """ | ||||
|         Re-enable foreign key checks after they have been disabled. | ||||
|         """ | ||||
|         # Override needs_rollback in case constraint_checks_disabled is | ||||
|         # nested inside transaction.atomic. | ||||
|         self.needs_rollback, needs_rollback = False, self.needs_rollback | ||||
|         try: | ||||
|             with self.cursor() as cursor: | ||||
|                 cursor.execute("SET foreign_key_checks=1") | ||||
|         finally: | ||||
|             self.needs_rollback = needs_rollback | ||||
|  | ||||
|     def check_constraints(self, table_names=None): | ||||
|         """ | ||||
|         Check each table name in `table_names` for rows with invalid foreign | ||||
|         key references. This method is intended to be used in conjunction with | ||||
|         `disable_constraint_checking()` and `enable_constraint_checking()`, to | ||||
|         determine if rows with invalid references were entered while constraint | ||||
|         checks were off. | ||||
|         """ | ||||
|         with self.cursor() as cursor: | ||||
|             if table_names is None: | ||||
|                 table_names = self.introspection.table_names(cursor) | ||||
|             for table_name in table_names: | ||||
|                 primary_key_column_name = self.introspection.get_primary_key_column( | ||||
|                     cursor, table_name | ||||
|                 ) | ||||
|                 if not primary_key_column_name: | ||||
|                     continue | ||||
|                 relations = self.introspection.get_relations(cursor, table_name) | ||||
|                 for column_name, ( | ||||
|                     referenced_column_name, | ||||
|                     referenced_table_name, | ||||
|                 ) in relations.items(): | ||||
|                     cursor.execute( | ||||
|                         """ | ||||
|                         SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING | ||||
|                         LEFT JOIN `%s` as REFERRED | ||||
|                         ON (REFERRING.`%s` = REFERRED.`%s`) | ||||
|                         WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL | ||||
|                         """ | ||||
|                         % ( | ||||
|                             primary_key_column_name, | ||||
|                             column_name, | ||||
|                             table_name, | ||||
|                             referenced_table_name, | ||||
|                             column_name, | ||||
|                             referenced_column_name, | ||||
|                             column_name, | ||||
|                             referenced_column_name, | ||||
|                         ) | ||||
|                     ) | ||||
|                     for bad_row in cursor.fetchall(): | ||||
|                         raise IntegrityError( | ||||
|                             "The row in table '%s' with primary key '%s' has an " | ||||
|                             "invalid foreign key: %s.%s contains a value '%s' that " | ||||
|                             "does not have a corresponding value in %s.%s." | ||||
|                             % ( | ||||
|                                 table_name, | ||||
|                                 bad_row[0], | ||||
|                                 table_name, | ||||
|                                 column_name, | ||||
|                                 bad_row[1], | ||||
|                                 referenced_table_name, | ||||
|                                 referenced_column_name, | ||||
|                             ) | ||||
|                         ) | ||||
|  | ||||
|     def is_usable(self): | ||||
|         try: | ||||
|             self.connection.ping() | ||||
|         except Database.Error: | ||||
|             return False | ||||
|         else: | ||||
|             return True | ||||
|  | ||||
|     @cached_property | ||||
|     def display_name(self): | ||||
|         return "MariaDB" if self.mysql_is_mariadb else "MySQL" | ||||
|  | ||||
|     @cached_property | ||||
|     def data_type_check_constraints(self): | ||||
|         if self.features.supports_column_check_constraints: | ||||
|             check_constraints = { | ||||
|                 "PositiveBigIntegerField": "`%(column)s` >= 0", | ||||
|                 "PositiveIntegerField": "`%(column)s` >= 0", | ||||
|                 "PositiveSmallIntegerField": "`%(column)s` >= 0", | ||||
|             } | ||||
|             if self.mysql_is_mariadb and self.mysql_version < (10, 4, 3): | ||||
|                 # MariaDB < 10.4.3 doesn't automatically use the JSON_VALID as | ||||
|                 # a check constraint. | ||||
|                 check_constraints["JSONField"] = "JSON_VALID(`%(column)s`)" | ||||
|             return check_constraints | ||||
|         return {} | ||||
|  | ||||
|     @cached_property | ||||
|     def mysql_server_data(self): | ||||
|         with self.temporary_connection() as cursor: | ||||
|             # Select some server variables and test if the time zone | ||||
|             # definitions are installed. CONVERT_TZ returns NULL if 'UTC' | ||||
|             # timezone isn't loaded into the mysql.time_zone table. | ||||
|             cursor.execute( | ||||
|                 """ | ||||
|                 SELECT VERSION(), | ||||
|                        @@sql_mode, | ||||
|                        @@default_storage_engine, | ||||
|                        @@sql_auto_is_null, | ||||
|                        @@lower_case_table_names, | ||||
|                        CONVERT_TZ('2001-01-01 01:00:00', 'UTC', 'UTC') IS NOT NULL | ||||
|             """ | ||||
|             ) | ||||
|             row = cursor.fetchone() | ||||
|         return { | ||||
|             "version": row[0], | ||||
|             "sql_mode": row[1], | ||||
|             "default_storage_engine": row[2], | ||||
|             "sql_auto_is_null": bool(row[3]), | ||||
|             "lower_case_table_names": bool(row[4]), | ||||
|             "has_zoneinfo_database": bool(row[5]), | ||||
|         } | ||||
|  | ||||
|     @cached_property | ||||
|     def mysql_server_info(self): | ||||
|         return self.mysql_server_data["version"] | ||||
|  | ||||
|     @cached_property | ||||
|     def mysql_version(self): | ||||
|         match = server_version_re.match(self.mysql_server_info) | ||||
|         if not match: | ||||
|             raise Exception( | ||||
|                 "Unable to determine MySQL version from version string %r" | ||||
|                 % self.mysql_server_info | ||||
|             ) | ||||
|         return tuple(int(x) for x in match.groups()) | ||||
|  | ||||
|     @cached_property | ||||
|     def mysql_is_mariadb(self): | ||||
|         return "mariadb" in self.mysql_server_info.lower() | ||||
|  | ||||
|     @cached_property | ||||
|     def sql_mode(self): | ||||
|         sql_mode = self.mysql_server_data["sql_mode"] | ||||
|         return set(sql_mode.split(",") if sql_mode else ()) | ||||
| @ -0,0 +1,72 @@ | ||||
| import signal | ||||
|  | ||||
| from django.db.backends.base.client import BaseDatabaseClient | ||||
|  | ||||
|  | ||||
| class DatabaseClient(BaseDatabaseClient): | ||||
|     executable_name = "mysql" | ||||
|  | ||||
|     @classmethod | ||||
|     def settings_to_cmd_args_env(cls, settings_dict, parameters): | ||||
|         args = [cls.executable_name] | ||||
|         env = None | ||||
|         database = settings_dict["OPTIONS"].get( | ||||
|             "database", | ||||
|             settings_dict["OPTIONS"].get("db", settings_dict["NAME"]), | ||||
|         ) | ||||
|         user = settings_dict["OPTIONS"].get("user", settings_dict["USER"]) | ||||
|         password = settings_dict["OPTIONS"].get( | ||||
|             "password", | ||||
|             settings_dict["OPTIONS"].get("passwd", settings_dict["PASSWORD"]), | ||||
|         ) | ||||
|         host = settings_dict["OPTIONS"].get("host", settings_dict["HOST"]) | ||||
|         port = settings_dict["OPTIONS"].get("port", settings_dict["PORT"]) | ||||
|         server_ca = settings_dict["OPTIONS"].get("ssl", {}).get("ca") | ||||
|         client_cert = settings_dict["OPTIONS"].get("ssl", {}).get("cert") | ||||
|         client_key = settings_dict["OPTIONS"].get("ssl", {}).get("key") | ||||
|         defaults_file = settings_dict["OPTIONS"].get("read_default_file") | ||||
|         charset = settings_dict["OPTIONS"].get("charset") | ||||
|         # Seems to be no good way to set sql_mode with CLI. | ||||
|  | ||||
|         if defaults_file: | ||||
|             args += ["--defaults-file=%s" % defaults_file] | ||||
|         if user: | ||||
|             args += ["--user=%s" % user] | ||||
|         if password: | ||||
|             # The MYSQL_PWD environment variable usage is discouraged per | ||||
|             # MySQL's documentation due to the possibility of exposure through | ||||
|             # `ps` on old Unix flavors but --password suffers from the same | ||||
|             # flaw on even more systems. Usage of an environment variable also | ||||
|             # prevents password exposure if the subprocess.run(check=True) call | ||||
|             # raises a CalledProcessError since the string representation of | ||||
|             # the latter includes all of the provided `args`. | ||||
|             env = {"MYSQL_PWD": password} | ||||
|         if host: | ||||
|             if "/" in host: | ||||
|                 args += ["--socket=%s" % host] | ||||
|             else: | ||||
|                 args += ["--host=%s" % host] | ||||
|         if port: | ||||
|             args += ["--port=%s" % port] | ||||
|         if server_ca: | ||||
|             args += ["--ssl-ca=%s" % server_ca] | ||||
|         if client_cert: | ||||
|             args += ["--ssl-cert=%s" % client_cert] | ||||
|         if client_key: | ||||
|             args += ["--ssl-key=%s" % client_key] | ||||
|         if charset: | ||||
|             args += ["--default-character-set=%s" % charset] | ||||
|         if database: | ||||
|             args += [database] | ||||
|         args.extend(parameters) | ||||
|         return args, env | ||||
|  | ||||
|     def runshell(self, parameters): | ||||
|         sigint_handler = signal.getsignal(signal.SIGINT) | ||||
|         try: | ||||
|             # Allow SIGINT to pass to mysql to abort queries. | ||||
|             signal.signal(signal.SIGINT, signal.SIG_IGN) | ||||
|             super().runshell(parameters) | ||||
|         finally: | ||||
|             # Restore the original SIGINT handler. | ||||
|             signal.signal(signal.SIGINT, sigint_handler) | ||||
| @ -0,0 +1,84 @@ | ||||
| from django.core.exceptions import FieldError, FullResultSet | ||||
| from django.db.models.expressions import Col | ||||
| from django.db.models.sql import compiler | ||||
|  | ||||
|  | ||||
| class SQLCompiler(compiler.SQLCompiler): | ||||
|     def as_subquery_condition(self, alias, columns, compiler): | ||||
|         qn = compiler.quote_name_unless_alias | ||||
|         qn2 = self.connection.ops.quote_name | ||||
|         sql, params = self.as_sql() | ||||
|         return ( | ||||
|             "(%s) IN (%s)" | ||||
|             % ( | ||||
|                 ", ".join("%s.%s" % (qn(alias), qn2(column)) for column in columns), | ||||
|                 sql, | ||||
|             ), | ||||
|             params, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler): | ||||
|     def as_sql(self): | ||||
|         # Prefer the non-standard DELETE FROM syntax over the SQL generated by | ||||
|         # the SQLDeleteCompiler's default implementation when multiple tables | ||||
|         # are involved since MySQL/MariaDB will generate a more efficient query | ||||
|         # plan than when using a subquery. | ||||
|         where, having, qualify = self.query.where.split_having_qualify( | ||||
|             must_group_by=self.query.group_by is not None | ||||
|         ) | ||||
|         if self.single_alias or having or qualify: | ||||
|             # DELETE FROM cannot be used when filtering against aggregates or | ||||
|             # window functions as it doesn't allow for GROUP BY/HAVING clauses | ||||
|             # and the subquery wrapping (necessary to emulate QUALIFY). | ||||
|             return super().as_sql() | ||||
|         result = [ | ||||
|             "DELETE %s FROM" | ||||
|             % self.quote_name_unless_alias(self.query.get_initial_alias()) | ||||
|         ] | ||||
|         from_sql, params = self.get_from_clause() | ||||
|         result.extend(from_sql) | ||||
|         try: | ||||
|             where_sql, where_params = self.compile(where) | ||||
|         except FullResultSet: | ||||
|             pass | ||||
|         else: | ||||
|             result.append("WHERE %s" % where_sql) | ||||
|             params.extend(where_params) | ||||
|         return " ".join(result), tuple(params) | ||||
|  | ||||
|  | ||||
| class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler): | ||||
|     def as_sql(self): | ||||
|         update_query, update_params = super().as_sql() | ||||
|         # MySQL and MariaDB support UPDATE ... ORDER BY syntax. | ||||
|         if self.query.order_by: | ||||
|             order_by_sql = [] | ||||
|             order_by_params = [] | ||||
|             db_table = self.query.get_meta().db_table | ||||
|             try: | ||||
|                 for resolved, (sql, params, _) in self.get_order_by(): | ||||
|                     if ( | ||||
|                         isinstance(resolved.expression, Col) | ||||
|                         and resolved.expression.alias != db_table | ||||
|                     ): | ||||
|                         # Ignore ordering if it contains joined fields, because | ||||
|                         # they cannot be used in the ORDER BY clause. | ||||
|                         raise FieldError | ||||
|                     order_by_sql.append(sql) | ||||
|                     order_by_params.extend(params) | ||||
|                 update_query += " ORDER BY " + ", ".join(order_by_sql) | ||||
|                 update_params += tuple(order_by_params) | ||||
|             except FieldError: | ||||
|                 # Ignore ordering if it contains annotations, because they're | ||||
|                 # removed in .update() and cannot be resolved. | ||||
|                 pass | ||||
|         return update_query, update_params | ||||
|  | ||||
|  | ||||
| class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler): | ||||
|     pass | ||||
| @ -0,0 +1,87 @@ | ||||
| import os | ||||
| import subprocess | ||||
| import sys | ||||
|  | ||||
| from django.db.backends.base.creation import BaseDatabaseCreation | ||||
|  | ||||
| from .client import DatabaseClient | ||||
|  | ||||
|  | ||||
| class DatabaseCreation(BaseDatabaseCreation): | ||||
|     def sql_table_creation_suffix(self): | ||||
|         suffix = [] | ||||
|         test_settings = self.connection.settings_dict["TEST"] | ||||
|         if test_settings["CHARSET"]: | ||||
|             suffix.append("CHARACTER SET %s" % test_settings["CHARSET"]) | ||||
|         if test_settings["COLLATION"]: | ||||
|             suffix.append("COLLATE %s" % test_settings["COLLATION"]) | ||||
|         return " ".join(suffix) | ||||
|  | ||||
|     def _execute_create_test_db(self, cursor, parameters, keepdb=False): | ||||
|         try: | ||||
|             super()._execute_create_test_db(cursor, parameters, keepdb) | ||||
|         except Exception as e: | ||||
|             if len(e.args) < 1 or e.args[0] != 1007: | ||||
|                 # All errors except "database exists" (1007) cancel tests. | ||||
|                 self.log("Got an error creating the test database: %s" % e) | ||||
|                 sys.exit(2) | ||||
|             else: | ||||
|                 raise | ||||
|  | ||||
|     def _clone_test_db(self, suffix, verbosity, keepdb=False): | ||||
|         source_database_name = self.connection.settings_dict["NAME"] | ||||
|         target_database_name = self.get_test_db_clone_settings(suffix)["NAME"] | ||||
|         test_db_params = { | ||||
|             "dbname": self.connection.ops.quote_name(target_database_name), | ||||
|             "suffix": self.sql_table_creation_suffix(), | ||||
|         } | ||||
|         with self._nodb_cursor() as cursor: | ||||
|             try: | ||||
|                 self._execute_create_test_db(cursor, test_db_params, keepdb) | ||||
|             except Exception: | ||||
|                 if keepdb: | ||||
|                     # If the database should be kept, skip everything else. | ||||
|                     return | ||||
|                 try: | ||||
|                     if verbosity >= 1: | ||||
|                         self.log( | ||||
|                             "Destroying old test database for alias %s..." | ||||
|                             % ( | ||||
|                                 self._get_database_display_str( | ||||
|                                     verbosity, target_database_name | ||||
|                                 ), | ||||
|                             ) | ||||
|                         ) | ||||
|                     cursor.execute("DROP DATABASE %(dbname)s" % test_db_params) | ||||
|                     self._execute_create_test_db(cursor, test_db_params, keepdb) | ||||
|                 except Exception as e: | ||||
|                     self.log("Got an error recreating the test database: %s" % e) | ||||
|                     sys.exit(2) | ||||
|         self._clone_db(source_database_name, target_database_name) | ||||
|  | ||||
|     def _clone_db(self, source_database_name, target_database_name): | ||||
|         cmd_args, cmd_env = DatabaseClient.settings_to_cmd_args_env( | ||||
|             self.connection.settings_dict, [] | ||||
|         ) | ||||
|         dump_cmd = [ | ||||
|             "mysqldump", | ||||
|             *cmd_args[1:-1], | ||||
|             "--routines", | ||||
|             "--events", | ||||
|             source_database_name, | ||||
|         ] | ||||
|         dump_env = load_env = {**os.environ, **cmd_env} if cmd_env else None | ||||
|         load_cmd = cmd_args | ||||
|         load_cmd[-1] = target_database_name | ||||
|  | ||||
|         with subprocess.Popen( | ||||
|             dump_cmd, stdout=subprocess.PIPE, env=dump_env | ||||
|         ) as dump_proc: | ||||
|             with subprocess.Popen( | ||||
|                 load_cmd, | ||||
|                 stdin=dump_proc.stdout, | ||||
|                 stdout=subprocess.DEVNULL, | ||||
|                 env=load_env, | ||||
|             ): | ||||
|                 # Allow dump_proc to receive a SIGPIPE if the load process exits. | ||||
|                 dump_proc.stdout.close() | ||||
| @ -0,0 +1,350 @@ | ||||
| import operator | ||||
|  | ||||
| from django.db.backends.base.features import BaseDatabaseFeatures | ||||
| from django.utils.functional import cached_property | ||||
|  | ||||
|  | ||||
| class DatabaseFeatures(BaseDatabaseFeatures): | ||||
|     empty_fetchmany_value = () | ||||
|     allows_group_by_selected_pks = True | ||||
|     related_fields_match_type = True | ||||
|     # MySQL doesn't support sliced subqueries with IN/ALL/ANY/SOME. | ||||
|     allow_sliced_subqueries_with_in = False | ||||
|     has_select_for_update = True | ||||
|     supports_forward_references = False | ||||
|     supports_regex_backreferencing = False | ||||
|     supports_date_lookup_using_string = False | ||||
|     supports_timezones = False | ||||
|     requires_explicit_null_ordering_when_grouping = True | ||||
|     atomic_transactions = False | ||||
|     can_clone_databases = True | ||||
|     supports_comments = True | ||||
|     supports_comments_inline = True | ||||
|     supports_temporal_subtraction = True | ||||
|     supports_slicing_ordering_in_compound = True | ||||
|     supports_index_on_text_field = False | ||||
|     supports_update_conflicts = True | ||||
|     create_test_procedure_without_params_sql = """ | ||||
|         CREATE PROCEDURE test_procedure () | ||||
|         BEGIN | ||||
|             DECLARE V_I INTEGER; | ||||
|             SET V_I = 1; | ||||
|         END; | ||||
|     """ | ||||
|     create_test_procedure_with_int_param_sql = """ | ||||
|         CREATE PROCEDURE test_procedure (P_I INTEGER) | ||||
|         BEGIN | ||||
|             DECLARE V_I INTEGER; | ||||
|             SET V_I = P_I; | ||||
|         END; | ||||
|     """ | ||||
|     create_test_table_with_composite_primary_key = """ | ||||
|         CREATE TABLE test_table_composite_pk ( | ||||
|             column_1 INTEGER NOT NULL, | ||||
|             column_2 INTEGER NOT NULL, | ||||
|             PRIMARY KEY(column_1, column_2) | ||||
|         ) | ||||
|     """ | ||||
|     # Neither MySQL nor MariaDB support partial indexes. | ||||
|     supports_partial_indexes = False | ||||
|     # COLLATE must be wrapped in parentheses because MySQL treats COLLATE as an | ||||
|     # indexed expression. | ||||
|     collate_as_index_expression = True | ||||
|  | ||||
|     supports_order_by_nulls_modifier = False | ||||
|     order_by_nulls_first = True | ||||
|     supports_logical_xor = True | ||||
|  | ||||
|     @cached_property | ||||
|     def minimum_database_version(self): | ||||
|         if self.connection.mysql_is_mariadb: | ||||
|             return (10, 4) | ||||
|         else: | ||||
|             return (8,) | ||||
|  | ||||
|     @cached_property | ||||
|     def test_collations(self): | ||||
|         charset = "utf8" | ||||
|         if ( | ||||
|             self.connection.mysql_is_mariadb | ||||
|             and self.connection.mysql_version >= (10, 6) | ||||
|         ) or ( | ||||
|             not self.connection.mysql_is_mariadb | ||||
|             and self.connection.mysql_version >= (8, 0, 30) | ||||
|         ): | ||||
|             # utf8 is an alias for utf8mb3 in MariaDB 10.6+ and MySQL 8.0.30+. | ||||
|             charset = "utf8mb3" | ||||
|         return { | ||||
|             "ci": f"{charset}_general_ci", | ||||
|             "non_default": f"{charset}_esperanto_ci", | ||||
|             "swedish_ci": f"{charset}_swedish_ci", | ||||
|         } | ||||
|  | ||||
|     test_now_utc_template = "UTC_TIMESTAMP(6)" | ||||
|  | ||||
|     @cached_property | ||||
|     def django_test_skips(self): | ||||
|         skips = { | ||||
|             "This doesn't work on MySQL.": { | ||||
|                 "db_functions.comparison.test_greatest.GreatestTests." | ||||
|                 "test_coalesce_workaround", | ||||
|                 "db_functions.comparison.test_least.LeastTests." | ||||
|                 "test_coalesce_workaround", | ||||
|             }, | ||||
|             "Running on MySQL requires utf8mb4 encoding (#18392).": { | ||||
|                 "model_fields.test_textfield.TextFieldTests.test_emoji", | ||||
|                 "model_fields.test_charfield.TestCharField.test_emoji", | ||||
|             }, | ||||
|             "MySQL doesn't support functional indexes on a function that " | ||||
|             "returns JSON": { | ||||
|                 "schema.tests.SchemaTests.test_func_index_json_key_transform", | ||||
|             }, | ||||
|             "MySQL supports multiplying and dividing DurationFields by a " | ||||
|             "scalar value but it's not implemented (#25287).": { | ||||
|                 "expressions.tests.FTimeDeltaTests.test_durationfield_multiply_divide", | ||||
|             }, | ||||
|             "UPDATE ... ORDER BY syntax on MySQL/MariaDB does not support ordering by" | ||||
|             "related fields.": { | ||||
|                 "update.tests.AdvancedTests." | ||||
|                 "test_update_ordered_by_inline_m2m_annotation", | ||||
|                 "update.tests.AdvancedTests.test_update_ordered_by_m2m_annotation", | ||||
|             }, | ||||
|         } | ||||
|         if self.connection.mysql_is_mariadb and ( | ||||
|             10, | ||||
|             4, | ||||
|             3, | ||||
|         ) < self.connection.mysql_version < (10, 5, 2): | ||||
|             skips.update( | ||||
|                 { | ||||
|                     "https://jira.mariadb.org/browse/MDEV-19598": { | ||||
|                         "schema.tests.SchemaTests." | ||||
|                         "test_alter_not_unique_field_to_primary_key", | ||||
|                     }, | ||||
|                 } | ||||
|             ) | ||||
|         if self.connection.mysql_is_mariadb and ( | ||||
|             10, | ||||
|             4, | ||||
|             12, | ||||
|         ) < self.connection.mysql_version < (10, 5): | ||||
|             skips.update( | ||||
|                 { | ||||
|                     "https://jira.mariadb.org/browse/MDEV-22775": { | ||||
|                         "schema.tests.SchemaTests." | ||||
|                         "test_alter_pk_with_self_referential_field", | ||||
|                     }, | ||||
|                 } | ||||
|             ) | ||||
|         if not self.supports_explain_analyze: | ||||
|             skips.update( | ||||
|                 { | ||||
|                     "MariaDB and MySQL >= 8.0.18 specific.": { | ||||
|                         "queries.test_explain.ExplainTests.test_mysql_analyze", | ||||
|                     }, | ||||
|                 } | ||||
|             ) | ||||
|         if "ONLY_FULL_GROUP_BY" in self.connection.sql_mode: | ||||
|             skips.update( | ||||
|                 { | ||||
|                     "GROUP BY cannot contain nonaggregated column when " | ||||
|                     "ONLY_FULL_GROUP_BY mode is enabled on MySQL, see #34262.": { | ||||
|                         "aggregation.tests.AggregateTestCase." | ||||
|                         "test_group_by_nested_expression_with_params", | ||||
|                     }, | ||||
|                 } | ||||
|             ) | ||||
|         if self.connection.mysql_is_mariadb and self.connection.mysql_version >= ( | ||||
|             10, | ||||
|             5, | ||||
|             2, | ||||
|         ): | ||||
|             skips.update( | ||||
|                 { | ||||
|                     "ALTER TABLE ... RENAME COLUMN statement doesn't rename inline " | ||||
|                     "constraints on MariaDB 10.5.2+, this is fixed in Django 5.0+ " | ||||
|                     "(#34320).": { | ||||
|                         "schema.tests.SchemaTests." | ||||
|                         "test_rename_field_with_check_to_truncated_name", | ||||
|                     }, | ||||
|                 } | ||||
|             ) | ||||
|         return skips | ||||
|  | ||||
|     @cached_property | ||||
|     def _mysql_storage_engine(self): | ||||
|         "Internal method used in Django tests. Don't rely on this from your code" | ||||
|         return self.connection.mysql_server_data["default_storage_engine"] | ||||
|  | ||||
|     @cached_property | ||||
|     def allows_auto_pk_0(self): | ||||
|         """ | ||||
|         Autoincrement primary key can be set to 0 if it doesn't generate new | ||||
|         autoincrement values. | ||||
|         """ | ||||
|         return "NO_AUTO_VALUE_ON_ZERO" in self.connection.sql_mode | ||||
|  | ||||
|     @cached_property | ||||
|     def update_can_self_select(self): | ||||
|         return self.connection.mysql_is_mariadb and self.connection.mysql_version >= ( | ||||
|             10, | ||||
|             3, | ||||
|             2, | ||||
|         ) | ||||
|  | ||||
|     @cached_property | ||||
|     def can_introspect_foreign_keys(self): | ||||
|         "Confirm support for introspected foreign keys" | ||||
|         return self._mysql_storage_engine != "MyISAM" | ||||
|  | ||||
|     @cached_property | ||||
|     def introspected_field_types(self): | ||||
|         return { | ||||
|             **super().introspected_field_types, | ||||
|             "BinaryField": "TextField", | ||||
|             "BooleanField": "IntegerField", | ||||
|             "DurationField": "BigIntegerField", | ||||
|             "GenericIPAddressField": "CharField", | ||||
|         } | ||||
|  | ||||
|     @cached_property | ||||
|     def can_return_columns_from_insert(self): | ||||
|         return self.connection.mysql_is_mariadb and self.connection.mysql_version >= ( | ||||
|             10, | ||||
|             5, | ||||
|             0, | ||||
|         ) | ||||
|  | ||||
|     can_return_rows_from_bulk_insert = property( | ||||
|         operator.attrgetter("can_return_columns_from_insert") | ||||
|     ) | ||||
|  | ||||
|     @cached_property | ||||
|     def has_zoneinfo_database(self): | ||||
|         return self.connection.mysql_server_data["has_zoneinfo_database"] | ||||
|  | ||||
|     @cached_property | ||||
|     def is_sql_auto_is_null_enabled(self): | ||||
|         return self.connection.mysql_server_data["sql_auto_is_null"] | ||||
|  | ||||
|     @cached_property | ||||
|     def supports_over_clause(self): | ||||
|         if self.connection.mysql_is_mariadb: | ||||
|             return True | ||||
|         return self.connection.mysql_version >= (8, 0, 2) | ||||
|  | ||||
|     supports_frame_range_fixed_distance = property( | ||||
|         operator.attrgetter("supports_over_clause") | ||||
|     ) | ||||
|  | ||||
|     @cached_property | ||||
|     def supports_column_check_constraints(self): | ||||
|         if self.connection.mysql_is_mariadb: | ||||
|             return True | ||||
|         return self.connection.mysql_version >= (8, 0, 16) | ||||
|  | ||||
|     supports_table_check_constraints = property( | ||||
|         operator.attrgetter("supports_column_check_constraints") | ||||
|     ) | ||||
|  | ||||
|     @cached_property | ||||
|     def can_introspect_check_constraints(self): | ||||
|         if self.connection.mysql_is_mariadb: | ||||
|             return True | ||||
|         return self.connection.mysql_version >= (8, 0, 16) | ||||
|  | ||||
|     @cached_property | ||||
|     def has_select_for_update_skip_locked(self): | ||||
|         if self.connection.mysql_is_mariadb: | ||||
|             return self.connection.mysql_version >= (10, 6) | ||||
|         return self.connection.mysql_version >= (8, 0, 1) | ||||
|  | ||||
|     @cached_property | ||||
|     def has_select_for_update_nowait(self): | ||||
|         if self.connection.mysql_is_mariadb: | ||||
|             return True | ||||
|         return self.connection.mysql_version >= (8, 0, 1) | ||||
|  | ||||
|     @cached_property | ||||
|     def has_select_for_update_of(self): | ||||
|         return ( | ||||
|             not self.connection.mysql_is_mariadb | ||||
|             and self.connection.mysql_version >= (8, 0, 1) | ||||
|         ) | ||||
|  | ||||
|     @cached_property | ||||
|     def supports_explain_analyze(self): | ||||
|         return self.connection.mysql_is_mariadb or self.connection.mysql_version >= ( | ||||
|             8, | ||||
|             0, | ||||
|             18, | ||||
|         ) | ||||
|  | ||||
|     @cached_property | ||||
|     def supported_explain_formats(self): | ||||
|         # Alias MySQL's TRADITIONAL to TEXT for consistency with other | ||||
|         # backends. | ||||
|         formats = {"JSON", "TEXT", "TRADITIONAL"} | ||||
|         if not self.connection.mysql_is_mariadb and self.connection.mysql_version >= ( | ||||
|             8, | ||||
|             0, | ||||
|             16, | ||||
|         ): | ||||
|             formats.add("TREE") | ||||
|         return formats | ||||
|  | ||||
|     @cached_property | ||||
|     def supports_transactions(self): | ||||
|         """ | ||||
|         All storage engines except MyISAM support transactions. | ||||
|         """ | ||||
|         return self._mysql_storage_engine != "MyISAM" | ||||
|  | ||||
|     uses_savepoints = property(operator.attrgetter("supports_transactions")) | ||||
|     can_release_savepoints = property(operator.attrgetter("supports_transactions")) | ||||
|  | ||||
|     @cached_property | ||||
|     def ignores_table_name_case(self): | ||||
|         return self.connection.mysql_server_data["lower_case_table_names"] | ||||
|  | ||||
|     @cached_property | ||||
|     def supports_default_in_lead_lag(self): | ||||
|         # To be added in https://jira.mariadb.org/browse/MDEV-12981. | ||||
|         return not self.connection.mysql_is_mariadb | ||||
|  | ||||
|     @cached_property | ||||
|     def can_introspect_json_field(self): | ||||
|         if self.connection.mysql_is_mariadb: | ||||
|             return self.can_introspect_check_constraints | ||||
|         return True | ||||
|  | ||||
|     @cached_property | ||||
|     def supports_index_column_ordering(self): | ||||
|         if self._mysql_storage_engine != "InnoDB": | ||||
|             return False | ||||
|         if self.connection.mysql_is_mariadb: | ||||
|             return self.connection.mysql_version >= (10, 8) | ||||
|         return self.connection.mysql_version >= (8, 0, 1) | ||||
|  | ||||
|     @cached_property | ||||
|     def supports_expression_indexes(self): | ||||
|         return ( | ||||
|             not self.connection.mysql_is_mariadb | ||||
|             and self._mysql_storage_engine != "MyISAM" | ||||
|             and self.connection.mysql_version >= (8, 0, 13) | ||||
|         ) | ||||
|  | ||||
|     @cached_property | ||||
|     def supports_select_intersection(self): | ||||
|         is_mariadb = self.connection.mysql_is_mariadb | ||||
|         return is_mariadb or self.connection.mysql_version >= (8, 0, 31) | ||||
|  | ||||
|     supports_select_difference = property( | ||||
|         operator.attrgetter("supports_select_intersection") | ||||
|     ) | ||||
|  | ||||
|     @cached_property | ||||
|     def can_rename_index(self): | ||||
|         if self.connection.mysql_is_mariadb: | ||||
|             return self.connection.mysql_version >= (10, 5, 2) | ||||
|         return True | ||||
| @ -0,0 +1,349 @@ | ||||
| from collections import namedtuple | ||||
|  | ||||
| import sqlparse | ||||
| from MySQLdb.constants import FIELD_TYPE | ||||
|  | ||||
| from django.db.backends.base.introspection import BaseDatabaseIntrospection | ||||
| from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo | ||||
| from django.db.backends.base.introspection import TableInfo as BaseTableInfo | ||||
| from django.db.models import Index | ||||
| from django.utils.datastructures import OrderedSet | ||||
|  | ||||
| FieldInfo = namedtuple( | ||||
|     "FieldInfo", | ||||
|     BaseFieldInfo._fields + ("extra", "is_unsigned", "has_json_constraint", "comment"), | ||||
| ) | ||||
| InfoLine = namedtuple( | ||||
|     "InfoLine", | ||||
|     "col_name data_type max_len num_prec num_scale extra column_default " | ||||
|     "collation is_unsigned comment", | ||||
| ) | ||||
| TableInfo = namedtuple("TableInfo", BaseTableInfo._fields + ("comment",)) | ||||
|  | ||||
|  | ||||
| class DatabaseIntrospection(BaseDatabaseIntrospection): | ||||
|     data_types_reverse = { | ||||
|         FIELD_TYPE.BLOB: "TextField", | ||||
|         FIELD_TYPE.CHAR: "CharField", | ||||
|         FIELD_TYPE.DECIMAL: "DecimalField", | ||||
|         FIELD_TYPE.NEWDECIMAL: "DecimalField", | ||||
|         FIELD_TYPE.DATE: "DateField", | ||||
|         FIELD_TYPE.DATETIME: "DateTimeField", | ||||
|         FIELD_TYPE.DOUBLE: "FloatField", | ||||
|         FIELD_TYPE.FLOAT: "FloatField", | ||||
|         FIELD_TYPE.INT24: "IntegerField", | ||||
|         FIELD_TYPE.JSON: "JSONField", | ||||
|         FIELD_TYPE.LONG: "IntegerField", | ||||
|         FIELD_TYPE.LONGLONG: "BigIntegerField", | ||||
|         FIELD_TYPE.SHORT: "SmallIntegerField", | ||||
|         FIELD_TYPE.STRING: "CharField", | ||||
|         FIELD_TYPE.TIME: "TimeField", | ||||
|         FIELD_TYPE.TIMESTAMP: "DateTimeField", | ||||
|         FIELD_TYPE.TINY: "IntegerField", | ||||
|         FIELD_TYPE.TINY_BLOB: "TextField", | ||||
|         FIELD_TYPE.MEDIUM_BLOB: "TextField", | ||||
|         FIELD_TYPE.LONG_BLOB: "TextField", | ||||
|         FIELD_TYPE.VAR_STRING: "CharField", | ||||
|     } | ||||
|  | ||||
|     def get_field_type(self, data_type, description): | ||||
|         field_type = super().get_field_type(data_type, description) | ||||
|         if "auto_increment" in description.extra: | ||||
|             if field_type == "IntegerField": | ||||
|                 return "AutoField" | ||||
|             elif field_type == "BigIntegerField": | ||||
|                 return "BigAutoField" | ||||
|             elif field_type == "SmallIntegerField": | ||||
|                 return "SmallAutoField" | ||||
|         if description.is_unsigned: | ||||
|             if field_type == "BigIntegerField": | ||||
|                 return "PositiveBigIntegerField" | ||||
|             elif field_type == "IntegerField": | ||||
|                 return "PositiveIntegerField" | ||||
|             elif field_type == "SmallIntegerField": | ||||
|                 return "PositiveSmallIntegerField" | ||||
|         # JSON data type is an alias for LONGTEXT in MariaDB, use check | ||||
|         # constraints clauses to introspect JSONField. | ||||
|         if description.has_json_constraint: | ||||
|             return "JSONField" | ||||
|         return field_type | ||||
|  | ||||
|     def get_table_list(self, cursor): | ||||
|         """Return a list of table and view names in the current database.""" | ||||
|         cursor.execute( | ||||
|             """ | ||||
|             SELECT | ||||
|                 table_name, | ||||
|                 table_type, | ||||
|                 table_comment | ||||
|             FROM information_schema.tables | ||||
|             WHERE table_schema = DATABASE() | ||||
|             """ | ||||
|         ) | ||||
|         return [ | ||||
|             TableInfo(row[0], {"BASE TABLE": "t", "VIEW": "v"}.get(row[1]), row[2]) | ||||
|             for row in cursor.fetchall() | ||||
|         ] | ||||
|  | ||||
|     def get_table_description(self, cursor, table_name): | ||||
|         """ | ||||
|         Return a description of the table with the DB-API cursor.description | ||||
|         interface." | ||||
|         """ | ||||
|         json_constraints = {} | ||||
|         if ( | ||||
|             self.connection.mysql_is_mariadb | ||||
|             and self.connection.features.can_introspect_json_field | ||||
|         ): | ||||
|             # JSON data type is an alias for LONGTEXT in MariaDB, select | ||||
|             # JSON_VALID() constraints to introspect JSONField. | ||||
|             cursor.execute( | ||||
|                 """ | ||||
|                 SELECT c.constraint_name AS column_name | ||||
|                 FROM information_schema.check_constraints AS c | ||||
|                 WHERE | ||||
|                     c.table_name = %s AND | ||||
|                     LOWER(c.check_clause) = | ||||
|                         'json_valid(`' + LOWER(c.constraint_name) + '`)' AND | ||||
|                     c.constraint_schema = DATABASE() | ||||
|                 """, | ||||
|                 [table_name], | ||||
|             ) | ||||
|             json_constraints = {row[0] for row in cursor.fetchall()} | ||||
|         # A default collation for the given table. | ||||
|         cursor.execute( | ||||
|             """ | ||||
|             SELECT  table_collation | ||||
|             FROM    information_schema.tables | ||||
|             WHERE   table_schema = DATABASE() | ||||
|             AND     table_name = %s | ||||
|             """, | ||||
|             [table_name], | ||||
|         ) | ||||
|         row = cursor.fetchone() | ||||
|         default_column_collation = row[0] if row else "" | ||||
|         # information_schema database gives more accurate results for some figures: | ||||
|         # - varchar length returned by cursor.description is an internal length, | ||||
|         #   not visible length (#5725) | ||||
|         # - precision and scale (for decimal fields) (#5014) | ||||
|         # - auto_increment is not available in cursor.description | ||||
|         cursor.execute( | ||||
|             """ | ||||
|             SELECT | ||||
|                 column_name, data_type, character_maximum_length, | ||||
|                 numeric_precision, numeric_scale, extra, column_default, | ||||
|                 CASE | ||||
|                     WHEN collation_name = %s THEN NULL | ||||
|                     ELSE collation_name | ||||
|                 END AS collation_name, | ||||
|                 CASE | ||||
|                     WHEN column_type LIKE '%% unsigned' THEN 1 | ||||
|                     ELSE 0 | ||||
|                 END AS is_unsigned, | ||||
|                 column_comment | ||||
|             FROM information_schema.columns | ||||
|             WHERE table_name = %s AND table_schema = DATABASE() | ||||
|             """, | ||||
|             [default_column_collation, table_name], | ||||
|         ) | ||||
|         field_info = {line[0]: InfoLine(*line) for line in cursor.fetchall()} | ||||
|  | ||||
|         cursor.execute( | ||||
|             "SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name) | ||||
|         ) | ||||
|  | ||||
|         def to_int(i): | ||||
|             return int(i) if i is not None else i | ||||
|  | ||||
|         fields = [] | ||||
|         for line in cursor.description: | ||||
|             info = field_info[line[0]] | ||||
|             fields.append( | ||||
|                 FieldInfo( | ||||
|                     *line[:2], | ||||
|                     to_int(info.max_len) or line[2], | ||||
|                     to_int(info.max_len) or line[3], | ||||
|                     to_int(info.num_prec) or line[4], | ||||
|                     to_int(info.num_scale) or line[5], | ||||
|                     line[6], | ||||
|                     info.column_default, | ||||
|                     info.collation, | ||||
|                     info.extra, | ||||
|                     info.is_unsigned, | ||||
|                     line[0] in json_constraints, | ||||
|                     info.comment, | ||||
|                 ) | ||||
|             ) | ||||
|         return fields | ||||
|  | ||||
|     def get_sequences(self, cursor, table_name, table_fields=()): | ||||
|         for field_info in self.get_table_description(cursor, table_name): | ||||
|             if "auto_increment" in field_info.extra: | ||||
|                 # MySQL allows only one auto-increment column per table. | ||||
|                 return [{"table": table_name, "column": field_info.name}] | ||||
|         return [] | ||||
|  | ||||
|     def get_relations(self, cursor, table_name): | ||||
|         """ | ||||
|         Return a dictionary of {field_name: (field_name_other_table, other_table)} | ||||
|         representing all foreign keys in the given table. | ||||
|         """ | ||||
|         cursor.execute( | ||||
|             """ | ||||
|             SELECT column_name, referenced_column_name, referenced_table_name | ||||
|             FROM information_schema.key_column_usage | ||||
|             WHERE table_name = %s | ||||
|                 AND table_schema = DATABASE() | ||||
|                 AND referenced_table_name IS NOT NULL | ||||
|                 AND referenced_column_name IS NOT NULL | ||||
|             """, | ||||
|             [table_name], | ||||
|         ) | ||||
|         return { | ||||
|             field_name: (other_field, other_table) | ||||
|             for field_name, other_field, other_table in cursor.fetchall() | ||||
|         } | ||||
|  | ||||
|     def get_storage_engine(self, cursor, table_name): | ||||
|         """ | ||||
|         Retrieve the storage engine for a given table. Return the default | ||||
|         storage engine if the table doesn't exist. | ||||
|         """ | ||||
|         cursor.execute( | ||||
|             """ | ||||
|             SELECT engine | ||||
|             FROM information_schema.tables | ||||
|             WHERE | ||||
|                 table_name = %s AND | ||||
|                 table_schema = DATABASE() | ||||
|             """, | ||||
|             [table_name], | ||||
|         ) | ||||
|         result = cursor.fetchone() | ||||
|         if not result: | ||||
|             return self.connection.features._mysql_storage_engine | ||||
|         return result[0] | ||||
|  | ||||
|     def _parse_constraint_columns(self, check_clause, columns): | ||||
|         check_columns = OrderedSet() | ||||
|         statement = sqlparse.parse(check_clause)[0] | ||||
|         tokens = (token for token in statement.flatten() if not token.is_whitespace) | ||||
|         for token in tokens: | ||||
|             if ( | ||||
|                 token.ttype == sqlparse.tokens.Name | ||||
|                 and self.connection.ops.quote_name(token.value) == token.value | ||||
|                 and token.value[1:-1] in columns | ||||
|             ): | ||||
|                 check_columns.add(token.value[1:-1]) | ||||
|         return check_columns | ||||
|  | ||||
|     def get_constraints(self, cursor, table_name): | ||||
|         """ | ||||
|         Retrieve any constraints or keys (unique, pk, fk, check, index) across | ||||
|         one or more columns. | ||||
|         """ | ||||
|         constraints = {} | ||||
|         # Get the actual constraint names and columns | ||||
|         name_query = """ | ||||
|             SELECT kc.`constraint_name`, kc.`column_name`, | ||||
|                 kc.`referenced_table_name`, kc.`referenced_column_name`, | ||||
|                 c.`constraint_type` | ||||
|             FROM | ||||
|                 information_schema.key_column_usage AS kc, | ||||
|                 information_schema.table_constraints AS c | ||||
|             WHERE | ||||
|                 kc.table_schema = DATABASE() AND | ||||
|                 c.table_schema = kc.table_schema AND | ||||
|                 c.constraint_name = kc.constraint_name AND | ||||
|                 c.constraint_type != 'CHECK' AND | ||||
|                 kc.table_name = %s | ||||
|             ORDER BY kc.`ordinal_position` | ||||
|         """ | ||||
|         cursor.execute(name_query, [table_name]) | ||||
|         for constraint, column, ref_table, ref_column, kind in cursor.fetchall(): | ||||
|             if constraint not in constraints: | ||||
|                 constraints[constraint] = { | ||||
|                     "columns": OrderedSet(), | ||||
|                     "primary_key": kind == "PRIMARY KEY", | ||||
|                     "unique": kind in {"PRIMARY KEY", "UNIQUE"}, | ||||
|                     "index": False, | ||||
|                     "check": False, | ||||
|                     "foreign_key": (ref_table, ref_column) if ref_column else None, | ||||
|                 } | ||||
|                 if self.connection.features.supports_index_column_ordering: | ||||
|                     constraints[constraint]["orders"] = [] | ||||
|             constraints[constraint]["columns"].add(column) | ||||
|         # Add check constraints. | ||||
|         if self.connection.features.can_introspect_check_constraints: | ||||
|             unnamed_constraints_index = 0 | ||||
|             columns = { | ||||
|                 info.name for info in self.get_table_description(cursor, table_name) | ||||
|             } | ||||
|             if self.connection.mysql_is_mariadb: | ||||
|                 type_query = """ | ||||
|                     SELECT c.constraint_name, c.check_clause | ||||
|                     FROM information_schema.check_constraints AS c | ||||
|                     WHERE | ||||
|                         c.constraint_schema = DATABASE() AND | ||||
|                         c.table_name = %s | ||||
|                 """ | ||||
|             else: | ||||
|                 type_query = """ | ||||
|                     SELECT cc.constraint_name, cc.check_clause | ||||
|                     FROM | ||||
|                         information_schema.check_constraints AS cc, | ||||
|                         information_schema.table_constraints AS tc | ||||
|                     WHERE | ||||
|                         cc.constraint_schema = DATABASE() AND | ||||
|                         tc.table_schema = cc.constraint_schema AND | ||||
|                         cc.constraint_name = tc.constraint_name AND | ||||
|                         tc.constraint_type = 'CHECK' AND | ||||
|                         tc.table_name = %s | ||||
|                 """ | ||||
|             cursor.execute(type_query, [table_name]) | ||||
|             for constraint, check_clause in cursor.fetchall(): | ||||
|                 constraint_columns = self._parse_constraint_columns( | ||||
|                     check_clause, columns | ||||
|                 ) | ||||
|                 # Ensure uniqueness of unnamed constraints. Unnamed unique | ||||
|                 # and check columns constraints have the same name as | ||||
|                 # a column. | ||||
|                 if set(constraint_columns) == {constraint}: | ||||
|                     unnamed_constraints_index += 1 | ||||
|                     constraint = "__unnamed_constraint_%s__" % unnamed_constraints_index | ||||
|                 constraints[constraint] = { | ||||
|                     "columns": constraint_columns, | ||||
|                     "primary_key": False, | ||||
|                     "unique": False, | ||||
|                     "index": False, | ||||
|                     "check": True, | ||||
|                     "foreign_key": None, | ||||
|                 } | ||||
|         # Now add in the indexes | ||||
|         cursor.execute( | ||||
|             "SHOW INDEX FROM %s" % self.connection.ops.quote_name(table_name) | ||||
|         ) | ||||
|         for table, non_unique, index, colseq, column, order, type_ in [ | ||||
|             x[:6] + (x[10],) for x in cursor.fetchall() | ||||
|         ]: | ||||
|             if index not in constraints: | ||||
|                 constraints[index] = { | ||||
|                     "columns": OrderedSet(), | ||||
|                     "primary_key": False, | ||||
|                     "unique": not non_unique, | ||||
|                     "check": False, | ||||
|                     "foreign_key": None, | ||||
|                 } | ||||
|                 if self.connection.features.supports_index_column_ordering: | ||||
|                     constraints[index]["orders"] = [] | ||||
|             constraints[index]["index"] = True | ||||
|             constraints[index]["type"] = ( | ||||
|                 Index.suffix if type_ == "BTREE" else type_.lower() | ||||
|             ) | ||||
|             constraints[index]["columns"].add(column) | ||||
|             if self.connection.features.supports_index_column_ordering: | ||||
|                 constraints[index]["orders"].append("DESC" if order == "D" else "ASC") | ||||
|         # Convert the sorted sets to lists | ||||
|         for constraint in constraints.values(): | ||||
|             constraint["columns"] = list(constraint["columns"]) | ||||
|         return constraints | ||||
| @ -0,0 +1,464 @@ | ||||
| import uuid | ||||
|  | ||||
| from django.conf import settings | ||||
| from django.db.backends.base.operations import BaseDatabaseOperations | ||||
| from django.db.backends.utils import split_tzname_delta | ||||
| from django.db.models import Exists, ExpressionWrapper, Lookup | ||||
| from django.db.models.constants import OnConflict | ||||
| from django.utils import timezone | ||||
| from django.utils.encoding import force_str | ||||
| from django.utils.regex_helper import _lazy_re_compile | ||||
|  | ||||
|  | ||||
| class DatabaseOperations(BaseDatabaseOperations): | ||||
|     compiler_module = "django.db.backends.mysql.compiler" | ||||
|  | ||||
|     # MySQL stores positive fields as UNSIGNED ints. | ||||
|     integer_field_ranges = { | ||||
|         **BaseDatabaseOperations.integer_field_ranges, | ||||
|         "PositiveSmallIntegerField": (0, 65535), | ||||
|         "PositiveIntegerField": (0, 4294967295), | ||||
|         "PositiveBigIntegerField": (0, 18446744073709551615), | ||||
|     } | ||||
|     cast_data_types = { | ||||
|         "AutoField": "signed integer", | ||||
|         "BigAutoField": "signed integer", | ||||
|         "SmallAutoField": "signed integer", | ||||
|         "CharField": "char(%(max_length)s)", | ||||
|         "DecimalField": "decimal(%(max_digits)s, %(decimal_places)s)", | ||||
|         "TextField": "char", | ||||
|         "IntegerField": "signed integer", | ||||
|         "BigIntegerField": "signed integer", | ||||
|         "SmallIntegerField": "signed integer", | ||||
|         "PositiveBigIntegerField": "unsigned integer", | ||||
|         "PositiveIntegerField": "unsigned integer", | ||||
|         "PositiveSmallIntegerField": "unsigned integer", | ||||
|         "DurationField": "signed integer", | ||||
|     } | ||||
|     cast_char_field_without_max_length = "char" | ||||
|     explain_prefix = "EXPLAIN" | ||||
|  | ||||
|     # EXTRACT format cannot be passed in parameters. | ||||
|     _extract_format_re = _lazy_re_compile(r"[A-Z_]+") | ||||
|  | ||||
|     def date_extract_sql(self, lookup_type, sql, params): | ||||
|         # https://dev.mysql.com/doc/mysql/en/date-and-time-functions.html | ||||
|         if lookup_type == "week_day": | ||||
|             # DAYOFWEEK() returns an integer, 1-7, Sunday=1. | ||||
|             return f"DAYOFWEEK({sql})", params | ||||
|         elif lookup_type == "iso_week_day": | ||||
|             # WEEKDAY() returns an integer, 0-6, Monday=0. | ||||
|             return f"WEEKDAY({sql}) + 1", params | ||||
|         elif lookup_type == "week": | ||||
|             # Override the value of default_week_format for consistency with | ||||
|             # other database backends. | ||||
|             # Mode 3: Monday, 1-53, with 4 or more days this year. | ||||
|             return f"WEEK({sql}, 3)", params | ||||
|         elif lookup_type == "iso_year": | ||||
|             # Get the year part from the YEARWEEK function, which returns a | ||||
|             # number as year * 100 + week. | ||||
|             return f"TRUNCATE(YEARWEEK({sql}, 3), -2) / 100", params | ||||
|         else: | ||||
|             # EXTRACT returns 1-53 based on ISO-8601 for the week number. | ||||
|             lookup_type = lookup_type.upper() | ||||
|             if not self._extract_format_re.fullmatch(lookup_type): | ||||
|                 raise ValueError(f"Invalid loookup type: {lookup_type!r}") | ||||
|             return f"EXTRACT({lookup_type} FROM {sql})", params | ||||
|  | ||||
|     def date_trunc_sql(self, lookup_type, sql, params, tzname=None): | ||||
|         sql, params = self._convert_sql_to_tz(sql, params, tzname) | ||||
|         fields = { | ||||
|             "year": "%Y-01-01", | ||||
|             "month": "%Y-%m-01", | ||||
|         } | ||||
|         if lookup_type in fields: | ||||
|             format_str = fields[lookup_type] | ||||
|             return f"CAST(DATE_FORMAT({sql}, %s) AS DATE)", (*params, format_str) | ||||
|         elif lookup_type == "quarter": | ||||
|             return ( | ||||
|                 f"MAKEDATE(YEAR({sql}), 1) + " | ||||
|                 f"INTERVAL QUARTER({sql}) QUARTER - INTERVAL 1 QUARTER", | ||||
|                 (*params, *params), | ||||
|             ) | ||||
|         elif lookup_type == "week": | ||||
|             return f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY)", (*params, *params) | ||||
|         else: | ||||
|             return f"DATE({sql})", params | ||||
|  | ||||
|     def _prepare_tzname_delta(self, tzname): | ||||
|         tzname, sign, offset = split_tzname_delta(tzname) | ||||
|         return f"{sign}{offset}" if offset else tzname | ||||
|  | ||||
|     def _convert_sql_to_tz(self, sql, params, tzname): | ||||
|         if tzname and settings.USE_TZ and self.connection.timezone_name != tzname: | ||||
|             return f"CONVERT_TZ({sql}, %s, %s)", ( | ||||
|                 *params, | ||||
|                 self.connection.timezone_name, | ||||
|                 self._prepare_tzname_delta(tzname), | ||||
|             ) | ||||
|         return sql, params | ||||
|  | ||||
|     def datetime_cast_date_sql(self, sql, params, tzname): | ||||
|         sql, params = self._convert_sql_to_tz(sql, params, tzname) | ||||
|         return f"DATE({sql})", params | ||||
|  | ||||
|     def datetime_cast_time_sql(self, sql, params, tzname): | ||||
|         sql, params = self._convert_sql_to_tz(sql, params, tzname) | ||||
|         return f"TIME({sql})", params | ||||
|  | ||||
|     def datetime_extract_sql(self, lookup_type, sql, params, tzname): | ||||
|         sql, params = self._convert_sql_to_tz(sql, params, tzname) | ||||
|         return self.date_extract_sql(lookup_type, sql, params) | ||||
|  | ||||
|     def datetime_trunc_sql(self, lookup_type, sql, params, tzname): | ||||
|         sql, params = self._convert_sql_to_tz(sql, params, tzname) | ||||
|         fields = ["year", "month", "day", "hour", "minute", "second"] | ||||
|         format = ("%Y-", "%m", "-%d", " %H:", "%i", ":%s") | ||||
|         format_def = ("0000-", "01", "-01", " 00:", "00", ":00") | ||||
|         if lookup_type == "quarter": | ||||
|             return ( | ||||
|                 f"CAST(DATE_FORMAT(MAKEDATE(YEAR({sql}), 1) + " | ||||
|                 f"INTERVAL QUARTER({sql}) QUARTER - " | ||||
|                 f"INTERVAL 1 QUARTER, %s) AS DATETIME)" | ||||
|             ), (*params, *params, "%Y-%m-01 00:00:00") | ||||
|         if lookup_type == "week": | ||||
|             return ( | ||||
|                 f"CAST(DATE_FORMAT(" | ||||
|                 f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY), %s) AS DATETIME)" | ||||
|             ), (*params, *params, "%Y-%m-%d 00:00:00") | ||||
|         try: | ||||
|             i = fields.index(lookup_type) + 1 | ||||
|         except ValueError: | ||||
|             pass | ||||
|         else: | ||||
|             format_str = "".join(format[:i] + format_def[i:]) | ||||
|             return f"CAST(DATE_FORMAT({sql}, %s) AS DATETIME)", (*params, format_str) | ||||
|         return sql, params | ||||
|  | ||||
|     def time_trunc_sql(self, lookup_type, sql, params, tzname=None): | ||||
|         sql, params = self._convert_sql_to_tz(sql, params, tzname) | ||||
|         fields = { | ||||
|             "hour": "%H:00:00", | ||||
|             "minute": "%H:%i:00", | ||||
|             "second": "%H:%i:%s", | ||||
|         } | ||||
|         if lookup_type in fields: | ||||
|             format_str = fields[lookup_type] | ||||
|             return f"CAST(DATE_FORMAT({sql}, %s) AS TIME)", (*params, format_str) | ||||
|         else: | ||||
|             return f"TIME({sql})", params | ||||
|  | ||||
|     def fetch_returned_insert_rows(self, cursor): | ||||
|         """ | ||||
|         Given a cursor object that has just performed an INSERT...RETURNING | ||||
|         statement into a table, return the tuple of returned data. | ||||
|         """ | ||||
|         return cursor.fetchall() | ||||
|  | ||||
|     def format_for_duration_arithmetic(self, sql): | ||||
|         return "INTERVAL %s MICROSECOND" % sql | ||||
|  | ||||
|     def force_no_ordering(self): | ||||
|         """ | ||||
|         "ORDER BY NULL" prevents MySQL from implicitly ordering by grouped | ||||
|         columns. If no ordering would otherwise be applied, we don't want any | ||||
|         implicit sorting going on. | ||||
|         """ | ||||
|         return [(None, ("NULL", [], False))] | ||||
|  | ||||
|     def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None): | ||||
|         return value | ||||
|  | ||||
|     def last_executed_query(self, cursor, sql, params): | ||||
|         # With MySQLdb, cursor objects have an (undocumented) "_executed" | ||||
|         # attribute where the exact query sent to the database is saved. | ||||
|         # See MySQLdb/cursors.py in the source distribution. | ||||
|         # MySQLdb returns string, PyMySQL bytes. | ||||
|         return force_str(getattr(cursor, "_executed", None), errors="replace") | ||||
|  | ||||
|     def no_limit_value(self): | ||||
|         # 2**64 - 1, as recommended by the MySQL documentation | ||||
|         return 18446744073709551615 | ||||
|  | ||||
|     def quote_name(self, name): | ||||
|         if name.startswith("`") and name.endswith("`"): | ||||
|             return name  # Quoting once is enough. | ||||
|         return "`%s`" % name | ||||
|  | ||||
|     def return_insert_columns(self, fields): | ||||
|         # MySQL and MariaDB < 10.5.0 don't support an INSERT...RETURNING | ||||
|         # statement. | ||||
|         if not fields: | ||||
|             return "", () | ||||
|         columns = [ | ||||
|             "%s.%s" | ||||
|             % ( | ||||
|                 self.quote_name(field.model._meta.db_table), | ||||
|                 self.quote_name(field.column), | ||||
|             ) | ||||
|             for field in fields | ||||
|         ] | ||||
|         return "RETURNING %s" % ", ".join(columns), () | ||||
|  | ||||
|     def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False): | ||||
|         if not tables: | ||||
|             return [] | ||||
|  | ||||
|         sql = ["SET FOREIGN_KEY_CHECKS = 0;"] | ||||
|         if reset_sequences: | ||||
|             # It's faster to TRUNCATE tables that require a sequence reset | ||||
|             # since ALTER TABLE AUTO_INCREMENT is slower than TRUNCATE. | ||||
|             sql.extend( | ||||
|                 "%s %s;" | ||||
|                 % ( | ||||
|                     style.SQL_KEYWORD("TRUNCATE"), | ||||
|                     style.SQL_FIELD(self.quote_name(table_name)), | ||||
|                 ) | ||||
|                 for table_name in tables | ||||
|             ) | ||||
|         else: | ||||
|             # Otherwise issue a simple DELETE since it's faster than TRUNCATE | ||||
|             # and preserves sequences. | ||||
|             sql.extend( | ||||
|                 "%s %s %s;" | ||||
|                 % ( | ||||
|                     style.SQL_KEYWORD("DELETE"), | ||||
|                     style.SQL_KEYWORD("FROM"), | ||||
|                     style.SQL_FIELD(self.quote_name(table_name)), | ||||
|                 ) | ||||
|                 for table_name in tables | ||||
|             ) | ||||
|         sql.append("SET FOREIGN_KEY_CHECKS = 1;") | ||||
|         return sql | ||||
|  | ||||
|     def sequence_reset_by_name_sql(self, style, sequences): | ||||
|         return [ | ||||
|             "%s %s %s %s = 1;" | ||||
|             % ( | ||||
|                 style.SQL_KEYWORD("ALTER"), | ||||
|                 style.SQL_KEYWORD("TABLE"), | ||||
|                 style.SQL_FIELD(self.quote_name(sequence_info["table"])), | ||||
|                 style.SQL_FIELD("AUTO_INCREMENT"), | ||||
|             ) | ||||
|             for sequence_info in sequences | ||||
|         ] | ||||
|  | ||||
|     def validate_autopk_value(self, value): | ||||
|         # Zero in AUTO_INCREMENT field does not work without the | ||||
|         # NO_AUTO_VALUE_ON_ZERO SQL mode. | ||||
|         if value == 0 and not self.connection.features.allows_auto_pk_0: | ||||
|             raise ValueError( | ||||
|                 "The database backend does not accept 0 as a value for AutoField." | ||||
|             ) | ||||
|         return value | ||||
|  | ||||
|     def adapt_datetimefield_value(self, value): | ||||
|         if value is None: | ||||
|             return None | ||||
|  | ||||
|         # Expression values are adapted by the database. | ||||
|         if hasattr(value, "resolve_expression"): | ||||
|             return value | ||||
|  | ||||
|         # MySQL doesn't support tz-aware datetimes | ||||
|         if timezone.is_aware(value): | ||||
|             if settings.USE_TZ: | ||||
|                 value = timezone.make_naive(value, self.connection.timezone) | ||||
|             else: | ||||
|                 raise ValueError( | ||||
|                     "MySQL backend does not support timezone-aware datetimes when " | ||||
|                     "USE_TZ is False." | ||||
|                 ) | ||||
|         return str(value) | ||||
|  | ||||
|     def adapt_timefield_value(self, value): | ||||
|         if value is None: | ||||
|             return None | ||||
|  | ||||
|         # Expression values are adapted by the database. | ||||
|         if hasattr(value, "resolve_expression"): | ||||
|             return value | ||||
|  | ||||
|         # MySQL doesn't support tz-aware times | ||||
|         if timezone.is_aware(value): | ||||
|             raise ValueError("MySQL backend does not support timezone-aware times.") | ||||
|  | ||||
|         return value.isoformat(timespec="microseconds") | ||||
|  | ||||
|     def max_name_length(self): | ||||
|         return 64 | ||||
|  | ||||
|     def pk_default_value(self): | ||||
|         return "NULL" | ||||
|  | ||||
|     def bulk_insert_sql(self, fields, placeholder_rows): | ||||
|         placeholder_rows_sql = (", ".join(row) for row in placeholder_rows) | ||||
|         values_sql = ", ".join("(%s)" % sql for sql in placeholder_rows_sql) | ||||
|         return "VALUES " + values_sql | ||||
|  | ||||
|     def combine_expression(self, connector, sub_expressions): | ||||
|         if connector == "^": | ||||
|             return "POW(%s)" % ",".join(sub_expressions) | ||||
|         # Convert the result to a signed integer since MySQL's binary operators | ||||
|         # return an unsigned integer. | ||||
|         elif connector in ("&", "|", "<<", "#"): | ||||
|             connector = "^" if connector == "#" else connector | ||||
|             return "CONVERT(%s, SIGNED)" % connector.join(sub_expressions) | ||||
|         elif connector == ">>": | ||||
|             lhs, rhs = sub_expressions | ||||
|             return "FLOOR(%(lhs)s / POW(2, %(rhs)s))" % {"lhs": lhs, "rhs": rhs} | ||||
|         return super().combine_expression(connector, sub_expressions) | ||||
|  | ||||
|     def get_db_converters(self, expression): | ||||
|         converters = super().get_db_converters(expression) | ||||
|         internal_type = expression.output_field.get_internal_type() | ||||
|         if internal_type == "BooleanField": | ||||
|             converters.append(self.convert_booleanfield_value) | ||||
|         elif internal_type == "DateTimeField": | ||||
|             if settings.USE_TZ: | ||||
|                 converters.append(self.convert_datetimefield_value) | ||||
|         elif internal_type == "UUIDField": | ||||
|             converters.append(self.convert_uuidfield_value) | ||||
|         return converters | ||||
|  | ||||
|     def convert_booleanfield_value(self, value, expression, connection): | ||||
|         if value in (0, 1): | ||||
|             value = bool(value) | ||||
|         return value | ||||
|  | ||||
|     def convert_datetimefield_value(self, value, expression, connection): | ||||
|         if value is not None: | ||||
|             value = timezone.make_aware(value, self.connection.timezone) | ||||
|         return value | ||||
|  | ||||
|     def convert_uuidfield_value(self, value, expression, connection): | ||||
|         if value is not None: | ||||
|             value = uuid.UUID(value) | ||||
|         return value | ||||
|  | ||||
|     def binary_placeholder_sql(self, value): | ||||
|         return ( | ||||
|             "_binary %s" if value is not None and not hasattr(value, "as_sql") else "%s" | ||||
|         ) | ||||
|  | ||||
|     def subtract_temporals(self, internal_type, lhs, rhs): | ||||
|         lhs_sql, lhs_params = lhs | ||||
|         rhs_sql, rhs_params = rhs | ||||
|         if internal_type == "TimeField": | ||||
|             if self.connection.mysql_is_mariadb: | ||||
|                 # MariaDB includes the microsecond component in TIME_TO_SEC as | ||||
|                 # a decimal. MySQL returns an integer without microseconds. | ||||
|                 return ( | ||||
|                     "CAST((TIME_TO_SEC(%(lhs)s) - TIME_TO_SEC(%(rhs)s)) " | ||||
|                     "* 1000000 AS SIGNED)" | ||||
|                 ) % { | ||||
|                     "lhs": lhs_sql, | ||||
|                     "rhs": rhs_sql, | ||||
|                 }, ( | ||||
|                     *lhs_params, | ||||
|                     *rhs_params, | ||||
|                 ) | ||||
|             return ( | ||||
|                 "((TIME_TO_SEC(%(lhs)s) * 1000000 + MICROSECOND(%(lhs)s)) -" | ||||
|                 " (TIME_TO_SEC(%(rhs)s) * 1000000 + MICROSECOND(%(rhs)s)))" | ||||
|             ) % {"lhs": lhs_sql, "rhs": rhs_sql}, tuple(lhs_params) * 2 + tuple( | ||||
|                 rhs_params | ||||
|             ) * 2 | ||||
|         params = (*rhs_params, *lhs_params) | ||||
|         return "TIMESTAMPDIFF(MICROSECOND, %s, %s)" % (rhs_sql, lhs_sql), params | ||||
|  | ||||
|     def explain_query_prefix(self, format=None, **options): | ||||
|         # Alias MySQL's TRADITIONAL to TEXT for consistency with other backends. | ||||
|         if format and format.upper() == "TEXT": | ||||
|             format = "TRADITIONAL" | ||||
|         elif ( | ||||
|             not format and "TREE" in self.connection.features.supported_explain_formats | ||||
|         ): | ||||
|             # Use TREE by default (if supported) as it's more informative. | ||||
|             format = "TREE" | ||||
|         analyze = options.pop("analyze", False) | ||||
|         prefix = super().explain_query_prefix(format, **options) | ||||
|         if analyze and self.connection.features.supports_explain_analyze: | ||||
|             # MariaDB uses ANALYZE instead of EXPLAIN ANALYZE. | ||||
|             prefix = ( | ||||
|                 "ANALYZE" if self.connection.mysql_is_mariadb else prefix + " ANALYZE" | ||||
|             ) | ||||
|         if format and not (analyze and not self.connection.mysql_is_mariadb): | ||||
|             # Only MariaDB supports the analyze option with formats. | ||||
|             prefix += " FORMAT=%s" % format | ||||
|         return prefix | ||||
|  | ||||
|     def regex_lookup(self, lookup_type): | ||||
|         # REGEXP_LIKE doesn't exist in MariaDB. | ||||
|         if self.connection.mysql_is_mariadb: | ||||
|             if lookup_type == "regex": | ||||
|                 return "%s REGEXP BINARY %s" | ||||
|             return "%s REGEXP %s" | ||||
|  | ||||
|         match_option = "c" if lookup_type == "regex" else "i" | ||||
|         return "REGEXP_LIKE(%%s, %%s, '%s')" % match_option | ||||
|  | ||||
|     def insert_statement(self, on_conflict=None): | ||||
|         if on_conflict == OnConflict.IGNORE: | ||||
|             return "INSERT IGNORE INTO" | ||||
|         return super().insert_statement(on_conflict=on_conflict) | ||||
|  | ||||
|     def lookup_cast(self, lookup_type, internal_type=None): | ||||
|         lookup = "%s" | ||||
|         if internal_type == "JSONField": | ||||
|             if self.connection.mysql_is_mariadb or lookup_type in ( | ||||
|                 "iexact", | ||||
|                 "contains", | ||||
|                 "icontains", | ||||
|                 "startswith", | ||||
|                 "istartswith", | ||||
|                 "endswith", | ||||
|                 "iendswith", | ||||
|                 "regex", | ||||
|                 "iregex", | ||||
|             ): | ||||
|                 lookup = "JSON_UNQUOTE(%s)" | ||||
|         return lookup | ||||
|  | ||||
|     def conditional_expression_supported_in_where_clause(self, expression): | ||||
|         # MySQL ignores indexes with boolean fields unless they're compared | ||||
|         # directly to a boolean value. | ||||
|         if isinstance(expression, (Exists, Lookup)): | ||||
|             return True | ||||
|         if isinstance(expression, ExpressionWrapper) and expression.conditional: | ||||
|             return self.conditional_expression_supported_in_where_clause( | ||||
|                 expression.expression | ||||
|             ) | ||||
|         if getattr(expression, "conditional", False): | ||||
|             return False | ||||
|         return super().conditional_expression_supported_in_where_clause(expression) | ||||
|  | ||||
|     def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields): | ||||
|         if on_conflict == OnConflict.UPDATE: | ||||
|             conflict_suffix_sql = "ON DUPLICATE KEY UPDATE %(fields)s" | ||||
|             # The use of VALUES() is deprecated in MySQL 8.0.20+. Instead, use | ||||
|             # aliases for the new row and its columns available in MySQL | ||||
|             # 8.0.19+. | ||||
|             if not self.connection.mysql_is_mariadb: | ||||
|                 if self.connection.mysql_version >= (8, 0, 19): | ||||
|                     conflict_suffix_sql = f"AS new {conflict_suffix_sql}" | ||||
|                     field_sql = "%(field)s = new.%(field)s" | ||||
|                 else: | ||||
|                     field_sql = "%(field)s = VALUES(%(field)s)" | ||||
|             # Use VALUE() on MariaDB. | ||||
|             else: | ||||
|                 field_sql = "%(field)s = VALUE(%(field)s)" | ||||
|  | ||||
|             fields = ", ".join( | ||||
|                 [ | ||||
|                     field_sql % {"field": field} | ||||
|                     for field in map(self.quote_name, update_fields) | ||||
|                 ] | ||||
|             ) | ||||
|             return conflict_suffix_sql % {"fields": fields} | ||||
|         return super().on_conflict_suffix_sql( | ||||
|             fields, | ||||
|             on_conflict, | ||||
|             update_fields, | ||||
|             unique_fields, | ||||
|         ) | ||||
| @ -0,0 +1,243 @@ | ||||
| from django.db.backends.base.schema import BaseDatabaseSchemaEditor | ||||
| from django.db.models import NOT_PROVIDED, F, UniqueConstraint | ||||
| from django.db.models.constants import LOOKUP_SEP | ||||
|  | ||||
|  | ||||
| class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): | ||||
|     sql_rename_table = "RENAME TABLE %(old_table)s TO %(new_table)s" | ||||
|  | ||||
|     sql_alter_column_null = "MODIFY %(column)s %(type)s NULL" | ||||
|     sql_alter_column_not_null = "MODIFY %(column)s %(type)s NOT NULL" | ||||
|     sql_alter_column_type = "MODIFY %(column)s %(type)s%(collation)s%(comment)s" | ||||
|     sql_alter_column_no_default_null = "ALTER COLUMN %(column)s SET DEFAULT NULL" | ||||
|  | ||||
|     # No 'CASCADE' which works as a no-op in MySQL but is undocumented | ||||
|     sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s" | ||||
|  | ||||
|     sql_delete_unique = "ALTER TABLE %(table)s DROP INDEX %(name)s" | ||||
|     sql_create_column_inline_fk = ( | ||||
|         ", ADD CONSTRAINT %(name)s FOREIGN KEY (%(column)s) " | ||||
|         "REFERENCES %(to_table)s(%(to_column)s)" | ||||
|     ) | ||||
|     sql_delete_fk = "ALTER TABLE %(table)s DROP FOREIGN KEY %(name)s" | ||||
|  | ||||
|     sql_delete_index = "DROP INDEX %(name)s ON %(table)s" | ||||
|     sql_rename_index = "ALTER TABLE %(table)s RENAME INDEX %(old_name)s TO %(new_name)s" | ||||
|  | ||||
|     sql_create_pk = ( | ||||
|         "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s PRIMARY KEY (%(columns)s)" | ||||
|     ) | ||||
|     sql_delete_pk = "ALTER TABLE %(table)s DROP PRIMARY KEY" | ||||
|  | ||||
|     sql_create_index = "CREATE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s" | ||||
|  | ||||
|     sql_alter_table_comment = "ALTER TABLE %(table)s COMMENT = %(comment)s" | ||||
|     sql_alter_column_comment = None | ||||
|  | ||||
|     @property | ||||
|     def sql_delete_check(self): | ||||
|         if self.connection.mysql_is_mariadb: | ||||
|             # The name of the column check constraint is the same as the field | ||||
|             # name on MariaDB. Adding IF EXISTS clause prevents migrations | ||||
|             # crash. Constraint is removed during a "MODIFY" column statement. | ||||
|             return "ALTER TABLE %(table)s DROP CONSTRAINT IF EXISTS %(name)s" | ||||
|         return "ALTER TABLE %(table)s DROP CHECK %(name)s" | ||||
|  | ||||
|     @property | ||||
|     def sql_rename_column(self): | ||||
|         # MariaDB >= 10.5.2 and MySQL >= 8.0.4 support an | ||||
|         # "ALTER TABLE ... RENAME COLUMN" statement. | ||||
|         if self.connection.mysql_is_mariadb: | ||||
|             if self.connection.mysql_version >= (10, 5, 2): | ||||
|                 return super().sql_rename_column | ||||
|         elif self.connection.mysql_version >= (8, 0, 4): | ||||
|             return super().sql_rename_column | ||||
|         return "ALTER TABLE %(table)s CHANGE %(old_column)s %(new_column)s %(type)s" | ||||
|  | ||||
|     def quote_value(self, value): | ||||
|         self.connection.ensure_connection() | ||||
|         if isinstance(value, str): | ||||
|             value = value.replace("%", "%%") | ||||
|         # MySQLdb escapes to string, PyMySQL to bytes. | ||||
|         quoted = self.connection.connection.escape( | ||||
|             value, self.connection.connection.encoders | ||||
|         ) | ||||
|         if isinstance(value, str) and isinstance(quoted, bytes): | ||||
|             quoted = quoted.decode() | ||||
|         return quoted | ||||
|  | ||||
|     def _is_limited_data_type(self, field): | ||||
|         db_type = field.db_type(self.connection) | ||||
|         return ( | ||||
|             db_type is not None | ||||
|             and db_type.lower() in self.connection._limited_data_types | ||||
|         ) | ||||
|  | ||||
|     def skip_default(self, field): | ||||
|         if not self._supports_limited_data_type_defaults: | ||||
|             return self._is_limited_data_type(field) | ||||
|         return False | ||||
|  | ||||
|     def skip_default_on_alter(self, field): | ||||
|         if self._is_limited_data_type(field) and not self.connection.mysql_is_mariadb: | ||||
|             # MySQL doesn't support defaults for BLOB and TEXT in the | ||||
|             # ALTER COLUMN statement. | ||||
|             return True | ||||
|         return False | ||||
|  | ||||
|     @property | ||||
|     def _supports_limited_data_type_defaults(self): | ||||
|         # MariaDB and MySQL >= 8.0.13 support defaults for BLOB and TEXT. | ||||
|         if self.connection.mysql_is_mariadb: | ||||
|             return True | ||||
|         return self.connection.mysql_version >= (8, 0, 13) | ||||
|  | ||||
|     def _column_default_sql(self, field): | ||||
|         if ( | ||||
|             not self.connection.mysql_is_mariadb | ||||
|             and self._supports_limited_data_type_defaults | ||||
|             and self._is_limited_data_type(field) | ||||
|         ): | ||||
|             # MySQL supports defaults for BLOB and TEXT columns only if the | ||||
|             # default value is written as an expression i.e. in parentheses. | ||||
|             return "(%s)" | ||||
|         return super()._column_default_sql(field) | ||||
|  | ||||
|     def add_field(self, model, field): | ||||
|         super().add_field(model, field) | ||||
|  | ||||
|         # Simulate the effect of a one-off default. | ||||
|         # field.default may be unhashable, so a set isn't used for "in" check. | ||||
|         if self.skip_default(field) and field.default not in (None, NOT_PROVIDED): | ||||
|             effective_default = self.effective_default(field) | ||||
|             self.execute( | ||||
|                 "UPDATE %(table)s SET %(column)s = %%s" | ||||
|                 % { | ||||
|                     "table": self.quote_name(model._meta.db_table), | ||||
|                     "column": self.quote_name(field.column), | ||||
|                 }, | ||||
|                 [effective_default], | ||||
|             ) | ||||
|  | ||||
|     def remove_constraint(self, model, constraint): | ||||
|         if ( | ||||
|             isinstance(constraint, UniqueConstraint) | ||||
|             and constraint.create_sql(model, self) is not None | ||||
|         ): | ||||
|             self._create_missing_fk_index( | ||||
|                 model, | ||||
|                 fields=constraint.fields, | ||||
|                 expressions=constraint.expressions, | ||||
|             ) | ||||
|         super().remove_constraint(model, constraint) | ||||
|  | ||||
|     def remove_index(self, model, index): | ||||
|         self._create_missing_fk_index( | ||||
|             model, | ||||
|             fields=[field_name for field_name, _ in index.fields_orders], | ||||
|             expressions=index.expressions, | ||||
|         ) | ||||
|         super().remove_index(model, index) | ||||
|  | ||||
|     def _field_should_be_indexed(self, model, field): | ||||
|         if not super()._field_should_be_indexed(model, field): | ||||
|             return False | ||||
|  | ||||
|         storage = self.connection.introspection.get_storage_engine( | ||||
|             self.connection.cursor(), model._meta.db_table | ||||
|         ) | ||||
|         # No need to create an index for ForeignKey fields except if | ||||
|         # db_constraint=False because the index from that constraint won't be | ||||
|         # created. | ||||
|         if ( | ||||
|             storage == "InnoDB" | ||||
|             and field.get_internal_type() == "ForeignKey" | ||||
|             and field.db_constraint | ||||
|         ): | ||||
|             return False | ||||
|         return not self._is_limited_data_type(field) | ||||
|  | ||||
|     def _create_missing_fk_index( | ||||
|         self, | ||||
|         model, | ||||
|         *, | ||||
|         fields, | ||||
|         expressions=None, | ||||
|     ): | ||||
|         """ | ||||
|         MySQL can remove an implicit FK index on a field when that field is | ||||
|         covered by another index like a unique_together. "covered" here means | ||||
|         that the more complex index has the FK field as its first field (see | ||||
|         https://bugs.mysql.com/bug.php?id=37910). | ||||
|  | ||||
|         Manually create an implicit FK index to make it possible to remove the | ||||
|         composed index. | ||||
|         """ | ||||
|         first_field_name = None | ||||
|         if fields: | ||||
|             first_field_name = fields[0] | ||||
|         elif ( | ||||
|             expressions | ||||
|             and self.connection.features.supports_expression_indexes | ||||
|             and isinstance(expressions[0], F) | ||||
|             and LOOKUP_SEP not in expressions[0].name | ||||
|         ): | ||||
|             first_field_name = expressions[0].name | ||||
|  | ||||
|         if not first_field_name: | ||||
|             return | ||||
|  | ||||
|         first_field = model._meta.get_field(first_field_name) | ||||
|         if first_field.get_internal_type() == "ForeignKey": | ||||
|             column = self.connection.introspection.identifier_converter( | ||||
|                 first_field.column | ||||
|             ) | ||||
|             with self.connection.cursor() as cursor: | ||||
|                 constraint_names = [ | ||||
|                     name | ||||
|                     for name, infodict in self.connection.introspection.get_constraints( | ||||
|                         cursor, model._meta.db_table | ||||
|                     ).items() | ||||
|                     if infodict["index"] and infodict["columns"][0] == column | ||||
|                 ] | ||||
|             # There are no other indexes that starts with the FK field, only | ||||
|             # the index that is expected to be deleted. | ||||
|             if len(constraint_names) == 1: | ||||
|                 self.execute( | ||||
|                     self._create_index_sql(model, fields=[first_field], suffix="") | ||||
|                 ) | ||||
|  | ||||
|     def _delete_composed_index(self, model, fields, *args): | ||||
|         self._create_missing_fk_index(model, fields=fields) | ||||
|         return super()._delete_composed_index(model, fields, *args) | ||||
|  | ||||
|     def _set_field_new_type_null_status(self, field, new_type): | ||||
|         """ | ||||
|         Keep the null property of the old field. If it has changed, it will be | ||||
|         handled separately. | ||||
|         """ | ||||
|         if field.null: | ||||
|             new_type += " NULL" | ||||
|         else: | ||||
|             new_type += " NOT NULL" | ||||
|         return new_type | ||||
|  | ||||
|     def _alter_column_type_sql( | ||||
|         self, model, old_field, new_field, new_type, old_collation, new_collation | ||||
|     ): | ||||
|         new_type = self._set_field_new_type_null_status(old_field, new_type) | ||||
|         return super()._alter_column_type_sql( | ||||
|             model, old_field, new_field, new_type, old_collation, new_collation | ||||
|         ) | ||||
|  | ||||
|     def _rename_field_sql(self, table, old_field, new_field, new_type): | ||||
|         new_type = self._set_field_new_type_null_status(old_field, new_type) | ||||
|         return super()._rename_field_sql(table, old_field, new_field, new_type) | ||||
|  | ||||
|     def _alter_column_comment_sql(self, model, new_field, new_type, new_db_comment): | ||||
|         # Comment is alter when altering the column type. | ||||
|         return "", [] | ||||
|  | ||||
|     def _comment_sql(self, comment): | ||||
|         comment_sql = super()._comment_sql(comment) | ||||
|         return f" COMMENT {comment_sql}" | ||||
| @ -0,0 +1,77 @@ | ||||
| from django.core import checks | ||||
| from django.db.backends.base.validation import BaseDatabaseValidation | ||||
| from django.utils.version import get_docs_version | ||||
|  | ||||
|  | ||||
| class DatabaseValidation(BaseDatabaseValidation): | ||||
|     def check(self, **kwargs): | ||||
|         issues = super().check(**kwargs) | ||||
|         issues.extend(self._check_sql_mode(**kwargs)) | ||||
|         return issues | ||||
|  | ||||
|     def _check_sql_mode(self, **kwargs): | ||||
|         if not ( | ||||
|             self.connection.sql_mode & {"STRICT_TRANS_TABLES", "STRICT_ALL_TABLES"} | ||||
|         ): | ||||
|             return [ | ||||
|                 checks.Warning( | ||||
|                     "%s Strict Mode is not set for database connection '%s'" | ||||
|                     % (self.connection.display_name, self.connection.alias), | ||||
|                     hint=( | ||||
|                         "%s's Strict Mode fixes many data integrity problems in " | ||||
|                         "%s, such as data truncation upon insertion, by " | ||||
|                         "escalating warnings into errors. It is strongly " | ||||
|                         "recommended you activate it. See: " | ||||
|                         "https://docs.djangoproject.com/en/%s/ref/databases/" | ||||
|                         "#mysql-sql-mode" | ||||
|                         % ( | ||||
|                             self.connection.display_name, | ||||
|                             self.connection.display_name, | ||||
|                             get_docs_version(), | ||||
|                         ), | ||||
|                     ), | ||||
|                     id="mysql.W002", | ||||
|                 ) | ||||
|             ] | ||||
|         return [] | ||||
|  | ||||
|     def check_field_type(self, field, field_type): | ||||
|         """ | ||||
|         MySQL has the following field length restriction: | ||||
|         No character (varchar) fields can have a length exceeding 255 | ||||
|         characters if they have a unique index on them. | ||||
|         MySQL doesn't support a database index on some data types. | ||||
|         """ | ||||
|         errors = [] | ||||
|         if ( | ||||
|             field_type.startswith("varchar") | ||||
|             and field.unique | ||||
|             and (field.max_length is None or int(field.max_length) > 255) | ||||
|         ): | ||||
|             errors.append( | ||||
|                 checks.Warning( | ||||
|                     "%s may not allow unique CharFields to have a max_length " | ||||
|                     "> 255." % self.connection.display_name, | ||||
|                     obj=field, | ||||
|                     hint=( | ||||
|                         "See: https://docs.djangoproject.com/en/%s/ref/" | ||||
|                         "databases/#mysql-character-fields" % get_docs_version() | ||||
|                     ), | ||||
|                     id="mysql.W003", | ||||
|                 ) | ||||
|             ) | ||||
|  | ||||
|         if field.db_index and field_type.lower() in self.connection._limited_data_types: | ||||
|             errors.append( | ||||
|                 checks.Warning( | ||||
|                     "%s does not support a database index on %s columns." | ||||
|                     % (self.connection.display_name, field_type), | ||||
|                     hint=( | ||||
|                         "An index won't be created. Silence this warning if " | ||||
|                         "you don't care about it." | ||||
|                     ), | ||||
|                     obj=field, | ||||
|                     id="fields.W162", | ||||
|                 ) | ||||
|             ) | ||||
|         return errors | ||||
| @ -0,0 +1,592 @@ | ||||
| """ | ||||
| Oracle database backend for Django. | ||||
|  | ||||
| Requires cx_Oracle: https://oracle.github.io/python-cx_Oracle/ | ||||
| """ | ||||
| import datetime | ||||
| import decimal | ||||
| import os | ||||
| import platform | ||||
| from contextlib import contextmanager | ||||
|  | ||||
| from django.conf import settings | ||||
| from django.core.exceptions import ImproperlyConfigured | ||||
| from django.db import IntegrityError | ||||
| from django.db.backends.base.base import BaseDatabaseWrapper | ||||
| from django.db.backends.utils import debug_transaction | ||||
| from django.utils.asyncio import async_unsafe | ||||
| from django.utils.encoding import force_bytes, force_str | ||||
| from django.utils.functional import cached_property | ||||
|  | ||||
|  | ||||
| def _setup_environment(environ): | ||||
|     # Cygwin requires some special voodoo to set the environment variables | ||||
|     # properly so that Oracle will see them. | ||||
|     if platform.system().upper().startswith("CYGWIN"): | ||||
|         try: | ||||
|             import ctypes | ||||
|         except ImportError as e: | ||||
|             raise ImproperlyConfigured( | ||||
|                 "Error loading ctypes: %s; " | ||||
|                 "the Oracle backend requires ctypes to " | ||||
|                 "operate correctly under Cygwin." % e | ||||
|             ) | ||||
|         kernel32 = ctypes.CDLL("kernel32") | ||||
|         for name, value in environ: | ||||
|             kernel32.SetEnvironmentVariableA(name, value) | ||||
|     else: | ||||
|         os.environ.update(environ) | ||||
|  | ||||
|  | ||||
| _setup_environment( | ||||
|     [ | ||||
|         # Oracle takes client-side character set encoding from the environment. | ||||
|         ("NLS_LANG", ".AL32UTF8"), | ||||
|         # This prevents Unicode from getting mangled by getting encoded into the | ||||
|         # potentially non-Unicode database character set. | ||||
|         ("ORA_NCHAR_LITERAL_REPLACE", "TRUE"), | ||||
|     ] | ||||
| ) | ||||
|  | ||||
|  | ||||
| try: | ||||
|     import cx_Oracle as Database | ||||
| except ImportError as e: | ||||
|     raise ImproperlyConfigured("Error loading cx_Oracle module: %s" % e) | ||||
|  | ||||
| # Some of these import cx_Oracle, so import them after checking if it's installed. | ||||
| from .client import DatabaseClient  # NOQA | ||||
| from .creation import DatabaseCreation  # NOQA | ||||
| from .features import DatabaseFeatures  # NOQA | ||||
| from .introspection import DatabaseIntrospection  # NOQA | ||||
| from .operations import DatabaseOperations  # NOQA | ||||
| from .schema import DatabaseSchemaEditor  # NOQA | ||||
| from .utils import Oracle_datetime, dsn  # NOQA | ||||
| from .validation import DatabaseValidation  # NOQA | ||||
|  | ||||
|  | ||||
| @contextmanager | ||||
| def wrap_oracle_errors(): | ||||
|     try: | ||||
|         yield | ||||
|     except Database.DatabaseError as e: | ||||
|         # cx_Oracle raises a cx_Oracle.DatabaseError exception with the | ||||
|         # following attributes and values: | ||||
|         #  code = 2091 | ||||
|         #  message = 'ORA-02091: transaction rolled back | ||||
|         #            'ORA-02291: integrity constraint (TEST_DJANGOTEST.SYS | ||||
|         #               _C00102056) violated - parent key not found' | ||||
|         #            or: | ||||
|         #            'ORA-00001: unique constraint (DJANGOTEST.DEFERRABLE_ | ||||
|         #               PINK_CONSTRAINT) violated | ||||
|         # Convert that case to Django's IntegrityError exception. | ||||
|         x = e.args[0] | ||||
|         if ( | ||||
|             hasattr(x, "code") | ||||
|             and hasattr(x, "message") | ||||
|             and x.code == 2091 | ||||
|             and ("ORA-02291" in x.message or "ORA-00001" in x.message) | ||||
|         ): | ||||
|             raise IntegrityError(*tuple(e.args)) | ||||
|         raise | ||||
|  | ||||
|  | ||||
| class _UninitializedOperatorsDescriptor: | ||||
|     def __get__(self, instance, cls=None): | ||||
|         # If connection.operators is looked up before a connection has been | ||||
|         # created, transparently initialize connection.operators to avert an | ||||
|         # AttributeError. | ||||
|         if instance is None: | ||||
|             raise AttributeError("operators not available as class attribute") | ||||
|         # Creating a cursor will initialize the operators. | ||||
|         instance.cursor().close() | ||||
|         return instance.__dict__["operators"] | ||||
|  | ||||
|  | ||||
| class DatabaseWrapper(BaseDatabaseWrapper): | ||||
|     vendor = "oracle" | ||||
|     display_name = "Oracle" | ||||
|     # This dictionary maps Field objects to their associated Oracle column | ||||
|     # types, as strings. Column-type strings can contain format strings; they'll | ||||
|     # be interpolated against the values of Field.__dict__ before being output. | ||||
|     # If a column type is set to None, it won't be included in the output. | ||||
|     # | ||||
|     # Any format strings starting with "qn_" are quoted before being used in the | ||||
|     # output (the "qn_" prefix is stripped before the lookup is performed. | ||||
|     data_types = { | ||||
|         "AutoField": "NUMBER(11) GENERATED BY DEFAULT ON NULL AS IDENTITY", | ||||
|         "BigAutoField": "NUMBER(19) GENERATED BY DEFAULT ON NULL AS IDENTITY", | ||||
|         "BinaryField": "BLOB", | ||||
|         "BooleanField": "NUMBER(1)", | ||||
|         "CharField": "NVARCHAR2(%(max_length)s)", | ||||
|         "DateField": "DATE", | ||||
|         "DateTimeField": "TIMESTAMP", | ||||
|         "DecimalField": "NUMBER(%(max_digits)s, %(decimal_places)s)", | ||||
|         "DurationField": "INTERVAL DAY(9) TO SECOND(6)", | ||||
|         "FileField": "NVARCHAR2(%(max_length)s)", | ||||
|         "FilePathField": "NVARCHAR2(%(max_length)s)", | ||||
|         "FloatField": "DOUBLE PRECISION", | ||||
|         "IntegerField": "NUMBER(11)", | ||||
|         "JSONField": "NCLOB", | ||||
|         "BigIntegerField": "NUMBER(19)", | ||||
|         "IPAddressField": "VARCHAR2(15)", | ||||
|         "GenericIPAddressField": "VARCHAR2(39)", | ||||
|         "OneToOneField": "NUMBER(11)", | ||||
|         "PositiveBigIntegerField": "NUMBER(19)", | ||||
|         "PositiveIntegerField": "NUMBER(11)", | ||||
|         "PositiveSmallIntegerField": "NUMBER(11)", | ||||
|         "SlugField": "NVARCHAR2(%(max_length)s)", | ||||
|         "SmallAutoField": "NUMBER(5) GENERATED BY DEFAULT ON NULL AS IDENTITY", | ||||
|         "SmallIntegerField": "NUMBER(11)", | ||||
|         "TextField": "NCLOB", | ||||
|         "TimeField": "TIMESTAMP", | ||||
|         "URLField": "VARCHAR2(%(max_length)s)", | ||||
|         "UUIDField": "VARCHAR2(32)", | ||||
|     } | ||||
|     data_type_check_constraints = { | ||||
|         "BooleanField": "%(qn_column)s IN (0,1)", | ||||
|         "JSONField": "%(qn_column)s IS JSON", | ||||
|         "PositiveBigIntegerField": "%(qn_column)s >= 0", | ||||
|         "PositiveIntegerField": "%(qn_column)s >= 0", | ||||
|         "PositiveSmallIntegerField": "%(qn_column)s >= 0", | ||||
|     } | ||||
|  | ||||
|     # Oracle doesn't support a database index on these columns. | ||||
|     _limited_data_types = ("clob", "nclob", "blob") | ||||
|  | ||||
|     operators = _UninitializedOperatorsDescriptor() | ||||
|  | ||||
|     _standard_operators = { | ||||
|         "exact": "= %s", | ||||
|         "iexact": "= UPPER(%s)", | ||||
|         "contains": ( | ||||
|             "LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)" | ||||
|         ), | ||||
|         "icontains": ( | ||||
|             "LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) " | ||||
|             "ESCAPE TRANSLATE('\\' USING NCHAR_CS)" | ||||
|         ), | ||||
|         "gt": "> %s", | ||||
|         "gte": ">= %s", | ||||
|         "lt": "< %s", | ||||
|         "lte": "<= %s", | ||||
|         "startswith": ( | ||||
|             "LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)" | ||||
|         ), | ||||
|         "endswith": ( | ||||
|             "LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)" | ||||
|         ), | ||||
|         "istartswith": ( | ||||
|             "LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) " | ||||
|             "ESCAPE TRANSLATE('\\' USING NCHAR_CS)" | ||||
|         ), | ||||
|         "iendswith": ( | ||||
|             "LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) " | ||||
|             "ESCAPE TRANSLATE('\\' USING NCHAR_CS)" | ||||
|         ), | ||||
|     } | ||||
|  | ||||
|     _likec_operators = { | ||||
|         **_standard_operators, | ||||
|         "contains": "LIKEC %s ESCAPE '\\'", | ||||
|         "icontains": "LIKEC UPPER(%s) ESCAPE '\\'", | ||||
|         "startswith": "LIKEC %s ESCAPE '\\'", | ||||
|         "endswith": "LIKEC %s ESCAPE '\\'", | ||||
|         "istartswith": "LIKEC UPPER(%s) ESCAPE '\\'", | ||||
|         "iendswith": "LIKEC UPPER(%s) ESCAPE '\\'", | ||||
|     } | ||||
|  | ||||
|     # The patterns below are used to generate SQL pattern lookup clauses when | ||||
|     # the right-hand side of the lookup isn't a raw string (it might be an expression | ||||
|     # or the result of a bilateral transformation). | ||||
|     # In those cases, special characters for LIKE operators (e.g. \, %, _) | ||||
|     # should be escaped on the database side. | ||||
|     # | ||||
|     # Note: we use str.format() here for readability as '%' is used as a wildcard for | ||||
|     # the LIKE operator. | ||||
|     pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\', '\\'), '%%', '\%%'), '_', '\_')" | ||||
|     _pattern_ops = { | ||||
|         "contains": "'%%' || {} || '%%'", | ||||
|         "icontains": "'%%' || UPPER({}) || '%%'", | ||||
|         "startswith": "{} || '%%'", | ||||
|         "istartswith": "UPPER({}) || '%%'", | ||||
|         "endswith": "'%%' || {}", | ||||
|         "iendswith": "'%%' || UPPER({})", | ||||
|     } | ||||
|  | ||||
|     _standard_pattern_ops = { | ||||
|         k: "LIKE TRANSLATE( " + v + " USING NCHAR_CS)" | ||||
|         " ESCAPE TRANSLATE('\\' USING NCHAR_CS)" | ||||
|         for k, v in _pattern_ops.items() | ||||
|     } | ||||
|     _likec_pattern_ops = { | ||||
|         k: "LIKEC " + v + " ESCAPE '\\'" for k, v in _pattern_ops.items() | ||||
|     } | ||||
|  | ||||
|     Database = Database | ||||
|     SchemaEditorClass = DatabaseSchemaEditor | ||||
|     # Classes instantiated in __init__(). | ||||
|     client_class = DatabaseClient | ||||
|     creation_class = DatabaseCreation | ||||
|     features_class = DatabaseFeatures | ||||
|     introspection_class = DatabaseIntrospection | ||||
|     ops_class = DatabaseOperations | ||||
|     validation_class = DatabaseValidation | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|         use_returning_into = self.settings_dict["OPTIONS"].get( | ||||
|             "use_returning_into", True | ||||
|         ) | ||||
|         self.features.can_return_columns_from_insert = use_returning_into | ||||
|  | ||||
|     def get_database_version(self): | ||||
|         return self.oracle_version | ||||
|  | ||||
|     def get_connection_params(self): | ||||
|         conn_params = self.settings_dict["OPTIONS"].copy() | ||||
|         if "use_returning_into" in conn_params: | ||||
|             del conn_params["use_returning_into"] | ||||
|         return conn_params | ||||
|  | ||||
|     @async_unsafe | ||||
|     def get_new_connection(self, conn_params): | ||||
|         return Database.connect( | ||||
|             user=self.settings_dict["USER"], | ||||
|             password=self.settings_dict["PASSWORD"], | ||||
|             dsn=dsn(self.settings_dict), | ||||
|             **conn_params, | ||||
|         ) | ||||
|  | ||||
|     def init_connection_state(self): | ||||
|         super().init_connection_state() | ||||
|         cursor = self.create_cursor() | ||||
|         # Set the territory first. The territory overrides NLS_DATE_FORMAT | ||||
|         # and NLS_TIMESTAMP_FORMAT to the territory default. When all of | ||||
|         # these are set in single statement it isn't clear what is supposed | ||||
|         # to happen. | ||||
|         cursor.execute("ALTER SESSION SET NLS_TERRITORY = 'AMERICA'") | ||||
|         # Set Oracle date to ANSI date format.  This only needs to execute | ||||
|         # once when we create a new connection. We also set the Territory | ||||
|         # to 'AMERICA' which forces Sunday to evaluate to a '1' in | ||||
|         # TO_CHAR(). | ||||
|         cursor.execute( | ||||
|             "ALTER SESSION SET NLS_DATE_FORMAT = 'YYYY-MM-DD HH24:MI:SS'" | ||||
|             " NLS_TIMESTAMP_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF'" | ||||
|             + (" TIME_ZONE = 'UTC'" if settings.USE_TZ else "") | ||||
|         ) | ||||
|         cursor.close() | ||||
|         if "operators" not in self.__dict__: | ||||
|             # Ticket #14149: Check whether our LIKE implementation will | ||||
|             # work for this connection or we need to fall back on LIKEC. | ||||
|             # This check is performed only once per DatabaseWrapper | ||||
|             # instance per thread, since subsequent connections will use | ||||
|             # the same settings. | ||||
|             cursor = self.create_cursor() | ||||
|             try: | ||||
|                 cursor.execute( | ||||
|                     "SELECT 1 FROM DUAL WHERE DUMMY %s" | ||||
|                     % self._standard_operators["contains"], | ||||
|                     ["X"], | ||||
|                 ) | ||||
|             except Database.DatabaseError: | ||||
|                 self.operators = self._likec_operators | ||||
|                 self.pattern_ops = self._likec_pattern_ops | ||||
|             else: | ||||
|                 self.operators = self._standard_operators | ||||
|                 self.pattern_ops = self._standard_pattern_ops | ||||
|             cursor.close() | ||||
|         self.connection.stmtcachesize = 20 | ||||
|         # Ensure all changes are preserved even when AUTOCOMMIT is False. | ||||
|         if not self.get_autocommit(): | ||||
|             self.commit() | ||||
|  | ||||
|     @async_unsafe | ||||
|     def create_cursor(self, name=None): | ||||
|         return FormatStylePlaceholderCursor(self.connection) | ||||
|  | ||||
|     def _commit(self): | ||||
|         if self.connection is not None: | ||||
|             with debug_transaction(self, "COMMIT"), wrap_oracle_errors(): | ||||
|                 return self.connection.commit() | ||||
|  | ||||
|     # Oracle doesn't support releasing savepoints. But we fake them when query | ||||
|     # logging is enabled to keep query counts consistent with other backends. | ||||
|     def _savepoint_commit(self, sid): | ||||
|         if self.queries_logged: | ||||
|             self.queries_log.append( | ||||
|                 { | ||||
|                     "sql": "-- RELEASE SAVEPOINT %s (faked)" % self.ops.quote_name(sid), | ||||
|                     "time": "0.000", | ||||
|                 } | ||||
|             ) | ||||
|  | ||||
|     def _set_autocommit(self, autocommit): | ||||
|         with self.wrap_database_errors: | ||||
|             self.connection.autocommit = autocommit | ||||
|  | ||||
|     def check_constraints(self, table_names=None): | ||||
|         """ | ||||
|         Check constraints by setting them to immediate. Return them to deferred | ||||
|         afterward. | ||||
|         """ | ||||
|         with self.cursor() as cursor: | ||||
|             cursor.execute("SET CONSTRAINTS ALL IMMEDIATE") | ||||
|             cursor.execute("SET CONSTRAINTS ALL DEFERRED") | ||||
|  | ||||
|     def is_usable(self): | ||||
|         try: | ||||
|             self.connection.ping() | ||||
|         except Database.Error: | ||||
|             return False | ||||
|         else: | ||||
|             return True | ||||
|  | ||||
|     @cached_property | ||||
|     def cx_oracle_version(self): | ||||
|         return tuple(int(x) for x in Database.version.split(".")) | ||||
|  | ||||
|     @cached_property | ||||
|     def oracle_version(self): | ||||
|         with self.temporary_connection(): | ||||
|             return tuple(int(x) for x in self.connection.version.split(".")) | ||||
|  | ||||
|  | ||||
| class OracleParam: | ||||
|     """ | ||||
|     Wrapper object for formatting parameters for Oracle. If the string | ||||
|     representation of the value is large enough (greater than 4000 characters) | ||||
|     the input size needs to be set as CLOB. Alternatively, if the parameter | ||||
|     has an `input_size` attribute, then the value of the `input_size` attribute | ||||
|     will be used instead. Otherwise, no input size will be set for the | ||||
|     parameter when executing the query. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, param, cursor, strings_only=False): | ||||
|         # With raw SQL queries, datetimes can reach this function | ||||
|         # without being converted by DateTimeField.get_db_prep_value. | ||||
|         if settings.USE_TZ and ( | ||||
|             isinstance(param, datetime.datetime) | ||||
|             and not isinstance(param, Oracle_datetime) | ||||
|         ): | ||||
|             param = Oracle_datetime.from_datetime(param) | ||||
|  | ||||
|         string_size = 0 | ||||
|         # Oracle doesn't recognize True and False correctly. | ||||
|         if param is True: | ||||
|             param = 1 | ||||
|         elif param is False: | ||||
|             param = 0 | ||||
|         if hasattr(param, "bind_parameter"): | ||||
|             self.force_bytes = param.bind_parameter(cursor) | ||||
|         elif isinstance(param, (Database.Binary, datetime.timedelta)): | ||||
|             self.force_bytes = param | ||||
|         else: | ||||
|             # To transmit to the database, we need Unicode if supported | ||||
|             # To get size right, we must consider bytes. | ||||
|             self.force_bytes = force_str(param, cursor.charset, strings_only) | ||||
|             if isinstance(self.force_bytes, str): | ||||
|                 # We could optimize by only converting up to 4000 bytes here | ||||
|                 string_size = len(force_bytes(param, cursor.charset, strings_only)) | ||||
|         if hasattr(param, "input_size"): | ||||
|             # If parameter has `input_size` attribute, use that. | ||||
|             self.input_size = param.input_size | ||||
|         elif string_size > 4000: | ||||
|             # Mark any string param greater than 4000 characters as a CLOB. | ||||
|             self.input_size = Database.CLOB | ||||
|         elif isinstance(param, datetime.datetime): | ||||
|             self.input_size = Database.TIMESTAMP | ||||
|         else: | ||||
|             self.input_size = None | ||||
|  | ||||
|  | ||||
| class VariableWrapper: | ||||
|     """ | ||||
|     An adapter class for cursor variables that prevents the wrapped object | ||||
|     from being converted into a string when used to instantiate an OracleParam. | ||||
|     This can be used generally for any other object that should be passed into | ||||
|     Cursor.execute as-is. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, var): | ||||
|         self.var = var | ||||
|  | ||||
|     def bind_parameter(self, cursor): | ||||
|         return self.var | ||||
|  | ||||
|     def __getattr__(self, key): | ||||
|         return getattr(self.var, key) | ||||
|  | ||||
|     def __setattr__(self, key, value): | ||||
|         if key == "var": | ||||
|             self.__dict__[key] = value | ||||
|         else: | ||||
|             setattr(self.var, key, value) | ||||
|  | ||||
|  | ||||
| class FormatStylePlaceholderCursor: | ||||
|     """ | ||||
|     Django uses "format" (e.g. '%s') style placeholders, but Oracle uses ":var" | ||||
|     style. This fixes it -- but note that if you want to use a literal "%s" in | ||||
|     a query, you'll need to use "%%s". | ||||
|     """ | ||||
|  | ||||
|     charset = "utf-8" | ||||
|  | ||||
|     def __init__(self, connection): | ||||
|         self.cursor = connection.cursor() | ||||
|         self.cursor.outputtypehandler = self._output_type_handler | ||||
|  | ||||
|     @staticmethod | ||||
|     def _output_number_converter(value): | ||||
|         return decimal.Decimal(value) if "." in value else int(value) | ||||
|  | ||||
|     @staticmethod | ||||
|     def _get_decimal_converter(precision, scale): | ||||
|         if scale == 0: | ||||
|             return int | ||||
|         context = decimal.Context(prec=precision) | ||||
|         quantize_value = decimal.Decimal(1).scaleb(-scale) | ||||
|         return lambda v: decimal.Decimal(v).quantize(quantize_value, context=context) | ||||
|  | ||||
|     @staticmethod | ||||
|     def _output_type_handler(cursor, name, defaultType, length, precision, scale): | ||||
|         """ | ||||
|         Called for each db column fetched from cursors. Return numbers as the | ||||
|         appropriate Python type. | ||||
|         """ | ||||
|         if defaultType == Database.NUMBER: | ||||
|             if scale == -127: | ||||
|                 if precision == 0: | ||||
|                     # NUMBER column: decimal-precision floating point. | ||||
|                     # This will normally be an integer from a sequence, | ||||
|                     # but it could be a decimal value. | ||||
|                     outconverter = FormatStylePlaceholderCursor._output_number_converter | ||||
|                 else: | ||||
|                     # FLOAT column: binary-precision floating point. | ||||
|                     # This comes from FloatField columns. | ||||
|                     outconverter = float | ||||
|             elif precision > 0: | ||||
|                 # NUMBER(p,s) column: decimal-precision fixed point. | ||||
|                 # This comes from IntegerField and DecimalField columns. | ||||
|                 outconverter = FormatStylePlaceholderCursor._get_decimal_converter( | ||||
|                     precision, scale | ||||
|                 ) | ||||
|             else: | ||||
|                 # No type information. This normally comes from a | ||||
|                 # mathematical expression in the SELECT list. Guess int | ||||
|                 # or Decimal based on whether it has a decimal point. | ||||
|                 outconverter = FormatStylePlaceholderCursor._output_number_converter | ||||
|             return cursor.var( | ||||
|                 Database.STRING, | ||||
|                 size=255, | ||||
|                 arraysize=cursor.arraysize, | ||||
|                 outconverter=outconverter, | ||||
|             ) | ||||
|  | ||||
|     def _format_params(self, params): | ||||
|         try: | ||||
|             return {k: OracleParam(v, self, True) for k, v in params.items()} | ||||
|         except AttributeError: | ||||
|             return tuple(OracleParam(p, self, True) for p in params) | ||||
|  | ||||
|     def _guess_input_sizes(self, params_list): | ||||
|         # Try dict handling; if that fails, treat as sequence | ||||
|         if hasattr(params_list[0], "keys"): | ||||
|             sizes = {} | ||||
|             for params in params_list: | ||||
|                 for k, value in params.items(): | ||||
|                     if value.input_size: | ||||
|                         sizes[k] = value.input_size | ||||
|             if sizes: | ||||
|                 self.setinputsizes(**sizes) | ||||
|         else: | ||||
|             # It's not a list of dicts; it's a list of sequences | ||||
|             sizes = [None] * len(params_list[0]) | ||||
|             for params in params_list: | ||||
|                 for i, value in enumerate(params): | ||||
|                     if value.input_size: | ||||
|                         sizes[i] = value.input_size | ||||
|             if sizes: | ||||
|                 self.setinputsizes(*sizes) | ||||
|  | ||||
|     def _param_generator(self, params): | ||||
|         # Try dict handling; if that fails, treat as sequence | ||||
|         if hasattr(params, "items"): | ||||
|             return {k: v.force_bytes for k, v in params.items()} | ||||
|         else: | ||||
|             return [p.force_bytes for p in params] | ||||
|  | ||||
|     def _fix_for_params(self, query, params, unify_by_values=False): | ||||
|         # cx_Oracle wants no trailing ';' for SQL statements.  For PL/SQL, it | ||||
|         # it does want a trailing ';' but not a trailing '/'.  However, these | ||||
|         # characters must be included in the original query in case the query | ||||
|         # is being passed to SQL*Plus. | ||||
|         if query.endswith(";") or query.endswith("/"): | ||||
|             query = query[:-1] | ||||
|         if params is None: | ||||
|             params = [] | ||||
|         elif hasattr(params, "keys"): | ||||
|             # Handle params as dict | ||||
|             args = {k: ":%s" % k for k in params} | ||||
|             query %= args | ||||
|         elif unify_by_values and params: | ||||
|             # Handle params as a dict with unified query parameters by their | ||||
|             # values. It can be used only in single query execute() because | ||||
|             # executemany() shares the formatted query with each of the params | ||||
|             # list. e.g. for input params = [0.75, 2, 0.75, 'sth', 0.75] | ||||
|             # params_dict = {0.75: ':arg0', 2: ':arg1', 'sth': ':arg2'} | ||||
|             # args = [':arg0', ':arg1', ':arg0', ':arg2', ':arg0'] | ||||
|             # params = {':arg0': 0.75, ':arg1': 2, ':arg2': 'sth'} | ||||
|             params_dict = { | ||||
|                 param: ":arg%d" % i for i, param in enumerate(dict.fromkeys(params)) | ||||
|             } | ||||
|             args = [params_dict[param] for param in params] | ||||
|             params = {value: key for key, value in params_dict.items()} | ||||
|             query %= tuple(args) | ||||
|         else: | ||||
|             # Handle params as sequence | ||||
|             args = [(":arg%d" % i) for i in range(len(params))] | ||||
|             query %= tuple(args) | ||||
|         return query, self._format_params(params) | ||||
|  | ||||
|     def execute(self, query, params=None): | ||||
|         query, params = self._fix_for_params(query, params, unify_by_values=True) | ||||
|         self._guess_input_sizes([params]) | ||||
|         with wrap_oracle_errors(): | ||||
|             return self.cursor.execute(query, self._param_generator(params)) | ||||
|  | ||||
|     def executemany(self, query, params=None): | ||||
|         if not params: | ||||
|             # No params given, nothing to do | ||||
|             return None | ||||
|         # uniform treatment for sequences and iterables | ||||
|         params_iter = iter(params) | ||||
|         query, firstparams = self._fix_for_params(query, next(params_iter)) | ||||
|         # we build a list of formatted params; as we're going to traverse it | ||||
|         # more than once, we can't make it lazy by using a generator | ||||
|         formatted = [firstparams] + [self._format_params(p) for p in params_iter] | ||||
|         self._guess_input_sizes(formatted) | ||||
|         with wrap_oracle_errors(): | ||||
|             return self.cursor.executemany( | ||||
|                 query, [self._param_generator(p) for p in formatted] | ||||
|             ) | ||||
|  | ||||
|     def close(self): | ||||
|         try: | ||||
|             self.cursor.close() | ||||
|         except Database.InterfaceError: | ||||
|             # already closed | ||||
|             pass | ||||
|  | ||||
|     def var(self, *args): | ||||
|         return VariableWrapper(self.cursor.var(*args)) | ||||
|  | ||||
|     def arrayvar(self, *args): | ||||
|         return VariableWrapper(self.cursor.arrayvar(*args)) | ||||
|  | ||||
|     def __getattr__(self, attr): | ||||
|         return getattr(self.cursor, attr) | ||||
|  | ||||
|     def __iter__(self): | ||||
|         return iter(self.cursor) | ||||
| @ -0,0 +1,27 @@ | ||||
| import shutil | ||||
|  | ||||
| from django.db.backends.base.client import BaseDatabaseClient | ||||
|  | ||||
|  | ||||
| class DatabaseClient(BaseDatabaseClient): | ||||
|     executable_name = "sqlplus" | ||||
|     wrapper_name = "rlwrap" | ||||
|  | ||||
|     @staticmethod | ||||
|     def connect_string(settings_dict): | ||||
|         from django.db.backends.oracle.utils import dsn | ||||
|  | ||||
|         return '%s/"%s"@%s' % ( | ||||
|             settings_dict["USER"], | ||||
|             settings_dict["PASSWORD"], | ||||
|             dsn(settings_dict), | ||||
|         ) | ||||
|  | ||||
|     @classmethod | ||||
|     def settings_to_cmd_args_env(cls, settings_dict, parameters): | ||||
|         args = [cls.executable_name, "-L", cls.connect_string(settings_dict)] | ||||
|         wrapper_path = shutil.which(cls.wrapper_name) | ||||
|         if wrapper_path: | ||||
|             args = [wrapper_path, *args] | ||||
|         args.extend(parameters) | ||||
|         return args, None | ||||
| @ -0,0 +1,464 @@ | ||||
| import sys | ||||
|  | ||||
| from django.conf import settings | ||||
| from django.db import DatabaseError | ||||
| from django.db.backends.base.creation import BaseDatabaseCreation | ||||
| from django.utils.crypto import get_random_string | ||||
| from django.utils.functional import cached_property | ||||
|  | ||||
| TEST_DATABASE_PREFIX = "test_" | ||||
|  | ||||
|  | ||||
| class DatabaseCreation(BaseDatabaseCreation): | ||||
|     @cached_property | ||||
|     def _maindb_connection(self): | ||||
|         """ | ||||
|         This is analogous to other backends' `_nodb_connection` property, | ||||
|         which allows access to an "administrative" connection which can | ||||
|         be used to manage the test databases. | ||||
|         For Oracle, the only connection that can be used for that purpose | ||||
|         is the main (non-test) connection. | ||||
|         """ | ||||
|         settings_dict = settings.DATABASES[self.connection.alias] | ||||
|         user = settings_dict.get("SAVED_USER") or settings_dict["USER"] | ||||
|         password = settings_dict.get("SAVED_PASSWORD") or settings_dict["PASSWORD"] | ||||
|         settings_dict = {**settings_dict, "USER": user, "PASSWORD": password} | ||||
|         DatabaseWrapper = type(self.connection) | ||||
|         return DatabaseWrapper(settings_dict, alias=self.connection.alias) | ||||
|  | ||||
|     def _create_test_db(self, verbosity=1, autoclobber=False, keepdb=False): | ||||
|         parameters = self._get_test_db_params() | ||||
|         with self._maindb_connection.cursor() as cursor: | ||||
|             if self._test_database_create(): | ||||
|                 try: | ||||
|                     self._execute_test_db_creation( | ||||
|                         cursor, parameters, verbosity, keepdb | ||||
|                     ) | ||||
|                 except Exception as e: | ||||
|                     if "ORA-01543" not in str(e): | ||||
|                         # All errors except "tablespace already exists" cancel tests | ||||
|                         self.log("Got an error creating the test database: %s" % e) | ||||
|                         sys.exit(2) | ||||
|                     if not autoclobber: | ||||
|                         confirm = input( | ||||
|                             "It appears the test database, %s, already exists. " | ||||
|                             "Type 'yes' to delete it, or 'no' to cancel: " | ||||
|                             % parameters["user"] | ||||
|                         ) | ||||
|                     if autoclobber or confirm == "yes": | ||||
|                         if verbosity >= 1: | ||||
|                             self.log( | ||||
|                                 "Destroying old test database for alias '%s'..." | ||||
|                                 % self.connection.alias | ||||
|                             ) | ||||
|                         try: | ||||
|                             self._execute_test_db_destruction( | ||||
|                                 cursor, parameters, verbosity | ||||
|                             ) | ||||
|                         except DatabaseError as e: | ||||
|                             if "ORA-29857" in str(e): | ||||
|                                 self._handle_objects_preventing_db_destruction( | ||||
|                                     cursor, parameters, verbosity, autoclobber | ||||
|                                 ) | ||||
|                             else: | ||||
|                                 # Ran into a database error that isn't about | ||||
|                                 # leftover objects in the tablespace. | ||||
|                                 self.log( | ||||
|                                     "Got an error destroying the old test database: %s" | ||||
|                                     % e | ||||
|                                 ) | ||||
|                                 sys.exit(2) | ||||
|                         except Exception as e: | ||||
|                             self.log( | ||||
|                                 "Got an error destroying the old test database: %s" % e | ||||
|                             ) | ||||
|                             sys.exit(2) | ||||
|                         try: | ||||
|                             self._execute_test_db_creation( | ||||
|                                 cursor, parameters, verbosity, keepdb | ||||
|                             ) | ||||
|                         except Exception as e: | ||||
|                             self.log( | ||||
|                                 "Got an error recreating the test database: %s" % e | ||||
|                             ) | ||||
|                             sys.exit(2) | ||||
|                     else: | ||||
|                         self.log("Tests cancelled.") | ||||
|                         sys.exit(1) | ||||
|  | ||||
|             if self._test_user_create(): | ||||
|                 if verbosity >= 1: | ||||
|                     self.log("Creating test user...") | ||||
|                 try: | ||||
|                     self._create_test_user(cursor, parameters, verbosity, keepdb) | ||||
|                 except Exception as e: | ||||
|                     if "ORA-01920" not in str(e): | ||||
|                         # All errors except "user already exists" cancel tests | ||||
|                         self.log("Got an error creating the test user: %s" % e) | ||||
|                         sys.exit(2) | ||||
|                     if not autoclobber: | ||||
|                         confirm = input( | ||||
|                             "It appears the test user, %s, already exists. Type " | ||||
|                             "'yes' to delete it, or 'no' to cancel: " | ||||
|                             % parameters["user"] | ||||
|                         ) | ||||
|                     if autoclobber or confirm == "yes": | ||||
|                         try: | ||||
|                             if verbosity >= 1: | ||||
|                                 self.log("Destroying old test user...") | ||||
|                             self._destroy_test_user(cursor, parameters, verbosity) | ||||
|                             if verbosity >= 1: | ||||
|                                 self.log("Creating test user...") | ||||
|                             self._create_test_user( | ||||
|                                 cursor, parameters, verbosity, keepdb | ||||
|                             ) | ||||
|                         except Exception as e: | ||||
|                             self.log("Got an error recreating the test user: %s" % e) | ||||
|                             sys.exit(2) | ||||
|                     else: | ||||
|                         self.log("Tests cancelled.") | ||||
|                         sys.exit(1) | ||||
|         # Done with main user -- test user and tablespaces created. | ||||
|         self._maindb_connection.close() | ||||
|         self._switch_to_test_user(parameters) | ||||
|         return self.connection.settings_dict["NAME"] | ||||
|  | ||||
|     def _switch_to_test_user(self, parameters): | ||||
|         """ | ||||
|         Switch to the user that's used for creating the test database. | ||||
|  | ||||
|         Oracle doesn't have the concept of separate databases under the same | ||||
|         user, so a separate user is used; see _create_test_db(). The main user | ||||
|         is also needed for cleanup when testing is completed, so save its | ||||
|         credentials in the SAVED_USER/SAVED_PASSWORD key in the settings dict. | ||||
|         """ | ||||
|         real_settings = settings.DATABASES[self.connection.alias] | ||||
|         real_settings["SAVED_USER"] = self.connection.settings_dict[ | ||||
|             "SAVED_USER" | ||||
|         ] = self.connection.settings_dict["USER"] | ||||
|         real_settings["SAVED_PASSWORD"] = self.connection.settings_dict[ | ||||
|             "SAVED_PASSWORD" | ||||
|         ] = self.connection.settings_dict["PASSWORD"] | ||||
|         real_test_settings = real_settings["TEST"] | ||||
|         test_settings = self.connection.settings_dict["TEST"] | ||||
|         real_test_settings["USER"] = real_settings["USER"] = test_settings[ | ||||
|             "USER" | ||||
|         ] = self.connection.settings_dict["USER"] = parameters["user"] | ||||
|         real_settings["PASSWORD"] = self.connection.settings_dict[ | ||||
|             "PASSWORD" | ||||
|         ] = parameters["password"] | ||||
|  | ||||
|     def set_as_test_mirror(self, primary_settings_dict): | ||||
|         """ | ||||
|         Set this database up to be used in testing as a mirror of a primary | ||||
|         database whose settings are given. | ||||
|         """ | ||||
|         self.connection.settings_dict["USER"] = primary_settings_dict["USER"] | ||||
|         self.connection.settings_dict["PASSWORD"] = primary_settings_dict["PASSWORD"] | ||||
|  | ||||
|     def _handle_objects_preventing_db_destruction( | ||||
|         self, cursor, parameters, verbosity, autoclobber | ||||
|     ): | ||||
|         # There are objects in the test tablespace which prevent dropping it | ||||
|         # The easy fix is to drop the test user -- but are we allowed to do so? | ||||
|         self.log( | ||||
|             "There are objects in the old test database which prevent its destruction." | ||||
|             "\nIf they belong to the test user, deleting the user will allow the test " | ||||
|             "database to be recreated.\n" | ||||
|             "Otherwise, you will need to find and remove each of these objects, " | ||||
|             "or use a different tablespace.\n" | ||||
|         ) | ||||
|         if self._test_user_create(): | ||||
|             if not autoclobber: | ||||
|                 confirm = input("Type 'yes' to delete user %s: " % parameters["user"]) | ||||
|             if autoclobber or confirm == "yes": | ||||
|                 try: | ||||
|                     if verbosity >= 1: | ||||
|                         self.log("Destroying old test user...") | ||||
|                     self._destroy_test_user(cursor, parameters, verbosity) | ||||
|                 except Exception as e: | ||||
|                     self.log("Got an error destroying the test user: %s" % e) | ||||
|                     sys.exit(2) | ||||
|                 try: | ||||
|                     if verbosity >= 1: | ||||
|                         self.log( | ||||
|                             "Destroying old test database for alias '%s'..." | ||||
|                             % self.connection.alias | ||||
|                         ) | ||||
|                     self._execute_test_db_destruction(cursor, parameters, verbosity) | ||||
|                 except Exception as e: | ||||
|                     self.log("Got an error destroying the test database: %s" % e) | ||||
|                     sys.exit(2) | ||||
|             else: | ||||
|                 self.log("Tests cancelled -- test database cannot be recreated.") | ||||
|                 sys.exit(1) | ||||
|         else: | ||||
|             self.log( | ||||
|                 "Django is configured to use pre-existing test user '%s'," | ||||
|                 " and will not attempt to delete it." % parameters["user"] | ||||
|             ) | ||||
|             self.log("Tests cancelled -- test database cannot be recreated.") | ||||
|             sys.exit(1) | ||||
|  | ||||
|     def _destroy_test_db(self, test_database_name, verbosity=1): | ||||
|         """ | ||||
|         Destroy a test database, prompting the user for confirmation if the | ||||
|         database already exists. Return the name of the test database created. | ||||
|         """ | ||||
|         self.connection.settings_dict["USER"] = self.connection.settings_dict[ | ||||
|             "SAVED_USER" | ||||
|         ] | ||||
|         self.connection.settings_dict["PASSWORD"] = self.connection.settings_dict[ | ||||
|             "SAVED_PASSWORD" | ||||
|         ] | ||||
|         self.connection.close() | ||||
|         parameters = self._get_test_db_params() | ||||
|         with self._maindb_connection.cursor() as cursor: | ||||
|             if self._test_user_create(): | ||||
|                 if verbosity >= 1: | ||||
|                     self.log("Destroying test user...") | ||||
|                 self._destroy_test_user(cursor, parameters, verbosity) | ||||
|             if self._test_database_create(): | ||||
|                 if verbosity >= 1: | ||||
|                     self.log("Destroying test database tables...") | ||||
|                 self._execute_test_db_destruction(cursor, parameters, verbosity) | ||||
|         self._maindb_connection.close() | ||||
|  | ||||
|     def _execute_test_db_creation(self, cursor, parameters, verbosity, keepdb=False): | ||||
|         if verbosity >= 2: | ||||
|             self.log("_create_test_db(): dbname = %s" % parameters["user"]) | ||||
|         if self._test_database_oracle_managed_files(): | ||||
|             statements = [ | ||||
|                 """ | ||||
|                 CREATE TABLESPACE %(tblspace)s | ||||
|                 DATAFILE SIZE %(size)s | ||||
|                 AUTOEXTEND ON NEXT %(extsize)s MAXSIZE %(maxsize)s | ||||
|                 """, | ||||
|                 """ | ||||
|                 CREATE TEMPORARY TABLESPACE %(tblspace_temp)s | ||||
|                 TEMPFILE SIZE %(size_tmp)s | ||||
|                 AUTOEXTEND ON NEXT %(extsize_tmp)s MAXSIZE %(maxsize_tmp)s | ||||
|                 """, | ||||
|             ] | ||||
|         else: | ||||
|             statements = [ | ||||
|                 """ | ||||
|                 CREATE TABLESPACE %(tblspace)s | ||||
|                 DATAFILE '%(datafile)s' SIZE %(size)s REUSE | ||||
|                 AUTOEXTEND ON NEXT %(extsize)s MAXSIZE %(maxsize)s | ||||
|                 """, | ||||
|                 """ | ||||
|                 CREATE TEMPORARY TABLESPACE %(tblspace_temp)s | ||||
|                 TEMPFILE '%(datafile_tmp)s' SIZE %(size_tmp)s REUSE | ||||
|                 AUTOEXTEND ON NEXT %(extsize_tmp)s MAXSIZE %(maxsize_tmp)s | ||||
|                 """, | ||||
|             ] | ||||
|         # Ignore "tablespace already exists" error when keepdb is on. | ||||
|         acceptable_ora_err = "ORA-01543" if keepdb else None | ||||
|         self._execute_allow_fail_statements( | ||||
|             cursor, statements, parameters, verbosity, acceptable_ora_err | ||||
|         ) | ||||
|  | ||||
|     def _create_test_user(self, cursor, parameters, verbosity, keepdb=False): | ||||
|         if verbosity >= 2: | ||||
|             self.log("_create_test_user(): username = %s" % parameters["user"]) | ||||
|         statements = [ | ||||
|             """CREATE USER %(user)s | ||||
|                IDENTIFIED BY "%(password)s" | ||||
|                DEFAULT TABLESPACE %(tblspace)s | ||||
|                TEMPORARY TABLESPACE %(tblspace_temp)s | ||||
|                QUOTA UNLIMITED ON %(tblspace)s | ||||
|             """, | ||||
|             """GRANT CREATE SESSION, | ||||
|                      CREATE TABLE, | ||||
|                      CREATE SEQUENCE, | ||||
|                      CREATE PROCEDURE, | ||||
|                      CREATE TRIGGER | ||||
|                TO %(user)s""", | ||||
|         ] | ||||
|         # Ignore "user already exists" error when keepdb is on | ||||
|         acceptable_ora_err = "ORA-01920" if keepdb else None | ||||
|         success = self._execute_allow_fail_statements( | ||||
|             cursor, statements, parameters, verbosity, acceptable_ora_err | ||||
|         ) | ||||
|         # If the password was randomly generated, change the user accordingly. | ||||
|         if not success and self._test_settings_get("PASSWORD") is None: | ||||
|             set_password = 'ALTER USER %(user)s IDENTIFIED BY "%(password)s"' | ||||
|             self._execute_statements(cursor, [set_password], parameters, verbosity) | ||||
|         # Most test suites can be run without "create view" and | ||||
|         # "create materialized view" privileges. But some need it. | ||||
|         for object_type in ("VIEW", "MATERIALIZED VIEW"): | ||||
|             extra = "GRANT CREATE %(object_type)s TO %(user)s" | ||||
|             parameters["object_type"] = object_type | ||||
|             success = self._execute_allow_fail_statements( | ||||
|                 cursor, [extra], parameters, verbosity, "ORA-01031" | ||||
|             ) | ||||
|             if not success and verbosity >= 2: | ||||
|                 self.log( | ||||
|                     "Failed to grant CREATE %s permission to test user. This may be ok." | ||||
|                     % object_type | ||||
|                 ) | ||||
|  | ||||
|     def _execute_test_db_destruction(self, cursor, parameters, verbosity): | ||||
|         if verbosity >= 2: | ||||
|             self.log("_execute_test_db_destruction(): dbname=%s" % parameters["user"]) | ||||
|         statements = [ | ||||
|             "DROP TABLESPACE %(tblspace)s " | ||||
|             "INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS", | ||||
|             "DROP TABLESPACE %(tblspace_temp)s " | ||||
|             "INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS", | ||||
|         ] | ||||
|         self._execute_statements(cursor, statements, parameters, verbosity) | ||||
|  | ||||
|     def _destroy_test_user(self, cursor, parameters, verbosity): | ||||
|         if verbosity >= 2: | ||||
|             self.log("_destroy_test_user(): user=%s" % parameters["user"]) | ||||
|             self.log("Be patient. This can take some time...") | ||||
|         statements = [ | ||||
|             "DROP USER %(user)s CASCADE", | ||||
|         ] | ||||
|         self._execute_statements(cursor, statements, parameters, verbosity) | ||||
|  | ||||
|     def _execute_statements( | ||||
|         self, cursor, statements, parameters, verbosity, allow_quiet_fail=False | ||||
|     ): | ||||
|         for template in statements: | ||||
|             stmt = template % parameters | ||||
|             if verbosity >= 2: | ||||
|                 print(stmt) | ||||
|             try: | ||||
|                 cursor.execute(stmt) | ||||
|             except Exception as err: | ||||
|                 if (not allow_quiet_fail) or verbosity >= 2: | ||||
|                     self.log("Failed (%s)" % (err)) | ||||
|                 raise | ||||
|  | ||||
|     def _execute_allow_fail_statements( | ||||
|         self, cursor, statements, parameters, verbosity, acceptable_ora_err | ||||
|     ): | ||||
|         """ | ||||
|         Execute statements which are allowed to fail silently if the Oracle | ||||
|         error code given by `acceptable_ora_err` is raised. Return True if the | ||||
|         statements execute without an exception, or False otherwise. | ||||
|         """ | ||||
|         try: | ||||
|             # Statement can fail when acceptable_ora_err is not None | ||||
|             allow_quiet_fail = ( | ||||
|                 acceptable_ora_err is not None and len(acceptable_ora_err) > 0 | ||||
|             ) | ||||
|             self._execute_statements( | ||||
|                 cursor, | ||||
|                 statements, | ||||
|                 parameters, | ||||
|                 verbosity, | ||||
|                 allow_quiet_fail=allow_quiet_fail, | ||||
|             ) | ||||
|             return True | ||||
|         except DatabaseError as err: | ||||
|             description = str(err) | ||||
|             if acceptable_ora_err is None or acceptable_ora_err not in description: | ||||
|                 raise | ||||
|             return False | ||||
|  | ||||
|     def _get_test_db_params(self): | ||||
|         return { | ||||
|             "dbname": self._test_database_name(), | ||||
|             "user": self._test_database_user(), | ||||
|             "password": self._test_database_passwd(), | ||||
|             "tblspace": self._test_database_tblspace(), | ||||
|             "tblspace_temp": self._test_database_tblspace_tmp(), | ||||
|             "datafile": self._test_database_tblspace_datafile(), | ||||
|             "datafile_tmp": self._test_database_tblspace_tmp_datafile(), | ||||
|             "maxsize": self._test_database_tblspace_maxsize(), | ||||
|             "maxsize_tmp": self._test_database_tblspace_tmp_maxsize(), | ||||
|             "size": self._test_database_tblspace_size(), | ||||
|             "size_tmp": self._test_database_tblspace_tmp_size(), | ||||
|             "extsize": self._test_database_tblspace_extsize(), | ||||
|             "extsize_tmp": self._test_database_tblspace_tmp_extsize(), | ||||
|         } | ||||
|  | ||||
|     def _test_settings_get(self, key, default=None, prefixed=None): | ||||
|         """ | ||||
|         Return a value from the test settings dict, or a given default, or a | ||||
|         prefixed entry from the main settings dict. | ||||
|         """ | ||||
|         settings_dict = self.connection.settings_dict | ||||
|         val = settings_dict["TEST"].get(key, default) | ||||
|         if val is None and prefixed: | ||||
|             val = TEST_DATABASE_PREFIX + settings_dict[prefixed] | ||||
|         return val | ||||
|  | ||||
|     def _test_database_name(self): | ||||
|         return self._test_settings_get("NAME", prefixed="NAME") | ||||
|  | ||||
|     def _test_database_create(self): | ||||
|         return self._test_settings_get("CREATE_DB", default=True) | ||||
|  | ||||
|     def _test_user_create(self): | ||||
|         return self._test_settings_get("CREATE_USER", default=True) | ||||
|  | ||||
|     def _test_database_user(self): | ||||
|         return self._test_settings_get("USER", prefixed="USER") | ||||
|  | ||||
|     def _test_database_passwd(self): | ||||
|         password = self._test_settings_get("PASSWORD") | ||||
|         if password is None and self._test_user_create(): | ||||
|             # Oracle passwords are limited to 30 chars and can't contain symbols. | ||||
|             password = get_random_string(30) | ||||
|         return password | ||||
|  | ||||
|     def _test_database_tblspace(self): | ||||
|         return self._test_settings_get("TBLSPACE", prefixed="USER") | ||||
|  | ||||
|     def _test_database_tblspace_tmp(self): | ||||
|         settings_dict = self.connection.settings_dict | ||||
|         return settings_dict["TEST"].get( | ||||
|             "TBLSPACE_TMP", TEST_DATABASE_PREFIX + settings_dict["USER"] + "_temp" | ||||
|         ) | ||||
|  | ||||
|     def _test_database_tblspace_datafile(self): | ||||
|         tblspace = "%s.dbf" % self._test_database_tblspace() | ||||
|         return self._test_settings_get("DATAFILE", default=tblspace) | ||||
|  | ||||
|     def _test_database_tblspace_tmp_datafile(self): | ||||
|         tblspace = "%s.dbf" % self._test_database_tblspace_tmp() | ||||
|         return self._test_settings_get("DATAFILE_TMP", default=tblspace) | ||||
|  | ||||
|     def _test_database_tblspace_maxsize(self): | ||||
|         return self._test_settings_get("DATAFILE_MAXSIZE", default="500M") | ||||
|  | ||||
|     def _test_database_tblspace_tmp_maxsize(self): | ||||
|         return self._test_settings_get("DATAFILE_TMP_MAXSIZE", default="500M") | ||||
|  | ||||
|     def _test_database_tblspace_size(self): | ||||
|         return self._test_settings_get("DATAFILE_SIZE", default="50M") | ||||
|  | ||||
|     def _test_database_tblspace_tmp_size(self): | ||||
|         return self._test_settings_get("DATAFILE_TMP_SIZE", default="50M") | ||||
|  | ||||
|     def _test_database_tblspace_extsize(self): | ||||
|         return self._test_settings_get("DATAFILE_EXTSIZE", default="25M") | ||||
|  | ||||
|     def _test_database_tblspace_tmp_extsize(self): | ||||
|         return self._test_settings_get("DATAFILE_TMP_EXTSIZE", default="25M") | ||||
|  | ||||
|     def _test_database_oracle_managed_files(self): | ||||
|         return self._test_settings_get("ORACLE_MANAGED_FILES", default=False) | ||||
|  | ||||
|     def _get_test_db_name(self): | ||||
|         """ | ||||
|         Return the 'production' DB name to get the test DB creation machinery | ||||
|         to work. This isn't a great deal in this case because DB names as | ||||
|         handled by Django don't have real counterparts in Oracle. | ||||
|         """ | ||||
|         return self.connection.settings_dict["NAME"] | ||||
|  | ||||
|     def test_db_signature(self): | ||||
|         settings_dict = self.connection.settings_dict | ||||
|         return ( | ||||
|             settings_dict["HOST"], | ||||
|             settings_dict["PORT"], | ||||
|             settings_dict["ENGINE"], | ||||
|             settings_dict["NAME"], | ||||
|             self._test_database_user(), | ||||
|         ) | ||||
| @ -0,0 +1,152 @@ | ||||
| from django.db import DatabaseError, InterfaceError | ||||
| from django.db.backends.base.features import BaseDatabaseFeatures | ||||
| from django.utils.functional import cached_property | ||||
|  | ||||
|  | ||||
| class DatabaseFeatures(BaseDatabaseFeatures): | ||||
|     minimum_database_version = (19,) | ||||
|     # Oracle crashes with "ORA-00932: inconsistent datatypes: expected - got | ||||
|     # BLOB" when grouping by LOBs (#24096). | ||||
|     allows_group_by_lob = False | ||||
|     allows_group_by_select_index = False | ||||
|     interprets_empty_strings_as_nulls = True | ||||
|     has_select_for_update = True | ||||
|     has_select_for_update_nowait = True | ||||
|     has_select_for_update_skip_locked = True | ||||
|     has_select_for_update_of = True | ||||
|     select_for_update_of_column = True | ||||
|     can_return_columns_from_insert = True | ||||
|     supports_subqueries_in_group_by = False | ||||
|     ignores_unnecessary_order_by_in_subqueries = False | ||||
|     supports_transactions = True | ||||
|     supports_timezones = False | ||||
|     has_native_duration_field = True | ||||
|     can_defer_constraint_checks = True | ||||
|     supports_partially_nullable_unique_constraints = False | ||||
|     supports_deferrable_unique_constraints = True | ||||
|     truncates_names = True | ||||
|     supports_comments = True | ||||
|     supports_tablespaces = True | ||||
|     supports_sequence_reset = False | ||||
|     can_introspect_materialized_views = True | ||||
|     atomic_transactions = False | ||||
|     nulls_order_largest = True | ||||
|     requires_literal_defaults = True | ||||
|     closed_cursor_error_class = InterfaceError | ||||
|     bare_select_suffix = " FROM DUAL" | ||||
|     # Select for update with limit can be achieved on Oracle, but not with the | ||||
|     # current backend. | ||||
|     supports_select_for_update_with_limit = False | ||||
|     supports_temporal_subtraction = True | ||||
|     # Oracle doesn't ignore quoted identifiers case but the current backend | ||||
|     # does by uppercasing all identifiers. | ||||
|     ignores_table_name_case = True | ||||
|     supports_index_on_text_field = False | ||||
|     create_test_procedure_without_params_sql = """ | ||||
|         CREATE PROCEDURE "TEST_PROCEDURE" AS | ||||
|             V_I INTEGER; | ||||
|         BEGIN | ||||
|             V_I := 1; | ||||
|         END; | ||||
|     """ | ||||
|     create_test_procedure_with_int_param_sql = """ | ||||
|         CREATE PROCEDURE "TEST_PROCEDURE" (P_I INTEGER) AS | ||||
|             V_I INTEGER; | ||||
|         BEGIN | ||||
|             V_I := P_I; | ||||
|         END; | ||||
|     """ | ||||
|     create_test_table_with_composite_primary_key = """ | ||||
|         CREATE TABLE test_table_composite_pk ( | ||||
|             column_1 NUMBER(11) NOT NULL, | ||||
|             column_2 NUMBER(11) NOT NULL, | ||||
|             PRIMARY KEY (column_1, column_2) | ||||
|         ) | ||||
|     """ | ||||
|     supports_callproc_kwargs = True | ||||
|     supports_over_clause = True | ||||
|     supports_frame_range_fixed_distance = True | ||||
|     supports_ignore_conflicts = False | ||||
|     max_query_params = 2**16 - 1 | ||||
|     supports_partial_indexes = False | ||||
|     can_rename_index = True | ||||
|     supports_slicing_ordering_in_compound = True | ||||
|     requires_compound_order_by_subquery = True | ||||
|     allows_multiple_constraints_on_same_fields = False | ||||
|     supports_boolean_expr_in_select_clause = False | ||||
|     supports_comparing_boolean_expr = False | ||||
|     supports_primitives_in_json_field = False | ||||
|     supports_json_field_contains = False | ||||
|     supports_collation_on_textfield = False | ||||
|     test_collations = { | ||||
|         "ci": "BINARY_CI", | ||||
|         "cs": "BINARY", | ||||
|         "non_default": "SWEDISH_CI", | ||||
|         "swedish_ci": "SWEDISH_CI", | ||||
|     } | ||||
|     test_now_utc_template = "CURRENT_TIMESTAMP AT TIME ZONE 'UTC'" | ||||
|  | ||||
|     django_test_skips = { | ||||
|         "Oracle doesn't support SHA224.": { | ||||
|             "db_functions.text.test_sha224.SHA224Tests.test_basic", | ||||
|             "db_functions.text.test_sha224.SHA224Tests.test_transform", | ||||
|         }, | ||||
|         "Oracle doesn't correctly calculate ISO 8601 week numbering before " | ||||
|         "1583 (the Gregorian calendar was introduced in 1582).": { | ||||
|             "db_functions.datetime.test_extract_trunc.DateFunctionTests." | ||||
|             "test_trunc_week_before_1000", | ||||
|             "db_functions.datetime.test_extract_trunc.DateFunctionWithTimeZoneTests." | ||||
|             "test_trunc_week_before_1000", | ||||
|         }, | ||||
|         "Oracle extracts seconds including fractional seconds (#33517).": { | ||||
|             "db_functions.datetime.test_extract_trunc.DateFunctionTests." | ||||
|             "test_extract_second_func_no_fractional", | ||||
|             "db_functions.datetime.test_extract_trunc.DateFunctionWithTimeZoneTests." | ||||
|             "test_extract_second_func_no_fractional", | ||||
|         }, | ||||
|         "Oracle doesn't support bitwise XOR.": { | ||||
|             "expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor", | ||||
|             "expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor_null", | ||||
|             "expressions.tests.ExpressionOperatorTests." | ||||
|             "test_lefthand_bitwise_xor_right_null", | ||||
|         }, | ||||
|         "Oracle requires ORDER BY in row_number, ANSI:SQL doesn't.": { | ||||
|             "expressions_window.tests.WindowFunctionTests.test_row_number_no_ordering", | ||||
|         }, | ||||
|         "Raises ORA-00600: internal error code.": { | ||||
|             "model_fields.test_jsonfield.TestQuerying.test_usage_in_subquery", | ||||
|         }, | ||||
|         "Oracle doesn't support changing collations on indexed columns (#33671).": { | ||||
|             "migrations.test_operations.OperationTests." | ||||
|             "test_alter_field_pk_fk_db_collation", | ||||
|         }, | ||||
|     } | ||||
|     django_test_expected_failures = { | ||||
|         # A bug in Django/cx_Oracle with respect to string handling (#23843). | ||||
|         "annotations.tests.NonAggregateAnnotationTestCase.test_custom_functions", | ||||
|         "annotations.tests.NonAggregateAnnotationTestCase." | ||||
|         "test_custom_functions_can_ref_other_functions", | ||||
|     } | ||||
|  | ||||
|     @cached_property | ||||
|     def introspected_field_types(self): | ||||
|         return { | ||||
|             **super().introspected_field_types, | ||||
|             "GenericIPAddressField": "CharField", | ||||
|             "PositiveBigIntegerField": "BigIntegerField", | ||||
|             "PositiveIntegerField": "IntegerField", | ||||
|             "PositiveSmallIntegerField": "IntegerField", | ||||
|             "SmallIntegerField": "IntegerField", | ||||
|             "TimeField": "DateTimeField", | ||||
|         } | ||||
|  | ||||
|     @cached_property | ||||
|     def supports_collation_on_charfield(self): | ||||
|         with self.connection.cursor() as cursor: | ||||
|             try: | ||||
|                 cursor.execute("SELECT CAST('a' AS VARCHAR2(4001)) FROM dual") | ||||
|             except DatabaseError as e: | ||||
|                 if e.args[0].code == 910: | ||||
|                     return False | ||||
|                 raise | ||||
|             return True | ||||
| @ -0,0 +1,26 @@ | ||||
| from django.db.models import DecimalField, DurationField, Func | ||||
|  | ||||
|  | ||||
| class IntervalToSeconds(Func): | ||||
|     function = "" | ||||
|     template = """ | ||||
|     EXTRACT(day from %(expressions)s) * 86400 + | ||||
|     EXTRACT(hour from %(expressions)s) * 3600 + | ||||
|     EXTRACT(minute from %(expressions)s) * 60 + | ||||
|     EXTRACT(second from %(expressions)s) | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, expression, *, output_field=None, **extra): | ||||
|         super().__init__( | ||||
|             expression, output_field=output_field or DecimalField(), **extra | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class SecondsToInterval(Func): | ||||
|     function = "NUMTODSINTERVAL" | ||||
|     template = "%(function)s(%(expressions)s, 'SECOND')" | ||||
|  | ||||
|     def __init__(self, expression, *, output_field=None, **extra): | ||||
|         super().__init__( | ||||
|             expression, output_field=output_field or DurationField(), **extra | ||||
|         ) | ||||
| @ -0,0 +1,411 @@ | ||||
| from collections import namedtuple | ||||
|  | ||||
| import cx_Oracle | ||||
|  | ||||
| from django.db import models | ||||
| from django.db.backends.base.introspection import BaseDatabaseIntrospection | ||||
| from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo | ||||
| from django.db.backends.base.introspection import TableInfo as BaseTableInfo | ||||
| from django.utils.functional import cached_property | ||||
|  | ||||
| FieldInfo = namedtuple( | ||||
|     "FieldInfo", BaseFieldInfo._fields + ("is_autofield", "is_json", "comment") | ||||
| ) | ||||
| TableInfo = namedtuple("TableInfo", BaseTableInfo._fields + ("comment",)) | ||||
|  | ||||
|  | ||||
| class DatabaseIntrospection(BaseDatabaseIntrospection): | ||||
|     cache_bust_counter = 1 | ||||
|  | ||||
|     # Maps type objects to Django Field types. | ||||
|     @cached_property | ||||
|     def data_types_reverse(self): | ||||
|         if self.connection.cx_oracle_version < (8,): | ||||
|             return { | ||||
|                 cx_Oracle.BLOB: "BinaryField", | ||||
|                 cx_Oracle.CLOB: "TextField", | ||||
|                 cx_Oracle.DATETIME: "DateField", | ||||
|                 cx_Oracle.FIXED_CHAR: "CharField", | ||||
|                 cx_Oracle.FIXED_NCHAR: "CharField", | ||||
|                 cx_Oracle.INTERVAL: "DurationField", | ||||
|                 cx_Oracle.NATIVE_FLOAT: "FloatField", | ||||
|                 cx_Oracle.NCHAR: "CharField", | ||||
|                 cx_Oracle.NCLOB: "TextField", | ||||
|                 cx_Oracle.NUMBER: "DecimalField", | ||||
|                 cx_Oracle.STRING: "CharField", | ||||
|                 cx_Oracle.TIMESTAMP: "DateTimeField", | ||||
|             } | ||||
|         else: | ||||
|             return { | ||||
|                 cx_Oracle.DB_TYPE_DATE: "DateField", | ||||
|                 cx_Oracle.DB_TYPE_BINARY_DOUBLE: "FloatField", | ||||
|                 cx_Oracle.DB_TYPE_BLOB: "BinaryField", | ||||
|                 cx_Oracle.DB_TYPE_CHAR: "CharField", | ||||
|                 cx_Oracle.DB_TYPE_CLOB: "TextField", | ||||
|                 cx_Oracle.DB_TYPE_INTERVAL_DS: "DurationField", | ||||
|                 cx_Oracle.DB_TYPE_NCHAR: "CharField", | ||||
|                 cx_Oracle.DB_TYPE_NCLOB: "TextField", | ||||
|                 cx_Oracle.DB_TYPE_NVARCHAR: "CharField", | ||||
|                 cx_Oracle.DB_TYPE_NUMBER: "DecimalField", | ||||
|                 cx_Oracle.DB_TYPE_TIMESTAMP: "DateTimeField", | ||||
|                 cx_Oracle.DB_TYPE_VARCHAR: "CharField", | ||||
|             } | ||||
|  | ||||
|     def get_field_type(self, data_type, description): | ||||
|         if data_type == cx_Oracle.NUMBER: | ||||
|             precision, scale = description[4:6] | ||||
|             if scale == 0: | ||||
|                 if precision > 11: | ||||
|                     return ( | ||||
|                         "BigAutoField" | ||||
|                         if description.is_autofield | ||||
|                         else "BigIntegerField" | ||||
|                     ) | ||||
|                 elif 1 < precision < 6 and description.is_autofield: | ||||
|                     return "SmallAutoField" | ||||
|                 elif precision == 1: | ||||
|                     return "BooleanField" | ||||
|                 elif description.is_autofield: | ||||
|                     return "AutoField" | ||||
|                 else: | ||||
|                     return "IntegerField" | ||||
|             elif scale == -127: | ||||
|                 return "FloatField" | ||||
|         elif data_type == cx_Oracle.NCLOB and description.is_json: | ||||
|             return "JSONField" | ||||
|  | ||||
|         return super().get_field_type(data_type, description) | ||||
|  | ||||
|     def get_table_list(self, cursor): | ||||
|         """Return a list of table and view names in the current database.""" | ||||
|         cursor.execute( | ||||
|             """ | ||||
|             SELECT | ||||
|                 user_tables.table_name, | ||||
|                 't', | ||||
|                 user_tab_comments.comments | ||||
|             FROM user_tables | ||||
|             LEFT OUTER JOIN | ||||
|                 user_tab_comments | ||||
|                 ON user_tab_comments.table_name = user_tables.table_name | ||||
|             WHERE | ||||
|                 NOT EXISTS ( | ||||
|                     SELECT 1 | ||||
|                     FROM user_mviews | ||||
|                     WHERE user_mviews.mview_name = user_tables.table_name | ||||
|                 ) | ||||
|             UNION ALL | ||||
|             SELECT view_name, 'v', NULL FROM user_views | ||||
|             UNION ALL | ||||
|             SELECT mview_name, 'v', NULL FROM user_mviews | ||||
|         """ | ||||
|         ) | ||||
|         return [ | ||||
|             TableInfo(self.identifier_converter(row[0]), row[1], row[2]) | ||||
|             for row in cursor.fetchall() | ||||
|         ] | ||||
|  | ||||
|     def get_table_description(self, cursor, table_name): | ||||
|         """ | ||||
|         Return a description of the table with the DB-API cursor.description | ||||
|         interface. | ||||
|         """ | ||||
|         # user_tab_columns gives data default for columns | ||||
|         cursor.execute( | ||||
|             """ | ||||
|             SELECT | ||||
|                 user_tab_cols.column_name, | ||||
|                 user_tab_cols.data_default, | ||||
|                 CASE | ||||
|                     WHEN user_tab_cols.collation = user_tables.default_collation | ||||
|                     THEN NULL | ||||
|                     ELSE user_tab_cols.collation | ||||
|                 END collation, | ||||
|                 CASE | ||||
|                     WHEN user_tab_cols.char_used IS NULL | ||||
|                     THEN user_tab_cols.data_length | ||||
|                     ELSE user_tab_cols.char_length | ||||
|                 END as display_size, | ||||
|                 CASE | ||||
|                     WHEN user_tab_cols.identity_column = 'YES' THEN 1 | ||||
|                     ELSE 0 | ||||
|                 END as is_autofield, | ||||
|                 CASE | ||||
|                     WHEN EXISTS ( | ||||
|                         SELECT  1 | ||||
|                         FROM user_json_columns | ||||
|                         WHERE | ||||
|                             user_json_columns.table_name = user_tab_cols.table_name AND | ||||
|                             user_json_columns.column_name = user_tab_cols.column_name | ||||
|                     ) | ||||
|                     THEN 1 | ||||
|                     ELSE 0 | ||||
|                 END as is_json, | ||||
|                 user_col_comments.comments as col_comment | ||||
|             FROM user_tab_cols | ||||
|             LEFT OUTER JOIN | ||||
|                 user_tables ON user_tables.table_name = user_tab_cols.table_name | ||||
|             LEFT OUTER JOIN | ||||
|                 user_col_comments ON | ||||
|                 user_col_comments.column_name = user_tab_cols.column_name AND | ||||
|                 user_col_comments.table_name = user_tab_cols.table_name | ||||
|             WHERE user_tab_cols.table_name = UPPER(%s) | ||||
|             """, | ||||
|             [table_name], | ||||
|         ) | ||||
|         field_map = { | ||||
|             column: ( | ||||
|                 display_size, | ||||
|                 default if default != "NULL" else None, | ||||
|                 collation, | ||||
|                 is_autofield, | ||||
|                 is_json, | ||||
|                 comment, | ||||
|             ) | ||||
|             for ( | ||||
|                 column, | ||||
|                 default, | ||||
|                 collation, | ||||
|                 display_size, | ||||
|                 is_autofield, | ||||
|                 is_json, | ||||
|                 comment, | ||||
|             ) in cursor.fetchall() | ||||
|         } | ||||
|         self.cache_bust_counter += 1 | ||||
|         cursor.execute( | ||||
|             "SELECT * FROM {} WHERE ROWNUM < 2 AND {} > 0".format( | ||||
|                 self.connection.ops.quote_name(table_name), self.cache_bust_counter | ||||
|             ) | ||||
|         ) | ||||
|         description = [] | ||||
|         for desc in cursor.description: | ||||
|             name = desc[0] | ||||
|             ( | ||||
|                 display_size, | ||||
|                 default, | ||||
|                 collation, | ||||
|                 is_autofield, | ||||
|                 is_json, | ||||
|                 comment, | ||||
|             ) = field_map[name] | ||||
|             name %= {}  # cx_Oracle, for some reason, doubles percent signs. | ||||
|             description.append( | ||||
|                 FieldInfo( | ||||
|                     self.identifier_converter(name), | ||||
|                     desc[1], | ||||
|                     display_size, | ||||
|                     desc[3], | ||||
|                     desc[4] or 0, | ||||
|                     desc[5] or 0, | ||||
|                     *desc[6:], | ||||
|                     default, | ||||
|                     collation, | ||||
|                     is_autofield, | ||||
|                     is_json, | ||||
|                     comment, | ||||
|                 ) | ||||
|             ) | ||||
|         return description | ||||
|  | ||||
|     def identifier_converter(self, name): | ||||
|         """Identifier comparison is case insensitive under Oracle.""" | ||||
|         return name.lower() | ||||
|  | ||||
|     def get_sequences(self, cursor, table_name, table_fields=()): | ||||
|         cursor.execute( | ||||
|             """ | ||||
|             SELECT | ||||
|                 user_tab_identity_cols.sequence_name, | ||||
|                 user_tab_identity_cols.column_name | ||||
|             FROM | ||||
|                 user_tab_identity_cols, | ||||
|                 user_constraints, | ||||
|                 user_cons_columns cols | ||||
|             WHERE | ||||
|                 user_constraints.constraint_name = cols.constraint_name | ||||
|                 AND user_constraints.table_name = user_tab_identity_cols.table_name | ||||
|                 AND cols.column_name = user_tab_identity_cols.column_name | ||||
|                 AND user_constraints.constraint_type = 'P' | ||||
|                 AND user_tab_identity_cols.table_name = UPPER(%s) | ||||
|             """, | ||||
|             [table_name], | ||||
|         ) | ||||
|         # Oracle allows only one identity column per table. | ||||
|         row = cursor.fetchone() | ||||
|         if row: | ||||
|             return [ | ||||
|                 { | ||||
|                     "name": self.identifier_converter(row[0]), | ||||
|                     "table": self.identifier_converter(table_name), | ||||
|                     "column": self.identifier_converter(row[1]), | ||||
|                 } | ||||
|             ] | ||||
|         # To keep backward compatibility for AutoFields that aren't Oracle | ||||
|         # identity columns. | ||||
|         for f in table_fields: | ||||
|             if isinstance(f, models.AutoField): | ||||
|                 return [{"table": table_name, "column": f.column}] | ||||
|         return [] | ||||
|  | ||||
|     def get_relations(self, cursor, table_name): | ||||
|         """ | ||||
|         Return a dictionary of {field_name: (field_name_other_table, other_table)} | ||||
|         representing all foreign keys in the given table. | ||||
|         """ | ||||
|         table_name = table_name.upper() | ||||
|         cursor.execute( | ||||
|             """ | ||||
|     SELECT ca.column_name, cb.table_name, cb.column_name | ||||
|     FROM   user_constraints, USER_CONS_COLUMNS ca, USER_CONS_COLUMNS cb | ||||
|     WHERE  user_constraints.table_name = %s AND | ||||
|            user_constraints.constraint_name = ca.constraint_name AND | ||||
|            user_constraints.r_constraint_name = cb.constraint_name AND | ||||
|            ca.position = cb.position""", | ||||
|             [table_name], | ||||
|         ) | ||||
|  | ||||
|         return { | ||||
|             self.identifier_converter(field_name): ( | ||||
|                 self.identifier_converter(rel_field_name), | ||||
|                 self.identifier_converter(rel_table_name), | ||||
|             ) | ||||
|             for field_name, rel_table_name, rel_field_name in cursor.fetchall() | ||||
|         } | ||||
|  | ||||
|     def get_primary_key_columns(self, cursor, table_name): | ||||
|         cursor.execute( | ||||
|             """ | ||||
|             SELECT | ||||
|                 cols.column_name | ||||
|             FROM | ||||
|                 user_constraints, | ||||
|                 user_cons_columns cols | ||||
|             WHERE | ||||
|                 user_constraints.constraint_name = cols.constraint_name AND | ||||
|                 user_constraints.constraint_type = 'P' AND | ||||
|                 user_constraints.table_name = UPPER(%s) | ||||
|             ORDER BY | ||||
|                 cols.position | ||||
|             """, | ||||
|             [table_name], | ||||
|         ) | ||||
|         return [self.identifier_converter(row[0]) for row in cursor.fetchall()] | ||||
|  | ||||
|     def get_constraints(self, cursor, table_name): | ||||
|         """ | ||||
|         Retrieve any constraints or keys (unique, pk, fk, check, index) across | ||||
|         one or more columns. | ||||
|         """ | ||||
|         constraints = {} | ||||
|         # Loop over the constraints, getting PKs, uniques, and checks | ||||
|         cursor.execute( | ||||
|             """ | ||||
|             SELECT | ||||
|                 user_constraints.constraint_name, | ||||
|                 LISTAGG(LOWER(cols.column_name), ',') | ||||
|                     WITHIN GROUP (ORDER BY cols.position), | ||||
|                 CASE user_constraints.constraint_type | ||||
|                     WHEN 'P' THEN 1 | ||||
|                     ELSE 0 | ||||
|                 END AS is_primary_key, | ||||
|                 CASE | ||||
|                     WHEN user_constraints.constraint_type IN ('P', 'U') THEN 1 | ||||
|                     ELSE 0 | ||||
|                 END AS is_unique, | ||||
|                 CASE user_constraints.constraint_type | ||||
|                     WHEN 'C' THEN 1 | ||||
|                     ELSE 0 | ||||
|                 END AS is_check_constraint | ||||
|             FROM | ||||
|                 user_constraints | ||||
|             LEFT OUTER JOIN | ||||
|                 user_cons_columns cols | ||||
|                 ON user_constraints.constraint_name = cols.constraint_name | ||||
|             WHERE | ||||
|                 user_constraints.constraint_type = ANY('P', 'U', 'C') | ||||
|                 AND user_constraints.table_name = UPPER(%s) | ||||
|             GROUP BY user_constraints.constraint_name, user_constraints.constraint_type | ||||
|             """, | ||||
|             [table_name], | ||||
|         ) | ||||
|         for constraint, columns, pk, unique, check in cursor.fetchall(): | ||||
|             constraint = self.identifier_converter(constraint) | ||||
|             constraints[constraint] = { | ||||
|                 "columns": columns.split(","), | ||||
|                 "primary_key": pk, | ||||
|                 "unique": unique, | ||||
|                 "foreign_key": None, | ||||
|                 "check": check, | ||||
|                 "index": unique,  # All uniques come with an index | ||||
|             } | ||||
|         # Foreign key constraints | ||||
|         cursor.execute( | ||||
|             """ | ||||
|             SELECT | ||||
|                 cons.constraint_name, | ||||
|                 LISTAGG(LOWER(cols.column_name), ',') | ||||
|                     WITHIN GROUP (ORDER BY cols.position), | ||||
|                 LOWER(rcols.table_name), | ||||
|                 LOWER(rcols.column_name) | ||||
|             FROM | ||||
|                 user_constraints cons | ||||
|             INNER JOIN | ||||
|                 user_cons_columns rcols | ||||
|                 ON rcols.constraint_name = cons.r_constraint_name AND rcols.position = 1 | ||||
|             LEFT OUTER JOIN | ||||
|                 user_cons_columns cols | ||||
|                 ON cons.constraint_name = cols.constraint_name | ||||
|             WHERE | ||||
|                 cons.constraint_type = 'R' AND | ||||
|                 cons.table_name = UPPER(%s) | ||||
|             GROUP BY cons.constraint_name, rcols.table_name, rcols.column_name | ||||
|             """, | ||||
|             [table_name], | ||||
|         ) | ||||
|         for constraint, columns, other_table, other_column in cursor.fetchall(): | ||||
|             constraint = self.identifier_converter(constraint) | ||||
|             constraints[constraint] = { | ||||
|                 "primary_key": False, | ||||
|                 "unique": False, | ||||
|                 "foreign_key": (other_table, other_column), | ||||
|                 "check": False, | ||||
|                 "index": False, | ||||
|                 "columns": columns.split(","), | ||||
|             } | ||||
|         # Now get indexes | ||||
|         cursor.execute( | ||||
|             """ | ||||
|             SELECT | ||||
|                 ind.index_name, | ||||
|                 LOWER(ind.index_type), | ||||
|                 LOWER(ind.uniqueness), | ||||
|                 LISTAGG(LOWER(cols.column_name), ',') | ||||
|                     WITHIN GROUP (ORDER BY cols.column_position), | ||||
|                 LISTAGG(cols.descend, ',') WITHIN GROUP (ORDER BY cols.column_position) | ||||
|             FROM | ||||
|                 user_ind_columns cols, user_indexes ind | ||||
|             WHERE | ||||
|                 cols.table_name = UPPER(%s) AND | ||||
|                 NOT EXISTS ( | ||||
|                     SELECT 1 | ||||
|                     FROM user_constraints cons | ||||
|                     WHERE ind.index_name = cons.index_name | ||||
|                 ) AND cols.index_name = ind.index_name | ||||
|             GROUP BY ind.index_name, ind.index_type, ind.uniqueness | ||||
|             """, | ||||
|             [table_name], | ||||
|         ) | ||||
|         for constraint, type_, unique, columns, orders in cursor.fetchall(): | ||||
|             constraint = self.identifier_converter(constraint) | ||||
|             constraints[constraint] = { | ||||
|                 "primary_key": False, | ||||
|                 "unique": unique == "unique", | ||||
|                 "foreign_key": None, | ||||
|                 "check": False, | ||||
|                 "index": True, | ||||
|                 "type": "idx" if type_ == "normal" else type_, | ||||
|                 "columns": columns.split(","), | ||||
|                 "orders": orders.split(","), | ||||
|             } | ||||
|         return constraints | ||||
| @ -0,0 +1,722 @@ | ||||
| import datetime | ||||
| import uuid | ||||
| from functools import lru_cache | ||||
|  | ||||
| from django.conf import settings | ||||
| from django.db import DatabaseError, NotSupportedError | ||||
| from django.db.backends.base.operations import BaseDatabaseOperations | ||||
| from django.db.backends.utils import split_tzname_delta, strip_quotes, truncate_name | ||||
| from django.db.models import AutoField, Exists, ExpressionWrapper, Lookup | ||||
| from django.db.models.expressions import RawSQL | ||||
| from django.db.models.sql.where import WhereNode | ||||
| from django.utils import timezone | ||||
| from django.utils.encoding import force_bytes, force_str | ||||
| from django.utils.functional import cached_property | ||||
| from django.utils.regex_helper import _lazy_re_compile | ||||
|  | ||||
| from .base import Database | ||||
| from .utils import BulkInsertMapper, InsertVar, Oracle_datetime | ||||
|  | ||||
|  | ||||
| class DatabaseOperations(BaseDatabaseOperations): | ||||
|     # Oracle uses NUMBER(5), NUMBER(11), and NUMBER(19) for integer fields. | ||||
|     # SmallIntegerField uses NUMBER(11) instead of NUMBER(5), which is used by | ||||
|     # SmallAutoField, to preserve backward compatibility. | ||||
|     integer_field_ranges = { | ||||
|         "SmallIntegerField": (-99999999999, 99999999999), | ||||
|         "IntegerField": (-99999999999, 99999999999), | ||||
|         "BigIntegerField": (-9999999999999999999, 9999999999999999999), | ||||
|         "PositiveBigIntegerField": (0, 9999999999999999999), | ||||
|         "PositiveSmallIntegerField": (0, 99999999999), | ||||
|         "PositiveIntegerField": (0, 99999999999), | ||||
|         "SmallAutoField": (-99999, 99999), | ||||
|         "AutoField": (-99999999999, 99999999999), | ||||
|         "BigAutoField": (-9999999999999999999, 9999999999999999999), | ||||
|     } | ||||
|     set_operators = {**BaseDatabaseOperations.set_operators, "difference": "MINUS"} | ||||
|  | ||||
|     # TODO: colorize this SQL code with style.SQL_KEYWORD(), etc. | ||||
|     _sequence_reset_sql = """ | ||||
| DECLARE | ||||
|     table_value integer; | ||||
|     seq_value integer; | ||||
|     seq_name user_tab_identity_cols.sequence_name%%TYPE; | ||||
| BEGIN | ||||
|     BEGIN | ||||
|         SELECT sequence_name INTO seq_name FROM user_tab_identity_cols | ||||
|         WHERE  table_name = '%(table_name)s' AND | ||||
|                column_name = '%(column_name)s'; | ||||
|         EXCEPTION WHEN NO_DATA_FOUND THEN | ||||
|             seq_name := '%(no_autofield_sequence_name)s'; | ||||
|     END; | ||||
|  | ||||
|     SELECT NVL(MAX(%(column)s), 0) INTO table_value FROM %(table)s; | ||||
|     SELECT NVL(last_number - cache_size, 0) INTO seq_value FROM user_sequences | ||||
|            WHERE sequence_name = seq_name; | ||||
|     WHILE table_value > seq_value LOOP | ||||
|         EXECUTE IMMEDIATE 'SELECT "'||seq_name||'".nextval FROM DUAL' | ||||
|         INTO seq_value; | ||||
|     END LOOP; | ||||
| END; | ||||
| /""" | ||||
|  | ||||
|     # Oracle doesn't support string without precision; use the max string size. | ||||
|     cast_char_field_without_max_length = "NVARCHAR2(2000)" | ||||
|     cast_data_types = { | ||||
|         "AutoField": "NUMBER(11)", | ||||
|         "BigAutoField": "NUMBER(19)", | ||||
|         "SmallAutoField": "NUMBER(5)", | ||||
|         "TextField": cast_char_field_without_max_length, | ||||
|     } | ||||
|  | ||||
|     def cache_key_culling_sql(self): | ||||
|         cache_key = self.quote_name("cache_key") | ||||
|         return ( | ||||
|             f"SELECT {cache_key} " | ||||
|             f"FROM %s " | ||||
|             f"ORDER BY {cache_key} OFFSET %%s ROWS FETCH FIRST 1 ROWS ONLY" | ||||
|         ) | ||||
|  | ||||
|     # EXTRACT format cannot be passed in parameters. | ||||
|     _extract_format_re = _lazy_re_compile(r"[A-Z_]+") | ||||
|  | ||||
|     def date_extract_sql(self, lookup_type, sql, params): | ||||
|         extract_sql = f"TO_CHAR({sql}, %s)" | ||||
|         extract_param = None | ||||
|         if lookup_type == "week_day": | ||||
|             # TO_CHAR(field, 'D') returns an integer from 1-7, where 1=Sunday. | ||||
|             extract_param = "D" | ||||
|         elif lookup_type == "iso_week_day": | ||||
|             extract_sql = f"TO_CHAR({sql} - 1, %s)" | ||||
|             extract_param = "D" | ||||
|         elif lookup_type == "week": | ||||
|             # IW = ISO week number | ||||
|             extract_param = "IW" | ||||
|         elif lookup_type == "quarter": | ||||
|             extract_param = "Q" | ||||
|         elif lookup_type == "iso_year": | ||||
|             extract_param = "IYYY" | ||||
|         else: | ||||
|             lookup_type = lookup_type.upper() | ||||
|             if not self._extract_format_re.fullmatch(lookup_type): | ||||
|                 raise ValueError(f"Invalid loookup type: {lookup_type!r}") | ||||
|             # https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/EXTRACT-datetime.html | ||||
|             return f"EXTRACT({lookup_type} FROM {sql})", params | ||||
|         return extract_sql, (*params, extract_param) | ||||
|  | ||||
|     def date_trunc_sql(self, lookup_type, sql, params, tzname=None): | ||||
|         sql, params = self._convert_sql_to_tz(sql, params, tzname) | ||||
|         # https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/ROUND-and-TRUNC-Date-Functions.html | ||||
|         trunc_param = None | ||||
|         if lookup_type in ("year", "month"): | ||||
|             trunc_param = lookup_type.upper() | ||||
|         elif lookup_type == "quarter": | ||||
|             trunc_param = "Q" | ||||
|         elif lookup_type == "week": | ||||
|             trunc_param = "IW" | ||||
|         else: | ||||
|             return f"TRUNC({sql})", params | ||||
|         return f"TRUNC({sql}, %s)", (*params, trunc_param) | ||||
|  | ||||
|     # Oracle crashes with "ORA-03113: end-of-file on communication channel" | ||||
|     # if the time zone name is passed in parameter. Use interpolation instead. | ||||
|     # https://groups.google.com/forum/#!msg/django-developers/zwQju7hbG78/9l934yelwfsJ | ||||
|     # This regexp matches all time zone names from the zoneinfo database. | ||||
|     _tzname_re = _lazy_re_compile(r"^[\w/:+-]+$") | ||||
|  | ||||
|     def _prepare_tzname_delta(self, tzname): | ||||
|         tzname, sign, offset = split_tzname_delta(tzname) | ||||
|         return f"{sign}{offset}" if offset else tzname | ||||
|  | ||||
|     def _convert_sql_to_tz(self, sql, params, tzname): | ||||
|         if not (settings.USE_TZ and tzname): | ||||
|             return sql, params | ||||
|         if not self._tzname_re.match(tzname): | ||||
|             raise ValueError("Invalid time zone name: %s" % tzname) | ||||
|         # Convert from connection timezone to the local time, returning | ||||
|         # TIMESTAMP WITH TIME ZONE and cast it back to TIMESTAMP to strip the | ||||
|         # TIME ZONE details. | ||||
|         if self.connection.timezone_name != tzname: | ||||
|             from_timezone_name = self.connection.timezone_name | ||||
|             to_timezone_name = self._prepare_tzname_delta(tzname) | ||||
|             return ( | ||||
|                 f"CAST((FROM_TZ({sql}, '{from_timezone_name}') AT TIME ZONE " | ||||
|                 f"'{to_timezone_name}') AS TIMESTAMP)", | ||||
|                 params, | ||||
|             ) | ||||
|         return sql, params | ||||
|  | ||||
|     def datetime_cast_date_sql(self, sql, params, tzname): | ||||
|         sql, params = self._convert_sql_to_tz(sql, params, tzname) | ||||
|         return f"TRUNC({sql})", params | ||||
|  | ||||
|     def datetime_cast_time_sql(self, sql, params, tzname): | ||||
|         # Since `TimeField` values are stored as TIMESTAMP change to the | ||||
|         # default date and convert the field to the specified timezone. | ||||
|         sql, params = self._convert_sql_to_tz(sql, params, tzname) | ||||
|         convert_datetime_sql = ( | ||||
|             f"TO_TIMESTAMP(CONCAT('1900-01-01 ', TO_CHAR({sql}, 'HH24:MI:SS.FF')), " | ||||
|             f"'YYYY-MM-DD HH24:MI:SS.FF')" | ||||
|         ) | ||||
|         return ( | ||||
|             f"CASE WHEN {sql} IS NOT NULL THEN {convert_datetime_sql} ELSE NULL END", | ||||
|             (*params, *params), | ||||
|         ) | ||||
|  | ||||
|     def datetime_extract_sql(self, lookup_type, sql, params, tzname): | ||||
|         sql, params = self._convert_sql_to_tz(sql, params, tzname) | ||||
|         return self.date_extract_sql(lookup_type, sql, params) | ||||
|  | ||||
|     def datetime_trunc_sql(self, lookup_type, sql, params, tzname): | ||||
|         sql, params = self._convert_sql_to_tz(sql, params, tzname) | ||||
|         # https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/ROUND-and-TRUNC-Date-Functions.html | ||||
|         trunc_param = None | ||||
|         if lookup_type in ("year", "month"): | ||||
|             trunc_param = lookup_type.upper() | ||||
|         elif lookup_type == "quarter": | ||||
|             trunc_param = "Q" | ||||
|         elif lookup_type == "week": | ||||
|             trunc_param = "IW" | ||||
|         elif lookup_type == "hour": | ||||
|             trunc_param = "HH24" | ||||
|         elif lookup_type == "minute": | ||||
|             trunc_param = "MI" | ||||
|         elif lookup_type == "day": | ||||
|             return f"TRUNC({sql})", params | ||||
|         else: | ||||
|             # Cast to DATE removes sub-second precision. | ||||
|             return f"CAST({sql} AS DATE)", params | ||||
|         return f"TRUNC({sql}, %s)", (*params, trunc_param) | ||||
|  | ||||
|     def time_trunc_sql(self, lookup_type, sql, params, tzname=None): | ||||
|         # The implementation is similar to `datetime_trunc_sql` as both | ||||
|         # `DateTimeField` and `TimeField` are stored as TIMESTAMP where | ||||
|         # the date part of the later is ignored. | ||||
|         sql, params = self._convert_sql_to_tz(sql, params, tzname) | ||||
|         trunc_param = None | ||||
|         if lookup_type == "hour": | ||||
|             trunc_param = "HH24" | ||||
|         elif lookup_type == "minute": | ||||
|             trunc_param = "MI" | ||||
|         elif lookup_type == "second": | ||||
|             # Cast to DATE removes sub-second precision. | ||||
|             return f"CAST({sql} AS DATE)", params | ||||
|         return f"TRUNC({sql}, %s)", (*params, trunc_param) | ||||
|  | ||||
|     def get_db_converters(self, expression): | ||||
|         converters = super().get_db_converters(expression) | ||||
|         internal_type = expression.output_field.get_internal_type() | ||||
|         if internal_type in ["JSONField", "TextField"]: | ||||
|             converters.append(self.convert_textfield_value) | ||||
|         elif internal_type == "BinaryField": | ||||
|             converters.append(self.convert_binaryfield_value) | ||||
|         elif internal_type == "BooleanField": | ||||
|             converters.append(self.convert_booleanfield_value) | ||||
|         elif internal_type == "DateTimeField": | ||||
|             if settings.USE_TZ: | ||||
|                 converters.append(self.convert_datetimefield_value) | ||||
|         elif internal_type == "DateField": | ||||
|             converters.append(self.convert_datefield_value) | ||||
|         elif internal_type == "TimeField": | ||||
|             converters.append(self.convert_timefield_value) | ||||
|         elif internal_type == "UUIDField": | ||||
|             converters.append(self.convert_uuidfield_value) | ||||
|         # Oracle stores empty strings as null. If the field accepts the empty | ||||
|         # string, undo this to adhere to the Django convention of using | ||||
|         # the empty string instead of null. | ||||
|         if expression.output_field.empty_strings_allowed: | ||||
|             converters.append( | ||||
|                 self.convert_empty_bytes | ||||
|                 if internal_type == "BinaryField" | ||||
|                 else self.convert_empty_string | ||||
|             ) | ||||
|         return converters | ||||
|  | ||||
|     def convert_textfield_value(self, value, expression, connection): | ||||
|         if isinstance(value, Database.LOB): | ||||
|             value = value.read() | ||||
|         return value | ||||
|  | ||||
|     def convert_binaryfield_value(self, value, expression, connection): | ||||
|         if isinstance(value, Database.LOB): | ||||
|             value = force_bytes(value.read()) | ||||
|         return value | ||||
|  | ||||
|     def convert_booleanfield_value(self, value, expression, connection): | ||||
|         if value in (0, 1): | ||||
|             value = bool(value) | ||||
|         return value | ||||
|  | ||||
|     # cx_Oracle always returns datetime.datetime objects for | ||||
|     # DATE and TIMESTAMP columns, but Django wants to see a | ||||
|     # python datetime.date, .time, or .datetime. | ||||
|  | ||||
|     def convert_datetimefield_value(self, value, expression, connection): | ||||
|         if value is not None: | ||||
|             value = timezone.make_aware(value, self.connection.timezone) | ||||
|         return value | ||||
|  | ||||
|     def convert_datefield_value(self, value, expression, connection): | ||||
|         if isinstance(value, Database.Timestamp): | ||||
|             value = value.date() | ||||
|         return value | ||||
|  | ||||
|     def convert_timefield_value(self, value, expression, connection): | ||||
|         if isinstance(value, Database.Timestamp): | ||||
|             value = value.time() | ||||
|         return value | ||||
|  | ||||
|     def convert_uuidfield_value(self, value, expression, connection): | ||||
|         if value is not None: | ||||
|             value = uuid.UUID(value) | ||||
|         return value | ||||
|  | ||||
|     @staticmethod | ||||
|     def convert_empty_string(value, expression, connection): | ||||
|         return "" if value is None else value | ||||
|  | ||||
|     @staticmethod | ||||
|     def convert_empty_bytes(value, expression, connection): | ||||
|         return b"" if value is None else value | ||||
|  | ||||
|     def deferrable_sql(self): | ||||
|         return " DEFERRABLE INITIALLY DEFERRED" | ||||
|  | ||||
|     def fetch_returned_insert_columns(self, cursor, returning_params): | ||||
|         columns = [] | ||||
|         for param in returning_params: | ||||
|             value = param.get_value() | ||||
|             if value == []: | ||||
|                 raise DatabaseError( | ||||
|                     "The database did not return a new row id. Probably " | ||||
|                     '"ORA-1403: no data found" was raised internally but was ' | ||||
|                     "hidden by the Oracle OCI library (see " | ||||
|                     "https://code.djangoproject.com/ticket/28859)." | ||||
|                 ) | ||||
|             columns.append(value[0]) | ||||
|         return tuple(columns) | ||||
|  | ||||
|     def no_limit_value(self): | ||||
|         return None | ||||
|  | ||||
|     def limit_offset_sql(self, low_mark, high_mark): | ||||
|         fetch, offset = self._get_limit_offset_params(low_mark, high_mark) | ||||
|         return " ".join( | ||||
|             sql | ||||
|             for sql in ( | ||||
|                 ("OFFSET %d ROWS" % offset) if offset else None, | ||||
|                 ("FETCH FIRST %d ROWS ONLY" % fetch) if fetch else None, | ||||
|             ) | ||||
|             if sql | ||||
|         ) | ||||
|  | ||||
|     def last_executed_query(self, cursor, sql, params): | ||||
|         # https://cx-oracle.readthedocs.io/en/latest/api_manual/cursor.html#Cursor.statement | ||||
|         # The DB API definition does not define this attribute. | ||||
|         statement = cursor.statement | ||||
|         # Unlike Psycopg's `query` and MySQLdb`'s `_executed`, cx_Oracle's | ||||
|         # `statement` doesn't contain the query parameters. Substitute | ||||
|         # parameters manually. | ||||
|         if params: | ||||
|             if isinstance(params, (tuple, list)): | ||||
|                 params = { | ||||
|                     f":arg{i}": param for i, param in enumerate(dict.fromkeys(params)) | ||||
|                 } | ||||
|             elif isinstance(params, dict): | ||||
|                 params = {f":{key}": val for (key, val) in params.items()} | ||||
|             for key in sorted(params, key=len, reverse=True): | ||||
|                 statement = statement.replace( | ||||
|                     key, force_str(params[key], errors="replace") | ||||
|                 ) | ||||
|         return statement | ||||
|  | ||||
|     def last_insert_id(self, cursor, table_name, pk_name): | ||||
|         sq_name = self._get_sequence_name(cursor, strip_quotes(table_name), pk_name) | ||||
|         cursor.execute('"%s".currval' % sq_name) | ||||
|         return cursor.fetchone()[0] | ||||
|  | ||||
|     def lookup_cast(self, lookup_type, internal_type=None): | ||||
|         if lookup_type in ("iexact", "icontains", "istartswith", "iendswith"): | ||||
|             return "UPPER(%s)" | ||||
|         if ( | ||||
|             lookup_type != "isnull" and internal_type in ("BinaryField", "TextField") | ||||
|         ) or (lookup_type == "exact" and internal_type == "JSONField"): | ||||
|             return "DBMS_LOB.SUBSTR(%s)" | ||||
|         return "%s" | ||||
|  | ||||
|     def max_in_list_size(self): | ||||
|         return 1000 | ||||
|  | ||||
|     def max_name_length(self): | ||||
|         return 30 | ||||
|  | ||||
|     def pk_default_value(self): | ||||
|         return "NULL" | ||||
|  | ||||
|     def prep_for_iexact_query(self, x): | ||||
|         return x | ||||
|  | ||||
|     def process_clob(self, value): | ||||
|         if value is None: | ||||
|             return "" | ||||
|         return value.read() | ||||
|  | ||||
|     def quote_name(self, name): | ||||
|         # SQL92 requires delimited (quoted) names to be case-sensitive.  When | ||||
|         # not quoted, Oracle has case-insensitive behavior for identifiers, but | ||||
|         # always defaults to uppercase. | ||||
|         # We simplify things by making Oracle identifiers always uppercase. | ||||
|         if not name.startswith('"') and not name.endswith('"'): | ||||
|             name = '"%s"' % truncate_name(name, self.max_name_length()) | ||||
|         # Oracle puts the query text into a (query % args) construct, so % signs | ||||
|         # in names need to be escaped. The '%%' will be collapsed back to '%' at | ||||
|         # that stage so we aren't really making the name longer here. | ||||
|         name = name.replace("%", "%%") | ||||
|         return name.upper() | ||||
|  | ||||
|     def regex_lookup(self, lookup_type): | ||||
|         if lookup_type == "regex": | ||||
|             match_option = "'c'" | ||||
|         else: | ||||
|             match_option = "'i'" | ||||
|         return "REGEXP_LIKE(%%s, %%s, %s)" % match_option | ||||
|  | ||||
|     def return_insert_columns(self, fields): | ||||
|         if not fields: | ||||
|             return "", () | ||||
|         field_names = [] | ||||
|         params = [] | ||||
|         for field in fields: | ||||
|             field_names.append( | ||||
|                 "%s.%s" | ||||
|                 % ( | ||||
|                     self.quote_name(field.model._meta.db_table), | ||||
|                     self.quote_name(field.column), | ||||
|                 ) | ||||
|             ) | ||||
|             params.append(InsertVar(field)) | ||||
|         return "RETURNING %s INTO %s" % ( | ||||
|             ", ".join(field_names), | ||||
|             ", ".join(["%s"] * len(params)), | ||||
|         ), tuple(params) | ||||
|  | ||||
|     def __foreign_key_constraints(self, table_name, recursive): | ||||
|         with self.connection.cursor() as cursor: | ||||
|             if recursive: | ||||
|                 cursor.execute( | ||||
|                     """ | ||||
|                     SELECT | ||||
|                         user_tables.table_name, rcons.constraint_name | ||||
|                     FROM | ||||
|                         user_tables | ||||
|                     JOIN | ||||
|                         user_constraints cons | ||||
|                         ON (user_tables.table_name = cons.table_name | ||||
|                         AND cons.constraint_type = ANY('P', 'U')) | ||||
|                     LEFT JOIN | ||||
|                         user_constraints rcons | ||||
|                         ON (user_tables.table_name = rcons.table_name | ||||
|                         AND rcons.constraint_type = 'R') | ||||
|                     START WITH user_tables.table_name = UPPER(%s) | ||||
|                     CONNECT BY | ||||
|                         NOCYCLE PRIOR cons.constraint_name = rcons.r_constraint_name | ||||
|                     GROUP BY | ||||
|                         user_tables.table_name, rcons.constraint_name | ||||
|                     HAVING user_tables.table_name != UPPER(%s) | ||||
|                     ORDER BY MAX(level) DESC | ||||
|                     """, | ||||
|                     (table_name, table_name), | ||||
|                 ) | ||||
|             else: | ||||
|                 cursor.execute( | ||||
|                     """ | ||||
|                     SELECT | ||||
|                         cons.table_name, cons.constraint_name | ||||
|                     FROM | ||||
|                         user_constraints cons | ||||
|                     WHERE | ||||
|                         cons.constraint_type = 'R' | ||||
|                         AND cons.table_name = UPPER(%s) | ||||
|                     """, | ||||
|                     (table_name,), | ||||
|                 ) | ||||
|             return cursor.fetchall() | ||||
|  | ||||
|     @cached_property | ||||
|     def _foreign_key_constraints(self): | ||||
|         # 512 is large enough to fit the ~330 tables (as of this writing) in | ||||
|         # Django's test suite. | ||||
|         return lru_cache(maxsize=512)(self.__foreign_key_constraints) | ||||
|  | ||||
|     def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False): | ||||
|         if not tables: | ||||
|             return [] | ||||
|  | ||||
|         truncated_tables = {table.upper() for table in tables} | ||||
|         constraints = set() | ||||
|         # Oracle's TRUNCATE CASCADE only works with ON DELETE CASCADE foreign | ||||
|         # keys which Django doesn't define. Emulate the PostgreSQL behavior | ||||
|         # which truncates all dependent tables by manually retrieving all | ||||
|         # foreign key constraints and resolving dependencies. | ||||
|         for table in tables: | ||||
|             for foreign_table, constraint in self._foreign_key_constraints( | ||||
|                 table, recursive=allow_cascade | ||||
|             ): | ||||
|                 if allow_cascade: | ||||
|                     truncated_tables.add(foreign_table) | ||||
|                 constraints.add((foreign_table, constraint)) | ||||
|         sql = ( | ||||
|             [ | ||||
|                 "%s %s %s %s %s %s %s %s;" | ||||
|                 % ( | ||||
|                     style.SQL_KEYWORD("ALTER"), | ||||
|                     style.SQL_KEYWORD("TABLE"), | ||||
|                     style.SQL_FIELD(self.quote_name(table)), | ||||
|                     style.SQL_KEYWORD("DISABLE"), | ||||
|                     style.SQL_KEYWORD("CONSTRAINT"), | ||||
|                     style.SQL_FIELD(self.quote_name(constraint)), | ||||
|                     style.SQL_KEYWORD("KEEP"), | ||||
|                     style.SQL_KEYWORD("INDEX"), | ||||
|                 ) | ||||
|                 for table, constraint in constraints | ||||
|             ] | ||||
|             + [ | ||||
|                 "%s %s %s;" | ||||
|                 % ( | ||||
|                     style.SQL_KEYWORD("TRUNCATE"), | ||||
|                     style.SQL_KEYWORD("TABLE"), | ||||
|                     style.SQL_FIELD(self.quote_name(table)), | ||||
|                 ) | ||||
|                 for table in truncated_tables | ||||
|             ] | ||||
|             + [ | ||||
|                 "%s %s %s %s %s %s;" | ||||
|                 % ( | ||||
|                     style.SQL_KEYWORD("ALTER"), | ||||
|                     style.SQL_KEYWORD("TABLE"), | ||||
|                     style.SQL_FIELD(self.quote_name(table)), | ||||
|                     style.SQL_KEYWORD("ENABLE"), | ||||
|                     style.SQL_KEYWORD("CONSTRAINT"), | ||||
|                     style.SQL_FIELD(self.quote_name(constraint)), | ||||
|                 ) | ||||
|                 for table, constraint in constraints | ||||
|             ] | ||||
|         ) | ||||
|         if reset_sequences: | ||||
|             sequences = [ | ||||
|                 sequence | ||||
|                 for sequence in self.connection.introspection.sequence_list() | ||||
|                 if sequence["table"].upper() in truncated_tables | ||||
|             ] | ||||
|             # Since we've just deleted all the rows, running our sequence ALTER | ||||
|             # code will reset the sequence to 0. | ||||
|             sql.extend(self.sequence_reset_by_name_sql(style, sequences)) | ||||
|         return sql | ||||
|  | ||||
|     def sequence_reset_by_name_sql(self, style, sequences): | ||||
|         sql = [] | ||||
|         for sequence_info in sequences: | ||||
|             no_autofield_sequence_name = self._get_no_autofield_sequence_name( | ||||
|                 sequence_info["table"] | ||||
|             ) | ||||
|             table = self.quote_name(sequence_info["table"]) | ||||
|             column = self.quote_name(sequence_info["column"] or "id") | ||||
|             query = self._sequence_reset_sql % { | ||||
|                 "no_autofield_sequence_name": no_autofield_sequence_name, | ||||
|                 "table": table, | ||||
|                 "column": column, | ||||
|                 "table_name": strip_quotes(table), | ||||
|                 "column_name": strip_quotes(column), | ||||
|             } | ||||
|             sql.append(query) | ||||
|         return sql | ||||
|  | ||||
|     def sequence_reset_sql(self, style, model_list): | ||||
|         output = [] | ||||
|         query = self._sequence_reset_sql | ||||
|         for model in model_list: | ||||
|             for f in model._meta.local_fields: | ||||
|                 if isinstance(f, AutoField): | ||||
|                     no_autofield_sequence_name = self._get_no_autofield_sequence_name( | ||||
|                         model._meta.db_table | ||||
|                     ) | ||||
|                     table = self.quote_name(model._meta.db_table) | ||||
|                     column = self.quote_name(f.column) | ||||
|                     output.append( | ||||
|                         query | ||||
|                         % { | ||||
|                             "no_autofield_sequence_name": no_autofield_sequence_name, | ||||
|                             "table": table, | ||||
|                             "column": column, | ||||
|                             "table_name": strip_quotes(table), | ||||
|                             "column_name": strip_quotes(column), | ||||
|                         } | ||||
|                     ) | ||||
|                     # Only one AutoField is allowed per model, so don't | ||||
|                     # continue to loop | ||||
|                     break | ||||
|         return output | ||||
|  | ||||
|     def start_transaction_sql(self): | ||||
|         return "" | ||||
|  | ||||
|     def tablespace_sql(self, tablespace, inline=False): | ||||
|         if inline: | ||||
|             return "USING INDEX TABLESPACE %s" % self.quote_name(tablespace) | ||||
|         else: | ||||
|             return "TABLESPACE %s" % self.quote_name(tablespace) | ||||
|  | ||||
|     def adapt_datefield_value(self, value): | ||||
|         """ | ||||
|         Transform a date value to an object compatible with what is expected | ||||
|         by the backend driver for date columns. | ||||
|         The default implementation transforms the date to text, but that is not | ||||
|         necessary for Oracle. | ||||
|         """ | ||||
|         return value | ||||
|  | ||||
|     def adapt_datetimefield_value(self, value): | ||||
|         """ | ||||
|         Transform a datetime value to an object compatible with what is expected | ||||
|         by the backend driver for datetime columns. | ||||
|  | ||||
|         If naive datetime is passed assumes that is in UTC. Normally Django | ||||
|         models.DateTimeField makes sure that if USE_TZ is True passed datetime | ||||
|         is timezone aware. | ||||
|         """ | ||||
|  | ||||
|         if value is None: | ||||
|             return None | ||||
|  | ||||
|         # Expression values are adapted by the database. | ||||
|         if hasattr(value, "resolve_expression"): | ||||
|             return value | ||||
|  | ||||
|         # cx_Oracle doesn't support tz-aware datetimes | ||||
|         if timezone.is_aware(value): | ||||
|             if settings.USE_TZ: | ||||
|                 value = timezone.make_naive(value, self.connection.timezone) | ||||
|             else: | ||||
|                 raise ValueError( | ||||
|                     "Oracle backend does not support timezone-aware datetimes when " | ||||
|                     "USE_TZ is False." | ||||
|                 ) | ||||
|  | ||||
|         return Oracle_datetime.from_datetime(value) | ||||
|  | ||||
|     def adapt_timefield_value(self, value): | ||||
|         if value is None: | ||||
|             return None | ||||
|  | ||||
|         # Expression values are adapted by the database. | ||||
|         if hasattr(value, "resolve_expression"): | ||||
|             return value | ||||
|  | ||||
|         if isinstance(value, str): | ||||
|             return datetime.datetime.strptime(value, "%H:%M:%S") | ||||
|  | ||||
|         # Oracle doesn't support tz-aware times | ||||
|         if timezone.is_aware(value): | ||||
|             raise ValueError("Oracle backend does not support timezone-aware times.") | ||||
|  | ||||
|         return Oracle_datetime( | ||||
|             1900, 1, 1, value.hour, value.minute, value.second, value.microsecond | ||||
|         ) | ||||
|  | ||||
|     def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None): | ||||
|         return value | ||||
|  | ||||
|     def combine_expression(self, connector, sub_expressions): | ||||
|         lhs, rhs = sub_expressions | ||||
|         if connector == "%%": | ||||
|             return "MOD(%s)" % ",".join(sub_expressions) | ||||
|         elif connector == "&": | ||||
|             return "BITAND(%s)" % ",".join(sub_expressions) | ||||
|         elif connector == "|": | ||||
|             return "BITAND(-%(lhs)s-1,%(rhs)s)+%(lhs)s" % {"lhs": lhs, "rhs": rhs} | ||||
|         elif connector == "<<": | ||||
|             return "(%(lhs)s * POWER(2, %(rhs)s))" % {"lhs": lhs, "rhs": rhs} | ||||
|         elif connector == ">>": | ||||
|             return "FLOOR(%(lhs)s / POWER(2, %(rhs)s))" % {"lhs": lhs, "rhs": rhs} | ||||
|         elif connector == "^": | ||||
|             return "POWER(%s)" % ",".join(sub_expressions) | ||||
|         elif connector == "#": | ||||
|             raise NotSupportedError("Bitwise XOR is not supported in Oracle.") | ||||
|         return super().combine_expression(connector, sub_expressions) | ||||
|  | ||||
|     def _get_no_autofield_sequence_name(self, table): | ||||
|         """ | ||||
|         Manually created sequence name to keep backward compatibility for | ||||
|         AutoFields that aren't Oracle identity columns. | ||||
|         """ | ||||
|         name_length = self.max_name_length() - 3 | ||||
|         return "%s_SQ" % truncate_name(strip_quotes(table), name_length).upper() | ||||
|  | ||||
|     def _get_sequence_name(self, cursor, table, pk_name): | ||||
|         cursor.execute( | ||||
|             """ | ||||
|             SELECT sequence_name | ||||
|             FROM user_tab_identity_cols | ||||
|             WHERE table_name = UPPER(%s) | ||||
|             AND column_name = UPPER(%s)""", | ||||
|             [table, pk_name], | ||||
|         ) | ||||
|         row = cursor.fetchone() | ||||
|         return self._get_no_autofield_sequence_name(table) if row is None else row[0] | ||||
|  | ||||
|     def bulk_insert_sql(self, fields, placeholder_rows): | ||||
|         query = [] | ||||
|         for row in placeholder_rows: | ||||
|             select = [] | ||||
|             for i, placeholder in enumerate(row): | ||||
|                 # A model without any fields has fields=[None]. | ||||
|                 if fields[i]: | ||||
|                     internal_type = getattr( | ||||
|                         fields[i], "target_field", fields[i] | ||||
|                     ).get_internal_type() | ||||
|                     placeholder = ( | ||||
|                         BulkInsertMapper.types.get(internal_type, "%s") % placeholder | ||||
|                     ) | ||||
|                 # Add columns aliases to the first select to avoid "ORA-00918: | ||||
|                 # column ambiguously defined" when two or more columns in the | ||||
|                 # first select have the same value. | ||||
|                 if not query: | ||||
|                     placeholder = "%s col_%s" % (placeholder, i) | ||||
|                 select.append(placeholder) | ||||
|             query.append("SELECT %s FROM DUAL" % ", ".join(select)) | ||||
|         # Bulk insert to tables with Oracle identity columns causes Oracle to | ||||
|         # add sequence.nextval to it. Sequence.nextval cannot be used with the | ||||
|         # UNION operator. To prevent incorrect SQL, move UNION to a subquery. | ||||
|         return "SELECT * FROM (%s)" % " UNION ALL ".join(query) | ||||
|  | ||||
|     def subtract_temporals(self, internal_type, lhs, rhs): | ||||
|         if internal_type == "DateField": | ||||
|             lhs_sql, lhs_params = lhs | ||||
|             rhs_sql, rhs_params = rhs | ||||
|             params = (*lhs_params, *rhs_params) | ||||
|             return ( | ||||
|                 "NUMTODSINTERVAL(TO_NUMBER(%s - %s), 'DAY')" % (lhs_sql, rhs_sql), | ||||
|                 params, | ||||
|             ) | ||||
|         return super().subtract_temporals(internal_type, lhs, rhs) | ||||
|  | ||||
|     def bulk_batch_size(self, fields, objs): | ||||
|         """Oracle restricts the number of parameters in a query.""" | ||||
|         if fields: | ||||
|             return self.connection.features.max_query_params // len(fields) | ||||
|         return len(objs) | ||||
|  | ||||
|     def conditional_expression_supported_in_where_clause(self, expression): | ||||
|         """ | ||||
|         Oracle supports only EXISTS(...) or filters in the WHERE clause, others | ||||
|         must be compared with True. | ||||
|         """ | ||||
|         if isinstance(expression, (Exists, Lookup, WhereNode)): | ||||
|             return True | ||||
|         if isinstance(expression, ExpressionWrapper) and expression.conditional: | ||||
|             return self.conditional_expression_supported_in_where_clause( | ||||
|                 expression.expression | ||||
|             ) | ||||
|         if isinstance(expression, RawSQL) and expression.conditional: | ||||
|             return True | ||||
|         return False | ||||
| @ -0,0 +1,250 @@ | ||||
| import copy | ||||
| import datetime | ||||
| import re | ||||
|  | ||||
| from django.db import DatabaseError | ||||
| from django.db.backends.base.schema import ( | ||||
|     BaseDatabaseSchemaEditor, | ||||
|     _related_non_m2m_objects, | ||||
| ) | ||||
| from django.utils.duration import duration_iso_string | ||||
|  | ||||
|  | ||||
| class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): | ||||
|     sql_create_column = "ALTER TABLE %(table)s ADD %(column)s %(definition)s" | ||||
|     sql_alter_column_type = "MODIFY %(column)s %(type)s%(collation)s" | ||||
|     sql_alter_column_null = "MODIFY %(column)s NULL" | ||||
|     sql_alter_column_not_null = "MODIFY %(column)s NOT NULL" | ||||
|     sql_alter_column_default = "MODIFY %(column)s DEFAULT %(default)s" | ||||
|     sql_alter_column_no_default = "MODIFY %(column)s DEFAULT NULL" | ||||
|     sql_alter_column_no_default_null = sql_alter_column_no_default | ||||
|  | ||||
|     sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s" | ||||
|     sql_create_column_inline_fk = ( | ||||
|         "CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s" | ||||
|     ) | ||||
|     sql_delete_table = "DROP TABLE %(table)s CASCADE CONSTRAINTS" | ||||
|     sql_create_index = "CREATE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s" | ||||
|  | ||||
|     def quote_value(self, value): | ||||
|         if isinstance(value, (datetime.date, datetime.time, datetime.datetime)): | ||||
|             return "'%s'" % value | ||||
|         elif isinstance(value, datetime.timedelta): | ||||
|             return "'%s'" % duration_iso_string(value) | ||||
|         elif isinstance(value, str): | ||||
|             return "'%s'" % value.replace("'", "''").replace("%", "%%") | ||||
|         elif isinstance(value, (bytes, bytearray, memoryview)): | ||||
|             return "'%s'" % value.hex() | ||||
|         elif isinstance(value, bool): | ||||
|             return "1" if value else "0" | ||||
|         else: | ||||
|             return str(value) | ||||
|  | ||||
|     def remove_field(self, model, field): | ||||
|         # If the column is an identity column, drop the identity before | ||||
|         # removing the field. | ||||
|         if self._is_identity_column(model._meta.db_table, field.column): | ||||
|             self._drop_identity(model._meta.db_table, field.column) | ||||
|         super().remove_field(model, field) | ||||
|  | ||||
|     def delete_model(self, model): | ||||
|         # Run superclass action | ||||
|         super().delete_model(model) | ||||
|         # Clean up manually created sequence. | ||||
|         self.execute( | ||||
|             """ | ||||
|             DECLARE | ||||
|                 i INTEGER; | ||||
|             BEGIN | ||||
|                 SELECT COUNT(1) INTO i FROM USER_SEQUENCES | ||||
|                     WHERE SEQUENCE_NAME = '%(sq_name)s'; | ||||
|                 IF i = 1 THEN | ||||
|                     EXECUTE IMMEDIATE 'DROP SEQUENCE "%(sq_name)s"'; | ||||
|                 END IF; | ||||
|             END; | ||||
|         /""" | ||||
|             % { | ||||
|                 "sq_name": self.connection.ops._get_no_autofield_sequence_name( | ||||
|                     model._meta.db_table | ||||
|                 ) | ||||
|             } | ||||
|         ) | ||||
|  | ||||
|     def alter_field(self, model, old_field, new_field, strict=False): | ||||
|         try: | ||||
|             super().alter_field(model, old_field, new_field, strict) | ||||
|         except DatabaseError as e: | ||||
|             description = str(e) | ||||
|             # If we're changing type to an unsupported type we need a | ||||
|             # SQLite-ish workaround | ||||
|             if "ORA-22858" in description or "ORA-22859" in description: | ||||
|                 self._alter_field_type_workaround(model, old_field, new_field) | ||||
|             # If an identity column is changing to a non-numeric type, drop the | ||||
|             # identity first. | ||||
|             elif "ORA-30675" in description: | ||||
|                 self._drop_identity(model._meta.db_table, old_field.column) | ||||
|                 self.alter_field(model, old_field, new_field, strict) | ||||
|             # If a primary key column is changing to an identity column, drop | ||||
|             # the primary key first. | ||||
|             elif "ORA-30673" in description and old_field.primary_key: | ||||
|                 self._delete_primary_key(model, strict=True) | ||||
|                 self._alter_field_type_workaround(model, old_field, new_field) | ||||
|             # If a collation is changing on a primary key, drop the primary key | ||||
|             # first. | ||||
|             elif "ORA-43923" in description and old_field.primary_key: | ||||
|                 self._delete_primary_key(model, strict=True) | ||||
|                 self.alter_field(model, old_field, new_field, strict) | ||||
|                 # Restore a primary key, if needed. | ||||
|                 if new_field.primary_key: | ||||
|                     self.execute(self._create_primary_key_sql(model, new_field)) | ||||
|             else: | ||||
|                 raise | ||||
|  | ||||
|     def _alter_field_type_workaround(self, model, old_field, new_field): | ||||
|         """ | ||||
|         Oracle refuses to change from some type to other type. | ||||
|         What we need to do instead is: | ||||
|         - Add a nullable version of the desired field with a temporary name. If | ||||
|           the new column is an auto field, then the temporary column can't be | ||||
|           nullable. | ||||
|         - Update the table to transfer values from old to new | ||||
|         - Drop old column | ||||
|         - Rename the new column and possibly drop the nullable property | ||||
|         """ | ||||
|         # Make a new field that's like the new one but with a temporary | ||||
|         # column name. | ||||
|         new_temp_field = copy.deepcopy(new_field) | ||||
|         new_temp_field.null = new_field.get_internal_type() not in ( | ||||
|             "AutoField", | ||||
|             "BigAutoField", | ||||
|             "SmallAutoField", | ||||
|         ) | ||||
|         new_temp_field.column = self._generate_temp_name(new_field.column) | ||||
|         # Add it | ||||
|         self.add_field(model, new_temp_field) | ||||
|         # Explicit data type conversion | ||||
|         # https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf | ||||
|         # /Data-Type-Comparison-Rules.html#GUID-D0C5A47E-6F93-4C2D-9E49-4F2B86B359DD | ||||
|         new_value = self.quote_name(old_field.column) | ||||
|         old_type = old_field.db_type(self.connection) | ||||
|         if re.match("^N?CLOB", old_type): | ||||
|             new_value = "TO_CHAR(%s)" % new_value | ||||
|             old_type = "VARCHAR2" | ||||
|         if re.match("^N?VARCHAR2", old_type): | ||||
|             new_internal_type = new_field.get_internal_type() | ||||
|             if new_internal_type == "DateField": | ||||
|                 new_value = "TO_DATE(%s, 'YYYY-MM-DD')" % new_value | ||||
|             elif new_internal_type == "DateTimeField": | ||||
|                 new_value = "TO_TIMESTAMP(%s, 'YYYY-MM-DD HH24:MI:SS.FF')" % new_value | ||||
|             elif new_internal_type == "TimeField": | ||||
|                 # TimeField are stored as TIMESTAMP with a 1900-01-01 date part. | ||||
|                 new_value = "CONCAT('1900-01-01 ', %s)" % new_value | ||||
|                 new_value = "TO_TIMESTAMP(%s, 'YYYY-MM-DD HH24:MI:SS.FF')" % new_value | ||||
|         # Transfer values across | ||||
|         self.execute( | ||||
|             "UPDATE %s set %s=%s" | ||||
|             % ( | ||||
|                 self.quote_name(model._meta.db_table), | ||||
|                 self.quote_name(new_temp_field.column), | ||||
|                 new_value, | ||||
|             ) | ||||
|         ) | ||||
|         # Drop the old field | ||||
|         self.remove_field(model, old_field) | ||||
|         # Rename and possibly make the new field NOT NULL | ||||
|         super().alter_field(model, new_temp_field, new_field) | ||||
|         # Recreate foreign key (if necessary) because the old field is not | ||||
|         # passed to the alter_field() and data types of new_temp_field and | ||||
|         # new_field always match. | ||||
|         new_type = new_field.db_type(self.connection) | ||||
|         if ( | ||||
|             (old_field.primary_key and new_field.primary_key) | ||||
|             or (old_field.unique and new_field.unique) | ||||
|         ) and old_type != new_type: | ||||
|             for _, rel in _related_non_m2m_objects(new_temp_field, new_field): | ||||
|                 if rel.field.db_constraint: | ||||
|                     self.execute( | ||||
|                         self._create_fk_sql(rel.related_model, rel.field, "_fk") | ||||
|                     ) | ||||
|  | ||||
|     def _alter_column_type_sql( | ||||
|         self, model, old_field, new_field, new_type, old_collation, new_collation | ||||
|     ): | ||||
|         auto_field_types = {"AutoField", "BigAutoField", "SmallAutoField"} | ||||
|         # Drop the identity if migrating away from AutoField. | ||||
|         if ( | ||||
|             old_field.get_internal_type() in auto_field_types | ||||
|             and new_field.get_internal_type() not in auto_field_types | ||||
|             and self._is_identity_column(model._meta.db_table, new_field.column) | ||||
|         ): | ||||
|             self._drop_identity(model._meta.db_table, new_field.column) | ||||
|         return super()._alter_column_type_sql( | ||||
|             model, old_field, new_field, new_type, old_collation, new_collation | ||||
|         ) | ||||
|  | ||||
|     def normalize_name(self, name): | ||||
|         """ | ||||
|         Get the properly shortened and uppercased identifier as returned by | ||||
|         quote_name() but without the quotes. | ||||
|         """ | ||||
|         nn = self.quote_name(name) | ||||
|         if nn[0] == '"' and nn[-1] == '"': | ||||
|             nn = nn[1:-1] | ||||
|         return nn | ||||
|  | ||||
|     def _generate_temp_name(self, for_name): | ||||
|         """Generate temporary names for workarounds that need temp columns.""" | ||||
|         suffix = hex(hash(for_name)).upper()[1:] | ||||
|         return self.normalize_name(for_name + "_" + suffix) | ||||
|  | ||||
|     def prepare_default(self, value): | ||||
|         return self.quote_value(value) | ||||
|  | ||||
|     def _field_should_be_indexed(self, model, field): | ||||
|         create_index = super()._field_should_be_indexed(model, field) | ||||
|         db_type = field.db_type(self.connection) | ||||
|         if ( | ||||
|             db_type is not None | ||||
|             and db_type.lower() in self.connection._limited_data_types | ||||
|         ): | ||||
|             return False | ||||
|         return create_index | ||||
|  | ||||
|     def _is_identity_column(self, table_name, column_name): | ||||
|         with self.connection.cursor() as cursor: | ||||
|             cursor.execute( | ||||
|                 """ | ||||
|                 SELECT | ||||
|                     CASE WHEN identity_column = 'YES' THEN 1 ELSE 0 END | ||||
|                 FROM user_tab_cols | ||||
|                 WHERE table_name = %s AND | ||||
|                       column_name = %s | ||||
|                 """, | ||||
|                 [self.normalize_name(table_name), self.normalize_name(column_name)], | ||||
|             ) | ||||
|             row = cursor.fetchone() | ||||
|             return row[0] if row else False | ||||
|  | ||||
|     def _drop_identity(self, table_name, column_name): | ||||
|         self.execute( | ||||
|             "ALTER TABLE %(table)s MODIFY %(column)s DROP IDENTITY" | ||||
|             % { | ||||
|                 "table": self.quote_name(table_name), | ||||
|                 "column": self.quote_name(column_name), | ||||
|             } | ||||
|         ) | ||||
|  | ||||
|     def _get_default_collation(self, table_name): | ||||
|         with self.connection.cursor() as cursor: | ||||
|             cursor.execute( | ||||
|                 """ | ||||
|                 SELECT default_collation FROM user_tables WHERE table_name = %s | ||||
|                 """, | ||||
|                 [self.normalize_name(table_name)], | ||||
|             ) | ||||
|             return cursor.fetchone()[0] | ||||
|  | ||||
|     def _collate_sql(self, collation, old_collation=None, table_name=None): | ||||
|         if collation is None and old_collation is not None: | ||||
|             collation = self._get_default_collation(table_name) | ||||
|         return super()._collate_sql(collation, old_collation, table_name) | ||||
| @ -0,0 +1,97 @@ | ||||
| import datetime | ||||
|  | ||||
| from .base import Database | ||||
|  | ||||
|  | ||||
| class InsertVar: | ||||
|     """ | ||||
|     A late-binding cursor variable that can be passed to Cursor.execute | ||||
|     as a parameter, in order to receive the id of the row created by an | ||||
|     insert statement. | ||||
|     """ | ||||
|  | ||||
|     types = { | ||||
|         "AutoField": int, | ||||
|         "BigAutoField": int, | ||||
|         "SmallAutoField": int, | ||||
|         "IntegerField": int, | ||||
|         "BigIntegerField": int, | ||||
|         "SmallIntegerField": int, | ||||
|         "PositiveBigIntegerField": int, | ||||
|         "PositiveSmallIntegerField": int, | ||||
|         "PositiveIntegerField": int, | ||||
|         "FloatField": Database.NATIVE_FLOAT, | ||||
|         "DateTimeField": Database.TIMESTAMP, | ||||
|         "DateField": Database.Date, | ||||
|         "DecimalField": Database.NUMBER, | ||||
|     } | ||||
|  | ||||
|     def __init__(self, field): | ||||
|         internal_type = getattr(field, "target_field", field).get_internal_type() | ||||
|         self.db_type = self.types.get(internal_type, str) | ||||
|         self.bound_param = None | ||||
|  | ||||
|     def bind_parameter(self, cursor): | ||||
|         self.bound_param = cursor.cursor.var(self.db_type) | ||||
|         return self.bound_param | ||||
|  | ||||
|     def get_value(self): | ||||
|         return self.bound_param.getvalue() | ||||
|  | ||||
|  | ||||
| class Oracle_datetime(datetime.datetime): | ||||
|     """ | ||||
|     A datetime object, with an additional class attribute | ||||
|     to tell cx_Oracle to save the microseconds too. | ||||
|     """ | ||||
|  | ||||
|     input_size = Database.TIMESTAMP | ||||
|  | ||||
|     @classmethod | ||||
|     def from_datetime(cls, dt): | ||||
|         return Oracle_datetime( | ||||
|             dt.year, | ||||
|             dt.month, | ||||
|             dt.day, | ||||
|             dt.hour, | ||||
|             dt.minute, | ||||
|             dt.second, | ||||
|             dt.microsecond, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class BulkInsertMapper: | ||||
|     BLOB = "TO_BLOB(%s)" | ||||
|     DATE = "TO_DATE(%s)" | ||||
|     INTERVAL = "CAST(%s as INTERVAL DAY(9) TO SECOND(6))" | ||||
|     NCLOB = "TO_NCLOB(%s)" | ||||
|     NUMBER = "TO_NUMBER(%s)" | ||||
|     TIMESTAMP = "TO_TIMESTAMP(%s)" | ||||
|  | ||||
|     types = { | ||||
|         "AutoField": NUMBER, | ||||
|         "BigAutoField": NUMBER, | ||||
|         "BigIntegerField": NUMBER, | ||||
|         "BinaryField": BLOB, | ||||
|         "BooleanField": NUMBER, | ||||
|         "DateField": DATE, | ||||
|         "DateTimeField": TIMESTAMP, | ||||
|         "DecimalField": NUMBER, | ||||
|         "DurationField": INTERVAL, | ||||
|         "FloatField": NUMBER, | ||||
|         "IntegerField": NUMBER, | ||||
|         "PositiveBigIntegerField": NUMBER, | ||||
|         "PositiveIntegerField": NUMBER, | ||||
|         "PositiveSmallIntegerField": NUMBER, | ||||
|         "SmallAutoField": NUMBER, | ||||
|         "SmallIntegerField": NUMBER, | ||||
|         "TextField": NCLOB, | ||||
|         "TimeField": TIMESTAMP, | ||||
|     } | ||||
|  | ||||
|  | ||||
| def dsn(settings_dict): | ||||
|     if settings_dict["PORT"]: | ||||
|         host = settings_dict["HOST"].strip() or "localhost" | ||||
|         return Database.makedsn(host, int(settings_dict["PORT"]), settings_dict["NAME"]) | ||||
|     return settings_dict["NAME"] | ||||
| @ -0,0 +1,22 @@ | ||||
| from django.core import checks | ||||
| from django.db.backends.base.validation import BaseDatabaseValidation | ||||
|  | ||||
|  | ||||
| class DatabaseValidation(BaseDatabaseValidation): | ||||
|     def check_field_type(self, field, field_type): | ||||
|         """Oracle doesn't support a database index on some data types.""" | ||||
|         errors = [] | ||||
|         if field.db_index and field_type.lower() in self.connection._limited_data_types: | ||||
|             errors.append( | ||||
|                 checks.Warning( | ||||
|                     "Oracle does not support a database index on %s columns." | ||||
|                     % field_type, | ||||
|                     hint=( | ||||
|                         "An index won't be created. Silence this warning if " | ||||
|                         "you don't care about it." | ||||
|                     ), | ||||
|                     obj=field, | ||||
|                     id="fields.W162", | ||||
|                 ) | ||||
|             ) | ||||
|         return errors | ||||
| @ -0,0 +1,487 @@ | ||||
| """ | ||||
| PostgreSQL database backend for Django. | ||||
|  | ||||
| Requires psycopg2 >= 2.8.4 or psycopg >= 3.1.8 | ||||
| """ | ||||
|  | ||||
| import asyncio | ||||
| import threading | ||||
| import warnings | ||||
| from contextlib import contextmanager | ||||
|  | ||||
| from django.conf import settings | ||||
| from django.core.exceptions import ImproperlyConfigured | ||||
| from django.db import DatabaseError as WrappedDatabaseError | ||||
| from django.db import connections | ||||
| from django.db.backends.base.base import BaseDatabaseWrapper | ||||
| from django.db.backends.utils import CursorDebugWrapper as BaseCursorDebugWrapper | ||||
| from django.utils.asyncio import async_unsafe | ||||
| from django.utils.functional import cached_property | ||||
| from django.utils.safestring import SafeString | ||||
| from django.utils.version import get_version_tuple | ||||
|  | ||||
| try: | ||||
|     try: | ||||
|         import psycopg as Database | ||||
|     except ImportError: | ||||
|         import psycopg2 as Database | ||||
| except ImportError: | ||||
|     raise ImproperlyConfigured("Error loading psycopg2 or psycopg module") | ||||
|  | ||||
|  | ||||
| def psycopg_version(): | ||||
|     version = Database.__version__.split(" ", 1)[0] | ||||
|     return get_version_tuple(version) | ||||
|  | ||||
|  | ||||
| if psycopg_version() < (2, 8, 4): | ||||
|     raise ImproperlyConfigured( | ||||
|         f"psycopg2 version 2.8.4 or newer is required; you have {Database.__version__}" | ||||
|     ) | ||||
| if (3,) <= psycopg_version() < (3, 1, 8): | ||||
|     raise ImproperlyConfigured( | ||||
|         f"psycopg version 3.1.8 or newer is required; you have {Database.__version__}" | ||||
|     ) | ||||
|  | ||||
|  | ||||
| from .psycopg_any import IsolationLevel, is_psycopg3  # NOQA isort:skip | ||||
|  | ||||
| if is_psycopg3: | ||||
|     from psycopg import adapters, sql | ||||
|     from psycopg.pq import Format | ||||
|  | ||||
|     from .psycopg_any import get_adapters_template, register_tzloader | ||||
|  | ||||
|     TIMESTAMPTZ_OID = adapters.types["timestamptz"].oid | ||||
|  | ||||
| else: | ||||
|     import psycopg2.extensions | ||||
|     import psycopg2.extras | ||||
|  | ||||
|     psycopg2.extensions.register_adapter(SafeString, psycopg2.extensions.QuotedString) | ||||
|     psycopg2.extras.register_uuid() | ||||
|  | ||||
|     # Register support for inet[] manually so we don't have to handle the Inet() | ||||
|     # object on load all the time. | ||||
|     INETARRAY_OID = 1041 | ||||
|     INETARRAY = psycopg2.extensions.new_array_type( | ||||
|         (INETARRAY_OID,), | ||||
|         "INETARRAY", | ||||
|         psycopg2.extensions.UNICODE, | ||||
|     ) | ||||
|     psycopg2.extensions.register_type(INETARRAY) | ||||
|  | ||||
| # Some of these import psycopg, so import them after checking if it's installed. | ||||
| from .client import DatabaseClient  # NOQA isort:skip | ||||
| from .creation import DatabaseCreation  # NOQA isort:skip | ||||
| from .features import DatabaseFeatures  # NOQA isort:skip | ||||
| from .introspection import DatabaseIntrospection  # NOQA isort:skip | ||||
| from .operations import DatabaseOperations  # NOQA isort:skip | ||||
| from .schema import DatabaseSchemaEditor  # NOQA isort:skip | ||||
|  | ||||
|  | ||||
| def _get_varchar_column(data): | ||||
|     if data["max_length"] is None: | ||||
|         return "varchar" | ||||
|     return "varchar(%(max_length)s)" % data | ||||
|  | ||||
|  | ||||
| class DatabaseWrapper(BaseDatabaseWrapper): | ||||
|     vendor = "postgresql" | ||||
|     display_name = "PostgreSQL" | ||||
|     # This dictionary maps Field objects to their associated PostgreSQL column | ||||
|     # types, as strings. Column-type strings can contain format strings; they'll | ||||
|     # be interpolated against the values of Field.__dict__ before being output. | ||||
|     # If a column type is set to None, it won't be included in the output. | ||||
|     data_types = { | ||||
|         "AutoField": "integer", | ||||
|         "BigAutoField": "bigint", | ||||
|         "BinaryField": "bytea", | ||||
|         "BooleanField": "boolean", | ||||
|         "CharField": _get_varchar_column, | ||||
|         "DateField": "date", | ||||
|         "DateTimeField": "timestamp with time zone", | ||||
|         "DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)", | ||||
|         "DurationField": "interval", | ||||
|         "FileField": "varchar(%(max_length)s)", | ||||
|         "FilePathField": "varchar(%(max_length)s)", | ||||
|         "FloatField": "double precision", | ||||
|         "IntegerField": "integer", | ||||
|         "BigIntegerField": "bigint", | ||||
|         "IPAddressField": "inet", | ||||
|         "GenericIPAddressField": "inet", | ||||
|         "JSONField": "jsonb", | ||||
|         "OneToOneField": "integer", | ||||
|         "PositiveBigIntegerField": "bigint", | ||||
|         "PositiveIntegerField": "integer", | ||||
|         "PositiveSmallIntegerField": "smallint", | ||||
|         "SlugField": "varchar(%(max_length)s)", | ||||
|         "SmallAutoField": "smallint", | ||||
|         "SmallIntegerField": "smallint", | ||||
|         "TextField": "text", | ||||
|         "TimeField": "time", | ||||
|         "UUIDField": "uuid", | ||||
|     } | ||||
|     data_type_check_constraints = { | ||||
|         "PositiveBigIntegerField": '"%(column)s" >= 0', | ||||
|         "PositiveIntegerField": '"%(column)s" >= 0', | ||||
|         "PositiveSmallIntegerField": '"%(column)s" >= 0', | ||||
|     } | ||||
|     data_types_suffix = { | ||||
|         "AutoField": "GENERATED BY DEFAULT AS IDENTITY", | ||||
|         "BigAutoField": "GENERATED BY DEFAULT AS IDENTITY", | ||||
|         "SmallAutoField": "GENERATED BY DEFAULT AS IDENTITY", | ||||
|     } | ||||
|     operators = { | ||||
|         "exact": "= %s", | ||||
|         "iexact": "= UPPER(%s)", | ||||
|         "contains": "LIKE %s", | ||||
|         "icontains": "LIKE UPPER(%s)", | ||||
|         "regex": "~ %s", | ||||
|         "iregex": "~* %s", | ||||
|         "gt": "> %s", | ||||
|         "gte": ">= %s", | ||||
|         "lt": "< %s", | ||||
|         "lte": "<= %s", | ||||
|         "startswith": "LIKE %s", | ||||
|         "endswith": "LIKE %s", | ||||
|         "istartswith": "LIKE UPPER(%s)", | ||||
|         "iendswith": "LIKE UPPER(%s)", | ||||
|     } | ||||
|  | ||||
|     # The patterns below are used to generate SQL pattern lookup clauses when | ||||
|     # the right-hand side of the lookup isn't a raw string (it might be an expression | ||||
|     # or the result of a bilateral transformation). | ||||
|     # In those cases, special characters for LIKE operators (e.g. \, *, _) should be | ||||
|     # escaped on database side. | ||||
|     # | ||||
|     # Note: we use str.format() here for readability as '%' is used as a wildcard for | ||||
|     # the LIKE operator. | ||||
|     pattern_esc = ( | ||||
|         r"REPLACE(REPLACE(REPLACE({}, E'\\', E'\\\\'), E'%%', E'\\%%'), E'_', E'\\_')" | ||||
|     ) | ||||
|     pattern_ops = { | ||||
|         "contains": "LIKE '%%' || {} || '%%'", | ||||
|         "icontains": "LIKE '%%' || UPPER({}) || '%%'", | ||||
|         "startswith": "LIKE {} || '%%'", | ||||
|         "istartswith": "LIKE UPPER({}) || '%%'", | ||||
|         "endswith": "LIKE '%%' || {}", | ||||
|         "iendswith": "LIKE '%%' || UPPER({})", | ||||
|     } | ||||
|  | ||||
|     Database = Database | ||||
|     SchemaEditorClass = DatabaseSchemaEditor | ||||
|     # Classes instantiated in __init__(). | ||||
|     client_class = DatabaseClient | ||||
|     creation_class = DatabaseCreation | ||||
|     features_class = DatabaseFeatures | ||||
|     introspection_class = DatabaseIntrospection | ||||
|     ops_class = DatabaseOperations | ||||
|     # PostgreSQL backend-specific attributes. | ||||
|     _named_cursor_idx = 0 | ||||
|  | ||||
|     def get_database_version(self): | ||||
|         """ | ||||
|         Return a tuple of the database's version. | ||||
|         E.g. for pg_version 120004, return (12, 4). | ||||
|         """ | ||||
|         return divmod(self.pg_version, 10000) | ||||
|  | ||||
|     def get_connection_params(self): | ||||
|         settings_dict = self.settings_dict | ||||
|         # None may be used to connect to the default 'postgres' db | ||||
|         if settings_dict["NAME"] == "" and not settings_dict.get("OPTIONS", {}).get( | ||||
|             "service" | ||||
|         ): | ||||
|             raise ImproperlyConfigured( | ||||
|                 "settings.DATABASES is improperly configured. " | ||||
|                 "Please supply the NAME or OPTIONS['service'] value." | ||||
|             ) | ||||
|         if len(settings_dict["NAME"] or "") > self.ops.max_name_length(): | ||||
|             raise ImproperlyConfigured( | ||||
|                 "The database name '%s' (%d characters) is longer than " | ||||
|                 "PostgreSQL's limit of %d characters. Supply a shorter NAME " | ||||
|                 "in settings.DATABASES." | ||||
|                 % ( | ||||
|                     settings_dict["NAME"], | ||||
|                     len(settings_dict["NAME"]), | ||||
|                     self.ops.max_name_length(), | ||||
|                 ) | ||||
|             ) | ||||
|         if settings_dict["NAME"]: | ||||
|             conn_params = { | ||||
|                 "dbname": settings_dict["NAME"], | ||||
|                 **settings_dict["OPTIONS"], | ||||
|             } | ||||
|         elif settings_dict["NAME"] is None: | ||||
|             # Connect to the default 'postgres' db. | ||||
|             settings_dict.get("OPTIONS", {}).pop("service", None) | ||||
|             conn_params = {"dbname": "postgres", **settings_dict["OPTIONS"]} | ||||
|         else: | ||||
|             conn_params = {**settings_dict["OPTIONS"]} | ||||
|         conn_params["client_encoding"] = "UTF8" | ||||
|  | ||||
|         conn_params.pop("assume_role", None) | ||||
|         conn_params.pop("isolation_level", None) | ||||
|         server_side_binding = conn_params.pop("server_side_binding", None) | ||||
|         conn_params.setdefault( | ||||
|             "cursor_factory", | ||||
|             ServerBindingCursor | ||||
|             if is_psycopg3 and server_side_binding is True | ||||
|             else Cursor, | ||||
|         ) | ||||
|         if settings_dict["USER"]: | ||||
|             conn_params["user"] = settings_dict["USER"] | ||||
|         if settings_dict["PASSWORD"]: | ||||
|             conn_params["password"] = settings_dict["PASSWORD"] | ||||
|         if settings_dict["HOST"]: | ||||
|             conn_params["host"] = settings_dict["HOST"] | ||||
|         if settings_dict["PORT"]: | ||||
|             conn_params["port"] = settings_dict["PORT"] | ||||
|         if is_psycopg3: | ||||
|             conn_params["context"] = get_adapters_template( | ||||
|                 settings.USE_TZ, self.timezone | ||||
|             ) | ||||
|             # Disable prepared statements by default to keep connection poolers | ||||
|             # working. Can be reenabled via OPTIONS in the settings dict. | ||||
|             conn_params["prepare_threshold"] = conn_params.pop( | ||||
|                 "prepare_threshold", None | ||||
|             ) | ||||
|         return conn_params | ||||
|  | ||||
|     @async_unsafe | ||||
|     def get_new_connection(self, conn_params): | ||||
|         # self.isolation_level must be set: | ||||
|         # - after connecting to the database in order to obtain the database's | ||||
|         #   default when no value is explicitly specified in options. | ||||
|         # - before calling _set_autocommit() because if autocommit is on, that | ||||
|         #   will set connection.isolation_level to ISOLATION_LEVEL_AUTOCOMMIT. | ||||
|         options = self.settings_dict["OPTIONS"] | ||||
|         set_isolation_level = False | ||||
|         try: | ||||
|             isolation_level_value = options["isolation_level"] | ||||
|         except KeyError: | ||||
|             self.isolation_level = IsolationLevel.READ_COMMITTED | ||||
|         else: | ||||
|             # Set the isolation level to the value from OPTIONS. | ||||
|             try: | ||||
|                 self.isolation_level = IsolationLevel(isolation_level_value) | ||||
|                 set_isolation_level = True | ||||
|             except ValueError: | ||||
|                 raise ImproperlyConfigured( | ||||
|                     f"Invalid transaction isolation level {isolation_level_value} " | ||||
|                     f"specified. Use one of the psycopg.IsolationLevel values." | ||||
|                 ) | ||||
|         connection = self.Database.connect(**conn_params) | ||||
|         if set_isolation_level: | ||||
|             connection.isolation_level = self.isolation_level | ||||
|         if not is_psycopg3: | ||||
|             # Register dummy loads() to avoid a round trip from psycopg2's | ||||
|             # decode to json.dumps() to json.loads(), when using a custom | ||||
|             # decoder in JSONField. | ||||
|             psycopg2.extras.register_default_jsonb( | ||||
|                 conn_or_curs=connection, loads=lambda x: x | ||||
|             ) | ||||
|         return connection | ||||
|  | ||||
|     def ensure_timezone(self): | ||||
|         if self.connection is None: | ||||
|             return False | ||||
|         conn_timezone_name = self.connection.info.parameter_status("TimeZone") | ||||
|         timezone_name = self.timezone_name | ||||
|         if timezone_name and conn_timezone_name != timezone_name: | ||||
|             with self.connection.cursor() as cursor: | ||||
|                 cursor.execute(self.ops.set_time_zone_sql(), [timezone_name]) | ||||
|             return True | ||||
|         return False | ||||
|  | ||||
|     def ensure_role(self): | ||||
|         if self.connection is None: | ||||
|             return False | ||||
|         if new_role := self.settings_dict.get("OPTIONS", {}).get("assume_role"): | ||||
|             with self.connection.cursor() as cursor: | ||||
|                 sql = self.ops.compose_sql("SET ROLE %s", [new_role]) | ||||
|                 cursor.execute(sql) | ||||
|             return True | ||||
|         return False | ||||
|  | ||||
|     def init_connection_state(self): | ||||
|         super().init_connection_state() | ||||
|  | ||||
|         # Commit after setting the time zone. | ||||
|         commit_tz = self.ensure_timezone() | ||||
|         # Set the role on the connection. This is useful if the credential used | ||||
|         # to login is not the same as the role that owns database resources. As | ||||
|         # can be the case when using temporary or ephemeral credentials. | ||||
|         commit_role = self.ensure_role() | ||||
|  | ||||
|         if (commit_role or commit_tz) and not self.get_autocommit(): | ||||
|             self.connection.commit() | ||||
|  | ||||
|     @async_unsafe | ||||
|     def create_cursor(self, name=None): | ||||
|         if name: | ||||
|             # In autocommit mode, the cursor will be used outside of a | ||||
|             # transaction, hence use a holdable cursor. | ||||
|             cursor = self.connection.cursor( | ||||
|                 name, scrollable=False, withhold=self.connection.autocommit | ||||
|             ) | ||||
|         else: | ||||
|             cursor = self.connection.cursor() | ||||
|  | ||||
|         if is_psycopg3: | ||||
|             # Register the cursor timezone only if the connection disagrees, to | ||||
|             # avoid copying the adapter map. | ||||
|             tzloader = self.connection.adapters.get_loader(TIMESTAMPTZ_OID, Format.TEXT) | ||||
|             if self.timezone != tzloader.timezone: | ||||
|                 register_tzloader(self.timezone, cursor) | ||||
|         else: | ||||
|             cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None | ||||
|         return cursor | ||||
|  | ||||
|     def tzinfo_factory(self, offset): | ||||
|         return self.timezone | ||||
|  | ||||
|     @async_unsafe | ||||
|     def chunked_cursor(self): | ||||
|         self._named_cursor_idx += 1 | ||||
|         # Get the current async task | ||||
|         # Note that right now this is behind @async_unsafe, so this is | ||||
|         # unreachable, but in future we'll start loosening this restriction. | ||||
|         # For now, it's here so that every use of "threading" is | ||||
|         # also async-compatible. | ||||
|         try: | ||||
|             current_task = asyncio.current_task() | ||||
|         except RuntimeError: | ||||
|             current_task = None | ||||
|         # Current task can be none even if the current_task call didn't error | ||||
|         if current_task: | ||||
|             task_ident = str(id(current_task)) | ||||
|         else: | ||||
|             task_ident = "sync" | ||||
|         # Use that and the thread ident to get a unique name | ||||
|         return self._cursor( | ||||
|             name="_django_curs_%d_%s_%d" | ||||
|             % ( | ||||
|                 # Avoid reusing name in other threads / tasks | ||||
|                 threading.current_thread().ident, | ||||
|                 task_ident, | ||||
|                 self._named_cursor_idx, | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|     def _set_autocommit(self, autocommit): | ||||
|         with self.wrap_database_errors: | ||||
|             self.connection.autocommit = autocommit | ||||
|  | ||||
|     def check_constraints(self, table_names=None): | ||||
|         """ | ||||
|         Check constraints by setting them to immediate. Return them to deferred | ||||
|         afterward. | ||||
|         """ | ||||
|         with self.cursor() as cursor: | ||||
|             cursor.execute("SET CONSTRAINTS ALL IMMEDIATE") | ||||
|             cursor.execute("SET CONSTRAINTS ALL DEFERRED") | ||||
|  | ||||
|     def is_usable(self): | ||||
|         try: | ||||
|             # Use a psycopg cursor directly, bypassing Django's utilities. | ||||
|             with self.connection.cursor() as cursor: | ||||
|                 cursor.execute("SELECT 1") | ||||
|         except Database.Error: | ||||
|             return False | ||||
|         else: | ||||
|             return True | ||||
|  | ||||
|     @contextmanager | ||||
|     def _nodb_cursor(self): | ||||
|         cursor = None | ||||
|         try: | ||||
|             with super()._nodb_cursor() as cursor: | ||||
|                 yield cursor | ||||
|         except (Database.DatabaseError, WrappedDatabaseError): | ||||
|             if cursor is not None: | ||||
|                 raise | ||||
|             warnings.warn( | ||||
|                 "Normally Django will use a connection to the 'postgres' database " | ||||
|                 "to avoid running initialization queries against the production " | ||||
|                 "database when it's not needed (for example, when running tests). " | ||||
|                 "Django was unable to create a connection to the 'postgres' database " | ||||
|                 "and will use the first PostgreSQL database instead.", | ||||
|                 RuntimeWarning, | ||||
|             ) | ||||
|             for connection in connections.all(): | ||||
|                 if ( | ||||
|                     connection.vendor == "postgresql" | ||||
|                     and connection.settings_dict["NAME"] != "postgres" | ||||
|                 ): | ||||
|                     conn = self.__class__( | ||||
|                         { | ||||
|                             **self.settings_dict, | ||||
|                             "NAME": connection.settings_dict["NAME"], | ||||
|                         }, | ||||
|                         alias=self.alias, | ||||
|                     ) | ||||
|                     try: | ||||
|                         with conn.cursor() as cursor: | ||||
|                             yield cursor | ||||
|                     finally: | ||||
|                         conn.close() | ||||
|                     break | ||||
|             else: | ||||
|                 raise | ||||
|  | ||||
|     @cached_property | ||||
|     def pg_version(self): | ||||
|         with self.temporary_connection(): | ||||
|             return self.connection.info.server_version | ||||
|  | ||||
|     def make_debug_cursor(self, cursor): | ||||
|         return CursorDebugWrapper(cursor, self) | ||||
|  | ||||
|  | ||||
| if is_psycopg3: | ||||
|  | ||||
|     class CursorMixin: | ||||
|         """ | ||||
|         A subclass of psycopg cursor implementing callproc. | ||||
|         """ | ||||
|  | ||||
|         def callproc(self, name, args=None): | ||||
|             if not isinstance(name, sql.Identifier): | ||||
|                 name = sql.Identifier(name) | ||||
|  | ||||
|             qparts = [sql.SQL("SELECT * FROM "), name, sql.SQL("(")] | ||||
|             if args: | ||||
|                 for item in args: | ||||
|                     qparts.append(sql.Literal(item)) | ||||
|                     qparts.append(sql.SQL(",")) | ||||
|                 del qparts[-1] | ||||
|  | ||||
|             qparts.append(sql.SQL(")")) | ||||
|             stmt = sql.Composed(qparts) | ||||
|             self.execute(stmt) | ||||
|             return args | ||||
|  | ||||
|     class ServerBindingCursor(CursorMixin, Database.Cursor): | ||||
|         pass | ||||
|  | ||||
|     class Cursor(CursorMixin, Database.ClientCursor): | ||||
|         pass | ||||
|  | ||||
|     class CursorDebugWrapper(BaseCursorDebugWrapper): | ||||
|         def copy(self, statement): | ||||
|             with self.debug_sql(statement): | ||||
|                 return self.cursor.copy(statement) | ||||
|  | ||||
| else: | ||||
|     Cursor = psycopg2.extensions.cursor | ||||
|  | ||||
|     class CursorDebugWrapper(BaseCursorDebugWrapper): | ||||
|         def copy_expert(self, sql, file, *args): | ||||
|             with self.debug_sql(sql): | ||||
|                 return self.cursor.copy_expert(sql, file, *args) | ||||
|  | ||||
|         def copy_to(self, file, table, *args, **kwargs): | ||||
|             with self.debug_sql(sql="COPY %s TO STDOUT" % table): | ||||
|                 return self.cursor.copy_to(file, table, *args, **kwargs) | ||||
| @ -0,0 +1,64 @@ | ||||
| import signal | ||||
|  | ||||
| from django.db.backends.base.client import BaseDatabaseClient | ||||
|  | ||||
|  | ||||
| class DatabaseClient(BaseDatabaseClient): | ||||
|     executable_name = "psql" | ||||
|  | ||||
|     @classmethod | ||||
|     def settings_to_cmd_args_env(cls, settings_dict, parameters): | ||||
|         args = [cls.executable_name] | ||||
|         options = settings_dict.get("OPTIONS", {}) | ||||
|  | ||||
|         host = settings_dict.get("HOST") | ||||
|         port = settings_dict.get("PORT") | ||||
|         dbname = settings_dict.get("NAME") | ||||
|         user = settings_dict.get("USER") | ||||
|         passwd = settings_dict.get("PASSWORD") | ||||
|         passfile = options.get("passfile") | ||||
|         service = options.get("service") | ||||
|         sslmode = options.get("sslmode") | ||||
|         sslrootcert = options.get("sslrootcert") | ||||
|         sslcert = options.get("sslcert") | ||||
|         sslkey = options.get("sslkey") | ||||
|  | ||||
|         if not dbname and not service: | ||||
|             # Connect to the default 'postgres' db. | ||||
|             dbname = "postgres" | ||||
|         if user: | ||||
|             args += ["-U", user] | ||||
|         if host: | ||||
|             args += ["-h", host] | ||||
|         if port: | ||||
|             args += ["-p", str(port)] | ||||
|         args.extend(parameters) | ||||
|         if dbname: | ||||
|             args += [dbname] | ||||
|  | ||||
|         env = {} | ||||
|         if passwd: | ||||
|             env["PGPASSWORD"] = str(passwd) | ||||
|         if service: | ||||
|             env["PGSERVICE"] = str(service) | ||||
|         if sslmode: | ||||
|             env["PGSSLMODE"] = str(sslmode) | ||||
|         if sslrootcert: | ||||
|             env["PGSSLROOTCERT"] = str(sslrootcert) | ||||
|         if sslcert: | ||||
|             env["PGSSLCERT"] = str(sslcert) | ||||
|         if sslkey: | ||||
|             env["PGSSLKEY"] = str(sslkey) | ||||
|         if passfile: | ||||
|             env["PGPASSFILE"] = str(passfile) | ||||
|         return args, (env or None) | ||||
|  | ||||
|     def runshell(self, parameters): | ||||
|         sigint_handler = signal.getsignal(signal.SIGINT) | ||||
|         try: | ||||
|             # Allow SIGINT to pass to psql to abort queries. | ||||
|             signal.signal(signal.SIGINT, signal.SIG_IGN) | ||||
|             super().runshell(parameters) | ||||
|         finally: | ||||
|             # Restore the original SIGINT handler. | ||||
|             signal.signal(signal.SIGINT, sigint_handler) | ||||
| @ -0,0 +1,86 @@ | ||||
| import sys | ||||
|  | ||||
| from django.core.exceptions import ImproperlyConfigured | ||||
| from django.db.backends.base.creation import BaseDatabaseCreation | ||||
| from django.db.backends.postgresql.psycopg_any import errors | ||||
| from django.db.backends.utils import strip_quotes | ||||
|  | ||||
|  | ||||
| class DatabaseCreation(BaseDatabaseCreation): | ||||
|     def _quote_name(self, name): | ||||
|         return self.connection.ops.quote_name(name) | ||||
|  | ||||
|     def _get_database_create_suffix(self, encoding=None, template=None): | ||||
|         suffix = "" | ||||
|         if encoding: | ||||
|             suffix += " ENCODING '{}'".format(encoding) | ||||
|         if template: | ||||
|             suffix += " TEMPLATE {}".format(self._quote_name(template)) | ||||
|         return suffix and "WITH" + suffix | ||||
|  | ||||
|     def sql_table_creation_suffix(self): | ||||
|         test_settings = self.connection.settings_dict["TEST"] | ||||
|         if test_settings.get("COLLATION") is not None: | ||||
|             raise ImproperlyConfigured( | ||||
|                 "PostgreSQL does not support collation setting at database " | ||||
|                 "creation time." | ||||
|             ) | ||||
|         return self._get_database_create_suffix( | ||||
|             encoding=test_settings["CHARSET"], | ||||
|             template=test_settings.get("TEMPLATE"), | ||||
|         ) | ||||
|  | ||||
|     def _database_exists(self, cursor, database_name): | ||||
|         cursor.execute( | ||||
|             "SELECT 1 FROM pg_catalog.pg_database WHERE datname = %s", | ||||
|             [strip_quotes(database_name)], | ||||
|         ) | ||||
|         return cursor.fetchone() is not None | ||||
|  | ||||
|     def _execute_create_test_db(self, cursor, parameters, keepdb=False): | ||||
|         try: | ||||
|             if keepdb and self._database_exists(cursor, parameters["dbname"]): | ||||
|                 # If the database should be kept and it already exists, don't | ||||
|                 # try to create a new one. | ||||
|                 return | ||||
|             super()._execute_create_test_db(cursor, parameters, keepdb) | ||||
|         except Exception as e: | ||||
|             if not isinstance(e.__cause__, errors.DuplicateDatabase): | ||||
|                 # All errors except "database already exists" cancel tests. | ||||
|                 self.log("Got an error creating the test database: %s" % e) | ||||
|                 sys.exit(2) | ||||
|             elif not keepdb: | ||||
|                 # If the database should be kept, ignore "database already | ||||
|                 # exists". | ||||
|                 raise | ||||
|  | ||||
|     def _clone_test_db(self, suffix, verbosity, keepdb=False): | ||||
|         # CREATE DATABASE ... WITH TEMPLATE ... requires closing connections | ||||
|         # to the template database. | ||||
|         self.connection.close() | ||||
|  | ||||
|         source_database_name = self.connection.settings_dict["NAME"] | ||||
|         target_database_name = self.get_test_db_clone_settings(suffix)["NAME"] | ||||
|         test_db_params = { | ||||
|             "dbname": self._quote_name(target_database_name), | ||||
|             "suffix": self._get_database_create_suffix(template=source_database_name), | ||||
|         } | ||||
|         with self._nodb_cursor() as cursor: | ||||
|             try: | ||||
|                 self._execute_create_test_db(cursor, test_db_params, keepdb) | ||||
|             except Exception: | ||||
|                 try: | ||||
|                     if verbosity >= 1: | ||||
|                         self.log( | ||||
|                             "Destroying old test database for alias %s..." | ||||
|                             % ( | ||||
|                                 self._get_database_display_str( | ||||
|                                     verbosity, target_database_name | ||||
|                                 ), | ||||
|                             ) | ||||
|                         ) | ||||
|                     cursor.execute("DROP DATABASE %(dbname)s" % test_db_params) | ||||
|                     self._execute_create_test_db(cursor, test_db_params, keepdb) | ||||
|                 except Exception as e: | ||||
|                     self.log("Got an error cloning the test database: %s" % e) | ||||
|                     sys.exit(2) | ||||
| @ -0,0 +1,136 @@ | ||||
| import operator | ||||
|  | ||||
| from django.db import DataError, InterfaceError | ||||
| from django.db.backends.base.features import BaseDatabaseFeatures | ||||
| from django.db.backends.postgresql.psycopg_any import is_psycopg3 | ||||
| from django.utils.functional import cached_property | ||||
|  | ||||
|  | ||||
| class DatabaseFeatures(BaseDatabaseFeatures): | ||||
|     minimum_database_version = (12,) | ||||
|     allows_group_by_selected_pks = True | ||||
|     can_return_columns_from_insert = True | ||||
|     can_return_rows_from_bulk_insert = True | ||||
|     has_real_datatype = True | ||||
|     has_native_uuid_field = True | ||||
|     has_native_duration_field = True | ||||
|     has_native_json_field = True | ||||
|     can_defer_constraint_checks = True | ||||
|     has_select_for_update = True | ||||
|     has_select_for_update_nowait = True | ||||
|     has_select_for_update_of = True | ||||
|     has_select_for_update_skip_locked = True | ||||
|     has_select_for_no_key_update = True | ||||
|     can_release_savepoints = True | ||||
|     supports_comments = True | ||||
|     supports_tablespaces = True | ||||
|     supports_transactions = True | ||||
|     can_introspect_materialized_views = True | ||||
|     can_distinct_on_fields = True | ||||
|     can_rollback_ddl = True | ||||
|     schema_editor_uses_clientside_param_binding = True | ||||
|     supports_combined_alters = True | ||||
|     nulls_order_largest = True | ||||
|     closed_cursor_error_class = InterfaceError | ||||
|     greatest_least_ignores_nulls = True | ||||
|     can_clone_databases = True | ||||
|     supports_temporal_subtraction = True | ||||
|     supports_slicing_ordering_in_compound = True | ||||
|     create_test_procedure_without_params_sql = """ | ||||
|         CREATE FUNCTION test_procedure () RETURNS void AS $$ | ||||
|         DECLARE | ||||
|             V_I INTEGER; | ||||
|         BEGIN | ||||
|             V_I := 1; | ||||
|         END; | ||||
|     $$ LANGUAGE plpgsql;""" | ||||
|     create_test_procedure_with_int_param_sql = """ | ||||
|         CREATE FUNCTION test_procedure (P_I INTEGER) RETURNS void AS $$ | ||||
|         DECLARE | ||||
|             V_I INTEGER; | ||||
|         BEGIN | ||||
|             V_I := P_I; | ||||
|         END; | ||||
|     $$ LANGUAGE plpgsql;""" | ||||
|     create_test_table_with_composite_primary_key = """ | ||||
|         CREATE TABLE test_table_composite_pk ( | ||||
|             column_1 INTEGER NOT NULL, | ||||
|             column_2 INTEGER NOT NULL, | ||||
|             PRIMARY KEY(column_1, column_2) | ||||
|         ) | ||||
|     """ | ||||
|     requires_casted_case_in_updates = True | ||||
|     supports_over_clause = True | ||||
|     only_supports_unbounded_with_preceding_and_following = True | ||||
|     supports_aggregate_filter_clause = True | ||||
|     supported_explain_formats = {"JSON", "TEXT", "XML", "YAML"} | ||||
|     supports_deferrable_unique_constraints = True | ||||
|     has_json_operators = True | ||||
|     json_key_contains_list_matching_requires_list = True | ||||
|     supports_update_conflicts = True | ||||
|     supports_update_conflicts_with_target = True | ||||
|     supports_covering_indexes = True | ||||
|     can_rename_index = True | ||||
|     test_collations = { | ||||
|         "non_default": "sv-x-icu", | ||||
|         "swedish_ci": "sv-x-icu", | ||||
|     } | ||||
|     test_now_utc_template = "STATEMENT_TIMESTAMP() AT TIME ZONE 'UTC'" | ||||
|  | ||||
|     django_test_skips = { | ||||
|         "opclasses are PostgreSQL only.": { | ||||
|             "indexes.tests.SchemaIndexesNotPostgreSQLTests." | ||||
|             "test_create_index_ignores_opclasses", | ||||
|         }, | ||||
|         "PostgreSQL requires casting to text.": { | ||||
|             "lookup.tests.LookupTests.test_textfield_exact_null", | ||||
|         }, | ||||
|     } | ||||
|  | ||||
|     @cached_property | ||||
|     def django_test_expected_failures(self): | ||||
|         expected_failures = set() | ||||
|         if self.uses_server_side_binding: | ||||
|             expected_failures.update( | ||||
|                 { | ||||
|                     # Parameters passed to expressions in SELECT and GROUP BY | ||||
|                     # clauses are not recognized as the same values when using | ||||
|                     # server-side binding cursors (#34255). | ||||
|                     "aggregation.tests.AggregateTestCase." | ||||
|                     "test_group_by_nested_expression_with_params", | ||||
|                 } | ||||
|             ) | ||||
|         return expected_failures | ||||
|  | ||||
|     @cached_property | ||||
|     def uses_server_side_binding(self): | ||||
|         options = self.connection.settings_dict["OPTIONS"] | ||||
|         return is_psycopg3 and options.get("server_side_binding") is True | ||||
|  | ||||
|     @cached_property | ||||
|     def prohibits_null_characters_in_text_exception(self): | ||||
|         if is_psycopg3: | ||||
|             return DataError, "PostgreSQL text fields cannot contain NUL (0x00) bytes" | ||||
|         else: | ||||
|             return ValueError, "A string literal cannot contain NUL (0x00) characters." | ||||
|  | ||||
|     @cached_property | ||||
|     def introspected_field_types(self): | ||||
|         return { | ||||
|             **super().introspected_field_types, | ||||
|             "PositiveBigIntegerField": "BigIntegerField", | ||||
|             "PositiveIntegerField": "IntegerField", | ||||
|             "PositiveSmallIntegerField": "SmallIntegerField", | ||||
|         } | ||||
|  | ||||
|     @cached_property | ||||
|     def is_postgresql_13(self): | ||||
|         return self.connection.pg_version >= 130000 | ||||
|  | ||||
|     @cached_property | ||||
|     def is_postgresql_14(self): | ||||
|         return self.connection.pg_version >= 140000 | ||||
|  | ||||
|     has_bit_xor = property(operator.attrgetter("is_postgresql_14")) | ||||
|     supports_covering_spgist_indexes = property(operator.attrgetter("is_postgresql_14")) | ||||
|     supports_unlimited_charfield = True | ||||
| @ -0,0 +1,299 @@ | ||||
| from collections import namedtuple | ||||
|  | ||||
| from django.db.backends.base.introspection import BaseDatabaseIntrospection | ||||
| from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo | ||||
| from django.db.backends.base.introspection import TableInfo as BaseTableInfo | ||||
| from django.db.models import Index | ||||
|  | ||||
| FieldInfo = namedtuple("FieldInfo", BaseFieldInfo._fields + ("is_autofield", "comment")) | ||||
| TableInfo = namedtuple("TableInfo", BaseTableInfo._fields + ("comment",)) | ||||
|  | ||||
|  | ||||
| class DatabaseIntrospection(BaseDatabaseIntrospection): | ||||
|     # Maps type codes to Django Field types. | ||||
|     data_types_reverse = { | ||||
|         16: "BooleanField", | ||||
|         17: "BinaryField", | ||||
|         20: "BigIntegerField", | ||||
|         21: "SmallIntegerField", | ||||
|         23: "IntegerField", | ||||
|         25: "TextField", | ||||
|         700: "FloatField", | ||||
|         701: "FloatField", | ||||
|         869: "GenericIPAddressField", | ||||
|         1042: "CharField",  # blank-padded | ||||
|         1043: "CharField", | ||||
|         1082: "DateField", | ||||
|         1083: "TimeField", | ||||
|         1114: "DateTimeField", | ||||
|         1184: "DateTimeField", | ||||
|         1186: "DurationField", | ||||
|         1266: "TimeField", | ||||
|         1700: "DecimalField", | ||||
|         2950: "UUIDField", | ||||
|         3802: "JSONField", | ||||
|     } | ||||
|     # A hook for subclasses. | ||||
|     index_default_access_method = "btree" | ||||
|  | ||||
|     ignored_tables = [] | ||||
|  | ||||
|     def get_field_type(self, data_type, description): | ||||
|         field_type = super().get_field_type(data_type, description) | ||||
|         if description.is_autofield or ( | ||||
|             # Required for pre-Django 4.1 serial columns. | ||||
|             description.default | ||||
|             and "nextval" in description.default | ||||
|         ): | ||||
|             if field_type == "IntegerField": | ||||
|                 return "AutoField" | ||||
|             elif field_type == "BigIntegerField": | ||||
|                 return "BigAutoField" | ||||
|             elif field_type == "SmallIntegerField": | ||||
|                 return "SmallAutoField" | ||||
|         return field_type | ||||
|  | ||||
|     def get_table_list(self, cursor): | ||||
|         """Return a list of table and view names in the current database.""" | ||||
|         cursor.execute( | ||||
|             """ | ||||
|             SELECT | ||||
|                 c.relname, | ||||
|                 CASE | ||||
|                     WHEN c.relispartition THEN 'p' | ||||
|                     WHEN c.relkind IN ('m', 'v') THEN 'v' | ||||
|                     ELSE 't' | ||||
|                 END, | ||||
|                 obj_description(c.oid, 'pg_class') | ||||
|             FROM pg_catalog.pg_class c | ||||
|             LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace | ||||
|             WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v') | ||||
|                 AND n.nspname NOT IN ('pg_catalog', 'pg_toast') | ||||
|                 AND pg_catalog.pg_table_is_visible(c.oid) | ||||
|         """ | ||||
|         ) | ||||
|         return [ | ||||
|             TableInfo(*row) | ||||
|             for row in cursor.fetchall() | ||||
|             if row[0] not in self.ignored_tables | ||||
|         ] | ||||
|  | ||||
|     def get_table_description(self, cursor, table_name): | ||||
|         """ | ||||
|         Return a description of the table with the DB-API cursor.description | ||||
|         interface. | ||||
|         """ | ||||
|         # Query the pg_catalog tables as cursor.description does not reliably | ||||
|         # return the nullable property and information_schema.columns does not | ||||
|         # contain details of materialized views. | ||||
|         cursor.execute( | ||||
|             """ | ||||
|             SELECT | ||||
|                 a.attname AS column_name, | ||||
|                 NOT (a.attnotnull OR (t.typtype = 'd' AND t.typnotnull)) AS is_nullable, | ||||
|                 pg_get_expr(ad.adbin, ad.adrelid) AS column_default, | ||||
|                 CASE WHEN collname = 'default' THEN NULL ELSE collname END AS collation, | ||||
|                 a.attidentity != '' AS is_autofield, | ||||
|                 col_description(a.attrelid, a.attnum) AS column_comment | ||||
|             FROM pg_attribute a | ||||
|             LEFT JOIN pg_attrdef ad ON a.attrelid = ad.adrelid AND a.attnum = ad.adnum | ||||
|             LEFT JOIN pg_collation co ON a.attcollation = co.oid | ||||
|             JOIN pg_type t ON a.atttypid = t.oid | ||||
|             JOIN pg_class c ON a.attrelid = c.oid | ||||
|             JOIN pg_namespace n ON c.relnamespace = n.oid | ||||
|             WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v') | ||||
|                 AND c.relname = %s | ||||
|                 AND n.nspname NOT IN ('pg_catalog', 'pg_toast') | ||||
|                 AND pg_catalog.pg_table_is_visible(c.oid) | ||||
|         """, | ||||
|             [table_name], | ||||
|         ) | ||||
|         field_map = {line[0]: line[1:] for line in cursor.fetchall()} | ||||
|         cursor.execute( | ||||
|             "SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name) | ||||
|         ) | ||||
|         return [ | ||||
|             FieldInfo( | ||||
|                 line.name, | ||||
|                 line.type_code, | ||||
|                 # display_size is always None on psycopg2. | ||||
|                 line.internal_size if line.display_size is None else line.display_size, | ||||
|                 line.internal_size, | ||||
|                 line.precision, | ||||
|                 line.scale, | ||||
|                 *field_map[line.name], | ||||
|             ) | ||||
|             for line in cursor.description | ||||
|         ] | ||||
|  | ||||
|     def get_sequences(self, cursor, table_name, table_fields=()): | ||||
|         cursor.execute( | ||||
|             """ | ||||
|             SELECT | ||||
|                 s.relname AS sequence_name, | ||||
|                 a.attname AS colname | ||||
|             FROM | ||||
|                 pg_class s | ||||
|                 JOIN pg_depend d ON d.objid = s.oid | ||||
|                     AND d.classid = 'pg_class'::regclass | ||||
|                     AND d.refclassid = 'pg_class'::regclass | ||||
|                 JOIN pg_attribute a ON d.refobjid = a.attrelid | ||||
|                     AND d.refobjsubid = a.attnum | ||||
|                 JOIN pg_class tbl ON tbl.oid = d.refobjid | ||||
|                     AND tbl.relname = %s | ||||
|                     AND pg_catalog.pg_table_is_visible(tbl.oid) | ||||
|             WHERE | ||||
|                 s.relkind = 'S'; | ||||
|         """, | ||||
|             [table_name], | ||||
|         ) | ||||
|         return [ | ||||
|             {"name": row[0], "table": table_name, "column": row[1]} | ||||
|             for row in cursor.fetchall() | ||||
|         ] | ||||
|  | ||||
|     def get_relations(self, cursor, table_name): | ||||
|         """ | ||||
|         Return a dictionary of {field_name: (field_name_other_table, other_table)} | ||||
|         representing all foreign keys in the given table. | ||||
|         """ | ||||
|         cursor.execute( | ||||
|             """ | ||||
|             SELECT a1.attname, c2.relname, a2.attname | ||||
|             FROM pg_constraint con | ||||
|             LEFT JOIN pg_class c1 ON con.conrelid = c1.oid | ||||
|             LEFT JOIN pg_class c2 ON con.confrelid = c2.oid | ||||
|             LEFT JOIN | ||||
|                 pg_attribute a1 ON c1.oid = a1.attrelid AND a1.attnum = con.conkey[1] | ||||
|             LEFT JOIN | ||||
|                 pg_attribute a2 ON c2.oid = a2.attrelid AND a2.attnum = con.confkey[1] | ||||
|             WHERE | ||||
|                 c1.relname = %s AND | ||||
|                 con.contype = 'f' AND | ||||
|                 c1.relnamespace = c2.relnamespace AND | ||||
|                 pg_catalog.pg_table_is_visible(c1.oid) | ||||
|         """, | ||||
|             [table_name], | ||||
|         ) | ||||
|         return {row[0]: (row[2], row[1]) for row in cursor.fetchall()} | ||||
|  | ||||
|     def get_constraints(self, cursor, table_name): | ||||
|         """ | ||||
|         Retrieve any constraints or keys (unique, pk, fk, check, index) across | ||||
|         one or more columns. Also retrieve the definition of expression-based | ||||
|         indexes. | ||||
|         """ | ||||
|         constraints = {} | ||||
|         # Loop over the key table, collecting things as constraints. The column | ||||
|         # array must return column names in the same order in which they were | ||||
|         # created. | ||||
|         cursor.execute( | ||||
|             """ | ||||
|             SELECT | ||||
|                 c.conname, | ||||
|                 array( | ||||
|                     SELECT attname | ||||
|                     FROM unnest(c.conkey) WITH ORDINALITY cols(colid, arridx) | ||||
|                     JOIN pg_attribute AS ca ON cols.colid = ca.attnum | ||||
|                     WHERE ca.attrelid = c.conrelid | ||||
|                     ORDER BY cols.arridx | ||||
|                 ), | ||||
|                 c.contype, | ||||
|                 (SELECT fkc.relname || '.' || fka.attname | ||||
|                 FROM pg_attribute AS fka | ||||
|                 JOIN pg_class AS fkc ON fka.attrelid = fkc.oid | ||||
|                 WHERE fka.attrelid = c.confrelid AND fka.attnum = c.confkey[1]), | ||||
|                 cl.reloptions | ||||
|             FROM pg_constraint AS c | ||||
|             JOIN pg_class AS cl ON c.conrelid = cl.oid | ||||
|             WHERE cl.relname = %s AND pg_catalog.pg_table_is_visible(cl.oid) | ||||
|         """, | ||||
|             [table_name], | ||||
|         ) | ||||
|         for constraint, columns, kind, used_cols, options in cursor.fetchall(): | ||||
|             constraints[constraint] = { | ||||
|                 "columns": columns, | ||||
|                 "primary_key": kind == "p", | ||||
|                 "unique": kind in ["p", "u"], | ||||
|                 "foreign_key": tuple(used_cols.split(".", 1)) if kind == "f" else None, | ||||
|                 "check": kind == "c", | ||||
|                 "index": False, | ||||
|                 "definition": None, | ||||
|                 "options": options, | ||||
|             } | ||||
|         # Now get indexes | ||||
|         cursor.execute( | ||||
|             """ | ||||
|             SELECT | ||||
|                 indexname, | ||||
|                 array_agg(attname ORDER BY arridx), | ||||
|                 indisunique, | ||||
|                 indisprimary, | ||||
|                 array_agg(ordering ORDER BY arridx), | ||||
|                 amname, | ||||
|                 exprdef, | ||||
|                 s2.attoptions | ||||
|             FROM ( | ||||
|                 SELECT | ||||
|                     c2.relname as indexname, idx.*, attr.attname, am.amname, | ||||
|                     CASE | ||||
|                         WHEN idx.indexprs IS NOT NULL THEN | ||||
|                             pg_get_indexdef(idx.indexrelid) | ||||
|                     END AS exprdef, | ||||
|                     CASE am.amname | ||||
|                         WHEN %s THEN | ||||
|                             CASE (option & 1) | ||||
|                                 WHEN 1 THEN 'DESC' ELSE 'ASC' | ||||
|                             END | ||||
|                     END as ordering, | ||||
|                     c2.reloptions as attoptions | ||||
|                 FROM ( | ||||
|                     SELECT * | ||||
|                     FROM | ||||
|                         pg_index i, | ||||
|                         unnest(i.indkey, i.indoption) | ||||
|                             WITH ORDINALITY koi(key, option, arridx) | ||||
|                 ) idx | ||||
|                 LEFT JOIN pg_class c ON idx.indrelid = c.oid | ||||
|                 LEFT JOIN pg_class c2 ON idx.indexrelid = c2.oid | ||||
|                 LEFT JOIN pg_am am ON c2.relam = am.oid | ||||
|                 LEFT JOIN | ||||
|                     pg_attribute attr ON attr.attrelid = c.oid AND attr.attnum = idx.key | ||||
|                 WHERE c.relname = %s AND pg_catalog.pg_table_is_visible(c.oid) | ||||
|             ) s2 | ||||
|             GROUP BY indexname, indisunique, indisprimary, amname, exprdef, attoptions; | ||||
|         """, | ||||
|             [self.index_default_access_method, table_name], | ||||
|         ) | ||||
|         for ( | ||||
|             index, | ||||
|             columns, | ||||
|             unique, | ||||
|             primary, | ||||
|             orders, | ||||
|             type_, | ||||
|             definition, | ||||
|             options, | ||||
|         ) in cursor.fetchall(): | ||||
|             if index not in constraints: | ||||
|                 basic_index = ( | ||||
|                     type_ == self.index_default_access_method | ||||
|                     and | ||||
|                     # '_btree' references | ||||
|                     # django.contrib.postgres.indexes.BTreeIndex.suffix. | ||||
|                     not index.endswith("_btree") | ||||
|                     and options is None | ||||
|                 ) | ||||
|                 constraints[index] = { | ||||
|                     "columns": columns if columns != [None] else [], | ||||
|                     "orders": orders if orders != [None] else [], | ||||
|                     "primary_key": primary, | ||||
|                     "unique": unique, | ||||
|                     "foreign_key": None, | ||||
|                     "check": False, | ||||
|                     "index": True, | ||||
|                     "type": Index.suffix if basic_index else type_, | ||||
|                     "definition": definition, | ||||
|                     "options": options, | ||||
|                 } | ||||
|         return constraints | ||||
| @ -0,0 +1,404 @@ | ||||
| import json | ||||
| from functools import lru_cache, partial | ||||
|  | ||||
| from django.conf import settings | ||||
| from django.db.backends.base.operations import BaseDatabaseOperations | ||||
| from django.db.backends.postgresql.psycopg_any import ( | ||||
|     Inet, | ||||
|     Jsonb, | ||||
|     errors, | ||||
|     is_psycopg3, | ||||
|     mogrify, | ||||
| ) | ||||
| from django.db.backends.utils import split_tzname_delta | ||||
| from django.db.models.constants import OnConflict | ||||
| from django.utils.regex_helper import _lazy_re_compile | ||||
|  | ||||
|  | ||||
| @lru_cache | ||||
| def get_json_dumps(encoder): | ||||
|     if encoder is None: | ||||
|         return json.dumps | ||||
|     return partial(json.dumps, cls=encoder) | ||||
|  | ||||
|  | ||||
| class DatabaseOperations(BaseDatabaseOperations): | ||||
|     cast_char_field_without_max_length = "varchar" | ||||
|     explain_prefix = "EXPLAIN" | ||||
|     explain_options = frozenset( | ||||
|         [ | ||||
|             "ANALYZE", | ||||
|             "BUFFERS", | ||||
|             "COSTS", | ||||
|             "SETTINGS", | ||||
|             "SUMMARY", | ||||
|             "TIMING", | ||||
|             "VERBOSE", | ||||
|             "WAL", | ||||
|         ] | ||||
|     ) | ||||
|     cast_data_types = { | ||||
|         "AutoField": "integer", | ||||
|         "BigAutoField": "bigint", | ||||
|         "SmallAutoField": "smallint", | ||||
|     } | ||||
|  | ||||
|     if is_psycopg3: | ||||
|         from psycopg.types import numeric | ||||
|  | ||||
|         integerfield_type_map = { | ||||
|             "SmallIntegerField": numeric.Int2, | ||||
|             "IntegerField": numeric.Int4, | ||||
|             "BigIntegerField": numeric.Int8, | ||||
|             "PositiveSmallIntegerField": numeric.Int2, | ||||
|             "PositiveIntegerField": numeric.Int4, | ||||
|             "PositiveBigIntegerField": numeric.Int8, | ||||
|         } | ||||
|  | ||||
|     def unification_cast_sql(self, output_field): | ||||
|         internal_type = output_field.get_internal_type() | ||||
|         if internal_type in ( | ||||
|             "GenericIPAddressField", | ||||
|             "IPAddressField", | ||||
|             "TimeField", | ||||
|             "UUIDField", | ||||
|         ): | ||||
|             # PostgreSQL will resolve a union as type 'text' if input types are | ||||
|             # 'unknown'. | ||||
|             # https://www.postgresql.org/docs/current/typeconv-union-case.html | ||||
|             # These fields cannot be implicitly cast back in the default | ||||
|             # PostgreSQL configuration so we need to explicitly cast them. | ||||
|             # We must also remove components of the type within brackets: | ||||
|             # varchar(255) -> varchar. | ||||
|             return ( | ||||
|                 "CAST(%%s AS %s)" % output_field.db_type(self.connection).split("(")[0] | ||||
|             ) | ||||
|         return "%s" | ||||
|  | ||||
|     # EXTRACT format cannot be passed in parameters. | ||||
|     _extract_format_re = _lazy_re_compile(r"[A-Z_]+") | ||||
|  | ||||
|     def date_extract_sql(self, lookup_type, sql, params): | ||||
|         # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT | ||||
|         if lookup_type == "week_day": | ||||
|             # For consistency across backends, we return Sunday=1, Saturday=7. | ||||
|             return f"EXTRACT(DOW FROM {sql}) + 1", params | ||||
|         elif lookup_type == "iso_week_day": | ||||
|             return f"EXTRACT(ISODOW FROM {sql})", params | ||||
|         elif lookup_type == "iso_year": | ||||
|             return f"EXTRACT(ISOYEAR FROM {sql})", params | ||||
|  | ||||
|         lookup_type = lookup_type.upper() | ||||
|         if not self._extract_format_re.fullmatch(lookup_type): | ||||
|             raise ValueError(f"Invalid lookup type: {lookup_type!r}") | ||||
|         return f"EXTRACT({lookup_type} FROM {sql})", params | ||||
|  | ||||
|     def date_trunc_sql(self, lookup_type, sql, params, tzname=None): | ||||
|         sql, params = self._convert_sql_to_tz(sql, params, tzname) | ||||
|         # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC | ||||
|         return f"DATE_TRUNC(%s, {sql})", (lookup_type, *params) | ||||
|  | ||||
|     def _prepare_tzname_delta(self, tzname): | ||||
|         tzname, sign, offset = split_tzname_delta(tzname) | ||||
|         if offset: | ||||
|             sign = "-" if sign == "+" else "+" | ||||
|             return f"{tzname}{sign}{offset}" | ||||
|         return tzname | ||||
|  | ||||
|     def _convert_sql_to_tz(self, sql, params, tzname): | ||||
|         if tzname and settings.USE_TZ: | ||||
|             tzname_param = self._prepare_tzname_delta(tzname) | ||||
|             return f"{sql} AT TIME ZONE %s", (*params, tzname_param) | ||||
|         return sql, params | ||||
|  | ||||
|     def datetime_cast_date_sql(self, sql, params, tzname): | ||||
|         sql, params = self._convert_sql_to_tz(sql, params, tzname) | ||||
|         return f"({sql})::date", params | ||||
|  | ||||
|     def datetime_cast_time_sql(self, sql, params, tzname): | ||||
|         sql, params = self._convert_sql_to_tz(sql, params, tzname) | ||||
|         return f"({sql})::time", params | ||||
|  | ||||
|     def datetime_extract_sql(self, lookup_type, sql, params, tzname): | ||||
|         sql, params = self._convert_sql_to_tz(sql, params, tzname) | ||||
|         if lookup_type == "second": | ||||
|             # Truncate fractional seconds. | ||||
|             return f"EXTRACT(SECOND FROM DATE_TRUNC(%s, {sql}))", ("second", *params) | ||||
|         return self.date_extract_sql(lookup_type, sql, params) | ||||
|  | ||||
|     def datetime_trunc_sql(self, lookup_type, sql, params, tzname): | ||||
|         sql, params = self._convert_sql_to_tz(sql, params, tzname) | ||||
|         # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC | ||||
|         return f"DATE_TRUNC(%s, {sql})", (lookup_type, *params) | ||||
|  | ||||
|     def time_extract_sql(self, lookup_type, sql, params): | ||||
|         if lookup_type == "second": | ||||
|             # Truncate fractional seconds. | ||||
|             return f"EXTRACT(SECOND FROM DATE_TRUNC(%s, {sql}))", ("second", *params) | ||||
|         return self.date_extract_sql(lookup_type, sql, params) | ||||
|  | ||||
|     def time_trunc_sql(self, lookup_type, sql, params, tzname=None): | ||||
|         sql, params = self._convert_sql_to_tz(sql, params, tzname) | ||||
|         return f"DATE_TRUNC(%s, {sql})::time", (lookup_type, *params) | ||||
|  | ||||
|     def deferrable_sql(self): | ||||
|         return " DEFERRABLE INITIALLY DEFERRED" | ||||
|  | ||||
|     def fetch_returned_insert_rows(self, cursor): | ||||
|         """ | ||||
|         Given a cursor object that has just performed an INSERT...RETURNING | ||||
|         statement into a table, return the tuple of returned data. | ||||
|         """ | ||||
|         return cursor.fetchall() | ||||
|  | ||||
|     def lookup_cast(self, lookup_type, internal_type=None): | ||||
|         lookup = "%s" | ||||
|         # Cast text lookups to text to allow things like filter(x__contains=4) | ||||
|         if lookup_type in ( | ||||
|             "iexact", | ||||
|             "contains", | ||||
|             "icontains", | ||||
|             "startswith", | ||||
|             "istartswith", | ||||
|             "endswith", | ||||
|             "iendswith", | ||||
|             "regex", | ||||
|             "iregex", | ||||
|         ): | ||||
|             if internal_type in ("IPAddressField", "GenericIPAddressField"): | ||||
|                 lookup = "HOST(%s)" | ||||
|             # RemovedInDjango51Warning. | ||||
|             elif internal_type in ("CICharField", "CIEmailField", "CITextField"): | ||||
|                 lookup = "%s::citext" | ||||
|             else: | ||||
|                 lookup = "%s::text" | ||||
|  | ||||
|         # Use UPPER(x) for case-insensitive lookups; it's faster. | ||||
|         if lookup_type in ("iexact", "icontains", "istartswith", "iendswith"): | ||||
|             lookup = "UPPER(%s)" % lookup | ||||
|  | ||||
|         return lookup | ||||
|  | ||||
|     def no_limit_value(self): | ||||
|         return None | ||||
|  | ||||
|     def prepare_sql_script(self, sql): | ||||
|         return [sql] | ||||
|  | ||||
|     def quote_name(self, name): | ||||
|         if name.startswith('"') and name.endswith('"'): | ||||
|             return name  # Quoting once is enough. | ||||
|         return '"%s"' % name | ||||
|  | ||||
|     def compose_sql(self, sql, params): | ||||
|         return mogrify(sql, params, self.connection) | ||||
|  | ||||
|     def set_time_zone_sql(self): | ||||
|         return "SELECT set_config('TimeZone', %s, false)" | ||||
|  | ||||
|     def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False): | ||||
|         if not tables: | ||||
|             return [] | ||||
|  | ||||
|         # Perform a single SQL 'TRUNCATE x, y, z...;' statement. It allows us | ||||
|         # to truncate tables referenced by a foreign key in any other table. | ||||
|         sql_parts = [ | ||||
|             style.SQL_KEYWORD("TRUNCATE"), | ||||
|             ", ".join(style.SQL_FIELD(self.quote_name(table)) for table in tables), | ||||
|         ] | ||||
|         if reset_sequences: | ||||
|             sql_parts.append(style.SQL_KEYWORD("RESTART IDENTITY")) | ||||
|         if allow_cascade: | ||||
|             sql_parts.append(style.SQL_KEYWORD("CASCADE")) | ||||
|         return ["%s;" % " ".join(sql_parts)] | ||||
|  | ||||
|     def sequence_reset_by_name_sql(self, style, sequences): | ||||
|         # 'ALTER SEQUENCE sequence_name RESTART WITH 1;'... style SQL statements | ||||
|         # to reset sequence indices | ||||
|         sql = [] | ||||
|         for sequence_info in sequences: | ||||
|             table_name = sequence_info["table"] | ||||
|             # 'id' will be the case if it's an m2m using an autogenerated | ||||
|             # intermediate table (see BaseDatabaseIntrospection.sequence_list). | ||||
|             column_name = sequence_info["column"] or "id" | ||||
|             sql.append( | ||||
|                 "%s setval(pg_get_serial_sequence('%s','%s'), 1, false);" | ||||
|                 % ( | ||||
|                     style.SQL_KEYWORD("SELECT"), | ||||
|                     style.SQL_TABLE(self.quote_name(table_name)), | ||||
|                     style.SQL_FIELD(column_name), | ||||
|                 ) | ||||
|             ) | ||||
|         return sql | ||||
|  | ||||
|     def tablespace_sql(self, tablespace, inline=False): | ||||
|         if inline: | ||||
|             return "USING INDEX TABLESPACE %s" % self.quote_name(tablespace) | ||||
|         else: | ||||
|             return "TABLESPACE %s" % self.quote_name(tablespace) | ||||
|  | ||||
|     def sequence_reset_sql(self, style, model_list): | ||||
|         from django.db import models | ||||
|  | ||||
|         output = [] | ||||
|         qn = self.quote_name | ||||
|         for model in model_list: | ||||
|             # Use `coalesce` to set the sequence for each model to the max pk | ||||
|             # value if there are records, or 1 if there are none. Set the | ||||
|             # `is_called` property (the third argument to `setval`) to true if | ||||
|             # there are records (as the max pk value is already in use), | ||||
|             # otherwise set it to false. Use pg_get_serial_sequence to get the | ||||
|             # underlying sequence name from the table name and column name. | ||||
|  | ||||
|             for f in model._meta.local_fields: | ||||
|                 if isinstance(f, models.AutoField): | ||||
|                     output.append( | ||||
|                         "%s setval(pg_get_serial_sequence('%s','%s'), " | ||||
|                         "coalesce(max(%s), 1), max(%s) %s null) %s %s;" | ||||
|                         % ( | ||||
|                             style.SQL_KEYWORD("SELECT"), | ||||
|                             style.SQL_TABLE(qn(model._meta.db_table)), | ||||
|                             style.SQL_FIELD(f.column), | ||||
|                             style.SQL_FIELD(qn(f.column)), | ||||
|                             style.SQL_FIELD(qn(f.column)), | ||||
|                             style.SQL_KEYWORD("IS NOT"), | ||||
|                             style.SQL_KEYWORD("FROM"), | ||||
|                             style.SQL_TABLE(qn(model._meta.db_table)), | ||||
|                         ) | ||||
|                     ) | ||||
|                     # Only one AutoField is allowed per model, so don't bother | ||||
|                     # continuing. | ||||
|                     break | ||||
|         return output | ||||
|  | ||||
|     def prep_for_iexact_query(self, x): | ||||
|         return x | ||||
|  | ||||
|     def max_name_length(self): | ||||
|         """ | ||||
|         Return the maximum length of an identifier. | ||||
|  | ||||
|         The maximum length of an identifier is 63 by default, but can be | ||||
|         changed by recompiling PostgreSQL after editing the NAMEDATALEN | ||||
|         macro in src/include/pg_config_manual.h. | ||||
|  | ||||
|         This implementation returns 63, but can be overridden by a custom | ||||
|         database backend that inherits most of its behavior from this one. | ||||
|         """ | ||||
|         return 63 | ||||
|  | ||||
|     def distinct_sql(self, fields, params): | ||||
|         if fields: | ||||
|             params = [param for param_list in params for param in param_list] | ||||
|             return (["DISTINCT ON (%s)" % ", ".join(fields)], params) | ||||
|         else: | ||||
|             return ["DISTINCT"], [] | ||||
|  | ||||
|     if is_psycopg3: | ||||
|  | ||||
|         def last_executed_query(self, cursor, sql, params): | ||||
|             try: | ||||
|                 return self.compose_sql(sql, params) | ||||
|             except errors.DataError: | ||||
|                 return None | ||||
|  | ||||
|     else: | ||||
|  | ||||
|         def last_executed_query(self, cursor, sql, params): | ||||
|             # https://www.psycopg.org/docs/cursor.html#cursor.query | ||||
|             # The query attribute is a Psycopg extension to the DB API 2.0. | ||||
|             if cursor.query is not None: | ||||
|                 return cursor.query.decode() | ||||
|             return None | ||||
|  | ||||
|     def return_insert_columns(self, fields): | ||||
|         if not fields: | ||||
|             return "", () | ||||
|         columns = [ | ||||
|             "%s.%s" | ||||
|             % ( | ||||
|                 self.quote_name(field.model._meta.db_table), | ||||
|                 self.quote_name(field.column), | ||||
|             ) | ||||
|             for field in fields | ||||
|         ] | ||||
|         return "RETURNING %s" % ", ".join(columns), () | ||||
|  | ||||
|     def bulk_insert_sql(self, fields, placeholder_rows): | ||||
|         placeholder_rows_sql = (", ".join(row) for row in placeholder_rows) | ||||
|         values_sql = ", ".join("(%s)" % sql for sql in placeholder_rows_sql) | ||||
|         return "VALUES " + values_sql | ||||
|  | ||||
|     if is_psycopg3: | ||||
|  | ||||
|         def adapt_integerfield_value(self, value, internal_type): | ||||
|             if value is None or hasattr(value, "resolve_expression"): | ||||
|                 return value | ||||
|             return self.integerfield_type_map[internal_type](value) | ||||
|  | ||||
|     def adapt_datefield_value(self, value): | ||||
|         return value | ||||
|  | ||||
|     def adapt_datetimefield_value(self, value): | ||||
|         return value | ||||
|  | ||||
|     def adapt_timefield_value(self, value): | ||||
|         return value | ||||
|  | ||||
|     def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None): | ||||
|         return value | ||||
|  | ||||
|     def adapt_ipaddressfield_value(self, value): | ||||
|         if value: | ||||
|             return Inet(value) | ||||
|         return None | ||||
|  | ||||
|     def adapt_json_value(self, value, encoder): | ||||
|         return Jsonb(value, dumps=get_json_dumps(encoder)) | ||||
|  | ||||
|     def subtract_temporals(self, internal_type, lhs, rhs): | ||||
|         if internal_type == "DateField": | ||||
|             lhs_sql, lhs_params = lhs | ||||
|             rhs_sql, rhs_params = rhs | ||||
|             params = (*lhs_params, *rhs_params) | ||||
|             return "(interval '1 day' * (%s - %s))" % (lhs_sql, rhs_sql), params | ||||
|         return super().subtract_temporals(internal_type, lhs, rhs) | ||||
|  | ||||
|     def explain_query_prefix(self, format=None, **options): | ||||
|         extra = {} | ||||
|         # Normalize options. | ||||
|         if options: | ||||
|             options = { | ||||
|                 name.upper(): "true" if value else "false" | ||||
|                 for name, value in options.items() | ||||
|             } | ||||
|             for valid_option in self.explain_options: | ||||
|                 value = options.pop(valid_option, None) | ||||
|                 if value is not None: | ||||
|                     extra[valid_option] = value | ||||
|         prefix = super().explain_query_prefix(format, **options) | ||||
|         if format: | ||||
|             extra["FORMAT"] = format | ||||
|         if extra: | ||||
|             prefix += " (%s)" % ", ".join("%s %s" % i for i in extra.items()) | ||||
|         return prefix | ||||
|  | ||||
|     def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields): | ||||
|         if on_conflict == OnConflict.IGNORE: | ||||
|             return "ON CONFLICT DO NOTHING" | ||||
|         if on_conflict == OnConflict.UPDATE: | ||||
|             return "ON CONFLICT(%s) DO UPDATE SET %s" % ( | ||||
|                 ", ".join(map(self.quote_name, unique_fields)), | ||||
|                 ", ".join( | ||||
|                     [ | ||||
|                         f"{field} = EXCLUDED.{field}" | ||||
|                         for field in map(self.quote_name, update_fields) | ||||
|                     ] | ||||
|                 ), | ||||
|             ) | ||||
|         return super().on_conflict_suffix_sql( | ||||
|             fields, | ||||
|             on_conflict, | ||||
|             update_fields, | ||||
|             unique_fields, | ||||
|         ) | ||||
| @ -0,0 +1,103 @@ | ||||
| import ipaddress | ||||
| from functools import lru_cache | ||||
|  | ||||
| try: | ||||
|     from psycopg import ClientCursor, IsolationLevel, adapt, adapters, errors, sql | ||||
|     from psycopg.postgres import types | ||||
|     from psycopg.types.datetime import TimestamptzLoader | ||||
|     from psycopg.types.json import Jsonb | ||||
|     from psycopg.types.range import Range, RangeDumper | ||||
|     from psycopg.types.string import TextLoader | ||||
|  | ||||
|     Inet = ipaddress.ip_address | ||||
|  | ||||
|     DateRange = DateTimeRange = DateTimeTZRange = NumericRange = Range | ||||
|     RANGE_TYPES = (Range,) | ||||
|  | ||||
|     TSRANGE_OID = types["tsrange"].oid | ||||
|     TSTZRANGE_OID = types["tstzrange"].oid | ||||
|  | ||||
|     def mogrify(sql, params, connection): | ||||
|         with connection.cursor() as cursor: | ||||
|             return ClientCursor(cursor.connection).mogrify(sql, params) | ||||
|  | ||||
|     # Adapters. | ||||
|     class BaseTzLoader(TimestamptzLoader): | ||||
|         """ | ||||
|         Load a PostgreSQL timestamptz using the a specific timezone. | ||||
|         The timezone can be None too, in which case it will be chopped. | ||||
|         """ | ||||
|  | ||||
|         timezone = None | ||||
|  | ||||
|         def load(self, data): | ||||
|             res = super().load(data) | ||||
|             return res.replace(tzinfo=self.timezone) | ||||
|  | ||||
|     def register_tzloader(tz, context): | ||||
|         class SpecificTzLoader(BaseTzLoader): | ||||
|             timezone = tz | ||||
|  | ||||
|         context.adapters.register_loader("timestamptz", SpecificTzLoader) | ||||
|  | ||||
|     class DjangoRangeDumper(RangeDumper): | ||||
|         """A Range dumper customized for Django.""" | ||||
|  | ||||
|         def upgrade(self, obj, format): | ||||
|             # Dump ranges containing naive datetimes as tstzrange, because | ||||
|             # Django doesn't use tz-aware ones. | ||||
|             dumper = super().upgrade(obj, format) | ||||
|             if dumper is not self and dumper.oid == TSRANGE_OID: | ||||
|                 dumper.oid = TSTZRANGE_OID | ||||
|             return dumper | ||||
|  | ||||
|     @lru_cache | ||||
|     def get_adapters_template(use_tz, timezone): | ||||
|         # Create at adapters map extending the base one. | ||||
|         ctx = adapt.AdaptersMap(adapters) | ||||
|         # Register a no-op dumper to avoid a round trip from psycopg version 3 | ||||
|         # decode to json.dumps() to json.loads(), when using a custom decoder | ||||
|         # in JSONField. | ||||
|         ctx.register_loader("jsonb", TextLoader) | ||||
|         # Don't convert automatically from PostgreSQL network types to Python | ||||
|         # ipaddress. | ||||
|         ctx.register_loader("inet", TextLoader) | ||||
|         ctx.register_loader("cidr", TextLoader) | ||||
|         ctx.register_dumper(Range, DjangoRangeDumper) | ||||
|         # Register a timestamptz loader configured on self.timezone. | ||||
|         # This, however, can be overridden by create_cursor. | ||||
|         register_tzloader(timezone, ctx) | ||||
|         return ctx | ||||
|  | ||||
|     is_psycopg3 = True | ||||
|  | ||||
| except ImportError: | ||||
|     from enum import IntEnum | ||||
|  | ||||
|     from psycopg2 import errors, extensions, sql  # NOQA | ||||
|     from psycopg2.extras import DateRange, DateTimeRange, DateTimeTZRange, Inet  # NOQA | ||||
|     from psycopg2.extras import Json as Jsonb  # NOQA | ||||
|     from psycopg2.extras import NumericRange, Range  # NOQA | ||||
|  | ||||
|     RANGE_TYPES = (DateRange, DateTimeRange, DateTimeTZRange, NumericRange) | ||||
|  | ||||
|     class IsolationLevel(IntEnum): | ||||
|         READ_UNCOMMITTED = extensions.ISOLATION_LEVEL_READ_UNCOMMITTED | ||||
|         READ_COMMITTED = extensions.ISOLATION_LEVEL_READ_COMMITTED | ||||
|         REPEATABLE_READ = extensions.ISOLATION_LEVEL_REPEATABLE_READ | ||||
|         SERIALIZABLE = extensions.ISOLATION_LEVEL_SERIALIZABLE | ||||
|  | ||||
|     def _quote(value, connection=None): | ||||
|         adapted = extensions.adapt(value) | ||||
|         if hasattr(adapted, "encoding"): | ||||
|             adapted.encoding = "utf8" | ||||
|         # getquoted() returns a quoted bytestring of the adapted value. | ||||
|         return adapted.getquoted().decode() | ||||
|  | ||||
|     sql.quote = _quote | ||||
|  | ||||
|     def mogrify(sql, params, connection): | ||||
|         with connection.cursor() as cursor: | ||||
|             return cursor.mogrify(sql, params).decode() | ||||
|  | ||||
|     is_psycopg3 = False | ||||
| @ -0,0 +1,374 @@ | ||||
| from django.db.backends.base.schema import BaseDatabaseSchemaEditor | ||||
| from django.db.backends.ddl_references import IndexColumns | ||||
| from django.db.backends.postgresql.psycopg_any import sql | ||||
| from django.db.backends.utils import strip_quotes | ||||
|  | ||||
|  | ||||
| class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): | ||||
|     # Setting all constraints to IMMEDIATE to allow changing data in the same | ||||
|     # transaction. | ||||
|     sql_update_with_default = ( | ||||
|         "UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL" | ||||
|         "; SET CONSTRAINTS ALL IMMEDIATE" | ||||
|     ) | ||||
|     sql_alter_sequence_type = "ALTER SEQUENCE IF EXISTS %(sequence)s AS %(type)s" | ||||
|     sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE" | ||||
|  | ||||
|     sql_create_index = ( | ||||
|         "CREATE INDEX %(name)s ON %(table)s%(using)s " | ||||
|         "(%(columns)s)%(include)s%(extra)s%(condition)s" | ||||
|     ) | ||||
|     sql_create_index_concurrently = ( | ||||
|         "CREATE INDEX CONCURRENTLY %(name)s ON %(table)s%(using)s " | ||||
|         "(%(columns)s)%(include)s%(extra)s%(condition)s" | ||||
|     ) | ||||
|     sql_delete_index = "DROP INDEX IF EXISTS %(name)s" | ||||
|     sql_delete_index_concurrently = "DROP INDEX CONCURRENTLY IF EXISTS %(name)s" | ||||
|  | ||||
|     # Setting the constraint to IMMEDIATE to allow changing data in the same | ||||
|     # transaction. | ||||
|     sql_create_column_inline_fk = ( | ||||
|         "CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s" | ||||
|         "; SET CONSTRAINTS %(namespace)s%(name)s IMMEDIATE" | ||||
|     ) | ||||
|     # Setting the constraint to IMMEDIATE runs any deferred checks to allow | ||||
|     # dropping it in the same transaction. | ||||
|     sql_delete_fk = ( | ||||
|         "SET CONSTRAINTS %(name)s IMMEDIATE; " | ||||
|         "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s" | ||||
|     ) | ||||
|     sql_delete_procedure = "DROP FUNCTION %(procedure)s(%(param_types)s)" | ||||
|  | ||||
|     def execute(self, sql, params=()): | ||||
|         # Merge the query client-side, as PostgreSQL won't do it server-side. | ||||
|         if params is None: | ||||
|             return super().execute(sql, params) | ||||
|         sql = self.connection.ops.compose_sql(str(sql), params) | ||||
|         # Don't let the superclass touch anything. | ||||
|         return super().execute(sql, None) | ||||
|  | ||||
|     sql_add_identity = ( | ||||
|         "ALTER TABLE %(table)s ALTER COLUMN %(column)s ADD " | ||||
|         "GENERATED BY DEFAULT AS IDENTITY" | ||||
|     ) | ||||
|     sql_drop_indentity = ( | ||||
|         "ALTER TABLE %(table)s ALTER COLUMN %(column)s DROP IDENTITY IF EXISTS" | ||||
|     ) | ||||
|  | ||||
|     def quote_value(self, value): | ||||
|         if isinstance(value, str): | ||||
|             value = value.replace("%", "%%") | ||||
|         return sql.quote(value, self.connection.connection) | ||||
|  | ||||
|     def _field_indexes_sql(self, model, field): | ||||
|         output = super()._field_indexes_sql(model, field) | ||||
|         like_index_statement = self._create_like_index_sql(model, field) | ||||
|         if like_index_statement is not None: | ||||
|             output.append(like_index_statement) | ||||
|         return output | ||||
|  | ||||
|     def _field_data_type(self, field): | ||||
|         if field.is_relation: | ||||
|             return field.rel_db_type(self.connection) | ||||
|         return self.connection.data_types.get( | ||||
|             field.get_internal_type(), | ||||
|             field.db_type(self.connection), | ||||
|         ) | ||||
|  | ||||
|     def _field_base_data_types(self, field): | ||||
|         # Yield base data types for array fields. | ||||
|         if field.base_field.get_internal_type() == "ArrayField": | ||||
|             yield from self._field_base_data_types(field.base_field) | ||||
|         else: | ||||
|             yield self._field_data_type(field.base_field) | ||||
|  | ||||
|     def _create_like_index_sql(self, model, field): | ||||
|         """ | ||||
|         Return the statement to create an index with varchar operator pattern | ||||
|         when the column type is 'varchar' or 'text', otherwise return None. | ||||
|         """ | ||||
|         db_type = field.db_type(connection=self.connection) | ||||
|         if db_type is not None and (field.db_index or field.unique): | ||||
|             # Fields with database column types of `varchar` and `text` need | ||||
|             # a second index that specifies their operator class, which is | ||||
|             # needed when performing correct LIKE queries outside the | ||||
|             # C locale. See #12234. | ||||
|             # | ||||
|             # The same doesn't apply to array fields such as varchar[size] | ||||
|             # and text[size], so skip them. | ||||
|             if "[" in db_type: | ||||
|                 return None | ||||
|             # Non-deterministic collations on Postgresql don't support indexes | ||||
|             # for operator classes varchar_pattern_ops/text_pattern_ops. | ||||
|             if getattr(field, "db_collation", None): | ||||
|                 return None | ||||
|             if db_type.startswith("varchar"): | ||||
|                 return self._create_index_sql( | ||||
|                     model, | ||||
|                     fields=[field], | ||||
|                     suffix="_like", | ||||
|                     opclasses=["varchar_pattern_ops"], | ||||
|                 ) | ||||
|             elif db_type.startswith("text"): | ||||
|                 return self._create_index_sql( | ||||
|                     model, | ||||
|                     fields=[field], | ||||
|                     suffix="_like", | ||||
|                     opclasses=["text_pattern_ops"], | ||||
|                 ) | ||||
|         return None | ||||
|  | ||||
|     def _using_sql(self, new_field, old_field): | ||||
|         using_sql = " USING %(column)s::%(type)s" | ||||
|         new_internal_type = new_field.get_internal_type() | ||||
|         old_internal_type = old_field.get_internal_type() | ||||
|         if new_internal_type == "ArrayField" and new_internal_type == old_internal_type: | ||||
|             # Compare base data types for array fields. | ||||
|             if list(self._field_base_data_types(old_field)) != list( | ||||
|                 self._field_base_data_types(new_field) | ||||
|             ): | ||||
|                 return using_sql | ||||
|         elif self._field_data_type(old_field) != self._field_data_type(new_field): | ||||
|             return using_sql | ||||
|         return "" | ||||
|  | ||||
|     def _get_sequence_name(self, table, column): | ||||
|         with self.connection.cursor() as cursor: | ||||
|             for sequence in self.connection.introspection.get_sequences(cursor, table): | ||||
|                 if sequence["column"] == column: | ||||
|                     return sequence["name"] | ||||
|         return None | ||||
|  | ||||
|     def _alter_column_type_sql( | ||||
|         self, model, old_field, new_field, new_type, old_collation, new_collation | ||||
|     ): | ||||
|         # Drop indexes on varchar/text/citext columns that are changing to a | ||||
|         # different type. | ||||
|         old_db_params = old_field.db_parameters(connection=self.connection) | ||||
|         old_type = old_db_params["type"] | ||||
|         if (old_field.db_index or old_field.unique) and ( | ||||
|             (old_type.startswith("varchar") and not new_type.startswith("varchar")) | ||||
|             or (old_type.startswith("text") and not new_type.startswith("text")) | ||||
|             or (old_type.startswith("citext") and not new_type.startswith("citext")) | ||||
|         ): | ||||
|             index_name = self._create_index_name( | ||||
|                 model._meta.db_table, [old_field.column], suffix="_like" | ||||
|             ) | ||||
|             self.execute(self._delete_index_sql(model, index_name)) | ||||
|  | ||||
|         self.sql_alter_column_type = ( | ||||
|             "ALTER COLUMN %(column)s TYPE %(type)s%(collation)s" | ||||
|         ) | ||||
|         # Cast when data type changed. | ||||
|         if using_sql := self._using_sql(new_field, old_field): | ||||
|             self.sql_alter_column_type += using_sql | ||||
|         new_internal_type = new_field.get_internal_type() | ||||
|         old_internal_type = old_field.get_internal_type() | ||||
|         # Make ALTER TYPE with IDENTITY make sense. | ||||
|         table = strip_quotes(model._meta.db_table) | ||||
|         auto_field_types = { | ||||
|             "AutoField", | ||||
|             "BigAutoField", | ||||
|             "SmallAutoField", | ||||
|         } | ||||
|         old_is_auto = old_internal_type in auto_field_types | ||||
|         new_is_auto = new_internal_type in auto_field_types | ||||
|         if new_is_auto and not old_is_auto: | ||||
|             column = strip_quotes(new_field.column) | ||||
|             return ( | ||||
|                 ( | ||||
|                     self.sql_alter_column_type | ||||
|                     % { | ||||
|                         "column": self.quote_name(column), | ||||
|                         "type": new_type, | ||||
|                         "collation": "", | ||||
|                     }, | ||||
|                     [], | ||||
|                 ), | ||||
|                 [ | ||||
|                     ( | ||||
|                         self.sql_add_identity | ||||
|                         % { | ||||
|                             "table": self.quote_name(table), | ||||
|                             "column": self.quote_name(column), | ||||
|                         }, | ||||
|                         [], | ||||
|                     ), | ||||
|                 ], | ||||
|             ) | ||||
|         elif old_is_auto and not new_is_auto: | ||||
|             # Drop IDENTITY if exists (pre-Django 4.1 serial columns don't have | ||||
|             # it). | ||||
|             self.execute( | ||||
|                 self.sql_drop_indentity | ||||
|                 % { | ||||
|                     "table": self.quote_name(table), | ||||
|                     "column": self.quote_name(strip_quotes(new_field.column)), | ||||
|                 } | ||||
|             ) | ||||
|             column = strip_quotes(new_field.column) | ||||
|             fragment, _ = super()._alter_column_type_sql( | ||||
|                 model, old_field, new_field, new_type, old_collation, new_collation | ||||
|             ) | ||||
|             # Drop the sequence if exists (Django 4.1+ identity columns don't | ||||
|             # have it). | ||||
|             other_actions = [] | ||||
|             if sequence_name := self._get_sequence_name(table, column): | ||||
|                 other_actions = [ | ||||
|                     ( | ||||
|                         self.sql_delete_sequence | ||||
|                         % { | ||||
|                             "sequence": self.quote_name(sequence_name), | ||||
|                         }, | ||||
|                         [], | ||||
|                     ) | ||||
|                 ] | ||||
|             return fragment, other_actions | ||||
|         elif new_is_auto and old_is_auto and old_internal_type != new_internal_type: | ||||
|             fragment, _ = super()._alter_column_type_sql( | ||||
|                 model, old_field, new_field, new_type, old_collation, new_collation | ||||
|             ) | ||||
|             column = strip_quotes(new_field.column) | ||||
|             db_types = { | ||||
|                 "AutoField": "integer", | ||||
|                 "BigAutoField": "bigint", | ||||
|                 "SmallAutoField": "smallint", | ||||
|             } | ||||
|             # Alter the sequence type if exists (Django 4.1+ identity columns | ||||
|             # don't have it). | ||||
|             other_actions = [] | ||||
|             if sequence_name := self._get_sequence_name(table, column): | ||||
|                 other_actions = [ | ||||
|                     ( | ||||
|                         self.sql_alter_sequence_type | ||||
|                         % { | ||||
|                             "sequence": self.quote_name(sequence_name), | ||||
|                             "type": db_types[new_internal_type], | ||||
|                         }, | ||||
|                         [], | ||||
|                     ), | ||||
|                 ] | ||||
|             return fragment, other_actions | ||||
|         else: | ||||
|             return super()._alter_column_type_sql( | ||||
|                 model, old_field, new_field, new_type, old_collation, new_collation | ||||
|             ) | ||||
|  | ||||
|     def _alter_column_collation_sql( | ||||
|         self, model, new_field, new_type, new_collation, old_field | ||||
|     ): | ||||
|         sql = self.sql_alter_column_collate | ||||
|         # Cast when data type changed. | ||||
|         if using_sql := self._using_sql(new_field, old_field): | ||||
|             sql += using_sql | ||||
|         return ( | ||||
|             sql | ||||
|             % { | ||||
|                 "column": self.quote_name(new_field.column), | ||||
|                 "type": new_type, | ||||
|                 "collation": " " + self._collate_sql(new_collation) | ||||
|                 if new_collation | ||||
|                 else "", | ||||
|             }, | ||||
|             [], | ||||
|         ) | ||||
|  | ||||
|     def _alter_field( | ||||
|         self, | ||||
|         model, | ||||
|         old_field, | ||||
|         new_field, | ||||
|         old_type, | ||||
|         new_type, | ||||
|         old_db_params, | ||||
|         new_db_params, | ||||
|         strict=False, | ||||
|     ): | ||||
|         super()._alter_field( | ||||
|             model, | ||||
|             old_field, | ||||
|             new_field, | ||||
|             old_type, | ||||
|             new_type, | ||||
|             old_db_params, | ||||
|             new_db_params, | ||||
|             strict, | ||||
|         ) | ||||
|         # Added an index? Create any PostgreSQL-specific indexes. | ||||
|         if (not (old_field.db_index or old_field.unique) and new_field.db_index) or ( | ||||
|             not old_field.unique and new_field.unique | ||||
|         ): | ||||
|             like_index_statement = self._create_like_index_sql(model, new_field) | ||||
|             if like_index_statement is not None: | ||||
|                 self.execute(like_index_statement) | ||||
|  | ||||
|         # Removed an index? Drop any PostgreSQL-specific indexes. | ||||
|         if old_field.unique and not (new_field.db_index or new_field.unique): | ||||
|             index_to_remove = self._create_index_name( | ||||
|                 model._meta.db_table, [old_field.column], suffix="_like" | ||||
|             ) | ||||
|             self.execute(self._delete_index_sql(model, index_to_remove)) | ||||
|  | ||||
|     def _index_columns(self, table, columns, col_suffixes, opclasses): | ||||
|         if opclasses: | ||||
|             return IndexColumns( | ||||
|                 table, | ||||
|                 columns, | ||||
|                 self.quote_name, | ||||
|                 col_suffixes=col_suffixes, | ||||
|                 opclasses=opclasses, | ||||
|             ) | ||||
|         return super()._index_columns(table, columns, col_suffixes, opclasses) | ||||
|  | ||||
|     def add_index(self, model, index, concurrently=False): | ||||
|         self.execute( | ||||
|             index.create_sql(model, self, concurrently=concurrently), params=None | ||||
|         ) | ||||
|  | ||||
|     def remove_index(self, model, index, concurrently=False): | ||||
|         self.execute(index.remove_sql(model, self, concurrently=concurrently)) | ||||
|  | ||||
|     def _delete_index_sql(self, model, name, sql=None, concurrently=False): | ||||
|         sql = ( | ||||
|             self.sql_delete_index_concurrently | ||||
|             if concurrently | ||||
|             else self.sql_delete_index | ||||
|         ) | ||||
|         return super()._delete_index_sql(model, name, sql) | ||||
|  | ||||
|     def _create_index_sql( | ||||
|         self, | ||||
|         model, | ||||
|         *, | ||||
|         fields=None, | ||||
|         name=None, | ||||
|         suffix="", | ||||
|         using="", | ||||
|         db_tablespace=None, | ||||
|         col_suffixes=(), | ||||
|         sql=None, | ||||
|         opclasses=(), | ||||
|         condition=None, | ||||
|         concurrently=False, | ||||
|         include=None, | ||||
|         expressions=None, | ||||
|     ): | ||||
|         sql = sql or ( | ||||
|             self.sql_create_index | ||||
|             if not concurrently | ||||
|             else self.sql_create_index_concurrently | ||||
|         ) | ||||
|         return super()._create_index_sql( | ||||
|             model, | ||||
|             fields=fields, | ||||
|             name=name, | ||||
|             suffix=suffix, | ||||
|             using=using, | ||||
|             db_tablespace=db_tablespace, | ||||
|             col_suffixes=col_suffixes, | ||||
|             sql=sql, | ||||
|             opclasses=opclasses, | ||||
|             condition=condition, | ||||
|             include=include, | ||||
|             expressions=expressions, | ||||
|         ) | ||||
| @ -0,0 +1,3 @@ | ||||
| from django.dispatch import Signal | ||||
|  | ||||
| connection_created = Signal() | ||||
| @ -0,0 +1,512 @@ | ||||
| """ | ||||
| Implementations of SQL functions for SQLite. | ||||
| """ | ||||
| import functools | ||||
| import random | ||||
| import statistics | ||||
| from datetime import timedelta | ||||
| from hashlib import sha1, sha224, sha256, sha384, sha512 | ||||
| from math import ( | ||||
|     acos, | ||||
|     asin, | ||||
|     atan, | ||||
|     atan2, | ||||
|     ceil, | ||||
|     cos, | ||||
|     degrees, | ||||
|     exp, | ||||
|     floor, | ||||
|     fmod, | ||||
|     log, | ||||
|     pi, | ||||
|     radians, | ||||
|     sin, | ||||
|     sqrt, | ||||
|     tan, | ||||
| ) | ||||
| from re import search as re_search | ||||
|  | ||||
| from django.db.backends.base.base import timezone_constructor | ||||
| from django.db.backends.utils import ( | ||||
|     split_tzname_delta, | ||||
|     typecast_time, | ||||
|     typecast_timestamp, | ||||
| ) | ||||
| from django.utils import timezone | ||||
| from django.utils.crypto import md5 | ||||
| from django.utils.duration import duration_microseconds | ||||
|  | ||||
|  | ||||
| def register(connection): | ||||
|     create_deterministic_function = functools.partial( | ||||
|         connection.create_function, | ||||
|         deterministic=True, | ||||
|     ) | ||||
|     create_deterministic_function("django_date_extract", 2, _sqlite_datetime_extract) | ||||
|     create_deterministic_function("django_date_trunc", 4, _sqlite_date_trunc) | ||||
|     create_deterministic_function( | ||||
|         "django_datetime_cast_date", 3, _sqlite_datetime_cast_date | ||||
|     ) | ||||
|     create_deterministic_function( | ||||
|         "django_datetime_cast_time", 3, _sqlite_datetime_cast_time | ||||
|     ) | ||||
|     create_deterministic_function( | ||||
|         "django_datetime_extract", 4, _sqlite_datetime_extract | ||||
|     ) | ||||
|     create_deterministic_function("django_datetime_trunc", 4, _sqlite_datetime_trunc) | ||||
|     create_deterministic_function("django_time_extract", 2, _sqlite_time_extract) | ||||
|     create_deterministic_function("django_time_trunc", 4, _sqlite_time_trunc) | ||||
|     create_deterministic_function("django_time_diff", 2, _sqlite_time_diff) | ||||
|     create_deterministic_function("django_timestamp_diff", 2, _sqlite_timestamp_diff) | ||||
|     create_deterministic_function("django_format_dtdelta", 3, _sqlite_format_dtdelta) | ||||
|     create_deterministic_function("regexp", 2, _sqlite_regexp) | ||||
|     create_deterministic_function("BITXOR", 2, _sqlite_bitxor) | ||||
|     create_deterministic_function("COT", 1, _sqlite_cot) | ||||
|     create_deterministic_function("LPAD", 3, _sqlite_lpad) | ||||
|     create_deterministic_function("MD5", 1, _sqlite_md5) | ||||
|     create_deterministic_function("REPEAT", 2, _sqlite_repeat) | ||||
|     create_deterministic_function("REVERSE", 1, _sqlite_reverse) | ||||
|     create_deterministic_function("RPAD", 3, _sqlite_rpad) | ||||
|     create_deterministic_function("SHA1", 1, _sqlite_sha1) | ||||
|     create_deterministic_function("SHA224", 1, _sqlite_sha224) | ||||
|     create_deterministic_function("SHA256", 1, _sqlite_sha256) | ||||
|     create_deterministic_function("SHA384", 1, _sqlite_sha384) | ||||
|     create_deterministic_function("SHA512", 1, _sqlite_sha512) | ||||
|     create_deterministic_function("SIGN", 1, _sqlite_sign) | ||||
|     # Don't use the built-in RANDOM() function because it returns a value | ||||
|     # in the range [-1 * 2^63, 2^63 - 1] instead of [0, 1). | ||||
|     connection.create_function("RAND", 0, random.random) | ||||
|     connection.create_aggregate("STDDEV_POP", 1, StdDevPop) | ||||
|     connection.create_aggregate("STDDEV_SAMP", 1, StdDevSamp) | ||||
|     connection.create_aggregate("VAR_POP", 1, VarPop) | ||||
|     connection.create_aggregate("VAR_SAMP", 1, VarSamp) | ||||
|     # Some math functions are enabled by default in SQLite 3.35+. | ||||
|     sql = "select sqlite_compileoption_used('ENABLE_MATH_FUNCTIONS')" | ||||
|     if not connection.execute(sql).fetchone()[0]: | ||||
|         create_deterministic_function("ACOS", 1, _sqlite_acos) | ||||
|         create_deterministic_function("ASIN", 1, _sqlite_asin) | ||||
|         create_deterministic_function("ATAN", 1, _sqlite_atan) | ||||
|         create_deterministic_function("ATAN2", 2, _sqlite_atan2) | ||||
|         create_deterministic_function("CEILING", 1, _sqlite_ceiling) | ||||
|         create_deterministic_function("COS", 1, _sqlite_cos) | ||||
|         create_deterministic_function("DEGREES", 1, _sqlite_degrees) | ||||
|         create_deterministic_function("EXP", 1, _sqlite_exp) | ||||
|         create_deterministic_function("FLOOR", 1, _sqlite_floor) | ||||
|         create_deterministic_function("LN", 1, _sqlite_ln) | ||||
|         create_deterministic_function("LOG", 2, _sqlite_log) | ||||
|         create_deterministic_function("MOD", 2, _sqlite_mod) | ||||
|         create_deterministic_function("PI", 0, _sqlite_pi) | ||||
|         create_deterministic_function("POWER", 2, _sqlite_power) | ||||
|         create_deterministic_function("RADIANS", 1, _sqlite_radians) | ||||
|         create_deterministic_function("SIN", 1, _sqlite_sin) | ||||
|         create_deterministic_function("SQRT", 1, _sqlite_sqrt) | ||||
|         create_deterministic_function("TAN", 1, _sqlite_tan) | ||||
|  | ||||
|  | ||||
| def _sqlite_datetime_parse(dt, tzname=None, conn_tzname=None): | ||||
|     if dt is None: | ||||
|         return None | ||||
|     try: | ||||
|         dt = typecast_timestamp(dt) | ||||
|     except (TypeError, ValueError): | ||||
|         return None | ||||
|     if conn_tzname: | ||||
|         dt = dt.replace(tzinfo=timezone_constructor(conn_tzname)) | ||||
|     if tzname is not None and tzname != conn_tzname: | ||||
|         tzname, sign, offset = split_tzname_delta(tzname) | ||||
|         if offset: | ||||
|             hours, minutes = offset.split(":") | ||||
|             offset_delta = timedelta(hours=int(hours), minutes=int(minutes)) | ||||
|             dt += offset_delta if sign == "+" else -offset_delta | ||||
|         dt = timezone.localtime(dt, timezone_constructor(tzname)) | ||||
|     return dt | ||||
|  | ||||
|  | ||||
| def _sqlite_date_trunc(lookup_type, dt, tzname, conn_tzname): | ||||
|     dt = _sqlite_datetime_parse(dt, tzname, conn_tzname) | ||||
|     if dt is None: | ||||
|         return None | ||||
|     if lookup_type == "year": | ||||
|         return f"{dt.year:04d}-01-01" | ||||
|     elif lookup_type == "quarter": | ||||
|         month_in_quarter = dt.month - (dt.month - 1) % 3 | ||||
|         return f"{dt.year:04d}-{month_in_quarter:02d}-01" | ||||
|     elif lookup_type == "month": | ||||
|         return f"{dt.year:04d}-{dt.month:02d}-01" | ||||
|     elif lookup_type == "week": | ||||
|         dt -= timedelta(days=dt.weekday()) | ||||
|         return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d}" | ||||
|     elif lookup_type == "day": | ||||
|         return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d}" | ||||
|     raise ValueError(f"Unsupported lookup type: {lookup_type!r}") | ||||
|  | ||||
|  | ||||
| def _sqlite_time_trunc(lookup_type, dt, tzname, conn_tzname): | ||||
|     if dt is None: | ||||
|         return None | ||||
|     dt_parsed = _sqlite_datetime_parse(dt, tzname, conn_tzname) | ||||
|     if dt_parsed is None: | ||||
|         try: | ||||
|             dt = typecast_time(dt) | ||||
|         except (ValueError, TypeError): | ||||
|             return None | ||||
|     else: | ||||
|         dt = dt_parsed | ||||
|     if lookup_type == "hour": | ||||
|         return f"{dt.hour:02d}:00:00" | ||||
|     elif lookup_type == "minute": | ||||
|         return f"{dt.hour:02d}:{dt.minute:02d}:00" | ||||
|     elif lookup_type == "second": | ||||
|         return f"{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}" | ||||
|     raise ValueError(f"Unsupported lookup type: {lookup_type!r}") | ||||
|  | ||||
|  | ||||
| def _sqlite_datetime_cast_date(dt, tzname, conn_tzname): | ||||
|     dt = _sqlite_datetime_parse(dt, tzname, conn_tzname) | ||||
|     if dt is None: | ||||
|         return None | ||||
|     return dt.date().isoformat() | ||||
|  | ||||
|  | ||||
| def _sqlite_datetime_cast_time(dt, tzname, conn_tzname): | ||||
|     dt = _sqlite_datetime_parse(dt, tzname, conn_tzname) | ||||
|     if dt is None: | ||||
|         return None | ||||
|     return dt.time().isoformat() | ||||
|  | ||||
|  | ||||
| def _sqlite_datetime_extract(lookup_type, dt, tzname=None, conn_tzname=None): | ||||
|     dt = _sqlite_datetime_parse(dt, tzname, conn_tzname) | ||||
|     if dt is None: | ||||
|         return None | ||||
|     if lookup_type == "week_day": | ||||
|         return (dt.isoweekday() % 7) + 1 | ||||
|     elif lookup_type == "iso_week_day": | ||||
|         return dt.isoweekday() | ||||
|     elif lookup_type == "week": | ||||
|         return dt.isocalendar()[1] | ||||
|     elif lookup_type == "quarter": | ||||
|         return ceil(dt.month / 3) | ||||
|     elif lookup_type == "iso_year": | ||||
|         return dt.isocalendar()[0] | ||||
|     else: | ||||
|         return getattr(dt, lookup_type) | ||||
|  | ||||
|  | ||||
| def _sqlite_datetime_trunc(lookup_type, dt, tzname, conn_tzname): | ||||
|     dt = _sqlite_datetime_parse(dt, tzname, conn_tzname) | ||||
|     if dt is None: | ||||
|         return None | ||||
|     if lookup_type == "year": | ||||
|         return f"{dt.year:04d}-01-01 00:00:00" | ||||
|     elif lookup_type == "quarter": | ||||
|         month_in_quarter = dt.month - (dt.month - 1) % 3 | ||||
|         return f"{dt.year:04d}-{month_in_quarter:02d}-01 00:00:00" | ||||
|     elif lookup_type == "month": | ||||
|         return f"{dt.year:04d}-{dt.month:02d}-01 00:00:00" | ||||
|     elif lookup_type == "week": | ||||
|         dt -= timedelta(days=dt.weekday()) | ||||
|         return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} 00:00:00" | ||||
|     elif lookup_type == "day": | ||||
|         return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} 00:00:00" | ||||
|     elif lookup_type == "hour": | ||||
|         return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} {dt.hour:02d}:00:00" | ||||
|     elif lookup_type == "minute": | ||||
|         return ( | ||||
|             f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} " | ||||
|             f"{dt.hour:02d}:{dt.minute:02d}:00" | ||||
|         ) | ||||
|     elif lookup_type == "second": | ||||
|         return ( | ||||
|             f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} " | ||||
|             f"{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}" | ||||
|         ) | ||||
|     raise ValueError(f"Unsupported lookup type: {lookup_type!r}") | ||||
|  | ||||
|  | ||||
| def _sqlite_time_extract(lookup_type, dt): | ||||
|     if dt is None: | ||||
|         return None | ||||
|     try: | ||||
|         dt = typecast_time(dt) | ||||
|     except (ValueError, TypeError): | ||||
|         return None | ||||
|     return getattr(dt, lookup_type) | ||||
|  | ||||
|  | ||||
| def _sqlite_prepare_dtdelta_param(conn, param): | ||||
|     if conn in ["+", "-"]: | ||||
|         if isinstance(param, int): | ||||
|             return timedelta(0, 0, param) | ||||
|         else: | ||||
|             return typecast_timestamp(param) | ||||
|     return param | ||||
|  | ||||
|  | ||||
| def _sqlite_format_dtdelta(connector, lhs, rhs): | ||||
|     """ | ||||
|     LHS and RHS can be either: | ||||
|     - An integer number of microseconds | ||||
|     - A string representing a datetime | ||||
|     - A scalar value, e.g. float | ||||
|     """ | ||||
|     if connector is None or lhs is None or rhs is None: | ||||
|         return None | ||||
|     connector = connector.strip() | ||||
|     try: | ||||
|         real_lhs = _sqlite_prepare_dtdelta_param(connector, lhs) | ||||
|         real_rhs = _sqlite_prepare_dtdelta_param(connector, rhs) | ||||
|     except (ValueError, TypeError): | ||||
|         return None | ||||
|     if connector == "+": | ||||
|         # typecast_timestamp() returns a date or a datetime without timezone. | ||||
|         # It will be formatted as "%Y-%m-%d" or "%Y-%m-%d %H:%M:%S[.%f]" | ||||
|         out = str(real_lhs + real_rhs) | ||||
|     elif connector == "-": | ||||
|         out = str(real_lhs - real_rhs) | ||||
|     elif connector == "*": | ||||
|         out = real_lhs * real_rhs | ||||
|     else: | ||||
|         out = real_lhs / real_rhs | ||||
|     return out | ||||
|  | ||||
|  | ||||
| def _sqlite_time_diff(lhs, rhs): | ||||
|     if lhs is None or rhs is None: | ||||
|         return None | ||||
|     left = typecast_time(lhs) | ||||
|     right = typecast_time(rhs) | ||||
|     return ( | ||||
|         (left.hour * 60 * 60 * 1000000) | ||||
|         + (left.minute * 60 * 1000000) | ||||
|         + (left.second * 1000000) | ||||
|         + (left.microsecond) | ||||
|         - (right.hour * 60 * 60 * 1000000) | ||||
|         - (right.minute * 60 * 1000000) | ||||
|         - (right.second * 1000000) | ||||
|         - (right.microsecond) | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def _sqlite_timestamp_diff(lhs, rhs): | ||||
|     if lhs is None or rhs is None: | ||||
|         return None | ||||
|     left = typecast_timestamp(lhs) | ||||
|     right = typecast_timestamp(rhs) | ||||
|     return duration_microseconds(left - right) | ||||
|  | ||||
|  | ||||
| def _sqlite_regexp(pattern, string): | ||||
|     if pattern is None or string is None: | ||||
|         return None | ||||
|     if not isinstance(string, str): | ||||
|         string = str(string) | ||||
|     return bool(re_search(pattern, string)) | ||||
|  | ||||
|  | ||||
| def _sqlite_acos(x): | ||||
|     if x is None: | ||||
|         return None | ||||
|     return acos(x) | ||||
|  | ||||
|  | ||||
| def _sqlite_asin(x): | ||||
|     if x is None: | ||||
|         return None | ||||
|     return asin(x) | ||||
|  | ||||
|  | ||||
| def _sqlite_atan(x): | ||||
|     if x is None: | ||||
|         return None | ||||
|     return atan(x) | ||||
|  | ||||
|  | ||||
| def _sqlite_atan2(y, x): | ||||
|     if y is None or x is None: | ||||
|         return None | ||||
|     return atan2(y, x) | ||||
|  | ||||
|  | ||||
| def _sqlite_bitxor(x, y): | ||||
|     if x is None or y is None: | ||||
|         return None | ||||
|     return x ^ y | ||||
|  | ||||
|  | ||||
| def _sqlite_ceiling(x): | ||||
|     if x is None: | ||||
|         return None | ||||
|     return ceil(x) | ||||
|  | ||||
|  | ||||
| def _sqlite_cos(x): | ||||
|     if x is None: | ||||
|         return None | ||||
|     return cos(x) | ||||
|  | ||||
|  | ||||
| def _sqlite_cot(x): | ||||
|     if x is None: | ||||
|         return None | ||||
|     return 1 / tan(x) | ||||
|  | ||||
|  | ||||
| def _sqlite_degrees(x): | ||||
|     if x is None: | ||||
|         return None | ||||
|     return degrees(x) | ||||
|  | ||||
|  | ||||
| def _sqlite_exp(x): | ||||
|     if x is None: | ||||
|         return None | ||||
|     return exp(x) | ||||
|  | ||||
|  | ||||
| def _sqlite_floor(x): | ||||
|     if x is None: | ||||
|         return None | ||||
|     return floor(x) | ||||
|  | ||||
|  | ||||
| def _sqlite_ln(x): | ||||
|     if x is None: | ||||
|         return None | ||||
|     return log(x) | ||||
|  | ||||
|  | ||||
| def _sqlite_log(base, x): | ||||
|     if base is None or x is None: | ||||
|         return None | ||||
|     # Arguments reversed to match SQL standard. | ||||
|     return log(x, base) | ||||
|  | ||||
|  | ||||
| def _sqlite_lpad(text, length, fill_text): | ||||
|     if text is None or length is None or fill_text is None: | ||||
|         return None | ||||
|     delta = length - len(text) | ||||
|     if delta <= 0: | ||||
|         return text[:length] | ||||
|     return (fill_text * length)[:delta] + text | ||||
|  | ||||
|  | ||||
| def _sqlite_md5(text): | ||||
|     if text is None: | ||||
|         return None | ||||
|     return md5(text.encode()).hexdigest() | ||||
|  | ||||
|  | ||||
| def _sqlite_mod(x, y): | ||||
|     if x is None or y is None: | ||||
|         return None | ||||
|     return fmod(x, y) | ||||
|  | ||||
|  | ||||
| def _sqlite_pi(): | ||||
|     return pi | ||||
|  | ||||
|  | ||||
| def _sqlite_power(x, y): | ||||
|     if x is None or y is None: | ||||
|         return None | ||||
|     return x**y | ||||
|  | ||||
|  | ||||
| def _sqlite_radians(x): | ||||
|     if x is None: | ||||
|         return None | ||||
|     return radians(x) | ||||
|  | ||||
|  | ||||
| def _sqlite_repeat(text, count): | ||||
|     if text is None or count is None: | ||||
|         return None | ||||
|     return text * count | ||||
|  | ||||
|  | ||||
| def _sqlite_reverse(text): | ||||
|     if text is None: | ||||
|         return None | ||||
|     return text[::-1] | ||||
|  | ||||
|  | ||||
| def _sqlite_rpad(text, length, fill_text): | ||||
|     if text is None or length is None or fill_text is None: | ||||
|         return None | ||||
|     return (text + fill_text * length)[:length] | ||||
|  | ||||
|  | ||||
| def _sqlite_sha1(text): | ||||
|     if text is None: | ||||
|         return None | ||||
|     return sha1(text.encode()).hexdigest() | ||||
|  | ||||
|  | ||||
| def _sqlite_sha224(text): | ||||
|     if text is None: | ||||
|         return None | ||||
|     return sha224(text.encode()).hexdigest() | ||||
|  | ||||
|  | ||||
| def _sqlite_sha256(text): | ||||
|     if text is None: | ||||
|         return None | ||||
|     return sha256(text.encode()).hexdigest() | ||||
|  | ||||
|  | ||||
| def _sqlite_sha384(text): | ||||
|     if text is None: | ||||
|         return None | ||||
|     return sha384(text.encode()).hexdigest() | ||||
|  | ||||
|  | ||||
| def _sqlite_sha512(text): | ||||
|     if text is None: | ||||
|         return None | ||||
|     return sha512(text.encode()).hexdigest() | ||||
|  | ||||
|  | ||||
| def _sqlite_sign(x): | ||||
|     if x is None: | ||||
|         return None | ||||
|     return (x > 0) - (x < 0) | ||||
|  | ||||
|  | ||||
| def _sqlite_sin(x): | ||||
|     if x is None: | ||||
|         return None | ||||
|     return sin(x) | ||||
|  | ||||
|  | ||||
| def _sqlite_sqrt(x): | ||||
|     if x is None: | ||||
|         return None | ||||
|     return sqrt(x) | ||||
|  | ||||
|  | ||||
| def _sqlite_tan(x): | ||||
|     if x is None: | ||||
|         return None | ||||
|     return tan(x) | ||||
|  | ||||
|  | ||||
| class ListAggregate(list): | ||||
|     step = list.append | ||||
|  | ||||
|  | ||||
| class StdDevPop(ListAggregate): | ||||
|     finalize = statistics.pstdev | ||||
|  | ||||
|  | ||||
| class StdDevSamp(ListAggregate): | ||||
|     finalize = statistics.stdev | ||||
|  | ||||
|  | ||||
| class VarPop(ListAggregate): | ||||
|     finalize = statistics.pvariance | ||||
|  | ||||
|  | ||||
| class VarSamp(ListAggregate): | ||||
|     finalize = statistics.variance | ||||
| @ -0,0 +1,347 @@ | ||||
| """ | ||||
| SQLite backend for the sqlite3 module in the standard library. | ||||
| """ | ||||
| import datetime | ||||
| import decimal | ||||
| import warnings | ||||
| from collections.abc import Mapping | ||||
| from itertools import chain, tee | ||||
| from sqlite3 import dbapi2 as Database | ||||
|  | ||||
| from django.core.exceptions import ImproperlyConfigured | ||||
| from django.db import IntegrityError | ||||
| from django.db.backends.base.base import BaseDatabaseWrapper | ||||
| from django.utils.asyncio import async_unsafe | ||||
| from django.utils.dateparse import parse_date, parse_datetime, parse_time | ||||
| from django.utils.regex_helper import _lazy_re_compile | ||||
|  | ||||
| from ._functions import register as register_functions | ||||
| from .client import DatabaseClient | ||||
| from .creation import DatabaseCreation | ||||
| from .features import DatabaseFeatures | ||||
| from .introspection import DatabaseIntrospection | ||||
| from .operations import DatabaseOperations | ||||
| from .schema import DatabaseSchemaEditor | ||||
|  | ||||
|  | ||||
| def decoder(conv_func): | ||||
|     """ | ||||
|     Convert bytestrings from Python's sqlite3 interface to a regular string. | ||||
|     """ | ||||
|     return lambda s: conv_func(s.decode()) | ||||
|  | ||||
|  | ||||
| def adapt_date(val): | ||||
|     return val.isoformat() | ||||
|  | ||||
|  | ||||
| def adapt_datetime(val): | ||||
|     return val.isoformat(" ") | ||||
|  | ||||
|  | ||||
| Database.register_converter("bool", b"1".__eq__) | ||||
| Database.register_converter("date", decoder(parse_date)) | ||||
| Database.register_converter("time", decoder(parse_time)) | ||||
| Database.register_converter("datetime", decoder(parse_datetime)) | ||||
| Database.register_converter("timestamp", decoder(parse_datetime)) | ||||
|  | ||||
| Database.register_adapter(decimal.Decimal, str) | ||||
| Database.register_adapter(datetime.date, adapt_date) | ||||
| Database.register_adapter(datetime.datetime, adapt_datetime) | ||||
|  | ||||
|  | ||||
| class DatabaseWrapper(BaseDatabaseWrapper): | ||||
|     vendor = "sqlite" | ||||
|     display_name = "SQLite" | ||||
|     # SQLite doesn't actually support most of these types, but it "does the right | ||||
|     # thing" given more verbose field definitions, so leave them as is so that | ||||
|     # schema inspection is more useful. | ||||
|     data_types = { | ||||
|         "AutoField": "integer", | ||||
|         "BigAutoField": "integer", | ||||
|         "BinaryField": "BLOB", | ||||
|         "BooleanField": "bool", | ||||
|         "CharField": "varchar(%(max_length)s)", | ||||
|         "DateField": "date", | ||||
|         "DateTimeField": "datetime", | ||||
|         "DecimalField": "decimal", | ||||
|         "DurationField": "bigint", | ||||
|         "FileField": "varchar(%(max_length)s)", | ||||
|         "FilePathField": "varchar(%(max_length)s)", | ||||
|         "FloatField": "real", | ||||
|         "IntegerField": "integer", | ||||
|         "BigIntegerField": "bigint", | ||||
|         "IPAddressField": "char(15)", | ||||
|         "GenericIPAddressField": "char(39)", | ||||
|         "JSONField": "text", | ||||
|         "OneToOneField": "integer", | ||||
|         "PositiveBigIntegerField": "bigint unsigned", | ||||
|         "PositiveIntegerField": "integer unsigned", | ||||
|         "PositiveSmallIntegerField": "smallint unsigned", | ||||
|         "SlugField": "varchar(%(max_length)s)", | ||||
|         "SmallAutoField": "integer", | ||||
|         "SmallIntegerField": "smallint", | ||||
|         "TextField": "text", | ||||
|         "TimeField": "time", | ||||
|         "UUIDField": "char(32)", | ||||
|     } | ||||
|     data_type_check_constraints = { | ||||
|         "PositiveBigIntegerField": '"%(column)s" >= 0', | ||||
|         "JSONField": '(JSON_VALID("%(column)s") OR "%(column)s" IS NULL)', | ||||
|         "PositiveIntegerField": '"%(column)s" >= 0', | ||||
|         "PositiveSmallIntegerField": '"%(column)s" >= 0', | ||||
|     } | ||||
|     data_types_suffix = { | ||||
|         "AutoField": "AUTOINCREMENT", | ||||
|         "BigAutoField": "AUTOINCREMENT", | ||||
|         "SmallAutoField": "AUTOINCREMENT", | ||||
|     } | ||||
|     # SQLite requires LIKE statements to include an ESCAPE clause if the value | ||||
|     # being escaped has a percent or underscore in it. | ||||
|     # See https://www.sqlite.org/lang_expr.html for an explanation. | ||||
|     operators = { | ||||
|         "exact": "= %s", | ||||
|         "iexact": "LIKE %s ESCAPE '\\'", | ||||
|         "contains": "LIKE %s ESCAPE '\\'", | ||||
|         "icontains": "LIKE %s ESCAPE '\\'", | ||||
|         "regex": "REGEXP %s", | ||||
|         "iregex": "REGEXP '(?i)' || %s", | ||||
|         "gt": "> %s", | ||||
|         "gte": ">= %s", | ||||
|         "lt": "< %s", | ||||
|         "lte": "<= %s", | ||||
|         "startswith": "LIKE %s ESCAPE '\\'", | ||||
|         "endswith": "LIKE %s ESCAPE '\\'", | ||||
|         "istartswith": "LIKE %s ESCAPE '\\'", | ||||
|         "iendswith": "LIKE %s ESCAPE '\\'", | ||||
|     } | ||||
|  | ||||
|     # The patterns below are used to generate SQL pattern lookup clauses when | ||||
|     # the right-hand side of the lookup isn't a raw string (it might be an expression | ||||
|     # or the result of a bilateral transformation). | ||||
|     # In those cases, special characters for LIKE operators (e.g. \, *, _) should be | ||||
|     # escaped on database side. | ||||
|     # | ||||
|     # Note: we use str.format() here for readability as '%' is used as a wildcard for | ||||
|     # the LIKE operator. | ||||
|     pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\', '\\'), '%%', '\%%'), '_', '\_')" | ||||
|     pattern_ops = { | ||||
|         "contains": r"LIKE '%%' || {} || '%%' ESCAPE '\'", | ||||
|         "icontains": r"LIKE '%%' || UPPER({}) || '%%' ESCAPE '\'", | ||||
|         "startswith": r"LIKE {} || '%%' ESCAPE '\'", | ||||
|         "istartswith": r"LIKE UPPER({}) || '%%' ESCAPE '\'", | ||||
|         "endswith": r"LIKE '%%' || {} ESCAPE '\'", | ||||
|         "iendswith": r"LIKE '%%' || UPPER({}) ESCAPE '\'", | ||||
|     } | ||||
|  | ||||
|     Database = Database | ||||
|     SchemaEditorClass = DatabaseSchemaEditor | ||||
|     # Classes instantiated in __init__(). | ||||
|     client_class = DatabaseClient | ||||
|     creation_class = DatabaseCreation | ||||
|     features_class = DatabaseFeatures | ||||
|     introspection_class = DatabaseIntrospection | ||||
|     ops_class = DatabaseOperations | ||||
|  | ||||
|     def get_connection_params(self): | ||||
|         settings_dict = self.settings_dict | ||||
|         if not settings_dict["NAME"]: | ||||
|             raise ImproperlyConfigured( | ||||
|                 "settings.DATABASES is improperly configured. " | ||||
|                 "Please supply the NAME value." | ||||
|             ) | ||||
|         kwargs = { | ||||
|             "database": settings_dict["NAME"], | ||||
|             "detect_types": Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES, | ||||
|             **settings_dict["OPTIONS"], | ||||
|         } | ||||
|         # Always allow the underlying SQLite connection to be shareable | ||||
|         # between multiple threads. The safe-guarding will be handled at a | ||||
|         # higher level by the `BaseDatabaseWrapper.allow_thread_sharing` | ||||
|         # property. This is necessary as the shareability is disabled by | ||||
|         # default in sqlite3 and it cannot be changed once a connection is | ||||
|         # opened. | ||||
|         if "check_same_thread" in kwargs and kwargs["check_same_thread"]: | ||||
|             warnings.warn( | ||||
|                 "The `check_same_thread` option was provided and set to " | ||||
|                 "True. It will be overridden with False. Use the " | ||||
|                 "`DatabaseWrapper.allow_thread_sharing` property instead " | ||||
|                 "for controlling thread shareability.", | ||||
|                 RuntimeWarning, | ||||
|             ) | ||||
|         kwargs.update({"check_same_thread": False, "uri": True}) | ||||
|         return kwargs | ||||
|  | ||||
|     def get_database_version(self): | ||||
|         return self.Database.sqlite_version_info | ||||
|  | ||||
|     @async_unsafe | ||||
|     def get_new_connection(self, conn_params): | ||||
|         conn = Database.connect(**conn_params) | ||||
|         register_functions(conn) | ||||
|  | ||||
|         conn.execute("PRAGMA foreign_keys = ON") | ||||
|         # The macOS bundled SQLite defaults legacy_alter_table ON, which | ||||
|         # prevents atomic table renames (feature supports_atomic_references_rename) | ||||
|         conn.execute("PRAGMA legacy_alter_table = OFF") | ||||
|         return conn | ||||
|  | ||||
|     def create_cursor(self, name=None): | ||||
|         return self.connection.cursor(factory=SQLiteCursorWrapper) | ||||
|  | ||||
|     @async_unsafe | ||||
|     def close(self): | ||||
|         self.validate_thread_sharing() | ||||
|         # If database is in memory, closing the connection destroys the | ||||
|         # database. To prevent accidental data loss, ignore close requests on | ||||
|         # an in-memory db. | ||||
|         if not self.is_in_memory_db(): | ||||
|             BaseDatabaseWrapper.close(self) | ||||
|  | ||||
|     def _savepoint_allowed(self): | ||||
|         # When 'isolation_level' is not None, sqlite3 commits before each | ||||
|         # savepoint; it's a bug. When it is None, savepoints don't make sense | ||||
|         # because autocommit is enabled. The only exception is inside 'atomic' | ||||
|         # blocks. To work around that bug, on SQLite, 'atomic' starts a | ||||
|         # transaction explicitly rather than simply disable autocommit. | ||||
|         return self.in_atomic_block | ||||
|  | ||||
|     def _set_autocommit(self, autocommit): | ||||
|         if autocommit: | ||||
|             level = None | ||||
|         else: | ||||
|             # sqlite3's internal default is ''. It's different from None. | ||||
|             # See Modules/_sqlite/connection.c. | ||||
|             level = "" | ||||
|         # 'isolation_level' is a misleading API. | ||||
|         # SQLite always runs at the SERIALIZABLE isolation level. | ||||
|         with self.wrap_database_errors: | ||||
|             self.connection.isolation_level = level | ||||
|  | ||||
|     def disable_constraint_checking(self): | ||||
|         with self.cursor() as cursor: | ||||
|             cursor.execute("PRAGMA foreign_keys = OFF") | ||||
|             # Foreign key constraints cannot be turned off while in a multi- | ||||
|             # statement transaction. Fetch the current state of the pragma | ||||
|             # to determine if constraints are effectively disabled. | ||||
|             enabled = cursor.execute("PRAGMA foreign_keys").fetchone()[0] | ||||
|         return not bool(enabled) | ||||
|  | ||||
|     def enable_constraint_checking(self): | ||||
|         with self.cursor() as cursor: | ||||
|             cursor.execute("PRAGMA foreign_keys = ON") | ||||
|  | ||||
|     def check_constraints(self, table_names=None): | ||||
|         """ | ||||
|         Check each table name in `table_names` for rows with invalid foreign | ||||
|         key references. This method is intended to be used in conjunction with | ||||
|         `disable_constraint_checking()` and `enable_constraint_checking()`, to | ||||
|         determine if rows with invalid references were entered while constraint | ||||
|         checks were off. | ||||
|         """ | ||||
|         with self.cursor() as cursor: | ||||
|             if table_names is None: | ||||
|                 violations = cursor.execute("PRAGMA foreign_key_check").fetchall() | ||||
|             else: | ||||
|                 violations = chain.from_iterable( | ||||
|                     cursor.execute( | ||||
|                         "PRAGMA foreign_key_check(%s)" % self.ops.quote_name(table_name) | ||||
|                     ).fetchall() | ||||
|                     for table_name in table_names | ||||
|                 ) | ||||
|             # See https://www.sqlite.org/pragma.html#pragma_foreign_key_check | ||||
|             for ( | ||||
|                 table_name, | ||||
|                 rowid, | ||||
|                 referenced_table_name, | ||||
|                 foreign_key_index, | ||||
|             ) in violations: | ||||
|                 foreign_key = cursor.execute( | ||||
|                     "PRAGMA foreign_key_list(%s)" % self.ops.quote_name(table_name) | ||||
|                 ).fetchall()[foreign_key_index] | ||||
|                 column_name, referenced_column_name = foreign_key[3:5] | ||||
|                 primary_key_column_name = self.introspection.get_primary_key_column( | ||||
|                     cursor, table_name | ||||
|                 ) | ||||
|                 primary_key_value, bad_value = cursor.execute( | ||||
|                     "SELECT %s, %s FROM %s WHERE rowid = %%s" | ||||
|                     % ( | ||||
|                         self.ops.quote_name(primary_key_column_name), | ||||
|                         self.ops.quote_name(column_name), | ||||
|                         self.ops.quote_name(table_name), | ||||
|                     ), | ||||
|                     (rowid,), | ||||
|                 ).fetchone() | ||||
|                 raise IntegrityError( | ||||
|                     "The row in table '%s' with primary key '%s' has an " | ||||
|                     "invalid foreign key: %s.%s contains a value '%s' that " | ||||
|                     "does not have a corresponding value in %s.%s." | ||||
|                     % ( | ||||
|                         table_name, | ||||
|                         primary_key_value, | ||||
|                         table_name, | ||||
|                         column_name, | ||||
|                         bad_value, | ||||
|                         referenced_table_name, | ||||
|                         referenced_column_name, | ||||
|                     ) | ||||
|                 ) | ||||
|  | ||||
|     def is_usable(self): | ||||
|         return True | ||||
|  | ||||
|     def _start_transaction_under_autocommit(self): | ||||
|         """ | ||||
|         Start a transaction explicitly in autocommit mode. | ||||
|  | ||||
|         Staying in autocommit mode works around a bug of sqlite3 that breaks | ||||
|         savepoints when autocommit is disabled. | ||||
|         """ | ||||
|         self.cursor().execute("BEGIN") | ||||
|  | ||||
|     def is_in_memory_db(self): | ||||
|         return self.creation.is_in_memory_db(self.settings_dict["NAME"]) | ||||
|  | ||||
|  | ||||
| FORMAT_QMARK_REGEX = _lazy_re_compile(r"(?<!%)%s") | ||||
|  | ||||
|  | ||||
| class SQLiteCursorWrapper(Database.Cursor): | ||||
|     """ | ||||
|     Django uses the "format" and "pyformat" styles, but Python's sqlite3 module | ||||
|     supports neither of these styles. | ||||
|  | ||||
|     This wrapper performs the following conversions: | ||||
|  | ||||
|     - "format" style to "qmark" style | ||||
|     - "pyformat" style to "named" style | ||||
|  | ||||
|     In both cases, if you want to use a literal "%s", you'll need to use "%%s". | ||||
|     """ | ||||
|  | ||||
|     def execute(self, query, params=None): | ||||
|         if params is None: | ||||
|             return super().execute(query) | ||||
|         # Extract names if params is a mapping, i.e. "pyformat" style is used. | ||||
|         param_names = list(params) if isinstance(params, Mapping) else None | ||||
|         query = self.convert_query(query, param_names=param_names) | ||||
|         return super().execute(query, params) | ||||
|  | ||||
|     def executemany(self, query, param_list): | ||||
|         # Extract names if params is a mapping, i.e. "pyformat" style is used. | ||||
|         # Peek carefully as a generator can be passed instead of a list/tuple. | ||||
|         peekable, param_list = tee(iter(param_list)) | ||||
|         if (params := next(peekable, None)) and isinstance(params, Mapping): | ||||
|             param_names = list(params) | ||||
|         else: | ||||
|             param_names = None | ||||
|         query = self.convert_query(query, param_names=param_names) | ||||
|         return super().executemany(query, param_list) | ||||
|  | ||||
|     def convert_query(self, query, *, param_names=None): | ||||
|         if param_names is None: | ||||
|             # Convert from "format" style to "qmark" style. | ||||
|             return FORMAT_QMARK_REGEX.sub("?", query).replace("%%", "%") | ||||
|         else: | ||||
|             # Convert from "pyformat" style to "named" style. | ||||
|             return query % {name: f":{name}" for name in param_names} | ||||
| @ -0,0 +1,10 @@ | ||||
| from django.db.backends.base.client import BaseDatabaseClient | ||||
|  | ||||
|  | ||||
| class DatabaseClient(BaseDatabaseClient): | ||||
|     executable_name = "sqlite3" | ||||
|  | ||||
|     @classmethod | ||||
|     def settings_to_cmd_args_env(cls, settings_dict, parameters): | ||||
|         args = [cls.executable_name, settings_dict["NAME"], *parameters] | ||||
|         return args, None | ||||
| @ -0,0 +1,159 @@ | ||||
| import multiprocessing | ||||
| import os | ||||
| import shutil | ||||
| import sqlite3 | ||||
| import sys | ||||
| from pathlib import Path | ||||
|  | ||||
| from django.db import NotSupportedError | ||||
| from django.db.backends.base.creation import BaseDatabaseCreation | ||||
|  | ||||
|  | ||||
| class DatabaseCreation(BaseDatabaseCreation): | ||||
|     @staticmethod | ||||
|     def is_in_memory_db(database_name): | ||||
|         return not isinstance(database_name, Path) and ( | ||||
|             database_name == ":memory:" or "mode=memory" in database_name | ||||
|         ) | ||||
|  | ||||
|     def _get_test_db_name(self): | ||||
|         test_database_name = self.connection.settings_dict["TEST"]["NAME"] or ":memory:" | ||||
|         if test_database_name == ":memory:": | ||||
|             return "file:memorydb_%s?mode=memory&cache=shared" % self.connection.alias | ||||
|         return test_database_name | ||||
|  | ||||
|     def _create_test_db(self, verbosity, autoclobber, keepdb=False): | ||||
|         test_database_name = self._get_test_db_name() | ||||
|  | ||||
|         if keepdb: | ||||
|             return test_database_name | ||||
|         if not self.is_in_memory_db(test_database_name): | ||||
|             # Erase the old test database | ||||
|             if verbosity >= 1: | ||||
|                 self.log( | ||||
|                     "Destroying old test database for alias %s..." | ||||
|                     % (self._get_database_display_str(verbosity, test_database_name),) | ||||
|                 ) | ||||
|             if os.access(test_database_name, os.F_OK): | ||||
|                 if not autoclobber: | ||||
|                     confirm = input( | ||||
|                         "Type 'yes' if you would like to try deleting the test " | ||||
|                         "database '%s', or 'no' to cancel: " % test_database_name | ||||
|                     ) | ||||
|                 if autoclobber or confirm == "yes": | ||||
|                     try: | ||||
|                         os.remove(test_database_name) | ||||
|                     except Exception as e: | ||||
|                         self.log("Got an error deleting the old test database: %s" % e) | ||||
|                         sys.exit(2) | ||||
|                 else: | ||||
|                     self.log("Tests cancelled.") | ||||
|                     sys.exit(1) | ||||
|         return test_database_name | ||||
|  | ||||
|     def get_test_db_clone_settings(self, suffix): | ||||
|         orig_settings_dict = self.connection.settings_dict | ||||
|         source_database_name = orig_settings_dict["NAME"] | ||||
|  | ||||
|         if not self.is_in_memory_db(source_database_name): | ||||
|             root, ext = os.path.splitext(source_database_name) | ||||
|             return {**orig_settings_dict, "NAME": f"{root}_{suffix}{ext}"} | ||||
|  | ||||
|         start_method = multiprocessing.get_start_method() | ||||
|         if start_method == "fork": | ||||
|             return orig_settings_dict | ||||
|         if start_method == "spawn": | ||||
|             return { | ||||
|                 **orig_settings_dict, | ||||
|                 "NAME": f"{self.connection.alias}_{suffix}.sqlite3", | ||||
|             } | ||||
|         raise NotSupportedError( | ||||
|             f"Cloning with start method {start_method!r} is not supported." | ||||
|         ) | ||||
|  | ||||
|     def _clone_test_db(self, suffix, verbosity, keepdb=False): | ||||
|         source_database_name = self.connection.settings_dict["NAME"] | ||||
|         target_database_name = self.get_test_db_clone_settings(suffix)["NAME"] | ||||
|         if not self.is_in_memory_db(source_database_name): | ||||
|             # Erase the old test database | ||||
|             if os.access(target_database_name, os.F_OK): | ||||
|                 if keepdb: | ||||
|                     return | ||||
|                 if verbosity >= 1: | ||||
|                     self.log( | ||||
|                         "Destroying old test database for alias %s..." | ||||
|                         % ( | ||||
|                             self._get_database_display_str( | ||||
|                                 verbosity, target_database_name | ||||
|                             ), | ||||
|                         ) | ||||
|                     ) | ||||
|                 try: | ||||
|                     os.remove(target_database_name) | ||||
|                 except Exception as e: | ||||
|                     self.log("Got an error deleting the old test database: %s" % e) | ||||
|                     sys.exit(2) | ||||
|             try: | ||||
|                 shutil.copy(source_database_name, target_database_name) | ||||
|             except Exception as e: | ||||
|                 self.log("Got an error cloning the test database: %s" % e) | ||||
|                 sys.exit(2) | ||||
|         # Forking automatically makes a copy of an in-memory database. | ||||
|         # Spawn requires migrating to disk which will be re-opened in | ||||
|         # setup_worker_connection. | ||||
|         elif multiprocessing.get_start_method() == "spawn": | ||||
|             ondisk_db = sqlite3.connect(target_database_name, uri=True) | ||||
|             self.connection.connection.backup(ondisk_db) | ||||
|             ondisk_db.close() | ||||
|  | ||||
|     def _destroy_test_db(self, test_database_name, verbosity): | ||||
|         if test_database_name and not self.is_in_memory_db(test_database_name): | ||||
|             # Remove the SQLite database file | ||||
|             os.remove(test_database_name) | ||||
|  | ||||
|     def test_db_signature(self): | ||||
|         """ | ||||
|         Return a tuple that uniquely identifies a test database. | ||||
|  | ||||
|         This takes into account the special cases of ":memory:" and "" for | ||||
|         SQLite since the databases will be distinct despite having the same | ||||
|         TEST NAME. See https://www.sqlite.org/inmemorydb.html | ||||
|         """ | ||||
|         test_database_name = self._get_test_db_name() | ||||
|         sig = [self.connection.settings_dict["NAME"]] | ||||
|         if self.is_in_memory_db(test_database_name): | ||||
|             sig.append(self.connection.alias) | ||||
|         else: | ||||
|             sig.append(test_database_name) | ||||
|         return tuple(sig) | ||||
|  | ||||
|     def setup_worker_connection(self, _worker_id): | ||||
|         settings_dict = self.get_test_db_clone_settings(_worker_id) | ||||
|         # connection.settings_dict must be updated in place for changes to be | ||||
|         # reflected in django.db.connections. Otherwise new threads would | ||||
|         # connect to the default database instead of the appropriate clone. | ||||
|         start_method = multiprocessing.get_start_method() | ||||
|         if start_method == "fork": | ||||
|             # Update settings_dict in place. | ||||
|             self.connection.settings_dict.update(settings_dict) | ||||
|             self.connection.close() | ||||
|         elif start_method == "spawn": | ||||
|             alias = self.connection.alias | ||||
|             connection_str = ( | ||||
|                 f"file:memorydb_{alias}_{_worker_id}?mode=memory&cache=shared" | ||||
|             ) | ||||
|             source_db = self.connection.Database.connect( | ||||
|                 f"file:{alias}_{_worker_id}.sqlite3", uri=True | ||||
|             ) | ||||
|             target_db = sqlite3.connect(connection_str, uri=True) | ||||
|             source_db.backup(target_db) | ||||
|             source_db.close() | ||||
|             # Update settings_dict in place. | ||||
|             self.connection.settings_dict.update(settings_dict) | ||||
|             self.connection.settings_dict["NAME"] = connection_str | ||||
|             # Re-open connection to in-memory database before closing copy | ||||
|             # connection. | ||||
|             self.connection.connect() | ||||
|             target_db.close() | ||||
|             if os.environ.get("RUNNING_DJANGOS_TEST_SUITE") == "true": | ||||
|                 self.mark_expected_failures_and_skips() | ||||
| @ -0,0 +1,165 @@ | ||||
| import operator | ||||
|  | ||||
| from django.db import transaction | ||||
| from django.db.backends.base.features import BaseDatabaseFeatures | ||||
| from django.db.utils import OperationalError | ||||
| from django.utils.functional import cached_property | ||||
|  | ||||
| from .base import Database | ||||
|  | ||||
|  | ||||
| class DatabaseFeatures(BaseDatabaseFeatures): | ||||
|     minimum_database_version = (3, 21) | ||||
|     test_db_allows_multiple_connections = False | ||||
|     supports_unspecified_pk = True | ||||
|     supports_timezones = False | ||||
|     max_query_params = 999 | ||||
|     supports_transactions = True | ||||
|     atomic_transactions = False | ||||
|     can_rollback_ddl = True | ||||
|     can_create_inline_fk = False | ||||
|     requires_literal_defaults = True | ||||
|     can_clone_databases = True | ||||
|     supports_temporal_subtraction = True | ||||
|     ignores_table_name_case = True | ||||
|     supports_cast_with_precision = False | ||||
|     time_cast_precision = 3 | ||||
|     can_release_savepoints = True | ||||
|     has_case_insensitive_like = True | ||||
|     # Is "ALTER TABLE ... RENAME COLUMN" supported? | ||||
|     can_alter_table_rename_column = Database.sqlite_version_info >= (3, 25, 0) | ||||
|     # Is "ALTER TABLE ... DROP COLUMN" supported? | ||||
|     can_alter_table_drop_column = Database.sqlite_version_info >= (3, 35, 5) | ||||
|     supports_parentheses_in_compound = False | ||||
|     can_defer_constraint_checks = True | ||||
|     supports_over_clause = Database.sqlite_version_info >= (3, 25, 0) | ||||
|     supports_frame_range_fixed_distance = Database.sqlite_version_info >= (3, 28, 0) | ||||
|     supports_aggregate_filter_clause = Database.sqlite_version_info >= (3, 30, 1) | ||||
|     supports_order_by_nulls_modifier = Database.sqlite_version_info >= (3, 30, 0) | ||||
|     # NULLS LAST/FIRST emulation on < 3.30 requires subquery wrapping. | ||||
|     requires_compound_order_by_subquery = Database.sqlite_version_info < (3, 30) | ||||
|     order_by_nulls_first = True | ||||
|     supports_json_field_contains = False | ||||
|     supports_update_conflicts = Database.sqlite_version_info >= (3, 24, 0) | ||||
|     supports_update_conflicts_with_target = supports_update_conflicts | ||||
|     test_collations = { | ||||
|         "ci": "nocase", | ||||
|         "cs": "binary", | ||||
|         "non_default": "nocase", | ||||
|     } | ||||
|     django_test_expected_failures = { | ||||
|         # The django_format_dtdelta() function doesn't properly handle mixed | ||||
|         # Date/DateTime fields and timedeltas. | ||||
|         "expressions.tests.FTimeDeltaTests.test_mixed_comparisons1", | ||||
|     } | ||||
|     create_test_table_with_composite_primary_key = """ | ||||
|         CREATE TABLE test_table_composite_pk ( | ||||
|             column_1 INTEGER NOT NULL, | ||||
|             column_2 INTEGER NOT NULL, | ||||
|             PRIMARY KEY(column_1, column_2) | ||||
|         ) | ||||
|     """ | ||||
|  | ||||
|     @cached_property | ||||
|     def django_test_skips(self): | ||||
|         skips = { | ||||
|             "SQLite stores values rounded to 15 significant digits.": { | ||||
|                 "model_fields.test_decimalfield.DecimalFieldTests." | ||||
|                 "test_fetch_from_db_without_float_rounding", | ||||
|             }, | ||||
|             "SQLite naively remakes the table on field alteration.": { | ||||
|                 "schema.tests.SchemaTests.test_unique_no_unnecessary_fk_drops", | ||||
|                 "schema.tests.SchemaTests.test_unique_and_reverse_m2m", | ||||
|                 "schema.tests.SchemaTests." | ||||
|                 "test_alter_field_default_doesnt_perform_queries", | ||||
|                 "schema.tests.SchemaTests." | ||||
|                 "test_rename_column_renames_deferred_sql_references", | ||||
|             }, | ||||
|             "SQLite doesn't support negative precision for ROUND().": { | ||||
|                 "db_functions.math.test_round.RoundTests." | ||||
|                 "test_null_with_negative_precision", | ||||
|                 "db_functions.math.test_round.RoundTests." | ||||
|                 "test_decimal_with_negative_precision", | ||||
|                 "db_functions.math.test_round.RoundTests." | ||||
|                 "test_float_with_negative_precision", | ||||
|                 "db_functions.math.test_round.RoundTests." | ||||
|                 "test_integer_with_negative_precision", | ||||
|             }, | ||||
|         } | ||||
|         if Database.sqlite_version_info < (3, 27): | ||||
|             skips.update( | ||||
|                 { | ||||
|                     "Nondeterministic failure on SQLite < 3.27.": { | ||||
|                         "expressions_window.tests.WindowFunctionTests." | ||||
|                         "test_subquery_row_range_rank", | ||||
|                     }, | ||||
|                 } | ||||
|             ) | ||||
|         if self.connection.is_in_memory_db(): | ||||
|             skips.update( | ||||
|                 { | ||||
|                     "the sqlite backend's close() method is a no-op when using an " | ||||
|                     "in-memory database": { | ||||
|                         "servers.test_liveserverthread.LiveServerThreadTest." | ||||
|                         "test_closes_connections", | ||||
|                         "servers.tests.LiveServerTestCloseConnectionTest." | ||||
|                         "test_closes_connections", | ||||
|                     }, | ||||
|                     "For SQLite in-memory tests, closing the connection destroys" | ||||
|                     "the database.": { | ||||
|                         "test_utils.tests.AssertNumQueriesUponConnectionTests." | ||||
|                         "test_ignores_connection_configuration_queries", | ||||
|                     }, | ||||
|                 } | ||||
|             ) | ||||
|         else: | ||||
|             skips.update( | ||||
|                 { | ||||
|                     "Only connections to in-memory SQLite databases are passed to the " | ||||
|                     "server thread.": { | ||||
|                         "servers.tests.LiveServerInMemoryDatabaseLockTest." | ||||
|                         "test_in_memory_database_lock", | ||||
|                     }, | ||||
|                     "multiprocessing's start method is checked only for in-memory " | ||||
|                     "SQLite databases": { | ||||
|                         "backends.sqlite.test_creation.TestDbSignatureTests." | ||||
|                         "test_get_test_db_clone_settings_not_supported", | ||||
|                     }, | ||||
|                 } | ||||
|             ) | ||||
|         return skips | ||||
|  | ||||
|     @cached_property | ||||
|     def supports_atomic_references_rename(self): | ||||
|         return Database.sqlite_version_info >= (3, 26, 0) | ||||
|  | ||||
|     @cached_property | ||||
|     def introspected_field_types(self): | ||||
|         return { | ||||
|             **super().introspected_field_types, | ||||
|             "BigAutoField": "AutoField", | ||||
|             "DurationField": "BigIntegerField", | ||||
|             "GenericIPAddressField": "CharField", | ||||
|             "SmallAutoField": "AutoField", | ||||
|         } | ||||
|  | ||||
|     @cached_property | ||||
|     def supports_json_field(self): | ||||
|         with self.connection.cursor() as cursor: | ||||
|             try: | ||||
|                 with transaction.atomic(self.connection.alias): | ||||
|                     cursor.execute('SELECT JSON(\'{"a": "b"}\')') | ||||
|             except OperationalError: | ||||
|                 return False | ||||
|         return True | ||||
|  | ||||
|     can_introspect_json_field = property(operator.attrgetter("supports_json_field")) | ||||
|     has_json_object_function = property(operator.attrgetter("supports_json_field")) | ||||
|  | ||||
|     @cached_property | ||||
|     def can_return_columns_from_insert(self): | ||||
|         return Database.sqlite_version_info >= (3, 35) | ||||
|  | ||||
|     can_return_rows_from_bulk_insert = property( | ||||
|         operator.attrgetter("can_return_columns_from_insert") | ||||
|     ) | ||||
| @ -0,0 +1,434 @@ | ||||
| from collections import namedtuple | ||||
|  | ||||
| import sqlparse | ||||
|  | ||||
| from django.db import DatabaseError | ||||
| from django.db.backends.base.introspection import BaseDatabaseIntrospection | ||||
| from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo | ||||
| from django.db.backends.base.introspection import TableInfo | ||||
| from django.db.models import Index | ||||
| from django.utils.regex_helper import _lazy_re_compile | ||||
|  | ||||
| FieldInfo = namedtuple( | ||||
|     "FieldInfo", BaseFieldInfo._fields + ("pk", "has_json_constraint") | ||||
| ) | ||||
|  | ||||
| field_size_re = _lazy_re_compile(r"^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$") | ||||
|  | ||||
|  | ||||
| def get_field_size(name): | ||||
|     """Extract the size number from a "varchar(11)" type name""" | ||||
|     m = field_size_re.search(name) | ||||
|     return int(m[1]) if m else None | ||||
|  | ||||
|  | ||||
| # This light wrapper "fakes" a dictionary interface, because some SQLite data | ||||
| # types include variables in them -- e.g. "varchar(30)" -- and can't be matched | ||||
| # as a simple dictionary lookup. | ||||
| class FlexibleFieldLookupDict: | ||||
|     # Maps SQL types to Django Field types. Some of the SQL types have multiple | ||||
|     # entries here because SQLite allows for anything and doesn't normalize the | ||||
|     # field type; it uses whatever was given. | ||||
|     base_data_types_reverse = { | ||||
|         "bool": "BooleanField", | ||||
|         "boolean": "BooleanField", | ||||
|         "smallint": "SmallIntegerField", | ||||
|         "smallint unsigned": "PositiveSmallIntegerField", | ||||
|         "smallinteger": "SmallIntegerField", | ||||
|         "int": "IntegerField", | ||||
|         "integer": "IntegerField", | ||||
|         "bigint": "BigIntegerField", | ||||
|         "integer unsigned": "PositiveIntegerField", | ||||
|         "bigint unsigned": "PositiveBigIntegerField", | ||||
|         "decimal": "DecimalField", | ||||
|         "real": "FloatField", | ||||
|         "text": "TextField", | ||||
|         "char": "CharField", | ||||
|         "varchar": "CharField", | ||||
|         "blob": "BinaryField", | ||||
|         "date": "DateField", | ||||
|         "datetime": "DateTimeField", | ||||
|         "time": "TimeField", | ||||
|     } | ||||
|  | ||||
|     def __getitem__(self, key): | ||||
|         key = key.lower().split("(", 1)[0].strip() | ||||
|         return self.base_data_types_reverse[key] | ||||
|  | ||||
|  | ||||
| class DatabaseIntrospection(BaseDatabaseIntrospection): | ||||
|     data_types_reverse = FlexibleFieldLookupDict() | ||||
|  | ||||
|     def get_field_type(self, data_type, description): | ||||
|         field_type = super().get_field_type(data_type, description) | ||||
|         if description.pk and field_type in { | ||||
|             "BigIntegerField", | ||||
|             "IntegerField", | ||||
|             "SmallIntegerField", | ||||
|         }: | ||||
|             # No support for BigAutoField or SmallAutoField as SQLite treats | ||||
|             # all integer primary keys as signed 64-bit integers. | ||||
|             return "AutoField" | ||||
|         if description.has_json_constraint: | ||||
|             return "JSONField" | ||||
|         return field_type | ||||
|  | ||||
|     def get_table_list(self, cursor): | ||||
|         """Return a list of table and view names in the current database.""" | ||||
|         # Skip the sqlite_sequence system table used for autoincrement key | ||||
|         # generation. | ||||
|         cursor.execute( | ||||
|             """ | ||||
|             SELECT name, type FROM sqlite_master | ||||
|             WHERE type in ('table', 'view') AND NOT name='sqlite_sequence' | ||||
|             ORDER BY name""" | ||||
|         ) | ||||
|         return [TableInfo(row[0], row[1][0]) for row in cursor.fetchall()] | ||||
|  | ||||
|     def get_table_description(self, cursor, table_name): | ||||
|         """ | ||||
|         Return a description of the table with the DB-API cursor.description | ||||
|         interface. | ||||
|         """ | ||||
|         cursor.execute( | ||||
|             "PRAGMA table_info(%s)" % self.connection.ops.quote_name(table_name) | ||||
|         ) | ||||
|         table_info = cursor.fetchall() | ||||
|         if not table_info: | ||||
|             raise DatabaseError(f"Table {table_name} does not exist (empty pragma).") | ||||
|         collations = self._get_column_collations(cursor, table_name) | ||||
|         json_columns = set() | ||||
|         if self.connection.features.can_introspect_json_field: | ||||
|             for line in table_info: | ||||
|                 column = line[1] | ||||
|                 json_constraint_sql = '%%json_valid("%s")%%' % column | ||||
|                 has_json_constraint = cursor.execute( | ||||
|                     """ | ||||
|                     SELECT sql | ||||
|                     FROM sqlite_master | ||||
|                     WHERE | ||||
|                         type = 'table' AND | ||||
|                         name = %s AND | ||||
|                         sql LIKE %s | ||||
|                 """, | ||||
|                     [table_name, json_constraint_sql], | ||||
|                 ).fetchone() | ||||
|                 if has_json_constraint: | ||||
|                     json_columns.add(column) | ||||
|         return [ | ||||
|             FieldInfo( | ||||
|                 name, | ||||
|                 data_type, | ||||
|                 get_field_size(data_type), | ||||
|                 None, | ||||
|                 None, | ||||
|                 None, | ||||
|                 not notnull, | ||||
|                 default, | ||||
|                 collations.get(name), | ||||
|                 pk == 1, | ||||
|                 name in json_columns, | ||||
|             ) | ||||
|             for cid, name, data_type, notnull, default, pk in table_info | ||||
|         ] | ||||
|  | ||||
|     def get_sequences(self, cursor, table_name, table_fields=()): | ||||
|         pk_col = self.get_primary_key_column(cursor, table_name) | ||||
|         return [{"table": table_name, "column": pk_col}] | ||||
|  | ||||
|     def get_relations(self, cursor, table_name): | ||||
|         """ | ||||
|         Return a dictionary of {column_name: (ref_column_name, ref_table_name)} | ||||
|         representing all foreign keys in the given table. | ||||
|         """ | ||||
|         cursor.execute( | ||||
|             "PRAGMA foreign_key_list(%s)" % self.connection.ops.quote_name(table_name) | ||||
|         ) | ||||
|         return { | ||||
|             column_name: (ref_column_name, ref_table_name) | ||||
|             for ( | ||||
|                 _, | ||||
|                 _, | ||||
|                 ref_table_name, | ||||
|                 column_name, | ||||
|                 ref_column_name, | ||||
|                 *_, | ||||
|             ) in cursor.fetchall() | ||||
|         } | ||||
|  | ||||
|     def get_primary_key_columns(self, cursor, table_name): | ||||
|         cursor.execute( | ||||
|             "PRAGMA table_info(%s)" % self.connection.ops.quote_name(table_name) | ||||
|         ) | ||||
|         return [name for _, name, *_, pk in cursor.fetchall() if pk] | ||||
|  | ||||
|     def _parse_column_or_constraint_definition(self, tokens, columns): | ||||
|         token = None | ||||
|         is_constraint_definition = None | ||||
|         field_name = None | ||||
|         constraint_name = None | ||||
|         unique = False | ||||
|         unique_columns = [] | ||||
|         check = False | ||||
|         check_columns = [] | ||||
|         braces_deep = 0 | ||||
|         for token in tokens: | ||||
|             if token.match(sqlparse.tokens.Punctuation, "("): | ||||
|                 braces_deep += 1 | ||||
|             elif token.match(sqlparse.tokens.Punctuation, ")"): | ||||
|                 braces_deep -= 1 | ||||
|                 if braces_deep < 0: | ||||
|                     # End of columns and constraints for table definition. | ||||
|                     break | ||||
|             elif braces_deep == 0 and token.match(sqlparse.tokens.Punctuation, ","): | ||||
|                 # End of current column or constraint definition. | ||||
|                 break | ||||
|             # Detect column or constraint definition by first token. | ||||
|             if is_constraint_definition is None: | ||||
|                 is_constraint_definition = token.match( | ||||
|                     sqlparse.tokens.Keyword, "CONSTRAINT" | ||||
|                 ) | ||||
|                 if is_constraint_definition: | ||||
|                     continue | ||||
|             if is_constraint_definition: | ||||
|                 # Detect constraint name by second token. | ||||
|                 if constraint_name is None: | ||||
|                     if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword): | ||||
|                         constraint_name = token.value | ||||
|                     elif token.ttype == sqlparse.tokens.Literal.String.Symbol: | ||||
|                         constraint_name = token.value[1:-1] | ||||
|                 # Start constraint columns parsing after UNIQUE keyword. | ||||
|                 if token.match(sqlparse.tokens.Keyword, "UNIQUE"): | ||||
|                     unique = True | ||||
|                     unique_braces_deep = braces_deep | ||||
|                 elif unique: | ||||
|                     if unique_braces_deep == braces_deep: | ||||
|                         if unique_columns: | ||||
|                             # Stop constraint parsing. | ||||
|                             unique = False | ||||
|                         continue | ||||
|                     if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword): | ||||
|                         unique_columns.append(token.value) | ||||
|                     elif token.ttype == sqlparse.tokens.Literal.String.Symbol: | ||||
|                         unique_columns.append(token.value[1:-1]) | ||||
|             else: | ||||
|                 # Detect field name by first token. | ||||
|                 if field_name is None: | ||||
|                     if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword): | ||||
|                         field_name = token.value | ||||
|                     elif token.ttype == sqlparse.tokens.Literal.String.Symbol: | ||||
|                         field_name = token.value[1:-1] | ||||
|                 if token.match(sqlparse.tokens.Keyword, "UNIQUE"): | ||||
|                     unique_columns = [field_name] | ||||
|             # Start constraint columns parsing after CHECK keyword. | ||||
|             if token.match(sqlparse.tokens.Keyword, "CHECK"): | ||||
|                 check = True | ||||
|                 check_braces_deep = braces_deep | ||||
|             elif check: | ||||
|                 if check_braces_deep == braces_deep: | ||||
|                     if check_columns: | ||||
|                         # Stop constraint parsing. | ||||
|                         check = False | ||||
|                     continue | ||||
|                 if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword): | ||||
|                     if token.value in columns: | ||||
|                         check_columns.append(token.value) | ||||
|                 elif token.ttype == sqlparse.tokens.Literal.String.Symbol: | ||||
|                     if token.value[1:-1] in columns: | ||||
|                         check_columns.append(token.value[1:-1]) | ||||
|         unique_constraint = ( | ||||
|             { | ||||
|                 "unique": True, | ||||
|                 "columns": unique_columns, | ||||
|                 "primary_key": False, | ||||
|                 "foreign_key": None, | ||||
|                 "check": False, | ||||
|                 "index": False, | ||||
|             } | ||||
|             if unique_columns | ||||
|             else None | ||||
|         ) | ||||
|         check_constraint = ( | ||||
|             { | ||||
|                 "check": True, | ||||
|                 "columns": check_columns, | ||||
|                 "primary_key": False, | ||||
|                 "unique": False, | ||||
|                 "foreign_key": None, | ||||
|                 "index": False, | ||||
|             } | ||||
|             if check_columns | ||||
|             else None | ||||
|         ) | ||||
|         return constraint_name, unique_constraint, check_constraint, token | ||||
|  | ||||
|     def _parse_table_constraints(self, sql, columns): | ||||
|         # Check constraint parsing is based of SQLite syntax diagram. | ||||
|         # https://www.sqlite.org/syntaxdiagrams.html#table-constraint | ||||
|         statement = sqlparse.parse(sql)[0] | ||||
|         constraints = {} | ||||
|         unnamed_constrains_index = 0 | ||||
|         tokens = (token for token in statement.flatten() if not token.is_whitespace) | ||||
|         # Go to columns and constraint definition | ||||
|         for token in tokens: | ||||
|             if token.match(sqlparse.tokens.Punctuation, "("): | ||||
|                 break | ||||
|         # Parse columns and constraint definition | ||||
|         while True: | ||||
|             ( | ||||
|                 constraint_name, | ||||
|                 unique, | ||||
|                 check, | ||||
|                 end_token, | ||||
|             ) = self._parse_column_or_constraint_definition(tokens, columns) | ||||
|             if unique: | ||||
|                 if constraint_name: | ||||
|                     constraints[constraint_name] = unique | ||||
|                 else: | ||||
|                     unnamed_constrains_index += 1 | ||||
|                     constraints[ | ||||
|                         "__unnamed_constraint_%s__" % unnamed_constrains_index | ||||
|                     ] = unique | ||||
|             if check: | ||||
|                 if constraint_name: | ||||
|                     constraints[constraint_name] = check | ||||
|                 else: | ||||
|                     unnamed_constrains_index += 1 | ||||
|                     constraints[ | ||||
|                         "__unnamed_constraint_%s__" % unnamed_constrains_index | ||||
|                     ] = check | ||||
|             if end_token.match(sqlparse.tokens.Punctuation, ")"): | ||||
|                 break | ||||
|         return constraints | ||||
|  | ||||
|     def get_constraints(self, cursor, table_name): | ||||
|         """ | ||||
|         Retrieve any constraints or keys (unique, pk, fk, check, index) across | ||||
|         one or more columns. | ||||
|         """ | ||||
|         constraints = {} | ||||
|         # Find inline check constraints. | ||||
|         try: | ||||
|             table_schema = cursor.execute( | ||||
|                 "SELECT sql FROM sqlite_master WHERE type='table' and name=%s" | ||||
|                 % (self.connection.ops.quote_name(table_name),) | ||||
|             ).fetchone()[0] | ||||
|         except TypeError: | ||||
|             # table_name is a view. | ||||
|             pass | ||||
|         else: | ||||
|             columns = { | ||||
|                 info.name for info in self.get_table_description(cursor, table_name) | ||||
|             } | ||||
|             constraints.update(self._parse_table_constraints(table_schema, columns)) | ||||
|  | ||||
|         # Get the index info | ||||
|         cursor.execute( | ||||
|             "PRAGMA index_list(%s)" % self.connection.ops.quote_name(table_name) | ||||
|         ) | ||||
|         for row in cursor.fetchall(): | ||||
|             # SQLite 3.8.9+ has 5 columns, however older versions only give 3 | ||||
|             # columns. Discard last 2 columns if there. | ||||
|             number, index, unique = row[:3] | ||||
|             cursor.execute( | ||||
|                 "SELECT sql FROM sqlite_master " | ||||
|                 "WHERE type='index' AND name=%s" % self.connection.ops.quote_name(index) | ||||
|             ) | ||||
|             # There's at most one row. | ||||
|             (sql,) = cursor.fetchone() or (None,) | ||||
|             # Inline constraints are already detected in | ||||
|             # _parse_table_constraints(). The reasons to avoid fetching inline | ||||
|             # constraints from `PRAGMA index_list` are: | ||||
|             # - Inline constraints can have a different name and information | ||||
|             #   than what `PRAGMA index_list` gives. | ||||
|             # - Not all inline constraints may appear in `PRAGMA index_list`. | ||||
|             if not sql: | ||||
|                 # An inline constraint | ||||
|                 continue | ||||
|             # Get the index info for that index | ||||
|             cursor.execute( | ||||
|                 "PRAGMA index_info(%s)" % self.connection.ops.quote_name(index) | ||||
|             ) | ||||
|             for index_rank, column_rank, column in cursor.fetchall(): | ||||
|                 if index not in constraints: | ||||
|                     constraints[index] = { | ||||
|                         "columns": [], | ||||
|                         "primary_key": False, | ||||
|                         "unique": bool(unique), | ||||
|                         "foreign_key": None, | ||||
|                         "check": False, | ||||
|                         "index": True, | ||||
|                     } | ||||
|                 constraints[index]["columns"].append(column) | ||||
|             # Add type and column orders for indexes | ||||
|             if constraints[index]["index"]: | ||||
|                 # SQLite doesn't support any index type other than b-tree | ||||
|                 constraints[index]["type"] = Index.suffix | ||||
|                 orders = self._get_index_columns_orders(sql) | ||||
|                 if orders is not None: | ||||
|                     constraints[index]["orders"] = orders | ||||
|         # Get the PK | ||||
|         pk_columns = self.get_primary_key_columns(cursor, table_name) | ||||
|         if pk_columns: | ||||
|             # SQLite doesn't actually give a name to the PK constraint, | ||||
|             # so we invent one. This is fine, as the SQLite backend never | ||||
|             # deletes PK constraints by name, as you can't delete constraints | ||||
|             # in SQLite; we remake the table with a new PK instead. | ||||
|             constraints["__primary__"] = { | ||||
|                 "columns": pk_columns, | ||||
|                 "primary_key": True, | ||||
|                 "unique": False,  # It's not actually a unique constraint. | ||||
|                 "foreign_key": None, | ||||
|                 "check": False, | ||||
|                 "index": False, | ||||
|             } | ||||
|         relations = enumerate(self.get_relations(cursor, table_name).items()) | ||||
|         constraints.update( | ||||
|             { | ||||
|                 f"fk_{index}": { | ||||
|                     "columns": [column_name], | ||||
|                     "primary_key": False, | ||||
|                     "unique": False, | ||||
|                     "foreign_key": (ref_table_name, ref_column_name), | ||||
|                     "check": False, | ||||
|                     "index": False, | ||||
|                 } | ||||
|                 for index, (column_name, (ref_column_name, ref_table_name)) in relations | ||||
|             } | ||||
|         ) | ||||
|         return constraints | ||||
|  | ||||
|     def _get_index_columns_orders(self, sql): | ||||
|         tokens = sqlparse.parse(sql)[0] | ||||
|         for token in tokens: | ||||
|             if isinstance(token, sqlparse.sql.Parenthesis): | ||||
|                 columns = str(token).strip("()").split(", ") | ||||
|                 return ["DESC" if info.endswith("DESC") else "ASC" for info in columns] | ||||
|         return None | ||||
|  | ||||
|     def _get_column_collations(self, cursor, table_name): | ||||
|         row = cursor.execute( | ||||
|             """ | ||||
|             SELECT sql | ||||
|             FROM sqlite_master | ||||
|             WHERE type = 'table' AND name = %s | ||||
|         """, | ||||
|             [table_name], | ||||
|         ).fetchone() | ||||
|         if not row: | ||||
|             return {} | ||||
|  | ||||
|         sql = row[0] | ||||
|         columns = str(sqlparse.parse(sql)[0][-1]).strip("()").split(", ") | ||||
|         collations = {} | ||||
|         for column in columns: | ||||
|             tokens = column[1:].split() | ||||
|             column_name = tokens[0].strip('"') | ||||
|             for index, token in enumerate(tokens): | ||||
|                 if token == "COLLATE": | ||||
|                     collation = tokens[index + 1] | ||||
|                     break | ||||
|             else: | ||||
|                 collation = None | ||||
|             collations[column_name] = collation | ||||
|         return collations | ||||
| @ -0,0 +1,434 @@ | ||||
| import datetime | ||||
| import decimal | ||||
| import uuid | ||||
| from functools import lru_cache | ||||
| from itertools import chain | ||||
|  | ||||
| from django.conf import settings | ||||
| from django.core.exceptions import FieldError | ||||
| from django.db import DatabaseError, NotSupportedError, models | ||||
| from django.db.backends.base.operations import BaseDatabaseOperations | ||||
| from django.db.models.constants import OnConflict | ||||
| from django.db.models.expressions import Col | ||||
| from django.utils import timezone | ||||
| from django.utils.dateparse import parse_date, parse_datetime, parse_time | ||||
| from django.utils.functional import cached_property | ||||
|  | ||||
|  | ||||
| class DatabaseOperations(BaseDatabaseOperations): | ||||
|     cast_char_field_without_max_length = "text" | ||||
|     cast_data_types = { | ||||
|         "DateField": "TEXT", | ||||
|         "DateTimeField": "TEXT", | ||||
|     } | ||||
|     explain_prefix = "EXPLAIN QUERY PLAN" | ||||
|     # List of datatypes to that cannot be extracted with JSON_EXTRACT() on | ||||
|     # SQLite. Use JSON_TYPE() instead. | ||||
|     jsonfield_datatype_values = frozenset(["null", "false", "true"]) | ||||
|  | ||||
|     def bulk_batch_size(self, fields, objs): | ||||
|         """ | ||||
|         SQLite has a compile-time default (SQLITE_LIMIT_VARIABLE_NUMBER) of | ||||
|         999 variables per query. | ||||
|  | ||||
|         If there's only a single field to insert, the limit is 500 | ||||
|         (SQLITE_MAX_COMPOUND_SELECT). | ||||
|         """ | ||||
|         if len(fields) == 1: | ||||
|             return 500 | ||||
|         elif len(fields) > 1: | ||||
|             return self.connection.features.max_query_params // len(fields) | ||||
|         else: | ||||
|             return len(objs) | ||||
|  | ||||
|     def check_expression_support(self, expression): | ||||
|         bad_fields = (models.DateField, models.DateTimeField, models.TimeField) | ||||
|         bad_aggregates = (models.Sum, models.Avg, models.Variance, models.StdDev) | ||||
|         if isinstance(expression, bad_aggregates): | ||||
|             for expr in expression.get_source_expressions(): | ||||
|                 try: | ||||
|                     output_field = expr.output_field | ||||
|                 except (AttributeError, FieldError): | ||||
|                     # Not every subexpression has an output_field which is fine | ||||
|                     # to ignore. | ||||
|                     pass | ||||
|                 else: | ||||
|                     if isinstance(output_field, bad_fields): | ||||
|                         raise NotSupportedError( | ||||
|                             "You cannot use Sum, Avg, StdDev, and Variance " | ||||
|                             "aggregations on date/time fields in sqlite3 " | ||||
|                             "since date/time is saved as text." | ||||
|                         ) | ||||
|         if ( | ||||
|             isinstance(expression, models.Aggregate) | ||||
|             and expression.distinct | ||||
|             and len(expression.source_expressions) > 1 | ||||
|         ): | ||||
|             raise NotSupportedError( | ||||
|                 "SQLite doesn't support DISTINCT on aggregate functions " | ||||
|                 "accepting multiple arguments." | ||||
|             ) | ||||
|  | ||||
|     def date_extract_sql(self, lookup_type, sql, params): | ||||
|         """ | ||||
|         Support EXTRACT with a user-defined function django_date_extract() | ||||
|         that's registered in connect(). Use single quotes because this is a | ||||
|         string and could otherwise cause a collision with a field name. | ||||
|         """ | ||||
|         return f"django_date_extract(%s, {sql})", (lookup_type.lower(), *params) | ||||
|  | ||||
|     def fetch_returned_insert_rows(self, cursor): | ||||
|         """ | ||||
|         Given a cursor object that has just performed an INSERT...RETURNING | ||||
|         statement into a table, return the list of returned data. | ||||
|         """ | ||||
|         return cursor.fetchall() | ||||
|  | ||||
|     def format_for_duration_arithmetic(self, sql): | ||||
|         """Do nothing since formatting is handled in the custom function.""" | ||||
|         return sql | ||||
|  | ||||
|     def date_trunc_sql(self, lookup_type, sql, params, tzname=None): | ||||
|         return f"django_date_trunc(%s, {sql}, %s, %s)", ( | ||||
|             lookup_type.lower(), | ||||
|             *params, | ||||
|             *self._convert_tznames_to_sql(tzname), | ||||
|         ) | ||||
|  | ||||
|     def time_trunc_sql(self, lookup_type, sql, params, tzname=None): | ||||
|         return f"django_time_trunc(%s, {sql}, %s, %s)", ( | ||||
|             lookup_type.lower(), | ||||
|             *params, | ||||
|             *self._convert_tznames_to_sql(tzname), | ||||
|         ) | ||||
|  | ||||
|     def _convert_tznames_to_sql(self, tzname): | ||||
|         if tzname and settings.USE_TZ: | ||||
|             return tzname, self.connection.timezone_name | ||||
|         return None, None | ||||
|  | ||||
|     def datetime_cast_date_sql(self, sql, params, tzname): | ||||
|         return f"django_datetime_cast_date({sql}, %s, %s)", ( | ||||
|             *params, | ||||
|             *self._convert_tznames_to_sql(tzname), | ||||
|         ) | ||||
|  | ||||
|     def datetime_cast_time_sql(self, sql, params, tzname): | ||||
|         return f"django_datetime_cast_time({sql}, %s, %s)", ( | ||||
|             *params, | ||||
|             *self._convert_tznames_to_sql(tzname), | ||||
|         ) | ||||
|  | ||||
|     def datetime_extract_sql(self, lookup_type, sql, params, tzname): | ||||
|         return f"django_datetime_extract(%s, {sql}, %s, %s)", ( | ||||
|             lookup_type.lower(), | ||||
|             *params, | ||||
|             *self._convert_tznames_to_sql(tzname), | ||||
|         ) | ||||
|  | ||||
|     def datetime_trunc_sql(self, lookup_type, sql, params, tzname): | ||||
|         return f"django_datetime_trunc(%s, {sql}, %s, %s)", ( | ||||
|             lookup_type.lower(), | ||||
|             *params, | ||||
|             *self._convert_tznames_to_sql(tzname), | ||||
|         ) | ||||
|  | ||||
|     def time_extract_sql(self, lookup_type, sql, params): | ||||
|         return f"django_time_extract(%s, {sql})", (lookup_type.lower(), *params) | ||||
|  | ||||
|     def pk_default_value(self): | ||||
|         return "NULL" | ||||
|  | ||||
|     def _quote_params_for_last_executed_query(self, params): | ||||
|         """ | ||||
|         Only for last_executed_query! Don't use this to execute SQL queries! | ||||
|         """ | ||||
|         # This function is limited both by SQLITE_LIMIT_VARIABLE_NUMBER (the | ||||
|         # number of parameters, default = 999) and SQLITE_MAX_COLUMN (the | ||||
|         # number of return values, default = 2000). Since Python's sqlite3 | ||||
|         # module doesn't expose the get_limit() C API, assume the default | ||||
|         # limits are in effect and split the work in batches if needed. | ||||
|         BATCH_SIZE = 999 | ||||
|         if len(params) > BATCH_SIZE: | ||||
|             results = () | ||||
|             for index in range(0, len(params), BATCH_SIZE): | ||||
|                 chunk = params[index : index + BATCH_SIZE] | ||||
|                 results += self._quote_params_for_last_executed_query(chunk) | ||||
|             return results | ||||
|  | ||||
|         sql = "SELECT " + ", ".join(["QUOTE(?)"] * len(params)) | ||||
|         # Bypass Django's wrappers and use the underlying sqlite3 connection | ||||
|         # to avoid logging this query - it would trigger infinite recursion. | ||||
|         cursor = self.connection.connection.cursor() | ||||
|         # Native sqlite3 cursors cannot be used as context managers. | ||||
|         try: | ||||
|             return cursor.execute(sql, params).fetchone() | ||||
|         finally: | ||||
|             cursor.close() | ||||
|  | ||||
|     def last_executed_query(self, cursor, sql, params): | ||||
|         # Python substitutes parameters in Modules/_sqlite/cursor.c with: | ||||
|         # bind_parameters(state, self->statement, parameters); | ||||
|         # Unfortunately there is no way to reach self->statement from Python, | ||||
|         # so we quote and substitute parameters manually. | ||||
|         if params: | ||||
|             if isinstance(params, (list, tuple)): | ||||
|                 params = self._quote_params_for_last_executed_query(params) | ||||
|             else: | ||||
|                 values = tuple(params.values()) | ||||
|                 values = self._quote_params_for_last_executed_query(values) | ||||
|                 params = dict(zip(params, values)) | ||||
|             return sql % params | ||||
|         # For consistency with SQLiteCursorWrapper.execute(), just return sql | ||||
|         # when there are no parameters. See #13648 and #17158. | ||||
|         else: | ||||
|             return sql | ||||
|  | ||||
|     def quote_name(self, name): | ||||
|         if name.startswith('"') and name.endswith('"'): | ||||
|             return name  # Quoting once is enough. | ||||
|         return '"%s"' % name | ||||
|  | ||||
|     def no_limit_value(self): | ||||
|         return -1 | ||||
|  | ||||
|     def __references_graph(self, table_name): | ||||
|         query = """ | ||||
|         WITH tables AS ( | ||||
|             SELECT %s name | ||||
|             UNION | ||||
|             SELECT sqlite_master.name | ||||
|             FROM sqlite_master | ||||
|             JOIN tables ON (sql REGEXP %s || tables.name || %s) | ||||
|         ) SELECT name FROM tables; | ||||
|         """ | ||||
|         params = ( | ||||
|             table_name, | ||||
|             r'(?i)\s+references\s+("|\')?', | ||||
|             r'("|\')?\s*\(', | ||||
|         ) | ||||
|         with self.connection.cursor() as cursor: | ||||
|             results = cursor.execute(query, params) | ||||
|             return [row[0] for row in results.fetchall()] | ||||
|  | ||||
|     @cached_property | ||||
|     def _references_graph(self): | ||||
|         # 512 is large enough to fit the ~330 tables (as of this writing) in | ||||
|         # Django's test suite. | ||||
|         return lru_cache(maxsize=512)(self.__references_graph) | ||||
|  | ||||
|     def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False): | ||||
|         if tables and allow_cascade: | ||||
|             # Simulate TRUNCATE CASCADE by recursively collecting the tables | ||||
|             # referencing the tables to be flushed. | ||||
|             tables = set( | ||||
|                 chain.from_iterable(self._references_graph(table) for table in tables) | ||||
|             ) | ||||
|         sql = [ | ||||
|             "%s %s %s;" | ||||
|             % ( | ||||
|                 style.SQL_KEYWORD("DELETE"), | ||||
|                 style.SQL_KEYWORD("FROM"), | ||||
|                 style.SQL_FIELD(self.quote_name(table)), | ||||
|             ) | ||||
|             for table in tables | ||||
|         ] | ||||
|         if reset_sequences: | ||||
|             sequences = [{"table": table} for table in tables] | ||||
|             sql.extend(self.sequence_reset_by_name_sql(style, sequences)) | ||||
|         return sql | ||||
|  | ||||
|     def sequence_reset_by_name_sql(self, style, sequences): | ||||
|         if not sequences: | ||||
|             return [] | ||||
|         return [ | ||||
|             "%s %s %s %s = 0 %s %s %s (%s);" | ||||
|             % ( | ||||
|                 style.SQL_KEYWORD("UPDATE"), | ||||
|                 style.SQL_TABLE(self.quote_name("sqlite_sequence")), | ||||
|                 style.SQL_KEYWORD("SET"), | ||||
|                 style.SQL_FIELD(self.quote_name("seq")), | ||||
|                 style.SQL_KEYWORD("WHERE"), | ||||
|                 style.SQL_FIELD(self.quote_name("name")), | ||||
|                 style.SQL_KEYWORD("IN"), | ||||
|                 ", ".join( | ||||
|                     ["'%s'" % sequence_info["table"] for sequence_info in sequences] | ||||
|                 ), | ||||
|             ), | ||||
|         ] | ||||
|  | ||||
|     def adapt_datetimefield_value(self, value): | ||||
|         if value is None: | ||||
|             return None | ||||
|  | ||||
|         # Expression values are adapted by the database. | ||||
|         if hasattr(value, "resolve_expression"): | ||||
|             return value | ||||
|  | ||||
|         # SQLite doesn't support tz-aware datetimes | ||||
|         if timezone.is_aware(value): | ||||
|             if settings.USE_TZ: | ||||
|                 value = timezone.make_naive(value, self.connection.timezone) | ||||
|             else: | ||||
|                 raise ValueError( | ||||
|                     "SQLite backend does not support timezone-aware datetimes when " | ||||
|                     "USE_TZ is False." | ||||
|                 ) | ||||
|  | ||||
|         return str(value) | ||||
|  | ||||
|     def adapt_timefield_value(self, value): | ||||
|         if value is None: | ||||
|             return None | ||||
|  | ||||
|         # Expression values are adapted by the database. | ||||
|         if hasattr(value, "resolve_expression"): | ||||
|             return value | ||||
|  | ||||
|         # SQLite doesn't support tz-aware datetimes | ||||
|         if timezone.is_aware(value): | ||||
|             raise ValueError("SQLite backend does not support timezone-aware times.") | ||||
|  | ||||
|         return str(value) | ||||
|  | ||||
|     def get_db_converters(self, expression): | ||||
|         converters = super().get_db_converters(expression) | ||||
|         internal_type = expression.output_field.get_internal_type() | ||||
|         if internal_type == "DateTimeField": | ||||
|             converters.append(self.convert_datetimefield_value) | ||||
|         elif internal_type == "DateField": | ||||
|             converters.append(self.convert_datefield_value) | ||||
|         elif internal_type == "TimeField": | ||||
|             converters.append(self.convert_timefield_value) | ||||
|         elif internal_type == "DecimalField": | ||||
|             converters.append(self.get_decimalfield_converter(expression)) | ||||
|         elif internal_type == "UUIDField": | ||||
|             converters.append(self.convert_uuidfield_value) | ||||
|         elif internal_type == "BooleanField": | ||||
|             converters.append(self.convert_booleanfield_value) | ||||
|         return converters | ||||
|  | ||||
|     def convert_datetimefield_value(self, value, expression, connection): | ||||
|         if value is not None: | ||||
|             if not isinstance(value, datetime.datetime): | ||||
|                 value = parse_datetime(value) | ||||
|             if settings.USE_TZ and not timezone.is_aware(value): | ||||
|                 value = timezone.make_aware(value, self.connection.timezone) | ||||
|         return value | ||||
|  | ||||
|     def convert_datefield_value(self, value, expression, connection): | ||||
|         if value is not None: | ||||
|             if not isinstance(value, datetime.date): | ||||
|                 value = parse_date(value) | ||||
|         return value | ||||
|  | ||||
|     def convert_timefield_value(self, value, expression, connection): | ||||
|         if value is not None: | ||||
|             if not isinstance(value, datetime.time): | ||||
|                 value = parse_time(value) | ||||
|         return value | ||||
|  | ||||
|     def get_decimalfield_converter(self, expression): | ||||
|         # SQLite stores only 15 significant digits. Digits coming from | ||||
|         # float inaccuracy must be removed. | ||||
|         create_decimal = decimal.Context(prec=15).create_decimal_from_float | ||||
|         if isinstance(expression, Col): | ||||
|             quantize_value = decimal.Decimal(1).scaleb( | ||||
|                 -expression.output_field.decimal_places | ||||
|             ) | ||||
|  | ||||
|             def converter(value, expression, connection): | ||||
|                 if value is not None: | ||||
|                     return create_decimal(value).quantize( | ||||
|                         quantize_value, context=expression.output_field.context | ||||
|                     ) | ||||
|  | ||||
|         else: | ||||
|  | ||||
|             def converter(value, expression, connection): | ||||
|                 if value is not None: | ||||
|                     return create_decimal(value) | ||||
|  | ||||
|         return converter | ||||
|  | ||||
|     def convert_uuidfield_value(self, value, expression, connection): | ||||
|         if value is not None: | ||||
|             value = uuid.UUID(value) | ||||
|         return value | ||||
|  | ||||
|     def convert_booleanfield_value(self, value, expression, connection): | ||||
|         return bool(value) if value in (1, 0) else value | ||||
|  | ||||
|     def bulk_insert_sql(self, fields, placeholder_rows): | ||||
|         placeholder_rows_sql = (", ".join(row) for row in placeholder_rows) | ||||
|         values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql) | ||||
|         return f"VALUES {values_sql}" | ||||
|  | ||||
|     def combine_expression(self, connector, sub_expressions): | ||||
|         # SQLite doesn't have a ^ operator, so use the user-defined POWER | ||||
|         # function that's registered in connect(). | ||||
|         if connector == "^": | ||||
|             return "POWER(%s)" % ",".join(sub_expressions) | ||||
|         elif connector == "#": | ||||
|             return "BITXOR(%s)" % ",".join(sub_expressions) | ||||
|         return super().combine_expression(connector, sub_expressions) | ||||
|  | ||||
|     def combine_duration_expression(self, connector, sub_expressions): | ||||
|         if connector not in ["+", "-", "*", "/"]: | ||||
|             raise DatabaseError("Invalid connector for timedelta: %s." % connector) | ||||
|         fn_params = ["'%s'" % connector] + sub_expressions | ||||
|         if len(fn_params) > 3: | ||||
|             raise ValueError("Too many params for timedelta operations.") | ||||
|         return "django_format_dtdelta(%s)" % ", ".join(fn_params) | ||||
|  | ||||
|     def integer_field_range(self, internal_type): | ||||
|         # SQLite doesn't enforce any integer constraints | ||||
|         return (None, None) | ||||
|  | ||||
|     def subtract_temporals(self, internal_type, lhs, rhs): | ||||
|         lhs_sql, lhs_params = lhs | ||||
|         rhs_sql, rhs_params = rhs | ||||
|         params = (*lhs_params, *rhs_params) | ||||
|         if internal_type == "TimeField": | ||||
|             return "django_time_diff(%s, %s)" % (lhs_sql, rhs_sql), params | ||||
|         return "django_timestamp_diff(%s, %s)" % (lhs_sql, rhs_sql), params | ||||
|  | ||||
|     def insert_statement(self, on_conflict=None): | ||||
|         if on_conflict == OnConflict.IGNORE: | ||||
|             return "INSERT OR IGNORE INTO" | ||||
|         return super().insert_statement(on_conflict=on_conflict) | ||||
|  | ||||
|     def return_insert_columns(self, fields): | ||||
|         # SQLite < 3.35 doesn't support an INSERT...RETURNING statement. | ||||
|         if not fields: | ||||
|             return "", () | ||||
|         columns = [ | ||||
|             "%s.%s" | ||||
|             % ( | ||||
|                 self.quote_name(field.model._meta.db_table), | ||||
|                 self.quote_name(field.column), | ||||
|             ) | ||||
|             for field in fields | ||||
|         ] | ||||
|         return "RETURNING %s" % ", ".join(columns), () | ||||
|  | ||||
|     def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields): | ||||
|         if ( | ||||
|             on_conflict == OnConflict.UPDATE | ||||
|             and self.connection.features.supports_update_conflicts_with_target | ||||
|         ): | ||||
|             return "ON CONFLICT(%s) DO UPDATE SET %s" % ( | ||||
|                 ", ".join(map(self.quote_name, unique_fields)), | ||||
|                 ", ".join( | ||||
|                     [ | ||||
|                         f"{field} = EXCLUDED.{field}" | ||||
|                         for field in map(self.quote_name, update_fields) | ||||
|                     ] | ||||
|                 ), | ||||
|             ) | ||||
|         return super().on_conflict_suffix_sql( | ||||
|             fields, | ||||
|             on_conflict, | ||||
|             update_fields, | ||||
|             unique_fields, | ||||
|         ) | ||||
| @ -0,0 +1,576 @@ | ||||
| import copy | ||||
| from decimal import Decimal | ||||
|  | ||||
| from django.apps.registry import Apps | ||||
| from django.db import NotSupportedError | ||||
| from django.db.backends.base.schema import BaseDatabaseSchemaEditor | ||||
| from django.db.backends.ddl_references import Statement | ||||
| from django.db.backends.utils import strip_quotes | ||||
| from django.db.models import UniqueConstraint | ||||
| from django.db.transaction import atomic | ||||
|  | ||||
|  | ||||
| class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): | ||||
|     sql_delete_table = "DROP TABLE %(table)s" | ||||
|     sql_create_fk = None | ||||
|     sql_create_inline_fk = ( | ||||
|         "REFERENCES %(to_table)s (%(to_column)s) DEFERRABLE INITIALLY DEFERRED" | ||||
|     ) | ||||
|     sql_create_column_inline_fk = sql_create_inline_fk | ||||
|     sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s" | ||||
|     sql_create_unique = "CREATE UNIQUE INDEX %(name)s ON %(table)s (%(columns)s)" | ||||
|     sql_delete_unique = "DROP INDEX %(name)s" | ||||
|  | ||||
|     def __enter__(self): | ||||
|         # Some SQLite schema alterations need foreign key constraints to be | ||||
|         # disabled. Enforce it here for the duration of the schema edition. | ||||
|         if not self.connection.disable_constraint_checking(): | ||||
|             raise NotSupportedError( | ||||
|                 "SQLite schema editor cannot be used while foreign key " | ||||
|                 "constraint checks are enabled. Make sure to disable them " | ||||
|                 "before entering a transaction.atomic() context because " | ||||
|                 "SQLite does not support disabling them in the middle of " | ||||
|                 "a multi-statement transaction." | ||||
|             ) | ||||
|         return super().__enter__() | ||||
|  | ||||
|     def __exit__(self, exc_type, exc_value, traceback): | ||||
|         self.connection.check_constraints() | ||||
|         super().__exit__(exc_type, exc_value, traceback) | ||||
|         self.connection.enable_constraint_checking() | ||||
|  | ||||
|     def quote_value(self, value): | ||||
|         # The backend "mostly works" without this function and there are use | ||||
|         # cases for compiling Python without the sqlite3 libraries (e.g. | ||||
|         # security hardening). | ||||
|         try: | ||||
|             import sqlite3 | ||||
|  | ||||
|             value = sqlite3.adapt(value) | ||||
|         except ImportError: | ||||
|             pass | ||||
|         except sqlite3.ProgrammingError: | ||||
|             pass | ||||
|         # Manual emulation of SQLite parameter quoting | ||||
|         if isinstance(value, bool): | ||||
|             return str(int(value)) | ||||
|         elif isinstance(value, (Decimal, float, int)): | ||||
|             return str(value) | ||||
|         elif isinstance(value, str): | ||||
|             return "'%s'" % value.replace("'", "''") | ||||
|         elif value is None: | ||||
|             return "NULL" | ||||
|         elif isinstance(value, (bytes, bytearray, memoryview)): | ||||
|             # Bytes are only allowed for BLOB fields, encoded as string | ||||
|             # literals containing hexadecimal data and preceded by a single "X" | ||||
|             # character. | ||||
|             return "X'%s'" % value.hex() | ||||
|         else: | ||||
|             raise ValueError( | ||||
|                 "Cannot quote parameter value %r of type %s" % (value, type(value)) | ||||
|             ) | ||||
|  | ||||
|     def prepare_default(self, value): | ||||
|         return self.quote_value(value) | ||||
|  | ||||
|     def _is_referenced_by_fk_constraint( | ||||
|         self, table_name, column_name=None, ignore_self=False | ||||
|     ): | ||||
|         """ | ||||
|         Return whether or not the provided table name is referenced by another | ||||
|         one. If `column_name` is specified, only references pointing to that | ||||
|         column are considered. If `ignore_self` is True, self-referential | ||||
|         constraints are ignored. | ||||
|         """ | ||||
|         with self.connection.cursor() as cursor: | ||||
|             for other_table in self.connection.introspection.get_table_list(cursor): | ||||
|                 if ignore_self and other_table.name == table_name: | ||||
|                     continue | ||||
|                 relations = self.connection.introspection.get_relations( | ||||
|                     cursor, other_table.name | ||||
|                 ) | ||||
|                 for constraint_column, constraint_table in relations.values(): | ||||
|                     if constraint_table == table_name and ( | ||||
|                         column_name is None or constraint_column == column_name | ||||
|                     ): | ||||
|                         return True | ||||
|         return False | ||||
|  | ||||
|     def alter_db_table( | ||||
|         self, model, old_db_table, new_db_table, disable_constraints=True | ||||
|     ): | ||||
|         if ( | ||||
|             not self.connection.features.supports_atomic_references_rename | ||||
|             and disable_constraints | ||||
|             and self._is_referenced_by_fk_constraint(old_db_table) | ||||
|         ): | ||||
|             if self.connection.in_atomic_block: | ||||
|                 raise NotSupportedError( | ||||
|                     ( | ||||
|                         "Renaming the %r table while in a transaction is not " | ||||
|                         "supported on SQLite < 3.26 because it would break referential " | ||||
|                         "integrity. Try adding `atomic = False` to the Migration class." | ||||
|                     ) | ||||
|                     % old_db_table | ||||
|                 ) | ||||
|             self.connection.enable_constraint_checking() | ||||
|             super().alter_db_table(model, old_db_table, new_db_table) | ||||
|             self.connection.disable_constraint_checking() | ||||
|         else: | ||||
|             super().alter_db_table(model, old_db_table, new_db_table) | ||||
|  | ||||
|     def alter_field(self, model, old_field, new_field, strict=False): | ||||
|         if not self._field_should_be_altered(old_field, new_field): | ||||
|             return | ||||
|         old_field_name = old_field.name | ||||
|         table_name = model._meta.db_table | ||||
|         _, old_column_name = old_field.get_attname_column() | ||||
|         if ( | ||||
|             new_field.name != old_field_name | ||||
|             and not self.connection.features.supports_atomic_references_rename | ||||
|             and self._is_referenced_by_fk_constraint( | ||||
|                 table_name, old_column_name, ignore_self=True | ||||
|             ) | ||||
|         ): | ||||
|             if self.connection.in_atomic_block: | ||||
|                 raise NotSupportedError( | ||||
|                     ( | ||||
|                         "Renaming the %r.%r column while in a transaction is not " | ||||
|                         "supported on SQLite < 3.26 because it would break referential " | ||||
|                         "integrity. Try adding `atomic = False` to the Migration class." | ||||
|                     ) | ||||
|                     % (model._meta.db_table, old_field_name) | ||||
|                 ) | ||||
|             with atomic(self.connection.alias): | ||||
|                 super().alter_field(model, old_field, new_field, strict=strict) | ||||
|                 # Follow SQLite's documented procedure for performing changes | ||||
|                 # that don't affect the on-disk content. | ||||
|                 # https://sqlite.org/lang_altertable.html#otheralter | ||||
|                 with self.connection.cursor() as cursor: | ||||
|                     schema_version = cursor.execute("PRAGMA schema_version").fetchone()[ | ||||
|                         0 | ||||
|                     ] | ||||
|                     cursor.execute("PRAGMA writable_schema = 1") | ||||
|                     references_template = ' REFERENCES "%s" ("%%s") ' % table_name | ||||
|                     new_column_name = new_field.get_attname_column()[1] | ||||
|                     search = references_template % old_column_name | ||||
|                     replacement = references_template % new_column_name | ||||
|                     cursor.execute( | ||||
|                         "UPDATE sqlite_master SET sql = replace(sql, %s, %s)", | ||||
|                         (search, replacement), | ||||
|                     ) | ||||
|                     cursor.execute("PRAGMA schema_version = %d" % (schema_version + 1)) | ||||
|                     cursor.execute("PRAGMA writable_schema = 0") | ||||
|                     # The integrity check will raise an exception and rollback | ||||
|                     # the transaction if the sqlite_master updates corrupt the | ||||
|                     # database. | ||||
|                     cursor.execute("PRAGMA integrity_check") | ||||
|             # Perform a VACUUM to refresh the database representation from | ||||
|             # the sqlite_master table. | ||||
|             with self.connection.cursor() as cursor: | ||||
|                 cursor.execute("VACUUM") | ||||
|         else: | ||||
|             super().alter_field(model, old_field, new_field, strict=strict) | ||||
|  | ||||
|     def _remake_table( | ||||
|         self, model, create_field=None, delete_field=None, alter_fields=None | ||||
|     ): | ||||
|         """ | ||||
|         Shortcut to transform a model from old_model into new_model | ||||
|  | ||||
|         This follows the correct procedure to perform non-rename or column | ||||
|         addition operations based on SQLite's documentation | ||||
|  | ||||
|         https://www.sqlite.org/lang_altertable.html#caution | ||||
|  | ||||
|         The essential steps are: | ||||
|           1. Create a table with the updated definition called "new__app_model" | ||||
|           2. Copy the data from the existing "app_model" table to the new table | ||||
|           3. Drop the "app_model" table | ||||
|           4. Rename the "new__app_model" table to "app_model" | ||||
|           5. Restore any index of the previous "app_model" table. | ||||
|         """ | ||||
|  | ||||
|         # Self-referential fields must be recreated rather than copied from | ||||
|         # the old model to ensure their remote_field.field_name doesn't refer | ||||
|         # to an altered field. | ||||
|         def is_self_referential(f): | ||||
|             return f.is_relation and f.remote_field.model is model | ||||
|  | ||||
|         # Work out the new fields dict / mapping | ||||
|         body = { | ||||
|             f.name: f.clone() if is_self_referential(f) else f | ||||
|             for f in model._meta.local_concrete_fields | ||||
|         } | ||||
|         # Since mapping might mix column names and default values, | ||||
|         # its values must be already quoted. | ||||
|         mapping = { | ||||
|             f.column: self.quote_name(f.column) | ||||
|             for f in model._meta.local_concrete_fields | ||||
|         } | ||||
|         # This maps field names (not columns) for things like unique_together | ||||
|         rename_mapping = {} | ||||
|         # If any of the new or altered fields is introducing a new PK, | ||||
|         # remove the old one | ||||
|         restore_pk_field = None | ||||
|         alter_fields = alter_fields or [] | ||||
|         if getattr(create_field, "primary_key", False) or any( | ||||
|             getattr(new_field, "primary_key", False) for _, new_field in alter_fields | ||||
|         ): | ||||
|             for name, field in list(body.items()): | ||||
|                 if field.primary_key and not any( | ||||
|                     # Do not remove the old primary key when an altered field | ||||
|                     # that introduces a primary key is the same field. | ||||
|                     name == new_field.name | ||||
|                     for _, new_field in alter_fields | ||||
|                 ): | ||||
|                     field.primary_key = False | ||||
|                     restore_pk_field = field | ||||
|                     if field.auto_created: | ||||
|                         del body[name] | ||||
|                         del mapping[field.column] | ||||
|         # Add in any created fields | ||||
|         if create_field: | ||||
|             body[create_field.name] = create_field | ||||
|             # Choose a default and insert it into the copy map | ||||
|             if not create_field.many_to_many and create_field.concrete: | ||||
|                 mapping[create_field.column] = self.prepare_default( | ||||
|                     self.effective_default(create_field), | ||||
|                 ) | ||||
|         # Add in any altered fields | ||||
|         for alter_field in alter_fields: | ||||
|             old_field, new_field = alter_field | ||||
|             body.pop(old_field.name, None) | ||||
|             mapping.pop(old_field.column, None) | ||||
|             body[new_field.name] = new_field | ||||
|             if old_field.null and not new_field.null: | ||||
|                 case_sql = "coalesce(%(col)s, %(default)s)" % { | ||||
|                     "col": self.quote_name(old_field.column), | ||||
|                     "default": self.prepare_default(self.effective_default(new_field)), | ||||
|                 } | ||||
|                 mapping[new_field.column] = case_sql | ||||
|             else: | ||||
|                 mapping[new_field.column] = self.quote_name(old_field.column) | ||||
|             rename_mapping[old_field.name] = new_field.name | ||||
|         # Remove any deleted fields | ||||
|         if delete_field: | ||||
|             del body[delete_field.name] | ||||
|             del mapping[delete_field.column] | ||||
|             # Remove any implicit M2M tables | ||||
|             if ( | ||||
|                 delete_field.many_to_many | ||||
|                 and delete_field.remote_field.through._meta.auto_created | ||||
|             ): | ||||
|                 return self.delete_model(delete_field.remote_field.through) | ||||
|         # Work inside a new app registry | ||||
|         apps = Apps() | ||||
|  | ||||
|         # Work out the new value of unique_together, taking renames into | ||||
|         # account | ||||
|         unique_together = [ | ||||
|             [rename_mapping.get(n, n) for n in unique] | ||||
|             for unique in model._meta.unique_together | ||||
|         ] | ||||
|  | ||||
|         # Work out the new value for index_together, taking renames into | ||||
|         # account | ||||
|         index_together = [ | ||||
|             [rename_mapping.get(n, n) for n in index] | ||||
|             for index in model._meta.index_together | ||||
|         ] | ||||
|  | ||||
|         indexes = model._meta.indexes | ||||
|         if delete_field: | ||||
|             indexes = [ | ||||
|                 index for index in indexes if delete_field.name not in index.fields | ||||
|             ] | ||||
|  | ||||
|         constraints = list(model._meta.constraints) | ||||
|  | ||||
|         # Provide isolated instances of the fields to the new model body so | ||||
|         # that the existing model's internals aren't interfered with when | ||||
|         # the dummy model is constructed. | ||||
|         body_copy = copy.deepcopy(body) | ||||
|  | ||||
|         # Construct a new model with the new fields to allow self referential | ||||
|         # primary key to resolve to. This model won't ever be materialized as a | ||||
|         # table and solely exists for foreign key reference resolution purposes. | ||||
|         # This wouldn't be required if the schema editor was operating on model | ||||
|         # states instead of rendered models. | ||||
|         meta_contents = { | ||||
|             "app_label": model._meta.app_label, | ||||
|             "db_table": model._meta.db_table, | ||||
|             "unique_together": unique_together, | ||||
|             "index_together": index_together, | ||||
|             "indexes": indexes, | ||||
|             "constraints": constraints, | ||||
|             "apps": apps, | ||||
|         } | ||||
|         meta = type("Meta", (), meta_contents) | ||||
|         body_copy["Meta"] = meta | ||||
|         body_copy["__module__"] = model.__module__ | ||||
|         type(model._meta.object_name, model.__bases__, body_copy) | ||||
|  | ||||
|         # Construct a model with a renamed table name. | ||||
|         body_copy = copy.deepcopy(body) | ||||
|         meta_contents = { | ||||
|             "app_label": model._meta.app_label, | ||||
|             "db_table": "new__%s" % strip_quotes(model._meta.db_table), | ||||
|             "unique_together": unique_together, | ||||
|             "index_together": index_together, | ||||
|             "indexes": indexes, | ||||
|             "constraints": constraints, | ||||
|             "apps": apps, | ||||
|         } | ||||
|         meta = type("Meta", (), meta_contents) | ||||
|         body_copy["Meta"] = meta | ||||
|         body_copy["__module__"] = model.__module__ | ||||
|         new_model = type("New%s" % model._meta.object_name, model.__bases__, body_copy) | ||||
|  | ||||
|         # Create a new table with the updated schema. | ||||
|         self.create_model(new_model) | ||||
|  | ||||
|         # Copy data from the old table into the new table | ||||
|         self.execute( | ||||
|             "INSERT INTO %s (%s) SELECT %s FROM %s" | ||||
|             % ( | ||||
|                 self.quote_name(new_model._meta.db_table), | ||||
|                 ", ".join(self.quote_name(x) for x in mapping), | ||||
|                 ", ".join(mapping.values()), | ||||
|                 self.quote_name(model._meta.db_table), | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         # Delete the old table to make way for the new | ||||
|         self.delete_model(model, handle_autom2m=False) | ||||
|  | ||||
|         # Rename the new table to take way for the old | ||||
|         self.alter_db_table( | ||||
|             new_model, | ||||
|             new_model._meta.db_table, | ||||
|             model._meta.db_table, | ||||
|             disable_constraints=False, | ||||
|         ) | ||||
|  | ||||
|         # Run deferred SQL on correct table | ||||
|         for sql in self.deferred_sql: | ||||
|             self.execute(sql) | ||||
|         self.deferred_sql = [] | ||||
|         # Fix any PK-removed field | ||||
|         if restore_pk_field: | ||||
|             restore_pk_field.primary_key = True | ||||
|  | ||||
|     def delete_model(self, model, handle_autom2m=True): | ||||
|         if handle_autom2m: | ||||
|             super().delete_model(model) | ||||
|         else: | ||||
|             # Delete the table (and only that) | ||||
|             self.execute( | ||||
|                 self.sql_delete_table | ||||
|                 % { | ||||
|                     "table": self.quote_name(model._meta.db_table), | ||||
|                 } | ||||
|             ) | ||||
|             # Remove all deferred statements referencing the deleted table. | ||||
|             for sql in list(self.deferred_sql): | ||||
|                 if isinstance(sql, Statement) and sql.references_table( | ||||
|                     model._meta.db_table | ||||
|                 ): | ||||
|                     self.deferred_sql.remove(sql) | ||||
|  | ||||
|     def add_field(self, model, field): | ||||
|         """Create a field on a model.""" | ||||
|         # Special-case implicit M2M tables. | ||||
|         if field.many_to_many and field.remote_field.through._meta.auto_created: | ||||
|             self.create_model(field.remote_field.through) | ||||
|         elif ( | ||||
|             # Primary keys and unique fields are not supported in ALTER TABLE | ||||
|             # ADD COLUMN. | ||||
|             field.primary_key | ||||
|             or field.unique | ||||
|             or | ||||
|             # Fields with default values cannot by handled by ALTER TABLE ADD | ||||
|             # COLUMN statement because DROP DEFAULT is not supported in | ||||
|             # ALTER TABLE. | ||||
|             not field.null | ||||
|             or self.effective_default(field) is not None | ||||
|         ): | ||||
|             self._remake_table(model, create_field=field) | ||||
|         else: | ||||
|             super().add_field(model, field) | ||||
|  | ||||
|     def remove_field(self, model, field): | ||||
|         """ | ||||
|         Remove a field from a model. Usually involves deleting a column, | ||||
|         but for M2Ms may involve deleting a table. | ||||
|         """ | ||||
|         # M2M fields are a special case | ||||
|         if field.many_to_many: | ||||
|             # For implicit M2M tables, delete the auto-created table | ||||
|             if field.remote_field.through._meta.auto_created: | ||||
|                 self.delete_model(field.remote_field.through) | ||||
|             # For explicit "through" M2M fields, do nothing | ||||
|         elif ( | ||||
|             self.connection.features.can_alter_table_drop_column | ||||
|             # Primary keys, unique fields, indexed fields, and foreign keys are | ||||
|             # not supported in ALTER TABLE DROP COLUMN. | ||||
|             and not field.primary_key | ||||
|             and not field.unique | ||||
|             and not field.db_index | ||||
|             and not (field.remote_field and field.db_constraint) | ||||
|         ): | ||||
|             super().remove_field(model, field) | ||||
|         # For everything else, remake. | ||||
|         else: | ||||
|             # It might not actually have a column behind it | ||||
|             if field.db_parameters(connection=self.connection)["type"] is None: | ||||
|                 return | ||||
|             self._remake_table(model, delete_field=field) | ||||
|  | ||||
|     def _alter_field( | ||||
|         self, | ||||
|         model, | ||||
|         old_field, | ||||
|         new_field, | ||||
|         old_type, | ||||
|         new_type, | ||||
|         old_db_params, | ||||
|         new_db_params, | ||||
|         strict=False, | ||||
|     ): | ||||
|         """Perform a "physical" (non-ManyToMany) field update.""" | ||||
|         # Use "ALTER TABLE ... RENAME COLUMN" if only the column name | ||||
|         # changed and there aren't any constraints. | ||||
|         if ( | ||||
|             self.connection.features.can_alter_table_rename_column | ||||
|             and old_field.column != new_field.column | ||||
|             and self.column_sql(model, old_field) == self.column_sql(model, new_field) | ||||
|             and not ( | ||||
|                 old_field.remote_field | ||||
|                 and old_field.db_constraint | ||||
|                 or new_field.remote_field | ||||
|                 and new_field.db_constraint | ||||
|             ) | ||||
|         ): | ||||
|             return self.execute( | ||||
|                 self._rename_field_sql( | ||||
|                     model._meta.db_table, old_field, new_field, new_type | ||||
|                 ) | ||||
|             ) | ||||
|         # Alter by remaking table | ||||
|         self._remake_table(model, alter_fields=[(old_field, new_field)]) | ||||
|         # Rebuild tables with FKs pointing to this field. | ||||
|         old_collation = old_db_params.get("collation") | ||||
|         new_collation = new_db_params.get("collation") | ||||
|         if new_field.unique and ( | ||||
|             old_type != new_type or old_collation != new_collation | ||||
|         ): | ||||
|             related_models = set() | ||||
|             opts = new_field.model._meta | ||||
|             for remote_field in opts.related_objects: | ||||
|                 # Ignore self-relationship since the table was already rebuilt. | ||||
|                 if remote_field.related_model == model: | ||||
|                     continue | ||||
|                 if not remote_field.many_to_many: | ||||
|                     if remote_field.field_name == new_field.name: | ||||
|                         related_models.add(remote_field.related_model) | ||||
|                 elif new_field.primary_key and remote_field.through._meta.auto_created: | ||||
|                     related_models.add(remote_field.through) | ||||
|             if new_field.primary_key: | ||||
|                 for many_to_many in opts.many_to_many: | ||||
|                     # Ignore self-relationship since the table was already rebuilt. | ||||
|                     if many_to_many.related_model == model: | ||||
|                         continue | ||||
|                     if many_to_many.remote_field.through._meta.auto_created: | ||||
|                         related_models.add(many_to_many.remote_field.through) | ||||
|             for related_model in related_models: | ||||
|                 self._remake_table(related_model) | ||||
|  | ||||
|     def _alter_many_to_many(self, model, old_field, new_field, strict): | ||||
|         """Alter M2Ms to repoint their to= endpoints.""" | ||||
|         if ( | ||||
|             old_field.remote_field.through._meta.db_table | ||||
|             == new_field.remote_field.through._meta.db_table | ||||
|         ): | ||||
|             # The field name didn't change, but some options did, so we have to | ||||
|             # propagate this altering. | ||||
|             self._remake_table( | ||||
|                 old_field.remote_field.through, | ||||
|                 alter_fields=[ | ||||
|                     ( | ||||
|                         # The field that points to the target model is needed, | ||||
|                         # so that table can be remade with the new m2m field - | ||||
|                         # this is m2m_reverse_field_name(). | ||||
|                         old_field.remote_field.through._meta.get_field( | ||||
|                             old_field.m2m_reverse_field_name() | ||||
|                         ), | ||||
|                         new_field.remote_field.through._meta.get_field( | ||||
|                             new_field.m2m_reverse_field_name() | ||||
|                         ), | ||||
|                     ), | ||||
|                     ( | ||||
|                         # The field that points to the model itself is needed, | ||||
|                         # so that table can be remade with the new self field - | ||||
|                         # this is m2m_field_name(). | ||||
|                         old_field.remote_field.through._meta.get_field( | ||||
|                             old_field.m2m_field_name() | ||||
|                         ), | ||||
|                         new_field.remote_field.through._meta.get_field( | ||||
|                             new_field.m2m_field_name() | ||||
|                         ), | ||||
|                     ), | ||||
|                 ], | ||||
|             ) | ||||
|             return | ||||
|  | ||||
|         # Make a new through table | ||||
|         self.create_model(new_field.remote_field.through) | ||||
|         # Copy the data across | ||||
|         self.execute( | ||||
|             "INSERT INTO %s (%s) SELECT %s FROM %s" | ||||
|             % ( | ||||
|                 self.quote_name(new_field.remote_field.through._meta.db_table), | ||||
|                 ", ".join( | ||||
|                     [ | ||||
|                         "id", | ||||
|                         new_field.m2m_column_name(), | ||||
|                         new_field.m2m_reverse_name(), | ||||
|                     ] | ||||
|                 ), | ||||
|                 ", ".join( | ||||
|                     [ | ||||
|                         "id", | ||||
|                         old_field.m2m_column_name(), | ||||
|                         old_field.m2m_reverse_name(), | ||||
|                     ] | ||||
|                 ), | ||||
|                 self.quote_name(old_field.remote_field.through._meta.db_table), | ||||
|             ) | ||||
|         ) | ||||
|         # Delete the old through table | ||||
|         self.delete_model(old_field.remote_field.through) | ||||
|  | ||||
|     def add_constraint(self, model, constraint): | ||||
|         if isinstance(constraint, UniqueConstraint) and ( | ||||
|             constraint.condition | ||||
|             or constraint.contains_expressions | ||||
|             or constraint.include | ||||
|             or constraint.deferrable | ||||
|         ): | ||||
|             super().add_constraint(model, constraint) | ||||
|         else: | ||||
|             self._remake_table(model) | ||||
|  | ||||
|     def remove_constraint(self, model, constraint): | ||||
|         if isinstance(constraint, UniqueConstraint) and ( | ||||
|             constraint.condition | ||||
|             or constraint.contains_expressions | ||||
|             or constraint.include | ||||
|             or constraint.deferrable | ||||
|         ): | ||||
|             super().remove_constraint(model, constraint) | ||||
|         else: | ||||
|             self._remake_table(model) | ||||
|  | ||||
|     def _collate_sql(self, collation): | ||||
|         return "COLLATE " + collation | ||||
| @ -0,0 +1,320 @@ | ||||
| import datetime | ||||
| import decimal | ||||
| import functools | ||||
| import logging | ||||
| import time | ||||
| from contextlib import contextmanager | ||||
|  | ||||
| from django.db import NotSupportedError | ||||
| from django.utils.crypto import md5 | ||||
| from django.utils.dateparse import parse_time | ||||
|  | ||||
| logger = logging.getLogger("django.db.backends") | ||||
|  | ||||
|  | ||||
| class CursorWrapper: | ||||
|     def __init__(self, cursor, db): | ||||
|         self.cursor = cursor | ||||
|         self.db = db | ||||
|  | ||||
|     WRAP_ERROR_ATTRS = frozenset(["fetchone", "fetchmany", "fetchall", "nextset"]) | ||||
|  | ||||
|     def __getattr__(self, attr): | ||||
|         cursor_attr = getattr(self.cursor, attr) | ||||
|         if attr in CursorWrapper.WRAP_ERROR_ATTRS: | ||||
|             return self.db.wrap_database_errors(cursor_attr) | ||||
|         else: | ||||
|             return cursor_attr | ||||
|  | ||||
|     def __iter__(self): | ||||
|         with self.db.wrap_database_errors: | ||||
|             yield from self.cursor | ||||
|  | ||||
|     def __enter__(self): | ||||
|         return self | ||||
|  | ||||
|     def __exit__(self, type, value, traceback): | ||||
|         # Close instead of passing through to avoid backend-specific behavior | ||||
|         # (#17671). Catch errors liberally because errors in cleanup code | ||||
|         # aren't useful. | ||||
|         try: | ||||
|             self.close() | ||||
|         except self.db.Database.Error: | ||||
|             pass | ||||
|  | ||||
|     # The following methods cannot be implemented in __getattr__, because the | ||||
|     # code must run when the method is invoked, not just when it is accessed. | ||||
|  | ||||
|     def callproc(self, procname, params=None, kparams=None): | ||||
|         # Keyword parameters for callproc aren't supported in PEP 249, but the | ||||
|         # database driver may support them (e.g. cx_Oracle). | ||||
|         if kparams is not None and not self.db.features.supports_callproc_kwargs: | ||||
|             raise NotSupportedError( | ||||
|                 "Keyword parameters for callproc are not supported on this " | ||||
|                 "database backend." | ||||
|             ) | ||||
|         self.db.validate_no_broken_transaction() | ||||
|         with self.db.wrap_database_errors: | ||||
|             if params is None and kparams is None: | ||||
|                 return self.cursor.callproc(procname) | ||||
|             elif kparams is None: | ||||
|                 return self.cursor.callproc(procname, params) | ||||
|             else: | ||||
|                 params = params or () | ||||
|                 return self.cursor.callproc(procname, params, kparams) | ||||
|  | ||||
|     def execute(self, sql, params=None): | ||||
|         return self._execute_with_wrappers( | ||||
|             sql, params, many=False, executor=self._execute | ||||
|         ) | ||||
|  | ||||
|     def executemany(self, sql, param_list): | ||||
|         return self._execute_with_wrappers( | ||||
|             sql, param_list, many=True, executor=self._executemany | ||||
|         ) | ||||
|  | ||||
|     def _execute_with_wrappers(self, sql, params, many, executor): | ||||
|         context = {"connection": self.db, "cursor": self} | ||||
|         for wrapper in reversed(self.db.execute_wrappers): | ||||
|             executor = functools.partial(wrapper, executor) | ||||
|         return executor(sql, params, many, context) | ||||
|  | ||||
|     def _execute(self, sql, params, *ignored_wrapper_args): | ||||
|         self.db.validate_no_broken_transaction() | ||||
|         with self.db.wrap_database_errors: | ||||
|             if params is None: | ||||
|                 # params default might be backend specific. | ||||
|                 return self.cursor.execute(sql) | ||||
|             else: | ||||
|                 return self.cursor.execute(sql, params) | ||||
|  | ||||
|     def _executemany(self, sql, param_list, *ignored_wrapper_args): | ||||
|         self.db.validate_no_broken_transaction() | ||||
|         with self.db.wrap_database_errors: | ||||
|             return self.cursor.executemany(sql, param_list) | ||||
|  | ||||
|  | ||||
| class CursorDebugWrapper(CursorWrapper): | ||||
|     # XXX callproc isn't instrumented at this time. | ||||
|  | ||||
|     def execute(self, sql, params=None): | ||||
|         with self.debug_sql(sql, params, use_last_executed_query=True): | ||||
|             return super().execute(sql, params) | ||||
|  | ||||
|     def executemany(self, sql, param_list): | ||||
|         with self.debug_sql(sql, param_list, many=True): | ||||
|             return super().executemany(sql, param_list) | ||||
|  | ||||
|     @contextmanager | ||||
|     def debug_sql( | ||||
|         self, sql=None, params=None, use_last_executed_query=False, many=False | ||||
|     ): | ||||
|         start = time.monotonic() | ||||
|         try: | ||||
|             yield | ||||
|         finally: | ||||
|             stop = time.monotonic() | ||||
|             duration = stop - start | ||||
|             if use_last_executed_query: | ||||
|                 sql = self.db.ops.last_executed_query(self.cursor, sql, params) | ||||
|             try: | ||||
|                 times = len(params) if many else "" | ||||
|             except TypeError: | ||||
|                 # params could be an iterator. | ||||
|                 times = "?" | ||||
|             self.db.queries_log.append( | ||||
|                 { | ||||
|                     "sql": "%s times: %s" % (times, sql) if many else sql, | ||||
|                     "time": "%.3f" % duration, | ||||
|                 } | ||||
|             ) | ||||
|             logger.debug( | ||||
|                 "(%.3f) %s; args=%s; alias=%s", | ||||
|                 duration, | ||||
|                 sql, | ||||
|                 params, | ||||
|                 self.db.alias, | ||||
|                 extra={ | ||||
|                     "duration": duration, | ||||
|                     "sql": sql, | ||||
|                     "params": params, | ||||
|                     "alias": self.db.alias, | ||||
|                 }, | ||||
|             ) | ||||
|  | ||||
|  | ||||
| @contextmanager | ||||
| def debug_transaction(connection, sql): | ||||
|     start = time.monotonic() | ||||
|     try: | ||||
|         yield | ||||
|     finally: | ||||
|         if connection.queries_logged: | ||||
|             stop = time.monotonic() | ||||
|             duration = stop - start | ||||
|             connection.queries_log.append( | ||||
|                 { | ||||
|                     "sql": "%s" % sql, | ||||
|                     "time": "%.3f" % duration, | ||||
|                 } | ||||
|             ) | ||||
|             logger.debug( | ||||
|                 "(%.3f) %s; args=%s; alias=%s", | ||||
|                 duration, | ||||
|                 sql, | ||||
|                 None, | ||||
|                 connection.alias, | ||||
|                 extra={ | ||||
|                     "duration": duration, | ||||
|                     "sql": sql, | ||||
|                     "alias": connection.alias, | ||||
|                 }, | ||||
|             ) | ||||
|  | ||||
|  | ||||
| def split_tzname_delta(tzname): | ||||
|     """ | ||||
|     Split a time zone name into a 3-tuple of (name, sign, offset). | ||||
|     """ | ||||
|     for sign in ["+", "-"]: | ||||
|         if sign in tzname: | ||||
|             name, offset = tzname.rsplit(sign, 1) | ||||
|             if offset and parse_time(offset): | ||||
|                 return name, sign, offset | ||||
|     return tzname, None, None | ||||
|  | ||||
|  | ||||
| ############################################### | ||||
| # Converters from database (string) to Python # | ||||
| ############################################### | ||||
|  | ||||
|  | ||||
| def typecast_date(s): | ||||
|     return ( | ||||
|         datetime.date(*map(int, s.split("-"))) if s else None | ||||
|     )  # return None if s is null | ||||
|  | ||||
|  | ||||
| def typecast_time(s):  # does NOT store time zone information | ||||
|     if not s: | ||||
|         return None | ||||
|     hour, minutes, seconds = s.split(":") | ||||
|     if "." in seconds:  # check whether seconds have a fractional part | ||||
|         seconds, microseconds = seconds.split(".") | ||||
|     else: | ||||
|         microseconds = "0" | ||||
|     return datetime.time( | ||||
|         int(hour), int(minutes), int(seconds), int((microseconds + "000000")[:6]) | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def typecast_timestamp(s):  # does NOT store time zone information | ||||
|     # "2005-07-29 15:48:00.590358-05" | ||||
|     # "2005-07-29 09:56:00-05" | ||||
|     if not s: | ||||
|         return None | ||||
|     if " " not in s: | ||||
|         return typecast_date(s) | ||||
|     d, t = s.split() | ||||
|     # Remove timezone information. | ||||
|     if "-" in t: | ||||
|         t, _ = t.split("-", 1) | ||||
|     elif "+" in t: | ||||
|         t, _ = t.split("+", 1) | ||||
|     dates = d.split("-") | ||||
|     times = t.split(":") | ||||
|     seconds = times[2] | ||||
|     if "." in seconds:  # check whether seconds have a fractional part | ||||
|         seconds, microseconds = seconds.split(".") | ||||
|     else: | ||||
|         microseconds = "0" | ||||
|     return datetime.datetime( | ||||
|         int(dates[0]), | ||||
|         int(dates[1]), | ||||
|         int(dates[2]), | ||||
|         int(times[0]), | ||||
|         int(times[1]), | ||||
|         int(seconds), | ||||
|         int((microseconds + "000000")[:6]), | ||||
|     ) | ||||
|  | ||||
|  | ||||
| ############################################### | ||||
| # Converters from Python to database (string) # | ||||
| ############################################### | ||||
|  | ||||
|  | ||||
| def split_identifier(identifier): | ||||
|     """ | ||||
|     Split an SQL identifier into a two element tuple of (namespace, name). | ||||
|  | ||||
|     The identifier could be a table, column, or sequence name might be prefixed | ||||
|     by a namespace. | ||||
|     """ | ||||
|     try: | ||||
|         namespace, name = identifier.split('"."') | ||||
|     except ValueError: | ||||
|         namespace, name = "", identifier | ||||
|     return namespace.strip('"'), name.strip('"') | ||||
|  | ||||
|  | ||||
| def truncate_name(identifier, length=None, hash_len=4): | ||||
|     """ | ||||
|     Shorten an SQL identifier to a repeatable mangled version with the given | ||||
|     length. | ||||
|  | ||||
|     If a quote stripped name contains a namespace, e.g. USERNAME"."TABLE, | ||||
|     truncate the table portion only. | ||||
|     """ | ||||
|     namespace, name = split_identifier(identifier) | ||||
|  | ||||
|     if length is None or len(name) <= length: | ||||
|         return identifier | ||||
|  | ||||
|     digest = names_digest(name, length=hash_len) | ||||
|     return "%s%s%s" % ( | ||||
|         '%s"."' % namespace if namespace else "", | ||||
|         name[: length - hash_len], | ||||
|         digest, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def names_digest(*args, length): | ||||
|     """ | ||||
|     Generate a 32-bit digest of a set of arguments that can be used to shorten | ||||
|     identifying names. | ||||
|     """ | ||||
|     h = md5(usedforsecurity=False) | ||||
|     for arg in args: | ||||
|         h.update(arg.encode()) | ||||
|     return h.hexdigest()[:length] | ||||
|  | ||||
|  | ||||
| def format_number(value, max_digits, decimal_places): | ||||
|     """ | ||||
|     Format a number into a string with the requisite number of digits and | ||||
|     decimal places. | ||||
|     """ | ||||
|     if value is None: | ||||
|         return None | ||||
|     context = decimal.getcontext().copy() | ||||
|     if max_digits is not None: | ||||
|         context.prec = max_digits | ||||
|     if decimal_places is not None: | ||||
|         value = value.quantize( | ||||
|             decimal.Decimal(1).scaleb(-decimal_places), context=context | ||||
|         ) | ||||
|     else: | ||||
|         context.traps[decimal.Rounded] = 1 | ||||
|         value = context.create_decimal(value) | ||||
|     return "{:f}".format(value) | ||||
|  | ||||
|  | ||||
| def strip_quotes(table_name): | ||||
|     """ | ||||
|     Strip quotes off of quoted table names to make them safe for use in index | ||||
|     names, sequence names, etc. For example '"USER"."TABLE"' (an Oracle naming | ||||
|     scheme) becomes 'USER"."TABLE'. | ||||
|     """ | ||||
|     has_quotes = table_name.startswith('"') and table_name.endswith('"') | ||||
|     return table_name[1:-1] if has_quotes else table_name | ||||
		Reference in New Issue
	
	Block a user