docker setup

This commit is contained in:
AdrienLSH
2023-11-23 16:43:30 +01:00
parent fd19180e1d
commit f29003c66a
5410 changed files with 869440 additions and 0 deletions

View File

@ -0,0 +1,33 @@
r"""
______ _____ _____ _____ __
| ___ \ ___/ ___|_ _| / _| | |
| |_/ / |__ \ `--. | | | |_ _ __ __ _ _ __ ___ _____ _____ _ __| |__
| /| __| `--. \ | | | _| '__/ _` | '_ ` _ \ / _ \ \ /\ / / _ \| '__| |/ /
| |\ \| |___/\__/ / | | | | | | | (_| | | | | | | __/\ V V / (_) | | | <
\_| \_\____/\____/ \_/ |_| |_| \__,_|_| |_| |_|\___| \_/\_/ \___/|_| |_|\_|
"""
import django
__title__ = 'Django REST framework'
__version__ = '3.14.0'
__author__ = 'Tom Christie'
__license__ = 'BSD 3-Clause'
__copyright__ = 'Copyright 2011-2019 Encode OSS Ltd'
# Version synonym
VERSION = __version__
# Header encoding (see RFC5987)
HTTP_HEADER_ENCODING = 'iso-8859-1'
# Default datetime input and output formats
ISO_8601 = 'iso-8601'
if django.VERSION < (3, 2):
default_app_config = 'rest_framework.apps.RestFrameworkConfig'
class RemovedInDRF315Warning(PendingDeprecationWarning):
pass

View File

@ -0,0 +1,10 @@
from django.apps import AppConfig
class RestFrameworkConfig(AppConfig):
name = 'rest_framework'
verbose_name = "Django REST framework"
def ready(self):
# Add System checks
from .checks import pagination_system_check # NOQA

View File

@ -0,0 +1,232 @@
"""
Provides various authentication policies.
"""
import base64
import binascii
from django.contrib.auth import authenticate, get_user_model
from django.middleware.csrf import CsrfViewMiddleware
from django.utils.translation import gettext_lazy as _
from rest_framework import HTTP_HEADER_ENCODING, exceptions
def get_authorization_header(request):
"""
Return request's 'Authorization:' header, as a bytestring.
Hide some test client ickyness where the header can be unicode.
"""
auth = request.META.get('HTTP_AUTHORIZATION', b'')
if isinstance(auth, str):
# Work around django test client oddness
auth = auth.encode(HTTP_HEADER_ENCODING)
return auth
class CSRFCheck(CsrfViewMiddleware):
def _reject(self, request, reason):
# Return the failure reason instead of an HttpResponse
return reason
class BaseAuthentication:
"""
All authentication classes should extend BaseAuthentication.
"""
def authenticate(self, request):
"""
Authenticate the request and return a two-tuple of (user, token).
"""
raise NotImplementedError(".authenticate() must be overridden.")
def authenticate_header(self, request):
"""
Return a string to be used as the value of the `WWW-Authenticate`
header in a `401 Unauthenticated` response, or `None` if the
authentication scheme should return `403 Permission Denied` responses.
"""
pass
class BasicAuthentication(BaseAuthentication):
"""
HTTP Basic authentication against username/password.
"""
www_authenticate_realm = 'api'
def authenticate(self, request):
"""
Returns a `User` if a correct username and password have been supplied
using HTTP Basic authentication. Otherwise returns `None`.
"""
auth = get_authorization_header(request).split()
if not auth or auth[0].lower() != b'basic':
return None
if len(auth) == 1:
msg = _('Invalid basic header. No credentials provided.')
raise exceptions.AuthenticationFailed(msg)
elif len(auth) > 2:
msg = _('Invalid basic header. Credentials string should not contain spaces.')
raise exceptions.AuthenticationFailed(msg)
try:
try:
auth_decoded = base64.b64decode(auth[1]).decode('utf-8')
except UnicodeDecodeError:
auth_decoded = base64.b64decode(auth[1]).decode('latin-1')
auth_parts = auth_decoded.partition(':')
except (TypeError, UnicodeDecodeError, binascii.Error):
msg = _('Invalid basic header. Credentials not correctly base64 encoded.')
raise exceptions.AuthenticationFailed(msg)
userid, password = auth_parts[0], auth_parts[2]
return self.authenticate_credentials(userid, password, request)
def authenticate_credentials(self, userid, password, request=None):
"""
Authenticate the userid and password against username and password
with optional request for context.
"""
credentials = {
get_user_model().USERNAME_FIELD: userid,
'password': password
}
user = authenticate(request=request, **credentials)
if user is None:
raise exceptions.AuthenticationFailed(_('Invalid username/password.'))
if not user.is_active:
raise exceptions.AuthenticationFailed(_('User inactive or deleted.'))
return (user, None)
def authenticate_header(self, request):
return 'Basic realm="%s"' % self.www_authenticate_realm
class SessionAuthentication(BaseAuthentication):
"""
Use Django's session framework for authentication.
"""
def authenticate(self, request):
"""
Returns a `User` if the request session currently has a logged in user.
Otherwise returns `None`.
"""
# Get the session-based user from the underlying HttpRequest object
user = getattr(request._request, 'user', None)
# Unauthenticated, CSRF validation not required
if not user or not user.is_active:
return None
self.enforce_csrf(request)
# CSRF passed with authenticated user
return (user, None)
def enforce_csrf(self, request):
"""
Enforce CSRF validation for session based authentication.
"""
def dummy_get_response(request): # pragma: no cover
return None
check = CSRFCheck(dummy_get_response)
# populates request.META['CSRF_COOKIE'], which is used in process_view()
check.process_request(request)
reason = check.process_view(request, None, (), {})
if reason:
# CSRF failed, bail with explicit error message
raise exceptions.PermissionDenied('CSRF Failed: %s' % reason)
class TokenAuthentication(BaseAuthentication):
"""
Simple token based authentication.
Clients should authenticate by passing the token key in the "Authorization"
HTTP header, prepended with the string "Token ". For example:
Authorization: Token 401f7ac837da42b97f613d789819ff93537bee6a
"""
keyword = 'Token'
model = None
def get_model(self):
if self.model is not None:
return self.model
from rest_framework.authtoken.models import Token
return Token
"""
A custom token model may be used, but must have the following properties.
* key -- The string identifying the token
* user -- The user to which the token belongs
"""
def authenticate(self, request):
auth = get_authorization_header(request).split()
if not auth or auth[0].lower() != self.keyword.lower().encode():
return None
if len(auth) == 1:
msg = _('Invalid token header. No credentials provided.')
raise exceptions.AuthenticationFailed(msg)
elif len(auth) > 2:
msg = _('Invalid token header. Token string should not contain spaces.')
raise exceptions.AuthenticationFailed(msg)
try:
token = auth[1].decode()
except UnicodeError:
msg = _('Invalid token header. Token string should not contain invalid characters.')
raise exceptions.AuthenticationFailed(msg)
return self.authenticate_credentials(token)
def authenticate_credentials(self, key):
model = self.get_model()
try:
token = model.objects.select_related('user').get(key=key)
except model.DoesNotExist:
raise exceptions.AuthenticationFailed(_('Invalid token.'))
if not token.user.is_active:
raise exceptions.AuthenticationFailed(_('User inactive or deleted.'))
return (token.user, token)
def authenticate_header(self, request):
return self.keyword
class RemoteUserAuthentication(BaseAuthentication):
"""
REMOTE_USER authentication.
To use this, set up your web server to perform authentication, which will
set the REMOTE_USER environment variable. You will need to have
'django.contrib.auth.backends.RemoteUserBackend in your
AUTHENTICATION_BACKENDS setting
"""
# Name of request header to grab username from. This will be the key as
# used in the request.META dictionary, i.e. the normalization of headers to
# all uppercase and the addition of "HTTP_" prefix apply.
header = "REMOTE_USER"
def authenticate(self, request):
user = authenticate(request=request, remote_user=request.META.get(self.header))
if user and user.is_active:
return (user, None)

View File

@ -0,0 +1,4 @@
import django
if django.VERSION < (3, 2):
default_app_config = 'rest_framework.authtoken.apps.AuthTokenConfig'

View File

@ -0,0 +1,51 @@
from django.contrib import admin
from django.contrib.admin.utils import quote
from django.contrib.admin.views.main import ChangeList
from django.contrib.auth import get_user_model
from django.core.exceptions import ValidationError
from django.urls import reverse
from rest_framework.authtoken.models import Token, TokenProxy
User = get_user_model()
class TokenChangeList(ChangeList):
"""Map to matching User id"""
def url_for_result(self, result):
pk = result.user.pk
return reverse('admin:%s_%s_change' % (self.opts.app_label,
self.opts.model_name),
args=(quote(pk),),
current_app=self.model_admin.admin_site.name)
class TokenAdmin(admin.ModelAdmin):
list_display = ('key', 'user', 'created')
fields = ('user',)
ordering = ('-created',)
actions = None # Actions not compatible with mapped IDs.
def get_changelist(self, request, **kwargs):
return TokenChangeList
def get_object(self, request, object_id, from_field=None):
"""
Map from User ID to matching Token.
"""
queryset = self.get_queryset(request)
field = User._meta.pk
try:
object_id = field.to_python(object_id)
user = User.objects.get(**{field.name: object_id})
return queryset.get(user=user)
except (queryset.model.DoesNotExist, User.DoesNotExist, ValidationError, ValueError):
return None
def delete_model(self, request, obj):
# Map back to actual Token, since delete() uses pk.
token = Token.objects.get(key=obj.key)
return super().delete_model(request, token)
admin.site.register(TokenProxy, TokenAdmin)

View File

@ -0,0 +1,7 @@
from django.apps import AppConfig
from django.utils.translation import gettext_lazy as _
class AuthTokenConfig(AppConfig):
name = 'rest_framework.authtoken'
verbose_name = _("Auth Token")

View File

@ -0,0 +1,45 @@
from django.contrib.auth import get_user_model
from django.core.management.base import BaseCommand, CommandError
from rest_framework.authtoken.models import Token
UserModel = get_user_model()
class Command(BaseCommand):
help = 'Create DRF Token for a given user'
def create_user_token(self, username, reset_token):
user = UserModel._default_manager.get_by_natural_key(username)
if reset_token:
Token.objects.filter(user=user).delete()
token = Token.objects.get_or_create(user=user)
return token[0]
def add_arguments(self, parser):
parser.add_argument('username', type=str)
parser.add_argument(
'-r',
'--reset',
action='store_true',
dest='reset_token',
default=False,
help='Reset existing User token and create a new one',
)
def handle(self, *args, **options):
username = options['username']
reset_token = options['reset_token']
try:
token = self.create_user_token(username, reset_token)
except UserModel.DoesNotExist:
raise CommandError(
'Cannot create the Token: user {} does not exist'.format(
username)
)
self.stdout.write(
'Generated token {} for user {}'.format(token.key, username))

View File

@ -0,0 +1,54 @@
import binascii
import os
from django.conf import settings
from django.db import models
from django.utils.translation import gettext_lazy as _
class Token(models.Model):
"""
The default authorization token model.
"""
key = models.CharField(_("Key"), max_length=40, primary_key=True)
user = models.OneToOneField(
settings.AUTH_USER_MODEL, related_name='auth_token',
on_delete=models.CASCADE, verbose_name=_("User")
)
created = models.DateTimeField(_("Created"), auto_now_add=True)
class Meta:
# Work around for a bug in Django:
# https://code.djangoproject.com/ticket/19422
#
# Also see corresponding ticket:
# https://github.com/encode/django-rest-framework/issues/705
abstract = 'rest_framework.authtoken' not in settings.INSTALLED_APPS
verbose_name = _("Token")
verbose_name_plural = _("Tokens")
def save(self, *args, **kwargs):
if not self.key:
self.key = self.generate_key()
return super().save(*args, **kwargs)
@classmethod
def generate_key(cls):
return binascii.hexlify(os.urandom(20)).decode()
def __str__(self):
return self.key
class TokenProxy(Token):
"""
Proxy mapping pk to user pk for use in admin.
"""
@property
def pk(self):
return self.user_id
class Meta:
proxy = 'rest_framework.authtoken' in settings.INSTALLED_APPS
abstract = 'rest_framework.authtoken' not in settings.INSTALLED_APPS
verbose_name = "token"

View File

@ -0,0 +1,42 @@
from django.contrib.auth import authenticate
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
class AuthTokenSerializer(serializers.Serializer):
username = serializers.CharField(
label=_("Username"),
write_only=True
)
password = serializers.CharField(
label=_("Password"),
style={'input_type': 'password'},
trim_whitespace=False,
write_only=True
)
token = serializers.CharField(
label=_("Token"),
read_only=True
)
def validate(self, attrs):
username = attrs.get('username')
password = attrs.get('password')
if username and password:
user = authenticate(request=self.context.get('request'),
username=username, password=password)
# The authenticate call simply returns None for is_active=False
# users. (Assuming the default ModelBackend authentication
# backend.)
if not user:
msg = _('Unable to log in with provided credentials.')
raise serializers.ValidationError(msg, code='authorization')
else:
msg = _('Must include "username" and "password".')
raise serializers.ValidationError(msg, code='authorization')
attrs['user'] = user
return attrs

View File

@ -0,0 +1,62 @@
from rest_framework import parsers, renderers
from rest_framework.authtoken.models import Token
from rest_framework.authtoken.serializers import AuthTokenSerializer
from rest_framework.compat import coreapi, coreschema
from rest_framework.response import Response
from rest_framework.schemas import ManualSchema
from rest_framework.schemas import coreapi as coreapi_schema
from rest_framework.views import APIView
class ObtainAuthToken(APIView):
throttle_classes = ()
permission_classes = ()
parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,)
renderer_classes = (renderers.JSONRenderer,)
serializer_class = AuthTokenSerializer
if coreapi_schema.is_enabled():
schema = ManualSchema(
fields=[
coreapi.Field(
name="username",
required=True,
location='form',
schema=coreschema.String(
title="Username",
description="Valid username for authentication",
),
),
coreapi.Field(
name="password",
required=True,
location='form',
schema=coreschema.String(
title="Password",
description="Valid password for authentication",
),
),
],
encoding="application/json",
)
def get_serializer_context(self):
return {
'request': self.request,
'format': self.format_kwarg,
'view': self
}
def get_serializer(self, *args, **kwargs):
kwargs['context'] = self.get_serializer_context()
return self.serializer_class(*args, **kwargs)
def post(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
user = serializer.validated_data['user']
token, created = Token.objects.get_or_create(user=user)
return Response({'token': token.key})
obtain_auth_token = ObtainAuthToken.as_view()

View File

@ -0,0 +1,21 @@
from django.core.checks import Tags, Warning, register
@register(Tags.compatibility)
def pagination_system_check(app_configs, **kwargs):
errors = []
# Use of default page size setting requires a default Paginator class
from rest_framework.settings import api_settings
if api_settings.PAGE_SIZE and not api_settings.DEFAULT_PAGINATION_CLASS:
errors.append(
Warning(
"You have specified a default PAGE_SIZE pagination rest_framework setting, "
"without specifying also a DEFAULT_PAGINATION_CLASS.",
hint="The default for DEFAULT_PAGINATION_CLASS is None. "
"In previous versions this was PageNumberPagination. "
"If you wish to define PAGE_SIZE globally whilst defining "
"pagination_class on a per-view basis you may silence this check.",
id="rest_framework.W001"
)
)
return errors

View File

@ -0,0 +1,184 @@
"""
The `compat` module provides support for backwards compatibility with older
versions of Django/Python, and compatibility wrappers around optional packages.
"""
import django
from django.conf import settings
from django.views.generic import View
def unicode_http_header(value):
# Coerce HTTP header value to unicode.
if isinstance(value, bytes):
return value.decode('iso-8859-1')
return value
def distinct(queryset, base):
if settings.DATABASES[queryset.db]["ENGINE"] == "django.db.backends.oracle":
# distinct analogue for Oracle users
return base.filter(pk__in=set(queryset.values_list('pk', flat=True)))
return queryset.distinct()
# django.contrib.postgres requires psycopg2
try:
from django.contrib.postgres import fields as postgres_fields
except ImportError:
postgres_fields = None
# coreapi is required for CoreAPI schema generation
try:
import coreapi
except ImportError:
coreapi = None
# uritemplate is required for OpenAPI and CoreAPI schema generation
try:
import uritemplate
except ImportError:
uritemplate = None
# coreschema is optional
try:
import coreschema
except ImportError:
coreschema = None
# pyyaml is optional
try:
import yaml
except ImportError:
yaml = None
# requests is optional
try:
import requests
except ImportError:
requests = None
# PATCH method is not implemented by Django
if 'patch' not in View.http_method_names:
View.http_method_names = View.http_method_names + ['patch']
# Markdown is optional (version 3.0+ required)
try:
import markdown
HEADERID_EXT_PATH = 'markdown.extensions.toc'
LEVEL_PARAM = 'baselevel'
def apply_markdown(text):
"""
Simple wrapper around :func:`markdown.markdown` to set the base level
of '#' style headers to <h2>.
"""
extensions = [HEADERID_EXT_PATH]
extension_configs = {
HEADERID_EXT_PATH: {
LEVEL_PARAM: '2'
}
}
md = markdown.Markdown(
extensions=extensions, extension_configs=extension_configs
)
md_filter_add_syntax_highlight(md)
return md.convert(text)
except ImportError:
apply_markdown = None
markdown = None
try:
import pygments
from pygments.formatters import HtmlFormatter
from pygments.lexers import TextLexer, get_lexer_by_name
def pygments_highlight(text, lang, style):
lexer = get_lexer_by_name(lang, stripall=False)
formatter = HtmlFormatter(nowrap=True, style=style)
return pygments.highlight(text, lexer, formatter)
def pygments_css(style):
formatter = HtmlFormatter(style=style)
return formatter.get_style_defs('.highlight')
except ImportError:
pygments = None
def pygments_highlight(text, lang, style):
return text
def pygments_css(style):
return None
if markdown is not None and pygments is not None:
# starting from this blogpost and modified to support current markdown extensions API
# https://zerokspot.com/weblog/2008/06/18/syntax-highlighting-in-markdown-with-pygments/
import re
from markdown.preprocessors import Preprocessor
class CodeBlockPreprocessor(Preprocessor):
pattern = re.compile(
r'^\s*``` *([^\n]+)\n(.+?)^\s*```', re.M | re.S)
formatter = HtmlFormatter()
def run(self, lines):
def repl(m):
try:
lexer = get_lexer_by_name(m.group(1))
except (ValueError, NameError):
lexer = TextLexer()
code = m.group(2).replace('\t', ' ')
code = pygments.highlight(code, lexer, self.formatter)
code = code.replace('\n\n', '\n&nbsp;\n').replace('\n', '<br />').replace('\\@', '@')
return '\n\n%s\n\n' % code
ret = self.pattern.sub(repl, "\n".join(lines))
return ret.split("\n")
def md_filter_add_syntax_highlight(md):
md.preprocessors.register(CodeBlockPreprocessor(), 'highlight', 40)
return True
else:
def md_filter_add_syntax_highlight(md):
return False
if django.VERSION >= (4, 2):
# Django 4.2+: use the stock parse_header_parameters function
# Note: Django 4.1 also has an implementation of parse_header_parameters
# which is slightly different from the one in 4.2, it needs
# the compatibility shim as well.
from django.utils.http import parse_header_parameters
else:
# Django <= 4.1: create a compatibility shim for parse_header_parameters
from django.http.multipartparser import parse_header
def parse_header_parameters(line):
# parse_header works with bytes, but parse_header_parameters
# works with strings. Call encode to convert the line to bytes.
main_value_pair, params = parse_header(line.encode())
return main_value_pair, {
# parse_header will convert *some* values to string.
# parse_header_parameters converts *all* values to string.
# Make sure all values are converted by calling decode on
# any remaining non-string values.
k: v if isinstance(v, str) else v.decode()
for k, v in params.items()
}
# `separators` argument to `json.dumps()` differs between 2.x and 3.x
# See: https://bugs.python.org/issue22767
SHORT_SEPARATORS = (',', ':')
LONG_SEPARATORS = (', ', ': ')
INDENT_SEPARATORS = (',', ': ')

View File

@ -0,0 +1,233 @@
"""
The most important decorator in this module is `@api_view`, which is used
for writing function-based views with REST framework.
There are also various decorators for setting the API policies on function
based views, as well as the `@action` decorator, which is used to annotate
methods on viewsets that should be included by routers.
"""
import types
from django.forms.utils import pretty_name
from rest_framework.views import APIView
def api_view(http_method_names=None):
"""
Decorator that converts a function-based view into an APIView subclass.
Takes a list of allowed methods for the view as an argument.
"""
http_method_names = ['GET'] if (http_method_names is None) else http_method_names
def decorator(func):
WrappedAPIView = type(
'WrappedAPIView',
(APIView,),
{'__doc__': func.__doc__}
)
# Note, the above allows us to set the docstring.
# It is the equivalent of:
#
# class WrappedAPIView(APIView):
# pass
# WrappedAPIView.__doc__ = func.doc <--- Not possible to do this
# api_view applied without (method_names)
assert not(isinstance(http_method_names, types.FunctionType)), \
'@api_view missing list of allowed HTTP methods'
# api_view applied with eg. string instead of list of strings
assert isinstance(http_method_names, (list, tuple)), \
'@api_view expected a list of strings, received %s' % type(http_method_names).__name__
allowed_methods = set(http_method_names) | {'options'}
WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods]
def handler(self, *args, **kwargs):
return func(*args, **kwargs)
for method in http_method_names:
setattr(WrappedAPIView, method.lower(), handler)
WrappedAPIView.__name__ = func.__name__
WrappedAPIView.__module__ = func.__module__
WrappedAPIView.renderer_classes = getattr(func, 'renderer_classes',
APIView.renderer_classes)
WrappedAPIView.parser_classes = getattr(func, 'parser_classes',
APIView.parser_classes)
WrappedAPIView.authentication_classes = getattr(func, 'authentication_classes',
APIView.authentication_classes)
WrappedAPIView.throttle_classes = getattr(func, 'throttle_classes',
APIView.throttle_classes)
WrappedAPIView.permission_classes = getattr(func, 'permission_classes',
APIView.permission_classes)
WrappedAPIView.schema = getattr(func, 'schema',
APIView.schema)
return WrappedAPIView.as_view()
return decorator
def renderer_classes(renderer_classes):
def decorator(func):
func.renderer_classes = renderer_classes
return func
return decorator
def parser_classes(parser_classes):
def decorator(func):
func.parser_classes = parser_classes
return func
return decorator
def authentication_classes(authentication_classes):
def decorator(func):
func.authentication_classes = authentication_classes
return func
return decorator
def throttle_classes(throttle_classes):
def decorator(func):
func.throttle_classes = throttle_classes
return func
return decorator
def permission_classes(permission_classes):
def decorator(func):
func.permission_classes = permission_classes
return func
return decorator
def schema(view_inspector):
def decorator(func):
func.schema = view_inspector
return func
return decorator
def action(methods=None, detail=None, url_path=None, url_name=None, **kwargs):
"""
Mark a ViewSet method as a routable action.
`@action`-decorated functions will be endowed with a `mapping` property,
a `MethodMapper` that can be used to add additional method-based behaviors
on the routed action.
:param methods: A list of HTTP method names this action responds to.
Defaults to GET only.
:param detail: Required. Determines whether this action applies to
instance/detail requests or collection/list requests.
:param url_path: Define the URL segment for this action. Defaults to the
name of the method decorated.
:param url_name: Define the internal (`reverse`) URL name for this action.
Defaults to the name of the method decorated with underscores
replaced with dashes.
:param kwargs: Additional properties to set on the view. This can be used
to override viewset-level *_classes settings, equivalent to
how the `@renderer_classes` etc. decorators work for function-
based API views.
"""
methods = ['get'] if methods is None else methods
methods = [method.lower() for method in methods]
assert detail is not None, (
"@action() missing required argument: 'detail'"
)
# name and suffix are mutually exclusive
if 'name' in kwargs and 'suffix' in kwargs:
raise TypeError("`name` and `suffix` are mutually exclusive arguments.")
def decorator(func):
func.mapping = MethodMapper(func, methods)
func.detail = detail
func.url_path = url_path if url_path else func.__name__
func.url_name = url_name if url_name else func.__name__.replace('_', '-')
# These kwargs will end up being passed to `ViewSet.as_view()` within
# the router, which eventually delegates to Django's CBV `View`,
# which assigns them as instance attributes for each request.
func.kwargs = kwargs
# Set descriptive arguments for viewsets
if 'name' not in kwargs and 'suffix' not in kwargs:
func.kwargs['name'] = pretty_name(func.__name__)
func.kwargs['description'] = func.__doc__ or None
return func
return decorator
class MethodMapper(dict):
"""
Enables mapping HTTP methods to different ViewSet methods for a single,
logical action.
Example usage:
class MyViewSet(ViewSet):
@action(detail=False)
def example(self, request, **kwargs):
...
@example.mapping.post
def create_example(self, request, **kwargs):
...
"""
def __init__(self, action, methods):
self.action = action
for method in methods:
self[method] = self.action.__name__
def _map(self, method, func):
assert method not in self, (
"Method '%s' has already been mapped to '.%s'." % (method, self[method]))
assert func.__name__ != self.action.__name__, (
"Method mapping does not behave like the property decorator. You "
"cannot use the same method name for each mapping declaration.")
self[method] = func.__name__
return func
def get(self, func):
return self._map('get', func)
def post(self, func):
return self._map('post', func)
def put(self, func):
return self._map('put', func)
def patch(self, func):
return self._map('patch', func)
def delete(self, func):
return self._map('delete', func)
def head(self, func):
return self._map('head', func)
def options(self, func):
return self._map('options', func)
def trace(self, func):
return self._map('trace', func)

