99from .preview import preview , repr_html
1010from .condition import (
1111 AndList ,
12+ Top ,
1213 Not ,
1314 make_condition ,
1415 assert_join_compatibility ,
@@ -52,6 +53,7 @@ class QueryExpression:
5253 _connection = None
5354 _heading = None
5455 _support = None
56+ _top = None
5557
5658 # If the query will be using distinct
5759 _distinct = False
@@ -121,17 +123,33 @@ def where_clause(self):
121123 else " WHERE (%s)" % ")AND(" .join (str (s ) for s in self .restriction )
122124 )
123125
126+ def sorting_clauses (self ):
127+ if not self ._top :
128+ return ""
129+ clause = ", " .join (
130+ _wrap_attributes (
131+ _flatten_attribute_list (self .primary_key , self ._top .order_by )
132+ )
133+ )
134+ if clause :
135+ clause = f" ORDER BY { clause } "
136+ if self ._top .limit is not None :
137+ clause += f" LIMIT { self ._top .limit } { f' OFFSET { self ._top .offset } ' if self ._top .offset else '' } "
138+
139+ return clause
140+
124141 def make_sql (self , fields = None ):
125142 """
126143 Make the SQL SELECT statement.
127144
128145 :param fields: used to explicitly set the select attributes
129146 """
130- return "SELECT {distinct}{fields} FROM {from_}{where}" .format (
147+ return "SELECT {distinct}{fields} FROM {from_}{where}{sorting} " .format (
131148 distinct = "DISTINCT " if self ._distinct else "" ,
132149 fields = self .heading .as_sql (fields or self .heading .names ),
133150 from_ = self .from_clause (),
134151 where = self .where_clause (),
152+ sorting = self .sorting_clauses (),
135153 )
136154
137155 # --------- query operators -----------
@@ -189,6 +207,14 @@ def restrict(self, restriction):
189207 string, or an AndList.
190208 """
191209 attributes = set ()
210+ if isinstance (restriction , Top ):
211+ result = (
212+ self .make_subquery ()
213+ if self ._top and not self ._top .__eq__ (restriction )
214+ else copy .copy (self )
215+ ) # make subquery to avoid overwriting existing Top
216+ result ._top = restriction
217+ return result
192218 new_condition = make_condition (self , restriction , attributes )
193219 if new_condition is True :
194220 return self # restriction has no effect, return the same object
@@ -202,8 +228,10 @@ def restrict(self, restriction):
202228 pass # all ok
203229 # If the new condition uses any new attributes, a subquery is required.
204230 # However, Aggregation's HAVING statement works fine with aliased attributes.
205- need_subquery = isinstance (self , Union ) or (
206- not isinstance (self , Aggregation ) and self .heading .new_attributes
231+ need_subquery = (
232+ isinstance (self , Union )
233+ or (not isinstance (self , Aggregation ) and self .heading .new_attributes )
234+ or self ._top
207235 )
208236 if need_subquery :
209237 result = self .make_subquery ()
@@ -539,19 +567,20 @@ def tail(self, limit=25, **fetch_kwargs):
539567
540568 def __len__ (self ):
541569 """:return: number of elements in the result set e.g. ``len(q1)``."""
542- return self .connection .query (
570+ result = self .make_subquery () if self ._top else copy .copy (self )
571+ return result .connection .query (
543572 "SELECT {select_} FROM {from_}{where}" .format (
544573 select_ = (
545574 "count(*)"
546- if any (self ._left )
575+ if any (result ._left )
547576 else "count(DISTINCT {fields})" .format (
548- fields = self .heading .as_sql (
549- self .primary_key , include_aliases = False
577+ fields = result .heading .as_sql (
578+ result .primary_key , include_aliases = False
550579 )
551580 )
552581 ),
553- from_ = self .from_clause (),
554- where = self .where_clause (),
582+ from_ = result .from_clause (),
583+ where = result .where_clause (),
555584 )
556585 ).fetchone ()[0 ]
557586
@@ -619,18 +648,12 @@ def __next__(self):
619648 # -- move on to next entry.
620649 return next (self )
621650
622- def cursor (self , offset = 0 , limit = None , order_by = None , as_dict = False ):
651+ def cursor (self , as_dict = False ):
623652 """
624653 See expression.fetch() for input description.
625654 :return: query cursor
626655 """
627- if offset and limit is None :
628- raise DataJointError ("limit is required when offset is set" )
629656 sql = self .make_sql ()
630- if order_by is not None :
631- sql += " ORDER BY " + ", " .join (order_by )
632- if limit is not None :
633- sql += " LIMIT %d" % limit + (" OFFSET %d" % offset if offset else "" )
634657 logger .debug (sql )
635658 return self .connection .query (sql , as_dict = as_dict )
636659
@@ -701,23 +724,26 @@ def make_sql(self, fields=None):
701724 fields = self .heading .as_sql (fields or self .heading .names )
702725 assert self ._grouping_attributes or not self .restriction
703726 distinct = set (self .heading .names ) == set (self .primary_key )
704- return "SELECT {distinct}{fields} FROM {from_}{where}{group_by}" .format (
705- distinct = "DISTINCT " if distinct else "" ,
706- fields = fields ,
707- from_ = self .from_clause (),
708- where = self .where_clause (),
709- group_by = (
710- ""
711- if not self .primary_key
712- else (
713- " GROUP BY `%s`" % "`,`" .join (self ._grouping_attributes )
714- + (
715- ""
716- if not self .restriction
717- else " HAVING (%s)" % ")AND(" .join (self .restriction )
727+ return (
728+ "SELECT {distinct}{fields} FROM {from_}{where}{group_by}{sorting}" .format (
729+ distinct = "DISTINCT " if distinct else "" ,
730+ fields = fields ,
731+ from_ = self .from_clause (),
732+ where = self .where_clause (),
733+ group_by = (
734+ ""
735+ if not self .primary_key
736+ else (
737+ " GROUP BY `%s`" % "`,`" .join (self ._grouping_attributes )
738+ + (
739+ ""
740+ if not self .restriction
741+ else " HAVING (%s)" % ")AND(" .join (self .restriction )
742+ )
718743 )
719- )
720- ),
744+ ),
745+ sorting = self .sorting_clauses (),
746+ )
721747 )
722748
723749 def __len__ (self ):
@@ -776,7 +802,7 @@ def make_sql(self):
776802 ):
777803 # no secondary attributes: use UNION DISTINCT
778804 fields = arg1 .primary_key
779- return "SELECT * FROM (({sql1}) UNION ({sql2})) as `_u{alias}`" .format (
805+ return "SELECT * FROM (({sql1}) UNION ({sql2})) as `_u{alias}{sorting} `" .format (
780806 sql1 = (
781807 arg1 .make_sql ()
782808 if isinstance (arg1 , Union )
@@ -788,6 +814,7 @@ def make_sql(self):
788814 else arg2 .make_sql (fields )
789815 ),
790816 alias = next (self .__count ),
817+ sorting = self .sorting_clauses (),
791818 )
792819 # with secondary attributes, use union of left join with antijoin
793820 fields = self .heading .names
@@ -939,3 +966,25 @@ def aggr(self, group, **named_attributes):
939966 )
940967
941968 aggregate = aggr # alias for aggr
969+
970+
971+ def _flatten_attribute_list (primary_key , attrs ):
972+ """
973+ :param primary_key: list of attributes in primary key
974+ :param attrs: list of attribute names, which may include "KEY", "KEY DESC" or "KEY ASC"
975+ :return: generator of attributes where "KEY" is replaced with its component attributes
976+ """
977+ for a in attrs :
978+ if re .match (r"^\s*KEY(\s+[aA][Ss][Cc])?\s*$" , a ):
979+ if primary_key :
980+ yield from primary_key
981+ elif re .match (r"^\s*KEY\s+[Dd][Ee][Ss][Cc]\s*$" , a ):
982+ if primary_key :
983+ yield from (q + " DESC" for q in primary_key )
984+ else :
985+ yield a
986+
987+
988+ def _wrap_attributes (attr ):
989+ for entry in attr : # wrap attribute names in backquotes
990+ yield re .sub (r"\b((?!asc|desc)\w+)\b" , r"`\1`" , entry , flags = re .IGNORECASE )
0 commit comments