Skip to content

Commit 32f151a

Browse files
authored
feat: implement whitelist (#119)
### TL;DR Implemented query validation to enhance security and prevent potential SQL injection attacks. ### What changed? - Added a `ValidateQuery` function in `utils.go` to check for disallowed patterns and ensure only allowed functions are used in queries. - Integrated query validation in the `ClickHouseConnector` methods for executing queries. - Updated error handling in `logs_handlers.go` and `transactions_handlers.go` to potentially use `BadRequestError` for disallowed functions. ### How to test? 1. Try running queries with allowed functions (e.g., `sum`, `count`, `reinterpretAsUInt256`) and ensure they work as expected. 2. Attempt to use disallowed patterns or functions in queries and verify that they are rejected with appropriate error messages. 3. Test different types of queries (SELECT, INSERT, UPDATE, etc.) to confirm that only SELECT queries are allowed. ### Why make this change? This change enhances the security of the application by preventing potential SQL injection attacks and restricting the use of potentially harmful functions or query patterns. It ensures that only safe, pre-approved functions can be used in queries, reducing the risk of unauthorized data access or manipulation.
2 parents 6d744ac + 9cb0aac commit 32f151a

File tree

4 files changed

+66
-1
lines changed

4 files changed

+66
-1
lines changed

internal/common/utils.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package common
33
import (
44
"fmt"
55
"math/big"
6+
"regexp"
67
"strings"
78
"unicode"
89
)
@@ -169,3 +170,51 @@ func isType(word string) bool {
169170

170171
return types[word]
171172
}
173+
174+
var allowedFunctions = map[string]struct{}{
175+
"sum": {},
176+
"count": {},
177+
"reinterpretAsUInt256": {},
178+
"reverse": {},
179+
"unhex": {},
180+
"substring": {},
181+
"length": {},
182+
"toUInt256": {},
183+
"if": {},
184+
}
185+
186+
var disallowedPatterns = []string{
187+
`(?i)\b(UNION|INSERT|DELETE|UPDATE|DROP|CREATE|ALTER|TRUNCATE|EXEC|;|--)`,
188+
}
189+
190+
// validateQuery checks the query for disallowed patterns and ensures only allowed functions are used.
191+
func ValidateQuery(query string) error {
192+
// Check for disallowed patterns
193+
for _, pattern := range disallowedPatterns {
194+
matched, err := regexp.MatchString(pattern, query)
195+
if err != nil {
196+
return fmt.Errorf("error checking disallowed patterns: %v", err)
197+
}
198+
if matched {
199+
return fmt.Errorf("query contains disallowed keywords or patterns")
200+
}
201+
}
202+
203+
// Ensure the query is a SELECT statement
204+
trimmedQuery := strings.TrimSpace(strings.ToUpper(query))
205+
if !strings.HasPrefix(trimmedQuery, "SELECT") {
206+
return fmt.Errorf("only SELECT queries are allowed")
207+
}
208+
209+
// Extract function names and validate them
210+
functionPattern := regexp.MustCompile(`(?i)(\b\w+\b)\s*\(`)
211+
matches := functionPattern.FindAllStringSubmatch(query, -1)
212+
for _, match := range matches {
213+
funcName := match[1]
214+
if _, ok := allowedFunctions[funcName]; !ok {
215+
return fmt.Errorf("function '%s' is not allowed", funcName)
216+
}
217+
}
218+
219+
return nil
220+
}

internal/handlers/logs_handlers.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ func handleLogsRequest(c *gin.Context, contractAddress, signature string) {
170170
aggregatesResult, err := mainStorage.GetAggregations("logs", qf)
171171
if err != nil {
172172
log.Error().Err(err).Msg("Error querying aggregates")
173+
// TODO: might want to choose BadRequestError if it's due to not-allowed functions
173174
api.InternalErrorHandler(c)
174175
return
175176
}
@@ -180,6 +181,7 @@ func handleLogsRequest(c *gin.Context, contractAddress, signature string) {
180181
logsResult, err := mainStorage.GetLogs(qf)
181182
if err != nil {
182183
log.Error().Err(err).Msg("Error querying logs")
184+
// TODO: might want to choose BadRequestError if it's due to not-allowed functions
183185
api.InternalErrorHandler(c)
184186
return
185187
}

internal/handlers/transactions_handlers.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ func handleTransactionsRequest(c *gin.Context, contractAddress, signature string
172172
aggregatesResult, err := mainStorage.GetAggregations("transactions", qf)
173173
if err != nil {
174174
log.Error().Err(err).Msg("Error querying aggregates")
175+
// TODO: might want to choose BadRequestError if it's due to not-allowed functions
175176
api.InternalErrorHandler(c)
176177
return
177178
}
@@ -181,7 +182,8 @@ func handleTransactionsRequest(c *gin.Context, contractAddress, signature string
181182
// Retrieve logs data
182183
transactionsResult, err := mainStorage.GetTransactions(qf)
183184
if err != nil {
184-
log.Error().Err(err).Msg("Error querying tran")
185+
log.Error().Err(err).Msg("Error querying transactions")
186+
// TODO: might want to choose BadRequestError if it's due to not-allowed functions
185187
api.InternalErrorHandler(c)
186188
return
187189
}

internal/storage/clickhouse.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,9 @@ func (c *ClickHouseConnector) GetBlocks(qf QueryFilter) (blocks []common.Block,
301301

302302
query += getLimitClause(int(qf.Limit))
303303

304+
if err := common.ValidateQuery(query); err != nil {
305+
return nil, err
306+
}
304307
rows, err := c.conn.Query(context.Background(), query)
305308
if err != nil {
306309
return nil, err
@@ -369,6 +372,9 @@ func (c *ClickHouseConnector) GetAggregations(table string, qf QueryFilter) (Que
369372
query += fmt.Sprintf(" GROUP BY %s", groupByColumns)
370373
}
371374

375+
if err := common.ValidateQuery(query); err != nil {
376+
return QueryResult[interface{}]{}, err
377+
}
372378
// Execute the query
373379
rows, err := c.conn.Query(context.Background(), query)
374380
if err != nil {
@@ -421,6 +427,9 @@ func (c *ClickHouseConnector) GetAggregations(table string, qf QueryFilter) (Que
421427
func executeQuery[T any](c *ClickHouseConnector, table, columns string, qf QueryFilter, scanFunc func(driver.Rows) (T, error)) (QueryResult[T], error) {
422428
query := c.buildQuery(table, columns, qf)
423429

430+
if err := common.ValidateQuery(query); err != nil {
431+
return QueryResult[T]{}, err
432+
}
424433
rows, err := c.conn.Query(context.Background(), query)
425434
if err != nil {
426435
return QueryResult[T]{}, err
@@ -856,6 +865,9 @@ func (c *ClickHouseConnector) GetTraces(qf QueryFilter) (traces []common.Trace,
856865

857866
query += getLimitClause(int(qf.Limit))
858867

868+
if err := common.ValidateQuery(query); err != nil {
869+
return nil, err
870+
}
859871
rows, err := c.conn.Query(context.Background(), query)
860872
if err != nil {
861873
return nil, err

0 commit comments

Comments
 (0)