From 12302387f4d0458760bed5381a024f715d2e32ef Mon Sep 17 00:00:00 2001 From: Saransh Gupta Date: Mon, 30 Mar 2026 17:39:20 -0700 Subject: [PATCH] feat: add tls verification to mysql dsn - Update mysql.template.json with new TLS options - Handle custom TLS configuration - Add tests --- mysql.template.json | 4 +++- utils/utils.go | 53 +++++++++++++++++++++++++++++++++++++-------- utils/utils_test.go | 47 +++++++++++++++++++++++++++++++++++++++- 3 files changed, 93 insertions(+), 11 deletions(-) diff --git a/mysql.template.json b/mysql.template.json index 9cf36a21..b7cd0590 100644 --- a/mysql.template.json +++ b/mysql.template.json @@ -2,5 +2,7 @@ "username": "user", "password": "password", "server": "localhost", - "database": "mydb" + "database": "mydb", + "tls": false, + "caCertPath": "/path/to/CA/certificate" } diff --git a/utils/utils.go b/utils/utils.go index 8556cf3f..faff8ee5 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -2,18 +2,22 @@ package utils import ( "bufio" + "crypto/tls" + "crypto/x509" "database/sql" "encoding/json" "errors" "fmt" - "github.com/rs/zerolog" - "golang.org/x/sys/unix" "io" "os" "path/filepath" "pbench/log" "reflect" "strings" + + "github.com/go-sql-driver/mysql" + "github.com/rs/zerolog" + "golang.org/x/sys/unix" ) const ( @@ -64,6 +68,23 @@ func InitLogFile(logPath string) (finalizer func()) { } } +func createTLSConfig(caCertPath string) (*tls.Config, error) { + rootCertPool := x509.NewCertPool() + pem, err := os.ReadFile(caCertPath) + if err != nil { + log.Error().Err(err).Msg("failed to read CA certificate") + return nil, err + } + if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { + log.Error().Msg("failed to append CA certificate") + return nil, fmt.Errorf("failed to append CA certificate from PEM") + } + return &tls.Config{ + MinVersion: tls.VersionTLS13, // Explicitly setting TLS version to 1.3 to remediate vulnerability CWE-327 + RootCAs: rootCertPool, + }, nil +} + func InitMySQLConnFromCfg(cfgPath string) *sql.DB { if cfgPath == "" { return nil @@ -73,17 +94,31 @@ func InitMySQLConnFromCfg(cfgPath string) *sql.DB { return nil } else { mySQLCfg := &struct { - Username string `json:"username"` - Password string `json:"password"` - Server string `json:"server"` - Database string `json:"database"` - }{} + Username string `json:"username"` + Password string `json:"password"` + Server string `json:"server"` + Database string `json:"database"` + TLS bool `json:"tls"` + CaCertPath string `json:"caCertPath"` + }{ + TLS: false, + } if err := json.Unmarshal(cfgBytes, mySQLCfg); err != nil { log.Error().Err(err).Msg("failed to unmarshal MySQL connection config for the run recorder") return nil } - if db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?parseTime=true", - mySQLCfg.Username, mySQLCfg.Password, mySQLCfg.Server, mySQLCfg.Database)); err != nil { + tlsType := "false" + if mySQLCfg.TLS { + tlsType = "custom" + tlsConfig, err := createTLSConfig(mySQLCfg.CaCertPath) + if err != nil { + log.Error().Msg("TLS enabled but failed to load certificates") + return nil + } + mysql.RegisterTLSConfig(tlsType, tlsConfig) + } + if db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?tls=%s&parseTime=true", + mySQLCfg.Username, mySQLCfg.Password, mySQLCfg.Server, mySQLCfg.Database, tlsType)); err != nil { log.Error().Err(err).Msg("failed to initialize MySQL connection for the run recorder") return nil } else if err = db.Ping(); err != nil { diff --git a/utils/utils_test.go b/utils/utils_test.go index 66158419..882f8d07 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -1,10 +1,12 @@ package utils import ( - "github.com/stretchr/testify/assert" + "encoding/json" "os" "path/filepath" "testing" + + "github.com/stretchr/testify/assert" ) func TestExpandHomeDirectory(t *testing.T) { @@ -25,3 +27,46 @@ func TestExpandHomeDirectory_JustTilde(t *testing.T) { ExpandHomeDirectory(&path) assert.Equal(t, os.Getenv("HOME"), path) } + +func TestCreateTLSConfig_InvalidCAPath(t *testing.T) { + // tests error handling when CA cert file doesn't exist + tlsConfig, err := createTLSConfig("/nonexistent/ca.pem") + + assert.Error(t, err, "should return error for non-existent CA certificate") + assert.Nil(t, tlsConfig, "should return nil config on error") +} + +func TestCreateTLSConfig_InvalidCAPEM(t *testing.T) { + // tests error handling when CA cert has invalid PEM content + tmpDir := t.TempDir() + caPath := filepath.Join(tmpDir, "invalid-ca.pem") + err := os.WriteFile(caPath, []byte("invalid pem content"), 0644) + assert.NoError(t, err) + + tlsConfig, err := createTLSConfig(caPath) + + assert.Error(t, err, "should return error for invalid PEM content") + assert.Nil(t, tlsConfig, "should return nil config on error") +} + +func TestInitMySQLConnFromCfg_TLSEnabledInvalidCerts(t *testing.T) { + // When TLS is enabled but certificates are invalid, function should return nil early + config := map[string]interface{}{ + "username": "testuser", + "password": "testpass", + "server": "localhost:3306", + "database": "testdb", + "tls": true, + "caCertPath": "/nonexistent/ca.pem", + } + + tmpDir := t.TempDir() + cfgPath := filepath.Join(tmpDir, "config.json") + configJSON, err := json.Marshal(config) + assert.NoError(t, err) + err = os.WriteFile(cfgPath, configJSON, 0644) + assert.NoError(t, err) + + db := InitMySQLConnFromCfg(cfgPath) + assert.Nil(t, db, "should return nil when TLS config fails") +}