diff --git a/api/v2board/cert.go b/api/v2board/cert.go new file mode 100644 index 00000000..2239fdba --- /dev/null +++ b/api/v2board/cert.go @@ -0,0 +1,55 @@ +package panel + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" +) + +type NodeCertPair struct { + Cert string `json:"cert"` + Key string `json:"key"` +} + +func (c *Client) GetNodeCertPair() (pair *NodeCertPair, changed bool, err error) { + const path = "/api/v1/server/UniProxy/nodecert" + r, err := c.client. + R(). + SetHeader("If-None-Match", c.nodeCertEtag). + ForceContentType("application/json"). + Get(path) + + if err != nil { + return nil, false, err + } + if r == nil { + return nil, false, fmt.Errorf("received nil response") + } + if r.StatusCode() == 304 { + return nil, false, nil + } + hash := sha256.Sum256(r.Body()) + newBodyHash := hex.EncodeToString(hash[:]) + if c.nodeCertBodyHash == newBodyHash { + return nil, false, nil + } + c.nodeCertBodyHash = newBodyHash + c.nodeCertEtag = r.Header().Get("ETag") + + defer func() { + if r.RawBody() != nil { + r.RawBody().Close() + } + }() + + pair = &NodeCertPair{} + err = json.Unmarshal(r.Body(), pair) + if err != nil { + return nil, false, fmt.Errorf("decode node cert error: %s", err) + } + if pair.Cert == "" || pair.Key == "" { + return nil, false, fmt.Errorf("received empty cert or key") + } + return pair, true, nil +} diff --git a/api/v2board/panel.go b/api/v2board/panel.go index 169a62e7..0617383a 100644 --- a/api/v2board/panel.go +++ b/api/v2board/panel.go @@ -21,6 +21,8 @@ type Client struct { NodeId int nodeEtag string userEtag string + nodeCertEtag string + nodeCertBodyHash string responseBodyHash string UserList *UserListBody AliveMap *AliveMap @@ -29,7 +31,7 @@ type Client struct { func New(c *conf.NodeConfig) (*Client, error) { client := resty.New() client.SetRetryCount(3) - client.SetHeader("User-Agent", fmt.Sprintf("v2node go-resty/%s (https://github.com/go-resty/resty)", resty.Version)) + client.SetHeader("User-Agent", fmt.Sprintf("v2node go-resty/%s (https://github.com/go-resty/resty)", resty.Version)) if c.Timeout > 0 { client.SetTimeout(time.Duration(c.Timeout) * time.Second) } else { diff --git a/node/cert.go b/node/cert.go index b55cdc25..9198787b 100644 --- a/node/cert.go +++ b/node/cert.go @@ -1,8 +1,10 @@ package node import ( + "bytes" "crypto/rand" "crypto/rsa" + "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/pem" @@ -12,7 +14,9 @@ import ( "time" log "github.com/sirupsen/logrus" + panel "github.com/wyx2685/v2node/api/v2board" "github.com/wyx2685/v2node/common/file" + "github.com/wyx2685/v2node/common/task" ) func (c *Controller) renewCertTask() error { @@ -29,6 +33,70 @@ func (c *Controller) renewCertTask() error { return nil } +func (c *Controller) syncOnlineCertTask() error { + cert := c.info.Common.CertInfo + certPEM, keyPEM, changed, err := c.fetchOnlineCertPair() + if err != nil { + log.WithFields(log.Fields{ + "tag": c.tag, + "err": err, + }).Info("sync online cert error") + return nil + } + if !changed { + return nil + } + currentCert, currentKey, err := loadCertPair(cert.CertFile, cert.KeyFile) + if err == nil && bytes.Equal(currentCert, certPEM) && bytes.Equal(currentKey, keyPEM) { + return nil + } + if err := writeCertPair(cert.CertFile, cert.KeyFile, certPEM, keyPEM); err != nil { + log.WithFields(log.Fields{ + "tag": c.tag, + "err": err, + }).Info("write online cert error") + return nil + } + log.WithField("tag", c.tag).Info("Online cert updated") + return nil +} + +func (c *Controller) startCertTask(node *panel.NodeInfo) { + if node.Security != panel.Tls { + return + } + switch c.info.Common.CertInfo.CertMode { + case "none", "", "file", "self": + case "dns", "http": + c.renewCertPeriodic = &task.Task{ + Name: "renewCertTask", + Interval: time.Hour * 24, + Execute: c.renewCertTask, + Reload: c.reloadTask, + } + log.WithField("tag", c.tag).Info("Start renew cert") + _ = c.renewCertPeriodic.Start(true) + case "online": + interval := node.PullInterval * 60 + if interval <= 0 { + interval = time.Hour + } + c.renewCertPeriodic = &task.Task{ + Name: "syncOnlineCertTask", + Interval: interval, + Execute: c.syncOnlineCertTask, + Reload: c.reloadTask, + } + log.WithField("tag", c.tag).Info("Start sync online cert") + _ = c.renewCertPeriodic.Start(false) + default: + log.WithFields(log.Fields{ + "tag": c.tag, + "certmode": c.info.Common.CertInfo.CertMode, + }).Warn("Skip unknown cert task mode") + } +} + func (c *Controller) requestCert() error { cert := c.info.Common.CertInfo switch cert.CertMode { @@ -66,12 +134,122 @@ func (c *Controller) requestCert() error { if err != nil { return fmt.Errorf("generate self cert error: %s", err) } + case "online": + if cert.CertFile == "" || cert.KeyFile == "" { + return fmt.Errorf("cert file path or key file path not exist") + } + certPEM, keyPEM, changed, err := c.fetchOnlineCertPair() + if err != nil { + if localErr := validateStoredCertPair(cert.CertFile, cert.KeyFile, cert.CertDomain); localErr == nil { + log.WithFields(log.Fields{ + "tag": c.tag, + "err": err, + }).Warn("fetch online cert failed, keep local cert") + return nil + } + return fmt.Errorf("fetch online cert error: %s", err) + } + if !changed { + if err := validateStoredCertPair(cert.CertFile, cert.KeyFile, cert.CertDomain); err != nil { + return fmt.Errorf("online cert not modified and local cert invalid: %s", err) + } + return nil + } + currentCert, currentKey, err := loadCertPair(cert.CertFile, cert.KeyFile) + if err == nil && bytes.Equal(currentCert, certPEM) && bytes.Equal(currentKey, keyPEM) { + return nil + } + if err := writeCertPair(cert.CertFile, cert.KeyFile, certPEM, keyPEM); err != nil { + return fmt.Errorf("write online cert error: %s", err) + } default: return fmt.Errorf("unsupported certmode: %s", cert.CertMode) } return nil } +func (c *Controller) fetchOnlineCertPair() ([]byte, []byte, bool, error) { + pair, changed, err := c.apiClient.GetNodeCertPair() + if err != nil { + return nil, nil, false, fmt.Errorf("get node cert pair error: %w", err) + } + if !changed { + return nil, nil, false, nil + } + certPEM := []byte(pair.Cert) + keyPEM := []byte(pair.Key) + if err := validateOnlineCertPair(certPEM, keyPEM, c.info.Common.CertInfo.CertDomain); err != nil { + return nil, nil, false, err + } + return certPEM, keyPEM, true, nil +} + +func loadCertPair(certPath, keyPath string) ([]byte, []byte, error) { + if !file.IsExist(certPath) || !file.IsExist(keyPath) { + return nil, nil, fmt.Errorf("cert file path or key file path not exist") + } + certPEM, err := os.ReadFile(certPath) + if err != nil { + return nil, nil, fmt.Errorf("read cert file error: %w", err) + } + keyPEM, err := os.ReadFile(keyPath) + if err != nil { + return nil, nil, fmt.Errorf("read key file error: %w", err) + } + return certPEM, keyPEM, nil +} + +func validateStoredCertPair(certPath, keyPath, domain string) error { + certPEM, keyPEM, err := loadCertPair(certPath, keyPath) + if err != nil { + return err + } + return validateOnlineCertPair(certPEM, keyPEM, domain) +} + +func validateOnlineCertPair(certPEM, keyPEM []byte, domain string) error { + pair, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + return fmt.Errorf("parse cert pair error: %w", err) + } + if len(pair.Certificate) == 0 { + return fmt.Errorf("certificate chain is empty") + } + leaf, err := x509.ParseCertificate(pair.Certificate[0]) + if err != nil { + return fmt.Errorf("parse leaf certificate error: %w", err) + } + now := time.Now() + if now.Before(leaf.NotBefore) { + return fmt.Errorf("certificate is not valid before %s", leaf.NotBefore) + } + if !now.Before(leaf.NotAfter) { + return fmt.Errorf("certificate expired at %s", leaf.NotAfter) + } + if domain != "" { + if err := leaf.VerifyHostname(domain); err != nil { + return fmt.Errorf("certificate does not match domain %s: %w", domain, err) + } + } + return nil +} + +func writeCertPair(certPath, keyPath string, certPEM, keyPEM []byte) error { + if err := checkPath(certPath); err != nil { + return fmt.Errorf("check cert path error: %w", err) + } + if err := os.WriteFile(certPath, certPEM, 0644); err != nil { + return fmt.Errorf("write cert file error: %w", err) + } + if err := checkPath(keyPath); err != nil { + return fmt.Errorf("check key path error: %w", err) + } + if err := os.WriteFile(keyPath, keyPEM, 0644); err != nil { + return fmt.Errorf("write key file error: %w", err) + } + return nil +} + func generateSelfSslCertificate(domain, certPath, keyPath string) error { key, _ := rsa.GenerateKey(rand.Reader, 2048) tmpl := &x509.Certificate{ diff --git a/node/task.go b/node/task.go index 3ac30d3c..484e1dee 100644 --- a/node/task.go +++ b/node/task.go @@ -1,8 +1,6 @@ package node import ( - "time" - log "github.com/sirupsen/logrus" panel "github.com/wyx2685/v2node/api/v2board" "github.com/wyx2685/v2node/common/task" @@ -29,21 +27,7 @@ func (c *Controller) startTasks(node *panel.NodeInfo) { _ = c.nodeInfoMonitorPeriodic.Start(false) log.WithField("tag", c.tag).Info("Start report node status") _ = c.userReportPeriodic.Start(false) - if node.Security == panel.Tls { - switch c.info.Common.CertInfo.CertMode { - case "none", "", "file", "self": - default: - c.renewCertPeriodic = &task.Task{ - Name: "renewCertTask", - Interval: time.Hour * 24, - Execute: c.renewCertTask, - Reload: c.reloadTask, - } - log.WithField("tag", c.tag).Info("Start renew cert") - // delay to start renewCert - _ = c.renewCertPeriodic.Start(true) - } - } + c.startCertTask(node) } func (c *Controller) reloadTask() {