55from sqlalchemy .sql import ClauseElement
66
77from . import json_support
8- from .declarative import Model
8+ from .declarative import Model , InvertDict
99from .exceptions import NoSuchRowError
1010from .loader import AliasLoader , ModelLoader
1111
@@ -78,7 +78,7 @@ class UpdateRequest:
7878 specific model instance and its database row.
7979
8080 """
81- def __init__ (self , instance ):
81+ def __init__ (self , instance : 'CRUDModel' ):
8282 self ._instance = instance
8383 self ._values = {}
8484 self ._props = {}
@@ -124,7 +124,7 @@ async def apply(self, bind=None, timeout=DEFAULT):
124124 json_updates = {}
125125 for prop , value in self ._props .items ():
126126 value = prop .save (self ._instance , value )
127- updates = json_updates .setdefault (prop .column_name , {})
127+ updates = json_updates .setdefault (prop .prop_name , {})
128128 if self ._literal :
129129 updates [prop .name ] = value
130130 else :
@@ -133,26 +133,26 @@ async def apply(self, bind=None, timeout=DEFAULT):
133133 elif not isinstance (value , ClauseElement ):
134134 value = sa .cast (value , sa .Unicode )
135135 updates [sa .cast (prop .name , sa .Unicode )] = value
136- for column_name , updates in json_updates .items ():
137- column = getattr (cls , column_name )
136+ for prop_name , updates in json_updates .items ():
137+ prop = getattr (cls , prop_name )
138138 from .dialects .asyncpg import JSONB
139- if isinstance (column .type , JSONB ):
139+ if isinstance (prop .type , JSONB ):
140140 if self ._literal :
141- values [column_name ] = column .concat (updates )
141+ values [prop_name ] = prop .concat (updates )
142142 else :
143- values [column_name ] = column .concat (
143+ values [prop_name ] = prop .concat (
144144 sa .func .jsonb_build_object (
145145 * itertools .chain (* updates .items ())))
146146 else :
147- raise TypeError ('{} is not supported.' .format (column .type ))
147+ raise TypeError ('{} is not supported.' .format (prop .type ))
148148
149149 opts = dict (return_model = False )
150150 if timeout is not DEFAULT :
151151 opts ['timeout' ] = timeout
152152 clause = type (self ._instance ).update .where (
153153 self ._locator ,
154154 ).values (
155- ** values ,
155+ ** self . _instance . _get_sa_values ( values ) ,
156156 ).returning (
157157 * [getattr (cls , key ) for key in values ],
158158 ).execution_options (** opts )
@@ -161,7 +161,9 @@ async def apply(self, bind=None, timeout=DEFAULT):
161161 row = await bind .first (clause )
162162 if not row :
163163 raise NoSuchRowError ()
164- self ._instance .__values__ .update (row )
164+ for k , v in row .items ():
165+ self ._instance .__values__ [
166+ self ._instance ._column_name_map .invert_get (k )] = v
165167 for prop in self ._props :
166168 prop .reload (self ._instance )
167169 return self
@@ -409,6 +411,7 @@ class CRUDModel(Model):
409411 """
410412
411413 _update_request_cls = UpdateRequest
414+ _column_name_map = InvertDict ()
412415
413416 def __init__ (self , ** values ):
414417 super ().__init__ ()
@@ -421,10 +424,10 @@ def _init_table(cls, sub_cls):
421424 for each_cls in sub_cls .__mro__ [::- 1 ]:
422425 for k , v in each_cls .__dict__ .items ():
423426 if isinstance (v , json_support .JSONProperty ):
424- if not hasattr (sub_cls , v .column_name ):
427+ if not hasattr (sub_cls , v .prop_name ):
425428 raise AttributeError (
426429 'Requires "{}" JSON[B] column.' .format (
427- v .column_name ))
430+ v .prop_name ))
428431 v .name = k
429432 if rv is not None :
430433 rv .__model__ = weakref .ref (sub_cls )
@@ -440,12 +443,12 @@ async def _create(self, bind=None, timeout=DEFAULT):
440443 cls = type (self )
441444 # noinspection PyUnresolvedReferences,PyProtectedMember
442445 cls ._check_abstract ()
443- keys = set (self .__profile__ .keys () if self .__profile__ else [])
444- for key in keys :
446+ profile_keys = set (self .__profile__ .keys () if self .__profile__ else [])
447+ for key in profile_keys :
445448 cls .__dict__ .get (key ).save (self )
446449 # initialize default values
447450 for key , prop in cls .__dict__ .items ():
448- if key in keys :
451+ if key in profile_keys :
449452 continue
450453 if isinstance (prop , json_support .JSONProperty ):
451454 if prop .default is None or prop .after_get .method is not None :
@@ -458,15 +461,25 @@ async def _create(self, bind=None, timeout=DEFAULT):
458461 if timeout is not DEFAULT :
459462 opts ['timeout' ] = timeout
460463 # noinspection PyArgumentList
461- q = cls .__table__ .insert ().values (** self .__values__ ).returning (
462- * cls ).execution_options (** opts )
464+ q = cls .__table__ .insert ().values (
465+ ** self ._get_sa_values (self .__values__ )
466+ ).returning (
467+ * cls
468+ ).execution_options (** opts )
463469 if bind is None :
464470 bind = cls .__metadata__ .bind
465471 row = await bind .first (q )
466- self .__values__ .update (row )
472+ for k , v in row .items ():
473+ self .__values__ [self ._column_name_map .invert_get (k )] = v
467474 self .__profile__ = None
468475 return self
469476
477+ def _get_sa_values (self , instance_values : dict ) -> dict :
478+ values = {}
479+ for k , v in instance_values .items ():
480+ values [self ._column_name_map [k ]] = v
481+ return values
482+
470483 @classmethod
471484 async def get (cls , ident , bind = None , timeout = DEFAULT ):
472485 """
@@ -592,11 +605,12 @@ def to_dict(self):
592605
593606 """
594607 cls = type (self )
595- keys = set (c .name for c in cls )
608+ # noinspection PyTypeChecker
609+ keys = set (cls ._column_name_map .invert_get (c .name ) for c in cls )
596610 for key , prop in cls .__dict__ .items ():
597611 if isinstance (prop , json_support .JSONProperty ):
598612 keys .add (key )
599- keys .discard (prop .column_name )
613+ keys .discard (prop .prop_name )
600614 return dict ((k , getattr (self , k )) for k in keys )
601615
602616 @classmethod
0 commit comments