Skip to content

Commit

Permalink
Merge pull request #140 from ardens-jw/master
Browse files Browse the repository at this point in the history
Feature: Provide support for RDS MySQL IAM Authentication
  • Loading branch information
dewey authored Sep 26, 2024
2 parents 62daa91 + 512daaa commit 3f78c6b
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 15 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,11 @@ For some database backends some special functionality is available:
which will use the equivalent of `rds generate-db-auth-token`
for the password. For this driver, the `AWS_REGION` environment variable
must be set.
* rds-mysql: This type of URL expects a working AWS configuration
which will use the equivalent of `rds generate-db-auth-token`
for the password. For this driver, the `AWS_REGION` environment variable
must be set.
Why this exporter exists
========================
Expand Down
13 changes: 7 additions & 6 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,13 @@ type Job struct {
}

type connection struct {
conn *sqlx.DB
url string
driver string
host string
database string
user string
conn *sqlx.DB
url string
driver string
host string
database string
user string
tokenExpirationTime time.Time
}

// Query is an SQL query that is executed on a connection
Expand Down
102 changes: 93 additions & 9 deletions job.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,27 @@ var (
CloudSQLPrefix = "cloudsql+"
)

func handleRDSMySQLIAMAuth(conn string) (string, time.Time, error) {
dsn := strings.TrimPrefix(conn, "rds-mysql://")
config, err := mysql.ParseDSN(dsn)
if err != nil {
return "", time.Time{}, fmt.Errorf("failed to parse MySQL DSN: %v", err)
}

sess := session.Must(session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable,
}))

token, err := rdsutils.BuildAuthToken(config.Addr, os.Getenv("AWS_REGION"), config.User, sess.Config.Credentials)
if err != nil {
return "", time.Time{}, fmt.Errorf("failed to build RDS auth token: %v", err)
}

expirationTime := time.Now().Add(14 * time.Minute)

return token, expirationTime, nil
}

// Init will initialize the metric descriptors
func (j *Job) Init(logger log.Logger, queries map[string]string) error {
j.log = log.With(logger, "job", j.Name)
Expand Down Expand Up @@ -207,23 +228,53 @@ func (j *Job) updateConnections() {
continue
}

// MySQL DSNs do not parse cleanly as URLs as of Go 1.12.8+
if strings.HasPrefix(conn, "mysql://") {
config, err := mysql.ParseDSN(strings.TrimPrefix(conn, "mysql://"))
// Handle both RDS MySQL and regular MySQL connections
if strings.HasPrefix(conn, "rds-mysql://") || strings.HasPrefix(conn, "mysql://") {
isRDS := strings.HasPrefix(conn, "rds-mysql://")
var dsn string
var expirationTime time.Time

trimmedConn := conn
if isRDS {
trimmedConn = strings.TrimPrefix(conn, "rds-mysql://")
} else {
trimmedConn = strings.TrimPrefix(conn, "mysql://")
}

config, err := mysql.ParseDSN(trimmedConn)
if err != nil {
level.Error(j.log).Log("msg", "Failed to parse MySQL DSN", "url", conn, "err", err)
continue
}

if isRDS {
authToken, tokenExpiration, err := handleRDSMySQLIAMAuth(conn)
if err != nil {
level.Error(j.log).Log("msg", "Failed to build RDS auth token", "url", conn, "err", err)
continue
}
config.Passwd = authToken
config.AllowCleartextPasswords = true
expirationTime = tokenExpiration
}

dsn = config.FormatDSN()
if isRDS {
dsn = "rds-mysql://" + dsn
}

j.conns = append(j.conns, &connection{
conn: nil,
url: conn,
driver: "mysql",
host: config.Addr,
database: config.DBName,
user: config.User,
conn: nil,
url: dsn,
driver: "mysql",
host: config.Addr,
database: config.DBName,
user: config.User,
tokenExpirationTime: expirationTime,
})
continue
}

if strings.HasPrefix(conn, "rds-postgres://") {
// Reuse Postgres driver by stripping "rds-" from connection URL after building the RDS authentication token
conn = strings.TrimPrefix(conn, "rds-")
Expand Down Expand Up @@ -438,12 +489,45 @@ func (j *Job) runOnce() error {
func (c *connection) connect(job *Job) error {
// already connected
if c.conn != nil {
if strings.HasPrefix(c.url, "rds-mysql://") && time.Now().After(c.tokenExpirationTime) {
level.Warn(job.log).Log("msg", "Connection token expired, reconnecting")

authToken, expirationTime, err := handleRDSMySQLIAMAuth(c.url)
if err != nil {
return fmt.Errorf("failed to refresh RDS MySQL IAM Auth token: %w", err)
}

config, err := mysql.ParseDSN(strings.TrimPrefix(c.url, "rds-mysql://"))
if err != nil {
return fmt.Errorf("failed to parse MySQL DSN: %w", err)
}

config.Passwd = authToken
dsn := "rds-mysql://" + config.FormatDSN()

// Close the existing connection
c.conn.Close()
c.conn = nil

// Update the connection details
c.tokenExpirationTime = expirationTime
c.url = dsn

// Connect to the database with the new token
conn, err := sqlx.Connect(c.driver, strings.TrimPrefix(dsn, "rds-mysql://"))
if err != nil {
return fmt.Errorf("failed to connect to the database: %w", err)
}
c.conn = conn
return nil
}
return nil
}
dsn := c.url
switch c.driver {
case "mysql":
dsn = strings.TrimPrefix(dsn, "mysql://")
dsn = strings.TrimPrefix(dsn, "rds-mysql://")
case "clickhouse+tcp", "clickhouse+http": // Support both http and tcp connections
dsn = strings.TrimPrefix(dsn, "clickhouse+")
c.driver = "clickhouse"
Expand Down

0 comments on commit 3f78c6b

Please sign in to comment.