@@ -118,8 +118,9 @@ func runTests(t *testing.T, name string, tests ...func(dbt *DBTest)) {
118118 }
119119 defer db .Close ()
120120
121+ db .Exec ("DROP TABLE IF EXISTS test" )
122+
121123 dbt := & DBTest {t , db }
122- dbt .db .Exec ("DROP TABLE IF EXISTS test" )
123124 for _ , test := range tests {
124125 test (dbt )
125126 dbt .db .Exec ("DROP TABLE IF EXISTS test" )
@@ -743,46 +744,62 @@ func TestStrict(t *testing.T) {
743744 runTests (t , "TestStrict" , func (dbt * DBTest ) {
744745 dbt .mustExec ("CREATE TABLE test (a TINYINT NOT NULL, b CHAR(4))" )
745746
746- queries := [... ][2 ]string {
747- {"DROP TABLE IF EXISTS no_such_table" , "Note 1051: Unknown table 'no_such_table'" },
748- {"INSERT INTO test VALUES(10,'mysql'),(NULL,'test'),(300,'Open Source')" ,
749- "Warning 1265: Data truncated for column 'b' at row 1\r \n " +
750- "Warning 1048: Column 'a' cannot be null\r \n " +
751- "Warning 1264: Out of range value for column 'a' at row 3\r \n " +
752- "Warning 1265: Data truncated for column 'b' at row 3" ,
753- },
747+ var queries = [... ]struct {
748+ in string
749+ codes []string
750+ }{
751+ {"DROP TABLE IF EXISTS no_such_table" , []string {"1051" }},
752+ {"INSERT INTO test VALUES(10,'mysql'),(NULL,'test'),(300,'Open Source')" , []string {"1265" , "1048" , "1264" , "1265" }},
754753 }
755754 var err error
756755
757- // text protocol
758- for i := range queries {
759- _ , err = dbt .db .Exec (queries [i ][0 ])
756+ var checkWarnings = func (err error , mode string , idx int ) {
760757 if err == nil {
761- dbt .Errorf ("Expecteded strict error on query [text] %s" , queries [i ][0 ])
762- } else if err .Error () != queries [i ][1 ] {
763- dbt .Errorf ("Unexpected error on query [text] %s: %s != %s" , queries [i ][0 ], err .Error (), queries [i ][1 ])
758+ dbt .Errorf ("Expected STRICT error on query [%s] %s" , mode , queries [idx ].in )
759+ }
760+
761+ if warnings , ok := err .(MySQLWarnings ); ok {
762+ var codes = make ([]string , len (warnings ))
763+ for i := range warnings {
764+ codes [i ] = warnings [i ].Code
765+ }
766+ if len (codes ) != len (queries [idx ].codes ) {
767+ dbt .Errorf ("Unexpected STRICT error count on query [%s] %s: Wanted %v, Got %v" , mode , queries [idx ].in , queries [idx ].codes , codes )
768+ }
769+
770+ for i := range warnings {
771+ if codes [i ] != queries [idx ].codes [i ] {
772+ dbt .Errorf ("Unexpected STRICT error codes on query [%s] %s: Wanted %v, Got %v" , mode , queries [idx ].in , queries [idx ].codes , codes )
773+ return
774+ }
775+ }
776+
777+ } else {
778+ dbt .Errorf ("Unexpected error on query [%s] %s: %s" , mode , queries [idx ].in , err .Error ())
764779 }
765780 }
766781
782+ // text protocol
783+ for i := range queries {
784+ _ , err = dbt .db .Exec (queries [i ].in )
785+ checkWarnings (err , "text" , i )
786+ }
787+
767788 var stmt * sql.Stmt
768789
769790 // binary protocol
770791 for i := range queries {
771- stmt , err = dbt .db .Prepare (queries [i ][ 0 ] )
792+ stmt , err = dbt .db .Prepare (queries [i ]. in )
772793 if err != nil {
773- dbt .Error ("Error on preparing query %: " , queries [i ][ 0 ] , err .Error ())
794+ dbt .Error ("Error on preparing query %: " , queries [i ]. in , err .Error ())
774795 }
775796
776797 _ , err = stmt .Exec ()
777- if err == nil {
778- dbt .Errorf ("Expecteded strict error on query [binary] %s" , queries [i ][0 ])
779- } else if err .Error () != queries [i ][1 ] {
780- dbt .Errorf ("Unexpected error on query [binary] %s: %s != %s" , queries [i ][0 ], err .Error (), queries [i ][1 ])
781- }
798+ checkWarnings (err , "binary" , i )
782799
783800 err = stmt .Close ()
784801 if err != nil {
785- dbt .Error ("Error on closing stmt for query %: " , queries [i ][ 0 ] , err .Error ())
802+ dbt .Error ("Error on closing stmt for query %: " , queries [i ]. in , err .Error ())
786803 }
787804 }
788805 })
0 commit comments