View File

@ -0,0 +1,88 @@
from django.urls import include, path
from rest_framework.renderers import (
CoreJSONRenderer, DocumentationRenderer, SchemaJSRenderer
)
from rest_framework.schemas import SchemaGenerator, get_schema_view
from rest_framework.settings import api_settings
def get_docs_view(
title=None, description=None, schema_url=None, urlconf=None,
public=True, patterns=None, generator_class=SchemaGenerator,
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES,
renderer_classes=None):
if renderer_classes is None:
renderer_classes = [DocumentationRenderer, CoreJSONRenderer]
return get_schema_view(
title=title,
url=schema_url,
urlconf=urlconf,
description=description,
renderer_classes=renderer_classes,
public=public,
patterns=patterns,
generator_class=generator_class,
authentication_classes=authentication_classes,
permission_classes=permission_classes,
)
def get_schemajs_view(
title=None, description=None, schema_url=None, urlconf=None,
public=True, patterns=None, generator_class=SchemaGenerator,
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES):
renderer_classes = [SchemaJSRenderer]
return get_schema_view(
title=title,
url=schema_url,
urlconf=urlconf,
description=description,
renderer_classes=renderer_classes,
public=public,
patterns=patterns,
generator_class=generator_class,
authentication_classes=authentication_classes,
permission_classes=permission_classes,
)
def include_docs_urls(
title=None, description=None, schema_url=None, urlconf=None,
public=True, patterns=None, generator_class=SchemaGenerator,
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES,
renderer_classes=None):
docs_view = get_docs_view(
title=title,
description=description,
schema_url=schema_url,
urlconf=urlconf,
public=public,
patterns=patterns,
generator_class=generator_class,
authentication_classes=authentication_classes,
renderer_classes=renderer_classes,
permission_classes=permission_classes,
)
schema_js_view = get_schemajs_view(
title=title,
description=description,
schema_url=schema_url,
urlconf=urlconf,
public=public,
patterns=patterns,
generator_class=generator_class,
authentication_classes=authentication_classes,
permission_classes=permission_classes,
)
urls = [
path('', docs_view, name='docs-index'),
path('schema.js', schema_js_view, name='schema-js')
]
return include((urls, 'api-docs'), namespace='api-docs')

View File

@ -0,0 +1,264 @@
"""
Handled exceptions raised by REST framework.
In addition, Django's built in 403 and 404 exceptions are handled.
(`django.http.Http404` and `django.core.exceptions.PermissionDenied`)
"""
import math
from django.http import JsonResponse
from django.utils.encoding import force_str
from django.utils.translation import gettext_lazy as _
from django.utils.translation import ngettext
from rest_framework import status
from rest_framework.utils.serializer_helpers import ReturnDict, ReturnList
def _get_error_details(data, default_code=None):
"""
Descend into a nested data structure, forcing any
lazy translation strings or strings into `ErrorDetail`.
"""
if isinstance(data, (list, tuple)):
ret = [
_get_error_details(item, default_code) for item in data
]
if isinstance(data, ReturnList):
return ReturnList(ret, serializer=data.serializer)
return ret
elif isinstance(data, dict):
ret = {
key: _get_error_details(value, default_code)
for key, value in data.items()
}
if isinstance(data, ReturnDict):
return ReturnDict(ret, serializer=data.serializer)
return ret
text = force_str(data)
code = getattr(data, 'code', default_code)
return ErrorDetail(text, code)
def _get_codes(detail):
if isinstance(detail, list):
return [_get_codes(item) for item in detail]
elif isinstance(detail, dict):
return {key: _get_codes(value) for key, value in detail.items()}
return detail.code
def _get_full_details(detail):
if isinstance(detail, list):
return [_get_full_details(item) for item in detail]
elif isinstance(detail, dict):
return {key: _get_full_details(value) for key, value in detail.items()}
return {
'message': detail,
'code': detail.code
}
class ErrorDetail(str):
"""
A string-like object that can additionally have a code.
"""
code = None
def __new__(cls, string, code=None):
self = super().__new__(cls, string)
self.code = code
return self
def __eq__(self, other):
result = super().__eq__(other)
if result is NotImplemented:
return NotImplemented
try:
return result and self.code == other.code
except AttributeError:
return result
def __ne__(self, other):
result = self.__eq__(other)
if result is NotImplemented:
return NotImplemented
return not result
def __repr__(self):
return 'ErrorDetail(string=%r, code=%r)' % (
str(self),
self.code,
)
def __hash__(self):
return hash(str(self))
class APIException(Exception):
"""
Base class for REST framework exceptions.
Subclasses should provide `.status_code` and `.default_detail` properties.
"""
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
default_detail = _('A server error occurred.')
default_code = 'error'
def __init__(self, detail=None, code=None):
if detail is None:
detail = self.default_detail
if code is None:
code = self.default_code
self.detail = _get_error_details(detail, code)
def __str__(self):
return str(self.detail)
def get_codes(self):
"""
Return only the code part of the error details.
Eg. {"name": ["required"]}
"""
return _get_codes(self.detail)
def get_full_details(self):
"""
Return both the message & code parts of the error details.
Eg. {"name": [{"message": "This field is required.", "code": "required"}]}
"""
return _get_full_details(self.detail)
# The recommended style for using `ValidationError` is to keep it namespaced
# under `serializers`, in order to minimize potential confusion with Django's
# built in `ValidationError`. For example:
#
# from rest_framework import serializers
# raise serializers.ValidationError('Value was invalid')
class ValidationError(APIException):
status_code = status.HTTP_400_BAD_REQUEST
default_detail = _('Invalid input.')
default_code = 'invalid'
def __init__(self, detail=None, code=None):
if detail is None:
detail = self.default_detail
if code is None:
code = self.default_code
# For validation failures, we may collect many errors together,
# so the details should always be coerced to a list if not already.
if isinstance(detail, tuple):
detail = list(detail)
elif not isinstance(detail, dict) and not isinstance(detail, list):
detail = [detail]
self.detail = _get_error_details(detail, code)
class ParseError(APIException):
status_code = status.HTTP_400_BAD_REQUEST
default_detail = _('Malformed request.')
default_code = 'parse_error'
class AuthenticationFailed(APIException):
status_code = status.HTTP_401_UNAUTHORIZED
default_detail = _('Incorrect authentication credentials.')
default_code = 'authentication_failed'
class NotAuthenticated(APIException):
status_code = status.HTTP_401_UNAUTHORIZED
default_detail = _('Authentication credentials were not provided.')
default_code = 'not_authenticated'
class PermissionDenied(APIException):
status_code = status.HTTP_403_FORBIDDEN
default_detail = _('You do not have permission to perform this action.')
default_code = 'permission_denied'
class NotFound(APIException):
status_code = status.HTTP_404_NOT_FOUND
default_detail = _('Not found.')
default_code = 'not_found'
class MethodNotAllowed(APIException):
status_code = status.HTTP_405_METHOD_NOT_ALLOWED
default_detail = _('Method "{method}" not allowed.')
default_code = 'method_not_allowed'
def __init__(self, method, detail=None, code=None):
if detail is None:
detail = force_str(self.default_detail).format(method=method)
super().__init__(detail, code)
class NotAcceptable(APIException):
status_code = status.HTTP_406_NOT_ACCEPTABLE
default_detail = _('Could not satisfy the request Accept header.')
default_code = 'not_acceptable'
def __init__(self, detail=None, code=None, available_renderers=None):
self.available_renderers = available_renderers
super().__init__(detail, code)
class UnsupportedMediaType(APIException):
status_code = status.HTTP_415_UNSUPPORTED_MEDIA_TYPE
default_detail = _('Unsupported media type "{media_type}" in request.')
default_code = 'unsupported_media_type'
def __init__(self, media_type, detail=None, code=None):
if detail is None:
detail = force_str(self.default_detail).format(media_type=media_type)
super().__init__(detail, code)
class Throttled(APIException):
status_code = status.HTTP_429_TOO_MANY_REQUESTS
default_detail = _('Request was throttled.')
extra_detail_singular = _('Expected available in {wait} second.')
extra_detail_plural = _('Expected available in {wait} seconds.')
default_code = 'throttled'
def __init__(self, wait=None, detail=None, code=None):
if detail is None:
detail = force_str(self.default_detail)
if wait is not None:
wait = math.ceil(wait)
detail = ' '.join((
detail,
force_str(ngettext(self.extra_detail_singular.format(wait=wait),
self.extra_detail_plural.format(wait=wait),
wait))))
self.wait = wait
super().__init__(detail, code)
def server_error(request, *args, **kwargs):
"""
Generic 500 error handler.
"""
data = {
'error': 'Server Error (500)'
}
return JsonResponse(data, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
def bad_request(request, exception, *args, **kwargs):
"""
Generic 400 error handler.
"""
data = {
'error': 'Bad Request (400)'
}
return JsonResponse(data, status=status.HTTP_400_BAD_REQUEST)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,333 @@
"""
Provides generic filtering backends that can be used to filter the results
returned by list views.
"""
import operator
from functools import reduce
from django.core.exceptions import ImproperlyConfigured
from django.db import models
from django.db.models.constants import LOOKUP_SEP
from django.template import loader
from django.utils.encoding import force_str
from django.utils.translation import gettext_lazy as _
from rest_framework.compat import coreapi, coreschema, distinct
from rest_framework.settings import api_settings
class BaseFilterBackend:
"""
A base class from which all filter backend classes should inherit.
"""
def filter_queryset(self, request, queryset, view):
"""
Return a filtered queryset.
"""
raise NotImplementedError(".filter_queryset() must be overridden.")
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
return []
def get_schema_operation_parameters(self, view):
return []
class SearchFilter(BaseFilterBackend):
# The URL query parameter used for the search.
search_param = api_settings.SEARCH_PARAM
template = 'rest_framework/filters/search.html'
lookup_prefixes = {
'^': 'istartswith',
'=': 'iexact',
'@': 'search',
'$': 'iregex',
}
search_title = _('Search')
search_description = _('A search term.')
def get_search_fields(self, view, request):
"""
Search fields are obtained from the view, but the request is always
passed to this method. Sub-classes can override this method to
dynamically change the search fields based on request content.
"""
return getattr(view, 'search_fields', None)
def get_search_terms(self, request):
"""
Search terms are set by a ?search=... query parameter,
and may be comma and/or whitespace delimited.
"""
params = request.query_params.get(self.search_param, '')
params = params.replace('\x00', '') # strip null characters
params = params.replace(',', ' ')
return params.split()
def construct_search(self, field_name):
lookup = self.lookup_prefixes.get(field_name[0])
if lookup:
field_name = field_name[1:]
else:
lookup = 'icontains'
return LOOKUP_SEP.join([field_name, lookup])
def must_call_distinct(self, queryset, search_fields):
"""
Return True if 'distinct()' should be used to query the given lookups.
"""
for search_field in search_fields:
opts = queryset.model._meta
if search_field[0] in self.lookup_prefixes:
search_field = search_field[1:]
# Annotated fields do not need to be distinct
if isinstance(queryset, models.QuerySet) and search_field in queryset.query.annotations:
continue
parts = search_field.split(LOOKUP_SEP)
for part in parts:
field = opts.get_field(part)
if hasattr(field, 'get_path_info'):
# This field is a relation, update opts to follow the relation
path_info = field.get_path_info()
opts = path_info[-1].to_opts
if any(path.m2m for path in path_info):
# This field is a m2m relation so we know we need to call distinct
return True
else:
# This field has a custom __ query transform but is not a relational field.
break
return False
def filter_queryset(self, request, queryset, view):
search_fields = self.get_search_fields(view, request)
search_terms = self.get_search_terms(request)
if not search_fields or not search_terms:
return queryset
orm_lookups = [
self.construct_search(str(search_field))
for search_field in search_fields
]
base = queryset
conditions = []
for search_term in search_terms:
queries = [
models.Q(**{orm_lookup: search_term})
for orm_lookup in orm_lookups
]
conditions.append(reduce(operator.or_, queries))
queryset = queryset.filter(reduce(operator.and_, conditions))
if self.must_call_distinct(queryset, search_fields):
# Filtering against a many-to-many field requires us to
# call queryset.distinct() in order to avoid duplicate items
# in the resulting queryset.
# We try to avoid this if possible, for performance reasons.
queryset = distinct(queryset, base)
return queryset
def to_html(self, request, queryset, view):
if not getattr(view, 'search_fields', None):
return ''
term = self.get_search_terms(request)
term = term[0] if term else ''
context = {
'param': self.search_param,
'term': term
}
template = loader.get_template(self.template)
return template.render(context)
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
return [
coreapi.Field(
name=self.search_param,
required=False,
location='query',
schema=coreschema.String(
title=force_str(self.search_title),
description=force_str(self.search_description)
)
)
]
def get_schema_operation_parameters(self, view):
return [
{
'name': self.search_param,
'required': False,
'in': 'query',
'description': force_str(self.search_description),
'schema': {
'type': 'string',
},
},
]
class OrderingFilter(BaseFilterBackend):
# The URL query parameter used for the ordering.
ordering_param = api_settings.ORDERING_PARAM
ordering_fields = None
ordering_title = _('Ordering')
ordering_description = _('Which field to use when ordering the results.')
template = 'rest_framework/filters/ordering.html'
def get_ordering(self, request, queryset, view):
"""
Ordering is set by a comma delimited ?ordering=... query parameter.
The `ordering` query parameter can be overridden by setting
the `ordering_param` value on the OrderingFilter or by
specifying an `ORDERING_PARAM` value in the API settings.
"""
params = request.query_params.get(self.ordering_param)
if params:
fields = [param.strip() for param in params.split(',')]
ordering = self.remove_invalid_fields(queryset, fields, view, request)
if ordering:
return ordering
# No ordering was included, or all the ordering fields were invalid
return self.get_default_ordering(view)
def get_default_ordering(self, view):
ordering = getattr(view, 'ordering', None)
if isinstance(ordering, str):
return (ordering,)
return ordering
def get_default_valid_fields(self, queryset, view, context={}):
# If `ordering_fields` is not specified, then we determine a default
# based on the serializer class, if one exists on the view.
if hasattr(view, 'get_serializer_class'):
try:
serializer_class = view.get_serializer_class()
except AssertionError:
# Raised by the default implementation if
# no serializer_class was found
serializer_class = None
else:
serializer_class = getattr(view, 'serializer_class', None)
if serializer_class is None:
msg = (
"Cannot use %s on a view which does not have either a "
"'serializer_class', an overriding 'get_serializer_class' "
"or 'ordering_fields' attribute."
)
raise ImproperlyConfigured(msg % self.__class__.__name__)
model_class = queryset.model
model_property_names = [
# 'pk' is a property added in Django's Model class, however it is valid for ordering.
attr for attr in dir(model_class) if isinstance(getattr(model_class, attr), property) and attr != 'pk'
]
return [
(field.source.replace('.', '__') or field_name, field.label)
for field_name, field in serializer_class(context=context).fields.items()
if (
not getattr(field, 'write_only', False) and
not field.source == '*' and
field.source not in model_property_names
)
]
def get_valid_fields(self, queryset, view, context={}):
valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)
if valid_fields is None:
# Default to allowing filtering on serializer fields
return self.get_default_valid_fields(queryset, view, context)
elif valid_fields == '__all__':
# View explicitly allows filtering on any model field
valid_fields = [
(field.name, field.verbose_name) for field in queryset.model._meta.fields
]
valid_fields += [
(key, key.title().split('__'))
for key in queryset.query.annotations
]
else:
valid_fields = [
(item, item) if isinstance(item, str) else item
for item in valid_fields
]
return valid_fields
def remove_invalid_fields(self, queryset, fields, view, request):
valid_fields = [item[0] for item in self.get_valid_fields(queryset, view, {'request': request})]
def term_valid(term):
if term.startswith("-"):
term = term[1:]
return term in valid_fields
return [term for term in fields if term_valid(term)]
def filter_queryset(self, request, queryset, view):
ordering = self.get_ordering(request, queryset, view)
if ordering:
return queryset.order_by(*ordering)
return queryset
def get_template_context(self, request, queryset, view):
current = self.get_ordering(request, queryset, view)
current = None if not current else current[0]
options = []
context = {
'request': request,
'current': current,
'param': self.ordering_param,
}
for key, label in self.get_valid_fields(queryset, view, context):
options.append((key, '%s - %s' % (label, _('ascending'))))
options.append(('-' + key, '%s - %s' % (label, _('descending'))))
context['options'] = options
return context
def to_html(self, request, queryset, view):
template = loader.get_template(self.template)
context = self.get_template_context(request, queryset, view)
return template.render(context)
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
return [
coreapi.Field(
name=self.ordering_param,
required=False,
location='query',
schema=coreschema.String(
title=force_str(self.ordering_title),
description=force_str(self.ordering_description)
)
)
]
def get_schema_operation_parameters(self, view):
return [
{
'name': self.ordering_param,
'required': False,
'in': 'query',
'description': force_str(self.ordering_description),
'schema': {
'type': 'string',
},
},
]

View File

