# -*- test-case-name: twisted.test.test_persisted -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Different styles of persisted objects.
"""
from __future__ import division, absolute_import
# System Imports
import types
import pickle
try:
import copy_reg
except ImportError:
import copyreg as copy_reg
import copy
import inspect
from twisted.python.compat import _PY3, _PYPY
# Twisted Imports
from twisted.python import log
from twisted.python import reflect
oldModules = {}
try:
import cPickle
except ImportError:
cPickle = None
if cPickle is None or cPickle.PicklingError is pickle.PicklingError:
_UniversalPicklingError = pickle.PicklingError
else:
class _UniversalPicklingError(pickle.PicklingError,
cPickle.PicklingError):
"""
A PicklingError catchable by both L{cPickle.PicklingError} and
L{pickle.PicklingError} handlers.
"""
## First, let's register support for some stuff that really ought to
## be registerable...
def pickleMethod(method):
'support function for copy_reg to pickle method refs'
if _PY3:
return (unpickleMethod, (method.__name__,
method.__self__,
method.__self__.__class__))
else:
return (unpickleMethod, (method.im_func.__name__,
method.im_self,
method.im_class))
def _methodFunction(classObject, methodName):
"""
Retrieve the function object implementing a method name given the class
it's on and a method name.
@param classObject: A class to retrieve the method's function from.
@type classObject: L{type} or L{types.ClassType}
@param methodName: The name of the method whose function to retrieve.
@type methodName: native L{str}
@return: the function object corresponding to the given method name.
@rtype: L{types.FunctionType}
"""
methodObject = getattr(classObject, methodName)
if _PY3:
return methodObject
return methodObject.im_func
def unpickleMethod(im_name, im_self, im_class):
"""
Support function for copy_reg to unpickle method refs.
@param im_name: The name of the method.
@type im_name: native L{str}
@param im_self: The instance that the method was present on.
@type im_self: L{object}
@param im_class: The class where the method was declared.
@type im_class: L{types.ClassType} or L{type} or L{None}
"""
if im_self is None:
return getattr(im_class, im_name)
try:
methodFunction = _methodFunction(im_class, im_name)
except AttributeError:
log.msg("Method", im_name, "not on class", im_class)
assert im_self is not None, "No recourse: no instance to guess from."
# Attempt a last-ditch fix before giving up. If classes have changed
# around since we pickled this method, we may still be able to get it
# by looking on the instance's current class.
if im_self.__class__ is im_class:
raise
return unpickleMethod(im_name, im_self, im_self.__class__)
else:
if _PY3:
maybeClass = ()
else:
maybeClass = tuple([im_class])
bound = types.MethodType(methodFunction, im_self, *maybeClass)
return bound
copy_reg.pickle(types.MethodType, pickleMethod, unpickleMethod)
def _pickleFunction(f):
"""
Reduce, in the sense of L{pickle}'s C{object.__reduce__} special method, a
function object into its constituent parts.
@param f: The function to reduce.
@type f: L{types.FunctionType}
@return: a 2-tuple of a reference to L{_unpickleFunction} and a tuple of
its arguments, a 1-tuple of the function's fully qualified name.
@rtype: 2-tuple of C{callable, native string}
"""
if f.__name__ == '<lambda>':
raise _UniversalPicklingError(
"Cannot pickle lambda function: {}".format(f))
return (_unpickleFunction,
tuple([".".join([f.__module__, f.__qualname__])]))
def _unpickleFunction(fullyQualifiedName):
"""
Convert a function name into a function by importing it.
This is a synonym for L{twisted.python.reflect.namedAny}, but imported
locally to avoid circular imports, and also to provide a persistent name
that can be stored (and deprecated) independently of C{namedAny}.
@param fullyQualifiedName: The fully qualified name of a function.
@type fullyQualifiedName: native C{str}
@return: A function object imported from the given location.
@rtype: L{types.FunctionType}
"""
from twisted.python.reflect import namedAny
return namedAny(fullyQualifiedName)
copy_reg.pickle(types.FunctionType, _pickleFunction, _unpickleFunction)
def pickleModule(module):
'support function for copy_reg to pickle module refs'
return unpickleModule, (module.__name__,)
def unpickleModule(name):
'support function for copy_reg to unpickle module refs'
if name in oldModules:
log.msg("Module has moved: %s" % name)
name = oldModules[name]
log.msg(name)
return __import__(name,{},{},'x')
copy_reg.pickle(types.ModuleType,
pickleModule,
unpickleModule)
def pickleStringO(stringo):
"""
Reduce the given cStringO.
This is only called on Python 2, because the cStringIO module only exists
on Python 2.
@param stringo: The string output to pickle.
@type stringo: L{cStringIO.OutputType}
"""
'support function for copy_reg to pickle StringIO.OutputTypes'
return unpickleStringO, (stringo.getvalue(), stringo.tell())
def unpickleStringO(val, sek):
"""
Convert the output of L{pickleStringO} into an appropriate type for the
current python version. This may be called on Python 3 and will convert a
cStringIO into an L{io.StringIO}.
@param val: The content of the file.
@type val: L{bytes}
@param sek: The seek position of the file.
@type sek: L{int}
@return: a file-like object which you can write bytes to.
@rtype: L{cStringIO.OutputType} on Python 2, L{io.StringIO} on Python 3.
"""
x = _cStringIO()
x.write(val)
x.seek(sek)
return x
def pickleStringI(stringi):
"""
Reduce the given cStringI.
This is only called on Python 2, because the cStringIO module only exists
on Python 2.
@param stringi: The string input to pickle.
@type stringi: L{cStringIO.InputType}
@return: a 2-tuple of (C{unpickleStringI}, (bytes, pointer))
@rtype: 2-tuple of (function, (bytes, int))
"""
return unpickleStringI, (stringi.getvalue(), stringi.tell())
def unpickleStringI(val, sek):
"""
Convert the output of L{pickleStringI} into an appropriate type for the
current Python version.
This may be called on Python 3 and will convert a cStringIO into an
L{io.StringIO}.
@param val: The content of the file.
@type val: L{bytes}
@param sek: The seek position of the file.
@type sek: L{int}
@return: a file-like object which you can read bytes from.
@rtype: L{cStringIO.OutputType} on Python 2, L{io.StringIO} on Python 3.
"""
x = _cStringIO(val)
x.seek(sek)
return x
try:
from cStringIO import InputType, OutputType, StringIO as _cStringIO
except ImportError:
from io import StringIO as _cStringIO
else:
copy_reg.pickle(OutputType, pickleStringO, unpickleStringO)
copy_reg.pickle(InputType, pickleStringI, unpickleStringI)
class Ephemeral:
"""
This type of object is never persisted; if possible, even references to it
are eliminated.
"""
def __reduce__(self):
"""
Serialize any subclass of L{Ephemeral} in a way which replaces it with
L{Ephemeral} itself.
"""
return (Ephemeral, ())
def __getstate__(self):
log.msg( "WARNING: serializing ephemeral %s" % self )
if not _PYPY:
import gc
if getattr(gc, 'get_referrers', None):
for r in gc.get_referrers(self):
log.msg( " referred to by %s" % (r,))
return None
def __setstate__(self, state):
log.msg( "WARNING: unserializing ephemeral %s" % self.__class__ )
self.__class__ = Ephemeral
versionedsToUpgrade = {}
upgraded = {}
def doUpgrade():
global versionedsToUpgrade, upgraded
for versioned in list(versionedsToUpgrade.values()):
requireUpgrade(versioned)
versionedsToUpgrade = {}
upgraded = {}
def requireUpgrade(obj):
"""Require that a Versioned instance be upgraded completely first.
"""
objID = id(obj)
if objID in versionedsToUpgrade and objID not in upgraded:
upgraded[objID] = 1
obj.versionUpgrade()
return obj
def _aybabtu(c):
"""
Get all of the parent classes of C{c}, not including C{c} itself, which are
strict subclasses of L{Versioned}.
@param c: a class
@returns: list of classes
"""
# begin with two classes that should *not* be included in the
# final result
l = [c, Versioned]
for b in inspect.getmro(c):
if b not in l and issubclass(b, Versioned):
l.append(b)
# return all except the unwanted classes
return l[2:]
class Versioned:
"""
This type of object is persisted with versioning information.
I have a single class attribute, the int persistenceVersion. After I am
unserialized (and styles.doUpgrade() is called), self.upgradeToVersionX()
will be called for each version upgrade I must undergo.
For example, if I serialize an instance of a Foo(Versioned) at version 4
and then unserialize it when the code is at version 9, the calls::
self.upgradeToVersion5()
self.upgradeToVersion6()
self.upgradeToVersion7()
self.upgradeToVersion8()
self.upgradeToVersion9()
will be made. If any of these methods are undefined, a warning message
will be printed.
"""
persistenceVersion = 0
persistenceForgets = ()
def __setstate__(self, state):
versionedsToUpgrade[id(self)] = self
self.__dict__ = state
def __getstate__(self, dict=None):
"""Get state, adding a version number to it on its way out.
"""
dct = copy.copy(dict or self.__dict__)
bases = _aybabtu(self.__class__)
bases.reverse()
bases.append(self.__class__) # don't forget me!!
for base in bases:
if 'persistenceForgets' in base.__dict__:
for slot in base.persistenceForgets:
if slot in dct:
del dct[slot]
if 'persistenceVersion' in base.__dict__:
dct['%s.persistenceVersion' % reflect.qual(base)] = base.persistenceVersion
return dct
def versionUpgrade(self):
"""(internal) Do a version upgrade.
"""
bases = _aybabtu(self.__class__)
# put the bases in order so superclasses' persistenceVersion methods
# will be called first.
bases.reverse()
bases.append(self.__class__) # don't forget me!!
# first let's look for old-skool versioned's
if "persistenceVersion" in self.__dict__:
# Hacky heuristic: if more than one class subclasses Versioned,
# we'll assume that the higher version number wins for the older
# class, so we'll consider the attribute the version of the older
# class. There are obviously possibly times when this will
# eventually be an incorrect assumption, but hopefully old-school
# persistenceVersion stuff won't make it that far into multiple
# classes inheriting from Versioned.
pver = self.__dict__['persistenceVersion']
del self.__dict__['persistenceVersion']
highestVersion = 0
highestBase = None
for base in bases:
if 'persistenceVersion' not in base.__dict__:
continue
if base.persistenceVersion > highestVersion:
highestBase = base
highestVersion = base.persistenceVersion
if highestBase:
self.__dict__['%s.persistenceVersion' % reflect.qual(highestBase)] = pver
for base in bases:
# ugly hack, but it's what the user expects, really
if (Versioned not in base.__bases__ and
'persistenceVersion' not in base.__dict__):
continue
currentVers = base.persistenceVersion
pverName = '%s.persistenceVersion' % reflect.qual(base)
persistVers = (self.__dict__.get(pverName) or 0)
if persistVers:
del self.__dict__[pverName]
assert persistVers <= currentVers, "Sorry, can't go backwards in time."
while persistVers < currentVers:
persistVers = persistVers + 1
method = base.__dict__.get('upgradeToVersion%s' % persistVers, None)
if method:
log.msg( "Upgrading %s (of %s @ %s) to version %s" % (reflect.qual(base), reflect.qual(self.__class__), id(self), persistVers) )
method(self)
else:
log.msg( 'Warning: cannot upgrade %s to version %s' % (base, persistVers) )