docker setup

This commit is contained in:
AdrienLSH
2023-11-23 16:43:30 +01:00
parent fd19180e1d
commit f29003c66a
5410 changed files with 869440 additions and 0 deletions

View File

@ -0,0 +1,115 @@
from django.core.exceptions import ObjectDoesNotExist
from django.db.models import signals
from django.db.models.aggregates import * # NOQA
from django.db.models.aggregates import __all__ as aggregates_all
from django.db.models.constraints import * # NOQA
from django.db.models.constraints import __all__ as constraints_all
from django.db.models.deletion import (
CASCADE,
DO_NOTHING,
PROTECT,
RESTRICT,
SET,
SET_DEFAULT,
SET_NULL,
ProtectedError,
RestrictedError,
)
from django.db.models.enums import * # NOQA
from django.db.models.enums import __all__ as enums_all
from django.db.models.expressions import (
Case,
Exists,
Expression,
ExpressionList,
ExpressionWrapper,
F,
Func,
OrderBy,
OuterRef,
RowRange,
Subquery,
Value,
ValueRange,
When,
Window,
WindowFrame,
)
from django.db.models.fields import * # NOQA
from django.db.models.fields import __all__ as fields_all
from django.db.models.fields.files import FileField, ImageField
from django.db.models.fields.json import JSONField
from django.db.models.fields.proxy import OrderWrt
from django.db.models.indexes import * # NOQA
from django.db.models.indexes import __all__ as indexes_all
from django.db.models.lookups import Lookup, Transform
from django.db.models.manager import Manager
from django.db.models.query import Prefetch, QuerySet, prefetch_related_objects
from django.db.models.query_utils import FilteredRelation, Q
# Imports that would create circular imports if sorted
from django.db.models.base import DEFERRED, Model # isort:skip
from django.db.models.fields.related import ( # isort:skip
ForeignKey,
ForeignObject,
OneToOneField,
ManyToManyField,
ForeignObjectRel,
ManyToOneRel,
ManyToManyRel,
OneToOneRel,
)
__all__ = aggregates_all + constraints_all + enums_all + fields_all + indexes_all
__all__ += [
"ObjectDoesNotExist",
"signals",
"CASCADE",
"DO_NOTHING",
"PROTECT",
"RESTRICT",
"SET",
"SET_DEFAULT",
"SET_NULL",
"ProtectedError",
"RestrictedError",
"Case",
"Exists",
"Expression",
"ExpressionList",
"ExpressionWrapper",
"F",
"Func",
"OrderBy",
"OuterRef",
"RowRange",
"Subquery",
"Value",
"ValueRange",
"When",
"Window",
"WindowFrame",
"FileField",
"ImageField",
"JSONField",
"OrderWrt",
"Lookup",
"Transform",
"Manager",
"Prefetch",
"Q",
"QuerySet",
"prefetch_related_objects",
"DEFERRED",
"Model",
"FilteredRelation",
"ForeignKey",
"ForeignObject",
"OneToOneField",
"ManyToManyField",
"ForeignObjectRel",
"ManyToOneRel",
"ManyToManyRel",
"OneToOneRel",
]

View File

@ -0,0 +1,210 @@
"""
Classes to represent the definitions of aggregate functions.
"""
from django.core.exceptions import FieldError, FullResultSet
from django.db.models.expressions import Case, Func, Star, Value, When
from django.db.models.fields import IntegerField
from django.db.models.functions.comparison import Coalesce
from django.db.models.functions.mixins import (
FixDurationInputMixin,
NumericOutputFieldMixin,
)
__all__ = [
"Aggregate",
"Avg",
"Count",
"Max",
"Min",
"StdDev",
"Sum",
"Variance",
]
class Aggregate(Func):
template = "%(function)s(%(distinct)s%(expressions)s)"
contains_aggregate = True
name = None
filter_template = "%s FILTER (WHERE %%(filter)s)"
window_compatible = True
allow_distinct = False
empty_result_set_value = None
def __init__(
self, *expressions, distinct=False, filter=None, default=None, **extra
):
if distinct and not self.allow_distinct:
raise TypeError("%s does not allow distinct." % self.__class__.__name__)
if default is not None and self.empty_result_set_value is not None:
raise TypeError(f"{self.__class__.__name__} does not allow default.")
self.distinct = distinct
self.filter = filter
self.default = default
super().__init__(*expressions, **extra)
def get_source_fields(self):
# Don't return the filter expression since it's not a source field.
return [e._output_field_or_none for e in super().get_source_expressions()]
def get_source_expressions(self):
source_expressions = super().get_source_expressions()
if self.filter:
return source_expressions + [self.filter]
return source_expressions
def set_source_expressions(self, exprs):
self.filter = self.filter and exprs.pop()
return super().set_source_expressions(exprs)
def resolve_expression(
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
):
# Aggregates are not allowed in UPDATE queries, so ignore for_save
c = super().resolve_expression(query, allow_joins, reuse, summarize)
c.filter = c.filter and c.filter.resolve_expression(
query, allow_joins, reuse, summarize
)
if summarize:
# Summarized aggregates cannot refer to summarized aggregates.
for ref in c.get_refs():
if query.annotations[ref].is_summary:
raise FieldError(
f"Cannot compute {c.name}('{ref}'): '{ref}' is an aggregate"
)
elif not self.is_summary:
# Call Aggregate.get_source_expressions() to avoid
# returning self.filter and including that in this loop.
expressions = super(Aggregate, c).get_source_expressions()
for index, expr in enumerate(expressions):
if expr.contains_aggregate:
before_resolved = self.get_source_expressions()[index]
name = (
before_resolved.name
if hasattr(before_resolved, "name")
else repr(before_resolved)
)
raise FieldError(
"Cannot compute %s('%s'): '%s' is an aggregate"
% (c.name, name, name)
)
if (default := c.default) is None:
return c
if hasattr(default, "resolve_expression"):
default = default.resolve_expression(query, allow_joins, reuse, summarize)
if default._output_field_or_none is None:
default.output_field = c._output_field_or_none
else:
default = Value(default, c._output_field_or_none)
c.default = None # Reset the default argument before wrapping.
coalesce = Coalesce(c, default, output_field=c._output_field_or_none)
coalesce.is_summary = c.is_summary
return coalesce
@property
def default_alias(self):
expressions = self.get_source_expressions()
if len(expressions) == 1 and hasattr(expressions[0], "name"):
return "%s__%s" % (expressions[0].name, self.name.lower())
raise TypeError("Complex expressions require an alias")
def get_group_by_cols(self):
return []
def as_sql(self, compiler, connection, **extra_context):
extra_context["distinct"] = "DISTINCT " if self.distinct else ""
if self.filter:
if connection.features.supports_aggregate_filter_clause:
try:
filter_sql, filter_params = self.filter.as_sql(compiler, connection)
except FullResultSet:
pass
else:
template = self.filter_template % extra_context.get(
"template", self.template
)
sql, params = super().as_sql(
compiler,
connection,
template=template,
filter=filter_sql,
**extra_context,
)
return sql, (*params, *filter_params)
else:
copy = self.copy()
copy.filter = None
source_expressions = copy.get_source_expressions()
condition = When(self.filter, then=source_expressions[0])
copy.set_source_expressions([Case(condition)] + source_expressions[1:])
return super(Aggregate, copy).as_sql(
compiler, connection, **extra_context
)
return super().as_sql(compiler, connection, **extra_context)
def _get_repr_options(self):
options = super()._get_repr_options()
if self.distinct:
options["distinct"] = self.distinct
if self.filter:
options["filter"] = self.filter
return options
class Avg(FixDurationInputMixin, NumericOutputFieldMixin, Aggregate):
function = "AVG"
name = "Avg"
allow_distinct = True
class Count(Aggregate):
function = "COUNT"
name = "Count"
output_field = IntegerField()
allow_distinct = True
empty_result_set_value = 0
def __init__(self, expression, filter=None, **extra):
if expression == "*":
expression = Star()
if isinstance(expression, Star) and filter is not None:
raise ValueError("Star cannot be used with filter. Please specify a field.")
super().__init__(expression, filter=filter, **extra)
class Max(Aggregate):
function = "MAX"
name = "Max"
class Min(Aggregate):
function = "MIN"
name = "Min"
class StdDev(NumericOutputFieldMixin, Aggregate):
name = "StdDev"
def __init__(self, expression, sample=False, **extra):
self.function = "STDDEV_SAMP" if sample else "STDDEV_POP"
super().__init__(expression, **extra)
def _get_repr_options(self):
return {**super()._get_repr_options(), "sample": self.function == "STDDEV_SAMP"}
class Sum(FixDurationInputMixin, Aggregate):
function = "SUM"
name = "Sum"
allow_distinct = True
class Variance(NumericOutputFieldMixin, Aggregate):
name = "Variance"
def __init__(self, expression, sample=False, **extra):
self.function = "VAR_SAMP" if sample else "VAR_POP"
super().__init__(expression, **extra)
def _get_repr_options(self):
return {**super()._get_repr_options(), "sample": self.function == "VAR_SAMP"}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,12 @@
"""
Constants used across the ORM in general.
"""
from enum import Enum
# Separator used to split filter strings apart.
LOOKUP_SEP = "__"
class OnConflict(Enum):
IGNORE = "ignore"
UPDATE = "update"

View File

@ -0,0 +1,371 @@
from enum import Enum
from django.core.exceptions import FieldError, ValidationError
from django.db import connections
from django.db.models.expressions import Exists, ExpressionList, F, OrderBy
from django.db.models.indexes import IndexExpression
from django.db.models.lookups import Exact
from django.db.models.query_utils import Q
from django.db.models.sql.query import Query
from django.db.utils import DEFAULT_DB_ALIAS
from django.utils.translation import gettext_lazy as _
__all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"]
class BaseConstraint:
default_violation_error_message = _("Constraint “%(name)s” is violated.")
violation_error_message = None
def __init__(self, name, violation_error_message=None):
self.name = name
if violation_error_message is not None:
self.violation_error_message = violation_error_message
else:
self.violation_error_message = self.default_violation_error_message
@property
def contains_expressions(self):
return False
def constraint_sql(self, model, schema_editor):
raise NotImplementedError("This method must be implemented by a subclass.")
def create_sql(self, model, schema_editor):
raise NotImplementedError("This method must be implemented by a subclass.")
def remove_sql(self, model, schema_editor):
raise NotImplementedError("This method must be implemented by a subclass.")
def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
raise NotImplementedError("This method must be implemented by a subclass.")
def get_violation_error_message(self):
return self.violation_error_message % {"name": self.name}
def deconstruct(self):
path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
path = path.replace("django.db.models.constraints", "django.db.models")
kwargs = {"name": self.name}
if (
self.violation_error_message is not None
and self.violation_error_message != self.default_violation_error_message
):
kwargs["violation_error_message"] = self.violation_error_message
return (path, (), kwargs)
def clone(self):
_, args, kwargs = self.deconstruct()
return self.__class__(*args, **kwargs)
class CheckConstraint(BaseConstraint):
def __init__(self, *, check, name, violation_error_message=None):
self.check = check
if not getattr(check, "conditional", False):
raise TypeError(
"CheckConstraint.check must be a Q instance or boolean expression."
)
super().__init__(name, violation_error_message=violation_error_message)
def _get_check_sql(self, model, schema_editor):
query = Query(model=model, alias_cols=False)
where = query.build_where(self.check)
compiler = query.get_compiler(connection=schema_editor.connection)
sql, params = where.as_sql(compiler, schema_editor.connection)
return sql % tuple(schema_editor.quote_value(p) for p in params)
def constraint_sql(self, model, schema_editor):
check = self._get_check_sql(model, schema_editor)
return schema_editor._check_sql(self.name, check)
def create_sql(self, model, schema_editor):
check = self._get_check_sql(model, schema_editor)
return schema_editor._create_check_sql(model, self.name, check)
def remove_sql(self, model, schema_editor):
return schema_editor._delete_check_sql(model, self.name)
def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
against = instance._get_field_value_map(meta=model._meta, exclude=exclude)
try:
if not Q(self.check).check(against, using=using):
raise ValidationError(self.get_violation_error_message())
except FieldError:
pass
def __repr__(self):
return "<%s: check=%s name=%s>" % (
self.__class__.__qualname__,
self.check,
repr(self.name),
)
def __eq__(self, other):
if isinstance(other, CheckConstraint):
return (
self.name == other.name
and self.check == other.check
and self.violation_error_message == other.violation_error_message
)
return super().__eq__(other)
def deconstruct(self):
path, args, kwargs = super().deconstruct()
kwargs["check"] = self.check
return path, args, kwargs
class Deferrable(Enum):
DEFERRED = "deferred"
IMMEDIATE = "immediate"
# A similar format was proposed for Python 3.10.
def __repr__(self):
return f"{self.__class__.__qualname__}.{self._name_}"
class UniqueConstraint(BaseConstraint):
def __init__(
self,
*expressions,
fields=(),
name=None,
condition=None,
deferrable=None,
include=None,
opclasses=(),
violation_error_message=None,
):
if not name:
raise ValueError("A unique constraint must be named.")
if not expressions and not fields:
raise ValueError(
"At least one field or expression is required to define a "
"unique constraint."
)
if expressions and fields:
raise ValueError(
"UniqueConstraint.fields and expressions are mutually exclusive."
)
if not isinstance(condition, (type(None), Q)):
raise ValueError("UniqueConstraint.condition must be a Q instance.")
if condition and deferrable:
raise ValueError("UniqueConstraint with conditions cannot be deferred.")
if include and deferrable:
raise ValueError("UniqueConstraint with include fields cannot be deferred.")
if opclasses and deferrable:
raise ValueError("UniqueConstraint with opclasses cannot be deferred.")
if expressions and deferrable:
raise ValueError("UniqueConstraint with expressions cannot be deferred.")
if expressions and opclasses:
raise ValueError(
"UniqueConstraint.opclasses cannot be used with expressions. "
"Use django.contrib.postgres.indexes.OpClass() instead."
)
if not isinstance(deferrable, (type(None), Deferrable)):
raise ValueError(
"UniqueConstraint.deferrable must be a Deferrable instance."
)
if not isinstance(include, (type(None), list, tuple)):
raise ValueError("UniqueConstraint.include must be a list or tuple.")
if not isinstance(opclasses, (list, tuple)):
raise ValueError("UniqueConstraint.opclasses must be a list or tuple.")
if opclasses and len(fields) != len(opclasses):
raise ValueError(
"UniqueConstraint.fields and UniqueConstraint.opclasses must "
"have the same number of elements."
)
self.fields = tuple(fields)
self.condition = condition
self.deferrable = deferrable
self.include = tuple(include) if include else ()
self.opclasses = opclasses
self.expressions = tuple(
F(expression) if isinstance(expression, str) else expression
for expression in expressions
)
super().__init__(name, violation_error_message=violation_error_message)
@property
def contains_expressions(self):
return bool(self.expressions)
def _get_condition_sql(self, model, schema_editor):
if self.condition is None:
return None
query = Query(model=model, alias_cols=False)
where = query.build_where(self.condition)
compiler = query.get_compiler(connection=schema_editor.connection)
sql, params = where.as_sql(compiler, schema_editor.connection)
return sql % tuple(schema_editor.quote_value(p) for p in params)
def _get_index_expressions(self, model, schema_editor):
if not self.expressions:
return None
index_expressions = []
for expression in self.expressions:
index_expression = IndexExpression(expression)
index_expression.set_wrapper_classes(schema_editor.connection)
index_expressions.append(index_expression)
return ExpressionList(*index_expressions).resolve_expression(
Query(model, alias_cols=False),
)
def constraint_sql(self, model, schema_editor):
fields = [model._meta.get_field(field_name) for field_name in self.fields]
include = [
model._meta.get_field(field_name).column for field_name in self.include
]
condition = self._get_condition_sql(model, schema_editor)
expressions = self._get_index_expressions(model, schema_editor)
return schema_editor._unique_sql(
model,
fields,
self.name,
condition=condition,
deferrable=self.deferrable,
include=include,
opclasses=self.opclasses,
expressions=expressions,
)
def create_sql(self, model, schema_editor):
fields = [model._meta.get_field(field_name) for field_name in self.fields]
include = [
model._meta.get_field(field_name).column for field_name in self.include
]
condition = self._get_condition_sql(model, schema_editor)
expressions = self._get_index_expressions(model, schema_editor)
return schema_editor._create_unique_sql(
model,
fields,
self.name,
condition=condition,
deferrable=self.deferrable,
include=include,
opclasses=self.opclasses,
expressions=expressions,
)
def remove_sql(self, model, schema_editor):
condition = self._get_condition_sql(model, schema_editor)
include = [
model._meta.get_field(field_name).column for field_name in self.include
]
expressions = self._get_index_expressions(model, schema_editor)
return schema_editor._delete_unique_sql(
model,
self.name,
condition=condition,
deferrable=self.deferrable,
include=include,
opclasses=self.opclasses,
expressions=expressions,
)
def __repr__(self):
return "<%s:%s%s%s%s%s%s%s>" % (
self.__class__.__qualname__,
"" if not self.fields else " fields=%s" % repr(self.fields),
"" if not self.expressions else " expressions=%s" % repr(self.expressions),
" name=%s" % repr(self.name),
"" if self.condition is None else " condition=%s" % self.condition,
"" if self.deferrable is None else " deferrable=%r" % self.deferrable,
"" if not self.include else " include=%s" % repr(self.include),
"" if not self.opclasses else " opclasses=%s" % repr(self.opclasses),
)
def __eq__(self, other):
if isinstance(other, UniqueConstraint):
return (
self.name == other.name
and self.fields == other.fields
and self.condition == other.condition
and self.deferrable == other.deferrable
and self.include == other.include
and self.opclasses == other.opclasses
and self.expressions == other.expressions
and self.violation_error_message == other.violation_error_message
)
return super().__eq__(other)
def deconstruct(self):
path, args, kwargs = super().deconstruct()
if self.fields:
kwargs["fields"] = self.fields
if self.condition:
kwargs["condition"] = self.condition
if self.deferrable:
kwargs["deferrable"] = self.deferrable
if self.include:
kwargs["include"] = self.include
if self.opclasses:
kwargs["opclasses"] = self.opclasses
return path, self.expressions, kwargs
def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
queryset = model._default_manager.using(using)
if self.fields:
lookup_kwargs = {}
for field_name in self.fields:
if exclude and field_name in exclude:
return
field = model._meta.get_field(field_name)
lookup_value = getattr(instance, field.attname)
if lookup_value is None or (
lookup_value == ""
and connections[using].features.interprets_empty_strings_as_nulls
):
# A composite constraint containing NULL value cannot cause
# a violation since NULL != NULL in SQL.
return
lookup_kwargs[field.name] = lookup_value
queryset = queryset.filter(**lookup_kwargs)
else:
# Ignore constraints with excluded fields.
if exclude:
for expression in self.expressions:
if hasattr(expression, "flatten"):
for expr in expression.flatten():
if isinstance(expr, F) and expr.name in exclude:
return
elif isinstance(expression, F) and expression.name in exclude:
return
replacements = {
F(field): value
for field, value in instance._get_field_value_map(
meta=model._meta, exclude=exclude
).items()
}
expressions = []
for expr in self.expressions:
# Ignore ordering.
if isinstance(expr, OrderBy):
expr = expr.expression
expressions.append(Exact(expr, expr.replace_expressions(replacements)))
queryset = queryset.filter(*expressions)
model_class_pk = instance._get_pk_val(model._meta)
if not instance._state.adding and model_class_pk is not None:
queryset = queryset.exclude(pk=model_class_pk)
if not self.condition:
if queryset.exists():
if self.expressions:
raise ValidationError(self.get_violation_error_message())
# When fields are defined, use the unique_error_message() for
# backward compatibility.
for model, constraints in instance.get_constraints():
for constraint in constraints:
if constraint is self:
raise ValidationError(
instance.unique_error_message(model, self.fields)
)
else:
against = instance._get_field_value_map(meta=model._meta, exclude=exclude)
try:
if (self.condition & Exists(queryset.filter(self.condition))).check(
against, using=using
):
raise ValidationError(self.get_violation_error_message())
except FieldError:
pass

View File