@ -0,0 +1,291 @@
"""
Generic views that provide commonly needed behaviour.
"""
from django.core.exceptions import ValidationError
from django.db.models.query import QuerySet
from django.http import Http404
from django.shortcuts import get_object_or_404 as _get_object_or_404
from rest_framework import mixins, views
from rest_framework.settings import api_settings
def get_object_or_404(queryset, *filter_args, **filter_kwargs):
"""
Same as Django's standard shortcut, but make sure to also raise 404
if the filter_kwargs don't match the required types.
"""
try:
return _get_object_or_404(queryset, *filter_args, **filter_kwargs)
except (TypeError, ValueError, ValidationError):
raise Http404
class GenericAPIView(views.APIView):
"""
Base class for all other generic views.
"""
# You'll need to either set these attributes,
# or override `get_queryset()`/`get_serializer_class()`.
# If you are overriding a view method, it is important that you call
# `get_queryset()` instead of accessing the `queryset` property directly,
# as `queryset` will get evaluated only once, and those results are cached
# for all subsequent requests.
queryset = None
serializer_class = None
# If you want to use object lookups other than pk, set 'lookup_field'.
# For more complex lookup requirements override `get_object()`.
lookup_field = 'pk'
lookup_url_kwarg = None
# The filter backend classes to use for queryset filtering
filter_backends = api_settings.DEFAULT_FILTER_BACKENDS
# The style to use for queryset pagination.
pagination_class = api_settings.DEFAULT_PAGINATION_CLASS
def get_queryset(self):
"""
Get the list of items for this view.
This must be an iterable, and may be a queryset.
Defaults to using `self.queryset`.
This method should always be used rather than accessing `self.queryset`
directly, as `self.queryset` gets evaluated only once, and those results
are cached for all subsequent requests.
You may want to override this if you need to provide different
querysets depending on the incoming request.
(Eg. return a list of items that is specific to the user)
"""
assert self.queryset is not None, (
"'%s' should either include a `queryset` attribute, "
"or override the `get_queryset()` method."
% self.__class__.__name__
)
queryset = self.queryset
if isinstance(queryset, QuerySet):
# Ensure queryset is re-evaluated on each request.
queryset = queryset.all()
return queryset
def get_object(self):
"""
Returns the object the view is displaying.
You may want to override this if you need to provide non-standard
queryset lookups. Eg if objects are referenced using multiple
keyword arguments in the url conf.
"""
queryset = self.filter_queryset(self.get_queryset())
# Perform the lookup filtering.
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
assert lookup_url_kwarg in self.kwargs, (
'Expected view %s to be called with a URL keyword argument '
'named "%s". Fix your URL conf, or set the `.lookup_field` '
'attribute on the view correctly.' %
(self.__class__.__name__, lookup_url_kwarg)
)
filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]}
obj = get_object_or_404(queryset, **filter_kwargs)
# May raise a permission denied
self.check_object_permissions(self.request, obj)
return obj
def get_serializer(self, *args, **kwargs):
"""
Return the serializer instance that should be used for validating and
deserializing input, and for serializing output.
"""
serializer_class = self.get_serializer_class()
kwargs.setdefault('context', self.get_serializer_context())
return serializer_class(*args, **kwargs)
def get_serializer_class(self):
"""
Return the class to use for the serializer.
Defaults to using `self.serializer_class`.
You may want to override this if you need to provide different
serializations depending on the incoming request.
(Eg. admins get full serialization, others get basic serialization)
"""
assert self.serializer_class is not None, (
"'%s' should either include a `serializer_class` attribute, "
"or override the `get_serializer_class()` method."
% self.__class__.__name__
)
return self.serializer_class
def get_serializer_context(self):
"""
Extra context provided to the serializer class.
"""
return {
'request': self.request,
'format': self.format_kwarg,
'view': self
}
def filter_queryset(self, queryset):
"""
Given a queryset, filter it with whichever filter backend is in use.
You are unlikely to want to override this method, although you may need
to call it either from a list view, or from a custom `get_object`
method if you want to apply the configured filtering backend to the
default queryset.
"""
for backend in list(self.filter_backends):
queryset = backend().filter_queryset(self.request, queryset, self)
return queryset
@property
def paginator(self):
"""
The paginator instance associated with the view, or `None`.
"""
if not hasattr(self, '_paginator'):
if self.pagination_class is None:
self._paginator = None
else:
self._paginator = self.pagination_class()
return self._paginator
def paginate_queryset(self, queryset):
"""
Return a single page of results, or `None` if pagination is disabled.
"""
if self.paginator is None:
return None
return self.paginator.paginate_queryset(queryset, self.request, view=self)
def get_paginated_response(self, data):
"""
Return a paginated style `Response` object for the given output data.
"""
assert self.paginator is not None
return self.paginator.get_paginated_response(data)
# Concrete view classes that provide method handlers
# by composing the mixin classes with the base view.
class CreateAPIView(mixins.CreateModelMixin,
GenericAPIView):
"""
Concrete view for creating a model instance.
"""
def post(self, request, *args, **kwargs):
return self.create(request, *args, **kwargs)
class ListAPIView(mixins.ListModelMixin,
GenericAPIView):
"""
Concrete view for listing a queryset.
"""
def get(self, request, *args, **kwargs):
return self.list(request, *args, **kwargs)
class RetrieveAPIView(mixins.RetrieveModelMixin,
GenericAPIView):
"""
Concrete view for retrieving a model instance.
"""
def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs)
class DestroyAPIView(mixins.DestroyModelMixin,
GenericAPIView):
"""
Concrete view for deleting a model instance.
"""
def delete(self, request, *args, **kwargs):
return self.destroy(request, *args, **kwargs)
class UpdateAPIView(mixins.UpdateModelMixin,
GenericAPIView):
"""
Concrete view for updating a model instance.
"""
def put(self, request, *args, **kwargs):
return self.update(request, *args, **kwargs)
def patch(self, request, *args, **kwargs):
return self.partial_update(request, *args, **kwargs)
class ListCreateAPIView(mixins.ListModelMixin,
mixins.CreateModelMixin,
GenericAPIView):
"""
Concrete view for listing a queryset or creating a model instance.
"""
def get(self, request, *args, **kwargs):
return self.list(request, *args, **kwargs)
def post(self, request, *args, **kwargs):
return self.create(request, *args, **kwargs)
class RetrieveUpdateAPIView(mixins.RetrieveModelMixin,
mixins.UpdateModelMixin,
GenericAPIView):
"""
Concrete view for retrieving, updating a model instance.
"""
def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs)
def put(self, request, *args, **kwargs):
return self.update(request, *args, **kwargs)
def patch(self, request, *args, **kwargs):
return self.partial_update(request, *args, **kwargs)
class RetrieveDestroyAPIView(mixins.RetrieveModelMixin,
mixins.DestroyModelMixin,
GenericAPIView):
"""
Concrete view for retrieving or deleting a model instance.
"""
def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs)
def delete(self, request, *args, **kwargs):
return self.destroy(request, *args, **kwargs)
class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,
mixins.UpdateModelMixin,
mixins.DestroyModelMixin,
GenericAPIView):
"""
Concrete view for retrieving, updating or deleting a model instance.
"""
def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs)
def put(self, request, *args, **kwargs):
return self.update(request, *args, **kwargs)
def patch(self, request, *args, **kwargs):
return self.partial_update(request, *args, **kwargs)
def delete(self, request, *args, **kwargs):
return self.destroy(request, *args, **kwargs)

View File

@ -0,0 +1,71 @@
from django.core.management.base import BaseCommand
from django.utils.module_loading import import_string
from rest_framework import renderers
from rest_framework.schemas import coreapi
from rest_framework.schemas.openapi import SchemaGenerator
OPENAPI_MODE = 'openapi'
COREAPI_MODE = 'coreapi'
class Command(BaseCommand):
help = "Generates configured API schema for project."
def get_mode(self):
return COREAPI_MODE if coreapi.is_enabled() else OPENAPI_MODE
def add_arguments(self, parser):
parser.add_argument('--title', dest="title", default='', type=str)
parser.add_argument('--url', dest="url", default=None, type=str)
parser.add_argument('--description', dest="description", default=None, type=str)
if self.get_mode() == COREAPI_MODE:
parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json', 'corejson'], default='openapi', type=str)
else:
parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json'], default='openapi', type=str)
parser.add_argument('--urlconf', dest="urlconf", default=None, type=str)
parser.add_argument('--generator_class', dest="generator_class", default=None, type=str)
parser.add_argument('--file', dest="file", default=None, type=str)
parser.add_argument('--api_version', dest="api_version", default='', type=str)
def handle(self, *args, **options):
if options['generator_class']:
generator_class = import_string(options['generator_class'])
else:
generator_class = self.get_generator_class()
generator = generator_class(
url=options['url'],
title=options['title'],
description=options['description'],
urlconf=options['urlconf'],
version=options['api_version'],
)
schema = generator.get_schema(request=None, public=True)
renderer = self.get_renderer(options['format'])
output = renderer.render(schema, renderer_context={})
if options['file']:
with open(options['file'], 'wb') as f:
f.write(output)
else:
self.stdout.write(output.decode())
def get_renderer(self, format):
if self.get_mode() == COREAPI_MODE:
renderer_cls = {
'corejson': renderers.CoreJSONRenderer,
'openapi': renderers.CoreAPIOpenAPIRenderer,
'openapi-json': renderers.CoreAPIJSONOpenAPIRenderer,
}[format]
return renderer_cls()
renderer_cls = {
'openapi': renderers.OpenAPIRenderer,
'openapi-json': renderers.JSONOpenAPIRenderer,
}[format]
return renderer_cls()
def get_generator_class(self):
if self.get_mode() == COREAPI_MODE:
return coreapi.SchemaGenerator
return SchemaGenerator

View File

@ -0,0 +1,150 @@
"""
The metadata API is used to allow customization of how `OPTIONS` requests
are handled. We currently provide a single default implementation that returns
some fairly ad-hoc information about the view.
Future implementations might use JSON schema or other definitions in order
to return this information in a more standardized way.
"""
from collections import OrderedDict
from django.core.exceptions import PermissionDenied
from django.http import Http404
from django.utils.encoding import force_str
from rest_framework import exceptions, serializers
from rest_framework.request import clone_request
from rest_framework.utils.field_mapping import ClassLookupDict
class BaseMetadata:
def determine_metadata(self, request, view):
"""
Return a dictionary of metadata about the view.
Used to return responses for OPTIONS requests.
"""
raise NotImplementedError(".determine_metadata() must be overridden.")
class SimpleMetadata(BaseMetadata):
"""
This is the default metadata implementation.
It returns an ad-hoc set of information about the view.
There are not any formalized standards for `OPTIONS` responses
for us to base this on.
"""
label_lookup = ClassLookupDict({
serializers.Field: 'field',
serializers.BooleanField: 'boolean',
serializers.CharField: 'string',
serializers.UUIDField: 'string',
serializers.URLField: 'url',
serializers.EmailField: 'email',
serializers.RegexField: 'regex',
serializers.SlugField: 'slug',
serializers.IntegerField: 'integer',
serializers.FloatField: 'float',
serializers.DecimalField: 'decimal',
serializers.DateField: 'date',
serializers.DateTimeField: 'datetime',
serializers.TimeField: 'time',
serializers.ChoiceField: 'choice',
serializers.MultipleChoiceField: 'multiple choice',
serializers.FileField: 'file upload',
serializers.ImageField: 'image upload',
serializers.ListField: 'list',
serializers.DictField: 'nested object',
serializers.Serializer: 'nested object',
})
def determine_metadata(self, request, view):
metadata = OrderedDict()
metadata['name'] = view.get_view_name()
metadata['description'] = view.get_view_description()
metadata['renders'] = [renderer.media_type for renderer in view.renderer_classes]
metadata['parses'] = [parser.media_type for parser in view.parser_classes]
if hasattr(view, 'get_serializer'):
actions = self.determine_actions(request, view)
if actions:
metadata['actions'] = actions
return metadata
def determine_actions(self, request, view):
"""
For generic class based views we return information about
the fields that are accepted for 'PUT' and 'POST' methods.
"""
actions = {}
for method in {'PUT', 'POST'} & set(view.allowed_methods):
view.request = clone_request(request, method)
try:
# Test global permissions
if hasattr(view, 'check_permissions'):
view.check_permissions(view.request)
# Test object permissions
if method == 'PUT' and hasattr(view, 'get_object'):
view.get_object()
except (exceptions.APIException, PermissionDenied, Http404):
pass
else:
# If user has appropriate permissions for the view, include
# appropriate metadata about the fields that should be supplied.
serializer = view.get_serializer()
actions[method] = self.get_serializer_info(serializer)
finally:
view.request = request
return actions
def get_serializer_info(self, serializer):
"""
Given an instance of a serializer, return a dictionary of metadata
about its fields.
"""
if hasattr(serializer, 'child'):
# If this is a `ListSerializer` then we want to examine the
# underlying child serializer instance instead.
serializer = serializer.child
return OrderedDict([
(field_name, self.get_field_info(field))
for field_name, field in serializer.fields.items()
if not isinstance(field, serializers.HiddenField)
])
def get_field_info(self, field):
"""
Given an instance of a serializer field, return a dictionary
of metadata about it.
"""
field_info = OrderedDict()
field_info['type'] = self.label_lookup[field]
field_info['required'] = getattr(field, 'required', False)
attrs = [
'read_only', 'label', 'help_text',
'min_length', 'max_length',
'min_value', 'max_value'
]
for attr in attrs:
value = getattr(field, attr, None)
if value is not None and value != '':
field_info[attr] = force_str(value, strings_only=True)
if getattr(field, 'child', None):
field_info['child'] = self.get_field_info(field.child)
elif getattr(field, 'fields', None):
field_info['children'] = self.get_serializer_info(field)
if (not field_info.get('read_only') and
not isinstance(field, (serializers.RelatedField, serializers.ManyRelatedField)) and
hasattr(field, 'choices')):
field_info['choices'] = [
{
'value': choice_value,
'display_name': force_str(choice_name, strings_only=True)
}
for choice_value, choice_name in field.choices.items()
]
return field_info

View File

@ -0,0 +1,95 @@
"""
Basic building blocks for generic class based views.
We don't bind behaviour to http method handlers yet,
which allows mixin classes to be composed in interesting ways.
"""
from rest_framework import status
from rest_framework.response import Response
from rest_framework.settings import api_settings
class CreateModelMixin:
"""
Create a model instance.
"""
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
headers = self.get_success_headers(serializer.data)
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
def perform_create(self, serializer):
serializer.save()
def get_success_headers(self, data):
try:
return {'Location': str(data[api_settings.URL_FIELD_NAME])}
except (TypeError, KeyError):
return {}
class ListModelMixin:
"""
List a queryset.
"""
def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
page = self.paginate_queryset(queryset)
if page is not None:
serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data)
serializer = self.get_serializer(queryset, many=True)
return Response(serializer.data)
class RetrieveModelMixin:
"""
Retrieve a model instance.
"""
def retrieve(self, request, *args, **kwargs):
instance = self.get_object()
serializer = self.get_serializer(instance)
return Response(serializer.data)
class UpdateModelMixin:
"""
Update a model instance.
"""
def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False)
instance = self.get_object()
serializer = self.get_serializer(instance, data=request.data, partial=partial)
serializer.is_valid(raise_exception=True)
self.perform_update(serializer)
if getattr(instance, '_prefetched_objects_cache', None):
# If 'prefetch_related' has been applied to a queryset, we need to
# forcibly invalidate the prefetch cache on the instance.
instance._prefetched_objects_cache = {}
return Response(serializer.data)
def perform_update(self, serializer):
serializer.save()
def partial_update(self, request, *args, **kwargs):
kwargs['partial'] = True
return self.update(request, *args, **kwargs)
class DestroyModelMixin:
"""
Destroy a model instance.
"""
def destroy(self, request, *args, **kwargs):
instance = self.get_object()
self.perform_destroy(instance)
return Response(status=status.HTTP_204_NO_CONTENT)
def perform_destroy(self, instance):
instance.delete()

View File

@ -0,0 +1,97 @@
"""
Content negotiation deals with selecting an appropriate renderer given the
incoming request. Typically this will be based on the request's Accept header.
"""
from django.http import Http404
from rest_framework import exceptions
from rest_framework.settings import api_settings
from rest_framework.utils.mediatypes import (
_MediaType, media_type_matches, order_by_precedence
)
class BaseContentNegotiation:
def select_parser(self, request, parsers):
raise NotImplementedError('.select_parser() must be implemented')
def select_renderer(self, request, renderers, format_suffix=None):
raise NotImplementedError('.select_renderer() must be implemented')
class DefaultContentNegotiation(BaseContentNegotiation):
settings = api_settings
def select_parser(self, request, parsers):
"""
Given a list of parsers and a media type, return the appropriate
parser to handle the incoming request.
"""
for parser in parsers:
if media_type_matches(parser.media_type, request.content_type):
return parser
return None
def select_renderer(self, request, renderers, format_suffix=None):
"""
Given a request and a list of renderers, return a two-tuple of:
(renderer, media type).
"""
# Allow URL style format override. eg. "?format=json
format_query_param = self.settings.URL_FORMAT_OVERRIDE
format = format_suffix or request.query_params.get(format_query_param)
if format:
renderers = self.filter_renderers(renderers, format)
accepts = self.get_accept_list(request)
# Check the acceptable media types against each renderer,
# attempting more specific media types first
# NB. The inner loop here isn't as bad as it first looks :)
# Worst case is we're looping over len(accept_list) * len(self.renderers)
for media_type_set in order_by_precedence(accepts):
for renderer in renderers:
for media_type in media_type_set:
if media_type_matches(renderer.media_type, media_type):
# Return the most specific media type as accepted.
media_type_wrapper = _MediaType(media_type)
if (
_MediaType(renderer.media_type).precedence >
media_type_wrapper.precedence
):
# Eg client requests '*/*'
# Accepted media type is 'application/json'
full_media_type = ';'.join(
(renderer.media_type,) +
tuple(
'{}={}'.format(key, value)
for key, value in media_type_wrapper.params.items()
)
)
return renderer, full_media_type
else:
# Eg client requests 'application/json; indent=8'
# Accepted media type is 'application/json; indent=8'
return renderer, media_type
raise exceptions.NotAcceptable(available_renderers=renderers)
def filter_renderers(self, renderers, format):
"""
If there is a '.json' style format suffix, filter the renderers
so that we only negotiation against those that accept that format.
"""
renderers = [renderer for renderer in renderers
if renderer.format == format]
if not renderers:
raise Http404
return renderers
def get_accept_list(self, request):
"""
Given the incoming request, return a tokenized list of media
type strings.
"""
header = request.META.get('HTTP_ACCEPT', '*/*')
return [token.strip() for token in header.split(',')]

View File

