@@ -13,6 +13,7 @@ import (
1313 "database/sql/driver"
1414 "errors"
1515 "net"
16+ "strconv"
1617 "strings"
1718 "time"
1819)
@@ -26,6 +27,7 @@ type mysqlConn struct {
2627 maxPacketAllowed int
2728 maxWriteSize int
2829 flags clientFlag
30+ status statusFlag
2931 sequence uint8
3032 parseTime bool
3133 strict bool
@@ -46,6 +48,7 @@ type config struct {
4648 allowOldPasswords bool
4749 clientFoundRows bool
4850 columnsWithAlias bool
51+ interpolateParams bool
4952}
5053
5154// Handles parameters set in DSN after the connection is established
@@ -162,28 +165,174 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
162165 return stmt , err
163166}
164167
168+ // estimateParamLength calculates upper bound of string length from types.
169+ func estimateParamLength (args []driver.Value ) (int , bool ) {
170+ l := 0
171+ for _ , a := range args {
172+ switch v := a .(type ) {
173+ case int64 , float64 :
174+ // 24 (-1.7976931348623157e+308) may be upper bound. But I'm not sure.
175+ l += 25
176+ case bool :
177+ l += 1 // 0 or 1
178+ case time.Time :
179+ l += 30 // '1234-12-23 12:34:56.777777'
180+ case string :
181+ l += len (v )* 2 + 2
182+ case []byte :
183+ l += len (v )* 2 + 2
184+ default :
185+ return 0 , false
186+ }
187+ }
188+ return l , true
189+ }
190+
191+ func (mc * mysqlConn ) interpolateParams (query string , args []driver.Value ) (string , error ) {
192+ estimated , ok := estimateParamLength (args )
193+ if ! ok {
194+ return "" , driver .ErrSkip
195+ }
196+ estimated += len (query )
197+
198+ buf := make ([]byte , 0 , estimated )
199+ argPos := 0
200+
201+ for i := 0 ; i < len (query ); i ++ {
202+ q := strings .IndexByte (query [i :], '?' )
203+ if q == - 1 {
204+ buf = append (buf , query [i :]... )
205+ break
206+ }
207+ buf = append (buf , query [i :i + q ]... )
208+ i += q
209+
210+ arg := args [argPos ]
211+ argPos ++
212+
213+ if arg == nil {
214+ buf = append (buf , "NULL" ... )
215+ continue
216+ }
217+
218+ switch v := arg .(type ) {
219+ case int64 :
220+ buf = strconv .AppendInt (buf , v , 10 )
221+ case float64 :
222+ buf = strconv .AppendFloat (buf , v , 'g' , - 1 , 64 )
223+ case bool :
224+ if v {
225+ buf = append (buf , '1' )
226+ } else {
227+ buf = append (buf , '0' )
228+ }
229+ case time.Time :
230+ if v .IsZero () {
231+ buf = append (buf , "'0000-00-00'" ... )
232+ } else {
233+ v := v .In (mc .cfg .loc )
234+ v = v .Add (time .Nanosecond * 500 ) // To round under microsecond
235+ year := v .Year ()
236+ year100 := year / 100
237+ year1 := year % 100
238+ month := v .Month ()
239+ day := v .Day ()
240+ hour := v .Hour ()
241+ minute := v .Minute ()
242+ second := v .Second ()
243+ micro := v .Nanosecond () / 1000
244+
245+ buf = append (buf , []byte {
246+ '\'' ,
247+ digits10 [year100 ], digits01 [year100 ],
248+ digits10 [year1 ], digits01 [year1 ],
249+ '-' ,
250+ digits10 [month ], digits01 [month ],
251+ '-' ,
252+ digits10 [day ], digits01 [day ],
253+ ' ' ,
254+ digits10 [hour ], digits01 [hour ],
255+ ':' ,
256+ digits10 [minute ], digits01 [minute ],
257+ ':' ,
258+ digits10 [second ], digits01 [second ],
259+ }... )
260+
261+ if micro != 0 {
262+ micro10000 := micro / 10000
263+ micro100 := micro / 100 % 100
264+ micro1 := micro % 100
265+ buf = append (buf , []byte {
266+ '.' ,
267+ digits10 [micro10000 ], digits01 [micro10000 ],
268+ digits10 [micro100 ], digits01 [micro100 ],
269+ digits10 [micro1 ], digits01 [micro1 ],
270+ }... )
271+ }
272+ buf = append (buf , '\'' )
273+ }
274+ case []byte :
275+ if v == nil {
276+ buf = append (buf , "NULL" ... )
277+ } else {
278+ buf = append (buf , '\'' )
279+ if mc .status & statusNoBackslashEscapes == 0 {
280+ buf = escapeBytesBackslash (buf , v )
281+ } else {
282+ buf = escapeBytesQuotes (buf , v )
283+ }
284+ buf = append (buf , '\'' )
285+ }
286+ case string :
287+ buf = append (buf , '\'' )
288+ if mc .status & statusNoBackslashEscapes == 0 {
289+ buf = escapeStringBackslash (buf , v )
290+ } else {
291+ buf = escapeStringQuotes (buf , v )
292+ }
293+ buf = append (buf , '\'' )
294+ default :
295+ return "" , driver .ErrSkip
296+ }
297+
298+ if len (buf )+ 4 > mc .maxPacketAllowed {
299+ return "" , driver .ErrSkip
300+ }
301+ }
302+ if argPos != len (args ) {
303+ return "" , driver .ErrSkip
304+ }
305+ return string (buf ), nil
306+ }
307+
165308func (mc * mysqlConn ) Exec (query string , args []driver.Value ) (driver.Result , error ) {
166309 if mc .netConn == nil {
167310 errLog .Print (ErrInvalidConn )
168311 return nil , driver .ErrBadConn
169312 }
170- if len (args ) == 0 { // no args, fastpath
171- mc .affectedRows = 0
172- mc .insertId = 0
173-
174- err := mc .exec (query )
175- if err == nil {
176- return & mysqlResult {
177- affectedRows : int64 (mc .affectedRows ),
178- insertId : int64 (mc .insertId ),
179- }, err
313+ if len (args ) != 0 {
314+ if ! mc .cfg .interpolateParams {
315+ return nil , driver .ErrSkip
180316 }
181- return nil , err
317+ // try to interpolate the parameters to save extra roundtrips for preparing and closing a statement
318+ prepared , err := mc .interpolateParams (query , args )
319+ if err != nil {
320+ return nil , err
321+ }
322+ query = prepared
323+ args = nil
182324 }
325+ mc .affectedRows = 0
326+ mc .insertId = 0
183327
184- // with args, must use prepared stmt
185- return nil , driver .ErrSkip
186-
328+ err := mc .exec (query )
329+ if err == nil {
330+ return & mysqlResult {
331+ affectedRows : int64 (mc .affectedRows ),
332+ insertId : int64 (mc .insertId ),
333+ }, err
334+ }
335+ return nil , err
187336}
188337
189338// Internal function to execute commands
@@ -212,31 +361,38 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
212361 errLog .Print (ErrInvalidConn )
213362 return nil , driver .ErrBadConn
214363 }
215- if len (args ) == 0 { // no args, fastpath
216- // Send command
217- err := mc .writeCommandPacketStr (comQuery , query )
364+ if len (args ) != 0 {
365+ if ! mc .cfg .interpolateParams {
366+ return nil , driver .ErrSkip
367+ }
368+ // try client-side prepare to reduce roundtrip
369+ prepared , err := mc .interpolateParams (query , args )
370+ if err != nil {
371+ return nil , err
372+ }
373+ query = prepared
374+ args = nil
375+ }
376+ // Send command
377+ err := mc .writeCommandPacketStr (comQuery , query )
378+ if err == nil {
379+ // Read Result
380+ var resLen int
381+ resLen , err = mc .readResultSetHeaderPacket ()
218382 if err == nil {
219- // Read Result
220- var resLen int
221- resLen , err = mc .readResultSetHeaderPacket ()
222- if err == nil {
223- rows := new (textRows )
224- rows .mc = mc
225-
226- if resLen == 0 {
227- // no columns, no more data
228- return emptyRows {}, nil
229- }
230- // Columns
231- rows .columns , err = mc .readColumns (resLen )
232- return rows , err
383+ rows := new (textRows )
384+ rows .mc = mc
385+
386+ if resLen == 0 {
387+ // no columns, no more data
388+ return emptyRows {}, nil
233389 }
390+ // Columns
391+ rows .columns , err = mc .readColumns (resLen )
392+ return rows , err
234393 }
235- return nil , err
236394 }
237-
238- // with args, must use prepared stmt
239- return nil , driver .ErrSkip
395+ return nil , err
240396}
241397
242398// Gets the value of the given MySQL System Variable
0 commit comments