55from sqlalchemy .sql import ClauseElement
66
77from . import json_support
8- from .declarative import Model
8+ from .declarative import Model , InvertDict
99from .exceptions import NoSuchRowError
1010from .loader import AliasLoader , ModelLoader
1111
@@ -78,7 +78,7 @@ class UpdateRequest:
7878 specific model instance and its database row.
7979
8080 """
81- def __init__ (self , instance ):
81+ def __init__ (self , instance : 'CRUDModel' ):
8282 self ._instance = instance
8383 self ._values = {}
8484 self ._props = {}
@@ -88,7 +88,7 @@ def __init__(self, instance):
8888 try :
8989 self ._locator = instance .lookup ()
9090 except LookupError :
91- # apply() will fail anyway, but still allow updates ()
91+ # apply() will fail anyway, but still allow update ()
9292 pass
9393
9494 def _set (self , key , value ):
@@ -124,7 +124,7 @@ async def apply(self, bind=None, timeout=DEFAULT):
124124 json_updates = {}
125125 for prop , value in self ._props .items ():
126126 value = prop .save (self ._instance , value )
127- updates = json_updates .setdefault (prop .column_name , {})
127+ updates = json_updates .setdefault (prop .prop_name , {})
128128 if self ._literal :
129129 updates [prop .name ] = value
130130 else :
@@ -133,26 +133,28 @@ async def apply(self, bind=None, timeout=DEFAULT):
133133 elif not isinstance (value , ClauseElement ):
134134 value = sa .cast (value , sa .Unicode )
135135 updates [sa .cast (prop .name , sa .Unicode )] = value
136- for column_name , updates in json_updates .items ():
137- column = getattr (cls , column_name )
136+ for prop_name , updates in json_updates .items ():
137+ prop = getattr (cls , prop_name )
138138 from .dialects .asyncpg import JSONB
139- if isinstance (column .type , JSONB ):
139+ if isinstance (prop .type , JSONB ):
140140 if self ._literal :
141- values [column_name ] = column .concat (updates )
141+ values [prop_name ] = prop .concat (updates )
142142 else :
143- values [column_name ] = column .concat (
143+ values [prop_name ] = prop .concat (
144144 sa .func .jsonb_build_object (
145145 * itertools .chain (* updates .items ())))
146146 else :
147- raise TypeError ('{} is not supported.' .format (column .type ))
147+ raise TypeError ('{} is not supported to update json '
148+ 'properties in Gino. Please consider using '
149+ 'JSONB.' .format (prop .type ))
148150
149151 opts = dict (return_model = False )
150152 if timeout is not DEFAULT :
151153 opts ['timeout' ] = timeout
152154 clause = type (self ._instance ).update .where (
153155 self ._locator ,
154156 ).values (
155- ** values ,
157+ ** self . _instance . _get_sa_values ( values ) ,
156158 ).returning (
157159 * [getattr (cls , key ) for key in values ],
158160 ).execution_options (** opts )
@@ -161,7 +163,9 @@ async def apply(self, bind=None, timeout=DEFAULT):
161163 row = await bind .first (clause )
162164 if not row :
163165 raise NoSuchRowError ()
164- self ._instance .__values__ .update (row )
166+ for k , v in row .items ():
167+ self ._instance .__values__ [
168+ self ._instance ._column_name_map .invert_get (k )] = v
165169 for prop in self ._props :
166170 prop .reload (self ._instance )
167171 return self
@@ -409,6 +413,7 @@ class CRUDModel(Model):
409413 """
410414
411415 _update_request_cls = UpdateRequest
416+ _column_name_map = InvertDict ()
412417
413418 def __init__ (self , ** values ):
414419 super ().__init__ ()
@@ -421,10 +426,10 @@ def _init_table(cls, sub_cls):
421426 for each_cls in sub_cls .__mro__ [::- 1 ]:
422427 for k , v in each_cls .__dict__ .items ():
423428 if isinstance (v , json_support .JSONProperty ):
424- if not hasattr (sub_cls , v .column_name ):
429+ if not hasattr (sub_cls , v .prop_name ):
425430 raise AttributeError (
426431 'Requires "{}" JSON[B] column.' .format (
427- v .column_name ))
432+ v .prop_name ))
428433 v .name = k
429434 if rv is not None :
430435 rv .__model__ = weakref .ref (sub_cls )
@@ -440,12 +445,12 @@ async def _create(self, bind=None, timeout=DEFAULT):
440445 cls = type (self )
441446 # noinspection PyUnresolvedReferences,PyProtectedMember
442447 cls ._check_abstract ()
443- keys = set (self .__profile__ .keys () if self .__profile__ else [])
444- for key in keys :
448+ profile_keys = set (self .__profile__ .keys () if self .__profile__ else [])
449+ for key in profile_keys :
445450 cls .__dict__ .get (key ).save (self )
446451 # initialize default values
447452 for key , prop in cls .__dict__ .items ():
448- if key in keys :
453+ if key in profile_keys :
449454 continue
450455 if isinstance (prop , json_support .JSONProperty ):
451456 if prop .default is None or prop .after_get .method is not None :
@@ -458,15 +463,25 @@ async def _create(self, bind=None, timeout=DEFAULT):
458463 if timeout is not DEFAULT :
459464 opts ['timeout' ] = timeout
460465 # noinspection PyArgumentList
461- q = cls .__table__ .insert ().values (** self .__values__ ).returning (
462- * cls ).execution_options (** opts )
466+ q = cls .__table__ .insert ().values (
467+ ** self ._get_sa_values (self .__values__ )
468+ ).returning (
469+ * cls
470+ ).execution_options (** opts )
463471 if bind is None :
464472 bind = cls .__metadata__ .bind
465473 row = await bind .first (q )
466- self .__values__ .update (row )
474+ for k , v in row .items ():
475+ self .__values__ [self ._column_name_map .invert_get (k )] = v
467476 self .__profile__ = None
468477 return self
469478
479+ def _get_sa_values (self , instance_values : dict ) -> dict :
480+ values = {}
481+ for k , v in instance_values .items ():
482+ values [self ._column_name_map [k ]] = v
483+ return values
484+
470485 @classmethod
471486 async def get (cls , ident , bind = None , timeout = DEFAULT ):
472487 """
@@ -592,11 +607,12 @@ def to_dict(self):
592607
593608 """
594609 cls = type (self )
595- keys = set (c .name for c in cls )
610+ # noinspection PyTypeChecker
611+ keys = set (cls ._column_name_map .invert_get (c .name ) for c in cls )
596612 for key , prop in cls .__dict__ .items ():
597613 if isinstance (prop , json_support .JSONProperty ):
598614 keys .add (key )
599- keys .discard (prop .column_name )
615+ keys .discard (prop .prop_name )
600616 return dict ((k , getattr (self , k )) for k in keys )
601617
602618 @classmethod
0 commit comments