@ -0,0 +1,980 @@
"""
Pagination serializers determine the structure of the output that should
be used for paginated responses.
"""
from base64 import b64decode, b64encode
from collections import OrderedDict, namedtuple
from urllib import parse
from django.core.paginator import InvalidPage
from django.core.paginator import Paginator as DjangoPaginator
from django.template import loader
from django.utils.encoding import force_str
from django.utils.translation import gettext_lazy as _
from rest_framework.compat import coreapi, coreschema
from rest_framework.exceptions import NotFound
from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework.utils.urls import remove_query_param, replace_query_param
def _positive_int(integer_string, strict=False, cutoff=None):
"""
Cast a string to a strictly positive integer.
"""
ret = int(integer_string)
if ret < 0 or (ret == 0 and strict):
raise ValueError()
if cutoff:
return min(ret, cutoff)
return ret
def _divide_with_ceil(a, b):
"""
Returns 'a' divided by 'b', with any remainder rounded up.
"""
if a % b:
return (a // b) + 1
return a // b
def _get_displayed_page_numbers(current, final):
"""
This utility function determines a list of page numbers to display.
This gives us a nice contextually relevant set of page numbers.
For example:
current=14, final=16 -> [1, None, 13, 14, 15, 16]
This implementation gives one page to each side of the cursor,
or two pages to the side when the cursor is at the edge, then
ensures that any breaks between non-continuous page numbers never
remove only a single page.
For an alternative implementation which gives two pages to each side of
the cursor, eg. as in GitHub issue list pagination, see:
https://gist.github.com/tomchristie/321140cebb1c4a558b15
"""
assert current >= 1
assert final >= current
if final <= 5:
return list(range(1, final + 1))
# We always include the first two pages, last two pages, and
# two pages either side of the current page.
included = {1, current - 1, current, current + 1, final}
# If the break would only exclude a single page number then we
# may as well include the page number instead of the break.
if current <= 4:
included.add(2)
included.add(3)
if current >= final - 3:
included.add(final - 1)
included.add(final - 2)
# Now sort the page numbers and drop anything outside the limits.
included = [
idx for idx in sorted(included)
if 0 < idx <= final
]
# Finally insert any `...` breaks
if current > 4:
included.insert(1, None)
if current < final - 3:
included.insert(len(included) - 1, None)
return included
def _get_page_links(page_numbers, current, url_func):
"""
Given a list of page numbers and `None` page breaks,
return a list of `PageLink` objects.
"""
page_links = []
for page_number in page_numbers:
if page_number is None:
page_link = PAGE_BREAK
else:
page_link = PageLink(
url=url_func(page_number),
number=page_number,
is_active=(page_number == current),
is_break=False
)
page_links.append(page_link)
return page_links
def _reverse_ordering(ordering_tuple):
"""
Given an order_by tuple such as `('-created', 'uuid')` reverse the
ordering and return a new tuple, eg. `('created', '-uuid')`.
"""
def invert(x):
return x[1:] if x.startswith('-') else '-' + x
return tuple([invert(item) for item in ordering_tuple])
Cursor = namedtuple('Cursor', ['offset', 'reverse', 'position'])
PageLink = namedtuple('PageLink', ['url', 'number', 'is_active', 'is_break'])
PAGE_BREAK = PageLink(url=None, number=None, is_active=False, is_break=True)
class BasePagination:
display_page_controls = False
def paginate_queryset(self, queryset, request, view=None): # pragma: no cover
raise NotImplementedError('paginate_queryset() must be implemented.')
def get_paginated_response(self, data): # pragma: no cover
raise NotImplementedError('get_paginated_response() must be implemented.')
def get_paginated_response_schema(self, schema):
return schema
def to_html(self): # pragma: no cover
raise NotImplementedError('to_html() must be implemented to display page controls.')
def get_results(self, data):
return data['results']
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
return []
def get_schema_operation_parameters(self, view):
return []
class PageNumberPagination(BasePagination):
"""
A simple page number based style that supports page numbers as
query parameters. For example:
http://api.example.org/accounts/?page=4
http://api.example.org/accounts/?page=4&page_size=100
"""
# The default page size.
# Defaults to `None`, meaning pagination is disabled.
page_size = api_settings.PAGE_SIZE
django_paginator_class = DjangoPaginator
# Client can control the page using this query parameter.
page_query_param = 'page'
page_query_description = _('A page number within the paginated result set.')
# Client can control the page size using this query parameter.
# Default is 'None'. Set to eg 'page_size' to enable usage.
page_size_query_param = None
page_size_query_description = _('Number of results to return per page.')
# Set to an integer to limit the maximum page size the client may request.
# Only relevant if 'page_size_query_param' has also been set.
max_page_size = None
last_page_strings = ('last',)
template = 'rest_framework/pagination/numbers.html'
invalid_page_message = _('Invalid page.')
def paginate_queryset(self, queryset, request, view=None):
"""
Paginate a queryset if required, either returning a
page object, or `None` if pagination is not configured for this view.
"""
page_size = self.get_page_size(request)
if not page_size:
return None
paginator = self.django_paginator_class(queryset, page_size)
page_number = self.get_page_number(request, paginator)
try:
self.page = paginator.page(page_number)
except InvalidPage as exc:
msg = self.invalid_page_message.format(
page_number=page_number, message=str(exc)
)
raise NotFound(msg)
if paginator.num_pages > 1 and self.template is not None:
# The browsable API should display pagination controls.
self.display_page_controls = True
self.request = request
return list(self.page)
def get_page_number(self, request, paginator):
page_number = request.query_params.get(self.page_query_param, 1)
if page_number in self.last_page_strings:
page_number = paginator.num_pages
return page_number
def get_paginated_response(self, data):
return Response(OrderedDict([
('count', self.page.paginator.count),
('next', self.get_next_link()),
('previous', self.get_previous_link()),
('results', data)
]))
def get_paginated_response_schema(self, schema):
return {
'type': 'object',
'properties': {
'count': {
'type': 'integer',
'example': 123,
},
'next': {
'type': 'string',
'nullable': True,
'format': 'uri',
'example': 'http://api.example.org/accounts/?{page_query_param}=4'.format(
page_query_param=self.page_query_param)
},
'previous': {
'type': 'string',
'nullable': True,
'format': 'uri',
'example': 'http://api.example.org/accounts/?{page_query_param}=2'.format(
page_query_param=self.page_query_param)
},
'results': schema,
},
}
def get_page_size(self, request):
if self.page_size_query_param:
try:
return _positive_int(
request.query_params[self.page_size_query_param],
strict=True,
cutoff=self.max_page_size
)
except (KeyError, ValueError):
pass
return self.page_size
def get_next_link(self):
if not self.page.has_next():
return None
url = self.request.build_absolute_uri()
page_number = self.page.next_page_number()
return replace_query_param(url, self.page_query_param, page_number)
def get_previous_link(self):
if not self.page.has_previous():
return None
url = self.request.build_absolute_uri()
page_number = self.page.previous_page_number()
if page_number == 1:
return remove_query_param(url, self.page_query_param)
return replace_query_param(url, self.page_query_param, page_number)
def get_html_context(self):
base_url = self.request.build_absolute_uri()
def page_number_to_url(page_number):
if page_number == 1:
return remove_query_param(base_url, self.page_query_param)
else:
return replace_query_param(base_url, self.page_query_param, page_number)
current = self.page.number
final = self.page.paginator.num_pages
page_numbers = _get_displayed_page_numbers(current, final)
page_links = _get_page_links(page_numbers, current, page_number_to_url)
return {
'previous_url': self.get_previous_link(),
'next_url': self.get_next_link(),
'page_links': page_links
}
def to_html(self):
template = loader.get_template(self.template)
context = self.get_html_context()
return template.render(context)
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
fields = [
coreapi.Field(
name=self.page_query_param,
required=False,
location='query',
schema=coreschema.Integer(
title='Page',
description=force_str(self.page_query_description)
)
)
]
if self.page_size_query_param is not None:
fields.append(
coreapi.Field(
name=self.page_size_query_param,
required=False,
location='query',
schema=coreschema.Integer(
title='Page size',
description=force_str(self.page_size_query_description)
)
)
)
return fields
def get_schema_operation_parameters(self, view):
parameters = [
{
'name': self.page_query_param,
'required': False,
'in': 'query',
'description': force_str(self.page_query_description),
'schema': {
'type': 'integer',
},
},
]
if self.page_size_query_param is not None:
parameters.append(
{
'name': self.page_size_query_param,
'required': False,
'in': 'query',
'description': force_str(self.page_size_query_description),
'schema': {
'type': 'integer',
},
},
)
return parameters
class LimitOffsetPagination(BasePagination):
"""
A limit/offset based style. For example:
http://api.example.org/accounts/?limit=100
http://api.example.org/accounts/?offset=400&limit=100
"""
default_limit = api_settings.PAGE_SIZE
limit_query_param = 'limit'
limit_query_description = _('Number of results to return per page.')
offset_query_param = 'offset'
offset_query_description = _('The initial index from which to return the results.')
max_limit = None
template = 'rest_framework/pagination/numbers.html'
def paginate_queryset(self, queryset, request, view=None):
self.limit = self.get_limit(request)
if self.limit is None:
return None
self.count = self.get_count(queryset)
self.offset = self.get_offset(request)
self.request = request
if self.count > self.limit and self.template is not None:
self.display_page_controls = True
if self.count == 0 or self.offset > self.count:
return []
return list(queryset[self.offset:self.offset + self.limit])
def get_paginated_response(self, data):
return Response(OrderedDict([
('count', self.count),
('next', self.get_next_link()),
('previous', self.get_previous_link()),
('results', data)
]))
def get_paginated_response_schema(self, schema):
return {
'type': 'object',
'properties': {
'count': {
'type': 'integer',
'example': 123,
},
'next': {
'type': 'string',
'nullable': True,
'format': 'uri',
'example': 'http://api.example.org/accounts/?{offset_param}=400&{limit_param}=100'.format(
offset_param=self.offset_query_param, limit_param=self.limit_query_param),
},
'previous': {
'type': 'string',
'nullable': True,
'format': 'uri',
'example': 'http://api.example.org/accounts/?{offset_param}=200&{limit_param}=100'.format(
offset_param=self.offset_query_param, limit_param=self.limit_query_param),
},
'results': schema,
},
}
def get_limit(self, request):
if self.limit_query_param:
try:
return _positive_int(
request.query_params[self.limit_query_param],
strict=True,
cutoff=self.max_limit
)
except (KeyError, ValueError):
pass
return self.default_limit
def get_offset(self, request):
try:
return _positive_int(
request.query_params[self.offset_query_param],
)
except (KeyError, ValueError):
return 0
def get_next_link(self):
if self.offset + self.limit >= self.count:
return None
url = self.request.build_absolute_uri()
url = replace_query_param(url, self.limit_query_param, self.limit)
offset = self.offset + self.limit
return replace_query_param(url, self.offset_query_param, offset)
def get_previous_link(self):
if self.offset <= 0:
return None
url = self.request.build_absolute_uri()
url = replace_query_param(url, self.limit_query_param, self.limit)
if self.offset - self.limit <= 0:
return remove_query_param(url, self.offset_query_param)
offset = self.offset - self.limit
return replace_query_param(url, self.offset_query_param, offset)
def get_html_context(self):
base_url = self.request.build_absolute_uri()
if self.limit:
current = _divide_with_ceil(self.offset, self.limit) + 1
# The number of pages is a little bit fiddly.
# We need to sum both the number of pages from current offset to end
# plus the number of pages up to the current offset.
# When offset is not strictly divisible by the limit then we may
# end up introducing an extra page as an artifact.
final = (
_divide_with_ceil(self.count - self.offset, self.limit) +
_divide_with_ceil(self.offset, self.limit)
)
final = max(final, 1)
else:
current = 1
final = 1
if current > final:
current = final
def page_number_to_url(page_number):
if page_number == 1:
return remove_query_param(base_url, self.offset_query_param)
else:
offset = self.offset + ((page_number - current) * self.limit)
return replace_query_param(base_url, self.offset_query_param, offset)
page_numbers = _get_displayed_page_numbers(current, final)
page_links = _get_page_links(page_numbers, current, page_number_to_url)
return {
'previous_url': self.get_previous_link(),
'next_url': self.get_next_link(),
'page_links': page_links
}
def to_html(self):
template = loader.get_template(self.template)
context = self.get_html_context()
return template.render(context)
def get_count(self, queryset):
"""
Determine an object count, supporting either querysets or regular lists.
"""
try:
return queryset.count()
except (AttributeError, TypeError):
return len(queryset)
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
return [
coreapi.Field(
name=self.limit_query_param,
required=False,
location='query',
schema=coreschema.Integer(
title='Limit',
description=force_str(self.limit_query_description)
)
),
coreapi.Field(
name=self.offset_query_param,
required=False,
location='query',
schema=coreschema.Integer(
title='Offset',
description=force_str(self.offset_query_description)
)
)
]
def get_schema_operation_parameters(self, view):
parameters = [
{
'name': self.limit_query_param,
'required': False,
'in': 'query',
'description': force_str(self.limit_query_description),
'schema': {
'type': 'integer',
},
},
{
'name': self.offset_query_param,
'required': False,
'in': 'query',
'description': force_str(self.offset_query_description),
'schema': {
'type': 'integer',
},
},
]
return parameters
class CursorPagination(BasePagination):
"""
The cursor pagination implementation is necessarily complex.
For an overview of the position/offset style we use, see this post:
https://cra.mr/2011/03/08/building-cursors-for-the-disqus-api
"""
cursor_query_param = 'cursor'
cursor_query_description = _('The pagination cursor value.')
page_size = api_settings.PAGE_SIZE
invalid_cursor_message = _('Invalid cursor')
ordering = '-created'
template = 'rest_framework/pagination/previous_and_next.html'
# Client can control the page size using this query parameter.
# Default is 'None'. Set to eg 'page_size' to enable usage.
page_size_query_param = None
page_size_query_description = _('Number of results to return per page.')
# Set to an integer to limit the maximum page size the client may request.
# Only relevant if 'page_size_query_param' has also been set.
max_page_size = None
# The offset in the cursor is used in situations where we have a
# nearly-unique index. (Eg millisecond precision creation timestamps)
# We guard against malicious users attempting to cause expensive database
# queries, by having a hard cap on the maximum possible size of the offset.
offset_cutoff = 1000
def paginate_queryset(self, queryset, request, view=None):
self.page_size = self.get_page_size(request)
if not self.page_size:
return None
self.base_url = request.build_absolute_uri()
self.ordering = self.get_ordering(request, queryset, view)
self.cursor = self.decode_cursor(request)
if self.cursor is None:
(offset, reverse, current_position) = (0, False, None)
else:
(offset, reverse, current_position) = self.cursor
# Cursor pagination always enforces an ordering.
if reverse:
queryset = queryset.order_by(*_reverse_ordering(self.ordering))
else:
queryset = queryset.order_by(*self.ordering)
# If we have a cursor with a fixed position then filter by that.
if current_position is not None:
order = self.ordering[0]
is_reversed = order.startswith('-')
order_attr = order.lstrip('-')
# Test for: (cursor reversed) XOR (queryset reversed)
if self.cursor.reverse != is_reversed:
kwargs = {order_attr + '__lt': current_position}
else:
kwargs = {order_attr + '__gt': current_position}
queryset = queryset.filter(**kwargs)
# If we have an offset cursor then offset the entire page by that amount.
# We also always fetch an extra item in order to determine if there is a
# page following on from this one.
results = list(queryset[offset:offset + self.page_size + 1])
self.page = list(results[:self.page_size])
# Determine the position of the final item following the page.
if len(results) > len(self.page):
has_following_position = True
following_position = self._get_position_from_instance(results[-1], self.ordering)
else:
has_following_position = False
following_position = None
if reverse:
# If we have a reverse queryset, then the query ordering was in reverse
# so we need to reverse the items again before returning them to the user.
self.page = list(reversed(self.page))
# Determine next and previous positions for reverse cursors.
self.has_next = (current_position is not None) or (offset > 0)
self.has_previous = has_following_position
if self.has_next:
self.next_position = current_position
if self.has_previous:
self.previous_position = following_position
else:
# Determine next and previous positions for forward cursors.
self.has_next = has_following_position
self.has_previous = (current_position is not None) or (offset > 0)
if self.has_next:
self.next_position = following_position
if self.has_previous:
self.previous_position = current_position
# Display page controls in the browsable API if there is more
# than one page.
if (self.has_previous or self.has_next) and self.template is not None:
self.display_page_controls = True
return self.page
def get_page_size(self, request):
if self.page_size_query_param:
try:
return _positive_int(
request.query_params[self.page_size_query_param],
strict=True,
cutoff=self.max_page_size
)
except (KeyError, ValueError):
pass
return self.page_size
def get_next_link(self):
if not self.has_next:
return None
if self.page and self.cursor and self.cursor.reverse and self.cursor.offset != 0:
# If we're reversing direction and we have an offset cursor
# then we cannot use the first position we find as a marker.
compare = self._get_position_from_instance(self.page[-1], self.ordering)
else:
compare = self.next_position
offset = 0
has_item_with_unique_position = False
for item in reversed(self.page):
position = self._get_position_from_instance(item, self.ordering)
if position != compare:
# The item in this position and the item following it
# have different positions. We can use this position as
# our marker.
has_item_with_unique_position = True
break
# The item in this position has the same position as the item
# following it, we can't use it as a marker position, so increment
# the offset and keep seeking to the previous item.
compare = position
offset += 1
if self.page and not has_item_with_unique_position:
# There were no unique positions in the page.
if not self.has_previous:
# We are on the first page.
# Our cursor will have an offset equal to the page size,
# but no position to filter against yet.
offset = self.page_size
position = None
elif self.cursor.reverse:
# The change in direction will introduce a paging artifact,
# where we end up skipping forward a few extra items.
offset = 0
position = self.previous_position
else:
# Use the position from the existing cursor and increment
# it's offset by the page size.
offset = self.cursor.offset + self.page_size
position = self.previous_position
if not self.page:
position = self.next_position
cursor = Cursor(offset=offset, reverse=False, position=position)
return self.encode_cursor(cursor)
def get_previous_link(self):
if not self.has_previous:
return None
if self.page and self.cursor and not self.cursor.reverse and self.cursor.offset != 0:
# If we're reversing direction and we have an offset cursor
# then we cannot use the first position we find as a marker.
compare = self._get_position_from_instance(self.page[0], self.ordering)
else:
compare = self.previous_position
offset = 0
has_item_with_unique_position = False
for item in self.page:
position = self._get_position_from_instance(item, self.ordering)
if position != compare:
# The item in this position and the item following it
# have different positions. We can use this position as
# our marker.
has_item_with_unique_position = True
break
# The item in this position has the same position as the item
# following it, we can't use it as a marker position, so increment
# the offset and keep seeking to the previous item.
compare = position
offset += 1
if self.page and not has_item_with_unique_position:
# There were no unique positions in the page.
if not self.has_next:
# We are on the final page.
# Our cursor will have an offset equal to the page size,
# but no position to filter against yet.
offset = self.page_size
position = None
elif self.cursor.reverse:
# Use the position from the existing cursor and increment
# it's offset by the page size.
offset = self.cursor.offset + self.page_size
position = self.next_position
else:
# The change in direction will introduce a paging artifact,
# where we end up skipping back a few extra items.
offset = 0
position = self.next_position
if not self.page:
position = self.previous_position
cursor = Cursor(offset=offset, reverse=True, position=position)
return self.encode_cursor(cursor)
def get_ordering(self, request, queryset, view):
"""
Return a tuple of strings, that may be used in an `order_by` method.
"""
ordering_filters = [
filter_cls for filter_cls in getattr(view, 'filter_backends', [])
if hasattr(filter_cls, 'get_ordering')
]
if ordering_filters:
# If a filter exists on the view that implements `get_ordering`
# then we defer to that filter to determine the ordering.
filter_cls = ordering_filters[0]
filter_instance = filter_cls()
ordering = filter_instance.get_ordering(request, queryset, view)
assert ordering is not None, (
'Using cursor pagination, but filter class {filter_cls} '
'returned a `None` ordering.'.format(
filter_cls=filter_cls.__name__
)
)
else:
# The default case is to check for an `ordering` attribute
# on this pagination instance.
ordering = self.ordering
assert ordering is not None, (
'Using cursor pagination, but no ordering attribute was declared '
'on the pagination class.'
)
assert '__' not in ordering, (
'Cursor pagination does not support double underscore lookups '
'for orderings. Orderings should be an unchanging, unique or '
'nearly-unique field on the model, such as "-created" or "pk".'
)
assert isinstance(ordering, (str, list, tuple)), (
'Invalid ordering. Expected string or tuple, but got {type}'.format(
type=type(ordering).__name__
)
)
if isinstance(ordering, str):
return (ordering,)
return tuple(ordering)
def decode_cursor(self, request):
"""
Given a request with a cursor, return a `Cursor` instance.
"""
# Determine if we have a cursor, and if so then decode it.
encoded = request.query_params.get(self.cursor_query_param)
if encoded is None:
return None
try:
querystring = b64decode(encoded.encode('ascii')).decode('ascii')
tokens = parse.parse_qs(querystring, keep_blank_values=True)
offset = tokens.get('o', ['0'])[0]
offset = _positive_int(offset, cutoff=self.offset_cutoff)
reverse = tokens.get('r', ['0'])[0]
reverse = bool(int(reverse))
position = tokens.get('p', [None])[0]
except (TypeError, ValueError):
raise NotFound(self.invalid_cursor_message)
return Cursor(offset=offset, reverse=reverse, position=position)
def encode_cursor(self, cursor):
"""
Given a Cursor instance, return an url with encoded cursor.
"""
tokens = {}
if cursor.offset != 0:
tokens['o'] = str(cursor.offset)
if cursor.reverse:
tokens['r'] = '1'
if cursor.position is not None:
tokens['p'] = cursor.position
querystring = parse.urlencode(tokens, doseq=True)
encoded = b64encode(querystring.encode('ascii')).decode('ascii')
return replace_query_param(self.base_url, self.cursor_query_param, encoded)
def _get_position_from_instance(self, instance, ordering):
field_name = ordering[0].lstrip('-')
if isinstance(instance, dict):
attr = instance[field_name]
else:
attr = getattr(instance, field_name)
return str(attr)
def get_paginated_response(self, data):
return Response(OrderedDict([
('next', self.get_next_link()),
('previous', self.get_previous_link()),
('results', data)
]))
def get_paginated_response_schema(self, schema):
return {
'type': 'object',
'properties': {
'next': {
'type': 'string',
'nullable': True,
},
'previous': {
'type': 'string',
'nullable': True,
},
'results': schema,
},
}
def get_html_context(self):
return {
'previous_url': self.get_previous_link(),
'next_url': self.get_next_link()
}
def to_html(self):
template = loader.get_template(self.template)
context = self.get_html_context()
return template.render(context)
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
fields = [
coreapi.Field(
name=self.cursor_query_param,
required=False,
location='query',
schema=coreschema.String(
title='Cursor',
description=force_str(self.cursor_query_description)
)
)
]
if self.page_size_query_param is not None:
fields.append(
coreapi.Field(
name=self.page_size_query_param,
required=False,
location='query',
schema=coreschema.Integer(
title='Page size',
description=force_str(self.page_size_query_description)
)
)
)
return fields
def get_schema_operation_parameters(self, view):
parameters = [
{
'name': self.cursor_query_param,
'required': False,
'in': 'query',
'description': force_str(self.cursor_query_description),
'schema': {
'type': 'string',
},
}
]
if self.page_size_query_param is not None:
parameters.append(
{
'name': self.page_size_query_param,
'required': False,
'in': 'query',
'description': force_str(self.page_size_query_description),
'schema': {
'type': 'integer',
},
}
)
return parameters

View File

@ -0,0 +1,209 @@
"""
Parsers are used to parse the content of incoming HTTP requests.
They give us a generic way of being able to handle various media types
on the request, such as form content or json encoded data.
"""
import codecs
from django.conf import settings
from django.core.files.uploadhandler import StopFutureHandlers
from django.http import QueryDict
from django.http.multipartparser import ChunkIter
from django.http.multipartparser import \
MultiPartParser as DjangoMultiPartParser
from django.http.multipartparser import MultiPartParserError
from rest_framework import renderers
from rest_framework.compat import parse_header_parameters
from rest_framework.exceptions import ParseError
from rest_framework.settings import api_settings
from rest_framework.utils import json
class DataAndFiles:
def __init__(self, data, files):
self.data = data
self.files = files
class BaseParser:
"""
All parsers should extend `BaseParser`, specifying a `media_type`
attribute, and overriding the `.parse()` method.
"""
media_type = None
def parse(self, stream, media_type=None, parser_context=None):
"""
Given a stream to read from, return the parsed representation.
Should return parsed data, or a `DataAndFiles` object consisting of the
parsed data and files.
"""
raise NotImplementedError(".parse() must be overridden.")
class JSONParser(BaseParser):
"""
Parses JSON-serialized data.
"""
media_type = 'application/json'
renderer_class = renderers.JSONRenderer
strict = api_settings.STRICT_JSON
def parse(self, stream, media_type=None, parser_context=None):
"""
Parses the incoming bytestream as JSON and returns the resulting data.
"""
parser_context = parser_context or {}
encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
try:
decoded_stream = codecs.getreader(encoding)(stream)
parse_constant = json.strict_constant if self.strict else None
return json.load(decoded_stream, parse_constant=parse_constant)
except ValueError as exc:
raise ParseError('JSON parse error - %s' % str(exc))
class FormParser(BaseParser):
"""
Parser for form data.
"""
media_type = 'application/x-www-form-urlencoded'
def parse(self, stream, media_type=None, parser_context=None):
"""
Parses the incoming bytestream as a URL encoded form,
and returns the resulting QueryDict.
"""
parser_context = parser_context or {}
encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
return QueryDict(stream.read(), encoding=encoding)
class MultiPartParser(BaseParser):
"""
Parser for multipart form data, which may include file data.
"""
media_type = 'multipart/form-data'
def parse(self, stream, media_type=None, parser_context=None):
"""
Parses the incoming bytestream as a multipart encoded form,
and returns a DataAndFiles object.
`.data` will be a `QueryDict` containing all the form parameters.
`.files` will be a `QueryDict` containing all the form files.
"""
parser_context = parser_context or {}
request = parser_context['request']
encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
meta = request.META.copy()
meta['CONTENT_TYPE'] = media_type
upload_handlers = request.upload_handlers
try:
parser = DjangoMultiPartParser(meta, stream, upload_handlers, encoding)
data, files = parser.parse()
return DataAndFiles(data, files)
except MultiPartParserError as exc:
raise ParseError('Multipart form parse error - %s' % str(exc))
class FileUploadParser(BaseParser):
"""
Parser for file upload data.
"""
media_type = '*/*'
errors = {
'unhandled': 'FileUpload parse error - none of upload handlers can handle the stream',
'no_filename': 'Missing filename. Request should include a Content-Disposition header with a filename parameter.',
}
def parse(self, stream, media_type=None, parser_context=None):
"""
Treats the incoming bytestream as a raw file upload and returns
a `DataAndFiles` object.
`.data` will be None (we expect request body to be a file content).
`.files` will be a `QueryDict` containing one 'file' element.
"""
parser_context = parser_context or {}
request = parser_context['request']
encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
meta = request.META
upload_handlers = request.upload_handlers
filename = self.get_filename(stream, media_type, parser_context)
if not filename:
raise ParseError(self.errors['no_filename'])
# Note that this code is extracted from Django's handling of
# file uploads in MultiPartParser.
content_type = meta.get('HTTP_CONTENT_TYPE',
meta.get('CONTENT_TYPE', ''))
try:
content_length = int(meta.get('HTTP_CONTENT_LENGTH',
meta.get('CONTENT_LENGTH', 0)))
except (ValueError, TypeError):
content_length = None
# See if the handler will want to take care of the parsing.
for handler in upload_handlers:
result = handler.handle_raw_input(stream,
meta,
content_length,
None,
encoding)
if result is not None:
return DataAndFiles({}, {'file': result[1]})
# This is the standard case.
possible_sizes = [x.chunk_size for x in upload_handlers if x.chunk_size]
chunk_size = min([2 ** 31 - 4] + possible_sizes)
chunks = ChunkIter(stream, chunk_size)
counters = [0] * len(upload_handlers)
for index, handler in enumerate(upload_handlers):
try:
handler.new_file(None, filename, content_type,
content_length, encoding)
except StopFutureHandlers:
upload_handlers = upload_handlers[:index + 1]
break
for chunk in chunks:
for index, handler in enumerate(upload_handlers):
chunk_length = len(chunk)
chunk = handler.receive_data_chunk(chunk, counters[index])
counters[index] += chunk_length
if chunk is None:
break
for index, handler in enumerate(upload_handlers):
file_obj = handler.file_complete(counters[index])
if file_obj is not None:
return DataAndFiles({}, {'file': file_obj})
raise ParseError(self.errors['unhandled'])
def get_filename(self, stream, media_type, parser_context):
"""
Detects the uploaded file name. First searches a 'filename' url kwarg.
Then tries to parse Content-Disposition header.
"""
try:
return parser_context['kwargs']['filename']
except KeyError:
pass
try:
meta = parser_context['request'].META
disposition, params = parse_header_parameters(meta['HTTP_CONTENT_DISPOSITION'])
if 'filename*' in params:
return params['filename*']
else:
return params['filename']
except (AttributeError, KeyError, ValueError):
pass

View File

@ -0,0 +1,303 @@
"""
Provides a set of pluggable permission policies.
"""
from django.http import Http404
from rest_framework import exceptions
SAFE_METHODS = ('GET', 'HEAD', 'OPTIONS')
class OperationHolderMixin:
def __and__(self, other):
return OperandHolder(AND, self, other)
def __or__(self, other):
return OperandHolder(OR, self, other)
def __rand__(self, other):
return OperandHolder(AND, other, self)
def __ror__(self, other):
return OperandHolder(OR, other, self)
def __invert__(self):
return SingleOperandHolder(NOT, self)
class SingleOperandHolder(OperationHolderMixin):
def __init__(self, operator_class, op1_class):
self.operator_class = operator_class
self.op1_class = op1_class
def __call__(self, *args, **kwargs):
op1 = self.op1_class(*args, **kwargs)
return self.operator_class(op1)
class OperandHolder(OperationHolderMixin):
def __init__(self, operator_class, op1_class, op2_class):
self.operator_class = operator_class
self.op1_class = op1_class
self.op2_class = op2_class
def __call__(self, *args, **kwargs):
op1 = self.op1_class(*args, **kwargs)
op2 = self.op2_class(*args, **kwargs)
return self.operator_class(op1, op2)
class AND:
def __init__(self, op1, op2):
self.op1 = op1
self.op2 = op2
def has_permission(self, request, view):
return (
self.op1.has_permission(request, view) and
self.op2.has_permission(request, view)
)
def has_object_permission(self, request, view, obj):
return (
self.op1.has_object_permission(request, view, obj) and
self.op2.has_object_permission(request, view, obj)
)
class OR:
def __init__(self, op1, op2):
self.op1 = op1
self.op2 = op2
def has_permission(self, request, view):
return (
self.op1.has_permission(request, view) or
self.op2.has_permission(request, view)
)
def has_object_permission(self, request, view, obj):
return (
self.op1.has_permission(request, view)
and self.op1.has_object_permission(request, view, obj)
) or (
self.op2.has_permission(request, view)
and self.op2.has_object_permission(request, view, obj)
)
class NOT:
def __init__(self, op1):
self.op1 = op1
def has_permission(self, request, view):
return not self.op1.has_permission(request, view)
def has_object_permission(self, request, view, obj):
return not self.op1.has_object_permission(request, view, obj)
class BasePermissionMetaclass(OperationHolderMixin, type):
pass
class BasePermission(metaclass=BasePermissionMetaclass):
"""
A base class from which all permission classes should inherit.
"""
def has_permission(self, request, view):
"""
Return `True` if permission is granted, `False` otherwise.
"""
return True
def has_object_permission(self, request, view, obj):
"""
Return `True` if permission is granted, `False` otherwise.
"""
return True
class AllowAny(BasePermission):
"""
Allow any access.
This isn't strictly required, since you could use an empty
permission_classes list, but it's useful because it makes the intention
more explicit.
"""
def has_permission(self, request, view):
return True
class IsAuthenticated(BasePermission):
"""
Allows access only to authenticated users.
"""
def has_permission(self, request, view):
return bool(request.user and request.user.is_authenticated)
class IsAdminUser(BasePermission):
"""
Allows access only to admin users.
"""
def has_permission(self, request, view):
return bool(request.user and request.user.is_staff)
class IsAuthenticatedOrReadOnly(BasePermission):
"""
The request is authenticated as a user, or is a read-only request.
"""
def has_permission(self, request, view):
return bool(
request.method in SAFE_METHODS or
request.user and
request.user.is_authenticated
)
class DjangoModelPermissions(BasePermission):
"""
The request is authenticated using `django.contrib.auth` permissions.
See: https://docs.djangoproject.com/en/dev/topics/auth/#permissions
It ensures that the user is authenticated, and has the appropriate
`add`/`change`/`delete` permissions on the model.
This permission can only be applied against view classes that
provide a `.queryset` attribute.
"""
# Map methods into required permission codes.
# Override this if you need to also provide 'view' permissions,
# or if you want to provide custom permission codes.
perms_map = {
'GET': [],
'OPTIONS': [],
'HEAD': [],
'POST': ['%(app_label)s.add_%(model_name)s'],
'PUT': ['%(app_label)s.change_%(model_name)s'],
'PATCH': ['%(app_label)s.change_%(model_name)s'],
'DELETE': ['%(app_label)s.delete_%(model_name)s'],
}
authenticated_users_only = True
def get_required_permissions(self, method, model_cls):
"""
Given a model and an HTTP method, return the list of permission
codes that the user is required to have.
"""
kwargs = {
'app_label': model_cls._meta.app_label,
'model_name': model_cls._meta.model_name
}
if method not in self.perms_map:
raise exceptions.MethodNotAllowed(method)
return [perm % kwargs for perm in self.perms_map[method]]
def _queryset(self, view):
assert hasattr(view, 'get_queryset') \
or getattr(view, 'queryset', None) is not None, (
'Cannot apply {} on a view that does not set '
'`.queryset` or have a `.get_queryset()` method.'
).format(self.__class__.__name__)
if hasattr(view, 'get_queryset'):
queryset = view.get_queryset()
assert queryset is not None, (
'{}.get_queryset() returned None'.format(view.__class__.__name__)
)
return queryset
return view.queryset
def has_permission(self, request, view):
# Workaround to ensure DjangoModelPermissions are not applied
# to the root view when using DefaultRouter.
if getattr(view, '_ignore_model_permissions', False):
return True
if not request.user or (
not request.user.is_authenticated and self.authenticated_users_only):
return False
queryset = self._queryset(view)
perms = self.get_required_permissions(request.method, queryset.model)
return request.user.has_perms(perms)
class DjangoModelPermissionsOrAnonReadOnly(DjangoModelPermissions):
"""
Similar to DjangoModelPermissions, except that anonymous users are
allowed read-only access.
"""
authenticated_users_only = False
class DjangoObjectPermissions(DjangoModelPermissions):
"""
The request is authenticated using Django's object-level permissions.
It requires an object-permissions-enabled backend, such as Django Guardian.
It ensures that the user is authenticated, and has the appropriate
`add`/`change`/`delete` permissions on the object using .has_perms.
This permission can only be applied against view classes that
provide a `.queryset` attribute.
"""
perms_map = {
'GET': [],
'OPTIONS': [],
'HEAD': [],
'POST': ['%(app_label)s.add_%(model_name)s'],
'PUT': ['%(app_label)s.change_%(model_name)s'],
'PATCH': ['%(app_label)s.change_%(model_name)s'],
'DELETE': ['%(app_label)s.delete_%(model_name)s'],
}
def get_required_object_permissions(self, method, model_cls):
kwargs = {
'app_label': model_cls._meta.app_label,
'model_name': model_cls._meta.model_name
}
if method not in self.perms_map:
raise exceptions.MethodNotAllowed(method)
return [perm % kwargs for perm in self.perms_map[method]]
def has_object_permission(self, request, view, obj):
# authentication checks have already executed via has_permission
queryset = self._queryset(view)
model_cls = queryset.model
user = request.user
perms = self.get_required_object_permissions(request.method, model_cls)
if not user.has_perms(perms, obj):
# If the user does not have permissions we need to determine if
# they have read permissions to see 403, or not, and simply see
# a 404 response.
if request.method in SAFE_METHODS:
# Read permissions already checked and failed, no need
# to make another lookup.
raise Http404
read_perms = self.get_required_object_permissions('GET', model_cls)
if not user.has_perms(read_perms, obj):
raise Http404
# Has read permissions.
return False
return True

View File

@ -0,0 +1,586 @@
import sys
from collections import OrderedDict
from urllib import parse
from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist
from django.db.models import Manager
from django.db.models.query import QuerySet
from django.urls import NoReverseMatch, Resolver404, get_script_prefix, resolve
from django.utils.encoding import smart_str, uri_to_iri
from django.utils.translation import gettext_lazy as _
from rest_framework.fields import (
Field, SkipField, empty, get_attribute, is_simple_callable, iter_options
)
from rest_framework.reverse import reverse
from rest_framework.settings import api_settings
from rest_framework.utils import html
def method_overridden(method_name, klass, instance):
"""
Determine if a method has been overridden.
"""
method = getattr(klass, method_name)
default_method = getattr(method, '__func__', method) # Python 3 compat
return default_method is not getattr(instance, method_name).__func__
class ObjectValueError(ValueError):
"""
Raised when `queryset.get()` failed due to an underlying `ValueError`.
Wrapping prevents calling code conflating this with unrelated errors.
"""
class ObjectTypeError(TypeError):
"""
Raised when `queryset.get()` failed due to an underlying `TypeError`.
Wrapping prevents calling code conflating this with unrelated errors.
"""
class Hyperlink(str):
"""
A string like object that additionally has an associated name.
We use this for hyperlinked URLs that may render as a named link
in some contexts, or render as a plain URL in others.
"""
def __new__(cls, url, obj):
ret = super().__new__(cls, url)
ret.obj = obj
return ret
def __getnewargs__(self):
return (str(self), self.name)
@property
def name(self):
# This ensures that we only called `__str__` lazily,
# as in some cases calling __str__ on a model instances *might*
# involve a database lookup.
return str(self.obj)
is_hyperlink = True
class PKOnlyObject:
"""
This is a mock object, used for when we only need the pk of the object
instance, but still want to return an object with a .pk attribute,
in order to keep the same interface as a regular model instance.
"""
def __init__(self, pk):
self.pk = pk
def __str__(self):
return "%s" % self.pk
# We assume that 'validators' are intended for the child serializer,
# rather than the parent serializer.
MANY_RELATION_KWARGS = (
'read_only', 'write_only', 'required', 'default', 'initial', 'source',
'label', 'help_text', 'style', 'error_messages', 'allow_empty',
'html_cutoff', 'html_cutoff_text'
)
class RelatedField(Field):
queryset = None
html_cutoff = None
html_cutoff_text = None
def __init__(self, **kwargs):
self.queryset = kwargs.pop('queryset', self.queryset)
cutoff_from_settings = api_settings.HTML_SELECT_CUTOFF
if cutoff_from_settings is not None:
cutoff_from_settings = int(cutoff_from_settings)
self.html_cutoff = kwargs.pop('html_cutoff', cutoff_from_settings)
self.html_cutoff_text = kwargs.pop(
'html_cutoff_text',
self.html_cutoff_text or _(api_settings.HTML_SELECT_CUTOFF_TEXT)
)
if not method_overridden('get_queryset', RelatedField, self):
assert self.queryset is not None or kwargs.get('read_only'), (
'Relational field must provide a `queryset` argument, '
'override `get_queryset`, or set read_only=`True`.'
)
assert not (self.queryset is not None and kwargs.get('read_only')), (
'Relational fields should not provide a `queryset` argument, '
'when setting read_only=`True`.'
)
kwargs.pop('many', None)
kwargs.pop('allow_empty', None)
super().__init__(**kwargs)
def __new__(cls, *args, **kwargs):
# We override this method in order to automagically create
# `ManyRelatedField` classes instead when `many=True` is set.
if kwargs.pop('many', False):
return cls.many_init(*args, **kwargs)
return super().__new__(cls, *args, **kwargs)
@classmethod
def many_init(cls, *args, **kwargs):
"""
This method handles creating a parent `ManyRelatedField` instance
when the `many=True` keyword argument is passed.
Typically you won't need to override this method.
Note that we're over-cautious in passing most arguments to both parent
and child classes in order to try to cover the general case. If you're
overriding this method you'll probably want something much simpler, eg:
@classmethod
def many_init(cls, *args, **kwargs):
kwargs['child'] = cls()
return CustomManyRelatedField(*args, **kwargs)
"""
list_kwargs = {'child_relation': cls(*args, **kwargs)}
for key in kwargs:
if key in MANY_RELATION_KWARGS:
list_kwargs[key] = kwargs[key]
return ManyRelatedField(**list_kwargs)
def run_validation(self, data=empty):
# We force empty strings to None values for relational fields.
if data == '':
data = None
return super().run_validation(data)
def get_queryset(self):
queryset = self.queryset
if isinstance(queryset, (QuerySet, Manager)):
# Ensure queryset is re-evaluated whenever used.
# Note that actually a `Manager` class may also be used as the
# queryset argument. This occurs on ModelSerializer fields,
# as it allows us to generate a more expressive 'repr' output
# for the field.
# Eg: 'MyRelationship(queryset=ExampleModel.objects.all())'
queryset = queryset.all()
return queryset
def use_pk_only_optimization(self):
return False
def get_attribute(self, instance):
if self.use_pk_only_optimization() and self.source_attrs:
# Optimized case, return a mock object only containing the pk attribute.
try:
attribute_instance = get_attribute(instance, self.source_attrs[:-1])
value = attribute_instance.serializable_value(self.source_attrs[-1])
if is_simple_callable(value):
# Handle edge case where the relationship `source` argument
# points to a `get_relationship()` method on the model.
value = value()
# Handle edge case where relationship `source` argument points
# to an instance instead of a pk (e.g., a `@property`).
value = getattr(value, 'pk', value)
return PKOnlyObject(pk=value)
except AttributeError:
pass
# Standard case, return the object instance.
return super().get_attribute(instance)
def get_choices(self, cutoff=None):
queryset = self.get_queryset()
if queryset is None:
# Ensure that field.choices returns something sensible
# even when accessed with a read-only field.
return {}
if cutoff is not None:
queryset = queryset[:cutoff]
return OrderedDict([
(
self.to_representation(item),
self.display_value(item)
)
for item in queryset
])
@property
def choices(self):
return self.get_choices()
@property
def grouped_choices(self):
return self.choices
def iter_options(self):
return iter_options(
self.get_choices(cutoff=self.html_cutoff),
cutoff=self.html_cutoff,
cutoff_text=self.html_cutoff_text
)
def display_value(self, instance):
return str(instance)
class StringRelatedField(RelatedField):
"""
A read only field that represents its targets using their
plain string representation.
"""
def __init__(self, **kwargs):
kwargs['read_only'] = True
super().__init__(**kwargs)
def to_representation(self, value):
return str(value)
class PrimaryKeyRelatedField(RelatedField):
default_error_messages = {
'required': _('This field is required.'),
'does_not_exist': _('Invalid pk "{pk_value}" - object does not exist.'),
'incorrect_type': _('Incorrect type. Expected pk value, received {data_type}.'),
}
def __init__(self, **kwargs):
self.pk_field = kwargs.pop('pk_field', None)
super().__init__(**kwargs)
def use_pk_only_optimization(self):
return True
def to_internal_value(self, data):
if self.pk_field is not None:
data = self.pk_field.to_internal_value(data)
queryset = self.get_queryset()
try:
if isinstance(data, bool):
raise TypeError
return queryset.get(pk=data)
except ObjectDoesNotExist:
self.fail('does_not_exist', pk_value=data)
except (TypeError, ValueError):
self.fail('incorrect_type', data_type=type(data).__name__)
def to_representation(self, value):
if self.pk_field is not None:
return self.pk_field.to_representation(value.pk)
return value.pk
class HyperlinkedRelatedField(RelatedField):
lookup_field = 'pk'
view_name = None
default_error_messages = {
'required': _('This field is required.'),
'no_match': _('Invalid hyperlink - No URL match.'),
'incorrect_match': _('Invalid hyperlink - Incorrect URL match.'),
'does_not_exist': _('Invalid hyperlink - Object does not exist.'),
'incorrect_type': _('Incorrect type. Expected URL string, received {data_type}.'),
}
def __init__(self, view_name=None, **kwargs):
if view_name is not None:
self.view_name = view_name
assert self.view_name is not None, 'The `view_name` argument is required.'
self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field)
self.format = kwargs.pop('format', None)
# We include this simply for dependency injection in tests.
# We can't add it as a class attributes or it would expect an
# implicit `self` argument to be passed.
self.reverse = reverse
super().__init__(**kwargs)
def use_pk_only_optimization(self):
return self.lookup_field == 'pk'
def get_object(self, view_name, view_args, view_kwargs):
"""
Return the object corresponding to a matched URL.
Takes the matched URL conf arguments, and should return an
object instance, or raise an `ObjectDoesNotExist` exception.
"""
lookup_value = view_kwargs[self.lookup_url_kwarg]
lookup_kwargs = {self.lookup_field: lookup_value}
queryset = self.get_queryset()
try:
return queryset.get(**lookup_kwargs)
except ValueError:
exc = ObjectValueError(str(sys.exc_info()[1]))
raise exc.with_traceback(sys.exc_info()[2])
except TypeError:
exc = ObjectTypeError(str(sys.exc_info()[1]))
raise exc.with_traceback(sys.exc_info()[2])
def get_url(self, obj, view_name, request, format):
"""
Given an object, return the URL that hyperlinks to the object.
May raise a `NoReverseMatch` if the `view_name` and `lookup_field`
attributes are not configured to correctly match the URL conf.
"""
# Unsaved objects will not yet have a valid URL.
if hasattr(obj, 'pk') and obj.pk in (None, ''):
return None
lookup_value = getattr(obj, self.lookup_field)
kwargs = {self.lookup_url_kwarg: lookup_value}
return self.reverse(view_name, kwargs=kwargs, request=request, format=format)
def to_internal_value(self, data):
request = self.context.get('request')
try:
http_prefix = data.startswith(('http:', 'https:'))
except AttributeError:
self.fail('incorrect_type', data_type=type(data).__name__)
if http_prefix:
# If needed convert absolute URLs to relative path
data = parse.urlparse(data).path
prefix = get_script_prefix()
if data.startswith(prefix):
data = '/' + data[len(prefix):]
data = uri_to_iri(parse.unquote(data))
try:
match = resolve(data)
except Resolver404:
self.fail('no_match')
try:
expected_viewname = request.versioning_scheme.get_versioned_viewname(
self.view_name, request
)
except AttributeError:
expected_viewname = self.view_name
if match.view_name != expected_viewname:
self.fail('incorrect_match')
try:
return self.get_object(match.view_name, match.args, match.kwargs)
except (ObjectDoesNotExist, ObjectValueError, ObjectTypeError):
self.fail('does_not_exist')
def to_representation(self, value):
assert 'request' in self.context, (
"`%s` requires the request in the serializer"
" context. Add `context={'request': request}` when instantiating "
"the serializer." % self.__class__.__name__
)
request = self.context['request']
format = self.context.get('format')
# By default use whatever format is given for the current context
# unless the target is a different type to the source.
#
# Eg. Consider a HyperlinkedIdentityField pointing from a json
# representation to an html property of that representation...
#
# '/snippets/1/' should link to '/snippets/1/highlight/'
# ...but...
# '/snippets/1/.json' should link to '/snippets/1/highlight/.html'
if format and self.format and self.format != format:
format = self.format
# Return the hyperlink, or error if incorrectly configured.
try:
url = self.get_url(value, self.view_name, request, format)
except NoReverseMatch:
msg = (
'Could not resolve URL for hyperlinked relationship using '
'view name "%s". You may have failed to include the related '
'model in your API, or incorrectly configured the '
'`lookup_field` attribute on this field.'
)
if value in ('', None):
value_string = {'': 'the empty string', None: 'None'}[value]
msg += (
" WARNING: The value of the field on the model instance "
"was %s, which may be why it didn't match any "
"entries in your URL conf." % value_string
)
raise ImproperlyConfigured(msg % self.view_name)
if url is None:
return None
return Hyperlink(url, value)
class HyperlinkedIdentityField(HyperlinkedRelatedField):
"""
A read-only field that represents the identity URL for an object, itself.
This is in contrast to `HyperlinkedRelatedField` which represents the
URL of relationships to other objects.
"""
def __init__(self, view_name=None, **kwargs):
assert view_name is not None, 'The `view_name` argument is required.'
kwargs['read_only'] = True
kwargs['source'] = '*'
super().__init__(view_name, **kwargs)
def use_pk_only_optimization(self):
# We have the complete object instance already. We don't need
# to run the 'only get the pk for this relationship' code.
return False
class SlugRelatedField(RelatedField):
"""
A read-write field that represents the target of the relationship
by a unique 'slug' attribute.
"""
default_error_messages = {
'does_not_exist': _('Object with {slug_name}={value} does not exist.'),
'invalid': _('Invalid value.'),
}
def __init__(self, slug_field=None, **kwargs):
assert slug_field is not None, 'The `slug_field` argument is required.'
self.slug_field = slug_field
super().__init__(**kwargs)
def to_internal_value(self, data):
queryset = self.get_queryset()
try:
return queryset.get(**{self.slug_field: data})
except ObjectDoesNotExist:
self.fail('does_not_exist', slug_name=self.slug_field, value=smart_str(data))
except (TypeError, ValueError):
self.fail('invalid')
def to_representation(self, obj):
return getattr(obj, self.slug_field)
class ManyRelatedField(Field):
"""
Relationships with `many=True` transparently get coerced into instead being
a ManyRelatedField with a child relationship.
The `ManyRelatedField` class is responsible for handling iterating through
the values and passing each one to the child relationship.
This class is treated as private API.
You shouldn't generally need to be using this class directly yourself,
and should instead simply set 'many=True' on the relationship.
"""
initial = []
default_empty_html = []
default_error_messages = {
'not_a_list': _('Expected a list of items but got type "{input_type}".'),
'empty': _('This list may not be empty.')
}
html_cutoff = None
html_cutoff_text = None
def __init__(self, child_relation=None, *args, **kwargs):
self.child_relation = child_relation
self.allow_empty = kwargs.pop('allow_empty', True)
cutoff_from_settings = api_settings.HTML_SELECT_CUTOFF
if cutoff_from_settings is not None:
cutoff_from_settings = int(cutoff_from_settings)
self.html_cutoff = kwargs.pop('html_cutoff', cutoff_from_settings)
self.html_cutoff_text = kwargs.pop(
'html_cutoff_text',
self.html_cutoff_text or _(api_settings.HTML_SELECT_CUTOFF_TEXT)
)
assert child_relation is not None, '`child_relation` is a required argument.'
super().__init__(*args, **kwargs)
self.child_relation.bind(field_name='', parent=self)
def get_value(self, dictionary):
# We override the default field access in order to support
# lists in HTML forms.
if html.is_html_input(dictionary):
# Don't return [] if the update is partial
if self.field_name not in dictionary:
if getattr(self.root, 'partial', False):
return empty
return dictionary.getlist(self.field_name)
return dictionary.get(self.field_name, empty)
def to_internal_value(self, data):
if isinstance(data, str) or not hasattr(data, '__iter__'):
self.fail('not_a_list', input_type=type(data).__name__)
if not self.allow_empty and len(data) == 0:
self.fail('empty')
return [
self.child_relation.to_internal_value(item)
for item in data
]
def get_attribute(self, instance):
# Can't have any relationships if not created
if hasattr(instance, 'pk') and instance.pk is None:
return []
try:
relationship = get_attribute(instance, self.source_attrs)
except (KeyError, AttributeError) as exc:
if self.default is not empty:
return self.get_default()
if self.allow_null:
return None
if not self.required:
raise SkipField()
msg = (
'Got {exc_type} when attempting to get a value for field '
'`{field}` on serializer `{serializer}`.\nThe serializer '
'field might be named incorrectly and not match '
'any attribute or key on the `{instance}` instance.\n'
'Original exception text was: {exc}.'.format(
exc_type=type(exc).__name__,
field=self.field_name,
serializer=self.parent.__class__.__name__,
instance=instance.__class__.__name__,
exc=exc
)
)
raise type(exc)(msg)
return relationship.all() if hasattr(relationship, 'all') else relationship
def to_representation(self, iterable):
return [
self.child_relation.to_representation(value)
for value in iterable
]
def get_choices(self, cutoff=None):
return self.child_relation.get_choices(cutoff)
@property
def choices(self):
return self.get_choices()
@property
def grouped_choices(self):
return self.choices
def iter_options(self):
return iter_options(
self.get_choices(cutoff=self.html_cutoff),
cutoff=self.html_cutoff,
cutoff_text=self.html_cutoff_text
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,455 @@
"""
The Request class is used as a wrapper around the standard request object.
The wrapped request then offers a richer API, in particular :
- content automatically parsed according to `Content-Type` header,
and available as `request.data`
- full support of PUT method, including support for file uploads
- form overloading of HTTP method, content type and content
"""
import io
import sys
from contextlib import contextmanager
from django.conf import settings
from django.http import HttpRequest, QueryDict
from django.http.request import RawPostDataException
from django.utils.datastructures import MultiValueDict
from rest_framework import exceptions
from rest_framework.compat import parse_header_parameters
from rest_framework.settings import api_settings
def is_form_media_type(media_type):
"""
Return True if the media type is a valid form media type.
"""
base_media_type, params = parse_header_parameters(media_type)
return (base_media_type == 'application/x-www-form-urlencoded' or
base_media_type == 'multipart/form-data')
class override_method:
"""
A context manager that temporarily overrides the method on a request,
additionally setting the `view.request` attribute.
Usage:
with override_method(view, request, 'POST') as request:
... # Do stuff with `view` and `request`
"""
def __init__(self, view, request, method):
self.view = view
self.request = request
self.method = method
self.action = getattr(view, 'action', None)
def __enter__(self):
self.view.request = clone_request(self.request, self.method)
# For viewsets we also set the `.action` attribute.
action_map = getattr(self.view, 'action_map', {})
self.view.action = action_map.get(self.method.lower())
return self.view.request
def __exit__(self, *args, **kwarg):
self.view.request = self.request
self.view.action = self.action
class WrappedAttributeError(Exception):
pass
@contextmanager
def wrap_attributeerrors():
"""
Used to re-raise AttributeErrors caught during authentication, preventing
these errors from otherwise being handled by the attribute access protocol.
"""
try:
yield
except AttributeError:
info = sys.exc_info()
exc = WrappedAttributeError(str(info[1]))
raise exc.with_traceback(info[2])
class Empty:
"""
Placeholder for unset attributes.
Cannot use `None`, as that may be a valid value.
"""
pass
def _hasattr(obj, name):
return not getattr(obj, name) is Empty
def clone_request(request, method):
"""
Internal helper method to clone a request, replacing with a different
HTTP method. Used for checking permissions against other methods.
"""
ret = Request(request=request._request,
parsers=request.parsers,
authenticators=request.authenticators,
negotiator=request.negotiator,
parser_context=request.parser_context)
ret._data = request._data
ret._files = request._files
ret._full_data = request._full_data
ret._content_type = request._content_type
ret._stream = request._stream
ret.method = method
if hasattr(request, '_user'):
ret._user = request._user
if hasattr(request, '_auth'):
ret._auth = request._auth
if hasattr(request, '_authenticator'):
ret._authenticator = request._authenticator
if hasattr(request, 'accepted_renderer'):
ret.accepted_renderer = request.accepted_renderer
if hasattr(request, 'accepted_media_type'):
ret.accepted_media_type = request.accepted_media_type
if hasattr(request, 'version'):
ret.version = request.version
if hasattr(request, 'versioning_scheme'):
ret.versioning_scheme = request.versioning_scheme
return ret
class ForcedAuthentication:
"""
This authentication class is used if the test client or request factory
forcibly authenticated the request.
"""
def __init__(self, force_user, force_token):
self.force_user = force_user
self.force_token = force_token
def authenticate(self, request):
return (self.force_user, self.force_token)
class Request:
"""
Wrapper allowing to enhance a standard `HttpRequest` instance.
Kwargs:
- request(HttpRequest). The original request instance.
- parsers(list/tuple). The parsers to use for parsing the
request content.
- authenticators(list/tuple). The authenticators used to try
authenticating the request's user.
"""
def __init__(self, request, parsers=None, authenticators=None,
negotiator=None, parser_context=None):
assert isinstance(request, HttpRequest), (
'The `request` argument must be an instance of '
'`django.http.HttpRequest`, not `{}.{}`.'
.format(request.__class__.__module__, request.__class__.__name__)
)
self._request = request
self.parsers = parsers or ()
self.authenticators = authenticators or ()
self.negotiator = negotiator or self._default_negotiator()
self.parser_context = parser_context
self._data = Empty
self._files = Empty
self._full_data = Empty
self._content_type = Empty
self._stream = Empty
if self.parser_context is None:
self.parser_context = {}
self.parser_context['request'] = self
self.parser_context['encoding'] = request.encoding or settings.DEFAULT_CHARSET
force_user = getattr(request, '_force_auth_user', None)
force_token = getattr(request, '_force_auth_token', None)
if force_user is not None or force_token is not None:
forced_auth = ForcedAuthentication(force_user, force_token)
self.authenticators = (forced_auth,)
def __repr__(self):
return '<%s.%s: %s %r>' % (
self.__class__.__module__,
self.__class__.__name__,
self.method,
self.get_full_path())
def _default_negotiator(self):
return api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS()
@property
def content_type(self):
meta = self._request.META
return meta.get('CONTENT_TYPE', meta.get('HTTP_CONTENT_TYPE', ''))
@property
def stream(self):
"""
Returns an object that may be used to stream the request content.
"""
if not _hasattr(self, '_stream'):
self._load_stream()
return self._stream
@property
def query_params(self):
"""
More semantically correct name for request.GET.
"""
return self._request.GET
@property
def data(self):
if not _hasattr(self, '_full_data'):
self._load_data_and_files()
return self._full_data
@property
def user(self):
"""
Returns the user associated with the current request, as authenticated
by the authentication classes provided to the request.
"""
if not hasattr(self, '_user'):
with wrap_attributeerrors():
self._authenticate()
return self._user
@user.setter
def user(self, value):
"""
Sets the user on the current request. This is necessary to maintain
compatibility with django.contrib.auth where the user property is
set in the login and logout functions.
Note that we also set the user on Django's underlying `HttpRequest`
instance, ensuring that it is available to any middleware in the stack.
"""
self._user = value
self._request.user = value
@property
def auth(self):
"""
Returns any non-user authentication information associated with the
request, such as an authentication token.
"""
if not hasattr(self, '_auth'):
with wrap_attributeerrors():
self._authenticate()
return self._auth
@auth.setter
def auth(self, value):
"""
Sets any non-user authentication information associated with the
request, such as an authentication token.
"""
self._auth = value
self._request.auth = value
@property
def successful_authenticator(self):
"""
Return the instance of the authentication instance class that was used
to authenticate the request, or `None`.
"""
if not hasattr(self, '_authenticator'):
with wrap_attributeerrors():
self._authenticate()
return self._authenticator
def _load_data_and_files(self):
"""
Parses the request content into `self.data`.
"""
if not _hasattr(self, '_data'):
self._data, self._files = self._parse()
if self._files:
self._full_data = self._data.copy()
self._full_data.update(self._files)
else:
self._full_data = self._data
# if a form media type, copy data & files refs to the underlying
# http request so that closable objects are handled appropriately.
if is_form_media_type(self.content_type):
self._request._post = self.POST
self._request._files = self.FILES
def _load_stream(self):
"""
Return the content body of the request, as a stream.
"""
meta = self._request.META
try:
content_length = int(
meta.get('CONTENT_LENGTH', meta.get('HTTP_CONTENT_LENGTH', 0))
)
except (ValueError, TypeError):
content_length = 0
if content_length == 0:
self._stream = None
elif not self._request._read_started:
self._stream = self._request
else:
self._stream = io.BytesIO(self.body)
def _supports_form_parsing(self):
"""
Return True if this requests supports parsing form data.
"""
form_media = (
'application/x-www-form-urlencoded',
'multipart/form-data'
)
return any(parser.media_type in form_media for parser in self.parsers)
def _parse(self):
"""
Parse the request content, returning a two-tuple of (data, files)
May raise an `UnsupportedMediaType`, or `ParseError` exception.
"""
media_type = self.content_type
try:
stream = self.stream
except RawPostDataException:
if not hasattr(self._request, '_post'):
raise
# If request.POST has been accessed in middleware, and a method='POST'
# request was made with 'multipart/form-data', then the request stream
# will already have been exhausted.
if self._supports_form_parsing():
return (self._request.POST, self._request.FILES)
stream = None
if stream is None or media_type is None:
if media_type and is_form_media_type(media_type):
empty_data = QueryDict('', encoding=self._request._encoding)
else:
empty_data = {}
empty_files = MultiValueDict()
return (empty_data, empty_files)
parser = self.negotiator.select_parser(self, self.parsers)
if not parser:
raise exceptions.UnsupportedMediaType(media_type)
try:
parsed = parser.parse(stream, media_type, self.parser_context)
except Exception:
# If we get an exception during parsing, fill in empty data and
# re-raise. Ensures we don't simply repeat the error when
# attempting to render the browsable renderer response, or when
# logging the request or similar.
self._data = QueryDict('', encoding=self._request._encoding)
self._files = MultiValueDict()
self._full_data = self._data
raise
# Parser classes may return the raw data, or a
# DataAndFiles object. Unpack the result as required.
try:
return (parsed.data, parsed.files)
except AttributeError:
empty_files = MultiValueDict()
return (parsed, empty_files)
def _authenticate(self):
"""
Attempt to authenticate the request using each authentication instance
in turn.
"""
for authenticator in self.authenticators:
try:
user_auth_tuple = authenticator.authenticate(self)
except exceptions.APIException:
self._not_authenticated()
raise
if user_auth_tuple is not None:
self._authenticator = authenticator
self.user, self.auth = user_auth_tuple
return
self._not_authenticated()
def _not_authenticated(self):
"""
Set authenticator, user & authtoken representing an unauthenticated request.
Defaults are None, AnonymousUser & None.
"""
self._authenticator = None
if api_settings.UNAUTHENTICATED_USER:
self.user = api_settings.UNAUTHENTICATED_USER()
else:
self.user = None
if api_settings.UNAUTHENTICATED_TOKEN:
self.auth = api_settings.UNAUTHENTICATED_TOKEN()
else:
self.auth = None
def __getattr__(self, attr):
"""
If an attribute does not exist on this instance, then we also attempt
to proxy it to the underlying HttpRequest object.
"""
try:
return getattr(self._request, attr)
except AttributeError:
return self.__getattribute__(attr)
@property
def DATA(self):
raise NotImplementedError(
'`request.DATA` has been deprecated in favor of `request.data` '
'since version 3.0, and has been fully removed as of version 3.2.'
)
@property
def POST(self):
# Ensure that request.POST uses our request parsing.
if not _hasattr(self, '_data'):
self._load_data_and_files()
if is_form_media_type(self.content_type):
return self._data
return QueryDict('', encoding=self._request._encoding)
@property
def FILES(self):
# Leave this one alone for backwards compat with Django's request.FILES
# Different from the other two cases, which are not valid property
# names on the WSGIRequest class.
if not _hasattr(self, '_files'):
self._load_data_and_files()
return self._files
@property
def QUERY_PARAMS(self):
raise NotImplementedError(
'`request.QUERY_PARAMS` has been deprecated in favor of `request.query_params` '
'since version 3.0, and has been fully removed as of version 3.2.'
)
def force_plaintext_errors(self, value):
# Hack to allow our exception handler to force choice of
# plaintext or html error responses.
self._request.is_ajax = lambda: value

View File

@ -0,0 +1,103 @@
"""
The Response class in REST framework is similar to HTTPResponse, except that
it is initialized with unrendered data, instead of a pre-rendered string.
The appropriate renderer is called during Django's template response rendering.
"""
from http.client import responses
from django.template.response import SimpleTemplateResponse
from rest_framework.serializers import Serializer
class Response(SimpleTemplateResponse):
"""
An HttpResponse that allows its data to be rendered into
arbitrary media types.
"""
def __init__(self, data=None, status=None,
template_name=None, headers=None,
exception=False, content_type=None):
"""
Alters the init arguments slightly.
For example, drop 'template_name', and instead use 'data'.
Setting 'renderer' and 'media_type' will typically be deferred,
For example being set automatically by the `APIView`.
"""
super().__init__(None, status=status)
if isinstance(data, Serializer):
msg = (
'You passed a Serializer instance as data, but '
'probably meant to pass serialized `.data` or '
'`.error`. representation.'
)
raise AssertionError(msg)
self.data = data
self.template_name = template_name
self.exception = exception
self.content_type = content_type
if headers:
for name, value in headers.items():
self[name] = value
@property
def rendered_content(self):
renderer = getattr(self, 'accepted_renderer', None)
accepted_media_type = getattr(self, 'accepted_media_type', None)
context = getattr(self, 'renderer_context', None)
assert renderer, ".accepted_renderer not set on Response"
assert accepted_media_type, ".accepted_media_type not set on Response"
assert context is not None, ".renderer_context not set on Response"
context['response'] = self
media_type = renderer.media_type
charset = renderer.charset
content_type = self.content_type
if content_type is None and charset is not None:
content_type = "{}; charset={}".format(media_type, charset)
elif content_type is None:
content_type = media_type
self['Content-Type'] = content_type
ret = renderer.render(self.data, accepted_media_type, context)
if isinstance(ret, str):
assert charset, (
'renderer returned unicode, and did not specify '
'a charset value.'
)
return ret.encode(charset)
if not ret:
del self['Content-Type']
return ret
@property
def status_text(self):
"""
Returns reason text corresponding to our HTTP response status code.
Provided for convenience.
"""
return responses.get(self.status_code, '')
def __getstate__(self):
"""
Remove attributes from the response that shouldn't be cached.
"""
state = super().__getstate__()
for key in (
'accepted_renderer', 'renderer_context', 'resolver_match',
'client', 'request', 'json', 'wsgi_request'
):
if key in state:
del state[key]
state['_closable_objects'] = []
return state

View File

@ -0,0 +1,66 @@
"""
Provide urlresolver functions that return fully qualified URLs or view names
"""
from django.urls import NoReverseMatch
from django.urls import reverse as django_reverse
from django.utils.functional import lazy
from rest_framework.settings import api_settings
from rest_framework.utils.urls import replace_query_param
def preserve_builtin_query_params(url, request=None):
"""
Given an incoming request, and an outgoing URL representation,
append the value of any built-in query parameters.
"""
if request is None:
return url
overrides = [
api_settings.URL_FORMAT_OVERRIDE,
]
for param in overrides:
if param and (param in request.GET):
value = request.GET[param]
url = replace_query_param(url, param, value)
return url
def reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra):
"""
If versioning is being used then we pass any `reverse` calls through
to the versioning scheme instance, so that the resulting URL
can be modified if needed.
"""
scheme = getattr(request, 'versioning_scheme', None)
if scheme is not None:
try:
url = scheme.reverse(viewname, args, kwargs, request, format, **extra)
except NoReverseMatch:
# In case the versioning scheme reversal fails, fallback to the
# default implementation
url = _reverse(viewname, args, kwargs, request, format, **extra)
else:
url = _reverse(viewname, args, kwargs, request, format, **extra)
return preserve_builtin_query_params(url, request)
def _reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra):
"""
Same as `django.urls.reverse`, but optionally takes a request
and returns a fully qualified URL, using the request to get the base URL.
"""
if format is not None:
kwargs = kwargs or {}
kwargs['format'] = format
url = django_reverse(viewname, args=args, kwargs=kwargs, **extra)
if request:
return request.build_absolute_uri(url)
return url
reverse_lazy = lazy(reverse, str)

