1- import operator
2- import csv
3- import six
41import codecs
2+ import csv
3+ import operator
54import os .path
65import re
6+
7+ import prettytable
8+ import six
79import sqlalchemy
810import sqlparse
9- import prettytable
10- from pgspecial .main import PGSpecial
11+
1112from .column_guesser import ColumnGuesserMixin
1213
14+ try :
15+ from pgspecial .main import PGSpecial
16+ except ImportError :
17+ PGSpecial = None
18+
1319
1420def unduplicate_field_names (field_names ):
1521 """Append a number to duplicate field names to make them unique. """
@@ -23,6 +29,7 @@ def unduplicate_field_names(field_names):
2329 res .append (k )
2430 return res
2531
32+
2633class UnicodeWriter (object ):
2734 """
2835 A CSV writer which will write rows to CSV file "f",
@@ -38,19 +45,17 @@ def __init__(self, f, dialect=csv.excel, encoding="utf-8", **kwds):
3845
3946 def writerow (self , row ):
4047 if six .PY2 :
41- _row = [s .encode ("utf-8" )
42- if hasattr (s , "encode" )
43- else s
48+ _row = [s .encode ("utf-8" ) if hasattr (s , "encode" ) else s
4449 for s in row ]
4550 else :
4651 _row = row
4752 self .writer .writerow (_row )
4853 # Fetch UTF-8 output from the queue ...
4954 data = self .queue .getvalue ()
5055 if six .PY2 :
51- data = data .decode ("utf-8" )
52- # ... and reencode it into the target encoding
53- data = self .encoder .encode (data )
56+ data = data .decode ("utf-8" )
57+ # ... and reencode it into the target encoding
58+ data = self .encoder .encode (data )
5459 # write to the target stream
5560 self .stream .write (data )
5661 # empty queue
@@ -61,14 +66,20 @@ def writerows(self, rows):
6166 for row in rows :
6267 self .writerow (row )
6368
69+
6470class CsvResultDescriptor (object ):
6571 """Provides IPython Notebook-friendly output for the feedback after a ``.csv`` called."""
72+
6673 def __init__ (self , file_path ):
6774 self .file_path = file_path
75+
6876 def __repr__ (self ):
69- return 'CSV results at %s' % os .path .join (os .path .abspath ('.' ), self .file_path )
77+ return 'CSV results at %s' % os .path .join (
78+ os .path .abspath ('.' ), self .file_path )
79+
7080 def _repr_html_ (self ):
71- return '<a href="%s">CSV results</a>' % os .path .join ('.' , 'files' , self .file_path )
81+ return '<a href="%s">CSV results</a>' % os .path .join ('.' , 'files' ,
82+ self .file_path )
7283
7384
7485def _nonbreaking_spaces (match_obj ):
@@ -81,6 +92,7 @@ def _nonbreaking_spaces(match_obj):
8192 spaces = ' ' * len (match_obj .group (2 ))
8293 return '%s%s' % (match_obj .group (1 ), spaces )
8394
95+
8496_cell_with_spaces_pattern = re .compile (r'(<td>)( {2,})' )
8597
8698
@@ -90,6 +102,7 @@ class ResultSet(list, ColumnGuesserMixin):
90102
91103 Can access rows listwise, or by string value of leftmost column.
92104 """
105+
93106 def __init__ (self , sqlaproxy , sql , config ):
94107 self .keys = sqlaproxy .keys ()
95108 self .sql = sql
@@ -115,7 +128,8 @@ def _repr_html_(self):
115128 self .pretty .add_rows (self )
116129 result = self .pretty .get_html_string ()
117130 result = _cell_with_spaces_pattern .sub (_nonbreaking_spaces , result )
118- if self .config .displaylimit and len (self ) > self .config .displaylimit :
131+ if self .config .displaylimit and len (
132+ self ) > self .config .displaylimit :
119133 result = '%s\n <span style="font-style:italic;text-align:center;">%d rows, truncated to displaylimit of %d</span>' % (
120134 result , len (self ), self .config .displaylimit )
121135 return result
@@ -140,6 +154,7 @@ def __getitem__(self, key):
140154 if len (result ) > 1 :
141155 raise KeyError ('%d results for "%s"' % (len (result ), key ))
142156 return result [0 ]
157+
143158 def dict (self ):
144159 """Returns a single dict built from the result set
145160
@@ -214,7 +229,7 @@ def plot(self, title=None, **kwargs):
214229 plt .ylabel (ylabel )
215230 return plot
216231
217- def bar (self , key_word_sep = " " , title = None , ** kwargs ):
232+ def bar (self , key_word_sep = " " , title = None , ** kwargs ):
218233 """Generates a pylab bar plot from the result set.
219234
220235 ``matplotlib`` must be installed, and in an
@@ -238,8 +253,7 @@ def bar(self, key_word_sep = " ", title=None, **kwargs):
238253 self .guess_pie_columns (xlabel_sep = key_word_sep )
239254 plot = plt .bar (range (len (self .ys [0 ])), self .ys [0 ], ** kwargs )
240255 if self .xlabels :
241- plt .xticks (range (len (self .xlabels )), self .xlabels ,
242- rotation = 45 )
256+ plt .xticks (range (len (self .xlabels )), self .xlabels , rotation = 45 )
243257 plt .xlabel (self .xlabel )
244258 plt .ylabel (self .ys [0 ].name )
245259 return plot
@@ -248,7 +262,7 @@ def csv(self, filename=None, **format_params):
248262 """Generate results in comma-separated form. Write to ``filename`` if given.
249263 Any other parameters will be passed on to csv.writer."""
250264 if not self .pretty :
251- return None # no results
265+ return None # no results
252266 self .pretty .add_rows (self )
253267 if filename :
254268 encoding = format_params .get ('encoding' , 'utf-8' )
@@ -276,17 +290,37 @@ def interpret_rowcount(rowcount):
276290 result = '%d rows affected.' % rowcount
277291 return result
278292
293+
279294class FakeResultProxy (object ):
280295 """A fake class that pretends to behave like the ResultProxy from
281296 SqlAlchemy.
282297 """
298+
283299 def __init__ (self , cursor , headers ):
284300 self .fetchall = cursor .fetchall
285301 self .fetchmany = cursor .fetchmany
286302 self .rowcount = cursor .rowcount
287303 self .keys = lambda : headers
288304 self .returns_rows = True
289305
306+ # some dialects have autocommit
307+ # specific dialects break when commit is used:
308+ _COMMIT_BLACKLIST_DIALECTS = ('mssql' , 'clickhouse' )
309+
310+
311+ def _commit (conn , config ):
312+ """Issues a commit, if appropriate for current config and dialect"""
313+
314+ _should_commit = config .autocommit and all (
315+ dialect not in str (conn .dialect )
316+ for dialect in _COMMIT_BLACKLIST_DIALECTS )
317+
318+ if _should_commit :
319+ try :
320+ conn .session .execute ('commit' )
321+ except sqlalchemy .exc .OperationalError :
322+ pass # not all engines can commit
323+
290324
291325def run (conn , sql , config , user_namespace ):
292326 if sql .strip ():
@@ -295,20 +329,16 @@ def run(conn, sql, config, user_namespace):
295329 if first_word == 'begin' :
296330 raise Exception ("ipython_sql does not support transactions" )
297331 if first_word .startswith ('\\ ' ) and 'postgres' in str (conn .dialect ):
332+ if not PGSpecial :
333+ raise ImportError ('pgspecial not installed' )
298334 pgspecial = PGSpecial ()
299335 _ , cur , headers , _ = pgspecial .execute (
300- conn .session .connection .cursor (),
301- statement )[0 ]
336+ conn .session .connection .cursor (), statement )[0 ]
302337 result = FakeResultProxy (cur , headers )
303338 else :
304339 txt = sqlalchemy .sql .text (statement )
305340 result = conn .session .execute (txt , user_namespace )
306- try :
307- # mssql has autocommit
308- if config .autocommit and ('mssql' not in str (conn .dialect )):
309- conn .session .execute ('commit' )
310- except sqlalchemy .exc .OperationalError :
311- pass # not all engines can commit
341+ _commit (conn = conn , config = config )
312342 if result and config .feedback :
313343 print (interpret_rowcount (result .rowcount ))
314344 resultset = ResultSet (result , statement , config )
@@ -322,11 +352,10 @@ def run(conn, sql, config, user_namespace):
322352
323353
324354class PrettyTable (prettytable .PrettyTable ):
325-
326355 def __init__ (self , * args , ** kwargs ):
327356 self .row_count = 0
328357 self .displaylimit = None
329- return super (PrettyTable , self ).__init__ (* args , ** kwargs )
358+ return super (PrettyTable , self ).__init__ (* args , ** kwargs )
330359
331360 def add_rows (self , data ):
332361 if self .row_count and (data .config .displaylimit == self .displaylimit ):
0 commit comments