docker setup
This commit is contained in:
		| @ -0,0 +1,115 @@ | ||||
| from django.core.exceptions import ObjectDoesNotExist | ||||
| from django.db.models import signals | ||||
| from django.db.models.aggregates import *  # NOQA | ||||
| from django.db.models.aggregates import __all__ as aggregates_all | ||||
| from django.db.models.constraints import *  # NOQA | ||||
| from django.db.models.constraints import __all__ as constraints_all | ||||
| from django.db.models.deletion import ( | ||||
|     CASCADE, | ||||
|     DO_NOTHING, | ||||
|     PROTECT, | ||||
|     RESTRICT, | ||||
|     SET, | ||||
|     SET_DEFAULT, | ||||
|     SET_NULL, | ||||
|     ProtectedError, | ||||
|     RestrictedError, | ||||
| ) | ||||
| from django.db.models.enums import *  # NOQA | ||||
| from django.db.models.enums import __all__ as enums_all | ||||
| from django.db.models.expressions import ( | ||||
|     Case, | ||||
|     Exists, | ||||
|     Expression, | ||||
|     ExpressionList, | ||||
|     ExpressionWrapper, | ||||
|     F, | ||||
|     Func, | ||||
|     OrderBy, | ||||
|     OuterRef, | ||||
|     RowRange, | ||||
|     Subquery, | ||||
|     Value, | ||||
|     ValueRange, | ||||
|     When, | ||||
|     Window, | ||||
|     WindowFrame, | ||||
| ) | ||||
| from django.db.models.fields import *  # NOQA | ||||
| from django.db.models.fields import __all__ as fields_all | ||||
| from django.db.models.fields.files import FileField, ImageField | ||||
| from django.db.models.fields.json import JSONField | ||||
| from django.db.models.fields.proxy import OrderWrt | ||||
| from django.db.models.indexes import *  # NOQA | ||||
| from django.db.models.indexes import __all__ as indexes_all | ||||
| from django.db.models.lookups import Lookup, Transform | ||||
| from django.db.models.manager import Manager | ||||
| from django.db.models.query import Prefetch, QuerySet, prefetch_related_objects | ||||
| from django.db.models.query_utils import FilteredRelation, Q | ||||
|  | ||||
| # Imports that would create circular imports if sorted | ||||
| from django.db.models.base import DEFERRED, Model  # isort:skip | ||||
| from django.db.models.fields.related import (  # isort:skip | ||||
|     ForeignKey, | ||||
|     ForeignObject, | ||||
|     OneToOneField, | ||||
|     ManyToManyField, | ||||
|     ForeignObjectRel, | ||||
|     ManyToOneRel, | ||||
|     ManyToManyRel, | ||||
|     OneToOneRel, | ||||
| ) | ||||
|  | ||||
|  | ||||
| __all__ = aggregates_all + constraints_all + enums_all + fields_all + indexes_all | ||||
| __all__ += [ | ||||
|     "ObjectDoesNotExist", | ||||
|     "signals", | ||||
|     "CASCADE", | ||||
|     "DO_NOTHING", | ||||
|     "PROTECT", | ||||
|     "RESTRICT", | ||||
|     "SET", | ||||
|     "SET_DEFAULT", | ||||
|     "SET_NULL", | ||||
|     "ProtectedError", | ||||
|     "RestrictedError", | ||||
|     "Case", | ||||
|     "Exists", | ||||
|     "Expression", | ||||
|     "ExpressionList", | ||||
|     "ExpressionWrapper", | ||||
|     "F", | ||||
|     "Func", | ||||
|     "OrderBy", | ||||
|     "OuterRef", | ||||
|     "RowRange", | ||||
|     "Subquery", | ||||
|     "Value", | ||||
|     "ValueRange", | ||||
|     "When", | ||||
|     "Window", | ||||
|     "WindowFrame", | ||||
|     "FileField", | ||||
|     "ImageField", | ||||
|     "JSONField", | ||||
|     "OrderWrt", | ||||
|     "Lookup", | ||||
|     "Transform", | ||||
|     "Manager", | ||||
|     "Prefetch", | ||||
|     "Q", | ||||
|     "QuerySet", | ||||
|     "prefetch_related_objects", | ||||
|     "DEFERRED", | ||||
|     "Model", | ||||
|     "FilteredRelation", | ||||
|     "ForeignKey", | ||||
|     "ForeignObject", | ||||
|     "OneToOneField", | ||||
|     "ManyToManyField", | ||||
|     "ForeignObjectRel", | ||||
|     "ManyToOneRel", | ||||
|     "ManyToManyRel", | ||||
|     "OneToOneRel", | ||||
| ] | ||||
| @ -0,0 +1,210 @@ | ||||
| """ | ||||
| Classes to represent the definitions of aggregate functions. | ||||
| """ | ||||
| from django.core.exceptions import FieldError, FullResultSet | ||||
| from django.db.models.expressions import Case, Func, Star, Value, When | ||||
| from django.db.models.fields import IntegerField | ||||
| from django.db.models.functions.comparison import Coalesce | ||||
| from django.db.models.functions.mixins import ( | ||||
|     FixDurationInputMixin, | ||||
|     NumericOutputFieldMixin, | ||||
| ) | ||||
|  | ||||
| __all__ = [ | ||||
|     "Aggregate", | ||||
|     "Avg", | ||||
|     "Count", | ||||
|     "Max", | ||||
|     "Min", | ||||
|     "StdDev", | ||||
|     "Sum", | ||||
|     "Variance", | ||||
| ] | ||||
|  | ||||
|  | ||||
| class Aggregate(Func): | ||||
|     template = "%(function)s(%(distinct)s%(expressions)s)" | ||||
|     contains_aggregate = True | ||||
|     name = None | ||||
|     filter_template = "%s FILTER (WHERE %%(filter)s)" | ||||
|     window_compatible = True | ||||
|     allow_distinct = False | ||||
|     empty_result_set_value = None | ||||
|  | ||||
|     def __init__( | ||||
|         self, *expressions, distinct=False, filter=None, default=None, **extra | ||||
|     ): | ||||
|         if distinct and not self.allow_distinct: | ||||
|             raise TypeError("%s does not allow distinct." % self.__class__.__name__) | ||||
|         if default is not None and self.empty_result_set_value is not None: | ||||
|             raise TypeError(f"{self.__class__.__name__} does not allow default.") | ||||
|         self.distinct = distinct | ||||
|         self.filter = filter | ||||
|         self.default = default | ||||
|         super().__init__(*expressions, **extra) | ||||
|  | ||||
|     def get_source_fields(self): | ||||
|         # Don't return the filter expression since it's not a source field. | ||||
|         return [e._output_field_or_none for e in super().get_source_expressions()] | ||||
|  | ||||
|     def get_source_expressions(self): | ||||
|         source_expressions = super().get_source_expressions() | ||||
|         if self.filter: | ||||
|             return source_expressions + [self.filter] | ||||
|         return source_expressions | ||||
|  | ||||
|     def set_source_expressions(self, exprs): | ||||
|         self.filter = self.filter and exprs.pop() | ||||
|         return super().set_source_expressions(exprs) | ||||
|  | ||||
|     def resolve_expression( | ||||
|         self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False | ||||
|     ): | ||||
|         # Aggregates are not allowed in UPDATE queries, so ignore for_save | ||||
|         c = super().resolve_expression(query, allow_joins, reuse, summarize) | ||||
|         c.filter = c.filter and c.filter.resolve_expression( | ||||
|             query, allow_joins, reuse, summarize | ||||
|         ) | ||||
|         if summarize: | ||||
|             # Summarized aggregates cannot refer to summarized aggregates. | ||||
|             for ref in c.get_refs(): | ||||
|                 if query.annotations[ref].is_summary: | ||||
|                     raise FieldError( | ||||
|                         f"Cannot compute {c.name}('{ref}'): '{ref}' is an aggregate" | ||||
|                     ) | ||||
|         elif not self.is_summary: | ||||
|             # Call Aggregate.get_source_expressions() to avoid | ||||
|             # returning self.filter and including that in this loop. | ||||
|             expressions = super(Aggregate, c).get_source_expressions() | ||||
|             for index, expr in enumerate(expressions): | ||||
|                 if expr.contains_aggregate: | ||||
|                     before_resolved = self.get_source_expressions()[index] | ||||
|                     name = ( | ||||
|                         before_resolved.name | ||||
|                         if hasattr(before_resolved, "name") | ||||
|                         else repr(before_resolved) | ||||
|                     ) | ||||
|                     raise FieldError( | ||||
|                         "Cannot compute %s('%s'): '%s' is an aggregate" | ||||
|                         % (c.name, name, name) | ||||
|                     ) | ||||
|         if (default := c.default) is None: | ||||
|             return c | ||||
|         if hasattr(default, "resolve_expression"): | ||||
|             default = default.resolve_expression(query, allow_joins, reuse, summarize) | ||||
|             if default._output_field_or_none is None: | ||||
|                 default.output_field = c._output_field_or_none | ||||
|         else: | ||||
|             default = Value(default, c._output_field_or_none) | ||||
|         c.default = None  # Reset the default argument before wrapping. | ||||
|         coalesce = Coalesce(c, default, output_field=c._output_field_or_none) | ||||
|         coalesce.is_summary = c.is_summary | ||||
|         return coalesce | ||||
|  | ||||
|     @property | ||||
|     def default_alias(self): | ||||
|         expressions = self.get_source_expressions() | ||||
|         if len(expressions) == 1 and hasattr(expressions[0], "name"): | ||||
|             return "%s__%s" % (expressions[0].name, self.name.lower()) | ||||
|         raise TypeError("Complex expressions require an alias") | ||||
|  | ||||
|     def get_group_by_cols(self): | ||||
|         return [] | ||||
|  | ||||
|     def as_sql(self, compiler, connection, **extra_context): | ||||
|         extra_context["distinct"] = "DISTINCT " if self.distinct else "" | ||||
|         if self.filter: | ||||
|             if connection.features.supports_aggregate_filter_clause: | ||||
|                 try: | ||||
|                     filter_sql, filter_params = self.filter.as_sql(compiler, connection) | ||||
|                 except FullResultSet: | ||||
|                     pass | ||||
|                 else: | ||||
|                     template = self.filter_template % extra_context.get( | ||||
|                         "template", self.template | ||||
|                     ) | ||||
|                     sql, params = super().as_sql( | ||||
|                         compiler, | ||||
|                         connection, | ||||
|                         template=template, | ||||
|                         filter=filter_sql, | ||||
|                         **extra_context, | ||||
|                     ) | ||||
|                     return sql, (*params, *filter_params) | ||||
|             else: | ||||
|                 copy = self.copy() | ||||
|                 copy.filter = None | ||||
|                 source_expressions = copy.get_source_expressions() | ||||
|                 condition = When(self.filter, then=source_expressions[0]) | ||||
|                 copy.set_source_expressions([Case(condition)] + source_expressions[1:]) | ||||
|                 return super(Aggregate, copy).as_sql( | ||||
|                     compiler, connection, **extra_context | ||||
|                 ) | ||||
|         return super().as_sql(compiler, connection, **extra_context) | ||||
|  | ||||
|     def _get_repr_options(self): | ||||
|         options = super()._get_repr_options() | ||||
|         if self.distinct: | ||||
|             options["distinct"] = self.distinct | ||||
|         if self.filter: | ||||
|             options["filter"] = self.filter | ||||
|         return options | ||||
|  | ||||
|  | ||||
| class Avg(FixDurationInputMixin, NumericOutputFieldMixin, Aggregate): | ||||
|     function = "AVG" | ||||
|     name = "Avg" | ||||
|     allow_distinct = True | ||||
|  | ||||
|  | ||||
| class Count(Aggregate): | ||||
|     function = "COUNT" | ||||
|     name = "Count" | ||||
|     output_field = IntegerField() | ||||
|     allow_distinct = True | ||||
|     empty_result_set_value = 0 | ||||
|  | ||||
|     def __init__(self, expression, filter=None, **extra): | ||||
|         if expression == "*": | ||||
|             expression = Star() | ||||
|         if isinstance(expression, Star) and filter is not None: | ||||
|             raise ValueError("Star cannot be used with filter. Please specify a field.") | ||||
|         super().__init__(expression, filter=filter, **extra) | ||||
|  | ||||
|  | ||||
| class Max(Aggregate): | ||||
|     function = "MAX" | ||||
|     name = "Max" | ||||
|  | ||||
|  | ||||
| class Min(Aggregate): | ||||
|     function = "MIN" | ||||
|     name = "Min" | ||||
|  | ||||
|  | ||||
| class StdDev(NumericOutputFieldMixin, Aggregate): | ||||
|     name = "StdDev" | ||||
|  | ||||
|     def __init__(self, expression, sample=False, **extra): | ||||
|         self.function = "STDDEV_SAMP" if sample else "STDDEV_POP" | ||||
|         super().__init__(expression, **extra) | ||||
|  | ||||
|     def _get_repr_options(self): | ||||
|         return {**super()._get_repr_options(), "sample": self.function == "STDDEV_SAMP"} | ||||
|  | ||||
|  | ||||
| class Sum(FixDurationInputMixin, Aggregate): | ||||
|     function = "SUM" | ||||
|     name = "Sum" | ||||
|     allow_distinct = True | ||||
|  | ||||
|  | ||||
| class Variance(NumericOutputFieldMixin, Aggregate): | ||||
|     name = "Variance" | ||||
|  | ||||
|     def __init__(self, expression, sample=False, **extra): | ||||
|         self.function = "VAR_SAMP" if sample else "VAR_POP" | ||||
|         super().__init__(expression, **extra) | ||||
|  | ||||
|     def _get_repr_options(self): | ||||
|         return {**super()._get_repr_options(), "sample": self.function == "VAR_SAMP"} | ||||
							
								
								
									
										2531
									
								
								srcs/.venv/lib/python3.11/site-packages/django/db/models/base.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2531
									
								
								srcs/.venv/lib/python3.11/site-packages/django/db/models/base.py
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -0,0 +1,12 @@ | ||||
| """ | ||||
| Constants used across the ORM in general. | ||||
| """ | ||||
| from enum import Enum | ||||
|  | ||||
| # Separator used to split filter strings apart. | ||||
| LOOKUP_SEP = "__" | ||||
|  | ||||
|  | ||||
| class OnConflict(Enum): | ||||
|     IGNORE = "ignore" | ||||
|     UPDATE = "update" | ||||
| @ -0,0 +1,371 @@ | ||||
| from enum import Enum | ||||
|  | ||||
| from django.core.exceptions import FieldError, ValidationError | ||||
| from django.db import connections | ||||
| from django.db.models.expressions import Exists, ExpressionList, F, OrderBy | ||||
| from django.db.models.indexes import IndexExpression | ||||
| from django.db.models.lookups import Exact | ||||
| from django.db.models.query_utils import Q | ||||
| from django.db.models.sql.query import Query | ||||
| from django.db.utils import DEFAULT_DB_ALIAS | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
|  | ||||
| __all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"] | ||||
|  | ||||
|  | ||||
| class BaseConstraint: | ||||
|     default_violation_error_message = _("Constraint “%(name)s” is violated.") | ||||
|     violation_error_message = None | ||||
|  | ||||
|     def __init__(self, name, violation_error_message=None): | ||||
|         self.name = name | ||||
|         if violation_error_message is not None: | ||||
|             self.violation_error_message = violation_error_message | ||||
|         else: | ||||
|             self.violation_error_message = self.default_violation_error_message | ||||
|  | ||||
|     @property | ||||
|     def contains_expressions(self): | ||||
|         return False | ||||
|  | ||||
|     def constraint_sql(self, model, schema_editor): | ||||
|         raise NotImplementedError("This method must be implemented by a subclass.") | ||||
|  | ||||
|     def create_sql(self, model, schema_editor): | ||||
|         raise NotImplementedError("This method must be implemented by a subclass.") | ||||
|  | ||||
|     def remove_sql(self, model, schema_editor): | ||||
|         raise NotImplementedError("This method must be implemented by a subclass.") | ||||
|  | ||||
|     def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS): | ||||
|         raise NotImplementedError("This method must be implemented by a subclass.") | ||||
|  | ||||
|     def get_violation_error_message(self): | ||||
|         return self.violation_error_message % {"name": self.name} | ||||
|  | ||||
|     def deconstruct(self): | ||||
|         path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__) | ||||
|         path = path.replace("django.db.models.constraints", "django.db.models") | ||||
|         kwargs = {"name": self.name} | ||||
|         if ( | ||||
|             self.violation_error_message is not None | ||||
|             and self.violation_error_message != self.default_violation_error_message | ||||
|         ): | ||||
|             kwargs["violation_error_message"] = self.violation_error_message | ||||
|         return (path, (), kwargs) | ||||
|  | ||||
|     def clone(self): | ||||
|         _, args, kwargs = self.deconstruct() | ||||
|         return self.__class__(*args, **kwargs) | ||||
|  | ||||
|  | ||||
| class CheckConstraint(BaseConstraint): | ||||
|     def __init__(self, *, check, name, violation_error_message=None): | ||||
|         self.check = check | ||||
|         if not getattr(check, "conditional", False): | ||||
|             raise TypeError( | ||||
|                 "CheckConstraint.check must be a Q instance or boolean expression." | ||||
|             ) | ||||
|         super().__init__(name, violation_error_message=violation_error_message) | ||||
|  | ||||
|     def _get_check_sql(self, model, schema_editor): | ||||
|         query = Query(model=model, alias_cols=False) | ||||
|         where = query.build_where(self.check) | ||||
|         compiler = query.get_compiler(connection=schema_editor.connection) | ||||
|         sql, params = where.as_sql(compiler, schema_editor.connection) | ||||
|         return sql % tuple(schema_editor.quote_value(p) for p in params) | ||||
|  | ||||
|     def constraint_sql(self, model, schema_editor): | ||||
|         check = self._get_check_sql(model, schema_editor) | ||||
|         return schema_editor._check_sql(self.name, check) | ||||
|  | ||||
|     def create_sql(self, model, schema_editor): | ||||
|         check = self._get_check_sql(model, schema_editor) | ||||
|         return schema_editor._create_check_sql(model, self.name, check) | ||||
|  | ||||
|     def remove_sql(self, model, schema_editor): | ||||
|         return schema_editor._delete_check_sql(model, self.name) | ||||
|  | ||||
|     def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS): | ||||
|         against = instance._get_field_value_map(meta=model._meta, exclude=exclude) | ||||
|         try: | ||||
|             if not Q(self.check).check(against, using=using): | ||||
|                 raise ValidationError(self.get_violation_error_message()) | ||||
|         except FieldError: | ||||
|             pass | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "<%s: check=%s name=%s>" % ( | ||||
|             self.__class__.__qualname__, | ||||
|             self.check, | ||||
|             repr(self.name), | ||||
|         ) | ||||
|  | ||||
|     def __eq__(self, other): | ||||
|         if isinstance(other, CheckConstraint): | ||||
|             return ( | ||||
|                 self.name == other.name | ||||
|                 and self.check == other.check | ||||
|                 and self.violation_error_message == other.violation_error_message | ||||
|             ) | ||||
|         return super().__eq__(other) | ||||
|  | ||||
|     def deconstruct(self): | ||||
|         path, args, kwargs = super().deconstruct() | ||||
|         kwargs["check"] = self.check | ||||
|         return path, args, kwargs | ||||
|  | ||||
|  | ||||
| class Deferrable(Enum): | ||||
|     DEFERRED = "deferred" | ||||
|     IMMEDIATE = "immediate" | ||||
|  | ||||
|     # A similar format was proposed for Python 3.10. | ||||
|     def __repr__(self): | ||||
|         return f"{self.__class__.__qualname__}.{self._name_}" | ||||
|  | ||||
|  | ||||
| class UniqueConstraint(BaseConstraint): | ||||
|     def __init__( | ||||
|         self, | ||||
|         *expressions, | ||||
|         fields=(), | ||||
|         name=None, | ||||
|         condition=None, | ||||
|         deferrable=None, | ||||
|         include=None, | ||||
|         opclasses=(), | ||||
|         violation_error_message=None, | ||||
|     ): | ||||
|         if not name: | ||||
|             raise ValueError("A unique constraint must be named.") | ||||
|         if not expressions and not fields: | ||||
|             raise ValueError( | ||||
|                 "At least one field or expression is required to define a " | ||||
|                 "unique constraint." | ||||
|             ) | ||||
|         if expressions and fields: | ||||
|             raise ValueError( | ||||
|                 "UniqueConstraint.fields and expressions are mutually exclusive." | ||||
|             ) | ||||
|         if not isinstance(condition, (type(None), Q)): | ||||
|             raise ValueError("UniqueConstraint.condition must be a Q instance.") | ||||
|         if condition and deferrable: | ||||
|             raise ValueError("UniqueConstraint with conditions cannot be deferred.") | ||||
|         if include and deferrable: | ||||
|             raise ValueError("UniqueConstraint with include fields cannot be deferred.") | ||||
|         if opclasses and deferrable: | ||||
|             raise ValueError("UniqueConstraint with opclasses cannot be deferred.") | ||||
|         if expressions and deferrable: | ||||
|             raise ValueError("UniqueConstraint with expressions cannot be deferred.") | ||||
|         if expressions and opclasses: | ||||
|             raise ValueError( | ||||
|                 "UniqueConstraint.opclasses cannot be used with expressions. " | ||||
|                 "Use django.contrib.postgres.indexes.OpClass() instead." | ||||
|             ) | ||||
|         if not isinstance(deferrable, (type(None), Deferrable)): | ||||
|             raise ValueError( | ||||
|                 "UniqueConstraint.deferrable must be a Deferrable instance." | ||||
|             ) | ||||
|         if not isinstance(include, (type(None), list, tuple)): | ||||
|             raise ValueError("UniqueConstraint.include must be a list or tuple.") | ||||
|         if not isinstance(opclasses, (list, tuple)): | ||||
|             raise ValueError("UniqueConstraint.opclasses must be a list or tuple.") | ||||
|         if opclasses and len(fields) != len(opclasses): | ||||
|             raise ValueError( | ||||
|                 "UniqueConstraint.fields and UniqueConstraint.opclasses must " | ||||
|                 "have the same number of elements." | ||||
|             ) | ||||
|         self.fields = tuple(fields) | ||||
|         self.condition = condition | ||||
|         self.deferrable = deferrable | ||||
|         self.include = tuple(include) if include else () | ||||
|         self.opclasses = opclasses | ||||
|         self.expressions = tuple( | ||||
|             F(expression) if isinstance(expression, str) else expression | ||||
|             for expression in expressions | ||||
|         ) | ||||
|         super().__init__(name, violation_error_message=violation_error_message) | ||||
|  | ||||
|     @property | ||||
|     def contains_expressions(self): | ||||
|         return bool(self.expressions) | ||||
|  | ||||
|     def _get_condition_sql(self, model, schema_editor): | ||||
|         if self.condition is None: | ||||
|             return None | ||||
|         query = Query(model=model, alias_cols=False) | ||||
|         where = query.build_where(self.condition) | ||||
|         compiler = query.get_compiler(connection=schema_editor.connection) | ||||
|         sql, params = where.as_sql(compiler, schema_editor.connection) | ||||
|         return sql % tuple(schema_editor.quote_value(p) for p in params) | ||||
|  | ||||
|     def _get_index_expressions(self, model, schema_editor): | ||||
|         if not self.expressions: | ||||
|             return None | ||||
|         index_expressions = [] | ||||
|         for expression in self.expressions: | ||||
|             index_expression = IndexExpression(expression) | ||||
|             index_expression.set_wrapper_classes(schema_editor.connection) | ||||
|             index_expressions.append(index_expression) | ||||
|         return ExpressionList(*index_expressions).resolve_expression( | ||||
|             Query(model, alias_cols=False), | ||||
|         ) | ||||
|  | ||||
|     def constraint_sql(self, model, schema_editor): | ||||
|         fields = [model._meta.get_field(field_name) for field_name in self.fields] | ||||
|         include = [ | ||||
|             model._meta.get_field(field_name).column for field_name in self.include | ||||
|         ] | ||||
|         condition = self._get_condition_sql(model, schema_editor) | ||||
|         expressions = self._get_index_expressions(model, schema_editor) | ||||
|         return schema_editor._unique_sql( | ||||
|             model, | ||||
|             fields, | ||||
|             self.name, | ||||
|             condition=condition, | ||||
|             deferrable=self.deferrable, | ||||
|             include=include, | ||||
|             opclasses=self.opclasses, | ||||
|             expressions=expressions, | ||||
|         ) | ||||
|  | ||||
|     def create_sql(self, model, schema_editor): | ||||
|         fields = [model._meta.get_field(field_name) for field_name in self.fields] | ||||
|         include = [ | ||||
|             model._meta.get_field(field_name).column for field_name in self.include | ||||
|         ] | ||||
|         condition = self._get_condition_sql(model, schema_editor) | ||||
|         expressions = self._get_index_expressions(model, schema_editor) | ||||
|         return schema_editor._create_unique_sql( | ||||
|             model, | ||||
|             fields, | ||||
|             self.name, | ||||
|             condition=condition, | ||||
|             deferrable=self.deferrable, | ||||
|             include=include, | ||||
|             opclasses=self.opclasses, | ||||
|             expressions=expressions, | ||||
|         ) | ||||
|  | ||||
|     def remove_sql(self, model, schema_editor): | ||||
|         condition = self._get_condition_sql(model, schema_editor) | ||||
|         include = [ | ||||
|             model._meta.get_field(field_name).column for field_name in self.include | ||||
|         ] | ||||
|         expressions = self._get_index_expressions(model, schema_editor) | ||||
|         return schema_editor._delete_unique_sql( | ||||
|             model, | ||||
|             self.name, | ||||
|             condition=condition, | ||||
|             deferrable=self.deferrable, | ||||
|             include=include, | ||||
|             opclasses=self.opclasses, | ||||
|             expressions=expressions, | ||||
|         ) | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "<%s:%s%s%s%s%s%s%s>" % ( | ||||
|             self.__class__.__qualname__, | ||||
|             "" if not self.fields else " fields=%s" % repr(self.fields), | ||||
|             "" if not self.expressions else " expressions=%s" % repr(self.expressions), | ||||
|             " name=%s" % repr(self.name), | ||||
|             "" if self.condition is None else " condition=%s" % self.condition, | ||||
|             "" if self.deferrable is None else " deferrable=%r" % self.deferrable, | ||||
|             "" if not self.include else " include=%s" % repr(self.include), | ||||
|             "" if not self.opclasses else " opclasses=%s" % repr(self.opclasses), | ||||
|         ) | ||||
|  | ||||
|     def __eq__(self, other): | ||||
|         if isinstance(other, UniqueConstraint): | ||||
|             return ( | ||||
|                 self.name == other.name | ||||
|                 and self.fields == other.fields | ||||
|                 and self.condition == other.condition | ||||
|                 and self.deferrable == other.deferrable | ||||
|                 and self.include == other.include | ||||
|                 and self.opclasses == other.opclasses | ||||
|                 and self.expressions == other.expressions | ||||
|                 and self.violation_error_message == other.violation_error_message | ||||
|             ) | ||||
|         return super().__eq__(other) | ||||
|  | ||||
|     def deconstruct(self): | ||||
|         path, args, kwargs = super().deconstruct() | ||||
|         if self.fields: | ||||
|             kwargs["fields"] = self.fields | ||||
|         if self.condition: | ||||
|             kwargs["condition"] = self.condition | ||||
|         if self.deferrable: | ||||
|             kwargs["deferrable"] = self.deferrable | ||||
|         if self.include: | ||||
|             kwargs["include"] = self.include | ||||
|         if self.opclasses: | ||||
|             kwargs["opclasses"] = self.opclasses | ||||
|         return path, self.expressions, kwargs | ||||
|  | ||||
|     def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS): | ||||
|         queryset = model._default_manager.using(using) | ||||
|         if self.fields: | ||||
|             lookup_kwargs = {} | ||||
|             for field_name in self.fields: | ||||
|                 if exclude and field_name in exclude: | ||||
|                     return | ||||
|                 field = model._meta.get_field(field_name) | ||||
|                 lookup_value = getattr(instance, field.attname) | ||||
|                 if lookup_value is None or ( | ||||
|                     lookup_value == "" | ||||
|                     and connections[using].features.interprets_empty_strings_as_nulls | ||||
|                 ): | ||||
|                     # A composite constraint containing NULL value cannot cause | ||||
|                     # a violation since NULL != NULL in SQL. | ||||
|                     return | ||||
|                 lookup_kwargs[field.name] = lookup_value | ||||
|             queryset = queryset.filter(**lookup_kwargs) | ||||
|         else: | ||||
|             # Ignore constraints with excluded fields. | ||||
|             if exclude: | ||||
|                 for expression in self.expressions: | ||||
|                     if hasattr(expression, "flatten"): | ||||
|                         for expr in expression.flatten(): | ||||
|                             if isinstance(expr, F) and expr.name in exclude: | ||||
|                                 return | ||||
|                     elif isinstance(expression, F) and expression.name in exclude: | ||||
|                         return | ||||
|             replacements = { | ||||
|                 F(field): value | ||||
|                 for field, value in instance._get_field_value_map( | ||||
|                     meta=model._meta, exclude=exclude | ||||
|                 ).items() | ||||
|             } | ||||
|             expressions = [] | ||||
|             for expr in self.expressions: | ||||
|                 # Ignore ordering. | ||||
|                 if isinstance(expr, OrderBy): | ||||
|                     expr = expr.expression | ||||
|                 expressions.append(Exact(expr, expr.replace_expressions(replacements))) | ||||
|             queryset = queryset.filter(*expressions) | ||||
|         model_class_pk = instance._get_pk_val(model._meta) | ||||
|         if not instance._state.adding and model_class_pk is not None: | ||||
|             queryset = queryset.exclude(pk=model_class_pk) | ||||
|         if not self.condition: | ||||
|             if queryset.exists(): | ||||
|                 if self.expressions: | ||||
|                     raise ValidationError(self.get_violation_error_message()) | ||||
|                 # When fields are defined, use the unique_error_message() for | ||||
|                 # backward compatibility. | ||||
|                 for model, constraints in instance.get_constraints(): | ||||
|                     for constraint in constraints: | ||||
|                         if constraint is self: | ||||
|                             raise ValidationError( | ||||
|                                 instance.unique_error_message(model, self.fields) | ||||
|                             ) | ||||
|         else: | ||||
|             against = instance._get_field_value_map(meta=model._meta, exclude=exclude) | ||||
|             try: | ||||
|                 if (self.condition & Exists(queryset.filter(self.condition))).check( | ||||
|                     against, using=using | ||||
|                 ): | ||||
|                     raise ValidationError(self.get_violation_error_message()) | ||||
|             except FieldError: | ||||
|                 pass | ||||
| @ -0,0 +1,522 @@ | ||||
| from collections import Counter, defaultdict | ||||
| from functools import partial, reduce | ||||
| from itertools import chain | ||||
| from operator import attrgetter, or_ | ||||
|  | ||||
| from django.db import IntegrityError, connections, models, transaction | ||||
| from django.db.models import query_utils, signals, sql | ||||
|  | ||||
|  | ||||
| class ProtectedError(IntegrityError): | ||||
|     def __init__(self, msg, protected_objects): | ||||
|         self.protected_objects = protected_objects | ||||
|         super().__init__(msg, protected_objects) | ||||
|  | ||||
|  | ||||
| class RestrictedError(IntegrityError): | ||||
|     def __init__(self, msg, restricted_objects): | ||||
|         self.restricted_objects = restricted_objects | ||||
|         super().__init__(msg, restricted_objects) | ||||
|  | ||||
|  | ||||
| def CASCADE(collector, field, sub_objs, using): | ||||
|     collector.collect( | ||||
|         sub_objs, | ||||
|         source=field.remote_field.model, | ||||
|         source_attr=field.name, | ||||
|         nullable=field.null, | ||||
|         fail_on_restricted=False, | ||||
|     ) | ||||
|     if field.null and not connections[using].features.can_defer_constraint_checks: | ||||
|         collector.add_field_update(field, None, sub_objs) | ||||
|  | ||||
|  | ||||
| def PROTECT(collector, field, sub_objs, using): | ||||
|     raise ProtectedError( | ||||
|         "Cannot delete some instances of model '%s' because they are " | ||||
|         "referenced through a protected foreign key: '%s.%s'" | ||||
|         % ( | ||||
|             field.remote_field.model.__name__, | ||||
|             sub_objs[0].__class__.__name__, | ||||
|             field.name, | ||||
|         ), | ||||
|         sub_objs, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def RESTRICT(collector, field, sub_objs, using): | ||||
|     collector.add_restricted_objects(field, sub_objs) | ||||
|     collector.add_dependency(field.remote_field.model, field.model) | ||||
|  | ||||
|  | ||||
| def SET(value): | ||||
|     if callable(value): | ||||
|  | ||||
|         def set_on_delete(collector, field, sub_objs, using): | ||||
|             collector.add_field_update(field, value(), sub_objs) | ||||
|  | ||||
|     else: | ||||
|  | ||||
|         def set_on_delete(collector, field, sub_objs, using): | ||||
|             collector.add_field_update(field, value, sub_objs) | ||||
|  | ||||
|     set_on_delete.deconstruct = lambda: ("django.db.models.SET", (value,), {}) | ||||
|     set_on_delete.lazy_sub_objs = True | ||||
|     return set_on_delete | ||||
|  | ||||
|  | ||||
| def SET_NULL(collector, field, sub_objs, using): | ||||
|     collector.add_field_update(field, None, sub_objs) | ||||
|  | ||||
|  | ||||
| SET_NULL.lazy_sub_objs = True | ||||
|  | ||||
|  | ||||
| def SET_DEFAULT(collector, field, sub_objs, using): | ||||
|     collector.add_field_update(field, field.get_default(), sub_objs) | ||||
|  | ||||
|  | ||||
| SET_DEFAULT.lazy_sub_objs = True | ||||
|  | ||||
|  | ||||
| def DO_NOTHING(collector, field, sub_objs, using): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| def get_candidate_relations_to_delete(opts): | ||||
|     # The candidate relations are the ones that come from N-1 and 1-1 relations. | ||||
|     # N-N  (i.e., many-to-many) relations aren't candidates for deletion. | ||||
|     return ( | ||||
|         f | ||||
|         for f in opts.get_fields(include_hidden=True) | ||||
|         if f.auto_created and not f.concrete and (f.one_to_one or f.one_to_many) | ||||
|     ) | ||||
|  | ||||
|  | ||||
| class Collector: | ||||
|     def __init__(self, using, origin=None): | ||||
|         self.using = using | ||||
|         # A Model or QuerySet object. | ||||
|         self.origin = origin | ||||
|         # Initially, {model: {instances}}, later values become lists. | ||||
|         self.data = defaultdict(set) | ||||
|         # {(field, value): [instances, …]} | ||||
|         self.field_updates = defaultdict(list) | ||||
|         # {model: {field: {instances}}} | ||||
|         self.restricted_objects = defaultdict(partial(defaultdict, set)) | ||||
|         # fast_deletes is a list of queryset-likes that can be deleted without | ||||
|         # fetching the objects into memory. | ||||
|         self.fast_deletes = [] | ||||
|  | ||||
|         # Tracks deletion-order dependency for databases without transactions | ||||
|         # or ability to defer constraint checks. Only concrete model classes | ||||
|         # should be included, as the dependencies exist only between actual | ||||
|         # database tables; proxy models are represented here by their concrete | ||||
|         # parent. | ||||
|         self.dependencies = defaultdict(set)  # {model: {models}} | ||||
|  | ||||
|     def add(self, objs, source=None, nullable=False, reverse_dependency=False): | ||||
|         """ | ||||
|         Add 'objs' to the collection of objects to be deleted.  If the call is | ||||
|         the result of a cascade, 'source' should be the model that caused it, | ||||
|         and 'nullable' should be set to True if the relation can be null. | ||||
|  | ||||
|         Return a list of all objects that were not already collected. | ||||
|         """ | ||||
|         if not objs: | ||||
|             return [] | ||||
|         new_objs = [] | ||||
|         model = objs[0].__class__ | ||||
|         instances = self.data[model] | ||||
|         for obj in objs: | ||||
|             if obj not in instances: | ||||
|                 new_objs.append(obj) | ||||
|         instances.update(new_objs) | ||||
|         # Nullable relationships can be ignored -- they are nulled out before | ||||
|         # deleting, and therefore do not affect the order in which objects have | ||||
|         # to be deleted. | ||||
|         if source is not None and not nullable: | ||||
|             self.add_dependency(source, model, reverse_dependency=reverse_dependency) | ||||
|         return new_objs | ||||
|  | ||||
|     def add_dependency(self, model, dependency, reverse_dependency=False): | ||||
|         if reverse_dependency: | ||||
|             model, dependency = dependency, model | ||||
|         self.dependencies[model._meta.concrete_model].add( | ||||
|             dependency._meta.concrete_model | ||||
|         ) | ||||
|         self.data.setdefault(dependency, self.data.default_factory()) | ||||
|  | ||||
|     def add_field_update(self, field, value, objs): | ||||
|         """ | ||||
|         Schedule a field update. 'objs' must be a homogeneous iterable | ||||
|         collection of model instances (e.g. a QuerySet). | ||||
|         """ | ||||
|         self.field_updates[field, value].append(objs) | ||||
|  | ||||
|     def add_restricted_objects(self, field, objs): | ||||
|         if objs: | ||||
|             model = objs[0].__class__ | ||||
|             self.restricted_objects[model][field].update(objs) | ||||
|  | ||||
|     def clear_restricted_objects_from_set(self, model, objs): | ||||
|         if model in self.restricted_objects: | ||||
|             self.restricted_objects[model] = { | ||||
|                 field: items - objs | ||||
|                 for field, items in self.restricted_objects[model].items() | ||||
|             } | ||||
|  | ||||
|     def clear_restricted_objects_from_queryset(self, model, qs): | ||||
|         if model in self.restricted_objects: | ||||
|             objs = set( | ||||
|                 qs.filter( | ||||
|                     pk__in=[ | ||||
|                         obj.pk | ||||
|                         for objs in self.restricted_objects[model].values() | ||||
|                         for obj in objs | ||||
|                     ] | ||||
|                 ) | ||||
|             ) | ||||
|             self.clear_restricted_objects_from_set(model, objs) | ||||
|  | ||||
|     def _has_signal_listeners(self, model): | ||||
|         return signals.pre_delete.has_listeners( | ||||
|             model | ||||
|         ) or signals.post_delete.has_listeners(model) | ||||
|  | ||||
|     def can_fast_delete(self, objs, from_field=None): | ||||
|         """ | ||||
|         Determine if the objects in the given queryset-like or single object | ||||
|         can be fast-deleted. This can be done if there are no cascades, no | ||||
|         parents and no signal listeners for the object class. | ||||
|  | ||||
|         The 'from_field' tells where we are coming from - we need this to | ||||
|         determine if the objects are in fact to be deleted. Allow also | ||||
|         skipping parent -> child -> parent chain preventing fast delete of | ||||
|         the child. | ||||
|         """ | ||||
|         if from_field and from_field.remote_field.on_delete is not CASCADE: | ||||
|             return False | ||||
|         if hasattr(objs, "_meta"): | ||||
|             model = objs._meta.model | ||||
|         elif hasattr(objs, "model") and hasattr(objs, "_raw_delete"): | ||||
|             model = objs.model | ||||
|         else: | ||||
|             return False | ||||
|         if self._has_signal_listeners(model): | ||||
|             return False | ||||
|         # The use of from_field comes from the need to avoid cascade back to | ||||
|         # parent when parent delete is cascading to child. | ||||
|         opts = model._meta | ||||
|         return ( | ||||
|             all( | ||||
|                 link == from_field | ||||
|                 for link in opts.concrete_model._meta.parents.values() | ||||
|             ) | ||||
|             and | ||||
|             # Foreign keys pointing to this model. | ||||
|             all( | ||||
|                 related.field.remote_field.on_delete is DO_NOTHING | ||||
|                 for related in get_candidate_relations_to_delete(opts) | ||||
|             ) | ||||
|             and ( | ||||
|                 # Something like generic foreign key. | ||||
|                 not any( | ||||
|                     hasattr(field, "bulk_related_objects") | ||||
|                     for field in opts.private_fields | ||||
|                 ) | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|     def get_del_batches(self, objs, fields): | ||||
|         """ | ||||
|         Return the objs in suitably sized batches for the used connection. | ||||
|         """ | ||||
|         field_names = [field.name for field in fields] | ||||
|         conn_batch_size = max( | ||||
|             connections[self.using].ops.bulk_batch_size(field_names, objs), 1 | ||||
|         ) | ||||
|         if len(objs) > conn_batch_size: | ||||
|             return [ | ||||
|                 objs[i : i + conn_batch_size] | ||||
|                 for i in range(0, len(objs), conn_batch_size) | ||||
|             ] | ||||
|         else: | ||||
|             return [objs] | ||||
|  | ||||
|     def collect( | ||||
|         self, | ||||
|         objs, | ||||
|         source=None, | ||||
|         nullable=False, | ||||
|         collect_related=True, | ||||
|         source_attr=None, | ||||
|         reverse_dependency=False, | ||||
|         keep_parents=False, | ||||
|         fail_on_restricted=True, | ||||
|     ): | ||||
|         """ | ||||
|         Add 'objs' to the collection of objects to be deleted as well as all | ||||
|         parent instances.  'objs' must be a homogeneous iterable collection of | ||||
|         model instances (e.g. a QuerySet).  If 'collect_related' is True, | ||||
|         related objects will be handled by their respective on_delete handler. | ||||
|  | ||||
|         If the call is the result of a cascade, 'source' should be the model | ||||
|         that caused it and 'nullable' should be set to True, if the relation | ||||
|         can be null. | ||||
|  | ||||
|         If 'reverse_dependency' is True, 'source' will be deleted before the | ||||
|         current model, rather than after. (Needed for cascading to parent | ||||
|         models, the one case in which the cascade follows the forwards | ||||
|         direction of an FK rather than the reverse direction.) | ||||
|  | ||||
|         If 'keep_parents' is True, data of parent model's will be not deleted. | ||||
|  | ||||
|         If 'fail_on_restricted' is False, error won't be raised even if it's | ||||
|         prohibited to delete such objects due to RESTRICT, that defers | ||||
|         restricted object checking in recursive calls where the top-level call | ||||
|         may need to collect more objects to determine whether restricted ones | ||||
|         can be deleted. | ||||
|         """ | ||||
|         if self.can_fast_delete(objs): | ||||
|             self.fast_deletes.append(objs) | ||||
|             return | ||||
|         new_objs = self.add( | ||||
|             objs, source, nullable, reverse_dependency=reverse_dependency | ||||
|         ) | ||||
|         if not new_objs: | ||||
|             return | ||||
|  | ||||
|         model = new_objs[0].__class__ | ||||
|  | ||||
|         if not keep_parents: | ||||
|             # Recursively collect concrete model's parent models, but not their | ||||
|             # related objects. These will be found by meta.get_fields() | ||||
|             concrete_model = model._meta.concrete_model | ||||
|             for ptr in concrete_model._meta.parents.values(): | ||||
|                 if ptr: | ||||
|                     parent_objs = [getattr(obj, ptr.name) for obj in new_objs] | ||||
|                     self.collect( | ||||
|                         parent_objs, | ||||
|                         source=model, | ||||
|                         source_attr=ptr.remote_field.related_name, | ||||
|                         collect_related=False, | ||||
|                         reverse_dependency=True, | ||||
|                         fail_on_restricted=False, | ||||
|                     ) | ||||
|         if not collect_related: | ||||
|             return | ||||
|  | ||||
|         if keep_parents: | ||||
|             parents = set(model._meta.get_parent_list()) | ||||
|         model_fast_deletes = defaultdict(list) | ||||
|         protected_objects = defaultdict(list) | ||||
|         for related in get_candidate_relations_to_delete(model._meta): | ||||
|             # Preserve parent reverse relationships if keep_parents=True. | ||||
|             if keep_parents and related.model in parents: | ||||
|                 continue | ||||
|             field = related.field | ||||
|             on_delete = field.remote_field.on_delete | ||||
|             if on_delete == DO_NOTHING: | ||||
|                 continue | ||||
|             related_model = related.related_model | ||||
|             if self.can_fast_delete(related_model, from_field=field): | ||||
|                 model_fast_deletes[related_model].append(field) | ||||
|                 continue | ||||
|             batches = self.get_del_batches(new_objs, [field]) | ||||
|             for batch in batches: | ||||
|                 sub_objs = self.related_objects(related_model, [field], batch) | ||||
|                 # Non-referenced fields can be deferred if no signal receivers | ||||
|                 # are connected for the related model as they'll never be | ||||
|                 # exposed to the user. Skip field deferring when some | ||||
|                 # relationships are select_related as interactions between both | ||||
|                 # features are hard to get right. This should only happen in | ||||
|                 # the rare cases where .related_objects is overridden anyway. | ||||
|                 if not ( | ||||
|                     sub_objs.query.select_related | ||||
|                     or self._has_signal_listeners(related_model) | ||||
|                 ): | ||||
|                     referenced_fields = set( | ||||
|                         chain.from_iterable( | ||||
|                             (rf.attname for rf in rel.field.foreign_related_fields) | ||||
|                             for rel in get_candidate_relations_to_delete( | ||||
|                                 related_model._meta | ||||
|                             ) | ||||
|                         ) | ||||
|                     ) | ||||
|                     sub_objs = sub_objs.only(*tuple(referenced_fields)) | ||||
|                 if getattr(on_delete, "lazy_sub_objs", False) or sub_objs: | ||||
|                     try: | ||||
|                         on_delete(self, field, sub_objs, self.using) | ||||
|                     except ProtectedError as error: | ||||
|                         key = "'%s.%s'" % (field.model.__name__, field.name) | ||||
|                         protected_objects[key] += error.protected_objects | ||||
|         if protected_objects: | ||||
|             raise ProtectedError( | ||||
|                 "Cannot delete some instances of model %r because they are " | ||||
|                 "referenced through protected foreign keys: %s." | ||||
|                 % ( | ||||
|                     model.__name__, | ||||
|                     ", ".join(protected_objects), | ||||
|                 ), | ||||
|                 set(chain.from_iterable(protected_objects.values())), | ||||
|             ) | ||||
|         for related_model, related_fields in model_fast_deletes.items(): | ||||
|             batches = self.get_del_batches(new_objs, related_fields) | ||||
|             for batch in batches: | ||||
|                 sub_objs = self.related_objects(related_model, related_fields, batch) | ||||
|                 self.fast_deletes.append(sub_objs) | ||||
|         for field in model._meta.private_fields: | ||||
|             if hasattr(field, "bulk_related_objects"): | ||||
|                 # It's something like generic foreign key. | ||||
|                 sub_objs = field.bulk_related_objects(new_objs, self.using) | ||||
|                 self.collect( | ||||
|                     sub_objs, source=model, nullable=True, fail_on_restricted=False | ||||
|                 ) | ||||
|  | ||||
|         if fail_on_restricted: | ||||
|             # Raise an error if collected restricted objects (RESTRICT) aren't | ||||
|             # candidates for deletion also collected via CASCADE. | ||||
|             for related_model, instances in self.data.items(): | ||||
|                 self.clear_restricted_objects_from_set(related_model, instances) | ||||
|             for qs in self.fast_deletes: | ||||
|                 self.clear_restricted_objects_from_queryset(qs.model, qs) | ||||
|             if self.restricted_objects.values(): | ||||
|                 restricted_objects = defaultdict(list) | ||||
|                 for related_model, fields in self.restricted_objects.items(): | ||||
|                     for field, objs in fields.items(): | ||||
|                         if objs: | ||||
|                             key = "'%s.%s'" % (related_model.__name__, field.name) | ||||
|                             restricted_objects[key] += objs | ||||
|                 if restricted_objects: | ||||
|                     raise RestrictedError( | ||||
|                         "Cannot delete some instances of model %r because " | ||||
|                         "they are referenced through restricted foreign keys: " | ||||
|                         "%s." | ||||
|                         % ( | ||||
|                             model.__name__, | ||||
|                             ", ".join(restricted_objects), | ||||
|                         ), | ||||
|                         set(chain.from_iterable(restricted_objects.values())), | ||||
|                     ) | ||||
|  | ||||
|     def related_objects(self, related_model, related_fields, objs): | ||||
|         """ | ||||
|         Get a QuerySet of the related model to objs via related fields. | ||||
|         """ | ||||
|         predicate = query_utils.Q.create( | ||||
|             [(f"{related_field.name}__in", objs) for related_field in related_fields], | ||||
|             connector=query_utils.Q.OR, | ||||
|         ) | ||||
|         return related_model._base_manager.using(self.using).filter(predicate) | ||||
|  | ||||
|     def instances_with_model(self): | ||||
|         for model, instances in self.data.items(): | ||||
|             for obj in instances: | ||||
|                 yield model, obj | ||||
|  | ||||
|     def sort(self): | ||||
|         sorted_models = [] | ||||
|         concrete_models = set() | ||||
|         models = list(self.data) | ||||
|         while len(sorted_models) < len(models): | ||||
|             found = False | ||||
|             for model in models: | ||||
|                 if model in sorted_models: | ||||
|                     continue | ||||
|                 dependencies = self.dependencies.get(model._meta.concrete_model) | ||||
|                 if not (dependencies and dependencies.difference(concrete_models)): | ||||
|                     sorted_models.append(model) | ||||
|                     concrete_models.add(model._meta.concrete_model) | ||||
|                     found = True | ||||
|             if not found: | ||||
|                 return | ||||
|         self.data = {model: self.data[model] for model in sorted_models} | ||||
|  | ||||
|     def delete(self): | ||||
|         # sort instance collections | ||||
|         for model, instances in self.data.items(): | ||||
|             self.data[model] = sorted(instances, key=attrgetter("pk")) | ||||
|  | ||||
|         # if possible, bring the models in an order suitable for databases that | ||||
|         # don't support transactions or cannot defer constraint checks until the | ||||
|         # end of a transaction. | ||||
|         self.sort() | ||||
|         # number of objects deleted for each model label | ||||
|         deleted_counter = Counter() | ||||
|  | ||||
|         # Optimize for the case with a single obj and no dependencies | ||||
|         if len(self.data) == 1 and len(instances) == 1: | ||||
|             instance = list(instances)[0] | ||||
|             if self.can_fast_delete(instance): | ||||
|                 with transaction.mark_for_rollback_on_error(self.using): | ||||
|                     count = sql.DeleteQuery(model).delete_batch( | ||||
|                         [instance.pk], self.using | ||||
|                     ) | ||||
|                 setattr(instance, model._meta.pk.attname, None) | ||||
|                 return count, {model._meta.label: count} | ||||
|  | ||||
|         with transaction.atomic(using=self.using, savepoint=False): | ||||
|             # send pre_delete signals | ||||
|             for model, obj in self.instances_with_model(): | ||||
|                 if not model._meta.auto_created: | ||||
|                     signals.pre_delete.send( | ||||
|                         sender=model, | ||||
|                         instance=obj, | ||||
|                         using=self.using, | ||||
|                         origin=self.origin, | ||||
|                     ) | ||||
|  | ||||
|             # fast deletes | ||||
|             for qs in self.fast_deletes: | ||||
|                 count = qs._raw_delete(using=self.using) | ||||
|                 if count: | ||||
|                     deleted_counter[qs.model._meta.label] += count | ||||
|  | ||||
|             # update fields | ||||
|             for (field, value), instances_list in self.field_updates.items(): | ||||
|                 updates = [] | ||||
|                 objs = [] | ||||
|                 for instances in instances_list: | ||||
|                     if ( | ||||
|                         isinstance(instances, models.QuerySet) | ||||
|                         and instances._result_cache is None | ||||
|                     ): | ||||
|                         updates.append(instances) | ||||
|                     else: | ||||
|                         objs.extend(instances) | ||||
|                 if updates: | ||||
|                     combined_updates = reduce(or_, updates) | ||||
|                     combined_updates.update(**{field.name: value}) | ||||
|                 if objs: | ||||
|                     model = objs[0].__class__ | ||||
|                     query = sql.UpdateQuery(model) | ||||
|                     query.update_batch( | ||||
|                         list({obj.pk for obj in objs}), {field.name: value}, self.using | ||||
|                     ) | ||||
|  | ||||
|             # reverse instance collections | ||||
|             for instances in self.data.values(): | ||||
|                 instances.reverse() | ||||
|  | ||||
|             # delete instances | ||||
|             for model, instances in self.data.items(): | ||||
|                 query = sql.DeleteQuery(model) | ||||
|                 pk_list = [obj.pk for obj in instances] | ||||
|                 count = query.delete_batch(pk_list, self.using) | ||||
|                 if count: | ||||
|                     deleted_counter[model._meta.label] += count | ||||
|  | ||||
|                 if not model._meta.auto_created: | ||||
|                     for obj in instances: | ||||
|                         signals.post_delete.send( | ||||
|                             sender=model, | ||||
|                             instance=obj, | ||||
|                             using=self.using, | ||||
|                             origin=self.origin, | ||||
|                         ) | ||||
|  | ||||
|         for model, instances in self.data.items(): | ||||
|             for instance in instances: | ||||
|                 setattr(instance, model._meta.pk.attname, None) | ||||
|         return sum(deleted_counter.values()), dict(deleted_counter) | ||||
| @ -0,0 +1,92 @@ | ||||
| import enum | ||||
| from types import DynamicClassAttribute | ||||
|  | ||||
| from django.utils.functional import Promise | ||||
|  | ||||
| __all__ = ["Choices", "IntegerChoices", "TextChoices"] | ||||
|  | ||||
|  | ||||
| class ChoicesMeta(enum.EnumMeta): | ||||
|     """A metaclass for creating a enum choices.""" | ||||
|  | ||||
|     def __new__(metacls, classname, bases, classdict, **kwds): | ||||
|         labels = [] | ||||
|         for key in classdict._member_names: | ||||
|             value = classdict[key] | ||||
|             if ( | ||||
|                 isinstance(value, (list, tuple)) | ||||
|                 and len(value) > 1 | ||||
|                 and isinstance(value[-1], (Promise, str)) | ||||
|             ): | ||||
|                 *value, label = value | ||||
|                 value = tuple(value) | ||||
|             else: | ||||
|                 label = key.replace("_", " ").title() | ||||
|             labels.append(label) | ||||
|             # Use dict.__setitem__() to suppress defenses against double | ||||
|             # assignment in enum's classdict. | ||||
|             dict.__setitem__(classdict, key, value) | ||||
|         cls = super().__new__(metacls, classname, bases, classdict, **kwds) | ||||
|         for member, label in zip(cls.__members__.values(), labels): | ||||
|             member._label_ = label | ||||
|         return enum.unique(cls) | ||||
|  | ||||
|     def __contains__(cls, member): | ||||
|         if not isinstance(member, enum.Enum): | ||||
|             # Allow non-enums to match against member values. | ||||
|             return any(x.value == member for x in cls) | ||||
|         return super().__contains__(member) | ||||
|  | ||||
|     @property | ||||
|     def names(cls): | ||||
|         empty = ["__empty__"] if hasattr(cls, "__empty__") else [] | ||||
|         return empty + [member.name for member in cls] | ||||
|  | ||||
|     @property | ||||
|     def choices(cls): | ||||
|         empty = [(None, cls.__empty__)] if hasattr(cls, "__empty__") else [] | ||||
|         return empty + [(member.value, member.label) for member in cls] | ||||
|  | ||||
|     @property | ||||
|     def labels(cls): | ||||
|         return [label for _, label in cls.choices] | ||||
|  | ||||
|     @property | ||||
|     def values(cls): | ||||
|         return [value for value, _ in cls.choices] | ||||
|  | ||||
|  | ||||
| class Choices(enum.Enum, metaclass=ChoicesMeta): | ||||
|     """Class for creating enumerated choices.""" | ||||
|  | ||||
|     @DynamicClassAttribute | ||||
|     def label(self): | ||||
|         return self._label_ | ||||
|  | ||||
|     @property | ||||
|     def do_not_call_in_templates(self): | ||||
|         return True | ||||
|  | ||||
|     def __str__(self): | ||||
|         """ | ||||
|         Use value when cast to str, so that Choices set as model instance | ||||
|         attributes are rendered as expected in templates and similar contexts. | ||||
|         """ | ||||
|         return str(self.value) | ||||
|  | ||||
|     # A similar format was proposed for Python 3.10. | ||||
|     def __repr__(self): | ||||
|         return f"{self.__class__.__qualname__}.{self._name_}" | ||||
|  | ||||
|  | ||||
| class IntegerChoices(int, Choices): | ||||
|     """Class for creating enumerated integer choices.""" | ||||
|  | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class TextChoices(str, Choices): | ||||
|     """Class for creating enumerated string choices.""" | ||||
|  | ||||
|     def _generate_next_value_(name, start, count, last_values): | ||||
|         return name | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -0,0 +1,510 @@ | ||||
| import datetime | ||||
| import posixpath | ||||
|  | ||||
| from django import forms | ||||
| from django.core import checks | ||||
| from django.core.files.base import File | ||||
| from django.core.files.images import ImageFile | ||||
| from django.core.files.storage import Storage, default_storage | ||||
| from django.core.files.utils import validate_file_name | ||||
| from django.db.models import signals | ||||
| from django.db.models.fields import Field | ||||
| from django.db.models.query_utils import DeferredAttribute | ||||
| from django.db.models.utils import AltersData | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
|  | ||||
|  | ||||
| class FieldFile(File, AltersData): | ||||
|     def __init__(self, instance, field, name): | ||||
|         super().__init__(None, name) | ||||
|         self.instance = instance | ||||
|         self.field = field | ||||
|         self.storage = field.storage | ||||
|         self._committed = True | ||||
|  | ||||
|     def __eq__(self, other): | ||||
|         # Older code may be expecting FileField values to be simple strings. | ||||
|         # By overriding the == operator, it can remain backwards compatibility. | ||||
|         if hasattr(other, "name"): | ||||
|             return self.name == other.name | ||||
|         return self.name == other | ||||
|  | ||||
|     def __hash__(self): | ||||
|         return hash(self.name) | ||||
|  | ||||
|     # The standard File contains most of the necessary properties, but | ||||
|     # FieldFiles can be instantiated without a name, so that needs to | ||||
|     # be checked for here. | ||||
|  | ||||
|     def _require_file(self): | ||||
|         if not self: | ||||
|             raise ValueError( | ||||
|                 "The '%s' attribute has no file associated with it." % self.field.name | ||||
|             ) | ||||
|  | ||||
|     def _get_file(self): | ||||
|         self._require_file() | ||||
|         if getattr(self, "_file", None) is None: | ||||
|             self._file = self.storage.open(self.name, "rb") | ||||
|         return self._file | ||||
|  | ||||
|     def _set_file(self, file): | ||||
|         self._file = file | ||||
|  | ||||
|     def _del_file(self): | ||||
|         del self._file | ||||
|  | ||||
|     file = property(_get_file, _set_file, _del_file) | ||||
|  | ||||
|     @property | ||||
|     def path(self): | ||||
|         self._require_file() | ||||
|         return self.storage.path(self.name) | ||||
|  | ||||
|     @property | ||||
|     def url(self): | ||||
|         self._require_file() | ||||
|         return self.storage.url(self.name) | ||||
|  | ||||
|     @property | ||||
|     def size(self): | ||||
|         self._require_file() | ||||
|         if not self._committed: | ||||
|             return self.file.size | ||||
|         return self.storage.size(self.name) | ||||
|  | ||||
|     def open(self, mode="rb"): | ||||
|         self._require_file() | ||||
|         if getattr(self, "_file", None) is None: | ||||
|             self.file = self.storage.open(self.name, mode) | ||||
|         else: | ||||
|             self.file.open(mode) | ||||
|         return self | ||||
|  | ||||
|     # open() doesn't alter the file's contents, but it does reset the pointer | ||||
|     open.alters_data = True | ||||
|  | ||||
|     # In addition to the standard File API, FieldFiles have extra methods | ||||
|     # to further manipulate the underlying file, as well as update the | ||||
|     # associated model instance. | ||||
|  | ||||
|     def save(self, name, content, save=True): | ||||
|         name = self.field.generate_filename(self.instance, name) | ||||
|         self.name = self.storage.save(name, content, max_length=self.field.max_length) | ||||
|         setattr(self.instance, self.field.attname, self.name) | ||||
|         self._committed = True | ||||
|  | ||||
|         # Save the object because it has changed, unless save is False | ||||
|         if save: | ||||
|             self.instance.save() | ||||
|  | ||||
|     save.alters_data = True | ||||
|  | ||||
|     def delete(self, save=True): | ||||
|         if not self: | ||||
|             return | ||||
|         # Only close the file if it's already open, which we know by the | ||||
|         # presence of self._file | ||||
|         if hasattr(self, "_file"): | ||||
|             self.close() | ||||
|             del self.file | ||||
|  | ||||
|         self.storage.delete(self.name) | ||||
|  | ||||
|         self.name = None | ||||
|         setattr(self.instance, self.field.attname, self.name) | ||||
|         self._committed = False | ||||
|  | ||||
|         if save: | ||||
|             self.instance.save() | ||||
|  | ||||
|     delete.alters_data = True | ||||
|  | ||||
|     @property | ||||
|     def closed(self): | ||||
|         file = getattr(self, "_file", None) | ||||
|         return file is None or file.closed | ||||
|  | ||||
|     def close(self): | ||||
|         file = getattr(self, "_file", None) | ||||
|         if file is not None: | ||||
|             file.close() | ||||
|  | ||||
|     def __getstate__(self): | ||||
|         # FieldFile needs access to its associated model field, an instance and | ||||
|         # the file's name. Everything else will be restored later, by | ||||
|         # FileDescriptor below. | ||||
|         return { | ||||
|             "name": self.name, | ||||
|             "closed": False, | ||||
|             "_committed": True, | ||||
|             "_file": None, | ||||
|             "instance": self.instance, | ||||
|             "field": self.field, | ||||
|         } | ||||
|  | ||||
|     def __setstate__(self, state): | ||||
|         self.__dict__.update(state) | ||||
|         self.storage = self.field.storage | ||||
|  | ||||
|  | ||||
| class FileDescriptor(DeferredAttribute): | ||||
|     """ | ||||
|     The descriptor for the file attribute on the model instance. Return a | ||||
|     FieldFile when accessed so you can write code like:: | ||||
|  | ||||
|         >>> from myapp.models import MyModel | ||||
|         >>> instance = MyModel.objects.get(pk=1) | ||||
|         >>> instance.file.size | ||||
|  | ||||
|     Assign a file object on assignment so you can do:: | ||||
|  | ||||
|         >>> with open('/path/to/hello.world') as f: | ||||
|         ...     instance.file = File(f) | ||||
|     """ | ||||
|  | ||||
|     def __get__(self, instance, cls=None): | ||||
|         if instance is None: | ||||
|             return self | ||||
|  | ||||
|         # This is slightly complicated, so worth an explanation. | ||||
|         # instance.file needs to ultimately return some instance of `File`, | ||||
|         # probably a subclass. Additionally, this returned object needs to have | ||||
|         # the FieldFile API so that users can easily do things like | ||||
|         # instance.file.path and have that delegated to the file storage engine. | ||||
|         # Easy enough if we're strict about assignment in __set__, but if you | ||||
|         # peek below you can see that we're not. So depending on the current | ||||
|         # value of the field we have to dynamically construct some sort of | ||||
|         # "thing" to return. | ||||
|  | ||||
|         # The instance dict contains whatever was originally assigned | ||||
|         # in __set__. | ||||
|         file = super().__get__(instance, cls) | ||||
|  | ||||
|         # If this value is a string (instance.file = "path/to/file") or None | ||||
|         # then we simply wrap it with the appropriate attribute class according | ||||
|         # to the file field. [This is FieldFile for FileFields and | ||||
|         # ImageFieldFile for ImageFields; it's also conceivable that user | ||||
|         # subclasses might also want to subclass the attribute class]. This | ||||
|         # object understands how to convert a path to a file, and also how to | ||||
|         # handle None. | ||||
|         if isinstance(file, str) or file is None: | ||||
|             attr = self.field.attr_class(instance, self.field, file) | ||||
|             instance.__dict__[self.field.attname] = attr | ||||
|  | ||||
|         # Other types of files may be assigned as well, but they need to have | ||||
|         # the FieldFile interface added to them. Thus, we wrap any other type of | ||||
|         # File inside a FieldFile (well, the field's attr_class, which is | ||||
|         # usually FieldFile). | ||||
|         elif isinstance(file, File) and not isinstance(file, FieldFile): | ||||
|             file_copy = self.field.attr_class(instance, self.field, file.name) | ||||
|             file_copy.file = file | ||||
|             file_copy._committed = False | ||||
|             instance.__dict__[self.field.attname] = file_copy | ||||
|  | ||||
|         # Finally, because of the (some would say boneheaded) way pickle works, | ||||
|         # the underlying FieldFile might not actually itself have an associated | ||||
|         # file. So we need to reset the details of the FieldFile in those cases. | ||||
|         elif isinstance(file, FieldFile) and not hasattr(file, "field"): | ||||
|             file.instance = instance | ||||
|             file.field = self.field | ||||
|             file.storage = self.field.storage | ||||
|  | ||||
|         # Make sure that the instance is correct. | ||||
|         elif isinstance(file, FieldFile) and instance is not file.instance: | ||||
|             file.instance = instance | ||||
|  | ||||
|         # That was fun, wasn't it? | ||||
|         return instance.__dict__[self.field.attname] | ||||
|  | ||||
|     def __set__(self, instance, value): | ||||
|         instance.__dict__[self.field.attname] = value | ||||
|  | ||||
|  | ||||
| class FileField(Field): | ||||
|     # The class to wrap instance attributes in. Accessing the file object off | ||||
|     # the instance will always return an instance of attr_class. | ||||
|     attr_class = FieldFile | ||||
|  | ||||
|     # The descriptor to use for accessing the attribute off of the class. | ||||
|     descriptor_class = FileDescriptor | ||||
|  | ||||
|     description = _("File") | ||||
|  | ||||
|     def __init__( | ||||
|         self, verbose_name=None, name=None, upload_to="", storage=None, **kwargs | ||||
|     ): | ||||
|         self._primary_key_set_explicitly = "primary_key" in kwargs | ||||
|  | ||||
|         self.storage = storage or default_storage | ||||
|         if callable(self.storage): | ||||
|             # Hold a reference to the callable for deconstruct(). | ||||
|             self._storage_callable = self.storage | ||||
|             self.storage = self.storage() | ||||
|             if not isinstance(self.storage, Storage): | ||||
|                 raise TypeError( | ||||
|                     "%s.storage must be a subclass/instance of %s.%s" | ||||
|                     % ( | ||||
|                         self.__class__.__qualname__, | ||||
|                         Storage.__module__, | ||||
|                         Storage.__qualname__, | ||||
|                     ) | ||||
|                 ) | ||||
|         self.upload_to = upload_to | ||||
|  | ||||
|         kwargs.setdefault("max_length", 100) | ||||
|         super().__init__(verbose_name, name, **kwargs) | ||||
|  | ||||
|     def check(self, **kwargs): | ||||
|         return [ | ||||
|             *super().check(**kwargs), | ||||
|             *self._check_primary_key(), | ||||
|             *self._check_upload_to(), | ||||
|         ] | ||||
|  | ||||
|     def _check_primary_key(self): | ||||
|         if self._primary_key_set_explicitly: | ||||
|             return [ | ||||
|                 checks.Error( | ||||
|                     "'primary_key' is not a valid argument for a %s." | ||||
|                     % self.__class__.__name__, | ||||
|                     obj=self, | ||||
|                     id="fields.E201", | ||||
|                 ) | ||||
|             ] | ||||
|         else: | ||||
|             return [] | ||||
|  | ||||
|     def _check_upload_to(self): | ||||
|         if isinstance(self.upload_to, str) and self.upload_to.startswith("/"): | ||||
|             return [ | ||||
|                 checks.Error( | ||||
|                     "%s's 'upload_to' argument must be a relative path, not an " | ||||
|                     "absolute path." % self.__class__.__name__, | ||||
|                     obj=self, | ||||
|                     id="fields.E202", | ||||
|                     hint="Remove the leading slash.", | ||||
|                 ) | ||||
|             ] | ||||
|         else: | ||||
|             return [] | ||||
|  | ||||
|     def deconstruct(self): | ||||
|         name, path, args, kwargs = super().deconstruct() | ||||
|         if kwargs.get("max_length") == 100: | ||||
|             del kwargs["max_length"] | ||||
|         kwargs["upload_to"] = self.upload_to | ||||
|         storage = getattr(self, "_storage_callable", self.storage) | ||||
|         if storage is not default_storage: | ||||
|             kwargs["storage"] = storage | ||||
|         return name, path, args, kwargs | ||||
|  | ||||
|     def get_internal_type(self): | ||||
|         return "FileField" | ||||
|  | ||||
|     def get_prep_value(self, value): | ||||
|         value = super().get_prep_value(value) | ||||
|         # Need to convert File objects provided via a form to string for | ||||
|         # database insertion. | ||||
|         if value is None: | ||||
|             return None | ||||
|         return str(value) | ||||
|  | ||||
|     def pre_save(self, model_instance, add): | ||||
|         file = super().pre_save(model_instance, add) | ||||
|         if file and not file._committed: | ||||
|             # Commit the file to storage prior to saving the model | ||||
|             file.save(file.name, file.file, save=False) | ||||
|         return file | ||||
|  | ||||
|     def contribute_to_class(self, cls, name, **kwargs): | ||||
|         super().contribute_to_class(cls, name, **kwargs) | ||||
|         setattr(cls, self.attname, self.descriptor_class(self)) | ||||
|  | ||||
|     def generate_filename(self, instance, filename): | ||||
|         """ | ||||
|         Apply (if callable) or prepend (if a string) upload_to to the filename, | ||||
|         then delegate further processing of the name to the storage backend. | ||||
|         Until the storage layer, all file paths are expected to be Unix style | ||||
|         (with forward slashes). | ||||
|         """ | ||||
|         if callable(self.upload_to): | ||||
|             filename = self.upload_to(instance, filename) | ||||
|         else: | ||||
|             dirname = datetime.datetime.now().strftime(str(self.upload_to)) | ||||
|             filename = posixpath.join(dirname, filename) | ||||
|         filename = validate_file_name(filename, allow_relative_path=True) | ||||
|         return self.storage.generate_filename(filename) | ||||
|  | ||||
|     def save_form_data(self, instance, data): | ||||
|         # Important: None means "no change", other false value means "clear" | ||||
|         # This subtle distinction (rather than a more explicit marker) is | ||||
|         # needed because we need to consume values that are also sane for a | ||||
|         # regular (non Model-) Form to find in its cleaned_data dictionary. | ||||
|         if data is not None: | ||||
|             # This value will be converted to str and stored in the | ||||
|             # database, so leaving False as-is is not acceptable. | ||||
|             setattr(instance, self.name, data or "") | ||||
|  | ||||
|     def formfield(self, **kwargs): | ||||
|         return super().formfield( | ||||
|             **{ | ||||
|                 "form_class": forms.FileField, | ||||
|                 "max_length": self.max_length, | ||||
|                 **kwargs, | ||||
|             } | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class ImageFileDescriptor(FileDescriptor): | ||||
|     """ | ||||
|     Just like the FileDescriptor, but for ImageFields. The only difference is | ||||
|     assigning the width/height to the width_field/height_field, if appropriate. | ||||
|     """ | ||||
|  | ||||
|     def __set__(self, instance, value): | ||||
|         previous_file = instance.__dict__.get(self.field.attname) | ||||
|         super().__set__(instance, value) | ||||
|  | ||||
|         # To prevent recalculating image dimensions when we are instantiating | ||||
|         # an object from the database (bug #11084), only update dimensions if | ||||
|         # the field had a value before this assignment.  Since the default | ||||
|         # value for FileField subclasses is an instance of field.attr_class, | ||||
|         # previous_file will only be None when we are called from | ||||
|         # Model.__init__().  The ImageField.update_dimension_fields method | ||||
|         # hooked up to the post_init signal handles the Model.__init__() cases. | ||||
|         # Assignment happening outside of Model.__init__() will trigger the | ||||
|         # update right here. | ||||
|         if previous_file is not None: | ||||
|             self.field.update_dimension_fields(instance, force=True) | ||||
|  | ||||
|  | ||||
| class ImageFieldFile(ImageFile, FieldFile): | ||||
|     def delete(self, save=True): | ||||
|         # Clear the image dimensions cache | ||||
|         if hasattr(self, "_dimensions_cache"): | ||||
|             del self._dimensions_cache | ||||
|         super().delete(save) | ||||
|  | ||||
|  | ||||
| class ImageField(FileField): | ||||
|     attr_class = ImageFieldFile | ||||
|     descriptor_class = ImageFileDescriptor | ||||
|     description = _("Image") | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         verbose_name=None, | ||||
|         name=None, | ||||
|         width_field=None, | ||||
|         height_field=None, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         self.width_field, self.height_field = width_field, height_field | ||||
|         super().__init__(verbose_name, name, **kwargs) | ||||
|  | ||||
|     def check(self, **kwargs): | ||||
|         return [ | ||||
|             *super().check(**kwargs), | ||||
|             *self._check_image_library_installed(), | ||||
|         ] | ||||
|  | ||||
|     def _check_image_library_installed(self): | ||||
|         try: | ||||
|             from PIL import Image  # NOQA | ||||
|         except ImportError: | ||||
|             return [ | ||||
|                 checks.Error( | ||||
|                     "Cannot use ImageField because Pillow is not installed.", | ||||
|                     hint=( | ||||
|                         "Get Pillow at https://pypi.org/project/Pillow/ " | ||||
|                         'or run command "python -m pip install Pillow".' | ||||
|                     ), | ||||
|                     obj=self, | ||||
|                     id="fields.E210", | ||||
|                 ) | ||||
|             ] | ||||
|         else: | ||||
|             return [] | ||||
|  | ||||
|     def deconstruct(self): | ||||
|         name, path, args, kwargs = super().deconstruct() | ||||
|         if self.width_field: | ||||
|             kwargs["width_field"] = self.width_field | ||||
|         if self.height_field: | ||||
|             kwargs["height_field"] = self.height_field | ||||
|         return name, path, args, kwargs | ||||
|  | ||||
|     def contribute_to_class(self, cls, name, **kwargs): | ||||
|         super().contribute_to_class(cls, name, **kwargs) | ||||
|         # Attach update_dimension_fields so that dimension fields declared | ||||
|         # after their corresponding image field don't stay cleared by | ||||
|         # Model.__init__, see bug #11196. | ||||
|         # Only run post-initialization dimension update on non-abstract models | ||||
|         if not cls._meta.abstract: | ||||
|             signals.post_init.connect(self.update_dimension_fields, sender=cls) | ||||
|  | ||||
|     def update_dimension_fields(self, instance, force=False, *args, **kwargs): | ||||
|         """ | ||||
|         Update field's width and height fields, if defined. | ||||
|  | ||||
|         This method is hooked up to model's post_init signal to update | ||||
|         dimensions after instantiating a model instance.  However, dimensions | ||||
|         won't be updated if the dimensions fields are already populated.  This | ||||
|         avoids unnecessary recalculation when loading an object from the | ||||
|         database. | ||||
|  | ||||
|         Dimensions can be forced to update with force=True, which is how | ||||
|         ImageFileDescriptor.__set__ calls this method. | ||||
|         """ | ||||
|         # Nothing to update if the field doesn't have dimension fields or if | ||||
|         # the field is deferred. | ||||
|         has_dimension_fields = self.width_field or self.height_field | ||||
|         if not has_dimension_fields or self.attname not in instance.__dict__: | ||||
|             return | ||||
|  | ||||
|         # getattr will call the ImageFileDescriptor's __get__ method, which | ||||
|         # coerces the assigned value into an instance of self.attr_class | ||||
|         # (ImageFieldFile in this case). | ||||
|         file = getattr(instance, self.attname) | ||||
|  | ||||
|         # Nothing to update if we have no file and not being forced to update. | ||||
|         if not file and not force: | ||||
|             return | ||||
|  | ||||
|         dimension_fields_filled = not ( | ||||
|             (self.width_field and not getattr(instance, self.width_field)) | ||||
|             or (self.height_field and not getattr(instance, self.height_field)) | ||||
|         ) | ||||
|         # When both dimension fields have values, we are most likely loading | ||||
|         # data from the database or updating an image field that already had | ||||
|         # an image stored.  In the first case, we don't want to update the | ||||
|         # dimension fields because we are already getting their values from the | ||||
|         # database.  In the second case, we do want to update the dimensions | ||||
|         # fields and will skip this return because force will be True since we | ||||
|         # were called from ImageFileDescriptor.__set__. | ||||
|         if dimension_fields_filled and not force: | ||||
|             return | ||||
|  | ||||
|         # file should be an instance of ImageFieldFile or should be None. | ||||
|         if file: | ||||
|             width = file.width | ||||
|             height = file.height | ||||
|         else: | ||||
|             # No file, so clear dimensions fields. | ||||
|             width = None | ||||
|             height = None | ||||
|  | ||||
|         # Update the width and height fields. | ||||
|         if self.width_field: | ||||
|             setattr(instance, self.width_field, width) | ||||
|         if self.height_field: | ||||
|             setattr(instance, self.height_field, height) | ||||
|  | ||||
|     def formfield(self, **kwargs): | ||||
|         return super().formfield( | ||||
|             **{ | ||||
|                 "form_class": forms.ImageField, | ||||
|                 **kwargs, | ||||
|             } | ||||
|         ) | ||||
| @ -0,0 +1,638 @@ | ||||
| import json | ||||
| import warnings | ||||
|  | ||||
| from django import forms | ||||
| from django.core import checks, exceptions | ||||
| from django.db import NotSupportedError, connections, router | ||||
| from django.db.models import expressions, lookups | ||||
| from django.db.models.constants import LOOKUP_SEP | ||||
| from django.db.models.fields import TextField | ||||
| from django.db.models.lookups import ( | ||||
|     FieldGetDbPrepValueMixin, | ||||
|     PostgresOperatorLookup, | ||||
|     Transform, | ||||
| ) | ||||
| from django.utils.deprecation import RemovedInDjango51Warning | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
|  | ||||
| from . import Field | ||||
| from .mixins import CheckFieldDefaultMixin | ||||
|  | ||||
| __all__ = ["JSONField"] | ||||
|  | ||||
|  | ||||
| class JSONField(CheckFieldDefaultMixin, Field): | ||||
|     empty_strings_allowed = False | ||||
|     description = _("A JSON object") | ||||
|     default_error_messages = { | ||||
|         "invalid": _("Value must be valid JSON."), | ||||
|     } | ||||
|     _default_hint = ("dict", "{}") | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         verbose_name=None, | ||||
|         name=None, | ||||
|         encoder=None, | ||||
|         decoder=None, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         if encoder and not callable(encoder): | ||||
|             raise ValueError("The encoder parameter must be a callable object.") | ||||
|         if decoder and not callable(decoder): | ||||
|             raise ValueError("The decoder parameter must be a callable object.") | ||||
|         self.encoder = encoder | ||||
|         self.decoder = decoder | ||||
|         super().__init__(verbose_name, name, **kwargs) | ||||
|  | ||||
|     def check(self, **kwargs): | ||||
|         errors = super().check(**kwargs) | ||||
|         databases = kwargs.get("databases") or [] | ||||
|         errors.extend(self._check_supported(databases)) | ||||
|         return errors | ||||
|  | ||||
|     def _check_supported(self, databases): | ||||
|         errors = [] | ||||
|         for db in databases: | ||||
|             if not router.allow_migrate_model(db, self.model): | ||||
|                 continue | ||||
|             connection = connections[db] | ||||
|             if ( | ||||
|                 self.model._meta.required_db_vendor | ||||
|                 and self.model._meta.required_db_vendor != connection.vendor | ||||
|             ): | ||||
|                 continue | ||||
|             if not ( | ||||
|                 "supports_json_field" in self.model._meta.required_db_features | ||||
|                 or connection.features.supports_json_field | ||||
|             ): | ||||
|                 errors.append( | ||||
|                     checks.Error( | ||||
|                         "%s does not support JSONFields." % connection.display_name, | ||||
|                         obj=self.model, | ||||
|                         id="fields.E180", | ||||
|                     ) | ||||
|                 ) | ||||
|         return errors | ||||
|  | ||||
|     def deconstruct(self): | ||||
|         name, path, args, kwargs = super().deconstruct() | ||||
|         if self.encoder is not None: | ||||
|             kwargs["encoder"] = self.encoder | ||||
|         if self.decoder is not None: | ||||
|             kwargs["decoder"] = self.decoder | ||||
|         return name, path, args, kwargs | ||||
|  | ||||
|     def from_db_value(self, value, expression, connection): | ||||
|         if value is None: | ||||
|             return value | ||||
|         # Some backends (SQLite at least) extract non-string values in their | ||||
|         # SQL datatypes. | ||||
|         if isinstance(expression, KeyTransform) and not isinstance(value, str): | ||||
|             return value | ||||
|         try: | ||||
|             return json.loads(value, cls=self.decoder) | ||||
|         except json.JSONDecodeError: | ||||
|             return value | ||||
|  | ||||
|     def get_internal_type(self): | ||||
|         return "JSONField" | ||||
|  | ||||
|     def get_db_prep_value(self, value, connection, prepared=False): | ||||
|         if not prepared: | ||||
|             value = self.get_prep_value(value) | ||||
|         # RemovedInDjango51Warning: When the deprecation ends, replace with: | ||||
|         # if ( | ||||
|         #     isinstance(value, expressions.Value) | ||||
|         #     and isinstance(value.output_field, JSONField) | ||||
|         # ): | ||||
|         #     value = value.value | ||||
|         # elif hasattr(value, "as_sql"): ... | ||||
|         if isinstance(value, expressions.Value): | ||||
|             if isinstance(value.value, str) and not isinstance( | ||||
|                 value.output_field, JSONField | ||||
|             ): | ||||
|                 try: | ||||
|                     value = json.loads(value.value, cls=self.decoder) | ||||
|                 except json.JSONDecodeError: | ||||
|                     value = value.value | ||||
|                 else: | ||||
|                     warnings.warn( | ||||
|                         "Providing an encoded JSON string via Value() is deprecated. " | ||||
|                         f"Use Value({value!r}, output_field=JSONField()) instead.", | ||||
|                         category=RemovedInDjango51Warning, | ||||
|                     ) | ||||
|             elif isinstance(value.output_field, JSONField): | ||||
|                 value = value.value | ||||
|             else: | ||||
|                 return value | ||||
|         elif hasattr(value, "as_sql"): | ||||
|             return value | ||||
|         return connection.ops.adapt_json_value(value, self.encoder) | ||||
|  | ||||
|     def get_db_prep_save(self, value, connection): | ||||
|         if value is None: | ||||
|             return value | ||||
|         return self.get_db_prep_value(value, connection) | ||||
|  | ||||
|     def get_transform(self, name): | ||||
|         transform = super().get_transform(name) | ||||
|         if transform: | ||||
|             return transform | ||||
|         return KeyTransformFactory(name) | ||||
|  | ||||
|     def validate(self, value, model_instance): | ||||
|         super().validate(value, model_instance) | ||||
|         try: | ||||
|             json.dumps(value, cls=self.encoder) | ||||
|         except TypeError: | ||||
|             raise exceptions.ValidationError( | ||||
|                 self.error_messages["invalid"], | ||||
|                 code="invalid", | ||||
|                 params={"value": value}, | ||||
|             ) | ||||
|  | ||||
|     def value_to_string(self, obj): | ||||
|         return self.value_from_object(obj) | ||||
|  | ||||
|     def formfield(self, **kwargs): | ||||
|         return super().formfield( | ||||
|             **{ | ||||
|                 "form_class": forms.JSONField, | ||||
|                 "encoder": self.encoder, | ||||
|                 "decoder": self.decoder, | ||||
|                 **kwargs, | ||||
|             } | ||||
|         ) | ||||
|  | ||||
|  | ||||
| def compile_json_path(key_transforms, include_root=True): | ||||
|     path = ["$"] if include_root else [] | ||||
|     for key_transform in key_transforms: | ||||
|         try: | ||||
|             num = int(key_transform) | ||||
|         except ValueError:  # non-integer | ||||
|             path.append(".") | ||||
|             path.append(json.dumps(key_transform)) | ||||
|         else: | ||||
|             path.append("[%s]" % num) | ||||
|     return "".join(path) | ||||
|  | ||||
|  | ||||
| class DataContains(FieldGetDbPrepValueMixin, PostgresOperatorLookup): | ||||
|     lookup_name = "contains" | ||||
|     postgres_operator = "@>" | ||||
|  | ||||
|     def as_sql(self, compiler, connection): | ||||
|         if not connection.features.supports_json_field_contains: | ||||
|             raise NotSupportedError( | ||||
|                 "contains lookup is not supported on this database backend." | ||||
|             ) | ||||
|         lhs, lhs_params = self.process_lhs(compiler, connection) | ||||
|         rhs, rhs_params = self.process_rhs(compiler, connection) | ||||
|         params = tuple(lhs_params) + tuple(rhs_params) | ||||
|         return "JSON_CONTAINS(%s, %s)" % (lhs, rhs), params | ||||
|  | ||||
|  | ||||
| class ContainedBy(FieldGetDbPrepValueMixin, PostgresOperatorLookup): | ||||
|     lookup_name = "contained_by" | ||||
|     postgres_operator = "<@" | ||||
|  | ||||
|     def as_sql(self, compiler, connection): | ||||
|         if not connection.features.supports_json_field_contains: | ||||
|             raise NotSupportedError( | ||||
|                 "contained_by lookup is not supported on this database backend." | ||||
|             ) | ||||
|         lhs, lhs_params = self.process_lhs(compiler, connection) | ||||
|         rhs, rhs_params = self.process_rhs(compiler, connection) | ||||
|         params = tuple(rhs_params) + tuple(lhs_params) | ||||
|         return "JSON_CONTAINS(%s, %s)" % (rhs, lhs), params | ||||
|  | ||||
|  | ||||
| class HasKeyLookup(PostgresOperatorLookup): | ||||
|     logical_operator = None | ||||
|  | ||||
|     def compile_json_path_final_key(self, key_transform): | ||||
|         # Compile the final key without interpreting ints as array elements. | ||||
|         return ".%s" % json.dumps(key_transform) | ||||
|  | ||||
|     def as_sql(self, compiler, connection, template=None): | ||||
|         # Process JSON path from the left-hand side. | ||||
|         if isinstance(self.lhs, KeyTransform): | ||||
|             lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs( | ||||
|                 compiler, connection | ||||
|             ) | ||||
|             lhs_json_path = compile_json_path(lhs_key_transforms) | ||||
|         else: | ||||
|             lhs, lhs_params = self.process_lhs(compiler, connection) | ||||
|             lhs_json_path = "$" | ||||
|         sql = template % lhs | ||||
|         # Process JSON path from the right-hand side. | ||||
|         rhs = self.rhs | ||||
|         rhs_params = [] | ||||
|         if not isinstance(rhs, (list, tuple)): | ||||
|             rhs = [rhs] | ||||
|         for key in rhs: | ||||
|             if isinstance(key, KeyTransform): | ||||
|                 *_, rhs_key_transforms = key.preprocess_lhs(compiler, connection) | ||||
|             else: | ||||
|                 rhs_key_transforms = [key] | ||||
|             *rhs_key_transforms, final_key = rhs_key_transforms | ||||
|             rhs_json_path = compile_json_path(rhs_key_transforms, include_root=False) | ||||
|             rhs_json_path += self.compile_json_path_final_key(final_key) | ||||
|             rhs_params.append(lhs_json_path + rhs_json_path) | ||||
|         # Add condition for each key. | ||||
|         if self.logical_operator: | ||||
|             sql = "(%s)" % self.logical_operator.join([sql] * len(rhs_params)) | ||||
|         return sql, tuple(lhs_params) + tuple(rhs_params) | ||||
|  | ||||
|     def as_mysql(self, compiler, connection): | ||||
|         return self.as_sql( | ||||
|             compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)" | ||||
|         ) | ||||
|  | ||||
|     def as_oracle(self, compiler, connection): | ||||
|         sql, params = self.as_sql( | ||||
|             compiler, connection, template="JSON_EXISTS(%s, '%%s')" | ||||
|         ) | ||||
|         # Add paths directly into SQL because path expressions cannot be passed | ||||
|         # as bind variables on Oracle. | ||||
|         return sql % tuple(params), [] | ||||
|  | ||||
|     def as_postgresql(self, compiler, connection): | ||||
|         if isinstance(self.rhs, KeyTransform): | ||||
|             *_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection) | ||||
|             for key in rhs_key_transforms[:-1]: | ||||
|                 self.lhs = KeyTransform(key, self.lhs) | ||||
|             self.rhs = rhs_key_transforms[-1] | ||||
|         return super().as_postgresql(compiler, connection) | ||||
|  | ||||
|     def as_sqlite(self, compiler, connection): | ||||
|         return self.as_sql( | ||||
|             compiler, connection, template="JSON_TYPE(%s, %%s) IS NOT NULL" | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class HasKey(HasKeyLookup): | ||||
|     lookup_name = "has_key" | ||||
|     postgres_operator = "?" | ||||
|     prepare_rhs = False | ||||
|  | ||||
|  | ||||
| class HasKeys(HasKeyLookup): | ||||
|     lookup_name = "has_keys" | ||||
|     postgres_operator = "?&" | ||||
|     logical_operator = " AND " | ||||
|  | ||||
|     def get_prep_lookup(self): | ||||
|         return [str(item) for item in self.rhs] | ||||
|  | ||||
|  | ||||
| class HasAnyKeys(HasKeys): | ||||
|     lookup_name = "has_any_keys" | ||||
|     postgres_operator = "?|" | ||||
|     logical_operator = " OR " | ||||
|  | ||||
|  | ||||
| class HasKeyOrArrayIndex(HasKey): | ||||
|     def compile_json_path_final_key(self, key_transform): | ||||
|         return compile_json_path([key_transform], include_root=False) | ||||
|  | ||||
|  | ||||
| class CaseInsensitiveMixin: | ||||
|     """ | ||||
|     Mixin to allow case-insensitive comparison of JSON values on MySQL. | ||||
|     MySQL handles strings used in JSON context using the utf8mb4_bin collation. | ||||
|     Because utf8mb4_bin is a binary collation, comparison of JSON values is | ||||
|     case-sensitive. | ||||
|     """ | ||||
|  | ||||
|     def process_lhs(self, compiler, connection): | ||||
|         lhs, lhs_params = super().process_lhs(compiler, connection) | ||||
|         if connection.vendor == "mysql": | ||||
|             return "LOWER(%s)" % lhs, lhs_params | ||||
|         return lhs, lhs_params | ||||
|  | ||||
|     def process_rhs(self, compiler, connection): | ||||
|         rhs, rhs_params = super().process_rhs(compiler, connection) | ||||
|         if connection.vendor == "mysql": | ||||
|             return "LOWER(%s)" % rhs, rhs_params | ||||
|         return rhs, rhs_params | ||||
|  | ||||
|  | ||||
| class JSONExact(lookups.Exact): | ||||
|     can_use_none_as_rhs = True | ||||
|  | ||||
|     def process_rhs(self, compiler, connection): | ||||
|         rhs, rhs_params = super().process_rhs(compiler, connection) | ||||
|         # Treat None lookup values as null. | ||||
|         if rhs == "%s" and rhs_params == [None]: | ||||
|             rhs_params = ["null"] | ||||
|         if connection.vendor == "mysql": | ||||
|             func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params) | ||||
|             rhs %= tuple(func) | ||||
|         return rhs, rhs_params | ||||
|  | ||||
|  | ||||
| class JSONIContains(CaseInsensitiveMixin, lookups.IContains): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| JSONField.register_lookup(DataContains) | ||||
| JSONField.register_lookup(ContainedBy) | ||||
| JSONField.register_lookup(HasKey) | ||||
| JSONField.register_lookup(HasKeys) | ||||
| JSONField.register_lookup(HasAnyKeys) | ||||
| JSONField.register_lookup(JSONExact) | ||||
| JSONField.register_lookup(JSONIContains) | ||||
|  | ||||
|  | ||||
| class KeyTransform(Transform): | ||||
|     postgres_operator = "->" | ||||
|     postgres_nested_operator = "#>" | ||||
|  | ||||
|     def __init__(self, key_name, *args, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|         self.key_name = str(key_name) | ||||
|  | ||||
|     def preprocess_lhs(self, compiler, connection): | ||||
|         key_transforms = [self.key_name] | ||||
|         previous = self.lhs | ||||
|         while isinstance(previous, KeyTransform): | ||||
|             key_transforms.insert(0, previous.key_name) | ||||
|             previous = previous.lhs | ||||
|         lhs, params = compiler.compile(previous) | ||||
|         if connection.vendor == "oracle": | ||||
|             # Escape string-formatting. | ||||
|             key_transforms = [key.replace("%", "%%") for key in key_transforms] | ||||
|         return lhs, params, key_transforms | ||||
|  | ||||
|     def as_mysql(self, compiler, connection): | ||||
|         lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) | ||||
|         json_path = compile_json_path(key_transforms) | ||||
|         return "JSON_EXTRACT(%s, %%s)" % lhs, tuple(params) + (json_path,) | ||||
|  | ||||
|     def as_oracle(self, compiler, connection): | ||||
|         lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) | ||||
|         json_path = compile_json_path(key_transforms) | ||||
|         return ( | ||||
|             "COALESCE(JSON_QUERY(%s, '%s'), JSON_VALUE(%s, '%s'))" | ||||
|             % ((lhs, json_path) * 2) | ||||
|         ), tuple(params) * 2 | ||||
|  | ||||
|     def as_postgresql(self, compiler, connection): | ||||
|         lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) | ||||
|         if len(key_transforms) > 1: | ||||
|             sql = "(%s %s %%s)" % (lhs, self.postgres_nested_operator) | ||||
|             return sql, tuple(params) + (key_transforms,) | ||||
|         try: | ||||
|             lookup = int(self.key_name) | ||||
|         except ValueError: | ||||
|             lookup = self.key_name | ||||
|         return "(%s %s %%s)" % (lhs, self.postgres_operator), tuple(params) + (lookup,) | ||||
|  | ||||
|     def as_sqlite(self, compiler, connection): | ||||
|         lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) | ||||
|         json_path = compile_json_path(key_transforms) | ||||
|         datatype_values = ",".join( | ||||
|             [repr(datatype) for datatype in connection.ops.jsonfield_datatype_values] | ||||
|         ) | ||||
|         return ( | ||||
|             "(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) " | ||||
|             "THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)" | ||||
|         ) % (lhs, datatype_values, lhs, lhs), (tuple(params) + (json_path,)) * 3 | ||||
|  | ||||
|  | ||||
| class KeyTextTransform(KeyTransform): | ||||
|     postgres_operator = "->>" | ||||
|     postgres_nested_operator = "#>>" | ||||
|     output_field = TextField() | ||||
|  | ||||
|     def as_mysql(self, compiler, connection): | ||||
|         if connection.mysql_is_mariadb: | ||||
|             # MariaDB doesn't support -> and ->> operators (see MDEV-13594). | ||||
|             sql, params = super().as_mysql(compiler, connection) | ||||
|             return "JSON_UNQUOTE(%s)" % sql, params | ||||
|         else: | ||||
|             lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) | ||||
|             json_path = compile_json_path(key_transforms) | ||||
|             return "(%s ->> %%s)" % lhs, tuple(params) + (json_path,) | ||||
|  | ||||
|     @classmethod | ||||
|     def from_lookup(cls, lookup): | ||||
|         transform, *keys = lookup.split(LOOKUP_SEP) | ||||
|         if not keys: | ||||
|             raise ValueError("Lookup must contain key or index transforms.") | ||||
|         for key in keys: | ||||
|             transform = cls(key, transform) | ||||
|         return transform | ||||
|  | ||||
|  | ||||
| KT = KeyTextTransform.from_lookup | ||||
|  | ||||
|  | ||||
| class KeyTransformTextLookupMixin: | ||||
|     """ | ||||
|     Mixin for combining with a lookup expecting a text lhs from a JSONField | ||||
|     key lookup. On PostgreSQL, make use of the ->> operator instead of casting | ||||
|     key values to text and performing the lookup on the resulting | ||||
|     representation. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, key_transform, *args, **kwargs): | ||||
|         if not isinstance(key_transform, KeyTransform): | ||||
|             raise TypeError( | ||||
|                 "Transform should be an instance of KeyTransform in order to " | ||||
|                 "use this lookup." | ||||
|             ) | ||||
|         key_text_transform = KeyTextTransform( | ||||
|             key_transform.key_name, | ||||
|             *key_transform.source_expressions, | ||||
|             **key_transform.extra, | ||||
|         ) | ||||
|         super().__init__(key_text_transform, *args, **kwargs) | ||||
|  | ||||
|  | ||||
| class KeyTransformIsNull(lookups.IsNull): | ||||
|     # key__isnull=False is the same as has_key='key' | ||||
|     def as_oracle(self, compiler, connection): | ||||
|         sql, params = HasKeyOrArrayIndex( | ||||
|             self.lhs.lhs, | ||||
|             self.lhs.key_name, | ||||
|         ).as_oracle(compiler, connection) | ||||
|         if not self.rhs: | ||||
|             return sql, params | ||||
|         # Column doesn't have a key or IS NULL. | ||||
|         lhs, lhs_params, _ = self.lhs.preprocess_lhs(compiler, connection) | ||||
|         return "(NOT %s OR %s IS NULL)" % (sql, lhs), tuple(params) + tuple(lhs_params) | ||||
|  | ||||
|     def as_sqlite(self, compiler, connection): | ||||
|         template = "JSON_TYPE(%s, %%s) IS NULL" | ||||
|         if not self.rhs: | ||||
|             template = "JSON_TYPE(%s, %%s) IS NOT NULL" | ||||
|         return HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name).as_sql( | ||||
|             compiler, | ||||
|             connection, | ||||
|             template=template, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class KeyTransformIn(lookups.In): | ||||
|     def resolve_expression_parameter(self, compiler, connection, sql, param): | ||||
|         sql, params = super().resolve_expression_parameter( | ||||
|             compiler, | ||||
|             connection, | ||||
|             sql, | ||||
|             param, | ||||
|         ) | ||||
|         if ( | ||||
|             not hasattr(param, "as_sql") | ||||
|             and not connection.features.has_native_json_field | ||||
|         ): | ||||
|             if connection.vendor == "oracle": | ||||
|                 value = json.loads(param) | ||||
|                 sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')" | ||||
|                 if isinstance(value, (list, dict)): | ||||
|                     sql %= "JSON_QUERY" | ||||
|                 else: | ||||
|                     sql %= "JSON_VALUE" | ||||
|             elif connection.vendor == "mysql" or ( | ||||
|                 connection.vendor == "sqlite" | ||||
|                 and params[0] not in connection.ops.jsonfield_datatype_values | ||||
|             ): | ||||
|                 sql = "JSON_EXTRACT(%s, '$')" | ||||
|         if connection.vendor == "mysql" and connection.mysql_is_mariadb: | ||||
|             sql = "JSON_UNQUOTE(%s)" % sql | ||||
|         return sql, params | ||||
|  | ||||
|  | ||||
| class KeyTransformExact(JSONExact): | ||||
|     def process_rhs(self, compiler, connection): | ||||
|         if isinstance(self.rhs, KeyTransform): | ||||
|             return super(lookups.Exact, self).process_rhs(compiler, connection) | ||||
|         rhs, rhs_params = super().process_rhs(compiler, connection) | ||||
|         if connection.vendor == "oracle": | ||||
|             func = [] | ||||
|             sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')" | ||||
|             for value in rhs_params: | ||||
|                 value = json.loads(value) | ||||
|                 if isinstance(value, (list, dict)): | ||||
|                     func.append(sql % "JSON_QUERY") | ||||
|                 else: | ||||
|                     func.append(sql % "JSON_VALUE") | ||||
|             rhs %= tuple(func) | ||||
|         elif connection.vendor == "sqlite": | ||||
|             func = [] | ||||
|             for value in rhs_params: | ||||
|                 if value in connection.ops.jsonfield_datatype_values: | ||||
|                     func.append("%s") | ||||
|                 else: | ||||
|                     func.append("JSON_EXTRACT(%s, '$')") | ||||
|             rhs %= tuple(func) | ||||
|         return rhs, rhs_params | ||||
|  | ||||
|     def as_oracle(self, compiler, connection): | ||||
|         rhs, rhs_params = super().process_rhs(compiler, connection) | ||||
|         if rhs_params == ["null"]: | ||||
|             # Field has key and it's NULL. | ||||
|             has_key_expr = HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name) | ||||
|             has_key_sql, has_key_params = has_key_expr.as_oracle(compiler, connection) | ||||
|             is_null_expr = self.lhs.get_lookup("isnull")(self.lhs, True) | ||||
|             is_null_sql, is_null_params = is_null_expr.as_sql(compiler, connection) | ||||
|             return ( | ||||
|                 "%s AND %s" % (has_key_sql, is_null_sql), | ||||
|                 tuple(has_key_params) + tuple(is_null_params), | ||||
|             ) | ||||
|         return super().as_sql(compiler, connection) | ||||
|  | ||||
|  | ||||
| class KeyTransformIExact( | ||||
|     CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact | ||||
| ): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class KeyTransformIContains( | ||||
|     CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains | ||||
| ): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class KeyTransformIStartsWith( | ||||
|     CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith | ||||
| ): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class KeyTransformIEndsWith( | ||||
|     CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith | ||||
| ): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class KeyTransformIRegex( | ||||
|     CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex | ||||
| ): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class KeyTransformNumericLookupMixin: | ||||
|     def process_rhs(self, compiler, connection): | ||||
|         rhs, rhs_params = super().process_rhs(compiler, connection) | ||||
|         if not connection.features.has_native_json_field: | ||||
|             rhs_params = [json.loads(value) for value in rhs_params] | ||||
|         return rhs, rhs_params | ||||
|  | ||||
|  | ||||
| class KeyTransformLt(KeyTransformNumericLookupMixin, lookups.LessThan): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class KeyTransformLte(KeyTransformNumericLookupMixin, lookups.LessThanOrEqual): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class KeyTransformGt(KeyTransformNumericLookupMixin, lookups.GreaterThan): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class KeyTransformGte(KeyTransformNumericLookupMixin, lookups.GreaterThanOrEqual): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| KeyTransform.register_lookup(KeyTransformIn) | ||||
| KeyTransform.register_lookup(KeyTransformExact) | ||||
| KeyTransform.register_lookup(KeyTransformIExact) | ||||
| KeyTransform.register_lookup(KeyTransformIsNull) | ||||
| KeyTransform.register_lookup(KeyTransformIContains) | ||||
| KeyTransform.register_lookup(KeyTransformStartsWith) | ||||
| KeyTransform.register_lookup(KeyTransformIStartsWith) | ||||
| KeyTransform.register_lookup(KeyTransformEndsWith) | ||||
| KeyTransform.register_lookup(KeyTransformIEndsWith) | ||||
| KeyTransform.register_lookup(KeyTransformRegex) | ||||
| KeyTransform.register_lookup(KeyTransformIRegex) | ||||
|  | ||||
| KeyTransform.register_lookup(KeyTransformLt) | ||||
| KeyTransform.register_lookup(KeyTransformLte) | ||||
| KeyTransform.register_lookup(KeyTransformGt) | ||||
| KeyTransform.register_lookup(KeyTransformGte) | ||||
|  | ||||
|  | ||||
| class KeyTransformFactory: | ||||
|     def __init__(self, key_name): | ||||
|         self.key_name = key_name | ||||
|  | ||||
|     def __call__(self, *args, **kwargs): | ||||
|         return KeyTransform(self.key_name, *args, **kwargs) | ||||
| @ -0,0 +1,59 @@ | ||||
| from django.core import checks | ||||
|  | ||||
| NOT_PROVIDED = object() | ||||
|  | ||||
|  | ||||
| class FieldCacheMixin: | ||||
|     """Provide an API for working with the model's fields value cache.""" | ||||
|  | ||||
|     def get_cache_name(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def get_cached_value(self, instance, default=NOT_PROVIDED): | ||||
|         cache_name = self.get_cache_name() | ||||
|         try: | ||||
|             return instance._state.fields_cache[cache_name] | ||||
|         except KeyError: | ||||
|             if default is NOT_PROVIDED: | ||||
|                 raise | ||||
|             return default | ||||
|  | ||||
|     def is_cached(self, instance): | ||||
|         return self.get_cache_name() in instance._state.fields_cache | ||||
|  | ||||
|     def set_cached_value(self, instance, value): | ||||
|         instance._state.fields_cache[self.get_cache_name()] = value | ||||
|  | ||||
|     def delete_cached_value(self, instance): | ||||
|         del instance._state.fields_cache[self.get_cache_name()] | ||||
|  | ||||
|  | ||||
| class CheckFieldDefaultMixin: | ||||
|     _default_hint = ("<valid default>", "<invalid default>") | ||||
|  | ||||
|     def _check_default(self): | ||||
|         if ( | ||||
|             self.has_default() | ||||
|             and self.default is not None | ||||
|             and not callable(self.default) | ||||
|         ): | ||||
|             return [ | ||||
|                 checks.Warning( | ||||
|                     "%s default should be a callable instead of an instance " | ||||
|                     "so that it's not shared between all field instances." | ||||
|                     % (self.__class__.__name__,), | ||||
|                     hint=( | ||||
|                         "Use a callable instead, e.g., use `%s` instead of " | ||||
|                         "`%s`." % self._default_hint | ||||
|                     ), | ||||
|                     obj=self, | ||||
|                     id="fields.E010", | ||||
|                 ) | ||||
|             ] | ||||
|         else: | ||||
|             return [] | ||||
|  | ||||
|     def check(self, **kwargs): | ||||
|         errors = super().check(**kwargs) | ||||
|         errors.extend(self._check_default()) | ||||
|         return errors | ||||
| @ -0,0 +1,18 @@ | ||||
| """ | ||||
| Field-like classes that aren't really fields. It's easier to use objects that | ||||
| have the same attributes as fields sometimes (avoids a lot of special casing). | ||||
| """ | ||||
|  | ||||
| from django.db.models import fields | ||||
|  | ||||
|  | ||||
| class OrderWrt(fields.IntegerField): | ||||
|     """ | ||||
|     A proxy for the _order database field that is used when | ||||
|     Meta.order_with_respect_to is specified. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         kwargs["name"] = "_order" | ||||
|         kwargs["editable"] = False | ||||
|         super().__init__(*args, **kwargs) | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -0,0 +1,209 @@ | ||||
| import warnings | ||||
|  | ||||
| from django.db.models.lookups import ( | ||||
|     Exact, | ||||
|     GreaterThan, | ||||
|     GreaterThanOrEqual, | ||||
|     In, | ||||
|     IsNull, | ||||
|     LessThan, | ||||
|     LessThanOrEqual, | ||||
| ) | ||||
| from django.utils.deprecation import RemovedInDjango50Warning | ||||
|  | ||||
|  | ||||
| class MultiColSource: | ||||
|     contains_aggregate = False | ||||
|     contains_over_clause = False | ||||
|  | ||||
|     def __init__(self, alias, targets, sources, field): | ||||
|         self.targets, self.sources, self.field, self.alias = ( | ||||
|             targets, | ||||
|             sources, | ||||
|             field, | ||||
|             alias, | ||||
|         ) | ||||
|         self.output_field = self.field | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{}({}, {})".format(self.__class__.__name__, self.alias, self.field) | ||||
|  | ||||
|     def relabeled_clone(self, relabels): | ||||
|         return self.__class__( | ||||
|             relabels.get(self.alias, self.alias), self.targets, self.sources, self.field | ||||
|         ) | ||||
|  | ||||
|     def get_lookup(self, lookup): | ||||
|         return self.output_field.get_lookup(lookup) | ||||
|  | ||||
|     def resolve_expression(self, *args, **kwargs): | ||||
|         return self | ||||
|  | ||||
|  | ||||
| def get_normalized_value(value, lhs): | ||||
|     from django.db.models import Model | ||||
|  | ||||
|     if isinstance(value, Model): | ||||
|         if value.pk is None: | ||||
|             # When the deprecation ends, replace with: | ||||
|             # raise ValueError( | ||||
|             #     "Model instances passed to related filters must be saved." | ||||
|             # ) | ||||
|             warnings.warn( | ||||
|                 "Passing unsaved model instances to related filters is deprecated.", | ||||
|                 RemovedInDjango50Warning, | ||||
|             ) | ||||
|         value_list = [] | ||||
|         sources = lhs.output_field.path_infos[-1].target_fields | ||||
|         for source in sources: | ||||
|             while not isinstance(value, source.model) and source.remote_field: | ||||
|                 source = source.remote_field.model._meta.get_field( | ||||
|                     source.remote_field.field_name | ||||
|                 ) | ||||
|             try: | ||||
|                 value_list.append(getattr(value, source.attname)) | ||||
|             except AttributeError: | ||||
|                 # A case like Restaurant.objects.filter(place=restaurant_instance), | ||||
|                 # where place is a OneToOneField and the primary key of Restaurant. | ||||
|                 return (value.pk,) | ||||
|         return tuple(value_list) | ||||
|     if not isinstance(value, tuple): | ||||
|         return (value,) | ||||
|     return value | ||||
|  | ||||
|  | ||||
| class RelatedIn(In): | ||||
|     def get_prep_lookup(self): | ||||
|         if not isinstance(self.lhs, MultiColSource): | ||||
|             if self.rhs_is_direct_value(): | ||||
|                 # If we get here, we are dealing with single-column relations. | ||||
|                 self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs] | ||||
|                 # We need to run the related field's get_prep_value(). Consider | ||||
|                 # case ForeignKey to IntegerField given value 'abc'. The | ||||
|                 # ForeignKey itself doesn't have validation for non-integers, | ||||
|                 # so we must run validation using the target field. | ||||
|                 if hasattr(self.lhs.output_field, "path_infos"): | ||||
|                     # Run the target field's get_prep_value. We can safely | ||||
|                     # assume there is only one as we don't get to the direct | ||||
|                     # value branch otherwise. | ||||
|                     target_field = self.lhs.output_field.path_infos[-1].target_fields[ | ||||
|                         -1 | ||||
|                     ] | ||||
|                     self.rhs = [target_field.get_prep_value(v) for v in self.rhs] | ||||
|             elif not getattr(self.rhs, "has_select_fields", True) and not getattr( | ||||
|                 self.lhs.field.target_field, "primary_key", False | ||||
|             ): | ||||
|                 if ( | ||||
|                     getattr(self.lhs.output_field, "primary_key", False) | ||||
|                     and self.lhs.output_field.model == self.rhs.model | ||||
|                 ): | ||||
|                     # A case like | ||||
|                     # Restaurant.objects.filter(place__in=restaurant_qs), where | ||||
|                     # place is a OneToOneField and the primary key of | ||||
|                     # Restaurant. | ||||
|                     target_field = self.lhs.field.name | ||||
|                 else: | ||||
|                     target_field = self.lhs.field.target_field.name | ||||
|                 self.rhs.set_values([target_field]) | ||||
|         return super().get_prep_lookup() | ||||
|  | ||||
|     def as_sql(self, compiler, connection): | ||||
|         if isinstance(self.lhs, MultiColSource): | ||||
|             # For multicolumn lookups we need to build a multicolumn where clause. | ||||
|             # This clause is either a SubqueryConstraint (for values that need | ||||
|             # to be compiled to SQL) or an OR-combined list of | ||||
|             # (col1 = val1 AND col2 = val2 AND ...) clauses. | ||||
|             from django.db.models.sql.where import ( | ||||
|                 AND, | ||||
|                 OR, | ||||
|                 SubqueryConstraint, | ||||
|                 WhereNode, | ||||
|             ) | ||||
|  | ||||
|             root_constraint = WhereNode(connector=OR) | ||||
|             if self.rhs_is_direct_value(): | ||||
|                 values = [get_normalized_value(value, self.lhs) for value in self.rhs] | ||||
|                 for value in values: | ||||
|                     value_constraint = WhereNode() | ||||
|                     for source, target, val in zip( | ||||
|                         self.lhs.sources, self.lhs.targets, value | ||||
|                     ): | ||||
|                         lookup_class = target.get_lookup("exact") | ||||
|                         lookup = lookup_class( | ||||
|                             target.get_col(self.lhs.alias, source), val | ||||
|                         ) | ||||
|                         value_constraint.add(lookup, AND) | ||||
|                     root_constraint.add(value_constraint, OR) | ||||
|             else: | ||||
|                 root_constraint.add( | ||||
|                     SubqueryConstraint( | ||||
|                         self.lhs.alias, | ||||
|                         [target.column for target in self.lhs.targets], | ||||
|                         [source.name for source in self.lhs.sources], | ||||
|                         self.rhs, | ||||
|                     ), | ||||
|                     AND, | ||||
|                 ) | ||||
|             return root_constraint.as_sql(compiler, connection) | ||||
|         return super().as_sql(compiler, connection) | ||||
|  | ||||
|  | ||||
| class RelatedLookupMixin: | ||||
|     def get_prep_lookup(self): | ||||
|         if not isinstance(self.lhs, MultiColSource) and not hasattr( | ||||
|             self.rhs, "resolve_expression" | ||||
|         ): | ||||
|             # If we get here, we are dealing with single-column relations. | ||||
|             self.rhs = get_normalized_value(self.rhs, self.lhs)[0] | ||||
|             # We need to run the related field's get_prep_value(). Consider case | ||||
|             # ForeignKey to IntegerField given value 'abc'. The ForeignKey itself | ||||
|             # doesn't have validation for non-integers, so we must run validation | ||||
|             # using the target field. | ||||
|             if self.prepare_rhs and hasattr(self.lhs.output_field, "path_infos"): | ||||
|                 # Get the target field. We can safely assume there is only one | ||||
|                 # as we don't get to the direct value branch otherwise. | ||||
|                 target_field = self.lhs.output_field.path_infos[-1].target_fields[-1] | ||||
|                 self.rhs = target_field.get_prep_value(self.rhs) | ||||
|  | ||||
|         return super().get_prep_lookup() | ||||
|  | ||||
|     def as_sql(self, compiler, connection): | ||||
|         if isinstance(self.lhs, MultiColSource): | ||||
|             assert self.rhs_is_direct_value() | ||||
|             self.rhs = get_normalized_value(self.rhs, self.lhs) | ||||
|             from django.db.models.sql.where import AND, WhereNode | ||||
|  | ||||
|             root_constraint = WhereNode() | ||||
|             for target, source, val in zip( | ||||
|                 self.lhs.targets, self.lhs.sources, self.rhs | ||||
|             ): | ||||
|                 lookup_class = target.get_lookup(self.lookup_name) | ||||
|                 root_constraint.add( | ||||
|                     lookup_class(target.get_col(self.lhs.alias, source), val), AND | ||||
|                 ) | ||||
|             return root_constraint.as_sql(compiler, connection) | ||||
|         return super().as_sql(compiler, connection) | ||||
|  | ||||
|  | ||||
| class RelatedExact(RelatedLookupMixin, Exact): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class RelatedLessThan(RelatedLookupMixin, LessThan): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class RelatedGreaterThan(RelatedLookupMixin, GreaterThan): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class RelatedGreaterThanOrEqual(RelatedLookupMixin, GreaterThanOrEqual): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class RelatedLessThanOrEqual(RelatedLookupMixin, LessThanOrEqual): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class RelatedIsNull(RelatedLookupMixin, IsNull): | ||||
|     pass | ||||
| @ -0,0 +1,402 @@ | ||||
| """ | ||||
| "Rel objects" for related fields. | ||||
|  | ||||
| "Rel objects" (for lack of a better name) carry information about the relation | ||||
| modeled by a related field and provide some utility functions. They're stored | ||||
| in the ``remote_field`` attribute of the field. | ||||
|  | ||||
| They also act as reverse fields for the purposes of the Meta API because | ||||
| they're the closest concept currently available. | ||||
| """ | ||||
|  | ||||
| from django.core import exceptions | ||||
| from django.utils.functional import cached_property | ||||
| from django.utils.hashable import make_hashable | ||||
|  | ||||
| from . import BLANK_CHOICE_DASH | ||||
| from .mixins import FieldCacheMixin | ||||
|  | ||||
|  | ||||
| class ForeignObjectRel(FieldCacheMixin): | ||||
|     """ | ||||
|     Used by ForeignObject to store information about the relation. | ||||
|  | ||||
|     ``_meta.get_fields()`` returns this class to provide access to the field | ||||
|     flags for the reverse relation. | ||||
|     """ | ||||
|  | ||||
|     # Field flags | ||||
|     auto_created = True | ||||
|     concrete = False | ||||
|     editable = False | ||||
|     is_relation = True | ||||
|  | ||||
|     # Reverse relations are always nullable (Django can't enforce that a | ||||
|     # foreign key on the related model points to this model). | ||||
|     null = True | ||||
|     empty_strings_allowed = False | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         field, | ||||
|         to, | ||||
|         related_name=None, | ||||
|         related_query_name=None, | ||||
|         limit_choices_to=None, | ||||
|         parent_link=False, | ||||
|         on_delete=None, | ||||
|     ): | ||||
|         self.field = field | ||||
|         self.model = to | ||||
|         self.related_name = related_name | ||||
|         self.related_query_name = related_query_name | ||||
|         self.limit_choices_to = {} if limit_choices_to is None else limit_choices_to | ||||
|         self.parent_link = parent_link | ||||
|         self.on_delete = on_delete | ||||
|  | ||||
|         self.symmetrical = False | ||||
|         self.multiple = True | ||||
|  | ||||
|     # Some of the following cached_properties can't be initialized in | ||||
|     # __init__ as the field doesn't have its model yet. Calling these methods | ||||
|     # before field.contribute_to_class() has been called will result in | ||||
|     # AttributeError | ||||
|     @cached_property | ||||
|     def hidden(self): | ||||
|         return self.is_hidden() | ||||
|  | ||||
|     @cached_property | ||||
|     def name(self): | ||||
|         return self.field.related_query_name() | ||||
|  | ||||
|     @property | ||||
|     def remote_field(self): | ||||
|         return self.field | ||||
|  | ||||
|     @property | ||||
|     def target_field(self): | ||||
|         """ | ||||
|         When filtering against this relation, return the field on the remote | ||||
|         model against which the filtering should happen. | ||||
|         """ | ||||
|         target_fields = self.path_infos[-1].target_fields | ||||
|         if len(target_fields) > 1: | ||||
|             raise exceptions.FieldError( | ||||
|                 "Can't use target_field for multicolumn relations." | ||||
|             ) | ||||
|         return target_fields[0] | ||||
|  | ||||
|     @cached_property | ||||
|     def related_model(self): | ||||
|         if not self.field.model: | ||||
|             raise AttributeError( | ||||
|                 "This property can't be accessed before self.field.contribute_to_class " | ||||
|                 "has been called." | ||||
|             ) | ||||
|         return self.field.model | ||||
|  | ||||
|     @cached_property | ||||
|     def many_to_many(self): | ||||
|         return self.field.many_to_many | ||||
|  | ||||
|     @cached_property | ||||
|     def many_to_one(self): | ||||
|         return self.field.one_to_many | ||||
|  | ||||
|     @cached_property | ||||
|     def one_to_many(self): | ||||
|         return self.field.many_to_one | ||||
|  | ||||
|     @cached_property | ||||
|     def one_to_one(self): | ||||
|         return self.field.one_to_one | ||||
|  | ||||
|     def get_lookup(self, lookup_name): | ||||
|         return self.field.get_lookup(lookup_name) | ||||
|  | ||||
|     def get_lookups(self): | ||||
|         return self.field.get_lookups() | ||||
|  | ||||
|     def get_transform(self, name): | ||||
|         return self.field.get_transform(name) | ||||
|  | ||||
|     def get_internal_type(self): | ||||
|         return self.field.get_internal_type() | ||||
|  | ||||
|     @property | ||||
|     def db_type(self): | ||||
|         return self.field.db_type | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "<%s: %s.%s>" % ( | ||||
|             type(self).__name__, | ||||
|             self.related_model._meta.app_label, | ||||
|             self.related_model._meta.model_name, | ||||
|         ) | ||||
|  | ||||
|     @property | ||||
|     def identity(self): | ||||
|         return ( | ||||
|             self.field, | ||||
|             self.model, | ||||
|             self.related_name, | ||||
|             self.related_query_name, | ||||
|             make_hashable(self.limit_choices_to), | ||||
|             self.parent_link, | ||||
|             self.on_delete, | ||||
|             self.symmetrical, | ||||
|             self.multiple, | ||||
|         ) | ||||
|  | ||||
|     def __eq__(self, other): | ||||
|         if not isinstance(other, self.__class__): | ||||
|             return NotImplemented | ||||
|         return self.identity == other.identity | ||||
|  | ||||
|     def __hash__(self): | ||||
|         return hash(self.identity) | ||||
|  | ||||
|     def __getstate__(self): | ||||
|         state = self.__dict__.copy() | ||||
|         # Delete the path_infos cached property because it can be recalculated | ||||
|         # at first invocation after deserialization. The attribute must be | ||||
|         # removed because subclasses like ManyToOneRel may have a PathInfo | ||||
|         # which contains an intermediate M2M table that's been dynamically | ||||
|         # created and doesn't exist in the .models module. | ||||
|         # This is a reverse relation, so there is no reverse_path_infos to | ||||
|         # delete. | ||||
|         state.pop("path_infos", None) | ||||
|         return state | ||||
|  | ||||
|     def get_choices( | ||||
|         self, | ||||
|         include_blank=True, | ||||
|         blank_choice=BLANK_CHOICE_DASH, | ||||
|         limit_choices_to=None, | ||||
|         ordering=(), | ||||
|     ): | ||||
|         """ | ||||
|         Return choices with a default blank choices included, for use | ||||
|         as <select> choices for this field. | ||||
|  | ||||
|         Analog of django.db.models.fields.Field.get_choices(), provided | ||||
|         initially for utilization by RelatedFieldListFilter. | ||||
|         """ | ||||
|         limit_choices_to = limit_choices_to or self.limit_choices_to | ||||
|         qs = self.related_model._default_manager.complex_filter(limit_choices_to) | ||||
|         if ordering: | ||||
|             qs = qs.order_by(*ordering) | ||||
|         return (blank_choice if include_blank else []) + [(x.pk, str(x)) for x in qs] | ||||
|  | ||||
|     def is_hidden(self): | ||||
|         """Should the related object be hidden?""" | ||||
|         return bool(self.related_name) and self.related_name[-1] == "+" | ||||
|  | ||||
|     def get_joining_columns(self): | ||||
|         return self.field.get_reverse_joining_columns() | ||||
|  | ||||
|     def get_extra_restriction(self, alias, related_alias): | ||||
|         return self.field.get_extra_restriction(related_alias, alias) | ||||
|  | ||||
|     def set_field_name(self): | ||||
|         """ | ||||
|         Set the related field's name, this is not available until later stages | ||||
|         of app loading, so set_field_name is called from | ||||
|         set_attributes_from_rel() | ||||
|         """ | ||||
|         # By default foreign object doesn't relate to any remote field (for | ||||
|         # example custom multicolumn joins currently have no remote field). | ||||
|         self.field_name = None | ||||
|  | ||||
|     def get_accessor_name(self, model=None): | ||||
|         # This method encapsulates the logic that decides what name to give an | ||||
|         # accessor descriptor that retrieves related many-to-one or | ||||
|         # many-to-many objects. It uses the lowercased object_name + "_set", | ||||
|         # but this can be overridden with the "related_name" option. Due to | ||||
|         # backwards compatibility ModelForms need to be able to provide an | ||||
|         # alternate model. See BaseInlineFormSet.get_default_prefix(). | ||||
|         opts = model._meta if model else self.related_model._meta | ||||
|         model = model or self.related_model | ||||
|         if self.multiple: | ||||
|             # If this is a symmetrical m2m relation on self, there is no | ||||
|             # reverse accessor. | ||||
|             if self.symmetrical and model == self.model: | ||||
|                 return None | ||||
|         if self.related_name: | ||||
|             return self.related_name | ||||
|         return opts.model_name + ("_set" if self.multiple else "") | ||||
|  | ||||
|     def get_path_info(self, filtered_relation=None): | ||||
|         if filtered_relation: | ||||
|             return self.field.get_reverse_path_info(filtered_relation) | ||||
|         else: | ||||
|             return self.field.reverse_path_infos | ||||
|  | ||||
|     @cached_property | ||||
|     def path_infos(self): | ||||
|         return self.get_path_info() | ||||
|  | ||||
|     def get_cache_name(self): | ||||
|         """ | ||||
|         Return the name of the cache key to use for storing an instance of the | ||||
|         forward model on the reverse model. | ||||
|         """ | ||||
|         return self.get_accessor_name() | ||||
|  | ||||
|  | ||||
| class ManyToOneRel(ForeignObjectRel): | ||||
|     """ | ||||
|     Used by the ForeignKey field to store information about the relation. | ||||
|  | ||||
|     ``_meta.get_fields()`` returns this class to provide access to the field | ||||
|     flags for the reverse relation. | ||||
|  | ||||
|     Note: Because we somewhat abuse the Rel objects by using them as reverse | ||||
|     fields we get the funny situation where | ||||
|     ``ManyToOneRel.many_to_one == False`` and | ||||
|     ``ManyToOneRel.one_to_many == True``. This is unfortunate but the actual | ||||
|     ManyToOneRel class is a private API and there is work underway to turn | ||||
|     reverse relations into actual fields. | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         field, | ||||
|         to, | ||||
|         field_name, | ||||
|         related_name=None, | ||||
|         related_query_name=None, | ||||
|         limit_choices_to=None, | ||||
|         parent_link=False, | ||||
|         on_delete=None, | ||||
|     ): | ||||
|         super().__init__( | ||||
|             field, | ||||
|             to, | ||||
|             related_name=related_name, | ||||
|             related_query_name=related_query_name, | ||||
|             limit_choices_to=limit_choices_to, | ||||
|             parent_link=parent_link, | ||||
|             on_delete=on_delete, | ||||
|         ) | ||||
|  | ||||
|         self.field_name = field_name | ||||
|  | ||||
|     def __getstate__(self): | ||||
|         state = super().__getstate__() | ||||
|         state.pop("related_model", None) | ||||
|         return state | ||||
|  | ||||
|     @property | ||||
|     def identity(self): | ||||
|         return super().identity + (self.field_name,) | ||||
|  | ||||
|     def get_related_field(self): | ||||
|         """ | ||||
|         Return the Field in the 'to' object to which this relationship is tied. | ||||
|         """ | ||||
|         field = self.model._meta.get_field(self.field_name) | ||||
|         if not field.concrete: | ||||
|             raise exceptions.FieldDoesNotExist( | ||||
|                 "No related field named '%s'" % self.field_name | ||||
|             ) | ||||
|         return field | ||||
|  | ||||
|     def set_field_name(self): | ||||
|         self.field_name = self.field_name or self.model._meta.pk.name | ||||
|  | ||||
|  | ||||
| class OneToOneRel(ManyToOneRel): | ||||
|     """ | ||||
|     Used by OneToOneField to store information about the relation. | ||||
|  | ||||
|     ``_meta.get_fields()`` returns this class to provide access to the field | ||||
|     flags for the reverse relation. | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         field, | ||||
|         to, | ||||
|         field_name, | ||||
|         related_name=None, | ||||
|         related_query_name=None, | ||||
|         limit_choices_to=None, | ||||
|         parent_link=False, | ||||
|         on_delete=None, | ||||
|     ): | ||||
|         super().__init__( | ||||
|             field, | ||||
|             to, | ||||
|             field_name, | ||||
|             related_name=related_name, | ||||
|             related_query_name=related_query_name, | ||||
|             limit_choices_to=limit_choices_to, | ||||
|             parent_link=parent_link, | ||||
|             on_delete=on_delete, | ||||
|         ) | ||||
|  | ||||
|         self.multiple = False | ||||
|  | ||||
|  | ||||
| class ManyToManyRel(ForeignObjectRel): | ||||
|     """ | ||||
|     Used by ManyToManyField to store information about the relation. | ||||
|  | ||||
|     ``_meta.get_fields()`` returns this class to provide access to the field | ||||
|     flags for the reverse relation. | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         field, | ||||
|         to, | ||||
|         related_name=None, | ||||
|         related_query_name=None, | ||||
|         limit_choices_to=None, | ||||
|         symmetrical=True, | ||||
|         through=None, | ||||
|         through_fields=None, | ||||
|         db_constraint=True, | ||||
|     ): | ||||
|         super().__init__( | ||||
|             field, | ||||
|             to, | ||||
|             related_name=related_name, | ||||
|             related_query_name=related_query_name, | ||||
|             limit_choices_to=limit_choices_to, | ||||
|         ) | ||||
|  | ||||
|         if through and not db_constraint: | ||||
|             raise ValueError("Can't supply a through model and db_constraint=False") | ||||
|         self.through = through | ||||
|  | ||||
|         if through_fields and not through: | ||||
|             raise ValueError("Cannot specify through_fields without a through model") | ||||
|         self.through_fields = through_fields | ||||
|  | ||||
|         self.symmetrical = symmetrical | ||||
|         self.db_constraint = db_constraint | ||||
|  | ||||
|     @property | ||||
|     def identity(self): | ||||
|         return super().identity + ( | ||||
|             self.through, | ||||
|             make_hashable(self.through_fields), | ||||
|             self.db_constraint, | ||||
|         ) | ||||
|  | ||||
|     def get_related_field(self): | ||||
|         """ | ||||
|         Return the field in the 'to' object to which this relationship is tied. | ||||
|         Provided for symmetry with ManyToOneRel. | ||||
|         """ | ||||
|         opts = self.through._meta | ||||
|         if self.through_fields: | ||||
|             field = opts.get_field(self.through_fields[0]) | ||||
|         else: | ||||
|             for field in opts.fields: | ||||
|                 rel = getattr(field, "remote_field", None) | ||||
|                 if rel and rel.model == self.model: | ||||
|                     break | ||||
|         return field.foreign_related_fields[0] | ||||
| @ -0,0 +1,190 @@ | ||||
| from .comparison import Cast, Coalesce, Collate, Greatest, JSONObject, Least, NullIf | ||||
| from .datetime import ( | ||||
|     Extract, | ||||
|     ExtractDay, | ||||
|     ExtractHour, | ||||
|     ExtractIsoWeekDay, | ||||
|     ExtractIsoYear, | ||||
|     ExtractMinute, | ||||
|     ExtractMonth, | ||||
|     ExtractQuarter, | ||||
|     ExtractSecond, | ||||
|     ExtractWeek, | ||||
|     ExtractWeekDay, | ||||
|     ExtractYear, | ||||
|     Now, | ||||
|     Trunc, | ||||
|     TruncDate, | ||||
|     TruncDay, | ||||
|     TruncHour, | ||||
|     TruncMinute, | ||||
|     TruncMonth, | ||||
|     TruncQuarter, | ||||
|     TruncSecond, | ||||
|     TruncTime, | ||||
|     TruncWeek, | ||||
|     TruncYear, | ||||
| ) | ||||
| from .math import ( | ||||
|     Abs, | ||||
|     ACos, | ||||
|     ASin, | ||||
|     ATan, | ||||
|     ATan2, | ||||
|     Ceil, | ||||
|     Cos, | ||||
|     Cot, | ||||
|     Degrees, | ||||
|     Exp, | ||||
|     Floor, | ||||
|     Ln, | ||||
|     Log, | ||||
|     Mod, | ||||
|     Pi, | ||||
|     Power, | ||||
|     Radians, | ||||
|     Random, | ||||
|     Round, | ||||
|     Sign, | ||||
|     Sin, | ||||
|     Sqrt, | ||||
|     Tan, | ||||
| ) | ||||
| from .text import ( | ||||
|     MD5, | ||||
|     SHA1, | ||||
|     SHA224, | ||||
|     SHA256, | ||||
|     SHA384, | ||||
|     SHA512, | ||||
|     Chr, | ||||
|     Concat, | ||||
|     ConcatPair, | ||||
|     Left, | ||||
|     Length, | ||||
|     Lower, | ||||
|     LPad, | ||||
|     LTrim, | ||||
|     Ord, | ||||
|     Repeat, | ||||
|     Replace, | ||||
|     Reverse, | ||||
|     Right, | ||||
|     RPad, | ||||
|     RTrim, | ||||
|     StrIndex, | ||||
|     Substr, | ||||
|     Trim, | ||||
|     Upper, | ||||
| ) | ||||
| from .window import ( | ||||
|     CumeDist, | ||||
|     DenseRank, | ||||
|     FirstValue, | ||||
|     Lag, | ||||
|     LastValue, | ||||
|     Lead, | ||||
|     NthValue, | ||||
|     Ntile, | ||||
|     PercentRank, | ||||
|     Rank, | ||||
|     RowNumber, | ||||
| ) | ||||
|  | ||||
| __all__ = [ | ||||
|     # comparison and conversion | ||||
|     "Cast", | ||||
|     "Coalesce", | ||||
|     "Collate", | ||||
|     "Greatest", | ||||
|     "JSONObject", | ||||
|     "Least", | ||||
|     "NullIf", | ||||
|     # datetime | ||||
|     "Extract", | ||||
|     "ExtractDay", | ||||
|     "ExtractHour", | ||||
|     "ExtractMinute", | ||||
|     "ExtractMonth", | ||||
|     "ExtractQuarter", | ||||
|     "ExtractSecond", | ||||
|     "ExtractWeek", | ||||
|     "ExtractIsoWeekDay", | ||||
|     "ExtractWeekDay", | ||||
|     "ExtractIsoYear", | ||||
|     "ExtractYear", | ||||
|     "Now", | ||||
|     "Trunc", | ||||
|     "TruncDate", | ||||
|     "TruncDay", | ||||
|     "TruncHour", | ||||
|     "TruncMinute", | ||||
|     "TruncMonth", | ||||
|     "TruncQuarter", | ||||
|     "TruncSecond", | ||||
|     "TruncTime", | ||||
|     "TruncWeek", | ||||
|     "TruncYear", | ||||
|     # math | ||||
|     "Abs", | ||||
|     "ACos", | ||||
|     "ASin", | ||||
|     "ATan", | ||||
|     "ATan2", | ||||
|     "Ceil", | ||||
|     "Cos", | ||||
|     "Cot", | ||||
|     "Degrees", | ||||
|     "Exp", | ||||
|     "Floor", | ||||
|     "Ln", | ||||
|     "Log", | ||||
|     "Mod", | ||||
|     "Pi", | ||||
|     "Power", | ||||
|     "Radians", | ||||
|     "Random", | ||||
|     "Round", | ||||
|     "Sign", | ||||
|     "Sin", | ||||
|     "Sqrt", | ||||
|     "Tan", | ||||
|     # text | ||||
|     "MD5", | ||||
|     "SHA1", | ||||
|     "SHA224", | ||||
|     "SHA256", | ||||
|     "SHA384", | ||||
|     "SHA512", | ||||
|     "Chr", | ||||
|     "Concat", | ||||
|     "ConcatPair", | ||||
|     "Left", | ||||
|     "Length", | ||||
|     "Lower", | ||||
|     "LPad", | ||||
|     "LTrim", | ||||
|     "Ord", | ||||
|     "Repeat", | ||||
|     "Replace", | ||||
|     "Reverse", | ||||
|     "Right", | ||||
|     "RPad", | ||||
|     "RTrim", | ||||
|     "StrIndex", | ||||
|     "Substr", | ||||
|     "Trim", | ||||
|     "Upper", | ||||
|     # window | ||||
|     "CumeDist", | ||||
|     "DenseRank", | ||||
|     "FirstValue", | ||||
|     "Lag", | ||||
|     "LastValue", | ||||
|     "Lead", | ||||
|     "NthValue", | ||||
|     "Ntile", | ||||
|     "PercentRank", | ||||
|     "Rank", | ||||
|     "RowNumber", | ||||
| ] | ||||
| @ -0,0 +1,220 @@ | ||||
| """Database functions that do comparisons or type conversions.""" | ||||
| from django.db import NotSupportedError | ||||
| from django.db.models.expressions import Func, Value | ||||
| from django.db.models.fields import TextField | ||||
| from django.db.models.fields.json import JSONField | ||||
| from django.utils.regex_helper import _lazy_re_compile | ||||
|  | ||||
|  | ||||
| class Cast(Func): | ||||
|     """Coerce an expression to a new field type.""" | ||||
|  | ||||
|     function = "CAST" | ||||
|     template = "%(function)s(%(expressions)s AS %(db_type)s)" | ||||
|  | ||||
|     def __init__(self, expression, output_field): | ||||
|         super().__init__(expression, output_field=output_field) | ||||
|  | ||||
|     def as_sql(self, compiler, connection, **extra_context): | ||||
|         extra_context["db_type"] = self.output_field.cast_db_type(connection) | ||||
|         return super().as_sql(compiler, connection, **extra_context) | ||||
|  | ||||
|     def as_sqlite(self, compiler, connection, **extra_context): | ||||
|         db_type = self.output_field.db_type(connection) | ||||
|         if db_type in {"datetime", "time"}: | ||||
|             # Use strftime as datetime/time don't keep fractional seconds. | ||||
|             template = "strftime(%%s, %(expressions)s)" | ||||
|             sql, params = super().as_sql( | ||||
|                 compiler, connection, template=template, **extra_context | ||||
|             ) | ||||
|             format_string = "%H:%M:%f" if db_type == "time" else "%Y-%m-%d %H:%M:%f" | ||||
|             params.insert(0, format_string) | ||||
|             return sql, params | ||||
|         elif db_type == "date": | ||||
|             template = "date(%(expressions)s)" | ||||
|             return super().as_sql( | ||||
|                 compiler, connection, template=template, **extra_context | ||||
|             ) | ||||
|         return self.as_sql(compiler, connection, **extra_context) | ||||
|  | ||||
|     def as_mysql(self, compiler, connection, **extra_context): | ||||
|         template = None | ||||
|         output_type = self.output_field.get_internal_type() | ||||
|         # MySQL doesn't support explicit cast to float. | ||||
|         if output_type == "FloatField": | ||||
|             template = "(%(expressions)s + 0.0)" | ||||
|         # MariaDB doesn't support explicit cast to JSON. | ||||
|         elif output_type == "JSONField" and connection.mysql_is_mariadb: | ||||
|             template = "JSON_EXTRACT(%(expressions)s, '$')" | ||||
|         return self.as_sql(compiler, connection, template=template, **extra_context) | ||||
|  | ||||
|     def as_postgresql(self, compiler, connection, **extra_context): | ||||
|         # CAST would be valid too, but the :: shortcut syntax is more readable. | ||||
|         # 'expressions' is wrapped in parentheses in case it's a complex | ||||
|         # expression. | ||||
|         return self.as_sql( | ||||
|             compiler, | ||||
|             connection, | ||||
|             template="(%(expressions)s)::%(db_type)s", | ||||
|             **extra_context, | ||||
|         ) | ||||
|  | ||||
|     def as_oracle(self, compiler, connection, **extra_context): | ||||
|         if self.output_field.get_internal_type() == "JSONField": | ||||
|             # Oracle doesn't support explicit cast to JSON. | ||||
|             template = "JSON_QUERY(%(expressions)s, '$')" | ||||
|             return super().as_sql( | ||||
|                 compiler, connection, template=template, **extra_context | ||||
|             ) | ||||
|         return self.as_sql(compiler, connection, **extra_context) | ||||
|  | ||||
|  | ||||
| class Coalesce(Func): | ||||
|     """Return, from left to right, the first non-null expression.""" | ||||
|  | ||||
|     function = "COALESCE" | ||||
|  | ||||
|     def __init__(self, *expressions, **extra): | ||||
|         if len(expressions) < 2: | ||||
|             raise ValueError("Coalesce must take at least two expressions") | ||||
|         super().__init__(*expressions, **extra) | ||||
|  | ||||
|     @property | ||||
|     def empty_result_set_value(self): | ||||
|         for expression in self.get_source_expressions(): | ||||
|             result = expression.empty_result_set_value | ||||
|             if result is NotImplemented or result is not None: | ||||
|                 return result | ||||
|         return None | ||||
|  | ||||
|     def as_oracle(self, compiler, connection, **extra_context): | ||||
|         # Oracle prohibits mixing TextField (NCLOB) and CharField (NVARCHAR2), | ||||
|         # so convert all fields to NCLOB when that type is expected. | ||||
|         if self.output_field.get_internal_type() == "TextField": | ||||
|             clone = self.copy() | ||||
|             clone.set_source_expressions( | ||||
|                 [ | ||||
|                     Func(expression, function="TO_NCLOB") | ||||
|                     for expression in self.get_source_expressions() | ||||
|                 ] | ||||
|             ) | ||||
|             return super(Coalesce, clone).as_sql(compiler, connection, **extra_context) | ||||
|         return self.as_sql(compiler, connection, **extra_context) | ||||
|  | ||||
|  | ||||
| class Collate(Func): | ||||
|     function = "COLLATE" | ||||
|     template = "%(expressions)s %(function)s %(collation)s" | ||||
|     # Inspired from | ||||
|     # https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS | ||||
|     collation_re = _lazy_re_compile(r"^[\w\-]+$") | ||||
|  | ||||
|     def __init__(self, expression, collation): | ||||
|         if not (collation and self.collation_re.match(collation)): | ||||
|             raise ValueError("Invalid collation name: %r." % collation) | ||||
|         self.collation = collation | ||||
|         super().__init__(expression) | ||||
|  | ||||
|     def as_sql(self, compiler, connection, **extra_context): | ||||
|         extra_context.setdefault("collation", connection.ops.quote_name(self.collation)) | ||||
|         return super().as_sql(compiler, connection, **extra_context) | ||||
|  | ||||
|  | ||||
| class Greatest(Func): | ||||
|     """ | ||||
|     Return the maximum expression. | ||||
|  | ||||
|     If any expression is null the return value is database-specific: | ||||
|     On PostgreSQL, the maximum not-null expression is returned. | ||||
|     On MySQL, Oracle, and SQLite, if any expression is null, null is returned. | ||||
|     """ | ||||
|  | ||||
|     function = "GREATEST" | ||||
|  | ||||
|     def __init__(self, *expressions, **extra): | ||||
|         if len(expressions) < 2: | ||||
|             raise ValueError("Greatest must take at least two expressions") | ||||
|         super().__init__(*expressions, **extra) | ||||
|  | ||||
|     def as_sqlite(self, compiler, connection, **extra_context): | ||||
|         """Use the MAX function on SQLite.""" | ||||
|         return super().as_sqlite(compiler, connection, function="MAX", **extra_context) | ||||
|  | ||||
|  | ||||
| class JSONObject(Func): | ||||
|     function = "JSON_OBJECT" | ||||
|     output_field = JSONField() | ||||
|  | ||||
|     def __init__(self, **fields): | ||||
|         expressions = [] | ||||
|         for key, value in fields.items(): | ||||
|             expressions.extend((Value(key), value)) | ||||
|         super().__init__(*expressions) | ||||
|  | ||||
|     def as_sql(self, compiler, connection, **extra_context): | ||||
|         if not connection.features.has_json_object_function: | ||||
|             raise NotSupportedError( | ||||
|                 "JSONObject() is not supported on this database backend." | ||||
|             ) | ||||
|         return super().as_sql(compiler, connection, **extra_context) | ||||
|  | ||||
|     def as_postgresql(self, compiler, connection, **extra_context): | ||||
|         copy = self.copy() | ||||
|         copy.set_source_expressions( | ||||
|             [ | ||||
|                 Cast(expression, TextField()) if index % 2 == 0 else expression | ||||
|                 for index, expression in enumerate(copy.get_source_expressions()) | ||||
|             ] | ||||
|         ) | ||||
|         return super(JSONObject, copy).as_sql( | ||||
|             compiler, | ||||
|             connection, | ||||
|             function="JSONB_BUILD_OBJECT", | ||||
|             **extra_context, | ||||
|         ) | ||||
|  | ||||
|     def as_oracle(self, compiler, connection, **extra_context): | ||||
|         class ArgJoiner: | ||||
|             def join(self, args): | ||||
|                 args = [" VALUE ".join(arg) for arg in zip(args[::2], args[1::2])] | ||||
|                 return ", ".join(args) | ||||
|  | ||||
|         return self.as_sql( | ||||
|             compiler, | ||||
|             connection, | ||||
|             arg_joiner=ArgJoiner(), | ||||
|             template="%(function)s(%(expressions)s RETURNING CLOB)", | ||||
|             **extra_context, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class Least(Func): | ||||
|     """ | ||||
|     Return the minimum expression. | ||||
|  | ||||
|     If any expression is null the return value is database-specific: | ||||
|     On PostgreSQL, return the minimum not-null expression. | ||||
|     On MySQL, Oracle, and SQLite, if any expression is null, return null. | ||||
|     """ | ||||
|  | ||||
|     function = "LEAST" | ||||
|  | ||||
|     def __init__(self, *expressions, **extra): | ||||
|         if len(expressions) < 2: | ||||
|             raise ValueError("Least must take at least two expressions") | ||||
|         super().__init__(*expressions, **extra) | ||||
|  | ||||
|     def as_sqlite(self, compiler, connection, **extra_context): | ||||
|         """Use the MIN function on SQLite.""" | ||||
|         return super().as_sqlite(compiler, connection, function="MIN", **extra_context) | ||||
|  | ||||
|  | ||||
| class NullIf(Func): | ||||
|     function = "NULLIF" | ||||
|     arity = 2 | ||||
|  | ||||
|     def as_oracle(self, compiler, connection, **extra_context): | ||||
|         expression1 = self.get_source_expressions()[0] | ||||
|         if isinstance(expression1, Value) and expression1.value is None: | ||||
|             raise ValueError("Oracle does not allow Value(None) for expression1.") | ||||
|         return super().as_sql(compiler, connection, **extra_context) | ||||
| @ -0,0 +1,439 @@ | ||||
| from datetime import datetime | ||||
|  | ||||
| from django.conf import settings | ||||
| from django.db.models.expressions import Func | ||||
| from django.db.models.fields import ( | ||||
|     DateField, | ||||
|     DateTimeField, | ||||
|     DurationField, | ||||
|     Field, | ||||
|     IntegerField, | ||||
|     TimeField, | ||||
| ) | ||||
| from django.db.models.lookups import ( | ||||
|     Transform, | ||||
|     YearExact, | ||||
|     YearGt, | ||||
|     YearGte, | ||||
|     YearLt, | ||||
|     YearLte, | ||||
| ) | ||||
| from django.utils import timezone | ||||
|  | ||||
|  | ||||
| class TimezoneMixin: | ||||
|     tzinfo = None | ||||
|  | ||||
|     def get_tzname(self): | ||||
|         # Timezone conversions must happen to the input datetime *before* | ||||
|         # applying a function. 2015-12-31 23:00:00 -02:00 is stored in the | ||||
|         # database as 2016-01-01 01:00:00 +00:00. Any results should be | ||||
|         # based on the input datetime not the stored datetime. | ||||
|         tzname = None | ||||
|         if settings.USE_TZ: | ||||
|             if self.tzinfo is None: | ||||
|                 tzname = timezone.get_current_timezone_name() | ||||
|             else: | ||||
|                 tzname = timezone._get_timezone_name(self.tzinfo) | ||||
|         return tzname | ||||
|  | ||||
|  | ||||
| class Extract(TimezoneMixin, Transform): | ||||
|     lookup_name = None | ||||
|     output_field = IntegerField() | ||||
|  | ||||
|     def __init__(self, expression, lookup_name=None, tzinfo=None, **extra): | ||||
|         if self.lookup_name is None: | ||||
|             self.lookup_name = lookup_name | ||||
|         if self.lookup_name is None: | ||||
|             raise ValueError("lookup_name must be provided") | ||||
|         self.tzinfo = tzinfo | ||||
|         super().__init__(expression, **extra) | ||||
|  | ||||
|     def as_sql(self, compiler, connection): | ||||
|         sql, params = compiler.compile(self.lhs) | ||||
|         lhs_output_field = self.lhs.output_field | ||||
|         if isinstance(lhs_output_field, DateTimeField): | ||||
|             tzname = self.get_tzname() | ||||
|             sql, params = connection.ops.datetime_extract_sql( | ||||
|                 self.lookup_name, sql, tuple(params), tzname | ||||
|             ) | ||||
|         elif self.tzinfo is not None: | ||||
|             raise ValueError("tzinfo can only be used with DateTimeField.") | ||||
|         elif isinstance(lhs_output_field, DateField): | ||||
|             sql, params = connection.ops.date_extract_sql( | ||||
|                 self.lookup_name, sql, tuple(params) | ||||
|             ) | ||||
|         elif isinstance(lhs_output_field, TimeField): | ||||
|             sql, params = connection.ops.time_extract_sql( | ||||
|                 self.lookup_name, sql, tuple(params) | ||||
|             ) | ||||
|         elif isinstance(lhs_output_field, DurationField): | ||||
|             if not connection.features.has_native_duration_field: | ||||
|                 raise ValueError( | ||||
|                     "Extract requires native DurationField database support." | ||||
|                 ) | ||||
|             sql, params = connection.ops.time_extract_sql( | ||||
|                 self.lookup_name, sql, tuple(params) | ||||
|             ) | ||||
|         else: | ||||
|             # resolve_expression has already validated the output_field so this | ||||
|             # assert should never be hit. | ||||
|             assert False, "Tried to Extract from an invalid type." | ||||
|         return sql, params | ||||
|  | ||||
|     def resolve_expression( | ||||
|         self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False | ||||
|     ): | ||||
|         copy = super().resolve_expression( | ||||
|             query, allow_joins, reuse, summarize, for_save | ||||
|         ) | ||||
|         field = getattr(copy.lhs, "output_field", None) | ||||
|         if field is None: | ||||
|             return copy | ||||
|         if not isinstance(field, (DateField, DateTimeField, TimeField, DurationField)): | ||||
|             raise ValueError( | ||||
|                 "Extract input expression must be DateField, DateTimeField, " | ||||
|                 "TimeField, or DurationField." | ||||
|             ) | ||||
|         # Passing dates to functions expecting datetimes is most likely a mistake. | ||||
|         if type(field) is DateField and copy.lookup_name in ( | ||||
|             "hour", | ||||
|             "minute", | ||||
|             "second", | ||||
|         ): | ||||
|             raise ValueError( | ||||
|                 "Cannot extract time component '%s' from DateField '%s'." | ||||
|                 % (copy.lookup_name, field.name) | ||||
|             ) | ||||
|         if isinstance(field, DurationField) and copy.lookup_name in ( | ||||
|             "year", | ||||
|             "iso_year", | ||||
|             "month", | ||||
|             "week", | ||||
|             "week_day", | ||||
|             "iso_week_day", | ||||
|             "quarter", | ||||
|         ): | ||||
|             raise ValueError( | ||||
|                 "Cannot extract component '%s' from DurationField '%s'." | ||||
|                 % (copy.lookup_name, field.name) | ||||
|             ) | ||||
|         return copy | ||||
|  | ||||
|  | ||||
| class ExtractYear(Extract): | ||||
|     lookup_name = "year" | ||||
|  | ||||
|  | ||||
| class ExtractIsoYear(Extract): | ||||
|     """Return the ISO-8601 week-numbering year.""" | ||||
|  | ||||
|     lookup_name = "iso_year" | ||||
|  | ||||
|  | ||||
| class ExtractMonth(Extract): | ||||
|     lookup_name = "month" | ||||
|  | ||||
|  | ||||
| class ExtractDay(Extract): | ||||
|     lookup_name = "day" | ||||
|  | ||||
|  | ||||
| class ExtractWeek(Extract): | ||||
|     """ | ||||
|     Return 1-52 or 53, based on ISO-8601, i.e., Monday is the first of the | ||||
|     week. | ||||
|     """ | ||||
|  | ||||
|     lookup_name = "week" | ||||
|  | ||||
|  | ||||
| class ExtractWeekDay(Extract): | ||||
|     """ | ||||
|     Return Sunday=1 through Saturday=7. | ||||
|  | ||||
|     To replicate this in Python: (mydatetime.isoweekday() % 7) + 1 | ||||
|     """ | ||||
|  | ||||
|     lookup_name = "week_day" | ||||
|  | ||||
|  | ||||
| class ExtractIsoWeekDay(Extract): | ||||
|     """Return Monday=1 through Sunday=7, based on ISO-8601.""" | ||||
|  | ||||
|     lookup_name = "iso_week_day" | ||||
|  | ||||
|  | ||||
| class ExtractQuarter(Extract): | ||||
|     lookup_name = "quarter" | ||||
|  | ||||
|  | ||||
| class ExtractHour(Extract): | ||||
|     lookup_name = "hour" | ||||
|  | ||||
|  | ||||
| class ExtractMinute(Extract): | ||||
|     lookup_name = "minute" | ||||
|  | ||||
|  | ||||
| class ExtractSecond(Extract): | ||||
|     lookup_name = "second" | ||||
|  | ||||
|  | ||||
| DateField.register_lookup(ExtractYear) | ||||
| DateField.register_lookup(ExtractMonth) | ||||
| DateField.register_lookup(ExtractDay) | ||||
| DateField.register_lookup(ExtractWeekDay) | ||||
| DateField.register_lookup(ExtractIsoWeekDay) | ||||
| DateField.register_lookup(ExtractWeek) | ||||
| DateField.register_lookup(ExtractIsoYear) | ||||
| DateField.register_lookup(ExtractQuarter) | ||||
|  | ||||
| TimeField.register_lookup(ExtractHour) | ||||
| TimeField.register_lookup(ExtractMinute) | ||||
| TimeField.register_lookup(ExtractSecond) | ||||
|  | ||||
| DateTimeField.register_lookup(ExtractHour) | ||||
| DateTimeField.register_lookup(ExtractMinute) | ||||
| DateTimeField.register_lookup(ExtractSecond) | ||||
|  | ||||
| ExtractYear.register_lookup(YearExact) | ||||
| ExtractYear.register_lookup(YearGt) | ||||
| ExtractYear.register_lookup(YearGte) | ||||
| ExtractYear.register_lookup(YearLt) | ||||
| ExtractYear.register_lookup(YearLte) | ||||
|  | ||||
| ExtractIsoYear.register_lookup(YearExact) | ||||
| ExtractIsoYear.register_lookup(YearGt) | ||||
| ExtractIsoYear.register_lookup(YearGte) | ||||
| ExtractIsoYear.register_lookup(YearLt) | ||||
| ExtractIsoYear.register_lookup(YearLte) | ||||
|  | ||||
|  | ||||
| class Now(Func): | ||||
|     template = "CURRENT_TIMESTAMP" | ||||
|     output_field = DateTimeField() | ||||
|  | ||||
|     def as_postgresql(self, compiler, connection, **extra_context): | ||||
|         # PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the | ||||
|         # transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with | ||||
|         # other databases. | ||||
|         return self.as_sql( | ||||
|             compiler, connection, template="STATEMENT_TIMESTAMP()", **extra_context | ||||
|         ) | ||||
|  | ||||
|     def as_mysql(self, compiler, connection, **extra_context): | ||||
|         return self.as_sql( | ||||
|             compiler, connection, template="CURRENT_TIMESTAMP(6)", **extra_context | ||||
|         ) | ||||
|  | ||||
|     def as_sqlite(self, compiler, connection, **extra_context): | ||||
|         return self.as_sql( | ||||
|             compiler, | ||||
|             connection, | ||||
|             template="STRFTIME('%%%%Y-%%%%m-%%%%d %%%%H:%%%%M:%%%%f', 'NOW')", | ||||
|             **extra_context, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class TruncBase(TimezoneMixin, Transform): | ||||
|     kind = None | ||||
|     tzinfo = None | ||||
|  | ||||
|     # RemovedInDjango50Warning: when the deprecation ends, remove is_dst | ||||
|     # argument. | ||||
|     def __init__( | ||||
|         self, | ||||
|         expression, | ||||
|         output_field=None, | ||||
|         tzinfo=None, | ||||
|         is_dst=timezone.NOT_PASSED, | ||||
|         **extra, | ||||
|     ): | ||||
|         self.tzinfo = tzinfo | ||||
|         self.is_dst = is_dst | ||||
|         super().__init__(expression, output_field=output_field, **extra) | ||||
|  | ||||
|     def as_sql(self, compiler, connection): | ||||
|         sql, params = compiler.compile(self.lhs) | ||||
|         tzname = None | ||||
|         if isinstance(self.lhs.output_field, DateTimeField): | ||||
|             tzname = self.get_tzname() | ||||
|         elif self.tzinfo is not None: | ||||
|             raise ValueError("tzinfo can only be used with DateTimeField.") | ||||
|         if isinstance(self.output_field, DateTimeField): | ||||
|             sql, params = connection.ops.datetime_trunc_sql( | ||||
|                 self.kind, sql, tuple(params), tzname | ||||
|             ) | ||||
|         elif isinstance(self.output_field, DateField): | ||||
|             sql, params = connection.ops.date_trunc_sql( | ||||
|                 self.kind, sql, tuple(params), tzname | ||||
|             ) | ||||
|         elif isinstance(self.output_field, TimeField): | ||||
|             sql, params = connection.ops.time_trunc_sql( | ||||
|                 self.kind, sql, tuple(params), tzname | ||||
|             ) | ||||
|         else: | ||||
|             raise ValueError( | ||||
|                 "Trunc only valid on DateField, TimeField, or DateTimeField." | ||||
|             ) | ||||
|         return sql, params | ||||
|  | ||||
|     def resolve_expression( | ||||
|         self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False | ||||
|     ): | ||||
|         copy = super().resolve_expression( | ||||
|             query, allow_joins, reuse, summarize, for_save | ||||
|         ) | ||||
|         field = copy.lhs.output_field | ||||
|         # DateTimeField is a subclass of DateField so this works for both. | ||||
|         if not isinstance(field, (DateField, TimeField)): | ||||
|             raise TypeError( | ||||
|                 "%r isn't a DateField, TimeField, or DateTimeField." % field.name | ||||
|             ) | ||||
|         # If self.output_field was None, then accessing the field will trigger | ||||
|         # the resolver to assign it to self.lhs.output_field. | ||||
|         if not isinstance(copy.output_field, (DateField, DateTimeField, TimeField)): | ||||
|             raise ValueError( | ||||
|                 "output_field must be either DateField, TimeField, or DateTimeField" | ||||
|             ) | ||||
|         # Passing dates or times to functions expecting datetimes is most | ||||
|         # likely a mistake. | ||||
|         class_output_field = ( | ||||
|             self.__class__.output_field | ||||
|             if isinstance(self.__class__.output_field, Field) | ||||
|             else None | ||||
|         ) | ||||
|         output_field = class_output_field or copy.output_field | ||||
|         has_explicit_output_field = ( | ||||
|             class_output_field or field.__class__ is not copy.output_field.__class__ | ||||
|         ) | ||||
|         if type(field) is DateField and ( | ||||
|             isinstance(output_field, DateTimeField) | ||||
|             or copy.kind in ("hour", "minute", "second", "time") | ||||
|         ): | ||||
|             raise ValueError( | ||||
|                 "Cannot truncate DateField '%s' to %s." | ||||
|                 % ( | ||||
|                     field.name, | ||||
|                     output_field.__class__.__name__ | ||||
|                     if has_explicit_output_field | ||||
|                     else "DateTimeField", | ||||
|                 ) | ||||
|             ) | ||||
|         elif isinstance(field, TimeField) and ( | ||||
|             isinstance(output_field, DateTimeField) | ||||
|             or copy.kind in ("year", "quarter", "month", "week", "day", "date") | ||||
|         ): | ||||
|             raise ValueError( | ||||
|                 "Cannot truncate TimeField '%s' to %s." | ||||
|                 % ( | ||||
|                     field.name, | ||||
|                     output_field.__class__.__name__ | ||||
|                     if has_explicit_output_field | ||||
|                     else "DateTimeField", | ||||
|                 ) | ||||
|             ) | ||||
|         return copy | ||||
|  | ||||
|     def convert_value(self, value, expression, connection): | ||||
|         if isinstance(self.output_field, DateTimeField): | ||||
|             if not settings.USE_TZ: | ||||
|                 pass | ||||
|             elif value is not None: | ||||
|                 value = value.replace(tzinfo=None) | ||||
|                 value = timezone.make_aware(value, self.tzinfo, is_dst=self.is_dst) | ||||
|             elif not connection.features.has_zoneinfo_database: | ||||
|                 raise ValueError( | ||||
|                     "Database returned an invalid datetime value. Are time " | ||||
|                     "zone definitions for your database installed?" | ||||
|                 ) | ||||
|         elif isinstance(value, datetime): | ||||
|             if value is None: | ||||
|                 pass | ||||
|             elif isinstance(self.output_field, DateField): | ||||
|                 value = value.date() | ||||
|             elif isinstance(self.output_field, TimeField): | ||||
|                 value = value.time() | ||||
|         return value | ||||
|  | ||||
|  | ||||
| class Trunc(TruncBase): | ||||
|     # RemovedInDjango50Warning: when the deprecation ends, remove is_dst | ||||
|     # argument. | ||||
|     def __init__( | ||||
|         self, | ||||
|         expression, | ||||
|         kind, | ||||
|         output_field=None, | ||||
|         tzinfo=None, | ||||
|         is_dst=timezone.NOT_PASSED, | ||||
|         **extra, | ||||
|     ): | ||||
|         self.kind = kind | ||||
|         super().__init__( | ||||
|             expression, output_field=output_field, tzinfo=tzinfo, is_dst=is_dst, **extra | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class TruncYear(TruncBase): | ||||
|     kind = "year" | ||||
|  | ||||
|  | ||||
| class TruncQuarter(TruncBase): | ||||
|     kind = "quarter" | ||||
|  | ||||
|  | ||||
| class TruncMonth(TruncBase): | ||||
|     kind = "month" | ||||
|  | ||||
|  | ||||
| class TruncWeek(TruncBase): | ||||
|     """Truncate to midnight on the Monday of the week.""" | ||||
|  | ||||
|     kind = "week" | ||||
|  | ||||
|  | ||||
| class TruncDay(TruncBase): | ||||
|     kind = "day" | ||||
|  | ||||
|  | ||||
| class TruncDate(TruncBase): | ||||
|     kind = "date" | ||||
|     lookup_name = "date" | ||||
|     output_field = DateField() | ||||
|  | ||||
|     def as_sql(self, compiler, connection): | ||||
|         # Cast to date rather than truncate to date. | ||||
|         sql, params = compiler.compile(self.lhs) | ||||
|         tzname = self.get_tzname() | ||||
|         return connection.ops.datetime_cast_date_sql(sql, tuple(params), tzname) | ||||
|  | ||||
|  | ||||
| class TruncTime(TruncBase): | ||||
|     kind = "time" | ||||
|     lookup_name = "time" | ||||
|     output_field = TimeField() | ||||
|  | ||||
|     def as_sql(self, compiler, connection): | ||||
|         # Cast to time rather than truncate to time. | ||||
|         sql, params = compiler.compile(self.lhs) | ||||
|         tzname = self.get_tzname() | ||||
|         return connection.ops.datetime_cast_time_sql(sql, tuple(params), tzname) | ||||
|  | ||||
|  | ||||
| class TruncHour(TruncBase): | ||||
|     kind = "hour" | ||||
|  | ||||
|  | ||||
| class TruncMinute(TruncBase): | ||||
|     kind = "minute" | ||||
|  | ||||
|  | ||||
| class TruncSecond(TruncBase): | ||||
|     kind = "second" | ||||
|  | ||||
|  | ||||
| DateTimeField.register_lookup(TruncDate) | ||||
| DateTimeField.register_lookup(TruncTime) | ||||
| @ -0,0 +1,212 @@ | ||||
| import math | ||||
|  | ||||
| from django.db.models.expressions import Func, Value | ||||
| from django.db.models.fields import FloatField, IntegerField | ||||
| from django.db.models.functions import Cast | ||||
| from django.db.models.functions.mixins import ( | ||||
|     FixDecimalInputMixin, | ||||
|     NumericOutputFieldMixin, | ||||
| ) | ||||
| from django.db.models.lookups import Transform | ||||
|  | ||||
|  | ||||
| class Abs(Transform): | ||||
|     function = "ABS" | ||||
|     lookup_name = "abs" | ||||
|  | ||||
|  | ||||
| class ACos(NumericOutputFieldMixin, Transform): | ||||
|     function = "ACOS" | ||||
|     lookup_name = "acos" | ||||
|  | ||||
|  | ||||
| class ASin(NumericOutputFieldMixin, Transform): | ||||
|     function = "ASIN" | ||||
|     lookup_name = "asin" | ||||
|  | ||||
|  | ||||
| class ATan(NumericOutputFieldMixin, Transform): | ||||
|     function = "ATAN" | ||||
|     lookup_name = "atan" | ||||
|  | ||||
|  | ||||
| class ATan2(NumericOutputFieldMixin, Func): | ||||
|     function = "ATAN2" | ||||
|     arity = 2 | ||||
|  | ||||
|     def as_sqlite(self, compiler, connection, **extra_context): | ||||
|         if not getattr( | ||||
|             connection.ops, "spatialite", False | ||||
|         ) or connection.ops.spatial_version >= (5, 0, 0): | ||||
|             return self.as_sql(compiler, connection) | ||||
|         # This function is usually ATan2(y, x), returning the inverse tangent | ||||
|         # of y / x, but it's ATan2(x, y) on SpatiaLite < 5.0.0. | ||||
|         # Cast integers to float to avoid inconsistent/buggy behavior if the | ||||
|         # arguments are mixed between integer and float or decimal. | ||||
|         # https://www.gaia-gis.it/fossil/libspatialite/tktview?name=0f72cca3a2 | ||||
|         clone = self.copy() | ||||
|         clone.set_source_expressions( | ||||
|             [ | ||||
|                 Cast(expression, FloatField()) | ||||
|                 if isinstance(expression.output_field, IntegerField) | ||||
|                 else expression | ||||
|                 for expression in self.get_source_expressions()[::-1] | ||||
|             ] | ||||
|         ) | ||||
|         return clone.as_sql(compiler, connection, **extra_context) | ||||
|  | ||||
|  | ||||
| class Ceil(Transform): | ||||
|     function = "CEILING" | ||||
|     lookup_name = "ceil" | ||||
|  | ||||
|     def as_oracle(self, compiler, connection, **extra_context): | ||||
|         return super().as_sql(compiler, connection, function="CEIL", **extra_context) | ||||
|  | ||||
|  | ||||
| class Cos(NumericOutputFieldMixin, Transform): | ||||
|     function = "COS" | ||||
|     lookup_name = "cos" | ||||
|  | ||||
|  | ||||
| class Cot(NumericOutputFieldMixin, Transform): | ||||
|     function = "COT" | ||||
|     lookup_name = "cot" | ||||
|  | ||||
|     def as_oracle(self, compiler, connection, **extra_context): | ||||
|         return super().as_sql( | ||||
|             compiler, connection, template="(1 / TAN(%(expressions)s))", **extra_context | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class Degrees(NumericOutputFieldMixin, Transform): | ||||
|     function = "DEGREES" | ||||
|     lookup_name = "degrees" | ||||
|  | ||||
|     def as_oracle(self, compiler, connection, **extra_context): | ||||
|         return super().as_sql( | ||||
|             compiler, | ||||
|             connection, | ||||
|             template="((%%(expressions)s) * 180 / %s)" % math.pi, | ||||
|             **extra_context, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class Exp(NumericOutputFieldMixin, Transform): | ||||
|     function = "EXP" | ||||
|     lookup_name = "exp" | ||||
|  | ||||
|  | ||||
| class Floor(Transform): | ||||
|     function = "FLOOR" | ||||
|     lookup_name = "floor" | ||||
|  | ||||
|  | ||||
| class Ln(NumericOutputFieldMixin, Transform): | ||||
|     function = "LN" | ||||
|     lookup_name = "ln" | ||||
|  | ||||
|  | ||||
| class Log(FixDecimalInputMixin, NumericOutputFieldMixin, Func): | ||||
|     function = "LOG" | ||||
|     arity = 2 | ||||
|  | ||||
|     def as_sqlite(self, compiler, connection, **extra_context): | ||||
|         if not getattr(connection.ops, "spatialite", False): | ||||
|             return self.as_sql(compiler, connection) | ||||
|         # This function is usually Log(b, x) returning the logarithm of x to | ||||
|         # the base b, but on SpatiaLite it's Log(x, b). | ||||
|         clone = self.copy() | ||||
|         clone.set_source_expressions(self.get_source_expressions()[::-1]) | ||||
|         return clone.as_sql(compiler, connection, **extra_context) | ||||
|  | ||||
|  | ||||
| class Mod(FixDecimalInputMixin, NumericOutputFieldMixin, Func): | ||||
|     function = "MOD" | ||||
|     arity = 2 | ||||
|  | ||||
|  | ||||
| class Pi(NumericOutputFieldMixin, Func): | ||||
|     function = "PI" | ||||
|     arity = 0 | ||||
|  | ||||
|     def as_oracle(self, compiler, connection, **extra_context): | ||||
|         return super().as_sql( | ||||
|             compiler, connection, template=str(math.pi), **extra_context | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class Power(NumericOutputFieldMixin, Func): | ||||
|     function = "POWER" | ||||
|     arity = 2 | ||||
|  | ||||
|  | ||||
| class Radians(NumericOutputFieldMixin, Transform): | ||||
|     function = "RADIANS" | ||||
|     lookup_name = "radians" | ||||
|  | ||||
|     def as_oracle(self, compiler, connection, **extra_context): | ||||
|         return super().as_sql( | ||||
|             compiler, | ||||
|             connection, | ||||
|             template="((%%(expressions)s) * %s / 180)" % math.pi, | ||||
|             **extra_context, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class Random(NumericOutputFieldMixin, Func): | ||||
|     function = "RANDOM" | ||||
|     arity = 0 | ||||
|  | ||||
|     def as_mysql(self, compiler, connection, **extra_context): | ||||
|         return super().as_sql(compiler, connection, function="RAND", **extra_context) | ||||
|  | ||||
|     def as_oracle(self, compiler, connection, **extra_context): | ||||
|         return super().as_sql( | ||||
|             compiler, connection, function="DBMS_RANDOM.VALUE", **extra_context | ||||
|         ) | ||||
|  | ||||
|     def as_sqlite(self, compiler, connection, **extra_context): | ||||
|         return super().as_sql(compiler, connection, function="RAND", **extra_context) | ||||
|  | ||||
|     def get_group_by_cols(self): | ||||
|         return [] | ||||
|  | ||||
|  | ||||
| class Round(FixDecimalInputMixin, Transform): | ||||
|     function = "ROUND" | ||||
|     lookup_name = "round" | ||||
|     arity = None  # Override Transform's arity=1 to enable passing precision. | ||||
|  | ||||
|     def __init__(self, expression, precision=0, **extra): | ||||
|         super().__init__(expression, precision, **extra) | ||||
|  | ||||
|     def as_sqlite(self, compiler, connection, **extra_context): | ||||
|         precision = self.get_source_expressions()[1] | ||||
|         if isinstance(precision, Value) and precision.value < 0: | ||||
|             raise ValueError("SQLite does not support negative precision.") | ||||
|         return super().as_sqlite(compiler, connection, **extra_context) | ||||
|  | ||||
|     def _resolve_output_field(self): | ||||
|         source = self.get_source_expressions()[0] | ||||
|         return source.output_field | ||||
|  | ||||
|  | ||||
| class Sign(Transform): | ||||
|     function = "SIGN" | ||||
|     lookup_name = "sign" | ||||
|  | ||||
|  | ||||
| class Sin(NumericOutputFieldMixin, Transform): | ||||
|     function = "SIN" | ||||
|     lookup_name = "sin" | ||||
|  | ||||
|  | ||||
| class Sqrt(NumericOutputFieldMixin, Transform): | ||||
|     function = "SQRT" | ||||
|     lookup_name = "sqrt" | ||||
|  | ||||
|  | ||||
| class Tan(NumericOutputFieldMixin, Transform): | ||||
|     function = "TAN" | ||||
|     lookup_name = "tan" | ||||
| @ -0,0 +1,57 @@ | ||||
| import sys | ||||
|  | ||||
| from django.db.models.fields import DecimalField, FloatField, IntegerField | ||||
| from django.db.models.functions import Cast | ||||
|  | ||||
|  | ||||
| class FixDecimalInputMixin: | ||||
|     def as_postgresql(self, compiler, connection, **extra_context): | ||||
|         # Cast FloatField to DecimalField as PostgreSQL doesn't support the | ||||
|         # following function signatures: | ||||
|         # - LOG(double, double) | ||||
|         # - MOD(double, double) | ||||
|         output_field = DecimalField(decimal_places=sys.float_info.dig, max_digits=1000) | ||||
|         clone = self.copy() | ||||
|         clone.set_source_expressions( | ||||
|             [ | ||||
|                 Cast(expression, output_field) | ||||
|                 if isinstance(expression.output_field, FloatField) | ||||
|                 else expression | ||||
|                 for expression in self.get_source_expressions() | ||||
|             ] | ||||
|         ) | ||||
|         return clone.as_sql(compiler, connection, **extra_context) | ||||
|  | ||||
|  | ||||
| class FixDurationInputMixin: | ||||
|     def as_mysql(self, compiler, connection, **extra_context): | ||||
|         sql, params = super().as_sql(compiler, connection, **extra_context) | ||||
|         if self.output_field.get_internal_type() == "DurationField": | ||||
|             sql = "CAST(%s AS SIGNED)" % sql | ||||
|         return sql, params | ||||
|  | ||||
|     def as_oracle(self, compiler, connection, **extra_context): | ||||
|         if self.output_field.get_internal_type() == "DurationField": | ||||
|             expression = self.get_source_expressions()[0] | ||||
|             options = self._get_repr_options() | ||||
|             from django.db.backends.oracle.functions import ( | ||||
|                 IntervalToSeconds, | ||||
|                 SecondsToInterval, | ||||
|             ) | ||||
|  | ||||
|             return compiler.compile( | ||||
|                 SecondsToInterval( | ||||
|                     self.__class__(IntervalToSeconds(expression), **options) | ||||
|                 ) | ||||
|             ) | ||||
|         return super().as_sql(compiler, connection, **extra_context) | ||||
|  | ||||
|  | ||||
| class NumericOutputFieldMixin: | ||||
|     def _resolve_output_field(self): | ||||
|         source_fields = self.get_source_fields() | ||||
|         if any(isinstance(s, DecimalField) for s in source_fields): | ||||
|             return DecimalField() | ||||
|         if any(isinstance(s, IntegerField) for s in source_fields): | ||||
|             return FloatField() | ||||
|         return super()._resolve_output_field() if source_fields else FloatField() | ||||
| @ -0,0 +1,365 @@ | ||||
| from django.db import NotSupportedError | ||||
| from django.db.models.expressions import Func, Value | ||||
| from django.db.models.fields import CharField, IntegerField, TextField | ||||
| from django.db.models.functions import Cast, Coalesce | ||||
| from django.db.models.lookups import Transform | ||||
|  | ||||
|  | ||||
| class MySQLSHA2Mixin: | ||||
|     def as_mysql(self, compiler, connection, **extra_context): | ||||
|         return super().as_sql( | ||||
|             compiler, | ||||
|             connection, | ||||
|             template="SHA2(%%(expressions)s, %s)" % self.function[3:], | ||||
|             **extra_context, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class OracleHashMixin: | ||||
|     def as_oracle(self, compiler, connection, **extra_context): | ||||
|         return super().as_sql( | ||||
|             compiler, | ||||
|             connection, | ||||
|             template=( | ||||
|                 "LOWER(RAWTOHEX(STANDARD_HASH(UTL_I18N.STRING_TO_RAW(" | ||||
|                 "%(expressions)s, 'AL32UTF8'), '%(function)s')))" | ||||
|             ), | ||||
|             **extra_context, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class PostgreSQLSHAMixin: | ||||
|     def as_postgresql(self, compiler, connection, **extra_context): | ||||
|         return super().as_sql( | ||||
|             compiler, | ||||
|             connection, | ||||
|             template="ENCODE(DIGEST(%(expressions)s, '%(function)s'), 'hex')", | ||||
|             function=self.function.lower(), | ||||
|             **extra_context, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class Chr(Transform): | ||||
|     function = "CHR" | ||||
|     lookup_name = "chr" | ||||
|  | ||||
|     def as_mysql(self, compiler, connection, **extra_context): | ||||
|         return super().as_sql( | ||||
|             compiler, | ||||
|             connection, | ||||
|             function="CHAR", | ||||
|             template="%(function)s(%(expressions)s USING utf16)", | ||||
|             **extra_context, | ||||
|         ) | ||||
|  | ||||
|     def as_oracle(self, compiler, connection, **extra_context): | ||||
|         return super().as_sql( | ||||
|             compiler, | ||||
|             connection, | ||||
|             template="%(function)s(%(expressions)s USING NCHAR_CS)", | ||||
|             **extra_context, | ||||
|         ) | ||||
|  | ||||
|     def as_sqlite(self, compiler, connection, **extra_context): | ||||
|         return super().as_sql(compiler, connection, function="CHAR", **extra_context) | ||||
|  | ||||
|  | ||||
| class ConcatPair(Func): | ||||
|     """ | ||||
|     Concatenate two arguments together. This is used by `Concat` because not | ||||
|     all backend databases support more than two arguments. | ||||
|     """ | ||||
|  | ||||
|     function = "CONCAT" | ||||
|  | ||||
|     def as_sqlite(self, compiler, connection, **extra_context): | ||||
|         coalesced = self.coalesce() | ||||
|         return super(ConcatPair, coalesced).as_sql( | ||||
|             compiler, | ||||
|             connection, | ||||
|             template="%(expressions)s", | ||||
|             arg_joiner=" || ", | ||||
|             **extra_context, | ||||
|         ) | ||||
|  | ||||
|     def as_postgresql(self, compiler, connection, **extra_context): | ||||
|         copy = self.copy() | ||||
|         copy.set_source_expressions( | ||||
|             [ | ||||
|                 Cast(expression, TextField()) | ||||
|                 for expression in copy.get_source_expressions() | ||||
|             ] | ||||
|         ) | ||||
|         return super(ConcatPair, copy).as_sql( | ||||
|             compiler, | ||||
|             connection, | ||||
|             **extra_context, | ||||
|         ) | ||||
|  | ||||
|     def as_mysql(self, compiler, connection, **extra_context): | ||||
|         # Use CONCAT_WS with an empty separator so that NULLs are ignored. | ||||
|         return super().as_sql( | ||||
|             compiler, | ||||
|             connection, | ||||
|             function="CONCAT_WS", | ||||
|             template="%(function)s('', %(expressions)s)", | ||||
|             **extra_context, | ||||
|         ) | ||||
|  | ||||
|     def coalesce(self): | ||||
|         # null on either side results in null for expression, wrap with coalesce | ||||
|         c = self.copy() | ||||
|         c.set_source_expressions( | ||||
|             [ | ||||
|                 Coalesce(expression, Value("")) | ||||
|                 for expression in c.get_source_expressions() | ||||
|             ] | ||||
|         ) | ||||
|         return c | ||||
|  | ||||
|  | ||||
| class Concat(Func): | ||||
|     """ | ||||
|     Concatenate text fields together. Backends that result in an entire | ||||
|     null expression when any arguments are null will wrap each argument in | ||||
|     coalesce functions to ensure a non-null result. | ||||
|     """ | ||||
|  | ||||
|     function = None | ||||
|     template = "%(expressions)s" | ||||
|  | ||||
|     def __init__(self, *expressions, **extra): | ||||
|         if len(expressions) < 2: | ||||
|             raise ValueError("Concat must take at least two expressions") | ||||
|         paired = self._paired(expressions) | ||||
|         super().__init__(paired, **extra) | ||||
|  | ||||
|     def _paired(self, expressions): | ||||
|         # wrap pairs of expressions in successive concat functions | ||||
|         # exp = [a, b, c, d] | ||||
|         # -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d)))) | ||||
|         if len(expressions) == 2: | ||||
|             return ConcatPair(*expressions) | ||||
|         return ConcatPair(expressions[0], self._paired(expressions[1:])) | ||||
|  | ||||
|  | ||||
| class Left(Func): | ||||
|     function = "LEFT" | ||||
|     arity = 2 | ||||
|     output_field = CharField() | ||||
|  | ||||
|     def __init__(self, expression, length, **extra): | ||||
|         """ | ||||
|         expression: the name of a field, or an expression returning a string | ||||
|         length: the number of characters to return from the start of the string | ||||
|         """ | ||||
|         if not hasattr(length, "resolve_expression"): | ||||
|             if length < 1: | ||||
|                 raise ValueError("'length' must be greater than 0.") | ||||
|         super().__init__(expression, length, **extra) | ||||
|  | ||||
|     def get_substr(self): | ||||
|         return Substr(self.source_expressions[0], Value(1), self.source_expressions[1]) | ||||
|  | ||||
|     def as_oracle(self, compiler, connection, **extra_context): | ||||
|         return self.get_substr().as_oracle(compiler, connection, **extra_context) | ||||
|  | ||||
|     def as_sqlite(self, compiler, connection, **extra_context): | ||||
|         return self.get_substr().as_sqlite(compiler, connection, **extra_context) | ||||
|  | ||||
|  | ||||
| class Length(Transform): | ||||
|     """Return the number of characters in the expression.""" | ||||
|  | ||||
|     function = "LENGTH" | ||||
|     lookup_name = "length" | ||||
|     output_field = IntegerField() | ||||
|  | ||||
|     def as_mysql(self, compiler, connection, **extra_context): | ||||
|         return super().as_sql( | ||||
|             compiler, connection, function="CHAR_LENGTH", **extra_context | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class Lower(Transform): | ||||
|     function = "LOWER" | ||||
|     lookup_name = "lower" | ||||
|  | ||||
|  | ||||
| class LPad(Func): | ||||
|     function = "LPAD" | ||||
|     output_field = CharField() | ||||
|  | ||||
|     def __init__(self, expression, length, fill_text=Value(" "), **extra): | ||||
|         if ( | ||||
|             not hasattr(length, "resolve_expression") | ||||
|             and length is not None | ||||
|             and length < 0 | ||||
|         ): | ||||
|             raise ValueError("'length' must be greater or equal to 0.") | ||||
|         super().__init__(expression, length, fill_text, **extra) | ||||
|  | ||||
|  | ||||
| class LTrim(Transform): | ||||
|     function = "LTRIM" | ||||
|     lookup_name = "ltrim" | ||||
|  | ||||
|  | ||||
| class MD5(OracleHashMixin, Transform): | ||||
|     function = "MD5" | ||||
|     lookup_name = "md5" | ||||
|  | ||||
|  | ||||
| class Ord(Transform): | ||||
|     function = "ASCII" | ||||
|     lookup_name = "ord" | ||||
|     output_field = IntegerField() | ||||
|  | ||||
|     def as_mysql(self, compiler, connection, **extra_context): | ||||
|         return super().as_sql(compiler, connection, function="ORD", **extra_context) | ||||
|  | ||||
|     def as_sqlite(self, compiler, connection, **extra_context): | ||||
|         return super().as_sql(compiler, connection, function="UNICODE", **extra_context) | ||||
|  | ||||
|  | ||||
| class Repeat(Func): | ||||
|     function = "REPEAT" | ||||
|     output_field = CharField() | ||||
|  | ||||
|     def __init__(self, expression, number, **extra): | ||||
|         if ( | ||||
|             not hasattr(number, "resolve_expression") | ||||
|             and number is not None | ||||
|             and number < 0 | ||||
|         ): | ||||
|             raise ValueError("'number' must be greater or equal to 0.") | ||||
|         super().__init__(expression, number, **extra) | ||||
|  | ||||
|     def as_oracle(self, compiler, connection, **extra_context): | ||||
|         expression, number = self.source_expressions | ||||
|         length = None if number is None else Length(expression) * number | ||||
|         rpad = RPad(expression, length, expression) | ||||
|         return rpad.as_sql(compiler, connection, **extra_context) | ||||
|  | ||||
|  | ||||
| class Replace(Func): | ||||
|     function = "REPLACE" | ||||
|  | ||||
|     def __init__(self, expression, text, replacement=Value(""), **extra): | ||||
|         super().__init__(expression, text, replacement, **extra) | ||||
|  | ||||
|  | ||||
| class Reverse(Transform): | ||||
|     function = "REVERSE" | ||||
|     lookup_name = "reverse" | ||||
|  | ||||
|     def as_oracle(self, compiler, connection, **extra_context): | ||||
|         # REVERSE in Oracle is undocumented and doesn't support multi-byte | ||||
|         # strings. Use a special subquery instead. | ||||
|         return super().as_sql( | ||||
|             compiler, | ||||
|             connection, | ||||
|             template=( | ||||
|                 "(SELECT LISTAGG(s) WITHIN GROUP (ORDER BY n DESC) FROM " | ||||
|                 "(SELECT LEVEL n, SUBSTR(%(expressions)s, LEVEL, 1) s " | ||||
|                 "FROM DUAL CONNECT BY LEVEL <= LENGTH(%(expressions)s)) " | ||||
|                 "GROUP BY %(expressions)s)" | ||||
|             ), | ||||
|             **extra_context, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class Right(Left): | ||||
|     function = "RIGHT" | ||||
|  | ||||
|     def get_substr(self): | ||||
|         return Substr( | ||||
|             self.source_expressions[0], self.source_expressions[1] * Value(-1) | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class RPad(LPad): | ||||
|     function = "RPAD" | ||||
|  | ||||
|  | ||||
| class RTrim(Transform): | ||||
|     function = "RTRIM" | ||||
|     lookup_name = "rtrim" | ||||
|  | ||||
|  | ||||
| class SHA1(OracleHashMixin, PostgreSQLSHAMixin, Transform): | ||||
|     function = "SHA1" | ||||
|     lookup_name = "sha1" | ||||
|  | ||||
|  | ||||
| class SHA224(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform): | ||||
|     function = "SHA224" | ||||
|     lookup_name = "sha224" | ||||
|  | ||||
|     def as_oracle(self, compiler, connection, **extra_context): | ||||
|         raise NotSupportedError("SHA224 is not supported on Oracle.") | ||||
|  | ||||
|  | ||||
| class SHA256(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform): | ||||
|     function = "SHA256" | ||||
|     lookup_name = "sha256" | ||||
|  | ||||
|  | ||||
| class SHA384(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform): | ||||
|     function = "SHA384" | ||||
|     lookup_name = "sha384" | ||||
|  | ||||
|  | ||||
| class SHA512(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform): | ||||
|     function = "SHA512" | ||||
|     lookup_name = "sha512" | ||||
|  | ||||
|  | ||||
| class StrIndex(Func): | ||||
|     """ | ||||
|     Return a positive integer corresponding to the 1-indexed position of the | ||||
|     first occurrence of a substring inside another string, or 0 if the | ||||
|     substring is not found. | ||||
|     """ | ||||
|  | ||||
|     function = "INSTR" | ||||
|     arity = 2 | ||||
|     output_field = IntegerField() | ||||
|  | ||||
|     def as_postgresql(self, compiler, connection, **extra_context): | ||||
|         return super().as_sql(compiler, connection, function="STRPOS", **extra_context) | ||||
|  | ||||
|  | ||||
| class Substr(Func): | ||||
|     function = "SUBSTRING" | ||||
|     output_field = CharField() | ||||
|  | ||||
|     def __init__(self, expression, pos, length=None, **extra): | ||||
|         """ | ||||
|         expression: the name of a field, or an expression returning a string | ||||
|         pos: an integer > 0, or an expression returning an integer | ||||
|         length: an optional number of characters to return | ||||
|         """ | ||||
|         if not hasattr(pos, "resolve_expression"): | ||||
|             if pos < 1: | ||||
|                 raise ValueError("'pos' must be greater than 0") | ||||
|         expressions = [expression, pos] | ||||
|         if length is not None: | ||||
|             expressions.append(length) | ||||
|         super().__init__(*expressions, **extra) | ||||
|  | ||||
|     def as_sqlite(self, compiler, connection, **extra_context): | ||||
|         return super().as_sql(compiler, connection, function="SUBSTR", **extra_context) | ||||
|  | ||||
|     def as_oracle(self, compiler, connection, **extra_context): | ||||
|         return super().as_sql(compiler, connection, function="SUBSTR", **extra_context) | ||||
|  | ||||
|  | ||||
| class Trim(Transform): | ||||
|     function = "TRIM" | ||||
|     lookup_name = "trim" | ||||
|  | ||||
|  | ||||
| class Upper(Transform): | ||||
|     function = "UPPER" | ||||
|     lookup_name = "upper" | ||||
| @ -0,0 +1,120 @@ | ||||
| from django.db.models.expressions import Func | ||||
| from django.db.models.fields import FloatField, IntegerField | ||||
|  | ||||
| __all__ = [ | ||||
|     "CumeDist", | ||||
|     "DenseRank", | ||||
|     "FirstValue", | ||||
|     "Lag", | ||||
|     "LastValue", | ||||
|     "Lead", | ||||
|     "NthValue", | ||||
|     "Ntile", | ||||
|     "PercentRank", | ||||
|     "Rank", | ||||
|     "RowNumber", | ||||
| ] | ||||
|  | ||||
|  | ||||
| class CumeDist(Func): | ||||
|     function = "CUME_DIST" | ||||
|     output_field = FloatField() | ||||
|     window_compatible = True | ||||
|  | ||||
|  | ||||
| class DenseRank(Func): | ||||
|     function = "DENSE_RANK" | ||||
|     output_field = IntegerField() | ||||
|     window_compatible = True | ||||
|  | ||||
|  | ||||
| class FirstValue(Func): | ||||
|     arity = 1 | ||||
|     function = "FIRST_VALUE" | ||||
|     window_compatible = True | ||||
|  | ||||
|  | ||||
| class LagLeadFunction(Func): | ||||
|     window_compatible = True | ||||
|  | ||||
|     def __init__(self, expression, offset=1, default=None, **extra): | ||||
|         if expression is None: | ||||
|             raise ValueError( | ||||
|                 "%s requires a non-null source expression." % self.__class__.__name__ | ||||
|             ) | ||||
|         if offset is None or offset <= 0: | ||||
|             raise ValueError( | ||||
|                 "%s requires a positive integer for the offset." | ||||
|                 % self.__class__.__name__ | ||||
|             ) | ||||
|         args = (expression, offset) | ||||
|         if default is not None: | ||||
|             args += (default,) | ||||
|         super().__init__(*args, **extra) | ||||
|  | ||||
|     def _resolve_output_field(self): | ||||
|         sources = self.get_source_expressions() | ||||
|         return sources[0].output_field | ||||
|  | ||||
|  | ||||
| class Lag(LagLeadFunction): | ||||
|     function = "LAG" | ||||
|  | ||||
|  | ||||
| class LastValue(Func): | ||||
|     arity = 1 | ||||
|     function = "LAST_VALUE" | ||||
|     window_compatible = True | ||||
|  | ||||
|  | ||||
| class Lead(LagLeadFunction): | ||||
|     function = "LEAD" | ||||
|  | ||||
|  | ||||
| class NthValue(Func): | ||||
|     function = "NTH_VALUE" | ||||
|     window_compatible = True | ||||
|  | ||||
|     def __init__(self, expression, nth=1, **extra): | ||||
|         if expression is None: | ||||
|             raise ValueError( | ||||
|                 "%s requires a non-null source expression." % self.__class__.__name__ | ||||
|             ) | ||||
|         if nth is None or nth <= 0: | ||||
|             raise ValueError( | ||||
|                 "%s requires a positive integer as for nth." % self.__class__.__name__ | ||||
|             ) | ||||
|         super().__init__(expression, nth, **extra) | ||||
|  | ||||
|     def _resolve_output_field(self): | ||||
|         sources = self.get_source_expressions() | ||||
|         return sources[0].output_field | ||||
|  | ||||
|  | ||||
| class Ntile(Func): | ||||
|     function = "NTILE" | ||||
|     output_field = IntegerField() | ||||
|     window_compatible = True | ||||
|  | ||||
|     def __init__(self, num_buckets=1, **extra): | ||||
|         if num_buckets <= 0: | ||||
|             raise ValueError("num_buckets must be greater than 0.") | ||||
|         super().__init__(num_buckets, **extra) | ||||
|  | ||||
|  | ||||
| class PercentRank(Func): | ||||
|     function = "PERCENT_RANK" | ||||
|     output_field = FloatField() | ||||
|     window_compatible = True | ||||
|  | ||||
|  | ||||
| class Rank(Func): | ||||
|     function = "RANK" | ||||
|     output_field = IntegerField() | ||||
|     window_compatible = True | ||||
|  | ||||
|  | ||||
| class RowNumber(Func): | ||||
|     function = "ROW_NUMBER" | ||||
|     output_field = IntegerField() | ||||
|     window_compatible = True | ||||
| @ -0,0 +1,295 @@ | ||||
| from django.db.backends.utils import names_digest, split_identifier | ||||
| from django.db.models.expressions import Col, ExpressionList, F, Func, OrderBy | ||||
| from django.db.models.functions import Collate | ||||
| from django.db.models.query_utils import Q | ||||
| from django.db.models.sql import Query | ||||
| from django.utils.functional import partition | ||||
|  | ||||
| __all__ = ["Index"] | ||||
|  | ||||
|  | ||||
| class Index: | ||||
|     suffix = "idx" | ||||
|     # The max length of the name of the index (restricted to 30 for | ||||
|     # cross-database compatibility with Oracle) | ||||
|     max_name_length = 30 | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         *expressions, | ||||
|         fields=(), | ||||
|         name=None, | ||||
|         db_tablespace=None, | ||||
|         opclasses=(), | ||||
|         condition=None, | ||||
|         include=None, | ||||
|     ): | ||||
|         if opclasses and not name: | ||||
|             raise ValueError("An index must be named to use opclasses.") | ||||
|         if not isinstance(condition, (type(None), Q)): | ||||
|             raise ValueError("Index.condition must be a Q instance.") | ||||
|         if condition and not name: | ||||
|             raise ValueError("An index must be named to use condition.") | ||||
|         if not isinstance(fields, (list, tuple)): | ||||
|             raise ValueError("Index.fields must be a list or tuple.") | ||||
|         if not isinstance(opclasses, (list, tuple)): | ||||
|             raise ValueError("Index.opclasses must be a list or tuple.") | ||||
|         if not expressions and not fields: | ||||
|             raise ValueError( | ||||
|                 "At least one field or expression is required to define an index." | ||||
|             ) | ||||
|         if expressions and fields: | ||||
|             raise ValueError( | ||||
|                 "Index.fields and expressions are mutually exclusive.", | ||||
|             ) | ||||
|         if expressions and not name: | ||||
|             raise ValueError("An index must be named to use expressions.") | ||||
|         if expressions and opclasses: | ||||
|             raise ValueError( | ||||
|                 "Index.opclasses cannot be used with expressions. Use " | ||||
|                 "django.contrib.postgres.indexes.OpClass() instead." | ||||
|             ) | ||||
|         if opclasses and len(fields) != len(opclasses): | ||||
|             raise ValueError( | ||||
|                 "Index.fields and Index.opclasses must have the same number of " | ||||
|                 "elements." | ||||
|             ) | ||||
|         if fields and not all(isinstance(field, str) for field in fields): | ||||
|             raise ValueError("Index.fields must contain only strings with field names.") | ||||
|         if include and not name: | ||||
|             raise ValueError("A covering index must be named.") | ||||
|         if not isinstance(include, (type(None), list, tuple)): | ||||
|             raise ValueError("Index.include must be a list or tuple.") | ||||
|         self.fields = list(fields) | ||||
|         # A list of 2-tuple with the field name and ordering ('' or 'DESC'). | ||||
|         self.fields_orders = [ | ||||
|             (field_name[1:], "DESC") if field_name.startswith("-") else (field_name, "") | ||||
|             for field_name in self.fields | ||||
|         ] | ||||
|         self.name = name or "" | ||||
|         self.db_tablespace = db_tablespace | ||||
|         self.opclasses = opclasses | ||||
|         self.condition = condition | ||||
|         self.include = tuple(include) if include else () | ||||
|         self.expressions = tuple( | ||||
|             F(expression) if isinstance(expression, str) else expression | ||||
|             for expression in expressions | ||||
|         ) | ||||
|  | ||||
|     @property | ||||
|     def contains_expressions(self): | ||||
|         return bool(self.expressions) | ||||
|  | ||||
|     def _get_condition_sql(self, model, schema_editor): | ||||
|         if self.condition is None: | ||||
|             return None | ||||
|         query = Query(model=model, alias_cols=False) | ||||
|         where = query.build_where(self.condition) | ||||
|         compiler = query.get_compiler(connection=schema_editor.connection) | ||||
|         sql, params = where.as_sql(compiler, schema_editor.connection) | ||||
|         return sql % tuple(schema_editor.quote_value(p) for p in params) | ||||
|  | ||||
|     def create_sql(self, model, schema_editor, using="", **kwargs): | ||||
|         include = [ | ||||
|             model._meta.get_field(field_name).column for field_name in self.include | ||||
|         ] | ||||
|         condition = self._get_condition_sql(model, schema_editor) | ||||
|         if self.expressions: | ||||
|             index_expressions = [] | ||||
|             for expression in self.expressions: | ||||
|                 index_expression = IndexExpression(expression) | ||||
|                 index_expression.set_wrapper_classes(schema_editor.connection) | ||||
|                 index_expressions.append(index_expression) | ||||
|             expressions = ExpressionList(*index_expressions).resolve_expression( | ||||
|                 Query(model, alias_cols=False), | ||||
|             ) | ||||
|             fields = None | ||||
|             col_suffixes = None | ||||
|         else: | ||||
|             fields = [ | ||||
|                 model._meta.get_field(field_name) | ||||
|                 for field_name, _ in self.fields_orders | ||||
|             ] | ||||
|             if schema_editor.connection.features.supports_index_column_ordering: | ||||
|                 col_suffixes = [order[1] for order in self.fields_orders] | ||||
|             else: | ||||
|                 col_suffixes = [""] * len(self.fields_orders) | ||||
|             expressions = None | ||||
|         return schema_editor._create_index_sql( | ||||
|             model, | ||||
|             fields=fields, | ||||
|             name=self.name, | ||||
|             using=using, | ||||
|             db_tablespace=self.db_tablespace, | ||||
|             col_suffixes=col_suffixes, | ||||
|             opclasses=self.opclasses, | ||||
|             condition=condition, | ||||
|             include=include, | ||||
|             expressions=expressions, | ||||
|             **kwargs, | ||||
|         ) | ||||
|  | ||||
|     def remove_sql(self, model, schema_editor, **kwargs): | ||||
|         return schema_editor._delete_index_sql(model, self.name, **kwargs) | ||||
|  | ||||
|     def deconstruct(self): | ||||
|         path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__) | ||||
|         path = path.replace("django.db.models.indexes", "django.db.models") | ||||
|         kwargs = {"name": self.name} | ||||
|         if self.fields: | ||||
|             kwargs["fields"] = self.fields | ||||
|         if self.db_tablespace is not None: | ||||
|             kwargs["db_tablespace"] = self.db_tablespace | ||||
|         if self.opclasses: | ||||
|             kwargs["opclasses"] = self.opclasses | ||||
|         if self.condition: | ||||
|             kwargs["condition"] = self.condition | ||||
|         if self.include: | ||||
|             kwargs["include"] = self.include | ||||
|         return (path, self.expressions, kwargs) | ||||
|  | ||||
|     def clone(self): | ||||
|         """Create a copy of this Index.""" | ||||
|         _, args, kwargs = self.deconstruct() | ||||
|         return self.__class__(*args, **kwargs) | ||||
|  | ||||
|     def set_name_with_model(self, model): | ||||
|         """ | ||||
|         Generate a unique name for the index. | ||||
|  | ||||
|         The name is divided into 3 parts - table name (12 chars), field name | ||||
|         (8 chars) and unique hash + suffix (10 chars). Each part is made to | ||||
|         fit its size by truncating the excess length. | ||||
|         """ | ||||
|         _, table_name = split_identifier(model._meta.db_table) | ||||
|         column_names = [ | ||||
|             model._meta.get_field(field_name).column | ||||
|             for field_name, order in self.fields_orders | ||||
|         ] | ||||
|         column_names_with_order = [ | ||||
|             (("-%s" if order else "%s") % column_name) | ||||
|             for column_name, (field_name, order) in zip( | ||||
|                 column_names, self.fields_orders | ||||
|             ) | ||||
|         ] | ||||
|         # The length of the parts of the name is based on the default max | ||||
|         # length of 30 characters. | ||||
|         hash_data = [table_name] + column_names_with_order + [self.suffix] | ||||
|         self.name = "%s_%s_%s" % ( | ||||
|             table_name[:11], | ||||
|             column_names[0][:7], | ||||
|             "%s_%s" % (names_digest(*hash_data, length=6), self.suffix), | ||||
|         ) | ||||
|         if len(self.name) > self.max_name_length: | ||||
|             raise ValueError( | ||||
|                 "Index too long for multiple database support. Is self.suffix " | ||||
|                 "longer than 3 characters?" | ||||
|             ) | ||||
|         if self.name[0] == "_" or self.name[0].isdigit(): | ||||
|             self.name = "D%s" % self.name[1:] | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "<%s:%s%s%s%s%s%s%s>" % ( | ||||
|             self.__class__.__qualname__, | ||||
|             "" if not self.fields else " fields=%s" % repr(self.fields), | ||||
|             "" if not self.expressions else " expressions=%s" % repr(self.expressions), | ||||
|             "" if not self.name else " name=%s" % repr(self.name), | ||||
|             "" | ||||
|             if self.db_tablespace is None | ||||
|             else " db_tablespace=%s" % repr(self.db_tablespace), | ||||
|             "" if self.condition is None else " condition=%s" % self.condition, | ||||
|             "" if not self.include else " include=%s" % repr(self.include), | ||||
|             "" if not self.opclasses else " opclasses=%s" % repr(self.opclasses), | ||||
|         ) | ||||
|  | ||||
|     def __eq__(self, other): | ||||
|         if self.__class__ == other.__class__: | ||||
|             return self.deconstruct() == other.deconstruct() | ||||
|         return NotImplemented | ||||
|  | ||||
|  | ||||
| class IndexExpression(Func): | ||||
|     """Order and wrap expressions for CREATE INDEX statements.""" | ||||
|  | ||||
|     template = "%(expressions)s" | ||||
|     wrapper_classes = (OrderBy, Collate) | ||||
|  | ||||
|     def set_wrapper_classes(self, connection=None): | ||||
|         # Some databases (e.g. MySQL) treats COLLATE as an indexed expression. | ||||
|         if connection and connection.features.collate_as_index_expression: | ||||
|             self.wrapper_classes = tuple( | ||||
|                 [ | ||||
|                     wrapper_cls | ||||
|                     for wrapper_cls in self.wrapper_classes | ||||
|                     if wrapper_cls is not Collate | ||||
|                 ] | ||||
|             ) | ||||
|  | ||||
|     @classmethod | ||||
|     def register_wrappers(cls, *wrapper_classes): | ||||
|         cls.wrapper_classes = wrapper_classes | ||||
|  | ||||
|     def resolve_expression( | ||||
|         self, | ||||
|         query=None, | ||||
|         allow_joins=True, | ||||
|         reuse=None, | ||||
|         summarize=False, | ||||
|         for_save=False, | ||||
|     ): | ||||
|         expressions = list(self.flatten()) | ||||
|         # Split expressions and wrappers. | ||||
|         index_expressions, wrappers = partition( | ||||
|             lambda e: isinstance(e, self.wrapper_classes), | ||||
|             expressions, | ||||
|         ) | ||||
|         wrapper_types = [type(wrapper) for wrapper in wrappers] | ||||
|         if len(wrapper_types) != len(set(wrapper_types)): | ||||
|             raise ValueError( | ||||
|                 "Multiple references to %s can't be used in an indexed " | ||||
|                 "expression." | ||||
|                 % ", ".join( | ||||
|                     [wrapper_cls.__qualname__ for wrapper_cls in self.wrapper_classes] | ||||
|                 ) | ||||
|             ) | ||||
|         if expressions[1 : len(wrappers) + 1] != wrappers: | ||||
|             raise ValueError( | ||||
|                 "%s must be topmost expressions in an indexed expression." | ||||
|                 % ", ".join( | ||||
|                     [wrapper_cls.__qualname__ for wrapper_cls in self.wrapper_classes] | ||||
|                 ) | ||||
|             ) | ||||
|         # Wrap expressions in parentheses if they are not column references. | ||||
|         root_expression = index_expressions[1] | ||||
|         resolve_root_expression = root_expression.resolve_expression( | ||||
|             query, | ||||
|             allow_joins, | ||||
|             reuse, | ||||
|             summarize, | ||||
|             for_save, | ||||
|         ) | ||||
|         if not isinstance(resolve_root_expression, Col): | ||||
|             root_expression = Func(root_expression, template="(%(expressions)s)") | ||||
|  | ||||
|         if wrappers: | ||||
|             # Order wrappers and set their expressions. | ||||
|             wrappers = sorted( | ||||
|                 wrappers, | ||||
|                 key=lambda w: self.wrapper_classes.index(type(w)), | ||||
|             ) | ||||
|             wrappers = [wrapper.copy() for wrapper in wrappers] | ||||
|             for i, wrapper in enumerate(wrappers[:-1]): | ||||
|                 wrapper.set_source_expressions([wrappers[i + 1]]) | ||||
|             # Set the root expression on the deepest wrapper. | ||||
|             wrappers[-1].set_source_expressions([root_expression]) | ||||
|             self.set_source_expressions([wrappers[0]]) | ||||
|         else: | ||||
|             # Use the root expression, if there are no wrappers. | ||||
|             self.set_source_expressions([root_expression]) | ||||
|         return super().resolve_expression( | ||||
|             query, allow_joins, reuse, summarize, for_save | ||||
|         ) | ||||
|  | ||||
|     def as_sqlite(self, compiler, connection, **extra_context): | ||||
|         # Casting to numeric is unnecessary. | ||||
|         return self.as_sql(compiler, connection, **extra_context) | ||||
| @ -0,0 +1,727 @@ | ||||
| import itertools | ||||
| import math | ||||
|  | ||||
| from django.core.exceptions import EmptyResultSet, FullResultSet | ||||
| from django.db.models.expressions import Case, Expression, Func, Value, When | ||||
| from django.db.models.fields import ( | ||||
|     BooleanField, | ||||
|     CharField, | ||||
|     DateTimeField, | ||||
|     Field, | ||||
|     IntegerField, | ||||
|     UUIDField, | ||||
| ) | ||||
| from django.db.models.query_utils import RegisterLookupMixin | ||||
| from django.utils.datastructures import OrderedSet | ||||
| from django.utils.functional import cached_property | ||||
| from django.utils.hashable import make_hashable | ||||
|  | ||||
|  | ||||
| class Lookup(Expression): | ||||
|     lookup_name = None | ||||
|     prepare_rhs = True | ||||
|     can_use_none_as_rhs = False | ||||
|  | ||||
|     def __init__(self, lhs, rhs): | ||||
|         self.lhs, self.rhs = lhs, rhs | ||||
|         self.rhs = self.get_prep_lookup() | ||||
|         self.lhs = self.get_prep_lhs() | ||||
|         if hasattr(self.lhs, "get_bilateral_transforms"): | ||||
|             bilateral_transforms = self.lhs.get_bilateral_transforms() | ||||
|         else: | ||||
|             bilateral_transforms = [] | ||||
|         if bilateral_transforms: | ||||
|             # Warn the user as soon as possible if they are trying to apply | ||||
|             # a bilateral transformation on a nested QuerySet: that won't work. | ||||
|             from django.db.models.sql.query import Query  # avoid circular import | ||||
|  | ||||
|             if isinstance(rhs, Query): | ||||
|                 raise NotImplementedError( | ||||
|                     "Bilateral transformations on nested querysets are not implemented." | ||||
|                 ) | ||||
|         self.bilateral_transforms = bilateral_transforms | ||||
|  | ||||
|     def apply_bilateral_transforms(self, value): | ||||
|         for transform in self.bilateral_transforms: | ||||
|             value = transform(value) | ||||
|         return value | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return f"{self.__class__.__name__}({self.lhs!r}, {self.rhs!r})" | ||||
|  | ||||
|     def batch_process_rhs(self, compiler, connection, rhs=None): | ||||
|         if rhs is None: | ||||
|             rhs = self.rhs | ||||
|         if self.bilateral_transforms: | ||||
|             sqls, sqls_params = [], [] | ||||
|             for p in rhs: | ||||
|                 value = Value(p, output_field=self.lhs.output_field) | ||||
|                 value = self.apply_bilateral_transforms(value) | ||||
|                 value = value.resolve_expression(compiler.query) | ||||
|                 sql, sql_params = compiler.compile(value) | ||||
|                 sqls.append(sql) | ||||
|                 sqls_params.extend(sql_params) | ||||
|         else: | ||||
|             _, params = self.get_db_prep_lookup(rhs, connection) | ||||
|             sqls, sqls_params = ["%s"] * len(params), params | ||||
|         return sqls, sqls_params | ||||
|  | ||||
|     def get_source_expressions(self): | ||||
|         if self.rhs_is_direct_value(): | ||||
|             return [self.lhs] | ||||
|         return [self.lhs, self.rhs] | ||||
|  | ||||
|     def set_source_expressions(self, new_exprs): | ||||
|         if len(new_exprs) == 1: | ||||
|             self.lhs = new_exprs[0] | ||||
|         else: | ||||
|             self.lhs, self.rhs = new_exprs | ||||
|  | ||||
|     def get_prep_lookup(self): | ||||
|         if not self.prepare_rhs or hasattr(self.rhs, "resolve_expression"): | ||||
|             return self.rhs | ||||
|         if hasattr(self.lhs, "output_field"): | ||||
|             if hasattr(self.lhs.output_field, "get_prep_value"): | ||||
|                 return self.lhs.output_field.get_prep_value(self.rhs) | ||||
|         elif self.rhs_is_direct_value(): | ||||
|             return Value(self.rhs) | ||||
|         return self.rhs | ||||
|  | ||||
|     def get_prep_lhs(self): | ||||
|         if hasattr(self.lhs, "resolve_expression"): | ||||
|             return self.lhs | ||||
|         return Value(self.lhs) | ||||
|  | ||||
|     def get_db_prep_lookup(self, value, connection): | ||||
|         return ("%s", [value]) | ||||
|  | ||||
|     def process_lhs(self, compiler, connection, lhs=None): | ||||
|         lhs = lhs or self.lhs | ||||
|         if hasattr(lhs, "resolve_expression"): | ||||
|             lhs = lhs.resolve_expression(compiler.query) | ||||
|         sql, params = compiler.compile(lhs) | ||||
|         if isinstance(lhs, Lookup): | ||||
|             # Wrapped in parentheses to respect operator precedence. | ||||
|             sql = f"({sql})" | ||||
|         return sql, params | ||||
|  | ||||
|     def process_rhs(self, compiler, connection): | ||||
|         value = self.rhs | ||||
|         if self.bilateral_transforms: | ||||
|             if self.rhs_is_direct_value(): | ||||
|                 # Do not call get_db_prep_lookup here as the value will be | ||||
|                 # transformed before being used for lookup | ||||
|                 value = Value(value, output_field=self.lhs.output_field) | ||||
|             value = self.apply_bilateral_transforms(value) | ||||
|             value = value.resolve_expression(compiler.query) | ||||
|         if hasattr(value, "as_sql"): | ||||
|             sql, params = compiler.compile(value) | ||||
|             # Ensure expression is wrapped in parentheses to respect operator | ||||
|             # precedence but avoid double wrapping as it can be misinterpreted | ||||
|             # on some backends (e.g. subqueries on SQLite). | ||||
|             if sql and sql[0] != "(": | ||||
|                 sql = "(%s)" % sql | ||||
|             return sql, params | ||||
|         else: | ||||
|             return self.get_db_prep_lookup(value, connection) | ||||
|  | ||||
|     def rhs_is_direct_value(self): | ||||
|         return not hasattr(self.rhs, "as_sql") | ||||
|  | ||||
|     def get_group_by_cols(self): | ||||
|         cols = [] | ||||
|         for source in self.get_source_expressions(): | ||||
|             cols.extend(source.get_group_by_cols()) | ||||
|         return cols | ||||
|  | ||||
|     def as_oracle(self, compiler, connection): | ||||
|         # Oracle doesn't allow EXISTS() and filters to be compared to another | ||||
|         # expression unless they're wrapped in a CASE WHEN. | ||||
|         wrapped = False | ||||
|         exprs = [] | ||||
|         for expr in (self.lhs, self.rhs): | ||||
|             if connection.ops.conditional_expression_supported_in_where_clause(expr): | ||||
|                 expr = Case(When(expr, then=True), default=False) | ||||
|                 wrapped = True | ||||
|             exprs.append(expr) | ||||
|         lookup = type(self)(*exprs) if wrapped else self | ||||
|         return lookup.as_sql(compiler, connection) | ||||
|  | ||||
|     @cached_property | ||||
|     def output_field(self): | ||||
|         return BooleanField() | ||||
|  | ||||
|     @property | ||||
|     def identity(self): | ||||
|         return self.__class__, self.lhs, self.rhs | ||||
|  | ||||
|     def __eq__(self, other): | ||||
|         if not isinstance(other, Lookup): | ||||
|             return NotImplemented | ||||
|         return self.identity == other.identity | ||||
|  | ||||
|     def __hash__(self): | ||||
|         return hash(make_hashable(self.identity)) | ||||
|  | ||||
|     def resolve_expression( | ||||
|         self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False | ||||
|     ): | ||||
|         c = self.copy() | ||||
|         c.is_summary = summarize | ||||
|         c.lhs = self.lhs.resolve_expression( | ||||
|             query, allow_joins, reuse, summarize, for_save | ||||
|         ) | ||||
|         if hasattr(self.rhs, "resolve_expression"): | ||||
|             c.rhs = self.rhs.resolve_expression( | ||||
|                 query, allow_joins, reuse, summarize, for_save | ||||
|             ) | ||||
|         return c | ||||
|  | ||||
|     def select_format(self, compiler, sql, params): | ||||
|         # Wrap filters with a CASE WHEN expression if a database backend | ||||
|         # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP | ||||
|         # BY list. | ||||
|         if not compiler.connection.features.supports_boolean_expr_in_select_clause: | ||||
|             sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END" | ||||
|         return sql, params | ||||
|  | ||||
|  | ||||
| class Transform(RegisterLookupMixin, Func): | ||||
|     """ | ||||
|     RegisterLookupMixin() is first so that get_lookup() and get_transform() | ||||
|     first examine self and then check output_field. | ||||
|     """ | ||||
|  | ||||
|     bilateral = False | ||||
|     arity = 1 | ||||
|  | ||||
|     @property | ||||
|     def lhs(self): | ||||
|         return self.get_source_expressions()[0] | ||||
|  | ||||
|     def get_bilateral_transforms(self): | ||||
|         if hasattr(self.lhs, "get_bilateral_transforms"): | ||||
|             bilateral_transforms = self.lhs.get_bilateral_transforms() | ||||
|         else: | ||||
|             bilateral_transforms = [] | ||||
|         if self.bilateral: | ||||
|             bilateral_transforms.append(self.__class__) | ||||
|         return bilateral_transforms | ||||
|  | ||||
|  | ||||
| class BuiltinLookup(Lookup): | ||||
|     def process_lhs(self, compiler, connection, lhs=None): | ||||
|         lhs_sql, params = super().process_lhs(compiler, connection, lhs) | ||||
|         field_internal_type = self.lhs.output_field.get_internal_type() | ||||
|         db_type = self.lhs.output_field.db_type(connection=connection) | ||||
|         lhs_sql = connection.ops.field_cast_sql(db_type, field_internal_type) % lhs_sql | ||||
|         lhs_sql = ( | ||||
|             connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql | ||||
|         ) | ||||
|         return lhs_sql, list(params) | ||||
|  | ||||
|     def as_sql(self, compiler, connection): | ||||
|         lhs_sql, params = self.process_lhs(compiler, connection) | ||||
|         rhs_sql, rhs_params = self.process_rhs(compiler, connection) | ||||
|         params.extend(rhs_params) | ||||
|         rhs_sql = self.get_rhs_op(connection, rhs_sql) | ||||
|         return "%s %s" % (lhs_sql, rhs_sql), params | ||||
|  | ||||
|     def get_rhs_op(self, connection, rhs): | ||||
|         return connection.operators[self.lookup_name] % rhs | ||||
|  | ||||
|  | ||||
| class FieldGetDbPrepValueMixin: | ||||
|     """ | ||||
|     Some lookups require Field.get_db_prep_value() to be called on their | ||||
|     inputs. | ||||
|     """ | ||||
|  | ||||
|     get_db_prep_lookup_value_is_iterable = False | ||||
|  | ||||
|     def get_db_prep_lookup(self, value, connection): | ||||
|         # For relational fields, use the 'target_field' attribute of the | ||||
|         # output_field. | ||||
|         field = getattr(self.lhs.output_field, "target_field", None) | ||||
|         get_db_prep_value = ( | ||||
|             getattr(field, "get_db_prep_value", None) | ||||
|             or self.lhs.output_field.get_db_prep_value | ||||
|         ) | ||||
|         return ( | ||||
|             "%s", | ||||
|             [get_db_prep_value(v, connection, prepared=True) for v in value] | ||||
|             if self.get_db_prep_lookup_value_is_iterable | ||||
|             else [get_db_prep_value(value, connection, prepared=True)], | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin): | ||||
|     """ | ||||
|     Some lookups require Field.get_db_prep_value() to be called on each value | ||||
|     in an iterable. | ||||
|     """ | ||||
|  | ||||
|     get_db_prep_lookup_value_is_iterable = True | ||||
|  | ||||
|     def get_prep_lookup(self): | ||||
|         if hasattr(self.rhs, "resolve_expression"): | ||||
|             return self.rhs | ||||
|         prepared_values = [] | ||||
|         for rhs_value in self.rhs: | ||||
|             if hasattr(rhs_value, "resolve_expression"): | ||||
|                 # An expression will be handled by the database but can coexist | ||||
|                 # alongside real values. | ||||
|                 pass | ||||
|             elif self.prepare_rhs and hasattr(self.lhs.output_field, "get_prep_value"): | ||||
|                 rhs_value = self.lhs.output_field.get_prep_value(rhs_value) | ||||
|             prepared_values.append(rhs_value) | ||||
|         return prepared_values | ||||
|  | ||||
|     def process_rhs(self, compiler, connection): | ||||
|         if self.rhs_is_direct_value(): | ||||
|             # rhs should be an iterable of values. Use batch_process_rhs() | ||||
|             # to prepare/transform those values. | ||||
|             return self.batch_process_rhs(compiler, connection) | ||||
|         else: | ||||
|             return super().process_rhs(compiler, connection) | ||||
|  | ||||
|     def resolve_expression_parameter(self, compiler, connection, sql, param): | ||||
|         params = [param] | ||||
|         if hasattr(param, "resolve_expression"): | ||||
|             param = param.resolve_expression(compiler.query) | ||||
|         if hasattr(param, "as_sql"): | ||||
|             sql, params = compiler.compile(param) | ||||
|         return sql, params | ||||
|  | ||||
|     def batch_process_rhs(self, compiler, connection, rhs=None): | ||||
|         pre_processed = super().batch_process_rhs(compiler, connection, rhs) | ||||
|         # The params list may contain expressions which compile to a | ||||
|         # sql/param pair. Zip them to get sql and param pairs that refer to the | ||||
|         # same argument and attempt to replace them with the result of | ||||
|         # compiling the param step. | ||||
|         sql, params = zip( | ||||
|             *( | ||||
|                 self.resolve_expression_parameter(compiler, connection, sql, param) | ||||
|                 for sql, param in zip(*pre_processed) | ||||
|             ) | ||||
|         ) | ||||
|         params = itertools.chain.from_iterable(params) | ||||
|         return sql, tuple(params) | ||||
|  | ||||
|  | ||||
| class PostgresOperatorLookup(Lookup): | ||||
|     """Lookup defined by operators on PostgreSQL.""" | ||||
|  | ||||
|     postgres_operator = None | ||||
|  | ||||
|     def as_postgresql(self, compiler, connection): | ||||
|         lhs, lhs_params = self.process_lhs(compiler, connection) | ||||
|         rhs, rhs_params = self.process_rhs(compiler, connection) | ||||
|         params = tuple(lhs_params) + tuple(rhs_params) | ||||
|         return "%s %s %s" % (lhs, self.postgres_operator, rhs), params | ||||
|  | ||||
|  | ||||
| @Field.register_lookup | ||||
| class Exact(FieldGetDbPrepValueMixin, BuiltinLookup): | ||||
|     lookup_name = "exact" | ||||
|  | ||||
|     def get_prep_lookup(self): | ||||
|         from django.db.models.sql.query import Query  # avoid circular import | ||||
|  | ||||
|         if isinstance(self.rhs, Query): | ||||
|             if self.rhs.has_limit_one(): | ||||
|                 if not self.rhs.has_select_fields: | ||||
|                     self.rhs.clear_select_clause() | ||||
|                     self.rhs.add_fields(["pk"]) | ||||
|             else: | ||||
|                 raise ValueError( | ||||
|                     "The QuerySet value for an exact lookup must be limited to " | ||||
|                     "one result using slicing." | ||||
|                 ) | ||||
|         return super().get_prep_lookup() | ||||
|  | ||||
|     def as_sql(self, compiler, connection): | ||||
|         # Avoid comparison against direct rhs if lhs is a boolean value. That | ||||
|         # turns "boolfield__exact=True" into "WHERE boolean_field" instead of | ||||
|         # "WHERE boolean_field = True" when allowed. | ||||
|         if ( | ||||
|             isinstance(self.rhs, bool) | ||||
|             and getattr(self.lhs, "conditional", False) | ||||
|             and connection.ops.conditional_expression_supported_in_where_clause( | ||||
|                 self.lhs | ||||
|             ) | ||||
|         ): | ||||
|             lhs_sql, params = self.process_lhs(compiler, connection) | ||||
|             template = "%s" if self.rhs else "NOT %s" | ||||
|             return template % lhs_sql, params | ||||
|         return super().as_sql(compiler, connection) | ||||
|  | ||||
|  | ||||
| @Field.register_lookup | ||||
| class IExact(BuiltinLookup): | ||||
|     lookup_name = "iexact" | ||||
|     prepare_rhs = False | ||||
|  | ||||
|     def process_rhs(self, qn, connection): | ||||
|         rhs, params = super().process_rhs(qn, connection) | ||||
|         if params: | ||||
|             params[0] = connection.ops.prep_for_iexact_query(params[0]) | ||||
|         return rhs, params | ||||
|  | ||||
|  | ||||
| @Field.register_lookup | ||||
| class GreaterThan(FieldGetDbPrepValueMixin, BuiltinLookup): | ||||
|     lookup_name = "gt" | ||||
|  | ||||
|  | ||||
| @Field.register_lookup | ||||
| class GreaterThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup): | ||||
|     lookup_name = "gte" | ||||
|  | ||||
|  | ||||
| @Field.register_lookup | ||||
| class LessThan(FieldGetDbPrepValueMixin, BuiltinLookup): | ||||
|     lookup_name = "lt" | ||||
|  | ||||
|  | ||||
| @Field.register_lookup | ||||
| class LessThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup): | ||||
|     lookup_name = "lte" | ||||
|  | ||||
|  | ||||
| class IntegerFieldFloatRounding: | ||||
|     """ | ||||
|     Allow floats to work as query values for IntegerField. Without this, the | ||||
|     decimal portion of the float would always be discarded. | ||||
|     """ | ||||
|  | ||||
|     def get_prep_lookup(self): | ||||
|         if isinstance(self.rhs, float): | ||||
|             self.rhs = math.ceil(self.rhs) | ||||
|         return super().get_prep_lookup() | ||||
|  | ||||
|  | ||||
| @IntegerField.register_lookup | ||||
| class IntegerGreaterThanOrEqual(IntegerFieldFloatRounding, GreaterThanOrEqual): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| @IntegerField.register_lookup | ||||
| class IntegerLessThan(IntegerFieldFloatRounding, LessThan): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| @Field.register_lookup | ||||
| class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup): | ||||
|     lookup_name = "in" | ||||
|  | ||||
|     def get_prep_lookup(self): | ||||
|         from django.db.models.sql.query import Query  # avoid circular import | ||||
|  | ||||
|         if isinstance(self.rhs, Query): | ||||
|             self.rhs.clear_ordering(clear_default=True) | ||||
|             if not self.rhs.has_select_fields: | ||||
|                 self.rhs.clear_select_clause() | ||||
|                 self.rhs.add_fields(["pk"]) | ||||
|         return super().get_prep_lookup() | ||||
|  | ||||
|     def process_rhs(self, compiler, connection): | ||||
|         db_rhs = getattr(self.rhs, "_db", None) | ||||
|         if db_rhs is not None and db_rhs != connection.alias: | ||||
|             raise ValueError( | ||||
|                 "Subqueries aren't allowed across different databases. Force " | ||||
|                 "the inner query to be evaluated using `list(inner_query)`." | ||||
|             ) | ||||
|  | ||||
|         if self.rhs_is_direct_value(): | ||||
|             # Remove None from the list as NULL is never equal to anything. | ||||
|             try: | ||||
|                 rhs = OrderedSet(self.rhs) | ||||
|                 rhs.discard(None) | ||||
|             except TypeError:  # Unhashable items in self.rhs | ||||
|                 rhs = [r for r in self.rhs if r is not None] | ||||
|  | ||||
|             if not rhs: | ||||
|                 raise EmptyResultSet | ||||
|  | ||||
|             # rhs should be an iterable; use batch_process_rhs() to | ||||
|             # prepare/transform those values. | ||||
|             sqls, sqls_params = self.batch_process_rhs(compiler, connection, rhs) | ||||
|             placeholder = "(" + ", ".join(sqls) + ")" | ||||
|             return (placeholder, sqls_params) | ||||
|         return super().process_rhs(compiler, connection) | ||||
|  | ||||
|     def get_rhs_op(self, connection, rhs): | ||||
|         return "IN %s" % rhs | ||||
|  | ||||
|     def as_sql(self, compiler, connection): | ||||
|         max_in_list_size = connection.ops.max_in_list_size() | ||||
|         if ( | ||||
|             self.rhs_is_direct_value() | ||||
|             and max_in_list_size | ||||
|             and len(self.rhs) > max_in_list_size | ||||
|         ): | ||||
|             return self.split_parameter_list_as_sql(compiler, connection) | ||||
|         return super().as_sql(compiler, connection) | ||||
|  | ||||
|     def split_parameter_list_as_sql(self, compiler, connection): | ||||
|         # This is a special case for databases which limit the number of | ||||
|         # elements which can appear in an 'IN' clause. | ||||
|         max_in_list_size = connection.ops.max_in_list_size() | ||||
|         lhs, lhs_params = self.process_lhs(compiler, connection) | ||||
|         rhs, rhs_params = self.batch_process_rhs(compiler, connection) | ||||
|         in_clause_elements = ["("] | ||||
|         params = [] | ||||
|         for offset in range(0, len(rhs_params), max_in_list_size): | ||||
|             if offset > 0: | ||||
|                 in_clause_elements.append(" OR ") | ||||
|             in_clause_elements.append("%s IN (" % lhs) | ||||
|             params.extend(lhs_params) | ||||
|             sqls = rhs[offset : offset + max_in_list_size] | ||||
|             sqls_params = rhs_params[offset : offset + max_in_list_size] | ||||
|             param_group = ", ".join(sqls) | ||||
|             in_clause_elements.append(param_group) | ||||
|             in_clause_elements.append(")") | ||||
|             params.extend(sqls_params) | ||||
|         in_clause_elements.append(")") | ||||
|         return "".join(in_clause_elements), params | ||||
|  | ||||
|  | ||||
| class PatternLookup(BuiltinLookup): | ||||
|     param_pattern = "%%%s%%" | ||||
|     prepare_rhs = False | ||||
|  | ||||
|     def get_rhs_op(self, connection, rhs): | ||||
|         # Assume we are in startswith. We need to produce SQL like: | ||||
|         #     col LIKE %s, ['thevalue%'] | ||||
|         # For python values we can (and should) do that directly in Python, | ||||
|         # but if the value is for example reference to other column, then | ||||
|         # we need to add the % pattern match to the lookup by something like | ||||
|         #     col LIKE othercol || '%%' | ||||
|         # So, for Python values we don't need any special pattern, but for | ||||
|         # SQL reference values or SQL transformations we need the correct | ||||
|         # pattern added. | ||||
|         if hasattr(self.rhs, "as_sql") or self.bilateral_transforms: | ||||
|             pattern = connection.pattern_ops[self.lookup_name].format( | ||||
|                 connection.pattern_esc | ||||
|             ) | ||||
|             return pattern.format(rhs) | ||||
|         else: | ||||
|             return super().get_rhs_op(connection, rhs) | ||||
|  | ||||
|     def process_rhs(self, qn, connection): | ||||
|         rhs, params = super().process_rhs(qn, connection) | ||||
|         if self.rhs_is_direct_value() and params and not self.bilateral_transforms: | ||||
|             params[0] = self.param_pattern % connection.ops.prep_for_like_query( | ||||
|                 params[0] | ||||
|             ) | ||||
|         return rhs, params | ||||
|  | ||||
|  | ||||
| @Field.register_lookup | ||||
| class Contains(PatternLookup): | ||||
|     lookup_name = "contains" | ||||
|  | ||||
|  | ||||
| @Field.register_lookup | ||||
| class IContains(Contains): | ||||
|     lookup_name = "icontains" | ||||
|  | ||||
|  | ||||
| @Field.register_lookup | ||||
| class StartsWith(PatternLookup): | ||||
|     lookup_name = "startswith" | ||||
|     param_pattern = "%s%%" | ||||
|  | ||||
|  | ||||
| @Field.register_lookup | ||||
| class IStartsWith(StartsWith): | ||||
|     lookup_name = "istartswith" | ||||
|  | ||||
|  | ||||
| @Field.register_lookup | ||||
| class EndsWith(PatternLookup): | ||||
|     lookup_name = "endswith" | ||||
|     param_pattern = "%%%s" | ||||
|  | ||||
|  | ||||
| @Field.register_lookup | ||||
| class IEndsWith(EndsWith): | ||||
|     lookup_name = "iendswith" | ||||
|  | ||||
|  | ||||
| @Field.register_lookup | ||||
| class Range(FieldGetDbPrepValueIterableMixin, BuiltinLookup): | ||||
|     lookup_name = "range" | ||||
|  | ||||
|     def get_rhs_op(self, connection, rhs): | ||||
|         return "BETWEEN %s AND %s" % (rhs[0], rhs[1]) | ||||
|  | ||||
|  | ||||
| @Field.register_lookup | ||||
| class IsNull(BuiltinLookup): | ||||
|     lookup_name = "isnull" | ||||
|     prepare_rhs = False | ||||
|  | ||||
|     def as_sql(self, compiler, connection): | ||||
|         if not isinstance(self.rhs, bool): | ||||
|             raise ValueError( | ||||
|                 "The QuerySet value for an isnull lookup must be True or False." | ||||
|             ) | ||||
|         if isinstance(self.lhs, Value): | ||||
|             if self.lhs.value is None or ( | ||||
|                 self.lhs.value == "" | ||||
|                 and connection.features.interprets_empty_strings_as_nulls | ||||
|             ): | ||||
|                 result_exception = FullResultSet if self.rhs else EmptyResultSet | ||||
|             else: | ||||
|                 result_exception = EmptyResultSet if self.rhs else FullResultSet | ||||
|             raise result_exception | ||||
|         sql, params = self.process_lhs(compiler, connection) | ||||
|         if self.rhs: | ||||
|             return "%s IS NULL" % sql, params | ||||
|         else: | ||||
|             return "%s IS NOT NULL" % sql, params | ||||
|  | ||||
|  | ||||
| @Field.register_lookup | ||||
| class Regex(BuiltinLookup): | ||||
|     lookup_name = "regex" | ||||
|     prepare_rhs = False | ||||
|  | ||||
|     def as_sql(self, compiler, connection): | ||||
|         if self.lookup_name in connection.operators: | ||||
|             return super().as_sql(compiler, connection) | ||||
|         else: | ||||
|             lhs, lhs_params = self.process_lhs(compiler, connection) | ||||
|             rhs, rhs_params = self.process_rhs(compiler, connection) | ||||
|             sql_template = connection.ops.regex_lookup(self.lookup_name) | ||||
|             return sql_template % (lhs, rhs), lhs_params + rhs_params | ||||
|  | ||||
|  | ||||
| @Field.register_lookup | ||||
| class IRegex(Regex): | ||||
|     lookup_name = "iregex" | ||||
|  | ||||
|  | ||||
| class YearLookup(Lookup): | ||||
|     def year_lookup_bounds(self, connection, year): | ||||
|         from django.db.models.functions import ExtractIsoYear | ||||
|  | ||||
|         iso_year = isinstance(self.lhs, ExtractIsoYear) | ||||
|         output_field = self.lhs.lhs.output_field | ||||
|         if isinstance(output_field, DateTimeField): | ||||
|             bounds = connection.ops.year_lookup_bounds_for_datetime_field( | ||||
|                 year, | ||||
|                 iso_year=iso_year, | ||||
|             ) | ||||
|         else: | ||||
|             bounds = connection.ops.year_lookup_bounds_for_date_field( | ||||
|                 year, | ||||
|                 iso_year=iso_year, | ||||
|             ) | ||||
|         return bounds | ||||
|  | ||||
|     def as_sql(self, compiler, connection): | ||||
|         # Avoid the extract operation if the rhs is a direct value to allow | ||||
|         # indexes to be used. | ||||
|         if self.rhs_is_direct_value(): | ||||
|             # Skip the extract part by directly using the originating field, | ||||
|             # that is self.lhs.lhs. | ||||
|             lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs) | ||||
|             rhs_sql, _ = self.process_rhs(compiler, connection) | ||||
|             rhs_sql = self.get_direct_rhs_sql(connection, rhs_sql) | ||||
|             start, finish = self.year_lookup_bounds(connection, self.rhs) | ||||
|             params.extend(self.get_bound_params(start, finish)) | ||||
|             return "%s %s" % (lhs_sql, rhs_sql), params | ||||
|         return super().as_sql(compiler, connection) | ||||
|  | ||||
|     def get_direct_rhs_sql(self, connection, rhs): | ||||
|         return connection.operators[self.lookup_name] % rhs | ||||
|  | ||||
|     def get_bound_params(self, start, finish): | ||||
|         raise NotImplementedError( | ||||
|             "subclasses of YearLookup must provide a get_bound_params() method" | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class YearExact(YearLookup, Exact): | ||||
|     def get_direct_rhs_sql(self, connection, rhs): | ||||
|         return "BETWEEN %s AND %s" | ||||
|  | ||||
|     def get_bound_params(self, start, finish): | ||||
|         return (start, finish) | ||||
|  | ||||
|  | ||||
| class YearGt(YearLookup, GreaterThan): | ||||
|     def get_bound_params(self, start, finish): | ||||
|         return (finish,) | ||||
|  | ||||
|  | ||||
| class YearGte(YearLookup, GreaterThanOrEqual): | ||||
|     def get_bound_params(self, start, finish): | ||||
|         return (start,) | ||||
|  | ||||
|  | ||||
| class YearLt(YearLookup, LessThan): | ||||
|     def get_bound_params(self, start, finish): | ||||
|         return (start,) | ||||
|  | ||||
|  | ||||
| class YearLte(YearLookup, LessThanOrEqual): | ||||
|     def get_bound_params(self, start, finish): | ||||
|         return (finish,) | ||||
|  | ||||
|  | ||||
| class UUIDTextMixin: | ||||
|     """ | ||||
|     Strip hyphens from a value when filtering a UUIDField on backends without | ||||
|     a native datatype for UUID. | ||||
|     """ | ||||
|  | ||||
|     def process_rhs(self, qn, connection): | ||||
|         if not connection.features.has_native_uuid_field: | ||||
|             from django.db.models.functions import Replace | ||||
|  | ||||
|             if self.rhs_is_direct_value(): | ||||
|                 self.rhs = Value(self.rhs) | ||||
|             self.rhs = Replace( | ||||
|                 self.rhs, Value("-"), Value(""), output_field=CharField() | ||||
|             ) | ||||
|         rhs, params = super().process_rhs(qn, connection) | ||||
|         return rhs, params | ||||
|  | ||||
|  | ||||
| @UUIDField.register_lookup | ||||
| class UUIDIExact(UUIDTextMixin, IExact): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| @UUIDField.register_lookup | ||||
| class UUIDContains(UUIDTextMixin, Contains): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| @UUIDField.register_lookup | ||||
| class UUIDIContains(UUIDTextMixin, IContains): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| @UUIDField.register_lookup | ||||
| class UUIDStartsWith(UUIDTextMixin, StartsWith): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| @UUIDField.register_lookup | ||||
| class UUIDIStartsWith(UUIDTextMixin, IStartsWith): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| @UUIDField.register_lookup | ||||
| class UUIDEndsWith(UUIDTextMixin, EndsWith): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| @UUIDField.register_lookup | ||||
| class UUIDIEndsWith(UUIDTextMixin, IEndsWith): | ||||
|     pass | ||||
| @ -0,0 +1,213 @@ | ||||
| import copy | ||||
| import inspect | ||||
| from functools import wraps | ||||
| from importlib import import_module | ||||
|  | ||||
| from django.db import router | ||||
| from django.db.models.query import QuerySet | ||||
|  | ||||
|  | ||||
| class BaseManager: | ||||
|     # To retain order, track each time a Manager instance is created. | ||||
|     creation_counter = 0 | ||||
|  | ||||
|     # Set to True for the 'objects' managers that are automatically created. | ||||
|     auto_created = False | ||||
|  | ||||
|     #: If set to True the manager will be serialized into migrations and will | ||||
|     #: thus be available in e.g. RunPython operations. | ||||
|     use_in_migrations = False | ||||
|  | ||||
|     def __new__(cls, *args, **kwargs): | ||||
|         # Capture the arguments to make returning them trivial. | ||||
|         obj = super().__new__(cls) | ||||
|         obj._constructor_args = (args, kwargs) | ||||
|         return obj | ||||
|  | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|         self._set_creation_counter() | ||||
|         self.model = None | ||||
|         self.name = None | ||||
|         self._db = None | ||||
|         self._hints = {} | ||||
|  | ||||
|     def __str__(self): | ||||
|         """Return "app_label.model_label.manager_name".""" | ||||
|         return "%s.%s" % (self.model._meta.label, self.name) | ||||
|  | ||||
|     def __class_getitem__(cls, *args, **kwargs): | ||||
|         return cls | ||||
|  | ||||
|     def deconstruct(self): | ||||
|         """ | ||||
|         Return a 5-tuple of the form (as_manager (True), manager_class, | ||||
|         queryset_class, args, kwargs). | ||||
|  | ||||
|         Raise a ValueError if the manager is dynamically generated. | ||||
|         """ | ||||
|         qs_class = self._queryset_class | ||||
|         if getattr(self, "_built_with_as_manager", False): | ||||
|             # using MyQuerySet.as_manager() | ||||
|             return ( | ||||
|                 True,  # as_manager | ||||
|                 None,  # manager_class | ||||
|                 "%s.%s" % (qs_class.__module__, qs_class.__name__),  # qs_class | ||||
|                 None,  # args | ||||
|                 None,  # kwargs | ||||
|             ) | ||||
|         else: | ||||
|             module_name = self.__module__ | ||||
|             name = self.__class__.__name__ | ||||
|             # Make sure it's actually there and not an inner class | ||||
|             module = import_module(module_name) | ||||
|             if not hasattr(module, name): | ||||
|                 raise ValueError( | ||||
|                     "Could not find manager %s in %s.\n" | ||||
|                     "Please note that you need to inherit from managers you " | ||||
|                     "dynamically generated with 'from_queryset()'." | ||||
|                     % (name, module_name) | ||||
|                 ) | ||||
|             return ( | ||||
|                 False,  # as_manager | ||||
|                 "%s.%s" % (module_name, name),  # manager_class | ||||
|                 None,  # qs_class | ||||
|                 self._constructor_args[0],  # args | ||||
|                 self._constructor_args[1],  # kwargs | ||||
|             ) | ||||
|  | ||||
|     def check(self, **kwargs): | ||||
|         return [] | ||||
|  | ||||
|     @classmethod | ||||
|     def _get_queryset_methods(cls, queryset_class): | ||||
|         def create_method(name, method): | ||||
|             @wraps(method) | ||||
|             def manager_method(self, *args, **kwargs): | ||||
|                 return getattr(self.get_queryset(), name)(*args, **kwargs) | ||||
|  | ||||
|             return manager_method | ||||
|  | ||||
|         new_methods = {} | ||||
|         for name, method in inspect.getmembers( | ||||
|             queryset_class, predicate=inspect.isfunction | ||||
|         ): | ||||
|             # Only copy missing methods. | ||||
|             if hasattr(cls, name): | ||||
|                 continue | ||||
|             # Only copy public methods or methods with the attribute | ||||
|             # queryset_only=False. | ||||
|             queryset_only = getattr(method, "queryset_only", None) | ||||
|             if queryset_only or (queryset_only is None and name.startswith("_")): | ||||
|                 continue | ||||
|             # Copy the method onto the manager. | ||||
|             new_methods[name] = create_method(name, method) | ||||
|         return new_methods | ||||
|  | ||||
|     @classmethod | ||||
|     def from_queryset(cls, queryset_class, class_name=None): | ||||
|         if class_name is None: | ||||
|             class_name = "%sFrom%s" % (cls.__name__, queryset_class.__name__) | ||||
|         return type( | ||||
|             class_name, | ||||
|             (cls,), | ||||
|             { | ||||
|                 "_queryset_class": queryset_class, | ||||
|                 **cls._get_queryset_methods(queryset_class), | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     def contribute_to_class(self, cls, name): | ||||
|         self.name = self.name or name | ||||
|         self.model = cls | ||||
|  | ||||
|         setattr(cls, name, ManagerDescriptor(self)) | ||||
|  | ||||
|         cls._meta.add_manager(self) | ||||
|  | ||||
|     def _set_creation_counter(self): | ||||
|         """ | ||||
|         Set the creation counter value for this instance and increment the | ||||
|         class-level copy. | ||||
|         """ | ||||
|         self.creation_counter = BaseManager.creation_counter | ||||
|         BaseManager.creation_counter += 1 | ||||
|  | ||||
|     def db_manager(self, using=None, hints=None): | ||||
|         obj = copy.copy(self) | ||||
|         obj._db = using or self._db | ||||
|         obj._hints = hints or self._hints | ||||
|         return obj | ||||
|  | ||||
|     @property | ||||
|     def db(self): | ||||
|         return self._db or router.db_for_read(self.model, **self._hints) | ||||
|  | ||||
|     ####################### | ||||
|     # PROXIES TO QUERYSET # | ||||
|     ####################### | ||||
|  | ||||
|     def get_queryset(self): | ||||
|         """ | ||||
|         Return a new QuerySet object. Subclasses can override this method to | ||||
|         customize the behavior of the Manager. | ||||
|         """ | ||||
|         return self._queryset_class(model=self.model, using=self._db, hints=self._hints) | ||||
|  | ||||
|     def all(self): | ||||
|         # We can't proxy this method through the `QuerySet` like we do for the | ||||
|         # rest of the `QuerySet` methods. This is because `QuerySet.all()` | ||||
|         # works by creating a "copy" of the current queryset and in making said | ||||
|         # copy, all the cached `prefetch_related` lookups are lost. See the | ||||
|         # implementation of `RelatedManager.get_queryset()` for a better | ||||
|         # understanding of how this comes into play. | ||||
|         return self.get_queryset() | ||||
|  | ||||
|     def __eq__(self, other): | ||||
|         return ( | ||||
|             isinstance(other, self.__class__) | ||||
|             and self._constructor_args == other._constructor_args | ||||
|         ) | ||||
|  | ||||
|     def __hash__(self): | ||||
|         return id(self) | ||||
|  | ||||
|  | ||||
| class Manager(BaseManager.from_queryset(QuerySet)): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class ManagerDescriptor: | ||||
|     def __init__(self, manager): | ||||
|         self.manager = manager | ||||
|  | ||||
|     def __get__(self, instance, cls=None): | ||||
|         if instance is not None: | ||||
|             raise AttributeError( | ||||
|                 "Manager isn't accessible via %s instances" % cls.__name__ | ||||
|             ) | ||||
|  | ||||
|         if cls._meta.abstract: | ||||
|             raise AttributeError( | ||||
|                 "Manager isn't available; %s is abstract" % (cls._meta.object_name,) | ||||
|             ) | ||||
|  | ||||
|         if cls._meta.swapped: | ||||
|             raise AttributeError( | ||||
|                 "Manager isn't available; '%s' has been swapped for '%s'" | ||||
|                 % ( | ||||
|                     cls._meta.label, | ||||
|                     cls._meta.swapped, | ||||
|                 ) | ||||
|             ) | ||||
|  | ||||
|         return cls._meta.managers_map[self.manager.name] | ||||
|  | ||||
|  | ||||
| class EmptyManager(Manager): | ||||
|     def __init__(self, model): | ||||
|         super().__init__() | ||||
|         self.model = model | ||||
|  | ||||
|     def get_queryset(self): | ||||
|         return super().get_queryset().none() | ||||
							
								
								
									
										1014
									
								
								srcs/.venv/lib/python3.11/site-packages/django/db/models/options.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1014
									
								
								srcs/.venv/lib/python3.11/site-packages/django/db/models/options.py
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										2631
									
								
								srcs/.venv/lib/python3.11/site-packages/django/db/models/query.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2631
									
								
								srcs/.venv/lib/python3.11/site-packages/django/db/models/query.py
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -0,0 +1,435 @@ | ||||
| """ | ||||
| Various data structures used in query construction. | ||||
|  | ||||
| Factored out from django.db.models.query to avoid making the main module very | ||||
| large and/or so that they can be used by other modules without getting into | ||||
| circular import difficulties. | ||||
| """ | ||||
| import functools | ||||
| import inspect | ||||
| import logging | ||||
| from collections import namedtuple | ||||
|  | ||||
| from django.core.exceptions import FieldError | ||||
| from django.db import DEFAULT_DB_ALIAS, DatabaseError, connections | ||||
| from django.db.models.constants import LOOKUP_SEP | ||||
| from django.utils import tree | ||||
|  | ||||
| logger = logging.getLogger("django.db.models") | ||||
|  | ||||
| # PathInfo is used when converting lookups (fk__somecol). The contents | ||||
| # describe the relation in Model terms (model Options and Fields for both | ||||
| # sides of the relation. The join_field is the field backing the relation. | ||||
| PathInfo = namedtuple( | ||||
|     "PathInfo", | ||||
|     "from_opts to_opts target_fields join_field m2m direct filtered_relation", | ||||
| ) | ||||
|  | ||||
|  | ||||
| def subclasses(cls): | ||||
|     yield cls | ||||
|     for subclass in cls.__subclasses__(): | ||||
|         yield from subclasses(subclass) | ||||
|  | ||||
|  | ||||
| class Q(tree.Node): | ||||
|     """ | ||||
|     Encapsulate filters as objects that can then be combined logically (using | ||||
|     `&` and `|`). | ||||
|     """ | ||||
|  | ||||
|     # Connection types | ||||
|     AND = "AND" | ||||
|     OR = "OR" | ||||
|     XOR = "XOR" | ||||
|     default = AND | ||||
|     conditional = True | ||||
|  | ||||
|     def __init__(self, *args, _connector=None, _negated=False, **kwargs): | ||||
|         super().__init__( | ||||
|             children=[*args, *sorted(kwargs.items())], | ||||
|             connector=_connector, | ||||
|             negated=_negated, | ||||
|         ) | ||||
|  | ||||
|     def _combine(self, other, conn): | ||||
|         if getattr(other, "conditional", False) is False: | ||||
|             raise TypeError(other) | ||||
|         if not self: | ||||
|             return other.copy() | ||||
|         if not other and isinstance(other, Q): | ||||
|             return self.copy() | ||||
|  | ||||
|         obj = self.create(connector=conn) | ||||
|         obj.add(self, conn) | ||||
|         obj.add(other, conn) | ||||
|         return obj | ||||
|  | ||||
|     def __or__(self, other): | ||||
|         return self._combine(other, self.OR) | ||||
|  | ||||
|     def __and__(self, other): | ||||
|         return self._combine(other, self.AND) | ||||
|  | ||||
|     def __xor__(self, other): | ||||
|         return self._combine(other, self.XOR) | ||||
|  | ||||
|     def __invert__(self): | ||||
|         obj = self.copy() | ||||
|         obj.negate() | ||||
|         return obj | ||||
|  | ||||
|     def resolve_expression( | ||||
|         self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False | ||||
|     ): | ||||
|         # We must promote any new joins to left outer joins so that when Q is | ||||
|         # used as an expression, rows aren't filtered due to joins. | ||||
|         clause, joins = query._add_q( | ||||
|             self, | ||||
|             reuse, | ||||
|             allow_joins=allow_joins, | ||||
|             split_subq=False, | ||||
|             check_filterable=False, | ||||
|             summarize=summarize, | ||||
|         ) | ||||
|         query.promote_joins(joins) | ||||
|         return clause | ||||
|  | ||||
|     def flatten(self): | ||||
|         """ | ||||
|         Recursively yield this Q object and all subexpressions, in depth-first | ||||
|         order. | ||||
|         """ | ||||
|         yield self | ||||
|         for child in self.children: | ||||
|             if isinstance(child, tuple): | ||||
|                 # Use the lookup. | ||||
|                 child = child[1] | ||||
|             if hasattr(child, "flatten"): | ||||
|                 yield from child.flatten() | ||||
|             else: | ||||
|                 yield child | ||||
|  | ||||
|     def check(self, against, using=DEFAULT_DB_ALIAS): | ||||
|         """ | ||||
|         Do a database query to check if the expressions of the Q instance | ||||
|         matches against the expressions. | ||||
|         """ | ||||
|         # Avoid circular imports. | ||||
|         from django.db.models import BooleanField, Value | ||||
|         from django.db.models.functions import Coalesce | ||||
|         from django.db.models.sql import Query | ||||
|         from django.db.models.sql.constants import SINGLE | ||||
|  | ||||
|         query = Query(None) | ||||
|         for name, value in against.items(): | ||||
|             if not hasattr(value, "resolve_expression"): | ||||
|                 value = Value(value) | ||||
|             query.add_annotation(value, name, select=False) | ||||
|         query.add_annotation(Value(1), "_check") | ||||
|         # This will raise a FieldError if a field is missing in "against". | ||||
|         if connections[using].features.supports_comparing_boolean_expr: | ||||
|             query.add_q(Q(Coalesce(self, True, output_field=BooleanField()))) | ||||
|         else: | ||||
|             query.add_q(self) | ||||
|         compiler = query.get_compiler(using=using) | ||||
|         try: | ||||
|             return compiler.execute_sql(SINGLE) is not None | ||||
|         except DatabaseError as e: | ||||
|             logger.warning("Got a database error calling check() on %r: %s", self, e) | ||||
|             return True | ||||
|  | ||||
|     def deconstruct(self): | ||||
|         path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__) | ||||
|         if path.startswith("django.db.models.query_utils"): | ||||
|             path = path.replace("django.db.models.query_utils", "django.db.models") | ||||
|         args = tuple(self.children) | ||||
|         kwargs = {} | ||||
|         if self.connector != self.default: | ||||
|             kwargs["_connector"] = self.connector | ||||
|         if self.negated: | ||||
|             kwargs["_negated"] = True | ||||
|         return path, args, kwargs | ||||
|  | ||||
|  | ||||
| class DeferredAttribute: | ||||
|     """ | ||||
|     A wrapper for a deferred-loading field. When the value is read from this | ||||
|     object the first time, the query is executed. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, field): | ||||
|         self.field = field | ||||
|  | ||||
|     def __get__(self, instance, cls=None): | ||||
|         """ | ||||
|         Retrieve and caches the value from the datastore on the first lookup. | ||||
|         Return the cached value. | ||||
|         """ | ||||
|         if instance is None: | ||||
|             return self | ||||
|         data = instance.__dict__ | ||||
|         field_name = self.field.attname | ||||
|         if field_name not in data: | ||||
|             # Let's see if the field is part of the parent chain. If so we | ||||
|             # might be able to reuse the already loaded value. Refs #18343. | ||||
|             val = self._check_parent_chain(instance) | ||||
|             if val is None: | ||||
|                 instance.refresh_from_db(fields=[field_name]) | ||||
|             else: | ||||
|                 data[field_name] = val | ||||
|         return data[field_name] | ||||
|  | ||||
|     def _check_parent_chain(self, instance): | ||||
|         """ | ||||
|         Check if the field value can be fetched from a parent field already | ||||
|         loaded in the instance. This can be done if the to-be fetched | ||||
|         field is a primary key field. | ||||
|         """ | ||||
|         opts = instance._meta | ||||
|         link_field = opts.get_ancestor_link(self.field.model) | ||||
|         if self.field.primary_key and self.field != link_field: | ||||
|             return getattr(instance, link_field.attname) | ||||
|         return None | ||||
|  | ||||
|  | ||||
| class class_or_instance_method: | ||||
|     """ | ||||
|     Hook used in RegisterLookupMixin to return partial functions depending on | ||||
|     the caller type (instance or class of models.Field). | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, class_method, instance_method): | ||||
|         self.class_method = class_method | ||||
|         self.instance_method = instance_method | ||||
|  | ||||
|     def __get__(self, instance, owner): | ||||
|         if instance is None: | ||||
|             return functools.partial(self.class_method, owner) | ||||
|         return functools.partial(self.instance_method, instance) | ||||
|  | ||||
|  | ||||
| class RegisterLookupMixin: | ||||
|     def _get_lookup(self, lookup_name): | ||||
|         return self.get_lookups().get(lookup_name, None) | ||||
|  | ||||
|     @functools.lru_cache(maxsize=None) | ||||
|     def get_class_lookups(cls): | ||||
|         class_lookups = [ | ||||
|             parent.__dict__.get("class_lookups", {}) for parent in inspect.getmro(cls) | ||||
|         ] | ||||
|         return cls.merge_dicts(class_lookups) | ||||
|  | ||||
|     def get_instance_lookups(self): | ||||
|         class_lookups = self.get_class_lookups() | ||||
|         if instance_lookups := getattr(self, "instance_lookups", None): | ||||
|             return {**class_lookups, **instance_lookups} | ||||
|         return class_lookups | ||||
|  | ||||
|     get_lookups = class_or_instance_method(get_class_lookups, get_instance_lookups) | ||||
|     get_class_lookups = classmethod(get_class_lookups) | ||||
|  | ||||
|     def get_lookup(self, lookup_name): | ||||
|         from django.db.models.lookups import Lookup | ||||
|  | ||||
|         found = self._get_lookup(lookup_name) | ||||
|         if found is None and hasattr(self, "output_field"): | ||||
|             return self.output_field.get_lookup(lookup_name) | ||||
|         if found is not None and not issubclass(found, Lookup): | ||||
|             return None | ||||
|         return found | ||||
|  | ||||
|     def get_transform(self, lookup_name): | ||||
|         from django.db.models.lookups import Transform | ||||
|  | ||||
|         found = self._get_lookup(lookup_name) | ||||
|         if found is None and hasattr(self, "output_field"): | ||||
|             return self.output_field.get_transform(lookup_name) | ||||
|         if found is not None and not issubclass(found, Transform): | ||||
|             return None | ||||
|         return found | ||||
|  | ||||
|     @staticmethod | ||||
|     def merge_dicts(dicts): | ||||
|         """ | ||||
|         Merge dicts in reverse to preference the order of the original list. e.g., | ||||
|         merge_dicts([a, b]) will preference the keys in 'a' over those in 'b'. | ||||
|         """ | ||||
|         merged = {} | ||||
|         for d in reversed(dicts): | ||||
|             merged.update(d) | ||||
|         return merged | ||||
|  | ||||
|     @classmethod | ||||
|     def _clear_cached_class_lookups(cls): | ||||
|         for subclass in subclasses(cls): | ||||
|             subclass.get_class_lookups.cache_clear() | ||||
|  | ||||
|     def register_class_lookup(cls, lookup, lookup_name=None): | ||||
|         if lookup_name is None: | ||||
|             lookup_name = lookup.lookup_name | ||||
|         if "class_lookups" not in cls.__dict__: | ||||
|             cls.class_lookups = {} | ||||
|         cls.class_lookups[lookup_name] = lookup | ||||
|         cls._clear_cached_class_lookups() | ||||
|         return lookup | ||||
|  | ||||
|     def register_instance_lookup(self, lookup, lookup_name=None): | ||||
|         if lookup_name is None: | ||||
|             lookup_name = lookup.lookup_name | ||||
|         if "instance_lookups" not in self.__dict__: | ||||
|             self.instance_lookups = {} | ||||
|         self.instance_lookups[lookup_name] = lookup | ||||
|         return lookup | ||||
|  | ||||
|     register_lookup = class_or_instance_method( | ||||
|         register_class_lookup, register_instance_lookup | ||||
|     ) | ||||
|     register_class_lookup = classmethod(register_class_lookup) | ||||
|  | ||||
|     def _unregister_class_lookup(cls, lookup, lookup_name=None): | ||||
|         """ | ||||
|         Remove given lookup from cls lookups. For use in tests only as it's | ||||
|         not thread-safe. | ||||
|         """ | ||||
|         if lookup_name is None: | ||||
|             lookup_name = lookup.lookup_name | ||||
|         del cls.class_lookups[lookup_name] | ||||
|         cls._clear_cached_class_lookups() | ||||
|  | ||||
|     def _unregister_instance_lookup(self, lookup, lookup_name=None): | ||||
|         """ | ||||
|         Remove given lookup from instance lookups. For use in tests only as | ||||
|         it's not thread-safe. | ||||
|         """ | ||||
|         if lookup_name is None: | ||||
|             lookup_name = lookup.lookup_name | ||||
|         del self.instance_lookups[lookup_name] | ||||
|  | ||||
|     _unregister_lookup = class_or_instance_method( | ||||
|         _unregister_class_lookup, _unregister_instance_lookup | ||||
|     ) | ||||
|     _unregister_class_lookup = classmethod(_unregister_class_lookup) | ||||
|  | ||||
|  | ||||
| def select_related_descend(field, restricted, requested, select_mask, reverse=False): | ||||
|     """ | ||||
|     Return True if this field should be used to descend deeper for | ||||
|     select_related() purposes. Used by both the query construction code | ||||
|     (compiler.get_related_selections()) and the model instance creation code | ||||
|     (compiler.klass_info). | ||||
|  | ||||
|     Arguments: | ||||
|      * field - the field to be checked | ||||
|      * restricted - a boolean field, indicating if the field list has been | ||||
|        manually restricted using a requested clause) | ||||
|      * requested - The select_related() dictionary. | ||||
|      * select_mask - the dictionary of selected fields. | ||||
|      * reverse - boolean, True if we are checking a reverse select related | ||||
|     """ | ||||
|     if not field.remote_field: | ||||
|         return False | ||||
|     if field.remote_field.parent_link and not reverse: | ||||
|         return False | ||||
|     if restricted: | ||||
|         if reverse and field.related_query_name() not in requested: | ||||
|             return False | ||||
|         if not reverse and field.name not in requested: | ||||
|             return False | ||||
|     if not restricted and field.null: | ||||
|         return False | ||||
|     if ( | ||||
|         restricted | ||||
|         and select_mask | ||||
|         and field.name in requested | ||||
|         and field not in select_mask | ||||
|     ): | ||||
|         raise FieldError( | ||||
|             f"Field {field.model._meta.object_name}.{field.name} cannot be both " | ||||
|             "deferred and traversed using select_related at the same time." | ||||
|         ) | ||||
|     return True | ||||
|  | ||||
|  | ||||
| def refs_expression(lookup_parts, annotations): | ||||
|     """ | ||||
|     Check if the lookup_parts contains references to the given annotations set. | ||||
|     Because the LOOKUP_SEP is contained in the default annotation names, check | ||||
|     each prefix of the lookup_parts for a match. | ||||
|     """ | ||||
|     for n in range(1, len(lookup_parts) + 1): | ||||
|         level_n_lookup = LOOKUP_SEP.join(lookup_parts[0:n]) | ||||
|         if annotations.get(level_n_lookup): | ||||
|             return level_n_lookup, lookup_parts[n:] | ||||
|     return None, () | ||||
|  | ||||
|  | ||||
| def check_rel_lookup_compatibility(model, target_opts, field): | ||||
|     """ | ||||
|     Check that self.model is compatible with target_opts. Compatibility | ||||
|     is OK if: | ||||
|       1) model and opts match (where proxy inheritance is removed) | ||||
|       2) model is parent of opts' model or the other way around | ||||
|     """ | ||||
|  | ||||
|     def check(opts): | ||||
|         return ( | ||||
|             model._meta.concrete_model == opts.concrete_model | ||||
|             or opts.concrete_model in model._meta.get_parent_list() | ||||
|             or model in opts.get_parent_list() | ||||
|         ) | ||||
|  | ||||
|     # If the field is a primary key, then doing a query against the field's | ||||
|     # model is ok, too. Consider the case: | ||||
|     # class Restaurant(models.Model): | ||||
|     #     place = OneToOneField(Place, primary_key=True): | ||||
|     # Restaurant.objects.filter(pk__in=Restaurant.objects.all()). | ||||
|     # If we didn't have the primary key check, then pk__in (== place__in) would | ||||
|     # give Place's opts as the target opts, but Restaurant isn't compatible | ||||
|     # with that. This logic applies only to primary keys, as when doing __in=qs, | ||||
|     # we are going to turn this into __in=qs.values('pk') later on. | ||||
|     return check(target_opts) or ( | ||||
|         getattr(field, "primary_key", False) and check(field.model._meta) | ||||
|     ) | ||||
|  | ||||
|  | ||||
| class FilteredRelation: | ||||
|     """Specify custom filtering in the ON clause of SQL joins.""" | ||||
|  | ||||
|     def __init__(self, relation_name, *, condition=Q()): | ||||
|         if not relation_name: | ||||
|             raise ValueError("relation_name cannot be empty.") | ||||
|         self.relation_name = relation_name | ||||
|         self.alias = None | ||||
|         if not isinstance(condition, Q): | ||||
|             raise ValueError("condition argument must be a Q() instance.") | ||||
|         self.condition = condition | ||||
|         self.path = [] | ||||
|  | ||||
|     def __eq__(self, other): | ||||
|         if not isinstance(other, self.__class__): | ||||
|             return NotImplemented | ||||
|         return ( | ||||
|             self.relation_name == other.relation_name | ||||
|             and self.alias == other.alias | ||||
|             and self.condition == other.condition | ||||
|         ) | ||||
|  | ||||
|     def clone(self): | ||||
|         clone = FilteredRelation(self.relation_name, condition=self.condition) | ||||
|         clone.alias = self.alias | ||||
|         clone.path = self.path[:] | ||||
|         return clone | ||||
|  | ||||
|     def resolve_expression(self, *args, **kwargs): | ||||
|         """ | ||||
|         QuerySet.annotate() only accepts expression-like arguments | ||||
|         (with a resolve_expression() method). | ||||
|         """ | ||||
|         raise NotImplementedError("FilteredRelation.resolve_expression() is unused.") | ||||
|  | ||||
|     def as_sql(self, compiler, connection): | ||||
|         # Resolve the condition in Join.filtered_relation. | ||||
|         query = compiler.query | ||||
|         where = query.build_filtered_relation_q(self.condition, reuse=set(self.path)) | ||||
|         return compiler.compile(where) | ||||
| @ -0,0 +1,54 @@ | ||||
| from functools import partial | ||||
|  | ||||
| from django.db.models.utils import make_model_tuple | ||||
| from django.dispatch import Signal | ||||
|  | ||||
| class_prepared = Signal() | ||||
|  | ||||
|  | ||||
| class ModelSignal(Signal): | ||||
|     """ | ||||
|     Signal subclass that allows the sender to be lazily specified as a string | ||||
|     of the `app_label.ModelName` form. | ||||
|     """ | ||||
|  | ||||
|     def _lazy_method(self, method, apps, receiver, sender, **kwargs): | ||||
|         from django.db.models.options import Options | ||||
|  | ||||
|         # This partial takes a single optional argument named "sender". | ||||
|         partial_method = partial(method, receiver, **kwargs) | ||||
|         if isinstance(sender, str): | ||||
|             apps = apps or Options.default_apps | ||||
|             apps.lazy_model_operation(partial_method, make_model_tuple(sender)) | ||||
|         else: | ||||
|             return partial_method(sender) | ||||
|  | ||||
|     def connect(self, receiver, sender=None, weak=True, dispatch_uid=None, apps=None): | ||||
|         self._lazy_method( | ||||
|             super().connect, | ||||
|             apps, | ||||
|             receiver, | ||||
|             sender, | ||||
|             weak=weak, | ||||
|             dispatch_uid=dispatch_uid, | ||||
|         ) | ||||
|  | ||||
|     def disconnect(self, receiver=None, sender=None, dispatch_uid=None, apps=None): | ||||
|         return self._lazy_method( | ||||
|             super().disconnect, apps, receiver, sender, dispatch_uid=dispatch_uid | ||||
|         ) | ||||
|  | ||||
|  | ||||
| pre_init = ModelSignal(use_caching=True) | ||||
| post_init = ModelSignal(use_caching=True) | ||||
|  | ||||
| pre_save = ModelSignal(use_caching=True) | ||||
| post_save = ModelSignal(use_caching=True) | ||||
|  | ||||
| pre_delete = ModelSignal(use_caching=True) | ||||
| post_delete = ModelSignal(use_caching=True) | ||||
|  | ||||
| m2m_changed = ModelSignal(use_caching=True) | ||||
|  | ||||
| pre_migrate = Signal() | ||||
| post_migrate = Signal() | ||||
| @ -0,0 +1,6 @@ | ||||
| from django.db.models.sql.query import *  # NOQA | ||||
| from django.db.models.sql.query import Query | ||||
| from django.db.models.sql.subqueries import *  # NOQA | ||||
| from django.db.models.sql.where import AND, OR, XOR | ||||
|  | ||||
| __all__ = ["Query", "AND", "OR", "XOR"] | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -0,0 +1,24 @@ | ||||
| """ | ||||
| Constants specific to the SQL storage portion of the ORM. | ||||
| """ | ||||
|  | ||||
| # Size of each "chunk" for get_iterator calls. | ||||
| # Larger values are slightly faster at the expense of more storage space. | ||||
| GET_ITERATOR_CHUNK_SIZE = 100 | ||||
|  | ||||
| # Namedtuples for sql.* internal use. | ||||
|  | ||||
| # How many results to expect from a cursor.execute call | ||||
| MULTI = "multi" | ||||
| SINGLE = "single" | ||||
| CURSOR = "cursor" | ||||
| NO_RESULTS = "no results" | ||||
|  | ||||
| ORDER_DIR = { | ||||
|     "ASC": ("ASC", "DESC"), | ||||
|     "DESC": ("DESC", "ASC"), | ||||
| } | ||||
|  | ||||
| # SQL join types. | ||||
| INNER = "INNER JOIN" | ||||
| LOUTER = "LEFT OUTER JOIN" | ||||
| @ -0,0 +1,224 @@ | ||||
| """ | ||||
| Useful auxiliary data structures for query construction. Not useful outside | ||||
| the SQL domain. | ||||
| """ | ||||
| from django.core.exceptions import FullResultSet | ||||
| from django.db.models.sql.constants import INNER, LOUTER | ||||
|  | ||||
|  | ||||
| class MultiJoin(Exception): | ||||
|     """ | ||||
|     Used by join construction code to indicate the point at which a | ||||
|     multi-valued join was attempted (if the caller wants to treat that | ||||
|     exceptionally). | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, names_pos, path_with_names): | ||||
|         self.level = names_pos | ||||
|         # The path travelled, this includes the path to the multijoin. | ||||
|         self.names_with_path = path_with_names | ||||
|  | ||||
|  | ||||
| class Empty: | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class Join: | ||||
|     """ | ||||
|     Used by sql.Query and sql.SQLCompiler to generate JOIN clauses into the | ||||
|     FROM entry. For example, the SQL generated could be | ||||
|         LEFT OUTER JOIN "sometable" T1 | ||||
|         ON ("othertable"."sometable_id" = "sometable"."id") | ||||
|  | ||||
|     This class is primarily used in Query.alias_map. All entries in alias_map | ||||
|     must be Join compatible by providing the following attributes and methods: | ||||
|         - table_name (string) | ||||
|         - table_alias (possible alias for the table, can be None) | ||||
|         - join_type (can be None for those entries that aren't joined from | ||||
|           anything) | ||||
|         - parent_alias (which table is this join's parent, can be None similarly | ||||
|           to join_type) | ||||
|         - as_sql() | ||||
|         - relabeled_clone() | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         table_name, | ||||
|         parent_alias, | ||||
|         table_alias, | ||||
|         join_type, | ||||
|         join_field, | ||||
|         nullable, | ||||
|         filtered_relation=None, | ||||
|     ): | ||||
|         # Join table | ||||
|         self.table_name = table_name | ||||
|         self.parent_alias = parent_alias | ||||
|         # Note: table_alias is not necessarily known at instantiation time. | ||||
|         self.table_alias = table_alias | ||||
|         # LOUTER or INNER | ||||
|         self.join_type = join_type | ||||
|         # A list of 2-tuples to use in the ON clause of the JOIN. | ||||
|         # Each 2-tuple will create one join condition in the ON clause. | ||||
|         self.join_cols = join_field.get_joining_columns() | ||||
|         # Along which field (or ForeignObjectRel in the reverse join case) | ||||
|         self.join_field = join_field | ||||
|         # Is this join nullabled? | ||||
|         self.nullable = nullable | ||||
|         self.filtered_relation = filtered_relation | ||||
|  | ||||
|     def as_sql(self, compiler, connection): | ||||
|         """ | ||||
|         Generate the full | ||||
|            LEFT OUTER JOIN sometable ON sometable.somecol = othertable.othercol, params | ||||
|         clause for this join. | ||||
|         """ | ||||
|         join_conditions = [] | ||||
|         params = [] | ||||
|         qn = compiler.quote_name_unless_alias | ||||
|         qn2 = connection.ops.quote_name | ||||
|  | ||||
|         # Add a join condition for each pair of joining columns. | ||||
|         for lhs_col, rhs_col in self.join_cols: | ||||
|             join_conditions.append( | ||||
|                 "%s.%s = %s.%s" | ||||
|                 % ( | ||||
|                     qn(self.parent_alias), | ||||
|                     qn2(lhs_col), | ||||
|                     qn(self.table_alias), | ||||
|                     qn2(rhs_col), | ||||
|                 ) | ||||
|             ) | ||||
|  | ||||
|         # Add a single condition inside parentheses for whatever | ||||
|         # get_extra_restriction() returns. | ||||
|         extra_cond = self.join_field.get_extra_restriction( | ||||
|             self.table_alias, self.parent_alias | ||||
|         ) | ||||
|         if extra_cond: | ||||
|             extra_sql, extra_params = compiler.compile(extra_cond) | ||||
|             join_conditions.append("(%s)" % extra_sql) | ||||
|             params.extend(extra_params) | ||||
|         if self.filtered_relation: | ||||
|             try: | ||||
|                 extra_sql, extra_params = compiler.compile(self.filtered_relation) | ||||
|             except FullResultSet: | ||||
|                 pass | ||||
|             else: | ||||
|                 join_conditions.append("(%s)" % extra_sql) | ||||
|                 params.extend(extra_params) | ||||
|         if not join_conditions: | ||||
|             # This might be a rel on the other end of an actual declared field. | ||||
|             declared_field = getattr(self.join_field, "field", self.join_field) | ||||
|             raise ValueError( | ||||
|                 "Join generated an empty ON clause. %s did not yield either " | ||||
|                 "joining columns or extra restrictions." % declared_field.__class__ | ||||
|             ) | ||||
|         on_clause_sql = " AND ".join(join_conditions) | ||||
|         alias_str = ( | ||||
|             "" if self.table_alias == self.table_name else (" %s" % self.table_alias) | ||||
|         ) | ||||
|         sql = "%s %s%s ON (%s)" % ( | ||||
|             self.join_type, | ||||
|             qn(self.table_name), | ||||
|             alias_str, | ||||
|             on_clause_sql, | ||||
|         ) | ||||
|         return sql, params | ||||
|  | ||||
|     def relabeled_clone(self, change_map): | ||||
|         new_parent_alias = change_map.get(self.parent_alias, self.parent_alias) | ||||
|         new_table_alias = change_map.get(self.table_alias, self.table_alias) | ||||
|         if self.filtered_relation is not None: | ||||
|             filtered_relation = self.filtered_relation.clone() | ||||
|             filtered_relation.path = [ | ||||
|                 change_map.get(p, p) for p in self.filtered_relation.path | ||||
|             ] | ||||
|         else: | ||||
|             filtered_relation = None | ||||
|         return self.__class__( | ||||
|             self.table_name, | ||||
|             new_parent_alias, | ||||
|             new_table_alias, | ||||
|             self.join_type, | ||||
|             self.join_field, | ||||
|             self.nullable, | ||||
|             filtered_relation=filtered_relation, | ||||
|         ) | ||||
|  | ||||
|     @property | ||||
|     def identity(self): | ||||
|         return ( | ||||
|             self.__class__, | ||||
|             self.table_name, | ||||
|             self.parent_alias, | ||||
|             self.join_field, | ||||
|             self.filtered_relation, | ||||
|         ) | ||||
|  | ||||
|     def __eq__(self, other): | ||||
|         if not isinstance(other, Join): | ||||
|             return NotImplemented | ||||
|         return self.identity == other.identity | ||||
|  | ||||
|     def __hash__(self): | ||||
|         return hash(self.identity) | ||||
|  | ||||
|     def equals(self, other): | ||||
|         # Ignore filtered_relation in equality check. | ||||
|         return self.identity[:-1] == other.identity[:-1] | ||||
|  | ||||
|     def demote(self): | ||||
|         new = self.relabeled_clone({}) | ||||
|         new.join_type = INNER | ||||
|         return new | ||||
|  | ||||
|     def promote(self): | ||||
|         new = self.relabeled_clone({}) | ||||
|         new.join_type = LOUTER | ||||
|         return new | ||||
|  | ||||
|  | ||||
| class BaseTable: | ||||
|     """ | ||||
|     The BaseTable class is used for base table references in FROM clause. For | ||||
|     example, the SQL "foo" in | ||||
|         SELECT * FROM "foo" WHERE somecond | ||||
|     could be generated by this class. | ||||
|     """ | ||||
|  | ||||
|     join_type = None | ||||
|     parent_alias = None | ||||
|     filtered_relation = None | ||||
|  | ||||
|     def __init__(self, table_name, alias): | ||||
|         self.table_name = table_name | ||||
|         self.table_alias = alias | ||||
|  | ||||
|     def as_sql(self, compiler, connection): | ||||
|         alias_str = ( | ||||
|             "" if self.table_alias == self.table_name else (" %s" % self.table_alias) | ||||
|         ) | ||||
|         base_sql = compiler.quote_name_unless_alias(self.table_name) | ||||
|         return base_sql + alias_str, [] | ||||
|  | ||||
|     def relabeled_clone(self, change_map): | ||||
|         return self.__class__( | ||||
|             self.table_name, change_map.get(self.table_alias, self.table_alias) | ||||
|         ) | ||||
|  | ||||
|     @property | ||||
|     def identity(self): | ||||
|         return self.__class__, self.table_name, self.table_alias | ||||
|  | ||||
|     def __eq__(self, other): | ||||
|         if not isinstance(other, BaseTable): | ||||
|             return NotImplemented | ||||
|         return self.identity == other.identity | ||||
|  | ||||
|     def __hash__(self): | ||||
|         return hash(self.identity) | ||||
|  | ||||
|     def equals(self, other): | ||||
|         return self.identity == other.identity | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -0,0 +1,171 @@ | ||||
| """ | ||||
| Query subclasses which provide extra functionality beyond simple data retrieval. | ||||
| """ | ||||
|  | ||||
| from django.core.exceptions import FieldError | ||||
| from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE, NO_RESULTS | ||||
| from django.db.models.sql.query import Query | ||||
|  | ||||
| __all__ = ["DeleteQuery", "UpdateQuery", "InsertQuery", "AggregateQuery"] | ||||
|  | ||||
|  | ||||
| class DeleteQuery(Query): | ||||
|     """A DELETE SQL query.""" | ||||
|  | ||||
|     compiler = "SQLDeleteCompiler" | ||||
|  | ||||
|     def do_query(self, table, where, using): | ||||
|         self.alias_map = {table: self.alias_map[table]} | ||||
|         self.where = where | ||||
|         cursor = self.get_compiler(using).execute_sql(CURSOR) | ||||
|         if cursor: | ||||
|             with cursor: | ||||
|                 return cursor.rowcount | ||||
|         return 0 | ||||
|  | ||||
|     def delete_batch(self, pk_list, using): | ||||
|         """ | ||||
|         Set up and execute delete queries for all the objects in pk_list. | ||||
|  | ||||
|         More than one physical query may be executed if there are a | ||||
|         lot of values in pk_list. | ||||
|         """ | ||||
|         # number of objects deleted | ||||
|         num_deleted = 0 | ||||
|         field = self.get_meta().pk | ||||
|         for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): | ||||
|             self.clear_where() | ||||
|             self.add_filter( | ||||
|                 f"{field.attname}__in", | ||||
|                 pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE], | ||||
|             ) | ||||
|             num_deleted += self.do_query( | ||||
|                 self.get_meta().db_table, self.where, using=using | ||||
|             ) | ||||
|         return num_deleted | ||||
|  | ||||
|  | ||||
| class UpdateQuery(Query): | ||||
|     """An UPDATE SQL query.""" | ||||
|  | ||||
|     compiler = "SQLUpdateCompiler" | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|         self._setup_query() | ||||
|  | ||||
|     def _setup_query(self): | ||||
|         """ | ||||
|         Run on initialization and at the end of chaining. Any attributes that | ||||
|         would normally be set in __init__() should go here instead. | ||||
|         """ | ||||
|         self.values = [] | ||||
|         self.related_ids = None | ||||
|         self.related_updates = {} | ||||
|  | ||||
|     def clone(self): | ||||
|         obj = super().clone() | ||||
|         obj.related_updates = self.related_updates.copy() | ||||
|         return obj | ||||
|  | ||||
|     def update_batch(self, pk_list, values, using): | ||||
|         self.add_update_values(values) | ||||
|         for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): | ||||
|             self.clear_where() | ||||
|             self.add_filter( | ||||
|                 "pk__in", pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE] | ||||
|             ) | ||||
|             self.get_compiler(using).execute_sql(NO_RESULTS) | ||||
|  | ||||
|     def add_update_values(self, values): | ||||
|         """ | ||||
|         Convert a dictionary of field name to value mappings into an update | ||||
|         query. This is the entry point for the public update() method on | ||||
|         querysets. | ||||
|         """ | ||||
|         values_seq = [] | ||||
|         for name, val in values.items(): | ||||
|             field = self.get_meta().get_field(name) | ||||
|             direct = ( | ||||
|                 not (field.auto_created and not field.concrete) or not field.concrete | ||||
|             ) | ||||
|             model = field.model._meta.concrete_model | ||||
|             if not direct or (field.is_relation and field.many_to_many): | ||||
|                 raise FieldError( | ||||
|                     "Cannot update model field %r (only non-relations and " | ||||
|                     "foreign keys permitted)." % field | ||||
|                 ) | ||||
|             if model is not self.get_meta().concrete_model: | ||||
|                 self.add_related_update(model, field, val) | ||||
|                 continue | ||||
|             values_seq.append((field, model, val)) | ||||
|         return self.add_update_fields(values_seq) | ||||
|  | ||||
|     def add_update_fields(self, values_seq): | ||||
|         """ | ||||
|         Append a sequence of (field, model, value) triples to the internal list | ||||
|         that will be used to generate the UPDATE query. Might be more usefully | ||||
|         called add_update_targets() to hint at the extra information here. | ||||
|         """ | ||||
|         for field, model, val in values_seq: | ||||
|             if hasattr(val, "resolve_expression"): | ||||
|                 # Resolve expressions here so that annotations are no longer needed | ||||
|                 val = val.resolve_expression(self, allow_joins=False, for_save=True) | ||||
|             self.values.append((field, model, val)) | ||||
|  | ||||
|     def add_related_update(self, model, field, value): | ||||
|         """ | ||||
|         Add (name, value) to an update query for an ancestor model. | ||||
|  | ||||
|         Update are coalesced so that only one update query per ancestor is run. | ||||
|         """ | ||||
|         self.related_updates.setdefault(model, []).append((field, None, value)) | ||||
|  | ||||
|     def get_related_updates(self): | ||||
|         """ | ||||
|         Return a list of query objects: one for each update required to an | ||||
|         ancestor model. Each query will have the same filtering conditions as | ||||
|         the current query but will only update a single table. | ||||
|         """ | ||||
|         if not self.related_updates: | ||||
|             return [] | ||||
|         result = [] | ||||
|         for model, values in self.related_updates.items(): | ||||
|             query = UpdateQuery(model) | ||||
|             query.values = values | ||||
|             if self.related_ids is not None: | ||||
|                 query.add_filter("pk__in", self.related_ids[model]) | ||||
|             result.append(query) | ||||
|         return result | ||||
|  | ||||
|  | ||||
| class InsertQuery(Query): | ||||
|     compiler = "SQLInsertCompiler" | ||||
|  | ||||
|     def __init__( | ||||
|         self, *args, on_conflict=None, update_fields=None, unique_fields=None, **kwargs | ||||
|     ): | ||||
|         super().__init__(*args, **kwargs) | ||||
|         self.fields = [] | ||||
|         self.objs = [] | ||||
|         self.on_conflict = on_conflict | ||||
|         self.update_fields = update_fields or [] | ||||
|         self.unique_fields = unique_fields or [] | ||||
|  | ||||
|     def insert_values(self, fields, objs, raw=False): | ||||
|         self.fields = fields | ||||
|         self.objs = objs | ||||
|         self.raw = raw | ||||
|  | ||||
|  | ||||
| class AggregateQuery(Query): | ||||
|     """ | ||||
|     Take another query as a parameter to the FROM clause and only select the | ||||
|     elements in the provided list. | ||||
|     """ | ||||
|  | ||||
|     compiler = "SQLAggregateCompiler" | ||||
|  | ||||
|     def __init__(self, model, inner_query): | ||||
|         self.inner_query = inner_query | ||||
|         super().__init__(model) | ||||
| @ -0,0 +1,355 @@ | ||||
| """ | ||||
| Code to manage the creation and SQL rendering of 'where' constraints. | ||||
| """ | ||||
| import operator | ||||
| from functools import reduce | ||||
|  | ||||
| from django.core.exceptions import EmptyResultSet, FullResultSet | ||||
| from django.db.models.expressions import Case, When | ||||
| from django.db.models.lookups import Exact | ||||
| from django.utils import tree | ||||
| from django.utils.functional import cached_property | ||||
|  | ||||
| # Connection types | ||||
| AND = "AND" | ||||
| OR = "OR" | ||||
| XOR = "XOR" | ||||
|  | ||||
|  | ||||
| class WhereNode(tree.Node): | ||||
|     """ | ||||
|     An SQL WHERE clause. | ||||
|  | ||||
|     The class is tied to the Query class that created it (in order to create | ||||
|     the correct SQL). | ||||
|  | ||||
|     A child is usually an expression producing boolean values. Most likely the | ||||
|     expression is a Lookup instance. | ||||
|  | ||||
|     However, a child could also be any class with as_sql() and either | ||||
|     relabeled_clone() method or relabel_aliases() and clone() methods and | ||||
|     contains_aggregate attribute. | ||||
|     """ | ||||
|  | ||||
|     default = AND | ||||
|     resolved = False | ||||
|     conditional = True | ||||
|  | ||||
|     def split_having_qualify(self, negated=False, must_group_by=False): | ||||
|         """ | ||||
|         Return three possibly None nodes: one for those parts of self that | ||||
|         should be included in the WHERE clause, one for those parts of self | ||||
|         that must be included in the HAVING clause, and one for those parts | ||||
|         that refer to window functions. | ||||
|         """ | ||||
|         if not self.contains_aggregate and not self.contains_over_clause: | ||||
|             return self, None, None | ||||
|         in_negated = negated ^ self.negated | ||||
|         # Whether or not children must be connected in the same filtering | ||||
|         # clause (WHERE > HAVING > QUALIFY) to maintain logical semantic. | ||||
|         must_remain_connected = ( | ||||
|             (in_negated and self.connector == AND) | ||||
|             or (not in_negated and self.connector == OR) | ||||
|             or self.connector == XOR | ||||
|         ) | ||||
|         if ( | ||||
|             must_remain_connected | ||||
|             and self.contains_aggregate | ||||
|             and not self.contains_over_clause | ||||
|         ): | ||||
|             # It's must cheaper to short-circuit and stash everything in the | ||||
|             # HAVING clause than split children if possible. | ||||
|             return None, self, None | ||||
|         where_parts = [] | ||||
|         having_parts = [] | ||||
|         qualify_parts = [] | ||||
|         for c in self.children: | ||||
|             if hasattr(c, "split_having_qualify"): | ||||
|                 where_part, having_part, qualify_part = c.split_having_qualify( | ||||
|                     in_negated, must_group_by | ||||
|                 ) | ||||
|                 if where_part is not None: | ||||
|                     where_parts.append(where_part) | ||||
|                 if having_part is not None: | ||||
|                     having_parts.append(having_part) | ||||
|                 if qualify_part is not None: | ||||
|                     qualify_parts.append(qualify_part) | ||||
|             elif c.contains_over_clause: | ||||
|                 qualify_parts.append(c) | ||||
|             elif c.contains_aggregate: | ||||
|                 having_parts.append(c) | ||||
|             else: | ||||
|                 where_parts.append(c) | ||||
|         if must_remain_connected and qualify_parts: | ||||
|             # Disjunctive heterogeneous predicates can be pushed down to | ||||
|             # qualify as long as no conditional aggregation is involved. | ||||
|             if not where_parts or (where_parts and not must_group_by): | ||||
|                 return None, None, self | ||||
|             elif where_parts: | ||||
|                 # In theory this should only be enforced when dealing with | ||||
|                 # where_parts containing predicates against multi-valued | ||||
|                 # relationships that could affect aggregation results but this | ||||
|                 # is complex to infer properly. | ||||
|                 raise NotImplementedError( | ||||
|                     "Heterogeneous disjunctive predicates against window functions are " | ||||
|                     "not implemented when performing conditional aggregation." | ||||
|                 ) | ||||
|         where_node = ( | ||||
|             self.create(where_parts, self.connector, self.negated) | ||||
|             if where_parts | ||||
|             else None | ||||
|         ) | ||||
|         having_node = ( | ||||
|             self.create(having_parts, self.connector, self.negated) | ||||
|             if having_parts | ||||
|             else None | ||||
|         ) | ||||
|         qualify_node = ( | ||||
|             self.create(qualify_parts, self.connector, self.negated) | ||||
|             if qualify_parts | ||||
|             else None | ||||
|         ) | ||||
|         return where_node, having_node, qualify_node | ||||
|  | ||||
|     def as_sql(self, compiler, connection): | ||||
|         """ | ||||
|         Return the SQL version of the where clause and the value to be | ||||
|         substituted in. Return '', [] if this node matches everything, | ||||
|         None, [] if this node is empty, and raise EmptyResultSet if this | ||||
|         node can't match anything. | ||||
|         """ | ||||
|         result = [] | ||||
|         result_params = [] | ||||
|         if self.connector == AND: | ||||
|             full_needed, empty_needed = len(self.children), 1 | ||||
|         else: | ||||
|             full_needed, empty_needed = 1, len(self.children) | ||||
|  | ||||
|         if self.connector == XOR and not connection.features.supports_logical_xor: | ||||
|             # Convert if the database doesn't support XOR: | ||||
|             #   a XOR b XOR c XOR ... | ||||
|             # to: | ||||
|             #   (a OR b OR c OR ...) AND (a + b + c + ...) == 1 | ||||
|             lhs = self.__class__(self.children, OR) | ||||
|             rhs_sum = reduce( | ||||
|                 operator.add, | ||||
|                 (Case(When(c, then=1), default=0) for c in self.children), | ||||
|             ) | ||||
|             rhs = Exact(1, rhs_sum) | ||||
|             return self.__class__([lhs, rhs], AND, self.negated).as_sql( | ||||
|                 compiler, connection | ||||
|             ) | ||||
|  | ||||
|         for child in self.children: | ||||
|             try: | ||||
|                 sql, params = compiler.compile(child) | ||||
|             except EmptyResultSet: | ||||
|                 empty_needed -= 1 | ||||
|             except FullResultSet: | ||||
|                 full_needed -= 1 | ||||
|             else: | ||||
|                 if sql: | ||||
|                     result.append(sql) | ||||
|                     result_params.extend(params) | ||||
|                 else: | ||||
|                     full_needed -= 1 | ||||
|             # Check if this node matches nothing or everything. | ||||
|             # First check the amount of full nodes and empty nodes | ||||
|             # to make this node empty/full. | ||||
|             # Now, check if this node is full/empty using the | ||||
|             # counts. | ||||
|             if empty_needed == 0: | ||||
|                 if self.negated: | ||||
|                     raise FullResultSet | ||||
|                 else: | ||||
|                     raise EmptyResultSet | ||||
|             if full_needed == 0: | ||||
|                 if self.negated: | ||||
|                     raise EmptyResultSet | ||||
|                 else: | ||||
|                     raise FullResultSet | ||||
|         conn = " %s " % self.connector | ||||
|         sql_string = conn.join(result) | ||||
|         if not sql_string: | ||||
|             raise FullResultSet | ||||
|         if self.negated: | ||||
|             # Some backends (Oracle at least) need parentheses around the inner | ||||
|             # SQL in the negated case, even if the inner SQL contains just a | ||||
|             # single expression. | ||||
|             sql_string = "NOT (%s)" % sql_string | ||||
|         elif len(result) > 1 or self.resolved: | ||||
|             sql_string = "(%s)" % sql_string | ||||
|         return sql_string, result_params | ||||
|  | ||||
|     def get_group_by_cols(self): | ||||
|         cols = [] | ||||
|         for child in self.children: | ||||
|             cols.extend(child.get_group_by_cols()) | ||||
|         return cols | ||||
|  | ||||
|     def get_source_expressions(self): | ||||
|         return self.children[:] | ||||
|  | ||||
|     def set_source_expressions(self, children): | ||||
|         assert len(children) == len(self.children) | ||||
|         self.children = children | ||||
|  | ||||
|     def relabel_aliases(self, change_map): | ||||
|         """ | ||||
|         Relabel the alias values of any children. 'change_map' is a dictionary | ||||
|         mapping old (current) alias values to the new values. | ||||
|         """ | ||||
|         for pos, child in enumerate(self.children): | ||||
|             if hasattr(child, "relabel_aliases"): | ||||
|                 # For example another WhereNode | ||||
|                 child.relabel_aliases(change_map) | ||||
|             elif hasattr(child, "relabeled_clone"): | ||||
|                 self.children[pos] = child.relabeled_clone(change_map) | ||||
|  | ||||
|     def clone(self): | ||||
|         clone = self.create(connector=self.connector, negated=self.negated) | ||||
|         for child in self.children: | ||||
|             if hasattr(child, "clone"): | ||||
|                 child = child.clone() | ||||
|             clone.children.append(child) | ||||
|         return clone | ||||
|  | ||||
|     def relabeled_clone(self, change_map): | ||||
|         clone = self.clone() | ||||
|         clone.relabel_aliases(change_map) | ||||
|         return clone | ||||
|  | ||||
|     def replace_expressions(self, replacements): | ||||
|         if replacement := replacements.get(self): | ||||
|             return replacement | ||||
|         clone = self.create(connector=self.connector, negated=self.negated) | ||||
|         for child in self.children: | ||||
|             clone.children.append(child.replace_expressions(replacements)) | ||||
|         return clone | ||||
|  | ||||
|     def get_refs(self): | ||||
|         refs = set() | ||||
|         for child in self.children: | ||||
|             refs |= child.get_refs() | ||||
|         return refs | ||||
|  | ||||
|     @classmethod | ||||
|     def _contains_aggregate(cls, obj): | ||||
|         if isinstance(obj, tree.Node): | ||||
|             return any(cls._contains_aggregate(c) for c in obj.children) | ||||
|         return obj.contains_aggregate | ||||
|  | ||||
|     @cached_property | ||||
|     def contains_aggregate(self): | ||||
|         return self._contains_aggregate(self) | ||||
|  | ||||
|     @classmethod | ||||
|     def _contains_over_clause(cls, obj): | ||||
|         if isinstance(obj, tree.Node): | ||||
|             return any(cls._contains_over_clause(c) for c in obj.children) | ||||
|         return obj.contains_over_clause | ||||
|  | ||||
|     @cached_property | ||||
|     def contains_over_clause(self): | ||||
|         return self._contains_over_clause(self) | ||||
|  | ||||
|     @property | ||||
|     def is_summary(self): | ||||
|         return any(child.is_summary for child in self.children) | ||||
|  | ||||
|     @staticmethod | ||||
|     def _resolve_leaf(expr, query, *args, **kwargs): | ||||
|         if hasattr(expr, "resolve_expression"): | ||||
|             expr = expr.resolve_expression(query, *args, **kwargs) | ||||
|         return expr | ||||
|  | ||||
|     @classmethod | ||||
|     def _resolve_node(cls, node, query, *args, **kwargs): | ||||
|         if hasattr(node, "children"): | ||||
|             for child in node.children: | ||||
|                 cls._resolve_node(child, query, *args, **kwargs) | ||||
|         if hasattr(node, "lhs"): | ||||
|             node.lhs = cls._resolve_leaf(node.lhs, query, *args, **kwargs) | ||||
|         if hasattr(node, "rhs"): | ||||
|             node.rhs = cls._resolve_leaf(node.rhs, query, *args, **kwargs) | ||||
|  | ||||
|     def resolve_expression(self, *args, **kwargs): | ||||
|         clone = self.clone() | ||||
|         clone._resolve_node(clone, *args, **kwargs) | ||||
|         clone.resolved = True | ||||
|         return clone | ||||
|  | ||||
|     @cached_property | ||||
|     def output_field(self): | ||||
|         from django.db.models import BooleanField | ||||
|  | ||||
|         return BooleanField() | ||||
|  | ||||
|     @property | ||||
|     def _output_field_or_none(self): | ||||
|         return self.output_field | ||||
|  | ||||
|     def select_format(self, compiler, sql, params): | ||||
|         # Wrap filters with a CASE WHEN expression if a database backend | ||||
|         # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP | ||||
|         # BY list. | ||||
|         if not compiler.connection.features.supports_boolean_expr_in_select_clause: | ||||
|             sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END" | ||||
|         return sql, params | ||||
|  | ||||
|     def get_db_converters(self, connection): | ||||
|         return self.output_field.get_db_converters(connection) | ||||
|  | ||||
|     def get_lookup(self, lookup): | ||||
|         return self.output_field.get_lookup(lookup) | ||||
|  | ||||
|     def leaves(self): | ||||
|         for child in self.children: | ||||
|             if isinstance(child, WhereNode): | ||||
|                 yield from child.leaves() | ||||
|             else: | ||||
|                 yield child | ||||
|  | ||||
|  | ||||
| class NothingNode: | ||||
|     """A node that matches nothing.""" | ||||
|  | ||||
|     contains_aggregate = False | ||||
|     contains_over_clause = False | ||||
|  | ||||
|     def as_sql(self, compiler=None, connection=None): | ||||
|         raise EmptyResultSet | ||||
|  | ||||
|  | ||||
| class ExtraWhere: | ||||
|     # The contents are a black box - assume no aggregates or windows are used. | ||||
|     contains_aggregate = False | ||||
|     contains_over_clause = False | ||||
|  | ||||
|     def __init__(self, sqls, params): | ||||
|         self.sqls = sqls | ||||
|         self.params = params | ||||
|  | ||||
|     def as_sql(self, compiler=None, connection=None): | ||||
|         sqls = ["(%s)" % sql for sql in self.sqls] | ||||
|         return " AND ".join(sqls), list(self.params or ()) | ||||
|  | ||||
|  | ||||
| class SubqueryConstraint: | ||||
|     # Even if aggregates or windows would be used in a subquery, | ||||
|     # the outer query isn't interested about those. | ||||
|     contains_aggregate = False | ||||
|     contains_over_clause = False | ||||
|  | ||||
|     def __init__(self, alias, columns, targets, query_object): | ||||
|         self.alias = alias | ||||
|         self.columns = columns | ||||
|         self.targets = targets | ||||
|         query_object.clear_ordering(clear_default=True) | ||||
|         self.query_object = query_object | ||||
|  | ||||
|     def as_sql(self, compiler, connection): | ||||
|         query = self.query_object | ||||
|         query.set_values(self.targets) | ||||
|         query_compiler = query.get_compiler(connection=connection) | ||||
|         return query_compiler.as_subquery_condition(self.alias, self.columns, compiler) | ||||
| @ -0,0 +1,69 @@ | ||||
| import functools | ||||
| from collections import namedtuple | ||||
|  | ||||
|  | ||||
| def make_model_tuple(model): | ||||
|     """ | ||||
|     Take a model or a string of the form "app_label.ModelName" and return a | ||||
|     corresponding ("app_label", "modelname") tuple. If a tuple is passed in, | ||||
|     assume it's a valid model tuple already and return it unchanged. | ||||
|     """ | ||||
|     try: | ||||
|         if isinstance(model, tuple): | ||||
|             model_tuple = model | ||||
|         elif isinstance(model, str): | ||||
|             app_label, model_name = model.split(".") | ||||
|             model_tuple = app_label, model_name.lower() | ||||
|         else: | ||||
|             model_tuple = model._meta.app_label, model._meta.model_name | ||||
|         assert len(model_tuple) == 2 | ||||
|         return model_tuple | ||||
|     except (ValueError, AssertionError): | ||||
|         raise ValueError( | ||||
|             "Invalid model reference '%s'. String model references " | ||||
|             "must be of the form 'app_label.ModelName'." % model | ||||
|         ) | ||||
|  | ||||
|  | ||||
| def resolve_callables(mapping): | ||||
|     """ | ||||
|     Generate key/value pairs for the given mapping where the values are | ||||
|     evaluated if they're callable. | ||||
|     """ | ||||
|     for k, v in mapping.items(): | ||||
|         yield k, v() if callable(v) else v | ||||
|  | ||||
|  | ||||
| def unpickle_named_row(names, values): | ||||
|     return create_namedtuple_class(*names)(*values) | ||||
|  | ||||
|  | ||||
| @functools.lru_cache | ||||
| def create_namedtuple_class(*names): | ||||
|     # Cache type() with @lru_cache since it's too slow to be called for every | ||||
|     # QuerySet evaluation. | ||||
|     def __reduce__(self): | ||||
|         return unpickle_named_row, (names, tuple(self)) | ||||
|  | ||||
|     return type( | ||||
|         "Row", | ||||
|         (namedtuple("Row", names),), | ||||
|         {"__reduce__": __reduce__, "__slots__": ()}, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| class AltersData: | ||||
|     """ | ||||
|     Make subclasses preserve the alters_data attribute on overridden methods. | ||||
|     """ | ||||
|  | ||||
|     def __init_subclass__(cls, **kwargs): | ||||
|         for fn_name, fn in vars(cls).items(): | ||||
|             if callable(fn) and not hasattr(fn, "alters_data"): | ||||
|                 for base in cls.__bases__: | ||||
|                     if base_fn := getattr(base, fn_name, None): | ||||
|                         if hasattr(base_fn, "alters_data"): | ||||
|                             fn.alters_data = base_fn.alters_data | ||||
|                         break | ||||
|  | ||||
|         super().__init_subclass__(**kwargs) | ||||
		Reference in New Issue
	
	Block a user