View File

@ -0,0 +1,348 @@
"""
Routers provide a convenient and consistent way of automatically
determining the URL conf for your API.
They are used by simply instantiating a Router class, and then registering
all the required ViewSets with that router.
For example, you might have a `urls.py` that looks something like this:
router = routers.DefaultRouter()
router.register('users', UserViewSet, 'user')
router.register('accounts', AccountViewSet, 'account')
urlpatterns = router.urls
"""
import itertools
from collections import OrderedDict, namedtuple
from django.core.exceptions import ImproperlyConfigured
from django.urls import NoReverseMatch, re_path
from rest_framework import views
from rest_framework.response import Response
from rest_framework.reverse import reverse
from rest_framework.schemas import SchemaGenerator
from rest_framework.schemas.views import SchemaView
from rest_framework.settings import api_settings
from rest_framework.urlpatterns import format_suffix_patterns
Route = namedtuple('Route', ['url', 'mapping', 'name', 'detail', 'initkwargs'])
DynamicRoute = namedtuple('DynamicRoute', ['url', 'name', 'detail', 'initkwargs'])
def escape_curly_brackets(url_path):
"""
Double brackets in regex of url_path for escape string formatting
"""
return url_path.replace('{', '{{').replace('}', '}}')
def flatten(list_of_lists):
"""
Takes an iterable of iterables, returns a single iterable containing all items
"""
return itertools.chain(*list_of_lists)
class BaseRouter:
def __init__(self):
self.registry = []
def register(self, prefix, viewset, basename=None):
if basename is None:
basename = self.get_default_basename(viewset)
self.registry.append((prefix, viewset, basename))
# invalidate the urls cache
if hasattr(self, '_urls'):
del self._urls
def get_default_basename(self, viewset):
"""
If `basename` is not specified, attempt to automatically determine
it from the viewset.
"""
raise NotImplementedError('get_default_basename must be overridden')
def get_urls(self):
"""
Return a list of URL patterns, given the registered viewsets.
"""
raise NotImplementedError('get_urls must be overridden')
@property
def urls(self):
if not hasattr(self, '_urls'):
self._urls = self.get_urls()
return self._urls
class SimpleRouter(BaseRouter):
routes = [
# List route.
Route(
url=r'^{prefix}{trailing_slash}$',
mapping={
'get': 'list',
'post': 'create'
},
name='{basename}-list',
detail=False,
initkwargs={'suffix': 'List'}
),
# Dynamically generated list routes. Generated using
# @action(detail=False) decorator on methods of the viewset.
DynamicRoute(
url=r'^{prefix}/{url_path}{trailing_slash}$',
name='{basename}-{url_name}',
detail=False,
initkwargs={}
),
# Detail route.
Route(
url=r'^{prefix}/{lookup}{trailing_slash}$',
mapping={
'get': 'retrieve',
'put': 'update',
'patch': 'partial_update',
'delete': 'destroy'
},
name='{basename}-detail',
detail=True,
initkwargs={'suffix': 'Instance'}
),
# Dynamically generated detail routes. Generated using
# @action(detail=True) decorator on methods of the viewset.
DynamicRoute(
url=r'^{prefix}/{lookup}/{url_path}{trailing_slash}$',
name='{basename}-{url_name}',
detail=True,
initkwargs={}
),
]
def __init__(self, trailing_slash=True):
self.trailing_slash = '/' if trailing_slash else ''
super().__init__()
def get_default_basename(self, viewset):
"""
If `basename` is not specified, attempt to automatically determine
it from the viewset.
"""
queryset = getattr(viewset, 'queryset', None)
assert queryset is not None, '`basename` argument not specified, and could ' \
'not automatically determine the name from the viewset, as ' \
'it does not have a `.queryset` attribute.'
return queryset.model._meta.object_name.lower()
def get_routes(self, viewset):
"""
Augment `self.routes` with any dynamically generated routes.
Returns a list of the Route namedtuple.
"""
# converting to list as iterables are good for one pass, known host needs to be checked again and again for
# different functions.
known_actions = list(flatten([route.mapping.values() for route in self.routes if isinstance(route, Route)]))
extra_actions = viewset.get_extra_actions()
# checking action names against the known actions list
not_allowed = [
action.__name__ for action in extra_actions
if action.__name__ in known_actions
]
if not_allowed:
msg = ('Cannot use the @action decorator on the following '
'methods, as they are existing routes: %s')
raise ImproperlyConfigured(msg % ', '.join(not_allowed))
# partition detail and list actions
detail_actions = [action for action in extra_actions if action.detail]
list_actions = [action for action in extra_actions if not action.detail]
routes = []
for route in self.routes:
if isinstance(route, DynamicRoute) and route.detail:
routes += [self._get_dynamic_route(route, action) for action in detail_actions]
elif isinstance(route, DynamicRoute) and not route.detail:
routes += [self._get_dynamic_route(route, action) for action in list_actions]
else:
routes.append(route)
return routes
def _get_dynamic_route(self, route, action):
initkwargs = route.initkwargs.copy()
initkwargs.update(action.kwargs)
url_path = escape_curly_brackets(action.url_path)
return Route(
url=route.url.replace('{url_path}', url_path),
mapping=action.mapping,
name=route.name.replace('{url_name}', action.url_name),
detail=route.detail,
initkwargs=initkwargs,
)
def get_method_map(self, viewset, method_map):
"""
Given a viewset, and a mapping of http methods to actions,
return a new mapping which only includes any mappings that
are actually implemented by the viewset.
"""
bound_methods = {}
for method, action in method_map.items():
if hasattr(viewset, action):
bound_methods[method] = action
return bound_methods
def get_lookup_regex(self, viewset, lookup_prefix=''):
"""
Given a viewset, return the portion of URL regex that is used
to match against a single instance.
Note that lookup_prefix is not used directly inside REST rest_framework
itself, but is required in order to nicely support nested router
implementations, such as drf-nested-routers.
https://github.com/alanjds/drf-nested-routers
"""
base_regex = '(?P<{lookup_prefix}{lookup_url_kwarg}>{lookup_value})'
# Use `pk` as default field, unset set. Default regex should not
# consume `.json` style suffixes and should break at '/' boundaries.
lookup_field = getattr(viewset, 'lookup_field', 'pk')
lookup_url_kwarg = getattr(viewset, 'lookup_url_kwarg', None) or lookup_field
lookup_value = getattr(viewset, 'lookup_value_regex', '[^/.]+')
return base_regex.format(
lookup_prefix=lookup_prefix,
lookup_url_kwarg=lookup_url_kwarg,
lookup_value=lookup_value
)
def get_urls(self):
"""
Use the registered viewsets to generate a list of URL patterns.
"""
ret = []
for prefix, viewset, basename in self.registry:
lookup = self.get_lookup_regex(viewset)
routes = self.get_routes(viewset)
for route in routes:
# Only actions which actually exist on the viewset will be bound
mapping = self.get_method_map(viewset, route.mapping)
if not mapping:
continue
# Build the url pattern
regex = route.url.format(
prefix=prefix,
lookup=lookup,
trailing_slash=self.trailing_slash
)
# If there is no prefix, the first part of the url is probably
# controlled by project's urls.py and the router is in an app,
# so a slash in the beginning will (A) cause Django to give
# warnings and (B) generate URLS that will require using '//'.
if not prefix and regex[:2] == '^/':
regex = '^' + regex[2:]
initkwargs = route.initkwargs.copy()
initkwargs.update({
'basename': basename,
'detail': route.detail,
})
view = viewset.as_view(mapping, **initkwargs)
name = route.name.format(basename=basename)
ret.append(re_path(regex, view, name=name))
return ret
class APIRootView(views.APIView):
"""
The default basic root view for DefaultRouter
"""
_ignore_model_permissions = True
schema = None # exclude from schema
api_root_dict = None
def get(self, request, *args, **kwargs):
# Return a plain {"name": "hyperlink"} response.
ret = OrderedDict()
namespace = request.resolver_match.namespace
for key, url_name in self.api_root_dict.items():
if namespace:
url_name = namespace + ':' + url_name
try:
ret[key] = reverse(
url_name,
args=args,
kwargs=kwargs,
request=request,
format=kwargs.get('format')
)
except NoReverseMatch:
# Don't bail out if eg. no list routes exist, only detail routes.
continue
return Response(ret)
class DefaultRouter(SimpleRouter):
"""
The default router extends the SimpleRouter, but also adds in a default
API root view, and adds format suffix patterns to the URLs.
"""
include_root_view = True
include_format_suffixes = True
root_view_name = 'api-root'
default_schema_renderers = None
APIRootView = APIRootView
APISchemaView = SchemaView
SchemaGenerator = SchemaGenerator
def __init__(self, *args, **kwargs):
if 'root_renderers' in kwargs:
self.root_renderers = kwargs.pop('root_renderers')
else:
self.root_renderers = list(api_settings.DEFAULT_RENDERER_CLASSES)
super().__init__(*args, **kwargs)
def get_api_root_view(self, api_urls=None):
"""
Return a basic root view.
"""
api_root_dict = OrderedDict()
list_name = self.routes[0].name
for prefix, viewset, basename in self.registry:
api_root_dict[prefix] = list_name.format(basename=basename)
return self.APIRootView.as_view(api_root_dict=api_root_dict)
def get_urls(self):
"""
Generate the list of URL patterns, including a default root view
for the API, and appending `.json` style format suffixes.
"""
urls = super().get_urls()
if self.include_root_view:
view = self.get_api_root_view(api_urls=urls)
root_url = re_path(r'^$', view, name=self.root_view_name)
urls.append(root_url)
if self.include_format_suffixes:
urls = format_suffix_patterns(urls)
return urls

