From 6185a4b9e6465ecd9c4806ffeec7688f7baa1f2f Mon Sep 17 00:00:00 2001 From: Jessica Tallon Date: Tue, 28 Apr 2015 19:53:25 +0200 Subject: [PATCH] Fix the GenericForeignKey implementation --- mediagoblin/db/models.py | 58 +++++++++++++++++++++++++--------------- 1 file changed, 36 insertions(+), 22 deletions(-) diff --git a/mediagoblin/db/models.py b/mediagoblin/db/models.py index 4b592792..b4cfe2a8 100644 --- a/mediagoblin/db/models.py +++ b/mediagoblin/db/models.py @@ -25,7 +25,7 @@ import datetime from sqlalchemy import Column, Integer, Unicode, UnicodeText, DateTime, \ Boolean, ForeignKey, UniqueConstraint, PrimaryKeyConstraint, \ - SmallInteger, Date + SmallInteger, Date, types from sqlalchemy.orm import relationship, backref, with_polymorphic, validates, \ class_mapper from sqlalchemy.orm.collections import attribute_mapped_collection @@ -69,7 +69,7 @@ class GenericModelReference(Base): return None model = self._get_model_from_type(self.model_type) - return model.query.filter_by(id=self.obj_pk) + return model.query.filter_by(id=self.obj_pk).first() def set_object(self, obj): model = obj.__class__ @@ -100,38 +100,30 @@ class GenericModelReference(Base): def _get_model_from_type(self, model_type): """ Gets a model from a tablename (model type) """ - if getattr(self, "_TYPE_MAP", None) is None: + if getattr(self.__class__, "_TYPE_MAP", None) is None: # We want to build on the class (not the instance) a map of all the # models by the table name (type) for easy lookup, this is done on # the class so it can be shared between all instances # to prevent circular imports do import here self._TYPE_MAP = dict(((m.__tablename__, m) for m in MODELS)) - setattr(self.__class__._TYPE_MAP, self._TYPE_MAP) + setattr(self.__class__, "_TYPE_MAP", self._TYPE_MAP) - return self._TYPE_MAP[model_type] + return self.__class__._TYPE_MAP[model_type] -class GenericForeignKey(ForeignKey): +class GenericForeignKey(types.TypeDecorator): - def __init__(self, *args, **kwargs): - super(GenericForeignKey, self).__init__( - GenericModelReference.id, - *args, - **kwargs - ) + impl = Integer - def __get__(self, *args, **kwargs): + def process_result_value(self, value, *args, **kwargs): """ Looks up GenericModelReference and model for field """ - # Find the value of the foreign key. - ref = super(self, GenericForeignKey).__get__(*args, **kwargs) - # If this hasn't been set yet return None - if ref is None: + if value is None: return None # Look up the GenericModelReference for this. - gmr = GenericModelReference.query.filter_by(id=ref).first() + gmr = GenericModelReference.query.filter_by(id=value).first() # If it's set to something invalid (i.e. no GMR exists return None) if gmr is None: @@ -140,6 +132,30 @@ class GenericForeignKey(ForeignKey): # Ask the GMR for the corresponding model return gmr.get_object() + def process_bind_param(self, value, *args, **kwargs): + """ Save the foreign key """ + if value is None: + return None + + # Is there one for this already. + model = type(value) + pk = getattr(value, "id") + + gmr = GenericModelReference.query.filter_by(id=pk).first() + if gmr is None: + # We need to create one + gmr = GenericModelReference( + obj_pk=pk, + model_type=model.__tablename__ + ) + gmr.save() + + return gmr.id + + def _set_parent_with_dispatch(self, parent): + self.parent = parent + + class Location(Base): """ Represents a physical location """ @@ -1431,11 +1447,9 @@ class Activity(Base, ActivityMixin): generator = Column(Integer, ForeignKey("core__generators.id"), nullable=True) - object = Column(Integer, - GenericForeignKey(), + object = Column(GenericForeignKey(), nullable=False) - target = Column(Integer, - GenericForeignKey(), + target = Column(GenericForeignKey(), nullable=True) get_actor = relationship(User, -- 2.25.1