Skip to content

Commit bd0b6e8

Browse files
authored
Merge pull request #2224 from actiontech/fis-issue2175
use 1 instead select fields
2 parents 63159d6 + b398e3e commit bd0b6e8

File tree

2 files changed

+83
-1
lines changed

2 files changed

+83
-1
lines changed

sqle/driver/mysql/audit_test.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6376,6 +6376,62 @@ func TestDMLCheckSelectRows(t *testing.T) {
63766376
WillReturnRows(sqlmock.NewRows([]string{"COUNT(1)"}).AddRow("100000000"))
63776377
runSingleRuleInspectCase(rule, t, "", inspect6, "select * from exist_tb_2 where user_id in (select v3 from exist_tb_3)", newTestResult().addResult(rulepkg.DMLCheckSelectRows))
63786378

6379+
inspect7 := NewMockInspect(e)
6380+
handler.ExpectQuery(regexp.QuoteMeta("select id, v1 as id from exist_tb_2 limit 10, 10")).
6381+
WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(executor.ExplainRecordAccessTypeIndex).AddRow("range"))
6382+
handler.ExpectQuery(regexp.QuoteMeta("select count(*) from (SELECT 1 FROM `exist_tb_2` LIMIT 10,10) as t")).
6383+
WillReturnRows(sqlmock.NewRows([]string{"count(*)"}).AddRow("100000000"))
6384+
runSingleRuleInspectCase(rule, t, "", inspect7, "select id, v1 as id from exist_tb_2 limit 10, 10", newTestResult().addResult(rulepkg.DMLCheckSelectRows))
6385+
6386+
inspect8 := NewMockInspect(e)
6387+
handler.ExpectQuery(regexp.QuoteMeta("select id, v1 as id from exist_tb_2 group by id, v1")).
6388+
WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(executor.ExplainRecordAccessTypeIndex).AddRow("range"))
6389+
handler.ExpectQuery(regexp.QuoteMeta("select count(*) from (SELECT 1 FROM `exist_tb_2` GROUP BY `id`,`v1`) as t")).
6390+
WillReturnRows(sqlmock.NewRows([]string{"count(*)"}).AddRow("100000000"))
6391+
runSingleRuleInspectCase(rule, t, "", inspect8, "select id, v1 as id from exist_tb_2 group by id, v1", newTestResult().addResult(rulepkg.DMLCheckSelectRows))
6392+
6393+
inspect9 := NewMockInspect(e)
6394+
handler.ExpectQuery(regexp.QuoteMeta("select id, v1 as id from exist_tb_2 limit 10, 10")).
6395+
WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(executor.ExplainRecordAccessTypeIndex).AddRow("range"))
6396+
handler.ExpectQuery(regexp.QuoteMeta("select count(*) from (SELECT 1 FROM `exist_tb_2` LIMIT 10,10) as t")).
6397+
WillReturnRows(sqlmock.NewRows([]string{"count(*)"}).AddRow("10"))
6398+
runSingleRuleInspectCase(rule, t, "", inspect9, "select id, v1 as id from exist_tb_2 limit 10, 10", newTestResult())
6399+
6400+
inspect10 := NewMockInspect(e)
6401+
handler.ExpectQuery(regexp.QuoteMeta("select id, v1 as id from exist_tb_2 group by id, v1")).
6402+
WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(executor.ExplainRecordAccessTypeIndex).AddRow("range"))
6403+
handler.ExpectQuery(regexp.QuoteMeta("select count(*) from (SELECT 1 FROM `exist_tb_2` GROUP BY `id`,`v1`) as t")).
6404+
WillReturnRows(sqlmock.NewRows([]string{"count(*)"}).AddRow("10"))
6405+
runSingleRuleInspectCase(rule, t, "", inspect10, "select id, v1 as id from exist_tb_2 group by id, v1", newTestResult())
6406+
6407+
inspect11 := NewMockInspect(e)
6408+
handler.ExpectQuery(regexp.QuoteMeta("select max(v1) from exist_tb_2 group by id")).
6409+
WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(executor.ExplainRecordAccessTypeIndex).AddRow("range"))
6410+
handler.ExpectQuery(regexp.QuoteMeta("select count(*) from (SELECT 1 FROM `exist_tb_2` GROUP BY `id`) as t")).
6411+
WillReturnRows(sqlmock.NewRows([]string{"count(*)"}).AddRow("10"))
6412+
runSingleRuleInspectCase(rule, t, "", inspect11, "select max(v1) from exist_tb_2 group by id", newTestResult())
6413+
6414+
inspect12 := NewMockInspect(e)
6415+
handler.ExpectQuery(regexp.QuoteMeta("select max(v1) from exist_tb_2 group by id")).
6416+
WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(executor.ExplainRecordAccessTypeIndex).AddRow("range"))
6417+
handler.ExpectQuery(regexp.QuoteMeta("select count(*) from (SELECT 1 FROM `exist_tb_2` GROUP BY `id`) as t")).
6418+
WillReturnRows(sqlmock.NewRows([]string{"count(*)"}).AddRow("10000000"))
6419+
runSingleRuleInspectCase(rule, t, "", inspect12, "select max(v1) from exist_tb_2 group by id", newTestResult().addResult(rulepkg.DMLCheckSelectRows))
6420+
6421+
inspect13 := NewMockInspect(e)
6422+
handler.ExpectQuery(regexp.QuoteMeta("select max(v1) as id, id from exist_tb_2 group by id")).
6423+
WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(executor.ExplainRecordAccessTypeIndex).AddRow("range"))
6424+
handler.ExpectQuery(regexp.QuoteMeta("select count(*) from (SELECT 1 FROM `exist_tb_2` GROUP BY `id`) as t")).
6425+
WillReturnRows(sqlmock.NewRows([]string{"count(*)"}).AddRow("10"))
6426+
runSingleRuleInspectCase(rule, t, "", inspect13, "select max(v1) as id, id from exist_tb_2 group by id", newTestResult())
6427+
6428+
inspect14 := NewMockInspect(e)
6429+
handler.ExpectQuery(regexp.QuoteMeta("select max(v1) as id, id from exist_tb_2 group by id")).
6430+
WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(executor.ExplainRecordAccessTypeIndex).AddRow("range"))
6431+
handler.ExpectQuery(regexp.QuoteMeta("select count(*) from (SELECT 1 FROM `exist_tb_2` GROUP BY `id`) as t")).
6432+
WillReturnRows(sqlmock.NewRows([]string{"count(*)"}).AddRow("10000000"))
6433+
runSingleRuleInspectCase(rule, t, "", inspect14, "select max(v1) as id, id from exist_tb_2 group by id", newTestResult().addResult(rulepkg.DMLCheckSelectRows))
6434+
63796435
}
63806436

