From eac924dce10a302fc2cd651852846b2b613c3896 Mon Sep 17 00:00:00 2001 From: lcy Date: Mon, 27 Apr 2026 15:51:22 +0800 Subject: [PATCH] fix: prevent reload storms on panel API failures --- api/panel/node.go | 3 +- api/panel/panel.go | 38 ++++--- api/panel/server.go | 148 ++++++++++++++++++++++++---- api/panel/server_test.go | 98 ++++++++++++++++++ api/panel/user.go | 20 ++-- cmd/server.go | 34 +++++-- common/serverstatus/serverstatus.go | 2 +- common/task/task.go | 32 +++--- core/xray.go | 15 ++- node/task.go | 5 - 10 files changed, 316 insertions(+), 79 deletions(-) create mode 100644 api/panel/server_test.go diff --git a/api/panel/node.go b/api/panel/node.go index 4c4b9b8..09115f9 100644 --- a/api/panel/node.go +++ b/api/panel/node.go @@ -2,7 +2,6 @@ package panel import ( "fmt" - "path" "time" ) @@ -38,7 +37,7 @@ func (c *ClientV1) ReportNodeStatus(nodeStatus *NodeStatus) (err error) { UpdatedAt: time.Now().UnixMilli(), } if _, err = c.Client.R().SetBody(status).ForceContentType("application/json").Post(p); err != nil { - return fmt.Errorf("访问 %s 失败: %v", path.Join(c.APIHost+p), err.Error()) + return fmt.Errorf("访问 %s 失败: %s", endpointURL(c.APIHost, p), sanitizeError(err, c.SecretKey)) } return nil } diff --git a/api/panel/panel.go b/api/panel/panel.go index 34ec0aa..15bb248 100644 --- a/api/panel/panel.go +++ b/api/panel/panel.go @@ -1,18 +1,18 @@ package panel import ( - "errors" "fmt" + "regexp" "strconv" "strings" "time" - "github.com/sirupsen/logrus" - "github.com/go-resty/resty/v2" "github.com/perfect-panel/ppanel-node/conf" ) +var secretKeyPattern = regexp.MustCompile(`secret_key=[^&\s"]+`) + type ClientV1 struct { Client *resty.Client APIHost string @@ -30,7 +30,25 @@ type ClientV2 struct { SecretKey string ServerId int ServerConfigEtag string - responseBodyHash string + serverConfigHash string +} + +func endpointURL(base, p string) string { + return strings.TrimRight(base, "/") + p +} + +func redactSecret(s, secret string) string { + if secret != "" { + s = strings.ReplaceAll(s, secret, "") + } + return secretKeyPattern.ReplaceAllString(s, "secret_key=") +} + +func sanitizeError(err error, secret string) string { + if err == nil { + return "" + } + return redactSecret(err.Error(), secret) } func NewClientV1(c *conf.NodeApiConfig) (*ClientV1, error) { @@ -41,12 +59,6 @@ func NewClientV1(c *conf.NodeApiConfig) (*ClientV1, error) { } else { client.SetTimeout(30 * time.Second) } - client.OnError(func(req *resty.Request, err error) { - var v *resty.ResponseError - if errors.As(err, &v) { - logrus.Error(v.Err) - } - }) client.SetBaseURL(c.APIHost) // Check node type c.NodeType = strings.ToLower(c.NodeType) @@ -88,12 +100,6 @@ func NewClientV2(c *conf.ServerApiConfig) *ClientV2 { } else { client.SetTimeout(30 * time.Second) } - client.OnError(func(req *resty.Request, err error) { - var v *resty.ResponseError - if errors.As(err, &v) { - logrus.Error(v.Err) - } - }) client.SetBaseURL(c.ApiHost) client.SetQueryParams(map[string]string{ "secret_key": c.SecretKey, diff --git a/api/panel/server.go b/api/panel/server.go index b15337b..874e5b5 100644 --- a/api/panel/server.go +++ b/api/panel/server.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "encoding/json" "fmt" + "sort" ) type ServerConfigResponse struct { @@ -26,6 +27,18 @@ type Data struct { Total int `json:"total"` } +type semanticServerConfigData struct { + TrafficReportThreshold int `json:"traffic_report_threshold"` + PushInterval int `json:"push_interval"` + PullInterval int `json:"pull_interval"` + IPStrategy string `json:"ip_strategy"` + DNS []DNSItem `json:"dns"` + Block []string `json:"block"` + Outbound []Outbound `json:"outbound"` + Protocols []Protocol `json:"protocols"` + Total int `json:"total"` +} + type DNSItem struct { Proto string `json:"proto"` Address string `json:"address"` @@ -102,42 +115,135 @@ func GetServerConfig(ctx context.Context, c *ClientV2) (*ServerConfigResponse, e // 优先检查错误,避免处理无效响应 if err != nil { - return nil, fmt.Errorf("访问 %s 失败: %v", client.BaseURL+path, err.Error()) + return nil, fmt.Errorf("访问 %s 失败: %s", endpointURL(client.BaseURL, path), sanitizeError(err, c.SecretKey)) + } + + if r == nil { + return nil, fmt.Errorf("服务端返回为空") } - + // 检查 HTTP 状态码 if r.StatusCode() == 304 { return nil, nil } if r.StatusCode() >= 400 { body := r.Body() - return nil, fmt.Errorf("访问 %s 失败: %s", client.BaseURL+path, string(body)) - } - - // 只有在成功响应时才检查 hash - hash := sha256.Sum256(r.Body()) - newBodyHash := hex.EncodeToString(hash[:]) - if c.responseBodyHash == newBodyHash { - return nil, nil + return nil, fmt.Errorf("访问 %s 失败: %s", endpointURL(client.BaseURL, path), redactSecret(string(body), c.SecretKey)) } - c.responseBodyHash = newBodyHash c.ServerConfigEtag = r.Header().Get("ETag") - if r != nil { - defer func() { - if r.RawBody() != nil { - r.RawBody().Close() - } - }() - } else { - return nil, fmt.Errorf("服务端返回为空") - } resp := &ServerConfigResponse{} err = json.Unmarshal(r.Body(), resp) if err != nil { return nil, fmt.Errorf("解码响应体失败: %s", err) } - if resp.Data.Protocols == nil { + if resp.Data == nil || resp.Data.Protocols == nil { return nil, fmt.Errorf("协议配置为空") } + newConfigHash, err := semanticServerConfigHash(resp) + if err != nil { + return nil, err + } + if c.serverConfigHash == newConfigHash { + return nil, nil + } + c.serverConfigHash = newConfigHash return resp, nil } + +func semanticServerConfigHash(resp *ServerConfigResponse) (string, error) { + normalized := normalizeServerConfigData(resp.Data) + body, err := json.Marshal(normalized) + if err != nil { + return "", fmt.Errorf("编码服务端配置指纹失败: %s", err) + } + hash := sha256.Sum256(body) + return hex.EncodeToString(hash[:]), nil +} + +func normalizeServerConfigData(data *Data) semanticServerConfigData { + if data == nil { + return semanticServerConfigData{ + DNS: []DNSItem{}, + Block: []string{}, + Outbound: []Outbound{}, + Protocols: []Protocol{}, + } + } + + dnsItems := cloneDNSItems(data.DNS) + blockItems := cloneStringSlice(data.Block) + outboundItems := cloneOutboundItems(data.Outbound) + protocolItems := cloneProtocolItems(data.Protocols) + + sort.Strings(blockItems) + sort.SliceStable(protocolItems, func(i, j int) bool { + if protocolItems[i].Type != protocolItems[j].Type { + return protocolItems[i].Type < protocolItems[j].Type + } + if protocolItems[i].Port != protocolItems[j].Port { + return protocolItems[i].Port < protocolItems[j].Port + } + if protocolItems[i].Transport != protocolItems[j].Transport { + return protocolItems[i].Transport < protocolItems[j].Transport + } + return protocolItems[i].Security < protocolItems[j].Security + }) + + return semanticServerConfigData{ + TrafficReportThreshold: data.TrafficReportThreshold, + PushInterval: data.PushInterval, + PullInterval: data.PullInterval, + IPStrategy: data.IPStrategy, + DNS: dnsItems, + Block: blockItems, + Outbound: outboundItems, + Protocols: protocolItems, + Total: data.Total, + } +} + +func cloneStringSlice(items *[]string) []string { + if items == nil { + return []string{} + } + clone := make([]string, len(*items)) + copy(clone, *items) + return clone +} + +func cloneDNSItems(items *[]DNSItem) []DNSItem { + if items == nil { + return []DNSItem{} + } + clone := make([]DNSItem, len(*items)) + for i, item := range *items { + clone[i] = item + clone[i].Domains = make([]string, len(item.Domains)) + copy(clone[i].Domains, item.Domains) + sort.Strings(clone[i].Domains) + } + return clone +} + +func cloneOutboundItems(items *[]Outbound) []Outbound { + if items == nil { + return []Outbound{} + } + clone := make([]Outbound, len(*items)) + for i, item := range *items { + clone[i] = item + clone[i].Rules = make([]string, len(item.Rules)) + copy(clone[i].Rules, item.Rules) + sort.Strings(clone[i].Rules) + } + return clone +} + +func cloneProtocolItems(items *[]Protocol) []Protocol { + if items == nil { + return []Protocol{} + } + clone := make([]Protocol, len(*items)) + copy(clone, *items) + return clone +} diff --git a/api/panel/server_test.go b/api/panel/server_test.go new file mode 100644 index 0000000..775c92e --- /dev/null +++ b/api/panel/server_test.go @@ -0,0 +1,98 @@ +package panel + +import "testing" + +func TestSemanticServerConfigHashNormalizesNonSemanticOrder(t *testing.T) { + first := &ServerConfigResponse{Data: &Data{ + PushInterval: 60, + PullInterval: 90, + DNS: &[]DNSItem{ + {Proto: "udp", Address: "1.1.1.1", Domains: []string{"suffix:example.com", "keyword:google"}}, + }, + Block: &[]string{"suffix:b.example", "suffix:a.example"}, + Outbound: &[]Outbound{ + {Name: "proxy", Protocol: "socks", Address: "127.0.0.1", Port: 1080, Rules: []string{"suffix:b.example", "suffix:a.example"}}, + }, + Protocols: &[]Protocol{ + {Type: "hysteria", Port: 443, Security: "tls"}, + {Type: "vless", Port: 8443, Security: "reality", Transport: "tcp"}, + }, + Total: 2, + }} + second := &ServerConfigResponse{Data: &Data{ + PushInterval: 60, + PullInterval: 90, + DNS: &[]DNSItem{ + {Proto: "udp", Address: "1.1.1.1", Domains: []string{"keyword:google", "suffix:example.com"}}, + }, + Block: &[]string{"suffix:a.example", "suffix:b.example"}, + Outbound: &[]Outbound{ + {Name: "proxy", Protocol: "socks", Address: "127.0.0.1", Port: 1080, Rules: []string{"suffix:a.example", "suffix:b.example"}}, + }, + Protocols: &[]Protocol{ + {Type: "vless", Port: 8443, Security: "reality", Transport: "tcp"}, + {Type: "hysteria", Port: 443, Security: "tls"}, + }, + Total: 2, + }} + + firstHash, err := semanticServerConfigHash(first) + if err != nil { + t.Fatal(err) + } + secondHash, err := semanticServerConfigHash(second) + if err != nil { + t.Fatal(err) + } + if firstHash != secondHash { + t.Fatalf("expected semantically equal configs to match: %s != %s", firstHash, secondHash) + } +} + +func TestSemanticServerConfigHashKeepsOutboundOrder(t *testing.T) { + first := &ServerConfigResponse{Data: &Data{ + Outbound: &[]Outbound{ + {Name: "first", Protocol: "socks", Address: "127.0.0.1", Port: 1080}, + {Name: "second", Protocol: "socks", Address: "127.0.0.2", Port: 1081}, + }, + Protocols: &[]Protocol{}, + }} + second := &ServerConfigResponse{Data: &Data{ + Outbound: &[]Outbound{ + {Name: "second", Protocol: "socks", Address: "127.0.0.2", Port: 1081}, + {Name: "first", Protocol: "socks", Address: "127.0.0.1", Port: 1080}, + }, + Protocols: &[]Protocol{}, + }} + + firstHash, err := semanticServerConfigHash(first) + if err != nil { + t.Fatal(err) + } + secondHash, err := semanticServerConfigHash(second) + if err != nil { + t.Fatal(err) + } + if firstHash == secondHash { + t.Fatal("expected outbound order changes to remain significant") + } +} + +func TestSemanticServerConfigHashTreatsNilSlicesAsEmpty(t *testing.T) { + firstHash, err := semanticServerConfigHash(&ServerConfigResponse{Data: &Data{Protocols: &[]Protocol{}}}) + if err != nil { + t.Fatal(err) + } + secondHash, err := semanticServerConfigHash(&ServerConfigResponse{Data: &Data{ + DNS: &[]DNSItem{}, + Block: &[]string{}, + Outbound: &[]Outbound{}, + Protocols: &[]Protocol{}, + }}) + if err != nil { + t.Fatal(err) + } + if firstHash != secondHash { + t.Fatalf("expected nil and empty slices to match: %s != %s", firstHash, secondHash) + } +} diff --git a/api/panel/user.go b/api/panel/user.go index fb099d8..4ea8f11 100644 --- a/api/panel/user.go +++ b/api/panel/user.go @@ -3,7 +3,7 @@ package panel import ( "context" "fmt" - "path" + "io" "encoding/json/jsontext" "encoding/json/v2" @@ -41,6 +41,9 @@ func (c *ClientV1) GetUserList(ctx context.Context) ([]UserInfo, error) { ForceContentType("application/json"). SetDoNotParseResponse(true). Get(p) + if err != nil { + return nil, fmt.Errorf("访问 %s 失败: %s", endpointURL(c.APIHost, p), sanitizeError(err, c.SecretKey)) + } if r == nil || r.RawResponse == nil { return nil, fmt.Errorf("服务端响应为空") } @@ -50,12 +53,9 @@ func (c *ClientV1) GetUserList(ctx context.Context) ([]UserInfo, error) { return nil, nil } - if err != nil { - return nil, fmt.Errorf("访问 %s 失败: %s", path.Join(c.APIHost+p), err) - } if r.StatusCode() >= 400 { - body := r.Body() - return nil, fmt.Errorf("访问 %s 失败: %s", path.Join(c.APIHost+p), string(body)) + body, _ := io.ReadAll(r.RawResponse.Body) + return nil, fmt.Errorf("访问 %s 失败: %s", endpointURL(c.APIHost, p), redactSecret(string(body), c.SecretKey)) } userlist := &UserListBody{} dec := jsontext.NewDecoder(r.RawResponse.Body) @@ -142,11 +142,11 @@ func (c *ClientV1) ReportUserTraffic(ctx context.Context, userTraffic *[]UserTra ForceContentType("application/json"). Post(p) if err != nil { - return fmt.Errorf("访问 %s 失败: %s", path.Join(c.APIHost+p), err) + return fmt.Errorf("访问 %s 失败: %s", endpointURL(c.APIHost, p), sanitizeError(err, c.SecretKey)) } if r.StatusCode() >= 400 { body := r.Body() - return fmt.Errorf("访问 %s 失败: %s", path.Join(c.APIHost+p), string(body)) + return fmt.Errorf("访问 %s 失败: %s", endpointURL(c.APIHost, p), redactSecret(string(body), c.SecretKey)) } return nil @@ -163,11 +163,11 @@ func (c *ClientV1) ReportNodeOnlineUsers(ctx context.Context, data *[]OnlineUser ForceContentType("application/json"). Post(p) if err != nil { - return fmt.Errorf("访问 %s 失败: %s", path.Join(c.APIHost+p), err) + return fmt.Errorf("访问 %s 失败: %s", endpointURL(c.APIHost, p), sanitizeError(err, c.SecretKey)) } if r.StatusCode() >= 400 { body := r.Body() - return fmt.Errorf("访问 %s 失败: %s", path.Join(c.APIHost+p), string(body)) + return fmt.Errorf("访问 %s 失败: %s", endpointURL(c.APIHost, p), redactSecret(string(body), c.SecretKey)) } return nil diff --git a/cmd/server.go b/cmd/server.go index 336b557..289679e 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -9,6 +9,7 @@ import ( "os/signal" "runtime" "syscall" + "time" "github.com/perfect-panel/ppanel-node/api/panel" "github.com/perfect-panel/ppanel-node/conf" @@ -24,6 +25,8 @@ var ( watch bool ) +const reloadDebounce = 60 * time.Second + var serverCommand = cobra.Command{ Use: "server", Short: "Run ppnode server", @@ -95,7 +98,11 @@ func serverHandle(_ *cobra.Command, _ []string) { log.WithField("err", err).Error("启动Xray核心失败") return } - defer xraycore.Close() + defer func() { + if xraycore != nil { + _ = xraycore.Close() + } + }() nodes, err := node.New(xraycore, c, serverconfig) if err != nil { log.WithField("err", err).Error("获取节点配置失败") @@ -125,14 +132,21 @@ func serverHandle(_ *cobra.Command, _ []string) { osSignals := make(chan os.Signal, 1) signal.Notify(osSignals, syscall.SIGINT, syscall.SIGTERM) + var lastReload time.Time for { select { case <-osSignals: nodes.Close() _ = xraycore.Close() + xraycore = nil return case <-reloadCh: + if !lastReload.IsZero() && time.Since(lastReload) < reloadDebounce { + log.Debug("重启信号被防抖跳过") + continue + } + lastReload = time.Now() log.Info("收到重启信号,正在重新加载配置...") if err := reload(config, &nodes, &xraycore); err != nil { log.WithField("err", err).Error("重启失败") @@ -149,11 +163,6 @@ func reload(config string, nodes **node.Node, xcore **core.XrayCore) error { oldReloadCh = (*xcore).ReloadCh } - (*nodes).Close() - if err := (*xcore).Close(); err != nil { - return err - } - newConf := conf.New() if err := newConf.LoadFromPath(config); err != nil { return err @@ -173,9 +182,22 @@ func reload(config string, nodes **node.Node, xcore **core.XrayCore) error { } newNodes, err := node.New(newCore, newConf, serverconfig) if err != nil { + _ = newCore.Close() return err } + + if *nodes != nil { + (*nodes).Close() + } + if *xcore != nil { + if err := (*xcore).Close(); err != nil { + _ = newCore.Close() + return err + } + } + if err := newNodes.Start(); err != nil { + _ = newCore.Close() return err } diff --git a/common/serverstatus/serverstatus.go b/common/serverstatus/serverstatus.go index aa8f36b..4737f4a 100644 --- a/common/serverstatus/serverstatus.go +++ b/common/serverstatus/serverstatus.go @@ -46,7 +46,7 @@ func GetSystemInfo() (Cpu float64, Mem float64, Disk float64, Uptime uint64, err } if errorString != "" { - err = fmt.Errorf(errorString) + err = fmt.Errorf("%s", errorString) } return Cpu, Mem, Disk, Uptime, err diff --git a/common/task/task.go b/common/task/task.go index 867f256..35dbafc 100644 --- a/common/task/task.go +++ b/common/task/task.go @@ -15,8 +15,9 @@ type Task struct { Execute func(context.Context) error Access sync.RWMutex Running bool - ReloadCh chan struct{} Stop chan struct{} + cancel context.CancelFunc + wg sync.WaitGroup } func (t *Task) Start(first bool) error { @@ -27,12 +28,16 @@ func (t *Task) Start(first bool) error { } t.Running = true t.Stop = make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) + t.cancel = cancel + t.wg.Add(1) t.Access.Unlock() go func() { + defer t.wg.Done() timer := time.NewTimer(t.Interval) defer timer.Stop() if first { - if err := t.ExecuteWithTimeout(); err != nil { + if err := t.ExecuteWithTimeout(ctx); err != nil { return } } @@ -44,9 +49,11 @@ func (t *Task) Start(first bool) error { // continue case <-t.Stop: return + case <-ctx.Done(): + return } - if err := t.ExecuteWithTimeout(); err != nil { + if err := t.ExecuteWithTimeout(ctx); err != nil { log.Errorf("Task %s execution error: %v", t.Name, err) return } @@ -56,8 +63,8 @@ func (t *Task) Start(first bool) error { return nil } -func (t *Task) ExecuteWithTimeout() error { - ctx, cancel := context.WithTimeout(context.Background(), min(5*t.Interval, 5*time.Minute)) +func (t *Task) ExecuteWithTimeout(parent context.Context) error { + ctx, cancel := context.WithTimeout(parent, min(5*t.Interval, 5*time.Minute)) defer cancel() done := make(chan error, 1) @@ -67,15 +74,10 @@ func (t *Task) ExecuteWithTimeout() error { select { case <-ctx.Done(): - log.Errorf("Task %s execution timed out, reloading", t.Name) - if t.ReloadCh != nil { - select { - case t.ReloadCh <- struct{}{}: - default: - } - } else { - log.Panic("Reload failed") + if errors.Is(parent.Err(), context.Canceled) { + return nil } + log.Errorf("Task %s execution timed out", t.Name) return nil case err := <-done: if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { @@ -89,6 +91,9 @@ func (t *Task) safeStop() { t.Access.Lock() if t.Running { t.Running = false + if t.cancel != nil { + t.cancel() + } close(t.Stop) } t.Access.Unlock() @@ -96,5 +101,6 @@ func (t *Task) safeStop() { func (t *Task) Close() { t.safeStop() + t.wg.Wait() log.Warningf("Task %s stopped", t.Name) } diff --git a/core/xray.go b/core/xray.go index 955a3f8..b3969d2 100644 --- a/core/xray.go +++ b/core/xray.go @@ -76,14 +76,18 @@ func (v *XrayCore) Close() error { defer v.access.Unlock() if v.serverConfigMonitorPeriodic != nil { v.serverConfigMonitorPeriodic.Close() + v.serverConfigMonitorPeriodic = nil } v.Config = nil v.ihm = nil v.ohm = nil v.dispatcher = nil - err := v.Server.Close() - if err != nil { - return err + if v.Server != nil { + err := v.Server.Close() + v.Server = nil + if err != nil { + return err + } } return nil } @@ -145,9 +149,9 @@ func (c *XrayCore) startTasks(serverconfig *panel.ServerConfigResponse) { pullinverval = 60 } c.serverConfigMonitorPeriodic = &task.Task{ + Name: "serverConfigMonitor", Interval: time.Duration(pullinverval) * time.Second, Execute: c.ServerConfigMonitor, - ReloadCh: c.ReloadCh, } _ = c.serverConfigMonitorPeriodic.Start(false) } @@ -159,12 +163,13 @@ func (c *XrayCore) ServerConfigMonitor(ctx context.Context) (err error) { return nil } if newServerConfig != nil { - log.Error("检测到服务端配置变更,正在重启节点...") // Non-blocking signal to avoid goroutine stuck when channel is full or nil if c.ReloadCh != nil { select { case c.ReloadCh <- struct{}{}: + log.Info("检测到服务端配置变更,已提交节点重启信号") default: + log.Debug("检测到服务端配置变更,已有重启信号等待处理") } } } diff --git a/node/task.go b/node/task.go index f74d2f6..b6977e6 100644 --- a/node/task.go +++ b/node/task.go @@ -18,14 +18,12 @@ func (c *Controller) startTasks(node *panel.NodeInfo) { Name: "userListMonitor", Interval: time.Duration(node.PullInterval) * time.Second, Execute: c.userListMonitor, - ReloadCh: c.server.ReloadCh, } // report user traffic task c.userReportPeriodic = &task.Task{ Name: "reportUserTraffic", Interval: time.Duration(node.PushInterval) * time.Second, Execute: c.reportUserTrafficTask, - ReloadCh: c.server.ReloadCh, } _ = c.userListMonitorPeriodic.Start(false) log.WithField("节点", c.tag).Info("用户列表监控任务已启动") @@ -57,7 +55,6 @@ func (c *Controller) startTasks(node *panel.NodeInfo) { Name: "renewCert", Interval: time.Hour * 24, Execute: c.renewCertTask, - ReloadCh: c.server.ReloadCh, } log.WithField("节点", c.tag).Info("证书定期更新任务已启动") // delay to start renewCert @@ -75,7 +72,6 @@ func (c *Controller) reloadTask() { c.startTasks(c.info) } - func (c *Controller) userListMonitor(ctx context.Context) (err error) { // get user info newU, err := c.apiClient.GetUserList(ctx) @@ -236,4 +232,3 @@ func compareUserList(old, new []panel.UserInfo) (deleted, added []panel.UserInfo return deleted, added } -