Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 64 additions & 11 deletions znet/server.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package znet

import (
"context"
"crypto/rand"
"crypto/tls"
"errors"
Expand All @@ -9,6 +10,7 @@ import (
"net/http"
"os"
"os/signal"
"sync"
"sync/atomic"
"syscall"
"time"
Expand All @@ -26,6 +28,11 @@ import (
"github.com/aceld/zinx/zpack"
)

// wsShutdownTimeout is the maximum time to wait for active WebSocket connections
// to finish when stopping the server gracefully.
// (wsShutdownTimeout 是优雅停服时等待活跃 WebSocket 连接完成的最长时间)
const wsShutdownTimeout = 5 * time.Second

// Server interface implementation, defines a Server service class
// (接口实现,定义一个Server服务类)
type Server struct {
Expand Down Expand Up @@ -71,6 +78,14 @@ type Server struct {
// (异步捕获连接关闭状态)
exitChan chan struct{}

// stopOnce ensures Stop() is idempotent and exitChan is closed only once
// (stopOnce 保证 Stop() 幂等,exitChan 只被关闭一次)
stopOnce sync.Once

// WebSocket HTTP server instance, used for graceful shutdown
// (WebSocket HTTP 服务实例,用于优雅停服)
wsServer *http.Server

// Decoder for dealing with message fragmentation and reassembly
// (断粘包解码器)
decoder ziface.IDecoder
Expand Down Expand Up @@ -315,7 +330,11 @@ func (s *Server) ListenTcpConn() {

func (s *Server) ListenWebsocketConn() {
zlog.Ins().InfoF("[START] WEBSOCKET Server name: %s,listener at IP: %s, Port %d, Path %s is starting", s.Name, s.IP, s.WsPort, s.WsPath)
http.HandleFunc(s.WsPath, func(w http.ResponseWriter, r *http.Request) {

// Use a local ServeMux to avoid polluting the global http.DefaultServeMux
// (使用局部 ServeMux 避免污染全局 http.DefaultServeMux)
mux := http.NewServeMux()
mux.HandleFunc(s.WsPath, func(w http.ResponseWriter, r *http.Request) {
// 1. Check if the server has reached the maximum allowed number of connections
// (设置服务器最大连接控制,如果超过最大连接,则等待)
if s.ConnMgr.Len() >= zconf.GlobalObject.MaxConn {
Expand Down Expand Up @@ -357,18 +376,44 @@ func (s *Server) ListenWebsocketConn() {

})

if zconf.GlobalObject.CertFile != "" && zconf.GlobalObject.PrivateKeyFile != "" {
err := http.ListenAndServeTLS(fmt.Sprintf("%s:%d", s.IP, s.WsPort), zconf.GlobalObject.CertFile, zconf.GlobalObject.PrivateKeyFile, nil)
if err != nil {
panic(err)
// Create an explicit http.Server so we can shut it down gracefully later
// (显式创建 http.Server,以便后续能够优雅停服)
srv := &http.Server{
Addr: fmt.Sprintf("%s:%d", s.IP, s.WsPort),
Handler: mux,
}
s.wsServer = srv

// Start the HTTP server in a background goroutine
// (在后台 goroutine 中启动 HTTP Server)
go func() {
var err error
if zconf.GlobalObject.CertFile != "" && zconf.GlobalObject.PrivateKeyFile != "" {
err = srv.ListenAndServeTLS(zconf.GlobalObject.CertFile, zconf.GlobalObject.PrivateKeyFile)
} else {
err = srv.ListenAndServe()
}
} else {
err := http.ListenAndServe(fmt.Sprintf("%s:%d", s.IP, s.WsPort), nil)
if err != nil {
panic(err)
// http.ErrServerClosed is returned after a successful Shutdown/Close call —
// this is not an error. (http.ErrServerClosed 是正常关闭后的返回值,不视为错误)
if err != nil && !errors.Is(err, http.ErrServerClosed) {
zlog.Ins().ErrorF("websocket server ListenAndServe err: %v", err)
}
}
}()

// Block until Stop() signals exit via exitChan
// (阻塞等待 Stop() 通过 exitChan 发出退出信号)
<-s.exitChan

// Gracefully shut down the WebSocket HTTP server with a timeout
// (带超时的优雅关闭 WebSocket HTTP Server)
ctx, cancel := context.WithTimeout(context.Background(), wsShutdownTimeout)
defer cancel()
if err := srv.Shutdown(ctx); err != nil {
zlog.Ins().ErrorF("websocket server shutdown err: %v", err)
// Fall back to forceful close if graceful shutdown fails
// (优雅关闭失败时兜底强制关闭)
_ = srv.Close()
}
}

func (s *Server) ListenKcpConn() {
Expand Down Expand Up @@ -463,7 +508,15 @@ func (s *Server) Stop() {
// Clear other connection information or other information that needs to be cleaned up
// (将其他需要清理的连接信息或者其他信息 也要一并停止或者清理)
s.ConnMgr.ClearConn()
close(s.exitChan)

// Use sync.Once to ensure exitChan is closed only once, making Stop() safe to call
// multiple times without panicking. Closing exitChan signals all listeners
// (including ListenWebsocketConn) to shut down gracefully.
// (使用 sync.Once 确保 exitChan 只被关闭一次,使 Stop() 可重复调用不 panic。
// 关闭 exitChan 会通知所有监听协程(包括 ListenWebsocketConn)执行优雅停服。)
s.stopOnce.Do(func() {
close(s.exitChan)
})
}

// Serve runs the server (运行服务)
Expand Down
119 changes: 119 additions & 0 deletions znet/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,21 @@ import (
"testing"
"time"

"github.com/aceld/zinx/zconf"
"github.com/aceld/zinx/ziface"
"github.com/aceld/zinx/zpack"
)

// retryInterval is the polling interval used by test helper functions that wait
// for a port to become available or unavailable.
// (retryInterval 是测试辅助函数轮询端口状态时的间隔时间)
const retryInterval = 20 * time.Millisecond

// dialTimeout is the per-attempt TCP dial timeout used when checking if a port
// is accepting connections.
// (dialTimeout 是检测端口是否在监听时每次 TCP 拨号的超时时间)
const dialTimeout = 50 * time.Millisecond

// run in terminal:
// go test -v ./znet -run=TestServer

Expand Down Expand Up @@ -212,3 +223,111 @@ func TestCloseConnectionBeforeSendMsg(t *testing.T) {
wg.Wait()
s.Stop()
}

// waitForPort retries binding to addr until it succeeds (port released) or the
// deadline is exceeded. It returns nil on success and an error otherwise.
// (重试绑定端口直到成功或超时,成功返回 nil)
func waitForPort(addr string, timeout time.Duration) error {
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
ln, err := net.Listen("tcp", addr)
if err == nil {
_ = ln.Close()
return nil
}
time.Sleep(retryInterval)
}
return fmt.Errorf("port %s not released within %v", addr, timeout)
}

// waitForPortListening retries connecting to addr until a connection succeeds
// (server is ready) or the deadline is exceeded.
// (重试连接直到成功或超时,用于等待服务端就绪)
func waitForPortListening(addr string, timeout time.Duration) error {
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
conn, err := net.DialTimeout("tcp", addr, dialTimeout)
if err == nil {
_ = conn.Close()
return nil
}
time.Sleep(retryInterval)
}
return fmt.Errorf("port %s not listening within %v", addr, timeout)
}

// TestWebsocketServerGracefulStop verifies that calling Stop() on a WebSocket-only
// server does not block, and that the HTTP listener port is released after Stop returns.
// (验证 WebSocket-only 模式下 Stop() 不阻塞且端口被释放)
func TestWebsocketServerGracefulStop(t *testing.T) {
// Use a dedicated port to avoid conflicts with TCP tests (使用独立端口避免冲突)
const wsPort = 19990

config := &zconf.Config{
Host: "127.0.0.1",
WsPort: wsPort,
WsPath: "/ws",
Mode: zconf.ServerModeWebsocket,
}
s := NewUserConfServer(config)
s.Start()

// Wait until the server is actually listening (等待服务端真正开始监听)
addr := fmt.Sprintf("127.0.0.1:%d", wsPort)
if err := waitForPortListening(addr, 3*time.Second); err != nil {
t.Fatalf("server did not start listening: %v", err)
}

// Stop() must return promptly — if it blocks, the test will time out.
// (Stop() 必须及时返回,否则测试会超时)
done := make(chan struct{})
go func() {
s.Stop()
close(done)
}()

select {
case <-done:
// Stop returned in time — good.
case <-time.After(3 * time.Second):
t.Fatal("Stop() blocked for more than 3s in websocket-only mode")
}

// Verify the port is released by retrying until the bind succeeds or timeout.
// (重试绑定端口,验证端口已被释放)
if err := waitForPort(addr, 3*time.Second); err != nil {
t.Fatalf("port %d not released after Stop(): %v", wsPort, err)
}
}

// TestWebsocketServerStopIdempotent verifies that calling Stop() multiple times
// does not panic (thanks to sync.Once protecting the exitChan close).
// (验证多次调用 Stop() 不会 panic)
func TestWebsocketServerStopIdempotent(t *testing.T) {
const wsPort = 19991

config := &zconf.Config{
Host: "127.0.0.1",
WsPort: wsPort,
WsPath: "/ws",
Mode: zconf.ServerModeWebsocket,
}
s := NewUserConfServer(config)
s.Start()

// Wait until the server is actually listening (等待服务端真正开始监听)
addr := fmt.Sprintf("127.0.0.1:%d", wsPort)
if err := waitForPortListening(addr, 3*time.Second); err != nil {
t.Fatalf("server did not start listening: %v", err)
}

// Calling Stop() twice must not panic.
// (两次调用 Stop() 不应 panic)
defer func() {
if r := recover(); r != nil {
t.Fatalf("Stop() panicked on second call: %v", r)
}
}()
s.Stop()
s.Stop()
}
Loading