55 "encoding/json"
66 "fmt"
77 "math"
8+ "runtime"
89
910 . "github.com/go-mysql-org/go-mysql/mysql"
1011 "github.com/go-mysql-org/go-mysql/utils"
@@ -56,18 +57,34 @@ func (s *Stmt) Close() error {
5657 return nil
5758}
5859
60+ // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html
5961func (s * Stmt ) write (args ... interface {}) error {
62+ defer clear (s .conn .queryAttributes )
6063 paramsNum := s .params
6164
6265 if len (args ) != paramsNum {
6366 return fmt .Errorf ("argument mismatch, need %d but got %d" , s .params , len (args ))
6467 }
6568
66- paramTypes := make ([]byte , paramsNum << 1 )
67- paramValues := make ([][]byte , paramsNum )
69+ if (s .conn .capability & CLIENT_QUERY_ATTRIBUTES > 0 ) && (s .conn .includeLine >= 0 ) {
70+ _ , file , line , ok := runtime .Caller (s .conn .includeLine )
71+ if ok {
72+ lineAttr := QueryAttribute {
73+ Name : "_line" ,
74+ Value : fmt .Sprintf ("%s:%d" , file , line ),
75+ }
76+ s .conn .queryAttributes = append (s .conn .queryAttributes , lineAttr )
77+ }
78+ }
79+
80+ qaLen := len (s .conn .queryAttributes )
81+ paramTypes := make ([][]byte , paramsNum + qaLen )
82+ paramFlags := make ([][]byte , paramsNum + qaLen )
83+ paramValues := make ([][]byte , paramsNum + qaLen )
84+ paramNames := make ([][]byte , paramsNum + qaLen )
6885
6986 //NULL-bitmap, length: (num-params+7)
70- nullBitmap := make ([]byte , (paramsNum + 7 )>> 3 )
87+ nullBitmap := make ([]byte , (paramsNum + qaLen + 7 )>> 3 )
7188
7289 length := 1 + 4 + 1 + 4 + ((paramsNum + 7 ) >> 3 ) + 1 + (paramsNum << 1 )
7390
@@ -76,76 +93,89 @@ func (s *Stmt) write(args ...interface{}) error {
7693 for i := range args {
7794 if args [i ] == nil {
7895 nullBitmap [i / 8 ] |= 1 << (uint (i ) % 8 )
79- paramTypes [i << 1 ] = MYSQL_TYPE_NULL
96+ paramTypes [i ] = []byte {MYSQL_TYPE_NULL }
97+ paramNames [i ] = []byte {0 } // length encoded, no name
98+ paramFlags [i ] = []byte {0 }
8099 continue
81100 }
82101
83102 newParamBoundFlag = 1
84103
85104 switch v := args [i ].(type ) {
86105 case int8 :
87- paramTypes [i << 1 ] = MYSQL_TYPE_TINY
106+ paramTypes [i ] = [] byte { MYSQL_TYPE_TINY }
88107 paramValues [i ] = []byte {byte (v )}
89108 case int16 :
90- paramTypes [i << 1 ] = MYSQL_TYPE_SHORT
109+ paramTypes [i ] = [] byte { MYSQL_TYPE_SHORT }
91110 paramValues [i ] = Uint16ToBytes (uint16 (v ))
92111 case int32 :
93- paramTypes [i << 1 ] = MYSQL_TYPE_LONG
112+ paramTypes [i ] = [] byte { MYSQL_TYPE_LONG }
94113 paramValues [i ] = Uint32ToBytes (uint32 (v ))
95114 case int :
96- paramTypes [i << 1 ] = MYSQL_TYPE_LONGLONG
115+ paramTypes [i ] = [] byte { MYSQL_TYPE_LONGLONG }
97116 paramValues [i ] = Uint64ToBytes (uint64 (v ))
98117 case int64 :
99- paramTypes [i << 1 ] = MYSQL_TYPE_LONGLONG
118+ paramTypes [i ] = [] byte { MYSQL_TYPE_LONGLONG }
100119 paramValues [i ] = Uint64ToBytes (uint64 (v ))
101120 case uint8 :
102- paramTypes [i << 1 ] = MYSQL_TYPE_TINY
103- paramTypes [( i << 1 ) + 1 ] = 0x80
121+ paramTypes [i ] = [] byte { MYSQL_TYPE_TINY }
122+ paramFlags [ i ] = [] byte { PARAM_UNSIGNED }
104123 paramValues [i ] = []byte {v }
105124 case uint16 :
106- paramTypes [i << 1 ] = MYSQL_TYPE_SHORT
107- paramTypes [( i << 1 ) + 1 ] = 0x80
125+ paramTypes [i ] = [] byte { MYSQL_TYPE_SHORT }
126+ paramFlags [ i ] = [] byte { PARAM_UNSIGNED }
108127 paramValues [i ] = Uint16ToBytes (v )
109128 case uint32 :
110- paramTypes [i << 1 ] = MYSQL_TYPE_LONG
111- paramTypes [( i << 1 ) + 1 ] = 0x80
129+ paramTypes [i ] = [] byte { MYSQL_TYPE_LONG }
130+ paramFlags [ i ] = [] byte { PARAM_UNSIGNED }
112131 paramValues [i ] = Uint32ToBytes (v )
113132 case uint :
114- paramTypes [i << 1 ] = MYSQL_TYPE_LONGLONG
115- paramTypes [( i << 1 ) + 1 ] = 0x80
133+ paramTypes [i ] = [] byte { MYSQL_TYPE_LONGLONG }
134+ paramFlags [ i ] = [] byte { PARAM_UNSIGNED }
116135 paramValues [i ] = Uint64ToBytes (uint64 (v ))
117136 case uint64 :
118- paramTypes [i << 1 ] = MYSQL_TYPE_LONGLONG
119- paramTypes [( i << 1 ) + 1 ] = 0x80
137+ paramTypes [i ] = [] byte { MYSQL_TYPE_LONGLONG }
138+ paramFlags [ i ] = [] byte { PARAM_UNSIGNED }
120139 paramValues [i ] = Uint64ToBytes (v )
121140 case bool :
122- paramTypes [i << 1 ] = MYSQL_TYPE_TINY
141+ paramTypes [i ] = [] byte { MYSQL_TYPE_TINY }
123142 if v {
124143 paramValues [i ] = []byte {1 }
125144 } else {
126145 paramValues [i ] = []byte {0 }
127146 }
128147 case float32 :
129- paramTypes [i << 1 ] = MYSQL_TYPE_FLOAT
148+ paramTypes [i ] = [] byte { MYSQL_TYPE_FLOAT }
130149 paramValues [i ] = Uint32ToBytes (math .Float32bits (v ))
131150 case float64 :
132- paramTypes [i << 1 ] = MYSQL_TYPE_DOUBLE
151+ paramTypes [i ] = [] byte { MYSQL_TYPE_DOUBLE }
133152 paramValues [i ] = Uint64ToBytes (math .Float64bits (v ))
134153 case string :
135- paramTypes [i << 1 ] = MYSQL_TYPE_STRING
154+ paramTypes [i ] = [] byte { MYSQL_TYPE_STRING }
136155 paramValues [i ] = append (PutLengthEncodedInt (uint64 (len (v ))), v ... )
137156 case []byte :
138- paramTypes [i << 1 ] = MYSQL_TYPE_STRING
157+ paramTypes [i ] = [] byte { MYSQL_TYPE_STRING }
139158 paramValues [i ] = append (PutLengthEncodedInt (uint64 (len (v ))), v ... )
140159 case json.RawMessage :
141- paramTypes [i << 1 ] = MYSQL_TYPE_STRING
160+ paramTypes [i ] = [] byte { MYSQL_TYPE_STRING }
142161 paramValues [i ] = append (PutLengthEncodedInt (uint64 (len (v ))), v ... )
143162 default :
144163 return fmt .Errorf ("invalid argument type %T" , args [i ])
145164 }
165+ paramNames [i ] = []byte {0 } // length encoded, no name
166+ if paramFlags [i ] == nil {
167+ paramFlags [i ] = []byte {0 }
168+ }
146169
147170 length += len (paramValues [i ])
148171 }
172+ for i , qa := range s .conn .queryAttributes {
173+ tf := qa .TypeAndFlag ()
174+ paramTypes [(i + paramsNum )] = []byte {tf [0 ]}
175+ paramFlags [i + paramsNum ] = []byte {tf [1 ]}
176+ paramValues [i + paramsNum ] = qa .ValueBytes ()
177+ paramNames [i + paramsNum ] = PutLengthEncodedString ([]byte (qa .Name ))
178+ }
149179
150180 data := utils .BytesBufferGet ()
151181 defer func () {
@@ -159,25 +189,40 @@ func (s *Stmt) write(args ...interface{}) error {
159189 data .WriteByte (COM_STMT_EXECUTE )
160190 data .Write ([]byte {byte (s .id ), byte (s .id >> 8 ), byte (s .id >> 16 ), byte (s .id >> 24 )})
161191
162- //flag: CURSOR_TYPE_NO_CURSOR
163- data .WriteByte (0x00 )
192+ flags := CURSOR_TYPE_NO_CURSOR
193+ if paramsNum > 0 {
194+ flags |= PARAMETER_COUNT_AVAILABLE
195+ }
196+ data .WriteByte (flags )
164197
165198 //iteration-count, always 1
166199 data .Write ([]byte {1 , 0 , 0 , 0 })
167200
168- if s .params > 0 {
169- data .Write (nullBitmap )
170-
171- //new-params-bound-flag
172- data .WriteByte (newParamBoundFlag )
173-
174- if newParamBoundFlag == 1 {
175- //type of each parameter, length: num-params * 2
176- data .Write (paramTypes )
177-
178- //value of each parameter
179- for _ , v := range paramValues {
180- data .Write (v )
201+ if paramsNum > 0 || (s .conn .capability & CLIENT_QUERY_ATTRIBUTES > 0 && (flags & PARAMETER_COUNT_AVAILABLE > 0 )) {
202+ if s .conn .capability & CLIENT_QUERY_ATTRIBUTES > 0 {
203+ paramsNum += len (s .conn .queryAttributes )
204+ data .Write (PutLengthEncodedInt (uint64 (paramsNum )))
205+ }
206+ if paramsNum > 0 {
207+ data .Write (nullBitmap )
208+
209+ //new-params-bound-flag
210+ data .WriteByte (newParamBoundFlag )
211+
212+ if newParamBoundFlag == 1 {
213+ for i := 0 ; i < paramsNum ; i ++ {
214+ data .Write (paramTypes [i ])
215+ data .Write (paramFlags [i ])
216+
217+ if s .conn .capability & CLIENT_QUERY_ATTRIBUTES > 0 {
218+ data .Write (paramNames [i ])
219+ }
220+ }
221+
222+ //value of each parameter
223+ for _ , v := range paramValues {
224+ data .Write (v )
225+ }
181226 }
182227 }
183228 }
0 commit comments