Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions demo-app/cmd/api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ func main() {
})
})

// Add handle group to HTTP server with /subApi path
api.HandleGroup("/subApi", subApiGroup)
// Mount handle group to HTTP server with /subApi path
api.Mount("/subApi", subApiGroup)

// Register HTTP server as application server
app.RegisterService("api", api)
Expand Down
4 changes: 2 additions & 2 deletions demo-app/cmd/auth/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ func main() {
api.Use(log.NewTraceIDMiddleware(nil, ""))
api.Use(httpserver.NewRecoverMiddleware())

api.HandleGroup("/auth", authDomain.HandleGroup)
api.Mount("/auth", authDomain.HandleGroup)

protected := httpserver.NewHandlerGroup()
protected.Use(authDomain.Middleware)
protected.HandleFunc("/profile", func(w http.ResponseWriter, r *http.Request) {
user := auth.UserFromContext(r.Context())
w.Write([]byte("Welcome, " + user.Username))
})
api.HandleGroup("/api", protected)
api.Mount("/api", protected)

api.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("pong"))
Expand Down
6 changes: 3 additions & 3 deletions docs/src/content/docs/packages/auth.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ Core Components:
```go
api := httpserver.New("8080", 3*time.Second)

api.HandleGroup("/auth", authDomain.HandleGroup)
api.Mount("/auth", authDomain.HandleGroup)
```

This exposes the following endpoints:
Expand All @@ -91,7 +91,7 @@ Core Components:
}
})

