Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/app/app_suite_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package app_test
package app

import (
"testing"
Expand Down
89 changes: 89 additions & 0 deletions internal/app/openapi.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package app

import (
"fmt"
"net/http"
"strings"

catalogapi "github.com/dcm-project/control-plane/api/catalog/v1alpha1"
policyapi "github.com/dcm-project/control-plane/api/policy/v1alpha1"
spproviderapi "github.com/dcm-project/control-plane/api/sp/v1alpha1/provider"
sprmapi "github.com/dcm-project/control-plane/api/sp/v1alpha1/resource_manager"
"github.com/getkin/kin-openapi/openapi3"
"github.com/getkin/kin-openapi/openapi3filter"
nethttpmiddleware "github.com/oapi-codegen/nethttp-middleware"
)

const apiV1Alpha1Prefix = "/api/v1alpha1"

type openAPIValidators struct {
catalog func(http.Handler) http.Handler
policy func(http.Handler) http.Handler
provider func(http.Handler) http.Handler
rm func(http.Handler) http.Handler
}

func newOpenAPIValidators() (*openAPIValidators, error) {
catalogSpec, err := catalogapi.GetSpec()
if err != nil {
return nil, fmt.Errorf("load catalog OpenAPI spec: %w", err)
}

policySpec, err := policyapi.GetSpec()
if err != nil {
return nil, fmt.Errorf("load policy OpenAPI spec: %w", err)
}

providerSpec, err := spproviderapi.GetSpec()
if err != nil {
return nil, fmt.Errorf("load service provider OpenAPI spec: %w", err)
}

rmSpec, err := sprmapi.GetSpec()
if err != nil {
return nil, fmt.Errorf("load resource manager OpenAPI spec: %w", err)
}

return &openAPIValidators{
catalog: oapiRequestValidator(catalogSpec),
policy: oapiRequestValidator(policySpec),
provider: oapiRequestValidator(providerSpec),
rm: oapiRequestValidator(rmSpec),
}, nil
}

func oapiRequestValidator(spec *openapi3.T) func(http.Handler) http.Handler {
return nethttpmiddleware.OapiRequestValidatorWithOptions(spec, &nethttpmiddleware.Options{
Options: openapi3filter.Options{
AuthenticationFunc: openapi3filter.NoopAuthenticationFunc,
},
SilenceServersWarning: true,
})
}

func (v *openAPIValidators) middleware() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
if !strings.HasPrefix(path, apiV1Alpha1Prefix) {
next.ServeHTTP(w, r)
return
}
if path == monolithHealthPath {
next.ServeHTTP(w, r)
return
}

switch {
case strings.HasPrefix(path, apiV1Alpha1Prefix+"/service-type-instances"):
v.rm(next).ServeHTTP(w, r)
case strings.HasPrefix(path, apiV1Alpha1Prefix+"/providers"):
v.provider(next).ServeHTTP(w, r)
case strings.HasPrefix(path, apiV1Alpha1Prefix+"/policies"):
v.policy(next).ServeHTTP(w, r)
default:
v.catalog(next).ServeHTTP(w, r)
}
})
}
}
74 changes: 74 additions & 0 deletions internal/app/openapi_validation_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package app

import (
"net/http"
"net/http/httptest"
"strings"

"github.com/go-chi/chi/v5"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

var _ = Describe("OpenAPI request validation", func() {
var validators *openAPIValidators

BeforeEach(func() {
var err error
validators, err = newOpenAPIValidators()
Expect(err).NotTo(HaveOccurred())
})

Describe("monolith health", func() {
It("bypasses domain validators", func() {
router := chi.NewRouter()
router.Use(validators.middleware())
registerMonolithHealth(router)

req := httptest.NewRequest(http.MethodGet, monolithHealthPath, nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)

Expect(rec.Code).To(Equal(http.StatusOK))
})
})

Describe("policy routes", func() {
It("rejects malformed JSON on POST /policies", func() {
expectInvalidJSONRejected(validators, "/api/v1alpha1/policies")
})
})

Describe("catalog routes", func() {
It("rejects malformed JSON on POST /catalog-items", func() {
expectInvalidJSONRejected(validators, "/api/v1alpha1/catalog-items")
})
})

Describe("SP provider routes", func() {
It("rejects malformed JSON on POST /providers", func() {
expectInvalidJSONRejected(validators, "/api/v1alpha1/providers")
})
})

Describe("SP resource manager routes", func() {
It("rejects malformed JSON on POST /service-type-instances", func() {
expectInvalidJSONRejected(validators, "/api/v1alpha1/service-type-instances")
})
})
})

func expectInvalidJSONRejected(validators *openAPIValidators, path string) {
router := chi.NewRouter()
router.Use(validators.middleware())
router.Post(path, func(_ http.ResponseWriter, _ *http.Request) {
Fail("validator passed invalid JSON through to handler for POST " + path)
})

req := httptest.NewRequest(http.MethodPost, path, strings.NewReader("not-json"))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)

Expect(rec.Code).To(Equal(http.StatusBadRequest), rec.Body.String())
}
16 changes: 13 additions & 3 deletions internal/app/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,16 @@ func Run() int {
cleanupScheduler.Start(ctx)
defer cleanupScheduler.Stop()

router := newRouter(RouteHandlers{
router, err := newRouter(RouteHandlers{
Catalog: cataloghandlers.NewHandler(catalogSvc, logger),
Policy: policyhandlers.NewPolicyHandler(policyService),
SPProvider: spproviderhandler.NewHandler(spProviderService),
SPRM: sprmhandler.NewHandler(spInstanceService),
})
if err != nil {
slog.Error("Failed to configure HTTP router", "error", err)
return 1
}

listener, err := net.Listen("tcp", cfg.Service.BindAddress)
if err != nil {
Expand Down Expand Up @@ -204,10 +208,16 @@ type RouteHandlers struct {
SPRM sprmserver.StrictServerInterface
}

func newRouter(h RouteHandlers) chi.Router {
func newRouter(h RouteHandlers) (chi.Router, error) {
validators, err := newOpenAPIValidators()
if err != nil {
return nil, err
}

router := chi.NewRouter()
router.Use(middleware.RequestID)
router.Use(middleware.Recoverer)
router.Use(validators.middleware())

const baseURL = "/api/v1alpha1"

Expand Down Expand Up @@ -237,7 +247,7 @@ func newRouter(h RouteHandlers) chi.Router {
)
router.Mount(baseURL, apiRouter)

return router
return router, nil
}

func buildPlacementClient(cfg *Config, svc *placementservice.PlacementService, logger *slog.Logger) (catalogplacement.Client, error) {
Expand Down