Add GenericForeignKey field and reference helper model
authorJessica Tallon <jessica@megworld.co.uk>
Wed, 25 Feb 2015 12:41:53 +0000 (13:41 +0100)
committerJessica Tallon <jessica@megworld.co.uk>
Tue, 26 May 2015 14:48:58 +0000 (16:48 +0200)
mediagoblin/db/migrations.py
mediagoblin/db/models.py

index 74c1194fe078b4aa875596083d91dd03c22617dd..446f30df6ab586dea6753da6e8b103d915ae11c6 100644 (file)
@@ -1249,3 +1249,18 @@ def datetime_to_utc(db):
 
     # Commit this to the database
     db.commit()
+
+class GenericModelReference_V0(declarative_base()):
+    __tablename__ = "core__generic_model_reference"
+
+    id = Column(Integer, primary_key=True)
+    obj_pk = Column(Integer, nullable=False)
+    model_type = Column(Unicode, nullable=False)
+
+@RegisterMigration(27, MIGRATIONS)
+def create_generic_model_reference(db):
+    """ Creates the Generic Model Reference table """
+    GenericModelReference_V0.__table__.create(db.bind)
+    db.commit()
+
+
index e8fb17a76ae40e3df5f2a622d2ec566f44ade97c..97f8b398dd6aedc8afd1e0b76130ed695a1ae0ef 100644 (file)
@@ -26,7 +26,8 @@ import datetime
 from sqlalchemy import Column, Integer, Unicode, UnicodeText, DateTime, \
         Boolean, ForeignKey, UniqueConstraint, PrimaryKeyConstraint, \
         SmallInteger, Date
-from sqlalchemy.orm import relationship, backref, with_polymorphic, validates
+from sqlalchemy.orm import relationship, backref, with_polymorphic, validates, \
+        class_mapper
 from sqlalchemy.orm.collections import attribute_mapped_collection
 from sqlalchemy.sql.expression import desc
 from sqlalchemy.ext.associationproxy import association_proxy
@@ -47,6 +48,81 @@ from pytz import UTC
 
 _log = logging.getLogger(__name__)
 
+class GenericModelReference(Base):
+    """
+    Represents a relationship to any model that is defined with a integer pk
+
+    NB: This model should not be used directly but through the GenericForeignKey
+        field provided. 
+    """
+    __tablename__ = "core__generic_model_reference"
+
+    id = Column(Integer, primary_key=True)
+    obj_pk = Column(Integer, nullable=False)
+
+    # This will be the tablename of the model
+    model_type = Column(Unicode, nullable=False)
+
+    @property
+    def get(self):
+        # This can happen if it's yet to be saved
+        if self.model_type is None or self.obj_pk is None:
+            return None
+
+        model = self._get_model_from_type(self.model_type)
+        return model.query.filter_by(id=self.obj_pk)
+
+    @property
+    def set(self, obj):
+        model = obj.__class__
+
+        # Check we've been given a object
+        if not issubclass(model, Base):
+            raise ValueError("Only models can be set as GenericForeignKeys")
+
+        # Check that the model has an explicit __tablename__ declaration
+        if getattr(model, "__tablename__", None) is None:
+            raise ValueError("Models must have __tablename__ attribute")
+
+        # Check that it's not a composite primary key
+        primary_keys = [key.name for key in class_mapper(model).primary_key]
+        if len(primary_keys) > 1:
+            raise ValueError("Models can not have composite primary keys")
+
+        # Check that the field on the model is a an integer field
+        pk_column = getattr(model, primary_keys[0])
+        if issubclass(Integer, pk_column):
+            raise ValueError("Only models with integer pks can be set")
+
+        # Ensure that everything has it's ID set
+        obj.save(commit=False)
+
+        self.obj_pk = obj.id
+        self.model_type = obj.__tablename__
+
+    def _get_model_from_type(self, model_type):
+        """ Gets a model from a tablename (model type) """
+        if getattr(self, "_TYPE_MAP", None) is None:
+            # We want to build on the class (not the instance) a map of all the
+            # models by the table name (type) for easy lookup, this is done on
+            # the class so it can be shared between all instances
+
+            # to prevent circular imports do import here
+            self._TYPE_MAP = dict(((m.__tablename__, m) for m in MODELS))
+            setattr(self.__class__._TYPE_MAP, self._TYPE_MAP)
+
+        return self._TYPE_MAP[model_type]
+
+
+class GenericForeignKey(ForeignKey):
+
+    def __init__(self, *args, **kwargs):
+        super(GenericForeignKey, self).__init__(
+            "core__generic_model_reference.id",
+            *args,
+            **kwargs
+        )
+
 class Location(Base):
     """ Represents a physical location """
     __tablename__ = "core__locations"
@@ -1416,7 +1492,7 @@ MODELS = [
        Privilege, PrivilegeUserAssociation,
     RequestToken, AccessToken, NonceTimestamp,
     Activity, ActivityIntermediator, Generator,
-    Location]
+    Location, GenericModelReference]
 
 """
  Foundations are the default rows that are created immediately after the tables