@ -0,0 +1,522 @@
from collections import Counter, defaultdict
from functools import partial, reduce
from itertools import chain
from operator import attrgetter, or_
from django.db import IntegrityError, connections, models, transaction
from django.db.models import query_utils, signals, sql
class ProtectedError(IntegrityError):
def __init__(self, msg, protected_objects):
self.protected_objects = protected_objects
super().__init__(msg, protected_objects)
class RestrictedError(IntegrityError):
def __init__(self, msg, restricted_objects):
self.restricted_objects = restricted_objects
super().__init__(msg, restricted_objects)
def CASCADE(collector, field, sub_objs, using):
collector.collect(
sub_objs,
source=field.remote_field.model,
source_attr=field.name,
nullable=field.null,
fail_on_restricted=False,
)
if field.null and not connections[using].features.can_defer_constraint_checks:
collector.add_field_update(field, None, sub_objs)
def PROTECT(collector, field, sub_objs, using):
raise ProtectedError(
"Cannot delete some instances of model '%s' because they are "
"referenced through a protected foreign key: '%s.%s'"
% (
field.remote_field.model.__name__,
sub_objs[0].__class__.__name__,
field.name,
),
sub_objs,
)
def RESTRICT(collector, field, sub_objs, using):
collector.add_restricted_objects(field, sub_objs)
collector.add_dependency(field.remote_field.model, field.model)
def SET(value):
if callable(value):
def set_on_delete(collector, field, sub_objs, using):
collector.add_field_update(field, value(), sub_objs)
else:
def set_on_delete(collector, field, sub_objs, using):
collector.add_field_update(field, value, sub_objs)
set_on_delete.deconstruct = lambda: ("django.db.models.SET", (value,), {})
set_on_delete.lazy_sub_objs = True
return set_on_delete
def SET_NULL(collector, field, sub_objs, using):
collector.add_field_update(field, None, sub_objs)
SET_NULL.lazy_sub_objs = True
def SET_DEFAULT(collector, field, sub_objs, using):
collector.add_field_update(field, field.get_default(), sub_objs)
SET_DEFAULT.lazy_sub_objs = True
def DO_NOTHING(collector, field, sub_objs, using):
pass
def get_candidate_relations_to_delete(opts):
# The candidate relations are the ones that come from N-1 and 1-1 relations.
# N-N (i.e., many-to-many) relations aren't candidates for deletion.
return (
f
for f in opts.get_fields(include_hidden=True)
if f.auto_created and not f.concrete and (f.one_to_one or f.one_to_many)
)
class Collector:
def __init__(self, using, origin=None):
self.using = using
# A Model or QuerySet object.
self.origin = origin
# Initially, {model: {instances}}, later values become lists.
self.data = defaultdict(set)
# {(field, value): [instances, …]}
self.field_updates = defaultdict(list)
# {model: {field: {instances}}}
self.restricted_objects = defaultdict(partial(defaultdict, set))
# fast_deletes is a list of queryset-likes that can be deleted without
# fetching the objects into memory.
self.fast_deletes = []
# Tracks deletion-order dependency for databases without transactions
# or ability to defer constraint checks. Only concrete model classes
# should be included, as the dependencies exist only between actual
# database tables; proxy models are represented here by their concrete
# parent.
self.dependencies = defaultdict(set) # {model: {models}}
def add(self, objs, source=None, nullable=False, reverse_dependency=False):
"""
Add 'objs' to the collection of objects to be deleted. If the call is
the result of a cascade, 'source' should be the model that caused it,
and 'nullable' should be set to True if the relation can be null.
Return a list of all objects that were not already collected.
"""
if not objs:
return []
new_objs = []
model = objs[0].__class__
instances = self.data[model]
for obj in objs:
if obj not in instances:
new_objs.append(obj)
instances.update(new_objs)
# Nullable relationships can be ignored -- they are nulled out before
# deleting, and therefore do not affect the order in which objects have
# to be deleted.
if source is not None and not nullable:
self.add_dependency(source, model, reverse_dependency=reverse_dependency)
return new_objs
def add_dependency(self, model, dependency, reverse_dependency=False):
if reverse_dependency:
model, dependency = dependency, model
self.dependencies[model._meta.concrete_model].add(
dependency._meta.concrete_model
)
self.data.setdefault(dependency, self.data.default_factory())
def add_field_update(self, field, value, objs):
"""
Schedule a field update. 'objs' must be a homogeneous iterable
collection of model instances (e.g. a QuerySet).
"""
self.field_updates[field, value].append(objs)
def add_restricted_objects(self, field, objs):
if objs:
model = objs[0].__class__
self.restricted_objects[model][field].update(objs)
def clear_restricted_objects_from_set(self, model, objs):
if model in self.restricted_objects:
self.restricted_objects[model] = {
field: items - objs
for field, items in self.restricted_objects[model].items()
}
def clear_restricted_objects_from_queryset(self, model, qs):
if model in self.restricted_objects:
objs = set(
qs.filter(
pk__in=[
obj.pk
for objs in self.restricted_objects[model].values()
for obj in objs
]
)
)
self.clear_restricted_objects_from_set(model, objs)
def _has_signal_listeners(self, model):
return signals.pre_delete.has_listeners(
model
) or signals.post_delete.has_listeners(model)
def can_fast_delete(self, objs, from_field=None):
"""
Determine if the objects in the given queryset-like or single object
can be fast-deleted. This can be done if there are no cascades, no
parents and no signal listeners for the object class.
The 'from_field' tells where we are coming from - we need this to
determine if the objects are in fact to be deleted. Allow also
skipping parent -> child -> parent chain preventing fast delete of
the child.
"""
if from_field and from_field.remote_field.on_delete is not CASCADE:
return False
if hasattr(objs, "_meta"):
model = objs._meta.model
elif hasattr(objs, "model") and hasattr(objs, "_raw_delete"):
model = objs.model
else:
return False
if self._has_signal_listeners(model):
return False
# The use of from_field comes from the need to avoid cascade back to
# parent when parent delete is cascading to child.
opts = model._meta
return (
all(
link == from_field
for link in opts.concrete_model._meta.parents.values()
)
and
# Foreign keys pointing to this model.
all(
related.field.remote_field.on_delete is DO_NOTHING
for related in get_candidate_relations_to_delete(opts)
)
and (
# Something like generic foreign key.
not any(
hasattr(field, "bulk_related_objects")
for field in opts.private_fields
)
)
)
def get_del_batches(self, objs, fields):
"""
Return the objs in suitably sized batches for the used connection.
"""
field_names = [field.name for field in fields]
conn_batch_size = max(
connections[self.using].ops.bulk_batch_size(field_names, objs), 1
)
if len(objs) > conn_batch_size:
return [
objs[i : i + conn_batch_size]
for i in range(0, len(objs), conn_batch_size)
]
else:
return [objs]
def collect(
self,
objs,
source=None,
nullable=False,
collect_related=True,
source_attr=None,
reverse_dependency=False,
keep_parents=False,
fail_on_restricted=True,
):
"""
Add 'objs' to the collection of objects to be deleted as well as all
parent instances. 'objs' must be a homogeneous iterable collection of
model instances (e.g. a QuerySet). If 'collect_related' is True,
related objects will be handled by their respective on_delete handler.
If the call is the result of a cascade, 'source' should be the model
that caused it and 'nullable' should be set to True, if the relation
can be null.
If 'reverse_dependency' is True, 'source' will be deleted before the
current model, rather than after. (Needed for cascading to parent
models, the one case in which the cascade follows the forwards
direction of an FK rather than the reverse direction.)
If 'keep_parents' is True, data of parent model's will be not deleted.
If 'fail_on_restricted' is False, error won't be raised even if it's
prohibited to delete such objects due to RESTRICT, that defers
restricted object checking in recursive calls where the top-level call
may need to collect more objects to determine whether restricted ones
can be deleted.
"""
if self.can_fast_delete(objs):
self.fast_deletes.append(objs)
return
new_objs = self.add(
objs, source, nullable, reverse_dependency=reverse_dependency
)
if not new_objs:
return
model = new_objs[0].__class__
if not keep_parents:
# Recursively collect concrete model's parent models, but not their
# related objects. These will be found by meta.get_fields()
concrete_model = model._meta.concrete_model
for ptr in concrete_model._meta.parents.values():
if ptr:
parent_objs = [getattr(obj, ptr.name) for obj in new_objs]
self.collect(
parent_objs,
source=model,
source_attr=ptr.remote_field.related_name,
collect_related=False,
reverse_dependency=True,
fail_on_restricted=False,
)
if not collect_related:
return
if keep_parents:
parents = set(model._meta.get_parent_list())
model_fast_deletes = defaultdict(list)
protected_objects = defaultdict(list)
for related in get_candidate_relations_to_delete(model._meta):
# Preserve parent reverse relationships if keep_parents=True.
if keep_parents and related.model in parents:
continue
field = related.field
on_delete = field.remote_field.on_delete
if on_delete == DO_NOTHING:
continue
related_model = related.related_model
if self.can_fast_delete(related_model, from_field=field):
model_fast_deletes[related_model].append(field)
continue
batches = self.get_del_batches(new_objs, [field])
for batch in batches:
sub_objs = self.related_objects(related_model, [field], batch)
# Non-referenced fields can be deferred if no signal receivers
# are connected for the related model as they'll never be
# exposed to the user. Skip field deferring when some
# relationships are select_related as interactions between both
# features are hard to get right. This should only happen in
# the rare cases where .related_objects is overridden anyway.
if not (
sub_objs.query.select_related
or self._has_signal_listeners(related_model)
):
referenced_fields = set(
chain.from_iterable(
(rf.attname for rf in rel.field.foreign_related_fields)
for rel in get_candidate_relations_to_delete(
related_model._meta
)
)
)
sub_objs = sub_objs.only(*tuple(referenced_fields))
if getattr(on_delete, "lazy_sub_objs", False) or sub_objs:
try:
on_delete(self, field, sub_objs, self.using)
except ProtectedError as error:
key = "'%s.%s'" % (field.model.__name__, field.name)
protected_objects[key] += error.protected_objects
if protected_objects:
raise ProtectedError(
"Cannot delete some instances of model %r because they are "
"referenced through protected foreign keys: %s."
% (
model.__name__,
", ".join(protected_objects),
),
set(chain.from_iterable(protected_objects.values())),
)
for related_model, related_fields in model_fast_deletes.items():
batches = self.get_del_batches(new_objs, related_fields)
for batch in batches:
sub_objs = self.related_objects(related_model, related_fields, batch)
self.fast_deletes.append(sub_objs)
for field in model._meta.private_fields:
if hasattr(field, "bulk_related_objects"):
# It's something like generic foreign key.
sub_objs = field.bulk_related_objects(new_objs, self.using)
self.collect(
sub_objs, source=model, nullable=True, fail_on_restricted=False
)
if fail_on_restricted:
# Raise an error if collected restricted objects (RESTRICT) aren't
# candidates for deletion also collected via CASCADE.
for related_model, instances in self.data.items():
self.clear_restricted_objects_from_set(related_model, instances)
for qs in self.fast_deletes:
self.clear_restricted_objects_from_queryset(qs.model, qs)
if self.restricted_objects.values():
restricted_objects = defaultdict(list)
for related_model, fields in self.restricted_objects.items():
for field, objs in fields.items():
if objs:
key = "'%s.%s'" % (related_model.__name__, field.name)
restricted_objects[key] += objs
if restricted_objects:
raise RestrictedError(
"Cannot delete some instances of model %r because "
"they are referenced through restricted foreign keys: "
"%s."
% (
model.__name__,
", ".join(restricted_objects),
),
set(chain.from_iterable(restricted_objects.values())),
)
def related_objects(self, related_model, related_fields, objs):
"""
Get a QuerySet of the related model to objs via related fields.
"""
predicate = query_utils.Q.create(
[(f"{related_field.name}__in", objs) for related_field in related_fields],
connector=query_utils.Q.OR,
)
return related_model._base_manager.using(self.using).filter(predicate)
def instances_with_model(self):
for model, instances in self.data.items():
for obj in instances:
yield model, obj
def sort(self):
sorted_models = []
concrete_models = set()
models = list(self.data)
while len(sorted_models) < len(models):
found = False
for model in models:
if model in sorted_models:
continue
dependencies = self.dependencies.get(model._meta.concrete_model)
if not (dependencies and dependencies.difference(concrete_models)):
sorted_models.append(model)
concrete_models.add(model._meta.concrete_model)
found = True
if not found:
return
self.data = {model: self.data[model] for model in sorted_models}
def delete(self):
# sort instance collections
for model, instances in self.data.items():
self.data[model] = sorted(instances, key=attrgetter("pk"))
# if possible, bring the models in an order suitable for databases that
# don't support transactions or cannot defer constraint checks until the
# end of a transaction.
self.sort()
# number of objects deleted for each model label
deleted_counter = Counter()
# Optimize for the case with a single obj and no dependencies
if len(self.data) == 1 and len(instances) == 1:
instance = list(instances)[0]
if self.can_fast_delete(instance):
with transaction.mark_for_rollback_on_error(self.using):
count = sql.DeleteQuery(model).delete_batch(
[instance.pk], self.using
)
setattr(instance, model._meta.pk.attname, None)
return count, {model._meta.label: count}
with transaction.atomic(using=self.using, savepoint=False):
# send pre_delete signals
for model, obj in self.instances_with_model():
if not model._meta.auto_created:
signals.pre_delete.send(
sender=model,
instance=obj,
using=self.using,
origin=self.origin,
)
# fast deletes
for qs in self.fast_deletes:
count = qs._raw_delete(using=self.using)
if count:
deleted_counter[qs.model._meta.label] += count
# update fields
for (field, value), instances_list in self.field_updates.items():
updates = []
objs = []
for instances in instances_list:
if (
isinstance(instances, models.QuerySet)
and instances._result_cache is None
):
updates.append(instances)
else:
objs.extend(instances)
if updates:
combined_updates = reduce(or_, updates)
combined_updates.update(**{field.name: value})
if objs:
model = objs[0].__class__
query = sql.UpdateQuery(model)
query.update_batch(
list({obj.pk for obj in objs}), {field.name: value}, self.using
)
# reverse instance collections
for instances in self.data.values():
instances.reverse()
# delete instances
for model, instances in self.data.items():
query = sql.DeleteQuery(model)
pk_list = [obj.pk for obj in instances]
count = query.delete_batch(pk_list, self.using)
if count:
deleted_counter[model._meta.label] += count
if not model._meta.auto_created:
for obj in instances:
signals.post_delete.send(
sender=model,
instance=obj,
using=self.using,
origin=self.origin,
)
for model, instances in self.data.items():
for instance in instances:
setattr(instance, model._meta.pk.attname, None)
return sum(deleted_counter.values()), dict(deleted_counter)

View File

@ -0,0 +1,92 @@
import enum
from types import DynamicClassAttribute
from django.utils.functional import Promise
__all__ = ["Choices", "IntegerChoices", "TextChoices"]
class ChoicesMeta(enum.EnumMeta):
"""A metaclass for creating a enum choices."""
def __new__(metacls, classname, bases, classdict, **kwds):
labels = []
for key in classdict._member_names:
value = classdict[key]
if (
isinstance(value, (list, tuple))
and len(value) > 1
and isinstance(value[-1], (Promise, str))
):
*value, label = value
value = tuple(value)
else:
label = key.replace("_", " ").title()
labels.append(label)
# Use dict.__setitem__() to suppress defenses against double
# assignment in enum's classdict.
dict.__setitem__(classdict, key, value)
cls = super().__new__(metacls, classname, bases, classdict, **kwds)
for member, label in zip(cls.__members__.values(), labels):
member._label_ = label
return enum.unique(cls)
def __contains__(cls, member):
if not isinstance(member, enum.Enum):
# Allow non-enums to match against member values.
return any(x.value == member for x in cls)
return super().__contains__(member)
@property
def names(cls):
empty = ["__empty__"] if hasattr(cls, "__empty__") else []
return empty + [member.name for member in cls]
@property
def choices(cls):
empty = [(None, cls.__empty__)] if hasattr(cls, "__empty__") else []
return empty + [(member.value, member.label) for member in cls]
@property
def labels(cls):
return [label for _, label in cls.choices]
@property
def values(cls):
return [value for value, _ in cls.choices]
class Choices(enum.Enum, metaclass=ChoicesMeta):
"""Class for creating enumerated choices."""
@DynamicClassAttribute
def label(self):
return self._label_
@property
def do_not_call_in_templates(self):
return True
def __str__(self):
"""
Use value when cast to str, so that Choices set as model instance
attributes are rendered as expected in templates and similar contexts.
"""
return str(self.value)
# A similar format was proposed for Python 3.10.
def __repr__(self):
return f"{self.__class__.__qualname__}.{self._name_}"
class IntegerChoices(int, Choices):
"""Class for creating enumerated integer choices."""
pass
class TextChoices(str, Choices):
"""Class for creating enumerated string choices."""
def _generate_next_value_(name, start, count, last_values):
return name

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,510 @@
import datetime
import posixpath
from django import forms
from django.core import checks
from django.core.files.base import File
from django.core.files.images import ImageFile
from django.core.files.storage import Storage, default_storage
from django.core.files.utils import validate_file_name
from django.db.models import signals
from django.db.models.fields import Field
from django.db.models.query_utils import DeferredAttribute
from django.db.models.utils import AltersData
from django.utils.translation import gettext_lazy as _
class FieldFile(File, AltersData):
def __init__(self, instance, field, name):
super().__init__(None, name)
self.instance = instance
self.field = field
self.storage = field.storage
self._committed = True
def __eq__(self, other):
# Older code may be expecting FileField values to be simple strings.
# By overriding the == operator, it can remain backwards compatibility.
if hasattr(other, "name"):
return self.name == other.name
return self.name == other
def __hash__(self):
return hash(self.name)
# The standard File contains most of the necessary properties, but
# FieldFiles can be instantiated without a name, so that needs to
# be checked for here.
def _require_file(self):
if not self:
raise ValueError(
"The '%s' attribute has no file associated with it." % self.field.name
)
def _get_file(self):
self._require_file()
if getattr(self, "_file", None) is None:
self._file = self.storage.open(self.name, "rb")
return self._file
def _set_file(self, file):
self._file = file
def _del_file(self):
del self._file
file = property(_get_file, _set_file, _del_file)
@property
def path(self):
self._require_file()
return self.storage.path(self.name)
@property
def url(self):
self._require_file()
return self.storage.url(self.name)
@property
def size(self):
self._require_file()
if not self._committed:
return self.file.size
return self.storage.size(self.name)
def open(self, mode="rb"):
self._require_file()
if getattr(self, "_file", None) is None:
self.file = self.storage.open(self.name, mode)
else:
self.file.open(mode)
return self
# open() doesn't alter the file's contents, but it does reset the pointer
open.alters_data = True
# In addition to the standard File API, FieldFiles have extra methods
# to further manipulate the underlying file, as well as update the
# associated model instance.
def save(self, name, content, save=True):
name = self.field.generate_filename(self.instance, name)
self.name = self.storage.save(name, content, max_length=self.field.max_length)
setattr(self.instance, self.field.attname, self.name)
self._committed = True
# Save the object because it has changed, unless save is False
if save:
self.instance.save()
save.alters_data = True
def delete(self, save=True):
if not self:
return
# Only close the file if it's already open, which we know by the
# presence of self._file
if hasattr(self, "_file"):
self.close()
del self.file
self.storage.delete(self.name)
self.name = None
setattr(self.instance, self.field.attname, self.name)
self._committed = False
if save:
self.instance.save()
delete.alters_data = True
@property
def closed(self):
file = getattr(self, "_file", None)
return file is None or file.closed
def close(self):
file = getattr(self, "_file", None)
if file is not None:
file.close()
def __getstate__(self):
# FieldFile needs access to its associated model field, an instance and
# the file's name. Everything else will be restored later, by
# FileDescriptor below.
return {
"name": self.name,
"closed": False,
"_committed": True,
"_file": None,
"instance": self.instance,
"field": self.field,
}
def __setstate__(self, state):
self.__dict__.update(state)
self.storage = self.field.storage
class FileDescriptor(DeferredAttribute):
"""
The descriptor for the file attribute on the model instance. Return a
FieldFile when accessed so you can write code like::
>>> from myapp.models import MyModel
>>> instance = MyModel.objects.get(pk=1)
>>> instance.file.size
Assign a file object on assignment so you can do::
>>> with open('/path/to/hello.world') as f:
... instance.file = File(f)
"""
def __get__(self, instance, cls=None):
if instance is None:
return self
# This is slightly complicated, so worth an explanation.
# instance.file needs to ultimately return some instance of `File`,
# probably a subclass. Additionally, this returned object needs to have
# the FieldFile API so that users can easily do things like
# instance.file.path and have that delegated to the file storage engine.
# Easy enough if we're strict about assignment in __set__, but if you
# peek below you can see that we're not. So depending on the current
# value of the field we have to dynamically construct some sort of
# "thing" to return.
# The instance dict contains whatever was originally assigned
# in __set__.
file = super().__get__(instance, cls)
# If this value is a string (instance.file = "path/to/file") or None
# then we simply wrap it with the appropriate attribute class according
# to the file field. [This is FieldFile for FileFields and
# ImageFieldFile for ImageFields; it's also conceivable that user
# subclasses might also want to subclass the attribute class]. This
# object understands how to convert a path to a file, and also how to
# handle None.
if isinstance(file, str) or file is None:
attr = self.field.attr_class(instance, self.field, file)
instance.__dict__[self.field.attname] = attr
# Other types of files may be assigned as well, but they need to have
# the FieldFile interface added to them. Thus, we wrap any other type of
# File inside a FieldFile (well, the field's attr_class, which is
# usually FieldFile).
elif isinstance(file, File) and not isinstance(file, FieldFile):
file_copy = self.field.attr_class(instance, self.field, file.name)
file_copy.file = file
file_copy._committed = False
instance.__dict__[self.field.attname] = file_copy
# Finally, because of the (some would say boneheaded) way pickle works,
# the underlying FieldFile might not actually itself have an associated
# file. So we need to reset the details of the FieldFile in those cases.
elif isinstance(file, FieldFile) and not hasattr(file, "field"):
file.instance = instance
file.field = self.field
file.storage = self.field.storage
# Make sure that the instance is correct.
elif isinstance(file, FieldFile) and instance is not file.instance:
file.instance = instance
# That was fun, wasn't it?
return instance.__dict__[self.field.attname]
def __set__(self, instance, value):
instance.__dict__[self.field.attname] = value
class FileField(Field):
# The class to wrap instance attributes in. Accessing the file object off
# the instance will always return an instance of attr_class.
attr_class = FieldFile
# The descriptor to use for accessing the attribute off of the class.
descriptor_class = FileDescriptor
description = _("File")
def __init__(
self, verbose_name=None, name=None, upload_to="", storage=None, **kwargs
):
self._primary_key_set_explicitly = "primary_key" in kwargs
self.storage = storage or default_storage
if callable(self.storage):
# Hold a reference to the callable for deconstruct().
self._storage_callable = self.storage
self.storage = self.storage()
if not isinstance(self.storage, Storage):
raise TypeError(
"%s.storage must be a subclass/instance of %s.%s"
% (
self.__class__.__qualname__,
Storage.__module__,
Storage.__qualname__,
)
)
self.upload_to = upload_to
kwargs.setdefault("max_length", 100)
super().__init__(verbose_name, name, **kwargs)
def check(self, **kwargs):
return [
*super().check(**kwargs),
*self._check_primary_key(),
*self._check_upload_to(),
]
def _check_primary_key(self):
if self._primary_key_set_explicitly:
return [
checks.Error(
"'primary_key' is not a valid argument for a %s."
% self.__class__.__name__,
obj=self,
id="fields.E201",
)
]
else:
return []
def _check_upload_to(self):
if isinstance(self.upload_to, str) and self.upload_to.startswith("/"):
return [
checks.Error(
"%s's 'upload_to' argument must be a relative path, not an "
"absolute path." % self.__class__.__name__,
obj=self,
id="fields.E202",
hint="Remove the leading slash.",
)
]
else:
return []
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
if kwargs.get("max_length") == 100:
del kwargs["max_length"]
kwargs["upload_to"] = self.upload_to
storage = getattr(self, "_storage_callable", self.storage)
if storage is not default_storage:
kwargs["storage"] = storage
return name, path, args, kwargs
def get_internal_type(self):
return "FileField"
def get_prep_value(self, value):
value = super().get_prep_value(value)
# Need to convert File objects provided via a form to string for
# database insertion.
if value is None:
return None
return str(value)
def pre_save(self, model_instance, add):
file = super().pre_save(model_instance, add)
if file and not file._committed:
# Commit the file to storage prior to saving the model
file.save(file.name, file.file, save=False)
return file
def contribute_to_class(self, cls, name, **kwargs):
super().contribute_to_class(cls, name, **kwargs)
setattr(cls, self.attname, self.descriptor_class(self))
def generate_filename(self, instance, filename):
"""
Apply (if callable) or prepend (if a string) upload_to to the filename,
then delegate further processing of the name to the storage backend.
Until the storage layer, all file paths are expected to be Unix style
(with forward slashes).
"""
if callable(self.upload_to):
filename = self.upload_to(instance, filename)
else:
dirname = datetime.datetime.now().strftime(str(self.upload_to))
filename = posixpath.join(dirname, filename)
filename = validate_file_name(filename, allow_relative_path=True)
return self.storage.generate_filename(filename)
def save_form_data(self, instance, data):
# Important: None means "no change", other false value means "clear"
# This subtle distinction (rather than a more explicit marker) is
# needed because we need to consume values that are also sane for a
# regular (non Model-) Form to find in its cleaned_data dictionary.
if data is not None:
# This value will be converted to str and stored in the
# database, so leaving False as-is is not acceptable.
setattr(instance, self.name, data or "")
def formfield(self, **kwargs):
return super().formfield(
**{
"form_class": forms.FileField,
"max_length": self.max_length,
**kwargs,
}
)
class ImageFileDescriptor(FileDescriptor):
"""
Just like the FileDescriptor, but for ImageFields. The only difference is
assigning the width/height to the width_field/height_field, if appropriate.
"""
def __set__(self, instance, value):
previous_file = instance.__dict__.get(self.field.attname)
super().__set__(instance, value)
# To prevent recalculating image dimensions when we are instantiating
# an object from the database (bug #11084), only update dimensions if
# the field had a value before this assignment. Since the default
# value for FileField subclasses is an instance of field.attr_class,
# previous_file will only be None when we are called from
# Model.__init__(). The ImageField.update_dimension_fields method
# hooked up to the post_init signal handles the Model.__init__() cases.
# Assignment happening outside of Model.__init__() will trigger the
# update right here.
if previous_file is not None:
self.field.update_dimension_fields(instance, force=True)
class ImageFieldFile(ImageFile, FieldFile):
def delete(self, save=True):
# Clear the image dimensions cache
if hasattr(self, "_dimensions_cache"):
del self._dimensions_cache
super().delete(save)
class ImageField(FileField):
attr_class = ImageFieldFile
descriptor_class = ImageFileDescriptor
description = _("Image")
def __init__(
self,
verbose_name=None,
name=None,
width_field=None,
height_field=None,
**kwargs,
):
self.width_field, self.height_field = width_field, height_field
super().__init__(verbose_name, name, **kwargs)
def check(self, **kwargs):
return [
*super().check(**kwargs),
*self._check_image_library_installed(),
]
def _check_image_library_installed(self):
try:
from PIL import Image # NOQA
except ImportError:
return [
checks.Error(
"Cannot use ImageField because Pillow is not installed.",
hint=(
"Get Pillow at https://pypi.org/project/Pillow/ "
'or run command "python -m pip install Pillow".'
),
obj=self,
id="fields.E210",
)
]
else:
return []
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
if self.width_field:
kwargs["width_field"] = self.width_field
if self.height_field:
kwargs["height_field"] = self.height_field
return name, path, args, kwargs
def contribute_to_class(self, cls, name, **kwargs):
super().contribute_to_class(cls, name, **kwargs)
# Attach update_dimension_fields so that dimension fields declared
# after their corresponding image field don't stay cleared by
# Model.__init__, see bug #11196.
# Only run post-initialization dimension update on non-abstract models
if not cls._meta.abstract:
signals.post_init.connect(self.update_dimension_fields, sender=cls)
def update_dimension_fields(self, instance, force=False, *args, **kwargs):
"""
Update field's width and height fields, if defined.
This method is hooked up to model's post_init signal to update
dimensions after instantiating a model instance. However, dimensions
won't be updated if the dimensions fields are already populated. This
avoids unnecessary recalculation when loading an object from the
database.
Dimensions can be forced to update with force=True, which is how
ImageFileDescriptor.__set__ calls this method.
"""
# Nothing to update if the field doesn't have dimension fields or if
# the field is deferred.
has_dimension_fields = self.width_field or self.height_field
if not has_dimension_fields or self.attname not in instance.__dict__:
return
# getattr will call the ImageFileDescriptor's __get__ method, which
# coerces the assigned value into an instance of self.attr_class
# (ImageFieldFile in this case).
file = getattr(instance, self.attname)
# Nothing to update if we have no file and not being forced to update.
if not file and not force:
return
dimension_fields_filled = not (
(self.width_field and not getattr(instance, self.width_field))
or (self.height_field and not getattr(instance, self.height_field))
)
# When both dimension fields have values, we are most likely loading
# data from the database or updating an image field that already had
# an image stored. In the first case, we don't want to update the
# dimension fields because we are already getting their values from the
# database. In the second case, we do want to update the dimensions
# fields and will skip this return because force will be True since we
# were called from ImageFileDescriptor.__set__.
if dimension_fields_filled and not force:
return
# file should be an instance of ImageFieldFile or should be None.
if file:
width = file.width
height = file.height
else:
# No file, so clear dimensions fields.
width = None
height = None
# Update the width and height fields.
if self.width_field:
setattr(instance, self.width_field, width)
if self.height_field:
setattr(instance, self.height_field, height)
def formfield(self, **kwargs):
return super().formfield(
**{
"form_class": forms.ImageField,
**kwargs,
}
)

