docker setup
This commit is contained in:
@ -0,0 +1,115 @@
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from django.db.models import signals
|
||||
from django.db.models.aggregates import * # NOQA
|
||||
from django.db.models.aggregates import __all__ as aggregates_all
|
||||
from django.db.models.constraints import * # NOQA
|
||||
from django.db.models.constraints import __all__ as constraints_all
|
||||
from django.db.models.deletion import (
|
||||
CASCADE,
|
||||
DO_NOTHING,
|
||||
PROTECT,
|
||||
RESTRICT,
|
||||
SET,
|
||||
SET_DEFAULT,
|
||||
SET_NULL,
|
||||
ProtectedError,
|
||||
RestrictedError,
|
||||
)
|
||||
from django.db.models.enums import * # NOQA
|
||||
from django.db.models.enums import __all__ as enums_all
|
||||
from django.db.models.expressions import (
|
||||
Case,
|
||||
Exists,
|
||||
Expression,
|
||||
ExpressionList,
|
||||
ExpressionWrapper,
|
||||
F,
|
||||
Func,
|
||||
OrderBy,
|
||||
OuterRef,
|
||||
RowRange,
|
||||
Subquery,
|
||||
Value,
|
||||
ValueRange,
|
||||
When,
|
||||
Window,
|
||||
WindowFrame,
|
||||
)
|
||||
from django.db.models.fields import * # NOQA
|
||||
from django.db.models.fields import __all__ as fields_all
|
||||
from django.db.models.fields.files import FileField, ImageField
|
||||
from django.db.models.fields.json import JSONField
|
||||
from django.db.models.fields.proxy import OrderWrt
|
||||
from django.db.models.indexes import * # NOQA
|
||||
from django.db.models.indexes import __all__ as indexes_all
|
||||
from django.db.models.lookups import Lookup, Transform
|
||||
from django.db.models.manager import Manager
|
||||
from django.db.models.query import Prefetch, QuerySet, prefetch_related_objects
|
||||
from django.db.models.query_utils import FilteredRelation, Q
|
||||
|
||||
# Imports that would create circular imports if sorted
|
||||
from django.db.models.base import DEFERRED, Model # isort:skip
|
||||
from django.db.models.fields.related import ( # isort:skip
|
||||
ForeignKey,
|
||||
ForeignObject,
|
||||
OneToOneField,
|
||||
ManyToManyField,
|
||||
ForeignObjectRel,
|
||||
ManyToOneRel,
|
||||
ManyToManyRel,
|
||||
OneToOneRel,
|
||||
)
|
||||
|
||||
|
||||
__all__ = aggregates_all + constraints_all + enums_all + fields_all + indexes_all
|
||||
__all__ += [
|
||||
"ObjectDoesNotExist",
|
||||
"signals",
|
||||
"CASCADE",
|
||||
"DO_NOTHING",
|
||||
"PROTECT",
|
||||
"RESTRICT",
|
||||
"SET",
|
||||
"SET_DEFAULT",
|
||||
"SET_NULL",
|
||||
"ProtectedError",
|
||||
"RestrictedError",
|
||||
"Case",
|
||||
"Exists",
|
||||
"Expression",
|
||||
"ExpressionList",
|
||||
"ExpressionWrapper",
|
||||
"F",
|
||||
"Func",
|
||||
"OrderBy",
|
||||
"OuterRef",
|
||||
"RowRange",
|
||||
"Subquery",
|
||||
"Value",
|
||||
"ValueRange",
|
||||
"When",
|
||||
"Window",
|
||||
"WindowFrame",
|
||||
"FileField",
|
||||
"ImageField",
|
||||
"JSONField",
|
||||
"OrderWrt",
|
||||
"Lookup",
|
||||
"Transform",
|
||||
"Manager",
|
||||
"Prefetch",
|
||||
"Q",
|
||||
"QuerySet",
|
||||
"prefetch_related_objects",
|
||||
"DEFERRED",
|
||||
"Model",
|
||||
"FilteredRelation",
|
||||
"ForeignKey",
|
||||
"ForeignObject",
|
||||
"OneToOneField",
|
||||
"ManyToManyField",
|
||||
"ForeignObjectRel",
|
||||
"ManyToOneRel",
|
||||
"ManyToManyRel",
|
||||
"OneToOneRel",
|
||||
]
|
@ -0,0 +1,210 @@
|
||||
"""
|
||||
Classes to represent the definitions of aggregate functions.
|
||||
"""
|
||||
from django.core.exceptions import FieldError, FullResultSet
|
||||
from django.db.models.expressions import Case, Func, Star, Value, When
|
||||
from django.db.models.fields import IntegerField
|
||||
from django.db.models.functions.comparison import Coalesce
|
||||
from django.db.models.functions.mixins import (
|
||||
FixDurationInputMixin,
|
||||
NumericOutputFieldMixin,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Aggregate",
|
||||
"Avg",
|
||||
"Count",
|
||||
"Max",
|
||||
"Min",
|
||||
"StdDev",
|
||||
"Sum",
|
||||
"Variance",
|
||||
]
|
||||
|
||||
|
||||
class Aggregate(Func):
|
||||
template = "%(function)s(%(distinct)s%(expressions)s)"
|
||||
contains_aggregate = True
|
||||
name = None
|
||||
filter_template = "%s FILTER (WHERE %%(filter)s)"
|
||||
window_compatible = True
|
||||
allow_distinct = False
|
||||
empty_result_set_value = None
|
||||
|
||||
def __init__(
|
||||
self, *expressions, distinct=False, filter=None, default=None, **extra
|
||||
):
|
||||
if distinct and not self.allow_distinct:
|
||||
raise TypeError("%s does not allow distinct." % self.__class__.__name__)
|
||||
if default is not None and self.empty_result_set_value is not None:
|
||||
raise TypeError(f"{self.__class__.__name__} does not allow default.")
|
||||
self.distinct = distinct
|
||||
self.filter = filter
|
||||
self.default = default
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
def get_source_fields(self):
|
||||
# Don't return the filter expression since it's not a source field.
|
||||
return [e._output_field_or_none for e in super().get_source_expressions()]
|
||||
|
||||
def get_source_expressions(self):
|
||||
source_expressions = super().get_source_expressions()
|
||||
if self.filter:
|
||||
return source_expressions + [self.filter]
|
||||
return source_expressions
|
||||
|
||||
def set_source_expressions(self, exprs):
|
||||
self.filter = self.filter and exprs.pop()
|
||||
return super().set_source_expressions(exprs)
|
||||
|
||||
def resolve_expression(
|
||||
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
|
||||
):
|
||||
# Aggregates are not allowed in UPDATE queries, so ignore for_save
|
||||
c = super().resolve_expression(query, allow_joins, reuse, summarize)
|
||||
c.filter = c.filter and c.filter.resolve_expression(
|
||||
query, allow_joins, reuse, summarize
|
||||
)
|
||||
if summarize:
|
||||
# Summarized aggregates cannot refer to summarized aggregates.
|
||||
for ref in c.get_refs():
|
||||
if query.annotations[ref].is_summary:
|
||||
raise FieldError(
|
||||
f"Cannot compute {c.name}('{ref}'): '{ref}' is an aggregate"
|
||||
)
|
||||
elif not self.is_summary:
|
||||
# Call Aggregate.get_source_expressions() to avoid
|
||||
# returning self.filter and including that in this loop.
|
||||
expressions = super(Aggregate, c).get_source_expressions()
|
||||
for index, expr in enumerate(expressions):
|
||||
if expr.contains_aggregate:
|
||||
before_resolved = self.get_source_expressions()[index]
|
||||
name = (
|
||||
before_resolved.name
|
||||
if hasattr(before_resolved, "name")
|
||||
else repr(before_resolved)
|
||||
)
|
||||
raise FieldError(
|
||||
"Cannot compute %s('%s'): '%s' is an aggregate"
|
||||
% (c.name, name, name)
|
||||
)
|
||||
if (default := c.default) is None:
|
||||
return c
|
||||
if hasattr(default, "resolve_expression"):
|
||||
default = default.resolve_expression(query, allow_joins, reuse, summarize)
|
||||
if default._output_field_or_none is None:
|
||||
default.output_field = c._output_field_or_none
|
||||
else:
|
||||
default = Value(default, c._output_field_or_none)
|
||||
c.default = None # Reset the default argument before wrapping.
|
||||
coalesce = Coalesce(c, default, output_field=c._output_field_or_none)
|
||||
coalesce.is_summary = c.is_summary
|
||||
return coalesce
|
||||
|
||||
@property
|
||||
def default_alias(self):
|
||||
expressions = self.get_source_expressions()
|
||||
if len(expressions) == 1 and hasattr(expressions[0], "name"):
|
||||
return "%s__%s" % (expressions[0].name, self.name.lower())
|
||||
raise TypeError("Complex expressions require an alias")
|
||||
|
||||
def get_group_by_cols(self):
|
||||
return []
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
extra_context["distinct"] = "DISTINCT " if self.distinct else ""
|
||||
if self.filter:
|
||||
if connection.features.supports_aggregate_filter_clause:
|
||||
try:
|
||||
filter_sql, filter_params = self.filter.as_sql(compiler, connection)
|
||||
except FullResultSet:
|
||||
pass
|
||||
else:
|
||||
template = self.filter_template % extra_context.get(
|
||||
"template", self.template
|
||||
)
|
||||
sql, params = super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template=template,
|
||||
filter=filter_sql,
|
||||
**extra_context,
|
||||
)
|
||||
return sql, (*params, *filter_params)
|
||||
else:
|
||||
copy = self.copy()
|
||||
copy.filter = None
|
||||
source_expressions = copy.get_source_expressions()
|
||||
condition = When(self.filter, then=source_expressions[0])
|
||||
copy.set_source_expressions([Case(condition)] + source_expressions[1:])
|
||||
return super(Aggregate, copy).as_sql(
|
||||
compiler, connection, **extra_context
|
||||
)
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
def _get_repr_options(self):
|
||||
options = super()._get_repr_options()
|
||||
if self.distinct:
|
||||
options["distinct"] = self.distinct
|
||||
if self.filter:
|
||||
options["filter"] = self.filter
|
||||
return options
|
||||
|
||||
|
||||
class Avg(FixDurationInputMixin, NumericOutputFieldMixin, Aggregate):
|
||||
function = "AVG"
|
||||
name = "Avg"
|
||||
allow_distinct = True
|
||||
|
||||
|
||||
class Count(Aggregate):
|
||||
function = "COUNT"
|
||||
name = "Count"
|
||||
output_field = IntegerField()
|
||||
allow_distinct = True
|
||||
empty_result_set_value = 0
|
||||
|
||||
def __init__(self, expression, filter=None, **extra):
|
||||
if expression == "*":
|
||||
expression = Star()
|
||||
if isinstance(expression, Star) and filter is not None:
|
||||
raise ValueError("Star cannot be used with filter. Please specify a field.")
|
||||
super().__init__(expression, filter=filter, **extra)
|
||||
|
||||
|
||||
class Max(Aggregate):
|
||||
function = "MAX"
|
||||
name = "Max"
|
||||
|
||||
|
||||
class Min(Aggregate):
|
||||
function = "MIN"
|
||||
name = "Min"
|
||||
|
||||
|
||||
class StdDev(NumericOutputFieldMixin, Aggregate):
|
||||
name = "StdDev"
|
||||
|
||||
def __init__(self, expression, sample=False, **extra):
|
||||
self.function = "STDDEV_SAMP" if sample else "STDDEV_POP"
|
||||
super().__init__(expression, **extra)
|
||||
|
||||
def _get_repr_options(self):
|
||||
return {**super()._get_repr_options(), "sample": self.function == "STDDEV_SAMP"}
|
||||
|
||||
|
||||
class Sum(FixDurationInputMixin, Aggregate):
|
||||
function = "SUM"
|
||||
name = "Sum"
|
||||
allow_distinct = True
|
||||
|
||||
|
||||
class Variance(NumericOutputFieldMixin, Aggregate):
|
||||
name = "Variance"
|
||||
|
||||
def __init__(self, expression, sample=False, **extra):
|
||||
self.function = "VAR_SAMP" if sample else "VAR_POP"
|
||||
super().__init__(expression, **extra)
|
||||
|
||||
def _get_repr_options(self):
|
||||
return {**super()._get_repr_options(), "sample": self.function == "VAR_SAMP"}
|
2531
srcs/.venv/lib/python3.11/site-packages/django/db/models/base.py
Normal file
2531
srcs/.venv/lib/python3.11/site-packages/django/db/models/base.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,12 @@
|
||||
"""
|
||||
Constants used across the ORM in general.
|
||||
"""
|
||||
from enum import Enum
|
||||
|
||||
# Separator used to split filter strings apart.
|
||||
LOOKUP_SEP = "__"
|
||||
|
||||
|
||||
class OnConflict(Enum):
|
||||
IGNORE = "ignore"
|
||||
UPDATE = "update"
|
@ -0,0 +1,371 @@
|
||||
from enum import Enum
|
||||
|
||||
from django.core.exceptions import FieldError, ValidationError
|
||||
from django.db import connections
|
||||
from django.db.models.expressions import Exists, ExpressionList, F, OrderBy
|
||||
from django.db.models.indexes import IndexExpression
|
||||
from django.db.models.lookups import Exact
|
||||
from django.db.models.query_utils import Q
|
||||
from django.db.models.sql.query import Query
|
||||
from django.db.utils import DEFAULT_DB_ALIAS
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
__all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"]
|
||||
|
||||
|
||||
class BaseConstraint:
|
||||
default_violation_error_message = _("Constraint “%(name)s” is violated.")
|
||||
violation_error_message = None
|
||||
|
||||
def __init__(self, name, violation_error_message=None):
|
||||
self.name = name
|
||||
if violation_error_message is not None:
|
||||
self.violation_error_message = violation_error_message
|
||||
else:
|
||||
self.violation_error_message = self.default_violation_error_message
|
||||
|
||||
@property
|
||||
def contains_expressions(self):
|
||||
return False
|
||||
|
||||
def constraint_sql(self, model, schema_editor):
|
||||
raise NotImplementedError("This method must be implemented by a subclass.")
|
||||
|
||||
def create_sql(self, model, schema_editor):
|
||||
raise NotImplementedError("This method must be implemented by a subclass.")
|
||||
|
||||
def remove_sql(self, model, schema_editor):
|
||||
raise NotImplementedError("This method must be implemented by a subclass.")
|
||||
|
||||
def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
|
||||
raise NotImplementedError("This method must be implemented by a subclass.")
|
||||
|
||||
def get_violation_error_message(self):
|
||||
return self.violation_error_message % {"name": self.name}
|
||||
|
||||
def deconstruct(self):
|
||||
path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
|
||||
path = path.replace("django.db.models.constraints", "django.db.models")
|
||||
kwargs = {"name": self.name}
|
||||
if (
|
||||
self.violation_error_message is not None
|
||||
and self.violation_error_message != self.default_violation_error_message
|
||||
):
|
||||
kwargs["violation_error_message"] = self.violation_error_message
|
||||
return (path, (), kwargs)
|
||||
|
||||
def clone(self):
|
||||
_, args, kwargs = self.deconstruct()
|
||||
return self.__class__(*args, **kwargs)
|
||||
|
||||
|
||||
class CheckConstraint(BaseConstraint):
|
||||
def __init__(self, *, check, name, violation_error_message=None):
|
||||
self.check = check
|
||||
if not getattr(check, "conditional", False):
|
||||
raise TypeError(
|
||||
"CheckConstraint.check must be a Q instance or boolean expression."
|
||||
)
|
||||
super().__init__(name, violation_error_message=violation_error_message)
|
||||
|
||||
def _get_check_sql(self, model, schema_editor):
|
||||
query = Query(model=model, alias_cols=False)
|
||||
where = query.build_where(self.check)
|
||||
compiler = query.get_compiler(connection=schema_editor.connection)
|
||||
sql, params = where.as_sql(compiler, schema_editor.connection)
|
||||
return sql % tuple(schema_editor.quote_value(p) for p in params)
|
||||
|
||||
def constraint_sql(self, model, schema_editor):
|
||||
check = self._get_check_sql(model, schema_editor)
|
||||
return schema_editor._check_sql(self.name, check)
|
||||
|
||||
def create_sql(self, model, schema_editor):
|
||||
check = self._get_check_sql(model, schema_editor)
|
||||
return schema_editor._create_check_sql(model, self.name, check)
|
||||
|
||||
def remove_sql(self, model, schema_editor):
|
||||
return schema_editor._delete_check_sql(model, self.name)
|
||||
|
||||
def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
|
||||
against = instance._get_field_value_map(meta=model._meta, exclude=exclude)
|
||||
try:
|
||||
if not Q(self.check).check(against, using=using):
|
||||
raise ValidationError(self.get_violation_error_message())
|
||||
except FieldError:
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s: check=%s name=%s>" % (
|
||||
self.__class__.__qualname__,
|
||||
self.check,
|
||||
repr(self.name),
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, CheckConstraint):
|
||||
return (
|
||||
self.name == other.name
|
||||
and self.check == other.check
|
||||
and self.violation_error_message == other.violation_error_message
|
||||
)
|
||||
return super().__eq__(other)
|
||||
|
||||
def deconstruct(self):
|
||||
path, args, kwargs = super().deconstruct()
|
||||
kwargs["check"] = self.check
|
||||
return path, args, kwargs
|
||||
|
||||
|
||||
class Deferrable(Enum):
|
||||
DEFERRED = "deferred"
|
||||
IMMEDIATE = "immediate"
|
||||
|
||||
# A similar format was proposed for Python 3.10.
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__qualname__}.{self._name_}"
|
||||
|
||||
|
||||
class UniqueConstraint(BaseConstraint):
|
||||
def __init__(
|
||||
self,
|
||||
*expressions,
|
||||
fields=(),
|
||||
name=None,
|
||||
condition=None,
|
||||
deferrable=None,
|
||||
include=None,
|
||||
opclasses=(),
|
||||
violation_error_message=None,
|
||||
):
|
||||
if not name:
|
||||
raise ValueError("A unique constraint must be named.")
|
||||
if not expressions and not fields:
|
||||
raise ValueError(
|
||||
"At least one field or expression is required to define a "
|
||||
"unique constraint."
|
||||
)
|
||||
if expressions and fields:
|
||||
raise ValueError(
|
||||
"UniqueConstraint.fields and expressions are mutually exclusive."
|
||||
)
|
||||
if not isinstance(condition, (type(None), Q)):
|
||||
raise ValueError("UniqueConstraint.condition must be a Q instance.")
|
||||
if condition and deferrable:
|
||||
raise ValueError("UniqueConstraint with conditions cannot be deferred.")
|
||||
if include and deferrable:
|
||||
raise ValueError("UniqueConstraint with include fields cannot be deferred.")
|
||||
if opclasses and deferrable:
|
||||
raise ValueError("UniqueConstraint with opclasses cannot be deferred.")
|
||||
if expressions and deferrable:
|
||||
raise ValueError("UniqueConstraint with expressions cannot be deferred.")
|
||||
if expressions and opclasses:
|
||||
raise ValueError(
|
||||
"UniqueConstraint.opclasses cannot be used with expressions. "
|
||||
"Use django.contrib.postgres.indexes.OpClass() instead."
|
||||
)
|
||||
if not isinstance(deferrable, (type(None), Deferrable)):
|
||||
raise ValueError(
|
||||
"UniqueConstraint.deferrable must be a Deferrable instance."
|
||||
)
|
||||
if not isinstance(include, (type(None), list, tuple)):
|
||||
raise ValueError("UniqueConstraint.include must be a list or tuple.")
|
||||
if not isinstance(opclasses, (list, tuple)):
|
||||
raise ValueError("UniqueConstraint.opclasses must be a list or tuple.")
|
||||
if opclasses and len(fields) != len(opclasses):
|
||||
raise ValueError(
|
||||
"UniqueConstraint.fields and UniqueConstraint.opclasses must "
|
||||
"have the same number of elements."
|
||||
)
|
||||
self.fields = tuple(fields)
|
||||
self.condition = condition
|
||||
self.deferrable = deferrable
|
||||
self.include = tuple(include) if include else ()
|
||||
self.opclasses = opclasses
|
||||
self.expressions = tuple(
|
||||
F(expression) if isinstance(expression, str) else expression
|
||||
for expression in expressions
|
||||
)
|
||||
super().__init__(name, violation_error_message=violation_error_message)
|
||||
|
||||
@property
|
||||
def contains_expressions(self):
|
||||
return bool(self.expressions)
|
||||
|
||||
def _get_condition_sql(self, model, schema_editor):
|
||||
if self.condition is None:
|
||||
return None
|
||||
query = Query(model=model, alias_cols=False)
|
||||
where = query.build_where(self.condition)
|
||||
compiler = query.get_compiler(connection=schema_editor.connection)
|
||||
sql, params = where.as_sql(compiler, schema_editor.connection)
|
||||
return sql % tuple(schema_editor.quote_value(p) for p in params)
|
||||
|
||||
def _get_index_expressions(self, model, schema_editor):
|
||||
if not self.expressions:
|
||||
return None
|
||||
index_expressions = []
|
||||
for expression in self.expressions:
|
||||
index_expression = IndexExpression(expression)
|
||||
index_expression.set_wrapper_classes(schema_editor.connection)
|
||||
index_expressions.append(index_expression)
|
||||
return ExpressionList(*index_expressions).resolve_expression(
|
||||
Query(model, alias_cols=False),
|
||||
)
|
||||
|
||||
def constraint_sql(self, model, schema_editor):
|
||||
fields = [model._meta.get_field(field_name) for field_name in self.fields]
|
||||
include = [
|
||||
model._meta.get_field(field_name).column for field_name in self.include
|
||||
]
|
||||
condition = self._get_condition_sql(model, schema_editor)
|
||||
expressions = self._get_index_expressions(model, schema_editor)
|
||||
return schema_editor._unique_sql(
|
||||
model,
|
||||
fields,
|
||||
self.name,
|
||||
condition=condition,
|
||||
deferrable=self.deferrable,
|
||||
include=include,
|
||||
opclasses=self.opclasses,
|
||||
expressions=expressions,
|
||||
)
|
||||
|
||||
def create_sql(self, model, schema_editor):
|
||||
fields = [model._meta.get_field(field_name) for field_name in self.fields]
|
||||
include = [
|
||||
model._meta.get_field(field_name).column for field_name in self.include
|
||||
]
|
||||
condition = self._get_condition_sql(model, schema_editor)
|
||||
expressions = self._get_index_expressions(model, schema_editor)
|
||||
return schema_editor._create_unique_sql(
|
||||
model,
|
||||
fields,
|
||||
self.name,
|
||||
condition=condition,
|
||||
deferrable=self.deferrable,
|
||||
include=include,
|
||||
opclasses=self.opclasses,
|
||||
expressions=expressions,
|
||||
)
|
||||
|
||||
def remove_sql(self, model, schema_editor):
|
||||
condition = self._get_condition_sql(model, schema_editor)
|
||||
include = [
|
||||
model._meta.get_field(field_name).column for field_name in self.include
|
||||
]
|
||||
expressions = self._get_index_expressions(model, schema_editor)
|
||||
return schema_editor._delete_unique_sql(
|
||||
model,
|
||||
self.name,
|
||||
condition=condition,
|
||||
deferrable=self.deferrable,
|
||||
include=include,
|
||||
opclasses=self.opclasses,
|
||||
expressions=expressions,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s:%s%s%s%s%s%s%s>" % (
|
||||
self.__class__.__qualname__,
|
||||
"" if not self.fields else " fields=%s" % repr(self.fields),
|
||||
"" if not self.expressions else " expressions=%s" % repr(self.expressions),
|
||||
" name=%s" % repr(self.name),
|
||||
"" if self.condition is None else " condition=%s" % self.condition,
|
||||
"" if self.deferrable is None else " deferrable=%r" % self.deferrable,
|
||||
"" if not self.include else " include=%s" % repr(self.include),
|
||||
"" if not self.opclasses else " opclasses=%s" % repr(self.opclasses),
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, UniqueConstraint):
|
||||
return (
|
||||
self.name == other.name
|
||||
and self.fields == other.fields
|
||||
and self.condition == other.condition
|
||||
and self.deferrable == other.deferrable
|
||||
and self.include == other.include
|
||||
and self.opclasses == other.opclasses
|
||||
and self.expressions == other.expressions
|
||||
and self.violation_error_message == other.violation_error_message
|
||||
)
|
||||
return super().__eq__(other)
|
||||
|
||||
def deconstruct(self):
|
||||
path, args, kwargs = super().deconstruct()
|
||||
if self.fields:
|
||||
kwargs["fields"] = self.fields
|
||||
if self.condition:
|
||||
kwargs["condition"] = self.condition
|
||||
if self.deferrable:
|
||||
kwargs["deferrable"] = self.deferrable
|
||||
if self.include:
|
||||
kwargs["include"] = self.include
|
||||
if self.opclasses:
|
||||
kwargs["opclasses"] = self.opclasses
|
||||
return path, self.expressions, kwargs
|
||||
|
||||
def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
|
||||
queryset = model._default_manager.using(using)
|
||||
if self.fields:
|
||||
lookup_kwargs = {}
|
||||
for field_name in self.fields:
|
||||
if exclude and field_name in exclude:
|
||||
return
|
||||
field = model._meta.get_field(field_name)
|
||||
lookup_value = getattr(instance, field.attname)
|
||||
if lookup_value is None or (
|
||||
lookup_value == ""
|
||||
and connections[using].features.interprets_empty_strings_as_nulls
|
||||
):
|
||||
# A composite constraint containing NULL value cannot cause
|
||||
# a violation since NULL != NULL in SQL.
|
||||
return
|
||||
lookup_kwargs[field.name] = lookup_value
|
||||
queryset = queryset.filter(**lookup_kwargs)
|
||||
else:
|
||||
# Ignore constraints with excluded fields.
|
||||
if exclude:
|
||||
for expression in self.expressions:
|
||||
if hasattr(expression, "flatten"):
|
||||
for expr in expression.flatten():
|
||||
if isinstance(expr, F) and expr.name in exclude:
|
||||
return
|
||||
elif isinstance(expression, F) and expression.name in exclude:
|
||||
return
|
||||
replacements = {
|
||||
F(field): value
|
||||
for field, value in instance._get_field_value_map(
|
||||
meta=model._meta, exclude=exclude
|
||||
).items()
|
||||
}
|
||||
expressions = []
|
||||
for expr in self.expressions:
|
||||
# Ignore ordering.
|
||||
if isinstance(expr, OrderBy):
|
||||
expr = expr.expression
|
||||
expressions.append(Exact(expr, expr.replace_expressions(replacements)))
|
||||
queryset = queryset.filter(*expressions)
|
||||
model_class_pk = instance._get_pk_val(model._meta)
|
||||
if not instance._state.adding and model_class_pk is not None:
|
||||
queryset = queryset.exclude(pk=model_class_pk)
|
||||
if not self.condition:
|
||||
if queryset.exists():
|
||||
if self.expressions:
|
||||
raise ValidationError(self.get_violation_error_message())
|
||||
# When fields are defined, use the unique_error_message() for
|
||||
# backward compatibility.
|
||||
for model, constraints in instance.get_constraints():
|
||||
for constraint in constraints:
|
||||
if constraint is self:
|
||||
raise ValidationError(
|
||||
instance.unique_error_message(model, self.fields)
|
||||
)
|
||||
else:
|
||||
against = instance._get_field_value_map(meta=model._meta, exclude=exclude)
|
||||
try:
|
||||
if (self.condition & Exists(queryset.filter(self.condition))).check(
|
||||
against, using=using
|
||||
):
|
||||
raise ValidationError(self.get_violation_error_message())
|
||||
except FieldError:
|
||||
pass
|
@ -0,0 +1,522 @@
|
||||
from collections import Counter, defaultdict
|
||||
from functools import partial, reduce
|
||||
from itertools import chain
|
||||
from operator import attrgetter, or_
|
||||
|
||||
from django.db import IntegrityError, connections, models, transaction
|
||||
from django.db.models import query_utils, signals, sql
|
||||
|
||||
|
||||
class ProtectedError(IntegrityError):
|
||||
def __init__(self, msg, protected_objects):
|
||||
self.protected_objects = protected_objects
|
||||
super().__init__(msg, protected_objects)
|
||||
|
||||
|
||||
class RestrictedError(IntegrityError):
|
||||
def __init__(self, msg, restricted_objects):
|
||||
self.restricted_objects = restricted_objects
|
||||
super().__init__(msg, restricted_objects)
|
||||
|
||||
|
||||
def CASCADE(collector, field, sub_objs, using):
|
||||
collector.collect(
|
||||
sub_objs,
|
||||
source=field.remote_field.model,
|
||||
source_attr=field.name,
|
||||
nullable=field.null,
|
||||
fail_on_restricted=False,
|
||||
)
|
||||
if field.null and not connections[using].features.can_defer_constraint_checks:
|
||||
collector.add_field_update(field, None, sub_objs)
|
||||
|
||||
|
||||
def PROTECT(collector, field, sub_objs, using):
|
||||
raise ProtectedError(
|
||||
"Cannot delete some instances of model '%s' because they are "
|
||||
"referenced through a protected foreign key: '%s.%s'"
|
||||
% (
|
||||
field.remote_field.model.__name__,
|
||||
sub_objs[0].__class__.__name__,
|
||||
field.name,
|
||||
),
|
||||
sub_objs,
|
||||
)
|
||||
|
||||
|
||||
def RESTRICT(collector, field, sub_objs, using):
|
||||
collector.add_restricted_objects(field, sub_objs)
|
||||
collector.add_dependency(field.remote_field.model, field.model)
|
||||
|
||||
|
||||
def SET(value):
|
||||
if callable(value):
|
||||
|
||||
def set_on_delete(collector, field, sub_objs, using):
|
||||
collector.add_field_update(field, value(), sub_objs)
|
||||
|
||||
else:
|
||||
|
||||
def set_on_delete(collector, field, sub_objs, using):
|
||||
collector.add_field_update(field, value, sub_objs)
|
||||
|
||||
set_on_delete.deconstruct = lambda: ("django.db.models.SET", (value,), {})
|
||||
set_on_delete.lazy_sub_objs = True
|
||||
return set_on_delete
|
||||
|
||||
|
||||
def SET_NULL(collector, field, sub_objs, using):
|
||||
collector.add_field_update(field, None, sub_objs)
|
||||
|
||||
|
||||
SET_NULL.lazy_sub_objs = True
|
||||
|
||||
|
||||
def SET_DEFAULT(collector, field, sub_objs, using):
|
||||
collector.add_field_update(field, field.get_default(), sub_objs)
|
||||
|
||||
|
||||
SET_DEFAULT.lazy_sub_objs = True
|
||||
|
||||
|
||||
def DO_NOTHING(collector, field, sub_objs, using):
|
||||
pass
|
||||
|
||||
|
||||
def get_candidate_relations_to_delete(opts):
|
||||
# The candidate relations are the ones that come from N-1 and 1-1 relations.
|
||||
# N-N (i.e., many-to-many) relations aren't candidates for deletion.
|
||||
return (
|
||||
f
|
||||
for f in opts.get_fields(include_hidden=True)
|
||||
if f.auto_created and not f.concrete and (f.one_to_one or f.one_to_many)
|
||||
)
|
||||
|
||||
|
||||
class Collector:
|
||||
def __init__(self, using, origin=None):
|
||||
self.using = using
|
||||
# A Model or QuerySet object.
|
||||
self.origin = origin
|
||||
# Initially, {model: {instances}}, later values become lists.
|
||||
self.data = defaultdict(set)
|
||||
# {(field, value): [instances, …]}
|
||||
self.field_updates = defaultdict(list)
|
||||
# {model: {field: {instances}}}
|
||||
self.restricted_objects = defaultdict(partial(defaultdict, set))
|
||||
# fast_deletes is a list of queryset-likes that can be deleted without
|
||||
# fetching the objects into memory.
|
||||
self.fast_deletes = []
|
||||
|
||||
# Tracks deletion-order dependency for databases without transactions
|
||||
# or ability to defer constraint checks. Only concrete model classes
|
||||
# should be included, as the dependencies exist only between actual
|
||||
# database tables; proxy models are represented here by their concrete
|
||||
# parent.
|
||||
self.dependencies = defaultdict(set) # {model: {models}}
|
||||
|
||||
def add(self, objs, source=None, nullable=False, reverse_dependency=False):
|
||||
"""
|
||||
Add 'objs' to the collection of objects to be deleted. If the call is
|
||||
the result of a cascade, 'source' should be the model that caused it,
|
||||
and 'nullable' should be set to True if the relation can be null.
|
||||
|
||||
Return a list of all objects that were not already collected.
|
||||
"""
|
||||
if not objs:
|
||||
return []
|
||||
new_objs = []
|
||||
model = objs[0].__class__
|
||||
instances = self.data[model]
|
||||
for obj in objs:
|
||||
if obj not in instances:
|
||||
new_objs.append(obj)
|
||||
instances.update(new_objs)
|
||||
# Nullable relationships can be ignored -- they are nulled out before
|
||||
# deleting, and therefore do not affect the order in which objects have
|
||||
# to be deleted.
|
||||
if source is not None and not nullable:
|
||||
self.add_dependency(source, model, reverse_dependency=reverse_dependency)
|
||||
return new_objs
|
||||
|
||||
def add_dependency(self, model, dependency, reverse_dependency=False):
|
||||
if reverse_dependency:
|
||||
model, dependency = dependency, model
|
||||
self.dependencies[model._meta.concrete_model].add(
|
||||
dependency._meta.concrete_model
|
||||
)
|
||||
self.data.setdefault(dependency, self.data.default_factory())
|
||||
|
||||
def add_field_update(self, field, value, objs):
|
||||
"""
|
||||
Schedule a field update. 'objs' must be a homogeneous iterable
|
||||
collection of model instances (e.g. a QuerySet).
|
||||
"""
|
||||
self.field_updates[field, value].append(objs)
|
||||
|
||||
def add_restricted_objects(self, field, objs):
|
||||
if objs:
|
||||
model = objs[0].__class__
|
||||
self.restricted_objects[model][field].update(objs)
|
||||
|
||||
def clear_restricted_objects_from_set(self, model, objs):
|
||||
if model in self.restricted_objects:
|
||||
self.restricted_objects[model] = {
|
||||
field: items - objs
|
||||
for field, items in self.restricted_objects[model].items()
|
||||
}
|
||||
|
||||
def clear_restricted_objects_from_queryset(self, model, qs):
|
||||
if model in self.restricted_objects:
|
||||
objs = set(
|
||||
qs.filter(
|
||||
pk__in=[
|
||||
obj.pk
|
||||
for objs in self.restricted_objects[model].values()
|
||||
for obj in objs
|
||||
]
|
||||
)
|
||||
)
|
||||
self.clear_restricted_objects_from_set(model, objs)
|
||||
|
||||
def _has_signal_listeners(self, model):
|
||||
return signals.pre_delete.has_listeners(
|
||||
model
|
||||
) or signals.post_delete.has_listeners(model)
|
||||
|
||||
def can_fast_delete(self, objs, from_field=None):
|
||||
"""
|
||||
Determine if the objects in the given queryset-like or single object
|
||||
can be fast-deleted. This can be done if there are no cascades, no
|
||||
parents and no signal listeners for the object class.
|
||||
|
||||
The 'from_field' tells where we are coming from - we need this to
|
||||
determine if the objects are in fact to be deleted. Allow also
|
||||
skipping parent -> child -> parent chain preventing fast delete of
|
||||
the child.
|
||||
"""
|
||||
if from_field and from_field.remote_field.on_delete is not CASCADE:
|
||||
return False
|
||||
if hasattr(objs, "_meta"):
|
||||
model = objs._meta.model
|
||||
elif hasattr(objs, "model") and hasattr(objs, "_raw_delete"):
|
||||
model = objs.model
|
||||
else:
|
||||
return False
|
||||
if self._has_signal_listeners(model):
|
||||
return False
|
||||
# The use of from_field comes from the need to avoid cascade back to
|
||||
# parent when parent delete is cascading to child.
|
||||
opts = model._meta
|
||||
return (
|
||||
all(
|
||||
link == from_field
|
||||
for link in opts.concrete_model._meta.parents.values()
|
||||
)
|
||||
and
|
||||
# Foreign keys pointing to this model.
|
||||
all(
|
||||
related.field.remote_field.on_delete is DO_NOTHING
|
||||
for related in get_candidate_relations_to_delete(opts)
|
||||
)
|
||||
and (
|
||||
# Something like generic foreign key.
|
||||
not any(
|
||||
hasattr(field, "bulk_related_objects")
|
||||
for field in opts.private_fields
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def get_del_batches(self, objs, fields):
|
||||
"""
|
||||
Return the objs in suitably sized batches for the used connection.
|
||||
"""
|
||||
field_names = [field.name for field in fields]
|
||||
conn_batch_size = max(
|
||||
connections[self.using].ops.bulk_batch_size(field_names, objs), 1
|
||||
)
|
||||
if len(objs) > conn_batch_size:
|
||||
return [
|
||||
objs[i : i + conn_batch_size]
|
||||
for i in range(0, len(objs), conn_batch_size)
|
||||
]
|
||||
else:
|
||||
return [objs]
|
||||
|
||||
def collect(
|
||||
self,
|
||||
objs,
|
||||
source=None,
|
||||
nullable=False,
|
||||
collect_related=True,
|
||||
source_attr=None,
|
||||
reverse_dependency=False,
|
||||
keep_parents=False,
|
||||
fail_on_restricted=True,
|
||||
):
|
||||
"""
|
||||
Add 'objs' to the collection of objects to be deleted as well as all
|
||||
parent instances. 'objs' must be a homogeneous iterable collection of
|
||||
model instances (e.g. a QuerySet). If 'collect_related' is True,
|
||||
related objects will be handled by their respective on_delete handler.
|
||||
|
||||
If the call is the result of a cascade, 'source' should be the model
|
||||
that caused it and 'nullable' should be set to True, if the relation
|
||||
can be null.
|
||||
|
||||
If 'reverse_dependency' is True, 'source' will be deleted before the
|
||||
current model, rather than after. (Needed for cascading to parent
|
||||
models, the one case in which the cascade follows the forwards
|
||||
direction of an FK rather than the reverse direction.)
|
||||
|
||||
If 'keep_parents' is True, data of parent model's will be not deleted.
|
||||
|
||||
If 'fail_on_restricted' is False, error won't be raised even if it's
|
||||
prohibited to delete such objects due to RESTRICT, that defers
|
||||
restricted object checking in recursive calls where the top-level call
|
||||
may need to collect more objects to determine whether restricted ones
|
||||
can be deleted.
|
||||
"""
|
||||
if self.can_fast_delete(objs):
|
||||
self.fast_deletes.append(objs)
|
||||
return
|
||||
new_objs = self.add(
|
||||
objs, source, nullable, reverse_dependency=reverse_dependency
|
||||
)
|
||||
if not new_objs:
|
||||
return
|
||||
|
||||
model = new_objs[0].__class__
|
||||
|
||||
if not keep_parents:
|
||||
# Recursively collect concrete model's parent models, but not their
|
||||
# related objects. These will be found by meta.get_fields()
|
||||
concrete_model = model._meta.concrete_model
|
||||
for ptr in concrete_model._meta.parents.values():
|
||||
if ptr:
|
||||
parent_objs = [getattr(obj, ptr.name) for obj in new_objs]
|
||||
self.collect(
|
||||
parent_objs,
|
||||
source=model,
|
||||
source_attr=ptr.remote_field.related_name,
|
||||
collect_related=False,
|
||||
reverse_dependency=True,
|
||||
fail_on_restricted=False,
|
||||
)
|
||||
if not collect_related:
|
||||
return
|
||||
|
||||
if keep_parents:
|
||||
parents = set(model._meta.get_parent_list())
|
||||
model_fast_deletes = defaultdict(list)
|
||||
protected_objects = defaultdict(list)
|
||||
for related in get_candidate_relations_to_delete(model._meta):
|
||||
# Preserve parent reverse relationships if keep_parents=True.
|
||||
if keep_parents and related.model in parents:
|
||||
continue
|
||||
field = related.field
|
||||
on_delete = field.remote_field.on_delete
|
||||
if on_delete == DO_NOTHING:
|
||||
continue
|
||||
related_model = related.related_model
|
||||
if self.can_fast_delete(related_model, from_field=field):
|
||||
model_fast_deletes[related_model].append(field)
|
||||
continue
|
||||
batches = self.get_del_batches(new_objs, [field])
|
||||
for batch in batches:
|
||||
sub_objs = self.related_objects(related_model, [field], batch)
|
||||
# Non-referenced fields can be deferred if no signal receivers
|
||||
# are connected for the related model as they'll never be
|
||||
# exposed to the user. Skip field deferring when some
|
||||
# relationships are select_related as interactions between both
|
||||
# features are hard to get right. This should only happen in
|
||||
# the rare cases where .related_objects is overridden anyway.
|
||||
if not (
|
||||
sub_objs.query.select_related
|
||||
or self._has_signal_listeners(related_model)
|
||||
):
|
||||
referenced_fields = set(
|
||||
chain.from_iterable(
|
||||
(rf.attname for rf in rel.field.foreign_related_fields)
|
||||
for rel in get_candidate_relations_to_delete(
|
||||
related_model._meta
|
||||
)
|
||||
)
|
||||
)
|
||||
sub_objs = sub_objs.only(*tuple(referenced_fields))
|
||||
if getattr(on_delete, "lazy_sub_objs", False) or sub_objs:
|
||||
try:
|
||||
on_delete(self, field, sub_objs, self.using)
|
||||
except ProtectedError as error:
|
||||
key = "'%s.%s'" % (field.model.__name__, field.name)
|
||||
protected_objects[key] += error.protected_objects
|
||||
if protected_objects:
|
||||
raise ProtectedError(
|
||||
"Cannot delete some instances of model %r because they are "
|
||||
"referenced through protected foreign keys: %s."
|
||||
% (
|
||||
model.__name__,
|
||||
", ".join(protected_objects),
|
||||
),
|
||||
set(chain.from_iterable(protected_objects.values())),
|
||||
)
|
||||
for related_model, related_fields in model_fast_deletes.items():
|
||||
batches = self.get_del_batches(new_objs, related_fields)
|
||||
for batch in batches:
|
||||
sub_objs = self.related_objects(related_model, related_fields, batch)
|
||||
self.fast_deletes.append(sub_objs)
|
||||
for field in model._meta.private_fields:
|
||||
if hasattr(field, "bulk_related_objects"):
|
||||
# It's something like generic foreign key.
|
||||
sub_objs = field.bulk_related_objects(new_objs, self.using)
|
||||
self.collect(
|
||||
sub_objs, source=model, nullable=True, fail_on_restricted=False
|
||||
)
|
||||
|
||||
if fail_on_restricted:
|
||||
# Raise an error if collected restricted objects (RESTRICT) aren't
|
||||
# candidates for deletion also collected via CASCADE.
|
||||
for related_model, instances in self.data.items():
|
||||
self.clear_restricted_objects_from_set(related_model, instances)
|
||||
for qs in self.fast_deletes:
|
||||
self.clear_restricted_objects_from_queryset(qs.model, qs)
|
||||
if self.restricted_objects.values():
|
||||
restricted_objects = defaultdict(list)
|
||||
for related_model, fields in self.restricted_objects.items():
|
||||
for field, objs in fields.items():
|
||||
if objs:
|
||||
key = "'%s.%s'" % (related_model.__name__, field.name)
|
||||
restricted_objects[key] += objs
|
||||
if restricted_objects:
|
||||
raise RestrictedError(
|
||||
"Cannot delete some instances of model %r because "
|
||||
"they are referenced through restricted foreign keys: "
|
||||
"%s."
|
||||
% (
|
||||
model.__name__,
|
||||
", ".join(restricted_objects),
|
||||
),
|
||||
set(chain.from_iterable(restricted_objects.values())),
|
||||
)
|
||||
|
||||
def related_objects(self, related_model, related_fields, objs):
|
||||
"""
|
||||
Get a QuerySet of the related model to objs via related fields.
|
||||
"""
|
||||
predicate = query_utils.Q.create(
|
||||
[(f"{related_field.name}__in", objs) for related_field in related_fields],
|
||||
connector=query_utils.Q.OR,
|
||||
)
|
||||
return related_model._base_manager.using(self.using).filter(predicate)
|
||||
|
||||
def instances_with_model(self):
|
||||
for model, instances in self.data.items():
|
||||
for obj in instances:
|
||||
yield model, obj
|
||||
|
||||
def sort(self):
|
||||
sorted_models = []
|
||||
concrete_models = set()
|
||||
models = list(self.data)
|
||||
while len(sorted_models) < len(models):
|
||||
found = False
|
||||
for model in models:
|
||||
if model in sorted_models:
|
||||
continue
|
||||
dependencies = self.dependencies.get(model._meta.concrete_model)
|
||||
if not (dependencies and dependencies.difference(concrete_models)):
|
||||
sorted_models.append(model)
|
||||
concrete_models.add(model._meta.concrete_model)
|
||||
found = True
|
||||
if not found:
|
||||
return
|
||||
self.data = {model: self.data[model] for model in sorted_models}
|
||||
|
||||
def delete(self):
|
||||
# sort instance collections
|
||||
for model, instances in self.data.items():
|
||||
self.data[model] = sorted(instances, key=attrgetter("pk"))
|
||||
|
||||
# if possible, bring the models in an order suitable for databases that
|
||||
# don't support transactions or cannot defer constraint checks until the
|
||||
# end of a transaction.
|
||||
self.sort()
|
||||
# number of objects deleted for each model label
|
||||
deleted_counter = Counter()
|
||||
|
||||
# Optimize for the case with a single obj and no dependencies
|
||||
if len(self.data) == 1 and len(instances) == 1:
|
||||
instance = list(instances)[0]
|
||||
if self.can_fast_delete(instance):
|
||||
with transaction.mark_for_rollback_on_error(self.using):
|
||||
count = sql.DeleteQuery(model).delete_batch(
|
||||
[instance.pk], self.using
|
||||
)
|
||||
setattr(instance, model._meta.pk.attname, None)
|
||||
return count, {model._meta.label: count}
|
||||
|
||||
with transaction.atomic(using=self.using, savepoint=False):
|
||||
# send pre_delete signals
|
||||
for model, obj in self.instances_with_model():
|
||||
if not model._meta.auto_created:
|
||||
signals.pre_delete.send(
|
||||
sender=model,
|
||||
instance=obj,
|
||||
using=self.using,
|
||||
origin=self.origin,
|
||||
)
|
||||
|
||||
# fast deletes
|
||||
for qs in self.fast_deletes:
|
||||
count = qs._raw_delete(using=self.using)
|
||||
if count:
|
||||
deleted_counter[qs.model._meta.label] += count
|
||||
|
||||
# update fields
|
||||
for (field, value), instances_list in self.field_updates.items():
|
||||
updates = []
|
||||
objs = []
|
||||
for instances in instances_list:
|
||||
if (
|
||||
isinstance(instances, models.QuerySet)
|
||||
and instances._result_cache is None
|
||||
):
|
||||
updates.append(instances)
|
||||
else:
|
||||
objs.extend(instances)
|
||||
if updates:
|
||||
combined_updates = reduce(or_, updates)
|
||||
combined_updates.update(**{field.name: value})
|
||||
if objs:
|
||||
model = objs[0].__class__
|
||||
query = sql.UpdateQuery(model)
|
||||
query.update_batch(
|
||||
list({obj.pk for obj in objs}), {field.name: value}, self.using
|
||||
)
|
||||
|
||||
# reverse instance collections
|
||||
for instances in self.data.values():
|
||||
instances.reverse()
|
||||
|
||||
# delete instances
|
||||
for model, instances in self.data.items():
|
||||
query = sql.DeleteQuery(model)
|
||||
pk_list = [obj.pk for obj in instances]
|
||||
count = query.delete_batch(pk_list, self.using)
|
||||
if count:
|
||||
deleted_counter[model._meta.label] += count
|
||||
|
||||
if not model._meta.auto_created:
|
||||
for obj in instances:
|
||||
signals.post_delete.send(
|
||||
sender=model,
|
||||
instance=obj,
|
||||
using=self.using,
|
||||
origin=self.origin,
|
||||
)
|
||||
|
||||
for model, instances in self.data.items():
|
||||
for instance in instances:
|
||||
setattr(instance, model._meta.pk.attname, None)
|
||||
return sum(deleted_counter.values()), dict(deleted_counter)
|
@ -0,0 +1,92 @@
|
||||
import enum
|
||||
from types import DynamicClassAttribute
|
||||
|
||||
from django.utils.functional import Promise
|
||||
|
||||
__all__ = ["Choices", "IntegerChoices", "TextChoices"]
|
||||
|
||||
|
||||
class ChoicesMeta(enum.EnumMeta):
|
||||
"""A metaclass for creating a enum choices."""
|
||||
|
||||
def __new__(metacls, classname, bases, classdict, **kwds):
|
||||
labels = []
|
||||
for key in classdict._member_names:
|
||||
value = classdict[key]
|
||||
if (
|
||||
isinstance(value, (list, tuple))
|
||||
and len(value) > 1
|
||||
and isinstance(value[-1], (Promise, str))
|
||||
):
|
||||
*value, label = value
|
||||
value = tuple(value)
|
||||
else:
|
||||
label = key.replace("_", " ").title()
|
||||
labels.append(label)
|
||||
# Use dict.__setitem__() to suppress defenses against double
|
||||
# assignment in enum's classdict.
|
||||
dict.__setitem__(classdict, key, value)
|
||||
cls = super().__new__(metacls, classname, bases, classdict, **kwds)
|
||||
for member, label in zip(cls.__members__.values(), labels):
|
||||
member._label_ = label
|
||||
return enum.unique(cls)
|
||||
|
||||
def __contains__(cls, member):
|
||||
if not isinstance(member, enum.Enum):
|
||||
# Allow non-enums to match against member values.
|
||||
return any(x.value == member for x in cls)
|
||||
return super().__contains__(member)
|
||||
|
||||
@property
|
||||
def names(cls):
|
||||
empty = ["__empty__"] if hasattr(cls, "__empty__") else []
|
||||
return empty + [member.name for member in cls]
|
||||
|
||||
@property
|
||||
def choices(cls):
|
||||
empty = [(None, cls.__empty__)] if hasattr(cls, "__empty__") else []
|
||||
return empty + [(member.value, member.label) for member in cls]
|
||||
|
||||
@property
|
||||
def labels(cls):
|
||||
return [label for _, label in cls.choices]
|
||||
|
||||
@property
|
||||
def values(cls):
|
||||
return [value for value, _ in cls.choices]
|
||||
|
||||
|
||||
class Choices(enum.Enum, metaclass=ChoicesMeta):
|
||||
"""Class for creating enumerated choices."""
|
||||
|
||||
@DynamicClassAttribute
|
||||
def label(self):
|
||||
return self._label_
|
||||
|
||||
@property
|
||||
def do_not_call_in_templates(self):
|
||||
return True
|
||||
|
||||
def __str__(self):
|
||||
"""
|
||||
Use value when cast to str, so that Choices set as model instance
|
||||
attributes are rendered as expected in templates and similar contexts.
|
||||
"""
|
||||
return str(self.value)
|
||||
|
||||
# A similar format was proposed for Python 3.10.
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__qualname__}.{self._name_}"
|
||||
|
||||
|
||||
class IntegerChoices(int, Choices):
|
||||
"""Class for creating enumerated integer choices."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TextChoices(str, Choices):
|
||||
"""Class for creating enumerated string choices."""
|
||||
|
||||
def _generate_next_value_(name, start, count, last_values):
|
||||
return name
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,510 @@
|
||||
import datetime
|
||||
import posixpath
|
||||
|
||||
from django import forms
|
||||
from django.core import checks
|
||||
from django.core.files.base import File
|
||||
from django.core.files.images import ImageFile
|
||||
from django.core.files.storage import Storage, default_storage
|
||||
from django.core.files.utils import validate_file_name
|
||||
from django.db.models import signals
|
||||
from django.db.models.fields import Field
|
||||
from django.db.models.query_utils import DeferredAttribute
|
||||
from django.db.models.utils import AltersData
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class FieldFile(File, AltersData):
|
||||
def __init__(self, instance, field, name):
|
||||
super().__init__(None, name)
|
||||
self.instance = instance
|
||||
self.field = field
|
||||
self.storage = field.storage
|
||||
self._committed = True
|
||||
|
||||
def __eq__(self, other):
|
||||
# Older code may be expecting FileField values to be simple strings.
|
||||
# By overriding the == operator, it can remain backwards compatibility.
|
||||
if hasattr(other, "name"):
|
||||
return self.name == other.name
|
||||
return self.name == other
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.name)
|
||||
|
||||
# The standard File contains most of the necessary properties, but
|
||||
# FieldFiles can be instantiated without a name, so that needs to
|
||||
# be checked for here.
|
||||
|
||||
def _require_file(self):
|
||||
if not self:
|
||||
raise ValueError(
|
||||
"The '%s' attribute has no file associated with it." % self.field.name
|
||||
)
|
||||
|
||||
def _get_file(self):
|
||||
self._require_file()
|
||||
if getattr(self, "_file", None) is None:
|
||||
self._file = self.storage.open(self.name, "rb")
|
||||
return self._file
|
||||
|
||||
def _set_file(self, file):
|
||||
self._file = file
|
||||
|
||||
def _del_file(self):
|
||||
del self._file
|
||||
|
||||
file = property(_get_file, _set_file, _del_file)
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
self._require_file()
|
||||
return self.storage.path(self.name)
|
||||
|
||||
@property
|
||||
def url(self):
|
||||
self._require_file()
|
||||
return self.storage.url(self.name)
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
self._require_file()
|
||||
if not self._committed:
|
||||
return self.file.size
|
||||
return self.storage.size(self.name)
|
||||
|
||||
def open(self, mode="rb"):
|
||||
self._require_file()
|
||||
if getattr(self, "_file", None) is None:
|
||||
self.file = self.storage.open(self.name, mode)
|
||||
else:
|
||||
self.file.open(mode)
|
||||
return self
|
||||
|
||||
# open() doesn't alter the file's contents, but it does reset the pointer
|
||||
open.alters_data = True
|
||||
|
||||
# In addition to the standard File API, FieldFiles have extra methods
|
||||
# to further manipulate the underlying file, as well as update the
|
||||
# associated model instance.
|
||||
|
||||
def save(self, name, content, save=True):
|
||||
name = self.field.generate_filename(self.instance, name)
|
||||
self.name = self.storage.save(name, content, max_length=self.field.max_length)
|
||||
setattr(self.instance, self.field.attname, self.name)
|
||||
self._committed = True
|
||||
|
||||
# Save the object because it has changed, unless save is False
|
||||
if save:
|
||||
self.instance.save()
|
||||
|
||||
save.alters_data = True
|
||||
|
||||
def delete(self, save=True):
|
||||
if not self:
|
||||
return
|
||||
# Only close the file if it's already open, which we know by the
|
||||
# presence of self._file
|
||||
if hasattr(self, "_file"):
|
||||
self.close()
|
||||
del self.file
|
||||
|
||||
self.storage.delete(self.name)
|
||||
|
||||
self.name = None
|
||||
setattr(self.instance, self.field.attname, self.name)
|
||||
self._committed = False
|
||||
|
||||
if save:
|
||||
self.instance.save()
|
||||
|
||||
delete.alters_data = True
|
||||
|
||||
@property
|
||||
def closed(self):
|
||||
file = getattr(self, "_file", None)
|
||||
return file is None or file.closed
|
||||
|
||||
def close(self):
|
||||
file = getattr(self, "_file", None)
|
||||
if file is not None:
|
||||
file.close()
|
||||
|
||||
def __getstate__(self):
|
||||
# FieldFile needs access to its associated model field, an instance and
|
||||
# the file's name. Everything else will be restored later, by
|
||||
# FileDescriptor below.
|
||||
return {
|
||||
"name": self.name,
|
||||
"closed": False,
|
||||
"_committed": True,
|
||||
"_file": None,
|
||||
"instance": self.instance,
|
||||
"field": self.field,
|
||||
}
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__.update(state)
|
||||
self.storage = self.field.storage
|
||||
|
||||
|
||||
class FileDescriptor(DeferredAttribute):
|
||||
"""
|
||||
The descriptor for the file attribute on the model instance. Return a
|
||||
FieldFile when accessed so you can write code like::
|
||||
|
||||
>>> from myapp.models import MyModel
|
||||
>>> instance = MyModel.objects.get(pk=1)
|
||||
>>> instance.file.size
|
||||
|
||||
Assign a file object on assignment so you can do::
|
||||
|
||||
>>> with open('/path/to/hello.world') as f:
|
||||
... instance.file = File(f)
|
||||
"""
|
||||
|
||||
def __get__(self, instance, cls=None):
|
||||
if instance is None:
|
||||
return self
|
||||
|
||||
# This is slightly complicated, so worth an explanation.
|
||||
# instance.file needs to ultimately return some instance of `File`,
|
||||
# probably a subclass. Additionally, this returned object needs to have
|
||||
# the FieldFile API so that users can easily do things like
|
||||
# instance.file.path and have that delegated to the file storage engine.
|
||||
# Easy enough if we're strict about assignment in __set__, but if you
|
||||
# peek below you can see that we're not. So depending on the current
|
||||
# value of the field we have to dynamically construct some sort of
|
||||
# "thing" to return.
|
||||
|
||||
# The instance dict contains whatever was originally assigned
|
||||
# in __set__.
|
||||
file = super().__get__(instance, cls)
|
||||
|
||||
# If this value is a string (instance.file = "path/to/file") or None
|
||||
# then we simply wrap it with the appropriate attribute class according
|
||||
# to the file field. [This is FieldFile for FileFields and
|
||||
# ImageFieldFile for ImageFields; it's also conceivable that user
|
||||
# subclasses might also want to subclass the attribute class]. This
|
||||
# object understands how to convert a path to a file, and also how to
|
||||
# handle None.
|
||||
if isinstance(file, str) or file is None:
|
||||
attr = self.field.attr_class(instance, self.field, file)
|
||||
instance.__dict__[self.field.attname] = attr
|
||||
|
||||
# Other types of files may be assigned as well, but they need to have
|
||||
# the FieldFile interface added to them. Thus, we wrap any other type of
|
||||
# File inside a FieldFile (well, the field's attr_class, which is
|
||||
# usually FieldFile).
|
||||
elif isinstance(file, File) and not isinstance(file, FieldFile):
|
||||
file_copy = self.field.attr_class(instance, self.field, file.name)
|
||||
file_copy.file = file
|
||||
file_copy._committed = False
|
||||
instance.__dict__[self.field.attname] = file_copy
|
||||
|
||||
# Finally, because of the (some would say boneheaded) way pickle works,
|
||||
# the underlying FieldFile might not actually itself have an associated
|
||||
# file. So we need to reset the details of the FieldFile in those cases.
|
||||
elif isinstance(file, FieldFile) and not hasattr(file, "field"):
|
||||
file.instance = instance
|
||||
file.field = self.field
|
||||
file.storage = self.field.storage
|
||||
|
||||
# Make sure that the instance is correct.
|
||||
elif isinstance(file, FieldFile) and instance is not file.instance:
|
||||
file.instance = instance
|
||||
|
||||
# That was fun, wasn't it?
|
||||
return instance.__dict__[self.field.attname]
|
||||
|
||||
def __set__(self, instance, value):
|
||||
instance.__dict__[self.field.attname] = value
|
||||
|
||||
|
||||
class FileField(Field):
|
||||
# The class to wrap instance attributes in. Accessing the file object off
|
||||
# the instance will always return an instance of attr_class.
|
||||
attr_class = FieldFile
|
||||
|
||||
# The descriptor to use for accessing the attribute off of the class.
|
||||
descriptor_class = FileDescriptor
|
||||
|
||||
description = _("File")
|
||||
|
||||
def __init__(
|
||||
self, verbose_name=None, name=None, upload_to="", storage=None, **kwargs
|
||||
):
|
||||
self._primary_key_set_explicitly = "primary_key" in kwargs
|
||||
|
||||
self.storage = storage or default_storage
|
||||
if callable(self.storage):
|
||||
# Hold a reference to the callable for deconstruct().
|
||||
self._storage_callable = self.storage
|
||||
self.storage = self.storage()
|
||||
if not isinstance(self.storage, Storage):
|
||||
raise TypeError(
|
||||
"%s.storage must be a subclass/instance of %s.%s"
|
||||
% (
|
||||
self.__class__.__qualname__,
|
||||
Storage.__module__,
|
||||
Storage.__qualname__,
|
||||
)
|
||||
)
|
||||
self.upload_to = upload_to
|
||||
|
||||
kwargs.setdefault("max_length", 100)
|
||||
super().__init__(verbose_name, name, **kwargs)
|
||||
|
||||
def check(self, **kwargs):
|
||||
return [
|
||||
*super().check(**kwargs),
|
||||
*self._check_primary_key(),
|
||||
*self._check_upload_to(),
|
||||
]
|
||||
|
||||
def _check_primary_key(self):
|
||||
if self._primary_key_set_explicitly:
|
||||
return [
|
||||
checks.Error(
|
||||
"'primary_key' is not a valid argument for a %s."
|
||||
% self.__class__.__name__,
|
||||
obj=self,
|
||||
id="fields.E201",
|
||||
)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def _check_upload_to(self):
|
||||
if isinstance(self.upload_to, str) and self.upload_to.startswith("/"):
|
||||
return [
|
||||
checks.Error(
|
||||
"%s's 'upload_to' argument must be a relative path, not an "
|
||||
"absolute path." % self.__class__.__name__,
|
||||
obj=self,
|
||||
id="fields.E202",
|
||||
hint="Remove the leading slash.",
|
||||
)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def deconstruct(self):
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
if kwargs.get("max_length") == 100:
|
||||
del kwargs["max_length"]
|
||||
kwargs["upload_to"] = self.upload_to
|
||||
storage = getattr(self, "_storage_callable", self.storage)
|
||||
if storage is not default_storage:
|
||||
kwargs["storage"] = storage
|
||||
return name, path, args, kwargs
|
||||
|
||||
def get_internal_type(self):
|
||||
return "FileField"
|
||||
|
||||
def get_prep_value(self, value):
|
||||
value = super().get_prep_value(value)
|
||||
# Need to convert File objects provided via a form to string for
|
||||
# database insertion.
|
||||
if value is None:
|
||||
return None
|
||||
return str(value)
|
||||
|
||||
def pre_save(self, model_instance, add):
|
||||
file = super().pre_save(model_instance, add)
|
||||
if file and not file._committed:
|
||||
# Commit the file to storage prior to saving the model
|
||||
file.save(file.name, file.file, save=False)
|
||||
return file
|
||||
|
||||
def contribute_to_class(self, cls, name, **kwargs):
|
||||
super().contribute_to_class(cls, name, **kwargs)
|
||||
setattr(cls, self.attname, self.descriptor_class(self))
|
||||
|
||||
def generate_filename(self, instance, filename):
|
||||
"""
|
||||
Apply (if callable) or prepend (if a string) upload_to to the filename,
|
||||
then delegate further processing of the name to the storage backend.
|
||||
Until the storage layer, all file paths are expected to be Unix style
|
||||
(with forward slashes).
|
||||
"""
|
||||
if callable(self.upload_to):
|
||||
filename = self.upload_to(instance, filename)
|
||||
else:
|
||||
dirname = datetime.datetime.now().strftime(str(self.upload_to))
|
||||
filename = posixpath.join(dirname, filename)
|
||||
filename = validate_file_name(filename, allow_relative_path=True)
|
||||
return self.storage.generate_filename(filename)
|
||||
|
||||
def save_form_data(self, instance, data):
|
||||
# Important: None means "no change", other false value means "clear"
|
||||
# This subtle distinction (rather than a more explicit marker) is
|
||||
# needed because we need to consume values that are also sane for a
|
||||
# regular (non Model-) Form to find in its cleaned_data dictionary.
|
||||
if data is not None:
|
||||
# This value will be converted to str and stored in the
|
||||
# database, so leaving False as-is is not acceptable.
|
||||
setattr(instance, self.name, data or "")
|
||||
|
||||
def formfield(self, **kwargs):
|
||||
return super().formfield(
|
||||
**{
|
||||
"form_class": forms.FileField,
|
||||
"max_length": self.max_length,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ImageFileDescriptor(FileDescriptor):
|
||||
"""
|
||||
Just like the FileDescriptor, but for ImageFields. The only difference is
|
||||
assigning the width/height to the width_field/height_field, if appropriate.
|
||||
"""
|
||||
|
||||
def __set__(self, instance, value):
|
||||
previous_file = instance.__dict__.get(self.field.attname)
|
||||
super().__set__(instance, value)
|
||||
|
||||
# To prevent recalculating image dimensions when we are instantiating
|
||||
# an object from the database (bug #11084), only update dimensions if
|
||||
# the field had a value before this assignment. Since the default
|
||||
# value for FileField subclasses is an instance of field.attr_class,
|
||||
# previous_file will only be None when we are called from
|
||||
# Model.__init__(). The ImageField.update_dimension_fields method
|
||||
# hooked up to the post_init signal handles the Model.__init__() cases.
|
||||
# Assignment happening outside of Model.__init__() will trigger the
|
||||
# update right here.
|
||||
if previous_file is not None:
|
||||
self.field.update_dimension_fields(instance, force=True)
|
||||
|
||||
|
||||
class ImageFieldFile(ImageFile, FieldFile):
|
||||
def delete(self, save=True):
|
||||
# Clear the image dimensions cache
|
||||
if hasattr(self, "_dimensions_cache"):
|
||||
del self._dimensions_cache
|
||||
super().delete(save)
|
||||
|
||||
|
||||
class ImageField(FileField):
|
||||
attr_class = ImageFieldFile
|
||||
descriptor_class = ImageFileDescriptor
|
||||
description = _("Image")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
verbose_name=None,
|
||||
name=None,
|
||||
width_field=None,
|
||||
height_field=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.width_field, self.height_field = width_field, height_field
|
||||
super().__init__(verbose_name, name, **kwargs)
|
||||
|
||||
def check(self, **kwargs):
|
||||
return [
|
||||
*super().check(**kwargs),
|
||||
*self._check_image_library_installed(),
|
||||
]
|
||||
|
||||
def _check_image_library_installed(self):
|
||||
try:
|
||||
from PIL import Image # NOQA
|
||||
except ImportError:
|
||||
return [
|
||||
checks.Error(
|
||||
"Cannot use ImageField because Pillow is not installed.",
|
||||
hint=(
|
||||
"Get Pillow at https://pypi.org/project/Pillow/ "
|
||||
'or run command "python -m pip install Pillow".'
|
||||
),
|
||||
obj=self,
|
||||
id="fields.E210",
|
||||
)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def deconstruct(self):
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
if self.width_field:
|
||||
kwargs["width_field"] = self.width_field
|
||||
if self.height_field:
|
||||
kwargs["height_field"] = self.height_field
|
||||
return name, path, args, kwargs
|
||||
|
||||
def contribute_to_class(self, cls, name, **kwargs):
|
||||
super().contribute_to_class(cls, name, **kwargs)
|
||||
# Attach update_dimension_fields so that dimension fields declared
|
||||
# after their corresponding image field don't stay cleared by
|
||||
# Model.__init__, see bug #11196.
|
||||
# Only run post-initialization dimension update on non-abstract models
|
||||
if not cls._meta.abstract:
|
||||
signals.post_init.connect(self.update_dimension_fields, sender=cls)
|
||||
|
||||
def update_dimension_fields(self, instance, force=False, *args, **kwargs):
|
||||
"""
|
||||
Update field's width and height fields, if defined.
|
||||
|
||||
This method is hooked up to model's post_init signal to update
|
||||
dimensions after instantiating a model instance. However, dimensions
|
||||
won't be updated if the dimensions fields are already populated. This
|
||||
avoids unnecessary recalculation when loading an object from the
|
||||
database.
|
||||
|
||||
Dimensions can be forced to update with force=True, which is how
|
||||
ImageFileDescriptor.__set__ calls this method.
|
||||
"""
|
||||
# Nothing to update if the field doesn't have dimension fields or if
|
||||
# the field is deferred.
|
||||
has_dimension_fields = self.width_field or self.height_field
|
||||
if not has_dimension_fields or self.attname not in instance.__dict__:
|
||||
return
|
||||
|
||||
# getattr will call the ImageFileDescriptor's __get__ method, which
|
||||
# coerces the assigned value into an instance of self.attr_class
|
||||
# (ImageFieldFile in this case).
|
||||
file = getattr(instance, self.attname)
|
||||
|
||||
# Nothing to update if we have no file and not being forced to update.
|
||||
if not file and not force:
|
||||
return
|
||||
|
||||
dimension_fields_filled = not (
|
||||
(self.width_field and not getattr(instance, self.width_field))
|
||||
or (self.height_field and not getattr(instance, self.height_field))
|
||||
)
|
||||
# When both dimension fields have values, we are most likely loading
|
||||
# data from the database or updating an image field that already had
|
||||
# an image stored. In the first case, we don't want to update the
|
||||
# dimension fields because we are already getting their values from the
|
||||
# database. In the second case, we do want to update the dimensions
|
||||
# fields and will skip this return because force will be True since we
|
||||
# were called from ImageFileDescriptor.__set__.
|
||||
if dimension_fields_filled and not force:
|
||||
return
|
||||
|
||||
# file should be an instance of ImageFieldFile or should be None.
|
||||
if file:
|
||||
width = file.width
|
||||
height = file.height
|
||||
else:
|
||||
# No file, so clear dimensions fields.
|
||||
width = None
|
||||
height = None
|
||||
|
||||
# Update the width and height fields.
|
||||
if self.width_field:
|
||||
setattr(instance, self.width_field, width)
|
||||
if self.height_field:
|
||||
setattr(instance, self.height_field, height)
|
||||
|
||||
def formfield(self, **kwargs):
|
||||
return super().formfield(
|
||||
**{
|
||||
"form_class": forms.ImageField,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
@ -0,0 +1,638 @@
|
||||
import json
|
||||
import warnings
|
||||
|
||||
from django import forms
|
||||
from django.core import checks, exceptions
|
||||
from django.db import NotSupportedError, connections, router
|
||||
from django.db.models import expressions, lookups
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.db.models.fields import TextField
|
||||
from django.db.models.lookups import (
|
||||
FieldGetDbPrepValueMixin,
|
||||
PostgresOperatorLookup,
|
||||
Transform,
|
||||
)
|
||||
from django.utils.deprecation import RemovedInDjango51Warning
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from . import Field
|
||||
from .mixins import CheckFieldDefaultMixin
|
||||
|
||||
__all__ = ["JSONField"]
|
||||
|
||||
|
||||
class JSONField(CheckFieldDefaultMixin, Field):
|
||||
empty_strings_allowed = False
|
||||
description = _("A JSON object")
|
||||
default_error_messages = {
|
||||
"invalid": _("Value must be valid JSON."),
|
||||
}
|
||||
_default_hint = ("dict", "{}")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
verbose_name=None,
|
||||
name=None,
|
||||
encoder=None,
|
||||
decoder=None,
|
||||
**kwargs,
|
||||
):
|
||||
if encoder and not callable(encoder):
|
||||
raise ValueError("The encoder parameter must be a callable object.")
|
||||
if decoder and not callable(decoder):
|
||||
raise ValueError("The decoder parameter must be a callable object.")
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
super().__init__(verbose_name, name, **kwargs)
|
||||
|
||||
def check(self, **kwargs):
|
||||
errors = super().check(**kwargs)
|
||||
databases = kwargs.get("databases") or []
|
||||
errors.extend(self._check_supported(databases))
|
||||
return errors
|
||||
|
||||
def _check_supported(self, databases):
|
||||
errors = []
|
||||
for db in databases:
|
||||
if not router.allow_migrate_model(db, self.model):
|
||||
continue
|
||||
connection = connections[db]
|
||||
if (
|
||||
self.model._meta.required_db_vendor
|
||||
and self.model._meta.required_db_vendor != connection.vendor
|
||||
):
|
||||
continue
|
||||
if not (
|
||||
"supports_json_field" in self.model._meta.required_db_features
|
||||
or connection.features.supports_json_field
|
||||
):
|
||||
errors.append(
|
||||
checks.Error(
|
||||
"%s does not support JSONFields." % connection.display_name,
|
||||
obj=self.model,
|
||||
id="fields.E180",
|
||||
)
|
||||
)
|
||||
return errors
|
||||
|
||||
def deconstruct(self):
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
if self.encoder is not None:
|
||||
kwargs["encoder"] = self.encoder
|
||||
if self.decoder is not None:
|
||||
kwargs["decoder"] = self.decoder
|
||||
return name, path, args, kwargs
|
||||
|
||||
def from_db_value(self, value, expression, connection):
|
||||
if value is None:
|
||||
return value
|
||||
# Some backends (SQLite at least) extract non-string values in their
|
||||
# SQL datatypes.
|
||||
if isinstance(expression, KeyTransform) and not isinstance(value, str):
|
||||
return value
|
||||
try:
|
||||
return json.loads(value, cls=self.decoder)
|
||||
except json.JSONDecodeError:
|
||||
return value
|
||||
|
||||
def get_internal_type(self):
|
||||
return "JSONField"
|
||||
|
||||
def get_db_prep_value(self, value, connection, prepared=False):
|
||||
if not prepared:
|
||||
value = self.get_prep_value(value)
|
||||
# RemovedInDjango51Warning: When the deprecation ends, replace with:
|
||||
# if (
|
||||
# isinstance(value, expressions.Value)
|
||||
# and isinstance(value.output_field, JSONField)
|
||||
# ):
|
||||
# value = value.value
|
||||
# elif hasattr(value, "as_sql"): ...
|
||||
if isinstance(value, expressions.Value):
|
||||
if isinstance(value.value, str) and not isinstance(
|
||||
value.output_field, JSONField
|
||||
):
|
||||
try:
|
||||
value = json.loads(value.value, cls=self.decoder)
|
||||
except json.JSONDecodeError:
|
||||
value = value.value
|
||||
else:
|
||||
warnings.warn(
|
||||
"Providing an encoded JSON string via Value() is deprecated. "
|
||||
f"Use Value({value!r}, output_field=JSONField()) instead.",
|
||||
category=RemovedInDjango51Warning,
|
||||
)
|
||||
elif isinstance(value.output_field, JSONField):
|
||||
value = value.value
|
||||
else:
|
||||
return value
|
||||
elif hasattr(value, "as_sql"):
|
||||
return value
|
||||
return connection.ops.adapt_json_value(value, self.encoder)
|
||||
|
||||
def get_db_prep_save(self, value, connection):
|
||||
if value is None:
|
||||
return value
|
||||
return self.get_db_prep_value(value, connection)
|
||||
|
||||
def get_transform(self, name):
|
||||
transform = super().get_transform(name)
|
||||
if transform:
|
||||
return transform
|
||||
return KeyTransformFactory(name)
|
||||
|
||||
def validate(self, value, model_instance):
|
||||
super().validate(value, model_instance)
|
||||
try:
|
||||
json.dumps(value, cls=self.encoder)
|
||||
except TypeError:
|
||||
raise exceptions.ValidationError(
|
||||
self.error_messages["invalid"],
|
||||
code="invalid",
|
||||
params={"value": value},
|
||||
)
|
||||
|
||||
def value_to_string(self, obj):
|
||||
return self.value_from_object(obj)
|
||||
|
||||
def formfield(self, **kwargs):
|
||||
return super().formfield(
|
||||
**{
|
||||
"form_class": forms.JSONField,
|
||||
"encoder": self.encoder,
|
||||
"decoder": self.decoder,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def compile_json_path(key_transforms, include_root=True):
|
||||
path = ["$"] if include_root else []
|
||||
for key_transform in key_transforms:
|
||||
try:
|
||||
num = int(key_transform)
|
||||
except ValueError: # non-integer
|
||||
path.append(".")
|
||||
path.append(json.dumps(key_transform))
|
||||
else:
|
||||
path.append("[%s]" % num)
|
||||
return "".join(path)
|
||||
|
||||
|
||||
class DataContains(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
|
||||
lookup_name = "contains"
|
||||
postgres_operator = "@>"
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if not connection.features.supports_json_field_contains:
|
||||
raise NotSupportedError(
|
||||
"contains lookup is not supported on this database backend."
|
||||
)
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||
params = tuple(lhs_params) + tuple(rhs_params)
|
||||
return "JSON_CONTAINS(%s, %s)" % (lhs, rhs), params
|
||||
|
||||
|
||||
class ContainedBy(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
|
||||
lookup_name = "contained_by"
|
||||
postgres_operator = "<@"
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if not connection.features.supports_json_field_contains:
|
||||
raise NotSupportedError(
|
||||
"contained_by lookup is not supported on this database backend."
|
||||
)
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||
params = tuple(rhs_params) + tuple(lhs_params)
|
||||
return "JSON_CONTAINS(%s, %s)" % (rhs, lhs), params
|
||||
|
||||
|
||||
class HasKeyLookup(PostgresOperatorLookup):
|
||||
logical_operator = None
|
||||
|
||||
def compile_json_path_final_key(self, key_transform):
|
||||
# Compile the final key without interpreting ints as array elements.
|
||||
return ".%s" % json.dumps(key_transform)
|
||||
|
||||
def as_sql(self, compiler, connection, template=None):
|
||||
# Process JSON path from the left-hand side.
|
||||
if isinstance(self.lhs, KeyTransform):
|
||||
lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(
|
||||
compiler, connection
|
||||
)
|
||||
lhs_json_path = compile_json_path(lhs_key_transforms)
|
||||
else:
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
lhs_json_path = "$"
|
||||
sql = template % lhs
|
||||
# Process JSON path from the right-hand side.
|
||||
rhs = self.rhs
|
||||
rhs_params = []
|
||||
if not isinstance(rhs, (list, tuple)):
|
||||
rhs = [rhs]
|
||||
for key in rhs:
|
||||
if isinstance(key, KeyTransform):
|
||||
*_, rhs_key_transforms = key.preprocess_lhs(compiler, connection)
|
||||
else:
|
||||
rhs_key_transforms = [key]
|
||||
*rhs_key_transforms, final_key = rhs_key_transforms
|
||||
rhs_json_path = compile_json_path(rhs_key_transforms, include_root=False)
|
||||
rhs_json_path += self.compile_json_path_final_key(final_key)
|
||||
rhs_params.append(lhs_json_path + rhs_json_path)
|
||||
# Add condition for each key.
|
||||
if self.logical_operator:
|
||||
sql = "(%s)" % self.logical_operator.join([sql] * len(rhs_params))
|
||||
return sql, tuple(lhs_params) + tuple(rhs_params)
|
||||
|
||||
def as_mysql(self, compiler, connection):
|
||||
return self.as_sql(
|
||||
compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)"
|
||||
)
|
||||
|
||||
def as_oracle(self, compiler, connection):
|
||||
sql, params = self.as_sql(
|
||||
compiler, connection, template="JSON_EXISTS(%s, '%%s')"
|
||||
)
|
||||
# Add paths directly into SQL because path expressions cannot be passed
|
||||
# as bind variables on Oracle.
|
||||
return sql % tuple(params), []
|
||||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
if isinstance(self.rhs, KeyTransform):
|
||||
*_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection)
|
||||
for key in rhs_key_transforms[:-1]:
|
||||
self.lhs = KeyTransform(key, self.lhs)
|
||||
self.rhs = rhs_key_transforms[-1]
|
||||
return super().as_postgresql(compiler, connection)
|
||||
|
||||
def as_sqlite(self, compiler, connection):
|
||||
return self.as_sql(
|
||||
compiler, connection, template="JSON_TYPE(%s, %%s) IS NOT NULL"
|
||||
)
|
||||
|
||||
|
||||
class HasKey(HasKeyLookup):
|
||||
lookup_name = "has_key"
|
||||
postgres_operator = "?"
|
||||
prepare_rhs = False
|
||||
|
||||
|
||||
class HasKeys(HasKeyLookup):
|
||||
lookup_name = "has_keys"
|
||||
postgres_operator = "?&"
|
||||
logical_operator = " AND "
|
||||
|
||||
def get_prep_lookup(self):
|
||||
return [str(item) for item in self.rhs]
|
||||
|
||||
|
||||
class HasAnyKeys(HasKeys):
|
||||
lookup_name = "has_any_keys"
|
||||
postgres_operator = "?|"
|
||||
logical_operator = " OR "
|
||||
|
||||
|
||||
class HasKeyOrArrayIndex(HasKey):
|
||||
def compile_json_path_final_key(self, key_transform):
|
||||
return compile_json_path([key_transform], include_root=False)
|
||||
|
||||
|
||||
class CaseInsensitiveMixin:
|
||||
"""
|
||||
Mixin to allow case-insensitive comparison of JSON values on MySQL.
|
||||
MySQL handles strings used in JSON context using the utf8mb4_bin collation.
|
||||
Because utf8mb4_bin is a binary collation, comparison of JSON values is
|
||||
case-sensitive.
|
||||
"""
|
||||
|
||||
def process_lhs(self, compiler, connection):
|
||||
lhs, lhs_params = super().process_lhs(compiler, connection)
|
||||
if connection.vendor == "mysql":
|
||||
return "LOWER(%s)" % lhs, lhs_params
|
||||
return lhs, lhs_params
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
if connection.vendor == "mysql":
|
||||
return "LOWER(%s)" % rhs, rhs_params
|
||||
return rhs, rhs_params
|
||||
|
||||
|
||||
class JSONExact(lookups.Exact):
|
||||
can_use_none_as_rhs = True
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
# Treat None lookup values as null.
|
||||
if rhs == "%s" and rhs_params == [None]:
|
||||
rhs_params = ["null"]
|
||||
if connection.vendor == "mysql":
|
||||
func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params)
|
||||
rhs %= tuple(func)
|
||||
return rhs, rhs_params
|
||||
|
||||
|
||||
class JSONIContains(CaseInsensitiveMixin, lookups.IContains):
|
||||
pass
|
||||
|
||||
|
||||
JSONField.register_lookup(DataContains)
|
||||
JSONField.register_lookup(ContainedBy)
|
||||
JSONField.register_lookup(HasKey)
|
||||
JSONField.register_lookup(HasKeys)
|
||||
JSONField.register_lookup(HasAnyKeys)
|
||||
JSONField.register_lookup(JSONExact)
|
||||
JSONField.register_lookup(JSONIContains)
|
||||
|
||||
|
||||
class KeyTransform(Transform):
|
||||
postgres_operator = "->"
|
||||
postgres_nested_operator = "#>"
|
||||
|
||||
def __init__(self, key_name, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.key_name = str(key_name)
|
||||
|
||||
def preprocess_lhs(self, compiler, connection):
|
||||
key_transforms = [self.key_name]
|
||||
previous = self.lhs
|
||||
while isinstance(previous, KeyTransform):
|
||||
key_transforms.insert(0, previous.key_name)
|
||||
previous = previous.lhs
|
||||
lhs, params = compiler.compile(previous)
|
||||
if connection.vendor == "oracle":
|
||||
# Escape string-formatting.
|
||||
key_transforms = [key.replace("%", "%%") for key in key_transforms]
|
||||
return lhs, params, key_transforms
|
||||
|
||||
def as_mysql(self, compiler, connection):
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
json_path = compile_json_path(key_transforms)
|
||||
return "JSON_EXTRACT(%s, %%s)" % lhs, tuple(params) + (json_path,)
|
||||
|
||||
def as_oracle(self, compiler, connection):
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
json_path = compile_json_path(key_transforms)
|
||||
return (
|
||||
"COALESCE(JSON_QUERY(%s, '%s'), JSON_VALUE(%s, '%s'))"
|
||||
% ((lhs, json_path) * 2)
|
||||
), tuple(params) * 2
|
||||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
if len(key_transforms) > 1:
|
||||
sql = "(%s %s %%s)" % (lhs, self.postgres_nested_operator)
|
||||
return sql, tuple(params) + (key_transforms,)
|
||||
try:
|
||||
lookup = int(self.key_name)
|
||||
except ValueError:
|
||||
lookup = self.key_name
|
||||
return "(%s %s %%s)" % (lhs, self.postgres_operator), tuple(params) + (lookup,)
|
||||
|
||||
def as_sqlite(self, compiler, connection):
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
json_path = compile_json_path(key_transforms)
|
||||
datatype_values = ",".join(
|
||||
[repr(datatype) for datatype in connection.ops.jsonfield_datatype_values]
|
||||
)
|
||||
return (
|
||||
"(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) "
|
||||
"THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)"
|
||||
) % (lhs, datatype_values, lhs, lhs), (tuple(params) + (json_path,)) * 3
|
||||
|
||||
|
||||
class KeyTextTransform(KeyTransform):
|
||||
postgres_operator = "->>"
|
||||
postgres_nested_operator = "#>>"
|
||||
output_field = TextField()
|
||||
|
||||
def as_mysql(self, compiler, connection):
|
||||
if connection.mysql_is_mariadb:
|
||||
# MariaDB doesn't support -> and ->> operators (see MDEV-13594).
|
||||
sql, params = super().as_mysql(compiler, connection)
|
||||
return "JSON_UNQUOTE(%s)" % sql, params
|
||||
else:
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
json_path = compile_json_path(key_transforms)
|
||||
return "(%s ->> %%s)" % lhs, tuple(params) + (json_path,)
|
||||
|
||||
@classmethod
|
||||
def from_lookup(cls, lookup):
|
||||
transform, *keys = lookup.split(LOOKUP_SEP)
|
||||
if not keys:
|
||||
raise ValueError("Lookup must contain key or index transforms.")
|
||||
for key in keys:
|
||||
transform = cls(key, transform)
|
||||
return transform
|
||||
|
||||
|
||||
KT = KeyTextTransform.from_lookup
|
||||
|
||||
|
||||
class KeyTransformTextLookupMixin:
|
||||
"""
|
||||
Mixin for combining with a lookup expecting a text lhs from a JSONField
|
||||
key lookup. On PostgreSQL, make use of the ->> operator instead of casting
|
||||
key values to text and performing the lookup on the resulting
|
||||
representation.
|
||||
"""
|
||||
|
||||
def __init__(self, key_transform, *args, **kwargs):
|
||||
if not isinstance(key_transform, KeyTransform):
|
||||
raise TypeError(
|
||||
"Transform should be an instance of KeyTransform in order to "
|
||||
"use this lookup."
|
||||
)
|
||||
key_text_transform = KeyTextTransform(
|
||||
key_transform.key_name,
|
||||
*key_transform.source_expressions,
|
||||
**key_transform.extra,
|
||||
)
|
||||
super().__init__(key_text_transform, *args, **kwargs)
|
||||
|
||||
|
||||
class KeyTransformIsNull(lookups.IsNull):
|
||||
# key__isnull=False is the same as has_key='key'
|
||||
def as_oracle(self, compiler, connection):
|
||||
sql, params = HasKeyOrArrayIndex(
|
||||
self.lhs.lhs,
|
||||
self.lhs.key_name,
|
||||
).as_oracle(compiler, connection)
|
||||
if not self.rhs:
|
||||
return sql, params
|
||||
# Column doesn't have a key or IS NULL.
|
||||
lhs, lhs_params, _ = self.lhs.preprocess_lhs(compiler, connection)
|
||||
return "(NOT %s OR %s IS NULL)" % (sql, lhs), tuple(params) + tuple(lhs_params)
|
||||
|
||||
def as_sqlite(self, compiler, connection):
|
||||
template = "JSON_TYPE(%s, %%s) IS NULL"
|
||||
if not self.rhs:
|
||||
template = "JSON_TYPE(%s, %%s) IS NOT NULL"
|
||||
return HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name).as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template=template,
|
||||
)
|
||||
|
||||
|
||||
class KeyTransformIn(lookups.In):
|
||||
def resolve_expression_parameter(self, compiler, connection, sql, param):
|
||||
sql, params = super().resolve_expression_parameter(
|
||||
compiler,
|
||||
connection,
|
||||
sql,
|
||||
param,
|
||||
)
|
||||
if (
|
||||
not hasattr(param, "as_sql")
|
||||
and not connection.features.has_native_json_field
|
||||
):
|
||||
if connection.vendor == "oracle":
|
||||
value = json.loads(param)
|
||||
sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
|
||||
if isinstance(value, (list, dict)):
|
||||
sql %= "JSON_QUERY"
|
||||
else:
|
||||
sql %= "JSON_VALUE"
|
||||
elif connection.vendor == "mysql" or (
|
||||
connection.vendor == "sqlite"
|
||||
and params[0] not in connection.ops.jsonfield_datatype_values
|
||||
):
|
||||
sql = "JSON_EXTRACT(%s, '$')"
|
||||
if connection.vendor == "mysql" and connection.mysql_is_mariadb:
|
||||
sql = "JSON_UNQUOTE(%s)" % sql
|
||||
return sql, params
|
||||
|
||||
|
||||
class KeyTransformExact(JSONExact):
|
||||
def process_rhs(self, compiler, connection):
|
||||
if isinstance(self.rhs, KeyTransform):
|
||||
return super(lookups.Exact, self).process_rhs(compiler, connection)
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
if connection.vendor == "oracle":
|
||||
func = []
|
||||
sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
|
||||
for value in rhs_params:
|
||||
value = json.loads(value)
|
||||
if isinstance(value, (list, dict)):
|
||||
func.append(sql % "JSON_QUERY")
|
||||
else:
|
||||
func.append(sql % "JSON_VALUE")
|
||||
rhs %= tuple(func)
|
||||
elif connection.vendor == "sqlite":
|
||||
func = []
|
||||
for value in rhs_params:
|
||||
if value in connection.ops.jsonfield_datatype_values:
|
||||
func.append("%s")
|
||||
else:
|
||||
func.append("JSON_EXTRACT(%s, '$')")
|
||||
rhs %= tuple(func)
|
||||
return rhs, rhs_params
|
||||
|
||||
def as_oracle(self, compiler, connection):
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
if rhs_params == ["null"]:
|
||||
# Field has key and it's NULL.
|
||||
has_key_expr = HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name)
|
||||
has_key_sql, has_key_params = has_key_expr.as_oracle(compiler, connection)
|
||||
is_null_expr = self.lhs.get_lookup("isnull")(self.lhs, True)
|
||||
is_null_sql, is_null_params = is_null_expr.as_sql(compiler, connection)
|
||||
return (
|
||||
"%s AND %s" % (has_key_sql, is_null_sql),
|
||||
tuple(has_key_params) + tuple(is_null_params),
|
||||
)
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
|
||||
class KeyTransformIExact(
|
||||
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformIContains(
|
||||
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformIStartsWith(
|
||||
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformIEndsWith(
|
||||
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformIRegex(
|
||||
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformNumericLookupMixin:
|
||||
def process_rhs(self, compiler, connection):
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
if not connection.features.has_native_json_field:
|
||||
rhs_params = [json.loads(value) for value in rhs_params]
|
||||
return rhs, rhs_params
|
||||
|
||||
|
||||
class KeyTransformLt(KeyTransformNumericLookupMixin, lookups.LessThan):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformLte(KeyTransformNumericLookupMixin, lookups.LessThanOrEqual):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformGt(KeyTransformNumericLookupMixin, lookups.GreaterThan):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformGte(KeyTransformNumericLookupMixin, lookups.GreaterThanOrEqual):
|
||||
pass
|
||||
|
||||
|
||||
KeyTransform.register_lookup(KeyTransformIn)
|
||||
KeyTransform.register_lookup(KeyTransformExact)
|
||||
KeyTransform.register_lookup(KeyTransformIExact)
|
||||
KeyTransform.register_lookup(KeyTransformIsNull)
|
||||
KeyTransform.register_lookup(KeyTransformIContains)
|
||||
KeyTransform.register_lookup(KeyTransformStartsWith)
|
||||
KeyTransform.register_lookup(KeyTransformIStartsWith)
|
||||
KeyTransform.register_lookup(KeyTransformEndsWith)
|
||||
KeyTransform.register_lookup(KeyTransformIEndsWith)
|
||||
KeyTransform.register_lookup(KeyTransformRegex)
|
||||
KeyTransform.register_lookup(KeyTransformIRegex)
|
||||
|
||||
KeyTransform.register_lookup(KeyTransformLt)
|
||||
KeyTransform.register_lookup(KeyTransformLte)
|
||||
KeyTransform.register_lookup(KeyTransformGt)
|
||||
KeyTransform.register_lookup(KeyTransformGte)
|
||||
|
||||
|
||||
class KeyTransformFactory:
|
||||
def __init__(self, key_name):
|
||||
self.key_name = key_name
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return KeyTransform(self.key_name, *args, **kwargs)
|
@ -0,0 +1,59 @@
|
||||
from django.core import checks
|
||||
|
||||
NOT_PROVIDED = object()
|
||||
|
||||
|
||||
class FieldCacheMixin:
|
||||
"""Provide an API for working with the model's fields value cache."""
|
||||
|
||||
def get_cache_name(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_cached_value(self, instance, default=NOT_PROVIDED):
|
||||
cache_name = self.get_cache_name()
|
||||
try:
|
||||
return instance._state.fields_cache[cache_name]
|
||||
except KeyError:
|
||||
if default is NOT_PROVIDED:
|
||||
raise
|
||||
return default
|
||||
|
||||
def is_cached(self, instance):
|
||||
return self.get_cache_name() in instance._state.fields_cache
|
||||
|
||||
def set_cached_value(self, instance, value):
|
||||
instance._state.fields_cache[self.get_cache_name()] = value
|
||||
|
||||
def delete_cached_value(self, instance):
|
||||
del instance._state.fields_cache[self.get_cache_name()]
|
||||
|
||||
|
||||
class CheckFieldDefaultMixin:
|
||||
_default_hint = ("<valid default>", "<invalid default>")
|
||||
|
||||
def _check_default(self):
|
||||
if (
|
||||
self.has_default()
|
||||
and self.default is not None
|
||||
and not callable(self.default)
|
||||
):
|
||||
return [
|
||||
checks.Warning(
|
||||
"%s default should be a callable instead of an instance "
|
||||
"so that it's not shared between all field instances."
|
||||
% (self.__class__.__name__,),
|
||||
hint=(
|
||||
"Use a callable instead, e.g., use `%s` instead of "
|
||||
"`%s`." % self._default_hint
|
||||
),
|
||||
obj=self,
|
||||
id="fields.E010",
|
||||
)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def check(self, **kwargs):
|
||||
errors = super().check(**kwargs)
|
||||
errors.extend(self._check_default())
|
||||
return errors
|
@ -0,0 +1,18 @@
|
||||
"""
|
||||
Field-like classes that aren't really fields. It's easier to use objects that
|
||||
have the same attributes as fields sometimes (avoids a lot of special casing).
|
||||
"""
|
||||
|
||||
from django.db.models import fields
|
||||
|
||||
|
||||
class OrderWrt(fields.IntegerField):
|
||||
"""
|
||||
A proxy for the _order database field that is used when
|
||||
Meta.order_with_respect_to is specified.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs["name"] = "_order"
|
||||
kwargs["editable"] = False
|
||||
super().__init__(*args, **kwargs)
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,209 @@
|
||||
import warnings
|
||||
|
||||
from django.db.models.lookups import (
|
||||
Exact,
|
||||
GreaterThan,
|
||||
GreaterThanOrEqual,
|
||||
In,
|
||||
IsNull,
|
||||
LessThan,
|
||||
LessThanOrEqual,
|
||||
)
|
||||
from django.utils.deprecation import RemovedInDjango50Warning
|
||||
|
||||
|
||||
class MultiColSource:
|
||||
contains_aggregate = False
|
||||
contains_over_clause = False
|
||||
|
||||
def __init__(self, alias, targets, sources, field):
|
||||
self.targets, self.sources, self.field, self.alias = (
|
||||
targets,
|
||||
sources,
|
||||
field,
|
||||
alias,
|
||||
)
|
||||
self.output_field = self.field
|
||||
|
||||
def __repr__(self):
|
||||
return "{}({}, {})".format(self.__class__.__name__, self.alias, self.field)
|
||||
|
||||
def relabeled_clone(self, relabels):
|
||||
return self.__class__(
|
||||
relabels.get(self.alias, self.alias), self.targets, self.sources, self.field
|
||||
)
|
||||
|
||||
def get_lookup(self, lookup):
|
||||
return self.output_field.get_lookup(lookup)
|
||||
|
||||
def resolve_expression(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
|
||||
def get_normalized_value(value, lhs):
|
||||
from django.db.models import Model
|
||||
|
||||
if isinstance(value, Model):
|
||||
if value.pk is None:
|
||||
# When the deprecation ends, replace with:
|
||||
# raise ValueError(
|
||||
# "Model instances passed to related filters must be saved."
|
||||
# )
|
||||
warnings.warn(
|
||||
"Passing unsaved model instances to related filters is deprecated.",
|
||||
RemovedInDjango50Warning,
|
||||
)
|
||||
value_list = []
|
||||
sources = lhs.output_field.path_infos[-1].target_fields
|
||||
for source in sources:
|
||||
while not isinstance(value, source.model) and source.remote_field:
|
||||
source = source.remote_field.model._meta.get_field(
|
||||
source.remote_field.field_name
|
||||
)
|
||||
try:
|
||||
value_list.append(getattr(value, source.attname))
|
||||
except AttributeError:
|
||||
# A case like Restaurant.objects.filter(place=restaurant_instance),
|
||||
# where place is a OneToOneField and the primary key of Restaurant.
|
||||
return (value.pk,)
|
||||
return tuple(value_list)
|
||||
if not isinstance(value, tuple):
|
||||
return (value,)
|
||||
return value
|
||||
|
||||
|
||||
class RelatedIn(In):
|
||||
def get_prep_lookup(self):
|
||||
if not isinstance(self.lhs, MultiColSource):
|
||||
if self.rhs_is_direct_value():
|
||||
# If we get here, we are dealing with single-column relations.
|
||||
self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs]
|
||||
# We need to run the related field's get_prep_value(). Consider
|
||||
# case ForeignKey to IntegerField given value 'abc'. The
|
||||
# ForeignKey itself doesn't have validation for non-integers,
|
||||
# so we must run validation using the target field.
|
||||
if hasattr(self.lhs.output_field, "path_infos"):
|
||||
# Run the target field's get_prep_value. We can safely
|
||||
# assume there is only one as we don't get to the direct
|
||||
# value branch otherwise.
|
||||
target_field = self.lhs.output_field.path_infos[-1].target_fields[
|
||||
-1
|
||||
]
|
||||
self.rhs = [target_field.get_prep_value(v) for v in self.rhs]
|
||||
elif not getattr(self.rhs, "has_select_fields", True) and not getattr(
|
||||
self.lhs.field.target_field, "primary_key", False
|
||||
):
|
||||
if (
|
||||
getattr(self.lhs.output_field, "primary_key", False)
|
||||
and self.lhs.output_field.model == self.rhs.model
|
||||
):
|
||||
# A case like
|
||||
# Restaurant.objects.filter(place__in=restaurant_qs), where
|
||||
# place is a OneToOneField and the primary key of
|
||||
# Restaurant.
|
||||
target_field = self.lhs.field.name
|
||||
else:
|
||||
target_field = self.lhs.field.target_field.name
|
||||
self.rhs.set_values([target_field])
|
||||
return super().get_prep_lookup()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if isinstance(self.lhs, MultiColSource):
|
||||
# For multicolumn lookups we need to build a multicolumn where clause.
|
||||
# This clause is either a SubqueryConstraint (for values that need
|
||||
# to be compiled to SQL) or an OR-combined list of
|
||||
# (col1 = val1 AND col2 = val2 AND ...) clauses.
|
||||
from django.db.models.sql.where import (
|
||||
AND,
|
||||
OR,
|
||||
SubqueryConstraint,
|
||||
WhereNode,
|
||||
)
|
||||
|
||||
root_constraint = WhereNode(connector=OR)
|
||||
if self.rhs_is_direct_value():
|
||||
values = [get_normalized_value(value, self.lhs) for value in self.rhs]
|
||||
for value in values:
|
||||
value_constraint = WhereNode()
|
||||
for source, target, val in zip(
|
||||
self.lhs.sources, self.lhs.targets, value
|
||||
):
|
||||
lookup_class = target.get_lookup("exact")
|
||||
lookup = lookup_class(
|
||||
target.get_col(self.lhs.alias, source), val
|
||||
)
|
||||
value_constraint.add(lookup, AND)
|
||||
root_constraint.add(value_constraint, OR)
|
||||
else:
|
||||
root_constraint.add(
|
||||
SubqueryConstraint(
|
||||
self.lhs.alias,
|
||||
[target.column for target in self.lhs.targets],
|
||||
[source.name for source in self.lhs.sources],
|
||||
self.rhs,
|
||||
),
|
||||
AND,
|
||||
)
|
||||
return root_constraint.as_sql(compiler, connection)
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
|
||||
class RelatedLookupMixin:
|
||||
def get_prep_lookup(self):
|
||||
if not isinstance(self.lhs, MultiColSource) and not hasattr(
|
||||
self.rhs, "resolve_expression"
|
||||
):
|
||||
# If we get here, we are dealing with single-column relations.
|
||||
self.rhs = get_normalized_value(self.rhs, self.lhs)[0]
|
||||
# We need to run the related field's get_prep_value(). Consider case
|
||||
# ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
|
||||
# doesn't have validation for non-integers, so we must run validation
|
||||
# using the target field.
|
||||
if self.prepare_rhs and hasattr(self.lhs.output_field, "path_infos"):
|
||||
# Get the target field. We can safely assume there is only one
|
||||
# as we don't get to the direct value branch otherwise.
|
||||
target_field = self.lhs.output_field.path_infos[-1].target_fields[-1]
|
||||
self.rhs = target_field.get_prep_value(self.rhs)
|
||||
|
||||
return super().get_prep_lookup()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if isinstance(self.lhs, MultiColSource):
|
||||
assert self.rhs_is_direct_value()
|
||||
self.rhs = get_normalized_value(self.rhs, self.lhs)
|
||||
from django.db.models.sql.where import AND, WhereNode
|
||||
|
||||
root_constraint = WhereNode()
|
||||
for target, source, val in zip(
|
||||
self.lhs.targets, self.lhs.sources, self.rhs
|
||||
):
|
||||
lookup_class = target.get_lookup(self.lookup_name)
|
||||
root_constraint.add(
|
||||
lookup_class(target.get_col(self.lhs.alias, source), val), AND
|
||||
)
|
||||
return root_constraint.as_sql(compiler, connection)
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
|
||||
class RelatedExact(RelatedLookupMixin, Exact):
|
||||
pass
|
||||
|
||||
|
||||
class RelatedLessThan(RelatedLookupMixin, LessThan):
|
||||
pass
|
||||
|
||||
|
||||
class RelatedGreaterThan(RelatedLookupMixin, GreaterThan):
|
||||
pass
|
||||
|
||||
|
||||
class RelatedGreaterThanOrEqual(RelatedLookupMixin, GreaterThanOrEqual):
|
||||
pass
|
||||
|
||||
|
||||
class RelatedLessThanOrEqual(RelatedLookupMixin, LessThanOrEqual):
|
||||
pass
|
||||
|
||||
|
||||
class RelatedIsNull(RelatedLookupMixin, IsNull):
|
||||
pass
|
@ -0,0 +1,402 @@
|
||||
"""
|
||||
"Rel objects" for related fields.
|
||||
|
||||
"Rel objects" (for lack of a better name) carry information about the relation
|
||||
modeled by a related field and provide some utility functions. They're stored
|
||||
in the ``remote_field`` attribute of the field.
|
||||
|
||||
They also act as reverse fields for the purposes of the Meta API because
|
||||
they're the closest concept currently available.
|
||||
"""
|
||||
|
||||
from django.core import exceptions
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.hashable import make_hashable
|
||||
|
||||
from . import BLANK_CHOICE_DASH
|
||||
from .mixins import FieldCacheMixin
|
||||
|
||||
|
||||
class ForeignObjectRel(FieldCacheMixin):
|
||||
"""
|
||||
Used by ForeignObject to store information about the relation.
|
||||
|
||||
``_meta.get_fields()`` returns this class to provide access to the field
|
||||
flags for the reverse relation.
|
||||
"""
|
||||
|
||||
# Field flags
|
||||
auto_created = True
|
||||
concrete = False
|
||||
editable = False
|
||||
is_relation = True
|
||||
|
||||
# Reverse relations are always nullable (Django can't enforce that a
|
||||
# foreign key on the related model points to this model).
|
||||
null = True
|
||||
empty_strings_allowed = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
field,
|
||||
to,
|
||||
related_name=None,
|
||||
related_query_name=None,
|
||||
limit_choices_to=None,
|
||||
parent_link=False,
|
||||
on_delete=None,
|
||||
):
|
||||
self.field = field
|
||||
self.model = to
|
||||
self.related_name = related_name
|
||||
self.related_query_name = related_query_name
|
||||
self.limit_choices_to = {} if limit_choices_to is None else limit_choices_to
|
||||
self.parent_link = parent_link
|
||||
self.on_delete = on_delete
|
||||
|
||||
self.symmetrical = False
|
||||
self.multiple = True
|
||||
|
||||
# Some of the following cached_properties can't be initialized in
|
||||
# __init__ as the field doesn't have its model yet. Calling these methods
|
||||
# before field.contribute_to_class() has been called will result in
|
||||
# AttributeError
|
||||
@cached_property
|
||||
def hidden(self):
|
||||
return self.is_hidden()
|
||||
|
||||
@cached_property
|
||||
def name(self):
|
||||
return self.field.related_query_name()
|
||||
|
||||
@property
|
||||
def remote_field(self):
|
||||
return self.field
|
||||
|
||||
@property
|
||||
def target_field(self):
|
||||
"""
|
||||
When filtering against this relation, return the field on the remote
|
||||
model against which the filtering should happen.
|
||||
"""
|
||||
target_fields = self.path_infos[-1].target_fields
|
||||
if len(target_fields) > 1:
|
||||
raise exceptions.FieldError(
|
||||
"Can't use target_field for multicolumn relations."
|
||||
)
|
||||
return target_fields[0]
|
||||
|
||||
@cached_property
|
||||
def related_model(self):
|
||||
if not self.field.model:
|
||||
raise AttributeError(
|
||||
"This property can't be accessed before self.field.contribute_to_class "
|
||||
"has been called."
|
||||
)
|
||||
return self.field.model
|
||||
|
||||
@cached_property
|
||||
def many_to_many(self):
|
||||
return self.field.many_to_many
|
||||
|
||||
@cached_property
|
||||
def many_to_one(self):
|
||||
return self.field.one_to_many
|
||||
|
||||
@cached_property
|
||||
def one_to_many(self):
|
||||
return self.field.many_to_one
|
||||
|
||||
@cached_property
|
||||
def one_to_one(self):
|
||||
return self.field.one_to_one
|
||||
|
||||
def get_lookup(self, lookup_name):
|
||||
return self.field.get_lookup(lookup_name)
|
||||
|
||||
def get_lookups(self):
|
||||
return self.field.get_lookups()
|
||||
|
||||
def get_transform(self, name):
|
||||
return self.field.get_transform(name)
|
||||
|
||||
def get_internal_type(self):
|
||||
return self.field.get_internal_type()
|
||||
|
||||
@property
|
||||
def db_type(self):
|
||||
return self.field.db_type
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s: %s.%s>" % (
|
||||
type(self).__name__,
|
||||
self.related_model._meta.app_label,
|
||||
self.related_model._meta.model_name,
|
||||
)
|
||||
|
||||
@property
|
||||
def identity(self):
|
||||
return (
|
||||
self.field,
|
||||
self.model,
|
||||
self.related_name,
|
||||
self.related_query_name,
|
||||
make_hashable(self.limit_choices_to),
|
||||
self.parent_link,
|
||||
self.on_delete,
|
||||
self.symmetrical,
|
||||
self.multiple,
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, self.__class__):
|
||||
return NotImplemented
|
||||
return self.identity == other.identity
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.identity)
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
# Delete the path_infos cached property because it can be recalculated
|
||||
# at first invocation after deserialization. The attribute must be
|
||||
# removed because subclasses like ManyToOneRel may have a PathInfo
|
||||
# which contains an intermediate M2M table that's been dynamically
|
||||
# created and doesn't exist in the .models module.
|
||||
# This is a reverse relation, so there is no reverse_path_infos to
|
||||
# delete.
|
||||
state.pop("path_infos", None)
|
||||
return state
|
||||
|
||||
def get_choices(
|
||||
self,
|
||||
include_blank=True,
|
||||
blank_choice=BLANK_CHOICE_DASH,
|
||||
limit_choices_to=None,
|
||||
ordering=(),
|
||||
):
|
||||
"""
|
||||
Return choices with a default blank choices included, for use
|
||||
as <select> choices for this field.
|
||||
|
||||
Analog of django.db.models.fields.Field.get_choices(), provided
|
||||
initially for utilization by RelatedFieldListFilter.
|
||||
"""
|
||||
limit_choices_to = limit_choices_to or self.limit_choices_to
|
||||
qs = self.related_model._default_manager.complex_filter(limit_choices_to)
|
||||
if ordering:
|
||||
qs = qs.order_by(*ordering)
|
||||
return (blank_choice if include_blank else []) + [(x.pk, str(x)) for x in qs]
|
||||
|
||||
def is_hidden(self):
|
||||
"""Should the related object be hidden?"""
|
||||
return bool(self.related_name) and self.related_name[-1] == "+"
|
||||
|
||||
def get_joining_columns(self):
|
||||
return self.field.get_reverse_joining_columns()
|
||||
|
||||
def get_extra_restriction(self, alias, related_alias):
|
||||
return self.field.get_extra_restriction(related_alias, alias)
|
||||
|
||||
def set_field_name(self):
|
||||
"""
|
||||
Set the related field's name, this is not available until later stages
|
||||
of app loading, so set_field_name is called from
|
||||
set_attributes_from_rel()
|
||||
"""
|
||||
# By default foreign object doesn't relate to any remote field (for
|
||||
# example custom multicolumn joins currently have no remote field).
|
||||
self.field_name = None
|
||||
|
||||
def get_accessor_name(self, model=None):
|
||||
# This method encapsulates the logic that decides what name to give an
|
||||
# accessor descriptor that retrieves related many-to-one or
|
||||
# many-to-many objects. It uses the lowercased object_name + "_set",
|
||||
# but this can be overridden with the "related_name" option. Due to
|
||||
# backwards compatibility ModelForms need to be able to provide an
|
||||
# alternate model. See BaseInlineFormSet.get_default_prefix().
|
||||
opts = model._meta if model else self.related_model._meta
|
||||
model = model or self.related_model
|
||||
if self.multiple:
|
||||
# If this is a symmetrical m2m relation on self, there is no
|
||||
# reverse accessor.
|
||||
if self.symmetrical and model == self.model:
|
||||
return None
|
||||
if self.related_name:
|
||||
return self.related_name
|
||||
return opts.model_name + ("_set" if self.multiple else "")
|
||||
|
||||
def get_path_info(self, filtered_relation=None):
|
||||
if filtered_relation:
|
||||
return self.field.get_reverse_path_info(filtered_relation)
|
||||
else:
|
||||
return self.field.reverse_path_infos
|
||||
|
||||
@cached_property
|
||||
def path_infos(self):
|
||||
return self.get_path_info()
|
||||
|
||||
def get_cache_name(self):
|
||||
"""
|
||||
Return the name of the cache key to use for storing an instance of the
|
||||
forward model on the reverse model.
|
||||
"""
|
||||
return self.get_accessor_name()
|
||||
|
||||
|
||||
class ManyToOneRel(ForeignObjectRel):
|
||||
"""
|
||||
Used by the ForeignKey field to store information about the relation.
|
||||
|
||||
``_meta.get_fields()`` returns this class to provide access to the field
|
||||
flags for the reverse relation.
|
||||
|
||||
Note: Because we somewhat abuse the Rel objects by using them as reverse
|
||||
fields we get the funny situation where
|
||||
``ManyToOneRel.many_to_one == False`` and
|
||||
``ManyToOneRel.one_to_many == True``. This is unfortunate but the actual
|
||||
ManyToOneRel class is a private API and there is work underway to turn
|
||||
reverse relations into actual fields.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
field,
|
||||
to,
|
||||
field_name,
|
||||
related_name=None,
|
||||
related_query_name=None,
|
||||
limit_choices_to=None,
|
||||
parent_link=False,
|
||||
on_delete=None,
|
||||
):
|
||||
super().__init__(
|
||||
field,
|
||||
to,
|
||||
related_name=related_name,
|
||||
related_query_name=related_query_name,
|
||||
limit_choices_to=limit_choices_to,
|
||||
parent_link=parent_link,
|
||||
on_delete=on_delete,
|
||||
)
|
||||
|
||||
self.field_name = field_name
|
||||
|
||||
def __getstate__(self):
|
||||
state = super().__getstate__()
|
||||
state.pop("related_model", None)
|
||||
return state
|
||||
|
||||
@property
|
||||
def identity(self):
|
||||
return super().identity + (self.field_name,)
|
||||
|
||||
def get_related_field(self):
|
||||
"""
|
||||
Return the Field in the 'to' object to which this relationship is tied.
|
||||
"""
|
||||
field = self.model._meta.get_field(self.field_name)
|
||||
if not field.concrete:
|
||||
raise exceptions.FieldDoesNotExist(
|
||||
"No related field named '%s'" % self.field_name
|
||||
)
|
||||
return field
|
||||
|
||||
def set_field_name(self):
|
||||
self.field_name = self.field_name or self.model._meta.pk.name
|
||||
|
||||
|
||||
class OneToOneRel(ManyToOneRel):
|
||||
"""
|
||||
Used by OneToOneField to store information about the relation.
|
||||
|
||||
``_meta.get_fields()`` returns this class to provide access to the field
|
||||
flags for the reverse relation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
field,
|
||||
to,
|
||||
field_name,
|
||||
related_name=None,
|
||||
related_query_name=None,
|
||||
limit_choices_to=None,
|
||||
parent_link=False,
|
||||
on_delete=None,
|
||||
):
|
||||
super().__init__(
|
||||
field,
|
||||
to,
|
||||
field_name,
|
||||
related_name=related_name,
|
||||
related_query_name=related_query_name,
|
||||
limit_choices_to=limit_choices_to,
|
||||
parent_link=parent_link,
|
||||
on_delete=on_delete,
|
||||
)
|
||||
|
||||
self.multiple = False
|
||||
|
||||
|
||||
class ManyToManyRel(ForeignObjectRel):
|
||||
"""
|
||||
Used by ManyToManyField to store information about the relation.
|
||||
|
||||
``_meta.get_fields()`` returns this class to provide access to the field
|
||||
flags for the reverse relation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
field,
|
||||
to,
|
||||
related_name=None,
|
||||
related_query_name=None,
|
||||
limit_choices_to=None,
|
||||
symmetrical=True,
|
||||
through=None,
|
||||
through_fields=None,
|
||||
db_constraint=True,
|
||||
):
|
||||
super().__init__(
|
||||
field,
|
||||
to,
|
||||
related_name=related_name,
|
||||
related_query_name=related_query_name,
|
||||
limit_choices_to=limit_choices_to,
|
||||
)
|
||||
|
||||
if through and not db_constraint:
|
||||
raise ValueError("Can't supply a through model and db_constraint=False")
|
||||
self.through = through
|
||||
|
||||
if through_fields and not through:
|
||||
raise ValueError("Cannot specify through_fields without a through model")
|
||||
self.through_fields = through_fields
|
||||
|
||||
self.symmetrical = symmetrical
|
||||
self.db_constraint = db_constraint
|
||||
|
||||
@property
|
||||
def identity(self):
|
||||
return super().identity + (
|
||||
self.through,
|
||||
make_hashable(self.through_fields),
|
||||
self.db_constraint,
|
||||
)
|
||||
|
||||
def get_related_field(self):
|
||||
"""
|
||||
Return the field in the 'to' object to which this relationship is tied.
|
||||
Provided for symmetry with ManyToOneRel.
|
||||
"""
|
||||
opts = self.through._meta
|
||||
if self.through_fields:
|
||||
field = opts.get_field(self.through_fields[0])
|
||||
else:
|
||||
for field in opts.fields:
|
||||
rel = getattr(field, "remote_field", None)
|
||||
if rel and rel.model == self.model:
|
||||
break
|
||||
return field.foreign_related_fields[0]
|
@ -0,0 +1,190 @@
|
||||
from .comparison import Cast, Coalesce, Collate, Greatest, JSONObject, Least, NullIf
|
||||
from .datetime import (
|
||||
Extract,
|
||||
ExtractDay,
|
||||
ExtractHour,
|
||||
ExtractIsoWeekDay,
|
||||
ExtractIsoYear,
|
||||
ExtractMinute,
|
||||
ExtractMonth,
|
||||
ExtractQuarter,
|
||||
ExtractSecond,
|
||||
ExtractWeek,
|
||||
ExtractWeekDay,
|
||||
ExtractYear,
|
||||
Now,
|
||||
Trunc,
|
||||
TruncDate,
|
||||
TruncDay,
|
||||
TruncHour,
|
||||
TruncMinute,
|
||||
TruncMonth,
|
||||
TruncQuarter,
|
||||
TruncSecond,
|
||||
TruncTime,
|
||||
TruncWeek,
|
||||
TruncYear,
|
||||
)
|
||||
from .math import (
|
||||
Abs,
|
||||
ACos,
|
||||
ASin,
|
||||
ATan,
|
||||
ATan2,
|
||||
Ceil,
|
||||
Cos,
|
||||
Cot,
|
||||
Degrees,
|
||||
Exp,
|
||||
Floor,
|
||||
Ln,
|
||||
Log,
|
||||
Mod,
|
||||
Pi,
|
||||
Power,
|
||||
Radians,
|
||||
Random,
|
||||
Round,
|
||||
Sign,
|
||||
Sin,
|
||||
Sqrt,
|
||||
Tan,
|
||||
)
|
||||
from .text import (
|
||||
MD5,
|
||||
SHA1,
|
||||
SHA224,
|
||||
SHA256,
|
||||
SHA384,
|
||||
SHA512,
|
||||
Chr,
|
||||
Concat,
|
||||
ConcatPair,
|
||||
Left,
|
||||
Length,
|
||||
Lower,
|
||||
LPad,
|
||||
LTrim,
|
||||
Ord,
|
||||
Repeat,
|
||||
Replace,
|
||||
Reverse,
|
||||
Right,
|
||||
RPad,
|
||||
RTrim,
|
||||
StrIndex,
|
||||
Substr,
|
||||
Trim,
|
||||
Upper,
|
||||
)
|
||||
from .window import (
|
||||
CumeDist,
|
||||
DenseRank,
|
||||
FirstValue,
|
||||
Lag,
|
||||
LastValue,
|
||||
Lead,
|
||||
NthValue,
|
||||
Ntile,
|
||||
PercentRank,
|
||||
Rank,
|
||||
RowNumber,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# comparison and conversion
|
||||
"Cast",
|
||||
"Coalesce",
|
||||
"Collate",
|
||||
"Greatest",
|
||||
"JSONObject",
|
||||
"Least",
|
||||
"NullIf",
|
||||
# datetime
|
||||
"Extract",
|
||||
"ExtractDay",
|
||||
"ExtractHour",
|
||||
"ExtractMinute",
|
||||
"ExtractMonth",
|
||||
"ExtractQuarter",
|
||||
"ExtractSecond",
|
||||
"ExtractWeek",
|
||||
"ExtractIsoWeekDay",
|
||||
"ExtractWeekDay",
|
||||
"ExtractIsoYear",
|
||||
"ExtractYear",
|
||||
"Now",
|
||||
"Trunc",
|
||||
"TruncDate",
|
||||
"TruncDay",
|
||||
"TruncHour",
|
||||
"TruncMinute",
|
||||
"TruncMonth",
|
||||
"TruncQuarter",
|
||||
"TruncSecond",
|
||||
"TruncTime",
|
||||
"TruncWeek",
|
||||
"TruncYear",
|
||||
# math
|
||||
"Abs",
|
||||
"ACos",
|
||||
"ASin",
|
||||
"ATan",
|
||||
"ATan2",
|
||||
"Ceil",
|
||||
"Cos",
|
||||
"Cot",
|
||||
"Degrees",
|
||||
"Exp",
|
||||
"Floor",
|
||||
"Ln",
|
||||
"Log",
|
||||
"Mod",
|
||||
"Pi",
|
||||
"Power",
|
||||
"Radians",
|
||||
"Random",
|
||||
"Round",
|
||||
"Sign",
|
||||
"Sin",
|
||||
"Sqrt",
|
||||
"Tan",
|
||||
# text
|
||||
"MD5",
|
||||
"SHA1",
|
||||
"SHA224",
|
||||
"SHA256",
|
||||
"SHA384",
|
||||
"SHA512",
|
||||
"Chr",
|
||||
"Concat",
|
||||
"ConcatPair",
|
||||
"Left",
|
||||
"Length",
|
||||
"Lower",
|
||||
"LPad",
|
||||
"LTrim",
|
||||
"Ord",
|
||||
"Repeat",
|
||||
"Replace",
|
||||
"Reverse",
|
||||
"Right",
|
||||
"RPad",
|
||||
"RTrim",
|
||||
"StrIndex",
|
||||
"Substr",
|
||||
"Trim",
|
||||
"Upper",
|
||||
# window
|
||||
"CumeDist",
|
||||
"DenseRank",
|
||||
"FirstValue",
|
||||
"Lag",
|
||||
"LastValue",
|
||||
"Lead",
|
||||
"NthValue",
|
||||
"Ntile",
|
||||
"PercentRank",
|
||||
"Rank",
|
||||
"RowNumber",
|
||||
]
|
@ -0,0 +1,220 @@
|
||||
"""Database functions that do comparisons or type conversions."""
|
||||
from django.db import NotSupportedError
|
||||
from django.db.models.expressions import Func, Value
|
||||
from django.db.models.fields import TextField
|
||||
from django.db.models.fields.json import JSONField
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
|
||||
|
||||
class Cast(Func):
|
||||
"""Coerce an expression to a new field type."""
|
||||
|
||||
function = "CAST"
|
||||
template = "%(function)s(%(expressions)s AS %(db_type)s)"
|
||||
|
||||
def __init__(self, expression, output_field):
|
||||
super().__init__(expression, output_field=output_field)
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
extra_context["db_type"] = self.output_field.cast_db_type(connection)
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
db_type = self.output_field.db_type(connection)
|
||||
if db_type in {"datetime", "time"}:
|
||||
# Use strftime as datetime/time don't keep fractional seconds.
|
||||
template = "strftime(%%s, %(expressions)s)"
|
||||
sql, params = super().as_sql(
|
||||
compiler, connection, template=template, **extra_context
|
||||
)
|
||||
format_string = "%H:%M:%f" if db_type == "time" else "%Y-%m-%d %H:%M:%f"
|
||||
params.insert(0, format_string)
|
||||
return sql, params
|
||||
elif db_type == "date":
|
||||
template = "date(%(expressions)s)"
|
||||
return super().as_sql(
|
||||
compiler, connection, template=template, **extra_context
|
||||
)
|
||||
return self.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
template = None
|
||||
output_type = self.output_field.get_internal_type()
|
||||
# MySQL doesn't support explicit cast to float.
|
||||
if output_type == "FloatField":
|
||||
template = "(%(expressions)s + 0.0)"
|
||||
# MariaDB doesn't support explicit cast to JSON.
|
||||
elif output_type == "JSONField" and connection.mysql_is_mariadb:
|
||||
template = "JSON_EXTRACT(%(expressions)s, '$')"
|
||||
return self.as_sql(compiler, connection, template=template, **extra_context)
|
||||
|
||||
def as_postgresql(self, compiler, connection, **extra_context):
|
||||
# CAST would be valid too, but the :: shortcut syntax is more readable.
|
||||
# 'expressions' is wrapped in parentheses in case it's a complex
|
||||
# expression.
|
||||
return self.as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template="(%(expressions)s)::%(db_type)s",
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
if self.output_field.get_internal_type() == "JSONField":
|
||||
# Oracle doesn't support explicit cast to JSON.
|
||||
template = "JSON_QUERY(%(expressions)s, '$')"
|
||||
return super().as_sql(
|
||||
compiler, connection, template=template, **extra_context
|
||||
)
|
||||
return self.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Coalesce(Func):
|
||||
"""Return, from left to right, the first non-null expression."""
|
||||
|
||||
function = "COALESCE"
|
||||
|
||||
def __init__(self, *expressions, **extra):
|
||||
if len(expressions) < 2:
|
||||
raise ValueError("Coalesce must take at least two expressions")
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
@property
|
||||
def empty_result_set_value(self):
|
||||
for expression in self.get_source_expressions():
|
||||
result = expression.empty_result_set_value
|
||||
if result is NotImplemented or result is not None:
|
||||
return result
|
||||
return None
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
# Oracle prohibits mixing TextField (NCLOB) and CharField (NVARCHAR2),
|
||||
# so convert all fields to NCLOB when that type is expected.
|
||||
if self.output_field.get_internal_type() == "TextField":
|
||||
clone = self.copy()
|
||||
clone.set_source_expressions(
|
||||
[
|
||||
Func(expression, function="TO_NCLOB")
|
||||
for expression in self.get_source_expressions()
|
||||
]
|
||||
)
|
||||
return super(Coalesce, clone).as_sql(compiler, connection, **extra_context)
|
||||
return self.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Collate(Func):
|
||||
function = "COLLATE"
|
||||
template = "%(expressions)s %(function)s %(collation)s"
|
||||
# Inspired from
|
||||
# https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
|
||||
collation_re = _lazy_re_compile(r"^[\w\-]+$")
|
||||
|
||||
def __init__(self, expression, collation):
|
||||
if not (collation and self.collation_re.match(collation)):
|
||||
raise ValueError("Invalid collation name: %r." % collation)
|
||||
self.collation = collation
|
||||
super().__init__(expression)
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
extra_context.setdefault("collation", connection.ops.quote_name(self.collation))
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Greatest(Func):
|
||||
"""
|
||||
Return the maximum expression.
|
||||
|
||||
If any expression is null the return value is database-specific:
|
||||
On PostgreSQL, the maximum not-null expression is returned.
|
||||
On MySQL, Oracle, and SQLite, if any expression is null, null is returned.
|
||||
"""
|
||||
|
||||
function = "GREATEST"
|
||||
|
||||
def __init__(self, *expressions, **extra):
|
||||
if len(expressions) < 2:
|
||||
raise ValueError("Greatest must take at least two expressions")
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
"""Use the MAX function on SQLite."""
|
||||
return super().as_sqlite(compiler, connection, function="MAX", **extra_context)
|
||||
|
||||
|
||||
class JSONObject(Func):
|
||||
function = "JSON_OBJECT"
|
||||
output_field = JSONField()
|
||||
|
||||
def __init__(self, **fields):
|
||||
expressions = []
|
||||
for key, value in fields.items():
|
||||
expressions.extend((Value(key), value))
|
||||
super().__init__(*expressions)
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
if not connection.features.has_json_object_function:
|
||||
raise NotSupportedError(
|
||||
"JSONObject() is not supported on this database backend."
|
||||
)
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
def as_postgresql(self, compiler, connection, **extra_context):
|
||||
copy = self.copy()
|
||||
copy.set_source_expressions(
|
||||
[
|
||||
Cast(expression, TextField()) if index % 2 == 0 else expression
|
||||
for index, expression in enumerate(copy.get_source_expressions())
|
||||
]
|
||||
)
|
||||
return super(JSONObject, copy).as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
function="JSONB_BUILD_OBJECT",
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
class ArgJoiner:
|
||||
def join(self, args):
|
||||
args = [" VALUE ".join(arg) for arg in zip(args[::2], args[1::2])]
|
||||
return ", ".join(args)
|
||||
|
||||
return self.as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
arg_joiner=ArgJoiner(),
|
||||
template="%(function)s(%(expressions)s RETURNING CLOB)",
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
|
||||
class Least(Func):
|
||||
"""
|
||||
Return the minimum expression.
|
||||
|
||||
If any expression is null the return value is database-specific:
|
||||
On PostgreSQL, return the minimum not-null expression.
|
||||
On MySQL, Oracle, and SQLite, if any expression is null, return null.
|
||||
"""
|
||||
|
||||
function = "LEAST"
|
||||
|
||||
def __init__(self, *expressions, **extra):
|
||||
if len(expressions) < 2:
|
||||
raise ValueError("Least must take at least two expressions")
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
"""Use the MIN function on SQLite."""
|
||||
return super().as_sqlite(compiler, connection, function="MIN", **extra_context)
|
||||
|
||||
|
||||
class NullIf(Func):
|
||||
function = "NULLIF"
|
||||
arity = 2
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
expression1 = self.get_source_expressions()[0]
|
||||
if isinstance(expression1, Value) and expression1.value is None:
|
||||
raise ValueError("Oracle does not allow Value(None) for expression1.")
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
@ -0,0 +1,439 @@
|
||||
from datetime import datetime
|
||||
|
||||
from django.conf import settings
|
||||
from django.db.models.expressions import Func
|
||||
from django.db.models.fields import (
|
||||
DateField,
|
||||
DateTimeField,
|
||||
DurationField,
|
||||
Field,
|
||||
IntegerField,
|
||||
TimeField,
|
||||
)
|
||||
from django.db.models.lookups import (
|
||||
Transform,
|
||||
YearExact,
|
||||
YearGt,
|
||||
YearGte,
|
||||
YearLt,
|
||||
YearLte,
|
||||
)
|
||||
from django.utils import timezone
|
||||
|
||||
|
||||
class TimezoneMixin:
|
||||
tzinfo = None
|
||||
|
||||
def get_tzname(self):
|
||||
# Timezone conversions must happen to the input datetime *before*
|
||||
# applying a function. 2015-12-31 23:00:00 -02:00 is stored in the
|
||||
# database as 2016-01-01 01:00:00 +00:00. Any results should be
|
||||
# based on the input datetime not the stored datetime.
|
||||
tzname = None
|
||||
if settings.USE_TZ:
|
||||
if self.tzinfo is None:
|
||||
tzname = timezone.get_current_timezone_name()
|
||||
else:
|
||||
tzname = timezone._get_timezone_name(self.tzinfo)
|
||||
return tzname
|
||||
|
||||
|
||||
class Extract(TimezoneMixin, Transform):
|
||||
lookup_name = None
|
||||
output_field = IntegerField()
|
||||
|
||||
def __init__(self, expression, lookup_name=None, tzinfo=None, **extra):
|
||||
if self.lookup_name is None:
|
||||
self.lookup_name = lookup_name
|
||||
if self.lookup_name is None:
|
||||
raise ValueError("lookup_name must be provided")
|
||||
self.tzinfo = tzinfo
|
||||
super().__init__(expression, **extra)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
sql, params = compiler.compile(self.lhs)
|
||||
lhs_output_field = self.lhs.output_field
|
||||
if isinstance(lhs_output_field, DateTimeField):
|
||||
tzname = self.get_tzname()
|
||||
sql, params = connection.ops.datetime_extract_sql(
|
||||
self.lookup_name, sql, tuple(params), tzname
|
||||
)
|
||||
elif self.tzinfo is not None:
|
||||
raise ValueError("tzinfo can only be used with DateTimeField.")
|
||||
elif isinstance(lhs_output_field, DateField):
|
||||
sql, params = connection.ops.date_extract_sql(
|
||||
self.lookup_name, sql, tuple(params)
|
||||
)
|
||||
elif isinstance(lhs_output_field, TimeField):
|
||||
sql, params = connection.ops.time_extract_sql(
|
||||
self.lookup_name, sql, tuple(params)
|
||||
)
|
||||
elif isinstance(lhs_output_field, DurationField):
|
||||
if not connection.features.has_native_duration_field:
|
||||
raise ValueError(
|
||||
"Extract requires native DurationField database support."
|
||||
)
|
||||
sql, params = connection.ops.time_extract_sql(
|
||||
self.lookup_name, sql, tuple(params)
|
||||
)
|
||||
else:
|
||||
# resolve_expression has already validated the output_field so this
|
||||
# assert should never be hit.
|
||||
assert False, "Tried to Extract from an invalid type."
|
||||
return sql, params
|
||||
|
||||
def resolve_expression(
|
||||
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
|
||||
):
|
||||
copy = super().resolve_expression(
|
||||
query, allow_joins, reuse, summarize, for_save
|
||||
)
|
||||
field = getattr(copy.lhs, "output_field", None)
|
||||
if field is None:
|
||||
return copy
|
||||
if not isinstance(field, (DateField, DateTimeField, TimeField, DurationField)):
|
||||
raise ValueError(
|
||||
"Extract input expression must be DateField, DateTimeField, "
|
||||
"TimeField, or DurationField."
|
||||
)
|
||||
# Passing dates to functions expecting datetimes is most likely a mistake.
|
||||
if type(field) is DateField and copy.lookup_name in (
|
||||
"hour",
|
||||
"minute",
|
||||
"second",
|
||||
):
|
||||
raise ValueError(
|
||||
"Cannot extract time component '%s' from DateField '%s'."
|
||||
% (copy.lookup_name, field.name)
|
||||
)
|
||||
if isinstance(field, DurationField) and copy.lookup_name in (
|
||||
"year",
|
||||
"iso_year",
|
||||
"month",
|
||||
"week",
|
||||
"week_day",
|
||||
"iso_week_day",
|
||||
"quarter",
|
||||
):
|
||||
raise ValueError(
|
||||
"Cannot extract component '%s' from DurationField '%s'."
|
||||
% (copy.lookup_name, field.name)
|
||||
)
|
||||
return copy
|
||||
|
||||
|
||||
class ExtractYear(Extract):
|
||||
lookup_name = "year"
|
||||
|
||||
|
||||
class ExtractIsoYear(Extract):
|
||||
"""Return the ISO-8601 week-numbering year."""
|
||||
|
||||
lookup_name = "iso_year"
|
||||
|
||||
|
||||
class ExtractMonth(Extract):
|
||||
lookup_name = "month"
|
||||
|
||||
|
||||
class ExtractDay(Extract):
|
||||
lookup_name = "day"
|
||||
|
||||
|
||||
class ExtractWeek(Extract):
|
||||
"""
|
||||
Return 1-52 or 53, based on ISO-8601, i.e., Monday is the first of the
|
||||
week.
|
||||
"""
|
||||
|
||||
lookup_name = "week"
|
||||
|
||||
|
||||
class ExtractWeekDay(Extract):
|
||||
"""
|
||||
Return Sunday=1 through Saturday=7.
|
||||
|
||||
To replicate this in Python: (mydatetime.isoweekday() % 7) + 1
|
||||
"""
|
||||
|
||||
lookup_name = "week_day"
|
||||
|
||||
|
||||
class ExtractIsoWeekDay(Extract):
|
||||
"""Return Monday=1 through Sunday=7, based on ISO-8601."""
|
||||
|
||||
lookup_name = "iso_week_day"
|
||||
|
||||
|
||||
class ExtractQuarter(Extract):
|
||||
lookup_name = "quarter"
|
||||
|
||||
|
||||
class ExtractHour(Extract):
|
||||
lookup_name = "hour"
|
||||
|
||||
|
||||
class ExtractMinute(Extract):
|
||||
lookup_name = "minute"
|
||||
|
||||
|
||||
class ExtractSecond(Extract):
|
||||
lookup_name = "second"
|
||||
|
||||
|
||||
DateField.register_lookup(ExtractYear)
|
||||
DateField.register_lookup(ExtractMonth)
|
||||
DateField.register_lookup(ExtractDay)
|
||||
DateField.register_lookup(ExtractWeekDay)
|
||||
DateField.register_lookup(ExtractIsoWeekDay)
|
||||
DateField.register_lookup(ExtractWeek)
|
||||
DateField.register_lookup(ExtractIsoYear)
|
||||
DateField.register_lookup(ExtractQuarter)
|
||||
|
||||
TimeField.register_lookup(ExtractHour)
|
||||
TimeField.register_lookup(ExtractMinute)
|
||||
TimeField.register_lookup(ExtractSecond)
|
||||
|
||||
DateTimeField.register_lookup(ExtractHour)
|
||||
DateTimeField.register_lookup(ExtractMinute)
|
||||
DateTimeField.register_lookup(ExtractSecond)
|
||||
|
||||
ExtractYear.register_lookup(YearExact)
|
||||
ExtractYear.register_lookup(YearGt)
|
||||
ExtractYear.register_lookup(YearGte)
|
||||
ExtractYear.register_lookup(YearLt)
|
||||
ExtractYear.register_lookup(YearLte)
|
||||
|
||||
ExtractIsoYear.register_lookup(YearExact)
|
||||
ExtractIsoYear.register_lookup(YearGt)
|
||||
ExtractIsoYear.register_lookup(YearGte)
|
||||
ExtractIsoYear.register_lookup(YearLt)
|
||||
ExtractIsoYear.register_lookup(YearLte)
|
||||
|
||||
|
||||
class Now(Func):
|
||||
template = "CURRENT_TIMESTAMP"
|
||||
output_field = DateTimeField()
|
||||
|
||||
def as_postgresql(self, compiler, connection, **extra_context):
|
||||
# PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the
|
||||
# transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with
|
||||
# other databases.
|
||||
return self.as_sql(
|
||||
compiler, connection, template="STATEMENT_TIMESTAMP()", **extra_context
|
||||
)
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
return self.as_sql(
|
||||
compiler, connection, template="CURRENT_TIMESTAMP(6)", **extra_context
|
||||
)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
return self.as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template="STRFTIME('%%%%Y-%%%%m-%%%%d %%%%H:%%%%M:%%%%f', 'NOW')",
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
|
||||
class TruncBase(TimezoneMixin, Transform):
|
||||
kind = None
|
||||
tzinfo = None
|
||||
|
||||
# RemovedInDjango50Warning: when the deprecation ends, remove is_dst
|
||||
# argument.
|
||||
def __init__(
|
||||
self,
|
||||
expression,
|
||||
output_field=None,
|
||||
tzinfo=None,
|
||||
is_dst=timezone.NOT_PASSED,
|
||||
**extra,
|
||||
):
|
||||
self.tzinfo = tzinfo
|
||||
self.is_dst = is_dst
|
||||
super().__init__(expression, output_field=output_field, **extra)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
sql, params = compiler.compile(self.lhs)
|
||||
tzname = None
|
||||
if isinstance(self.lhs.output_field, DateTimeField):
|
||||
tzname = self.get_tzname()
|
||||
elif self.tzinfo is not None:
|
||||
raise ValueError("tzinfo can only be used with DateTimeField.")
|
||||
if isinstance(self.output_field, DateTimeField):
|
||||
sql, params = connection.ops.datetime_trunc_sql(
|
||||
self.kind, sql, tuple(params), tzname
|
||||
)
|
||||
elif isinstance(self.output_field, DateField):
|
||||
sql, params = connection.ops.date_trunc_sql(
|
||||
self.kind, sql, tuple(params), tzname
|
||||
)
|
||||
elif isinstance(self.output_field, TimeField):
|
||||
sql, params = connection.ops.time_trunc_sql(
|
||||
self.kind, sql, tuple(params), tzname
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Trunc only valid on DateField, TimeField, or DateTimeField."
|
||||
)
|
||||
return sql, params
|
||||
|
||||
def resolve_expression(
|
||||
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
|
||||
):
|
||||
copy = super().resolve_expression(
|
||||
query, allow_joins, reuse, summarize, for_save
|
||||
)
|
||||
field = copy.lhs.output_field
|
||||
# DateTimeField is a subclass of DateField so this works for both.
|
||||
if not isinstance(field, (DateField, TimeField)):
|
||||
raise TypeError(
|
||||
"%r isn't a DateField, TimeField, or DateTimeField." % field.name
|
||||
)
|
||||
# If self.output_field was None, then accessing the field will trigger
|
||||
# the resolver to assign it to self.lhs.output_field.
|
||||
if not isinstance(copy.output_field, (DateField, DateTimeField, TimeField)):
|
||||
raise ValueError(
|
||||
"output_field must be either DateField, TimeField, or DateTimeField"
|
||||
)
|
||||
# Passing dates or times to functions expecting datetimes is most
|
||||
# likely a mistake.
|
||||
class_output_field = (
|
||||
self.__class__.output_field
|
||||
if isinstance(self.__class__.output_field, Field)
|
||||
else None
|
||||
)
|
||||
output_field = class_output_field or copy.output_field
|
||||
has_explicit_output_field = (
|
||||
class_output_field or field.__class__ is not copy.output_field.__class__
|
||||
)
|
||||
if type(field) is DateField and (
|
||||
isinstance(output_field, DateTimeField)
|
||||
or copy.kind in ("hour", "minute", "second", "time")
|
||||
):
|
||||
raise ValueError(
|
||||
"Cannot truncate DateField '%s' to %s."
|
||||
% (
|
||||
field.name,
|
||||
output_field.__class__.__name__
|
||||
if has_explicit_output_field
|
||||
else "DateTimeField",
|
||||
)
|
||||
)
|
||||
elif isinstance(field, TimeField) and (
|
||||
isinstance(output_field, DateTimeField)
|
||||
or copy.kind in ("year", "quarter", "month", "week", "day", "date")
|
||||
):
|
||||
raise ValueError(
|
||||
"Cannot truncate TimeField '%s' to %s."
|
||||
% (
|
||||
field.name,
|
||||
output_field.__class__.__name__
|
||||
if has_explicit_output_field
|
||||
else "DateTimeField",
|
||||
)
|
||||
)
|
||||
return copy
|
||||
|
||||
def convert_value(self, value, expression, connection):
|
||||
if isinstance(self.output_field, DateTimeField):
|
||||
if not settings.USE_TZ:
|
||||
pass
|
||||
elif value is not None:
|
||||
value = value.replace(tzinfo=None)
|
||||
value = timezone.make_aware(value, self.tzinfo, is_dst=self.is_dst)
|
||||
elif not connection.features.has_zoneinfo_database:
|
||||
raise ValueError(
|
||||
"Database returned an invalid datetime value. Are time "
|
||||
"zone definitions for your database installed?"
|
||||
)
|
||||
elif isinstance(value, datetime):
|
||||
if value is None:
|
||||
pass
|
||||
elif isinstance(self.output_field, DateField):
|
||||
value = value.date()
|
||||
elif isinstance(self.output_field, TimeField):
|
||||
value = value.time()
|
||||
return value
|
||||
|
||||
|
||||
class Trunc(TruncBase):
|
||||
# RemovedInDjango50Warning: when the deprecation ends, remove is_dst
|
||||
# argument.
|
||||
def __init__(
|
||||
self,
|
||||
expression,
|
||||
kind,
|
||||
output_field=None,
|
||||
tzinfo=None,
|
||||
is_dst=timezone.NOT_PASSED,
|
||||
**extra,
|
||||
):
|
||||
self.kind = kind
|
||||
super().__init__(
|
||||
expression, output_field=output_field, tzinfo=tzinfo, is_dst=is_dst, **extra
|
||||
)
|
||||
|
||||
|
||||
class TruncYear(TruncBase):
|
||||
kind = "year"
|
||||
|
||||
|
||||
class TruncQuarter(TruncBase):
|
||||
kind = "quarter"
|
||||
|
||||
|
||||
class TruncMonth(TruncBase):
|
||||
kind = "month"
|
||||
|
||||
|
||||
class TruncWeek(TruncBase):
|
||||
"""Truncate to midnight on the Monday of the week."""
|
||||
|
||||
kind = "week"
|
||||
|
||||
|
||||
class TruncDay(TruncBase):
|
||||
kind = "day"
|
||||
|
||||
|
||||
class TruncDate(TruncBase):
|
||||
kind = "date"
|
||||
lookup_name = "date"
|
||||
output_field = DateField()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
# Cast to date rather than truncate to date.
|
||||
sql, params = compiler.compile(self.lhs)
|
||||
tzname = self.get_tzname()
|
||||
return connection.ops.datetime_cast_date_sql(sql, tuple(params), tzname)
|
||||
|
||||
|
||||
class TruncTime(TruncBase):
|
||||
kind = "time"
|
||||
lookup_name = "time"
|
||||
output_field = TimeField()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
# Cast to time rather than truncate to time.
|
||||
sql, params = compiler.compile(self.lhs)
|
||||
tzname = self.get_tzname()
|
||||
return connection.ops.datetime_cast_time_sql(sql, tuple(params), tzname)
|
||||
|
||||
|
||||
class TruncHour(TruncBase):
|
||||
kind = "hour"
|
||||
|
||||
|
||||
class TruncMinute(TruncBase):
|
||||
kind = "minute"
|
||||
|
||||
|
||||
class TruncSecond(TruncBase):
|
||||
kind = "second"
|
||||
|
||||
|
||||
DateTimeField.register_lookup(TruncDate)
|
||||
DateTimeField.register_lookup(TruncTime)
|
@ -0,0 +1,212 @@
|
||||
import math
|
||||
|
||||
from django.db.models.expressions import Func, Value
|
||||
from django.db.models.fields import FloatField, IntegerField
|
||||
from django.db.models.functions import Cast
|
||||
from django.db.models.functions.mixins import (
|
||||
FixDecimalInputMixin,
|
||||
NumericOutputFieldMixin,
|
||||
)
|
||||
from django.db.models.lookups import Transform
|
||||
|
||||
|
||||
class Abs(Transform):
|
||||
function = "ABS"
|
||||
lookup_name = "abs"
|
||||
|
||||
|
||||
class ACos(NumericOutputFieldMixin, Transform):
|
||||
function = "ACOS"
|
||||
lookup_name = "acos"
|
||||
|
||||
|
||||
class ASin(NumericOutputFieldMixin, Transform):
|
||||
function = "ASIN"
|
||||
lookup_name = "asin"
|
||||
|
||||
|
||||
class ATan(NumericOutputFieldMixin, Transform):
|
||||
function = "ATAN"
|
||||
lookup_name = "atan"
|
||||
|
||||
|
||||
class ATan2(NumericOutputFieldMixin, Func):
|
||||
function = "ATAN2"
|
||||
arity = 2
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
if not getattr(
|
||||
connection.ops, "spatialite", False
|
||||
) or connection.ops.spatial_version >= (5, 0, 0):
|
||||
return self.as_sql(compiler, connection)
|
||||
# This function is usually ATan2(y, x), returning the inverse tangent
|
||||
# of y / x, but it's ATan2(x, y) on SpatiaLite < 5.0.0.
|
||||
# Cast integers to float to avoid inconsistent/buggy behavior if the
|
||||
# arguments are mixed between integer and float or decimal.
|
||||
# https://www.gaia-gis.it/fossil/libspatialite/tktview?name=0f72cca3a2
|
||||
clone = self.copy()
|
||||
clone.set_source_expressions(
|
||||
[
|
||||
Cast(expression, FloatField())
|
||||
if isinstance(expression.output_field, IntegerField)
|
||||
else expression
|
||||
for expression in self.get_source_expressions()[::-1]
|
||||
]
|
||||
)
|
||||
return clone.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Ceil(Transform):
|
||||
function = "CEILING"
|
||||
lookup_name = "ceil"
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="CEIL", **extra_context)
|
||||
|
||||
|
||||
class Cos(NumericOutputFieldMixin, Transform):
|
||||
function = "COS"
|
||||
lookup_name = "cos"
|
||||
|
||||
|
||||
class Cot(NumericOutputFieldMixin, Transform):
|
||||
function = "COT"
|
||||
lookup_name = "cot"
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler, connection, template="(1 / TAN(%(expressions)s))", **extra_context
|
||||
)
|
||||
|
||||
|
||||
class Degrees(NumericOutputFieldMixin, Transform):
|
||||
function = "DEGREES"
|
||||
lookup_name = "degrees"
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template="((%%(expressions)s) * 180 / %s)" % math.pi,
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
|
||||
class Exp(NumericOutputFieldMixin, Transform):
|
||||
function = "EXP"
|
||||
lookup_name = "exp"
|
||||
|
||||
|
||||
class Floor(Transform):
|
||||
function = "FLOOR"
|
||||
lookup_name = "floor"
|
||||
|
||||
|
||||
class Ln(NumericOutputFieldMixin, Transform):
|
||||
function = "LN"
|
||||
lookup_name = "ln"
|
||||
|
||||
|
||||
class Log(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
|
||||
function = "LOG"
|
||||
arity = 2
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
if not getattr(connection.ops, "spatialite", False):
|
||||
return self.as_sql(compiler, connection)
|
||||
# This function is usually Log(b, x) returning the logarithm of x to
|
||||
# the base b, but on SpatiaLite it's Log(x, b).
|
||||
clone = self.copy()
|
||||
clone.set_source_expressions(self.get_source_expressions()[::-1])
|
||||
return clone.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Mod(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
|
||||
function = "MOD"
|
||||
arity = 2
|
||||
|
||||
|
||||
class Pi(NumericOutputFieldMixin, Func):
|
||||
function = "PI"
|
||||
arity = 0
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler, connection, template=str(math.pi), **extra_context
|
||||
)
|
||||
|
||||
|
||||
class Power(NumericOutputFieldMixin, Func):
|
||||
function = "POWER"
|
||||
arity = 2
|
||||
|
||||
|
||||
class Radians(NumericOutputFieldMixin, Transform):
|
||||
function = "RADIANS"
|
||||
lookup_name = "radians"
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template="((%%(expressions)s) * %s / 180)" % math.pi,
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
|
||||
class Random(NumericOutputFieldMixin, Func):
|
||||
function = "RANDOM"
|
||||
arity = 0
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="RAND", **extra_context)
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler, connection, function="DBMS_RANDOM.VALUE", **extra_context
|
||||
)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="RAND", **extra_context)
|
||||
|
||||
def get_group_by_cols(self):
|
||||
return []
|
||||
|
||||
|
||||
class Round(FixDecimalInputMixin, Transform):
|
||||
function = "ROUND"
|
||||
lookup_name = "round"
|
||||
arity = None # Override Transform's arity=1 to enable passing precision.
|
||||
|
||||
def __init__(self, expression, precision=0, **extra):
|
||||
super().__init__(expression, precision, **extra)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
precision = self.get_source_expressions()[1]
|
||||
if isinstance(precision, Value) and precision.value < 0:
|
||||
raise ValueError("SQLite does not support negative precision.")
|
||||
return super().as_sqlite(compiler, connection, **extra_context)
|
||||
|
||||
def _resolve_output_field(self):
|
||||
source = self.get_source_expressions()[0]
|
||||
return source.output_field
|
||||
|
||||
|
||||
class Sign(Transform):
|
||||
function = "SIGN"
|
||||
lookup_name = "sign"
|
||||
|
||||
|
||||
class Sin(NumericOutputFieldMixin, Transform):
|
||||
function = "SIN"
|
||||
lookup_name = "sin"
|
||||
|
||||
|
||||
class Sqrt(NumericOutputFieldMixin, Transform):
|
||||
function = "SQRT"
|
||||
lookup_name = "sqrt"
|
||||
|
||||
|
||||
class Tan(NumericOutputFieldMixin, Transform):
|
||||
function = "TAN"
|
||||
lookup_name = "tan"
|
@ -0,0 +1,57 @@
|
||||
import sys
|
||||
|
||||
from django.db.models.fields import DecimalField, FloatField, IntegerField
|
||||
from django.db.models.functions import Cast
|
||||
|
||||
|
||||
class FixDecimalInputMixin:
|
||||
def as_postgresql(self, compiler, connection, **extra_context):
|
||||
# Cast FloatField to DecimalField as PostgreSQL doesn't support the
|
||||
# following function signatures:
|
||||
# - LOG(double, double)
|
||||
# - MOD(double, double)
|
||||
output_field = DecimalField(decimal_places=sys.float_info.dig, max_digits=1000)
|
||||
clone = self.copy()
|
||||
clone.set_source_expressions(
|
||||
[
|
||||
Cast(expression, output_field)
|
||||
if isinstance(expression.output_field, FloatField)
|
||||
else expression
|
||||
for expression in self.get_source_expressions()
|
||||
]
|
||||
)
|
||||
return clone.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class FixDurationInputMixin:
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
sql, params = super().as_sql(compiler, connection, **extra_context)
|
||||
if self.output_field.get_internal_type() == "DurationField":
|
||||
sql = "CAST(%s AS SIGNED)" % sql
|
||||
return sql, params
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
if self.output_field.get_internal_type() == "DurationField":
|
||||
expression = self.get_source_expressions()[0]
|
||||
options = self._get_repr_options()
|
||||
from django.db.backends.oracle.functions import (
|
||||
IntervalToSeconds,
|
||||
SecondsToInterval,
|
||||
)
|
||||
|
||||
return compiler.compile(
|
||||
SecondsToInterval(
|
||||
self.__class__(IntervalToSeconds(expression), **options)
|
||||
)
|
||||
)
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class NumericOutputFieldMixin:
|
||||
def _resolve_output_field(self):
|
||||
source_fields = self.get_source_fields()
|
||||
if any(isinstance(s, DecimalField) for s in source_fields):
|
||||
return DecimalField()
|
||||
if any(isinstance(s, IntegerField) for s in source_fields):
|
||||
return FloatField()
|
||||
return super()._resolve_output_field() if source_fields else FloatField()
|
@ -0,0 +1,365 @@
|
||||
from django.db import NotSupportedError
|
||||
from django.db.models.expressions import Func, Value
|
||||
from django.db.models.fields import CharField, IntegerField, TextField
|
||||
from django.db.models.functions import Cast, Coalesce
|
||||
from django.db.models.lookups import Transform
|
||||
|
||||
|
||||
class MySQLSHA2Mixin:
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template="SHA2(%%(expressions)s, %s)" % self.function[3:],
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
|
||||
class OracleHashMixin:
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template=(
|
||||
"LOWER(RAWTOHEX(STANDARD_HASH(UTL_I18N.STRING_TO_RAW("
|
||||
"%(expressions)s, 'AL32UTF8'), '%(function)s')))"
|
||||
),
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
|
||||
class PostgreSQLSHAMixin:
|
||||
def as_postgresql(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template="ENCODE(DIGEST(%(expressions)s, '%(function)s'), 'hex')",
|
||||
function=self.function.lower(),
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
|
||||
class Chr(Transform):
|
||||
function = "CHR"
|
||||
lookup_name = "chr"
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
function="CHAR",
|
||||
template="%(function)s(%(expressions)s USING utf16)",
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template="%(function)s(%(expressions)s USING NCHAR_CS)",
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="CHAR", **extra_context)
|
||||
|
||||
|
||||
class ConcatPair(Func):
|
||||
"""
|
||||
Concatenate two arguments together. This is used by `Concat` because not
|
||||
all backend databases support more than two arguments.
|
||||
"""
|
||||
|
||||
function = "CONCAT"
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
coalesced = self.coalesce()
|
||||
return super(ConcatPair, coalesced).as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template="%(expressions)s",
|
||||
arg_joiner=" || ",
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
def as_postgresql(self, compiler, connection, **extra_context):
|
||||
copy = self.copy()
|
||||
copy.set_source_expressions(
|
||||
[
|
||||
Cast(expression, TextField())
|
||||
for expression in copy.get_source_expressions()
|
||||
]
|
||||
)
|
||||
return super(ConcatPair, copy).as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
# Use CONCAT_WS with an empty separator so that NULLs are ignored.
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
function="CONCAT_WS",
|
||||
template="%(function)s('', %(expressions)s)",
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
def coalesce(self):
|
||||
# null on either side results in null for expression, wrap with coalesce
|
||||
c = self.copy()
|
||||
c.set_source_expressions(
|
||||
[
|
||||
Coalesce(expression, Value(""))
|
||||
for expression in c.get_source_expressions()
|
||||
]
|
||||
)
|
||||
return c
|
||||
|
||||
|
||||
class Concat(Func):
|
||||
"""
|
||||
Concatenate text fields together. Backends that result in an entire
|
||||
null expression when any arguments are null will wrap each argument in
|
||||
coalesce functions to ensure a non-null result.
|
||||
"""
|
||||
|
||||
function = None
|
||||
template = "%(expressions)s"
|
||||
|
||||
def __init__(self, *expressions, **extra):
|
||||
if len(expressions) < 2:
|
||||
raise ValueError("Concat must take at least two expressions")
|
||||
paired = self._paired(expressions)
|
||||
super().__init__(paired, **extra)
|
||||
|
||||
def _paired(self, expressions):
|
||||
# wrap pairs of expressions in successive concat functions
|
||||
# exp = [a, b, c, d]
|
||||
# -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d))))
|
||||
if len(expressions) == 2:
|
||||
return ConcatPair(*expressions)
|
||||
return ConcatPair(expressions[0], self._paired(expressions[1:]))
|
||||
|
||||
|
||||
class Left(Func):
|
||||
function = "LEFT"
|
||||
arity = 2
|
||||
output_field = CharField()
|
||||
|
||||
def __init__(self, expression, length, **extra):
|
||||
"""
|
||||
expression: the name of a field, or an expression returning a string
|
||||
length: the number of characters to return from the start of the string
|
||||
"""
|
||||
if not hasattr(length, "resolve_expression"):
|
||||
if length < 1:
|
||||
raise ValueError("'length' must be greater than 0.")
|
||||
super().__init__(expression, length, **extra)
|
||||
|
||||
def get_substr(self):
|
||||
return Substr(self.source_expressions[0], Value(1), self.source_expressions[1])
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return self.get_substr().as_oracle(compiler, connection, **extra_context)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
return self.get_substr().as_sqlite(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Length(Transform):
|
||||
"""Return the number of characters in the expression."""
|
||||
|
||||
function = "LENGTH"
|
||||
lookup_name = "length"
|
||||
output_field = IntegerField()
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler, connection, function="CHAR_LENGTH", **extra_context
|
||||
)
|
||||
|
||||
|
||||
class Lower(Transform):
|
||||
function = "LOWER"
|
||||
lookup_name = "lower"
|
||||
|
||||
|
||||
class LPad(Func):
|
||||
function = "LPAD"
|
||||
output_field = CharField()
|
||||
|
||||
def __init__(self, expression, length, fill_text=Value(" "), **extra):
|
||||
if (
|
||||
not hasattr(length, "resolve_expression")
|
||||
and length is not None
|
||||
and length < 0
|
||||
):
|
||||
raise ValueError("'length' must be greater or equal to 0.")
|
||||
super().__init__(expression, length, fill_text, **extra)
|
||||
|
||||
|
||||
class LTrim(Transform):
|
||||
function = "LTRIM"
|
||||
lookup_name = "ltrim"
|
||||
|
||||
|
||||
class MD5(OracleHashMixin, Transform):
|
||||
function = "MD5"
|
||||
lookup_name = "md5"
|
||||
|
||||
|
||||
class Ord(Transform):
|
||||
function = "ASCII"
|
||||
lookup_name = "ord"
|
||||
output_field = IntegerField()
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="ORD", **extra_context)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="UNICODE", **extra_context)
|
||||
|
||||
|
||||
class Repeat(Func):
|
||||
function = "REPEAT"
|
||||
output_field = CharField()
|
||||
|
||||
def __init__(self, expression, number, **extra):
|
||||
if (
|
||||
not hasattr(number, "resolve_expression")
|
||||
and number is not None
|
||||
and number < 0
|
||||
):
|
||||
raise ValueError("'number' must be greater or equal to 0.")
|
||||
super().__init__(expression, number, **extra)
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
expression, number = self.source_expressions
|
||||
length = None if number is None else Length(expression) * number
|
||||
rpad = RPad(expression, length, expression)
|
||||
return rpad.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Replace(Func):
|
||||
function = "REPLACE"
|
||||
|
||||
def __init__(self, expression, text, replacement=Value(""), **extra):
|
||||
super().__init__(expression, text, replacement, **extra)
|
||||
|
||||
|
||||
class Reverse(Transform):
|
||||
function = "REVERSE"
|
||||
lookup_name = "reverse"
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
# REVERSE in Oracle is undocumented and doesn't support multi-byte
|
||||
# strings. Use a special subquery instead.
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template=(
|
||||
"(SELECT LISTAGG(s) WITHIN GROUP (ORDER BY n DESC) FROM "
|
||||
"(SELECT LEVEL n, SUBSTR(%(expressions)s, LEVEL, 1) s "
|
||||
"FROM DUAL CONNECT BY LEVEL <= LENGTH(%(expressions)s)) "
|
||||
"GROUP BY %(expressions)s)"
|
||||
),
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
|
||||
class Right(Left):
|
||||
function = "RIGHT"
|
||||
|
||||
def get_substr(self):
|
||||
return Substr(
|
||||
self.source_expressions[0], self.source_expressions[1] * Value(-1)
|
||||
)
|
||||
|
||||
|
||||
class RPad(LPad):
|
||||
function = "RPAD"
|
||||
|
||||
|
||||
class RTrim(Transform):
|
||||
function = "RTRIM"
|
||||
lookup_name = "rtrim"
|
||||
|
||||
|
||||
class SHA1(OracleHashMixin, PostgreSQLSHAMixin, Transform):
|
||||
function = "SHA1"
|
||||
lookup_name = "sha1"
|
||||
|
||||
|
||||
class SHA224(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
|
||||
function = "SHA224"
|
||||
lookup_name = "sha224"
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
raise NotSupportedError("SHA224 is not supported on Oracle.")
|
||||
|
||||
|
||||
class SHA256(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
|
||||
function = "SHA256"
|
||||
lookup_name = "sha256"
|
||||
|
||||
|
||||
class SHA384(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
|
||||
function = "SHA384"
|
||||
lookup_name = "sha384"
|
||||
|
||||
|
||||
class SHA512(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
|
||||
function = "SHA512"
|
||||
lookup_name = "sha512"
|
||||
|
||||
|
||||
class StrIndex(Func):
|
||||
"""
|
||||
Return a positive integer corresponding to the 1-indexed position of the
|
||||
first occurrence of a substring inside another string, or 0 if the
|
||||
substring is not found.
|
||||
"""
|
||||
|
||||
function = "INSTR"
|
||||
arity = 2
|
||||
output_field = IntegerField()
|
||||
|
||||
def as_postgresql(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="STRPOS", **extra_context)
|
||||
|
||||
|
||||
class Substr(Func):
|
||||
function = "SUBSTRING"
|
||||
output_field = CharField()
|
||||
|
||||
def __init__(self, expression, pos, length=None, **extra):
|
||||
"""
|
||||
expression: the name of a field, or an expression returning a string
|
||||
pos: an integer > 0, or an expression returning an integer
|
||||
length: an optional number of characters to return
|
||||
"""
|
||||
if not hasattr(pos, "resolve_expression"):
|
||||
if pos < 1:
|
||||
raise ValueError("'pos' must be greater than 0")
|
||||
expressions = [expression, pos]
|
||||
if length is not None:
|
||||
expressions.append(length)
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="SUBSTR", **extra_context)
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="SUBSTR", **extra_context)
|
||||
|
||||
|
||||
class Trim(Transform):
|
||||
function = "TRIM"
|
||||
lookup_name = "trim"
|
||||
|
||||
|
||||
class Upper(Transform):
|
||||
function = "UPPER"
|
||||
lookup_name = "upper"
|
@ -0,0 +1,120 @@
|
||||
from django.db.models.expressions import Func
|
||||
from django.db.models.fields import FloatField, IntegerField
|
||||
|
||||
__all__ = [
|
||||
"CumeDist",
|
||||
"DenseRank",
|
||||
"FirstValue",
|
||||
"Lag",
|
||||
"LastValue",
|
||||
"Lead",
|
||||
"NthValue",
|
||||
"Ntile",
|
||||
"PercentRank",
|
||||
"Rank",
|
||||
"RowNumber",
|
||||
]
|
||||
|
||||
|
||||
class CumeDist(Func):
|
||||
function = "CUME_DIST"
|
||||
output_field = FloatField()
|
||||
window_compatible = True
|
||||
|
||||
|
||||
class DenseRank(Func):
|
||||
function = "DENSE_RANK"
|
||||
output_field = IntegerField()
|
||||
window_compatible = True
|
||||
|
||||
|
||||
class FirstValue(Func):
|
||||
arity = 1
|
||||
function = "FIRST_VALUE"
|
||||
window_compatible = True
|
||||
|
||||
|
||||
class LagLeadFunction(Func):
|
||||
window_compatible = True
|
||||
|
||||
def __init__(self, expression, offset=1, default=None, **extra):
|
||||
if expression is None:
|
||||
raise ValueError(
|
||||
"%s requires a non-null source expression." % self.__class__.__name__
|
||||
)
|
||||
if offset is None or offset <= 0:
|
||||
raise ValueError(
|
||||
"%s requires a positive integer for the offset."
|
||||
% self.__class__.__name__
|
||||
)
|
||||
args = (expression, offset)
|
||||
if default is not None:
|
||||
args += (default,)
|
||||
super().__init__(*args, **extra)
|
||||
|
||||
def _resolve_output_field(self):
|
||||
sources = self.get_source_expressions()
|
||||
return sources[0].output_field
|
||||
|
||||
|
||||
class Lag(LagLeadFunction):
|
||||
function = "LAG"
|
||||
|
||||
|
||||
class LastValue(Func):
|
||||
arity = 1
|
||||
function = "LAST_VALUE"
|
||||
window_compatible = True
|
||||
|
||||
|
||||
class Lead(LagLeadFunction):
|
||||
function = "LEAD"
|
||||
|
||||
|
||||
class NthValue(Func):
|
||||
function = "NTH_VALUE"
|
||||
window_compatible = True
|
||||
|
||||
def __init__(self, expression, nth=1, **extra):
|
||||
if expression is None:
|
||||
raise ValueError(
|
||||
"%s requires a non-null source expression." % self.__class__.__name__
|
||||
)
|
||||
if nth is None or nth <= 0:
|
||||
raise ValueError(
|
||||
"%s requires a positive integer as for nth." % self.__class__.__name__
|
||||
)
|
||||
super().__init__(expression, nth, **extra)
|
||||
|
||||
def _resolve_output_field(self):
|
||||
sources = self.get_source_expressions()
|
||||
return sources[0].output_field
|
||||
|
||||
|
||||
class Ntile(Func):
|
||||
function = "NTILE"
|
||||
output_field = IntegerField()
|
||||
window_compatible = True
|
||||
|
||||
def __init__(self, num_buckets=1, **extra):
|
||||
if num_buckets <= 0:
|
||||
raise ValueError("num_buckets must be greater than 0.")
|
||||
super().__init__(num_buckets, **extra)
|
||||
|
||||
|
||||
class PercentRank(Func):
|
||||
function = "PERCENT_RANK"
|
||||
output_field = FloatField()
|
||||
window_compatible = True
|
||||
|
||||
|
||||
class Rank(Func):
|
||||
function = "RANK"
|
||||
output_field = IntegerField()
|
||||
window_compatible = True
|
||||
|
||||
|
||||
class RowNumber(Func):
|
||||
function = "ROW_NUMBER"
|
||||
output_field = IntegerField()
|
||||
window_compatible = True
|
@ -0,0 +1,295 @@
|
||||
from django.db.backends.utils import names_digest, split_identifier
|
||||
from django.db.models.expressions import Col, ExpressionList, F, Func, OrderBy
|
||||
from django.db.models.functions import Collate
|
||||
from django.db.models.query_utils import Q
|
||||
from django.db.models.sql import Query
|
||||
from django.utils.functional import partition
|
||||
|
||||
__all__ = ["Index"]
|
||||
|
||||
|
||||
class Index:
|
||||
suffix = "idx"
|
||||
# The max length of the name of the index (restricted to 30 for
|
||||
# cross-database compatibility with Oracle)
|
||||
max_name_length = 30
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*expressions,
|
||||
fields=(),
|
||||
name=None,
|
||||
db_tablespace=None,
|
||||
opclasses=(),
|
||||
condition=None,
|
||||
include=None,
|
||||
):
|
||||
if opclasses and not name:
|
||||
raise ValueError("An index must be named to use opclasses.")
|
||||
if not isinstance(condition, (type(None), Q)):
|
||||
raise ValueError("Index.condition must be a Q instance.")
|
||||
if condition and not name:
|
||||
raise ValueError("An index must be named to use condition.")
|
||||
if not isinstance(fields, (list, tuple)):
|
||||
raise ValueError("Index.fields must be a list or tuple.")
|
||||
if not isinstance(opclasses, (list, tuple)):
|
||||
raise ValueError("Index.opclasses must be a list or tuple.")
|
||||
if not expressions and not fields:
|
||||
raise ValueError(
|
||||
"At least one field or expression is required to define an index."
|
||||
)
|
||||
if expressions and fields:
|
||||
raise ValueError(
|
||||
"Index.fields and expressions are mutually exclusive.",
|
||||
)
|
||||
if expressions and not name:
|
||||
raise ValueError("An index must be named to use expressions.")
|
||||
if expressions and opclasses:
|
||||
raise ValueError(
|
||||
"Index.opclasses cannot be used with expressions. Use "
|
||||
"django.contrib.postgres.indexes.OpClass() instead."
|
||||
)
|
||||
if opclasses and len(fields) != len(opclasses):
|
||||
raise ValueError(
|
||||
"Index.fields and Index.opclasses must have the same number of "
|
||||
"elements."
|
||||
)
|
||||
if fields and not all(isinstance(field, str) for field in fields):
|
||||
raise ValueError("Index.fields must contain only strings with field names.")
|
||||
if include and not name:
|
||||
raise ValueError("A covering index must be named.")
|
||||
if not isinstance(include, (type(None), list, tuple)):
|
||||
raise ValueError("Index.include must be a list or tuple.")
|
||||
self.fields = list(fields)
|
||||
# A list of 2-tuple with the field name and ordering ('' or 'DESC').
|
||||
self.fields_orders = [
|
||||
(field_name[1:], "DESC") if field_name.startswith("-") else (field_name, "")
|
||||
for field_name in self.fields
|
||||
]
|
||||
self.name = name or ""
|
||||
self.db_tablespace = db_tablespace
|
||||
self.opclasses = opclasses
|
||||
self.condition = condition
|
||||
self.include = tuple(include) if include else ()
|
||||
self.expressions = tuple(
|
||||
F(expression) if isinstance(expression, str) else expression
|
||||
for expression in expressions
|
||||
)
|
||||
|
||||
@property
|
||||
def contains_expressions(self):
|
||||
return bool(self.expressions)
|
||||
|
||||
def _get_condition_sql(self, model, schema_editor):
|
||||
if self.condition is None:
|
||||
return None
|
||||
query = Query(model=model, alias_cols=False)
|
||||
where = query.build_where(self.condition)
|
||||
compiler = query.get_compiler(connection=schema_editor.connection)
|
||||
sql, params = where.as_sql(compiler, schema_editor.connection)
|
||||
return sql % tuple(schema_editor.quote_value(p) for p in params)
|
||||
|
||||
def create_sql(self, model, schema_editor, using="", **kwargs):
|
||||
include = [
|
||||
model._meta.get_field(field_name).column for field_name in self.include
|
||||
]
|
||||
condition = self._get_condition_sql(model, schema_editor)
|
||||
if self.expressions:
|
||||
index_expressions = []
|
||||
for expression in self.expressions:
|
||||
index_expression = IndexExpression(expression)
|
||||
index_expression.set_wrapper_classes(schema_editor.connection)
|
||||
index_expressions.append(index_expression)
|
||||
expressions = ExpressionList(*index_expressions).resolve_expression(
|
||||
Query(model, alias_cols=False),
|
||||
)
|
||||
fields = None
|
||||
col_suffixes = None
|
||||
else:
|
||||
fields = [
|
||||
model._meta.get_field(field_name)
|
||||
for field_name, _ in self.fields_orders
|
||||
]
|
||||
if schema_editor.connection.features.supports_index_column_ordering:
|
||||
col_suffixes = [order[1] for order in self.fields_orders]
|
||||
else:
|
||||
col_suffixes = [""] * len(self.fields_orders)
|
||||
expressions = None
|
||||
return schema_editor._create_index_sql(
|
||||
model,
|
||||
fields=fields,
|
||||
name=self.name,
|
||||
using=using,
|
||||
db_tablespace=self.db_tablespace,
|
||||
col_suffixes=col_suffixes,
|
||||
opclasses=self.opclasses,
|
||||
condition=condition,
|
||||
include=include,
|
||||
expressions=expressions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def remove_sql(self, model, schema_editor, **kwargs):
|
||||
return schema_editor._delete_index_sql(model, self.name, **kwargs)
|
||||
|
||||
def deconstruct(self):
|
||||
path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
|
||||
path = path.replace("django.db.models.indexes", "django.db.models")
|
||||
kwargs = {"name": self.name}
|
||||
if self.fields:
|
||||
kwargs["fields"] = self.fields
|
||||
if self.db_tablespace is not None:
|
||||
kwargs["db_tablespace"] = self.db_tablespace
|
||||
if self.opclasses:
|
||||
kwargs["opclasses"] = self.opclasses
|
||||
if self.condition:
|
||||
kwargs["condition"] = self.condition
|
||||
if self.include:
|
||||
kwargs["include"] = self.include
|
||||
return (path, self.expressions, kwargs)
|
||||
|
||||
def clone(self):
|
||||
"""Create a copy of this Index."""
|
||||
_, args, kwargs = self.deconstruct()
|
||||
return self.__class__(*args, **kwargs)
|
||||
|
||||
def set_name_with_model(self, model):
|
||||
"""
|
||||
Generate a unique name for the index.
|
||||
|
||||
The name is divided into 3 parts - table name (12 chars), field name
|
||||
(8 chars) and unique hash + suffix (10 chars). Each part is made to
|
||||
fit its size by truncating the excess length.
|
||||
"""
|
||||
_, table_name = split_identifier(model._meta.db_table)
|
||||
column_names = [
|
||||
model._meta.get_field(field_name).column
|
||||
for field_name, order in self.fields_orders
|
||||
]
|
||||
column_names_with_order = [
|
||||
(("-%s" if order else "%s") % column_name)
|
||||
for column_name, (field_name, order) in zip(
|
||||
column_names, self.fields_orders
|
||||
)
|
||||
]
|
||||
# The length of the parts of the name is based on the default max
|
||||
# length of 30 characters.
|
||||
hash_data = [table_name] + column_names_with_order + [self.suffix]
|
||||
self.name = "%s_%s_%s" % (
|
||||
table_name[:11],
|
||||
column_names[0][:7],
|
||||
"%s_%s" % (names_digest(*hash_data, length=6), self.suffix),
|
||||
)
|
||||
if len(self.name) > self.max_name_length:
|
||||
raise ValueError(
|
||||
"Index too long for multiple database support. Is self.suffix "
|
||||
"longer than 3 characters?"
|
||||
)
|
||||
if self.name[0] == "_" or self.name[0].isdigit():
|
||||
self.name = "D%s" % self.name[1:]
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s:%s%s%s%s%s%s%s>" % (
|
||||
self.__class__.__qualname__,
|
||||
"" if not self.fields else " fields=%s" % repr(self.fields),
|
||||
"" if not self.expressions else " expressions=%s" % repr(self.expressions),
|
||||
"" if not self.name else " name=%s" % repr(self.name),
|
||||
""
|
||||
if self.db_tablespace is None
|
||||
else " db_tablespace=%s" % repr(self.db_tablespace),
|
||||
"" if self.condition is None else " condition=%s" % self.condition,
|
||||
"" if not self.include else " include=%s" % repr(self.include),
|
||||
"" if not self.opclasses else " opclasses=%s" % repr(self.opclasses),
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
if self.__class__ == other.__class__:
|
||||
return self.deconstruct() == other.deconstruct()
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class IndexExpression(Func):
|
||||
"""Order and wrap expressions for CREATE INDEX statements."""
|
||||
|
||||
template = "%(expressions)s"
|
||||
wrapper_classes = (OrderBy, Collate)
|
||||
|
||||
def set_wrapper_classes(self, connection=None):
|
||||
# Some databases (e.g. MySQL) treats COLLATE as an indexed expression.
|
||||
if connection and connection.features.collate_as_index_expression:
|
||||
self.wrapper_classes = tuple(
|
||||
[
|
||||
wrapper_cls
|
||||
for wrapper_cls in self.wrapper_classes
|
||||
if wrapper_cls is not Collate
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def register_wrappers(cls, *wrapper_classes):
|
||||
cls.wrapper_classes = wrapper_classes
|
||||
|
||||
def resolve_expression(
|
||||
self,
|
||||
query=None,
|
||||
allow_joins=True,
|
||||
reuse=None,
|
||||
summarize=False,
|
||||
for_save=False,
|
||||
):
|
||||
expressions = list(self.flatten())
|
||||
# Split expressions and wrappers.
|
||||
index_expressions, wrappers = partition(
|
||||
lambda e: isinstance(e, self.wrapper_classes),
|
||||
expressions,
|
||||
)
|
||||
wrapper_types = [type(wrapper) for wrapper in wrappers]
|
||||
if len(wrapper_types) != len(set(wrapper_types)):
|
||||
raise ValueError(
|
||||
"Multiple references to %s can't be used in an indexed "
|
||||
"expression."
|
||||
% ", ".join(
|
||||
[wrapper_cls.__qualname__ for wrapper_cls in self.wrapper_classes]
|
||||
)
|
||||
)
|
||||
if expressions[1 : len(wrappers) + 1] != wrappers:
|
||||
raise ValueError(
|
||||
"%s must be topmost expressions in an indexed expression."
|
||||
% ", ".join(
|
||||
[wrapper_cls.__qualname__ for wrapper_cls in self.wrapper_classes]
|
||||
)
|
||||
)
|
||||
# Wrap expressions in parentheses if they are not column references.
|
||||
root_expression = index_expressions[1]
|
||||
resolve_root_expression = root_expression.resolve_expression(
|
||||
query,
|
||||
allow_joins,
|
||||
reuse,
|
||||
summarize,
|
||||
for_save,
|
||||
)
|
||||
if not isinstance(resolve_root_expression, Col):
|
||||
root_expression = Func(root_expression, template="(%(expressions)s)")
|
||||
|
||||
if wrappers:
|
||||
# Order wrappers and set their expressions.
|
||||
wrappers = sorted(
|
||||
wrappers,
|
||||
key=lambda w: self.wrapper_classes.index(type(w)),
|
||||
)
|
||||
wrappers = [wrapper.copy() for wrapper in wrappers]
|
||||
for i, wrapper in enumerate(wrappers[:-1]):
|
||||
wrapper.set_source_expressions([wrappers[i + 1]])
|
||||
# Set the root expression on the deepest wrapper.
|
||||
wrappers[-1].set_source_expressions([root_expression])
|
||||
self.set_source_expressions([wrappers[0]])
|
||||
else:
|
||||
# Use the root expression, if there are no wrappers.
|
||||
self.set_source_expressions([root_expression])
|
||||
return super().resolve_expression(
|
||||
query, allow_joins, reuse, summarize, for_save
|
||||
)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
# Casting to numeric is unnecessary.
|
||||
return self.as_sql(compiler, connection, **extra_context)
|
@ -0,0 +1,727 @@
|
||||
import itertools
|
||||
import math
|
||||
|
||||
from django.core.exceptions import EmptyResultSet, FullResultSet
|
||||
from django.db.models.expressions import Case, Expression, Func, Value, When
|
||||
from django.db.models.fields import (
|
||||
BooleanField,
|
||||
CharField,
|
||||
DateTimeField,
|
||||
Field,
|
||||
IntegerField,
|
||||
UUIDField,
|
||||
)
|
||||
from django.db.models.query_utils import RegisterLookupMixin
|
||||
from django.utils.datastructures import OrderedSet
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.hashable import make_hashable
|
||||
|
||||
|
||||
class Lookup(Expression):
|
||||
lookup_name = None
|
||||
prepare_rhs = True
|
||||
can_use_none_as_rhs = False
|
||||
|
||||
def __init__(self, lhs, rhs):
|
||||
self.lhs, self.rhs = lhs, rhs
|
||||
self.rhs = self.get_prep_lookup()
|
||||
self.lhs = self.get_prep_lhs()
|
||||
if hasattr(self.lhs, "get_bilateral_transforms"):
|
||||
bilateral_transforms = self.lhs.get_bilateral_transforms()
|
||||
else:
|
||||
bilateral_transforms = []
|
||||
if bilateral_transforms:
|
||||
# Warn the user as soon as possible if they are trying to apply
|
||||
# a bilateral transformation on a nested QuerySet: that won't work.
|
||||
from django.db.models.sql.query import Query # avoid circular import
|
||||
|
||||
if isinstance(rhs, Query):
|
||||
raise NotImplementedError(
|
||||
"Bilateral transformations on nested querysets are not implemented."
|
||||
)
|
||||
self.bilateral_transforms = bilateral_transforms
|
||||
|
||||
def apply_bilateral_transforms(self, value):
|
||||
for transform in self.bilateral_transforms:
|
||||
value = transform(value)
|
||||
return value
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}({self.lhs!r}, {self.rhs!r})"
|
||||
|
||||
def batch_process_rhs(self, compiler, connection, rhs=None):
|
||||
if rhs is None:
|
||||
rhs = self.rhs
|
||||
if self.bilateral_transforms:
|
||||
sqls, sqls_params = [], []
|
||||
for p in rhs:
|
||||
value = Value(p, output_field=self.lhs.output_field)
|
||||
value = self.apply_bilateral_transforms(value)
|
||||
value = value.resolve_expression(compiler.query)
|
||||
sql, sql_params = compiler.compile(value)
|
||||
sqls.append(sql)
|
||||
sqls_params.extend(sql_params)
|
||||
else:
|
||||
_, params = self.get_db_prep_lookup(rhs, connection)
|
||||
sqls, sqls_params = ["%s"] * len(params), params
|
||||
return sqls, sqls_params
|
||||
|
||||
def get_source_expressions(self):
|
||||
if self.rhs_is_direct_value():
|
||||
return [self.lhs]
|
||||
return [self.lhs, self.rhs]
|
||||
|
||||
def set_source_expressions(self, new_exprs):
|
||||
if len(new_exprs) == 1:
|
||||
self.lhs = new_exprs[0]
|
||||
else:
|
||||
self.lhs, self.rhs = new_exprs
|
||||
|
||||
def get_prep_lookup(self):
|
||||
if not self.prepare_rhs or hasattr(self.rhs, "resolve_expression"):
|
||||
return self.rhs
|
||||
if hasattr(self.lhs, "output_field"):
|
||||
if hasattr(self.lhs.output_field, "get_prep_value"):
|
||||
return self.lhs.output_field.get_prep_value(self.rhs)
|
||||
elif self.rhs_is_direct_value():
|
||||
return Value(self.rhs)
|
||||
return self.rhs
|
||||
|
||||
def get_prep_lhs(self):
|
||||
if hasattr(self.lhs, "resolve_expression"):
|
||||
return self.lhs
|
||||
return Value(self.lhs)
|
||||
|
||||
def get_db_prep_lookup(self, value, connection):
|
||||
return ("%s", [value])
|
||||
|
||||
def process_lhs(self, compiler, connection, lhs=None):
|
||||
lhs = lhs or self.lhs
|
||||
if hasattr(lhs, "resolve_expression"):
|
||||
lhs = lhs.resolve_expression(compiler.query)
|
||||
sql, params = compiler.compile(lhs)
|
||||
if isinstance(lhs, Lookup):
|
||||
# Wrapped in parentheses to respect operator precedence.
|
||||
sql = f"({sql})"
|
||||
return sql, params
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
value = self.rhs
|
||||
if self.bilateral_transforms:
|
||||
if self.rhs_is_direct_value():
|
||||
# Do not call get_db_prep_lookup here as the value will be
|
||||
# transformed before being used for lookup
|
||||
value = Value(value, output_field=self.lhs.output_field)
|
||||
value = self.apply_bilateral_transforms(value)
|
||||
value = value.resolve_expression(compiler.query)
|
||||
if hasattr(value, "as_sql"):
|
||||
sql, params = compiler.compile(value)
|
||||
# Ensure expression is wrapped in parentheses to respect operator
|
||||
# precedence but avoid double wrapping as it can be misinterpreted
|
||||
# on some backends (e.g. subqueries on SQLite).
|
||||
if sql and sql[0] != "(":
|
||||
sql = "(%s)" % sql
|
||||
return sql, params
|
||||
else:
|
||||
return self.get_db_prep_lookup(value, connection)
|
||||
|
||||
def rhs_is_direct_value(self):
|
||||
return not hasattr(self.rhs, "as_sql")
|
||||
|
||||
def get_group_by_cols(self):
|
||||
cols = []
|
||||
for source in self.get_source_expressions():
|
||||
cols.extend(source.get_group_by_cols())
|
||||
return cols
|
||||
|
||||
def as_oracle(self, compiler, connection):
|
||||
# Oracle doesn't allow EXISTS() and filters to be compared to another
|
||||
# expression unless they're wrapped in a CASE WHEN.
|
||||
wrapped = False
|
||||
exprs = []
|
||||
for expr in (self.lhs, self.rhs):
|
||||
if connection.ops.conditional_expression_supported_in_where_clause(expr):
|
||||
expr = Case(When(expr, then=True), default=False)
|
||||
wrapped = True
|
||||
exprs.append(expr)
|
||||
lookup = type(self)(*exprs) if wrapped else self
|
||||
return lookup.as_sql(compiler, connection)
|
||||
|
||||
@cached_property
|
||||
def output_field(self):
|
||||
return BooleanField()
|
||||
|
||||
@property
|
||||
def identity(self):
|
||||
return self.__class__, self.lhs, self.rhs
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Lookup):
|
||||
return NotImplemented
|
||||
return self.identity == other.identity
|
||||
|
||||
def __hash__(self):
|
||||
return hash(make_hashable(self.identity))
|
||||
|
||||
def resolve_expression(
|
||||
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
|
||||
):
|
||||
c = self.copy()
|
||||
c.is_summary = summarize
|
||||
c.lhs = self.lhs.resolve_expression(
|
||||
query, allow_joins, reuse, summarize, for_save
|
||||
)
|
||||
if hasattr(self.rhs, "resolve_expression"):
|
||||
c.rhs = self.rhs.resolve_expression(
|
||||
query, allow_joins, reuse, summarize, for_save
|
||||
)
|
||||
return c
|
||||
|
||||
def select_format(self, compiler, sql, params):
|
||||
# Wrap filters with a CASE WHEN expression if a database backend
|
||||
# (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
|
||||
# BY list.
|
||||
if not compiler.connection.features.supports_boolean_expr_in_select_clause:
|
||||
sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
|
||||
return sql, params
|
||||
|
||||
|
||||
class Transform(RegisterLookupMixin, Func):
|
||||
"""
|
||||
RegisterLookupMixin() is first so that get_lookup() and get_transform()
|
||||
first examine self and then check output_field.
|
||||
"""
|
||||
|
||||
bilateral = False
|
||||
arity = 1
|
||||
|
||||
@property
|
||||
def lhs(self):
|
||||
return self.get_source_expressions()[0]
|
||||
|
||||
def get_bilateral_transforms(self):
|
||||
if hasattr(self.lhs, "get_bilateral_transforms"):
|
||||
bilateral_transforms = self.lhs.get_bilateral_transforms()
|
||||
else:
|
||||
bilateral_transforms = []
|
||||
if self.bilateral:
|
||||
bilateral_transforms.append(self.__class__)
|
||||
return bilateral_transforms
|
||||
|
||||
|
||||
class BuiltinLookup(Lookup):
|
||||
def process_lhs(self, compiler, connection, lhs=None):
|
||||
lhs_sql, params = super().process_lhs(compiler, connection, lhs)
|
||||
field_internal_type = self.lhs.output_field.get_internal_type()
|
||||
db_type = self.lhs.output_field.db_type(connection=connection)
|
||||
lhs_sql = connection.ops.field_cast_sql(db_type, field_internal_type) % lhs_sql
|
||||
lhs_sql = (
|
||||
connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql
|
||||
)
|
||||
return lhs_sql, list(params)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
lhs_sql, params = self.process_lhs(compiler, connection)
|
||||
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
|
||||
params.extend(rhs_params)
|
||||
rhs_sql = self.get_rhs_op(connection, rhs_sql)
|
||||
return "%s %s" % (lhs_sql, rhs_sql), params
|
||||
|
||||
def get_rhs_op(self, connection, rhs):
|
||||
return connection.operators[self.lookup_name] % rhs
|
||||
|
||||
|
||||
class FieldGetDbPrepValueMixin:
|
||||
"""
|
||||
Some lookups require Field.get_db_prep_value() to be called on their
|
||||
inputs.
|
||||
"""
|
||||
|
||||
get_db_prep_lookup_value_is_iterable = False
|
||||
|
||||
def get_db_prep_lookup(self, value, connection):
|
||||
# For relational fields, use the 'target_field' attribute of the
|
||||
# output_field.
|
||||
field = getattr(self.lhs.output_field, "target_field", None)
|
||||
get_db_prep_value = (
|
||||
getattr(field, "get_db_prep_value", None)
|
||||
or self.lhs.output_field.get_db_prep_value
|
||||
)
|
||||
return (
|
||||
"%s",
|
||||
[get_db_prep_value(v, connection, prepared=True) for v in value]
|
||||
if self.get_db_prep_lookup_value_is_iterable
|
||||
else [get_db_prep_value(value, connection, prepared=True)],
|
||||
)
|
||||
|
||||
|
||||
class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
|
||||
"""
|
||||
Some lookups require Field.get_db_prep_value() to be called on each value
|
||||
in an iterable.
|
||||
"""
|
||||
|
||||
get_db_prep_lookup_value_is_iterable = True
|
||||
|
||||
def get_prep_lookup(self):
|
||||
if hasattr(self.rhs, "resolve_expression"):
|
||||
return self.rhs
|
||||
prepared_values = []
|
||||
for rhs_value in self.rhs:
|
||||
if hasattr(rhs_value, "resolve_expression"):
|
||||
# An expression will be handled by the database but can coexist
|
||||
# alongside real values.
|
||||
pass
|
||||
elif self.prepare_rhs and hasattr(self.lhs.output_field, "get_prep_value"):
|
||||
rhs_value = self.lhs.output_field.get_prep_value(rhs_value)
|
||||
prepared_values.append(rhs_value)
|
||||
return prepared_values
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
if self.rhs_is_direct_value():
|
||||
# rhs should be an iterable of values. Use batch_process_rhs()
|
||||
# to prepare/transform those values.
|
||||
return self.batch_process_rhs(compiler, connection)
|
||||
else:
|
||||
return super().process_rhs(compiler, connection)
|
||||
|
||||
def resolve_expression_parameter(self, compiler, connection, sql, param):
|
||||
params = [param]
|
||||
if hasattr(param, "resolve_expression"):
|
||||
param = param.resolve_expression(compiler.query)
|
||||
if hasattr(param, "as_sql"):
|
||||
sql, params = compiler.compile(param)
|
||||
return sql, params
|
||||
|
||||
def batch_process_rhs(self, compiler, connection, rhs=None):
|
||||
pre_processed = super().batch_process_rhs(compiler, connection, rhs)
|
||||
# The params list may contain expressions which compile to a
|
||||
# sql/param pair. Zip them to get sql and param pairs that refer to the
|
||||
# same argument and attempt to replace them with the result of
|
||||
# compiling the param step.
|
||||
sql, params = zip(
|
||||
*(
|
||||
self.resolve_expression_parameter(compiler, connection, sql, param)
|
||||
for sql, param in zip(*pre_processed)
|
||||
)
|
||||
)
|
||||
params = itertools.chain.from_iterable(params)
|
||||
return sql, tuple(params)
|
||||
|
||||
|
||||
class PostgresOperatorLookup(Lookup):
|
||||
"""Lookup defined by operators on PostgreSQL."""
|
||||
|
||||
postgres_operator = None
|
||||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||
params = tuple(lhs_params) + tuple(rhs_params)
|
||||
return "%s %s %s" % (lhs, self.postgres_operator, rhs), params
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class Exact(FieldGetDbPrepValueMixin, BuiltinLookup):
|
||||
lookup_name = "exact"
|
||||
|
||||
def get_prep_lookup(self):
|
||||
from django.db.models.sql.query import Query # avoid circular import
|
||||
|
||||
if isinstance(self.rhs, Query):
|
||||
if self.rhs.has_limit_one():
|
||||
if not self.rhs.has_select_fields:
|
||||
self.rhs.clear_select_clause()
|
||||
self.rhs.add_fields(["pk"])
|
||||
else:
|
||||
raise ValueError(
|
||||
"The QuerySet value for an exact lookup must be limited to "
|
||||
"one result using slicing."
|
||||
)
|
||||
return super().get_prep_lookup()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
# Avoid comparison against direct rhs if lhs is a boolean value. That
|
||||
# turns "boolfield__exact=True" into "WHERE boolean_field" instead of
|
||||
# "WHERE boolean_field = True" when allowed.
|
||||
if (
|
||||
isinstance(self.rhs, bool)
|
||||
and getattr(self.lhs, "conditional", False)
|
||||
and connection.ops.conditional_expression_supported_in_where_clause(
|
||||
self.lhs
|
||||
)
|
||||
):
|
||||
lhs_sql, params = self.process_lhs(compiler, connection)
|
||||
template = "%s" if self.rhs else "NOT %s"
|
||||
return template % lhs_sql, params
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class IExact(BuiltinLookup):
|
||||
lookup_name = "iexact"
|
||||
prepare_rhs = False
|
||||
|
||||
def process_rhs(self, qn, connection):
|
||||
rhs, params = super().process_rhs(qn, connection)
|
||||
if params:
|
||||
params[0] = connection.ops.prep_for_iexact_query(params[0])
|
||||
return rhs, params
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class GreaterThan(FieldGetDbPrepValueMixin, BuiltinLookup):
|
||||
lookup_name = "gt"
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class GreaterThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
|
||||
lookup_name = "gte"
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class LessThan(FieldGetDbPrepValueMixin, BuiltinLookup):
|
||||
lookup_name = "lt"
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class LessThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
|
||||
lookup_name = "lte"
|
||||
|
||||
|
||||
class IntegerFieldFloatRounding:
|
||||
"""
|
||||
Allow floats to work as query values for IntegerField. Without this, the
|
||||
decimal portion of the float would always be discarded.
|
||||
"""
|
||||
|
||||
def get_prep_lookup(self):
|
||||
if isinstance(self.rhs, float):
|
||||
self.rhs = math.ceil(self.rhs)
|
||||
return super().get_prep_lookup()
|
||||
|
||||
|
||||
@IntegerField.register_lookup
|
||||
class IntegerGreaterThanOrEqual(IntegerFieldFloatRounding, GreaterThanOrEqual):
|
||||
pass
|
||||
|
||||
|
||||
@IntegerField.register_lookup
|
||||
class IntegerLessThan(IntegerFieldFloatRounding, LessThan):
|
||||
pass
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
|
||||
lookup_name = "in"
|
||||
|
||||
def get_prep_lookup(self):
|
||||
from django.db.models.sql.query import Query # avoid circular import
|
||||
|
||||
if isinstance(self.rhs, Query):
|
||||
self.rhs.clear_ordering(clear_default=True)
|
||||
if not self.rhs.has_select_fields:
|
||||
self.rhs.clear_select_clause()
|
||||
self.rhs.add_fields(["pk"])
|
||||
return super().get_prep_lookup()
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
db_rhs = getattr(self.rhs, "_db", None)
|
||||
if db_rhs is not None and db_rhs != connection.alias:
|
||||
raise ValueError(
|
||||
"Subqueries aren't allowed across different databases. Force "
|
||||
"the inner query to be evaluated using `list(inner_query)`."
|
||||
)
|
||||
|
||||
if self.rhs_is_direct_value():
|
||||
# Remove None from the list as NULL is never equal to anything.
|
||||
try:
|
||||
rhs = OrderedSet(self.rhs)
|
||||
rhs.discard(None)
|
||||
except TypeError: # Unhashable items in self.rhs
|
||||
rhs = [r for r in self.rhs if r is not None]
|
||||
|
||||
if not rhs:
|
||||
raise EmptyResultSet
|
||||
|
||||
# rhs should be an iterable; use batch_process_rhs() to
|
||||
# prepare/transform those values.
|
||||
sqls, sqls_params = self.batch_process_rhs(compiler, connection, rhs)
|
||||
placeholder = "(" + ", ".join(sqls) + ")"
|
||||
return (placeholder, sqls_params)
|
||||
return super().process_rhs(compiler, connection)
|
||||
|
||||
def get_rhs_op(self, connection, rhs):
|
||||
return "IN %s" % rhs
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
max_in_list_size = connection.ops.max_in_list_size()
|
||||
if (
|
||||
self.rhs_is_direct_value()
|
||||
and max_in_list_size
|
||||
and len(self.rhs) > max_in_list_size
|
||||
):
|
||||
return self.split_parameter_list_as_sql(compiler, connection)
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
def split_parameter_list_as_sql(self, compiler, connection):
|
||||
# This is a special case for databases which limit the number of
|
||||
# elements which can appear in an 'IN' clause.
|
||||
max_in_list_size = connection.ops.max_in_list_size()
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.batch_process_rhs(compiler, connection)
|
||||
in_clause_elements = ["("]
|
||||
params = []
|
||||
for offset in range(0, len(rhs_params), max_in_list_size):
|
||||
if offset > 0:
|
||||
in_clause_elements.append(" OR ")
|
||||
in_clause_elements.append("%s IN (" % lhs)
|
||||
params.extend(lhs_params)
|
||||
sqls = rhs[offset : offset + max_in_list_size]
|
||||
sqls_params = rhs_params[offset : offset + max_in_list_size]
|
||||
param_group = ", ".join(sqls)
|
||||
in_clause_elements.append(param_group)
|
||||
in_clause_elements.append(")")
|
||||
params.extend(sqls_params)
|
||||
in_clause_elements.append(")")
|
||||
return "".join(in_clause_elements), params
|
||||
|
||||
|
||||
class PatternLookup(BuiltinLookup):
|
||||
param_pattern = "%%%s%%"
|
||||
prepare_rhs = False
|
||||
|
||||
def get_rhs_op(self, connection, rhs):
|
||||
# Assume we are in startswith. We need to produce SQL like:
|
||||
# col LIKE %s, ['thevalue%']
|
||||
# For python values we can (and should) do that directly in Python,
|
||||
# but if the value is for example reference to other column, then
|
||||
# we need to add the % pattern match to the lookup by something like
|
||||
# col LIKE othercol || '%%'
|
||||
# So, for Python values we don't need any special pattern, but for
|
||||
# SQL reference values or SQL transformations we need the correct
|
||||
# pattern added.
|
||||
if hasattr(self.rhs, "as_sql") or self.bilateral_transforms:
|
||||
pattern = connection.pattern_ops[self.lookup_name].format(
|
||||
connection.pattern_esc
|
||||
)
|
||||
return pattern.format(rhs)
|
||||
else:
|
||||
return super().get_rhs_op(connection, rhs)
|
||||
|
||||
def process_rhs(self, qn, connection):
|
||||
rhs, params = super().process_rhs(qn, connection)
|
||||
if self.rhs_is_direct_value() and params and not self.bilateral_transforms:
|
||||
params[0] = self.param_pattern % connection.ops.prep_for_like_query(
|
||||
params[0]
|
||||
)
|
||||
return rhs, params
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class Contains(PatternLookup):
|
||||
lookup_name = "contains"
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class IContains(Contains):
|
||||
lookup_name = "icontains"
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class StartsWith(PatternLookup):
|
||||
lookup_name = "startswith"
|
||||
param_pattern = "%s%%"
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class IStartsWith(StartsWith):
|
||||
lookup_name = "istartswith"
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class EndsWith(PatternLookup):
|
||||
lookup_name = "endswith"
|
||||
param_pattern = "%%%s"
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class IEndsWith(EndsWith):
|
||||
lookup_name = "iendswith"
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class Range(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
|
||||
lookup_name = "range"
|
||||
|
||||
def get_rhs_op(self, connection, rhs):
|
||||
return "BETWEEN %s AND %s" % (rhs[0], rhs[1])
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class IsNull(BuiltinLookup):
|
||||
lookup_name = "isnull"
|
||||
prepare_rhs = False
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if not isinstance(self.rhs, bool):
|
||||
raise ValueError(
|
||||
"The QuerySet value for an isnull lookup must be True or False."
|
||||
)
|
||||
if isinstance(self.lhs, Value):
|
||||
if self.lhs.value is None or (
|
||||
self.lhs.value == ""
|
||||
and connection.features.interprets_empty_strings_as_nulls
|
||||
):
|
||||
result_exception = FullResultSet if self.rhs else EmptyResultSet
|
||||
else:
|
||||
result_exception = EmptyResultSet if self.rhs else FullResultSet
|
||||
raise result_exception
|
||||
sql, params = self.process_lhs(compiler, connection)
|
||||
if self.rhs:
|
||||
return "%s IS NULL" % sql, params
|
||||
else:
|
||||
return "%s IS NOT NULL" % sql, params
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class Regex(BuiltinLookup):
|
||||
lookup_name = "regex"
|
||||
prepare_rhs = False
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if self.lookup_name in connection.operators:
|
||||
return super().as_sql(compiler, connection)
|
||||
else:
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||
sql_template = connection.ops.regex_lookup(self.lookup_name)
|
||||
return sql_template % (lhs, rhs), lhs_params + rhs_params
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class IRegex(Regex):
|
||||
lookup_name = "iregex"
|
||||
|
||||
|
||||
class YearLookup(Lookup):
|
||||
def year_lookup_bounds(self, connection, year):
|
||||
from django.db.models.functions import ExtractIsoYear
|
||||
|
||||
iso_year = isinstance(self.lhs, ExtractIsoYear)
|
||||
output_field = self.lhs.lhs.output_field
|
||||
if isinstance(output_field, DateTimeField):
|
||||
bounds = connection.ops.year_lookup_bounds_for_datetime_field(
|
||||
year,
|
||||
iso_year=iso_year,
|
||||
)
|
||||
else:
|
||||
bounds = connection.ops.year_lookup_bounds_for_date_field(
|
||||
year,
|
||||
iso_year=iso_year,
|
||||
)
|
||||
return bounds
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
# Avoid the extract operation if the rhs is a direct value to allow
|
||||
# indexes to be used.
|
||||
if self.rhs_is_direct_value():
|
||||
# Skip the extract part by directly using the originating field,
|
||||
# that is self.lhs.lhs.
|
||||
lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
|
||||
rhs_sql, _ = self.process_rhs(compiler, connection)
|
||||
rhs_sql = self.get_direct_rhs_sql(connection, rhs_sql)
|
||||
start, finish = self.year_lookup_bounds(connection, self.rhs)
|
||||
params.extend(self.get_bound_params(start, finish))
|
||||
return "%s %s" % (lhs_sql, rhs_sql), params
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
def get_direct_rhs_sql(self, connection, rhs):
|
||||
return connection.operators[self.lookup_name] % rhs
|
||||
|
||||
def get_bound_params(self, start, finish):
|
||||
raise NotImplementedError(
|
||||
"subclasses of YearLookup must provide a get_bound_params() method"
|
||||
)
|
||||
|
||||
|
||||
class YearExact(YearLookup, Exact):
|
||||
def get_direct_rhs_sql(self, connection, rhs):
|
||||
return "BETWEEN %s AND %s"
|
||||
|
||||
def get_bound_params(self, start, finish):
|
||||
return (start, finish)
|
||||
|
||||
|
||||
class YearGt(YearLookup, GreaterThan):
|
||||
def get_bound_params(self, start, finish):
|
||||
return (finish,)
|
||||
|
||||
|
||||
class YearGte(YearLookup, GreaterThanOrEqual):
|
||||
def get_bound_params(self, start, finish):
|
||||
return (start,)
|
||||
|
||||
|
||||
class YearLt(YearLookup, LessThan):
|
||||
def get_bound_params(self, start, finish):
|
||||
return (start,)
|
||||
|
||||
|
||||
class YearLte(YearLookup, LessThanOrEqual):
|
||||
def get_bound_params(self, start, finish):
|
||||
return (finish,)
|
||||
|
||||
|
||||
class UUIDTextMixin:
|
||||
"""
|
||||
Strip hyphens from a value when filtering a UUIDField on backends without
|
||||
a native datatype for UUID.
|
||||
"""
|
||||
|
||||
def process_rhs(self, qn, connection):
|
||||
if not connection.features.has_native_uuid_field:
|
||||
from django.db.models.functions import Replace
|
||||
|
||||
if self.rhs_is_direct_value():
|
||||
self.rhs = Value(self.rhs)
|
||||
self.rhs = Replace(
|
||||
self.rhs, Value("-"), Value(""), output_field=CharField()
|
||||
)
|
||||
rhs, params = super().process_rhs(qn, connection)
|
||||
return rhs, params
|
||||
|
||||
|
||||
@UUIDField.register_lookup
|
||||
class UUIDIExact(UUIDTextMixin, IExact):
|
||||
pass
|
||||
|
||||
|
||||
@UUIDField.register_lookup
|
||||
class UUIDContains(UUIDTextMixin, Contains):
|
||||
pass
|
||||
|
||||
|
||||
@UUIDField.register_lookup
|
||||
class UUIDIContains(UUIDTextMixin, IContains):
|
||||
pass
|
||||
|
||||
|
||||
@UUIDField.register_lookup
|
||||
class UUIDStartsWith(UUIDTextMixin, StartsWith):
|
||||
pass
|
||||
|
||||
|
||||
@UUIDField.register_lookup
|
||||
class UUIDIStartsWith(UUIDTextMixin, IStartsWith):
|
||||
pass
|
||||
|
||||
|
||||
@UUIDField.register_lookup
|
||||
class UUIDEndsWith(UUIDTextMixin, EndsWith):
|
||||
pass
|
||||
|
||||
|
||||
@UUIDField.register_lookup
|
||||
class UUIDIEndsWith(UUIDTextMixin, IEndsWith):
|
||||
pass
|
@ -0,0 +1,213 @@
|
||||
import copy
|
||||
import inspect
|
||||
from functools import wraps
|
||||
from importlib import import_module
|
||||
|
||||
from django.db import router
|
||||
from django.db.models.query import QuerySet
|
||||
|
||||
|
||||
class BaseManager:
|
||||
# To retain order, track each time a Manager instance is created.
|
||||
creation_counter = 0
|
||||
|
||||
# Set to True for the 'objects' managers that are automatically created.
|
||||
auto_created = False
|
||||
|
||||
#: If set to True the manager will be serialized into migrations and will
|
||||
#: thus be available in e.g. RunPython operations.
|
||||
use_in_migrations = False
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# Capture the arguments to make returning them trivial.
|
||||
obj = super().__new__(cls)
|
||||
obj._constructor_args = (args, kwargs)
|
||||
return obj
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._set_creation_counter()
|
||||
self.model = None
|
||||
self.name = None
|
||||
self._db = None
|
||||
self._hints = {}
|
||||
|
||||
def __str__(self):
|
||||
"""Return "app_label.model_label.manager_name"."""
|
||||
return "%s.%s" % (self.model._meta.label, self.name)
|
||||
|
||||
def __class_getitem__(cls, *args, **kwargs):
|
||||
return cls
|
||||
|
||||
def deconstruct(self):
|
||||
"""
|
||||
Return a 5-tuple of the form (as_manager (True), manager_class,
|
||||
queryset_class, args, kwargs).
|
||||
|
||||
Raise a ValueError if the manager is dynamically generated.
|
||||
"""
|
||||
qs_class = self._queryset_class
|
||||
if getattr(self, "_built_with_as_manager", False):
|
||||
# using MyQuerySet.as_manager()
|
||||
return (
|
||||
True, # as_manager
|
||||
None, # manager_class
|
||||
"%s.%s" % (qs_class.__module__, qs_class.__name__), # qs_class
|
||||
None, # args
|
||||
None, # kwargs
|
||||
)
|
||||
else:
|
||||
module_name = self.__module__
|
||||
name = self.__class__.__name__
|
||||
# Make sure it's actually there and not an inner class
|
||||
module = import_module(module_name)
|
||||
if not hasattr(module, name):
|
||||
raise ValueError(
|
||||
"Could not find manager %s in %s.\n"
|
||||
"Please note that you need to inherit from managers you "
|
||||
"dynamically generated with 'from_queryset()'."
|
||||
% (name, module_name)
|
||||
)
|
||||
return (
|
||||
False, # as_manager
|
||||
"%s.%s" % (module_name, name), # manager_class
|
||||
None, # qs_class
|
||||
self._constructor_args[0], # args
|
||||
self._constructor_args[1], # kwargs
|
||||
)
|
||||
|
||||
def check(self, **kwargs):
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _get_queryset_methods(cls, queryset_class):
|
||||
def create_method(name, method):
|
||||
@wraps(method)
|
||||
def manager_method(self, *args, **kwargs):
|
||||
return getattr(self.get_queryset(), name)(*args, **kwargs)
|
||||
|
||||
return manager_method
|
||||
|
||||
new_methods = {}
|
||||
for name, method in inspect.getmembers(
|
||||
queryset_class, predicate=inspect.isfunction
|
||||
):
|
||||
# Only copy missing methods.
|
||||
if hasattr(cls, name):
|
||||
continue
|
||||
# Only copy public methods or methods with the attribute
|
||||
# queryset_only=False.
|
||||
queryset_only = getattr(method, "queryset_only", None)
|
||||
if queryset_only or (queryset_only is None and name.startswith("_")):
|
||||
continue
|
||||
# Copy the method onto the manager.
|
||||
new_methods[name] = create_method(name, method)
|
||||
return new_methods
|
||||
|
||||
@classmethod
|
||||
def from_queryset(cls, queryset_class, class_name=None):
|
||||
if class_name is None:
|
||||
class_name = "%sFrom%s" % (cls.__name__, queryset_class.__name__)
|
||||
return type(
|
||||
class_name,
|
||||
(cls,),
|
||||
{
|
||||
"_queryset_class": queryset_class,
|
||||
**cls._get_queryset_methods(queryset_class),
|
||||
},
|
||||
)
|
||||
|
||||
def contribute_to_class(self, cls, name):
|
||||
self.name = self.name or name
|
||||
self.model = cls
|
||||
|
||||
setattr(cls, name, ManagerDescriptor(self))
|
||||
|
||||
cls._meta.add_manager(self)
|
||||
|
||||
def _set_creation_counter(self):
|
||||
"""
|
||||
Set the creation counter value for this instance and increment the
|
||||
class-level copy.
|
||||
"""
|
||||
self.creation_counter = BaseManager.creation_counter
|
||||
BaseManager.creation_counter += 1
|
||||
|
||||
def db_manager(self, using=None, hints=None):
|
||||
obj = copy.copy(self)
|
||||
obj._db = using or self._db
|
||||
obj._hints = hints or self._hints
|
||||
return obj
|
||||
|
||||
@property
|
||||
def db(self):
|
||||
return self._db or router.db_for_read(self.model, **self._hints)
|
||||
|
||||
#######################
|
||||
# PROXIES TO QUERYSET #
|
||||
#######################
|
||||
|
||||
def get_queryset(self):
|
||||
"""
|
||||
Return a new QuerySet object. Subclasses can override this method to
|
||||
customize the behavior of the Manager.
|
||||
"""
|
||||
return self._queryset_class(model=self.model, using=self._db, hints=self._hints)
|
||||
|
||||
def all(self):
|
||||
# We can't proxy this method through the `QuerySet` like we do for the
|
||||
# rest of the `QuerySet` methods. This is because `QuerySet.all()`
|
||||
# works by creating a "copy" of the current queryset and in making said
|
||||
# copy, all the cached `prefetch_related` lookups are lost. See the
|
||||
# implementation of `RelatedManager.get_queryset()` for a better
|
||||
# understanding of how this comes into play.
|
||||
return self.get_queryset()
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
isinstance(other, self.__class__)
|
||||
and self._constructor_args == other._constructor_args
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
|
||||
class Manager(BaseManager.from_queryset(QuerySet)):
|
||||
pass
|
||||
|
||||
|
||||
class ManagerDescriptor:
|
||||
def __init__(self, manager):
|
||||
self.manager = manager
|
||||
|
||||
def __get__(self, instance, cls=None):
|
||||
if instance is not None:
|
||||
raise AttributeError(
|
||||
"Manager isn't accessible via %s instances" % cls.__name__
|
||||
)
|
||||
|
||||
if cls._meta.abstract:
|
||||
raise AttributeError(
|
||||
"Manager isn't available; %s is abstract" % (cls._meta.object_name,)
|
||||
)
|
||||
|
||||
if cls._meta.swapped:
|
||||
raise AttributeError(
|
||||
"Manager isn't available; '%s' has been swapped for '%s'"
|
||||
% (
|
||||
cls._meta.label,
|
||||
cls._meta.swapped,
|
||||
)
|
||||
)
|
||||
|
||||
return cls._meta.managers_map[self.manager.name]
|
||||
|
||||
|
||||
class EmptyManager(Manager):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def get_queryset(self):
|
||||
return super().get_queryset().none()
|
1014
srcs/.venv/lib/python3.11/site-packages/django/db/models/options.py
Normal file
1014
srcs/.venv/lib/python3.11/site-packages/django/db/models/options.py
Normal file
File diff suppressed because it is too large
Load Diff
2631
srcs/.venv/lib/python3.11/site-packages/django/db/models/query.py
Normal file
2631
srcs/.venv/lib/python3.11/site-packages/django/db/models/query.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,435 @@
|
||||
"""
|
||||
Various data structures used in query construction.
|
||||
|
||||
Factored out from django.db.models.query to avoid making the main module very
|
||||
large and/or so that they can be used by other modules without getting into
|
||||
circular import difficulties.
|
||||
"""
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db import DEFAULT_DB_ALIAS, DatabaseError, connections
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.utils import tree
|
||||
|
||||
logger = logging.getLogger("django.db.models")
|
||||
|
||||
# PathInfo is used when converting lookups (fk__somecol). The contents
|
||||
# describe the relation in Model terms (model Options and Fields for both
|
||||
# sides of the relation. The join_field is the field backing the relation.
|
||||
PathInfo = namedtuple(
|
||||
"PathInfo",
|
||||
"from_opts to_opts target_fields join_field m2m direct filtered_relation",
|
||||
)
|
||||
|
||||
|
||||
def subclasses(cls):
|
||||
yield cls
|
||||
for subclass in cls.__subclasses__():
|
||||
yield from subclasses(subclass)
|
||||
|
||||
|
||||
class Q(tree.Node):
|
||||
"""
|
||||
Encapsulate filters as objects that can then be combined logically (using
|
||||
`&` and `|`).
|
||||
"""
|
||||
|
||||
# Connection types
|
||||
AND = "AND"
|
||||
OR = "OR"
|
||||
XOR = "XOR"
|
||||
default = AND
|
||||
conditional = True
|
||||
|
||||
def __init__(self, *args, _connector=None, _negated=False, **kwargs):
|
||||
super().__init__(
|
||||
children=[*args, *sorted(kwargs.items())],
|
||||
connector=_connector,
|
||||
negated=_negated,
|
||||
)
|
||||
|
||||
def _combine(self, other, conn):
|
||||
if getattr(other, "conditional", False) is False:
|
||||
raise TypeError(other)
|
||||
if not self:
|
||||
return other.copy()
|
||||
if not other and isinstance(other, Q):
|
||||
return self.copy()
|
||||
|
||||
obj = self.create(connector=conn)
|
||||
obj.add(self, conn)
|
||||
obj.add(other, conn)
|
||||
return obj
|
||||
|
||||
def __or__(self, other):
|
||||
return self._combine(other, self.OR)
|
||||
|
||||
def __and__(self, other):
|
||||
return self._combine(other, self.AND)
|
||||
|
||||
def __xor__(self, other):
|
||||
return self._combine(other, self.XOR)
|
||||
|
||||
def __invert__(self):
|
||||
obj = self.copy()
|
||||
obj.negate()
|
||||
return obj
|
||||
|
||||
def resolve_expression(
|
||||
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
|
||||
):
|
||||
# We must promote any new joins to left outer joins so that when Q is
|
||||
# used as an expression, rows aren't filtered due to joins.
|
||||
clause, joins = query._add_q(
|
||||
self,
|
||||
reuse,
|
||||
allow_joins=allow_joins,
|
||||
split_subq=False,
|
||||
check_filterable=False,
|
||||
summarize=summarize,
|
||||
)
|
||||
query.promote_joins(joins)
|
||||
return clause
|
||||
|
||||
def flatten(self):
|
||||
"""
|
||||
Recursively yield this Q object and all subexpressions, in depth-first
|
||||
order.
|
||||
"""
|
||||
yield self
|
||||
for child in self.children:
|
||||
if isinstance(child, tuple):
|
||||
# Use the lookup.
|
||||
child = child[1]
|
||||
if hasattr(child, "flatten"):
|
||||
yield from child.flatten()
|
||||
else:
|
||||
yield child
|
||||
|
||||
def check(self, against, using=DEFAULT_DB_ALIAS):
|
||||
"""
|
||||
Do a database query to check if the expressions of the Q instance
|
||||
matches against the expressions.
|
||||
"""
|
||||
# Avoid circular imports.
|
||||
from django.db.models import BooleanField, Value
|
||||
from django.db.models.functions import Coalesce
|
||||
from django.db.models.sql import Query
|
||||
from django.db.models.sql.constants import SINGLE
|
||||
|
||||
query = Query(None)
|
||||
for name, value in against.items():
|
||||
if not hasattr(value, "resolve_expression"):
|
||||
value = Value(value)
|
||||
query.add_annotation(value, name, select=False)
|
||||
query.add_annotation(Value(1), "_check")
|
||||
# This will raise a FieldError if a field is missing in "against".
|
||||
if connections[using].features.supports_comparing_boolean_expr:
|
||||
query.add_q(Q(Coalesce(self, True, output_field=BooleanField())))
|
||||
else:
|
||||
query.add_q(self)
|
||||
compiler = query.get_compiler(using=using)
|
||||
try:
|
||||
return compiler.execute_sql(SINGLE) is not None
|
||||
except DatabaseError as e:
|
||||
logger.warning("Got a database error calling check() on %r: %s", self, e)
|
||||
return True
|
||||
|
||||
def deconstruct(self):
|
||||
path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
|
||||
if path.startswith("django.db.models.query_utils"):
|
||||
path = path.replace("django.db.models.query_utils", "django.db.models")
|
||||
args = tuple(self.children)
|
||||
kwargs = {}
|
||||
if self.connector != self.default:
|
||||
kwargs["_connector"] = self.connector
|
||||
if self.negated:
|
||||
kwargs["_negated"] = True
|
||||
return path, args, kwargs
|
||||
|
||||
|
||||
class DeferredAttribute:
|
||||
"""
|
||||
A wrapper for a deferred-loading field. When the value is read from this
|
||||
object the first time, the query is executed.
|
||||
"""
|
||||
|
||||
def __init__(self, field):
|
||||
self.field = field
|
||||
|
||||
def __get__(self, instance, cls=None):
|
||||
"""
|
||||
Retrieve and caches the value from the datastore on the first lookup.
|
||||
Return the cached value.
|
||||
"""
|
||||
if instance is None:
|
||||
return self
|
||||
data = instance.__dict__
|
||||
field_name = self.field.attname
|
||||
if field_name not in data:
|
||||
# Let's see if the field is part of the parent chain. If so we
|
||||
# might be able to reuse the already loaded value. Refs #18343.
|
||||
val = self._check_parent_chain(instance)
|
||||
if val is None:
|
||||
instance.refresh_from_db(fields=[field_name])
|
||||
else:
|
||||
data[field_name] = val
|
||||
return data[field_name]
|
||||
|
||||
def _check_parent_chain(self, instance):
|
||||
"""
|
||||
Check if the field value can be fetched from a parent field already
|
||||
loaded in the instance. This can be done if the to-be fetched
|
||||
field is a primary key field.
|
||||
"""
|
||||
opts = instance._meta
|
||||
link_field = opts.get_ancestor_link(self.field.model)
|
||||
if self.field.primary_key and self.field != link_field:
|
||||
return getattr(instance, link_field.attname)
|
||||
return None
|
||||
|
||||
|
||||
class class_or_instance_method:
|
||||
"""
|
||||
Hook used in RegisterLookupMixin to return partial functions depending on
|
||||
the caller type (instance or class of models.Field).
|
||||
"""
|
||||
|
||||
def __init__(self, class_method, instance_method):
|
||||
self.class_method = class_method
|
||||
self.instance_method = instance_method
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
if instance is None:
|
||||
return functools.partial(self.class_method, owner)
|
||||
return functools.partial(self.instance_method, instance)
|
||||
|
||||
|
||||
class RegisterLookupMixin:
|
||||
def _get_lookup(self, lookup_name):
|
||||
return self.get_lookups().get(lookup_name, None)
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_class_lookups(cls):
|
||||
class_lookups = [
|
||||
parent.__dict__.get("class_lookups", {}) for parent in inspect.getmro(cls)
|
||||
]
|
||||
return cls.merge_dicts(class_lookups)
|
||||
|
||||
def get_instance_lookups(self):
|
||||
class_lookups = self.get_class_lookups()
|
||||
if instance_lookups := getattr(self, "instance_lookups", None):
|
||||
return {**class_lookups, **instance_lookups}
|
||||
return class_lookups
|
||||
|
||||
get_lookups = class_or_instance_method(get_class_lookups, get_instance_lookups)
|
||||
get_class_lookups = classmethod(get_class_lookups)
|
||||
|
||||
def get_lookup(self, lookup_name):
|
||||
from django.db.models.lookups import Lookup
|
||||
|
||||
found = self._get_lookup(lookup_name)
|
||||
if found is None and hasattr(self, "output_field"):
|
||||
return self.output_field.get_lookup(lookup_name)
|
||||
if found is not None and not issubclass(found, Lookup):
|
||||
return None
|
||||
return found
|
||||
|
||||
def get_transform(self, lookup_name):
|
||||
from django.db.models.lookups import Transform
|
||||
|
||||
found = self._get_lookup(lookup_name)
|
||||
if found is None and hasattr(self, "output_field"):
|
||||
return self.output_field.get_transform(lookup_name)
|
||||
if found is not None and not issubclass(found, Transform):
|
||||
return None
|
||||
return found
|
||||
|
||||
@staticmethod
|
||||
def merge_dicts(dicts):
|
||||
"""
|
||||
Merge dicts in reverse to preference the order of the original list. e.g.,
|
||||
merge_dicts([a, b]) will preference the keys in 'a' over those in 'b'.
|
||||
"""
|
||||
merged = {}
|
||||
for d in reversed(dicts):
|
||||
merged.update(d)
|
||||
return merged
|
||||
|
||||
@classmethod
|
||||
def _clear_cached_class_lookups(cls):
|
||||
for subclass in subclasses(cls):
|
||||
subclass.get_class_lookups.cache_clear()
|
||||
|
||||
def register_class_lookup(cls, lookup, lookup_name=None):
|
||||
if lookup_name is None:
|
||||
lookup_name = lookup.lookup_name
|
||||
if "class_lookups" not in cls.__dict__:
|
||||
cls.class_lookups = {}
|
||||
cls.class_lookups[lookup_name] = lookup
|
||||
cls._clear_cached_class_lookups()
|
||||
return lookup
|
||||
|
||||
def register_instance_lookup(self, lookup, lookup_name=None):
|
||||
if lookup_name is None:
|
||||
lookup_name = lookup.lookup_name
|
||||
if "instance_lookups" not in self.__dict__:
|
||||
self.instance_lookups = {}
|
||||
self.instance_lookups[lookup_name] = lookup
|
||||
return lookup
|
||||
|
||||
register_lookup = class_or_instance_method(
|
||||
register_class_lookup, register_instance_lookup
|
||||
)
|
||||
register_class_lookup = classmethod(register_class_lookup)
|
||||
|
||||
def _unregister_class_lookup(cls, lookup, lookup_name=None):
|
||||
"""
|
||||
Remove given lookup from cls lookups. For use in tests only as it's
|
||||
not thread-safe.
|
||||
"""
|
||||
if lookup_name is None:
|
||||
lookup_name = lookup.lookup_name
|
||||
del cls.class_lookups[lookup_name]
|
||||
cls._clear_cached_class_lookups()
|
||||
|
||||
def _unregister_instance_lookup(self, lookup, lookup_name=None):
|
||||
"""
|
||||
Remove given lookup from instance lookups. For use in tests only as
|
||||
it's not thread-safe.
|
||||
"""
|
||||
if lookup_name is None:
|
||||
lookup_name = lookup.lookup_name
|
||||
del self.instance_lookups[lookup_name]
|
||||
|
||||
_unregister_lookup = class_or_instance_method(
|
||||
_unregister_class_lookup, _unregister_instance_lookup
|
||||
)
|
||||
_unregister_class_lookup = classmethod(_unregister_class_lookup)
|
||||
|
||||
|
||||
def select_related_descend(field, restricted, requested, select_mask, reverse=False):
|
||||
"""
|
||||
Return True if this field should be used to descend deeper for
|
||||
select_related() purposes. Used by both the query construction code
|
||||
(compiler.get_related_selections()) and the model instance creation code
|
||||
(compiler.klass_info).
|
||||
|
||||
Arguments:
|
||||
* field - the field to be checked
|
||||
* restricted - a boolean field, indicating if the field list has been
|
||||
manually restricted using a requested clause)
|
||||
* requested - The select_related() dictionary.
|
||||
* select_mask - the dictionary of selected fields.
|
||||
* reverse - boolean, True if we are checking a reverse select related
|
||||
"""
|
||||
if not field.remote_field:
|
||||
return False
|
||||
if field.remote_field.parent_link and not reverse:
|
||||
return False
|
||||
if restricted:
|
||||
if reverse and field.related_query_name() not in requested:
|
||||
return False
|
||||
if not reverse and field.name not in requested:
|
||||
return False
|
||||
if not restricted and field.null:
|
||||
return False
|
||||
if (
|
||||
restricted
|
||||
and select_mask
|
||||
and field.name in requested
|
||||
and field not in select_mask
|
||||
):
|
||||
raise FieldError(
|
||||
f"Field {field.model._meta.object_name}.{field.name} cannot be both "
|
||||
"deferred and traversed using select_related at the same time."
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def refs_expression(lookup_parts, annotations):
|
||||
"""
|
||||
Check if the lookup_parts contains references to the given annotations set.
|
||||
Because the LOOKUP_SEP is contained in the default annotation names, check
|
||||
each prefix of the lookup_parts for a match.
|
||||
"""
|
||||
for n in range(1, len(lookup_parts) + 1):
|
||||
level_n_lookup = LOOKUP_SEP.join(lookup_parts[0:n])
|
||||
if annotations.get(level_n_lookup):
|
||||
return level_n_lookup, lookup_parts[n:]
|
||||
return None, ()
|
||||
|
||||
|
||||
def check_rel_lookup_compatibility(model, target_opts, field):
|
||||
"""
|
||||
Check that self.model is compatible with target_opts. Compatibility
|
||||
is OK if:
|
||||
1) model and opts match (where proxy inheritance is removed)
|
||||
2) model is parent of opts' model or the other way around
|
||||
"""
|
||||
|
||||
def check(opts):
|
||||
return (
|
||||
model._meta.concrete_model == opts.concrete_model
|
||||
or opts.concrete_model in model._meta.get_parent_list()
|
||||
or model in opts.get_parent_list()
|
||||
)
|
||||
|
||||
# If the field is a primary key, then doing a query against the field's
|
||||
# model is ok, too. Consider the case:
|
||||
# class Restaurant(models.Model):
|
||||
# place = OneToOneField(Place, primary_key=True):
|
||||
# Restaurant.objects.filter(pk__in=Restaurant.objects.all()).
|
||||
# If we didn't have the primary key check, then pk__in (== place__in) would
|
||||
# give Place's opts as the target opts, but Restaurant isn't compatible
|
||||
# with that. This logic applies only to primary keys, as when doing __in=qs,
|
||||
# we are going to turn this into __in=qs.values('pk') later on.
|
||||
return check(target_opts) or (
|
||||
getattr(field, "primary_key", False) and check(field.model._meta)
|
||||
)
|
||||
|
||||
|
||||
class FilteredRelation:
|
||||
"""Specify custom filtering in the ON clause of SQL joins."""
|
||||
|
||||
def __init__(self, relation_name, *, condition=Q()):
|
||||
if not relation_name:
|
||||
raise ValueError("relation_name cannot be empty.")
|
||||
self.relation_name = relation_name
|
||||
self.alias = None
|
||||
if not isinstance(condition, Q):
|
||||
raise ValueError("condition argument must be a Q() instance.")
|
||||
self.condition = condition
|
||||
self.path = []
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, self.__class__):
|
||||
return NotImplemented
|
||||
return (
|
||||
self.relation_name == other.relation_name
|
||||
and self.alias == other.alias
|
||||
and self.condition == other.condition
|
||||
)
|
||||
|
||||
def clone(self):
|
||||
clone = FilteredRelation(self.relation_name, condition=self.condition)
|
||||
clone.alias = self.alias
|
||||
clone.path = self.path[:]
|
||||
return clone
|
||||
|
||||
def resolve_expression(self, *args, **kwargs):
|
||||
"""
|
||||
QuerySet.annotate() only accepts expression-like arguments
|
||||
(with a resolve_expression() method).
|
||||
"""
|
||||
raise NotImplementedError("FilteredRelation.resolve_expression() is unused.")
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
# Resolve the condition in Join.filtered_relation.
|
||||
query = compiler.query
|
||||
where = query.build_filtered_relation_q(self.condition, reuse=set(self.path))
|
||||
return compiler.compile(where)
|
@ -0,0 +1,54 @@
|
||||
from functools import partial
|
||||
|
||||
from django.db.models.utils import make_model_tuple
|
||||
from django.dispatch import Signal
|
||||
|
||||
class_prepared = Signal()
|
||||
|
||||
|
||||
class ModelSignal(Signal):
|
||||
"""
|
||||
Signal subclass that allows the sender to be lazily specified as a string
|
||||
of the `app_label.ModelName` form.
|
||||
"""
|
||||
|
||||
def _lazy_method(self, method, apps, receiver, sender, **kwargs):
|
||||
from django.db.models.options import Options
|
||||
|
||||
# This partial takes a single optional argument named "sender".
|
||||
partial_method = partial(method, receiver, **kwargs)
|
||||
if isinstance(sender, str):
|
||||
apps = apps or Options.default_apps
|
||||
apps.lazy_model_operation(partial_method, make_model_tuple(sender))
|
||||
else:
|
||||
return partial_method(sender)
|
||||
|
||||
def connect(self, receiver, sender=None, weak=True, dispatch_uid=None, apps=None):
|
||||
self._lazy_method(
|
||||
super().connect,
|
||||
apps,
|
||||
receiver,
|
||||
sender,
|
||||
weak=weak,
|
||||
dispatch_uid=dispatch_uid,
|
||||
)
|
||||
|
||||
def disconnect(self, receiver=None, sender=None, dispatch_uid=None, apps=None):
|
||||
return self._lazy_method(
|
||||
super().disconnect, apps, receiver, sender, dispatch_uid=dispatch_uid
|
||||
)
|
||||
|
||||
|
||||
pre_init = ModelSignal(use_caching=True)
|
||||
post_init = ModelSignal(use_caching=True)
|
||||
|
||||
pre_save = ModelSignal(use_caching=True)
|
||||
post_save = ModelSignal(use_caching=True)
|
||||
|
||||
pre_delete = ModelSignal(use_caching=True)
|
||||
post_delete = ModelSignal(use_caching=True)
|
||||
|
||||
m2m_changed = ModelSignal(use_caching=True)
|
||||
|
||||
pre_migrate = Signal()
|
||||
post_migrate = Signal()
|
@ -0,0 +1,6 @@
|
||||
from django.db.models.sql.query import * # NOQA
|
||||
from django.db.models.sql.query import Query
|
||||
from django.db.models.sql.subqueries import * # NOQA
|
||||
from django.db.models.sql.where import AND, OR, XOR
|
||||
|
||||
__all__ = ["Query", "AND", "OR", "XOR"]
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,24 @@
|
||||
"""
|
||||
Constants specific to the SQL storage portion of the ORM.
|
||||
"""
|
||||
|
||||
# Size of each "chunk" for get_iterator calls.
|
||||
# Larger values are slightly faster at the expense of more storage space.
|
||||
GET_ITERATOR_CHUNK_SIZE = 100
|
||||
|
||||
# Namedtuples for sql.* internal use.
|
||||
|
||||
# How many results to expect from a cursor.execute call
|
||||
MULTI = "multi"
|
||||
SINGLE = "single"
|
||||
CURSOR = "cursor"
|
||||
NO_RESULTS = "no results"
|
||||
|
||||
ORDER_DIR = {
|
||||
"ASC": ("ASC", "DESC"),
|
||||
"DESC": ("DESC", "ASC"),
|
||||
}
|
||||
|
||||
# SQL join types.
|
||||
INNER = "INNER JOIN"
|
||||
LOUTER = "LEFT OUTER JOIN"
|
@ -0,0 +1,224 @@
|
||||
"""
|
||||
Useful auxiliary data structures for query construction. Not useful outside
|
||||
the SQL domain.
|
||||
"""
|
||||
from django.core.exceptions import FullResultSet
|
||||
from django.db.models.sql.constants import INNER, LOUTER
|
||||
|
||||
|
||||
class MultiJoin(Exception):
|
||||
"""
|
||||
Used by join construction code to indicate the point at which a
|
||||
multi-valued join was attempted (if the caller wants to treat that
|
||||
exceptionally).
|
||||
"""
|
||||
|
||||
def __init__(self, names_pos, path_with_names):
|
||||
self.level = names_pos
|
||||
# The path travelled, this includes the path to the multijoin.
|
||||
self.names_with_path = path_with_names
|
||||
|
||||
|
||||
class Empty:
|
||||
pass
|
||||
|
||||
|
||||
class Join:
|
||||
"""
|
||||
Used by sql.Query and sql.SQLCompiler to generate JOIN clauses into the
|
||||
FROM entry. For example, the SQL generated could be
|
||||
LEFT OUTER JOIN "sometable" T1
|
||||
ON ("othertable"."sometable_id" = "sometable"."id")
|
||||
|
||||
This class is primarily used in Query.alias_map. All entries in alias_map
|
||||
must be Join compatible by providing the following attributes and methods:
|
||||
- table_name (string)
|
||||
- table_alias (possible alias for the table, can be None)
|
||||
- join_type (can be None for those entries that aren't joined from
|
||||
anything)
|
||||
- parent_alias (which table is this join's parent, can be None similarly
|
||||
to join_type)
|
||||
- as_sql()
|
||||
- relabeled_clone()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
table_name,
|
||||
parent_alias,
|
||||
table_alias,
|
||||
join_type,
|
||||
join_field,
|
||||
nullable,
|
||||
filtered_relation=None,
|
||||
):
|
||||
# Join table
|
||||
self.table_name = table_name
|
||||
self.parent_alias = parent_alias
|
||||
# Note: table_alias is not necessarily known at instantiation time.
|
||||
self.table_alias = table_alias
|
||||
# LOUTER or INNER
|
||||
self.join_type = join_type
|
||||
# A list of 2-tuples to use in the ON clause of the JOIN.
|
||||
# Each 2-tuple will create one join condition in the ON clause.
|
||||
self.join_cols = join_field.get_joining_columns()
|
||||
# Along which field (or ForeignObjectRel in the reverse join case)
|
||||
self.join_field = join_field
|
||||
# Is this join nullabled?
|
||||
self.nullable = nullable
|
||||
self.filtered_relation = filtered_relation
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
"""
|
||||
Generate the full
|
||||
LEFT OUTER JOIN sometable ON sometable.somecol = othertable.othercol, params
|
||||
clause for this join.
|
||||
"""
|
||||
join_conditions = []
|
||||
params = []
|
||||
qn = compiler.quote_name_unless_alias
|
||||
qn2 = connection.ops.quote_name
|
||||
|
||||
# Add a join condition for each pair of joining columns.
|
||||
for lhs_col, rhs_col in self.join_cols:
|
||||
join_conditions.append(
|
||||
"%s.%s = %s.%s"
|
||||
% (
|
||||
qn(self.parent_alias),
|
||||
qn2(lhs_col),
|
||||
qn(self.table_alias),
|
||||
qn2(rhs_col),
|
||||
)
|
||||
)
|
||||
|
||||
# Add a single condition inside parentheses for whatever
|
||||
# get_extra_restriction() returns.
|
||||
extra_cond = self.join_field.get_extra_restriction(
|
||||
self.table_alias, self.parent_alias
|
||||
)
|
||||
if extra_cond:
|
||||
extra_sql, extra_params = compiler.compile(extra_cond)
|
||||
join_conditions.append("(%s)" % extra_sql)
|
||||
params.extend(extra_params)
|
||||
if self.filtered_relation:
|
||||
try:
|
||||
extra_sql, extra_params = compiler.compile(self.filtered_relation)
|
||||
except FullResultSet:
|
||||
pass
|
||||
else:
|
||||
join_conditions.append("(%s)" % extra_sql)
|
||||
params.extend(extra_params)
|
||||
if not join_conditions:
|
||||
# This might be a rel on the other end of an actual declared field.
|
||||
declared_field = getattr(self.join_field, "field", self.join_field)
|
||||
raise ValueError(
|
||||
"Join generated an empty ON clause. %s did not yield either "
|
||||
"joining columns or extra restrictions." % declared_field.__class__
|
||||
)
|
||||
on_clause_sql = " AND ".join(join_conditions)
|
||||
alias_str = (
|
||||
"" if self.table_alias == self.table_name else (" %s" % self.table_alias)
|
||||
)
|
||||
sql = "%s %s%s ON (%s)" % (
|
||||
self.join_type,
|
||||
qn(self.table_name),
|
||||
alias_str,
|
||||
on_clause_sql,
|
||||
)
|
||||
return sql, params
|
||||
|
||||
def relabeled_clone(self, change_map):
|
||||
new_parent_alias = change_map.get(self.parent_alias, self.parent_alias)
|
||||
new_table_alias = change_map.get(self.table_alias, self.table_alias)
|
||||
if self.filtered_relation is not None:
|
||||
filtered_relation = self.filtered_relation.clone()
|
||||
filtered_relation.path = [
|
||||
change_map.get(p, p) for p in self.filtered_relation.path
|
||||
]
|
||||
else:
|
||||
filtered_relation = None
|
||||
return self.__class__(
|
||||
self.table_name,
|
||||
new_parent_alias,
|
||||
new_table_alias,
|
||||
self.join_type,
|
||||
self.join_field,
|
||||
self.nullable,
|
||||
filtered_relation=filtered_relation,
|
||||
)
|
||||
|
||||
@property
|
||||
def identity(self):
|
||||
return (
|
||||
self.__class__,
|
||||
self.table_name,
|
||||
self.parent_alias,
|
||||
self.join_field,
|
||||
self.filtered_relation,
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Join):
|
||||
return NotImplemented
|
||||
return self.identity == other.identity
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.identity)
|
||||
|
||||
def equals(self, other):
|
||||
# Ignore filtered_relation in equality check.
|
||||
return self.identity[:-1] == other.identity[:-1]
|
||||
|
||||
def demote(self):
|
||||
new = self.relabeled_clone({})
|
||||
new.join_type = INNER
|
||||
return new
|
||||
|
||||
def promote(self):
|
||||
new = self.relabeled_clone({})
|
||||
new.join_type = LOUTER
|
||||
return new
|
||||
|
||||
|
||||
class BaseTable:
|
||||
"""
|
||||
The BaseTable class is used for base table references in FROM clause. For
|
||||
example, the SQL "foo" in
|
||||
SELECT * FROM "foo" WHERE somecond
|
||||
could be generated by this class.
|
||||
"""
|
||||
|
||||
join_type = None
|
||||
parent_alias = None
|
||||
filtered_relation = None
|
||||
|
||||
def __init__(self, table_name, alias):
|
||||
self.table_name = table_name
|
||||
self.table_alias = alias
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
alias_str = (
|
||||
"" if self.table_alias == self.table_name else (" %s" % self.table_alias)
|
||||
)
|
||||
base_sql = compiler.quote_name_unless_alias(self.table_name)
|
||||
return base_sql + alias_str, []
|
||||
|
||||
def relabeled_clone(self, change_map):
|
||||
return self.__class__(
|
||||
self.table_name, change_map.get(self.table_alias, self.table_alias)
|
||||
)
|
||||
|
||||
@property
|
||||
def identity(self):
|
||||
return self.__class__, self.table_name, self.table_alias
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, BaseTable):
|
||||
return NotImplemented
|
||||
return self.identity == other.identity
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.identity)
|
||||
|
||||
def equals(self, other):
|
||||
return self.identity == other.identity
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,171 @@
|
||||
"""
|
||||
Query subclasses which provide extra functionality beyond simple data retrieval.
|
||||
"""
|
||||
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE, NO_RESULTS
|
||||
from django.db.models.sql.query import Query
|
||||
|
||||
__all__ = ["DeleteQuery", "UpdateQuery", "InsertQuery", "AggregateQuery"]
|
||||
|
||||
|
||||
class DeleteQuery(Query):
|
||||
"""A DELETE SQL query."""
|
||||
|
||||
compiler = "SQLDeleteCompiler"
|
||||
|
||||
def do_query(self, table, where, using):
|
||||
self.alias_map = {table: self.alias_map[table]}
|
||||
self.where = where
|
||||
cursor = self.get_compiler(using).execute_sql(CURSOR)
|
||||
if cursor:
|
||||
with cursor:
|
||||
return cursor.rowcount
|
||||
return 0
|
||||
|
||||
def delete_batch(self, pk_list, using):
|
||||
"""
|
||||
Set up and execute delete queries for all the objects in pk_list.
|
||||
|
||||
More than one physical query may be executed if there are a
|
||||
lot of values in pk_list.
|
||||
"""
|
||||
# number of objects deleted
|
||||
num_deleted = 0
|
||||
field = self.get_meta().pk
|
||||
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
|
||||
self.clear_where()
|
||||
self.add_filter(
|
||||
f"{field.attname}__in",
|
||||
pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE],
|
||||
)
|
||||
num_deleted += self.do_query(
|
||||
self.get_meta().db_table, self.where, using=using
|
||||
)
|
||||
return num_deleted
|
||||
|
||||
|
||||
class UpdateQuery(Query):
|
||||
"""An UPDATE SQL query."""
|
||||
|
||||
compiler = "SQLUpdateCompiler"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._setup_query()
|
||||
|
||||
def _setup_query(self):
|
||||
"""
|
||||
Run on initialization and at the end of chaining. Any attributes that
|
||||
would normally be set in __init__() should go here instead.
|
||||
"""
|
||||
self.values = []
|
||||
self.related_ids = None
|
||||
self.related_updates = {}
|
||||
|
||||
def clone(self):
|
||||
obj = super().clone()
|
||||
obj.related_updates = self.related_updates.copy()
|
||||
return obj
|
||||
|
||||
def update_batch(self, pk_list, values, using):
|
||||
self.add_update_values(values)
|
||||
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
|
||||
self.clear_where()
|
||||
self.add_filter(
|
||||
"pk__in", pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]
|
||||
)
|
||||
self.get_compiler(using).execute_sql(NO_RESULTS)
|
||||
|
||||
def add_update_values(self, values):
|
||||
"""
|
||||
Convert a dictionary of field name to value mappings into an update
|
||||
query. This is the entry point for the public update() method on
|
||||
querysets.
|
||||
"""
|
||||
values_seq = []
|
||||
for name, val in values.items():
|
||||
field = self.get_meta().get_field(name)
|
||||
direct = (
|
||||
not (field.auto_created and not field.concrete) or not field.concrete
|
||||
)
|
||||
model = field.model._meta.concrete_model
|
||||
if not direct or (field.is_relation and field.many_to_many):
|
||||
raise FieldError(
|
||||
"Cannot update model field %r (only non-relations and "
|
||||
"foreign keys permitted)." % field
|
||||
)
|
||||
if model is not self.get_meta().concrete_model:
|
||||
self.add_related_update(model, field, val)
|
||||
continue
|
||||
values_seq.append((field, model, val))
|
||||
return self.add_update_fields(values_seq)
|
||||
|
||||
def add_update_fields(self, values_seq):
|
||||
"""
|
||||
Append a sequence of (field, model, value) triples to the internal list
|
||||
that will be used to generate the UPDATE query. Might be more usefully
|
||||
called add_update_targets() to hint at the extra information here.
|
||||
"""
|
||||
for field, model, val in values_seq:
|
||||
if hasattr(val, "resolve_expression"):
|
||||
# Resolve expressions here so that annotations are no longer needed
|
||||
val = val.resolve_expression(self, allow_joins=False, for_save=True)
|
||||
self.values.append((field, model, val))
|
||||
|
||||
def add_related_update(self, model, field, value):
|
||||
"""
|
||||
Add (name, value) to an update query for an ancestor model.
|
||||
|
||||
Update are coalesced so that only one update query per ancestor is run.
|
||||
"""
|
||||
self.related_updates.setdefault(model, []).append((field, None, value))
|
||||
|
||||
def get_related_updates(self):
|
||||
"""
|
||||
Return a list of query objects: one for each update required to an
|
||||
ancestor model. Each query will have the same filtering conditions as
|
||||
the current query but will only update a single table.
|
||||
"""
|
||||
if not self.related_updates:
|
||||
return []
|
||||
result = []
|
||||
for model, values in self.related_updates.items():
|
||||
query = UpdateQuery(model)
|
||||
query.values = values
|
||||
if self.related_ids is not None:
|
||||
query.add_filter("pk__in", self.related_ids[model])
|
||||
result.append(query)
|
||||
return result
|
||||
|
||||
|
||||
class InsertQuery(Query):
|
||||
compiler = "SQLInsertCompiler"
|
||||
|
||||
def __init__(
|
||||
self, *args, on_conflict=None, update_fields=None, unique_fields=None, **kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.fields = []
|
||||
self.objs = []
|
||||
self.on_conflict = on_conflict
|
||||
self.update_fields = update_fields or []
|
||||
self.unique_fields = unique_fields or []
|
||||
|
||||
def insert_values(self, fields, objs, raw=False):
|
||||
self.fields = fields
|
||||
self.objs = objs
|
||||
self.raw = raw
|
||||
|
||||
|
||||
class AggregateQuery(Query):
|
||||
"""
|
||||
Take another query as a parameter to the FROM clause and only select the
|
||||
elements in the provided list.
|
||||
"""
|
||||
|
||||
compiler = "SQLAggregateCompiler"
|
||||
|
||||
def __init__(self, model, inner_query):
|
||||
self.inner_query = inner_query
|
||||
super().__init__(model)
|
@ -0,0 +1,355 @@
|
||||
"""
|
||||
Code to manage the creation and SQL rendering of 'where' constraints.
|
||||
"""
|
||||
import operator
|
||||
from functools import reduce
|
||||
|
||||
from django.core.exceptions import EmptyResultSet, FullResultSet
|
||||
from django.db.models.expressions import Case, When
|
||||
from django.db.models.lookups import Exact
|
||||
from django.utils import tree
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
# Connection types
|
||||
AND = "AND"
|
||||
OR = "OR"
|
||||
XOR = "XOR"
|
||||
|
||||
|
||||
class WhereNode(tree.Node):
|
||||
"""
|
||||
An SQL WHERE clause.
|
||||
|
||||
The class is tied to the Query class that created it (in order to create
|
||||
the correct SQL).
|
||||
|
||||
A child is usually an expression producing boolean values. Most likely the
|
||||
expression is a Lookup instance.
|
||||
|
||||
However, a child could also be any class with as_sql() and either
|
||||
relabeled_clone() method or relabel_aliases() and clone() methods and
|
||||
contains_aggregate attribute.
|
||||
"""
|
||||
|
||||
default = AND
|
||||
resolved = False
|
||||
conditional = True
|
||||
|
||||
def split_having_qualify(self, negated=False, must_group_by=False):
|
||||
"""
|
||||
Return three possibly None nodes: one for those parts of self that
|
||||
should be included in the WHERE clause, one for those parts of self
|
||||
that must be included in the HAVING clause, and one for those parts
|
||||
that refer to window functions.
|
||||
"""
|
||||
if not self.contains_aggregate and not self.contains_over_clause:
|
||||
return self, None, None
|
||||
in_negated = negated ^ self.negated
|
||||
# Whether or not children must be connected in the same filtering
|
||||
# clause (WHERE > HAVING > QUALIFY) to maintain logical semantic.
|
||||
must_remain_connected = (
|
||||
(in_negated and self.connector == AND)
|
||||
or (not in_negated and self.connector == OR)
|
||||
or self.connector == XOR
|
||||
)
|
||||
if (
|
||||
must_remain_connected
|
||||
and self.contains_aggregate
|
||||
and not self.contains_over_clause
|
||||
):
|
||||
# It's must cheaper to short-circuit and stash everything in the
|
||||
# HAVING clause than split children if possible.
|
||||
return None, self, None
|
||||
where_parts = []
|
||||
having_parts = []
|
||||
qualify_parts = []
|
||||
for c in self.children:
|
||||
if hasattr(c, "split_having_qualify"):
|
||||
where_part, having_part, qualify_part = c.split_having_qualify(
|
||||
in_negated, must_group_by
|
||||
)
|
||||
if where_part is not None:
|
||||
where_parts.append(where_part)
|
||||
if having_part is not None:
|
||||
having_parts.append(having_part)
|
||||
if qualify_part is not None:
|
||||
qualify_parts.append(qualify_part)
|
||||
elif c.contains_over_clause:
|
||||
qualify_parts.append(c)
|
||||
elif c.contains_aggregate:
|
||||
having_parts.append(c)
|
||||
else:
|
||||
where_parts.append(c)
|
||||
if must_remain_connected and qualify_parts:
|
||||
# Disjunctive heterogeneous predicates can be pushed down to
|
||||
# qualify as long as no conditional aggregation is involved.
|
||||
if not where_parts or (where_parts and not must_group_by):
|
||||
return None, None, self
|
||||
elif where_parts:
|
||||
# In theory this should only be enforced when dealing with
|
||||
# where_parts containing predicates against multi-valued
|
||||
# relationships that could affect aggregation results but this
|
||||
# is complex to infer properly.
|
||||
raise NotImplementedError(
|
||||
"Heterogeneous disjunctive predicates against window functions are "
|
||||
"not implemented when performing conditional aggregation."
|
||||
)
|
||||
where_node = (
|
||||
self.create(where_parts, self.connector, self.negated)
|
||||
if where_parts
|
||||
else None
|
||||
)
|
||||
having_node = (
|
||||
self.create(having_parts, self.connector, self.negated)
|
||||
if having_parts
|
||||
else None
|
||||
)
|
||||
qualify_node = (
|
||||
self.create(qualify_parts, self.connector, self.negated)
|
||||
if qualify_parts
|
||||
else None
|
||||
)
|
||||
return where_node, having_node, qualify_node
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
"""
|
||||
Return the SQL version of the where clause and the value to be
|
||||
substituted in. Return '', [] if this node matches everything,
|
||||
None, [] if this node is empty, and raise EmptyResultSet if this
|
||||
node can't match anything.
|
||||
"""
|
||||
result = []
|
||||
result_params = []
|
||||
if self.connector == AND:
|
||||
full_needed, empty_needed = len(self.children), 1
|
||||
else:
|
||||
full_needed, empty_needed = 1, len(self.children)
|
||||
|
||||
if self.connector == XOR and not connection.features.supports_logical_xor:
|
||||
# Convert if the database doesn't support XOR:
|
||||
# a XOR b XOR c XOR ...
|
||||
# to:
|
||||
# (a OR b OR c OR ...) AND (a + b + c + ...) == 1
|
||||
lhs = self.__class__(self.children, OR)
|
||||
rhs_sum = reduce(
|
||||
operator.add,
|
||||
(Case(When(c, then=1), default=0) for c in self.children),
|
||||
)
|
||||
rhs = Exact(1, rhs_sum)
|
||||
return self.__class__([lhs, rhs], AND, self.negated).as_sql(
|
||||
compiler, connection
|
||||
)
|
||||
|
||||
for child in self.children:
|
||||
try:
|
||||
sql, params = compiler.compile(child)
|
||||
except EmptyResultSet:
|
||||
empty_needed -= 1
|
||||
except FullResultSet:
|
||||
full_needed -= 1
|
||||
else:
|
||||
if sql:
|
||||
result.append(sql)
|
||||
result_params.extend(params)
|
||||
else:
|
||||
full_needed -= 1
|
||||
# Check if this node matches nothing or everything.
|
||||
# First check the amount of full nodes and empty nodes
|
||||
# to make this node empty/full.
|
||||
# Now, check if this node is full/empty using the
|
||||
# counts.
|
||||
if empty_needed == 0:
|
||||
if self.negated:
|
||||
raise FullResultSet
|
||||
else:
|
||||
raise EmptyResultSet
|
||||
if full_needed == 0:
|
||||
if self.negated:
|
||||
raise EmptyResultSet
|
||||
else:
|
||||
raise FullResultSet
|
||||
conn = " %s " % self.connector
|
||||
sql_string = conn.join(result)
|
||||
if not sql_string:
|
||||
raise FullResultSet
|
||||
if self.negated:
|
||||
# Some backends (Oracle at least) need parentheses around the inner
|
||||
# SQL in the negated case, even if the inner SQL contains just a
|
||||
# single expression.
|
||||
sql_string = "NOT (%s)" % sql_string
|
||||
elif len(result) > 1 or self.resolved:
|
||||
sql_string = "(%s)" % sql_string
|
||||
return sql_string, result_params
|
||||
|
||||
def get_group_by_cols(self):
|
||||
cols = []
|
||||
for child in self.children:
|
||||
cols.extend(child.get_group_by_cols())
|
||||
return cols
|
||||
|
||||
def get_source_expressions(self):
|
||||
return self.children[:]
|
||||
|
||||
def set_source_expressions(self, children):
|
||||
assert len(children) == len(self.children)
|
||||
self.children = children
|
||||
|
||||
def relabel_aliases(self, change_map):
|
||||
"""
|
||||
Relabel the alias values of any children. 'change_map' is a dictionary
|
||||
mapping old (current) alias values to the new values.
|
||||
"""
|
||||
for pos, child in enumerate(self.children):
|
||||
if hasattr(child, "relabel_aliases"):
|
||||
# For example another WhereNode
|
||||
child.relabel_aliases(change_map)
|
||||
elif hasattr(child, "relabeled_clone"):
|
||||
self.children[pos] = child.relabeled_clone(change_map)
|
||||
|
||||
def clone(self):
|
||||
clone = self.create(connector=self.connector, negated=self.negated)
|
||||
for child in self.children:
|
||||
if hasattr(child, "clone"):
|
||||
child = child.clone()
|
||||
clone.children.append(child)
|
||||
return clone
|
||||
|
||||
def relabeled_clone(self, change_map):
|
||||
clone = self.clone()
|
||||
clone.relabel_aliases(change_map)
|
||||
return clone
|
||||
|
||||
def replace_expressions(self, replacements):
|
||||
if replacement := replacements.get(self):
|
||||
return replacement
|
||||
clone = self.create(connector=self.connector, negated=self.negated)
|
||||
for child in self.children:
|
||||
clone.children.append(child.replace_expressions(replacements))
|
||||
return clone
|
||||
|
||||
def get_refs(self):
|
||||
refs = set()
|
||||
for child in self.children:
|
||||
refs |= child.get_refs()
|
||||
return refs
|
||||
|
||||
@classmethod
|
||||
def _contains_aggregate(cls, obj):
|
||||
if isinstance(obj, tree.Node):
|
||||
return any(cls._contains_aggregate(c) for c in obj.children)
|
||||
return obj.contains_aggregate
|
||||
|
||||
@cached_property
|
||||
def contains_aggregate(self):
|
||||
return self._contains_aggregate(self)
|
||||
|
||||
@classmethod
|
||||
def _contains_over_clause(cls, obj):
|
||||
if isinstance(obj, tree.Node):
|
||||
return any(cls._contains_over_clause(c) for c in obj.children)
|
||||
return obj.contains_over_clause
|
||||
|
||||
@cached_property
|
||||
def contains_over_clause(self):
|
||||
return self._contains_over_clause(self)
|
||||
|
||||
@property
|
||||
def is_summary(self):
|
||||
return any(child.is_summary for child in self.children)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_leaf(expr, query, *args, **kwargs):
|
||||
if hasattr(expr, "resolve_expression"):
|
||||
expr = expr.resolve_expression(query, *args, **kwargs)
|
||||
return expr
|
||||
|
||||
@classmethod
|
||||
def _resolve_node(cls, node, query, *args, **kwargs):
|
||||
if hasattr(node, "children"):
|
||||
for child in node.children:
|
||||
cls._resolve_node(child, query, *args, **kwargs)
|
||||
if hasattr(node, "lhs"):
|
||||
node.lhs = cls._resolve_leaf(node.lhs, query, *args, **kwargs)
|
||||
if hasattr(node, "rhs"):
|
||||
node.rhs = cls._resolve_leaf(node.rhs, query, *args, **kwargs)
|
||||
|
||||
def resolve_expression(self, *args, **kwargs):
|
||||
clone = self.clone()
|
||||
clone._resolve_node(clone, *args, **kwargs)
|
||||
clone.resolved = True
|
||||
return clone
|
||||
|
||||
@cached_property
|
||||
def output_field(self):
|
||||
from django.db.models import BooleanField
|
||||
|
||||
return BooleanField()
|
||||
|
||||
@property
|
||||
def _output_field_or_none(self):
|
||||
return self.output_field
|
||||
|
||||
def select_format(self, compiler, sql, params):
|
||||
# Wrap filters with a CASE WHEN expression if a database backend
|
||||
# (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
|
||||
# BY list.
|
||||
if not compiler.connection.features.supports_boolean_expr_in_select_clause:
|
||||
sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
|
||||
return sql, params
|
||||
|
||||
def get_db_converters(self, connection):
|
||||
return self.output_field.get_db_converters(connection)
|
||||
|
||||
def get_lookup(self, lookup):
|
||||
return self.output_field.get_lookup(lookup)
|
||||
|
||||
def leaves(self):
|
||||
for child in self.children:
|
||||
if isinstance(child, WhereNode):
|
||||
yield from child.leaves()
|
||||
else:
|
||||
yield child
|
||||
|
||||
|
||||
class NothingNode:
|
||||
"""A node that matches nothing."""
|
||||
|
||||
contains_aggregate = False
|
||||
contains_over_clause = False
|
||||
|
||||
def as_sql(self, compiler=None, connection=None):
|
||||
raise EmptyResultSet
|
||||
|
||||
|
||||
class ExtraWhere:
|
||||
# The contents are a black box - assume no aggregates or windows are used.
|
||||
contains_aggregate = False
|
||||
contains_over_clause = False
|
||||
|
||||
def __init__(self, sqls, params):
|
||||
self.sqls = sqls
|
||||
self.params = params
|
||||
|
||||
def as_sql(self, compiler=None, connection=None):
|
||||
sqls = ["(%s)" % sql for sql in self.sqls]
|
||||
return " AND ".join(sqls), list(self.params or ())
|
||||
|
||||
|
||||
class SubqueryConstraint:
|
||||
# Even if aggregates or windows would be used in a subquery,
|
||||
# the outer query isn't interested about those.
|
||||
contains_aggregate = False
|
||||
contains_over_clause = False
|
||||
|
||||
def __init__(self, alias, columns, targets, query_object):
|
||||
self.alias = alias
|
||||
self.columns = columns
|
||||
self.targets = targets
|
||||
query_object.clear_ordering(clear_default=True)
|
||||
self.query_object = query_object
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
query = self.query_object
|
||||
query.set_values(self.targets)
|
||||
query_compiler = query.get_compiler(connection=connection)
|
||||
return query_compiler.as_subquery_condition(self.alias, self.columns, compiler)
|
@ -0,0 +1,69 @@
|
||||
import functools
|
||||
from collections import namedtuple
|
||||
|
||||
|
||||
def make_model_tuple(model):
|
||||
"""
|
||||
Take a model or a string of the form "app_label.ModelName" and return a
|
||||
corresponding ("app_label", "modelname") tuple. If a tuple is passed in,
|
||||
assume it's a valid model tuple already and return it unchanged.
|
||||
"""
|
||||
try:
|
||||
if isinstance(model, tuple):
|
||||
model_tuple = model
|
||||
elif isinstance(model, str):
|
||||
app_label, model_name = model.split(".")
|
||||
model_tuple = app_label, model_name.lower()
|
||||
else:
|
||||
model_tuple = model._meta.app_label, model._meta.model_name
|
||||
assert len(model_tuple) == 2
|
||||
return model_tuple
|
||||
except (ValueError, AssertionError):
|
||||
raise ValueError(
|
||||
"Invalid model reference '%s'. String model references "
|
||||
"must be of the form 'app_label.ModelName'." % model
|
||||
)
|
||||
|
||||
|
||||
def resolve_callables(mapping):
|
||||
"""
|
||||
Generate key/value pairs for the given mapping where the values are
|
||||
evaluated if they're callable.
|
||||
"""
|
||||
for k, v in mapping.items():
|
||||
yield k, v() if callable(v) else v
|
||||
|
||||
|
||||
def unpickle_named_row(names, values):
|
||||
return create_namedtuple_class(*names)(*values)
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def create_namedtuple_class(*names):
|
||||
# Cache type() with @lru_cache since it's too slow to be called for every
|
||||
# QuerySet evaluation.
|
||||
def __reduce__(self):
|
||||
return unpickle_named_row, (names, tuple(self))
|
||||
|
||||
return type(
|
||||
"Row",
|
||||
(namedtuple("Row", names),),
|
||||
{"__reduce__": __reduce__, "__slots__": ()},
|
||||
)
|
||||
|
||||
|
||||
class AltersData:
|
||||
"""
|
||||
Make subclasses preserve the alters_data attribute on overridden methods.
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
for fn_name, fn in vars(cls).items():
|
||||
if callable(fn) and not hasattr(fn, "alters_data"):
|
||||
for base in cls.__bases__:
|
||||
if base_fn := getattr(base, fn_name, None):
|
||||
if hasattr(base_fn, "alters_data"):
|
||||
fn.alters_data = base_fn.alters_data
|
||||
break
|
||||
|
||||
super().__init_subclass__(**kwargs)
|
Reference in New Issue
Block a user