63816437
func TestDMLCheckScanRows(t *testing.T) {

sqle/driver/mysql/util/util.go

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,15 @@ func GetAffectedRowNum(ctx context.Context, originSql string, conn *executor.Exe
6868
// 2. SELECT COUNT(1) FROM test LIMIT 10,10 类型的SQL结果集为空
6969
// 已上两种情况,使用子查询 select count(*) from (输入的sql) as t的方式来获取影响行数
7070
if cannotConvert {
71+
// 将select语句中的查询字段替换为数字1
72+
// https://github.com/actiontech/sqle/issues/2175
73+
newSql, err := useIntReplaceSelectFields(node)
74+
if err != nil {
75+
log.NewEntry().Errorf("replace select fields failed, err: %v", err)
76+
newSql = originSql
77+
}
7178
// 移除后缀分号,避免sql语法错误
72-
trimSuffix := strings.TrimRight(originSql, ";")
79+
trimSuffix := strings.TrimRight(newSql, ";")
7380
affectRowSql = fmt.Sprintf("select count(*) from (%s) as t", trimSuffix)
7481
} else {
7582
sqlBuilder := new(strings.Builder)
@@ -112,6 +119,25 @@ func GetAffectedRowNum(ctx context.Context, originSql string, conn *executor.Exe
112119
return affectCount, nil
113120
}
114121

122+
func useIntReplaceSelectFields(node ast.StmtNode) (string, error) {
123+
stmt, ok := node.(*ast.SelectStmt)
124+
if !ok {
125+
return "", errors.New("pass parameter is not select node")
126+
}
127+
newValue := &driver.ValueExpr{}
128+
newValue.SetInt64(1)
129+
selectFields := &ast.SelectField{Expr: newValue}
130+
stmt.Fields.Fields = []*ast.SelectField{selectFields}
131+
132+
sqlBuilder := new(strings.Builder)
133+
err := stmt.Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, sqlBuilder))
134+
if err != nil {
135+
return "", err
136+
}
137+
affectRowSql := sqlBuilder.String()
138+
return affectRowSql, nil
139+
}
140+
115141
func getSelectNodeFromDelete(stmt *ast.DeleteStmt) *ast.SelectStmt {
116142
newSelect := newSelectWithCount()
117143

0 commit comments

Comments
 (0)