View File

@ -0,0 +1,638 @@
import json
import warnings
from django import forms
from django.core import checks, exceptions
from django.db import NotSupportedError, connections, router
from django.db.models import expressions, lookups
from django.db.models.constants import LOOKUP_SEP
from django.db.models.fields import TextField
from django.db.models.lookups import (
FieldGetDbPrepValueMixin,
PostgresOperatorLookup,
Transform,
)
from django.utils.deprecation import RemovedInDjango51Warning
from django.utils.translation import gettext_lazy as _
from . import Field
from .mixins import CheckFieldDefaultMixin
__all__ = ["JSONField"]
class JSONField(CheckFieldDefaultMixin, Field):
empty_strings_allowed = False
description = _("A JSON object")
default_error_messages = {
"invalid": _("Value must be valid JSON."),
}
_default_hint = ("dict", "{}")
def __init__(
self,
verbose_name=None,
name=None,
encoder=None,
decoder=None,
**kwargs,
):
if encoder and not callable(encoder):
raise ValueError("The encoder parameter must be a callable object.")
if decoder and not callable(decoder):
raise ValueError("The decoder parameter must be a callable object.")
self.encoder = encoder
self.decoder = decoder
super().__init__(verbose_name, name, **kwargs)
def check(self, **kwargs):
errors = super().check(**kwargs)
databases = kwargs.get("databases") or []
errors.extend(self._check_supported(databases))
return errors
def _check_supported(self, databases):
errors = []
for db in databases:
if not router.allow_migrate_model(db, self.model):
continue
connection = connections[db]
if (
self.model._meta.required_db_vendor
and self.model._meta.required_db_vendor != connection.vendor
):
continue
if not (
"supports_json_field" in self.model._meta.required_db_features
or connection.features.supports_json_field
):
errors.append(
checks.Error(
"%s does not support JSONFields." % connection.display_name,
obj=self.model,
id="fields.E180",
)
)
return errors
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
if self.encoder is not None:
kwargs["encoder"] = self.encoder
if self.decoder is not None:
kwargs["decoder"] = self.decoder
return name, path, args, kwargs
def from_db_value(self, value, expression, connection):
if value is None:
return value
# Some backends (SQLite at least) extract non-string values in their
# SQL datatypes.
if isinstance(expression, KeyTransform) and not isinstance(value, str):
return value
try:
return json.loads(value, cls=self.decoder)
except json.JSONDecodeError:
return value
def get_internal_type(self):
return "JSONField"
def get_db_prep_value(self, value, connection, prepared=False):
if not prepared:
value = self.get_prep_value(value)
# RemovedInDjango51Warning: When the deprecation ends, replace with:
# if (
# isinstance(value, expressions.Value)
# and isinstance(value.output_field, JSONField)
# ):
# value = value.value
# elif hasattr(value, "as_sql"): ...
if isinstance(value, expressions.Value):
if isinstance(value.value, str) and not isinstance(
value.output_field, JSONField
):
try:
value = json.loads(value.value, cls=self.decoder)
except json.JSONDecodeError:
value = value.value
else:
warnings.warn(
"Providing an encoded JSON string via Value() is deprecated. "
f"Use Value({value!r}, output_field=JSONField()) instead.",
category=RemovedInDjango51Warning,
)
elif isinstance(value.output_field, JSONField):
value = value.value
else:
return value
elif hasattr(value, "as_sql"):
return value
return connection.ops.adapt_json_value(value, self.encoder)
def get_db_prep_save(self, value, connection):
if value is None:
return value
return self.get_db_prep_value(value, connection)
def get_transform(self, name):
transform = super().get_transform(name)
if transform:
return transform
return KeyTransformFactory(name)
def validate(self, value, model_instance):
super().validate(value, model_instance)
try:
json.dumps(value, cls=self.encoder)
except TypeError:
raise exceptions.ValidationError(
self.error_messages["invalid"],
code="invalid",
params={"value": value},
)
def value_to_string(self, obj):
return self.value_from_object(obj)
def formfield(self, **kwargs):
return super().formfield(
**{
"form_class": forms.JSONField,
"encoder": self.encoder,
"decoder": self.decoder,
**kwargs,
}
)
def compile_json_path(key_transforms, include_root=True):
path = ["$"] if include_root else []
for key_transform in key_transforms:
try:
num = int(key_transform)
except ValueError: # non-integer
path.append(".")
path.append(json.dumps(key_transform))
else:
path.append("[%s]" % num)
return "".join(path)
class DataContains(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
lookup_name = "contains"
postgres_operator = "@>"
def as_sql(self, compiler, connection):
if not connection.features.supports_json_field_contains:
raise NotSupportedError(
"contains lookup is not supported on this database backend."
)
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = tuple(lhs_params) + tuple(rhs_params)
return "JSON_CONTAINS(%s, %s)" % (lhs, rhs), params
class ContainedBy(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
lookup_name = "contained_by"
postgres_operator = "<@"
def as_sql(self, compiler, connection):
if not connection.features.supports_json_field_contains:
raise NotSupportedError(
"contained_by lookup is not supported on this database backend."
)
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = tuple(rhs_params) + tuple(lhs_params)
return "JSON_CONTAINS(%s, %s)" % (rhs, lhs), params
class HasKeyLookup(PostgresOperatorLookup):
logical_operator = None
def compile_json_path_final_key(self, key_transform):
# Compile the final key without interpreting ints as array elements.
return ".%s" % json.dumps(key_transform)
def as_sql(self, compiler, connection, template=None):
# Process JSON path from the left-hand side.
if isinstance(self.lhs, KeyTransform):
lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(
compiler, connection
)
lhs_json_path = compile_json_path(lhs_key_transforms)
else:
lhs, lhs_params = self.process_lhs(compiler, connection)
lhs_json_path = "$"
sql = template % lhs
# Process JSON path from the right-hand side.
rhs = self.rhs
rhs_params = []
if not isinstance(rhs, (list, tuple)):
rhs = [rhs]
for key in rhs:
if isinstance(key, KeyTransform):
*_, rhs_key_transforms = key.preprocess_lhs(compiler, connection)
else:
rhs_key_transforms = [key]
*rhs_key_transforms, final_key = rhs_key_transforms
rhs_json_path = compile_json_path(rhs_key_transforms, include_root=False)
rhs_json_path += self.compile_json_path_final_key(final_key)
rhs_params.append(lhs_json_path + rhs_json_path)
# Add condition for each key.
if self.logical_operator:
sql = "(%s)" % self.logical_operator.join([sql] * len(rhs_params))
return sql, tuple(lhs_params) + tuple(rhs_params)
def as_mysql(self, compiler, connection):
return self.as_sql(
compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)"
)
def as_oracle(self, compiler, connection):
sql, params = self.as_sql(
compiler, connection, template="JSON_EXISTS(%s, '%%s')"
)
# Add paths directly into SQL because path expressions cannot be passed
# as bind variables on Oracle.
return sql % tuple(params), []
def as_postgresql(self, compiler, connection):
if isinstance(self.rhs, KeyTransform):
*_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection)
for key in rhs_key_transforms[:-1]:
self.lhs = KeyTransform(key, self.lhs)
self.rhs = rhs_key_transforms[-1]
return super().as_postgresql(compiler, connection)
def as_sqlite(self, compiler, connection):
return self.as_sql(
compiler, connection, template="JSON_TYPE(%s, %%s) IS NOT NULL"
)
class HasKey(HasKeyLookup):
lookup_name = "has_key"
postgres_operator = "?"
prepare_rhs = False
class HasKeys(HasKeyLookup):
lookup_name = "has_keys"
postgres_operator = "?&"
logical_operator = " AND "
def get_prep_lookup(self):
return [str(item) for item in self.rhs]
class HasAnyKeys(HasKeys):
lookup_name = "has_any_keys"
postgres_operator = "?|"
logical_operator = " OR "
class HasKeyOrArrayIndex(HasKey):
def compile_json_path_final_key(self, key_transform):
return compile_json_path([key_transform], include_root=False)
class CaseInsensitiveMixin:
"""
Mixin to allow case-insensitive comparison of JSON values on MySQL.
MySQL handles strings used in JSON context using the utf8mb4_bin collation.
Because utf8mb4_bin is a binary collation, comparison of JSON values is
case-sensitive.
"""
def process_lhs(self, compiler, connection):
lhs, lhs_params = super().process_lhs(compiler, connection)
if connection.vendor == "mysql":
return "LOWER(%s)" % lhs, lhs_params
return lhs, lhs_params
def process_rhs(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
if connection.vendor == "mysql":
return "LOWER(%s)" % rhs, rhs_params
return rhs, rhs_params
class JSONExact(lookups.Exact):
can_use_none_as_rhs = True
def process_rhs(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
# Treat None lookup values as null.
if rhs == "%s" and rhs_params == [None]:
rhs_params = ["null"]
if connection.vendor == "mysql":
func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params)
rhs %= tuple(func)
return rhs, rhs_params
class JSONIContains(CaseInsensitiveMixin, lookups.IContains):
pass
JSONField.register_lookup(DataContains)
JSONField.register_lookup(ContainedBy)
JSONField.register_lookup(HasKey)
JSONField.register_lookup(HasKeys)
JSONField.register_lookup(HasAnyKeys)
JSONField.register_lookup(JSONExact)
JSONField.register_lookup(JSONIContains)
class KeyTransform(Transform):
postgres_operator = "->"
postgres_nested_operator = "#>"
def __init__(self, key_name, *args, **kwargs):
super().__init__(*args, **kwargs)
self.key_name = str(key_name)
def preprocess_lhs(self, compiler, connection):
key_transforms = [self.key_name]
previous = self.lhs
while isinstance(previous, KeyTransform):
key_transforms.insert(0, previous.key_name)
previous = previous.lhs
lhs, params = compiler.compile(previous)
if connection.vendor == "oracle":
# Escape string-formatting.
key_transforms = [key.replace("%", "%%") for key in key_transforms]
return lhs, params, key_transforms
def as_mysql(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
json_path = compile_json_path(key_transforms)
return "JSON_EXTRACT(%s, %%s)" % lhs, tuple(params) + (json_path,)
def as_oracle(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
json_path = compile_json_path(key_transforms)
return (
"COALESCE(JSON_QUERY(%s, '%s'), JSON_VALUE(%s, '%s'))"
% ((lhs, json_path) * 2)
), tuple(params) * 2
def as_postgresql(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
if len(key_transforms) > 1:
sql = "(%s %s %%s)" % (lhs, self.postgres_nested_operator)
return sql, tuple(params) + (key_transforms,)
try:
lookup = int(self.key_name)
except ValueError:
lookup = self.key_name
return "(%s %s %%s)" % (lhs, self.postgres_operator), tuple(params) + (lookup,)
def as_sqlite(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
json_path = compile_json_path(key_transforms)
datatype_values = ",".join(
[repr(datatype) for datatype in connection.ops.jsonfield_datatype_values]
)
return (
"(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) "
"THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)"
) % (lhs, datatype_values, lhs, lhs), (tuple(params) + (json_path,)) * 3
class KeyTextTransform(KeyTransform):
postgres_operator = "->>"
postgres_nested_operator = "#>>"
output_field = TextField()
def as_mysql(self, compiler, connection):
if connection.mysql_is_mariadb:
# MariaDB doesn't support -> and ->> operators (see MDEV-13594).
sql, params = super().as_mysql(compiler, connection)
return "JSON_UNQUOTE(%s)" % sql, params
else:
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
json_path = compile_json_path(key_transforms)
return "(%s ->> %%s)" % lhs, tuple(params) + (json_path,)
@classmethod
def from_lookup(cls, lookup):
transform, *keys = lookup.split(LOOKUP_SEP)
if not keys:
raise ValueError("Lookup must contain key or index transforms.")
for key in keys:
transform = cls(key, transform)
return transform
KT = KeyTextTransform.from_lookup
class KeyTransformTextLookupMixin:
"""
Mixin for combining with a lookup expecting a text lhs from a JSONField
key lookup. On PostgreSQL, make use of the ->> operator instead of casting
key values to text and performing the lookup on the resulting
representation.
"""
def __init__(self, key_transform, *args, **kwargs):
if not isinstance(key_transform, KeyTransform):
raise TypeError(
"Transform should be an instance of KeyTransform in order to "
"use this lookup."
)
key_text_transform = KeyTextTransform(
key_transform.key_name,
*key_transform.source_expressions,
**key_transform.extra,
)
super().__init__(key_text_transform, *args, **kwargs)
class KeyTransformIsNull(lookups.IsNull):
# key__isnull=False is the same as has_key='key'
def as_oracle(self, compiler, connection):
sql, params = HasKeyOrArrayIndex(
self.lhs.lhs,
self.lhs.key_name,
).as_oracle(compiler, connection)
if not self.rhs:
return sql, params
# Column doesn't have a key or IS NULL.
lhs, lhs_params, _ = self.lhs.preprocess_lhs(compiler, connection)
return "(NOT %s OR %s IS NULL)" % (sql, lhs), tuple(params) + tuple(lhs_params)
def as_sqlite(self, compiler, connection):
template = "JSON_TYPE(%s, %%s) IS NULL"
if not self.rhs:
template = "JSON_TYPE(%s, %%s) IS NOT NULL"
return HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name).as_sql(
compiler,
connection,
template=template,
)
class KeyTransformIn(lookups.In):
def resolve_expression_parameter(self, compiler, connection, sql, param):
sql, params = super().resolve_expression_parameter(
compiler,
connection,
sql,
param,
)
if (
not hasattr(param, "as_sql")
and not connection.features.has_native_json_field
):
if connection.vendor == "oracle":
value = json.loads(param)
sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
if isinstance(value, (list, dict)):
sql %= "JSON_QUERY"
else:
sql %= "JSON_VALUE"
elif connection.vendor == "mysql" or (
connection.vendor == "sqlite"
and params[0] not in connection.ops.jsonfield_datatype_values
):
sql = "JSON_EXTRACT(%s, '$')"
if connection.vendor == "mysql" and connection.mysql_is_mariadb:
sql = "JSON_UNQUOTE(%s)" % sql
return sql, params
class KeyTransformExact(JSONExact):
def process_rhs(self, compiler, connection):
if isinstance(self.rhs, KeyTransform):
return super(lookups.Exact, self).process_rhs(compiler, connection)
rhs, rhs_params = super().process_rhs(compiler, connection)
if connection.vendor == "oracle":
func = []
sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
for value in rhs_params:
value = json.loads(value)
if isinstance(value, (list, dict)):
func.append(sql % "JSON_QUERY")
else:
func.append(sql % "JSON_VALUE")
rhs %= tuple(func)
elif connection.vendor == "sqlite":
func = []
for value in rhs_params:
if value in connection.ops.jsonfield_datatype_values:
func.append("%s")
else:
func.append("JSON_EXTRACT(%s, '$')")
rhs %= tuple(func)
return rhs, rhs_params
def as_oracle(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
if rhs_params == ["null"]:
# Field has key and it's NULL.
has_key_expr = HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name)
has_key_sql, has_key_params = has_key_expr.as_oracle(compiler, connection)
is_null_expr = self.lhs.get_lookup("isnull")(self.lhs, True)
is_null_sql, is_null_params = is_null_expr.as_sql(compiler, connection)
return (
"%s AND %s" % (has_key_sql, is_null_sql),
tuple(has_key_params) + tuple(is_null_params),
)
return super().as_sql(compiler, connection)
class KeyTransformIExact(
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact
):
pass
class KeyTransformIContains(
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains
):
pass
class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
pass
class KeyTransformIStartsWith(
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith
):
pass
class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
pass
class KeyTransformIEndsWith(
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith
):
pass
class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
pass
class KeyTransformIRegex(
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex
):
pass
class KeyTransformNumericLookupMixin:
def process_rhs(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
if not connection.features.has_native_json_field:
rhs_params = [json.loads(value) for value in rhs_params]
return rhs, rhs_params
class KeyTransformLt(KeyTransformNumericLookupMixin, lookups.LessThan):
pass
class KeyTransformLte(KeyTransformNumericLookupMixin, lookups.LessThanOrEqual):
pass
class KeyTransformGt(KeyTransformNumericLookupMixin, lookups.GreaterThan):
pass
class KeyTransformGte(KeyTransformNumericLookupMixin, lookups.GreaterThanOrEqual):
pass
KeyTransform.register_lookup(KeyTransformIn)
KeyTransform.register_lookup(KeyTransformExact)
KeyTransform.register_lookup(KeyTransformIExact)
KeyTransform.register_lookup(KeyTransformIsNull)
KeyTransform.register_lookup(KeyTransformIContains)
KeyTransform.register_lookup(KeyTransformStartsWith)
KeyTransform.register_lookup(KeyTransformIStartsWith)
KeyTransform.register_lookup(KeyTransformEndsWith)
KeyTransform.register_lookup(KeyTransformIEndsWith)
KeyTransform.register_lookup(KeyTransformRegex)
KeyTransform.register_lookup(KeyTransformIRegex)
KeyTransform.register_lookup(KeyTransformLt)
KeyTransform.register_lookup(KeyTransformLte)
KeyTransform.register_lookup(KeyTransformGt)
KeyTransform.register_lookup(KeyTransformGte)
class KeyTransformFactory:
def __init__(self, key_name):
self.key_name = key_name
def __call__(self, *args, **kwargs):
return KeyTransform(self.key_name, *args, **kwargs)

View File

@ -0,0 +1,59 @@
from django.core import checks
NOT_PROVIDED = object()
class FieldCacheMixin:
"""Provide an API for working with the model's fields value cache."""
def get_cache_name(self):
raise NotImplementedError
def get_cached_value(self, instance, default=NOT_PROVIDED):
cache_name = self.get_cache_name()
try:
return instance._state.fields_cache[cache_name]
except KeyError:
if default is NOT_PROVIDED:
raise
return default
def is_cached(self, instance):
return self.get_cache_name() in instance._state.fields_cache
def set_cached_value(self, instance, value):
instance._state.fields_cache[self.get_cache_name()] = value
def delete_cached_value(self, instance):
del instance._state.fields_cache[self.get_cache_name()]
class CheckFieldDefaultMixin:
_default_hint = ("<valid default>", "<invalid default>")
def _check_default(self):
if (
self.has_default()
and self.default is not None
and not callable(self.default)
):
return [
checks.Warning(
"%s default should be a callable instead of an instance "
"so that it's not shared between all field instances."
% (self.__class__.__name__,),
hint=(
"Use a callable instead, e.g., use `%s` instead of "
"`%s`." % self._default_hint
),
obj=self,
id="fields.E010",
)
]
else:
return []
def check(self, **kwargs):
errors = super().check(**kwargs)
errors.extend(self._check_default())
return errors

View File

@ -0,0 +1,18 @@
"""
Field-like classes that aren't really fields. It's easier to use objects that
have the same attributes as fields sometimes (avoids a lot of special casing).
"""
from django.db.models import fields
class OrderWrt(fields.IntegerField):
"""
A proxy for the _order database field that is used when
Meta.order_with_respect_to is specified.
"""
def __init__(self, *args, **kwargs):
kwargs["name"] = "_order"
kwargs["editable"] = False
super().__init__(*args, **kwargs)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,209 @@
import warnings
from django.db.models.lookups import (
Exact,
GreaterThan,
GreaterThanOrEqual,
In,
IsNull,
LessThan,
LessThanOrEqual,
)
from django.utils.deprecation import RemovedInDjango50Warning
class MultiColSource:
contains_aggregate = False
contains_over_clause = False
def __init__(self, alias, targets, sources, field):
self.targets, self.sources, self.field, self.alias = (
targets,
sources,
field,
alias,
)
self.output_field = self.field
def __repr__(self):
return "{}({}, {})".format(self.__class__.__name__, self.alias, self.field)
def relabeled_clone(self, relabels):
return self.__class__(
relabels.get(self.alias, self.alias), self.targets, self.sources, self.field
)
def get_lookup(self, lookup):
return self.output_field.get_lookup(lookup)
def resolve_expression(self, *args, **kwargs):
return self
def get_normalized_value(value, lhs):
from django.db.models import Model
if isinstance(value, Model):
if value.pk is None:
# When the deprecation ends, replace with:
# raise ValueError(
# "Model instances passed to related filters must be saved."
# )
warnings.warn(
"Passing unsaved model instances to related filters is deprecated.",
RemovedInDjango50Warning,
)
value_list = []
sources = lhs.output_field.path_infos[-1].target_fields
for source in sources:
while not isinstance(value, source.model) and source.remote_field:
source = source.remote_field.model._meta.get_field(
source.remote_field.field_name
)
try:
value_list.append(getattr(value, source.attname))
except AttributeError:
# A case like Restaurant.objects.filter(place=restaurant_instance),
# where place is a OneToOneField and the primary key of Restaurant.
return (value.pk,)
return tuple(value_list)
if not isinstance(value, tuple):
return (value,)
return value
class RelatedIn(In):
def get_prep_lookup(self):
if not isinstance(self.lhs, MultiColSource):
if self.rhs_is_direct_value():
# If we get here, we are dealing with single-column relations.
self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs]
# We need to run the related field's get_prep_value(). Consider
# case ForeignKey to IntegerField given value 'abc'. The
# ForeignKey itself doesn't have validation for non-integers,
# so we must run validation using the target field.
if hasattr(self.lhs.output_field, "path_infos"):
# Run the target field's get_prep_value. We can safely
# assume there is only one as we don't get to the direct
# value branch otherwise.
target_field = self.lhs.output_field.path_infos[-1].target_fields[
-1
]
self.rhs = [target_field.get_prep_value(v) for v in self.rhs]
elif not getattr(self.rhs, "has_select_fields", True) and not getattr(
self.lhs.field.target_field, "primary_key", False
):
if (
getattr(self.lhs.output_field, "primary_key", False)
and self.lhs.output_field.model == self.rhs.model
):
# A case like
# Restaurant.objects.filter(place__in=restaurant_qs), where
# place is a OneToOneField and the primary key of
# Restaurant.
target_field = self.lhs.field.name
else:
target_field = self.lhs.field.target_field.name
self.rhs.set_values([target_field])
return super().get_prep_lookup()
def as_sql(self, compiler, connection):
if isinstance(self.lhs, MultiColSource):
# For multicolumn lookups we need to build a multicolumn where clause.
# This clause is either a SubqueryConstraint (for values that need
# to be compiled to SQL) or an OR-combined list of
# (col1 = val1 AND col2 = val2 AND ...) clauses.
from django.db.models.sql.where import (
AND,
OR,
SubqueryConstraint,
WhereNode,
)
root_constraint = WhereNode(connector=OR)
if self.rhs_is_direct_value():
values = [get_normalized_value(value, self.lhs) for value in self.rhs]
for value in values:
value_constraint = WhereNode()
for source, target, val in zip(
self.lhs.sources, self.lhs.targets, value
):
lookup_class = target.get_lookup("exact")
lookup = lookup_class(
target.get_col(self.lhs.alias, source), val
)
value_constraint.add(lookup, AND)
root_constraint.add(value_constraint, OR)
else:
root_constraint.add(
SubqueryConstraint(
self.lhs.alias,
[target.column for target in self.lhs.targets],
[source.name for source in self.lhs.sources],
self.rhs,
),
AND,
)
return root_constraint.as_sql(compiler, connection)
return super().as_sql(compiler, connection)
class RelatedLookupMixin:
def get_prep_lookup(self):
if not isinstance(self.lhs, MultiColSource) and not hasattr(
self.rhs, "resolve_expression"
):
# If we get here, we are dealing with single-column relations.
self.rhs = get_normalized_value(self.rhs, self.lhs)[0]
# We need to run the related field's get_prep_value(). Consider case
# ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
# doesn't have validation for non-integers, so we must run validation
# using the target field.
if self.prepare_rhs and hasattr(self.lhs.output_field, "path_infos"):
# Get the target field. We can safely assume there is only one
# as we don't get to the direct value branch otherwise.
target_field = self.lhs.output_field.path_infos[-1].target_fields[-1]
self.rhs = target_field.get_prep_value(self.rhs)
return super().get_prep_lookup()
def as_sql(self, compiler, connection):
if isinstance(self.lhs, MultiColSource):
assert self.rhs_is_direct_value()
self.rhs = get_normalized_value(self.rhs, self.lhs)
from django.db.models.sql.where import AND, WhereNode
root_constraint = WhereNode()
for target, source, val in zip(
self.lhs.targets, self.lhs.sources, self.rhs
):
lookup_class = target.get_lookup(self.lookup_name)
root_constraint.add(
lookup_class(target.get_col(self.lhs.alias, source), val), AND
)
return root_constraint.as_sql(compiler, connection)
return super().as_sql(compiler, connection)
class RelatedExact(RelatedLookupMixin, Exact):
pass
class RelatedLessThan(RelatedLookupMixin, LessThan):
pass
class RelatedGreaterThan(RelatedLookupMixin, GreaterThan):
pass
class RelatedGreaterThanOrEqual(RelatedLookupMixin, GreaterThanOrEqual):
pass
class RelatedLessThanOrEqual(RelatedLookupMixin, LessThanOrEqual):
pass
class RelatedIsNull(RelatedLookupMixin, IsNull):
pass

View File

@ -0,0 +1,402 @@
"""
"Rel objects" for related fields.
"Rel objects" (for lack of a better name) carry information about the relation
modeled by a related field and provide some utility functions. They're stored
in the ``remote_field`` attribute of the field.
They also act as reverse fields for the purposes of the Meta API because
they're the closest concept currently available.
"""
from django.core import exceptions
from django.utils.functional import cached_property
from django.utils.hashable import make_hashable
from . import BLANK_CHOICE_DASH
from .mixins import FieldCacheMixin
class ForeignObjectRel(FieldCacheMixin):
"""
Used by ForeignObject to store information about the relation.
``_meta.get_fields()`` returns this class to provide access to the field
flags for the reverse relation.
"""
# Field flags
auto_created = True
concrete = False
editable = False
is_relation = True
# Reverse relations are always nullable (Django can't enforce that a
# foreign key on the related model points to this model).
null = True
empty_strings_allowed = False
def __init__(
self,
field,
to,
related_name=None,
related_query_name=None,
limit_choices_to=None,
parent_link=False,
on_delete=None,
):
self.field = field
self.model = to
self.related_name = related_name
self.related_query_name = related_query_name
self.limit_choices_to = {} if limit_choices_to is None else limit_choices_to
self.parent_link = parent_link
self.on_delete = on_delete
self.symmetrical = False
self.multiple = True
# Some of the following cached_properties can't be initialized in
# __init__ as the field doesn't have its model yet. Calling these methods
# before field.contribute_to_class() has been called will result in
# AttributeError
@cached_property
def hidden(self):
return self.is_hidden()
@cached_property
def name(self):
return self.field.related_query_name()
@property
def remote_field(self):
return self.field
@property
def target_field(self):
"""
When filtering against this relation, return the field on the remote
model against which the filtering should happen.
"""
target_fields = self.path_infos[-1].target_fields
if len(target_fields) > 1:
raise exceptions.FieldError(
"Can't use target_field for multicolumn relations."
)
return target_fields[0]
@cached_property
def related_model(self):
if not self.field.model:
raise AttributeError(
"This property can't be accessed before self.field.contribute_to_class "
"has been called."
)
return self.field.model
@cached_property
def many_to_many(self):
return self.field.many_to_many
@cached_property
def many_to_one(self):
return self.field.one_to_many
@cached_property
def one_to_many(self):
return self.field.many_to_one
@cached_property
def one_to_one(self):
return self.field.one_to_one
def get_lookup(self, lookup_name):
return self.field.get_lookup(lookup_name)
def get_lookups(self):
return self.field.get_lookups()
def get_transform(self, name):
return self.field.get_transform(name)
def get_internal_type(self):
return self.field.get_internal_type()
@property
def db_type(self):
return self.field.db_type
def __repr__(self):
return "<%s: %s.%s>" % (
type(self).__name__,
self.related_model._meta.app_label,
self.related_model._meta.model_name,
)
@property
def identity(self):
return (
self.field,
self.model,
self.related_name,
self.related_query_name,
make_hashable(self.limit_choices_to),
self.parent_link,
self.on_delete,
self.symmetrical,
self.multiple,
)
def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return self.identity == other.identity
def __hash__(self):
return hash(self.identity)
def __getstate__(self):
state = self.__dict__.copy()
# Delete the path_infos cached property because it can be recalculated
# at first invocation after deserialization. The attribute must be
# removed because subclasses like ManyToOneRel may have a PathInfo
# which contains an intermediate M2M table that's been dynamically
# created and doesn't exist in the .models module.
# This is a reverse relation, so there is no reverse_path_infos to
# delete.
state.pop("path_infos", None)
return state
def get_choices(
self,
include_blank=True,
blank_choice=BLANK_CHOICE_DASH,
limit_choices_to=None,
ordering=(),
):
"""
Return choices with a default blank choices included, for use
as <select> choices for this field.
Analog of django.db.models.fields.Field.get_choices(), provided
initially for utilization by RelatedFieldListFilter.
"""
limit_choices_to = limit_choices_to or self.limit_choices_to
qs = self.related_model._default_manager.complex_filter(limit_choices_to)
if ordering:
qs = qs.order_by(*ordering)
return (blank_choice if include_blank else []) + [(x.pk, str(x)) for x in qs]
def is_hidden(self):
"""Should the related object be hidden?"""
return bool(self.related_name) and self.related_name[-1] == "+"
def get_joining_columns(self):
return self.field.get_reverse_joining_columns()
def get_extra_restriction(self, alias, related_alias):
return self.field.get_extra_restriction(related_alias, alias)
def set_field_name(self):
"""
Set the related field's name, this is not available until later stages
of app loading, so set_field_name is called from
set_attributes_from_rel()
"""
# By default foreign object doesn't relate to any remote field (for
# example custom multicolumn joins currently have no remote field).
self.field_name = None
def get_accessor_name(self, model=None):
# This method encapsulates the logic that decides what name to give an
# accessor descriptor that retrieves related many-to-one or
# many-to-many objects. It uses the lowercased object_name + "_set",
# but this can be overridden with the "related_name" option. Due to
# backwards compatibility ModelForms need to be able to provide an
# alternate model. See BaseInlineFormSet.get_default_prefix().
opts = model._meta if model else self.related_model._meta
model = model or self.related_model
if self.multiple:
# If this is a symmetrical m2m relation on self, there is no
# reverse accessor.
if self.symmetrical and model == self.model:
return None
if self.related_name:
return self.related_name
return opts.model_name + ("_set" if self.multiple else "")
def get_path_info(self, filtered_relation=None):
if filtered_relation:
return self.field.get_reverse_path_info(filtered_relation)
else:
return self.field.reverse_path_infos
@cached_property
def path_infos(self):
return self.get_path_info()
def get_cache_name(self):
"""
Return the name of the cache key to use for storing an instance of the
forward model on the reverse model.
"""
return self.get_accessor_name()
class ManyToOneRel(ForeignObjectRel):
"""
Used by the ForeignKey field to store information about the relation.
``_meta.get_fields()`` returns this class to provide access to the field
flags for the reverse relation.
Note: Because we somewhat abuse the Rel objects by using them as reverse
fields we get the funny situation where
``ManyToOneRel.many_to_one == False`` and
``ManyToOneRel.one_to_many == True``. This is unfortunate but the actual
ManyToOneRel class is a private API and there is work underway to turn
reverse relations into actual fields.
"""
def __init__(
self,
field,
to,
field_name,
related_name=None,
related_query_name=None,
limit_choices_to=None,
parent_link=False,
on_delete=None,
):
super().__init__(
field,
to,
related_name=related_name,
related_query_name=related_query_name,
limit_choices_to=limit_choices_to,
parent_link=parent_link,
on_delete=on_delete,
)
self.field_name = field_name
def __getstate__(self):
state = super().__getstate__()
state.pop("related_model", None)
return state
@property
def identity(self):
return super().identity + (self.field_name,)
def get_related_field(self):
"""
Return the Field in the 'to' object to which this relationship is tied.
"""
field = self.model._meta.get_field(self.field_name)
if not field.concrete:
raise exceptions.FieldDoesNotExist(
"No related field named '%s'" % self.field_name
)
return field
def set_field_name(self):
self.field_name = self.field_name or self.model._meta.pk.name
class OneToOneRel(ManyToOneRel):
"""
Used by OneToOneField to store information about the relation.
``_meta.get_fields()`` returns this class to provide access to the field
flags for the reverse relation.
"""
def __init__(
self,
field,
to,
field_name,
related_name=None,
related_query_name=None,
limit_choices_to=None,
parent_link=False,
on_delete=None,
):
super().__init__(
field,
to,
field_name,
related_name=related_name,
related_query_name=related_query_name,
limit_choices_to=limit_choices_to,
parent_link=parent_link,
on_delete=on_delete,
)
self.multiple = False
class ManyToManyRel(ForeignObjectRel):
"""
Used by ManyToManyField to store information about the relation.
``_meta.get_fields()`` returns this class to provide access to the field
flags for the reverse relation.
"""
def __init__(
self,
field,
to,
related_name=None,
related_query_name=None,
limit_choices_to=None,
symmetrical=True,
through=None,
through_fields=None,
db_constraint=True,
):
super().__init__(
field,
to,
related_name=related_name,
related_query_name=related_query_name,
limit_choices_to=limit_choices_to,
)
if through and not db_constraint:
raise ValueError("Can't supply a through model and db_constraint=False")
self.through = through
if through_fields and not through:
raise ValueError("Cannot specify through_fields without a through model")
self.through_fields = through_fields
self.symmetrical = symmetrical
self.db_constraint = db_constraint
@property
def identity(self):
return super().identity + (
self.through,
make_hashable(self.through_fields),
self.db_constraint,
)
def get_related_field(self):
"""
Return the field in the 'to' object to which this relationship is tied.
Provided for symmetry with ManyToOneRel.
"""
opts = self.through._meta
if self.through_fields:
field = opts.get_field(self.through_fields[0])
else:
for field in opts.fields:
rel = getattr(field, "remote_field", None)
if rel and rel.model == self.model:
break
return field.foreign_related_fields[0]

View File

@ -0,0 +1,190 @@
from .comparison import Cast, Coalesce, Collate, Greatest, JSONObject, Least, NullIf
from .datetime import (
Extract,
ExtractDay,
ExtractHour,
ExtractIsoWeekDay,
ExtractIsoYear,
ExtractMinute,
ExtractMonth,
ExtractQuarter,
ExtractSecond,
ExtractWeek,
ExtractWeekDay,
ExtractYear,
Now,
Trunc,
TruncDate,
TruncDay,
TruncHour,
TruncMinute,
TruncMonth,
TruncQuarter,
TruncSecond,
TruncTime,
TruncWeek,
TruncYear,
)
from .math import (
Abs,
ACos,
ASin,
ATan,
ATan2,
Ceil,
Cos,
Cot,
Degrees,
Exp,
Floor,
Ln,
Log,
Mod,
Pi,
Power,
Radians,
Random,
Round,
Sign,
Sin,
Sqrt,
Tan,
)
from .text import (
MD5,
SHA1,
SHA224,
SHA256,
SHA384,
SHA512,
Chr,
Concat,
ConcatPair,
Left,
Length,
Lower,
LPad,
LTrim,
Ord,
Repeat,
Replace,
Reverse,
Right,
RPad,
RTrim,
StrIndex,
Substr,
Trim,
Upper,
)
from .window import (
CumeDist,
DenseRank,
FirstValue,
Lag,
LastValue,
Lead,
NthValue,
Ntile,
PercentRank,
Rank,
RowNumber,
)
__all__ = [
# comparison and conversion
"Cast",
"Coalesce",
"Collate",
"Greatest",
"JSONObject",
"Least",
"NullIf",
# datetime
"Extract",
"ExtractDay",
"ExtractHour",
"ExtractMinute",
"ExtractMonth",
"ExtractQuarter",
"ExtractSecond",
"ExtractWeek",
"ExtractIsoWeekDay",
"ExtractWeekDay",
"ExtractIsoYear",
"ExtractYear",
"Now",
"Trunc",
"TruncDate",
"TruncDay",
"TruncHour",
"TruncMinute",
"TruncMonth",
"TruncQuarter",
"TruncSecond",
"TruncTime",
"TruncWeek",
"TruncYear",
# math
"Abs",
"ACos",
"ASin",
"ATan",
"ATan2",
"Ceil",
"Cos",
"Cot",
"Degrees",
"Exp",
"Floor",
"Ln",
"Log",
"Mod",
"Pi",
"Power",
"Radians",
"Random",
"Round",
"Sign",
"Sin",
"Sqrt",
"Tan",
# text
"MD5",
"SHA1",
"SHA224",
"SHA256",
"SHA384",
"SHA512",
"Chr",
"Concat",
"ConcatPair",
"Left",
"Length",
"Lower",
"LPad",
"LTrim",
"Ord",
"Repeat",
"Replace",
"Reverse",
"Right",
"RPad",
"RTrim",
"StrIndex",
"Substr",
"Trim",
"Upper",
# window
"CumeDist",
"DenseRank",
"FirstValue",
"Lag",
"LastValue",
"Lead",
"NthValue",
"Ntile",
"PercentRank",
"Rank",
"RowNumber",
]

View File

@ -0,0 +1,220 @@
"""Database functions that do comparisons or type conversions."""
from django.db import NotSupportedError
from django.db.models.expressions import Func, Value
from django.db.models.fields import TextField
from django.db.models.fields.json import JSONField
from django.utils.regex_helper import _lazy_re_compile
class Cast(Func):
"""Coerce an expression to a new field type."""
function = "CAST"
template = "%(function)s(%(expressions)s AS %(db_type)s)"
def __init__(self, expression, output_field):
super().__init__(expression, output_field=output_field)
def as_sql(self, compiler, connection, **extra_context):
extra_context["db_type"] = self.output_field.cast_db_type(connection)
return super().as_sql(compiler, connection, **extra_context)
def as_sqlite(self, compiler, connection, **extra_context):
db_type = self.output_field.db_type(connection)
if db_type in {"datetime", "time"}:
# Use strftime as datetime/time don't keep fractional seconds.
template = "strftime(%%s, %(expressions)s)"
sql, params = super().as_sql(
compiler, connection, template=template, **extra_context
)
format_string = "%H:%M:%f" if db_type == "time" else "%Y-%m-%d %H:%M:%f"
params.insert(0, format_string)
return sql, params
elif db_type == "date":
template = "date(%(expressions)s)"
return super().as_sql(
compiler, connection, template=template, **extra_context
)
return self.as_sql(compiler, connection, **extra_context)
def as_mysql(self, compiler, connection, **extra_context):
template = None
output_type = self.output_field.get_internal_type()
# MySQL doesn't support explicit cast to float.
if output_type == "FloatField":
template = "(%(expressions)s + 0.0)"
# MariaDB doesn't support explicit cast to JSON.
elif output_type == "JSONField" and connection.mysql_is_mariadb:
template = "JSON_EXTRACT(%(expressions)s, '$')"
return self.as_sql(compiler, connection, template=template, **extra_context)
def as_postgresql(self, compiler, connection, **extra_context):
# CAST would be valid too, but the :: shortcut syntax is more readable.
# 'expressions' is wrapped in parentheses in case it's a complex
# expression.
return self.as_sql(
compiler,
connection,
template="(%(expressions)s)::%(db_type)s",
**extra_context,
)
def as_oracle(self, compiler, connection, **extra_context):
if self.output_field.get_internal_type() == "JSONField":
# Oracle doesn't support explicit cast to JSON.
template = "JSON_QUERY(%(expressions)s, '$')"
return super().as_sql(
compiler, connection, template=template, **extra_context
)
return self.as_sql(compiler, connection, **extra_context)
class Coalesce(Func):
"""Return, from left to right, the first non-null expression."""
function = "COALESCE"
def __init__(self, *expressions, **extra):
if len(expressions) < 2:
raise ValueError("Coalesce must take at least two expressions")
super().__init__(*expressions, **extra)
@property
def empty_result_set_value(self):
for expression in self.get_source_expressions():
result = expression.empty_result_set_value
if result is NotImplemented or result is not None:
return result
return None
def as_oracle(self, compiler, connection, **extra_context):
# Oracle prohibits mixing TextField (NCLOB) and CharField (NVARCHAR2),
# so convert all fields to NCLOB when that type is expected.
if self.output_field.get_internal_type() == "TextField":
clone = self.copy()
clone.set_source_expressions(
[
Func(expression, function="TO_NCLOB")
for expression in self.get_source_expressions()
]
)
return super(Coalesce, clone).as_sql(compiler, connection, **extra_context)
return self.as_sql(compiler, connection, **extra_context)
class Collate(Func):
function = "COLLATE"
template = "%(expressions)s %(function)s %(collation)s"
# Inspired from
# https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
collation_re = _lazy_re_compile(r"^[\w\-]+$")
def __init__(self, expression, collation):
if not (collation and self.collation_re.match(collation)):
raise ValueError("Invalid collation name: %r." % collation)
self.collation = collation
super().__init__(expression)
def as_sql(self, compiler, connection, **extra_context):
extra_context.setdefault("collation", connection.ops.quote_name(self.collation))
return super().as_sql(compiler, connection, **extra_context)
class Greatest(Func):
"""
Return the maximum expression.
If any expression is null the return value is database-specific:
On PostgreSQL, the maximum not-null expression is returned.
On MySQL, Oracle, and SQLite, if any expression is null, null is returned.
"""
function = "GREATEST"
def __init__(self, *expressions, **extra):
if len(expressions) < 2:
raise ValueError("Greatest must take at least two expressions")
super().__init__(*expressions, **extra)
def as_sqlite(self, compiler, connection, **extra_context):
"""Use the MAX function on SQLite."""
return super().as_sqlite(compiler, connection, function="MAX", **extra_context)
class JSONObject(Func):
function = "JSON_OBJECT"
output_field = JSONField()
def __init__(self, **fields):
expressions = []
for key, value in fields.items():
expressions.extend((Value(key), value))
super().__init__(*expressions)
def as_sql(self, compiler, connection, **extra_context):
if not connection.features.has_json_object_function:
raise NotSupportedError(
"JSONObject() is not supported on this database backend."
)
return super().as_sql(compiler, connection, **extra_context)
def as_postgresql(self, compiler, connection, **extra_context):
copy = self.copy()
copy.set_source_expressions(
[
Cast(expression, TextField()) if index % 2 == 0 else expression
for index, expression in enumerate(copy.get_source_expressions())
]
)
return super(JSONObject, copy).as_sql(
compiler,
connection,
function="JSONB_BUILD_OBJECT",
**extra_context,
)
def as_oracle(self, compiler, connection, **extra_context):
class ArgJoiner:
def join(self, args):
args = [" VALUE ".join(arg) for arg in zip(args[::2], args[1::2])]
return ", ".join(args)
return self.as_sql(
compiler,
connection,
arg_joiner=ArgJoiner(),
template="%(function)s(%(expressions)s RETURNING CLOB)",
**extra_context,
)
class Least(Func):
"""
Return the minimum expression.
If any expression is null the return value is database-specific:
On PostgreSQL, return the minimum not-null expression.
On MySQL, Oracle, and SQLite, if any expression is null, return null.
"""
function = "LEAST"
def __init__(self, *expressions, **extra):
if len(expressions) < 2:
raise ValueError("Least must take at least two expressions")
super().__init__(*expressions, **extra)
def as_sqlite(self, compiler, connection, **extra_context):
"""Use the MIN function on SQLite."""
return super().as_sqlite(compiler, connection, function="MIN", **extra_context)
class NullIf(Func):
function = "NULLIF"
arity = 2
def as_oracle(self, compiler, connection, **extra_context):
expression1 = self.get_source_expressions()[0]
if isinstance(expression1, Value) and expression1.value is None:
raise ValueError("Oracle does not allow Value(None) for expression1.")
return super().as_sql(compiler, connection, **extra_context)

View File

@ -0,0 +1,439 @@
from datetime import datetime
from django.conf import settings
from django.db.models.expressions import Func
from django.db.models.fields import (
DateField,
DateTimeField,
DurationField,
Field,
IntegerField,
TimeField,
)
from django.db.models.lookups import (
Transform,
YearExact,
YearGt,
YearGte,
YearLt,
YearLte,
)
from django.utils import timezone
class TimezoneMixin:
tzinfo = None
def get_tzname(self):
# Timezone conversions must happen to the input datetime *before*
# applying a function. 2015-12-31 23:00:00 -02:00 is stored in the
# database as 2016-01-01 01:00:00 +00:00. Any results should be
# based on the input datetime not the stored datetime.
tzname = None
if settings.USE_TZ:
if self.tzinfo is None:
tzname = timezone.get_current_timezone_name()
else:
tzname = timezone._get_timezone_name(self.tzinfo)
return tzname
class Extract(TimezoneMixin, Transform):
lookup_name = None
output_field = IntegerField()
def __init__(self, expression, lookup_name=None, tzinfo=None, **extra):
if self.lookup_name is None:
self.lookup_name = lookup_name
if self.lookup_name is None:
raise ValueError("lookup_name must be provided")
self.tzinfo = tzinfo
super().__init__(expression, **extra)
def as_sql(self, compiler, connection):
sql, params = compiler.compile(self.lhs)
lhs_output_field = self.lhs.output_field
if isinstance(lhs_output_field, DateTimeField):
tzname = self.get_tzname()
sql, params = connection.ops.datetime_extract_sql(
self.lookup_name, sql, tuple(params), tzname
)
elif self.tzinfo is not None:
raise ValueError("tzinfo can only be used with DateTimeField.")
elif isinstance(lhs_output_field, DateField):
sql, params = connection.ops.date_extract_sql(
self.lookup_name, sql, tuple(params)
)
elif isinstance(lhs_output_field, TimeField):
sql, params = connection.ops.time_extract_sql(
self.lookup_name, sql, tuple(params)
)
elif isinstance(lhs_output_field, DurationField):
if not connection.features.has_native_duration_field:
raise ValueError(
"Extract requires native DurationField database support."
)
sql, params = connection.ops.time_extract_sql(
self.lookup_name, sql, tuple(params)
)
else:
# resolve_expression has already validated the output_field so this
# assert should never be hit.
assert False, "Tried to Extract from an invalid type."
return sql, params
def resolve_expression(
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
):
copy = super().resolve_expression(
query, allow_joins, reuse, summarize, for_save
)
field = getattr(copy.lhs, "output_field", None)
if field is None:
return copy
if not isinstance(field, (DateField, DateTimeField, TimeField, DurationField)):
raise ValueError(
"Extract input expression must be DateField, DateTimeField, "
"TimeField, or DurationField."
)
# Passing dates to functions expecting datetimes is most likely a mistake.
if type(field) is DateField and copy.lookup_name in (
"hour",
"minute",
"second",
):
raise ValueError(
"Cannot extract time component '%s' from DateField '%s'."
% (copy.lookup_name, field.name)
)
if isinstance(field, DurationField) and copy.lookup_name in (
"year",
"iso_year",
"month",
"week",
"week_day",
"iso_week_day",
"quarter",
):
raise ValueError(
"Cannot extract component '%s' from DurationField '%s'."
% (copy.lookup_name, field.name)
)
return copy
class ExtractYear(Extract):
lookup_name = "year"
class ExtractIsoYear(Extract):
"""Return the ISO-8601 week-numbering year."""
lookup_name = "iso_year"
class ExtractMonth(Extract):
lookup_name = "month"
class ExtractDay(Extract):
lookup_name = "day"
class ExtractWeek(Extract):
"""
Return 1-52 or 53, based on ISO-8601, i.e., Monday is the first of the
week.
"""
lookup_name = "week"
class ExtractWeekDay(Extract):
"""
Return Sunday=1 through Saturday=7.
To replicate this in Python: (mydatetime.isoweekday() % 7) + 1
"""
lookup_name = "week_day"
class ExtractIsoWeekDay(Extract):
"""Return Monday=1 through Sunday=7, based on ISO-8601."""
lookup_name = "iso_week_day"
class ExtractQuarter(Extract):
lookup_name = "quarter"
class ExtractHour(Extract):
lookup_name = "hour"
class ExtractMinute(Extract):
lookup_name = "minute"
class ExtractSecond(Extract):
lookup_name = "second"
DateField.register_lookup(ExtractYear)
DateField.register_lookup(ExtractMonth)
DateField.register_lookup(ExtractDay)
DateField.register_lookup(ExtractWeekDay)
DateField.register_lookup(ExtractIsoWeekDay)
DateField.register_lookup(ExtractWeek)
DateField.register_lookup(ExtractIsoYear)
DateField.register_lookup(ExtractQuarter)
TimeField.register_lookup(ExtractHour)
TimeField.register_lookup(ExtractMinute)
TimeField.register_lookup(ExtractSecond)
DateTimeField.register_lookup(ExtractHour)
DateTimeField.register_lookup(ExtractMinute)
DateTimeField.register_lookup(ExtractSecond)
ExtractYear.register_lookup(YearExact)
ExtractYear.register_lookup(YearGt)
ExtractYear.register_lookup(YearGte)
ExtractYear.register_lookup(YearLt)
ExtractYear.register_lookup(YearLte)
ExtractIsoYear.register_lookup(YearExact)
ExtractIsoYear.register_lookup(YearGt)
ExtractIsoYear.register_lookup(YearGte)
ExtractIsoYear.register_lookup(YearLt)
ExtractIsoYear.register_lookup(YearLte)
class Now(Func):
template = "CURRENT_TIMESTAMP"
output_field = DateTimeField()
def as_postgresql(self, compiler, connection, **extra_context):
# PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the
# transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with
# other databases.
return self.as_sql(
compiler, connection, template="STATEMENT_TIMESTAMP()", **extra_context
)
def as_mysql(self, compiler, connection, **extra_context):
return self.as_sql(
compiler, connection, template="CURRENT_TIMESTAMP(6)", **extra_context
)
def as_sqlite(self, compiler, connection, **extra_context):
return self.as_sql(
compiler,
connection,
template="STRFTIME('%%%%Y-%%%%m-%%%%d %%%%H:%%%%M:%%%%f', 'NOW')",
**extra_context,
)
class TruncBase(TimezoneMixin, Transform):
kind = None
tzinfo = None
# RemovedInDjango50Warning: when the deprecation ends, remove is_dst
# argument.
def __init__(
self,
expression,
output_field=None,
tzinfo=None,
is_dst=timezone.NOT_PASSED,
**extra,
):
self.tzinfo = tzinfo
self.is_dst = is_dst
super().__init__(expression, output_field=output_field, **extra)
def as_sql(self, compiler, connection):
sql, params = compiler.compile(self.lhs)
tzname = None
if isinstance(self.lhs.output_field, DateTimeField):
tzname = self.get_tzname()
elif self.tzinfo is not None:
raise ValueError("tzinfo can only be used with DateTimeField.")
if isinstance(self.output_field, DateTimeField):
sql, params = connection.ops.datetime_trunc_sql(
self.kind, sql, tuple(params), tzname
)
elif isinstance(self.output_field, DateField):
sql, params = connection.ops.date_trunc_sql(
self.kind, sql, tuple(params), tzname
)
elif isinstance(self.output_field, TimeField):
sql, params = connection.ops.time_trunc_sql(
self.kind, sql, tuple(params), tzname
)
else:
raise ValueError(
"Trunc only valid on DateField, TimeField, or DateTimeField."
)
return sql, params
def resolve_expression(
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
):
copy = super().resolve_expression(
query, allow_joins, reuse, summarize, for_save
)
field = copy.lhs.output_field
# DateTimeField is a subclass of DateField so this works for both.
if not isinstance(field, (DateField, TimeField)):
raise TypeError(
"%r isn't a DateField, TimeField, or DateTimeField." % field.name
)
# If self.output_field was None, then accessing the field will trigger
# the resolver to assign it to self.lhs.output_field.
if not isinstance(copy.output_field, (DateField, DateTimeField, TimeField)):
raise ValueError(
"output_field must be either DateField, TimeField, or DateTimeField"
)
# Passing dates or times to functions expecting datetimes is most
# likely a mistake.
class_output_field = (
self.__class__.output_field
if isinstance(self.__class__.output_field, Field)
else None
)
output_field = class_output_field or copy.output_field
has_explicit_output_field = (
class_output_field or field.__class__ is not copy.output_field.__class__
)
if type(field) is DateField and (
isinstance(output_field, DateTimeField)
or copy.kind in ("hour", "minute", "second", "time")
):
raise ValueError(
"Cannot truncate DateField '%s' to %s."
% (
field.name,
output_field.__class__.__name__
if has_explicit_output_field
else "DateTimeField",
)
)
elif isinstance(field, TimeField) and (
isinstance(output_field, DateTimeField)
or copy.kind in ("year", "quarter", "month", "week", "day", "date")
):
raise ValueError(
"Cannot truncate TimeField '%s' to %s."
% (
field.name,
output_field.__class__.__name__
if has_explicit_output_field
else "DateTimeField",
)
)
return copy
def convert_value(self, value, expression, connection):
if isinstance(self.output_field, DateTimeField):
if not settings.USE_TZ:
pass
elif value is not None:
value = value.replace(tzinfo=None)
value = timezone.make_aware(value, self.tzinfo, is_dst=self.is_dst)
elif not connection.features.has_zoneinfo_database:
raise ValueError(
"Database returned an invalid datetime value. Are time "
"zone definitions for your database installed?"
)
elif isinstance(value, datetime):
if value is None:
pass
elif isinstance(self.output_field, DateField):
value = value.date()
elif isinstance(self.output_field, TimeField):
value = value.time()
return value
class Trunc(TruncBase):
# RemovedInDjango50Warning: when the deprecation ends, remove is_dst
# argument.
def __init__(
self,
expression,
kind,
output_field=None,
tzinfo=None,
is_dst=timezone.NOT_PASSED,
**extra,
):
self.kind = kind
super().__init__(
expression, output_field=output_field, tzinfo=tzinfo, is_dst=is_dst, **extra
)
class TruncYear(TruncBase):
kind = "year"
class TruncQuarter(TruncBase):
kind = "quarter"
class TruncMonth(TruncBase):
kind = "month"
class TruncWeek(TruncBase):
"""Truncate to midnight on the Monday of the week."""
kind = "week"
class TruncDay(TruncBase):
kind = "day"
class TruncDate(TruncBase):
kind = "date"
lookup_name = "date"
output_field = DateField()
def as_sql(self, compiler, connection):
# Cast to date rather than truncate to date.
sql, params = compiler.compile(self.lhs)
tzname = self.get_tzname()
return connection.ops.datetime_cast_date_sql(sql, tuple(params), tzname)
class TruncTime(TruncBase):
kind = "time"
lookup_name = "time"
output_field = TimeField()
def as_sql(self, compiler, connection):
# Cast to time rather than truncate to time.
sql, params = compiler.compile(self.lhs)
tzname = self.get_tzname()
return connection.ops.datetime_cast_time_sql(sql, tuple(params), tzname)
class TruncHour(TruncBase):
kind = "hour"
class TruncMinute(TruncBase):
kind = "minute"
class TruncSecond(TruncBase):
kind = "second"
DateTimeField.register_lookup(TruncDate)
DateTimeField.register_lookup(TruncTime)

View File

@ -0,0 +1,212 @@
import math
from django.db.models.expressions import Func, Value
from django.db.models.fields import FloatField, IntegerField
from django.db.models.functions import Cast
from django.db.models.functions.mixins import (
FixDecimalInputMixin,
NumericOutputFieldMixin,
)
from django.db.models.lookups import Transform
class Abs(Transform):
function = "ABS"
lookup_name = "abs"
class ACos(NumericOutputFieldMixin, Transform):
function = "ACOS"
lookup_name = "acos"
class ASin(NumericOutputFieldMixin, Transform):
function = "ASIN"
lookup_name = "asin"
class ATan(NumericOutputFieldMixin, Transform):
function = "ATAN"
lookup_name = "atan"
class ATan2(NumericOutputFieldMixin, Func):
function = "ATAN2"
arity = 2
def as_sqlite(self, compiler, connection, **extra_context):
if not getattr(
connection.ops, "spatialite", False
) or connection.ops.spatial_version >= (5, 0, 0):
return self.as_sql(compiler, connection)
# This function is usually ATan2(y, x), returning the inverse tangent
# of y / x, but it's ATan2(x, y) on SpatiaLite < 5.0.0.
# Cast integers to float to avoid inconsistent/buggy behavior if the
# arguments are mixed between integer and float or decimal.
# https://www.gaia-gis.it/fossil/libspatialite/tktview?name=0f72cca3a2
clone = self.copy()
clone.set_source_expressions(
[
Cast(expression, FloatField())
if isinstance(expression.output_field, IntegerField)
else expression
for expression in self.get_source_expressions()[::-1]
]
)
return clone.as_sql(compiler, connection, **extra_context)
class Ceil(Transform):
function = "CEILING"
lookup_name = "ceil"
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function="CEIL", **extra_context)
class Cos(NumericOutputFieldMixin, Transform):
function = "COS"
lookup_name = "cos"
class Cot(NumericOutputFieldMixin, Transform):
function = "COT"
lookup_name = "cot"
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(
compiler, connection, template="(1 / TAN(%(expressions)s))", **extra_context
)
class Degrees(NumericOutputFieldMixin, Transform):
function = "DEGREES"
lookup_name = "degrees"
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(
compiler,
connection,
template="((%%(expressions)s) * 180 / %s)" % math.pi,
**extra_context,
)
class Exp(NumericOutputFieldMixin, Transform):
function = "EXP"
lookup_name = "exp"
class Floor(Transform):
function = "FLOOR"
lookup_name = "floor"
class Ln(NumericOutputFieldMixin, Transform):
function = "LN"
lookup_name = "ln"
class Log(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
function = "LOG"
arity = 2
def as_sqlite(self, compiler, connection, **extra_context):
if not getattr(connection.ops, "spatialite", False):
return self.as_sql(compiler, connection)
# This function is usually Log(b, x) returning the logarithm of x to
# the base b, but on SpatiaLite it's Log(x, b).
clone = self.copy()
clone.set_source_expressions(self.get_source_expressions()[::-1])
return clone.as_sql(compiler, connection, **extra_context)
class Mod(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
function = "MOD"
arity = 2
class Pi(NumericOutputFieldMixin, Func):
function = "PI"
arity = 0
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(
compiler, connection, template=str(math.pi), **extra_context
)
class Power(NumericOutputFieldMixin, Func):
function = "POWER"
arity = 2
class Radians(NumericOutputFieldMixin, Transform):
function = "RADIANS"
lookup_name = "radians"
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(
compiler,
connection,
template="((%%(expressions)s) * %s / 180)" % math.pi,
**extra_context,
)
class Random(NumericOutputFieldMixin, Func):
function = "RANDOM"
arity = 0
def as_mysql(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function="RAND", **extra_context)
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(
compiler, connection, function="DBMS_RANDOM.VALUE", **extra_context
)
def as_sqlite(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function="RAND", **extra_context)
def get_group_by_cols(self):
return []
class Round(FixDecimalInputMixin, Transform):
function = "ROUND"
lookup_name = "round"
arity = None # Override Transform's arity=1 to enable passing precision.
def __init__(self, expression, precision=0, **extra):
super().__init__(expression, precision, **extra)
def as_sqlite(self, compiler, connection, **extra_context):
precision = self.get_source_expressions()[1]
if isinstance(precision, Value) and precision.value < 0:
raise ValueError("SQLite does not support negative precision.")
return super().as_sqlite(compiler, connection, **extra_context)
def _resolve_output_field(self):
source = self.get_source_expressions()[0]
return source.output_field
class Sign(Transform):
function = "SIGN"
lookup_name = "sign"
class Sin(NumericOutputFieldMixin, Transform):
function = "SIN"
lookup_name = "sin"
class Sqrt(NumericOutputFieldMixin, Transform):
function = "SQRT"
lookup_name = "sqrt"
class Tan(NumericOutputFieldMixin, Transform):
function = "TAN"
lookup_name = "tan"

View File

@ -0,0 +1,57 @@
import sys
from django.db.models.fields import DecimalField, FloatField, IntegerField
from django.db.models.functions import Cast
class FixDecimalInputMixin:
def as_postgresql(self, compiler, connection, **extra_context):
# Cast FloatField to DecimalField as PostgreSQL doesn't support the
# following function signatures:
# - LOG(double, double)
# - MOD(double, double)
output_field = DecimalField(decimal_places=sys.float_info.dig, max_digits=1000)
clone = self.copy()
clone.set_source_expressions(
[
Cast(expression, output_field)
if isinstance(expression.output_field, FloatField)
else expression
for expression in self.get_source_expressions()
]
)
return clone.as_sql(compiler, connection, **extra_context)
class FixDurationInputMixin:
def as_mysql(self, compiler, connection, **extra_context):
sql, params = super().as_sql(compiler, connection, **extra_context)
if self.output_field.get_internal_type() == "DurationField":
sql = "CAST(%s AS SIGNED)" % sql
return sql, params
def as_oracle(self, compiler, connection, **extra_context):
if self.output_field.get_internal_type() == "DurationField":
expression = self.get_source_expressions()[0]
options = self._get_repr_options()
from django.db.backends.oracle.functions import (
IntervalToSeconds,
SecondsToInterval,
)
return compiler.compile(
SecondsToInterval(
self.__class__(IntervalToSeconds(expression), **options)
)
)
return super().as_sql(compiler, connection, **extra_context)
class NumericOutputFieldMixin:
def _resolve_output_field(self):
source_fields = self.get_source_fields()
if any(isinstance(s, DecimalField) for s in source_fields):
return DecimalField()
if any(isinstance(s, IntegerField) for s in source_fields):
return FloatField()
return super()._resolve_output_field() if source_fields else FloatField()

View File

@ -0,0 +1,365 @@
from django.db import NotSupportedError
from django.db.models.expressions import Func, Value
from django.db.models.fields import CharField, IntegerField, TextField
from django.db.models.functions import Cast, Coalesce
from django.db.models.lookups import Transform
class MySQLSHA2Mixin:
def as_mysql(self, compiler, connection, **extra_context):
return super().as_sql(
compiler,
connection,
template="SHA2(%%(expressions)s, %s)" % self.function[3:],
**extra_context,
)
class OracleHashMixin:
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(
compiler,
connection,
template=(
"LOWER(RAWTOHEX(STANDARD_HASH(UTL_I18N.STRING_TO_RAW("
"%(expressions)s, 'AL32UTF8'), '%(function)s')))"
),
**extra_context,
)
class PostgreSQLSHAMixin:
def as_postgresql(self, compiler, connection, **extra_context):
return super().as_sql(
compiler,
connection,
template="ENCODE(DIGEST(%(expressions)s, '%(function)s'), 'hex')",
function=self.function.lower(),
**extra_context,
)
class Chr(Transform):
function = "CHR"
lookup_name = "chr"
def as_mysql(self, compiler, connection, **extra_context):
return super().as_sql(
compiler,
connection,
function="CHAR",
template="%(function)s(%(expressions)s USING utf16)",
**extra_context,
)
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(
compiler,
connection,
template="%(function)s(%(expressions)s USING NCHAR_CS)",
**extra_context,
)
def as_sqlite(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function="CHAR", **extra_context)
class ConcatPair(Func):
"""
Concatenate two arguments together. This is used by `Concat` because not
all backend databases support more than two arguments.
"""
function = "CONCAT"
def as_sqlite(self, compiler, connection, **extra_context):
coalesced = self.coalesce()
return super(ConcatPair, coalesced).as_sql(
compiler,
connection,
template="%(expressions)s",
arg_joiner=" || ",
**extra_context,
)
def as_postgresql(self, compiler, connection, **extra_context):
copy = self.copy()
copy.set_source_expressions(
[
Cast(expression, TextField())
for expression in copy.get_source_expressions()
]
)
return super(ConcatPair, copy).as_sql(
compiler,
connection,
**extra_context,
)
def as_mysql(self, compiler, connection, **extra_context):
# Use CONCAT_WS with an empty separator so that NULLs are ignored.
return super().as_sql(
compiler,
connection,
function="CONCAT_WS",
template="%(function)s('', %(expressions)s)",
**extra_context,
)
def coalesce(self):
# null on either side results in null for expression, wrap with coalesce
c = self.copy()
c.set_source_expressions(
[
Coalesce(expression, Value(""))
for expression in c.get_source_expressions()
]
)
return c
class Concat(Func):
"""
Concatenate text fields together. Backends that result in an entire
null expression when any arguments are null will wrap each argument in
coalesce functions to ensure a non-null result.
"""
function = None
template = "%(expressions)s"
def __init__(self, *expressions, **extra):
if len(expressions) < 2:
raise ValueError("Concat must take at least two expressions")
paired = self._paired(expressions)
super().__init__(paired, **extra)
def _paired(self, expressions):
# wrap pairs of expressions in successive concat functions
# exp = [a, b, c, d]
# -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d))))
if len(expressions) == 2:
return ConcatPair(*expressions)
return ConcatPair(expressions[0], self._paired(expressions[1:]))
class Left(Func):
function = "LEFT"
arity = 2
output_field = CharField()
def __init__(self, expression, length, **extra):
"""
expression: the name of a field, or an expression returning a string
length: the number of characters to return from the start of the string
"""
if not hasattr(length, "resolve_expression"):
if length < 1:
raise ValueError("'length' must be greater than 0.")
super().__init__(expression, length, **extra)
def get_substr(self):
return Substr(self.source_expressions[0], Value(1), self.source_expressions[1])
def as_oracle(self, compiler, connection, **extra_context):
return self.get_substr().as_oracle(compiler, connection, **extra_context)
def as_sqlite(self, compiler, connection, **extra_context):
return self.get_substr().as_sqlite(compiler, connection, **extra_context)
class Length(Transform):
"""Return the number of characters in the expression."""
function = "LENGTH"
lookup_name = "length"
output_field = IntegerField()
def as_mysql(self, compiler, connection, **extra_context):
return super().as_sql(
compiler, connection, function="CHAR_LENGTH", **extra_context
)
class Lower(Transform):
function = "LOWER"
lookup_name = "lower"
class LPad(Func):
function = "LPAD"
output_field = CharField()
def __init__(self, expression, length, fill_text=Value(" "), **extra):
if (
not hasattr(length, "resolve_expression")
and length is not None
and length < 0
):
raise ValueError("'length' must be greater or equal to 0.")
super().__init__(expression, length, fill_text, **extra)
class LTrim(Transform):
function = "LTRIM"
lookup_name = "ltrim"
class MD5(OracleHashMixin, Transform):
function = "MD5"
lookup_name = "md5"
class Ord(Transform):
function = "ASCII"
lookup_name = "ord"
output_field = IntegerField()
def as_mysql(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function="ORD", **extra_context)
def as_sqlite(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function="UNICODE", **extra_context)
class Repeat(Func):
function = "REPEAT"
output_field = CharField()
def __init__(self, expression, number, **extra):
if (
not hasattr(number, "resolve_expression")
and number is not None
and number < 0
):
raise ValueError("'number' must be greater or equal to 0.")
super().__init__(expression, number, **extra)
def as_oracle(self, compiler, connection, **extra_context):
expression, number = self.source_expressions
length = None if number is None else Length(expression) * number
rpad = RPad(expression, length, expression)
return rpad.as_sql(compiler, connection, **extra_context)
class Replace(Func):
function = "REPLACE"
def __init__(self, expression, text, replacement=Value(""), **extra):
super().__init__(expression, text, replacement, **extra)
class Reverse(Transform):
function = "REVERSE"
lookup_name = "reverse"
def as_oracle(self, compiler, connection, **extra_context):
# REVERSE in Oracle is undocumented and doesn't support multi-byte
# strings. Use a special subquery instead.
return super().as_sql(
compiler,
connection,
template=(
"(SELECT LISTAGG(s) WITHIN GROUP (ORDER BY n DESC) FROM "
"(SELECT LEVEL n, SUBSTR(%(expressions)s, LEVEL, 1) s "
"FROM DUAL CONNECT BY LEVEL <= LENGTH(%(expressions)s)) "
"GROUP BY %(expressions)s)"
),
**extra_context,
)
class Right(Left):
function = "RIGHT"
def get_substr(self):
return Substr(
self.source_expressions[0], self.source_expressions[1] * Value(-1)
)
class RPad(LPad):
function = "RPAD"
class RTrim(Transform):
function = "RTRIM"
lookup_name = "rtrim"
class SHA1(OracleHashMixin, PostgreSQLSHAMixin, Transform):
function = "SHA1"
lookup_name = "sha1"
class SHA224(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
function = "SHA224"
lookup_name = "sha224"
def as_oracle(self, compiler, connection, **extra_context):
raise NotSupportedError("SHA224 is not supported on Oracle.")
class SHA256(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
function = "SHA256"
lookup_name = "sha256"
class SHA384(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
function = "SHA384"
lookup_name = "sha384"
class SHA512(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
function = "SHA512"
lookup_name = "sha512"
class StrIndex(Func):
"""
Return a positive integer corresponding to the 1-indexed position of the
first occurrence of a substring inside another string, or 0 if the
substring is not found.
"""
function = "INSTR"
arity = 2
output_field = IntegerField()
def as_postgresql(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function="STRPOS", **extra_context)
class Substr(Func):
function = "SUBSTRING"
output_field = CharField()
def __init__(self, expression, pos, length=None, **extra):
"""
expression: the name of a field, or an expression returning a string
pos: an integer > 0, or an expression returning an integer
length: an optional number of characters to return
"""
if not hasattr(pos, "resolve_expression"):
if pos < 1:
raise ValueError("'pos' must be greater than 0")
expressions = [expression, pos]
if length is not None:
expressions.append(length)
super().__init__(*expressions, **extra)
def as_sqlite(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function="SUBSTR", **extra_context)
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function="SUBSTR", **extra_context)
class Trim(Transform):
function = "TRIM"
lookup_name = "trim"
class Upper(Transform):
function = "UPPER"
lookup_name = "upper"

View File

@ -0,0 +1,120 @@
from django.db.models.expressions import Func
from django.db.models.fields import FloatField, IntegerField
__all__ = [
"CumeDist",
"DenseRank",
"FirstValue",
"Lag",
"LastValue",
"Lead",
"NthValue",
"Ntile",
"PercentRank",
"Rank",
"RowNumber",
]
class CumeDist(Func):
function = "CUME_DIST"
output_field = FloatField()
window_compatible = True
class DenseRank(Func):
function = "DENSE_RANK"
output_field = IntegerField()
window_compatible = True
class FirstValue(Func):
arity = 1
function = "FIRST_VALUE"
window_compatible = True
class LagLeadFunction(Func):
window_compatible = True
def __init__(self, expression, offset=1, default=None, **extra):
if expression is None:
raise ValueError(
"%s requires a non-null source expression." % self.__class__.__name__
)
if offset is None or offset <= 0:
raise ValueError(
"%s requires a positive integer for the offset."
% self.__class__.__name__
)
args = (expression, offset)
if default is not None:
args += (default,)
super().__init__(*args, **extra)
def _resolve_output_field(self):
sources = self.get_source_expressions()
return sources[0].output_field
class Lag(LagLeadFunction):
function = "LAG"
class LastValue(Func):
arity = 1
function = "LAST_VALUE"
window_compatible = True
class Lead(LagLeadFunction):
function = "LEAD"
class NthValue(Func):
function = "NTH_VALUE"
window_compatible = True
def __init__(self, expression, nth=1, **extra):
if expression is None:
raise ValueError(
"%s requires a non-null source expression." % self.__class__.__name__
)
if nth is None or nth <= 0:
raise ValueError(
"%s requires a positive integer as for nth." % self.__class__.__name__
)
super().__init__(expression, nth, **extra)
def _resolve_output_field(self):
sources = self.get_source_expressions()
return sources[0].output_field
class Ntile(Func):
function = "NTILE"
output_field = IntegerField()
window_compatible = True
def __init__(self, num_buckets=1, **extra):
if num_buckets <= 0:
raise ValueError("num_buckets must be greater than 0.")
super().__init__(num_buckets, **extra)
class PercentRank(Func):
function = "PERCENT_RANK"
output_field = FloatField()
window_compatible = True
class Rank(Func):
function = "RANK"
output_field = IntegerField()
window_compatible = True
class RowNumber(Func):
function = "ROW_NUMBER"
output_field = IntegerField()
window_compatible = True

View File

@ -0,0 +1,295 @@
from django.db.backends.utils import names_digest, split_identifier
from django.db.models.expressions import Col, ExpressionList, F, Func, OrderBy
from django.db.models.functions import Collate
from django.db.models.query_utils import Q
from django.db.models.sql import Query
from django.utils.functional import partition
__all__ = ["Index"]
class Index:
suffix = "idx"
# The max length of the name of the index (restricted to 30 for
# cross-database compatibility with Oracle)
max_name_length = 30
def __init__(
self,
*expressions,
fields=(),
name=None,
db_tablespace=None,
opclasses=(),
condition=None,
include=None,
):
if opclasses and not name:
raise ValueError("An index must be named to use opclasses.")
if not isinstance(condition, (type(None), Q)):
raise ValueError("Index.condition must be a Q instance.")
if condition and not name:
raise ValueError("An index must be named to use condition.")
if not isinstance(fields, (list, tuple)):
raise ValueError("Index.fields must be a list or tuple.")
if not isinstance(opclasses, (list, tuple)):
raise ValueError("Index.opclasses must be a list or tuple.")
if not expressions and not fields:
raise ValueError(
"At least one field or expression is required to define an index."
)
if expressions and fields:
raise ValueError(
"Index.fields and expressions are mutually exclusive.",
)
if expressions and not name:
raise ValueError("An index must be named to use expressions.")
if expressions and opclasses:
raise ValueError(
"Index.opclasses cannot be used with expressions. Use "
"django.contrib.postgres.indexes.OpClass() instead."
)
if opclasses and len(fields) != len(opclasses):
raise ValueError(
"Index.fields and Index.opclasses must have the same number of "
"elements."
)
if fields and not all(isinstance(field, str) for field in fields):
raise ValueError("Index.fields must contain only strings with field names.")
if include and not name:
raise ValueError("A covering index must be named.")
if not isinstance(include, (type(None), list, tuple)):
raise ValueError("Index.include must be a list or tuple.")
self.fields = list(fields)
# A list of 2-tuple with the field name and ordering ('' or 'DESC').
self.fields_orders = [
(field_name[1:], "DESC") if field_name.startswith("-") else (field_name, "")
for field_name in self.fields
]
self.name = name or ""
self.db_tablespace = db_tablespace
self.opclasses = opclasses
self.condition = condition
self.include = tuple(include) if include else ()
self.expressions = tuple(
F(expression) if isinstance(expression, str) else expression
for expression in expressions
)
@property
def contains_expressions(self):
return bool(self.expressions)
def _get_condition_sql(self, model, schema_editor):
if self.condition is None:
return None
query = Query(model=model, alias_cols=False)
where = query.build_where(self.condition)
compiler = query.get_compiler(connection=schema_editor.connection)
sql, params = where.as_sql(compiler, schema_editor.connection)
return sql % tuple(schema_editor.quote_value(p) for p in params)
def create_sql(self, model, schema_editor, using="", **kwargs):
include = [
model._meta.get_field(field_name).column for field_name in self.include
]
condition = self._get_condition_sql(model, schema_editor)
if self.expressions:
index_expressions = []
for expression in self.expressions:
index_expression = IndexExpression(expression)
index_expression.set_wrapper_classes(schema_editor.connection)
index_expressions.append(index_expression)
expressions = ExpressionList(*index_expressions).resolve_expression(
Query(model, alias_cols=False),
)
fields = None
col_suffixes = None
else:
fields = [
model._meta.get_field(field_name)
for field_name, _ in self.fields_orders
]
if schema_editor.connection.features.supports_index_column_ordering:
col_suffixes = [order[1] for order in self.fields_orders]
else:
col_suffixes = [""] * len(self.fields_orders)
expressions = None
return schema_editor._create_index_sql(
model,
fields=fields,
name=self.name,
using=using,
db_tablespace=self.db_tablespace,
col_suffixes=col_suffixes,
opclasses=self.opclasses,
condition=condition,
include=include,
expressions=expressions,
**kwargs,
)
def remove_sql(self, model, schema_editor, **kwargs):
return schema_editor._delete_index_sql(model, self.name, **kwargs)
def deconstruct(self):
path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
path = path.replace("django.db.models.indexes", "django.db.models")
kwargs = {"name": self.name}
if self.fields:
kwargs["fields"] = self.fields
if self.db_tablespace is not None:
kwargs["db_tablespace"] = self.db_tablespace
if self.opclasses:
kwargs["opclasses"] = self.opclasses
if self.condition:
kwargs["condition"] = self.condition
if self.include:
kwargs["include"] = self.include
return (path, self.expressions, kwargs)
def clone(self):
"""Create a copy of this Index."""
_, args, kwargs = self.deconstruct()
return self.__class__(*args, **kwargs)
def set_name_with_model(self, model):
"""
Generate a unique name for the index.
The name is divided into 3 parts - table name (12 chars), field name
(8 chars) and unique hash + suffix (10 chars). Each part is made to
fit its size by truncating the excess length.
"""
_, table_name = split_identifier(model._meta.db_table)
column_names = [
model._meta.get_field(field_name).column
for field_name, order in self.fields_orders
]
column_names_with_order = [
(("-%s" if order else "%s") % column_name)
for column_name, (field_name, order) in zip(
column_names, self.fields_orders
)
]
# The length of the parts of the name is based on the default max
# length of 30 characters.
hash_data = [table_name] + column_names_with_order + [self.suffix]
self.name = "%s_%s_%s" % (
table_name[:11],
column_names[0][:7],
"%s_%s" % (names_digest(*hash_data, length=6), self.suffix),
)
if len(self.name) > self.max_name_length:
raise ValueError(
"Index too long for multiple database support. Is self.suffix "
"longer than 3 characters?"
)
if self.name[0] == "_" or self.name[0].isdigit():
self.name = "D%s" % self.name[1:]
def __repr__(self):
return "<%s:%s%s%s%s%s%s%s>" % (
self.__class__.__qualname__,
"" if not self.fields else " fields=%s" % repr(self.fields),
"" if not self.expressions else " expressions=%s" % repr(self.expressions),
"" if not self.name else " name=%s" % repr(self.name),
""
if self.db_tablespace is None
else " db_tablespace=%s" % repr(self.db_tablespace),
"" if self.condition is None else " condition=%s" % self.condition,
"" if not self.include else " include=%s" % repr(self.include),
"" if not self.opclasses else " opclasses=%s" % repr(self.opclasses),
)
def __eq__(self, other):
if self.__class__ == other.__class__:
return self.deconstruct() == other.deconstruct()
return NotImplemented
class IndexExpression(Func):
"""Order and wrap expressions for CREATE INDEX statements."""
template = "%(expressions)s"
wrapper_classes = (OrderBy, Collate)
def set_wrapper_classes(self, connection=None):
# Some databases (e.g. MySQL) treats COLLATE as an indexed expression.
if connection and connection.features.collate_as_index_expression:
self.wrapper_classes = tuple(
[
wrapper_cls
for wrapper_cls in self.wrapper_classes
if wrapper_cls is not Collate
]
)
@classmethod
def register_wrappers(cls, *wrapper_classes):
cls.wrapper_classes = wrapper_classes
def resolve_expression(
self,
query=None,
allow_joins=True,
reuse=None,
summarize=False,
for_save=False,
):
expressions = list(self.flatten())
# Split expressions and wrappers.
index_expressions, wrappers = partition(
lambda e: isinstance(e, self.wrapper_classes),
expressions,
)
wrapper_types = [type(wrapper) for wrapper in wrappers]
if len(wrapper_types) != len(set(wrapper_types)):
raise ValueError(
"Multiple references to %s can't be used in an indexed "
"expression."
% ", ".join(
[wrapper_cls.__qualname__ for wrapper_cls in self.wrapper_classes]
)
)
if expressions[1 : len(wrappers) + 1] != wrappers:
raise ValueError(
"%s must be topmost expressions in an indexed expression."
% ", ".join(
[wrapper_cls.__qualname__ for wrapper_cls in self.wrapper_classes]
)
)
# Wrap expressions in parentheses if they are not column references.
root_expression = index_expressions[1]
resolve_root_expression = root_expression.resolve_expression(
query,
allow_joins,
reuse,
summarize,
for_save,
)
if not isinstance(resolve_root_expression, Col):
root_expression = Func(root_expression, template="(%(expressions)s)")
if wrappers:
# Order wrappers and set their expressions.
wrappers = sorted(
wrappers,
key=lambda w: self.wrapper_classes.index(type(w)),
)
wrappers = [wrapper.copy() for wrapper in wrappers]
for i, wrapper in enumerate(wrappers[:-1]):
wrapper.set_source_expressions([wrappers[i + 1]])
# Set the root expression on the deepest wrapper.
wrappers[-1].set_source_expressions([root_expression])
self.set_source_expressions([wrappers[0]])
else:
# Use the root expression, if there are no wrappers.
self.set_source_expressions([root_expression])
return super().resolve_expression(
query, allow_joins, reuse, summarize, for_save
)
def as_sqlite(self, compiler, connection, **extra_context):
# Casting to numeric is unnecessary.
return self.as_sql(compiler, connection, **extra_context)

View File

@ -0,0 +1,727 @@
import itertools
import math
from django.core.exceptions import EmptyResultSet, FullResultSet
from django.db.models.expressions import Case, Expression, Func, Value, When
from django.db.models.fields import (
BooleanField,
CharField,
DateTimeField,
Field,
IntegerField,
UUIDField,
)
from django.db.models.query_utils import RegisterLookupMixin
from django.utils.datastructures import OrderedSet
from django.utils.functional import cached_property
from django.utils.hashable import make_hashable
class Lookup(Expression):
lookup_name = None
prepare_rhs = True
can_use_none_as_rhs = False
def __init__(self, lhs, rhs):
self.lhs, self.rhs = lhs, rhs
self.rhs = self.get_prep_lookup()
self.lhs = self.get_prep_lhs()
if hasattr(self.lhs, "get_bilateral_transforms"):
bilateral_transforms = self.lhs.get_bilateral_transforms()
else:
bilateral_transforms = []
if bilateral_transforms:
# Warn the user as soon as possible if they are trying to apply
# a bilateral transformation on a nested QuerySet: that won't work.
from django.db.models.sql.query import Query # avoid circular import
if isinstance(rhs, Query):
raise NotImplementedError(
"Bilateral transformations on nested querysets are not implemented."
)
self.bilateral_transforms = bilateral_transforms
def apply_bilateral_transforms(self, value):
for transform in self.bilateral_transforms:
value = transform(value)
return value
def __repr__(self):
return f"{self.__class__.__name__}({self.lhs!r}, {self.rhs!r})"
def batch_process_rhs(self, compiler, connection, rhs=None):
if rhs is None:
rhs = self.rhs
if self.bilateral_transforms:
sqls, sqls_params = [], []
for p in rhs:
value = Value(p, output_field=self.lhs.output_field)
value = self.apply_bilateral_transforms(value)
value = value.resolve_expression(compiler.query)
sql, sql_params = compiler.compile(value)
sqls.append(sql)
sqls_params.extend(sql_params)
else:
_, params = self.get_db_prep_lookup(rhs, connection)
sqls, sqls_params = ["%s"] * len(params), params
return sqls, sqls_params
def get_source_expressions(self):
if self.rhs_is_direct_value():
return [self.lhs]
return [self.lhs, self.rhs]
def set_source_expressions(self, new_exprs):
if len(new_exprs) == 1:
self.lhs = new_exprs[0]
else:
self.lhs, self.rhs = new_exprs
def get_prep_lookup(self):
if not self.prepare_rhs or hasattr(self.rhs, "resolve_expression"):
return self.rhs
if hasattr(self.lhs, "output_field"):
if hasattr(self.lhs.output_field, "get_prep_value"):
return self.lhs.output_field.get_prep_value(self.rhs)
elif self.rhs_is_direct_value():
return Value(self.rhs)
return self.rhs
def get_prep_lhs(self):
if hasattr(self.lhs, "resolve_expression"):
return self.lhs
return Value(self.lhs)
def get_db_prep_lookup(self, value, connection):
return ("%s", [value])
def process_lhs(self, compiler, connection, lhs=None):
lhs = lhs or self.lhs
if hasattr(lhs, "resolve_expression"):
lhs = lhs.resolve_expression(compiler.query)
sql, params = compiler.compile(lhs)
if isinstance(lhs, Lookup):
# Wrapped in parentheses to respect operator precedence.
sql = f"({sql})"
return sql, params
def process_rhs(self, compiler, connection):
value = self.rhs
if self.bilateral_transforms:
if self.rhs_is_direct_value():
# Do not call get_db_prep_lookup here as the value will be
# transformed before being used for lookup
value = Value(value, output_field=self.lhs.output_field)
value = self.apply_bilateral_transforms(value)
value = value.resolve_expression(compiler.query)
if hasattr(value, "as_sql"):
sql, params = compiler.compile(value)
# Ensure expression is wrapped in parentheses to respect operator
# precedence but avoid double wrapping as it can be misinterpreted
# on some backends (e.g. subqueries on SQLite).
if sql and sql[0] != "(":
sql = "(%s)" % sql
return sql, params
else:
return self.get_db_prep_lookup(value, connection)
def rhs_is_direct_value(self):
return not hasattr(self.rhs, "as_sql")
def get_group_by_cols(self):
cols = []
for source in self.get_source_expressions():
cols.extend(source.get_group_by_cols())
return cols
def as_oracle(self, compiler, connection):
# Oracle doesn't allow EXISTS() and filters to be compared to another
# expression unless they're wrapped in a CASE WHEN.
wrapped = False
exprs = []
for expr in (self.lhs, self.rhs):
if connection.ops.conditional_expression_supported_in_where_clause(expr):
expr = Case(When(expr, then=True), default=False)
wrapped = True
exprs.append(expr)
lookup = type(self)(*exprs) if wrapped else self
return lookup.as_sql(compiler, connection)
@cached_property
def output_field(self):
return BooleanField()
@property
def identity(self):
return self.__class__, self.lhs, self.rhs
def __eq__(self, other):
if not isinstance(other, Lookup):
return NotImplemented
return self.identity == other.identity
def __hash__(self):
return hash(make_hashable(self.identity))
def resolve_expression(
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
):
c = self.copy()
c.is_summary = summarize
c.lhs = self.lhs.resolve_expression(
query, allow_joins, reuse, summarize, for_save
)
if hasattr(self.rhs, "resolve_expression"):
c.rhs = self.rhs.resolve_expression(
query, allow_joins, reuse, summarize, for_save
)
return c
def select_format(self, compiler, sql, params):
# Wrap filters with a CASE WHEN expression if a database backend
# (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
# BY list.
if not compiler.connection.features.supports_boolean_expr_in_select_clause:
sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
return sql, params
class Transform(RegisterLookupMixin, Func):
"""
RegisterLookupMixin() is first so that get_lookup() and get_transform()
first examine self and then check output_field.
"""
bilateral = False
arity = 1
@property
def lhs(self):
return self.get_source_expressions()[0]
def get_bilateral_transforms(self):
if hasattr(self.lhs, "get_bilateral_transforms"):
bilateral_transforms = self.lhs.get_bilateral_transforms()
else:
bilateral_transforms = []
if self.bilateral:
bilateral_transforms.append(self.__class__)
return bilateral_transforms
class BuiltinLookup(Lookup):
def process_lhs(self, compiler, connection, lhs=None):
lhs_sql, params = super().process_lhs(compiler, connection, lhs)
field_internal_type = self.lhs.output_field.get_internal_type()
db_type = self.lhs.output_field.db_type(connection=connection)
lhs_sql = connection.ops.field_cast_sql(db_type, field_internal_type) % lhs_sql
lhs_sql = (
connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql
)
return lhs_sql, list(params)
def as_sql(self, compiler, connection):
lhs_sql, params = self.process_lhs(compiler, connection)
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
params.extend(rhs_params)
rhs_sql = self.get_rhs_op(connection, rhs_sql)
return "%s %s" % (lhs_sql, rhs_sql), params
def get_rhs_op(self, connection, rhs):
return connection.operators[self.lookup_name] % rhs
class FieldGetDbPrepValueMixin:
"""
Some lookups require Field.get_db_prep_value() to be called on their
inputs.
"""
get_db_prep_lookup_value_is_iterable = False
def get_db_prep_lookup(self, value, connection):
# For relational fields, use the 'target_field' attribute of the
# output_field.
field = getattr(self.lhs.output_field, "target_field", None)
get_db_prep_value = (
getattr(field, "get_db_prep_value", None)
or self.lhs.output_field.get_db_prep_value
)
return (
"%s",
[get_db_prep_value(v, connection, prepared=True) for v in value]
if self.get_db_prep_lookup_value_is_iterable
else [get_db_prep_value(value, connection, prepared=True)],
)
class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
"""
Some lookups require Field.get_db_prep_value() to be called on each value
in an iterable.
"""
get_db_prep_lookup_value_is_iterable = True
def get_prep_lookup(self):
if hasattr(self.rhs, "resolve_expression"):
return self.rhs
prepared_values = []
for rhs_value in self.rhs:
if hasattr(rhs_value, "resolve_expression"):
# An expression will be handled by the database but can coexist
# alongside real values.
pass
elif self.prepare_rhs and hasattr(self.lhs.output_field, "get_prep_value"):
rhs_value = self.lhs.output_field.get_prep_value(rhs_value)
prepared_values.append(rhs_value)
return prepared_values
def process_rhs(self, compiler, connection):
if self.rhs_is_direct_value():
# rhs should be an iterable of values. Use batch_process_rhs()
# to prepare/transform those values.
return self.batch_process_rhs(compiler, connection)
else:
return super().process_rhs(compiler, connection)
def resolve_expression_parameter(self, compiler, connection, sql, param):
params = [param]
if hasattr(param, "resolve_expression"):
param = param.resolve_expression(compiler.query)
if hasattr(param, "as_sql"):
sql, params = compiler.compile(param)
return sql, params
def batch_process_rhs(self, compiler, connection, rhs=None):
pre_processed = super().batch_process_rhs(compiler, connection, rhs)
# The params list may contain expressions which compile to a
# sql/param pair. Zip them to get sql and param pairs that refer to the
# same argument and attempt to replace them with the result of
# compiling the param step.
sql, params = zip(
*(
self.resolve_expression_parameter(compiler, connection, sql, param)
for sql, param in zip(*pre_processed)
)
)
params = itertools.chain.from_iterable(params)
return sql, tuple(params)
class PostgresOperatorLookup(Lookup):
"""Lookup defined by operators on PostgreSQL."""
postgres_operator = None
def as_postgresql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = tuple(lhs_params) + tuple(rhs_params)
return "%s %s %s" % (lhs, self.postgres_operator, rhs), params
@Field.register_lookup
class Exact(FieldGetDbPrepValueMixin, BuiltinLookup):
lookup_name = "exact"
def get_prep_lookup(self):
from django.db.models.sql.query import Query # avoid circular import
if isinstance(self.rhs, Query):
if self.rhs.has_limit_one():
if not self.rhs.has_select_fields:
self.rhs.clear_select_clause()
self.rhs.add_fields(["pk"])
else:
raise ValueError(
"The QuerySet value for an exact lookup must be limited to "
"one result using slicing."
)
return super().get_prep_lookup()
def as_sql(self, compiler, connection):
# Avoid comparison against direct rhs if lhs is a boolean value. That
# turns "boolfield__exact=True" into "WHERE boolean_field" instead of
# "WHERE boolean_field = True" when allowed.
if (
isinstance(self.rhs, bool)
and getattr(self.lhs, "conditional", False)
and connection.ops.conditional_expression_supported_in_where_clause(
self.lhs
)
):
lhs_sql, params = self.process_lhs(compiler, connection)
template = "%s" if self.rhs else "NOT %s"
return template % lhs_sql, params
return super().as_sql(compiler, connection)
@Field.register_lookup
class IExact(BuiltinLookup):
lookup_name = "iexact"
prepare_rhs = False
def process_rhs(self, qn, connection):
rhs, params = super().process_rhs(qn, connection)
if params:
params[0] = connection.ops.prep_for_iexact_query(params[0])
return rhs, params
@Field.register_lookup
class GreaterThan(FieldGetDbPrepValueMixin, BuiltinLookup):
lookup_name = "gt"
@Field.register_lookup
class GreaterThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
lookup_name = "gte"
@Field.register_lookup
class LessThan(FieldGetDbPrepValueMixin, BuiltinLookup):
lookup_name = "lt"
@Field.register_lookup
class LessThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
lookup_name = "lte"
class IntegerFieldFloatRounding:
"""
Allow floats to work as query values for IntegerField. Without this, the
decimal portion of the float would always be discarded.
"""
def get_prep_lookup(self):
if isinstance(self.rhs, float):
self.rhs = math.ceil(self.rhs)
return super().get_prep_lookup()
@IntegerField.register_lookup
class IntegerGreaterThanOrEqual(IntegerFieldFloatRounding, GreaterThanOrEqual):
pass
@IntegerField.register_lookup
class IntegerLessThan(IntegerFieldFloatRounding, LessThan):
pass
@Field.register_lookup
class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
lookup_name = "in"
def get_prep_lookup(self):
from django.db.models.sql.query import Query # avoid circular import
if isinstance(self.rhs, Query):
self.rhs.clear_ordering(clear_default=True)
if not self.rhs.has_select_fields:
self.rhs.clear_select_clause()
self.rhs.add_fields(["pk"])
return super().get_prep_lookup()
def process_rhs(self, compiler, connection):
db_rhs = getattr(self.rhs, "_db", None)
if db_rhs is not None and db_rhs != connection.alias:
raise ValueError(
"Subqueries aren't allowed across different databases. Force "
"the inner query to be evaluated using `list(inner_query)`."
)
if self.rhs_is_direct_value():
# Remove None from the list as NULL is never equal to anything.
try:
rhs = OrderedSet(self.rhs)
rhs.discard(None)
except TypeError: # Unhashable items in self.rhs
rhs = [r for r in self.rhs if r is not None]
if not rhs:
raise EmptyResultSet
# rhs should be an iterable; use batch_process_rhs() to
# prepare/transform those values.
sqls, sqls_params = self.batch_process_rhs(compiler, connection, rhs)
placeholder = "(" + ", ".join(sqls) + ")"
return (placeholder, sqls_params)
return super().process_rhs(compiler, connection)
def get_rhs_op(self, connection, rhs):
return "IN %s" % rhs
def as_sql(self, compiler, connection):
max_in_list_size = connection.ops.max_in_list_size()
if (
self.rhs_is_direct_value()
and max_in_list_size
and len(self.rhs) > max_in_list_size
):
return self.split_parameter_list_as_sql(compiler, connection)
return super().as_sql(compiler, connection)
def split_parameter_list_as_sql(self, compiler, connection):
# This is a special case for databases which limit the number of
# elements which can appear in an 'IN' clause.
max_in_list_size = connection.ops.max_in_list_size()
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.batch_process_rhs(compiler, connection)
in_clause_elements = ["("]
params = []
for offset in range(0, len(rhs_params), max_in_list_size):
if offset > 0:
in_clause_elements.append(" OR ")
in_clause_elements.append("%s IN (" % lhs)
params.extend(lhs_params)
sqls = rhs[offset : offset + max_in_list_size]
sqls_params = rhs_params[offset : offset + max_in_list_size]
param_group = ", ".join(sqls)
in_clause_elements.append(param_group)
in_clause_elements.append(")")
params.extend(sqls_params)
in_clause_elements.append(")")
return "".join(in_clause_elements), params
class PatternLookup(BuiltinLookup):
param_pattern = "%%%s%%"
prepare_rhs = False
def get_rhs_op(self, connection, rhs):
# Assume we are in startswith. We need to produce SQL like:
# col LIKE %s, ['thevalue%']
# For python values we can (and should) do that directly in Python,
# but if the value is for example reference to other column, then
# we need to add the % pattern match to the lookup by something like
# col LIKE othercol || '%%'
# So, for Python values we don't need any special pattern, but for
# SQL reference values or SQL transformations we need the correct
# pattern added.
if hasattr(self.rhs, "as_sql") or self.bilateral_transforms:
pattern = connection.pattern_ops[self.lookup_name].format(
connection.pattern_esc
)
return pattern.format(rhs)
else:
return super().get_rhs_op(connection, rhs)
def process_rhs(self, qn, connection):
rhs, params = super().process_rhs(qn, connection)
if self.rhs_is_direct_value() and params and not self.bilateral_transforms:
params[0] = self.param_pattern % connection.ops.prep_for_like_query(
params[0]
)
return rhs, params
@Field.register_lookup
class Contains(PatternLookup):
lookup_name = "contains"
@Field.register_lookup
class IContains(Contains):
lookup_name = "icontains"
@Field.register_lookup
class StartsWith(PatternLookup):
lookup_name = "startswith"
param_pattern = "%s%%"
@Field.register_lookup
class IStartsWith(StartsWith):
lookup_name = "istartswith"
@Field.register_lookup
class EndsWith(PatternLookup):
lookup_name = "endswith"
param_pattern = "%%%s"
@Field.register_lookup
class IEndsWith(EndsWith):
lookup_name = "iendswith"
@Field.register_lookup
class Range(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
lookup_name = "range"
def get_rhs_op(self, connection, rhs):
return "BETWEEN %s AND %s" % (rhs[0], rhs[1])
@Field.register_lookup
class IsNull(BuiltinLookup):
lookup_name = "isnull"
prepare_rhs = False
def as_sql(self, compiler, connection):
if not isinstance(self.rhs, bool):
raise ValueError(
"The QuerySet value for an isnull lookup must be True or False."
)
if isinstance(self.lhs, Value):
if self.lhs.value is None or (
self.lhs.value == ""
and connection.features.interprets_empty_strings_as_nulls
):
result_exception = FullResultSet if self.rhs else EmptyResultSet
else:
result_exception = EmptyResultSet if self.rhs else FullResultSet
raise result_exception
sql, params = self.process_lhs(compiler, connection)
if self.rhs:
return "%s IS NULL" % sql, params
else:
return "%s IS NOT NULL" % sql, params
@Field.register_lookup
class Regex(BuiltinLookup):
lookup_name = "regex"
prepare_rhs = False
def as_sql(self, compiler, connection):
if self.lookup_name in connection.operators:
return super().as_sql(compiler, connection)
else:
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
sql_template = connection.ops.regex_lookup(self.lookup_name)
return sql_template % (lhs, rhs), lhs_params + rhs_params
@Field.register_lookup
class IRegex(Regex):
lookup_name = "iregex"
class YearLookup(Lookup):
def year_lookup_bounds(self, connection, year):
from django.db.models.functions import ExtractIsoYear
iso_year = isinstance(self.lhs, ExtractIsoYear)
output_field = self.lhs.lhs.output_field
if isinstance(output_field, DateTimeField):
bounds = connection.ops.year_lookup_bounds_for_datetime_field(
year,
iso_year=iso_year,
)
else:
bounds = connection.ops.year_lookup_bounds_for_date_field(
year,
iso_year=iso_year,
)
return bounds
def as_sql(self, compiler, connection):
# Avoid the extract operation if the rhs is a direct value to allow
# indexes to be used.
if self.rhs_is_direct_value():
# Skip the extract part by directly using the originating field,
# that is self.lhs.lhs.
lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
rhs_sql, _ = self.process_rhs(compiler, connection)
rhs_sql = self.get_direct_rhs_sql(connection, rhs_sql)
start, finish = self.year_lookup_bounds(connection, self.rhs)
params.extend(self.get_bound_params(start, finish))
return "%s %s" % (lhs_sql, rhs_sql), params
return super().as_sql(compiler, connection)
def get_direct_rhs_sql(self, connection, rhs):
return connection.operators[self.lookup_name] % rhs
def get_bound_params(self, start, finish):
raise NotImplementedError(
"subclasses of YearLookup must provide a get_bound_params() method"
)
class YearExact(YearLookup, Exact):
def get_direct_rhs_sql(self, connection, rhs):
return "BETWEEN %s AND %s"
def get_bound_params(self, start, finish):
return (start, finish)
class YearGt(YearLookup, GreaterThan):
def get_bound_params(self, start, finish):
return (finish,)
class YearGte(YearLookup, GreaterThanOrEqual):
def get_bound_params(self, start, finish):
return (start,)
class YearLt(YearLookup, LessThan):
def get_bound_params(self, start, finish):
return (start,)
class YearLte(YearLookup, LessThanOrEqual):
def get_bound_params(self, start, finish):
return (finish,)
class UUIDTextMixin:
"""
Strip hyphens from a value when filtering a UUIDField on backends without
a native datatype for UUID.
"""
def process_rhs(self, qn, connection):
if not connection.features.has_native_uuid_field:
from django.db.models.functions import Replace
if self.rhs_is_direct_value():
self.rhs = Value(self.rhs)
self.rhs = Replace(
self.rhs, Value("-"), Value(""), output_field=CharField()
)
rhs, params = super().process_rhs(qn, connection)
return rhs, params
@UUIDField.register_lookup
class UUIDIExact(UUIDTextMixin, IExact):
pass
@UUIDField.register_lookup
class UUIDContains(UUIDTextMixin, Contains):
pass
@UUIDField.register_lookup
class UUIDIContains(UUIDTextMixin, IContains):
pass
@UUIDField.register_lookup
class UUIDStartsWith(UUIDTextMixin, StartsWith):
pass
@UUIDField.register_lookup
class UUIDIStartsWith(UUIDTextMixin, IStartsWith):
pass
@UUIDField.register_lookup
class UUIDEndsWith(UUIDTextMixin, EndsWith):
pass
@UUIDField.register_lookup
class UUIDIEndsWith(UUIDTextMixin, IEndsWith):
pass

View File

@ -0,0 +1,213 @@
import copy
import inspect
from functools import wraps
from importlib import import_module
from django.db import router
from django.db.models.query import QuerySet
class BaseManager:
# To retain order, track each time a Manager instance is created.
creation_counter = 0
# Set to True for the 'objects' managers that are automatically created.
auto_created = False
#: If set to True the manager will be serialized into migrations and will
#: thus be available in e.g. RunPython operations.
use_in_migrations = False
def __new__(cls, *args, **kwargs):
# Capture the arguments to make returning them trivial.
obj = super().__new__(cls)
obj._constructor_args = (args, kwargs)
return obj
def __init__(self):
super().__init__()
self._set_creation_counter()
self.model = None
self.name = None
self._db = None
self._hints = {}
def __str__(self):
"""Return "app_label.model_label.manager_name"."""
return "%s.%s" % (self.model._meta.label, self.name)
def __class_getitem__(cls, *args, **kwargs):
return cls
def deconstruct(self):
"""
Return a 5-tuple of the form (as_manager (True), manager_class,
queryset_class, args, kwargs).
Raise a ValueError if the manager is dynamically generated.
"""
qs_class = self._queryset_class
if getattr(self, "_built_with_as_manager", False):
# using MyQuerySet.as_manager()
return (
True, # as_manager
None, # manager_class
"%s.%s" % (qs_class.__module__, qs_class.__name__), # qs_class
None, # args
None, # kwargs
)
else:
module_name = self.__module__
name = self.__class__.__name__
# Make sure it's actually there and not an inner class
module = import_module(module_name)
if not hasattr(module, name):
raise ValueError(
"Could not find manager %s in %s.\n"
"Please note that you need to inherit from managers you "
"dynamically generated with 'from_queryset()'."
% (name, module_name)
)
return (
False, # as_manager
"%s.%s" % (module_name, name), # manager_class
None, # qs_class
self._constructor_args[0], # args
self._constructor_args[1], # kwargs
)
def check(self, **kwargs):
return []
@classmethod
def _get_queryset_methods(cls, queryset_class):
def create_method(name, method):
@wraps(method)
def manager_method(self, *args, **kwargs):
return getattr(self.get_queryset(), name)(*args, **kwargs)
return manager_method
new_methods = {}
for name, method in inspect.getmembers(
queryset_class, predicate=inspect.isfunction
):
# Only copy missing methods.
if hasattr(cls, name):
continue
# Only copy public methods or methods with the attribute
# queryset_only=False.
queryset_only = getattr(method, "queryset_only", None)
if queryset_only or (queryset_only is None and name.startswith("_")):
continue
# Copy the method onto the manager.
new_methods[name] = create_method(name, method)
return new_methods
@classmethod
def from_queryset(cls, queryset_class, class_name=None):
if class_name is None:
class_name = "%sFrom%s" % (cls.__name__, queryset_class.__name__)
return type(
class_name,
(cls,),
{
"_queryset_class": queryset_class,
**cls._get_queryset_methods(queryset_class),
},
)
def contribute_to_class(self, cls, name):
self.name = self.name or name
self.model = cls
setattr(cls, name, ManagerDescriptor(self))
cls._meta.add_manager(self)
def _set_creation_counter(self):
"""
Set the creation counter value for this instance and increment the
class-level copy.
"""
self.creation_counter = BaseManager.creation_counter
BaseManager.creation_counter += 1
def db_manager(self, using=None, hints=None):
obj = copy.copy(self)
obj._db = using or self._db
obj._hints = hints or self._hints
return obj
@property
def db(self):
return self._db or router.db_for_read(self.model, **self._hints)
#######################
# PROXIES TO QUERYSET #
#######################
def get_queryset(self):
"""
Return a new QuerySet object. Subclasses can override this method to
customize the behavior of the Manager.
"""
return self._queryset_class(model=self.model, using=self._db, hints=self._hints)
def all(self):
# We can't proxy this method through the `QuerySet` like we do for the
# rest of the `QuerySet` methods. This is because `QuerySet.all()`
# works by creating a "copy" of the current queryset and in making said
# copy, all the cached `prefetch_related` lookups are lost. See the
# implementation of `RelatedManager.get_queryset()` for a better
# understanding of how this comes into play.
return self.get_queryset()
def __eq__(self, other):
return (
isinstance(other, self.__class__)
and self._constructor_args == other._constructor_args
)
def __hash__(self):
return id(self)
class Manager(BaseManager.from_queryset(QuerySet)):
pass
class ManagerDescriptor:
def __init__(self, manager):
self.manager = manager
def __get__(self, instance, cls=None):
if instance is not None:
raise AttributeError(
"Manager isn't accessible via %s instances" % cls.__name__
)
if cls._meta.abstract:
raise AttributeError(
"Manager isn't available; %s is abstract" % (cls._meta.object_name,)
)
if cls._meta.swapped:
raise AttributeError(
"Manager isn't available; '%s' has been swapped for '%s'"
% (
cls._meta.label,
cls._meta.swapped,
)
)
return cls._meta.managers_map[self.manager.name]
class EmptyManager(Manager):
def __init__(self, model):
super().__init__()
self.model = model
def get_queryset(self):
return super().get_queryset().none()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,435 @@
"""
Various data structures used in query construction.
Factored out from django.db.models.query to avoid making the main module very
large and/or so that they can be used by other modules without getting into
circular import difficulties.
"""
import functools
import inspect
import logging
from collections import namedtuple
from django.core.exceptions import FieldError
from django.db import DEFAULT_DB_ALIAS, DatabaseError, connections
from django.db.models.constants import LOOKUP_SEP
from django.utils import tree
logger = logging.getLogger("django.db.models")
# PathInfo is used when converting lookups (fk__somecol). The contents
# describe the relation in Model terms (model Options and Fields for both
# sides of the relation. The join_field is the field backing the relation.
PathInfo = namedtuple(
"PathInfo",
"from_opts to_opts target_fields join_field m2m direct filtered_relation",
)
def subclasses(cls):
yield cls
for subclass in cls.__subclasses__():
yield from subclasses(subclass)
class Q(tree.Node):
"""
Encapsulate filters as objects that can then be combined logically (using
`&` and `|`).
"""
# Connection types
AND = "AND"
OR = "OR"
XOR = "XOR"
default = AND
conditional = True
def __init__(self, *args, _connector=None, _negated=False, **kwargs):
super().__init__(
children=[*args, *sorted(kwargs.items())],
connector=_connector,
negated=_negated,
)
def _combine(self, other, conn):
if getattr(other, "conditional", False) is False:
raise TypeError(other)
if not self:
return other.copy()
if not other and isinstance(other, Q):
return self.copy()
obj = self.create(connector=conn)
obj.add(self, conn)
obj.add(other, conn)
return obj
def __or__(self, other):
return self._combine(other, self.OR)
def __and__(self, other):
return self._combine(other, self.AND)
def __xor__(self, other):
return self._combine(other, self.XOR)
def __invert__(self):
obj = self.copy()
obj.negate()
return obj
def resolve_expression(
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
):
# We must promote any new joins to left outer joins so that when Q is
# used as an expression, rows aren't filtered due to joins.
clause, joins = query._add_q(
self,
reuse,
allow_joins=allow_joins,
split_subq=False,
check_filterable=False,
summarize=summarize,
)
query.promote_joins(joins)
return clause
def flatten(self):
"""
Recursively yield this Q object and all subexpressions, in depth-first
order.
"""
yield self
for child in self.children:
if isinstance(child, tuple):
# Use the lookup.
child = child[1]
if hasattr(child, "flatten"):
yield from child.flatten()
else:
yield child
def check(self, against, using=DEFAULT_DB_ALIAS):
"""
Do a database query to check if the expressions of the Q instance
matches against the expressions.
"""
# Avoid circular imports.
from django.db.models import BooleanField, Value
from django.db.models.functions import Coalesce
from django.db.models.sql import Query
from django.db.models.sql.constants import SINGLE
query = Query(None)
for name, value in against.items():
if not hasattr(value, "resolve_expression"):
value = Value(value)
query.add_annotation(value, name, select=False)
query.add_annotation(Value(1), "_check")
# This will raise a FieldError if a field is missing in "against".
if connections[using].features.supports_comparing_boolean_expr:
query.add_q(Q(Coalesce(self, True, output_field=BooleanField())))
else:
query.add_q(self)
compiler = query.get_compiler(using=using)
try:
return compiler.execute_sql(SINGLE) is not None
except DatabaseError as e:
logger.warning("Got a database error calling check() on %r: %s", self, e)
return True
def deconstruct(self):
path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
if path.startswith("django.db.models.query_utils"):
path = path.replace("django.db.models.query_utils", "django.db.models")
args = tuple(self.children)
kwargs = {}
if self.connector != self.default:
kwargs["_connector"] = self.connector
if self.negated:
kwargs["_negated"] = True
return path, args, kwargs
class DeferredAttribute:
"""
A wrapper for a deferred-loading field. When the value is read from this
object the first time, the query is executed.
"""
def __init__(self, field):
self.field = field
def __get__(self, instance, cls=None):
"""
Retrieve and caches the value from the datastore on the first lookup.
Return the cached value.
"""
if instance is None:
return self
data = instance.__dict__
field_name = self.field.attname
if field_name not in data:
# Let's see if the field is part of the parent chain. If so we
# might be able to reuse the already loaded value. Refs #18343.
val = self._check_parent_chain(instance)
if val is None:
instance.refresh_from_db(fields=[field_name])
else:
data[field_name] = val
return data[field_name]
def _check_parent_chain(self, instance):
"""
Check if the field value can be fetched from a parent field already
loaded in the instance. This can be done if the to-be fetched
field is a primary key field.
"""
opts = instance._meta
link_field = opts.get_ancestor_link(self.field.model)
if self.field.primary_key and self.field != link_field:
return getattr(instance, link_field.attname)
return None
class class_or_instance_method:
"""
Hook used in RegisterLookupMixin to return partial functions depending on
the caller type (instance or class of models.Field).
"""
def __init__(self, class_method, instance_method):
self.class_method = class_method
self.instance_method = instance_method
def __get__(self, instance, owner):
if instance is None:
return functools.partial(self.class_method, owner)
return functools.partial(self.instance_method, instance)
class RegisterLookupMixin:
def _get_lookup(self, lookup_name):
return self.get_lookups().get(lookup_name, None)
@functools.lru_cache(maxsize=None)
def get_class_lookups(cls):
class_lookups = [
parent.__dict__.get("class_lookups", {}) for parent in inspect.getmro(cls)
]
return cls.merge_dicts(class_lookups)
def get_instance_lookups(self):
class_lookups = self.get_class_lookups()
if instance_lookups := getattr(self, "instance_lookups", None):
return {**class_lookups, **instance_lookups}
return class_lookups
get_lookups = class_or_instance_method(get_class_lookups, get_instance_lookups)
get_class_lookups = classmethod(get_class_lookups)
def get_lookup(self, lookup_name):
from django.db.models.lookups import Lookup
found = self._get_lookup(lookup_name)
if found is None and hasattr(self, "output_field"):
return self.output_field.get_lookup(lookup_name)
if found is not None and not issubclass(found, Lookup):
return None
return found
def get_transform(self, lookup_name):
from django.db.models.lookups import Transform
found = self._get_lookup(lookup_name)
if found is None and hasattr(self, "output_field"):
return self.output_field.get_transform(lookup_name)
if found is not None and not issubclass(found, Transform):
return None
return found
@staticmethod
def merge_dicts(dicts):
"""
Merge dicts in reverse to preference the order of the original list. e.g.,
merge_dicts([a, b]) will preference the keys in 'a' over those in 'b'.
"""
merged = {}
for d in reversed(dicts):
merged.update(d)
return merged
@classmethod
def _clear_cached_class_lookups(cls):
for subclass in subclasses(cls):
subclass.get_class_lookups.cache_clear()
def register_class_lookup(cls, lookup, lookup_name=None):
if lookup_name is None:
lookup_name = lookup.lookup_name
if "class_lookups" not in cls.__dict__:
cls.class_lookups = {}
cls.class_lookups[lookup_name] = lookup
cls._clear_cached_class_lookups()
return lookup
def register_instance_lookup(self, lookup, lookup_name=None):
if lookup_name is None:
lookup_name = lookup.lookup_name
if "instance_lookups" not in self.__dict__:
self.instance_lookups = {}
self.instance_lookups[lookup_name] = lookup
return lookup
register_lookup = class_or_instance_method(
register_class_lookup, register_instance_lookup
)
register_class_lookup = classmethod(register_class_lookup)
def _unregister_class_lookup(cls, lookup, lookup_name=None):
"""
Remove given lookup from cls lookups. For use in tests only as it's
not thread-safe.
"""
if lookup_name is None:
lookup_name = lookup.lookup_name
del cls.class_lookups[lookup_name]
cls._clear_cached_class_lookups()
def _unregister_instance_lookup(self, lookup, lookup_name=None):
"""
Remove given lookup from instance lookups. For use in tests only as
it's not thread-safe.
"""
if lookup_name is None:
lookup_name = lookup.lookup_name
del self.instance_lookups[lookup_name]
_unregister_lookup = class_or_instance_method(
_unregister_class_lookup, _unregister_instance_lookup
)
_unregister_class_lookup = classmethod(_unregister_class_lookup)
def select_related_descend(field, restricted, requested, select_mask, reverse=False):
"""
Return True if this field should be used to descend deeper for
select_related() purposes. Used by both the query construction code
(compiler.get_related_selections()) and the model instance creation code
(compiler.klass_info).
Arguments:
* field - the field to be checked
* restricted - a boolean field, indicating if the field list has been
manually restricted using a requested clause)
* requested - The select_related() dictionary.
* select_mask - the dictionary of selected fields.
* reverse - boolean, True if we are checking a reverse select related
"""
if not field.remote_field:
return False
if field.remote_field.parent_link and not reverse:
return False
if restricted:
if reverse and field.related_query_name() not in requested:
return False
if not reverse and field.name not in requested:
return False
if not restricted and field.null:
return False
if (
restricted
and select_mask
and field.name in requested
and field not in select_mask
):
raise FieldError(
f"Field {field.model._meta.object_name}.{field.name} cannot be both "
"deferred and traversed using select_related at the same time."
)
return True
def refs_expression(lookup_parts, annotations):
"""
Check if the lookup_parts contains references to the given annotations set.
Because the LOOKUP_SEP is contained in the default annotation names, check
each prefix of the lookup_parts for a match.
"""
for n in range(1, len(lookup_parts) + 1):
level_n_lookup = LOOKUP_SEP.join(lookup_parts[0:n])
if annotations.get(level_n_lookup):
return level_n_lookup, lookup_parts[n:]
return None, ()
def check_rel_lookup_compatibility(model, target_opts, field):
"""
Check that self.model is compatible with target_opts. Compatibility
is OK if:
1) model and opts match (where proxy inheritance is removed)
2) model is parent of opts' model or the other way around
"""
def check(opts):
return (
model._meta.concrete_model == opts.concrete_model
or opts.concrete_model in model._meta.get_parent_list()
or model in opts.get_parent_list()
)
# If the field is a primary key, then doing a query against the field's
# model is ok, too. Consider the case:
# class Restaurant(models.Model):
# place = OneToOneField(Place, primary_key=True):
# Restaurant.objects.filter(pk__in=Restaurant.objects.all()).
# If we didn't have the primary key check, then pk__in (== place__in) would
# give Place's opts as the target opts, but Restaurant isn't compatible
# with that. This logic applies only to primary keys, as when doing __in=qs,
# we are going to turn this into __in=qs.values('pk') later on.
return check(target_opts) or (
getattr(field, "primary_key", False) and check(field.model._meta)
)
class FilteredRelation:
"""Specify custom filtering in the ON clause of SQL joins."""
def __init__(self, relation_name, *, condition=Q()):
if not relation_name:
raise ValueError("relation_name cannot be empty.")
self.relation_name = relation_name
self.alias = None
if not isinstance(condition, Q):
raise ValueError("condition argument must be a Q() instance.")
self.condition = condition
self.path = []
def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return (
self.relation_name == other.relation_name
and self.alias == other.alias
and self.condition == other.condition
)
def clone(self):
clone = FilteredRelation(self.relation_name, condition=self.condition)
clone.alias = self.alias
clone.path = self.path[:]
return clone
def resolve_expression(self, *args, **kwargs):
"""
QuerySet.annotate() only accepts expression-like arguments
(with a resolve_expression() method).
"""
raise NotImplementedError("FilteredRelation.resolve_expression() is unused.")
def as_sql(self, compiler, connection):
# Resolve the condition in Join.filtered_relation.
query = compiler.query
where = query.build_filtered_relation_q(self.condition, reuse=set(self.path))
return compiler.compile(where)

View File

@ -0,0 +1,54 @@
from functools import partial
from django.db.models.utils import make_model_tuple
from django.dispatch import Signal
class_prepared = Signal()
class ModelSignal(Signal):
"""
Signal subclass that allows the sender to be lazily specified as a string
of the `app_label.ModelName` form.
"""
def _lazy_method(self, method, apps, receiver, sender, **kwargs):
from django.db.models.options import Options
# This partial takes a single optional argument named "sender".
partial_method = partial(method, receiver, **kwargs)
if isinstance(sender, str):
apps = apps or Options.default_apps
apps.lazy_model_operation(partial_method, make_model_tuple(sender))
else:
return partial_method(sender)
def connect(self, receiver, sender=None, weak=True, dispatch_uid=None, apps=None):
self._lazy_method(
super().connect,
apps,
receiver,
sender,
weak=weak,
dispatch_uid=dispatch_uid,
)
def disconnect(self, receiver=None, sender=None, dispatch_uid=None, apps=None):
return self._lazy_method(
super().disconnect, apps, receiver, sender, dispatch_uid=dispatch_uid
)
pre_init = ModelSignal(use_caching=True)
post_init = ModelSignal(use_caching=True)
pre_save = ModelSignal(use_caching=True)
post_save = ModelSignal(use_caching=True)
pre_delete = ModelSignal(use_caching=True)
post_delete = ModelSignal(use_caching=True)
m2m_changed = ModelSignal(use_caching=True)
pre_migrate = Signal()
post_migrate = Signal()

View File

@ -0,0 +1,6 @@
from django.db.models.sql.query import * # NOQA
from django.db.models.sql.query import Query
from django.db.models.sql.subqueries import * # NOQA
from django.db.models.sql.where import AND, OR, XOR
__all__ = ["Query", "AND", "OR", "XOR"]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,24 @@
"""
Constants specific to the SQL storage portion of the ORM.
"""
# Size of each "chunk" for get_iterator calls.
# Larger values are slightly faster at the expense of more storage space.
GET_ITERATOR_CHUNK_SIZE = 100
# Namedtuples for sql.* internal use.
# How many results to expect from a cursor.execute call
MULTI = "multi"
SINGLE = "single"
CURSOR = "cursor"
NO_RESULTS = "no results"
ORDER_DIR = {
"ASC": ("ASC", "DESC"),
"DESC": ("DESC", "ASC"),
}
# SQL join types.
INNER = "INNER JOIN"
LOUTER = "LEFT OUTER JOIN"

View File

@ -0,0 +1,224 @@
"""
Useful auxiliary data structures for query construction. Not useful outside
the SQL domain.
"""
from django.core.exceptions import FullResultSet
from django.db.models.sql.constants import INNER, LOUTER
class MultiJoin(Exception):
"""
Used by join construction code to indicate the point at which a
multi-valued join was attempted (if the caller wants to treat that
exceptionally).
"""
def __init__(self, names_pos, path_with_names):
self.level = names_pos
# The path travelled, this includes the path to the multijoin.
self.names_with_path = path_with_names
class Empty:
pass
class Join:
"""
Used by sql.Query and sql.SQLCompiler to generate JOIN clauses into the
FROM entry. For example, the SQL generated could be
LEFT OUTER JOIN "sometable" T1
ON ("othertable"."sometable_id" = "sometable"."id")
This class is primarily used in Query.alias_map. All entries in alias_map
must be Join compatible by providing the following attributes and methods:
- table_name (string)
- table_alias (possible alias for the table, can be None)
- join_type (can be None for those entries that aren't joined from
anything)
- parent_alias (which table is this join's parent, can be None similarly
to join_type)
- as_sql()
- relabeled_clone()
"""
def __init__(
self,
table_name,
parent_alias,
table_alias,
join_type,
join_field,
nullable,
filtered_relation=None,
):
# Join table
self.table_name = table_name
self.parent_alias = parent_alias
# Note: table_alias is not necessarily known at instantiation time.
self.table_alias = table_alias
# LOUTER or INNER
self.join_type = join_type
# A list of 2-tuples to use in the ON clause of the JOIN.
# Each 2-tuple will create one join condition in the ON clause.
self.join_cols = join_field.get_joining_columns()
# Along which field (or ForeignObjectRel in the reverse join case)
self.join_field = join_field
# Is this join nullabled?
self.nullable = nullable
self.filtered_relation = filtered_relation
def as_sql(self, compiler, connection):
"""
Generate the full
LEFT OUTER JOIN sometable ON sometable.somecol = othertable.othercol, params
clause for this join.
"""
join_conditions = []
params = []
qn = compiler.quote_name_unless_alias
qn2 = connection.ops.quote_name
# Add a join condition for each pair of joining columns.
for lhs_col, rhs_col in self.join_cols:
join_conditions.append(
"%s.%s = %s.%s"
% (
qn(self.parent_alias),
qn2(lhs_col),
qn(self.table_alias),
qn2(rhs_col),
)
)
# Add a single condition inside parentheses for whatever
# get_extra_restriction() returns.
extra_cond = self.join_field.get_extra_restriction(
self.table_alias, self.parent_alias
)
if extra_cond:
extra_sql, extra_params = compiler.compile(extra_cond)
join_conditions.append("(%s)" % extra_sql)
params.extend(extra_params)
if self.filtered_relation:
try:
extra_sql, extra_params = compiler.compile(self.filtered_relation)
except FullResultSet:
pass
else:
join_conditions.append("(%s)" % extra_sql)
params.extend(extra_params)
if not join_conditions:
# This might be a rel on the other end of an actual declared field.
declared_field = getattr(self.join_field, "field", self.join_field)
raise ValueError(
"Join generated an empty ON clause. %s did not yield either "
"joining columns or extra restrictions." % declared_field.__class__
)
on_clause_sql = " AND ".join(join_conditions)
alias_str = (
"" if self.table_alias == self.table_name else (" %s" % self.table_alias)
)
sql = "%s %s%s ON (%s)" % (
self.join_type,
qn(self.table_name),
alias_str,
on_clause_sql,
)
return sql, params
def relabeled_clone(self, change_map):
new_parent_alias = change_map.get(self.parent_alias, self.parent_alias)
new_table_alias = change_map.get(self.table_alias, self.table_alias)
if self.filtered_relation is not None:
filtered_relation = self.filtered_relation.clone()
filtered_relation.path = [
change_map.get(p, p) for p in self.filtered_relation.path
]
else:
filtered_relation = None
return self.__class__(
self.table_name,
new_parent_alias,
new_table_alias,
self.join_type,
self.join_field,
self.nullable,
filtered_relation=filtered_relation,
)
@property
def identity(self):
return (
self.__class__,
self.table_name,
self.parent_alias,
self.join_field,
self.filtered_relation,
)
def __eq__(self, other):
if not isinstance(other, Join):
return NotImplemented
return self.identity == other.identity
def __hash__(self):
return hash(self.identity)
def equals(self, other):
# Ignore filtered_relation in equality check.
return self.identity[:-1] == other.identity[:-1]
def demote(self):
new = self.relabeled_clone({})
new.join_type = INNER
return new
def promote(self):
new = self.relabeled_clone({})
new.join_type = LOUTER
return new
class BaseTable:
"""
The BaseTable class is used for base table references in FROM clause. For
example, the SQL "foo" in
SELECT * FROM "foo" WHERE somecond
could be generated by this class.
"""
join_type = None
parent_alias = None
filtered_relation = None
def __init__(self, table_name, alias):
self.table_name = table_name
self.table_alias = alias
def as_sql(self, compiler, connection):
alias_str = (
"" if self.table_alias == self.table_name else (" %s" % self.table_alias)
)
base_sql = compiler.quote_name_unless_alias(self.table_name)
return base_sql + alias_str, []
def relabeled_clone(self, change_map):
return self.__class__(
self.table_name, change_map.get(self.table_alias, self.table_alias)
)
@property
def identity(self):
return self.__class__, self.table_name, self.table_alias
def __eq__(self, other):
if not isinstance(other, BaseTable):
return NotImplemented
return self.identity == other.identity
def __hash__(self):
return hash(self.identity)
def equals(self, other):
return self.identity == other.identity

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,171 @@
"""
Query subclasses which provide extra functionality beyond simple data retrieval.
"""
from django.core.exceptions import FieldError
from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE, NO_RESULTS
from django.db.models.sql.query import Query
__all__ = ["DeleteQuery", "UpdateQuery", "InsertQuery", "AggregateQuery"]
class DeleteQuery(Query):
"""A DELETE SQL query."""
compiler = "SQLDeleteCompiler"
def do_query(self, table, where, using):
self.alias_map = {table: self.alias_map[table]}
self.where = where
cursor = self.get_compiler(using).execute_sql(CURSOR)
if cursor:
with cursor:
return cursor.rowcount
return 0
def delete_batch(self, pk_list, using):
"""
Set up and execute delete queries for all the objects in pk_list.
More than one physical query may be executed if there are a
lot of values in pk_list.
"""
# number of objects deleted
num_deleted = 0
field = self.get_meta().pk
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
self.clear_where()
self.add_filter(
f"{field.attname}__in",
pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE],
)
num_deleted += self.do_query(
self.get_meta().db_table, self.where, using=using
)
return num_deleted
class UpdateQuery(Query):
"""An UPDATE SQL query."""
compiler = "SQLUpdateCompiler"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._setup_query()
def _setup_query(self):
"""
Run on initialization and at the end of chaining. Any attributes that
would normally be set in __init__() should go here instead.
"""
self.values = []
self.related_ids = None
self.related_updates = {}
def clone(self):
obj = super().clone()
obj.related_updates = self.related_updates.copy()
return obj
def update_batch(self, pk_list, values, using):
self.add_update_values(values)
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
self.clear_where()
self.add_filter(
"pk__in", pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]
)
self.get_compiler(using).execute_sql(NO_RESULTS)
def add_update_values(self, values):
"""
Convert a dictionary of field name to value mappings into an update
query. This is the entry point for the public update() method on
querysets.
"""
values_seq = []
for name, val in values.items():
field = self.get_meta().get_field(name)
direct = (
not (field.auto_created and not field.concrete) or not field.concrete
)
model = field.model._meta.concrete_model
if not direct or (field.is_relation and field.many_to_many):
raise FieldError(
"Cannot update model field %r (only non-relations and "
"foreign keys permitted)." % field
)
if model is not self.get_meta().concrete_model:
self.add_related_update(model, field, val)
continue
values_seq.append((field, model, val))
return self.add_update_fields(values_seq)
def add_update_fields(self, values_seq):
"""
Append a sequence of (field, model, value) triples to the internal list
that will be used to generate the UPDATE query. Might be more usefully
called add_update_targets() to hint at the extra information here.
"""
for field, model, val in values_seq:
if hasattr(val, "resolve_expression"):
# Resolve expressions here so that annotations are no longer needed
val = val.resolve_expression(self, allow_joins=False, for_save=True)
self.values.append((field, model, val))
def add_related_update(self, model, field, value):
"""
Add (name, value) to an update query for an ancestor model.
Update are coalesced so that only one update query per ancestor is run.
"""
self.related_updates.setdefault(model, []).append((field, None, value))
def get_related_updates(self):
"""
Return a list of query objects: one for each update required to an
ancestor model. Each query will have the same filtering conditions as
the current query but will only update a single table.
"""
if not self.related_updates:
return []
result = []
for model, values in self.related_updates.items():
query = UpdateQuery(model)
query.values = values
if self.related_ids is not None:
query.add_filter("pk__in", self.related_ids[model])
result.append(query)
return result
class InsertQuery(Query):
compiler = "SQLInsertCompiler"
def __init__(
self, *args, on_conflict=None, update_fields=None, unique_fields=None, **kwargs
):
super().__init__(*args, **kwargs)
self.fields = []
self.objs = []
self.on_conflict = on_conflict
self.update_fields = update_fields or []
self.unique_fields = unique_fields or []
def insert_values(self, fields, objs, raw=False):
self.fields = fields
self.objs = objs
self.raw = raw
class AggregateQuery(Query):
"""
Take another query as a parameter to the FROM clause and only select the
elements in the provided list.
"""
compiler = "SQLAggregateCompiler"
def __init__(self, model, inner_query):
self.inner_query = inner_query
super().__init__(model)

View File

@ -0,0 +1,355 @@
"""
Code to manage the creation and SQL rendering of 'where' constraints.
"""
import operator
from functools import reduce
from django.core.exceptions import EmptyResultSet, FullResultSet
from django.db.models.expressions import Case, When
from django.db.models.lookups import Exact
from django.utils import tree
from django.utils.functional import cached_property
# Connection types
AND = "AND"
OR = "OR"
XOR = "XOR"
class WhereNode(tree.Node):
"""
An SQL WHERE clause.
The class is tied to the Query class that created it (in order to create
the correct SQL).
A child is usually an expression producing boolean values. Most likely the
expression is a Lookup instance.
However, a child could also be any class with as_sql() and either
relabeled_clone() method or relabel_aliases() and clone() methods and
contains_aggregate attribute.
"""
default = AND
resolved = False
conditional = True
def split_having_qualify(self, negated=False, must_group_by=False):
"""
Return three possibly None nodes: one for those parts of self that
should be included in the WHERE clause, one for those parts of self
that must be included in the HAVING clause, and one for those parts
that refer to window functions.
"""
if not self.contains_aggregate and not self.contains_over_clause:
return self, None, None
in_negated = negated ^ self.negated
# Whether or not children must be connected in the same filtering
# clause (WHERE > HAVING > QUALIFY) to maintain logical semantic.
must_remain_connected = (
(in_negated and self.connector == AND)
or (not in_negated and self.connector == OR)
or self.connector == XOR
)
if (
must_remain_connected
and self.contains_aggregate
and not self.contains_over_clause
):
# It's must cheaper to short-circuit and stash everything in the
# HAVING clause than split children if possible.
return None, self, None
where_parts = []
having_parts = []
qualify_parts = []
for c in self.children:
if hasattr(c, "split_having_qualify"):
where_part, having_part, qualify_part = c.split_having_qualify(
in_negated, must_group_by
)
if where_part is not None:
where_parts.append(where_part)
if having_part is not None:
having_parts.append(having_part)
if qualify_part is not None:
qualify_parts.append(qualify_part)
elif c.contains_over_clause:
qualify_parts.append(c)
elif c.contains_aggregate:
having_parts.append(c)
else:
where_parts.append(c)
if must_remain_connected and qualify_parts:
# Disjunctive heterogeneous predicates can be pushed down to
# qualify as long as no conditional aggregation is involved.
if not where_parts or (where_parts and not must_group_by):
return None, None, self
elif where_parts:
# In theory this should only be enforced when dealing with
# where_parts containing predicates against multi-valued
# relationships that could affect aggregation results but this
# is complex to infer properly.
raise NotImplementedError(
"Heterogeneous disjunctive predicates against window functions are "
"not implemented when performing conditional aggregation."
)
where_node = (
self.create(where_parts, self.connector, self.negated)
if where_parts
else None
)
having_node = (
self.create(having_parts, self.connector, self.negated)
if having_parts
else None
)
qualify_node = (
self.create(qualify_parts, self.connector, self.negated)
if qualify_parts
else None
)
return where_node, having_node, qualify_node
def as_sql(self, compiler, connection):
"""
Return the SQL version of the where clause and the value to be
substituted in. Return '', [] if this node matches everything,
None, [] if this node is empty, and raise EmptyResultSet if this
node can't match anything.
"""
result = []
result_params = []
if self.connector == AND:
full_needed, empty_needed = len(self.children), 1
else:
full_needed, empty_needed = 1, len(self.children)
if self.connector == XOR and not connection.features.supports_logical_xor:
# Convert if the database doesn't support XOR:
# a XOR b XOR c XOR ...
# to:
# (a OR b OR c OR ...) AND (a + b + c + ...) == 1
lhs = self.__class__(self.children, OR)
rhs_sum = reduce(
operator.add,
(Case(When(c, then=1), default=0) for c in self.children),
)
rhs = Exact(1, rhs_sum)
return self.__class__([lhs, rhs], AND, self.negated).as_sql(
compiler, connection
)
for child in self.children:
try:
sql, params = compiler.compile(child)
except EmptyResultSet:
empty_needed -= 1
except FullResultSet:
full_needed -= 1
else:
if sql:
result.append(sql)
result_params.extend(params)
else:
full_needed -= 1
# Check if this node matches nothing or everything.
# First check the amount of full nodes and empty nodes
# to make this node empty/full.
# Now, check if this node is full/empty using the
# counts.
if empty_needed == 0:
if self.negated:
raise FullResultSet
else:
raise EmptyResultSet
if full_needed == 0:
if self.negated:
raise EmptyResultSet
else:
raise FullResultSet
conn = " %s " % self.connector
sql_string = conn.join(result)
if not sql_string:
raise FullResultSet
if self.negated:
# Some backends (Oracle at least) need parentheses around the inner
# SQL in the negated case, even if the inner SQL contains just a
# single expression.
sql_string = "NOT (%s)" % sql_string
elif len(result) > 1 or self.resolved:
sql_string = "(%s)" % sql_string
return sql_string, result_params
def get_group_by_cols(self):
cols = []
for child in self.children:
cols.extend(child.get_group_by_cols())
return cols
def get_source_expressions(self):
return self.children[:]
def set_source_expressions(self, children):
assert len(children) == len(self.children)
self.children = children
def relabel_aliases(self, change_map):
"""
Relabel the alias values of any children. 'change_map' is a dictionary
mapping old (current) alias values to the new values.
"""
for pos, child in enumerate(self.children):
if hasattr(child, "relabel_aliases"):
# For example another WhereNode
child.relabel_aliases(change_map)
elif hasattr(child, "relabeled_clone"):
self.children[pos] = child.relabeled_clone(change_map)
def clone(self):
clone = self.create(connector=self.connector, negated=self.negated)
for child in self.children:
if hasattr(child, "clone"):
child = child.clone()
clone.children.append(child)
return clone
def relabeled_clone(self, change_map):
clone = self.clone()
clone.relabel_aliases(change_map)
return clone
def replace_expressions(self, replacements):
if replacement := replacements.get(self):
return replacement
clone = self.create(connector=self.connector, negated=self.negated)
for child in self.children:
clone.children.append(child.replace_expressions(replacements))
return clone
def get_refs(self):
refs = set()
for child in self.children:
refs |= child.get_refs()
return refs
@classmethod
def _contains_aggregate(cls, obj):
if isinstance(obj, tree.Node):
return any(cls._contains_aggregate(c) for c in obj.children)
return obj.contains_aggregate
@cached_property
def contains_aggregate(self):
return self._contains_aggregate(self)
@classmethod
def _contains_over_clause(cls, obj):
if isinstance(obj, tree.Node):
return any(cls._contains_over_clause(c) for c in obj.children)
return obj.contains_over_clause
@cached_property
def contains_over_clause(self):
return self._contains_over_clause(self)
@property
def is_summary(self):
return any(child.is_summary for child in self.children)
@staticmethod
def _resolve_leaf(expr, query, *args, **kwargs):
if hasattr(expr, "resolve_expression"):
expr = expr.resolve_expression(query, *args, **kwargs)
return expr
@classmethod
def _resolve_node(cls, node, query, *args, **kwargs):
if hasattr(node, "children"):
for child in node.children:
cls._resolve_node(child, query, *args, **kwargs)
if hasattr(node, "lhs"):
node.lhs = cls._resolve_leaf(node.lhs, query, *args, **kwargs)
if hasattr(node, "rhs"):
node.rhs = cls._resolve_leaf(node.rhs, query, *args, **kwargs)
def resolve_expression(self, *args, **kwargs):
clone = self.clone()
clone._resolve_node(clone, *args, **kwargs)
clone.resolved = True
return clone
@cached_property
def output_field(self):
from django.db.models import BooleanField
return BooleanField()
@property
def _output_field_or_none(self):
return self.output_field
def select_format(self, compiler, sql, params):
# Wrap filters with a CASE WHEN expression if a database backend
# (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
# BY list.
if not compiler.connection.features.supports_boolean_expr_in_select_clause:
sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
return sql, params
def get_db_converters(self, connection):
return self.output_field.get_db_converters(connection)
def get_lookup(self, lookup):
return self.output_field.get_lookup(lookup)
def leaves(self):
for child in self.children:
if isinstance(child, WhereNode):
yield from child.leaves()
else:
yield child
class NothingNode:
"""A node that matches nothing."""
contains_aggregate = False
contains_over_clause = False
def as_sql(self, compiler=None, connection=None):
raise EmptyResultSet
class ExtraWhere:
# The contents are a black box - assume no aggregates or windows are used.
contains_aggregate = False
contains_over_clause = False
def __init__(self, sqls, params):
self.sqls = sqls
self.params = params
def as_sql(self, compiler=None, connection=None):
sqls = ["(%s)" % sql for sql in self.sqls]
return " AND ".join(sqls), list(self.params or ())
class SubqueryConstraint:
# Even if aggregates or windows would be used in a subquery,
# the outer query isn't interested about those.
contains_aggregate = False
contains_over_clause = False
def __init__(self, alias, columns, targets, query_object):
self.alias = alias
self.columns = columns
self.targets = targets
query_object.clear_ordering(clear_default=True)
self.query_object = query_object
def as_sql(self, compiler, connection):
query = self.query_object
query.set_values(self.targets)
query_compiler = query.get_compiler(connection=connection)
return query_compiler.as_subquery_condition(self.alias, self.columns, compiler)

View File

@ -0,0 +1,69 @@
import functools
from collections import namedtuple
def make_model_tuple(model):
"""
Take a model or a string of the form "app_label.ModelName" and return a
corresponding ("app_label", "modelname") tuple. If a tuple is passed in,
assume it's a valid model tuple already and return it unchanged.
"""
try:
if isinstance(model, tuple):
model_tuple = model
elif isinstance(model, str):
app_label, model_name = model.split(".")
model_tuple = app_label, model_name.lower()
else:
model_tuple = model._meta.app_label, model._meta.model_name
assert len(model_tuple) == 2
return model_tuple
except (ValueError, AssertionError):
raise ValueError(
"Invalid model reference '%s'. String model references "
"must be of the form 'app_label.ModelName'." % model
)
def resolve_callables(mapping):
"""
Generate key/value pairs for the given mapping where the values are
evaluated if they're callable.
"""
for k, v in mapping.items():
yield k, v() if callable(v) else v
def unpickle_named_row(names, values):
return create_namedtuple_class(*names)(*values)
@functools.lru_cache
def create_namedtuple_class(*names):
# Cache type() with @lru_cache since it's too slow to be called for every
# QuerySet evaluation.
def __reduce__(self):
return unpickle_named_row, (names, tuple(self))
return type(
"Row",
(namedtuple("Row", names),),
{"__reduce__": __reduce__, "__slots__": ()},
)
class AltersData:
"""
Make subclasses preserve the alters_data attribute on overridden methods.
"""
def __init_subclass__(cls, **kwargs):
for fn_name, fn in vars(cls).items():
if callable(fn) and not hasattr(fn, "alters_data"):
for base in cls.__bases__:
if base_fn := getattr(base, fn_name, None):
if hasattr(base_fn, "alters_data"):
fn.alters_data = base_fn.alters_data
break
super().__init_subclass__(**kwargs)