Skip to content

Commit 25db825

Browse files
authored
add middleware (#3)
1 parent ee37b77 commit 25db825

2 files changed

Lines changed: 53 additions & 0 deletions

File tree

arpc.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,3 +339,18 @@ func (m *Manager) Handler(f any) http.Handler {
339339
}
340340
})
341341
}
342+
343+
type Middleware func(r *http.Request) error
344+
345+
func (m *Manager) Middleware(f Middleware) func(http.Handler) http.Handler {
346+
return func(h http.Handler) http.Handler {
347+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
348+
err := f(r)
349+
if err != nil {
350+
m.encodeAndHookError(w, r, nil, err)
351+
return
352+
}
353+
h.ServeHTTP(w, r)
354+
})
355+
}
356+
}

arpc_test.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,3 +235,41 @@ func TestManager_WrapError(t *testing.T) {
235235
assert.Equal(t, http.StatusOK, w.Code)
236236
assert.JSONEq(t, `{"ok":false,"error":{"code":"1000","message":"some error"}}`, w.Body.String())
237237
}
238+
239+
func TestMiddleware(t *testing.T) {
240+
t.Parallel()
241+
242+
m := arpc.New()
243+
244+
t.Run("Error", func(t *testing.T) {
245+
runHandler := false
246+
h := m.Middleware(func(r *http.Request) error {
247+
return arpc.NewError("middleware error")
248+
})(m.Handler(func() {
249+
runHandler = true
250+
}))
251+
252+
w := httptest.NewRecorder()
253+
r := httptest.NewRequest("POST", "/", nil)
254+
h.ServeHTTP(w, r)
255+
256+
assert.False(t, runHandler)
257+
assert.JSONEq(t, `{"ok":false,"error":{"message":"middleware error"}}`, w.Body.String())
258+
})
259+
260+
t.Run("OK", func(t *testing.T) {
261+
runHandler := false
262+
h := m.Middleware(func(r *http.Request) error {
263+
return nil
264+
})(m.Handler(func() {
265+
runHandler = true
266+
}))
267+
268+
w := httptest.NewRecorder()
269+
r := httptest.NewRequest("POST", "/", nil)
270+
h.ServeHTTP(w, r)
271+
272+
assert.True(t, runHandler)
273+
assert.JSONEq(t, `{"ok":true,"result":{}}`, w.Body.String())
274+
})
275+
}

0 commit comments

Comments
 (0)