View File

@ -0,0 +1,58 @@
"""
rest_framework.schemas
schemas:
__init__.py
generators.py # Top-down schema generation
inspectors.py # Per-endpoint view introspection
utils.py # Shared helper functions
views.py # Houses `SchemaView`, `APIView` subclass.
We expose a minimal "public" API directly from `schemas`. This covers the
basic use-cases:
from rest_framework.schemas import (
AutoSchema,
ManualSchema,
get_schema_view,
SchemaGenerator,
)
Other access should target the submodules directly
"""
from rest_framework.settings import api_settings
from . import coreapi, openapi
from .coreapi import AutoSchema, ManualSchema, SchemaGenerator # noqa
from .inspectors import DefaultSchema # noqa
def get_schema_view(
title=None, url=None, description=None, urlconf=None, renderer_classes=None,
public=False, patterns=None, generator_class=None,
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES,
version=None):
"""
Return a schema view.
"""
if generator_class is None:
if coreapi.is_enabled():
generator_class = coreapi.SchemaGenerator
else:
generator_class = openapi.SchemaGenerator
generator = generator_class(
title=title, url=url, description=description,
urlconf=urlconf, patterns=patterns, version=version
)
# Avoid import cycle on APIView
from .views import SchemaView
return SchemaView.as_view(
renderer_classes=renderer_classes,
schema_generator=generator,
public=public,
authentication_classes=authentication_classes,
permission_classes=permission_classes,
)

