Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions johnny/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

from hashlib import md5
from uuid import uuid4
from types import MethodType

import django
from django.db.models.signals import post_save, post_delete
from django.utils import six

from . import localstore, signals
from . import settings
Expand Down Expand Up @@ -264,6 +266,19 @@ def sql_key(self, generation, sql, params, order, result_type,
return '%s_%s_query_%s.%s' % (self.prefix, using, generation, suffix)


def _get_original(original, instance, *args, **kwargs):
"""
Return the value from the call to the original method.
"""
if original.im_class == instance.__class__:
return original(instance, *args, **kwargs)
else: # allow compiler proxies as well
if six.PY3:
return MethodType(original.__func__, instance)(*args, **kwargs)
else:
return MethodType(original.__func__, instance, instance.__class__)(*args, **kwargs)


# XXX: Thread safety concerns? Should we only need to patch once per process?
class QueryCacheBackend(object):
"""This class is the engine behind the query cache. It reads the queries
Expand Down Expand Up @@ -311,7 +326,7 @@ def newfun(cls, *args, **kwargs):
result_type = kwargs.get('result_type', MULTI)

if any([isinstance(cls, c) for c in self._write_compilers]):
return original(cls, *args, **kwargs)
return _get_original(original, cls, *args, **kwargs)
try:
sql, params = cls.as_sql()
if not sql:
Expand Down Expand Up @@ -360,7 +375,7 @@ def newfun(cls, *args, **kwargs):
query=(sql, params, ordering_aliases),
key=key)

val = original(cls, *args, **kwargs)
val = _get_original(original, cls, *args, **kwargs)

if hasattr(val, '__iter__'):
#Can't permanently cache lazy iterables without creating
Expand All @@ -383,7 +398,7 @@ def newfun(cls, *args, **kwargs):
from django.db.models.sql import compiler
# we have to do this before we check the tables, since the tables
# are actually being set in the original function
ret = original(cls, *args, **kwargs)
ret = _get_original(original, cls, *args, **kwargs)

if isinstance(cls, compiler.SQLInsertCompiler):
#Inserts are a special case where cls.tables
Expand Down