@@ -56,7 +56,7 @@ def __init__(self, model, *column_names, **extras):
5656 self .model = model
5757 self ._distinct = None
5858 if column_names :
59- self .columns = [ getattr (model , name ) for name in column_names ]
59+ self .columns = self . _column_loader (model , column_names )
6060 else :
6161 self .columns = model
6262 self .extras = dict ((key , self .get (value ))
@@ -123,11 +123,28 @@ def get_from(self):
123123
124124 def load (self , * column_names , ** extras ):
125125 if column_names :
126- self .columns = [getattr (self .model , name ) for name in column_names ]
126+ self .columns = self ._column_loader (self .model , column_names )
127+
127128 self .extras .update ((key , self .get (value ))
128129 for key , value in extras .items ())
129130 return self
130131
132+ @classmethod
133+ def _column_loader (cls , model , column_names ):
134+ def column_formatter (column_name ):
135+ if isinstance (column_name , str ):
136+ return getattr (model , column_name )
137+ elif isinstance (column_name , Column ):
138+ if column_name not in model :
139+ raise AttributeError ('Column {} does not belong '
140+ 'to this model' .format (column_name ))
141+ return column_name
142+ else :
143+ raise TypeError ('Unknown column name {} type {}' .
144+ format (column_name , type (column_name )))
145+
146+ return [column_formatter (column_name ) for column_name in column_names ]
147+
131148 def on (self , on_clause ):
132149 self .on_clause = on_clause
133150 return self
0 commit comments