View File

@ -0,0 +1,616 @@
import warnings
from collections import Counter, OrderedDict
from urllib import parse
from django.db import models
from django.utils.encoding import force_str
from rest_framework import exceptions, serializers
from rest_framework.compat import coreapi, coreschema, uritemplate
from rest_framework.settings import api_settings
from .generators import BaseSchemaGenerator
from .inspectors import ViewInspector
from .utils import get_pk_description, is_list_view
def common_path(paths):
split_paths = [path.strip('/').split('/') for path in paths]
s1 = min(split_paths)
s2 = max(split_paths)
common = s1
for i, c in enumerate(s1):
if c != s2[i]:
common = s1[:i]
break
return '/' + '/'.join(common)
def is_custom_action(action):
return action not in {
'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy'
}
def distribute_links(obj):
for key, value in obj.items():
distribute_links(value)
for preferred_key, link in obj.links:
key = obj.get_available_key(preferred_key)
obj[key] = link
INSERT_INTO_COLLISION_FMT = """
Schema Naming Collision.
coreapi.Link for URL path {value_url} cannot be inserted into schema.
Position conflicts with coreapi.Link for URL path {target_url}.
Attempted to insert link with keys: {keys}.
Adjust URLs to avoid naming collision or override `SchemaGenerator.get_keys()`
to customise schema structure.
"""
class LinkNode(OrderedDict):
def __init__(self):
self.links = []
self.methods_counter = Counter()
super().__init__()
def get_available_key(self, preferred_key):
if preferred_key not in self:
return preferred_key
while True:
current_val = self.methods_counter[preferred_key]
self.methods_counter[preferred_key] += 1
key = '{}_{}'.format(preferred_key, current_val)
if key not in self:
return key
def insert_into(target, keys, value):
"""
Nested dictionary insertion.
>>> example = {}
>>> insert_into(example, ['a', 'b', 'c'], 123)
>>> example
LinkNode({'a': LinkNode({'b': LinkNode({'c': LinkNode(links=[123])}}})))
"""
for key in keys[:-1]:
if key not in target:
target[key] = LinkNode()
target = target[key]
try:
target.links.append((keys[-1], value))
except TypeError:
msg = INSERT_INTO_COLLISION_FMT.format(
value_url=value.url,
target_url=target.url,
keys=keys
)
raise ValueError(msg)
class SchemaGenerator(BaseSchemaGenerator):
"""
Original CoreAPI version.
"""
# Map HTTP methods onto actions.
default_mapping = {
'get': 'retrieve',
'post': 'create',
'put': 'update',
'patch': 'partial_update',
'delete': 'destroy',
}
# Map the method names we use for viewset actions onto external schema names.
# These give us names that are more suitable for the external representation.
# Set by 'SCHEMA_COERCE_METHOD_NAMES'.
coerce_method_names = None
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None, version=None):
assert coreapi, '`coreapi` must be installed for schema support.'
assert coreschema, '`coreschema` must be installed for schema support.'
super().__init__(title, url, description, patterns, urlconf)
self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
def get_links(self, request=None):
"""
Return a dictionary containing all the links that should be
included in the API schema.
"""
links = LinkNode()
paths, view_endpoints = self._get_paths_and_endpoints(request)
# Only generate the path prefix for paths that will be included
if not paths:
return None
prefix = self.determine_path_prefix(paths)
for path, method, view in view_endpoints:
if not self.has_view_permissions(path, method, view):
continue
link = view.schema.get_link(path, method, base_url=self.url)
subpath = path[len(prefix):]
keys = self.get_keys(subpath, method, view)
insert_into(links, keys, link)
return links
def get_schema(self, request=None, public=False):
"""
Generate a `coreapi.Document` representing the API schema.
"""
self._initialise_endpoints()
links = self.get_links(None if public else request)
if not links:
return None
url = self.url
if not url and request is not None:
url = request.build_absolute_uri()
distribute_links(links)
return coreapi.Document(
title=self.title, description=self.description,
url=url, content=links
)
# Method for generating the link layout....
def get_keys(self, subpath, method, view):
"""
Return a list of keys that should be used to layout a link within
the schema document.
/users/ ("users", "list"), ("users", "create")
/users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete")
/users/enabled/ ("users", "enabled") # custom viewset list action
/users/{pk}/star/ ("users", "star") # custom viewset detail action
/users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create")
/users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update"), ("users", "groups", "delete")
"""
if hasattr(view, 'action'):
# Viewsets have explicitly named actions.
action = view.action
else:
# Views have no associated action, so we determine one from the method.
if is_list_view(subpath, method, view):
action = 'list'
else:
action = self.default_mapping[method.lower()]
named_path_components = [
component for component
in subpath.strip('/').split('/')
if '{' not in component
]
if is_custom_action(action):
# Custom action, eg "/users/{pk}/activate/", "/users/active/"
mapped_methods = {
# Don't count head mapping, e.g. not part of the schema
method for method in view.action_map if method != 'head'
}
if len(mapped_methods) > 1:
action = self.default_mapping[method.lower()]
if action in self.coerce_method_names:
action = self.coerce_method_names[action]
return named_path_components + [action]
else:
return named_path_components[:-1] + [action]
if action in self.coerce_method_names:
action = self.coerce_method_names[action]
# Default action, eg "/users/", "/users/{pk}/"
return named_path_components + [action]
def determine_path_prefix(self, paths):
"""
Given a list of all paths, return the common prefix which should be
discounted when generating a schema structure.
This will be the longest common string that does not include that last
component of the URL, or the last component before a path parameter.
For example:
/api/v1/users/
/api/v1/users/{pk}/
The path prefix is '/api/v1'
"""
prefixes = []
for path in paths:
components = path.strip('/').split('/')
initial_components = []
for component in components:
if '{' in component:
break
initial_components.append(component)
prefix = '/'.join(initial_components[:-1])
if not prefix:
# We can just break early in the case that there's at least
# one URL that doesn't have a path prefix.
return '/'
prefixes.append('/' + prefix + '/')
return common_path(prefixes)
# View Inspectors #
def field_to_schema(field):
title = force_str(field.label) if field.label else ''
description = force_str(field.help_text) if field.help_text else ''
if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
child_schema = field_to_schema(field.child)
return coreschema.Array(
items=child_schema,
title=title,
description=description
)
elif isinstance(field, serializers.DictField):
return coreschema.Object(
title=title,
description=description
)
elif isinstance(field, serializers.Serializer):
return coreschema.Object(
properties=OrderedDict([
(key, field_to_schema(value))
for key, value
in field.fields.items()
]),
title=title,
description=description
)
elif isinstance(field, serializers.ManyRelatedField):
related_field_schema = field_to_schema(field.child_relation)
return coreschema.Array(
items=related_field_schema,
title=title,
description=description
)
elif isinstance(field, serializers.PrimaryKeyRelatedField):
schema_cls = coreschema.String
model = getattr(field.queryset, 'model', None)
if model is not None:
model_field = model._meta.pk
if isinstance(model_field, models.AutoField):
schema_cls = coreschema.Integer
return schema_cls(title=title, description=description)
elif isinstance(field, serializers.RelatedField):
return coreschema.String(title=title, description=description)
elif isinstance(field, serializers.MultipleChoiceField):
return coreschema.Array(
items=coreschema.Enum(enum=list(field.choices)),
title=title,
description=description
)
elif isinstance(field, serializers.ChoiceField):
return coreschema.Enum(
enum=list(field.choices),
title=title,
description=description
)
elif isinstance(field, serializers.BooleanField):
return coreschema.Boolean(title=title, description=description)
elif isinstance(field, (serializers.DecimalField, serializers.FloatField)):
return coreschema.Number(title=title, description=description)
elif isinstance(field, serializers.IntegerField):
return coreschema.Integer(title=title, description=description)
elif isinstance(field, serializers.DateField):
return coreschema.String(
title=title,
description=description,
format='date'
)
elif isinstance(field, serializers.DateTimeField):
return coreschema.String(
title=title,
description=description,
format='date-time'
)
elif isinstance(field, serializers.JSONField):
return coreschema.Object(title=title, description=description)
if field.style.get('base_template') == 'textarea.html':
return coreschema.String(
title=title,
description=description,
format='textarea'
)
return coreschema.String(title=title, description=description)
class AutoSchema(ViewInspector):
"""
Default inspector for APIView
Responsible for per-view introspection and schema generation.
"""
def __init__(self, manual_fields=None):
"""
Parameters:
* `manual_fields`: list of `coreapi.Field` instances that
will be added to auto-generated fields, overwriting on `Field.name`
"""
super().__init__()
if manual_fields is None:
manual_fields = []
self._manual_fields = manual_fields
def get_link(self, path, method, base_url):
"""
Generate `coreapi.Link` for self.view, path and method.
This is the main _public_ access point.
Parameters:
* path: Route path for view from URLConf.
* method: The HTTP request method.
* base_url: The project "mount point" as given to SchemaGenerator
"""
fields = self.get_path_fields(path, method)
fields += self.get_serializer_fields(path, method)
fields += self.get_pagination_fields(path, method)
fields += self.get_filter_fields(path, method)
manual_fields = self.get_manual_fields(path, method)
fields = self.update_fields(fields, manual_fields)
if fields and any([field.location in ('form', 'body') for field in fields]):
encoding = self.get_encoding(path, method)
else:
encoding = None
description = self.get_description(path, method)
if base_url and path.startswith('/'):
path = path[1:]
return coreapi.Link(
url=parse.urljoin(base_url, path),
action=method.lower(),
encoding=encoding,
fields=fields,
description=description
)
def get_path_fields(self, path, method):
"""
Return a list of `coreapi.Field` instances corresponding to any
templated path variables.
"""
view = self.view
model = getattr(getattr(view, 'queryset', None), 'model', None)
fields = []
for variable in uritemplate.variables(path):
title = ''
description = ''
schema_cls = coreschema.String
kwargs = {}
if model is not None:
# Attempt to infer a field description if possible.
try:
model_field = model._meta.get_field(variable)
except Exception:
model_field = None
if model_field is not None and model_field.verbose_name:
title = force_str(model_field.verbose_name)
if model_field is not None and model_field.help_text:
description = force_str(model_field.help_text)
elif model_field is not None and model_field.primary_key:
description = get_pk_description(model, model_field)
if hasattr(view, 'lookup_value_regex') and view.lookup_field == variable:
kwargs['pattern'] = view.lookup_value_regex
elif isinstance(model_field, models.AutoField):
schema_cls = coreschema.Integer
field = coreapi.Field(
name=variable,
location='path',
required=True,
schema=schema_cls(title=title, description=description, **kwargs)
)
fields.append(field)
return fields
def get_serializer_fields(self, path, method):
"""
Return a list of `coreapi.Field` instances corresponding to any
request body input, as determined by the serializer class.
"""
view = self.view
if method not in ('PUT', 'PATCH', 'POST'):
return []
if not hasattr(view, 'get_serializer'):
return []
try:
serializer = view.get_serializer()
except exceptions.APIException:
serializer = None
warnings.warn('{}.get_serializer() raised an exception during '
'schema generation. Serializer fields will not be '
'generated for {} {}.'
.format(view.__class__.__name__, method, path))
if isinstance(serializer, serializers.ListSerializer):
return [
coreapi.Field(
name='data',
location='body',
required=True,
schema=coreschema.Array()
)
]
if not isinstance(serializer, serializers.Serializer):
return []
fields = []
for field in serializer.fields.values():
if field.read_only or isinstance(field, serializers.HiddenField):
continue
required = field.required and method != 'PATCH'
field = coreapi.Field(
name=field.field_name,
location='form',
required=required,
schema=field_to_schema(field)
)
fields.append(field)
return fields
def get_pagination_fields(self, path, method):
view = self.view
if not is_list_view(path, method, view):
return []
pagination = getattr(view, 'pagination_class', None)
if not pagination:
return []
paginator = view.pagination_class()
return paginator.get_schema_fields(view)
def _allows_filters(self, path, method):
"""
Determine whether to include filter Fields in schema.
Default implementation looks for ModelViewSet or GenericAPIView
actions/methods that cause filtering on the default implementation.
Override to adjust behaviour for your view.
Note: Introduced in v3.7: Initially "private" (i.e. with leading underscore)
to allow changes based on user experience.
"""
if getattr(self.view, 'filter_backends', None) is None:
return False
if hasattr(self.view, 'action'):
return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"]
return method.lower() in ["get", "put", "patch", "delete"]
def get_filter_fields(self, path, method):
if not self._allows_filters(path, method):
return []
fields = []
for filter_backend in self.view.filter_backends:
fields += filter_backend().get_schema_fields(self.view)
return fields
def get_manual_fields(self, path, method):
return self._manual_fields
@staticmethod
def update_fields(fields, update_with):
"""
Update list of coreapi.Field instances, overwriting on `Field.name`.
Utility function to handle replacing coreapi.Field fields
from a list by name. Used to handle `manual_fields`.
Parameters:
* `fields`: list of `coreapi.Field` instances to update
* `update_with: list of `coreapi.Field` instances to add or replace.
"""
if not update_with:
return fields
by_name = OrderedDict((f.name, f) for f in fields)
for f in update_with:
by_name[f.name] = f
fields = list(by_name.values())
return fields
def get_encoding(self, path, method):
"""
Return the 'encoding' parameter to use for a given endpoint.
"""
view = self.view
# Core API supports the following request encodings over HTTP...
supported_media_types = {
'application/json',
'application/x-www-form-urlencoded',
'multipart/form-data',
}
parser_classes = getattr(view, 'parser_classes', [])
for parser_class in parser_classes:
media_type = getattr(parser_class, 'media_type', None)
if media_type in supported_media_types:
return media_type
# Raw binary uploads are supported with "application/octet-stream"
if media_type == '*/*':
return 'application/octet-stream'
return None
class ManualSchema(ViewInspector):
"""
Allows providing a list of coreapi.Fields,
plus an optional description.
"""
def __init__(self, fields, description='', encoding=None):
"""
Parameters:
* `fields`: list of `coreapi.Field` instances.
* `description`: String description for view. Optional.
"""
super().__init__()
assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances"
self._fields = fields
self._description = description
self._encoding = encoding
def get_link(self, path, method, base_url):
if base_url and path.startswith('/'):
path = path[1:]
return coreapi.Link(
url=parse.urljoin(base_url, path),
action=method.lower(),
encoding=self._encoding,
fields=self._fields,
description=self._description
)
def is_enabled():
"""Is CoreAPI Mode enabled?"""
return issubclass(api_settings.DEFAULT_SCHEMA_CLASS, AutoSchema)

View File

@ -0,0 +1,239 @@
"""
generators.py # Top-down schema generation
See schemas.__init__.py for package overview.
"""
import re
from importlib import import_module
from django.conf import settings
from django.contrib.admindocs.views import simplify_regex
from django.core.exceptions import PermissionDenied
from django.http import Http404
from django.urls import URLPattern, URLResolver
from rest_framework import exceptions
from rest_framework.request import clone_request
from rest_framework.settings import api_settings
from rest_framework.utils.model_meta import _get_pk
def get_pk_name(model):
meta = model._meta.concrete_model._meta
return _get_pk(meta).name
def is_api_view(callback):
"""
Return `True` if the given view callback is a REST framework view/viewset.
"""
# Avoid import cycle on APIView
from rest_framework.views import APIView
cls = getattr(callback, 'cls', None)
return (cls is not None) and issubclass(cls, APIView)
def endpoint_ordering(endpoint):
path, method, callback = endpoint
method_priority = {
'GET': 0,
'POST': 1,
'PUT': 2,
'PATCH': 3,
'DELETE': 4
}.get(method, 5)
return (method_priority,)
_PATH_PARAMETER_COMPONENT_RE = re.compile(
r'<(?:(?P<converter>[^>:]+):)?(?P<parameter>\w+)>'
)
class EndpointEnumerator:
"""
A class to determine the available API endpoints that a project exposes.
"""
def __init__(self, patterns=None, urlconf=None):
if patterns is None:
if urlconf is None:
# Use the default Django URL conf
urlconf = settings.ROOT_URLCONF
# Load the given URLconf module
if isinstance(urlconf, str):
urls = import_module(urlconf)
else:
urls = urlconf
patterns = urls.urlpatterns
self.patterns = patterns
def get_api_endpoints(self, patterns=None, prefix=''):
"""
Return a list of all available API endpoints by inspecting the URL conf.
"""
if patterns is None:
patterns = self.patterns
api_endpoints = []
for pattern in patterns:
path_regex = prefix + str(pattern.pattern)
if isinstance(pattern, URLPattern):
path = self.get_path_from_regex(path_regex)
callback = pattern.callback
if self.should_include_endpoint(path, callback):
for method in self.get_allowed_methods(callback):
endpoint = (path, method, callback)
api_endpoints.append(endpoint)
elif isinstance(pattern, URLResolver):
nested_endpoints = self.get_api_endpoints(
patterns=pattern.url_patterns,
prefix=path_regex
)
api_endpoints.extend(nested_endpoints)
return sorted(api_endpoints, key=endpoint_ordering)
def get_path_from_regex(self, path_regex):
"""
Given a URL conf regex, return a URI template string.
"""
# ???: Would it be feasible to adjust this such that we generate the
# path, plus the kwargs, plus the type from the convertor, such that we
# could feed that straight into the parameter schema object?
path = simplify_regex(path_regex)
# Strip Django 2.0 convertors as they are incompatible with uritemplate format
return re.sub(_PATH_PARAMETER_COMPONENT_RE, r'{\g<parameter>}', path)
def should_include_endpoint(self, path, callback):
"""
Return `True` if the given endpoint should be included.
"""
if not is_api_view(callback):
return False # Ignore anything except REST framework views.
if callback.cls.schema is None:
return False
if 'schema' in callback.initkwargs:
if callback.initkwargs['schema'] is None:
return False
if path.endswith('.{format}') or path.endswith('.{format}/'):
return False # Ignore .json style URLs.
return True
def get_allowed_methods(self, callback):
"""
Return a list of the valid HTTP methods for this endpoint.
"""
if hasattr(callback, 'actions'):
actions = set(callback.actions)
http_method_names = set(callback.cls.http_method_names)
methods = [method.upper() for method in actions & http_method_names]
else:
methods = callback.cls().allowed_methods
return [method for method in methods if method not in ('OPTIONS', 'HEAD')]
class BaseSchemaGenerator:
endpoint_inspector_cls = EndpointEnumerator
# 'pk' isn't great as an externally exposed name for an identifier,
# so by default we prefer to use the actual model field name for schemas.
# Set by 'SCHEMA_COERCE_PATH_PK'.
coerce_path_pk = None
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None, version=None):
if url and not url.endswith('/'):
url += '/'
self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK
self.patterns = patterns
self.urlconf = urlconf
self.title = title
self.description = description
self.version = version
self.url = url
self.endpoints = None
def _initialise_endpoints(self):
if self.endpoints is None:
inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf)
self.endpoints = inspector.get_api_endpoints()
def _get_paths_and_endpoints(self, request):
"""
Generate (path, method, view) given (path, method, callback) for paths.
"""
paths = []
view_endpoints = []
for path, method, callback in self.endpoints:
view = self.create_view(callback, method, request)
path = self.coerce_path(path, method, view)
paths.append(path)
view_endpoints.append((path, method, view))
return paths, view_endpoints
def create_view(self, callback, method, request=None):
"""
Given a callback, return an actual view instance.
"""
view = callback.cls(**getattr(callback, 'initkwargs', {}))
view.args = ()
view.kwargs = {}
view.format_kwarg = None
view.request = None
view.action_map = getattr(callback, 'actions', None)
actions = getattr(callback, 'actions', None)
if actions is not None:
if method == 'OPTIONS':
view.action = 'metadata'
else:
view.action = actions.get(method.lower())
if request is not None:
view.request = clone_request(request, method)
return view
def coerce_path(self, path, method, view):
"""
Coerce {pk} path arguments into the name of the model field,
where possible. This is cleaner for an external representation.
(Ie. "this is an identifier", not "this is a database primary key")
"""
if not self.coerce_path_pk or '{pk}' not in path:
return path
model = getattr(getattr(view, 'queryset', None), 'model', None)
if model:
field_name = get_pk_name(model)
else:
field_name = 'id'
return path.replace('{pk}', '{%s}' % field_name)
def get_schema(self, request=None, public=False):
raise NotImplementedError(".get_schema() must be implemented in subclasses.")
def has_view_permissions(self, path, method, view):
"""
Return `True` if the incoming request has the correct view permissions.
"""
if view.request is None:
return True
try:
view.check_permissions(view.request)
except (exceptions.APIException, Http404, PermissionDenied):
return False
return True

