Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions api/v2board/cert.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package panel

import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
)

type NodeCertPair struct {
Cert string `json:"cert"`
Key string `json:"key"`
}

func (c *Client) GetNodeCertPair() (pair *NodeCertPair, changed bool, err error) {
const path = "/api/v1/server/UniProxy/nodecert"
r, err := c.client.
R().
SetHeader("If-None-Match", c.nodeCertEtag).
ForceContentType("application/json").
Get(path)

if err != nil {
return nil, false, err
}
if r == nil {
return nil, false, fmt.Errorf("received nil response")
}
if r.StatusCode() == 304 {
return nil, false, nil
}
hash := sha256.Sum256(r.Body())
newBodyHash := hex.EncodeToString(hash[:])
if c.nodeCertBodyHash == newBodyHash {
return nil, false, nil
}
c.nodeCertBodyHash = newBodyHash
c.nodeCertEtag = r.Header().Get("ETag")

defer func() {
if r.RawBody() != nil {
r.RawBody().Close()
}
}()

pair = &NodeCertPair{}
err = json.Unmarshal(r.Body(), pair)
if err != nil {
return nil, false, fmt.Errorf("decode node cert error: %s", err)
}
if pair.Cert == "" || pair.Key == "" {
return nil, false, fmt.Errorf("received empty cert or key")
}
return pair, true, nil
}
4 changes: 3 additions & 1 deletion api/v2board/panel.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ type Client struct {
NodeId int
nodeEtag string
userEtag string
nodeCertEtag string
nodeCertBodyHash string
responseBodyHash string
UserList *UserListBody
AliveMap *AliveMap
Expand All @@ -29,7 +31,7 @@ type Client struct {
func New(c *conf.NodeConfig) (*Client, error) {
client := resty.New()
client.SetRetryCount(3)
client.SetHeader("User-Agent", fmt.Sprintf("v2node go-resty/%s (https://github.com/go-resty/resty)", resty.Version))
client.SetHeader("User-Agent", fmt.Sprintf("v2node go-resty/%s (https://github.com/go-resty/resty)", resty.Version))
if c.Timeout > 0 {
client.SetTimeout(time.Duration(c.Timeout) * time.Second)
} else {
Expand Down
178 changes: 178 additions & 0 deletions node/cert.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package node

import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
Expand All @@ -12,7 +14,9 @@ import (
"time"

log "github.com/sirupsen/logrus"
panel "github.com/wyx2685/v2node/api/v2board"
"github.com/wyx2685/v2node/common/file"
"github.com/wyx2685/v2node/common/task"
)

func (c *Controller) renewCertTask() error {
Expand All @@ -29,6 +33,70 @@ func (c *Controller) renewCertTask() error {
return nil
}

func (c *Controller) syncOnlineCertTask() error {
cert := c.info.Common.CertInfo
certPEM, keyPEM, changed, err := c.fetchOnlineCertPair()
if err != nil {
log.WithFields(log.Fields{
"tag": c.tag,
"err": err,
}).Info("sync online cert error")
return nil
}
if !changed {
return nil
}
currentCert, currentKey, err := loadCertPair(cert.CertFile, cert.KeyFile)
if err == nil && bytes.Equal(currentCert, certPEM) && bytes.Equal(currentKey, keyPEM) {
return nil
}
if err := writeCertPair(cert.CertFile, cert.KeyFile, certPEM, keyPEM); err != nil {
log.WithFields(log.Fields{
"tag": c.tag,
"err": err,
}).Info("write online cert error")
return nil
}
log.WithField("tag", c.tag).Info("Online cert updated")
return nil
}

func (c *Controller) startCertTask(node *panel.NodeInfo) {
if node.Security != panel.Tls {
return
}
switch c.info.Common.CertInfo.CertMode {
case "none", "", "file", "self":
case "dns", "http":
c.renewCertPeriodic = &task.Task{
Name: "renewCertTask",
Interval: time.Hour * 24,
Execute: c.renewCertTask,
Reload: c.reloadTask,
}
log.WithField("tag", c.tag).Info("Start renew cert")
_ = c.renewCertPeriodic.Start(true)
case "online":
interval := node.PullInterval * 60
if interval <= 0 {
interval = time.Hour
}
c.renewCertPeriodic = &task.Task{
Name: "syncOnlineCertTask",
Interval: interval,
Execute: c.syncOnlineCertTask,
Reload: c.reloadTask,
}
log.WithField("tag", c.tag).Info("Start sync online cert")
_ = c.renewCertPeriodic.Start(false)
default:
log.WithFields(log.Fields{
"tag": c.tag,
"certmode": c.info.Common.CertInfo.CertMode,
}).Warn("Skip unknown cert task mode")
}
}

func (c *Controller) requestCert() error {
cert := c.info.Common.CertInfo
switch cert.CertMode {
Expand Down Expand Up @@ -66,12 +134,122 @@ func (c *Controller) requestCert() error {
if err != nil {
return fmt.Errorf("generate self cert error: %s", err)
}
case "online":
if cert.CertFile == "" || cert.KeyFile == "" {
return fmt.Errorf("cert file path or key file path not exist")
}
certPEM, keyPEM, changed, err := c.fetchOnlineCertPair()
if err != nil {
if localErr := validateStoredCertPair(cert.CertFile, cert.KeyFile, cert.CertDomain); localErr == nil {
log.WithFields(log.Fields{
"tag": c.tag,
"err": err,
}).Warn("fetch online cert failed, keep local cert")
return nil
}
return fmt.Errorf("fetch online cert error: %s", err)
}
if !changed {
if err := validateStoredCertPair(cert.CertFile, cert.KeyFile, cert.CertDomain); err != nil {
return fmt.Errorf("online cert not modified and local cert invalid: %s", err)
}
return nil
}
currentCert, currentKey, err := loadCertPair(cert.CertFile, cert.KeyFile)
if err == nil && bytes.Equal(currentCert, certPEM) && bytes.Equal(currentKey, keyPEM) {
return nil
}
if err := writeCertPair(cert.CertFile, cert.KeyFile, certPEM, keyPEM); err != nil {
return fmt.Errorf("write online cert error: %s", err)
}
default:
return fmt.Errorf("unsupported certmode: %s", cert.CertMode)
}
return nil
}

func (c *Controller) fetchOnlineCertPair() ([]byte, []byte, bool, error) {
pair, changed, err := c.apiClient.GetNodeCertPair()
if err != nil {
return nil, nil, false, fmt.Errorf("get node cert pair error: %w", err)
}
if !changed {
return nil, nil, false, nil
}
certPEM := []byte(pair.Cert)
keyPEM := []byte(pair.Key)
if err := validateOnlineCertPair(certPEM, keyPEM, c.info.Common.CertInfo.CertDomain); err != nil {
return nil, nil, false, err
}
return certPEM, keyPEM, true, nil
}

func loadCertPair(certPath, keyPath string) ([]byte, []byte, error) {
if !file.IsExist(certPath) || !file.IsExist(keyPath) {
return nil, nil, fmt.Errorf("cert file path or key file path not exist")
}
certPEM, err := os.ReadFile(certPath)
if err != nil {
return nil, nil, fmt.Errorf("read cert file error: %w", err)
}
keyPEM, err := os.ReadFile(keyPath)
if err != nil {
return nil, nil, fmt.Errorf("read key file error: %w", err)
}
return certPEM, keyPEM, nil
}

func validateStoredCertPair(certPath, keyPath, domain string) error {
certPEM, keyPEM, err := loadCertPair(certPath, keyPath)
if err != nil {
return err
}
return validateOnlineCertPair(certPEM, keyPEM, domain)
}

func validateOnlineCertPair(certPEM, keyPEM []byte, domain string) error {
pair, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
return fmt.Errorf("parse cert pair error: %w", err)
}
if len(pair.Certificate) == 0 {
return fmt.Errorf("certificate chain is empty")
}
leaf, err := x509.ParseCertificate(pair.Certificate[0])
if err != nil {
return fmt.Errorf("parse leaf certificate error: %w", err)
}
now := time.Now()
if now.Before(leaf.NotBefore) {
return fmt.Errorf("certificate is not valid before %s", leaf.NotBefore)
}
if !now.Before(leaf.NotAfter) {
return fmt.Errorf("certificate expired at %s", leaf.NotAfter)
}
if domain != "" {
if err := leaf.VerifyHostname(domain); err != nil {
return fmt.Errorf("certificate does not match domain %s: %w", domain, err)
}
}
return nil
}

func writeCertPair(certPath, keyPath string, certPEM, keyPEM []byte) error {
if err := checkPath(certPath); err != nil {
return fmt.Errorf("check cert path error: %w", err)
}
if err := os.WriteFile(certPath, certPEM, 0644); err != nil {
return fmt.Errorf("write cert file error: %w", err)
}
if err := checkPath(keyPath); err != nil {
return fmt.Errorf("check key path error: %w", err)
}
if err := os.WriteFile(keyPath, keyPEM, 0644); err != nil {
return fmt.Errorf("write key file error: %w", err)
}
return nil
}

func generateSelfSslCertificate(domain, certPath, keyPath string) error {
key, _ := rsa.GenerateKey(rand.Reader, 2048)
tmpl := &x509.Certificate{
Expand Down
18 changes: 1 addition & 17 deletions node/task.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package node

import (
"time"

log "github.com/sirupsen/logrus"
panel "github.com/wyx2685/v2node/api/v2board"
"github.com/wyx2685/v2node/common/task"
Expand All @@ -29,21 +27,7 @@ func (c *Controller) startTasks(node *panel.NodeInfo) {
_ = c.nodeInfoMonitorPeriodic.Start(false)
log.WithField("tag", c.tag).Info("Start report node status")
_ = c.userReportPeriodic.Start(false)
if node.Security == panel.Tls {
switch c.info.Common.CertInfo.CertMode {
case "none", "", "file", "self":
default:
c.renewCertPeriodic = &task.Task{
Name: "renewCertTask",
Interval: time.Hour * 24,
Execute: c.renewCertTask,
Reload: c.reloadTask,
}
log.WithField("tag", c.tag).Info("Start renew cert")
// delay to start renewCert
_ = c.renewCertPeriodic.Start(true)
}
}
c.startCertTask(node)
}

func (c *Controller) reloadTask() {
Expand Down
Loading