diff --git a/znet/server.go b/znet/server.go index a86bcc4e..b6c88b95 100644 --- a/znet/server.go +++ b/znet/server.go @@ -1,6 +1,7 @@ package znet import ( + "context" "crypto/rand" "crypto/tls" "errors" @@ -9,6 +10,7 @@ import ( "net/http" "os" "os/signal" + "sync" "sync/atomic" "syscall" "time" @@ -26,6 +28,11 @@ import ( "github.com/aceld/zinx/zpack" ) +// wsShutdownTimeout is the maximum time to wait for active WebSocket connections +// to finish when stopping the server gracefully. +// (wsShutdownTimeout 是优雅停服时等待活跃 WebSocket 连接完成的最长时间) +const wsShutdownTimeout = 5 * time.Second + // Server interface implementation, defines a Server service class // (接口实现,定义一个Server服务类) type Server struct { @@ -71,6 +78,14 @@ type Server struct { // (异步捕获连接关闭状态) exitChan chan struct{} + // stopOnce ensures Stop() is idempotent and exitChan is closed only once + // (stopOnce 保证 Stop() 幂等,exitChan 只被关闭一次) + stopOnce sync.Once + + // WebSocket HTTP server instance, used for graceful shutdown + // (WebSocket HTTP 服务实例,用于优雅停服) + wsServer *http.Server + // Decoder for dealing with message fragmentation and reassembly // (断粘包解码器) decoder ziface.IDecoder @@ -315,7 +330,11 @@ func (s *Server) ListenTcpConn() { func (s *Server) ListenWebsocketConn() { zlog.Ins().InfoF("[START] WEBSOCKET Server name: %s,listener at IP: %s, Port %d, Path %s is starting", s.Name, s.IP, s.WsPort, s.WsPath) - http.HandleFunc(s.WsPath, func(w http.ResponseWriter, r *http.Request) { + + // Use a local ServeMux to avoid polluting the global http.DefaultServeMux + // (使用局部 ServeMux 避免污染全局 http.DefaultServeMux) + mux := http.NewServeMux() + mux.HandleFunc(s.WsPath, func(w http.ResponseWriter, r *http.Request) { // 1. Check if the server has reached the maximum allowed number of connections // (设置服务器最大连接控制,如果超过最大连接,则等待) if s.ConnMgr.Len() >= zconf.GlobalObject.MaxConn { @@ -357,18 +376,44 @@ func (s *Server) ListenWebsocketConn() { }) - if zconf.GlobalObject.CertFile != "" && zconf.GlobalObject.PrivateKeyFile != "" { - err := http.ListenAndServeTLS(fmt.Sprintf("%s:%d", s.IP, s.WsPort), zconf.GlobalObject.CertFile, zconf.GlobalObject.PrivateKeyFile, nil) - if err != nil { - panic(err) + // Create an explicit http.Server so we can shut it down gracefully later + // (显式创建 http.Server,以便后续能够优雅停服) + srv := &http.Server{ + Addr: fmt.Sprintf("%s:%d", s.IP, s.WsPort), + Handler: mux, + } + s.wsServer = srv + + // Start the HTTP server in a background goroutine + // (在后台 goroutine 中启动 HTTP Server) + go func() { + var err error + if zconf.GlobalObject.CertFile != "" && zconf.GlobalObject.PrivateKeyFile != "" { + err = srv.ListenAndServeTLS(zconf.GlobalObject.CertFile, zconf.GlobalObject.PrivateKeyFile) + } else { + err = srv.ListenAndServe() } - } else { - err := http.ListenAndServe(fmt.Sprintf("%s:%d", s.IP, s.WsPort), nil) - if err != nil { - panic(err) + // http.ErrServerClosed is returned after a successful Shutdown/Close call — + // this is not an error. (http.ErrServerClosed 是正常关闭后的返回值,不视为错误) + if err != nil && !errors.Is(err, http.ErrServerClosed) { + zlog.Ins().ErrorF("websocket server ListenAndServe err: %v", err) } - } + }() + // Block until Stop() signals exit via exitChan + // (阻塞等待 Stop() 通过 exitChan 发出退出信号) + <-s.exitChan + + // Gracefully shut down the WebSocket HTTP server with a timeout + // (带超时的优雅关闭 WebSocket HTTP Server) + ctx, cancel := context.WithTimeout(context.Background(), wsShutdownTimeout) + defer cancel() + if err := srv.Shutdown(ctx); err != nil { + zlog.Ins().ErrorF("websocket server shutdown err: %v", err) + // Fall back to forceful close if graceful shutdown fails + // (优雅关闭失败时兜底强制关闭) + _ = srv.Close() + } } func (s *Server) ListenKcpConn() { @@ -463,7 +508,15 @@ func (s *Server) Stop() { // Clear other connection information or other information that needs to be cleaned up // (将其他需要清理的连接信息或者其他信息 也要一并停止或者清理) s.ConnMgr.ClearConn() - close(s.exitChan) + + // Use sync.Once to ensure exitChan is closed only once, making Stop() safe to call + // multiple times without panicking. Closing exitChan signals all listeners + // (including ListenWebsocketConn) to shut down gracefully. + // (使用 sync.Once 确保 exitChan 只被关闭一次,使 Stop() 可重复调用不 panic。 + // 关闭 exitChan 会通知所有监听协程(包括 ListenWebsocketConn)执行优雅停服。) + s.stopOnce.Do(func() { + close(s.exitChan) + }) } // Serve runs the server (运行服务) diff --git a/znet/server_test.go b/znet/server_test.go index 65c1e069..273a9a79 100644 --- a/znet/server_test.go +++ b/znet/server_test.go @@ -8,10 +8,21 @@ import ( "testing" "time" + "github.com/aceld/zinx/zconf" "github.com/aceld/zinx/ziface" "github.com/aceld/zinx/zpack" ) +// retryInterval is the polling interval used by test helper functions that wait +// for a port to become available or unavailable. +// (retryInterval 是测试辅助函数轮询端口状态时的间隔时间) +const retryInterval = 20 * time.Millisecond + +// dialTimeout is the per-attempt TCP dial timeout used when checking if a port +// is accepting connections. +// (dialTimeout 是检测端口是否在监听时每次 TCP 拨号的超时时间) +const dialTimeout = 50 * time.Millisecond + // run in terminal: // go test -v ./znet -run=TestServer @@ -212,3 +223,111 @@ func TestCloseConnectionBeforeSendMsg(t *testing.T) { wg.Wait() s.Stop() } + +// waitForPort retries binding to addr until it succeeds (port released) or the +// deadline is exceeded. It returns nil on success and an error otherwise. +// (重试绑定端口直到成功或超时,成功返回 nil) +func waitForPort(addr string, timeout time.Duration) error { + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + ln, err := net.Listen("tcp", addr) + if err == nil { + _ = ln.Close() + return nil + } + time.Sleep(retryInterval) + } + return fmt.Errorf("port %s not released within %v", addr, timeout) +} + +// waitForPortListening retries connecting to addr until a connection succeeds +// (server is ready) or the deadline is exceeded. +// (重试连接直到成功或超时,用于等待服务端就绪) +func waitForPortListening(addr string, timeout time.Duration) error { + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + conn, err := net.DialTimeout("tcp", addr, dialTimeout) + if err == nil { + _ = conn.Close() + return nil + } + time.Sleep(retryInterval) + } + return fmt.Errorf("port %s not listening within %v", addr, timeout) +} + +// TestWebsocketServerGracefulStop verifies that calling Stop() on a WebSocket-only +// server does not block, and that the HTTP listener port is released after Stop returns. +// (验证 WebSocket-only 模式下 Stop() 不阻塞且端口被释放) +func TestWebsocketServerGracefulStop(t *testing.T) { + // Use a dedicated port to avoid conflicts with TCP tests (使用独立端口避免冲突) + const wsPort = 19990 + + config := &zconf.Config{ + Host: "127.0.0.1", + WsPort: wsPort, + WsPath: "/ws", + Mode: zconf.ServerModeWebsocket, + } + s := NewUserConfServer(config) + s.Start() + + // Wait until the server is actually listening (等待服务端真正开始监听) + addr := fmt.Sprintf("127.0.0.1:%d", wsPort) + if err := waitForPortListening(addr, 3*time.Second); err != nil { + t.Fatalf("server did not start listening: %v", err) + } + + // Stop() must return promptly — if it blocks, the test will time out. + // (Stop() 必须及时返回,否则测试会超时) + done := make(chan struct{}) + go func() { + s.Stop() + close(done) + }() + + select { + case <-done: + // Stop returned in time — good. + case <-time.After(3 * time.Second): + t.Fatal("Stop() blocked for more than 3s in websocket-only mode") + } + + // Verify the port is released by retrying until the bind succeeds or timeout. + // (重试绑定端口,验证端口已被释放) + if err := waitForPort(addr, 3*time.Second); err != nil { + t.Fatalf("port %d not released after Stop(): %v", wsPort, err) + } +} + +// TestWebsocketServerStopIdempotent verifies that calling Stop() multiple times +// does not panic (thanks to sync.Once protecting the exitChan close). +// (验证多次调用 Stop() 不会 panic) +func TestWebsocketServerStopIdempotent(t *testing.T) { + const wsPort = 19991 + + config := &zconf.Config{ + Host: "127.0.0.1", + WsPort: wsPort, + WsPath: "/ws", + Mode: zconf.ServerModeWebsocket, + } + s := NewUserConfServer(config) + s.Start() + + // Wait until the server is actually listening (等待服务端真正开始监听) + addr := fmt.Sprintf("127.0.0.1:%d", wsPort) + if err := waitForPortListening(addr, 3*time.Second); err != nil { + t.Fatalf("server did not start listening: %v", err) + } + + // Calling Stop() twice must not panic. + // (两次调用 Stop() 不应 panic) + defer func() { + if r := recover(); r != nil { + t.Fatalf("Stop() panicked on second call: %v", r) + } + }() + s.Stop() + s.Stop() +}