Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion reversion/revisions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from asgiref.sync import sync_to_async
from contextvars import ContextVar
from collections import namedtuple, defaultdict
from contextlib import contextmanager
from contextlib import asynccontextmanager, contextmanager
from functools import wraps
from django.apps import apps
from django.core import serializers
Expand Down Expand Up @@ -279,6 +280,25 @@ def _dummy_context():
yield


@asynccontextmanager
async def _acreate_revision_context(manage_manually, using):
_push_frame(manage_manually, using)
try:
yield
# Only save for a db if that's the last stack frame for that db.
if not any(using in frame.db_versions for frame in _stack.get()[:-1]):
current_frame = _current_frame()
await sync_to_async(_save_revision)(
versions=current_frame.db_versions[using].values(),
user=current_frame.user,
comment=current_frame.comment,
meta=current_frame.meta,
date_created=current_frame.date_created,
using=using,
)
finally:
_pop_frame()

@contextmanager
def _create_revision_context(manage_manually, using, atomic):
context = transaction.atomic(using=using) if atomic else _dummy_context()
Expand Down Expand Up @@ -307,6 +327,11 @@ def _create_revision_context(manage_manually, using, atomic):
finally:
_pop_frame()

def acreate_revision(manage_manually=False, using=None):
from reversion.models import Revision
using = using or router.db_for_write(Revision)
return _ContextWrapper(_acreate_revision_context, (manage_manually, using))


def create_revision(manage_manually=False, using=None, atomic=True):
from reversion.models import Revision
Expand All @@ -321,6 +346,12 @@ def __init__(self, func, args):
self._args = args
self._context = func(*args)

async def __aenter__(self):
return await self._context.__aenter__()

async def __aexit__(self, exc_type, exc_value, traceback):
return await self._context.__aexit__(exc_type, exc_value, traceback)

def __enter__(self):
return self._context.__enter__()

Expand Down
68 changes: 57 additions & 11 deletions reversion/views.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,54 @@
from functools import wraps

from reversion.revisions import create_revision as create_revision_base, set_user, get_user
from reversion.revisions import (
create_revision as create_revision_base,
acreate_revision as acreate_revision_base,
set_user,
get_user,
)


def _request_creates_revision(request):
return request.method not in ("OPTIONS", "GET", "HEAD")


def _set_user_from_request(request):
if getattr(request, "user", None) and request.user.is_authenticated and get_user() is None:
if (
getattr(request, "user", None)
and request.user.is_authenticated
and get_user() is None
):
set_user(request.user)


def create_revision(manage_manually=False, using=None, atomic=True, request_creates_revision=None):
def acreate_revision(manage_manually=False, using=None, request_creates_revision=None):
"""
View decorator that wraps the request in a revision.

The revision will have it's user set from the request automatically.
"""
request_creates_revision = request_creates_revision or _request_creates_revision

def decorator(func):
@wraps(func)
async def do_revision_view(request, *args, **kwargs):
if request_creates_revision(request):
async with acreate_revision_base(
manage_manually=manage_manually, using=using
):
response = await func(request, *args, **kwargs)
_set_user_from_request(request)
return response
return await func(request, *args, **kwargs)

return do_revision_view

return decorator


def create_revision(
manage_manually=False, using=None, atomic=True, request_creates_revision=None
):
"""
View decorator that wraps the request in a revision.

Expand All @@ -24,17 +60,20 @@ def decorator(func):
@wraps(func)
def do_revision_view(request, *args, **kwargs):
if request_creates_revision(request):
with create_revision_base(manage_manually=manage_manually, using=using, atomic=atomic):
with create_revision_base(
manage_manually=manage_manually, using=using, atomic=atomic
):
response = func(request, *args, **kwargs)
_set_user_from_request(request)
return response
return func(request, *args, **kwargs)

return do_revision_view

return decorator


class RevisionMixin:

"""
A class-based view mixin that wraps the request in a revision.

Expand All @@ -49,12 +88,19 @@ class RevisionMixin:

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dispatch = create_revision(
manage_manually=self.revision_manage_manually,
using=self.revision_using,
atomic=self.revision_atomic,
request_creates_revision=self.revision_request_creates_revision
)(self.dispatch)
if getattr(self.__class__, "view_is_async", False):
self.dispatch = acreate_revision(
manage_manually=self.revision_manage_manually,
using=self.revision_using,
request_creates_revision=self.revision_request_creates_revision,
)(self.dispatch)
else:
self.dispatch = create_revision(
manage_manually=self.revision_manage_manually,
using=self.revision_using,
atomic=revision_atomic,
request_creates_revision=self.revision_request_creates_revision,
)(self.dispatch)

def revision_request_creates_revision(self, request):
return _request_creates_revision(request)