168168from collections import namedtuple
169169
170170import psycopg2
171+ from inspect import isclass
171172from postgres .context_managers import ConnectionContextManager
172173from postgres .context_managers import CursorContextManager
173174from postgres .cursors import SimpleTupleCursor , SimpleNamedTupleCursor
@@ -214,8 +215,8 @@ def __str__(self):
214215class NotAModel (Exception ):
215216 def __str__ (self ):
216217 return "Only subclasses of postgres.orm.Model can be registered as " \
217- "orm models. {} (registered for {}) doesn't fit the bill." \
218- .format (self .args [0 ]. __name__ , self . args [ 1 ] )
218+ "orm models. {} doesn't fit the bill." \
219+ .format (self .args [0 ])
219220
220221class NoTypeSpecified (Exception ):
221222 def __str__ (self ):
@@ -621,8 +622,7 @@ def register_model(self, ModelSubclass, typname=None):
621622 subclassing :py:class:`~postgres.orm.Model`.
622623
623624 """
624- if not issubclass (ModelSubclass , Model ):
625- raise NotAModel (ModelSubclass )
625+ self ._validate_model_subclass (ModelSubclass )
626626
627627 if typname is None :
628628 typname = getattr (ModelSubclass , 'typname' , None )
@@ -676,11 +676,14 @@ def unregister_model(self, ModelSubclass):
676676 del self .model_registry [key ]
677677
678678
679- def check_registration (self , ModelSubclass ):
679+ def check_registration (self , ModelSubclass , include_subsubclasses = False ):
680680 """Check whether an ORM model is registered.
681681
682682 :param ModelSubclass: the :py:class:`~postgres.orm.Model` subclass to
683683 check for
684+ :param bool include_subsubclasses: whether to also check for subclasses
685+ of :py:class:`ModelSubclass` or just :py:class:`ModelSubclass`
686+ itself
684687
685688 :returns: the :py:attr:`typname` (a string) for which this model is
686689 registered, or a list of strings if it's registered for multiple
@@ -690,7 +693,13 @@ def check_registration(self, ModelSubclass):
690693 :raises: :py:exc:`~postgres.NotRegistered`
691694
692695 """
693- keys = [k for k ,v in self .model_registry .items () if v is ModelSubclass ]
696+ self ._validate_model_subclass (ModelSubclass )
697+
698+ if include_subsubclasses :
699+ filt = lambda v : v is ModelSubclass or issubclass (ModelSubclass , v )
700+ else :
701+ filt = lambda v : v is ModelSubclass
702+ keys = [k for k ,v in self .model_registry .items () if filt (v )]
694703 if not keys :
695704 raise NotRegistered (ModelSubclass )
696705 if len (keys ) == 1 :
@@ -700,6 +709,11 @@ def check_registration(self, ModelSubclass):
700709 return keys
701710
702711
712+ def _validate_model_subclass (self , ModelSubclass , ):
713+ if not isclass (ModelSubclass ) or not issubclass (ModelSubclass , Model ):
714+ raise NotAModel (ModelSubclass )
715+
716+
703717# Class Factories
704718# ===============
705719
0 commit comments