From d828c33aee098db3261708574c5f346137cd9519 Mon Sep 17 00:00:00 2001 From: Martin Alderete Date: Sun, 27 Aug 2023 00:15:27 +0200 Subject: [PATCH] Improved heartbeat collector Added argument's placeholders rather than string concatenation when using query data with non-trusted arguments (CLI inputs) Signed-off-by: Martin Alderete --- collector/heartbeat.go | 6 ++---- collector/heartbeat_test.go | 11 ++++++----- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/collector/heartbeat.go b/collector/heartbeat.go index 7bc5fc58..5f626d85 100644 --- a/collector/heartbeat.go +++ b/collector/heartbeat.go @@ -18,7 +18,6 @@ package collector import ( "context" "database/sql" - "fmt" "strconv" "github.com/alecthomas/kingpin/v2" @@ -33,7 +32,7 @@ const ( // timestamps. %s will be replaced by the database and table name. // The second column allows gets the server timestamp at the exact same // time the query is run. - heartbeatQuery = "SELECT UNIX_TIMESTAMP(ts), UNIX_TIMESTAMP(%s), server_id from `%s`.`%s`" + heartbeatQuery = "SELECT UNIX_TIMESTAMP(ts), UNIX_TIMESTAMP(?), server_id from ?.?" ) var ( @@ -101,8 +100,7 @@ func nowExpr() string { // Scrape collects data from database connection and sends it over channel as prometheus metric. func (ScrapeHeartbeat) Scrape(ctx context.Context, db *sql.DB, ch chan<- prometheus.Metric, logger log.Logger) error { - query := fmt.Sprintf(heartbeatQuery, nowExpr(), *collectHeartbeatDatabase, *collectHeartbeatTable) - heartbeatRows, err := db.QueryContext(ctx, query) + heartbeatRows, err := db.QueryContext(ctx, heartbeatQuery, nowExpr(), *collectHeartbeatDatabase, *collectHeartbeatTable) if err != nil { return err } diff --git a/collector/heartbeat_test.go b/collector/heartbeat_test.go index 48e35925..8acbdc72 100644 --- a/collector/heartbeat_test.go +++ b/collector/heartbeat_test.go @@ -15,6 +15,7 @@ package collector import ( "context" + "database/sql/driver" "fmt" "testing" @@ -29,17 +30,17 @@ import ( type ScrapeHeartbeatTestCase struct { Args []string Columns []string - Query string + SQLArgs []driver.Value } var ScrapeHeartbeatTestCases = []ScrapeHeartbeatTestCase{ { []string{ "--collect.heartbeat.database", "heartbeat-test", - "--collect.heartbeat.table", "heartbeat-test", + "--collect.heartbeat.table", "heartbeat-test' OR 1=1;", }, []string{"UNIX_TIMESTAMP(ts)", "UNIX_TIMESTAMP(NOW(6))", "server_id"}, - "SELECT UNIX_TIMESTAMP(ts), UNIX_TIMESTAMP(NOW(6)), server_id from `heartbeat-test`.`heartbeat-test`", + []driver.Value{"NOW(6)", "heartbeat-test", "heartbeat-test' OR 1=1;"}, }, { []string{ @@ -48,7 +49,7 @@ var ScrapeHeartbeatTestCases = []ScrapeHeartbeatTestCase{ "--collect.heartbeat.utc", }, []string{"UNIX_TIMESTAMP(ts)", "UNIX_TIMESTAMP(UTC_TIMESTAMP(6))", "server_id"}, - "SELECT UNIX_TIMESTAMP(ts), UNIX_TIMESTAMP(UTC_TIMESTAMP(6)), server_id from `heartbeat-test`.`heartbeat-test`", + []driver.Value{"UTC_TIMESTAMP(6)", "heartbeat-test", "heartbeat-test"}, }, } @@ -68,7 +69,7 @@ func TestScrapeHeartbeat(t *testing.T) { rows := sqlmock.NewRows(tt.Columns). AddRow("1487597613.001320", "1487598113.448042", 1) - mock.ExpectQuery(sanitizeQuery(tt.Query)).WillReturnRows(rows) + mock.ExpectQuery("SELECT").WithArgs(tt.SQLArgs...).WillReturnRows(rows) ch := make(chan prometheus.Metric) go func() {