Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 58 additions & 42 deletions rest_framework/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,6 @@ def is_api_view(callback):
return (cls is not None) and issubclass(cls, APIView)


def insert_into(target, keys, item):
"""
Insert `item` into the nested dictionary `target`.

For example:

target = {}
insert_into(target, ('users', 'list'), Link(...))
insert_into(target, ('users', 'detail'), Link(...))
assert target == {'users': {'list': Link(...), 'detail': Link(...)}}
"""
for key in keys[:1]:
if key not in target:
target[key] = {}
target = target[key]
target[keys[-1]] = item


class SchemaGenerator(object):
default_mapping = {
'get': 'read',
Expand Down Expand Up @@ -84,7 +66,7 @@ def get_schema(self, request=None):
self.endpoints = self.get_api_endpoints(self.patterns)

links = []
for key, path, method, callback in self.endpoints:
for path, method, category, action, callback in self.endpoints:
view = callback.cls()
for attr, val in getattr(callback, 'initkwargs', {}).items():
setattr(view, attr, val)
Expand All @@ -102,16 +84,21 @@ def get_schema(self, request=None):
view.request = None

link = self.get_link(path, method, callback, view)
links.append((key, link))
links.append((category, action, link))

if not link:
if not links:
return None

# Generate the schema content structure, from the endpoints.
# ('users', 'list'), Link -> {'users': {'list': Link()}}
# Generate the schema content structure, eg:
# {'users': {'list': Link()}}
content = {}
for key, link in links:
insert_into(content, key, link)
for category, action, link in links:
if category is None:
content[action] = link
elif category in content:
content[category][action] = link
else:
content[category] = {action: link}

# Return the schema document.
return coreapi.Document(title=self.title, content=content, url=self.url)
Expand All @@ -129,8 +116,8 @@ def get_api_endpoints(self, patterns, prefix=''):
callback = pattern.callback
if self.should_include_endpoint(path, callback):
for method in self.get_allowed_methods(callback):
key = self.get_key(path, method, callback)
endpoint = (key, path, method, callback)
action = self.get_action(path, method, callback)
endpoint = (path, method, action, callback)
api_endpoints.append(endpoint)

elif isinstance(pattern, RegexURLResolver):
Expand All @@ -140,7 +127,21 @@ def get_api_endpoints(self, patterns, prefix=''):
)
api_endpoints.extend(nested_endpoints)

return api_endpoints
return self.add_categories(api_endpoints)

def add_categories(self, api_endpoints):
"""
(path, method, action, callback) -> (path, method, category, action, callback)
"""
# Determine the top level categories for the schema content,
# based on the URLs of the endpoints. Eg `set(['users', 'organisations'])`
paths = [endpoint[0] for endpoint in api_endpoints]
categories = self.get_categories(paths)

return [
(path, method, self.get_category(categories, path), action, callback)
for (path, method, action, callback) in api_endpoints
]

def get_path(self, path_regex):
"""
Expand Down Expand Up @@ -177,23 +178,38 @@ def get_allowed_methods(self, callback):
callback.cls().allowed_methods if method not in ('OPTIONS', 'HEAD')
]

def get_key(self, path, method, callback):
def get_action(self, path, method, callback):
"""
Return a tuple of strings, indicating the identity to use for a
given endpoint. eg. ('users', 'list').
Return a description action string for the endpoint, eg. 'list'.
"""
category = None
for item in path.strip('/').split('/'):
if '{' in item:
break
category = item

actions = getattr(callback, 'actions', self.default_mapping)
action = actions[method.lower()]

if category:
return (category, action)
return (action,)
return actions[method.lower()]

def get_categories(self, paths):
categories = set()
split_paths = set([
tuple(path.split("{")[0].strip('/').split('/'))
for path in paths
])

while split_paths:
for split_path in list(split_paths):
if len(split_path) == 0:
split_paths.remove(split_path)
elif len(split_path) == 1:
categories.add(split_path[0])
split_paths.remove(split_path)
elif split_path[0] in categories:
split_paths.remove(split_path)

return categories

def get_category(self, categories, path):
path_components = path.split("{")[0].strip('/').split('/')
for path_component in path_components:
if path_component in categories:
return path_component
return None

# Methods for generating each individual `Link` instance...

Expand Down
14 changes: 13 additions & 1 deletion tests/test_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from rest_framework import filters, pagination, permissions, serializers
from rest_framework.compat import coreapi
from rest_framework.decorators import detail_route
from rest_framework.decorators import detail_route, list_route
from rest_framework.response import Response
from rest_framework.routers import DefaultRouter
from rest_framework.schemas import SchemaGenerator
Expand Down Expand Up @@ -43,6 +43,10 @@ class ExampleViewSet(ModelViewSet):
def custom_action(self, request, pk):
return super(ExampleSerializer, self).retrieve(self, request)

@list_route()
def custom_list_action(self, request):
return super(ExampleViewSet, self).list(self, request)

def get_serializer(self, *args, **kwargs):
assert self.request
return super(ExampleViewSet, self).get_serializer(*args, **kwargs)
Expand Down Expand Up @@ -88,6 +92,10 @@ def test_anonymous_request(self):
coreapi.Field('ordering', required=False, location='query')
]
),
'custom_list_action': coreapi.Link(
url='/example/custom_list_action/',
action='get'
),
'retrieve': coreapi.Link(
url='/example/{pk}/',
action='get',
Expand Down Expand Up @@ -144,6 +152,10 @@ def test_authenticated_request(self):
coreapi.Field('d', required=False, location='form'),
]
),
'custom_list_action': coreapi.Link(
url='/example/custom_list_action/',
action='get'
),
'update': coreapi.Link(
url='/example/{pk}/',
action='put',
Expand Down