Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mysql.template.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,7 @@
"username": "user",
"password": "password",
"server": "localhost",
"database": "mydb"
"database": "mydb",
"tls": false,
"caCertPath": "/path/to/CA/certificate"
}
53 changes: 44 additions & 9 deletions utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,22 @@ package utils

import (
"bufio"
"crypto/tls"
"crypto/x509"
"database/sql"
"encoding/json"
"errors"
"fmt"
"github.com/rs/zerolog"
"golang.org/x/sys/unix"
"io"
"os"
"path/filepath"
"pbench/log"
"reflect"
"strings"

"github.com/go-sql-driver/mysql"
"github.com/rs/zerolog"
"golang.org/x/sys/unix"
)

const (
Expand Down Expand Up @@ -64,6 +68,23 @@ func InitLogFile(logPath string) (finalizer func()) {
}
}

func createTLSConfig(caCertPath string) (*tls.Config, error) {
rootCertPool := x509.NewCertPool()
pem, err := os.ReadFile(caCertPath)
if err != nil {
log.Error().Err(err).Msg("failed to read CA certificate")
return nil, err
}
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
log.Error().Msg("failed to append CA certificate")
return nil, fmt.Errorf("failed to append CA certificate from PEM")
}
return &tls.Config{
MinVersion: tls.VersionTLS13, // Explicitly setting TLS version to 1.3 to remediate vulnerability CWE-327
RootCAs: rootCertPool,
}, nil
}

func InitMySQLConnFromCfg(cfgPath string) *sql.DB {
if cfgPath == "" {
return nil
Expand All @@ -73,17 +94,31 @@ func InitMySQLConnFromCfg(cfgPath string) *sql.DB {
return nil
} else {
mySQLCfg := &struct {
Username string `json:"username"`
Password string `json:"password"`
Server string `json:"server"`
Database string `json:"database"`
}{}
Username string `json:"username"`
Password string `json:"password"`
Server string `json:"server"`
Database string `json:"database"`
TLS bool `json:"tls"`
CaCertPath string `json:"caCertPath"`
}{
TLS: false,
}
if err := json.Unmarshal(cfgBytes, mySQLCfg); err != nil {
log.Error().Err(err).Msg("failed to unmarshal MySQL connection config for the run recorder")
return nil
}
if db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?parseTime=true",
mySQLCfg.Username, mySQLCfg.Password, mySQLCfg.Server, mySQLCfg.Database)); err != nil {
tlsType := "false"
if mySQLCfg.TLS {
tlsType = "custom"
tlsConfig, err := createTLSConfig(mySQLCfg.CaCertPath)
if err != nil {
log.Error().Msg("TLS enabled but failed to load certificates")
return nil
}
mysql.RegisterTLSConfig(tlsType, tlsConfig)
}
if db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?tls=%s&parseTime=true",
mySQLCfg.Username, mySQLCfg.Password, mySQLCfg.Server, mySQLCfg.Database, tlsType)); err != nil {
log.Error().Err(err).Msg("failed to initialize MySQL connection for the run recorder")
return nil
} else if err = db.Ping(); err != nil {
Expand Down
47 changes: 46 additions & 1 deletion utils/utils_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package utils

import (
"github.com/stretchr/testify/assert"
"encoding/json"
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
)

func TestExpandHomeDirectory(t *testing.T) {
Expand All @@ -25,3 +27,46 @@ func TestExpandHomeDirectory_JustTilde(t *testing.T) {
ExpandHomeDirectory(&path)
assert.Equal(t, os.Getenv("HOME"), path)
}

func TestCreateTLSConfig_InvalidCAPath(t *testing.T) {
// tests error handling when CA cert file doesn't exist
tlsConfig, err := createTLSConfig("/nonexistent/ca.pem")

assert.Error(t, err, "should return error for non-existent CA certificate")
assert.Nil(t, tlsConfig, "should return nil config on error")
}

func TestCreateTLSConfig_InvalidCAPEM(t *testing.T) {
// tests error handling when CA cert has invalid PEM content
tmpDir := t.TempDir()
caPath := filepath.Join(tmpDir, "invalid-ca.pem")
err := os.WriteFile(caPath, []byte("invalid pem content"), 0644)
assert.NoError(t, err)

tlsConfig, err := createTLSConfig(caPath)

assert.Error(t, err, "should return error for invalid PEM content")
assert.Nil(t, tlsConfig, "should return nil config on error")
}

func TestInitMySQLConnFromCfg_TLSEnabledInvalidCerts(t *testing.T) {
// When TLS is enabled but certificates are invalid, function should return nil early
config := map[string]interface{}{
"username": "testuser",
"password": "testpass",
"server": "localhost:3306",
"database": "testdb",
"tls": true,
"caCertPath": "/nonexistent/ca.pem",
}

tmpDir := t.TempDir()
cfgPath := filepath.Join(tmpDir, "config.json")
configJSON, err := json.Marshal(config)
assert.NoError(t, err)
err = os.WriteFile(cfgPath, configJSON, 0644)
assert.NoError(t, err)

db := InitMySQLConnFromCfg(cfgPath)
assert.Nil(t, db, "should return nil when TLS config fails")
}