View File

@ -0,0 +1,125 @@
"""
inspectors.py # Per-endpoint view introspection
See schemas.__init__.py for package overview.
"""
import re
from weakref import WeakKeyDictionary
from django.utils.encoding import smart_str
from rest_framework.settings import api_settings
from rest_framework.utils import formatting
class ViewInspector:
"""
Descriptor class on APIView.
Provide subclass for per-view schema generation
"""
# Used in _get_description_section()
header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:')
def __init__(self):
self.instance_schemas = WeakKeyDictionary()
def __get__(self, instance, owner):
"""
Enables `ViewInspector` as a Python _Descriptor_.
This is how `view.schema` knows about `view`.
`__get__` is called when the descriptor is accessed on the owner.
(That will be when view.schema is called in our case.)
`owner` is always the owner class. (An APIView, or subclass for us.)
`instance` is the view instance or `None` if accessed from the class,
rather than an instance.
See: https://docs.python.org/3/howto/descriptor.html for info on
descriptor usage.
"""
if instance in self.instance_schemas:
return self.instance_schemas[instance]
self.view = instance
return self
def __set__(self, instance, other):
self.instance_schemas[instance] = other
if other is not None:
other.view = instance
@property
def view(self):
"""View property."""
assert self._view is not None, (
"Schema generation REQUIRES a view instance. (Hint: you accessed "
"`schema` from the view class rather than an instance.)"
)
return self._view
@view.setter
def view(self, value):
self._view = value
@view.deleter
def view(self):
self._view = None
def get_description(self, path, method):
"""
Determine a path description.
This will be based on the method docstring if one exists,
or else the class docstring.
"""
view = self.view
method_name = getattr(view, 'action', method.lower())
method_docstring = getattr(view, method_name, None).__doc__
if method_docstring:
# An explicit docstring on the method or action.
return self._get_description_section(view, method.lower(), formatting.dedent(smart_str(method_docstring)))
else:
return self._get_description_section(view, getattr(view, 'action', method.lower()),
view.get_view_description())
def _get_description_section(self, view, header, description):
lines = [line for line in description.splitlines()]
current_section = ''
sections = {'': ''}
for line in lines:
if self.header_regex.match(line):
current_section, separator, lead = line.partition(':')
sections[current_section] = lead.strip()
else:
sections[current_section] += '\n' + line
# TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys`
coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
if header in sections:
return sections[header].strip()
if header in coerce_method_names:
if coerce_method_names[header] in sections:
return sections[coerce_method_names[header]].strip()
return sections[''].strip()
class DefaultSchema(ViewInspector):
"""Allows overriding AutoSchema using DEFAULT_SCHEMA_CLASS setting"""
def __get__(self, instance, owner):
result = super().__get__(instance, owner)
if not isinstance(result, DefaultSchema):
return result
inspector_class = api_settings.DEFAULT_SCHEMA_CLASS
assert issubclass(inspector_class, ViewInspector), (
"DEFAULT_SCHEMA_CLASS must be set to a ViewInspector (usually an AutoSchema) subclass"
)
inspector = inspector_class()
inspector.view = instance
return inspector

View File

@ -0,0 +1,722 @@
import re
import warnings
from collections import OrderedDict
from decimal import Decimal
from operator import attrgetter
from urllib.parse import urljoin
from django.core.validators import (
DecimalValidator, EmailValidator, MaxLengthValidator, MaxValueValidator,
MinLengthValidator, MinValueValidator, RegexValidator, URLValidator
)
from django.db import models
from django.utils.encoding import force_str
from rest_framework import (
RemovedInDRF315Warning, exceptions, renderers, serializers
)
from rest_framework.compat import uritemplate
from rest_framework.fields import _UnvalidatedField, empty
from rest_framework.settings import api_settings
from .generators import BaseSchemaGenerator
from .inspectors import ViewInspector
from .utils import get_pk_description, is_list_view
class SchemaGenerator(BaseSchemaGenerator):
def get_info(self):
# Title and version are required by openapi specification 3.x
info = {
'title': self.title or '',
'version': self.version or ''
}
if self.description is not None:
info['description'] = self.description
return info
def check_duplicate_operation_id(self, paths):
ids = {}
for route in paths:
for method in paths[route]:
if 'operationId' not in paths[route][method]:
continue
operation_id = paths[route][method]['operationId']
if operation_id in ids:
warnings.warn(
'You have a duplicated operationId in your OpenAPI schema: {operation_id}\n'
'\tRoute: {route1}, Method: {method1}\n'
'\tRoute: {route2}, Method: {method2}\n'
'\tAn operationId has to be unique across your schema. Your schema may not work in other tools.'
.format(
route1=ids[operation_id]['route'],
method1=ids[operation_id]['method'],
route2=route,
method2=method,
operation_id=operation_id
)
)
ids[operation_id] = {
'route': route,
'method': method
}
def get_schema(self, request=None, public=False):
"""
Generate a OpenAPI schema.
"""
self._initialise_endpoints()
components_schemas = {}
# Iterate endpoints generating per method path operations.
paths = {}
_, view_endpoints = self._get_paths_and_endpoints(None if public else request)
for path, method, view in view_endpoints:
if not self.has_view_permissions(path, method, view):
continue
operation = view.schema.get_operation(path, method)
components = view.schema.get_components(path, method)
for k in components.keys():
if k not in components_schemas:
continue
if components_schemas[k] == components[k]:
continue
warnings.warn('Schema component "{}" has been overriden with a different value.'.format(k))
components_schemas.update(components)
# Normalise path for any provided mount url.
if path.startswith('/'):
path = path[1:]
path = urljoin(self.url or '/', path)
paths.setdefault(path, {})
paths[path][method.lower()] = operation
self.check_duplicate_operation_id(paths)
# Compile final schema.
schema = {
'openapi': '3.0.2',
'info': self.get_info(),
'paths': paths,
}
if len(components_schemas) > 0:
schema['components'] = {
'schemas': components_schemas
}
return schema
# View Inspectors
class AutoSchema(ViewInspector):
def __init__(self, tags=None, operation_id_base=None, component_name=None):
"""
:param operation_id_base: user-defined name in operationId. If empty, it will be deducted from the Model/Serializer/View name.
:param component_name: user-defined component's name. If empty, it will be deducted from the Serializer's class name.
"""
if tags and not all(isinstance(tag, str) for tag in tags):
raise ValueError('tags must be a list or tuple of string.')
self._tags = tags
self.operation_id_base = operation_id_base
self.component_name = component_name
super().__init__()
request_media_types = []
response_media_types = []
method_mapping = {
'get': 'retrieve',
'post': 'create',
'put': 'update',
'patch': 'partialUpdate',
'delete': 'destroy',
}
def get_operation(self, path, method):
operation = {}
operation['operationId'] = self.get_operation_id(path, method)
operation['description'] = self.get_description(path, method)
parameters = []
parameters += self.get_path_parameters(path, method)
parameters += self.get_pagination_parameters(path, method)
parameters += self.get_filter_parameters(path, method)
operation['parameters'] = parameters
request_body = self.get_request_body(path, method)
if request_body:
operation['requestBody'] = request_body
operation['responses'] = self.get_responses(path, method)
operation['tags'] = self.get_tags(path, method)
return operation
def get_component_name(self, serializer):
"""
Compute the component's name from the serializer.
Raise an exception if the serializer's class name is "Serializer" (case-insensitive).
"""
if self.component_name is not None:
return self.component_name
# use the serializer's class name as the component name.
component_name = serializer.__class__.__name__
# We remove the "serializer" string from the class name.
pattern = re.compile("serializer", re.IGNORECASE)
component_name = pattern.sub("", component_name)
if component_name == "":
raise Exception(
'"{}" is an invalid class name for schema generation. '
'Serializer\'s class name should be unique and explicit. e.g. "ItemSerializer"'
.format(serializer.__class__.__name__)
)
return component_name
def get_components(self, path, method):
"""
Return components with their properties from the serializer.
"""
if method.lower() == 'delete':
return {}
request_serializer = self.get_request_serializer(path, method)
response_serializer = self.get_response_serializer(path, method)
components = {}
if isinstance(request_serializer, serializers.Serializer):
component_name = self.get_component_name(request_serializer)
content = self.map_serializer(request_serializer)
components.setdefault(component_name, content)
if isinstance(response_serializer, serializers.Serializer):
component_name = self.get_component_name(response_serializer)
content = self.map_serializer(response_serializer)
components.setdefault(component_name, content)
return components
def _to_camel_case(self, snake_str):
components = snake_str.split('_')
# We capitalize the first letter of each component except the first one
# with the 'title' method and join them together.
return components[0] + ''.join(x.title() for x in components[1:])
def get_operation_id_base(self, path, method, action):
"""
Compute the base part for operation ID from the model, serializer or view name.
"""
model = getattr(getattr(self.view, 'queryset', None), 'model', None)
if self.operation_id_base is not None:
name = self.operation_id_base
# Try to deduce the ID from the view's model
elif model is not None:
name = model.__name__
# Try with the serializer class name
elif self.get_serializer(path, method) is not None:
name = self.get_serializer(path, method).__class__.__name__
if name.endswith('Serializer'):
name = name[:-10]
# Fallback to the view name
else:
name = self.view.__class__.__name__
if name.endswith('APIView'):
name = name[:-7]
elif name.endswith('View'):
name = name[:-4]
# Due to camel-casing of classes and `action` being lowercase, apply title in order to find if action truly
# comes at the end of the name
if name.endswith(action.title()): # ListView, UpdateAPIView, ThingDelete ...
name = name[:-len(action)]
if action == 'list' and not name.endswith('s'): # listThings instead of listThing
name += 's'
return name
def get_operation_id(self, path, method):
"""
Compute an operation ID from the view type and get_operation_id_base method.
"""
method_name = getattr(self.view, 'action', method.lower())
if is_list_view(path, method, self.view):
action = 'list'
elif method_name not in self.method_mapping:
action = self._to_camel_case(method_name)
else:
action = self.method_mapping[method.lower()]
name = self.get_operation_id_base(path, method, action)
return action + name
def get_path_parameters(self, path, method):
"""
Return a list of parameters from templated path variables.
"""
assert uritemplate, '`uritemplate` must be installed for OpenAPI schema support.'
model = getattr(getattr(self.view, 'queryset', None), 'model', None)
parameters = []
for variable in uritemplate.variables(path):
description = ''
if model is not None: # TODO: test this.
# Attempt to infer a field description if possible.
try:
model_field = model._meta.get_field(variable)
except Exception:
model_field = None
if model_field is not None and model_field.help_text:
description = force_str(model_field.help_text)
elif model_field is not None and model_field.primary_key:
description = get_pk_description(model, model_field)
parameter = {
"name": variable,
"in": "path",
"required": True,
"description": description,
'schema': {
'type': 'string', # TODO: integer, pattern, ...
},
}
parameters.append(parameter)
return parameters
def get_filter_parameters(self, path, method):
if not self.allows_filters(path, method):
return []
parameters = []
for filter_backend in self.view.filter_backends:
parameters += filter_backend().get_schema_operation_parameters(self.view)
return parameters
def allows_filters(self, path, method):
"""
Determine whether to include filter Fields in schema.
Default implementation looks for ModelViewSet or GenericAPIView
actions/methods that cause filtering on the default implementation.
"""
if getattr(self.view, 'filter_backends', None) is None:
return False
if hasattr(self.view, 'action'):
return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"]
return method.lower() in ["get", "put", "patch", "delete"]
def get_pagination_parameters(self, path, method):
view = self.view
if not is_list_view(path, method, view):
return []
paginator = self.get_paginator()
if not paginator:
return []
return paginator.get_schema_operation_parameters(view)
def map_choicefield(self, field):
choices = list(OrderedDict.fromkeys(field.choices)) # preserve order and remove duplicates
if all(isinstance(choice, bool) for choice in choices):
type = 'boolean'
elif all(isinstance(choice, int) for choice in choices):
type = 'integer'
elif all(isinstance(choice, (int, float, Decimal)) for choice in choices): # `number` includes `integer`
# Ref: https://tools.ietf.org/html/draft-wright-json-schema-validation-00#section-5.21
type = 'number'
elif all(isinstance(choice, str) for choice in choices):
type = 'string'
else:
type = None
mapping = {
# The value of `enum` keyword MUST be an array and SHOULD be unique.
# Ref: https://tools.ietf.org/html/draft-wright-json-schema-validation-00#section-5.20
'enum': choices
}
# If We figured out `type` then and only then we should set it. It must be a string.
# Ref: https://swagger.io/docs/specification/data-models/data-types/#mixed-type
# It is optional but it can not be null.
# Ref: https://tools.ietf.org/html/draft-wright-json-schema-validation-00#section-5.21
if type:
mapping['type'] = type
return mapping
def map_field(self, field):
# Nested Serializers, `many` or not.
if isinstance(field, serializers.ListSerializer):
return {
'type': 'array',
'items': self.map_serializer(field.child)
}
if isinstance(field, serializers.Serializer):
data = self.map_serializer(field)
data['type'] = 'object'
return data
# Related fields.
if isinstance(field, serializers.ManyRelatedField):
return {
'type': 'array',
'items': self.map_field(field.child_relation)
}
if isinstance(field, serializers.PrimaryKeyRelatedField):
model = getattr(field.queryset, 'model', None)
if model is not None:
model_field = model._meta.pk
if isinstance(model_field, models.AutoField):
return {'type': 'integer'}
# ChoiceFields (single and multiple).
# Q:
# - Is 'type' required?
# - can we determine the TYPE of a choicefield?
if isinstance(field, serializers.MultipleChoiceField):
return {
'type': 'array',
'items': self.map_choicefield(field)
}
if isinstance(field, serializers.ChoiceField):
return self.map_choicefield(field)
# ListField.
if isinstance(field, serializers.ListField):
mapping = {
'type': 'array',
'items': {},
}
if not isinstance(field.child, _UnvalidatedField):
mapping['items'] = self.map_field(field.child)
return mapping
# DateField and DateTimeField type is string
if isinstance(field, serializers.DateField):
return {
'type': 'string',
'format': 'date',
}
if isinstance(field, serializers.DateTimeField):
return {
'type': 'string',
'format': 'date-time',
}
# "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification."
# see: https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types
# see also: https://swagger.io/docs/specification/data-models/data-types/#string
if isinstance(field, serializers.EmailField):
return {
'type': 'string',
'format': 'email'
}
if isinstance(field, serializers.URLField):
return {
'type': 'string',
'format': 'uri'
}
if isinstance(field, serializers.UUIDField):
return {
'type': 'string',
'format': 'uuid'
}
if isinstance(field, serializers.IPAddressField):
content = {
'type': 'string',
}
if field.protocol != 'both':
content['format'] = field.protocol
return content
if isinstance(field, serializers.DecimalField):
if getattr(field, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING):
content = {
'type': 'string',
'format': 'decimal',
}
else:
content = {
'type': 'number'
}
if field.decimal_places:
content['multipleOf'] = float('.' + (field.decimal_places - 1) * '0' + '1')
if field.max_whole_digits:
content['maximum'] = int(field.max_whole_digits * '9') + 1
content['minimum'] = -content['maximum']
self._map_min_max(field, content)
return content
if isinstance(field, serializers.FloatField):
content = {
'type': 'number',
}
self._map_min_max(field, content)
return content
if isinstance(field, serializers.IntegerField):
content = {
'type': 'integer'
}
self._map_min_max(field, content)
# 2147483647 is max for int32_size, so we use int64 for format
if int(content.get('maximum', 0)) > 2147483647 or int(content.get('minimum', 0)) > 2147483647:
content['format'] = 'int64'
return content
if isinstance(field, serializers.FileField):
return {
'type': 'string',
'format': 'binary'
}
# Simplest cases, default to 'string' type:
FIELD_CLASS_SCHEMA_TYPE = {
serializers.BooleanField: 'boolean',
serializers.JSONField: 'object',
serializers.DictField: 'object',
serializers.HStoreField: 'object',
}
return {'type': FIELD_CLASS_SCHEMA_TYPE.get(field.__class__, 'string')}
def _map_min_max(self, field, content):
if field.max_value:
content['maximum'] = field.max_value
if field.min_value:
content['minimum'] = field.min_value
def map_serializer(self, serializer):
# Assuming we have a valid serializer instance.
required = []
properties = {}
for field in serializer.fields.values():
if isinstance(field, serializers.HiddenField):
continue
if field.required:
required.append(field.field_name)
schema = self.map_field(field)
if field.read_only:
schema['readOnly'] = True
if field.write_only:
schema['writeOnly'] = True
if field.allow_null:
schema['nullable'] = True
if field.default is not None and field.default != empty and not callable(field.default):
schema['default'] = field.default
if field.help_text:
schema['description'] = str(field.help_text)
self.map_field_validators(field, schema)
properties[field.field_name] = schema
result = {
'type': 'object',
'properties': properties
}
if required:
result['required'] = required
return result
def map_field_validators(self, field, schema):
"""
map field validators
"""
for v in field.validators:
# "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification."
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types
if isinstance(v, EmailValidator):
schema['format'] = 'email'
if isinstance(v, URLValidator):
schema['format'] = 'uri'
if isinstance(v, RegexValidator):
# In Python, the token \Z does what \z does in other engines.
# https://stackoverflow.com/questions/53283160
schema['pattern'] = v.regex.pattern.replace('\\Z', '\\z')
elif isinstance(v, MaxLengthValidator):
attr_name = 'maxLength'
if isinstance(field, serializers.ListField):
attr_name = 'maxItems'
schema[attr_name] = v.limit_value
elif isinstance(v, MinLengthValidator):
attr_name = 'minLength'
if isinstance(field, serializers.ListField):
attr_name = 'minItems'
schema[attr_name] = v.limit_value
elif isinstance(v, MaxValueValidator):
schema['maximum'] = v.limit_value
elif isinstance(v, MinValueValidator):
schema['minimum'] = v.limit_value
elif isinstance(v, DecimalValidator) and \
not getattr(field, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING):
if v.decimal_places:
schema['multipleOf'] = float('.' + (v.decimal_places - 1) * '0' + '1')
if v.max_digits:
digits = v.max_digits
if v.decimal_places is not None and v.decimal_places > 0:
digits -= v.decimal_places
schema['maximum'] = int(digits * '9') + 1
schema['minimum'] = -schema['maximum']
def get_paginator(self):
pagination_class = getattr(self.view, 'pagination_class', None)
if pagination_class:
return pagination_class()
return None
def map_parsers(self, path, method):
return list(map(attrgetter('media_type'), self.view.parser_classes))
def map_renderers(self, path, method):
media_types = []
for renderer in self.view.renderer_classes:
# BrowsableAPIRenderer not relevant to OpenAPI spec
if issubclass(renderer, renderers.BrowsableAPIRenderer):
continue
media_types.append(renderer.media_type)
return media_types
def get_serializer(self, path, method):
view = self.view
if not hasattr(view, 'get_serializer'):
return None
try:
return view.get_serializer()
except exceptions.APIException:
warnings.warn('{}.get_serializer() raised an exception during '
'schema generation. Serializer fields will not be '
'generated for {} {}.'
.format(view.__class__.__name__, method, path))
return None
def get_request_serializer(self, path, method):
"""
Override this method if your view uses a different serializer for
handling request body.
"""
return self.get_serializer(path, method)
def get_response_serializer(self, path, method):
"""
Override this method if your view uses a different serializer for
populating response data.
"""
return self.get_serializer(path, method)
def get_reference(self, serializer):
return {'$ref': '#/components/schemas/{}'.format(self.get_component_name(serializer))}
def get_request_body(self, path, method):
if method not in ('PUT', 'PATCH', 'POST'):
return {}
self.request_media_types = self.map_parsers(path, method)
serializer = self.get_request_serializer(path, method)
if not isinstance(serializer, serializers.Serializer):
item_schema = {}
else:
item_schema = self.get_reference(serializer)
return {
'content': {
ct: {'schema': item_schema}
for ct in self.request_media_types
}
}
def get_responses(self, path, method):
if method == 'DELETE':
return {
'204': {
'description': ''
}
}
self.response_media_types = self.map_renderers(path, method)
serializer = self.get_response_serializer(path, method)
if not isinstance(serializer, serializers.Serializer):
item_schema = {}
else:
item_schema = self.get_reference(serializer)
if is_list_view(path, method, self.view):
response_schema = {
'type': 'array',
'items': item_schema,
}
paginator = self.get_paginator()
if paginator:
response_schema = paginator.get_paginated_response_schema(response_schema)
else:
response_schema = item_schema
status_code = '201' if method == 'POST' else '200'
return {
status_code: {
'content': {
ct: {'schema': response_schema}
for ct in self.response_media_types
},
# description is a mandatory property,
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#responseObject
# TODO: put something meaningful into it
'description': ""
}
}
def get_tags(self, path, method):
# If user have specified tags, use them.
if self._tags:
return self._tags
# First element of a specific path could be valid tag. This is a fallback solution.
# PUT, PATCH, GET(Retrieve), DELETE: /user_profile/{id}/ tags = [user-profile]
# POST, GET(List): /user_profile/ tags = [user-profile]
if path.startswith('/'):
path = path[1:]
return [path.split('/')[0].replace('_', '-')]
def _get_reference(self, serializer):
warnings.warn(
"Method `_get_reference()` has been renamed to `get_reference()`. "
"The old name will be removed in DRF v3.15.",
RemovedInDRF315Warning, stacklevel=2
)
return self.get_reference(serializer)

View File

@ -0,0 +1,41 @@
"""
utils.py # Shared helper functions
See schemas.__init__.py for package overview.
"""
from django.db import models
from django.utils.translation import gettext_lazy as _
from rest_framework.mixins import RetrieveModelMixin
def is_list_view(path, method, view):
"""
Return True if the given path/method appears to represent a list view.
"""
if hasattr(view, 'action'):
# Viewsets have an explicitly defined action, which we can inspect.
return view.action == 'list'
if method.lower() != 'get':
return False
if isinstance(view, RetrieveModelMixin):
return False
path_components = path.strip('/').split('/')
if path_components and '{' in path_components[-1]:
return False
return True
def get_pk_description(model, model_field):
if isinstance(model_field, models.AutoField):
value_type = _('unique integer value')
elif isinstance(model_field, models.UUIDField):
value_type = _('UUID string')
else:
value_type = _('unique value')
return _('A {value_type} identifying this {name}.').format(
value_type=value_type,
name=model._meta.verbose_name,
)

Some files were not shown because too many files have changed in this diff Show More