Fix the GenericForeignKey implementation
authorJessica Tallon <jessica@megworld.co.uk>
Tue, 28 Apr 2015 17:53:25 +0000 (19:53 +0200)
committerJessica Tallon <jessica@megworld.co.uk>
Tue, 26 May 2015 14:48:58 +0000 (16:48 +0200)
mediagoblin/db/models.py

index 4b59279223ae07ec284ff6cd5096951db841cb92..b4cfe2a82622addcfdcaf41ca605eb951f14c3c7 100644 (file)
@@ -25,7 +25,7 @@ import datetime
 
 from sqlalchemy import Column, Integer, Unicode, UnicodeText, DateTime, \
         Boolean, ForeignKey, UniqueConstraint, PrimaryKeyConstraint, \
-        SmallInteger, Date
+        SmallInteger, Date, types
 from sqlalchemy.orm import relationship, backref, with_polymorphic, validates, \
         class_mapper
 from sqlalchemy.orm.collections import attribute_mapped_collection
@@ -69,7 +69,7 @@ class GenericModelReference(Base):
             return None
 
         model = self._get_model_from_type(self.model_type)
-        return model.query.filter_by(id=self.obj_pk)
+        return model.query.filter_by(id=self.obj_pk).first()
 
     def set_object(self, obj):
         model = obj.__class__
@@ -100,38 +100,30 @@ class GenericModelReference(Base):
 
     def _get_model_from_type(self, model_type):
         """ Gets a model from a tablename (model type) """
-        if getattr(self, "_TYPE_MAP", None) is None:
+        if getattr(self.__class__, "_TYPE_MAP", None) is None:
             # We want to build on the class (not the instance) a map of all the
             # models by the table name (type) for easy lookup, this is done on
             # the class so it can be shared between all instances
 
             # to prevent circular imports do import here
             self._TYPE_MAP = dict(((m.__tablename__, m) for m in MODELS))
-            setattr(self.__class__._TYPE_MAP, self._TYPE_MAP)
+            setattr(self.__class__, "_TYPE_MAP",  self._TYPE_MAP)
 
-        return self._TYPE_MAP[model_type]
+        return self.__class__._TYPE_MAP[model_type]
 
 
-class GenericForeignKey(ForeignKey):
+class GenericForeignKey(types.TypeDecorator):
 
-    def __init__(self, *args, **kwargs):
-        super(GenericForeignKey, self).__init__(
-            GenericModelReference.id,
-            *args,
-            **kwargs
-        )
+    impl = Integer
 
-    def __get__(self, *args, **kwargs):
+    def process_result_value(self, value, *args, **kwargs):
         """ Looks up GenericModelReference and model for field """
-        # Find the value of the foreign key.
-        ref = super(self, GenericForeignKey).__get__(*args, **kwargs)
-
         # If this hasn't been set yet return None
-        if ref is None:
+        if value is None:
             return None
 
         # Look up the GenericModelReference for this.
-        gmr = GenericModelReference.query.filter_by(id=ref).first()
+        gmr = GenericModelReference.query.filter_by(id=value).first()
 
         # If it's set to something invalid (i.e. no GMR exists return None)
         if gmr is None:
@@ -140,6 +132,30 @@ class GenericForeignKey(ForeignKey):
         # Ask the GMR for the corresponding model
         return gmr.get_object()
 
+    def process_bind_param(self, value, *args, **kwargs):
+        """ Save the foreign key """
+        if value is None:
+            return None
+
+        # Is there one for this already.
+        model = type(value)
+        pk = getattr(value, "id")
+
+        gmr = GenericModelReference.query.filter_by(id=pk).first()
+        if gmr is None:
+            # We need to create one
+            gmr = GenericModelReference(
+                obj_pk=pk,
+                model_type=model.__tablename__
+            )
+            gmr.save()
+
+        return gmr.id
+
+    def _set_parent_with_dispatch(self, parent):
+        self.parent = parent
+
+
 
 class Location(Base):
     """ Represents a physical location """
@@ -1431,11 +1447,9 @@ class Activity(Base, ActivityMixin):
     generator = Column(Integer,
                        ForeignKey("core__generators.id"),
                        nullable=True)
-    object = Column(Integer,
-                    GenericForeignKey(),
+    object = Column(GenericForeignKey(),
                     nullable=False)
-    target = Column(Integer,
-                    GenericForeignKey(),
+    target = Column(GenericForeignKey(),
                     nullable=True)
 
     get_actor = relationship(User,