diff --git a/demo-app/cmd/api/main.go b/demo-app/cmd/api/main.go index c17208f..af9aca2 100644 --- a/demo-app/cmd/api/main.go +++ b/demo-app/cmd/api/main.go @@ -49,8 +49,8 @@ func main() { }) }) - // Add handle group to HTTP server with /subApi path - api.HandleGroup("/subApi", subApiGroup) + // Mount handle group to HTTP server with /subApi path + api.Mount("/subApi", subApiGroup) // Register HTTP server as application server app.RegisterService("api", api) diff --git a/demo-app/cmd/auth/main.go b/demo-app/cmd/auth/main.go index 9215b10..ba2bc84 100644 --- a/demo-app/cmd/auth/main.go +++ b/demo-app/cmd/auth/main.go @@ -34,7 +34,7 @@ func main() { api.Use(log.NewTraceIDMiddleware(nil, "")) api.Use(httpserver.NewRecoverMiddleware()) - api.HandleGroup("/auth", authDomain.HandleGroup) + api.Mount("/auth", authDomain.HandleGroup) protected := httpserver.NewHandlerGroup() protected.Use(authDomain.Middleware) @@ -42,7 +42,7 @@ func main() { user := auth.UserFromContext(r.Context()) w.Write([]byte("Welcome, " + user.Username)) }) - api.HandleGroup("/api", protected) + api.Mount("/api", protected) api.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("pong")) diff --git a/docs/src/content/docs/packages/auth.mdx b/docs/src/content/docs/packages/auth.mdx index 01d6a09..827df0a 100644 --- a/docs/src/content/docs/packages/auth.mdx +++ b/docs/src/content/docs/packages/auth.mdx @@ -67,7 +67,7 @@ Core Components: ```go api := httpserver.New("8080", 3*time.Second) - api.HandleGroup("/auth", authDomain.HandleGroup) + api.Mount("/auth", authDomain.HandleGroup) ``` This exposes the following endpoints: @@ -91,7 +91,7 @@ Core Components: } }) - api.HandleGroup("/api", protectedGroup) + api.Mount("/api", protectedGroup) ``` The `AuthenticationMiddleware` returns 401 Unauthorized if no valid session is found. Use `auth.UserFromContext()` to access the authenticated user. @@ -136,7 +136,7 @@ app.RegisterDomain("auth", "main", authDomain) // Set up HTTP server with auth endpoints api := httpserver.New("8080", 3*time.Second) -api.HandleGroup("/auth", authDomain.HandleGroup) +api.Mount("/auth", authDomain.HandleGroup) app.RegisterService("api", api) app.Run(ctx) diff --git a/docs/src/content/docs/packages/httpserver.mdx b/docs/src/content/docs/packages/httpserver.mdx index 2abfb67..9b9e0aa 100644 --- a/docs/src/content/docs/packages/httpserver.mdx +++ b/docs/src/content/docs/packages/httpserver.mdx @@ -77,7 +77,7 @@ Core Components: 6. Mount the group on the server ```go - server.HandleGroup("/api", apiGroup) + server.Mount("/api", apiGroup) ``` The group is now accessible at `/api/users` and `/api/posts`. The path prefix is automatically stripped. diff --git a/httpserver/fileserver.go b/httpserver/fileserver.go index 70a5061..cbc127a 100644 --- a/httpserver/fileserver.go +++ b/httpserver/fileserver.go @@ -15,7 +15,7 @@ type FileServer struct { // NewFileServer creates a new FileServer instance with the given file system, base path, and port. func NewFileServer(fs fs.FS, basePath, port string) *FileServer { server := New(port, 1*time.Second) - server.HandleGroup(basePath, http.FileServer(http.FS(fs))) + server.Mount(basePath, http.FileServer(http.FS(fs))) return &FileServer{server: server} } diff --git a/httpserver/handlergroup.go b/httpserver/handlergroup.go index a7047ac..71988d2 100644 --- a/httpserver/handlergroup.go +++ b/httpserver/handlergroup.go @@ -2,6 +2,7 @@ package httpserver import ( "net/http" + "strings" ) // HandlerGroup represents a group of HTTP handlers that share common middlewares. @@ -37,9 +38,56 @@ func (hg *HandlerGroup) HandleFunc(pattern string, handler func(http.ResponseWri hg.mux.Handle(pattern, http.HandlerFunc(handler)) } -// HandleGroup applies `http.StripPrefix` to http.Handler and registers it for the given pattern -func (hg *HandlerGroup) HandleGroup(pattern string, handler http.Handler) { - hg.mux.Handle(pattern+"/", http.StripPrefix(pattern, handler)) +// Mount mounts handler at both prefix (group root) and prefix+"/" (subtree). +// The handler receives requests with the path prefix stripped; an empty stripped +// path is normalized to "/" so that nested groups can register "GET /" etc. +func (hg *HandlerGroup) Mount(prefix string, handler http.Handler) { + if prefix != "" && !strings.HasPrefix(prefix, "/") { + panic("httpserver: mount prefix must be a path starting with /") + } + + prefix = strings.TrimRight(prefix, "/") + if prefix == "" { + prefix = "/" + } + mounted := stripPrefix(prefix, handler) + + if prefix == "/" { + hg.mux.Handle(prefix, mounted) + return + } + + hg.mux.Handle(prefix, mounted) + hg.mux.Handle(prefix+"/", mounted) +} + +// stripPrefix returns a handler that strips prefix from r.URL.Path, writing a 404 +// if the request path does not start with prefix. If stripping leaves an empty +// path, it is normalized to "/". +func stripPrefix(prefix string, handler http.Handler) http.Handler { + if prefix == "/" { + return handler + } + + normalized := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "" { + r2 := r.Clone(r.Context()) + r2.URL.Path = "/" + + if r.URL.RawPath == "" { + handler.ServeHTTP(w, r2) + return + } + + r2.URL.RawPath = "/" + handler.ServeHTTP(w, r2) + return + } + + handler.ServeHTTP(w, r) + }) + + return http.StripPrefix(prefix, normalized) } // ServeHTTP implements the http.Handler interface, allowing HandlerGroup to diff --git a/httpserver/httpserver_test.go b/httpserver/httpserver_test.go index 64bb356..74c316c 100644 --- a/httpserver/httpserver_test.go +++ b/httpserver/httpserver_test.go @@ -61,14 +61,14 @@ func TestHTTPServer(t *testing.T) { } }) - t.Run("handle group", func(t *testing.T) { + t.Run("mount group", func(t *testing.T) { t.Parallel() hg := httpserver.NewHandlerGroup() hg.Handle("/test", &handler{}) server := httpserver.New("", 0) - server.HandleGroup("/hg", hg) + server.Mount("/hg", hg) r := httptest.NewRequest(http.MethodGet, "/hg/test", nil) w := httptest.NewRecorder() @@ -82,6 +82,208 @@ func TestHTTPServer(t *testing.T) { } }) + t.Run("mount group exact root", func(t *testing.T) { + t.Parallel() + + hg := httpserver.NewHandlerGroup() + hg.HandleFunc("GET /", func(w http.ResponseWriter, _ *http.Request) { + w.Write([]byte("root")) + }) + + server := httpserver.New("", 0) + server.Mount("/hg", hg) + + r := httptest.NewRequest(http.MethodGet, "/hg", nil) + w := httptest.NewRecorder() + + server.ServeHTTP(w, r) + + resp := w.Result() + body, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status code to be 200, got %d", resp.StatusCode) + } + + if string(body) != "root" { + t.Fatalf("expected body to be 'root', got %s", string(body)) + } + }) + + t.Run("mount group subtree", func(t *testing.T) { + t.Parallel() + + hg := httpserver.NewHandlerGroup() + hg.HandleFunc("GET /verify", func(w http.ResponseWriter, _ *http.Request) { + w.Write([]byte("verified")) + }) + + server := httpserver.New("", 0) + server.Mount("/domains", hg) + + r := httptest.NewRequest(http.MethodGet, "/domains/verify", nil) + w := httptest.NewRecorder() + + server.ServeHTTP(w, r) + + resp := w.Result() + body, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status code to be 200, got %d", resp.StatusCode) + } + + if string(body) != "verified" { + t.Fatalf("expected body to be 'verified', got %s", string(body)) + } + }) + + t.Run("mount group trailing slash pattern", func(t *testing.T) { + t.Parallel() + + hg := httpserver.NewHandlerGroup() + hg.HandleFunc("GET /", func(w http.ResponseWriter, _ *http.Request) { + w.Write([]byte("root")) + }) + + server := httpserver.New("", 0) + server.Mount("/hg/", hg) + + r := httptest.NewRequest(http.MethodGet, "/hg", nil) + w := httptest.NewRecorder() + + server.ServeHTTP(w, r) + + resp := w.Result() + body, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status code to be 200, got %d", resp.StatusCode) + } + + if string(body) != "root" { + t.Fatalf("expected body to be 'root', got %s", string(body)) + } + }) + + t.Run("mount group root prefix", func(t *testing.T) { + t.Parallel() + + hg := httpserver.NewHandlerGroup() + hg.HandleFunc("GET /verify", func(w http.ResponseWriter, _ *http.Request) { + w.Write([]byte("verified")) + }) + + server := httpserver.New("", 0) + server.Mount("/", hg) + + r := httptest.NewRequest(http.MethodGet, "/verify", nil) + w := httptest.NewRecorder() + + server.ServeHTTP(w, r) + + resp := w.Result() + body, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status code to be 200, got %d", resp.StatusCode) + } + + if string(body) != "verified" { + t.Fatalf("expected body to be 'verified', got %s", string(body)) + } + }) + + t.Run("mount group empty root prefix", func(t *testing.T) { + t.Parallel() + + hg := httpserver.NewHandlerGroup() + hg.HandleFunc("GET /verify", func(w http.ResponseWriter, _ *http.Request) { + w.Write([]byte("verified")) + }) + + server := httpserver.New("", 0) + server.Mount("", hg) + + r := httptest.NewRequest(http.MethodGet, "/verify", nil) + w := httptest.NewRecorder() + + server.ServeHTTP(w, r) + + resp := w.Result() + body, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status code to be 200, got %d", resp.StatusCode) + } + + if string(body) != "verified" { + t.Fatalf("expected body to be 'verified', got %s", string(body)) + } + }) + + t.Run("mount rejects method pattern", func(t *testing.T) { + t.Parallel() + + hg := httpserver.NewHandlerGroup() + server := httpserver.New("", 0) + + defer func() { + if recover() == nil { + t.Fatal("expected Mount with method pattern to panic") + } + }() + + server.Mount("GET /hg", hg) + }) + + t.Run("mount rejects escaped prefix mismatch", func(t *testing.T) { + t.Parallel() + + hg := httpserver.NewHandlerGroup() + hg.HandleFunc("GET /", func(w http.ResponseWriter, _ *http.Request) { + w.Write([]byte("root")) + }) + + server := httpserver.New("", 0) + server.Mount("/hg", hg) + + r := httptest.NewRequest(http.MethodGet, "/hg", nil) + r.URL.RawPath = "/%68g" + w := httptest.NewRecorder() + + server.ServeHTTP(w, r) + + resp := w.Result() + + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("expected status code to be 404, got %d", resp.StatusCode) + } + }) + + t.Run("mount group not found", func(t *testing.T) { + t.Parallel() + + hg := httpserver.NewHandlerGroup() + hg.HandleFunc("GET /", func(w http.ResponseWriter, _ *http.Request) { + w.Write([]byte("root")) + }) + + server := httpserver.New("", 0) + server.Mount("/hg", hg) + + r := httptest.NewRequest(http.MethodGet, "/other", nil) + w := httptest.NewRecorder() + + server.ServeHTTP(w, r) + + resp := w.Result() + + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("expected status code to be 404, got %d", resp.StatusCode) + } + }) + t.Run("healthcheck", func(t *testing.T) { t.Parallel()