api.HandleGroup("/api", protectedGroup)
api.Mount("/api", protectedGroup)
```

The `AuthenticationMiddleware` returns 401 Unauthorized if no valid session is found. Use `auth.UserFromContext()` to access the authenticated user.
Expand Down Expand Up @@ -136,7 +136,7 @@ app.RegisterDomain("auth", "main", authDomain)

// Set up HTTP server with auth endpoints
api := httpserver.New("8080", 3*time.Second)
api.HandleGroup("/auth", authDomain.HandleGroup)
api.Mount("/auth", authDomain.HandleGroup)
app.RegisterService("api", api)

app.Run(ctx)
Expand Down
2 changes: 1 addition & 1 deletion docs/src/content/docs/packages/httpserver.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ Core Components:
6. Mount the group on the server

```go
server.HandleGroup("/api", apiGroup)
server.Mount("/api", apiGroup)
```

The group is now accessible at `/api/users` and `/api/posts`. The path prefix is automatically stripped.
Expand Down
2 changes: 1 addition & 1 deletion httpserver/fileserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type FileServer struct {
// NewFileServer creates a new FileServer instance with the given file system, base path, and port.
func NewFileServer(fs fs.FS, basePath, port string) *FileServer {
server := New(port, 1*time.Second)
server.HandleGroup(basePath, http.FileServer(http.FS(fs)))
server.Mount(basePath, http.FileServer(http.FS(fs)))

return &FileServer{server: server}
}
Expand Down
54 changes: 51 additions & 3 deletions httpserver/handlergroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package httpserver

import (
"net/http"
"strings"
)

// HandlerGroup represents a group of HTTP handlers that share common middlewares.
Expand Down Expand Up @@ -37,9 +38,56 @@ func (hg *HandlerGroup) HandleFunc(pattern string, handler func(http.ResponseWri
hg.mux.Handle(pattern, http.HandlerFunc(handler))
}

// HandleGroup applies `http.StripPrefix` to http.Handler and registers it for the given pattern
func (hg *HandlerGroup) HandleGroup(pattern string, handler http.Handler) {
hg.mux.Handle(pattern+"/", http.StripPrefix(pattern, handler))
// Mount mounts handler at both prefix (group root) and prefix+"/" (subtree).
// The handler receives requests with the path prefix stripped; an empty stripped
// path is normalized to "/" so that nested groups can register "GET /" etc.
func (hg *HandlerGroup) Mount(prefix string, handler http.Handler) {
if prefix != "" && !strings.HasPrefix(prefix, "/") {
panic("httpserver: mount prefix must be a path starting with /")
}

prefix = strings.TrimRight(prefix, "/")
if prefix == "" {
prefix = "/"
}
mounted := stripPrefix(prefix, handler)

if prefix == "/" {
hg.mux.Handle(prefix, mounted)
return
}

hg.mux.Handle(prefix, mounted)
hg.mux.Handle(prefix+"/", mounted)
}

// stripPrefix returns a handler that strips prefix from r.URL.Path, writing a 404
// if the request path does not start with prefix. If stripping leaves an empty
// path, it is normalized to "/".
func stripPrefix(prefix string, handler http.Handler) http.Handler {
if prefix == "/" {
return handler
}

normalized := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "" {
r2 := r.Clone(r.Context())
r2.URL.Path = "/"

if r.URL.RawPath == "" {
handler.ServeHTTP(w, r2)
return
}

r2.URL.RawPath = "/"
handler.ServeHTTP(w, r2)
return
}

handler.ServeHTTP(w, r)
})

return http.StripPrefix(prefix, normalized)
}

// ServeHTTP implements the http.Handler interface, allowing HandlerGroup to
Expand Down
206 changes: 204 additions & 2 deletions httpserver/httpserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@ func TestHTTPServer(t *testing.T) {
}
})

t.Run("handle group", func(t *testing.T) {
t.Run("mount group", func(t *testing.T) {
t.Parallel()

hg := httpserver.NewHandlerGroup()
hg.Handle("/test", &handler{})

server := httpserver.New("", 0)
server.HandleGroup("/hg", hg)
server.Mount("/hg", hg)

r := httptest.NewRequest(http.MethodGet, "/hg/test", nil)
w := httptest.NewRecorder()
Expand All @@ -82,6 +82,208 @@ func TestHTTPServer(t *testing.T) {
}
})

t.Run("mount group exact root", func(t *testing.T) {
t.Parallel()

hg := httpserver.NewHandlerGroup()
hg.HandleFunc("GET /", func(w http.ResponseWriter, _ *http.Request) {
w.Write([]byte("root"))
})

server := httpserver.New("", 0)
server.Mount("/hg", hg)

r := httptest.NewRequest(http.MethodGet, "/hg", nil)
w := httptest.NewRecorder()

server.ServeHTTP(w, r)

resp := w.Result()
body, _ := io.ReadAll(resp.Body)

if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status code to be 200, got %d", resp.StatusCode)
}

if string(body) != "root" {
t.Fatalf("expected body to be 'root', got %s", string(body))
}
})

t.Run("mount group subtree", func(t *testing.T) {
t.Parallel()

hg := httpserver.NewHandlerGroup()
hg.HandleFunc("GET /verify", func(w http.ResponseWriter, _ *http.Request) {
w.Write([]byte("verified"))
})

server := httpserver.New("", 0)
server.Mount("/domains", hg)

r := httptest.NewRequest(http.MethodGet, "/domains/verify", nil)
w := httptest.NewRecorder()

server.ServeHTTP(w, r)

resp := w.Result()
body, _ := io.ReadAll(resp.Body)

if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status code to be 200, got %d", resp.StatusCode)
}

if string(body) != "verified" {
t.Fatalf("expected body to be 'verified', got %s", string(body))
}
})

t.Run("mount group trailing slash pattern", func(t *testing.T) {
t.Parallel()

hg := httpserver.NewHandlerGroup()
hg.HandleFunc("GET /", func(w http.ResponseWriter, _ *http.Request) {
w.Write([]byte("root"))
})

server := httpserver.New("", 0)
server.Mount("/hg/", hg)

r := httptest.NewRequest(http.MethodGet, "/hg", nil)
w := httptest.NewRecorder()

server.ServeHTTP(w, r)

resp := w.Result()
body, _ := io.ReadAll(resp.Body)

if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status code to be 200, got %d", resp.StatusCode)
}

if string(body) != "root" {
t.Fatalf("expected body to be 'root', got %s", string(body))
}
})

t.Run("mount group root prefix", func(t *testing.T) {
t.Parallel()

hg := httpserver.NewHandlerGroup()
hg.HandleFunc("GET /verify", func(w http.ResponseWriter, _ *http.Request) {
w.Write([]byte("verified"))
})

server := httpserver.New("", 0)
server.Mount("/", hg)

r := httptest.NewRequest(http.MethodGet, "/verify", nil)
w := httptest.NewRecorder()

server.ServeHTTP(w, r)

resp := w.Result()
body, _ := io.ReadAll(resp.Body)

if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status code to be 200, got %d", resp.StatusCode)
}

if string(body) != "verified" {
t.Fatalf("expected body to be 'verified', got %s", string(body))
}
})

t.Run("mount group empty root prefix", func(t *testing.T) {
t.Parallel()

hg := httpserver.NewHandlerGroup()
hg.HandleFunc("GET /verify", func(w http.ResponseWriter, _ *http.Request) {
w.Write([]byte("verified"))
})

server := httpserver.New("", 0)
server.Mount("", hg)

r := httptest.NewRequest(http.MethodGet, "/verify", nil)
w := httptest.NewRecorder()

server.ServeHTTP(w, r)

resp := w.Result()
body, _ := io.ReadAll(resp.Body)

if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status code to be 200, got %d", resp.StatusCode)
}

if string(body) != "verified" {
t.Fatalf("expected body to be 'verified', got %s", string(body))
}
})

t.Run("mount rejects method pattern", func(t *testing.T) {
t.Parallel()

hg := httpserver.NewHandlerGroup()
server := httpserver.New("", 0)

defer func() {
if recover() == nil {
t.Fatal("expected Mount with method pattern to panic")
}
}()

server.Mount("GET /hg", hg)
})

t.Run("mount rejects escaped prefix mismatch", func(t *testing.T) {
t.Parallel()

hg := httpserver.NewHandlerGroup()
hg.HandleFunc("GET /", func(w http.ResponseWriter, _ *http.Request) {
w.Write([]byte("root"))
})

server := httpserver.New("", 0)
server.Mount("/hg", hg)

r := httptest.NewRequest(http.MethodGet, "/hg", nil)
r.URL.RawPath = "/%68g"
w := httptest.NewRecorder()

server.ServeHTTP(w, r)

resp := w.Result()

if resp.StatusCode != http.StatusNotFound {
t.Fatalf("expected status code to be 404, got %d", resp.StatusCode)
}
})

t.Run("mount group not found", func(t *testing.T) {
t.Parallel()

hg := httpserver.NewHandlerGroup()
hg.HandleFunc("GET /", func(w http.ResponseWriter, _ *http.Request) {
w.Write([]byte("root"))
})

server := httpserver.New("", 0)
server.Mount("/hg", hg)

r := httptest.NewRequest(http.MethodGet, "/other", nil)
w := httptest.NewRecorder()

server.ServeHTTP(w, r)

resp := w.Result()

if resp.StatusCode != http.StatusNotFound {
t.Fatalf("expected status code to be 404, got %d", resp.StatusCode)
}
})

t.Run("healthcheck", func(t *testing.T) {
t.Parallel()

Expand Down
Loading