docker setup
This commit is contained in:
@ -0,0 +1,33 @@
|
||||
r"""
|
||||
______ _____ _____ _____ __
|
||||
| ___ \ ___/ ___|_ _| / _| | |
|
||||
| |_/ / |__ \ `--. | | | |_ _ __ __ _ _ __ ___ _____ _____ _ __| |__
|
||||
| /| __| `--. \ | | | _| '__/ _` | '_ ` _ \ / _ \ \ /\ / / _ \| '__| |/ /
|
||||
| |\ \| |___/\__/ / | | | | | | | (_| | | | | | | __/\ V V / (_) | | | <
|
||||
\_| \_\____/\____/ \_/ |_| |_| \__,_|_| |_| |_|\___| \_/\_/ \___/|_| |_|\_|
|
||||
"""
|
||||
|
||||
import django
|
||||
|
||||
__title__ = 'Django REST framework'
|
||||
__version__ = '3.14.0'
|
||||
__author__ = 'Tom Christie'
|
||||
__license__ = 'BSD 3-Clause'
|
||||
__copyright__ = 'Copyright 2011-2019 Encode OSS Ltd'
|
||||
|
||||
# Version synonym
|
||||
VERSION = __version__
|
||||
|
||||
# Header encoding (see RFC5987)
|
||||
HTTP_HEADER_ENCODING = 'iso-8859-1'
|
||||
|
||||
# Default datetime input and output formats
|
||||
ISO_8601 = 'iso-8601'
|
||||
|
||||
|
||||
if django.VERSION < (3, 2):
|
||||
default_app_config = 'rest_framework.apps.RestFrameworkConfig'
|
||||
|
||||
|
||||
class RemovedInDRF315Warning(PendingDeprecationWarning):
|
||||
pass
|
@ -0,0 +1,10 @@
|
||||
from django.apps import AppConfig
|
||||
|
||||
|
||||
class RestFrameworkConfig(AppConfig):
|
||||
name = 'rest_framework'
|
||||
verbose_name = "Django REST framework"
|
||||
|
||||
def ready(self):
|
||||
# Add System checks
|
||||
from .checks import pagination_system_check # NOQA
|
@ -0,0 +1,232 @@
|
||||
"""
|
||||
Provides various authentication policies.
|
||||
"""
|
||||
import base64
|
||||
import binascii
|
||||
|
||||
from django.contrib.auth import authenticate, get_user_model
|
||||
from django.middleware.csrf import CsrfViewMiddleware
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from rest_framework import HTTP_HEADER_ENCODING, exceptions
|
||||
|
||||
|
||||
def get_authorization_header(request):
|
||||
"""
|
||||
Return request's 'Authorization:' header, as a bytestring.
|
||||
|
||||
Hide some test client ickyness where the header can be unicode.
|
||||
"""
|
||||
auth = request.META.get('HTTP_AUTHORIZATION', b'')
|
||||
if isinstance(auth, str):
|
||||
# Work around django test client oddness
|
||||
auth = auth.encode(HTTP_HEADER_ENCODING)
|
||||
return auth
|
||||
|
||||
|
||||
class CSRFCheck(CsrfViewMiddleware):
|
||||
def _reject(self, request, reason):
|
||||
# Return the failure reason instead of an HttpResponse
|
||||
return reason
|
||||
|
||||
|
||||
class BaseAuthentication:
|
||||
"""
|
||||
All authentication classes should extend BaseAuthentication.
|
||||
"""
|
||||
|
||||
def authenticate(self, request):
|
||||
"""
|
||||
Authenticate the request and return a two-tuple of (user, token).
|
||||
"""
|
||||
raise NotImplementedError(".authenticate() must be overridden.")
|
||||
|
||||
def authenticate_header(self, request):
|
||||
"""
|
||||
Return a string to be used as the value of the `WWW-Authenticate`
|
||||
header in a `401 Unauthenticated` response, or `None` if the
|
||||
authentication scheme should return `403 Permission Denied` responses.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class BasicAuthentication(BaseAuthentication):
|
||||
"""
|
||||
HTTP Basic authentication against username/password.
|
||||
"""
|
||||
www_authenticate_realm = 'api'
|
||||
|
||||
def authenticate(self, request):
|
||||
"""
|
||||
Returns a `User` if a correct username and password have been supplied
|
||||
using HTTP Basic authentication. Otherwise returns `None`.
|
||||
"""
|
||||
auth = get_authorization_header(request).split()
|
||||
|
||||
if not auth or auth[0].lower() != b'basic':
|
||||
return None
|
||||
|
||||
if len(auth) == 1:
|
||||
msg = _('Invalid basic header. No credentials provided.')
|
||||
raise exceptions.AuthenticationFailed(msg)
|
||||
elif len(auth) > 2:
|
||||
msg = _('Invalid basic header. Credentials string should not contain spaces.')
|
||||
raise exceptions.AuthenticationFailed(msg)
|
||||
|
||||
try:
|
||||
try:
|
||||
auth_decoded = base64.b64decode(auth[1]).decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
auth_decoded = base64.b64decode(auth[1]).decode('latin-1')
|
||||
auth_parts = auth_decoded.partition(':')
|
||||
except (TypeError, UnicodeDecodeError, binascii.Error):
|
||||
msg = _('Invalid basic header. Credentials not correctly base64 encoded.')
|
||||
raise exceptions.AuthenticationFailed(msg)
|
||||
|
||||
userid, password = auth_parts[0], auth_parts[2]
|
||||
return self.authenticate_credentials(userid, password, request)
|
||||
|
||||
def authenticate_credentials(self, userid, password, request=None):
|
||||
"""
|
||||
Authenticate the userid and password against username and password
|
||||
with optional request for context.
|
||||
"""
|
||||
credentials = {
|
||||
get_user_model().USERNAME_FIELD: userid,
|
||||
'password': password
|
||||
}
|
||||
user = authenticate(request=request, **credentials)
|
||||
|
||||
if user is None:
|
||||
raise exceptions.AuthenticationFailed(_('Invalid username/password.'))
|
||||
|
||||
if not user.is_active:
|
||||
raise exceptions.AuthenticationFailed(_('User inactive or deleted.'))
|
||||
|
||||
return (user, None)
|
||||
|
||||
def authenticate_header(self, request):
|
||||
return 'Basic realm="%s"' % self.www_authenticate_realm
|
||||
|
||||
|
||||
class SessionAuthentication(BaseAuthentication):
|
||||
"""
|
||||
Use Django's session framework for authentication.
|
||||
"""
|
||||
|
||||
def authenticate(self, request):
|
||||
"""
|
||||
Returns a `User` if the request session currently has a logged in user.
|
||||
Otherwise returns `None`.
|
||||
"""
|
||||
|
||||
# Get the session-based user from the underlying HttpRequest object
|
||||
user = getattr(request._request, 'user', None)
|
||||
|
||||
# Unauthenticated, CSRF validation not required
|
||||
if not user or not user.is_active:
|
||||
return None
|
||||
|
||||
self.enforce_csrf(request)
|
||||
|
||||
# CSRF passed with authenticated user
|
||||
return (user, None)
|
||||
|
||||
def enforce_csrf(self, request):
|
||||
"""
|
||||
Enforce CSRF validation for session based authentication.
|
||||
"""
|
||||
def dummy_get_response(request): # pragma: no cover
|
||||
return None
|
||||
|
||||
check = CSRFCheck(dummy_get_response)
|
||||
# populates request.META['CSRF_COOKIE'], which is used in process_view()
|
||||
check.process_request(request)
|
||||
reason = check.process_view(request, None, (), {})
|
||||
if reason:
|
||||
# CSRF failed, bail with explicit error message
|
||||
raise exceptions.PermissionDenied('CSRF Failed: %s' % reason)
|
||||
|
||||
|
||||
class TokenAuthentication(BaseAuthentication):
|
||||
"""
|
||||
Simple token based authentication.
|
||||
|
||||
Clients should authenticate by passing the token key in the "Authorization"
|
||||
HTTP header, prepended with the string "Token ". For example:
|
||||
|
||||
Authorization: Token 401f7ac837da42b97f613d789819ff93537bee6a
|
||||
"""
|
||||
|
||||
keyword = 'Token'
|
||||
model = None
|
||||
|
||||
def get_model(self):
|
||||
if self.model is not None:
|
||||
return self.model
|
||||
from rest_framework.authtoken.models import Token
|
||||
return Token
|
||||
|
||||
"""
|
||||
A custom token model may be used, but must have the following properties.
|
||||
|
||||
* key -- The string identifying the token
|
||||
* user -- The user to which the token belongs
|
||||
"""
|
||||
|
||||
def authenticate(self, request):
|
||||
auth = get_authorization_header(request).split()
|
||||
|
||||
if not auth or auth[0].lower() != self.keyword.lower().encode():
|
||||
return None
|
||||
|
||||
if len(auth) == 1:
|
||||
msg = _('Invalid token header. No credentials provided.')
|
||||
raise exceptions.AuthenticationFailed(msg)
|
||||
elif len(auth) > 2:
|
||||
msg = _('Invalid token header. Token string should not contain spaces.')
|
||||
raise exceptions.AuthenticationFailed(msg)
|
||||
|
||||
try:
|
||||
token = auth[1].decode()
|
||||
except UnicodeError:
|
||||
msg = _('Invalid token header. Token string should not contain invalid characters.')
|
||||
raise exceptions.AuthenticationFailed(msg)
|
||||
|
||||
return self.authenticate_credentials(token)
|
||||
|
||||
def authenticate_credentials(self, key):
|
||||
model = self.get_model()
|
||||
try:
|
||||
token = model.objects.select_related('user').get(key=key)
|
||||
except model.DoesNotExist:
|
||||
raise exceptions.AuthenticationFailed(_('Invalid token.'))
|
||||
|
||||
if not token.user.is_active:
|
||||
raise exceptions.AuthenticationFailed(_('User inactive or deleted.'))
|
||||
|
||||
return (token.user, token)
|
||||
|
||||
def authenticate_header(self, request):
|
||||
return self.keyword
|
||||
|
||||
|
||||
class RemoteUserAuthentication(BaseAuthentication):
|
||||
"""
|
||||
REMOTE_USER authentication.
|
||||
|
||||
To use this, set up your web server to perform authentication, which will
|
||||
set the REMOTE_USER environment variable. You will need to have
|
||||
'django.contrib.auth.backends.RemoteUserBackend in your
|
||||
AUTHENTICATION_BACKENDS setting
|
||||
"""
|
||||
|
||||
# Name of request header to grab username from. This will be the key as
|
||||
# used in the request.META dictionary, i.e. the normalization of headers to
|
||||
# all uppercase and the addition of "HTTP_" prefix apply.
|
||||
header = "REMOTE_USER"
|
||||
|
||||
def authenticate(self, request):
|
||||
user = authenticate(request=request, remote_user=request.META.get(self.header))
|
||||
if user and user.is_active:
|
||||
return (user, None)
|
@ -0,0 +1,4 @@
|
||||
import django
|
||||
|
||||
if django.VERSION < (3, 2):
|
||||
default_app_config = 'rest_framework.authtoken.apps.AuthTokenConfig'
|
@ -0,0 +1,51 @@
|
||||
from django.contrib import admin
|
||||
from django.contrib.admin.utils import quote
|
||||
from django.contrib.admin.views.main import ChangeList
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.urls import reverse
|
||||
|
||||
from rest_framework.authtoken.models import Token, TokenProxy
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
|
||||
class TokenChangeList(ChangeList):
|
||||
"""Map to matching User id"""
|
||||
def url_for_result(self, result):
|
||||
pk = result.user.pk
|
||||
return reverse('admin:%s_%s_change' % (self.opts.app_label,
|
||||
self.opts.model_name),
|
||||
args=(quote(pk),),
|
||||
current_app=self.model_admin.admin_site.name)
|
||||
|
||||
|
||||
class TokenAdmin(admin.ModelAdmin):
|
||||
list_display = ('key', 'user', 'created')
|
||||
fields = ('user',)
|
||||
ordering = ('-created',)
|
||||
actions = None # Actions not compatible with mapped IDs.
|
||||
|
||||
def get_changelist(self, request, **kwargs):
|
||||
return TokenChangeList
|
||||
|
||||
def get_object(self, request, object_id, from_field=None):
|
||||
"""
|
||||
Map from User ID to matching Token.
|
||||
"""
|
||||
queryset = self.get_queryset(request)
|
||||
field = User._meta.pk
|
||||
try:
|
||||
object_id = field.to_python(object_id)
|
||||
user = User.objects.get(**{field.name: object_id})
|
||||
return queryset.get(user=user)
|
||||
except (queryset.model.DoesNotExist, User.DoesNotExist, ValidationError, ValueError):
|
||||
return None
|
||||
|
||||
def delete_model(self, request, obj):
|
||||
# Map back to actual Token, since delete() uses pk.
|
||||
token = Token.objects.get(key=obj.key)
|
||||
return super().delete_model(request, token)
|
||||
|
||||
|
||||
admin.site.register(TokenProxy, TokenAdmin)
|
@ -0,0 +1,7 @@
|
||||
from django.apps import AppConfig
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class AuthTokenConfig(AppConfig):
|
||||
name = 'rest_framework.authtoken'
|
||||
verbose_name = _("Auth Token")
|
@ -0,0 +1,45 @@
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.core.management.base import BaseCommand, CommandError
|
||||
|
||||
from rest_framework.authtoken.models import Token
|
||||
|
||||
UserModel = get_user_model()
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = 'Create DRF Token for a given user'
|
||||
|
||||
def create_user_token(self, username, reset_token):
|
||||
user = UserModel._default_manager.get_by_natural_key(username)
|
||||
|
||||
if reset_token:
|
||||
Token.objects.filter(user=user).delete()
|
||||
|
||||
token = Token.objects.get_or_create(user=user)
|
||||
return token[0]
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument('username', type=str)
|
||||
|
||||
parser.add_argument(
|
||||
'-r',
|
||||
'--reset',
|
||||
action='store_true',
|
||||
dest='reset_token',
|
||||
default=False,
|
||||
help='Reset existing User token and create a new one',
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
username = options['username']
|
||||
reset_token = options['reset_token']
|
||||
|
||||
try:
|
||||
token = self.create_user_token(username, reset_token)
|
||||
except UserModel.DoesNotExist:
|
||||
raise CommandError(
|
||||
'Cannot create the Token: user {} does not exist'.format(
|
||||
username)
|
||||
)
|
||||
self.stdout.write(
|
||||
'Generated token {} for user {}'.format(token.key, username))
|
@ -0,0 +1,54 @@
|
||||
import binascii
|
||||
import os
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import models
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class Token(models.Model):
|
||||
"""
|
||||
The default authorization token model.
|
||||
"""
|
||||
key = models.CharField(_("Key"), max_length=40, primary_key=True)
|
||||
user = models.OneToOneField(
|
||||
settings.AUTH_USER_MODEL, related_name='auth_token',
|
||||
on_delete=models.CASCADE, verbose_name=_("User")
|
||||
)
|
||||
created = models.DateTimeField(_("Created"), auto_now_add=True)
|
||||
|
||||
class Meta:
|
||||
# Work around for a bug in Django:
|
||||
# https://code.djangoproject.com/ticket/19422
|
||||
#
|
||||
# Also see corresponding ticket:
|
||||
# https://github.com/encode/django-rest-framework/issues/705
|
||||
abstract = 'rest_framework.authtoken' not in settings.INSTALLED_APPS
|
||||
verbose_name = _("Token")
|
||||
verbose_name_plural = _("Tokens")
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
if not self.key:
|
||||
self.key = self.generate_key()
|
||||
return super().save(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def generate_key(cls):
|
||||
return binascii.hexlify(os.urandom(20)).decode()
|
||||
|
||||
def __str__(self):
|
||||
return self.key
|
||||
|
||||
|
||||
class TokenProxy(Token):
|
||||
"""
|
||||
Proxy mapping pk to user pk for use in admin.
|
||||
"""
|
||||
@property
|
||||
def pk(self):
|
||||
return self.user_id
|
||||
|
||||
class Meta:
|
||||
proxy = 'rest_framework.authtoken' in settings.INSTALLED_APPS
|
||||
abstract = 'rest_framework.authtoken' not in settings.INSTALLED_APPS
|
||||
verbose_name = "token"
|
@ -0,0 +1,42 @@
|
||||
from django.contrib.auth import authenticate
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
|
||||
class AuthTokenSerializer(serializers.Serializer):
|
||||
username = serializers.CharField(
|
||||
label=_("Username"),
|
||||
write_only=True
|
||||
)
|
||||
password = serializers.CharField(
|
||||
label=_("Password"),
|
||||
style={'input_type': 'password'},
|
||||
trim_whitespace=False,
|
||||
write_only=True
|
||||
)
|
||||
token = serializers.CharField(
|
||||
label=_("Token"),
|
||||
read_only=True
|
||||
)
|
||||
|
||||
def validate(self, attrs):
|
||||
username = attrs.get('username')
|
||||
password = attrs.get('password')
|
||||
|
||||
if username and password:
|
||||
user = authenticate(request=self.context.get('request'),
|
||||
username=username, password=password)
|
||||
|
||||
# The authenticate call simply returns None for is_active=False
|
||||
# users. (Assuming the default ModelBackend authentication
|
||||
# backend.)
|
||||
if not user:
|
||||
msg = _('Unable to log in with provided credentials.')
|
||||
raise serializers.ValidationError(msg, code='authorization')
|
||||
else:
|
||||
msg = _('Must include "username" and "password".')
|
||||
raise serializers.ValidationError(msg, code='authorization')
|
||||
|
||||
attrs['user'] = user
|
||||
return attrs
|
@ -0,0 +1,62 @@
|
||||
from rest_framework import parsers, renderers
|
||||
from rest_framework.authtoken.models import Token
|
||||
from rest_framework.authtoken.serializers import AuthTokenSerializer
|
||||
from rest_framework.compat import coreapi, coreschema
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.schemas import ManualSchema
|
||||
from rest_framework.schemas import coreapi as coreapi_schema
|
||||
from rest_framework.views import APIView
|
||||
|
||||
|
||||
class ObtainAuthToken(APIView):
|
||||
throttle_classes = ()
|
||||
permission_classes = ()
|
||||
parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,)
|
||||
renderer_classes = (renderers.JSONRenderer,)
|
||||
serializer_class = AuthTokenSerializer
|
||||
|
||||
if coreapi_schema.is_enabled():
|
||||
schema = ManualSchema(
|
||||
fields=[
|
||||
coreapi.Field(
|
||||
name="username",
|
||||
required=True,
|
||||
location='form',
|
||||
schema=coreschema.String(
|
||||
title="Username",
|
||||
description="Valid username for authentication",
|
||||
),
|
||||
),
|
||||
coreapi.Field(
|
||||
name="password",
|
||||
required=True,
|
||||
location='form',
|
||||
schema=coreschema.String(
|
||||
title="Password",
|
||||
description="Valid password for authentication",
|
||||
),
|
||||
),
|
||||
],
|
||||
encoding="application/json",
|
||||
)
|
||||
|
||||
def get_serializer_context(self):
|
||||
return {
|
||||
'request': self.request,
|
||||
'format': self.format_kwarg,
|
||||
'view': self
|
||||
}
|
||||
|
||||
def get_serializer(self, *args, **kwargs):
|
||||
kwargs['context'] = self.get_serializer_context()
|
||||
return self.serializer_class(*args, **kwargs)
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
serializer = self.get_serializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
user = serializer.validated_data['user']
|
||||
token, created = Token.objects.get_or_create(user=user)
|
||||
return Response({'token': token.key})
|
||||
|
||||
|
||||
obtain_auth_token = ObtainAuthToken.as_view()
|
@ -0,0 +1,21 @@
|
||||
from django.core.checks import Tags, Warning, register
|
||||
|
||||
|
||||
@register(Tags.compatibility)
|
||||
def pagination_system_check(app_configs, **kwargs):
|
||||
errors = []
|
||||
# Use of default page size setting requires a default Paginator class
|
||||
from rest_framework.settings import api_settings
|
||||
if api_settings.PAGE_SIZE and not api_settings.DEFAULT_PAGINATION_CLASS:
|
||||
errors.append(
|
||||
Warning(
|
||||
"You have specified a default PAGE_SIZE pagination rest_framework setting, "
|
||||
"without specifying also a DEFAULT_PAGINATION_CLASS.",
|
||||
hint="The default for DEFAULT_PAGINATION_CLASS is None. "
|
||||
"In previous versions this was PageNumberPagination. "
|
||||
"If you wish to define PAGE_SIZE globally whilst defining "
|
||||
"pagination_class on a per-view basis you may silence this check.",
|
||||
id="rest_framework.W001"
|
||||
)
|
||||
)
|
||||
return errors
|
184
srcs/.venv/lib/python3.11/site-packages/rest_framework/compat.py
Normal file
184
srcs/.venv/lib/python3.11/site-packages/rest_framework/compat.py
Normal file
@ -0,0 +1,184 @@
|
||||
"""
|
||||
The `compat` module provides support for backwards compatibility with older
|
||||
versions of Django/Python, and compatibility wrappers around optional packages.
|
||||
"""
|
||||
import django
|
||||
from django.conf import settings
|
||||
from django.views.generic import View
|
||||
|
||||
|
||||
def unicode_http_header(value):
|
||||
# Coerce HTTP header value to unicode.
|
||||
if isinstance(value, bytes):
|
||||
return value.decode('iso-8859-1')
|
||||
return value
|
||||
|
||||
|
||||
def distinct(queryset, base):
|
||||
if settings.DATABASES[queryset.db]["ENGINE"] == "django.db.backends.oracle":
|
||||
# distinct analogue for Oracle users
|
||||
return base.filter(pk__in=set(queryset.values_list('pk', flat=True)))
|
||||
return queryset.distinct()
|
||||
|
||||
|
||||
# django.contrib.postgres requires psycopg2
|
||||
try:
|
||||
from django.contrib.postgres import fields as postgres_fields
|
||||
except ImportError:
|
||||
postgres_fields = None
|
||||
|
||||
|
||||
# coreapi is required for CoreAPI schema generation
|
||||
try:
|
||||
import coreapi
|
||||
except ImportError:
|
||||
coreapi = None
|
||||
|
||||
# uritemplate is required for OpenAPI and CoreAPI schema generation
|
||||
try:
|
||||
import uritemplate
|
||||
except ImportError:
|
||||
uritemplate = None
|
||||
|
||||
|
||||
# coreschema is optional
|
||||
try:
|
||||
import coreschema
|
||||
except ImportError:
|
||||
coreschema = None
|
||||
|
||||
|
||||
# pyyaml is optional
|
||||
try:
|
||||
import yaml
|
||||
except ImportError:
|
||||
yaml = None
|
||||
|
||||
|
||||
# requests is optional
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
requests = None
|
||||
|
||||
|
||||
# PATCH method is not implemented by Django
|
||||
if 'patch' not in View.http_method_names:
|
||||
View.http_method_names = View.http_method_names + ['patch']
|
||||
|
||||
|
||||
# Markdown is optional (version 3.0+ required)
|
||||
try:
|
||||
import markdown
|
||||
|
||||
HEADERID_EXT_PATH = 'markdown.extensions.toc'
|
||||
LEVEL_PARAM = 'baselevel'
|
||||
|
||||
def apply_markdown(text):
|
||||
"""
|
||||
Simple wrapper around :func:`markdown.markdown` to set the base level
|
||||
of '#' style headers to <h2>.
|
||||
"""
|
||||
extensions = [HEADERID_EXT_PATH]
|
||||
extension_configs = {
|
||||
HEADERID_EXT_PATH: {
|
||||
LEVEL_PARAM: '2'
|
||||
}
|
||||
}
|
||||
md = markdown.Markdown(
|
||||
extensions=extensions, extension_configs=extension_configs
|
||||
)
|
||||
md_filter_add_syntax_highlight(md)
|
||||
return md.convert(text)
|
||||
except ImportError:
|
||||
apply_markdown = None
|
||||
markdown = None
|
||||
|
||||
|
||||
try:
|
||||
import pygments
|
||||
from pygments.formatters import HtmlFormatter
|
||||
from pygments.lexers import TextLexer, get_lexer_by_name
|
||||
|
||||
def pygments_highlight(text, lang, style):
|
||||
lexer = get_lexer_by_name(lang, stripall=False)
|
||||
formatter = HtmlFormatter(nowrap=True, style=style)
|
||||
return pygments.highlight(text, lexer, formatter)
|
||||
|
||||
def pygments_css(style):
|
||||
formatter = HtmlFormatter(style=style)
|
||||
return formatter.get_style_defs('.highlight')
|
||||
|
||||
except ImportError:
|
||||
pygments = None
|
||||
|
||||
def pygments_highlight(text, lang, style):
|
||||
return text
|
||||
|
||||
def pygments_css(style):
|
||||
return None
|
||||
|
||||
if markdown is not None and pygments is not None:
|
||||
# starting from this blogpost and modified to support current markdown extensions API
|
||||
# https://zerokspot.com/weblog/2008/06/18/syntax-highlighting-in-markdown-with-pygments/
|
||||
|
||||
import re
|
||||
|
||||
from markdown.preprocessors import Preprocessor
|
||||
|
||||
class CodeBlockPreprocessor(Preprocessor):
|
||||
pattern = re.compile(
|
||||
r'^\s*``` *([^\n]+)\n(.+?)^\s*```', re.M | re.S)
|
||||
|
||||
formatter = HtmlFormatter()
|
||||
|
||||
def run(self, lines):
|
||||
def repl(m):
|
||||
try:
|
||||
lexer = get_lexer_by_name(m.group(1))
|
||||
except (ValueError, NameError):
|
||||
lexer = TextLexer()
|
||||
code = m.group(2).replace('\t', ' ')
|
||||
code = pygments.highlight(code, lexer, self.formatter)
|
||||
code = code.replace('\n\n', '\n \n').replace('\n', '<br />').replace('\\@', '@')
|
||||
return '\n\n%s\n\n' % code
|
||||
ret = self.pattern.sub(repl, "\n".join(lines))
|
||||
return ret.split("\n")
|
||||
|
||||
def md_filter_add_syntax_highlight(md):
|
||||
md.preprocessors.register(CodeBlockPreprocessor(), 'highlight', 40)
|
||||
return True
|
||||
else:
|
||||
def md_filter_add_syntax_highlight(md):
|
||||
return False
|
||||
|
||||
|
||||
if django.VERSION >= (4, 2):
|
||||
# Django 4.2+: use the stock parse_header_parameters function
|
||||
# Note: Django 4.1 also has an implementation of parse_header_parameters
|
||||
# which is slightly different from the one in 4.2, it needs
|
||||
# the compatibility shim as well.
|
||||
from django.utils.http import parse_header_parameters
|
||||
else:
|
||||
# Django <= 4.1: create a compatibility shim for parse_header_parameters
|
||||
from django.http.multipartparser import parse_header
|
||||
|
||||
def parse_header_parameters(line):
|
||||
# parse_header works with bytes, but parse_header_parameters
|
||||
# works with strings. Call encode to convert the line to bytes.
|
||||
main_value_pair, params = parse_header(line.encode())
|
||||
return main_value_pair, {
|
||||
# parse_header will convert *some* values to string.
|
||||
# parse_header_parameters converts *all* values to string.
|
||||
# Make sure all values are converted by calling decode on
|
||||
# any remaining non-string values.
|
||||
k: v if isinstance(v, str) else v.decode()
|
||||
for k, v in params.items()
|
||||
}
|
||||
|
||||
|
||||
# `separators` argument to `json.dumps()` differs between 2.x and 3.x
|
||||
# See: https://bugs.python.org/issue22767
|
||||
SHORT_SEPARATORS = (',', ':')
|
||||
LONG_SEPARATORS = (', ', ': ')
|
||||
INDENT_SEPARATORS = (',', ': ')
|
@ -0,0 +1,233 @@
|
||||
"""
|
||||
The most important decorator in this module is `@api_view`, which is used
|
||||
for writing function-based views with REST framework.
|
||||
|
||||
There are also various decorators for setting the API policies on function
|
||||
based views, as well as the `@action` decorator, which is used to annotate
|
||||
methods on viewsets that should be included by routers.
|
||||
"""
|
||||
import types
|
||||
|
||||
from django.forms.utils import pretty_name
|
||||
|
||||
from rest_framework.views import APIView
|
||||
|
||||
|
||||
def api_view(http_method_names=None):
|
||||
"""
|
||||
Decorator that converts a function-based view into an APIView subclass.
|
||||
Takes a list of allowed methods for the view as an argument.
|
||||
"""
|
||||
http_method_names = ['GET'] if (http_method_names is None) else http_method_names
|
||||
|
||||
def decorator(func):
|
||||
|
||||
WrappedAPIView = type(
|
||||
'WrappedAPIView',
|
||||
(APIView,),
|
||||
{'__doc__': func.__doc__}
|
||||
)
|
||||
|
||||
# Note, the above allows us to set the docstring.
|
||||
# It is the equivalent of:
|
||||
#
|
||||
# class WrappedAPIView(APIView):
|
||||
# pass
|
||||
# WrappedAPIView.__doc__ = func.doc <--- Not possible to do this
|
||||
|
||||
# api_view applied without (method_names)
|
||||
assert not(isinstance(http_method_names, types.FunctionType)), \
|
||||
'@api_view missing list of allowed HTTP methods'
|
||||
|
||||
# api_view applied with eg. string instead of list of strings
|
||||
assert isinstance(http_method_names, (list, tuple)), \
|
||||
'@api_view expected a list of strings, received %s' % type(http_method_names).__name__
|
||||
|
||||
allowed_methods = set(http_method_names) | {'options'}
|
||||
WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods]
|
||||
|
||||
def handler(self, *args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
for method in http_method_names:
|
||||
setattr(WrappedAPIView, method.lower(), handler)
|
||||
|
||||
WrappedAPIView.__name__ = func.__name__
|
||||
WrappedAPIView.__module__ = func.__module__
|
||||
|
||||
WrappedAPIView.renderer_classes = getattr(func, 'renderer_classes',
|
||||
APIView.renderer_classes)
|
||||
|
||||
WrappedAPIView.parser_classes = getattr(func, 'parser_classes',
|
||||
APIView.parser_classes)
|
||||
|
||||
WrappedAPIView.authentication_classes = getattr(func, 'authentication_classes',
|
||||
APIView.authentication_classes)
|
||||
|
||||
WrappedAPIView.throttle_classes = getattr(func, 'throttle_classes',
|
||||
APIView.throttle_classes)
|
||||
|
||||
WrappedAPIView.permission_classes = getattr(func, 'permission_classes',
|
||||
APIView.permission_classes)
|
||||
|
||||
WrappedAPIView.schema = getattr(func, 'schema',
|
||||
APIView.schema)
|
||||
|
||||
return WrappedAPIView.as_view()
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def renderer_classes(renderer_classes):
|
||||
def decorator(func):
|
||||
func.renderer_classes = renderer_classes
|
||||
return func
|
||||
return decorator
|
||||
|
||||
|
||||
def parser_classes(parser_classes):
|
||||
def decorator(func):
|
||||
func.parser_classes = parser_classes
|
||||
return func
|
||||
return decorator
|
||||
|
||||
|
||||
def authentication_classes(authentication_classes):
|
||||
def decorator(func):
|
||||
func.authentication_classes = authentication_classes
|
||||
return func
|
||||
return decorator
|
||||
|
||||
|
||||
def throttle_classes(throttle_classes):
|
||||
def decorator(func):
|
||||
func.throttle_classes = throttle_classes
|
||||
return func
|
||||
return decorator
|
||||
|
||||
|
||||
def permission_classes(permission_classes):
|
||||
def decorator(func):
|
||||
func.permission_classes = permission_classes
|
||||
return func
|
||||
return decorator
|
||||
|
||||
|
||||
def schema(view_inspector):
|
||||
def decorator(func):
|
||||
func.schema = view_inspector
|
||||
return func
|
||||
return decorator
|
||||
|
||||
|
||||
def action(methods=None, detail=None, url_path=None, url_name=None, **kwargs):
|
||||
"""
|
||||
Mark a ViewSet method as a routable action.
|
||||
|
||||
`@action`-decorated functions will be endowed with a `mapping` property,
|
||||
a `MethodMapper` that can be used to add additional method-based behaviors
|
||||
on the routed action.
|
||||
|
||||
:param methods: A list of HTTP method names this action responds to.
|
||||
Defaults to GET only.
|
||||
:param detail: Required. Determines whether this action applies to
|
||||
instance/detail requests or collection/list requests.
|
||||
:param url_path: Define the URL segment for this action. Defaults to the
|
||||
name of the method decorated.
|
||||
:param url_name: Define the internal (`reverse`) URL name for this action.
|
||||
Defaults to the name of the method decorated with underscores
|
||||
replaced with dashes.
|
||||
:param kwargs: Additional properties to set on the view. This can be used
|
||||
to override viewset-level *_classes settings, equivalent to
|
||||
how the `@renderer_classes` etc. decorators work for function-
|
||||
based API views.
|
||||
"""
|
||||
methods = ['get'] if methods is None else methods
|
||||
methods = [method.lower() for method in methods]
|
||||
|
||||
assert detail is not None, (
|
||||
"@action() missing required argument: 'detail'"
|
||||
)
|
||||
|
||||
# name and suffix are mutually exclusive
|
||||
if 'name' in kwargs and 'suffix' in kwargs:
|
||||
raise TypeError("`name` and `suffix` are mutually exclusive arguments.")
|
||||
|
||||
def decorator(func):
|
||||
func.mapping = MethodMapper(func, methods)
|
||||
|
||||
func.detail = detail
|
||||
func.url_path = url_path if url_path else func.__name__
|
||||
func.url_name = url_name if url_name else func.__name__.replace('_', '-')
|
||||
|
||||
# These kwargs will end up being passed to `ViewSet.as_view()` within
|
||||
# the router, which eventually delegates to Django's CBV `View`,
|
||||
# which assigns them as instance attributes for each request.
|
||||
func.kwargs = kwargs
|
||||
|
||||
# Set descriptive arguments for viewsets
|
||||
if 'name' not in kwargs and 'suffix' not in kwargs:
|
||||
func.kwargs['name'] = pretty_name(func.__name__)
|
||||
func.kwargs['description'] = func.__doc__ or None
|
||||
|
||||
return func
|
||||
return decorator
|
||||
|
||||
|
||||
class MethodMapper(dict):
|
||||
"""
|
||||
Enables mapping HTTP methods to different ViewSet methods for a single,
|
||||
logical action.
|
||||
|
||||
Example usage:
|
||||
|
||||
class MyViewSet(ViewSet):
|
||||
|
||||
@action(detail=False)
|
||||
def example(self, request, **kwargs):
|
||||
...
|
||||
|
||||
@example.mapping.post
|
||||
def create_example(self, request, **kwargs):
|
||||
...
|
||||
"""
|
||||
|
||||
def __init__(self, action, methods):
|
||||
self.action = action
|
||||
for method in methods:
|
||||
self[method] = self.action.__name__
|
||||
|
||||
def _map(self, method, func):
|
||||
assert method not in self, (
|
||||
"Method '%s' has already been mapped to '.%s'." % (method, self[method]))
|
||||
assert func.__name__ != self.action.__name__, (
|
||||
"Method mapping does not behave like the property decorator. You "
|
||||
"cannot use the same method name for each mapping declaration.")
|
||||
|
||||
self[method] = func.__name__
|
||||
|
||||
return func
|
||||
|
||||
def get(self, func):
|
||||
return self._map('get', func)
|
||||
|
||||
def post(self, func):
|
||||
return self._map('post', func)
|
||||
|
||||
def put(self, func):
|
||||
return self._map('put', func)
|
||||
|
||||
def patch(self, func):
|
||||
return self._map('patch', func)
|
||||
|
||||
def delete(self, func):
|
||||
return self._map('delete', func)
|
||||
|
||||
def head(self, func):
|
||||
return self._map('head', func)
|
||||
|
||||
def options(self, func):
|
||||
return self._map('options', func)
|
||||
|
||||
def trace(self, func):
|
||||
return self._map('trace', func)
|
@ -0,0 +1,88 @@
|
||||
from django.urls import include, path
|
||||
|
||||
from rest_framework.renderers import (
|
||||
CoreJSONRenderer, DocumentationRenderer, SchemaJSRenderer
|
||||
)
|
||||
from rest_framework.schemas import SchemaGenerator, get_schema_view
|
||||
from rest_framework.settings import api_settings
|
||||
|
||||
|
||||
def get_docs_view(
|
||||
title=None, description=None, schema_url=None, urlconf=None,
|
||||
public=True, patterns=None, generator_class=SchemaGenerator,
|
||||
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
|
||||
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES,
|
||||
renderer_classes=None):
|
||||
|
||||
if renderer_classes is None:
|
||||
renderer_classes = [DocumentationRenderer, CoreJSONRenderer]
|
||||
|
||||
return get_schema_view(
|
||||
title=title,
|
||||
url=schema_url,
|
||||
urlconf=urlconf,
|
||||
description=description,
|
||||
renderer_classes=renderer_classes,
|
||||
public=public,
|
||||
patterns=patterns,
|
||||
generator_class=generator_class,
|
||||
authentication_classes=authentication_classes,
|
||||
permission_classes=permission_classes,
|
||||
)
|
||||
|
||||
|
||||
def get_schemajs_view(
|
||||
title=None, description=None, schema_url=None, urlconf=None,
|
||||
public=True, patterns=None, generator_class=SchemaGenerator,
|
||||
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
|
||||
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES):
|
||||
renderer_classes = [SchemaJSRenderer]
|
||||
|
||||
return get_schema_view(
|
||||
title=title,
|
||||
url=schema_url,
|
||||
urlconf=urlconf,
|
||||
description=description,
|
||||
renderer_classes=renderer_classes,
|
||||
public=public,
|
||||
patterns=patterns,
|
||||
generator_class=generator_class,
|
||||
authentication_classes=authentication_classes,
|
||||
permission_classes=permission_classes,
|
||||
)
|
||||
|
||||
|
||||
def include_docs_urls(
|
||||
title=None, description=None, schema_url=None, urlconf=None,
|
||||
public=True, patterns=None, generator_class=SchemaGenerator,
|
||||
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
|
||||
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES,
|
||||
renderer_classes=None):
|
||||
docs_view = get_docs_view(
|
||||
title=title,
|
||||
description=description,
|
||||
schema_url=schema_url,
|
||||
urlconf=urlconf,
|
||||
public=public,
|
||||
patterns=patterns,
|
||||
generator_class=generator_class,
|
||||
authentication_classes=authentication_classes,
|
||||
renderer_classes=renderer_classes,
|
||||
permission_classes=permission_classes,
|
||||
)
|
||||
schema_js_view = get_schemajs_view(
|
||||
title=title,
|
||||
description=description,
|
||||
schema_url=schema_url,
|
||||
urlconf=urlconf,
|
||||
public=public,
|
||||
patterns=patterns,
|
||||
generator_class=generator_class,
|
||||
authentication_classes=authentication_classes,
|
||||
permission_classes=permission_classes,
|
||||
)
|
||||
urls = [
|
||||
path('', docs_view, name='docs-index'),
|
||||
path('schema.js', schema_js_view, name='schema-js')
|
||||
]
|
||||
return include((urls, 'api-docs'), namespace='api-docs')
|
@ -0,0 +1,264 @@
|
||||
"""
|
||||
Handled exceptions raised by REST framework.
|
||||
|
||||
In addition, Django's built in 403 and 404 exceptions are handled.
|
||||
(`django.http.Http404` and `django.core.exceptions.PermissionDenied`)
|
||||
"""
|
||||
import math
|
||||
|
||||
from django.http import JsonResponse
|
||||
from django.utils.encoding import force_str
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django.utils.translation import ngettext
|
||||
|
||||
from rest_framework import status
|
||||
from rest_framework.utils.serializer_helpers import ReturnDict, ReturnList
|
||||
|
||||
|
||||
def _get_error_details(data, default_code=None):
|
||||
"""
|
||||
Descend into a nested data structure, forcing any
|
||||
lazy translation strings or strings into `ErrorDetail`.
|
||||
"""
|
||||
if isinstance(data, (list, tuple)):
|
||||
ret = [
|
||||
_get_error_details(item, default_code) for item in data
|
||||
]
|
||||
if isinstance(data, ReturnList):
|
||||
return ReturnList(ret, serializer=data.serializer)
|
||||
return ret
|
||||
elif isinstance(data, dict):
|
||||
ret = {
|
||||
key: _get_error_details(value, default_code)
|
||||
for key, value in data.items()
|
||||
}
|
||||
if isinstance(data, ReturnDict):
|
||||
return ReturnDict(ret, serializer=data.serializer)
|
||||
return ret
|
||||
|
||||
text = force_str(data)
|
||||
code = getattr(data, 'code', default_code)
|
||||
return ErrorDetail(text, code)
|
||||
|
||||
|
||||
def _get_codes(detail):
|
||||
if isinstance(detail, list):
|
||||
return [_get_codes(item) for item in detail]
|
||||
elif isinstance(detail, dict):
|
||||
return {key: _get_codes(value) for key, value in detail.items()}
|
||||
return detail.code
|
||||
|
||||
|
||||
def _get_full_details(detail):
|
||||
if isinstance(detail, list):
|
||||
return [_get_full_details(item) for item in detail]
|
||||
elif isinstance(detail, dict):
|
||||
return {key: _get_full_details(value) for key, value in detail.items()}
|
||||
return {
|
||||
'message': detail,
|
||||
'code': detail.code
|
||||
}
|
||||
|
||||
|
||||
class ErrorDetail(str):
|
||||
"""
|
||||
A string-like object that can additionally have a code.
|
||||
"""
|
||||
code = None
|
||||
|
||||
def __new__(cls, string, code=None):
|
||||
self = super().__new__(cls, string)
|
||||
self.code = code
|
||||
return self
|
||||
|
||||
def __eq__(self, other):
|
||||
result = super().__eq__(other)
|
||||
if result is NotImplemented:
|
||||
return NotImplemented
|
||||
try:
|
||||
return result and self.code == other.code
|
||||
except AttributeError:
|
||||
return result
|
||||
|
||||
def __ne__(self, other):
|
||||
result = self.__eq__(other)
|
||||
if result is NotImplemented:
|
||||
return NotImplemented
|
||||
return not result
|
||||
|
||||
def __repr__(self):
|
||||
return 'ErrorDetail(string=%r, code=%r)' % (
|
||||
str(self),
|
||||
self.code,
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(str(self))
|
||||
|
||||
|
||||
class APIException(Exception):
|
||||
"""
|
||||
Base class for REST framework exceptions.
|
||||
Subclasses should provide `.status_code` and `.default_detail` properties.
|
||||
"""
|
||||
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
default_detail = _('A server error occurred.')
|
||||
default_code = 'error'
|
||||
|
||||
def __init__(self, detail=None, code=None):
|
||||
if detail is None:
|
||||
detail = self.default_detail
|
||||
if code is None:
|
||||
code = self.default_code
|
||||
|
||||
self.detail = _get_error_details(detail, code)
|
||||
|
||||
def __str__(self):
|
||||
return str(self.detail)
|
||||
|
||||
def get_codes(self):
|
||||
"""
|
||||
Return only the code part of the error details.
|
||||
|
||||
Eg. {"name": ["required"]}
|
||||
"""
|
||||
return _get_codes(self.detail)
|
||||
|
||||
def get_full_details(self):
|
||||
"""
|
||||
Return both the message & code parts of the error details.
|
||||
|
||||
Eg. {"name": [{"message": "This field is required.", "code": "required"}]}
|
||||
"""
|
||||
return _get_full_details(self.detail)
|
||||
|
||||
|
||||
# The recommended style for using `ValidationError` is to keep it namespaced
|
||||
# under `serializers`, in order to minimize potential confusion with Django's
|
||||
# built in `ValidationError`. For example:
|
||||
#
|
||||
# from rest_framework import serializers
|
||||
# raise serializers.ValidationError('Value was invalid')
|
||||
|
||||
class ValidationError(APIException):
|
||||
status_code = status.HTTP_400_BAD_REQUEST
|
||||
default_detail = _('Invalid input.')
|
||||
default_code = 'invalid'
|
||||
|
||||
def __init__(self, detail=None, code=None):
|
||||
if detail is None:
|
||||
detail = self.default_detail
|
||||
if code is None:
|
||||
code = self.default_code
|
||||
|
||||
# For validation failures, we may collect many errors together,
|
||||
# so the details should always be coerced to a list if not already.
|
||||
if isinstance(detail, tuple):
|
||||
detail = list(detail)
|
||||
elif not isinstance(detail, dict) and not isinstance(detail, list):
|
||||
detail = [detail]
|
||||
|
||||
self.detail = _get_error_details(detail, code)
|
||||
|
||||
|
||||
class ParseError(APIException):
|
||||
status_code = status.HTTP_400_BAD_REQUEST
|
||||
default_detail = _('Malformed request.')
|
||||
default_code = 'parse_error'
|
||||
|
||||
|
||||
class AuthenticationFailed(APIException):
|
||||
status_code = status.HTTP_401_UNAUTHORIZED
|
||||
default_detail = _('Incorrect authentication credentials.')
|
||||
default_code = 'authentication_failed'
|
||||
|
||||
|
||||
class NotAuthenticated(APIException):
|
||||
status_code = status.HTTP_401_UNAUTHORIZED
|
||||
default_detail = _('Authentication credentials were not provided.')
|
||||
default_code = 'not_authenticated'
|
||||
|
||||
|
||||
class PermissionDenied(APIException):
|
||||
status_code = status.HTTP_403_FORBIDDEN
|
||||
default_detail = _('You do not have permission to perform this action.')
|
||||
default_code = 'permission_denied'
|
||||
|
||||
|
||||
class NotFound(APIException):
|
||||
status_code = status.HTTP_404_NOT_FOUND
|
||||
default_detail = _('Not found.')
|
||||
default_code = 'not_found'
|
||||
|
||||
|
||||
class MethodNotAllowed(APIException):
|
||||
status_code = status.HTTP_405_METHOD_NOT_ALLOWED
|
||||
default_detail = _('Method "{method}" not allowed.')
|
||||
default_code = 'method_not_allowed'
|
||||
|
||||
def __init__(self, method, detail=None, code=None):
|
||||
if detail is None:
|
||||
detail = force_str(self.default_detail).format(method=method)
|
||||
super().__init__(detail, code)
|
||||
|
||||
|
||||
class NotAcceptable(APIException):
|
||||
status_code = status.HTTP_406_NOT_ACCEPTABLE
|
||||
default_detail = _('Could not satisfy the request Accept header.')
|
||||
default_code = 'not_acceptable'
|
||||
|
||||
def __init__(self, detail=None, code=None, available_renderers=None):
|
||||
self.available_renderers = available_renderers
|
||||
super().__init__(detail, code)
|
||||
|
||||
|
||||
class UnsupportedMediaType(APIException):
|
||||
status_code = status.HTTP_415_UNSUPPORTED_MEDIA_TYPE
|
||||
default_detail = _('Unsupported media type "{media_type}" in request.')
|
||||
default_code = 'unsupported_media_type'
|
||||
|
||||
def __init__(self, media_type, detail=None, code=None):
|
||||
if detail is None:
|
||||
detail = force_str(self.default_detail).format(media_type=media_type)
|
||||
super().__init__(detail, code)
|
||||
|
||||
|
||||
class Throttled(APIException):
|
||||
status_code = status.HTTP_429_TOO_MANY_REQUESTS
|
||||
default_detail = _('Request was throttled.')
|
||||
extra_detail_singular = _('Expected available in {wait} second.')
|
||||
extra_detail_plural = _('Expected available in {wait} seconds.')
|
||||
default_code = 'throttled'
|
||||
|
||||
def __init__(self, wait=None, detail=None, code=None):
|
||||
if detail is None:
|
||||
detail = force_str(self.default_detail)
|
||||
if wait is not None:
|
||||
wait = math.ceil(wait)
|
||||
detail = ' '.join((
|
||||
detail,
|
||||
force_str(ngettext(self.extra_detail_singular.format(wait=wait),
|
||||
self.extra_detail_plural.format(wait=wait),
|
||||
wait))))
|
||||
self.wait = wait
|
||||
super().__init__(detail, code)
|
||||
|
||||
|
||||
def server_error(request, *args, **kwargs):
|
||||
"""
|
||||
Generic 500 error handler.
|
||||
"""
|
||||
data = {
|
||||
'error': 'Server Error (500)'
|
||||
}
|
||||
return JsonResponse(data, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
|
||||
|
||||
def bad_request(request, exception, *args, **kwargs):
|
||||
"""
|
||||
Generic 400 error handler.
|
||||
"""
|
||||
data = {
|
||||
'error': 'Bad Request (400)'
|
||||
}
|
||||
return JsonResponse(data, status=status.HTTP_400_BAD_REQUEST)
|
1878
srcs/.venv/lib/python3.11/site-packages/rest_framework/fields.py
Normal file
1878
srcs/.venv/lib/python3.11/site-packages/rest_framework/fields.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,333 @@
|
||||
"""
|
||||
Provides generic filtering backends that can be used to filter the results
|
||||
returned by list views.
|
||||
"""
|
||||
import operator
|
||||
from functools import reduce
|
||||
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db import models
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.template import loader
|
||||
from django.utils.encoding import force_str
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from rest_framework.compat import coreapi, coreschema, distinct
|
||||
from rest_framework.settings import api_settings
|
||||
|
||||
|
||||
class BaseFilterBackend:
|
||||
"""
|
||||
A base class from which all filter backend classes should inherit.
|
||||
"""
|
||||
|
||||
def filter_queryset(self, request, queryset, view):
|
||||
"""
|
||||
Return a filtered queryset.
|
||||
"""
|
||||
raise NotImplementedError(".filter_queryset() must be overridden.")
|
||||
|
||||
def get_schema_fields(self, view):
|
||||
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
|
||||
return []
|
||||
|
||||
def get_schema_operation_parameters(self, view):
|
||||
return []
|
||||
|
||||
|
||||
class SearchFilter(BaseFilterBackend):
|
||||
# The URL query parameter used for the search.
|
||||
search_param = api_settings.SEARCH_PARAM
|
||||
template = 'rest_framework/filters/search.html'
|
||||
lookup_prefixes = {
|
||||
'^': 'istartswith',
|
||||
'=': 'iexact',
|
||||
'@': 'search',
|
||||
'$': 'iregex',
|
||||
}
|
||||
search_title = _('Search')
|
||||
search_description = _('A search term.')
|
||||
|
||||
def get_search_fields(self, view, request):
|
||||
"""
|
||||
Search fields are obtained from the view, but the request is always
|
||||
passed to this method. Sub-classes can override this method to
|
||||
dynamically change the search fields based on request content.
|
||||
"""
|
||||
return getattr(view, 'search_fields', None)
|
||||
|
||||
def get_search_terms(self, request):
|
||||
"""
|
||||
Search terms are set by a ?search=... query parameter,
|
||||
and may be comma and/or whitespace delimited.
|
||||
"""
|
||||
params = request.query_params.get(self.search_param, '')
|
||||
params = params.replace('\x00', '') # strip null characters
|
||||
params = params.replace(',', ' ')
|
||||
return params.split()
|
||||
|
||||
def construct_search(self, field_name):
|
||||
lookup = self.lookup_prefixes.get(field_name[0])
|
||||
if lookup:
|
||||
field_name = field_name[1:]
|
||||
else:
|
||||
lookup = 'icontains'
|
||||
return LOOKUP_SEP.join([field_name, lookup])
|
||||
|
||||
def must_call_distinct(self, queryset, search_fields):
|
||||
"""
|
||||
Return True if 'distinct()' should be used to query the given lookups.
|
||||
"""
|
||||
for search_field in search_fields:
|
||||
opts = queryset.model._meta
|
||||
if search_field[0] in self.lookup_prefixes:
|
||||
search_field = search_field[1:]
|
||||
# Annotated fields do not need to be distinct
|
||||
if isinstance(queryset, models.QuerySet) and search_field in queryset.query.annotations:
|
||||
continue
|
||||
parts = search_field.split(LOOKUP_SEP)
|
||||
for part in parts:
|
||||
field = opts.get_field(part)
|
||||
if hasattr(field, 'get_path_info'):
|
||||
# This field is a relation, update opts to follow the relation
|
||||
path_info = field.get_path_info()
|
||||
opts = path_info[-1].to_opts
|
||||
if any(path.m2m for path in path_info):
|
||||
# This field is a m2m relation so we know we need to call distinct
|
||||
return True
|
||||
else:
|
||||
# This field has a custom __ query transform but is not a relational field.
|
||||
break
|
||||
return False
|
||||
|
||||
def filter_queryset(self, request, queryset, view):
|
||||
search_fields = self.get_search_fields(view, request)
|
||||
search_terms = self.get_search_terms(request)
|
||||
|
||||
if not search_fields or not search_terms:
|
||||
return queryset
|
||||
|
||||
orm_lookups = [
|
||||
self.construct_search(str(search_field))
|
||||
for search_field in search_fields
|
||||
]
|
||||
|
||||
base = queryset
|
||||
conditions = []
|
||||
for search_term in search_terms:
|
||||
queries = [
|
||||
models.Q(**{orm_lookup: search_term})
|
||||
for orm_lookup in orm_lookups
|
||||
]
|
||||
conditions.append(reduce(operator.or_, queries))
|
||||
queryset = queryset.filter(reduce(operator.and_, conditions))
|
||||
|
||||
if self.must_call_distinct(queryset, search_fields):
|
||||
# Filtering against a many-to-many field requires us to
|
||||
# call queryset.distinct() in order to avoid duplicate items
|
||||
# in the resulting queryset.
|
||||
# We try to avoid this if possible, for performance reasons.
|
||||
queryset = distinct(queryset, base)
|
||||
return queryset
|
||||
|
||||
def to_html(self, request, queryset, view):
|
||||
if not getattr(view, 'search_fields', None):
|
||||
return ''
|
||||
|
||||
term = self.get_search_terms(request)
|
||||
term = term[0] if term else ''
|
||||
context = {
|
||||
'param': self.search_param,
|
||||
'term': term
|
||||
}
|
||||
template = loader.get_template(self.template)
|
||||
return template.render(context)
|
||||
|
||||
def get_schema_fields(self, view):
|
||||
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
|
||||
return [
|
||||
coreapi.Field(
|
||||
name=self.search_param,
|
||||
required=False,
|
||||
location='query',
|
||||
schema=coreschema.String(
|
||||
title=force_str(self.search_title),
|
||||
description=force_str(self.search_description)
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
def get_schema_operation_parameters(self, view):
|
||||
return [
|
||||
{
|
||||
'name': self.search_param,
|
||||
'required': False,
|
||||
'in': 'query',
|
||||
'description': force_str(self.search_description),
|
||||
'schema': {
|
||||
'type': 'string',
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class OrderingFilter(BaseFilterBackend):
|
||||
# The URL query parameter used for the ordering.
|
||||
ordering_param = api_settings.ORDERING_PARAM
|
||||
ordering_fields = None
|
||||
ordering_title = _('Ordering')
|
||||
ordering_description = _('Which field to use when ordering the results.')
|
||||
template = 'rest_framework/filters/ordering.html'
|
||||
|
||||
def get_ordering(self, request, queryset, view):
|
||||
"""
|
||||
Ordering is set by a comma delimited ?ordering=... query parameter.
|
||||
|
||||
The `ordering` query parameter can be overridden by setting
|
||||
the `ordering_param` value on the OrderingFilter or by
|
||||
specifying an `ORDERING_PARAM` value in the API settings.
|
||||
"""
|
||||
params = request.query_params.get(self.ordering_param)
|
||||
if params:
|
||||
fields = [param.strip() for param in params.split(',')]
|
||||
ordering = self.remove_invalid_fields(queryset, fields, view, request)
|
||||
if ordering:
|
||||
return ordering
|
||||
|
||||
# No ordering was included, or all the ordering fields were invalid
|
||||
return self.get_default_ordering(view)
|
||||
|
||||
def get_default_ordering(self, view):
|
||||
ordering = getattr(view, 'ordering', None)
|
||||
if isinstance(ordering, str):
|
||||
return (ordering,)
|
||||
return ordering
|
||||
|
||||
def get_default_valid_fields(self, queryset, view, context={}):
|
||||
# If `ordering_fields` is not specified, then we determine a default
|
||||
# based on the serializer class, if one exists on the view.
|
||||
if hasattr(view, 'get_serializer_class'):
|
||||
try:
|
||||
serializer_class = view.get_serializer_class()
|
||||
except AssertionError:
|
||||
# Raised by the default implementation if
|
||||
# no serializer_class was found
|
||||
serializer_class = None
|
||||
else:
|
||||
serializer_class = getattr(view, 'serializer_class', None)
|
||||
|
||||
if serializer_class is None:
|
||||
msg = (
|
||||
"Cannot use %s on a view which does not have either a "
|
||||
"'serializer_class', an overriding 'get_serializer_class' "
|
||||
"or 'ordering_fields' attribute."
|
||||
)
|
||||
raise ImproperlyConfigured(msg % self.__class__.__name__)
|
||||
|
||||
model_class = queryset.model
|
||||
model_property_names = [
|
||||
# 'pk' is a property added in Django's Model class, however it is valid for ordering.
|
||||
attr for attr in dir(model_class) if isinstance(getattr(model_class, attr), property) and attr != 'pk'
|
||||
]
|
||||
|
||||
return [
|
||||
(field.source.replace('.', '__') or field_name, field.label)
|
||||
for field_name, field in serializer_class(context=context).fields.items()
|
||||
if (
|
||||
not getattr(field, 'write_only', False) and
|
||||
not field.source == '*' and
|
||||
field.source not in model_property_names
|
||||
)
|
||||
]
|
||||
|
||||
def get_valid_fields(self, queryset, view, context={}):
|
||||
valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)
|
||||
|
||||
if valid_fields is None:
|
||||
# Default to allowing filtering on serializer fields
|
||||
return self.get_default_valid_fields(queryset, view, context)
|
||||
|
||||
elif valid_fields == '__all__':
|
||||
# View explicitly allows filtering on any model field
|
||||
valid_fields = [
|
||||
(field.name, field.verbose_name) for field in queryset.model._meta.fields
|
||||
]
|
||||
valid_fields += [
|
||||
(key, key.title().split('__'))
|
||||
for key in queryset.query.annotations
|
||||
]
|
||||
else:
|
||||
valid_fields = [
|
||||
(item, item) if isinstance(item, str) else item
|
||||
for item in valid_fields
|
||||
]
|
||||
|
||||
return valid_fields
|
||||
|
||||
def remove_invalid_fields(self, queryset, fields, view, request):
|
||||
valid_fields = [item[0] for item in self.get_valid_fields(queryset, view, {'request': request})]
|
||||
|
||||
def term_valid(term):
|
||||
if term.startswith("-"):
|
||||
term = term[1:]
|
||||
return term in valid_fields
|
||||
|
||||
return [term for term in fields if term_valid(term)]
|
||||
|
||||
def filter_queryset(self, request, queryset, view):
|
||||
ordering = self.get_ordering(request, queryset, view)
|
||||
|
||||
if ordering:
|
||||
return queryset.order_by(*ordering)
|
||||
|
||||
return queryset
|
||||
|
||||
def get_template_context(self, request, queryset, view):
|
||||
current = self.get_ordering(request, queryset, view)
|
||||
current = None if not current else current[0]
|
||||
options = []
|
||||
context = {
|
||||
'request': request,
|
||||
'current': current,
|
||||
'param': self.ordering_param,
|
||||
}
|
||||
for key, label in self.get_valid_fields(queryset, view, context):
|
||||
options.append((key, '%s - %s' % (label, _('ascending'))))
|
||||
options.append(('-' + key, '%s - %s' % (label, _('descending'))))
|
||||
context['options'] = options
|
||||
return context
|
||||
|
||||
def to_html(self, request, queryset, view):
|
||||
template = loader.get_template(self.template)
|
||||
context = self.get_template_context(request, queryset, view)
|
||||
return template.render(context)
|
||||
|
||||
def get_schema_fields(self, view):
|
||||
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
|
||||
return [
|
||||
coreapi.Field(
|
||||
name=self.ordering_param,
|
||||
required=False,
|
||||
location='query',
|
||||
schema=coreschema.String(
|
||||
title=force_str(self.ordering_title),
|
||||
description=force_str(self.ordering_description)
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
def get_schema_operation_parameters(self, view):
|
||||
return [
|
||||
{
|
||||
'name': self.ordering_param,
|
||||
'required': False,
|
||||
'in': 'query',
|
||||
'description': force_str(self.ordering_description),
|
||||
'schema': {
|
||||
'type': 'string',
|
||||
},
|
||||
},
|
||||
]
|
@ -0,0 +1,291 @@
|
||||
"""
|
||||
Generic views that provide commonly needed behaviour.
|
||||
"""
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db.models.query import QuerySet
|
||||
from django.http import Http404
|
||||
from django.shortcuts import get_object_or_404 as _get_object_or_404
|
||||
|
||||
from rest_framework import mixins, views
|
||||
from rest_framework.settings import api_settings
|
||||
|
||||
|
||||
def get_object_or_404(queryset, *filter_args, **filter_kwargs):
|
||||
"""
|
||||
Same as Django's standard shortcut, but make sure to also raise 404
|
||||
if the filter_kwargs don't match the required types.
|
||||
"""
|
||||
try:
|
||||
return _get_object_or_404(queryset, *filter_args, **filter_kwargs)
|
||||
except (TypeError, ValueError, ValidationError):
|
||||
raise Http404
|
||||
|
||||
|
||||
class GenericAPIView(views.APIView):
|
||||
"""
|
||||
Base class for all other generic views.
|
||||
"""
|
||||
# You'll need to either set these attributes,
|
||||
# or override `get_queryset()`/`get_serializer_class()`.
|
||||
# If you are overriding a view method, it is important that you call
|
||||
# `get_queryset()` instead of accessing the `queryset` property directly,
|
||||
# as `queryset` will get evaluated only once, and those results are cached
|
||||
# for all subsequent requests.
|
||||
queryset = None
|
||||
serializer_class = None
|
||||
|
||||
# If you want to use object lookups other than pk, set 'lookup_field'.
|
||||
# For more complex lookup requirements override `get_object()`.
|
||||
lookup_field = 'pk'
|
||||
lookup_url_kwarg = None
|
||||
|
||||
# The filter backend classes to use for queryset filtering
|
||||
filter_backends = api_settings.DEFAULT_FILTER_BACKENDS
|
||||
|
||||
# The style to use for queryset pagination.
|
||||
pagination_class = api_settings.DEFAULT_PAGINATION_CLASS
|
||||
|
||||
def get_queryset(self):
|
||||
"""
|
||||
Get the list of items for this view.
|
||||
This must be an iterable, and may be a queryset.
|
||||
Defaults to using `self.queryset`.
|
||||
|
||||
This method should always be used rather than accessing `self.queryset`
|
||||
directly, as `self.queryset` gets evaluated only once, and those results
|
||||
are cached for all subsequent requests.
|
||||
|
||||
You may want to override this if you need to provide different
|
||||
querysets depending on the incoming request.
|
||||
|
||||
(Eg. return a list of items that is specific to the user)
|
||||
"""
|
||||
assert self.queryset is not None, (
|
||||
"'%s' should either include a `queryset` attribute, "
|
||||
"or override the `get_queryset()` method."
|
||||
% self.__class__.__name__
|
||||
)
|
||||
|
||||
queryset = self.queryset
|
||||
if isinstance(queryset, QuerySet):
|
||||
# Ensure queryset is re-evaluated on each request.
|
||||
queryset = queryset.all()
|
||||
return queryset
|
||||
|
||||
def get_object(self):
|
||||
"""
|
||||
Returns the object the view is displaying.
|
||||
|
||||
You may want to override this if you need to provide non-standard
|
||||
queryset lookups. Eg if objects are referenced using multiple
|
||||
keyword arguments in the url conf.
|
||||
"""
|
||||
queryset = self.filter_queryset(self.get_queryset())
|
||||
|
||||
# Perform the lookup filtering.
|
||||
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
|
||||
|
||||
assert lookup_url_kwarg in self.kwargs, (
|
||||
'Expected view %s to be called with a URL keyword argument '
|
||||
'named "%s". Fix your URL conf, or set the `.lookup_field` '
|
||||
'attribute on the view correctly.' %
|
||||
(self.__class__.__name__, lookup_url_kwarg)
|
||||
)
|
||||
|
||||
filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]}
|
||||
obj = get_object_or_404(queryset, **filter_kwargs)
|
||||
|
||||
# May raise a permission denied
|
||||
self.check_object_permissions(self.request, obj)
|
||||
|
||||
return obj
|
||||
|
||||
def get_serializer(self, *args, **kwargs):
|
||||
"""
|
||||
Return the serializer instance that should be used for validating and
|
||||
deserializing input, and for serializing output.
|
||||
"""
|
||||
serializer_class = self.get_serializer_class()
|
||||
kwargs.setdefault('context', self.get_serializer_context())
|
||||
return serializer_class(*args, **kwargs)
|
||||
|
||||
def get_serializer_class(self):
|
||||
"""
|
||||
Return the class to use for the serializer.
|
||||
Defaults to using `self.serializer_class`.
|
||||
|
||||
You may want to override this if you need to provide different
|
||||
serializations depending on the incoming request.
|
||||
|
||||
(Eg. admins get full serialization, others get basic serialization)
|
||||
"""
|
||||
assert self.serializer_class is not None, (
|
||||
"'%s' should either include a `serializer_class` attribute, "
|
||||
"or override the `get_serializer_class()` method."
|
||||
% self.__class__.__name__
|
||||
)
|
||||
|
||||
return self.serializer_class
|
||||
|
||||
def get_serializer_context(self):
|
||||
"""
|
||||
Extra context provided to the serializer class.
|
||||
"""
|
||||
return {
|
||||
'request': self.request,
|
||||
'format': self.format_kwarg,
|
||||
'view': self
|
||||
}
|
||||
|
||||
def filter_queryset(self, queryset):
|
||||
"""
|
||||
Given a queryset, filter it with whichever filter backend is in use.
|
||||
|
||||
You are unlikely to want to override this method, although you may need
|
||||
to call it either from a list view, or from a custom `get_object`
|
||||
method if you want to apply the configured filtering backend to the
|
||||
default queryset.
|
||||
"""
|
||||
for backend in list(self.filter_backends):
|
||||
queryset = backend().filter_queryset(self.request, queryset, self)
|
||||
return queryset
|
||||
|
||||
@property
|
||||
def paginator(self):
|
||||
"""
|
||||
The paginator instance associated with the view, or `None`.
|
||||
"""
|
||||
if not hasattr(self, '_paginator'):
|
||||
if self.pagination_class is None:
|
||||
self._paginator = None
|
||||
else:
|
||||
self._paginator = self.pagination_class()
|
||||
return self._paginator
|
||||
|
||||
def paginate_queryset(self, queryset):
|
||||
"""
|
||||
Return a single page of results, or `None` if pagination is disabled.
|
||||
"""
|
||||
if self.paginator is None:
|
||||
return None
|
||||
return self.paginator.paginate_queryset(queryset, self.request, view=self)
|
||||
|
||||
def get_paginated_response(self, data):
|
||||
"""
|
||||
Return a paginated style `Response` object for the given output data.
|
||||
"""
|
||||
assert self.paginator is not None
|
||||
return self.paginator.get_paginated_response(data)
|
||||
|
||||
|
||||
# Concrete view classes that provide method handlers
|
||||
# by composing the mixin classes with the base view.
|
||||
|
||||
class CreateAPIView(mixins.CreateModelMixin,
|
||||
GenericAPIView):
|
||||
"""
|
||||
Concrete view for creating a model instance.
|
||||
"""
|
||||
def post(self, request, *args, **kwargs):
|
||||
return self.create(request, *args, **kwargs)
|
||||
|
||||
|
||||
class ListAPIView(mixins.ListModelMixin,
|
||||
GenericAPIView):
|
||||
"""
|
||||
Concrete view for listing a queryset.
|
||||
"""
|
||||
def get(self, request, *args, **kwargs):
|
||||
return self.list(request, *args, **kwargs)
|
||||
|
||||
|
||||
class RetrieveAPIView(mixins.RetrieveModelMixin,
|
||||
GenericAPIView):
|
||||
"""
|
||||
Concrete view for retrieving a model instance.
|
||||
"""
|
||||
def get(self, request, *args, **kwargs):
|
||||
return self.retrieve(request, *args, **kwargs)
|
||||
|
||||
|
||||
class DestroyAPIView(mixins.DestroyModelMixin,
|
||||
GenericAPIView):
|
||||
"""
|
||||
Concrete view for deleting a model instance.
|
||||
"""
|
||||
def delete(self, request, *args, **kwargs):
|
||||
return self.destroy(request, *args, **kwargs)
|
||||
|
||||
|
||||
class UpdateAPIView(mixins.UpdateModelMixin,
|
||||
GenericAPIView):
|
||||
"""
|
||||
Concrete view for updating a model instance.
|
||||
"""
|
||||
def put(self, request, *args, **kwargs):
|
||||
return self.update(request, *args, **kwargs)
|
||||
|
||||
def patch(self, request, *args, **kwargs):
|
||||
return self.partial_update(request, *args, **kwargs)
|
||||
|
||||
|
||||
class ListCreateAPIView(mixins.ListModelMixin,
|
||||
mixins.CreateModelMixin,
|
||||
GenericAPIView):
|
||||
"""
|
||||
Concrete view for listing a queryset or creating a model instance.
|
||||
"""
|
||||
def get(self, request, *args, **kwargs):
|
||||
return self.list(request, *args, **kwargs)
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
return self.create(request, *args, **kwargs)
|
||||
|
||||
|
||||
class RetrieveUpdateAPIView(mixins.RetrieveModelMixin,
|
||||
mixins.UpdateModelMixin,
|
||||
GenericAPIView):
|
||||
"""
|
||||
Concrete view for retrieving, updating a model instance.
|
||||
"""
|
||||
def get(self, request, *args, **kwargs):
|
||||
return self.retrieve(request, *args, **kwargs)
|
||||
|
||||
def put(self, request, *args, **kwargs):
|
||||
return self.update(request, *args, **kwargs)
|
||||
|
||||
def patch(self, request, *args, **kwargs):
|
||||
return self.partial_update(request, *args, **kwargs)
|
||||
|
||||
|
||||
class RetrieveDestroyAPIView(mixins.RetrieveModelMixin,
|
||||
mixins.DestroyModelMixin,
|
||||
GenericAPIView):
|
||||
"""
|
||||
Concrete view for retrieving or deleting a model instance.
|
||||
"""
|
||||
def get(self, request, *args, **kwargs):
|
||||
return self.retrieve(request, *args, **kwargs)
|
||||
|
||||
def delete(self, request, *args, **kwargs):
|
||||
return self.destroy(request, *args, **kwargs)
|
||||
|
||||
|
||||
class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,
|
||||
mixins.UpdateModelMixin,
|
||||
mixins.DestroyModelMixin,
|
||||
GenericAPIView):
|
||||
"""
|
||||
Concrete view for retrieving, updating or deleting a model instance.
|
||||
"""
|
||||
def get(self, request, *args, **kwargs):
|
||||
return self.retrieve(request, *args, **kwargs)
|
||||
|
||||
def put(self, request, *args, **kwargs):
|
||||
return self.update(request, *args, **kwargs)
|
||||
|
||||
def patch(self, request, *args, **kwargs):
|
||||
return self.partial_update(request, *args, **kwargs)
|
||||
|
||||
def delete(self, request, *args, **kwargs):
|
||||
return self.destroy(request, *args, **kwargs)
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,71 @@
|
||||
from django.core.management.base import BaseCommand
|
||||
from django.utils.module_loading import import_string
|
||||
|
||||
from rest_framework import renderers
|
||||
from rest_framework.schemas import coreapi
|
||||
from rest_framework.schemas.openapi import SchemaGenerator
|
||||
|
||||
OPENAPI_MODE = 'openapi'
|
||||
COREAPI_MODE = 'coreapi'
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = "Generates configured API schema for project."
|
||||
|
||||
def get_mode(self):
|
||||
return COREAPI_MODE if coreapi.is_enabled() else OPENAPI_MODE
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument('--title', dest="title", default='', type=str)
|
||||
parser.add_argument('--url', dest="url", default=None, type=str)
|
||||
parser.add_argument('--description', dest="description", default=None, type=str)
|
||||
if self.get_mode() == COREAPI_MODE:
|
||||
parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json', 'corejson'], default='openapi', type=str)
|
||||
else:
|
||||
parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json'], default='openapi', type=str)
|
||||
parser.add_argument('--urlconf', dest="urlconf", default=None, type=str)
|
||||
parser.add_argument('--generator_class', dest="generator_class", default=None, type=str)
|
||||
parser.add_argument('--file', dest="file", default=None, type=str)
|
||||
parser.add_argument('--api_version', dest="api_version", default='', type=str)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
if options['generator_class']:
|
||||
generator_class = import_string(options['generator_class'])
|
||||
else:
|
||||
generator_class = self.get_generator_class()
|
||||
generator = generator_class(
|
||||
url=options['url'],
|
||||
title=options['title'],
|
||||
description=options['description'],
|
||||
urlconf=options['urlconf'],
|
||||
version=options['api_version'],
|
||||
)
|
||||
schema = generator.get_schema(request=None, public=True)
|
||||
renderer = self.get_renderer(options['format'])
|
||||
output = renderer.render(schema, renderer_context={})
|
||||
|
||||
if options['file']:
|
||||
with open(options['file'], 'wb') as f:
|
||||
f.write(output)
|
||||
else:
|
||||
self.stdout.write(output.decode())
|
||||
|
||||
def get_renderer(self, format):
|
||||
if self.get_mode() == COREAPI_MODE:
|
||||
renderer_cls = {
|
||||
'corejson': renderers.CoreJSONRenderer,
|
||||
'openapi': renderers.CoreAPIOpenAPIRenderer,
|
||||
'openapi-json': renderers.CoreAPIJSONOpenAPIRenderer,
|
||||
}[format]
|
||||
return renderer_cls()
|
||||
|
||||
renderer_cls = {
|
||||
'openapi': renderers.OpenAPIRenderer,
|
||||
'openapi-json': renderers.JSONOpenAPIRenderer,
|
||||
}[format]
|
||||
return renderer_cls()
|
||||
|
||||
def get_generator_class(self):
|
||||
if self.get_mode() == COREAPI_MODE:
|
||||
return coreapi.SchemaGenerator
|
||||
return SchemaGenerator
|
@ -0,0 +1,150 @@
|
||||
"""
|
||||
The metadata API is used to allow customization of how `OPTIONS` requests
|
||||
are handled. We currently provide a single default implementation that returns
|
||||
some fairly ad-hoc information about the view.
|
||||
|
||||
Future implementations might use JSON schema or other definitions in order
|
||||
to return this information in a more standardized way.
|
||||
"""
|
||||
from collections import OrderedDict
|
||||
|
||||
from django.core.exceptions import PermissionDenied
|
||||
from django.http import Http404
|
||||
from django.utils.encoding import force_str
|
||||
|
||||
from rest_framework import exceptions, serializers
|
||||
from rest_framework.request import clone_request
|
||||
from rest_framework.utils.field_mapping import ClassLookupDict
|
||||
|
||||
|
||||
class BaseMetadata:
|
||||
def determine_metadata(self, request, view):
|
||||
"""
|
||||
Return a dictionary of metadata about the view.
|
||||
Used to return responses for OPTIONS requests.
|
||||
"""
|
||||
raise NotImplementedError(".determine_metadata() must be overridden.")
|
||||
|
||||
|
||||
class SimpleMetadata(BaseMetadata):
|
||||
"""
|
||||
This is the default metadata implementation.
|
||||
It returns an ad-hoc set of information about the view.
|
||||
There are not any formalized standards for `OPTIONS` responses
|
||||
for us to base this on.
|
||||
"""
|
||||
label_lookup = ClassLookupDict({
|
||||
serializers.Field: 'field',
|
||||
serializers.BooleanField: 'boolean',
|
||||
serializers.CharField: 'string',
|
||||
serializers.UUIDField: 'string',
|
||||
serializers.URLField: 'url',
|
||||
serializers.EmailField: 'email',
|
||||
serializers.RegexField: 'regex',
|
||||
serializers.SlugField: 'slug',
|
||||
serializers.IntegerField: 'integer',
|
||||
serializers.FloatField: 'float',
|
||||
serializers.DecimalField: 'decimal',
|
||||
serializers.DateField: 'date',
|
||||
serializers.DateTimeField: 'datetime',
|
||||
serializers.TimeField: 'time',
|
||||
serializers.ChoiceField: 'choice',
|
||||
serializers.MultipleChoiceField: 'multiple choice',
|
||||
serializers.FileField: 'file upload',
|
||||
serializers.ImageField: 'image upload',
|
||||
serializers.ListField: 'list',
|
||||
serializers.DictField: 'nested object',
|
||||
serializers.Serializer: 'nested object',
|
||||
})
|
||||
|
||||
def determine_metadata(self, request, view):
|
||||
metadata = OrderedDict()
|
||||
metadata['name'] = view.get_view_name()
|
||||
metadata['description'] = view.get_view_description()
|
||||
metadata['renders'] = [renderer.media_type for renderer in view.renderer_classes]
|
||||
metadata['parses'] = [parser.media_type for parser in view.parser_classes]
|
||||
if hasattr(view, 'get_serializer'):
|
||||
actions = self.determine_actions(request, view)
|
||||
if actions:
|
||||
metadata['actions'] = actions
|
||||
return metadata
|
||||
|
||||
def determine_actions(self, request, view):
|
||||
"""
|
||||
For generic class based views we return information about
|
||||
the fields that are accepted for 'PUT' and 'POST' methods.
|
||||
"""
|
||||
actions = {}
|
||||
for method in {'PUT', 'POST'} & set(view.allowed_methods):
|
||||
view.request = clone_request(request, method)
|
||||
try:
|
||||
# Test global permissions
|
||||
if hasattr(view, 'check_permissions'):
|
||||
view.check_permissions(view.request)
|
||||
# Test object permissions
|
||||
if method == 'PUT' and hasattr(view, 'get_object'):
|
||||
view.get_object()
|
||||
except (exceptions.APIException, PermissionDenied, Http404):
|
||||
pass
|
||||
else:
|
||||
# If user has appropriate permissions for the view, include
|
||||
# appropriate metadata about the fields that should be supplied.
|
||||
serializer = view.get_serializer()
|
||||
actions[method] = self.get_serializer_info(serializer)
|
||||
finally:
|
||||
view.request = request
|
||||
|
||||
return actions
|
||||
|
||||
def get_serializer_info(self, serializer):
|
||||
"""
|
||||
Given an instance of a serializer, return a dictionary of metadata
|
||||
about its fields.
|
||||
"""
|
||||
if hasattr(serializer, 'child'):
|
||||
# If this is a `ListSerializer` then we want to examine the
|
||||
# underlying child serializer instance instead.
|
||||
serializer = serializer.child
|
||||
return OrderedDict([
|
||||
(field_name, self.get_field_info(field))
|
||||
for field_name, field in serializer.fields.items()
|
||||
if not isinstance(field, serializers.HiddenField)
|
||||
])
|
||||
|
||||
def get_field_info(self, field):
|
||||
"""
|
||||
Given an instance of a serializer field, return a dictionary
|
||||
of metadata about it.
|
||||
"""
|
||||
field_info = OrderedDict()
|
||||
field_info['type'] = self.label_lookup[field]
|
||||
field_info['required'] = getattr(field, 'required', False)
|
||||
|
||||
attrs = [
|
||||
'read_only', 'label', 'help_text',
|
||||
'min_length', 'max_length',
|
||||
'min_value', 'max_value'
|
||||
]
|
||||
|
||||
for attr in attrs:
|
||||
value = getattr(field, attr, None)
|
||||
if value is not None and value != '':
|
||||
field_info[attr] = force_str(value, strings_only=True)
|
||||
|
||||
if getattr(field, 'child', None):
|
||||
field_info['child'] = self.get_field_info(field.child)
|
||||
elif getattr(field, 'fields', None):
|
||||
field_info['children'] = self.get_serializer_info(field)
|
||||
|
||||
if (not field_info.get('read_only') and
|
||||
not isinstance(field, (serializers.RelatedField, serializers.ManyRelatedField)) and
|
||||
hasattr(field, 'choices')):
|
||||
field_info['choices'] = [
|
||||
{
|
||||
'value': choice_value,
|
||||
'display_name': force_str(choice_name, strings_only=True)
|
||||
}
|
||||
for choice_value, choice_name in field.choices.items()
|
||||
]
|
||||
|
||||
return field_info
|
@ -0,0 +1,95 @@
|
||||
"""
|
||||
Basic building blocks for generic class based views.
|
||||
|
||||
We don't bind behaviour to http method handlers yet,
|
||||
which allows mixin classes to be composed in interesting ways.
|
||||
"""
|
||||
from rest_framework import status
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.settings import api_settings
|
||||
|
||||
|
||||
class CreateModelMixin:
|
||||
"""
|
||||
Create a model instance.
|
||||
"""
|
||||
def create(self, request, *args, **kwargs):
|
||||
serializer = self.get_serializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
self.perform_create(serializer)
|
||||
headers = self.get_success_headers(serializer.data)
|
||||
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
|
||||
|
||||
def perform_create(self, serializer):
|
||||
serializer.save()
|
||||
|
||||
def get_success_headers(self, data):
|
||||
try:
|
||||
return {'Location': str(data[api_settings.URL_FIELD_NAME])}
|
||||
except (TypeError, KeyError):
|
||||
return {}
|
||||
|
||||
|
||||
class ListModelMixin:
|
||||
"""
|
||||
List a queryset.
|
||||
"""
|
||||
def list(self, request, *args, **kwargs):
|
||||
queryset = self.filter_queryset(self.get_queryset())
|
||||
|
||||
page = self.paginate_queryset(queryset)
|
||||
if page is not None:
|
||||
serializer = self.get_serializer(page, many=True)
|
||||
return self.get_paginated_response(serializer.data)
|
||||
|
||||
serializer = self.get_serializer(queryset, many=True)
|
||||
return Response(serializer.data)
|
||||
|
||||
|
||||
class RetrieveModelMixin:
|
||||
"""
|
||||
Retrieve a model instance.
|
||||
"""
|
||||
def retrieve(self, request, *args, **kwargs):
|
||||
instance = self.get_object()
|
||||
serializer = self.get_serializer(instance)
|
||||
return Response(serializer.data)
|
||||
|
||||
|
||||
class UpdateModelMixin:
|
||||
"""
|
||||
Update a model instance.
|
||||
"""
|
||||
def update(self, request, *args, **kwargs):
|
||||
partial = kwargs.pop('partial', False)
|
||||
instance = self.get_object()
|
||||
serializer = self.get_serializer(instance, data=request.data, partial=partial)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
self.perform_update(serializer)
|
||||
|
||||
if getattr(instance, '_prefetched_objects_cache', None):
|
||||
# If 'prefetch_related' has been applied to a queryset, we need to
|
||||
# forcibly invalidate the prefetch cache on the instance.
|
||||
instance._prefetched_objects_cache = {}
|
||||
|
||||
return Response(serializer.data)
|
||||
|
||||
def perform_update(self, serializer):
|
||||
serializer.save()
|
||||
|
||||
def partial_update(self, request, *args, **kwargs):
|
||||
kwargs['partial'] = True
|
||||
return self.update(request, *args, **kwargs)
|
||||
|
||||
|
||||
class DestroyModelMixin:
|
||||
"""
|
||||
Destroy a model instance.
|
||||
"""
|
||||
def destroy(self, request, *args, **kwargs):
|
||||
instance = self.get_object()
|
||||
self.perform_destroy(instance)
|
||||
return Response(status=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
def perform_destroy(self, instance):
|
||||
instance.delete()
|
@ -0,0 +1,97 @@
|
||||
"""
|
||||
Content negotiation deals with selecting an appropriate renderer given the
|
||||
incoming request. Typically this will be based on the request's Accept header.
|
||||
"""
|
||||
from django.http import Http404
|
||||
|
||||
from rest_framework import exceptions
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.utils.mediatypes import (
|
||||
_MediaType, media_type_matches, order_by_precedence
|
||||
)
|
||||
|
||||
|
||||
class BaseContentNegotiation:
|
||||
def select_parser(self, request, parsers):
|
||||
raise NotImplementedError('.select_parser() must be implemented')
|
||||
|
||||
def select_renderer(self, request, renderers, format_suffix=None):
|
||||
raise NotImplementedError('.select_renderer() must be implemented')
|
||||
|
||||
|
||||
class DefaultContentNegotiation(BaseContentNegotiation):
|
||||
settings = api_settings
|
||||
|
||||
def select_parser(self, request, parsers):
|
||||
"""
|
||||
Given a list of parsers and a media type, return the appropriate
|
||||
parser to handle the incoming request.
|
||||
"""
|
||||
for parser in parsers:
|
||||
if media_type_matches(parser.media_type, request.content_type):
|
||||
return parser
|
||||
return None
|
||||
|
||||
def select_renderer(self, request, renderers, format_suffix=None):
|
||||
"""
|
||||
Given a request and a list of renderers, return a two-tuple of:
|
||||
(renderer, media type).
|
||||
"""
|
||||
# Allow URL style format override. eg. "?format=json
|
||||
format_query_param = self.settings.URL_FORMAT_OVERRIDE
|
||||
format = format_suffix or request.query_params.get(format_query_param)
|
||||
|
||||
if format:
|
||||
renderers = self.filter_renderers(renderers, format)
|
||||
|
||||
accepts = self.get_accept_list(request)
|
||||
|
||||
# Check the acceptable media types against each renderer,
|
||||
# attempting more specific media types first
|
||||
# NB. The inner loop here isn't as bad as it first looks :)
|
||||
# Worst case is we're looping over len(accept_list) * len(self.renderers)
|
||||
for media_type_set in order_by_precedence(accepts):
|
||||
for renderer in renderers:
|
||||
for media_type in media_type_set:
|
||||
if media_type_matches(renderer.media_type, media_type):
|
||||
# Return the most specific media type as accepted.
|
||||
media_type_wrapper = _MediaType(media_type)
|
||||
if (
|
||||
_MediaType(renderer.media_type).precedence >
|
||||
media_type_wrapper.precedence
|
||||
):
|
||||
# Eg client requests '*/*'
|
||||
# Accepted media type is 'application/json'
|
||||
full_media_type = ';'.join(
|
||||
(renderer.media_type,) +
|
||||
tuple(
|
||||
'{}={}'.format(key, value)
|
||||
for key, value in media_type_wrapper.params.items()
|
||||
)
|
||||
)
|
||||
return renderer, full_media_type
|
||||
else:
|
||||
# Eg client requests 'application/json; indent=8'
|
||||
# Accepted media type is 'application/json; indent=8'
|
||||
return renderer, media_type
|
||||
|
||||
raise exceptions.NotAcceptable(available_renderers=renderers)
|
||||
|
||||
def filter_renderers(self, renderers, format):
|
||||
"""
|
||||
If there is a '.json' style format suffix, filter the renderers
|
||||
so that we only negotiation against those that accept that format.
|
||||
"""
|
||||
renderers = [renderer for renderer in renderers
|
||||
if renderer.format == format]
|
||||
if not renderers:
|
||||
raise Http404
|
||||
return renderers
|
||||
|
||||
def get_accept_list(self, request):
|
||||
"""
|
||||
Given the incoming request, return a tokenized list of media
|
||||
type strings.
|
||||
"""
|
||||
header = request.META.get('HTTP_ACCEPT', '*/*')
|
||||
return [token.strip() for token in header.split(',')]
|
@ -0,0 +1,980 @@
|
||||
"""
|
||||
Pagination serializers determine the structure of the output that should
|
||||
be used for paginated responses.
|
||||
"""
|
||||
from base64 import b64decode, b64encode
|
||||
from collections import OrderedDict, namedtuple
|
||||
from urllib import parse
|
||||
|
||||
from django.core.paginator import InvalidPage
|
||||
from django.core.paginator import Paginator as DjangoPaginator
|
||||
from django.template import loader
|
||||
from django.utils.encoding import force_str
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from rest_framework.compat import coreapi, coreschema
|
||||
from rest_framework.exceptions import NotFound
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.utils.urls import remove_query_param, replace_query_param
|
||||
|
||||
|
||||
def _positive_int(integer_string, strict=False, cutoff=None):
|
||||
"""
|
||||
Cast a string to a strictly positive integer.
|
||||
"""
|
||||
ret = int(integer_string)
|
||||
if ret < 0 or (ret == 0 and strict):
|
||||
raise ValueError()
|
||||
if cutoff:
|
||||
return min(ret, cutoff)
|
||||
return ret
|
||||
|
||||
|
||||
def _divide_with_ceil(a, b):
|
||||
"""
|
||||
Returns 'a' divided by 'b', with any remainder rounded up.
|
||||
"""
|
||||
if a % b:
|
||||
return (a // b) + 1
|
||||
|
||||
return a // b
|
||||
|
||||
|
||||
def _get_displayed_page_numbers(current, final):
|
||||
"""
|
||||
This utility function determines a list of page numbers to display.
|
||||
This gives us a nice contextually relevant set of page numbers.
|
||||
|
||||
For example:
|
||||
current=14, final=16 -> [1, None, 13, 14, 15, 16]
|
||||
|
||||
This implementation gives one page to each side of the cursor,
|
||||
or two pages to the side when the cursor is at the edge, then
|
||||
ensures that any breaks between non-continuous page numbers never
|
||||
remove only a single page.
|
||||
|
||||
For an alternative implementation which gives two pages to each side of
|
||||
the cursor, eg. as in GitHub issue list pagination, see:
|
||||
|
||||
https://gist.github.com/tomchristie/321140cebb1c4a558b15
|
||||
"""
|
||||
assert current >= 1
|
||||
assert final >= current
|
||||
|
||||
if final <= 5:
|
||||
return list(range(1, final + 1))
|
||||
|
||||
# We always include the first two pages, last two pages, and
|
||||
# two pages either side of the current page.
|
||||
included = {1, current - 1, current, current + 1, final}
|
||||
|
||||
# If the break would only exclude a single page number then we
|
||||
# may as well include the page number instead of the break.
|
||||
if current <= 4:
|
||||
included.add(2)
|
||||
included.add(3)
|
||||
if current >= final - 3:
|
||||
included.add(final - 1)
|
||||
included.add(final - 2)
|
||||
|
||||
# Now sort the page numbers and drop anything outside the limits.
|
||||
included = [
|
||||
idx for idx in sorted(included)
|
||||
if 0 < idx <= final
|
||||
]
|
||||
|
||||
# Finally insert any `...` breaks
|
||||
if current > 4:
|
||||
included.insert(1, None)
|
||||
if current < final - 3:
|
||||
included.insert(len(included) - 1, None)
|
||||
return included
|
||||
|
||||
|
||||
def _get_page_links(page_numbers, current, url_func):
|
||||
"""
|
||||
Given a list of page numbers and `None` page breaks,
|
||||
return a list of `PageLink` objects.
|
||||
"""
|
||||
page_links = []
|
||||
for page_number in page_numbers:
|
||||
if page_number is None:
|
||||
page_link = PAGE_BREAK
|
||||
else:
|
||||
page_link = PageLink(
|
||||
url=url_func(page_number),
|
||||
number=page_number,
|
||||
is_active=(page_number == current),
|
||||
is_break=False
|
||||
)
|
||||
page_links.append(page_link)
|
||||
return page_links
|
||||
|
||||
|
||||
def _reverse_ordering(ordering_tuple):
|
||||
"""
|
||||
Given an order_by tuple such as `('-created', 'uuid')` reverse the
|
||||
ordering and return a new tuple, eg. `('created', '-uuid')`.
|
||||
"""
|
||||
def invert(x):
|
||||
return x[1:] if x.startswith('-') else '-' + x
|
||||
|
||||
return tuple([invert(item) for item in ordering_tuple])
|
||||
|
||||
|
||||
Cursor = namedtuple('Cursor', ['offset', 'reverse', 'position'])
|
||||
PageLink = namedtuple('PageLink', ['url', 'number', 'is_active', 'is_break'])
|
||||
|
||||
PAGE_BREAK = PageLink(url=None, number=None, is_active=False, is_break=True)
|
||||
|
||||
|
||||
class BasePagination:
|
||||
display_page_controls = False
|
||||
|
||||
def paginate_queryset(self, queryset, request, view=None): # pragma: no cover
|
||||
raise NotImplementedError('paginate_queryset() must be implemented.')
|
||||
|
||||
def get_paginated_response(self, data): # pragma: no cover
|
||||
raise NotImplementedError('get_paginated_response() must be implemented.')
|
||||
|
||||
def get_paginated_response_schema(self, schema):
|
||||
return schema
|
||||
|
||||
def to_html(self): # pragma: no cover
|
||||
raise NotImplementedError('to_html() must be implemented to display page controls.')
|
||||
|
||||
def get_results(self, data):
|
||||
return data['results']
|
||||
|
||||
def get_schema_fields(self, view):
|
||||
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||
return []
|
||||
|
||||
def get_schema_operation_parameters(self, view):
|
||||
return []
|
||||
|
||||
|
||||
class PageNumberPagination(BasePagination):
|
||||
"""
|
||||
A simple page number based style that supports page numbers as
|
||||
query parameters. For example:
|
||||
|
||||
http://api.example.org/accounts/?page=4
|
||||
http://api.example.org/accounts/?page=4&page_size=100
|
||||
"""
|
||||
# The default page size.
|
||||
# Defaults to `None`, meaning pagination is disabled.
|
||||
page_size = api_settings.PAGE_SIZE
|
||||
|
||||
django_paginator_class = DjangoPaginator
|
||||
|
||||
# Client can control the page using this query parameter.
|
||||
page_query_param = 'page'
|
||||
page_query_description = _('A page number within the paginated result set.')
|
||||
|
||||
# Client can control the page size using this query parameter.
|
||||
# Default is 'None'. Set to eg 'page_size' to enable usage.
|
||||
page_size_query_param = None
|
||||
page_size_query_description = _('Number of results to return per page.')
|
||||
|
||||
# Set to an integer to limit the maximum page size the client may request.
|
||||
# Only relevant if 'page_size_query_param' has also been set.
|
||||
max_page_size = None
|
||||
|
||||
last_page_strings = ('last',)
|
||||
|
||||
template = 'rest_framework/pagination/numbers.html'
|
||||
|
||||
invalid_page_message = _('Invalid page.')
|
||||
|
||||
def paginate_queryset(self, queryset, request, view=None):
|
||||
"""
|
||||
Paginate a queryset if required, either returning a
|
||||
page object, or `None` if pagination is not configured for this view.
|
||||
"""
|
||||
page_size = self.get_page_size(request)
|
||||
if not page_size:
|
||||
return None
|
||||
|
||||
paginator = self.django_paginator_class(queryset, page_size)
|
||||
page_number = self.get_page_number(request, paginator)
|
||||
|
||||
try:
|
||||
self.page = paginator.page(page_number)
|
||||
except InvalidPage as exc:
|
||||
msg = self.invalid_page_message.format(
|
||||
page_number=page_number, message=str(exc)
|
||||
)
|
||||
raise NotFound(msg)
|
||||
|
||||
if paginator.num_pages > 1 and self.template is not None:
|
||||
# The browsable API should display pagination controls.
|
||||
self.display_page_controls = True
|
||||
|
||||
self.request = request
|
||||
return list(self.page)
|
||||
|
||||
def get_page_number(self, request, paginator):
|
||||
page_number = request.query_params.get(self.page_query_param, 1)
|
||||
if page_number in self.last_page_strings:
|
||||
page_number = paginator.num_pages
|
||||
return page_number
|
||||
|
||||
def get_paginated_response(self, data):
|
||||
return Response(OrderedDict([
|
||||
('count', self.page.paginator.count),
|
||||
('next', self.get_next_link()),
|
||||
('previous', self.get_previous_link()),
|
||||
('results', data)
|
||||
]))
|
||||
|
||||
def get_paginated_response_schema(self, schema):
|
||||
return {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'count': {
|
||||
'type': 'integer',
|
||||
'example': 123,
|
||||
},
|
||||
'next': {
|
||||
'type': 'string',
|
||||
'nullable': True,
|
||||
'format': 'uri',
|
||||
'example': 'http://api.example.org/accounts/?{page_query_param}=4'.format(
|
||||
page_query_param=self.page_query_param)
|
||||
},
|
||||
'previous': {
|
||||
'type': 'string',
|
||||
'nullable': True,
|
||||
'format': 'uri',
|
||||
'example': 'http://api.example.org/accounts/?{page_query_param}=2'.format(
|
||||
page_query_param=self.page_query_param)
|
||||
},
|
||||
'results': schema,
|
||||
},
|
||||
}
|
||||
|
||||
def get_page_size(self, request):
|
||||
if self.page_size_query_param:
|
||||
try:
|
||||
return _positive_int(
|
||||
request.query_params[self.page_size_query_param],
|
||||
strict=True,
|
||||
cutoff=self.max_page_size
|
||||
)
|
||||
except (KeyError, ValueError):
|
||||
pass
|
||||
|
||||
return self.page_size
|
||||
|
||||
def get_next_link(self):
|
||||
if not self.page.has_next():
|
||||
return None
|
||||
url = self.request.build_absolute_uri()
|
||||
page_number = self.page.next_page_number()
|
||||
return replace_query_param(url, self.page_query_param, page_number)
|
||||
|
||||
def get_previous_link(self):
|
||||
if not self.page.has_previous():
|
||||
return None
|
||||
url = self.request.build_absolute_uri()
|
||||
page_number = self.page.previous_page_number()
|
||||
if page_number == 1:
|
||||
return remove_query_param(url, self.page_query_param)
|
||||
return replace_query_param(url, self.page_query_param, page_number)
|
||||
|
||||
def get_html_context(self):
|
||||
base_url = self.request.build_absolute_uri()
|
||||
|
||||
def page_number_to_url(page_number):
|
||||
if page_number == 1:
|
||||
return remove_query_param(base_url, self.page_query_param)
|
||||
else:
|
||||
return replace_query_param(base_url, self.page_query_param, page_number)
|
||||
|
||||
current = self.page.number
|
||||
final = self.page.paginator.num_pages
|
||||
page_numbers = _get_displayed_page_numbers(current, final)
|
||||
page_links = _get_page_links(page_numbers, current, page_number_to_url)
|
||||
|
||||
return {
|
||||
'previous_url': self.get_previous_link(),
|
||||
'next_url': self.get_next_link(),
|
||||
'page_links': page_links
|
||||
}
|
||||
|
||||
def to_html(self):
|
||||
template = loader.get_template(self.template)
|
||||
context = self.get_html_context()
|
||||
return template.render(context)
|
||||
|
||||
def get_schema_fields(self, view):
|
||||
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
|
||||
fields = [
|
||||
coreapi.Field(
|
||||
name=self.page_query_param,
|
||||
required=False,
|
||||
location='query',
|
||||
schema=coreschema.Integer(
|
||||
title='Page',
|
||||
description=force_str(self.page_query_description)
|
||||
)
|
||||
)
|
||||
]
|
||||
if self.page_size_query_param is not None:
|
||||
fields.append(
|
||||
coreapi.Field(
|
||||
name=self.page_size_query_param,
|
||||
required=False,
|
||||
location='query',
|
||||
schema=coreschema.Integer(
|
||||
title='Page size',
|
||||
description=force_str(self.page_size_query_description)
|
||||
)
|
||||
)
|
||||
)
|
||||
return fields
|
||||
|
||||
def get_schema_operation_parameters(self, view):
|
||||
parameters = [
|
||||
{
|
||||
'name': self.page_query_param,
|
||||
'required': False,
|
||||
'in': 'query',
|
||||
'description': force_str(self.page_query_description),
|
||||
'schema': {
|
||||
'type': 'integer',
|
||||
},
|
||||
},
|
||||
]
|
||||
if self.page_size_query_param is not None:
|
||||
parameters.append(
|
||||
{
|
||||
'name': self.page_size_query_param,
|
||||
'required': False,
|
||||
'in': 'query',
|
||||
'description': force_str(self.page_size_query_description),
|
||||
'schema': {
|
||||
'type': 'integer',
|
||||
},
|
||||
},
|
||||
)
|
||||
return parameters
|
||||
|
||||
|
||||
class LimitOffsetPagination(BasePagination):
|
||||
"""
|
||||
A limit/offset based style. For example:
|
||||
|
||||
http://api.example.org/accounts/?limit=100
|
||||
http://api.example.org/accounts/?offset=400&limit=100
|
||||
"""
|
||||
default_limit = api_settings.PAGE_SIZE
|
||||
limit_query_param = 'limit'
|
||||
limit_query_description = _('Number of results to return per page.')
|
||||
offset_query_param = 'offset'
|
||||
offset_query_description = _('The initial index from which to return the results.')
|
||||
max_limit = None
|
||||
template = 'rest_framework/pagination/numbers.html'
|
||||
|
||||
def paginate_queryset(self, queryset, request, view=None):
|
||||
self.limit = self.get_limit(request)
|
||||
if self.limit is None:
|
||||
return None
|
||||
|
||||
self.count = self.get_count(queryset)
|
||||
self.offset = self.get_offset(request)
|
||||
self.request = request
|
||||
if self.count > self.limit and self.template is not None:
|
||||
self.display_page_controls = True
|
||||
|
||||
if self.count == 0 or self.offset > self.count:
|
||||
return []
|
||||
return list(queryset[self.offset:self.offset + self.limit])
|
||||
|
||||
def get_paginated_response(self, data):
|
||||
return Response(OrderedDict([
|
||||
('count', self.count),
|
||||
('next', self.get_next_link()),
|
||||
('previous', self.get_previous_link()),
|
||||
('results', data)
|
||||
]))
|
||||
|
||||
def get_paginated_response_schema(self, schema):
|
||||
return {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'count': {
|
||||
'type': 'integer',
|
||||
'example': 123,
|
||||
},
|
||||
'next': {
|
||||
'type': 'string',
|
||||
'nullable': True,
|
||||
'format': 'uri',
|
||||
'example': 'http://api.example.org/accounts/?{offset_param}=400&{limit_param}=100'.format(
|
||||
offset_param=self.offset_query_param, limit_param=self.limit_query_param),
|
||||
},
|
||||
'previous': {
|
||||
'type': 'string',
|
||||
'nullable': True,
|
||||
'format': 'uri',
|
||||
'example': 'http://api.example.org/accounts/?{offset_param}=200&{limit_param}=100'.format(
|
||||
offset_param=self.offset_query_param, limit_param=self.limit_query_param),
|
||||
},
|
||||
'results': schema,
|
||||
},
|
||||
}
|
||||
|
||||
def get_limit(self, request):
|
||||
if self.limit_query_param:
|
||||
try:
|
||||
return _positive_int(
|
||||
request.query_params[self.limit_query_param],
|
||||
strict=True,
|
||||
cutoff=self.max_limit
|
||||
)
|
||||
except (KeyError, ValueError):
|
||||
pass
|
||||
|
||||
return self.default_limit
|
||||
|
||||
def get_offset(self, request):
|
||||
try:
|
||||
return _positive_int(
|
||||
request.query_params[self.offset_query_param],
|
||||
)
|
||||
except (KeyError, ValueError):
|
||||
return 0
|
||||
|
||||
def get_next_link(self):
|
||||
if self.offset + self.limit >= self.count:
|
||||
return None
|
||||
|
||||
url = self.request.build_absolute_uri()
|
||||
url = replace_query_param(url, self.limit_query_param, self.limit)
|
||||
|
||||
offset = self.offset + self.limit
|
||||
return replace_query_param(url, self.offset_query_param, offset)
|
||||
|
||||
def get_previous_link(self):
|
||||
if self.offset <= 0:
|
||||
return None
|
||||
|
||||
url = self.request.build_absolute_uri()
|
||||
url = replace_query_param(url, self.limit_query_param, self.limit)
|
||||
|
||||
if self.offset - self.limit <= 0:
|
||||
return remove_query_param(url, self.offset_query_param)
|
||||
|
||||
offset = self.offset - self.limit
|
||||
return replace_query_param(url, self.offset_query_param, offset)
|
||||
|
||||
def get_html_context(self):
|
||||
base_url = self.request.build_absolute_uri()
|
||||
|
||||
if self.limit:
|
||||
current = _divide_with_ceil(self.offset, self.limit) + 1
|
||||
|
||||
# The number of pages is a little bit fiddly.
|
||||
# We need to sum both the number of pages from current offset to end
|
||||
# plus the number of pages up to the current offset.
|
||||
# When offset is not strictly divisible by the limit then we may
|
||||
# end up introducing an extra page as an artifact.
|
||||
final = (
|
||||
_divide_with_ceil(self.count - self.offset, self.limit) +
|
||||
_divide_with_ceil(self.offset, self.limit)
|
||||
)
|
||||
|
||||
final = max(final, 1)
|
||||
else:
|
||||
current = 1
|
||||
final = 1
|
||||
|
||||
if current > final:
|
||||
current = final
|
||||
|
||||
def page_number_to_url(page_number):
|
||||
if page_number == 1:
|
||||
return remove_query_param(base_url, self.offset_query_param)
|
||||
else:
|
||||
offset = self.offset + ((page_number - current) * self.limit)
|
||||
return replace_query_param(base_url, self.offset_query_param, offset)
|
||||
|
||||
page_numbers = _get_displayed_page_numbers(current, final)
|
||||
page_links = _get_page_links(page_numbers, current, page_number_to_url)
|
||||
|
||||
return {
|
||||
'previous_url': self.get_previous_link(),
|
||||
'next_url': self.get_next_link(),
|
||||
'page_links': page_links
|
||||
}
|
||||
|
||||
def to_html(self):
|
||||
template = loader.get_template(self.template)
|
||||
context = self.get_html_context()
|
||||
return template.render(context)
|
||||
|
||||
def get_count(self, queryset):
|
||||
"""
|
||||
Determine an object count, supporting either querysets or regular lists.
|
||||
"""
|
||||
try:
|
||||
return queryset.count()
|
||||
except (AttributeError, TypeError):
|
||||
return len(queryset)
|
||||
|
||||
def get_schema_fields(self, view):
|
||||
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
|
||||
return [
|
||||
coreapi.Field(
|
||||
name=self.limit_query_param,
|
||||
required=False,
|
||||
location='query',
|
||||
schema=coreschema.Integer(
|
||||
title='Limit',
|
||||
description=force_str(self.limit_query_description)
|
||||
)
|
||||
),
|
||||
coreapi.Field(
|
||||
name=self.offset_query_param,
|
||||
required=False,
|
||||
location='query',
|
||||
schema=coreschema.Integer(
|
||||
title='Offset',
|
||||
description=force_str(self.offset_query_description)
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
def get_schema_operation_parameters(self, view):
|
||||
parameters = [
|
||||
{
|
||||
'name': self.limit_query_param,
|
||||
'required': False,
|
||||
'in': 'query',
|
||||
'description': force_str(self.limit_query_description),
|
||||
'schema': {
|
||||
'type': 'integer',
|
||||
},
|
||||
},
|
||||
{
|
||||
'name': self.offset_query_param,
|
||||
'required': False,
|
||||
'in': 'query',
|
||||
'description': force_str(self.offset_query_description),
|
||||
'schema': {
|
||||
'type': 'integer',
|
||||
},
|
||||
},
|
||||
]
|
||||
return parameters
|
||||
|
||||
|
||||
class CursorPagination(BasePagination):
|
||||
"""
|
||||
The cursor pagination implementation is necessarily complex.
|
||||
For an overview of the position/offset style we use, see this post:
|
||||
https://cra.mr/2011/03/08/building-cursors-for-the-disqus-api
|
||||
"""
|
||||
cursor_query_param = 'cursor'
|
||||
cursor_query_description = _('The pagination cursor value.')
|
||||
page_size = api_settings.PAGE_SIZE
|
||||
invalid_cursor_message = _('Invalid cursor')
|
||||
ordering = '-created'
|
||||
template = 'rest_framework/pagination/previous_and_next.html'
|
||||
|
||||
# Client can control the page size using this query parameter.
|
||||
# Default is 'None'. Set to eg 'page_size' to enable usage.
|
||||
page_size_query_param = None
|
||||
page_size_query_description = _('Number of results to return per page.')
|
||||
|
||||
# Set to an integer to limit the maximum page size the client may request.
|
||||
# Only relevant if 'page_size_query_param' has also been set.
|
||||
max_page_size = None
|
||||
|
||||
# The offset in the cursor is used in situations where we have a
|
||||
# nearly-unique index. (Eg millisecond precision creation timestamps)
|
||||
# We guard against malicious users attempting to cause expensive database
|
||||
# queries, by having a hard cap on the maximum possible size of the offset.
|
||||
offset_cutoff = 1000
|
||||
|
||||
def paginate_queryset(self, queryset, request, view=None):
|
||||
self.page_size = self.get_page_size(request)
|
||||
if not self.page_size:
|
||||
return None
|
||||
|
||||
self.base_url = request.build_absolute_uri()
|
||||
self.ordering = self.get_ordering(request, queryset, view)
|
||||
|
||||
self.cursor = self.decode_cursor(request)
|
||||
if self.cursor is None:
|
||||
(offset, reverse, current_position) = (0, False, None)
|
||||
else:
|
||||
(offset, reverse, current_position) = self.cursor
|
||||
|
||||
# Cursor pagination always enforces an ordering.
|
||||
if reverse:
|
||||
queryset = queryset.order_by(*_reverse_ordering(self.ordering))
|
||||
else:
|
||||
queryset = queryset.order_by(*self.ordering)
|
||||
|
||||
# If we have a cursor with a fixed position then filter by that.
|
||||
if current_position is not None:
|
||||
order = self.ordering[0]
|
||||
is_reversed = order.startswith('-')
|
||||
order_attr = order.lstrip('-')
|
||||
|
||||
# Test for: (cursor reversed) XOR (queryset reversed)
|
||||
if self.cursor.reverse != is_reversed:
|
||||
kwargs = {order_attr + '__lt': current_position}
|
||||
else:
|
||||
kwargs = {order_attr + '__gt': current_position}
|
||||
|
||||
queryset = queryset.filter(**kwargs)
|
||||
|
||||
# If we have an offset cursor then offset the entire page by that amount.
|
||||
# We also always fetch an extra item in order to determine if there is a
|
||||
# page following on from this one.
|
||||
results = list(queryset[offset:offset + self.page_size + 1])
|
||||
self.page = list(results[:self.page_size])
|
||||
|
||||
# Determine the position of the final item following the page.
|
||||
if len(results) > len(self.page):
|
||||
has_following_position = True
|
||||
following_position = self._get_position_from_instance(results[-1], self.ordering)
|
||||
else:
|
||||
has_following_position = False
|
||||
following_position = None
|
||||
|
||||
if reverse:
|
||||
# If we have a reverse queryset, then the query ordering was in reverse
|
||||
# so we need to reverse the items again before returning them to the user.
|
||||
self.page = list(reversed(self.page))
|
||||
|
||||
# Determine next and previous positions for reverse cursors.
|
||||
self.has_next = (current_position is not None) or (offset > 0)
|
||||
self.has_previous = has_following_position
|
||||
if self.has_next:
|
||||
self.next_position = current_position
|
||||
if self.has_previous:
|
||||
self.previous_position = following_position
|
||||
else:
|
||||
# Determine next and previous positions for forward cursors.
|
||||
self.has_next = has_following_position
|
||||
self.has_previous = (current_position is not None) or (offset > 0)
|
||||
if self.has_next:
|
||||
self.next_position = following_position
|
||||
if self.has_previous:
|
||||
self.previous_position = current_position
|
||||
|
||||
# Display page controls in the browsable API if there is more
|
||||
# than one page.
|
||||
if (self.has_previous or self.has_next) and self.template is not None:
|
||||
self.display_page_controls = True
|
||||
|
||||
return self.page
|
||||
|
||||
def get_page_size(self, request):
|
||||
if self.page_size_query_param:
|
||||
try:
|
||||
return _positive_int(
|
||||
request.query_params[self.page_size_query_param],
|
||||
strict=True,
|
||||
cutoff=self.max_page_size
|
||||
)
|
||||
except (KeyError, ValueError):
|
||||
pass
|
||||
|
||||
return self.page_size
|
||||
|
||||
def get_next_link(self):
|
||||
if not self.has_next:
|
||||
return None
|
||||
|
||||
if self.page and self.cursor and self.cursor.reverse and self.cursor.offset != 0:
|
||||
# If we're reversing direction and we have an offset cursor
|
||||
# then we cannot use the first position we find as a marker.
|
||||
compare = self._get_position_from_instance(self.page[-1], self.ordering)
|
||||
else:
|
||||
compare = self.next_position
|
||||
offset = 0
|
||||
|
||||
has_item_with_unique_position = False
|
||||
for item in reversed(self.page):
|
||||
position = self._get_position_from_instance(item, self.ordering)
|
||||
if position != compare:
|
||||
# The item in this position and the item following it
|
||||
# have different positions. We can use this position as
|
||||
# our marker.
|
||||
has_item_with_unique_position = True
|
||||
break
|
||||
|
||||
# The item in this position has the same position as the item
|
||||
# following it, we can't use it as a marker position, so increment
|
||||
# the offset and keep seeking to the previous item.
|
||||
compare = position
|
||||
offset += 1
|
||||
|
||||
if self.page and not has_item_with_unique_position:
|
||||
# There were no unique positions in the page.
|
||||
if not self.has_previous:
|
||||
# We are on the first page.
|
||||
# Our cursor will have an offset equal to the page size,
|
||||
# but no position to filter against yet.
|
||||
offset = self.page_size
|
||||
position = None
|
||||
elif self.cursor.reverse:
|
||||
# The change in direction will introduce a paging artifact,
|
||||
# where we end up skipping forward a few extra items.
|
||||
offset = 0
|
||||
position = self.previous_position
|
||||
else:
|
||||
# Use the position from the existing cursor and increment
|
||||
# it's offset by the page size.
|
||||
offset = self.cursor.offset + self.page_size
|
||||
position = self.previous_position
|
||||
|
||||
if not self.page:
|
||||
position = self.next_position
|
||||
|
||||
cursor = Cursor(offset=offset, reverse=False, position=position)
|
||||
return self.encode_cursor(cursor)
|
||||
|
||||
def get_previous_link(self):
|
||||
if not self.has_previous:
|
||||
return None
|
||||
|
||||
if self.page and self.cursor and not self.cursor.reverse and self.cursor.offset != 0:
|
||||
# If we're reversing direction and we have an offset cursor
|
||||
# then we cannot use the first position we find as a marker.
|
||||
compare = self._get_position_from_instance(self.page[0], self.ordering)
|
||||
else:
|
||||
compare = self.previous_position
|
||||
offset = 0
|
||||
|
||||
has_item_with_unique_position = False
|
||||
for item in self.page:
|
||||
position = self._get_position_from_instance(item, self.ordering)
|
||||
if position != compare:
|
||||
# The item in this position and the item following it
|
||||
# have different positions. We can use this position as
|
||||
# our marker.
|
||||
has_item_with_unique_position = True
|
||||
break
|
||||
|
||||
# The item in this position has the same position as the item
|
||||
# following it, we can't use it as a marker position, so increment
|
||||
# the offset and keep seeking to the previous item.
|
||||
compare = position
|
||||
offset += 1
|
||||
|
||||
if self.page and not has_item_with_unique_position:
|
||||
# There were no unique positions in the page.
|
||||
if not self.has_next:
|
||||
# We are on the final page.
|
||||
# Our cursor will have an offset equal to the page size,
|
||||
# but no position to filter against yet.
|
||||
offset = self.page_size
|
||||
position = None
|
||||
elif self.cursor.reverse:
|
||||
# Use the position from the existing cursor and increment
|
||||
# it's offset by the page size.
|
||||
offset = self.cursor.offset + self.page_size
|
||||
position = self.next_position
|
||||
else:
|
||||
# The change in direction will introduce a paging artifact,
|
||||
# where we end up skipping back a few extra items.
|
||||
offset = 0
|
||||
position = self.next_position
|
||||
|
||||
if not self.page:
|
||||
position = self.previous_position
|
||||
|
||||
cursor = Cursor(offset=offset, reverse=True, position=position)
|
||||
return self.encode_cursor(cursor)
|
||||
|
||||
def get_ordering(self, request, queryset, view):
|
||||
"""
|
||||
Return a tuple of strings, that may be used in an `order_by` method.
|
||||
"""
|
||||
ordering_filters = [
|
||||
filter_cls for filter_cls in getattr(view, 'filter_backends', [])
|
||||
if hasattr(filter_cls, 'get_ordering')
|
||||
]
|
||||
|
||||
if ordering_filters:
|
||||
# If a filter exists on the view that implements `get_ordering`
|
||||
# then we defer to that filter to determine the ordering.
|
||||
filter_cls = ordering_filters[0]
|
||||
filter_instance = filter_cls()
|
||||
ordering = filter_instance.get_ordering(request, queryset, view)
|
||||
assert ordering is not None, (
|
||||
'Using cursor pagination, but filter class {filter_cls} '
|
||||
'returned a `None` ordering.'.format(
|
||||
filter_cls=filter_cls.__name__
|
||||
)
|
||||
)
|
||||
else:
|
||||
# The default case is to check for an `ordering` attribute
|
||||
# on this pagination instance.
|
||||
ordering = self.ordering
|
||||
assert ordering is not None, (
|
||||
'Using cursor pagination, but no ordering attribute was declared '
|
||||
'on the pagination class.'
|
||||
)
|
||||
assert '__' not in ordering, (
|
||||
'Cursor pagination does not support double underscore lookups '
|
||||
'for orderings. Orderings should be an unchanging, unique or '
|
||||
'nearly-unique field on the model, such as "-created" or "pk".'
|
||||
)
|
||||
|
||||
assert isinstance(ordering, (str, list, tuple)), (
|
||||
'Invalid ordering. Expected string or tuple, but got {type}'.format(
|
||||
type=type(ordering).__name__
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(ordering, str):
|
||||
return (ordering,)
|
||||
return tuple(ordering)
|
||||
|
||||
def decode_cursor(self, request):
|
||||
"""
|
||||
Given a request with a cursor, return a `Cursor` instance.
|
||||
"""
|
||||
# Determine if we have a cursor, and if so then decode it.
|
||||
encoded = request.query_params.get(self.cursor_query_param)
|
||||
if encoded is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
querystring = b64decode(encoded.encode('ascii')).decode('ascii')
|
||||
tokens = parse.parse_qs(querystring, keep_blank_values=True)
|
||||
|
||||
offset = tokens.get('o', ['0'])[0]
|
||||
offset = _positive_int(offset, cutoff=self.offset_cutoff)
|
||||
|
||||
reverse = tokens.get('r', ['0'])[0]
|
||||
reverse = bool(int(reverse))
|
||||
|
||||
position = tokens.get('p', [None])[0]
|
||||
except (TypeError, ValueError):
|
||||
raise NotFound(self.invalid_cursor_message)
|
||||
|
||||
return Cursor(offset=offset, reverse=reverse, position=position)
|
||||
|
||||
def encode_cursor(self, cursor):
|
||||
"""
|
||||
Given a Cursor instance, return an url with encoded cursor.
|
||||
"""
|
||||
tokens = {}
|
||||
if cursor.offset != 0:
|
||||
tokens['o'] = str(cursor.offset)
|
||||
if cursor.reverse:
|
||||
tokens['r'] = '1'
|
||||
if cursor.position is not None:
|
||||
tokens['p'] = cursor.position
|
||||
|
||||
querystring = parse.urlencode(tokens, doseq=True)
|
||||
encoded = b64encode(querystring.encode('ascii')).decode('ascii')
|
||||
return replace_query_param(self.base_url, self.cursor_query_param, encoded)
|
||||
|
||||
def _get_position_from_instance(self, instance, ordering):
|
||||
field_name = ordering[0].lstrip('-')
|
||||
if isinstance(instance, dict):
|
||||
attr = instance[field_name]
|
||||
else:
|
||||
attr = getattr(instance, field_name)
|
||||
return str(attr)
|
||||
|
||||
def get_paginated_response(self, data):
|
||||
return Response(OrderedDict([
|
||||
('next', self.get_next_link()),
|
||||
('previous', self.get_previous_link()),
|
||||
('results', data)
|
||||
]))
|
||||
|
||||
def get_paginated_response_schema(self, schema):
|
||||
return {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'next': {
|
||||
'type': 'string',
|
||||
'nullable': True,
|
||||
},
|
||||
'previous': {
|
||||
'type': 'string',
|
||||
'nullable': True,
|
||||
},
|
||||
'results': schema,
|
||||
},
|
||||
}
|
||||
|
||||
def get_html_context(self):
|
||||
return {
|
||||
'previous_url': self.get_previous_link(),
|
||||
'next_url': self.get_next_link()
|
||||
}
|
||||
|
||||
def to_html(self):
|
||||
template = loader.get_template(self.template)
|
||||
context = self.get_html_context()
|
||||
return template.render(context)
|
||||
|
||||
def get_schema_fields(self, view):
|
||||
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
|
||||
fields = [
|
||||
coreapi.Field(
|
||||
name=self.cursor_query_param,
|
||||
required=False,
|
||||
location='query',
|
||||
schema=coreschema.String(
|
||||
title='Cursor',
|
||||
description=force_str(self.cursor_query_description)
|
||||
)
|
||||
)
|
||||
]
|
||||
if self.page_size_query_param is not None:
|
||||
fields.append(
|
||||
coreapi.Field(
|
||||
name=self.page_size_query_param,
|
||||
required=False,
|
||||
location='query',
|
||||
schema=coreschema.Integer(
|
||||
title='Page size',
|
||||
description=force_str(self.page_size_query_description)
|
||||
)
|
||||
)
|
||||
)
|
||||
return fields
|
||||
|
||||
def get_schema_operation_parameters(self, view):
|
||||
parameters = [
|
||||
{
|
||||
'name': self.cursor_query_param,
|
||||
'required': False,
|
||||
'in': 'query',
|
||||
'description': force_str(self.cursor_query_description),
|
||||
'schema': {
|
||||
'type': 'string',
|
||||
},
|
||||
}
|
||||
]
|
||||
if self.page_size_query_param is not None:
|
||||
parameters.append(
|
||||
{
|
||||
'name': self.page_size_query_param,
|
||||
'required': False,
|
||||
'in': 'query',
|
||||
'description': force_str(self.page_size_query_description),
|
||||
'schema': {
|
||||
'type': 'integer',
|
||||
},
|
||||
}
|
||||
)
|
||||
return parameters
|
@ -0,0 +1,209 @@
|
||||
"""
|
||||
Parsers are used to parse the content of incoming HTTP requests.
|
||||
|
||||
They give us a generic way of being able to handle various media types
|
||||
on the request, such as form content or json encoded data.
|
||||
"""
|
||||
import codecs
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.files.uploadhandler import StopFutureHandlers
|
||||
from django.http import QueryDict
|
||||
from django.http.multipartparser import ChunkIter
|
||||
from django.http.multipartparser import \
|
||||
MultiPartParser as DjangoMultiPartParser
|
||||
from django.http.multipartparser import MultiPartParserError
|
||||
|
||||
from rest_framework import renderers
|
||||
from rest_framework.compat import parse_header_parameters
|
||||
from rest_framework.exceptions import ParseError
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.utils import json
|
||||
|
||||
|
||||
class DataAndFiles:
|
||||
def __init__(self, data, files):
|
||||
self.data = data
|
||||
self.files = files
|
||||
|
||||
|
||||
class BaseParser:
|
||||
"""
|
||||
All parsers should extend `BaseParser`, specifying a `media_type`
|
||||
attribute, and overriding the `.parse()` method.
|
||||
"""
|
||||
media_type = None
|
||||
|
||||
def parse(self, stream, media_type=None, parser_context=None):
|
||||
"""
|
||||
Given a stream to read from, return the parsed representation.
|
||||
Should return parsed data, or a `DataAndFiles` object consisting of the
|
||||
parsed data and files.
|
||||
"""
|
||||
raise NotImplementedError(".parse() must be overridden.")
|
||||
|
||||
|
||||
class JSONParser(BaseParser):
|
||||
"""
|
||||
Parses JSON-serialized data.
|
||||
"""
|
||||
media_type = 'application/json'
|
||||
renderer_class = renderers.JSONRenderer
|
||||
strict = api_settings.STRICT_JSON
|
||||
|
||||
def parse(self, stream, media_type=None, parser_context=None):
|
||||
"""
|
||||
Parses the incoming bytestream as JSON and returns the resulting data.
|
||||
"""
|
||||
parser_context = parser_context or {}
|
||||
encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
|
||||
|
||||
try:
|
||||
decoded_stream = codecs.getreader(encoding)(stream)
|
||||
parse_constant = json.strict_constant if self.strict else None
|
||||
return json.load(decoded_stream, parse_constant=parse_constant)
|
||||
except ValueError as exc:
|
||||
raise ParseError('JSON parse error - %s' % str(exc))
|
||||
|
||||
|
||||
class FormParser(BaseParser):
|
||||
"""
|
||||
Parser for form data.
|
||||
"""
|
||||
media_type = 'application/x-www-form-urlencoded'
|
||||
|
||||
def parse(self, stream, media_type=None, parser_context=None):
|
||||
"""
|
||||
Parses the incoming bytestream as a URL encoded form,
|
||||
and returns the resulting QueryDict.
|
||||
"""
|
||||
parser_context = parser_context or {}
|
||||
encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
|
||||
return QueryDict(stream.read(), encoding=encoding)
|
||||
|
||||
|
||||
class MultiPartParser(BaseParser):
|
||||
"""
|
||||
Parser for multipart form data, which may include file data.
|
||||
"""
|
||||
media_type = 'multipart/form-data'
|
||||
|
||||
def parse(self, stream, media_type=None, parser_context=None):
|
||||
"""
|
||||
Parses the incoming bytestream as a multipart encoded form,
|
||||
and returns a DataAndFiles object.
|
||||
|
||||
`.data` will be a `QueryDict` containing all the form parameters.
|
||||
`.files` will be a `QueryDict` containing all the form files.
|
||||
"""
|
||||
parser_context = parser_context or {}
|
||||
request = parser_context['request']
|
||||
encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
|
||||
meta = request.META.copy()
|
||||
meta['CONTENT_TYPE'] = media_type
|
||||
upload_handlers = request.upload_handlers
|
||||
|
||||
try:
|
||||
parser = DjangoMultiPartParser(meta, stream, upload_handlers, encoding)
|
||||
data, files = parser.parse()
|
||||
return DataAndFiles(data, files)
|
||||
except MultiPartParserError as exc:
|
||||
raise ParseError('Multipart form parse error - %s' % str(exc))
|
||||
|
||||
|
||||
class FileUploadParser(BaseParser):
|
||||
"""
|
||||
Parser for file upload data.
|
||||
"""
|
||||
media_type = '*/*'
|
||||
errors = {
|
||||
'unhandled': 'FileUpload parse error - none of upload handlers can handle the stream',
|
||||
'no_filename': 'Missing filename. Request should include a Content-Disposition header with a filename parameter.',
|
||||
}
|
||||
|
||||
def parse(self, stream, media_type=None, parser_context=None):
|
||||
"""
|
||||
Treats the incoming bytestream as a raw file upload and returns
|
||||
a `DataAndFiles` object.
|
||||
|
||||
`.data` will be None (we expect request body to be a file content).
|
||||
`.files` will be a `QueryDict` containing one 'file' element.
|
||||
"""
|
||||
parser_context = parser_context or {}
|
||||
request = parser_context['request']
|
||||
encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
|
||||
meta = request.META
|
||||
upload_handlers = request.upload_handlers
|
||||
filename = self.get_filename(stream, media_type, parser_context)
|
||||
|
||||
if not filename:
|
||||
raise ParseError(self.errors['no_filename'])
|
||||
|
||||
# Note that this code is extracted from Django's handling of
|
||||
# file uploads in MultiPartParser.
|
||||
content_type = meta.get('HTTP_CONTENT_TYPE',
|
||||
meta.get('CONTENT_TYPE', ''))
|
||||
try:
|
||||
content_length = int(meta.get('HTTP_CONTENT_LENGTH',
|
||||
meta.get('CONTENT_LENGTH', 0)))
|
||||
except (ValueError, TypeError):
|
||||
content_length = None
|
||||
|
||||
# See if the handler will want to take care of the parsing.
|
||||
for handler in upload_handlers:
|
||||
result = handler.handle_raw_input(stream,
|
||||
meta,
|
||||
content_length,
|
||||
None,
|
||||
encoding)
|
||||
if result is not None:
|
||||
return DataAndFiles({}, {'file': result[1]})
|
||||
|
||||
# This is the standard case.
|
||||
possible_sizes = [x.chunk_size for x in upload_handlers if x.chunk_size]
|
||||
chunk_size = min([2 ** 31 - 4] + possible_sizes)
|
||||
chunks = ChunkIter(stream, chunk_size)
|
||||
counters = [0] * len(upload_handlers)
|
||||
|
||||
for index, handler in enumerate(upload_handlers):
|
||||
try:
|
||||
handler.new_file(None, filename, content_type,
|
||||
content_length, encoding)
|
||||
except StopFutureHandlers:
|
||||
upload_handlers = upload_handlers[:index + 1]
|
||||
break
|
||||
|
||||
for chunk in chunks:
|
||||
for index, handler in enumerate(upload_handlers):
|
||||
chunk_length = len(chunk)
|
||||
chunk = handler.receive_data_chunk(chunk, counters[index])
|
||||
counters[index] += chunk_length
|
||||
if chunk is None:
|
||||
break
|
||||
|
||||
for index, handler in enumerate(upload_handlers):
|
||||
file_obj = handler.file_complete(counters[index])
|
||||
if file_obj is not None:
|
||||
return DataAndFiles({}, {'file': file_obj})
|
||||
|
||||
raise ParseError(self.errors['unhandled'])
|
||||
|
||||
def get_filename(self, stream, media_type, parser_context):
|
||||
"""
|
||||
Detects the uploaded file name. First searches a 'filename' url kwarg.
|
||||
Then tries to parse Content-Disposition header.
|
||||
"""
|
||||
try:
|
||||
return parser_context['kwargs']['filename']
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
try:
|
||||
meta = parser_context['request'].META
|
||||
disposition, params = parse_header_parameters(meta['HTTP_CONTENT_DISPOSITION'])
|
||||
if 'filename*' in params:
|
||||
return params['filename*']
|
||||
else:
|
||||
return params['filename']
|
||||
except (AttributeError, KeyError, ValueError):
|
||||
pass
|
@ -0,0 +1,303 @@
|
||||
"""
|
||||
Provides a set of pluggable permission policies.
|
||||
"""
|
||||
from django.http import Http404
|
||||
|
||||
from rest_framework import exceptions
|
||||
|
||||
SAFE_METHODS = ('GET', 'HEAD', 'OPTIONS')
|
||||
|
||||
|
||||
class OperationHolderMixin:
|
||||
def __and__(self, other):
|
||||
return OperandHolder(AND, self, other)
|
||||
|
||||
def __or__(self, other):
|
||||
return OperandHolder(OR, self, other)
|
||||
|
||||
def __rand__(self, other):
|
||||
return OperandHolder(AND, other, self)
|
||||
|
||||
def __ror__(self, other):
|
||||
return OperandHolder(OR, other, self)
|
||||
|
||||
def __invert__(self):
|
||||
return SingleOperandHolder(NOT, self)
|
||||
|
||||
|
||||
class SingleOperandHolder(OperationHolderMixin):
|
||||
def __init__(self, operator_class, op1_class):
|
||||
self.operator_class = operator_class
|
||||
self.op1_class = op1_class
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
op1 = self.op1_class(*args, **kwargs)
|
||||
return self.operator_class(op1)
|
||||
|
||||
|
||||
class OperandHolder(OperationHolderMixin):
|
||||
def __init__(self, operator_class, op1_class, op2_class):
|
||||
self.operator_class = operator_class
|
||||
self.op1_class = op1_class
|
||||
self.op2_class = op2_class
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
op1 = self.op1_class(*args, **kwargs)
|
||||
op2 = self.op2_class(*args, **kwargs)
|
||||
return self.operator_class(op1, op2)
|
||||
|
||||
|
||||
class AND:
|
||||
def __init__(self, op1, op2):
|
||||
self.op1 = op1
|
||||
self.op2 = op2
|
||||
|
||||
def has_permission(self, request, view):
|
||||
return (
|
||||
self.op1.has_permission(request, view) and
|
||||
self.op2.has_permission(request, view)
|
||||
)
|
||||
|
||||
def has_object_permission(self, request, view, obj):
|
||||
return (
|
||||
self.op1.has_object_permission(request, view, obj) and
|
||||
self.op2.has_object_permission(request, view, obj)
|
||||
)
|
||||
|
||||
|
||||
class OR:
|
||||
def __init__(self, op1, op2):
|
||||
self.op1 = op1
|
||||
self.op2 = op2
|
||||
|
||||
def has_permission(self, request, view):
|
||||
return (
|
||||
self.op1.has_permission(request, view) or
|
||||
self.op2.has_permission(request, view)
|
||||
)
|
||||
|
||||
def has_object_permission(self, request, view, obj):
|
||||
return (
|
||||
self.op1.has_permission(request, view)
|
||||
and self.op1.has_object_permission(request, view, obj)
|
||||
) or (
|
||||
self.op2.has_permission(request, view)
|
||||
and self.op2.has_object_permission(request, view, obj)
|
||||
)
|
||||
|
||||
|
||||
class NOT:
|
||||
def __init__(self, op1):
|
||||
self.op1 = op1
|
||||
|
||||
def has_permission(self, request, view):
|
||||
return not self.op1.has_permission(request, view)
|
||||
|
||||
def has_object_permission(self, request, view, obj):
|
||||
return not self.op1.has_object_permission(request, view, obj)
|
||||
|
||||
|
||||
class BasePermissionMetaclass(OperationHolderMixin, type):
|
||||
pass
|
||||
|
||||
|
||||
class BasePermission(metaclass=BasePermissionMetaclass):
|
||||
"""
|
||||
A base class from which all permission classes should inherit.
|
||||
"""
|
||||
|
||||
def has_permission(self, request, view):
|
||||
"""
|
||||
Return `True` if permission is granted, `False` otherwise.
|
||||
"""
|
||||
return True
|
||||
|
||||
def has_object_permission(self, request, view, obj):
|
||||
"""
|
||||
Return `True` if permission is granted, `False` otherwise.
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
class AllowAny(BasePermission):
|
||||
"""
|
||||
Allow any access.
|
||||
This isn't strictly required, since you could use an empty
|
||||
permission_classes list, but it's useful because it makes the intention
|
||||
more explicit.
|
||||
"""
|
||||
|
||||
def has_permission(self, request, view):
|
||||
return True
|
||||
|
||||
|
||||
class IsAuthenticated(BasePermission):
|
||||
"""
|
||||
Allows access only to authenticated users.
|
||||
"""
|
||||
|
||||
def has_permission(self, request, view):
|
||||
return bool(request.user and request.user.is_authenticated)
|
||||
|
||||
|
||||
class IsAdminUser(BasePermission):
|
||||
"""
|
||||
Allows access only to admin users.
|
||||
"""
|
||||
|
||||
def has_permission(self, request, view):
|
||||
return bool(request.user and request.user.is_staff)
|
||||
|
||||
|
||||
class IsAuthenticatedOrReadOnly(BasePermission):
|
||||
"""
|
||||
The request is authenticated as a user, or is a read-only request.
|
||||
"""
|
||||
|
||||
def has_permission(self, request, view):
|
||||
return bool(
|
||||
request.method in SAFE_METHODS or
|
||||
request.user and
|
||||
request.user.is_authenticated
|
||||
)
|
||||
|
||||
|
||||
class DjangoModelPermissions(BasePermission):
|
||||
"""
|
||||
The request is authenticated using `django.contrib.auth` permissions.
|
||||
See: https://docs.djangoproject.com/en/dev/topics/auth/#permissions
|
||||
|
||||
It ensures that the user is authenticated, and has the appropriate
|
||||
`add`/`change`/`delete` permissions on the model.
|
||||
|
||||
This permission can only be applied against view classes that
|
||||
provide a `.queryset` attribute.
|
||||
"""
|
||||
|
||||
# Map methods into required permission codes.
|
||||
# Override this if you need to also provide 'view' permissions,
|
||||
# or if you want to provide custom permission codes.
|
||||
perms_map = {
|
||||
'GET': [],
|
||||
'OPTIONS': [],
|
||||
'HEAD': [],
|
||||
'POST': ['%(app_label)s.add_%(model_name)s'],
|
||||
'PUT': ['%(app_label)s.change_%(model_name)s'],
|
||||
'PATCH': ['%(app_label)s.change_%(model_name)s'],
|
||||
'DELETE': ['%(app_label)s.delete_%(model_name)s'],
|
||||
}
|
||||
|
||||
authenticated_users_only = True
|
||||
|
||||
def get_required_permissions(self, method, model_cls):
|
||||
"""
|
||||
Given a model and an HTTP method, return the list of permission
|
||||
codes that the user is required to have.
|
||||
"""
|
||||
kwargs = {
|
||||
'app_label': model_cls._meta.app_label,
|
||||
'model_name': model_cls._meta.model_name
|
||||
}
|
||||
|
||||
if method not in self.perms_map:
|
||||
raise exceptions.MethodNotAllowed(method)
|
||||
|
||||
return [perm % kwargs for perm in self.perms_map[method]]
|
||||
|
||||
def _queryset(self, view):
|
||||
assert hasattr(view, 'get_queryset') \
|
||||
or getattr(view, 'queryset', None) is not None, (
|
||||
'Cannot apply {} on a view that does not set '
|
||||
'`.queryset` or have a `.get_queryset()` method.'
|
||||
).format(self.__class__.__name__)
|
||||
|
||||
if hasattr(view, 'get_queryset'):
|
||||
queryset = view.get_queryset()
|
||||
assert queryset is not None, (
|
||||
'{}.get_queryset() returned None'.format(view.__class__.__name__)
|
||||
)
|
||||
return queryset
|
||||
return view.queryset
|
||||
|
||||
def has_permission(self, request, view):
|
||||
# Workaround to ensure DjangoModelPermissions are not applied
|
||||
# to the root view when using DefaultRouter.
|
||||
if getattr(view, '_ignore_model_permissions', False):
|
||||
return True
|
||||
|
||||
if not request.user or (
|
||||
not request.user.is_authenticated and self.authenticated_users_only):
|
||||
return False
|
||||
|
||||
queryset = self._queryset(view)
|
||||
perms = self.get_required_permissions(request.method, queryset.model)
|
||||
|
||||
return request.user.has_perms(perms)
|
||||
|
||||
|
||||
class DjangoModelPermissionsOrAnonReadOnly(DjangoModelPermissions):
|
||||
"""
|
||||
Similar to DjangoModelPermissions, except that anonymous users are
|
||||
allowed read-only access.
|
||||
"""
|
||||
authenticated_users_only = False
|
||||
|
||||
|
||||
class DjangoObjectPermissions(DjangoModelPermissions):
|
||||
"""
|
||||
The request is authenticated using Django's object-level permissions.
|
||||
It requires an object-permissions-enabled backend, such as Django Guardian.
|
||||
|
||||
It ensures that the user is authenticated, and has the appropriate
|
||||
`add`/`change`/`delete` permissions on the object using .has_perms.
|
||||
|
||||
This permission can only be applied against view classes that
|
||||
provide a `.queryset` attribute.
|
||||
"""
|
||||
perms_map = {
|
||||
'GET': [],
|
||||
'OPTIONS': [],
|
||||
'HEAD': [],
|
||||
'POST': ['%(app_label)s.add_%(model_name)s'],
|
||||
'PUT': ['%(app_label)s.change_%(model_name)s'],
|
||||
'PATCH': ['%(app_label)s.change_%(model_name)s'],
|
||||
'DELETE': ['%(app_label)s.delete_%(model_name)s'],
|
||||
}
|
||||
|
||||
def get_required_object_permissions(self, method, model_cls):
|
||||
kwargs = {
|
||||
'app_label': model_cls._meta.app_label,
|
||||
'model_name': model_cls._meta.model_name
|
||||
}
|
||||
|
||||
if method not in self.perms_map:
|
||||
raise exceptions.MethodNotAllowed(method)
|
||||
|
||||
return [perm % kwargs for perm in self.perms_map[method]]
|
||||
|
||||
def has_object_permission(self, request, view, obj):
|
||||
# authentication checks have already executed via has_permission
|
||||
queryset = self._queryset(view)
|
||||
model_cls = queryset.model
|
||||
user = request.user
|
||||
|
||||
perms = self.get_required_object_permissions(request.method, model_cls)
|
||||
|
||||
if not user.has_perms(perms, obj):
|
||||
# If the user does not have permissions we need to determine if
|
||||
# they have read permissions to see 403, or not, and simply see
|
||||
# a 404 response.
|
||||
|
||||
if request.method in SAFE_METHODS:
|
||||
# Read permissions already checked and failed, no need
|
||||
# to make another lookup.
|
||||
raise Http404
|
||||
|
||||
read_perms = self.get_required_object_permissions('GET', model_cls)
|
||||
if not user.has_perms(read_perms, obj):
|
||||
raise Http404
|
||||
|
||||
# Has read permissions.
|
||||
return False
|
||||
|
||||
return True
|
@ -0,0 +1,586 @@
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
from urllib import parse
|
||||
|
||||
from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist
|
||||
from django.db.models import Manager
|
||||
from django.db.models.query import QuerySet
|
||||
from django.urls import NoReverseMatch, Resolver404, get_script_prefix, resolve
|
||||
from django.utils.encoding import smart_str, uri_to_iri
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from rest_framework.fields import (
|
||||
Field, SkipField, empty, get_attribute, is_simple_callable, iter_options
|
||||
)
|
||||
from rest_framework.reverse import reverse
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.utils import html
|
||||
|
||||
|
||||
def method_overridden(method_name, klass, instance):
|
||||
"""
|
||||
Determine if a method has been overridden.
|
||||
"""
|
||||
method = getattr(klass, method_name)
|
||||
default_method = getattr(method, '__func__', method) # Python 3 compat
|
||||
return default_method is not getattr(instance, method_name).__func__
|
||||
|
||||
|
||||
class ObjectValueError(ValueError):
|
||||
"""
|
||||
Raised when `queryset.get()` failed due to an underlying `ValueError`.
|
||||
Wrapping prevents calling code conflating this with unrelated errors.
|
||||
"""
|
||||
|
||||
|
||||
class ObjectTypeError(TypeError):
|
||||
"""
|
||||
Raised when `queryset.get()` failed due to an underlying `TypeError`.
|
||||
Wrapping prevents calling code conflating this with unrelated errors.
|
||||
"""
|
||||
|
||||
|
||||
class Hyperlink(str):
|
||||
"""
|
||||
A string like object that additionally has an associated name.
|
||||
We use this for hyperlinked URLs that may render as a named link
|
||||
in some contexts, or render as a plain URL in others.
|
||||
"""
|
||||
def __new__(cls, url, obj):
|
||||
ret = super().__new__(cls, url)
|
||||
ret.obj = obj
|
||||
return ret
|
||||
|
||||
def __getnewargs__(self):
|
||||
return (str(self), self.name)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
# This ensures that we only called `__str__` lazily,
|
||||
# as in some cases calling __str__ on a model instances *might*
|
||||
# involve a database lookup.
|
||||
return str(self.obj)
|
||||
|
||||
is_hyperlink = True
|
||||
|
||||
|
||||
class PKOnlyObject:
|
||||
"""
|
||||
This is a mock object, used for when we only need the pk of the object
|
||||
instance, but still want to return an object with a .pk attribute,
|
||||
in order to keep the same interface as a regular model instance.
|
||||
"""
|
||||
def __init__(self, pk):
|
||||
self.pk = pk
|
||||
|
||||
def __str__(self):
|
||||
return "%s" % self.pk
|
||||
|
||||
|
||||
# We assume that 'validators' are intended for the child serializer,
|
||||
# rather than the parent serializer.
|
||||
MANY_RELATION_KWARGS = (
|
||||
'read_only', 'write_only', 'required', 'default', 'initial', 'source',
|
||||
'label', 'help_text', 'style', 'error_messages', 'allow_empty',
|
||||
'html_cutoff', 'html_cutoff_text'
|
||||
)
|
||||
|
||||
|
||||
class RelatedField(Field):
|
||||
queryset = None
|
||||
html_cutoff = None
|
||||
html_cutoff_text = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.queryset = kwargs.pop('queryset', self.queryset)
|
||||
|
||||
cutoff_from_settings = api_settings.HTML_SELECT_CUTOFF
|
||||
if cutoff_from_settings is not None:
|
||||
cutoff_from_settings = int(cutoff_from_settings)
|
||||
self.html_cutoff = kwargs.pop('html_cutoff', cutoff_from_settings)
|
||||
|
||||
self.html_cutoff_text = kwargs.pop(
|
||||
'html_cutoff_text',
|
||||
self.html_cutoff_text or _(api_settings.HTML_SELECT_CUTOFF_TEXT)
|
||||
)
|
||||
if not method_overridden('get_queryset', RelatedField, self):
|
||||
assert self.queryset is not None or kwargs.get('read_only'), (
|
||||
'Relational field must provide a `queryset` argument, '
|
||||
'override `get_queryset`, or set read_only=`True`.'
|
||||
)
|
||||
assert not (self.queryset is not None and kwargs.get('read_only')), (
|
||||
'Relational fields should not provide a `queryset` argument, '
|
||||
'when setting read_only=`True`.'
|
||||
)
|
||||
kwargs.pop('many', None)
|
||||
kwargs.pop('allow_empty', None)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# We override this method in order to automagically create
|
||||
# `ManyRelatedField` classes instead when `many=True` is set.
|
||||
if kwargs.pop('many', False):
|
||||
return cls.many_init(*args, **kwargs)
|
||||
return super().__new__(cls, *args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def many_init(cls, *args, **kwargs):
|
||||
"""
|
||||
This method handles creating a parent `ManyRelatedField` instance
|
||||
when the `many=True` keyword argument is passed.
|
||||
|
||||
Typically you won't need to override this method.
|
||||
|
||||
Note that we're over-cautious in passing most arguments to both parent
|
||||
and child classes in order to try to cover the general case. If you're
|
||||
overriding this method you'll probably want something much simpler, eg:
|
||||
|
||||
@classmethod
|
||||
def many_init(cls, *args, **kwargs):
|
||||
kwargs['child'] = cls()
|
||||
return CustomManyRelatedField(*args, **kwargs)
|
||||
"""
|
||||
list_kwargs = {'child_relation': cls(*args, **kwargs)}
|
||||
for key in kwargs:
|
||||
if key in MANY_RELATION_KWARGS:
|
||||
list_kwargs[key] = kwargs[key]
|
||||
return ManyRelatedField(**list_kwargs)
|
||||
|
||||
def run_validation(self, data=empty):
|
||||
# We force empty strings to None values for relational fields.
|
||||
if data == '':
|
||||
data = None
|
||||
return super().run_validation(data)
|
||||
|
||||
def get_queryset(self):
|
||||
queryset = self.queryset
|
||||
if isinstance(queryset, (QuerySet, Manager)):
|
||||
# Ensure queryset is re-evaluated whenever used.
|
||||
# Note that actually a `Manager` class may also be used as the
|
||||
# queryset argument. This occurs on ModelSerializer fields,
|
||||
# as it allows us to generate a more expressive 'repr' output
|
||||
# for the field.
|
||||
# Eg: 'MyRelationship(queryset=ExampleModel.objects.all())'
|
||||
queryset = queryset.all()
|
||||
return queryset
|
||||
|
||||
def use_pk_only_optimization(self):
|
||||
return False
|
||||
|
||||
def get_attribute(self, instance):
|
||||
if self.use_pk_only_optimization() and self.source_attrs:
|
||||
# Optimized case, return a mock object only containing the pk attribute.
|
||||
try:
|
||||
attribute_instance = get_attribute(instance, self.source_attrs[:-1])
|
||||
value = attribute_instance.serializable_value(self.source_attrs[-1])
|
||||
if is_simple_callable(value):
|
||||
# Handle edge case where the relationship `source` argument
|
||||
# points to a `get_relationship()` method on the model.
|
||||
value = value()
|
||||
|
||||
# Handle edge case where relationship `source` argument points
|
||||
# to an instance instead of a pk (e.g., a `@property`).
|
||||
value = getattr(value, 'pk', value)
|
||||
|
||||
return PKOnlyObject(pk=value)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
# Standard case, return the object instance.
|
||||
return super().get_attribute(instance)
|
||||
|
||||
def get_choices(self, cutoff=None):
|
||||
queryset = self.get_queryset()
|
||||
if queryset is None:
|
||||
# Ensure that field.choices returns something sensible
|
||||
# even when accessed with a read-only field.
|
||||
return {}
|
||||
|
||||
if cutoff is not None:
|
||||
queryset = queryset[:cutoff]
|
||||
|
||||
return OrderedDict([
|
||||
(
|
||||
self.to_representation(item),
|
||||
self.display_value(item)
|
||||
)
|
||||
for item in queryset
|
||||
])
|
||||
|
||||
@property
|
||||
def choices(self):
|
||||
return self.get_choices()
|
||||
|
||||
@property
|
||||
def grouped_choices(self):
|
||||
return self.choices
|
||||
|
||||
def iter_options(self):
|
||||
return iter_options(
|
||||
self.get_choices(cutoff=self.html_cutoff),
|
||||
cutoff=self.html_cutoff,
|
||||
cutoff_text=self.html_cutoff_text
|
||||
)
|
||||
|
||||
def display_value(self, instance):
|
||||
return str(instance)
|
||||
|
||||
|
||||
class StringRelatedField(RelatedField):
|
||||
"""
|
||||
A read only field that represents its targets using their
|
||||
plain string representation.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs['read_only'] = True
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def to_representation(self, value):
|
||||
return str(value)
|
||||
|
||||
|
||||
class PrimaryKeyRelatedField(RelatedField):
|
||||
default_error_messages = {
|
||||
'required': _('This field is required.'),
|
||||
'does_not_exist': _('Invalid pk "{pk_value}" - object does not exist.'),
|
||||
'incorrect_type': _('Incorrect type. Expected pk value, received {data_type}.'),
|
||||
}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.pk_field = kwargs.pop('pk_field', None)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def use_pk_only_optimization(self):
|
||||
return True
|
||||
|
||||
def to_internal_value(self, data):
|
||||
if self.pk_field is not None:
|
||||
data = self.pk_field.to_internal_value(data)
|
||||
queryset = self.get_queryset()
|
||||
try:
|
||||
if isinstance(data, bool):
|
||||
raise TypeError
|
||||
return queryset.get(pk=data)
|
||||
except ObjectDoesNotExist:
|
||||
self.fail('does_not_exist', pk_value=data)
|
||||
except (TypeError, ValueError):
|
||||
self.fail('incorrect_type', data_type=type(data).__name__)
|
||||
|
||||
def to_representation(self, value):
|
||||
if self.pk_field is not None:
|
||||
return self.pk_field.to_representation(value.pk)
|
||||
return value.pk
|
||||
|
||||
|
||||
class HyperlinkedRelatedField(RelatedField):
|
||||
lookup_field = 'pk'
|
||||
view_name = None
|
||||
|
||||
default_error_messages = {
|
||||
'required': _('This field is required.'),
|
||||
'no_match': _('Invalid hyperlink - No URL match.'),
|
||||
'incorrect_match': _('Invalid hyperlink - Incorrect URL match.'),
|
||||
'does_not_exist': _('Invalid hyperlink - Object does not exist.'),
|
||||
'incorrect_type': _('Incorrect type. Expected URL string, received {data_type}.'),
|
||||
}
|
||||
|
||||
def __init__(self, view_name=None, **kwargs):
|
||||
if view_name is not None:
|
||||
self.view_name = view_name
|
||||
assert self.view_name is not None, 'The `view_name` argument is required.'
|
||||
self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
|
||||
self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field)
|
||||
self.format = kwargs.pop('format', None)
|
||||
|
||||
# We include this simply for dependency injection in tests.
|
||||
# We can't add it as a class attributes or it would expect an
|
||||
# implicit `self` argument to be passed.
|
||||
self.reverse = reverse
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def use_pk_only_optimization(self):
|
||||
return self.lookup_field == 'pk'
|
||||
|
||||
def get_object(self, view_name, view_args, view_kwargs):
|
||||
"""
|
||||
Return the object corresponding to a matched URL.
|
||||
|
||||
Takes the matched URL conf arguments, and should return an
|
||||
object instance, or raise an `ObjectDoesNotExist` exception.
|
||||
"""
|
||||
lookup_value = view_kwargs[self.lookup_url_kwarg]
|
||||
lookup_kwargs = {self.lookup_field: lookup_value}
|
||||
queryset = self.get_queryset()
|
||||
|
||||
try:
|
||||
return queryset.get(**lookup_kwargs)
|
||||
except ValueError:
|
||||
exc = ObjectValueError(str(sys.exc_info()[1]))
|
||||
raise exc.with_traceback(sys.exc_info()[2])
|
||||
except TypeError:
|
||||
exc = ObjectTypeError(str(sys.exc_info()[1]))
|
||||
raise exc.with_traceback(sys.exc_info()[2])
|
||||
|
||||
def get_url(self, obj, view_name, request, format):
|
||||
"""
|
||||
Given an object, return the URL that hyperlinks to the object.
|
||||
|
||||
May raise a `NoReverseMatch` if the `view_name` and `lookup_field`
|
||||
attributes are not configured to correctly match the URL conf.
|
||||
"""
|
||||
# Unsaved objects will not yet have a valid URL.
|
||||
if hasattr(obj, 'pk') and obj.pk in (None, ''):
|
||||
return None
|
||||
|
||||
lookup_value = getattr(obj, self.lookup_field)
|
||||
kwargs = {self.lookup_url_kwarg: lookup_value}
|
||||
return self.reverse(view_name, kwargs=kwargs, request=request, format=format)
|
||||
|
||||
def to_internal_value(self, data):
|
||||
request = self.context.get('request')
|
||||
try:
|
||||
http_prefix = data.startswith(('http:', 'https:'))
|
||||
except AttributeError:
|
||||
self.fail('incorrect_type', data_type=type(data).__name__)
|
||||
|
||||
if http_prefix:
|
||||
# If needed convert absolute URLs to relative path
|
||||
data = parse.urlparse(data).path
|
||||
prefix = get_script_prefix()
|
||||
if data.startswith(prefix):
|
||||
data = '/' + data[len(prefix):]
|
||||
|
||||
data = uri_to_iri(parse.unquote(data))
|
||||
|
||||
try:
|
||||
match = resolve(data)
|
||||
except Resolver404:
|
||||
self.fail('no_match')
|
||||
|
||||
try:
|
||||
expected_viewname = request.versioning_scheme.get_versioned_viewname(
|
||||
self.view_name, request
|
||||
)
|
||||
except AttributeError:
|
||||
expected_viewname = self.view_name
|
||||
|
||||
if match.view_name != expected_viewname:
|
||||
self.fail('incorrect_match')
|
||||
|
||||
try:
|
||||
return self.get_object(match.view_name, match.args, match.kwargs)
|
||||
except (ObjectDoesNotExist, ObjectValueError, ObjectTypeError):
|
||||
self.fail('does_not_exist')
|
||||
|
||||
def to_representation(self, value):
|
||||
assert 'request' in self.context, (
|
||||
"`%s` requires the request in the serializer"
|
||||
" context. Add `context={'request': request}` when instantiating "
|
||||
"the serializer." % self.__class__.__name__
|
||||
)
|
||||
|
||||
request = self.context['request']
|
||||
format = self.context.get('format')
|
||||
|
||||
# By default use whatever format is given for the current context
|
||||
# unless the target is a different type to the source.
|
||||
#
|
||||
# Eg. Consider a HyperlinkedIdentityField pointing from a json
|
||||
# representation to an html property of that representation...
|
||||
#
|
||||
# '/snippets/1/' should link to '/snippets/1/highlight/'
|
||||
# ...but...
|
||||
# '/snippets/1/.json' should link to '/snippets/1/highlight/.html'
|
||||
if format and self.format and self.format != format:
|
||||
format = self.format
|
||||
|
||||
# Return the hyperlink, or error if incorrectly configured.
|
||||
try:
|
||||
url = self.get_url(value, self.view_name, request, format)
|
||||
except NoReverseMatch:
|
||||
msg = (
|
||||
'Could not resolve URL for hyperlinked relationship using '
|
||||
'view name "%s". You may have failed to include the related '
|
||||
'model in your API, or incorrectly configured the '
|
||||
'`lookup_field` attribute on this field.'
|
||||
)
|
||||
if value in ('', None):
|
||||
value_string = {'': 'the empty string', None: 'None'}[value]
|
||||
msg += (
|
||||
" WARNING: The value of the field on the model instance "
|
||||
"was %s, which may be why it didn't match any "
|
||||
"entries in your URL conf." % value_string
|
||||
)
|
||||
raise ImproperlyConfigured(msg % self.view_name)
|
||||
|
||||
if url is None:
|
||||
return None
|
||||
|
||||
return Hyperlink(url, value)
|
||||
|
||||
|
||||
class HyperlinkedIdentityField(HyperlinkedRelatedField):
|
||||
"""
|
||||
A read-only field that represents the identity URL for an object, itself.
|
||||
|
||||
This is in contrast to `HyperlinkedRelatedField` which represents the
|
||||
URL of relationships to other objects.
|
||||
"""
|
||||
|
||||
def __init__(self, view_name=None, **kwargs):
|
||||
assert view_name is not None, 'The `view_name` argument is required.'
|
||||
kwargs['read_only'] = True
|
||||
kwargs['source'] = '*'
|
||||
super().__init__(view_name, **kwargs)
|
||||
|
||||
def use_pk_only_optimization(self):
|
||||
# We have the complete object instance already. We don't need
|
||||
# to run the 'only get the pk for this relationship' code.
|
||||
return False
|
||||
|
||||
|
||||
class SlugRelatedField(RelatedField):
|
||||
"""
|
||||
A read-write field that represents the target of the relationship
|
||||
by a unique 'slug' attribute.
|
||||
"""
|
||||
default_error_messages = {
|
||||
'does_not_exist': _('Object with {slug_name}={value} does not exist.'),
|
||||
'invalid': _('Invalid value.'),
|
||||
}
|
||||
|
||||
def __init__(self, slug_field=None, **kwargs):
|
||||
assert slug_field is not None, 'The `slug_field` argument is required.'
|
||||
self.slug_field = slug_field
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def to_internal_value(self, data):
|
||||
queryset = self.get_queryset()
|
||||
try:
|
||||
return queryset.get(**{self.slug_field: data})
|
||||
except ObjectDoesNotExist:
|
||||
self.fail('does_not_exist', slug_name=self.slug_field, value=smart_str(data))
|
||||
except (TypeError, ValueError):
|
||||
self.fail('invalid')
|
||||
|
||||
def to_representation(self, obj):
|
||||
return getattr(obj, self.slug_field)
|
||||
|
||||
|
||||
class ManyRelatedField(Field):
|
||||
"""
|
||||
Relationships with `many=True` transparently get coerced into instead being
|
||||
a ManyRelatedField with a child relationship.
|
||||
|
||||
The `ManyRelatedField` class is responsible for handling iterating through
|
||||
the values and passing each one to the child relationship.
|
||||
|
||||
This class is treated as private API.
|
||||
You shouldn't generally need to be using this class directly yourself,
|
||||
and should instead simply set 'many=True' on the relationship.
|
||||
"""
|
||||
initial = []
|
||||
default_empty_html = []
|
||||
default_error_messages = {
|
||||
'not_a_list': _('Expected a list of items but got type "{input_type}".'),
|
||||
'empty': _('This list may not be empty.')
|
||||
}
|
||||
html_cutoff = None
|
||||
html_cutoff_text = None
|
||||
|
||||
def __init__(self, child_relation=None, *args, **kwargs):
|
||||
self.child_relation = child_relation
|
||||
self.allow_empty = kwargs.pop('allow_empty', True)
|
||||
|
||||
cutoff_from_settings = api_settings.HTML_SELECT_CUTOFF
|
||||
if cutoff_from_settings is not None:
|
||||
cutoff_from_settings = int(cutoff_from_settings)
|
||||
self.html_cutoff = kwargs.pop('html_cutoff', cutoff_from_settings)
|
||||
|
||||
self.html_cutoff_text = kwargs.pop(
|
||||
'html_cutoff_text',
|
||||
self.html_cutoff_text or _(api_settings.HTML_SELECT_CUTOFF_TEXT)
|
||||
)
|
||||
assert child_relation is not None, '`child_relation` is a required argument.'
|
||||
super().__init__(*args, **kwargs)
|
||||
self.child_relation.bind(field_name='', parent=self)
|
||||
|
||||
def get_value(self, dictionary):
|
||||
# We override the default field access in order to support
|
||||
# lists in HTML forms.
|
||||
if html.is_html_input(dictionary):
|
||||
# Don't return [] if the update is partial
|
||||
if self.field_name not in dictionary:
|
||||
if getattr(self.root, 'partial', False):
|
||||
return empty
|
||||
return dictionary.getlist(self.field_name)
|
||||
|
||||
return dictionary.get(self.field_name, empty)
|
||||
|
||||
def to_internal_value(self, data):
|
||||
if isinstance(data, str) or not hasattr(data, '__iter__'):
|
||||
self.fail('not_a_list', input_type=type(data).__name__)
|
||||
if not self.allow_empty and len(data) == 0:
|
||||
self.fail('empty')
|
||||
|
||||
return [
|
||||
self.child_relation.to_internal_value(item)
|
||||
for item in data
|
||||
]
|
||||
|
||||
def get_attribute(self, instance):
|
||||
# Can't have any relationships if not created
|
||||
if hasattr(instance, 'pk') and instance.pk is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
relationship = get_attribute(instance, self.source_attrs)
|
||||
except (KeyError, AttributeError) as exc:
|
||||
if self.default is not empty:
|
||||
return self.get_default()
|
||||
if self.allow_null:
|
||||
return None
|
||||
if not self.required:
|
||||
raise SkipField()
|
||||
msg = (
|
||||
'Got {exc_type} when attempting to get a value for field '
|
||||
'`{field}` on serializer `{serializer}`.\nThe serializer '
|
||||
'field might be named incorrectly and not match '
|
||||
'any attribute or key on the `{instance}` instance.\n'
|
||||
'Original exception text was: {exc}.'.format(
|
||||
exc_type=type(exc).__name__,
|
||||
field=self.field_name,
|
||||
serializer=self.parent.__class__.__name__,
|
||||
instance=instance.__class__.__name__,
|
||||
exc=exc
|
||||
)
|
||||
)
|
||||
raise type(exc)(msg)
|
||||
|
||||
return relationship.all() if hasattr(relationship, 'all') else relationship
|
||||
|
||||
def to_representation(self, iterable):
|
||||
return [
|
||||
self.child_relation.to_representation(value)
|
||||
for value in iterable
|
||||
]
|
||||
|
||||
def get_choices(self, cutoff=None):
|
||||
return self.child_relation.get_choices(cutoff)
|
||||
|
||||
@property
|
||||
def choices(self):
|
||||
return self.get_choices()
|
||||
|
||||
@property
|
||||
def grouped_choices(self):
|
||||
return self.choices
|
||||
|
||||
def iter_options(self):
|
||||
return iter_options(
|
||||
self.get_choices(cutoff=self.html_cutoff),
|
||||
cutoff=self.html_cutoff,
|
||||
cutoff_text=self.html_cutoff_text
|
||||
)
|
1075
srcs/.venv/lib/python3.11/site-packages/rest_framework/renderers.py
Normal file
1075
srcs/.venv/lib/python3.11/site-packages/rest_framework/renderers.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,455 @@
|
||||
"""
|
||||
The Request class is used as a wrapper around the standard request object.
|
||||
|
||||
The wrapped request then offers a richer API, in particular :
|
||||
|
||||
- content automatically parsed according to `Content-Type` header,
|
||||
and available as `request.data`
|
||||
- full support of PUT method, including support for file uploads
|
||||
- form overloading of HTTP method, content type and content
|
||||
"""
|
||||
import io
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
|
||||
from django.conf import settings
|
||||
from django.http import HttpRequest, QueryDict
|
||||
from django.http.request import RawPostDataException
|
||||
from django.utils.datastructures import MultiValueDict
|
||||
|
||||
from rest_framework import exceptions
|
||||
from rest_framework.compat import parse_header_parameters
|
||||
from rest_framework.settings import api_settings
|
||||
|
||||
|
||||
def is_form_media_type(media_type):
|
||||
"""
|
||||
Return True if the media type is a valid form media type.
|
||||
"""
|
||||
base_media_type, params = parse_header_parameters(media_type)
|
||||
return (base_media_type == 'application/x-www-form-urlencoded' or
|
||||
base_media_type == 'multipart/form-data')
|
||||
|
||||
|
||||
class override_method:
|
||||
"""
|
||||
A context manager that temporarily overrides the method on a request,
|
||||
additionally setting the `view.request` attribute.
|
||||
|
||||
Usage:
|
||||
|
||||
with override_method(view, request, 'POST') as request:
|
||||
... # Do stuff with `view` and `request`
|
||||
"""
|
||||
|
||||
def __init__(self, view, request, method):
|
||||
self.view = view
|
||||
self.request = request
|
||||
self.method = method
|
||||
self.action = getattr(view, 'action', None)
|
||||
|
||||
def __enter__(self):
|
||||
self.view.request = clone_request(self.request, self.method)
|
||||
# For viewsets we also set the `.action` attribute.
|
||||
action_map = getattr(self.view, 'action_map', {})
|
||||
self.view.action = action_map.get(self.method.lower())
|
||||
return self.view.request
|
||||
|
||||
def __exit__(self, *args, **kwarg):
|
||||
self.view.request = self.request
|
||||
self.view.action = self.action
|
||||
|
||||
|
||||
class WrappedAttributeError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def wrap_attributeerrors():
|
||||
"""
|
||||
Used to re-raise AttributeErrors caught during authentication, preventing
|
||||
these errors from otherwise being handled by the attribute access protocol.
|
||||
"""
|
||||
try:
|
||||
yield
|
||||
except AttributeError:
|
||||
info = sys.exc_info()
|
||||
exc = WrappedAttributeError(str(info[1]))
|
||||
raise exc.with_traceback(info[2])
|
||||
|
||||
|
||||
class Empty:
|
||||
"""
|
||||
Placeholder for unset attributes.
|
||||
Cannot use `None`, as that may be a valid value.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def _hasattr(obj, name):
|
||||
return not getattr(obj, name) is Empty
|
||||
|
||||
|
||||
def clone_request(request, method):
|
||||
"""
|
||||
Internal helper method to clone a request, replacing with a different
|
||||
HTTP method. Used for checking permissions against other methods.
|
||||
"""
|
||||
ret = Request(request=request._request,
|
||||
parsers=request.parsers,
|
||||
authenticators=request.authenticators,
|
||||
negotiator=request.negotiator,
|
||||
parser_context=request.parser_context)
|
||||
ret._data = request._data
|
||||
ret._files = request._files
|
||||
ret._full_data = request._full_data
|
||||
ret._content_type = request._content_type
|
||||
ret._stream = request._stream
|
||||
ret.method = method
|
||||
if hasattr(request, '_user'):
|
||||
ret._user = request._user
|
||||
if hasattr(request, '_auth'):
|
||||
ret._auth = request._auth
|
||||
if hasattr(request, '_authenticator'):
|
||||
ret._authenticator = request._authenticator
|
||||
if hasattr(request, 'accepted_renderer'):
|
||||
ret.accepted_renderer = request.accepted_renderer
|
||||
if hasattr(request, 'accepted_media_type'):
|
||||
ret.accepted_media_type = request.accepted_media_type
|
||||
if hasattr(request, 'version'):
|
||||
ret.version = request.version
|
||||
if hasattr(request, 'versioning_scheme'):
|
||||
ret.versioning_scheme = request.versioning_scheme
|
||||
return ret
|
||||
|
||||
|
||||
class ForcedAuthentication:
|
||||
"""
|
||||
This authentication class is used if the test client or request factory
|
||||
forcibly authenticated the request.
|
||||
"""
|
||||
|
||||
def __init__(self, force_user, force_token):
|
||||
self.force_user = force_user
|
||||
self.force_token = force_token
|
||||
|
||||
def authenticate(self, request):
|
||||
return (self.force_user, self.force_token)
|
||||
|
||||
|
||||
class Request:
|
||||
"""
|
||||
Wrapper allowing to enhance a standard `HttpRequest` instance.
|
||||
|
||||
Kwargs:
|
||||
- request(HttpRequest). The original request instance.
|
||||
- parsers(list/tuple). The parsers to use for parsing the
|
||||
request content.
|
||||
- authenticators(list/tuple). The authenticators used to try
|
||||
authenticating the request's user.
|
||||
"""
|
||||
|
||||
def __init__(self, request, parsers=None, authenticators=None,
|
||||
negotiator=None, parser_context=None):
|
||||
assert isinstance(request, HttpRequest), (
|
||||
'The `request` argument must be an instance of '
|
||||
'`django.http.HttpRequest`, not `{}.{}`.'
|
||||
.format(request.__class__.__module__, request.__class__.__name__)
|
||||
)
|
||||
|
||||
self._request = request
|
||||
self.parsers = parsers or ()
|
||||
self.authenticators = authenticators or ()
|
||||
self.negotiator = negotiator or self._default_negotiator()
|
||||
self.parser_context = parser_context
|
||||
self._data = Empty
|
||||
self._files = Empty
|
||||
self._full_data = Empty
|
||||
self._content_type = Empty
|
||||
self._stream = Empty
|
||||
|
||||
if self.parser_context is None:
|
||||
self.parser_context = {}
|
||||
self.parser_context['request'] = self
|
||||
self.parser_context['encoding'] = request.encoding or settings.DEFAULT_CHARSET
|
||||
|
||||
force_user = getattr(request, '_force_auth_user', None)
|
||||
force_token = getattr(request, '_force_auth_token', None)
|
||||
if force_user is not None or force_token is not None:
|
||||
forced_auth = ForcedAuthentication(force_user, force_token)
|
||||
self.authenticators = (forced_auth,)
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s.%s: %s %r>' % (
|
||||
self.__class__.__module__,
|
||||
self.__class__.__name__,
|
||||
self.method,
|
||||
self.get_full_path())
|
||||
|
||||
def _default_negotiator(self):
|
||||
return api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS()
|
||||
|
||||
@property
|
||||
def content_type(self):
|
||||
meta = self._request.META
|
||||
return meta.get('CONTENT_TYPE', meta.get('HTTP_CONTENT_TYPE', ''))
|
||||
|
||||
@property
|
||||
def stream(self):
|
||||
"""
|
||||
Returns an object that may be used to stream the request content.
|
||||
"""
|
||||
if not _hasattr(self, '_stream'):
|
||||
self._load_stream()
|
||||
return self._stream
|
||||
|
||||
@property
|
||||
def query_params(self):
|
||||
"""
|
||||
More semantically correct name for request.GET.
|
||||
"""
|
||||
return self._request.GET
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
if not _hasattr(self, '_full_data'):
|
||||
self._load_data_and_files()
|
||||
return self._full_data
|
||||
|
||||
@property
|
||||
def user(self):
|
||||
"""
|
||||
Returns the user associated with the current request, as authenticated
|
||||
by the authentication classes provided to the request.
|
||||
"""
|
||||
if not hasattr(self, '_user'):
|
||||
with wrap_attributeerrors():
|
||||
self._authenticate()
|
||||
return self._user
|
||||
|
||||
@user.setter
|
||||
def user(self, value):
|
||||
"""
|
||||
Sets the user on the current request. This is necessary to maintain
|
||||
compatibility with django.contrib.auth where the user property is
|
||||
set in the login and logout functions.
|
||||
|
||||
Note that we also set the user on Django's underlying `HttpRequest`
|
||||
instance, ensuring that it is available to any middleware in the stack.
|
||||
"""
|
||||
self._user = value
|
||||
self._request.user = value
|
||||
|
||||
@property
|
||||
def auth(self):
|
||||
"""
|
||||
Returns any non-user authentication information associated with the
|
||||
request, such as an authentication token.
|
||||
"""
|
||||
if not hasattr(self, '_auth'):
|
||||
with wrap_attributeerrors():
|
||||
self._authenticate()
|
||||
return self._auth
|
||||
|
||||
@auth.setter
|
||||
def auth(self, value):
|
||||
"""
|
||||
Sets any non-user authentication information associated with the
|
||||
request, such as an authentication token.
|
||||
"""
|
||||
self._auth = value
|
||||
self._request.auth = value
|
||||
|
||||
@property
|
||||
def successful_authenticator(self):
|
||||
"""
|
||||
Return the instance of the authentication instance class that was used
|
||||
to authenticate the request, or `None`.
|
||||
"""
|
||||
if not hasattr(self, '_authenticator'):
|
||||
with wrap_attributeerrors():
|
||||
self._authenticate()
|
||||
return self._authenticator
|
||||
|
||||
def _load_data_and_files(self):
|
||||
"""
|
||||
Parses the request content into `self.data`.
|
||||
"""
|
||||
if not _hasattr(self, '_data'):
|
||||
self._data, self._files = self._parse()
|
||||
if self._files:
|
||||
self._full_data = self._data.copy()
|
||||
self._full_data.update(self._files)
|
||||
else:
|
||||
self._full_data = self._data
|
||||
|
||||
# if a form media type, copy data & files refs to the underlying
|
||||
# http request so that closable objects are handled appropriately.
|
||||
if is_form_media_type(self.content_type):
|
||||
self._request._post = self.POST
|
||||
self._request._files = self.FILES
|
||||
|
||||
def _load_stream(self):
|
||||
"""
|
||||
Return the content body of the request, as a stream.
|
||||
"""
|
||||
meta = self._request.META
|
||||
try:
|
||||
content_length = int(
|
||||
meta.get('CONTENT_LENGTH', meta.get('HTTP_CONTENT_LENGTH', 0))
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
content_length = 0
|
||||
|
||||
if content_length == 0:
|
||||
self._stream = None
|
||||
elif not self._request._read_started:
|
||||
self._stream = self._request
|
||||
else:
|
||||
self._stream = io.BytesIO(self.body)
|
||||
|
||||
def _supports_form_parsing(self):
|
||||
"""
|
||||
Return True if this requests supports parsing form data.
|
||||
"""
|
||||
form_media = (
|
||||
'application/x-www-form-urlencoded',
|
||||
'multipart/form-data'
|
||||
)
|
||||
return any(parser.media_type in form_media for parser in self.parsers)
|
||||
|
||||
def _parse(self):
|
||||
"""
|
||||
Parse the request content, returning a two-tuple of (data, files)
|
||||
|
||||
May raise an `UnsupportedMediaType`, or `ParseError` exception.
|
||||
"""
|
||||
media_type = self.content_type
|
||||
try:
|
||||
stream = self.stream
|
||||
except RawPostDataException:
|
||||
if not hasattr(self._request, '_post'):
|
||||
raise
|
||||
# If request.POST has been accessed in middleware, and a method='POST'
|
||||
# request was made with 'multipart/form-data', then the request stream
|
||||
# will already have been exhausted.
|
||||
if self._supports_form_parsing():
|
||||
return (self._request.POST, self._request.FILES)
|
||||
stream = None
|
||||
|
||||
if stream is None or media_type is None:
|
||||
if media_type and is_form_media_type(media_type):
|
||||
empty_data = QueryDict('', encoding=self._request._encoding)
|
||||
else:
|
||||
empty_data = {}
|
||||
empty_files = MultiValueDict()
|
||||
return (empty_data, empty_files)
|
||||
|
||||
parser = self.negotiator.select_parser(self, self.parsers)
|
||||
|
||||
if not parser:
|
||||
raise exceptions.UnsupportedMediaType(media_type)
|
||||
|
||||
try:
|
||||
parsed = parser.parse(stream, media_type, self.parser_context)
|
||||
except Exception:
|
||||
# If we get an exception during parsing, fill in empty data and
|
||||
# re-raise. Ensures we don't simply repeat the error when
|
||||
# attempting to render the browsable renderer response, or when
|
||||
# logging the request or similar.
|
||||
self._data = QueryDict('', encoding=self._request._encoding)
|
||||
self._files = MultiValueDict()
|
||||
self._full_data = self._data
|
||||
raise
|
||||
|
||||
# Parser classes may return the raw data, or a
|
||||
# DataAndFiles object. Unpack the result as required.
|
||||
try:
|
||||
return (parsed.data, parsed.files)
|
||||
except AttributeError:
|
||||
empty_files = MultiValueDict()
|
||||
return (parsed, empty_files)
|
||||
|
||||
def _authenticate(self):
|
||||
"""
|
||||
Attempt to authenticate the request using each authentication instance
|
||||
in turn.
|
||||
"""
|
||||
for authenticator in self.authenticators:
|
||||
try:
|
||||
user_auth_tuple = authenticator.authenticate(self)
|
||||
except exceptions.APIException:
|
||||
self._not_authenticated()
|
||||
raise
|
||||
|
||||
if user_auth_tuple is not None:
|
||||
self._authenticator = authenticator
|
||||
self.user, self.auth = user_auth_tuple
|
||||
return
|
||||
|
||||
self._not_authenticated()
|
||||
|
||||
def _not_authenticated(self):
|
||||
"""
|
||||
Set authenticator, user & authtoken representing an unauthenticated request.
|
||||
|
||||
Defaults are None, AnonymousUser & None.
|
||||
"""
|
||||
self._authenticator = None
|
||||
|
||||
if api_settings.UNAUTHENTICATED_USER:
|
||||
self.user = api_settings.UNAUTHENTICATED_USER()
|
||||
else:
|
||||
self.user = None
|
||||
|
||||
if api_settings.UNAUTHENTICATED_TOKEN:
|
||||
self.auth = api_settings.UNAUTHENTICATED_TOKEN()
|
||||
else:
|
||||
self.auth = None
|
||||
|
||||
def __getattr__(self, attr):
|
||||
"""
|
||||
If an attribute does not exist on this instance, then we also attempt
|
||||
to proxy it to the underlying HttpRequest object.
|
||||
"""
|
||||
try:
|
||||
return getattr(self._request, attr)
|
||||
except AttributeError:
|
||||
return self.__getattribute__(attr)
|
||||
|
||||
@property
|
||||
def DATA(self):
|
||||
raise NotImplementedError(
|
||||
'`request.DATA` has been deprecated in favor of `request.data` '
|
||||
'since version 3.0, and has been fully removed as of version 3.2.'
|
||||
)
|
||||
|
||||
@property
|
||||
def POST(self):
|
||||
# Ensure that request.POST uses our request parsing.
|
||||
if not _hasattr(self, '_data'):
|
||||
self._load_data_and_files()
|
||||
if is_form_media_type(self.content_type):
|
||||
return self._data
|
||||
return QueryDict('', encoding=self._request._encoding)
|
||||
|
||||
@property
|
||||
def FILES(self):
|
||||
# Leave this one alone for backwards compat with Django's request.FILES
|
||||
# Different from the other two cases, which are not valid property
|
||||
# names on the WSGIRequest class.
|
||||
if not _hasattr(self, '_files'):
|
||||
self._load_data_and_files()
|
||||
return self._files
|
||||
|
||||
@property
|
||||
def QUERY_PARAMS(self):
|
||||
raise NotImplementedError(
|
||||
'`request.QUERY_PARAMS` has been deprecated in favor of `request.query_params` '
|
||||
'since version 3.0, and has been fully removed as of version 3.2.'
|
||||
)
|
||||
|
||||
def force_plaintext_errors(self, value):
|
||||
# Hack to allow our exception handler to force choice of
|
||||
# plaintext or html error responses.
|
||||
self._request.is_ajax = lambda: value
|
@ -0,0 +1,103 @@
|
||||
"""
|
||||
The Response class in REST framework is similar to HTTPResponse, except that
|
||||
it is initialized with unrendered data, instead of a pre-rendered string.
|
||||
|
||||
The appropriate renderer is called during Django's template response rendering.
|
||||
"""
|
||||
from http.client import responses
|
||||
|
||||
from django.template.response import SimpleTemplateResponse
|
||||
|
||||
from rest_framework.serializers import Serializer
|
||||
|
||||
|
||||
class Response(SimpleTemplateResponse):
|
||||
"""
|
||||
An HttpResponse that allows its data to be rendered into
|
||||
arbitrary media types.
|
||||
"""
|
||||
|
||||
def __init__(self, data=None, status=None,
|
||||
template_name=None, headers=None,
|
||||
exception=False, content_type=None):
|
||||
"""
|
||||
Alters the init arguments slightly.
|
||||
For example, drop 'template_name', and instead use 'data'.
|
||||
|
||||
Setting 'renderer' and 'media_type' will typically be deferred,
|
||||
For example being set automatically by the `APIView`.
|
||||
"""
|
||||
super().__init__(None, status=status)
|
||||
|
||||
if isinstance(data, Serializer):
|
||||
msg = (
|
||||
'You passed a Serializer instance as data, but '
|
||||
'probably meant to pass serialized `.data` or '
|
||||
'`.error`. representation.'
|
||||
)
|
||||
raise AssertionError(msg)
|
||||
|
||||
self.data = data
|
||||
self.template_name = template_name
|
||||
self.exception = exception
|
||||
self.content_type = content_type
|
||||
|
||||
if headers:
|
||||
for name, value in headers.items():
|
||||
self[name] = value
|
||||
|
||||
@property
|
||||
def rendered_content(self):
|
||||
renderer = getattr(self, 'accepted_renderer', None)
|
||||
accepted_media_type = getattr(self, 'accepted_media_type', None)
|
||||
context = getattr(self, 'renderer_context', None)
|
||||
|
||||
assert renderer, ".accepted_renderer not set on Response"
|
||||
assert accepted_media_type, ".accepted_media_type not set on Response"
|
||||
assert context is not None, ".renderer_context not set on Response"
|
||||
context['response'] = self
|
||||
|
||||
media_type = renderer.media_type
|
||||
charset = renderer.charset
|
||||
content_type = self.content_type
|
||||
|
||||
if content_type is None and charset is not None:
|
||||
content_type = "{}; charset={}".format(media_type, charset)
|
||||
elif content_type is None:
|
||||
content_type = media_type
|
||||
self['Content-Type'] = content_type
|
||||
|
||||
ret = renderer.render(self.data, accepted_media_type, context)
|
||||
if isinstance(ret, str):
|
||||
assert charset, (
|
||||
'renderer returned unicode, and did not specify '
|
||||
'a charset value.'
|
||||
)
|
||||
return ret.encode(charset)
|
||||
|
||||
if not ret:
|
||||
del self['Content-Type']
|
||||
|
||||
return ret
|
||||
|
||||
@property
|
||||
def status_text(self):
|
||||
"""
|
||||
Returns reason text corresponding to our HTTP response status code.
|
||||
Provided for convenience.
|
||||
"""
|
||||
return responses.get(self.status_code, '')
|
||||
|
||||
def __getstate__(self):
|
||||
"""
|
||||
Remove attributes from the response that shouldn't be cached.
|
||||
"""
|
||||
state = super().__getstate__()
|
||||
for key in (
|
||||
'accepted_renderer', 'renderer_context', 'resolver_match',
|
||||
'client', 'request', 'json', 'wsgi_request'
|
||||
):
|
||||
if key in state:
|
||||
del state[key]
|
||||
state['_closable_objects'] = []
|
||||
return state
|
@ -0,0 +1,66 @@
|
||||
"""
|
||||
Provide urlresolver functions that return fully qualified URLs or view names
|
||||
"""
|
||||
from django.urls import NoReverseMatch
|
||||
from django.urls import reverse as django_reverse
|
||||
from django.utils.functional import lazy
|
||||
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.utils.urls import replace_query_param
|
||||
|
||||
|
||||
def preserve_builtin_query_params(url, request=None):
|
||||
"""
|
||||
Given an incoming request, and an outgoing URL representation,
|
||||
append the value of any built-in query parameters.
|
||||
"""
|
||||
if request is None:
|
||||
return url
|
||||
|
||||
overrides = [
|
||||
api_settings.URL_FORMAT_OVERRIDE,
|
||||
]
|
||||
|
||||
for param in overrides:
|
||||
if param and (param in request.GET):
|
||||
value = request.GET[param]
|
||||
url = replace_query_param(url, param, value)
|
||||
|
||||
return url
|
||||
|
||||
|
||||
def reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra):
|
||||
"""
|
||||
If versioning is being used then we pass any `reverse` calls through
|
||||
to the versioning scheme instance, so that the resulting URL
|
||||
can be modified if needed.
|
||||
"""
|
||||
scheme = getattr(request, 'versioning_scheme', None)
|
||||
if scheme is not None:
|
||||
try:
|
||||
url = scheme.reverse(viewname, args, kwargs, request, format, **extra)
|
||||
except NoReverseMatch:
|
||||
# In case the versioning scheme reversal fails, fallback to the
|
||||
# default implementation
|
||||
url = _reverse(viewname, args, kwargs, request, format, **extra)
|
||||
else:
|
||||
url = _reverse(viewname, args, kwargs, request, format, **extra)
|
||||
|
||||
return preserve_builtin_query_params(url, request)
|
||||
|
||||
|
||||
def _reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra):
|
||||
"""
|
||||
Same as `django.urls.reverse`, but optionally takes a request
|
||||
and returns a fully qualified URL, using the request to get the base URL.
|
||||
"""
|
||||
if format is not None:
|
||||
kwargs = kwargs or {}
|
||||
kwargs['format'] = format
|
||||
url = django_reverse(viewname, args=args, kwargs=kwargs, **extra)
|
||||
if request:
|
||||
return request.build_absolute_uri(url)
|
||||
return url
|
||||
|
||||
|
||||
reverse_lazy = lazy(reverse, str)
|
@ -0,0 +1,348 @@
|
||||
"""
|
||||
Routers provide a convenient and consistent way of automatically
|
||||
determining the URL conf for your API.
|
||||
|
||||
They are used by simply instantiating a Router class, and then registering
|
||||
all the required ViewSets with that router.
|
||||
|
||||
For example, you might have a `urls.py` that looks something like this:
|
||||
|
||||
router = routers.DefaultRouter()
|
||||
router.register('users', UserViewSet, 'user')
|
||||
router.register('accounts', AccountViewSet, 'account')
|
||||
|
||||
urlpatterns = router.urls
|
||||
"""
|
||||
import itertools
|
||||
from collections import OrderedDict, namedtuple
|
||||
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.urls import NoReverseMatch, re_path
|
||||
|
||||
from rest_framework import views
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.reverse import reverse
|
||||
from rest_framework.schemas import SchemaGenerator
|
||||
from rest_framework.schemas.views import SchemaView
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.urlpatterns import format_suffix_patterns
|
||||
|
||||
Route = namedtuple('Route', ['url', 'mapping', 'name', 'detail', 'initkwargs'])
|
||||
DynamicRoute = namedtuple('DynamicRoute', ['url', 'name', 'detail', 'initkwargs'])
|
||||
|
||||
|
||||
def escape_curly_brackets(url_path):
|
||||
"""
|
||||
Double brackets in regex of url_path for escape string formatting
|
||||
"""
|
||||
return url_path.replace('{', '{{').replace('}', '}}')
|
||||
|
||||
|
||||
def flatten(list_of_lists):
|
||||
"""
|
||||
Takes an iterable of iterables, returns a single iterable containing all items
|
||||
"""
|
||||
return itertools.chain(*list_of_lists)
|
||||
|
||||
|
||||
class BaseRouter:
|
||||
def __init__(self):
|
||||
self.registry = []
|
||||
|
||||
def register(self, prefix, viewset, basename=None):
|
||||
if basename is None:
|
||||
basename = self.get_default_basename(viewset)
|
||||
self.registry.append((prefix, viewset, basename))
|
||||
|
||||
# invalidate the urls cache
|
||||
if hasattr(self, '_urls'):
|
||||
del self._urls
|
||||
|
||||
def get_default_basename(self, viewset):
|
||||
"""
|
||||
If `basename` is not specified, attempt to automatically determine
|
||||
it from the viewset.
|
||||
"""
|
||||
raise NotImplementedError('get_default_basename must be overridden')
|
||||
|
||||
def get_urls(self):
|
||||
"""
|
||||
Return a list of URL patterns, given the registered viewsets.
|
||||
"""
|
||||
raise NotImplementedError('get_urls must be overridden')
|
||||
|
||||
@property
|
||||
def urls(self):
|
||||
if not hasattr(self, '_urls'):
|
||||
self._urls = self.get_urls()
|
||||
return self._urls
|
||||
|
||||
|
||||
class SimpleRouter(BaseRouter):
|
||||
|
||||
routes = [
|
||||
# List route.
|
||||
Route(
|
||||
url=r'^{prefix}{trailing_slash}$',
|
||||
mapping={
|
||||
'get': 'list',
|
||||
'post': 'create'
|
||||
},
|
||||
name='{basename}-list',
|
||||
detail=False,
|
||||
initkwargs={'suffix': 'List'}
|
||||
),
|
||||
# Dynamically generated list routes. Generated using
|
||||
# @action(detail=False) decorator on methods of the viewset.
|
||||
DynamicRoute(
|
||||
url=r'^{prefix}/{url_path}{trailing_slash}$',
|
||||
name='{basename}-{url_name}',
|
||||
detail=False,
|
||||
initkwargs={}
|
||||
),
|
||||
# Detail route.
|
||||
Route(
|
||||
url=r'^{prefix}/{lookup}{trailing_slash}$',
|
||||
mapping={
|
||||
'get': 'retrieve',
|
||||
'put': 'update',
|
||||
'patch': 'partial_update',
|
||||
'delete': 'destroy'
|
||||
},
|
||||
name='{basename}-detail',
|
||||
detail=True,
|
||||
initkwargs={'suffix': 'Instance'}
|
||||
),
|
||||
# Dynamically generated detail routes. Generated using
|
||||
# @action(detail=True) decorator on methods of the viewset.
|
||||
DynamicRoute(
|
||||
url=r'^{prefix}/{lookup}/{url_path}{trailing_slash}$',
|
||||
name='{basename}-{url_name}',
|
||||
detail=True,
|
||||
initkwargs={}
|
||||
),
|
||||
]
|
||||
|
||||
def __init__(self, trailing_slash=True):
|
||||
self.trailing_slash = '/' if trailing_slash else ''
|
||||
super().__init__()
|
||||
|
||||
def get_default_basename(self, viewset):
|
||||
"""
|
||||
If `basename` is not specified, attempt to automatically determine
|
||||
it from the viewset.
|
||||
"""
|
||||
queryset = getattr(viewset, 'queryset', None)
|
||||
|
||||
assert queryset is not None, '`basename` argument not specified, and could ' \
|
||||
'not automatically determine the name from the viewset, as ' \
|
||||
'it does not have a `.queryset` attribute.'
|
||||
|
||||
return queryset.model._meta.object_name.lower()
|
||||
|
||||
def get_routes(self, viewset):
|
||||
"""
|
||||
Augment `self.routes` with any dynamically generated routes.
|
||||
|
||||
Returns a list of the Route namedtuple.
|
||||
"""
|
||||
# converting to list as iterables are good for one pass, known host needs to be checked again and again for
|
||||
# different functions.
|
||||
known_actions = list(flatten([route.mapping.values() for route in self.routes if isinstance(route, Route)]))
|
||||
extra_actions = viewset.get_extra_actions()
|
||||
|
||||
# checking action names against the known actions list
|
||||
not_allowed = [
|
||||
action.__name__ for action in extra_actions
|
||||
if action.__name__ in known_actions
|
||||
]
|
||||
if not_allowed:
|
||||
msg = ('Cannot use the @action decorator on the following '
|
||||
'methods, as they are existing routes: %s')
|
||||
raise ImproperlyConfigured(msg % ', '.join(not_allowed))
|
||||
|
||||
# partition detail and list actions
|
||||
detail_actions = [action for action in extra_actions if action.detail]
|
||||
list_actions = [action for action in extra_actions if not action.detail]
|
||||
|
||||
routes = []
|
||||
for route in self.routes:
|
||||
if isinstance(route, DynamicRoute) and route.detail:
|
||||
routes += [self._get_dynamic_route(route, action) for action in detail_actions]
|
||||
elif isinstance(route, DynamicRoute) and not route.detail:
|
||||
routes += [self._get_dynamic_route(route, action) for action in list_actions]
|
||||
else:
|
||||
routes.append(route)
|
||||
|
||||
return routes
|
||||
|
||||
def _get_dynamic_route(self, route, action):
|
||||
initkwargs = route.initkwargs.copy()
|
||||
initkwargs.update(action.kwargs)
|
||||
|
||||
url_path = escape_curly_brackets(action.url_path)
|
||||
|
||||
return Route(
|
||||
url=route.url.replace('{url_path}', url_path),
|
||||
mapping=action.mapping,
|
||||
name=route.name.replace('{url_name}', action.url_name),
|
||||
detail=route.detail,
|
||||
initkwargs=initkwargs,
|
||||
)
|
||||
|
||||
def get_method_map(self, viewset, method_map):
|
||||
"""
|
||||
Given a viewset, and a mapping of http methods to actions,
|
||||
return a new mapping which only includes any mappings that
|
||||
are actually implemented by the viewset.
|
||||
"""
|
||||
bound_methods = {}
|
||||
for method, action in method_map.items():
|
||||
if hasattr(viewset, action):
|
||||
bound_methods[method] = action
|
||||
return bound_methods
|
||||
|
||||
def get_lookup_regex(self, viewset, lookup_prefix=''):
|
||||
"""
|
||||
Given a viewset, return the portion of URL regex that is used
|
||||
to match against a single instance.
|
||||
|
||||
Note that lookup_prefix is not used directly inside REST rest_framework
|
||||
itself, but is required in order to nicely support nested router
|
||||
implementations, such as drf-nested-routers.
|
||||
|
||||
https://github.com/alanjds/drf-nested-routers
|
||||
"""
|
||||
base_regex = '(?P<{lookup_prefix}{lookup_url_kwarg}>{lookup_value})'
|
||||
# Use `pk` as default field, unset set. Default regex should not
|
||||
# consume `.json` style suffixes and should break at '/' boundaries.
|
||||
lookup_field = getattr(viewset, 'lookup_field', 'pk')
|
||||
lookup_url_kwarg = getattr(viewset, 'lookup_url_kwarg', None) or lookup_field
|
||||
lookup_value = getattr(viewset, 'lookup_value_regex', '[^/.]+')
|
||||
return base_regex.format(
|
||||
lookup_prefix=lookup_prefix,
|
||||
lookup_url_kwarg=lookup_url_kwarg,
|
||||
lookup_value=lookup_value
|
||||
)
|
||||
|
||||
def get_urls(self):
|
||||
"""
|
||||
Use the registered viewsets to generate a list of URL patterns.
|
||||
"""
|
||||
ret = []
|
||||
|
||||
for prefix, viewset, basename in self.registry:
|
||||
lookup = self.get_lookup_regex(viewset)
|
||||
routes = self.get_routes(viewset)
|
||||
|
||||
for route in routes:
|
||||
|
||||
# Only actions which actually exist on the viewset will be bound
|
||||
mapping = self.get_method_map(viewset, route.mapping)
|
||||
if not mapping:
|
||||
continue
|
||||
|
||||
# Build the url pattern
|
||||
regex = route.url.format(
|
||||
prefix=prefix,
|
||||
lookup=lookup,
|
||||
trailing_slash=self.trailing_slash
|
||||
)
|
||||
|
||||
# If there is no prefix, the first part of the url is probably
|
||||
# controlled by project's urls.py and the router is in an app,
|
||||
# so a slash in the beginning will (A) cause Django to give
|
||||
# warnings and (B) generate URLS that will require using '//'.
|
||||
if not prefix and regex[:2] == '^/':
|
||||
regex = '^' + regex[2:]
|
||||
|
||||
initkwargs = route.initkwargs.copy()
|
||||
initkwargs.update({
|
||||
'basename': basename,
|
||||
'detail': route.detail,
|
||||
})
|
||||
|
||||
view = viewset.as_view(mapping, **initkwargs)
|
||||
name = route.name.format(basename=basename)
|
||||
ret.append(re_path(regex, view, name=name))
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
class APIRootView(views.APIView):
|
||||
"""
|
||||
The default basic root view for DefaultRouter
|
||||
"""
|
||||
_ignore_model_permissions = True
|
||||
schema = None # exclude from schema
|
||||
api_root_dict = None
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
# Return a plain {"name": "hyperlink"} response.
|
||||
ret = OrderedDict()
|
||||
namespace = request.resolver_match.namespace
|
||||
for key, url_name in self.api_root_dict.items():
|
||||
if namespace:
|
||||
url_name = namespace + ':' + url_name
|
||||
try:
|
||||
ret[key] = reverse(
|
||||
url_name,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
request=request,
|
||||
format=kwargs.get('format')
|
||||
)
|
||||
except NoReverseMatch:
|
||||
# Don't bail out if eg. no list routes exist, only detail routes.
|
||||
continue
|
||||
|
||||
return Response(ret)
|
||||
|
||||
|
||||
class DefaultRouter(SimpleRouter):
|
||||
"""
|
||||
The default router extends the SimpleRouter, but also adds in a default
|
||||
API root view, and adds format suffix patterns to the URLs.
|
||||
"""
|
||||
include_root_view = True
|
||||
include_format_suffixes = True
|
||||
root_view_name = 'api-root'
|
||||
default_schema_renderers = None
|
||||
APIRootView = APIRootView
|
||||
APISchemaView = SchemaView
|
||||
SchemaGenerator = SchemaGenerator
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if 'root_renderers' in kwargs:
|
||||
self.root_renderers = kwargs.pop('root_renderers')
|
||||
else:
|
||||
self.root_renderers = list(api_settings.DEFAULT_RENDERER_CLASSES)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def get_api_root_view(self, api_urls=None):
|
||||
"""
|
||||
Return a basic root view.
|
||||
"""
|
||||
api_root_dict = OrderedDict()
|
||||
list_name = self.routes[0].name
|
||||
for prefix, viewset, basename in self.registry:
|
||||
api_root_dict[prefix] = list_name.format(basename=basename)
|
||||
|
||||
return self.APIRootView.as_view(api_root_dict=api_root_dict)
|
||||
|
||||
def get_urls(self):
|
||||
"""
|
||||
Generate the list of URL patterns, including a default root view
|
||||
for the API, and appending `.json` style format suffixes.
|
||||
"""
|
||||
urls = super().get_urls()
|
||||
|
||||
if self.include_root_view:
|
||||
view = self.get_api_root_view(api_urls=urls)
|
||||
root_url = re_path(r'^$', view, name=self.root_view_name)
|
||||
urls.append(root_url)
|
||||
|
||||
if self.include_format_suffixes:
|
||||
urls = format_suffix_patterns(urls)
|
||||
|
||||
return urls
|
@ -0,0 +1,58 @@
|
||||
"""
|
||||
rest_framework.schemas
|
||||
|
||||
schemas:
|
||||
__init__.py
|
||||
generators.py # Top-down schema generation
|
||||
inspectors.py # Per-endpoint view introspection
|
||||
utils.py # Shared helper functions
|
||||
views.py # Houses `SchemaView`, `APIView` subclass.
|
||||
|
||||
We expose a minimal "public" API directly from `schemas`. This covers the
|
||||
basic use-cases:
|
||||
|
||||
from rest_framework.schemas import (
|
||||
AutoSchema,
|
||||
ManualSchema,
|
||||
get_schema_view,
|
||||
SchemaGenerator,
|
||||
)
|
||||
|
||||
Other access should target the submodules directly
|
||||
"""
|
||||
from rest_framework.settings import api_settings
|
||||
|
||||
from . import coreapi, openapi
|
||||
from .coreapi import AutoSchema, ManualSchema, SchemaGenerator # noqa
|
||||
from .inspectors import DefaultSchema # noqa
|
||||
|
||||
|
||||
def get_schema_view(
|
||||
title=None, url=None, description=None, urlconf=None, renderer_classes=None,
|
||||
public=False, patterns=None, generator_class=None,
|
||||
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
|
||||
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES,
|
||||
version=None):
|
||||
"""
|
||||
Return a schema view.
|
||||
"""
|
||||
if generator_class is None:
|
||||
if coreapi.is_enabled():
|
||||
generator_class = coreapi.SchemaGenerator
|
||||
else:
|
||||
generator_class = openapi.SchemaGenerator
|
||||
|
||||
generator = generator_class(
|
||||
title=title, url=url, description=description,
|
||||
urlconf=urlconf, patterns=patterns, version=version
|
||||
)
|
||||
|
||||
# Avoid import cycle on APIView
|
||||
from .views import SchemaView
|
||||
return SchemaView.as_view(
|
||||
renderer_classes=renderer_classes,
|
||||
schema_generator=generator,
|
||||
public=public,
|
||||
authentication_classes=authentication_classes,
|
||||
permission_classes=permission_classes,
|
||||
)
|
@ -0,0 +1,616 @@
|
||||
import warnings
|
||||
from collections import Counter, OrderedDict
|
||||
from urllib import parse
|
||||
|
||||
from django.db import models
|
||||
from django.utils.encoding import force_str
|
||||
|
||||
from rest_framework import exceptions, serializers
|
||||
from rest_framework.compat import coreapi, coreschema, uritemplate
|
||||
from rest_framework.settings import api_settings
|
||||
|
||||
from .generators import BaseSchemaGenerator
|
||||
from .inspectors import ViewInspector
|
||||
from .utils import get_pk_description, is_list_view
|
||||
|
||||
|
||||
def common_path(paths):
|
||||
split_paths = [path.strip('/').split('/') for path in paths]
|
||||
s1 = min(split_paths)
|
||||
s2 = max(split_paths)
|
||||
common = s1
|
||||
for i, c in enumerate(s1):
|
||||
if c != s2[i]:
|
||||
common = s1[:i]
|
||||
break
|
||||
return '/' + '/'.join(common)
|
||||
|
||||
|
||||
def is_custom_action(action):
|
||||
return action not in {
|
||||
'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy'
|
||||
}
|
||||
|
||||
|
||||
def distribute_links(obj):
|
||||
for key, value in obj.items():
|
||||
distribute_links(value)
|
||||
|
||||
for preferred_key, link in obj.links:
|
||||
key = obj.get_available_key(preferred_key)
|
||||
obj[key] = link
|
||||
|
||||
|
||||
INSERT_INTO_COLLISION_FMT = """
|
||||
Schema Naming Collision.
|
||||
|
||||
coreapi.Link for URL path {value_url} cannot be inserted into schema.
|
||||
Position conflicts with coreapi.Link for URL path {target_url}.
|
||||
|
||||
Attempted to insert link with keys: {keys}.
|
||||
|
||||
Adjust URLs to avoid naming collision or override `SchemaGenerator.get_keys()`
|
||||
to customise schema structure.
|
||||
"""
|
||||
|
||||
|
||||
class LinkNode(OrderedDict):
|
||||
def __init__(self):
|
||||
self.links = []
|
||||
self.methods_counter = Counter()
|
||||
super().__init__()
|
||||
|
||||
def get_available_key(self, preferred_key):
|
||||
if preferred_key not in self:
|
||||
return preferred_key
|
||||
|
||||
while True:
|
||||
current_val = self.methods_counter[preferred_key]
|
||||
self.methods_counter[preferred_key] += 1
|
||||
|
||||
key = '{}_{}'.format(preferred_key, current_val)
|
||||
if key not in self:
|
||||
return key
|
||||
|
||||
|
||||
def insert_into(target, keys, value):
|
||||
"""
|
||||
Nested dictionary insertion.
|
||||
|
||||
>>> example = {}
|
||||
>>> insert_into(example, ['a', 'b', 'c'], 123)
|
||||
>>> example
|
||||
LinkNode({'a': LinkNode({'b': LinkNode({'c': LinkNode(links=[123])}}})))
|
||||
"""
|
||||
for key in keys[:-1]:
|
||||
if key not in target:
|
||||
target[key] = LinkNode()
|
||||
target = target[key]
|
||||
|
||||
try:
|
||||
target.links.append((keys[-1], value))
|
||||
except TypeError:
|
||||
msg = INSERT_INTO_COLLISION_FMT.format(
|
||||
value_url=value.url,
|
||||
target_url=target.url,
|
||||
keys=keys
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
class SchemaGenerator(BaseSchemaGenerator):
|
||||
"""
|
||||
Original CoreAPI version.
|
||||
"""
|
||||
# Map HTTP methods onto actions.
|
||||
default_mapping = {
|
||||
'get': 'retrieve',
|
||||
'post': 'create',
|
||||
'put': 'update',
|
||||
'patch': 'partial_update',
|
||||
'delete': 'destroy',
|
||||
}
|
||||
|
||||
# Map the method names we use for viewset actions onto external schema names.
|
||||
# These give us names that are more suitable for the external representation.
|
||||
# Set by 'SCHEMA_COERCE_METHOD_NAMES'.
|
||||
coerce_method_names = None
|
||||
|
||||
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None, version=None):
|
||||
assert coreapi, '`coreapi` must be installed for schema support.'
|
||||
assert coreschema, '`coreschema` must be installed for schema support.'
|
||||
|
||||
super().__init__(title, url, description, patterns, urlconf)
|
||||
self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
|
||||
|
||||
def get_links(self, request=None):
|
||||
"""
|
||||
Return a dictionary containing all the links that should be
|
||||
included in the API schema.
|
||||
"""
|
||||
links = LinkNode()
|
||||
|
||||
paths, view_endpoints = self._get_paths_and_endpoints(request)
|
||||
|
||||
# Only generate the path prefix for paths that will be included
|
||||
if not paths:
|
||||
return None
|
||||
prefix = self.determine_path_prefix(paths)
|
||||
|
||||
for path, method, view in view_endpoints:
|
||||
if not self.has_view_permissions(path, method, view):
|
||||
continue
|
||||
link = view.schema.get_link(path, method, base_url=self.url)
|
||||
subpath = path[len(prefix):]
|
||||
keys = self.get_keys(subpath, method, view)
|
||||
insert_into(links, keys, link)
|
||||
|
||||
return links
|
||||
|
||||
def get_schema(self, request=None, public=False):
|
||||
"""
|
||||
Generate a `coreapi.Document` representing the API schema.
|
||||
"""
|
||||
self._initialise_endpoints()
|
||||
|
||||
links = self.get_links(None if public else request)
|
||||
if not links:
|
||||
return None
|
||||
|
||||
url = self.url
|
||||
if not url and request is not None:
|
||||
url = request.build_absolute_uri()
|
||||
|
||||
distribute_links(links)
|
||||
return coreapi.Document(
|
||||
title=self.title, description=self.description,
|
||||
url=url, content=links
|
||||
)
|
||||
|
||||
# Method for generating the link layout....
|
||||
def get_keys(self, subpath, method, view):
|
||||
"""
|
||||
Return a list of keys that should be used to layout a link within
|
||||
the schema document.
|
||||
|
||||
/users/ ("users", "list"), ("users", "create")
|
||||
/users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete")
|
||||
/users/enabled/ ("users", "enabled") # custom viewset list action
|
||||
/users/{pk}/star/ ("users", "star") # custom viewset detail action
|
||||
/users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create")
|
||||
/users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update"), ("users", "groups", "delete")
|
||||
"""
|
||||
if hasattr(view, 'action'):
|
||||
# Viewsets have explicitly named actions.
|
||||
action = view.action
|
||||
else:
|
||||
# Views have no associated action, so we determine one from the method.
|
||||
if is_list_view(subpath, method, view):
|
||||
action = 'list'
|
||||
else:
|
||||
action = self.default_mapping[method.lower()]
|
||||
|
||||
named_path_components = [
|
||||
component for component
|
||||
in subpath.strip('/').split('/')
|
||||
if '{' not in component
|
||||
]
|
||||
|
||||
if is_custom_action(action):
|
||||
# Custom action, eg "/users/{pk}/activate/", "/users/active/"
|
||||
mapped_methods = {
|
||||
# Don't count head mapping, e.g. not part of the schema
|
||||
method for method in view.action_map if method != 'head'
|
||||
}
|
||||
if len(mapped_methods) > 1:
|
||||
action = self.default_mapping[method.lower()]
|
||||
if action in self.coerce_method_names:
|
||||
action = self.coerce_method_names[action]
|
||||
return named_path_components + [action]
|
||||
else:
|
||||
return named_path_components[:-1] + [action]
|
||||
|
||||
if action in self.coerce_method_names:
|
||||
action = self.coerce_method_names[action]
|
||||
|
||||
# Default action, eg "/users/", "/users/{pk}/"
|
||||
return named_path_components + [action]
|
||||
|
||||
def determine_path_prefix(self, paths):
|
||||
"""
|
||||
Given a list of all paths, return the common prefix which should be
|
||||
discounted when generating a schema structure.
|
||||
|
||||
This will be the longest common string that does not include that last
|
||||
component of the URL, or the last component before a path parameter.
|
||||
|
||||
For example:
|
||||
|
||||
/api/v1/users/
|
||||
/api/v1/users/{pk}/
|
||||
|
||||
The path prefix is '/api/v1'
|
||||
"""
|
||||
prefixes = []
|
||||
for path in paths:
|
||||
components = path.strip('/').split('/')
|
||||
initial_components = []
|
||||
for component in components:
|
||||
if '{' in component:
|
||||
break
|
||||
initial_components.append(component)
|
||||
prefix = '/'.join(initial_components[:-1])
|
||||
if not prefix:
|
||||
# We can just break early in the case that there's at least
|
||||
# one URL that doesn't have a path prefix.
|
||||
return '/'
|
||||
prefixes.append('/' + prefix + '/')
|
||||
return common_path(prefixes)
|
||||
|
||||
# View Inspectors #
|
||||
|
||||
|
||||
def field_to_schema(field):
|
||||
title = force_str(field.label) if field.label else ''
|
||||
description = force_str(field.help_text) if field.help_text else ''
|
||||
|
||||
if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
|
||||
child_schema = field_to_schema(field.child)
|
||||
return coreschema.Array(
|
||||
items=child_schema,
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.DictField):
|
||||
return coreschema.Object(
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.Serializer):
|
||||
return coreschema.Object(
|
||||
properties=OrderedDict([
|
||||
(key, field_to_schema(value))
|
||||
for key, value
|
||||
in field.fields.items()
|
||||
]),
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.ManyRelatedField):
|
||||
related_field_schema = field_to_schema(field.child_relation)
|
||||
|
||||
return coreschema.Array(
|
||||
items=related_field_schema,
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.PrimaryKeyRelatedField):
|
||||
schema_cls = coreschema.String
|
||||
model = getattr(field.queryset, 'model', None)
|
||||
if model is not None:
|
||||
model_field = model._meta.pk
|
||||
if isinstance(model_field, models.AutoField):
|
||||
schema_cls = coreschema.Integer
|
||||
return schema_cls(title=title, description=description)
|
||||
elif isinstance(field, serializers.RelatedField):
|
||||
return coreschema.String(title=title, description=description)
|
||||
elif isinstance(field, serializers.MultipleChoiceField):
|
||||
return coreschema.Array(
|
||||
items=coreschema.Enum(enum=list(field.choices)),
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.ChoiceField):
|
||||
return coreschema.Enum(
|
||||
enum=list(field.choices),
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.BooleanField):
|
||||
return coreschema.Boolean(title=title, description=description)
|
||||
elif isinstance(field, (serializers.DecimalField, serializers.FloatField)):
|
||||
return coreschema.Number(title=title, description=description)
|
||||
elif isinstance(field, serializers.IntegerField):
|
||||
return coreschema.Integer(title=title, description=description)
|
||||
elif isinstance(field, serializers.DateField):
|
||||
return coreschema.String(
|
||||
title=title,
|
||||
description=description,
|
||||
format='date'
|
||||
)
|
||||
elif isinstance(field, serializers.DateTimeField):
|
||||
return coreschema.String(
|
||||
title=title,
|
||||
description=description,
|
||||
format='date-time'
|
||||
)
|
||||
elif isinstance(field, serializers.JSONField):
|
||||
return coreschema.Object(title=title, description=description)
|
||||
|
||||
if field.style.get('base_template') == 'textarea.html':
|
||||
return coreschema.String(
|
||||
title=title,
|
||||
description=description,
|
||||
format='textarea'
|
||||
)
|
||||
|
||||
return coreschema.String(title=title, description=description)
|
||||
|
||||
|
||||
class AutoSchema(ViewInspector):
|
||||
"""
|
||||
Default inspector for APIView
|
||||
|
||||
Responsible for per-view introspection and schema generation.
|
||||
"""
|
||||
def __init__(self, manual_fields=None):
|
||||
"""
|
||||
Parameters:
|
||||
|
||||
* `manual_fields`: list of `coreapi.Field` instances that
|
||||
will be added to auto-generated fields, overwriting on `Field.name`
|
||||
"""
|
||||
super().__init__()
|
||||
if manual_fields is None:
|
||||
manual_fields = []
|
||||
self._manual_fields = manual_fields
|
||||
|
||||
def get_link(self, path, method, base_url):
|
||||
"""
|
||||
Generate `coreapi.Link` for self.view, path and method.
|
||||
|
||||
This is the main _public_ access point.
|
||||
|
||||
Parameters:
|
||||
|
||||
* path: Route path for view from URLConf.
|
||||
* method: The HTTP request method.
|
||||
* base_url: The project "mount point" as given to SchemaGenerator
|
||||
"""
|
||||
fields = self.get_path_fields(path, method)
|
||||
fields += self.get_serializer_fields(path, method)
|
||||
fields += self.get_pagination_fields(path, method)
|
||||
fields += self.get_filter_fields(path, method)
|
||||
|
||||
manual_fields = self.get_manual_fields(path, method)
|
||||
fields = self.update_fields(fields, manual_fields)
|
||||
|
||||
if fields and any([field.location in ('form', 'body') for field in fields]):
|
||||
encoding = self.get_encoding(path, method)
|
||||
else:
|
||||
encoding = None
|
||||
|
||||
description = self.get_description(path, method)
|
||||
|
||||
if base_url and path.startswith('/'):
|
||||
path = path[1:]
|
||||
|
||||
return coreapi.Link(
|
||||
url=parse.urljoin(base_url, path),
|
||||
action=method.lower(),
|
||||
encoding=encoding,
|
||||
fields=fields,
|
||||
description=description
|
||||
)
|
||||
|
||||
def get_path_fields(self, path, method):
|
||||
"""
|
||||
Return a list of `coreapi.Field` instances corresponding to any
|
||||
templated path variables.
|
||||
"""
|
||||
view = self.view
|
||||
model = getattr(getattr(view, 'queryset', None), 'model', None)
|
||||
fields = []
|
||||
|
||||
for variable in uritemplate.variables(path):
|
||||
title = ''
|
||||
description = ''
|
||||
schema_cls = coreschema.String
|
||||
kwargs = {}
|
||||
if model is not None:
|
||||
# Attempt to infer a field description if possible.
|
||||
try:
|
||||
model_field = model._meta.get_field(variable)
|
||||
except Exception:
|
||||
model_field = None
|
||||
|
||||
if model_field is not None and model_field.verbose_name:
|
||||
title = force_str(model_field.verbose_name)
|
||||
|
||||
if model_field is not None and model_field.help_text:
|
||||
description = force_str(model_field.help_text)
|
||||
elif model_field is not None and model_field.primary_key:
|
||||
description = get_pk_description(model, model_field)
|
||||
|
||||
if hasattr(view, 'lookup_value_regex') and view.lookup_field == variable:
|
||||
kwargs['pattern'] = view.lookup_value_regex
|
||||
elif isinstance(model_field, models.AutoField):
|
||||
schema_cls = coreschema.Integer
|
||||
|
||||
field = coreapi.Field(
|
||||
name=variable,
|
||||
location='path',
|
||||
required=True,
|
||||
schema=schema_cls(title=title, description=description, **kwargs)
|
||||
)
|
||||
fields.append(field)
|
||||
|
||||
return fields
|
||||
|
||||
def get_serializer_fields(self, path, method):
|
||||
"""
|
||||
Return a list of `coreapi.Field` instances corresponding to any
|
||||
request body input, as determined by the serializer class.
|
||||
"""
|
||||
view = self.view
|
||||
|
||||
if method not in ('PUT', 'PATCH', 'POST'):
|
||||
return []
|
||||
|
||||
if not hasattr(view, 'get_serializer'):
|
||||
return []
|
||||
|
||||
try:
|
||||
serializer = view.get_serializer()
|
||||
except exceptions.APIException:
|
||||
serializer = None
|
||||
warnings.warn('{}.get_serializer() raised an exception during '
|
||||
'schema generation. Serializer fields will not be '
|
||||
'generated for {} {}.'
|
||||
.format(view.__class__.__name__, method, path))
|
||||
|
||||
if isinstance(serializer, serializers.ListSerializer):
|
||||
return [
|
||||
coreapi.Field(
|
||||
name='data',
|
||||
location='body',
|
||||
required=True,
|
||||
schema=coreschema.Array()
|
||||
)
|
||||
]
|
||||
|
||||
if not isinstance(serializer, serializers.Serializer):
|
||||
return []
|
||||
|
||||
fields = []
|
||||
for field in serializer.fields.values():
|
||||
if field.read_only or isinstance(field, serializers.HiddenField):
|
||||
continue
|
||||
|
||||
required = field.required and method != 'PATCH'
|
||||
field = coreapi.Field(
|
||||
name=field.field_name,
|
||||
location='form',
|
||||
required=required,
|
||||
schema=field_to_schema(field)
|
||||
)
|
||||
fields.append(field)
|
||||
|
||||
return fields
|
||||
|
||||
def get_pagination_fields(self, path, method):
|
||||
view = self.view
|
||||
|
||||
if not is_list_view(path, method, view):
|
||||
return []
|
||||
|
||||
pagination = getattr(view, 'pagination_class', None)
|
||||
if not pagination:
|
||||
return []
|
||||
|
||||
paginator = view.pagination_class()
|
||||
return paginator.get_schema_fields(view)
|
||||
|
||||
def _allows_filters(self, path, method):
|
||||
"""
|
||||
Determine whether to include filter Fields in schema.
|
||||
|
||||
Default implementation looks for ModelViewSet or GenericAPIView
|
||||
actions/methods that cause filtering on the default implementation.
|
||||
|
||||
Override to adjust behaviour for your view.
|
||||
|
||||
Note: Introduced in v3.7: Initially "private" (i.e. with leading underscore)
|
||||
to allow changes based on user experience.
|
||||
"""
|
||||
if getattr(self.view, 'filter_backends', None) is None:
|
||||
return False
|
||||
|
||||
if hasattr(self.view, 'action'):
|
||||
return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"]
|
||||
|
||||
return method.lower() in ["get", "put", "patch", "delete"]
|
||||
|
||||
def get_filter_fields(self, path, method):
|
||||
if not self._allows_filters(path, method):
|
||||
return []
|
||||
|
||||
fields = []
|
||||
for filter_backend in self.view.filter_backends:
|
||||
fields += filter_backend().get_schema_fields(self.view)
|
||||
return fields
|
||||
|
||||
def get_manual_fields(self, path, method):
|
||||
return self._manual_fields
|
||||
|
||||
@staticmethod
|
||||
def update_fields(fields, update_with):
|
||||
"""
|
||||
Update list of coreapi.Field instances, overwriting on `Field.name`.
|
||||
|
||||
Utility function to handle replacing coreapi.Field fields
|
||||
from a list by name. Used to handle `manual_fields`.
|
||||
|
||||
Parameters:
|
||||
|
||||
* `fields`: list of `coreapi.Field` instances to update
|
||||
* `update_with: list of `coreapi.Field` instances to add or replace.
|
||||
"""
|
||||
if not update_with:
|
||||
return fields
|
||||
|
||||
by_name = OrderedDict((f.name, f) for f in fields)
|
||||
for f in update_with:
|
||||
by_name[f.name] = f
|
||||
fields = list(by_name.values())
|
||||
return fields
|
||||
|
||||
def get_encoding(self, path, method):
|
||||
"""
|
||||
Return the 'encoding' parameter to use for a given endpoint.
|
||||
"""
|
||||
view = self.view
|
||||
|
||||
# Core API supports the following request encodings over HTTP...
|
||||
supported_media_types = {
|
||||
'application/json',
|
||||
'application/x-www-form-urlencoded',
|
||||
'multipart/form-data',
|
||||
}
|
||||
parser_classes = getattr(view, 'parser_classes', [])
|
||||
for parser_class in parser_classes:
|
||||
media_type = getattr(parser_class, 'media_type', None)
|
||||
if media_type in supported_media_types:
|
||||
return media_type
|
||||
# Raw binary uploads are supported with "application/octet-stream"
|
||||
if media_type == '*/*':
|
||||
return 'application/octet-stream'
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class ManualSchema(ViewInspector):
|
||||
"""
|
||||
Allows providing a list of coreapi.Fields,
|
||||
plus an optional description.
|
||||
"""
|
||||
def __init__(self, fields, description='', encoding=None):
|
||||
"""
|
||||
Parameters:
|
||||
|
||||
* `fields`: list of `coreapi.Field` instances.
|
||||
* `description`: String description for view. Optional.
|
||||
"""
|
||||
super().__init__()
|
||||
assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances"
|
||||
self._fields = fields
|
||||
self._description = description
|
||||
self._encoding = encoding
|
||||
|
||||
def get_link(self, path, method, base_url):
|
||||
|
||||
if base_url and path.startswith('/'):
|
||||
path = path[1:]
|
||||
|
||||
return coreapi.Link(
|
||||
url=parse.urljoin(base_url, path),
|
||||
action=method.lower(),
|
||||
encoding=self._encoding,
|
||||
fields=self._fields,
|
||||
description=self._description
|
||||
)
|
||||
|
||||
|
||||
def is_enabled():
|
||||
"""Is CoreAPI Mode enabled?"""
|
||||
return issubclass(api_settings.DEFAULT_SCHEMA_CLASS, AutoSchema)
|
@ -0,0 +1,239 @@
|
||||
"""
|
||||
generators.py # Top-down schema generation
|
||||
|
||||
See schemas.__init__.py for package overview.
|
||||
"""
|
||||
import re
|
||||
from importlib import import_module
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.admindocs.views import simplify_regex
|
||||
from django.core.exceptions import PermissionDenied
|
||||
from django.http import Http404
|
||||
from django.urls import URLPattern, URLResolver
|
||||
|
||||
from rest_framework import exceptions
|
||||
from rest_framework.request import clone_request
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.utils.model_meta import _get_pk
|
||||
|
||||
|
||||
def get_pk_name(model):
|
||||
meta = model._meta.concrete_model._meta
|
||||
return _get_pk(meta).name
|
||||
|
||||
|
||||
def is_api_view(callback):
|
||||
"""
|
||||
Return `True` if the given view callback is a REST framework view/viewset.
|
||||
"""
|
||||
# Avoid import cycle on APIView
|
||||
from rest_framework.views import APIView
|
||||
cls = getattr(callback, 'cls', None)
|
||||
return (cls is not None) and issubclass(cls, APIView)
|
||||
|
||||
|
||||
def endpoint_ordering(endpoint):
|
||||
path, method, callback = endpoint
|
||||
method_priority = {
|
||||
'GET': 0,
|
||||
'POST': 1,
|
||||
'PUT': 2,
|
||||
'PATCH': 3,
|
||||
'DELETE': 4
|
||||
}.get(method, 5)
|
||||
return (method_priority,)
|
||||
|
||||
|
||||
_PATH_PARAMETER_COMPONENT_RE = re.compile(
|
||||
r'<(?:(?P<converter>[^>:]+):)?(?P<parameter>\w+)>'
|
||||
)
|
||||
|
||||
|
||||
class EndpointEnumerator:
|
||||
"""
|
||||
A class to determine the available API endpoints that a project exposes.
|
||||
"""
|
||||
def __init__(self, patterns=None, urlconf=None):
|
||||
if patterns is None:
|
||||
if urlconf is None:
|
||||
# Use the default Django URL conf
|
||||
urlconf = settings.ROOT_URLCONF
|
||||
|
||||
# Load the given URLconf module
|
||||
if isinstance(urlconf, str):
|
||||
urls = import_module(urlconf)
|
||||
else:
|
||||
urls = urlconf
|
||||
patterns = urls.urlpatterns
|
||||
|
||||
self.patterns = patterns
|
||||
|
||||
def get_api_endpoints(self, patterns=None, prefix=''):
|
||||
"""
|
||||
Return a list of all available API endpoints by inspecting the URL conf.
|
||||
"""
|
||||
if patterns is None:
|
||||
patterns = self.patterns
|
||||
|
||||
api_endpoints = []
|
||||
|
||||
for pattern in patterns:
|
||||
path_regex = prefix + str(pattern.pattern)
|
||||
if isinstance(pattern, URLPattern):
|
||||
path = self.get_path_from_regex(path_regex)
|
||||
callback = pattern.callback
|
||||
if self.should_include_endpoint(path, callback):
|
||||
for method in self.get_allowed_methods(callback):
|
||||
endpoint = (path, method, callback)
|
||||
api_endpoints.append(endpoint)
|
||||
|
||||
elif isinstance(pattern, URLResolver):
|
||||
nested_endpoints = self.get_api_endpoints(
|
||||
patterns=pattern.url_patterns,
|
||||
prefix=path_regex
|
||||
)
|
||||
api_endpoints.extend(nested_endpoints)
|
||||
|
||||
return sorted(api_endpoints, key=endpoint_ordering)
|
||||
|
||||
def get_path_from_regex(self, path_regex):
|
||||
"""
|
||||
Given a URL conf regex, return a URI template string.
|
||||
"""
|
||||
# ???: Would it be feasible to adjust this such that we generate the
|
||||
# path, plus the kwargs, plus the type from the convertor, such that we
|
||||
# could feed that straight into the parameter schema object?
|
||||
|
||||
path = simplify_regex(path_regex)
|
||||
|
||||
# Strip Django 2.0 convertors as they are incompatible with uritemplate format
|
||||
return re.sub(_PATH_PARAMETER_COMPONENT_RE, r'{\g<parameter>}', path)
|
||||
|
||||
def should_include_endpoint(self, path, callback):
|
||||
"""
|
||||
Return `True` if the given endpoint should be included.
|
||||
"""
|
||||
if not is_api_view(callback):
|
||||
return False # Ignore anything except REST framework views.
|
||||
|
||||
if callback.cls.schema is None:
|
||||
return False
|
||||
|
||||
if 'schema' in callback.initkwargs:
|
||||
if callback.initkwargs['schema'] is None:
|
||||
return False
|
||||
|
||||
if path.endswith('.{format}') or path.endswith('.{format}/'):
|
||||
return False # Ignore .json style URLs.
|
||||
|
||||
return True
|
||||
|
||||
def get_allowed_methods(self, callback):
|
||||
"""
|
||||
Return a list of the valid HTTP methods for this endpoint.
|
||||
"""
|
||||
if hasattr(callback, 'actions'):
|
||||
actions = set(callback.actions)
|
||||
http_method_names = set(callback.cls.http_method_names)
|
||||
methods = [method.upper() for method in actions & http_method_names]
|
||||
else:
|
||||
methods = callback.cls().allowed_methods
|
||||
|
||||
return [method for method in methods if method not in ('OPTIONS', 'HEAD')]
|
||||
|
||||
|
||||
class BaseSchemaGenerator:
|
||||
endpoint_inspector_cls = EndpointEnumerator
|
||||
|
||||
# 'pk' isn't great as an externally exposed name for an identifier,
|
||||
# so by default we prefer to use the actual model field name for schemas.
|
||||
# Set by 'SCHEMA_COERCE_PATH_PK'.
|
||||
coerce_path_pk = None
|
||||
|
||||
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None, version=None):
|
||||
if url and not url.endswith('/'):
|
||||
url += '/'
|
||||
|
||||
self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK
|
||||
|
||||
self.patterns = patterns
|
||||
self.urlconf = urlconf
|
||||
self.title = title
|
||||
self.description = description
|
||||
self.version = version
|
||||
self.url = url
|
||||
self.endpoints = None
|
||||
|
||||
def _initialise_endpoints(self):
|
||||
if self.endpoints is None:
|
||||
inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf)
|
||||
self.endpoints = inspector.get_api_endpoints()
|
||||
|
||||
def _get_paths_and_endpoints(self, request):
|
||||
"""
|
||||
Generate (path, method, view) given (path, method, callback) for paths.
|
||||
"""
|
||||
paths = []
|
||||
view_endpoints = []
|
||||
for path, method, callback in self.endpoints:
|
||||
view = self.create_view(callback, method, request)
|
||||
path = self.coerce_path(path, method, view)
|
||||
paths.append(path)
|
||||
view_endpoints.append((path, method, view))
|
||||
|
||||
return paths, view_endpoints
|
||||
|
||||
def create_view(self, callback, method, request=None):
|
||||
"""
|
||||
Given a callback, return an actual view instance.
|
||||
"""
|
||||
view = callback.cls(**getattr(callback, 'initkwargs', {}))
|
||||
view.args = ()
|
||||
view.kwargs = {}
|
||||
view.format_kwarg = None
|
||||
view.request = None
|
||||
view.action_map = getattr(callback, 'actions', None)
|
||||
|
||||
actions = getattr(callback, 'actions', None)
|
||||
if actions is not None:
|
||||
if method == 'OPTIONS':
|
||||
view.action = 'metadata'
|
||||
else:
|
||||
view.action = actions.get(method.lower())
|
||||
|
||||
if request is not None:
|
||||
view.request = clone_request(request, method)
|
||||
|
||||
return view
|
||||
|
||||
def coerce_path(self, path, method, view):
|
||||
"""
|
||||
Coerce {pk} path arguments into the name of the model field,
|
||||
where possible. This is cleaner for an external representation.
|
||||
(Ie. "this is an identifier", not "this is a database primary key")
|
||||
"""
|
||||
if not self.coerce_path_pk or '{pk}' not in path:
|
||||
return path
|
||||
model = getattr(getattr(view, 'queryset', None), 'model', None)
|
||||
if model:
|
||||
field_name = get_pk_name(model)
|
||||
else:
|
||||
field_name = 'id'
|
||||
return path.replace('{pk}', '{%s}' % field_name)
|
||||
|
||||
def get_schema(self, request=None, public=False):
|
||||
raise NotImplementedError(".get_schema() must be implemented in subclasses.")
|
||||
|
||||
def has_view_permissions(self, path, method, view):
|
||||
"""
|
||||
Return `True` if the incoming request has the correct view permissions.
|
||||
"""
|
||||
if view.request is None:
|
||||
return True
|
||||
|
||||
try:
|
||||
view.check_permissions(view.request)
|
||||
except (exceptions.APIException, Http404, PermissionDenied):
|
||||
return False
|
||||
return True
|
@ -0,0 +1,125 @@
|
||||
"""
|
||||
inspectors.py # Per-endpoint view introspection
|
||||
|
||||
See schemas.__init__.py for package overview.
|
||||
"""
|
||||
import re
|
||||
from weakref import WeakKeyDictionary
|
||||
|
||||
from django.utils.encoding import smart_str
|
||||
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.utils import formatting
|
||||
|
||||
|
||||
class ViewInspector:
|
||||
"""
|
||||
Descriptor class on APIView.
|
||||
|
||||
Provide subclass for per-view schema generation
|
||||
"""
|
||||
|
||||
# Used in _get_description_section()
|
||||
header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:')
|
||||
|
||||
def __init__(self):
|
||||
self.instance_schemas = WeakKeyDictionary()
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
"""
|
||||
Enables `ViewInspector` as a Python _Descriptor_.
|
||||
|
||||
This is how `view.schema` knows about `view`.
|
||||
|
||||
`__get__` is called when the descriptor is accessed on the owner.
|
||||
(That will be when view.schema is called in our case.)
|
||||
|
||||
`owner` is always the owner class. (An APIView, or subclass for us.)
|
||||
`instance` is the view instance or `None` if accessed from the class,
|
||||
rather than an instance.
|
||||
|
||||
See: https://docs.python.org/3/howto/descriptor.html for info on
|
||||
descriptor usage.
|
||||
"""
|
||||
if instance in self.instance_schemas:
|
||||
return self.instance_schemas[instance]
|
||||
|
||||
self.view = instance
|
||||
return self
|
||||
|
||||
def __set__(self, instance, other):
|
||||
self.instance_schemas[instance] = other
|
||||
if other is not None:
|
||||
other.view = instance
|
||||
|
||||
@property
|
||||
def view(self):
|
||||
"""View property."""
|
||||
assert self._view is not None, (
|
||||
"Schema generation REQUIRES a view instance. (Hint: you accessed "
|
||||
"`schema` from the view class rather than an instance.)"
|
||||
)
|
||||
return self._view
|
||||
|
||||
@view.setter
|
||||
def view(self, value):
|
||||
self._view = value
|
||||
|
||||
@view.deleter
|
||||
def view(self):
|
||||
self._view = None
|
||||
|
||||
def get_description(self, path, method):
|
||||
"""
|
||||
Determine a path description.
|
||||
|
||||
This will be based on the method docstring if one exists,
|
||||
or else the class docstring.
|
||||
"""
|
||||
view = self.view
|
||||
|
||||
method_name = getattr(view, 'action', method.lower())
|
||||
method_docstring = getattr(view, method_name, None).__doc__
|
||||
if method_docstring:
|
||||
# An explicit docstring on the method or action.
|
||||
return self._get_description_section(view, method.lower(), formatting.dedent(smart_str(method_docstring)))
|
||||
else:
|
||||
return self._get_description_section(view, getattr(view, 'action', method.lower()),
|
||||
view.get_view_description())
|
||||
|
||||
def _get_description_section(self, view, header, description):
|
||||
lines = [line for line in description.splitlines()]
|
||||
current_section = ''
|
||||
sections = {'': ''}
|
||||
|
||||
for line in lines:
|
||||
if self.header_regex.match(line):
|
||||
current_section, separator, lead = line.partition(':')
|
||||
sections[current_section] = lead.strip()
|
||||
else:
|
||||
sections[current_section] += '\n' + line
|
||||
|
||||
# TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys`
|
||||
coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
|
||||
if header in sections:
|
||||
return sections[header].strip()
|
||||
if header in coerce_method_names:
|
||||
if coerce_method_names[header] in sections:
|
||||
return sections[coerce_method_names[header]].strip()
|
||||
return sections[''].strip()
|
||||
|
||||
|
||||
class DefaultSchema(ViewInspector):
|
||||
"""Allows overriding AutoSchema using DEFAULT_SCHEMA_CLASS setting"""
|
||||
def __get__(self, instance, owner):
|
||||
result = super().__get__(instance, owner)
|
||||
if not isinstance(result, DefaultSchema):
|
||||
return result
|
||||
|
||||
inspector_class = api_settings.DEFAULT_SCHEMA_CLASS
|
||||
assert issubclass(inspector_class, ViewInspector), (
|
||||
"DEFAULT_SCHEMA_CLASS must be set to a ViewInspector (usually an AutoSchema) subclass"
|
||||
)
|
||||
inspector = inspector_class()
|
||||
inspector.view = instance
|
||||
return inspector
|
@ -0,0 +1,722 @@
|
||||
import re
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from decimal import Decimal
|
||||
from operator import attrgetter
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from django.core.validators import (
|
||||
DecimalValidator, EmailValidator, MaxLengthValidator, MaxValueValidator,
|
||||
MinLengthValidator, MinValueValidator, RegexValidator, URLValidator
|
||||
)
|
||||
from django.db import models
|
||||
from django.utils.encoding import force_str
|
||||
|
||||
from rest_framework import (
|
||||
RemovedInDRF315Warning, exceptions, renderers, serializers
|
||||
)
|
||||
from rest_framework.compat import uritemplate
|
||||
from rest_framework.fields import _UnvalidatedField, empty
|
||||
from rest_framework.settings import api_settings
|
||||
|
||||
from .generators import BaseSchemaGenerator
|
||||
from .inspectors import ViewInspector
|
||||
from .utils import get_pk_description, is_list_view
|
||||
|
||||
|
||||
class SchemaGenerator(BaseSchemaGenerator):
|
||||
|
||||
def get_info(self):
|
||||
# Title and version are required by openapi specification 3.x
|
||||
info = {
|
||||
'title': self.title or '',
|
||||
'version': self.version or ''
|
||||
}
|
||||
|
||||
if self.description is not None:
|
||||
info['description'] = self.description
|
||||
|
||||
return info
|
||||
|
||||
def check_duplicate_operation_id(self, paths):
|
||||
ids = {}
|
||||
for route in paths:
|
||||
for method in paths[route]:
|
||||
if 'operationId' not in paths[route][method]:
|
||||
continue
|
||||
operation_id = paths[route][method]['operationId']
|
||||
if operation_id in ids:
|
||||
warnings.warn(
|
||||
'You have a duplicated operationId in your OpenAPI schema: {operation_id}\n'
|
||||
'\tRoute: {route1}, Method: {method1}\n'
|
||||
'\tRoute: {route2}, Method: {method2}\n'
|
||||
'\tAn operationId has to be unique across your schema. Your schema may not work in other tools.'
|
||||
.format(
|
||||
route1=ids[operation_id]['route'],
|
||||
method1=ids[operation_id]['method'],
|
||||
route2=route,
|
||||
method2=method,
|
||||
operation_id=operation_id
|
||||
)
|
||||
)
|
||||
ids[operation_id] = {
|
||||
'route': route,
|
||||
'method': method
|
||||
}
|
||||
|
||||
def get_schema(self, request=None, public=False):
|
||||
"""
|
||||
Generate a OpenAPI schema.
|
||||
"""
|
||||
self._initialise_endpoints()
|
||||
components_schemas = {}
|
||||
|
||||
# Iterate endpoints generating per method path operations.
|
||||
paths = {}
|
||||
_, view_endpoints = self._get_paths_and_endpoints(None if public else request)
|
||||
for path, method, view in view_endpoints:
|
||||
if not self.has_view_permissions(path, method, view):
|
||||
continue
|
||||
|
||||
operation = view.schema.get_operation(path, method)
|
||||
components = view.schema.get_components(path, method)
|
||||
for k in components.keys():
|
||||
if k not in components_schemas:
|
||||
continue
|
||||
if components_schemas[k] == components[k]:
|
||||
continue
|
||||
warnings.warn('Schema component "{}" has been overriden with a different value.'.format(k))
|
||||
|
||||
components_schemas.update(components)
|
||||
|
||||
# Normalise path for any provided mount url.
|
||||
if path.startswith('/'):
|
||||
path = path[1:]
|
||||
path = urljoin(self.url or '/', path)
|
||||
|
||||
paths.setdefault(path, {})
|
||||
paths[path][method.lower()] = operation
|
||||
|
||||
self.check_duplicate_operation_id(paths)
|
||||
|
||||
# Compile final schema.
|
||||
schema = {
|
||||
'openapi': '3.0.2',
|
||||
'info': self.get_info(),
|
||||
'paths': paths,
|
||||
}
|
||||
|
||||
if len(components_schemas) > 0:
|
||||
schema['components'] = {
|
||||
'schemas': components_schemas
|
||||
}
|
||||
|
||||
return schema
|
||||
|
||||
# View Inspectors
|
||||
|
||||
|
||||
class AutoSchema(ViewInspector):
|
||||
|
||||
def __init__(self, tags=None, operation_id_base=None, component_name=None):
|
||||
"""
|
||||
:param operation_id_base: user-defined name in operationId. If empty, it will be deducted from the Model/Serializer/View name.
|
||||
:param component_name: user-defined component's name. If empty, it will be deducted from the Serializer's class name.
|
||||
"""
|
||||
if tags and not all(isinstance(tag, str) for tag in tags):
|
||||
raise ValueError('tags must be a list or tuple of string.')
|
||||
self._tags = tags
|
||||
self.operation_id_base = operation_id_base
|
||||
self.component_name = component_name
|
||||
super().__init__()
|
||||
|
||||
request_media_types = []
|
||||
response_media_types = []
|
||||
|
||||
method_mapping = {
|
||||
'get': 'retrieve',
|
||||
'post': 'create',
|
||||
'put': 'update',
|
||||
'patch': 'partialUpdate',
|
||||
'delete': 'destroy',
|
||||
}
|
||||
|
||||
def get_operation(self, path, method):
|
||||
operation = {}
|
||||
|
||||
operation['operationId'] = self.get_operation_id(path, method)
|
||||
operation['description'] = self.get_description(path, method)
|
||||
|
||||
parameters = []
|
||||
parameters += self.get_path_parameters(path, method)
|
||||
parameters += self.get_pagination_parameters(path, method)
|
||||
parameters += self.get_filter_parameters(path, method)
|
||||
operation['parameters'] = parameters
|
||||
|
||||
request_body = self.get_request_body(path, method)
|
||||
if request_body:
|
||||
operation['requestBody'] = request_body
|
||||
operation['responses'] = self.get_responses(path, method)
|
||||
operation['tags'] = self.get_tags(path, method)
|
||||
|
||||
return operation
|
||||
|
||||
def get_component_name(self, serializer):
|
||||
"""
|
||||
Compute the component's name from the serializer.
|
||||
Raise an exception if the serializer's class name is "Serializer" (case-insensitive).
|
||||
"""
|
||||
if self.component_name is not None:
|
||||
return self.component_name
|
||||
|
||||
# use the serializer's class name as the component name.
|
||||
component_name = serializer.__class__.__name__
|
||||
# We remove the "serializer" string from the class name.
|
||||
pattern = re.compile("serializer", re.IGNORECASE)
|
||||
component_name = pattern.sub("", component_name)
|
||||
|
||||
if component_name == "":
|
||||
raise Exception(
|
||||
'"{}" is an invalid class name for schema generation. '
|
||||
'Serializer\'s class name should be unique and explicit. e.g. "ItemSerializer"'
|
||||
.format(serializer.__class__.__name__)
|
||||
)
|
||||
|
||||
return component_name
|
||||
|
||||
def get_components(self, path, method):
|
||||
"""
|
||||
Return components with their properties from the serializer.
|
||||
"""
|
||||
|
||||
if method.lower() == 'delete':
|
||||
return {}
|
||||
|
||||
request_serializer = self.get_request_serializer(path, method)
|
||||
response_serializer = self.get_response_serializer(path, method)
|
||||
|
||||
components = {}
|
||||
|
||||
if isinstance(request_serializer, serializers.Serializer):
|
||||
component_name = self.get_component_name(request_serializer)
|
||||
content = self.map_serializer(request_serializer)
|
||||
components.setdefault(component_name, content)
|
||||
|
||||
if isinstance(response_serializer, serializers.Serializer):
|
||||
component_name = self.get_component_name(response_serializer)
|
||||
content = self.map_serializer(response_serializer)
|
||||
components.setdefault(component_name, content)
|
||||
|
||||
return components
|
||||
|
||||
def _to_camel_case(self, snake_str):
|
||||
components = snake_str.split('_')
|
||||
# We capitalize the first letter of each component except the first one
|
||||
# with the 'title' method and join them together.
|
||||
return components[0] + ''.join(x.title() for x in components[1:])
|
||||
|
||||
def get_operation_id_base(self, path, method, action):
|
||||
"""
|
||||
Compute the base part for operation ID from the model, serializer or view name.
|
||||
"""
|
||||
model = getattr(getattr(self.view, 'queryset', None), 'model', None)
|
||||
|
||||
if self.operation_id_base is not None:
|
||||
name = self.operation_id_base
|
||||
|
||||
# Try to deduce the ID from the view's model
|
||||
elif model is not None:
|
||||
name = model.__name__
|
||||
|
||||
# Try with the serializer class name
|
||||
elif self.get_serializer(path, method) is not None:
|
||||
name = self.get_serializer(path, method).__class__.__name__
|
||||
if name.endswith('Serializer'):
|
||||
name = name[:-10]
|
||||
|
||||
# Fallback to the view name
|
||||
else:
|
||||
name = self.view.__class__.__name__
|
||||
if name.endswith('APIView'):
|
||||
name = name[:-7]
|
||||
elif name.endswith('View'):
|
||||
name = name[:-4]
|
||||
|
||||
# Due to camel-casing of classes and `action` being lowercase, apply title in order to find if action truly
|
||||
# comes at the end of the name
|
||||
if name.endswith(action.title()): # ListView, UpdateAPIView, ThingDelete ...
|
||||
name = name[:-len(action)]
|
||||
|
||||
if action == 'list' and not name.endswith('s'): # listThings instead of listThing
|
||||
name += 's'
|
||||
|
||||
return name
|
||||
|
||||
def get_operation_id(self, path, method):
|
||||
"""
|
||||
Compute an operation ID from the view type and get_operation_id_base method.
|
||||
"""
|
||||
method_name = getattr(self.view, 'action', method.lower())
|
||||
if is_list_view(path, method, self.view):
|
||||
action = 'list'
|
||||
elif method_name not in self.method_mapping:
|
||||
action = self._to_camel_case(method_name)
|
||||
else:
|
||||
action = self.method_mapping[method.lower()]
|
||||
|
||||
name = self.get_operation_id_base(path, method, action)
|
||||
|
||||
return action + name
|
||||
|
||||
def get_path_parameters(self, path, method):
|
||||
"""
|
||||
Return a list of parameters from templated path variables.
|
||||
"""
|
||||
assert uritemplate, '`uritemplate` must be installed for OpenAPI schema support.'
|
||||
|
||||
model = getattr(getattr(self.view, 'queryset', None), 'model', None)
|
||||
parameters = []
|
||||
|
||||
for variable in uritemplate.variables(path):
|
||||
description = ''
|
||||
if model is not None: # TODO: test this.
|
||||
# Attempt to infer a field description if possible.
|
||||
try:
|
||||
model_field = model._meta.get_field(variable)
|
||||
except Exception:
|
||||
model_field = None
|
||||
|
||||
if model_field is not None and model_field.help_text:
|
||||
description = force_str(model_field.help_text)
|
||||
elif model_field is not None and model_field.primary_key:
|
||||
description = get_pk_description(model, model_field)
|
||||
|
||||
parameter = {
|
||||
"name": variable,
|
||||
"in": "path",
|
||||
"required": True,
|
||||
"description": description,
|
||||
'schema': {
|
||||
'type': 'string', # TODO: integer, pattern, ...
|
||||
},
|
||||
}
|
||||
parameters.append(parameter)
|
||||
|
||||
return parameters
|
||||
|
||||
def get_filter_parameters(self, path, method):
|
||||
if not self.allows_filters(path, method):
|
||||
return []
|
||||
parameters = []
|
||||
for filter_backend in self.view.filter_backends:
|
||||
parameters += filter_backend().get_schema_operation_parameters(self.view)
|
||||
return parameters
|
||||
|
||||
def allows_filters(self, path, method):
|
||||
"""
|
||||
Determine whether to include filter Fields in schema.
|
||||
|
||||
Default implementation looks for ModelViewSet or GenericAPIView
|
||||
actions/methods that cause filtering on the default implementation.
|
||||
"""
|
||||
if getattr(self.view, 'filter_backends', None) is None:
|
||||
return False
|
||||
if hasattr(self.view, 'action'):
|
||||
return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"]
|
||||
return method.lower() in ["get", "put", "patch", "delete"]
|
||||
|
||||
def get_pagination_parameters(self, path, method):
|
||||
view = self.view
|
||||
|
||||
if not is_list_view(path, method, view):
|
||||
return []
|
||||
|
||||
paginator = self.get_paginator()
|
||||
if not paginator:
|
||||
return []
|
||||
|
||||
return paginator.get_schema_operation_parameters(view)
|
||||
|
||||
def map_choicefield(self, field):
|
||||
choices = list(OrderedDict.fromkeys(field.choices)) # preserve order and remove duplicates
|
||||
if all(isinstance(choice, bool) for choice in choices):
|
||||
type = 'boolean'
|
||||
elif all(isinstance(choice, int) for choice in choices):
|
||||
type = 'integer'
|
||||
elif all(isinstance(choice, (int, float, Decimal)) for choice in choices): # `number` includes `integer`
|
||||
# Ref: https://tools.ietf.org/html/draft-wright-json-schema-validation-00#section-5.21
|
||||
type = 'number'
|
||||
elif all(isinstance(choice, str) for choice in choices):
|
||||
type = 'string'
|
||||
else:
|
||||
type = None
|
||||
|
||||
mapping = {
|
||||
# The value of `enum` keyword MUST be an array and SHOULD be unique.
|
||||
# Ref: https://tools.ietf.org/html/draft-wright-json-schema-validation-00#section-5.20
|
||||
'enum': choices
|
||||
}
|
||||
|
||||
# If We figured out `type` then and only then we should set it. It must be a string.
|
||||
# Ref: https://swagger.io/docs/specification/data-models/data-types/#mixed-type
|
||||
# It is optional but it can not be null.
|
||||
# Ref: https://tools.ietf.org/html/draft-wright-json-schema-validation-00#section-5.21
|
||||
if type:
|
||||
mapping['type'] = type
|
||||
return mapping
|
||||
|
||||
def map_field(self, field):
|
||||
|
||||
# Nested Serializers, `many` or not.
|
||||
if isinstance(field, serializers.ListSerializer):
|
||||
return {
|
||||
'type': 'array',
|
||||
'items': self.map_serializer(field.child)
|
||||
}
|
||||
if isinstance(field, serializers.Serializer):
|
||||
data = self.map_serializer(field)
|
||||
data['type'] = 'object'
|
||||
return data
|
||||
|
||||
# Related fields.
|
||||
if isinstance(field, serializers.ManyRelatedField):
|
||||
return {
|
||||
'type': 'array',
|
||||
'items': self.map_field(field.child_relation)
|
||||
}
|
||||
if isinstance(field, serializers.PrimaryKeyRelatedField):
|
||||
model = getattr(field.queryset, 'model', None)
|
||||
if model is not None:
|
||||
model_field = model._meta.pk
|
||||
if isinstance(model_field, models.AutoField):
|
||||
return {'type': 'integer'}
|
||||
|
||||
# ChoiceFields (single and multiple).
|
||||
# Q:
|
||||
# - Is 'type' required?
|
||||
# - can we determine the TYPE of a choicefield?
|
||||
if isinstance(field, serializers.MultipleChoiceField):
|
||||
return {
|
||||
'type': 'array',
|
||||
'items': self.map_choicefield(field)
|
||||
}
|
||||
|
||||
if isinstance(field, serializers.ChoiceField):
|
||||
return self.map_choicefield(field)
|
||||
|
||||
# ListField.
|
||||
if isinstance(field, serializers.ListField):
|
||||
mapping = {
|
||||
'type': 'array',
|
||||
'items': {},
|
||||
}
|
||||
if not isinstance(field.child, _UnvalidatedField):
|
||||
mapping['items'] = self.map_field(field.child)
|
||||
return mapping
|
||||
|
||||
# DateField and DateTimeField type is string
|
||||
if isinstance(field, serializers.DateField):
|
||||
return {
|
||||
'type': 'string',
|
||||
'format': 'date',
|
||||
}
|
||||
|
||||
if isinstance(field, serializers.DateTimeField):
|
||||
return {
|
||||
'type': 'string',
|
||||
'format': 'date-time',
|
||||
}
|
||||
|
||||
# "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification."
|
||||
# see: https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types
|
||||
# see also: https://swagger.io/docs/specification/data-models/data-types/#string
|
||||
if isinstance(field, serializers.EmailField):
|
||||
return {
|
||||
'type': 'string',
|
||||
'format': 'email'
|
||||
}
|
||||
|
||||
if isinstance(field, serializers.URLField):
|
||||
return {
|
||||
'type': 'string',
|
||||
'format': 'uri'
|
||||
}
|
||||
|
||||
if isinstance(field, serializers.UUIDField):
|
||||
return {
|
||||
'type': 'string',
|
||||
'format': 'uuid'
|
||||
}
|
||||
|
||||
if isinstance(field, serializers.IPAddressField):
|
||||
content = {
|
||||
'type': 'string',
|
||||
}
|
||||
if field.protocol != 'both':
|
||||
content['format'] = field.protocol
|
||||
return content
|
||||
|
||||
if isinstance(field, serializers.DecimalField):
|
||||
if getattr(field, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING):
|
||||
content = {
|
||||
'type': 'string',
|
||||
'format': 'decimal',
|
||||
}
|
||||
else:
|
||||
content = {
|
||||
'type': 'number'
|
||||
}
|
||||
|
||||
if field.decimal_places:
|
||||
content['multipleOf'] = float('.' + (field.decimal_places - 1) * '0' + '1')
|
||||
if field.max_whole_digits:
|
||||
content['maximum'] = int(field.max_whole_digits * '9') + 1
|
||||
content['minimum'] = -content['maximum']
|
||||
self._map_min_max(field, content)
|
||||
return content
|
||||
|
||||
if isinstance(field, serializers.FloatField):
|
||||
content = {
|
||||
'type': 'number',
|
||||
}
|
||||
self._map_min_max(field, content)
|
||||
return content
|
||||
|
||||
if isinstance(field, serializers.IntegerField):
|
||||
content = {
|
||||
'type': 'integer'
|
||||
}
|
||||
self._map_min_max(field, content)
|
||||
# 2147483647 is max for int32_size, so we use int64 for format
|
||||
if int(content.get('maximum', 0)) > 2147483647 or int(content.get('minimum', 0)) > 2147483647:
|
||||
content['format'] = 'int64'
|
||||
return content
|
||||
|
||||
if isinstance(field, serializers.FileField):
|
||||
return {
|
||||
'type': 'string',
|
||||
'format': 'binary'
|
||||
}
|
||||
|
||||
# Simplest cases, default to 'string' type:
|
||||
FIELD_CLASS_SCHEMA_TYPE = {
|
||||
serializers.BooleanField: 'boolean',
|
||||
serializers.JSONField: 'object',
|
||||
serializers.DictField: 'object',
|
||||
serializers.HStoreField: 'object',
|
||||
}
|
||||
return {'type': FIELD_CLASS_SCHEMA_TYPE.get(field.__class__, 'string')}
|
||||
|
||||
def _map_min_max(self, field, content):
|
||||
if field.max_value:
|
||||
content['maximum'] = field.max_value
|
||||
if field.min_value:
|
||||
content['minimum'] = field.min_value
|
||||
|
||||
def map_serializer(self, serializer):
|
||||
# Assuming we have a valid serializer instance.
|
||||
required = []
|
||||
properties = {}
|
||||
|
||||
for field in serializer.fields.values():
|
||||
if isinstance(field, serializers.HiddenField):
|
||||
continue
|
||||
|
||||
if field.required:
|
||||
required.append(field.field_name)
|
||||
|
||||
schema = self.map_field(field)
|
||||
if field.read_only:
|
||||
schema['readOnly'] = True
|
||||
if field.write_only:
|
||||
schema['writeOnly'] = True
|
||||
if field.allow_null:
|
||||
schema['nullable'] = True
|
||||
if field.default is not None and field.default != empty and not callable(field.default):
|
||||
schema['default'] = field.default
|
||||
if field.help_text:
|
||||
schema['description'] = str(field.help_text)
|
||||
self.map_field_validators(field, schema)
|
||||
|
||||
properties[field.field_name] = schema
|
||||
|
||||
result = {
|
||||
'type': 'object',
|
||||
'properties': properties
|
||||
}
|
||||
if required:
|
||||
result['required'] = required
|
||||
|
||||
return result
|
||||
|
||||
def map_field_validators(self, field, schema):
|
||||
"""
|
||||
map field validators
|
||||
"""
|
||||
for v in field.validators:
|
||||
# "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification."
|
||||
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types
|
||||
if isinstance(v, EmailValidator):
|
||||
schema['format'] = 'email'
|
||||
if isinstance(v, URLValidator):
|
||||
schema['format'] = 'uri'
|
||||
if isinstance(v, RegexValidator):
|
||||
# In Python, the token \Z does what \z does in other engines.
|
||||
# https://stackoverflow.com/questions/53283160
|
||||
schema['pattern'] = v.regex.pattern.replace('\\Z', '\\z')
|
||||
elif isinstance(v, MaxLengthValidator):
|
||||
attr_name = 'maxLength'
|
||||
if isinstance(field, serializers.ListField):
|
||||
attr_name = 'maxItems'
|
||||
schema[attr_name] = v.limit_value
|
||||
elif isinstance(v, MinLengthValidator):
|
||||
attr_name = 'minLength'
|
||||
if isinstance(field, serializers.ListField):
|
||||
attr_name = 'minItems'
|
||||
schema[attr_name] = v.limit_value
|
||||
elif isinstance(v, MaxValueValidator):
|
||||
schema['maximum'] = v.limit_value
|
||||
elif isinstance(v, MinValueValidator):
|
||||
schema['minimum'] = v.limit_value
|
||||
elif isinstance(v, DecimalValidator) and \
|
||||
not getattr(field, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING):
|
||||
if v.decimal_places:
|
||||
schema['multipleOf'] = float('.' + (v.decimal_places - 1) * '0' + '1')
|
||||
if v.max_digits:
|
||||
digits = v.max_digits
|
||||
if v.decimal_places is not None and v.decimal_places > 0:
|
||||
digits -= v.decimal_places
|
||||
schema['maximum'] = int(digits * '9') + 1
|
||||
schema['minimum'] = -schema['maximum']
|
||||
|
||||
def get_paginator(self):
|
||||
pagination_class = getattr(self.view, 'pagination_class', None)
|
||||
if pagination_class:
|
||||
return pagination_class()
|
||||
return None
|
||||
|
||||
def map_parsers(self, path, method):
|
||||
return list(map(attrgetter('media_type'), self.view.parser_classes))
|
||||
|
||||
def map_renderers(self, path, method):
|
||||
media_types = []
|
||||
for renderer in self.view.renderer_classes:
|
||||
# BrowsableAPIRenderer not relevant to OpenAPI spec
|
||||
if issubclass(renderer, renderers.BrowsableAPIRenderer):
|
||||
continue
|
||||
media_types.append(renderer.media_type)
|
||||
return media_types
|
||||
|
||||
def get_serializer(self, path, method):
|
||||
view = self.view
|
||||
|
||||
if not hasattr(view, 'get_serializer'):
|
||||
return None
|
||||
|
||||
try:
|
||||
return view.get_serializer()
|
||||
except exceptions.APIException:
|
||||
warnings.warn('{}.get_serializer() raised an exception during '
|
||||
'schema generation. Serializer fields will not be '
|
||||
'generated for {} {}.'
|
||||
.format(view.__class__.__name__, method, path))
|
||||
return None
|
||||
|
||||
def get_request_serializer(self, path, method):
|
||||
"""
|
||||
Override this method if your view uses a different serializer for
|
||||
handling request body.
|
||||
"""
|
||||
return self.get_serializer(path, method)
|
||||
|
||||
def get_response_serializer(self, path, method):
|
||||
"""
|
||||
Override this method if your view uses a different serializer for
|
||||
populating response data.
|
||||
"""
|
||||
return self.get_serializer(path, method)
|
||||
|
||||
def get_reference(self, serializer):
|
||||
return {'$ref': '#/components/schemas/{}'.format(self.get_component_name(serializer))}
|
||||
|
||||
def get_request_body(self, path, method):
|
||||
if method not in ('PUT', 'PATCH', 'POST'):
|
||||
return {}
|
||||
|
||||
self.request_media_types = self.map_parsers(path, method)
|
||||
|
||||
serializer = self.get_request_serializer(path, method)
|
||||
|
||||
if not isinstance(serializer, serializers.Serializer):
|
||||
item_schema = {}
|
||||
else:
|
||||
item_schema = self.get_reference(serializer)
|
||||
|
||||
return {
|
||||
'content': {
|
||||
ct: {'schema': item_schema}
|
||||
for ct in self.request_media_types
|
||||
}
|
||||
}
|
||||
|
||||
def get_responses(self, path, method):
|
||||
if method == 'DELETE':
|
||||
return {
|
||||
'204': {
|
||||
'description': ''
|
||||
}
|
||||
}
|
||||
|
||||
self.response_media_types = self.map_renderers(path, method)
|
||||
|
||||
serializer = self.get_response_serializer(path, method)
|
||||
|
||||
if not isinstance(serializer, serializers.Serializer):
|
||||
item_schema = {}
|
||||
else:
|
||||
item_schema = self.get_reference(serializer)
|
||||
|
||||
if is_list_view(path, method, self.view):
|
||||
response_schema = {
|
||||
'type': 'array',
|
||||
'items': item_schema,
|
||||
}
|
||||
paginator = self.get_paginator()
|
||||
if paginator:
|
||||
response_schema = paginator.get_paginated_response_schema(response_schema)
|
||||
else:
|
||||
response_schema = item_schema
|
||||
status_code = '201' if method == 'POST' else '200'
|
||||
return {
|
||||
status_code: {
|
||||
'content': {
|
||||
ct: {'schema': response_schema}
|
||||
for ct in self.response_media_types
|
||||
},
|
||||
# description is a mandatory property,
|
||||
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#responseObject
|
||||
# TODO: put something meaningful into it
|
||||
'description': ""
|
||||
}
|
||||
}
|
||||
|
||||
def get_tags(self, path, method):
|
||||
# If user have specified tags, use them.
|
||||
if self._tags:
|
||||
return self._tags
|
||||
|
||||
# First element of a specific path could be valid tag. This is a fallback solution.
|
||||
# PUT, PATCH, GET(Retrieve), DELETE: /user_profile/{id}/ tags = [user-profile]
|
||||
# POST, GET(List): /user_profile/ tags = [user-profile]
|
||||
if path.startswith('/'):
|
||||
path = path[1:]
|
||||
|
||||
return [path.split('/')[0].replace('_', '-')]
|
||||
|
||||
def _get_reference(self, serializer):
|
||||
warnings.warn(
|
||||
"Method `_get_reference()` has been renamed to `get_reference()`. "
|
||||
"The old name will be removed in DRF v3.15.",
|
||||
RemovedInDRF315Warning, stacklevel=2
|
||||
)
|
||||
return self.get_reference(serializer)
|
@ -0,0 +1,41 @@
|
||||
"""
|
||||
utils.py # Shared helper functions
|
||||
|
||||
See schemas.__init__.py for package overview.
|
||||
"""
|
||||
from django.db import models
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from rest_framework.mixins import RetrieveModelMixin
|
||||
|
||||
|
||||
def is_list_view(path, method, view):
|
||||
"""
|
||||
Return True if the given path/method appears to represent a list view.
|
||||
"""
|
||||
if hasattr(view, 'action'):
|
||||
# Viewsets have an explicitly defined action, which we can inspect.
|
||||
return view.action == 'list'
|
||||
|
||||
if method.lower() != 'get':
|
||||
return False
|
||||
if isinstance(view, RetrieveModelMixin):
|
||||
return False
|
||||
path_components = path.strip('/').split('/')
|
||||
if path_components and '{' in path_components[-1]:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_pk_description(model, model_field):
|
||||
if isinstance(model_field, models.AutoField):
|
||||
value_type = _('unique integer value')
|
||||
elif isinstance(model_field, models.UUIDField):
|
||||
value_type = _('UUID string')
|
||||
else:
|
||||
value_type = _('unique value')
|
||||
|
||||
return _('A {value_type} identifying this {name}.').format(
|
||||
value_type=value_type,
|
||||
name=model._meta.verbose_name,
|
||||
)
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user