Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions internal/apierrors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,8 @@ func NewBadRequest(message string) error {
func NewConflict(message string) error {
return fmt.Errorf("%w: %s", ErrConflict, message)
}

// NewForbidden creates a new forbidden error with the given message
func NewForbidden(message string) error {
return fmt.Errorf("%w: %s", ErrForbidden, message)
}
22 changes: 13 additions & 9 deletions internal/db/artifacts.go
Original file line number Diff line number Diff line change
Expand Up @@ -369,23 +369,27 @@ func GetOrCreateArtifact(ctx context.Context, orgID uuid.UUID, artifactName stri
db := internalctx.GetDb(ctx)
rows, err := db.Query(
ctx,
`SELECT `+artifactOutputExpr+`
FROM Artifact a
WHERE a.name = @name AND a.organization_id = @orgId`,
`WITH inserted AS (
INSERT INTO Artifact (name, organization_id)
VALUES (@name, @orgId)
ON CONFLICT ON CONSTRAINT Artifact_unique_name DO NOTHING
RETURNING *
)
SELECT`+artifactOutputExpr+`FROM inserted a
UNION ALL
SELECT`+artifactOutputExpr+`FROM Artifact a
WHERE a.organization_id = @orgId
AND a.name = @name
AND NOT EXISTS (SELECT 1 FROM inserted)`,
pgx.NamedArgs{
"name": artifactName,
"orgId": orgID,
},
)
if err != nil {
return nil, fmt.Errorf("could not query artifact: %w", err)
return nil, fmt.Errorf("could not get or create artifact: %w", err)
}
if result, err := pgx.CollectExactlyOneRow(rows, pgx.RowToStructByName[types.Artifact]); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
artifact := &types.Artifact{Name: artifactName, OrganizationID: orgID}
err = CreateArtifact(ctx, artifact)
return artifact, err
}
return nil, fmt.Errorf("could not collect artifact: %w", err)
} else {
return &result, nil
Expand Down
5 changes: 5 additions & 0 deletions internal/db/auth_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package db

import (
"context"
"errors"
"fmt"
"time"

"github.com/distr-sh/distr/internal/apierrors"
internalctx "github.com/distr-sh/distr/internal/context"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
Expand Down Expand Up @@ -38,6 +40,9 @@ func DeleteOIDCState(ctx context.Context, id uuid.UUID) (string, time.Time, erro
CreatedAt time.Time `db:"created_at"`
}])
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
err = apierrors.ErrNotFound
}
return "", time.Time{}, err
}
return r.PKCECodeVerifier, r.CreatedAt, nil
Expand Down
8 changes: 7 additions & 1 deletion internal/db/customer_organization.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,13 @@ func CreateCustomerOrganization(ctx context.Context, customerOrg *types.Customer
}
result, err := pgx.CollectExactlyOneRow(rows, pgx.RowToStructByName[types.CustomerOrganization])
if err != nil {
return err
if pgErr, ok := errors.AsType[*pgconn.PgError](err); ok &&
pgErr.Code == pgerrcode.ForeignKeyViolation &&
pgErr.ConstraintName == "customerorganization_image_id_fkey" {
return apierrors.NewBadRequest("invalid image ID")
}

return fmt.Errorf("could not scan created CustomerOrganization: %w", err)
} else {
*customerOrg = result
return nil
Expand Down
4 changes: 4 additions & 0 deletions internal/db/deployment_target_log_records.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ func SaveDeploymentTargetLogRecords(
}),
)

if pgErr, ok := errors.AsType[*pgconn.PgError](err); ok && pgErr.Code == pgerrcode.ForeignKeyViolation {
return apierrors.NewBadRequest("invalid deployment target ID")
}

return err
}

Expand Down
5 changes: 5 additions & 0 deletions internal/handlers/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,11 @@ func agentPutDeploymentTargetLogsHandler() http.HandlerFunc {
}

if err := db.SaveDeploymentTargetLogRecords(ctx, deploymentTarget.ID, records); err != nil {
if errors.Is(err, apierrors.ErrBadRequest) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

log.Error("error saving deployment target log records", zap.Error(err))
sentry.GetHubFromContext(ctx).CaptureException(err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
Expand Down
28 changes: 25 additions & 3 deletions internal/handlers/auth_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ func authLoginOidcCallbackHandler(w http.ResponseWriter, r *http.Request) {

pkceVerifier, err := verifyOIDCState(r)
if err != nil {
if errors.Is(err, apierrors.ErrBadRequest) {
http.Redirect(w, r, redirectToLoginOIDCFailed, http.StatusFound)
return
}

sentry.GetHubFromContext(ctx).CaptureException(err)
log.Warn("could not verify OIDC state", zap.Error(err))
http.Redirect(w, r, redirectToLoginOIDCFailed, http.StatusFound)
Expand All @@ -68,7 +73,21 @@ func authLoginOidcCallbackHandler(w http.ResponseWriter, r *http.Request) {

provider := oidc.Provider(r.PathValue("oidcProvider"))
log = log.With(zap.String("provider", string(provider)))

if oidcError := r.URL.Query().Get("error"); oidcError != "" {
log.Warn("OIDC provider returned error",
zap.String("error", oidcError),
zap.String("error_description", r.URL.Query().Get("error_description")))
http.Redirect(w, r, redirectToLoginOIDCFailed, http.StatusFound)
Comment on lines +77 to +81
return
}

code := r.URL.Query().Get("code")
if code == "" {
log.Warn("OIDC callback missing code parameter")
http.Redirect(w, r, redirectToLoginOIDCFailed, http.StatusFound)
return
}

oidcer := internalctx.GetOIDCer(ctx)
email, emailVerified, err := oidcer.GetEmailForCode(ctx, provider, code, pkceVerifier, r)
Expand Down Expand Up @@ -133,15 +152,18 @@ func authLoginOidcCallbackHandler(w http.ResponseWriter, r *http.Request) {
func verifyOIDCState(r *http.Request) (string, error) {
state, err := uuid.Parse(r.URL.Query().Get("state"))
if err != nil {
return "", err
return "", fmt.Errorf("%w: %w", apierrors.ErrBadRequest, err)
}
pkceVerifier, createdAt, err := db.DeleteOIDCState(r.Context(), state)
if err != nil {
if errors.Is(err, apierrors.ErrNotFound) {
return "", apierrors.ErrBadRequest
}
return "", err
}
if createdAt.Before(time.Now().UTC().Add(-1 * time.Minute)) {
return "", fmt.Errorf("got an OIDC state that is too old: %v, created_at: %v, now: %v",
state, createdAt, time.Now().UTC())
return "", fmt.Errorf("%w: got an OIDC state that is too old: %v, created_at: %v, now: %v",
apierrors.ErrBadRequest, state, createdAt, time.Now().UTC())
}
return pkceVerifier, nil
}
29 changes: 11 additions & 18 deletions internal/handlers/customer_organizations.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,30 +93,23 @@ func createCustomerOrganizationHandler() http.HandlerFunc {

err = db.RunTx(ctx, func(ctx context.Context) error {
if limitReached, err := subscription.IsCustomerOrganizationLimitReached(ctx, *auth.CurrentOrg()); err != nil {
log.Error("failed to get customer orgs", zap.Error(err))
sentry.GetHubFromContext(ctx).CaptureException(err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return err
} else if limitReached {
err = errors.New("customer limit reached")
http.Error(w, err.Error(), http.StatusForbidden)
return err
}

if err := db.CreateCustomerOrganization(ctx, &customerOrganization); err != nil {
log.Error("failed to create customer org", zap.Error(err))
sentry.GetHubFromContext(ctx).CaptureException(err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return err
return apierrors.NewForbidden("customer limit reached")
}

return nil
return db.CreateCustomerOrganization(ctx, &customerOrganization)
})

if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
} else {
if err == nil {
RespondJSON(w, mapping.CustomerOrganizationToAPI(customerOrganization))
} else if errors.Is(err, apierrors.ErrBadRequest) {
http.Error(w, err.Error(), http.StatusBadRequest)
} else if errors.Is(err, apierrors.ErrForbidden) {
http.Error(w, err.Error(), http.StatusForbidden)
} else {
log.Error("failed to create customer org", zap.Error(err))
sentry.GetHubFromContext(ctx).CaptureException(err)
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
}
Expand Down
48 changes: 44 additions & 4 deletions internal/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,18 @@ func ParseUserRole(value string) (UserRole, error) {
}
}

func (ref *UserRole) UnmarshalJSON(data []byte) error {
var value string
if err := json.Unmarshal(data, &value); err != nil {
return err
} else if userRole, err := ParseUserRole(value); err != nil {
Comment on lines +37 to +41
return err
} else {
*ref = userRole
return nil
}
Comment thread
kosmoz marked this conversation as resolved.
}

type OrderDirection string

const (
Expand Down Expand Up @@ -137,8 +149,39 @@ func (ref *DeploymentStatusType) UnmarshalJSON(data []byte) error {
}
}

type DeploymentType string

const (
DeploymentTypeDocker DeploymentType = "docker"
DeploymentTypeKubernetes DeploymentType = "kubernetes"
)

var ErrInvalidDeploymentType = errors.New("invalid deployment type")

func ParseDeploymentType(value string) (DeploymentType, error) {
switch value {
case string(DeploymentTypeDocker):
return DeploymentTypeDocker, nil
case string(DeploymentTypeKubernetes):
return DeploymentTypeKubernetes, nil
default:
return "", fmt.Errorf("%w: %v", ErrInvalidDeploymentType, value)
}
}

func (ref *DeploymentType) UnmarshalJSON(data []byte) error {
var value string
if err := json.Unmarshal(data, &value); err != nil {
return err
} else if deploymentType, err := ParseDeploymentType(value); err != nil {
return err
} else {
*ref = deploymentType
return nil
}
}

type (
DeploymentType string
HelmChartType string
DeploymentTargetScope string
DockerType string
Expand All @@ -148,9 +191,6 @@ type (
)

const (
DeploymentTypeDocker DeploymentType = "docker"
DeploymentTypeKubernetes DeploymentType = "kubernetes"

HelmChartTypeRepository HelmChartType = "repository"
HelmChartTypeOCI HelmChartType = "oci"

Expand Down
42 changes: 42 additions & 0 deletions internal/types/types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,45 @@ func TestDeploymentStatusTypeParsing(t *testing.T) {
err = json.Unmarshal([]byte(`{"type": "does-not-exist"}`), &target)
g.Expect(err).To(MatchError(ErrInvalidDeploymentStatusType))
}

func TestUserRoleParsing(t *testing.T) {
g := NewWithT(t)

var target struct {
Role UserRole `json:"role"`
}

err := json.Unmarshal([]byte(`{"role": "read_only"}`), &target)
g.Expect(err).NotTo(HaveOccurred())
g.Expect(target.Role).To(Equal(UserRoleReadOnly))

err = json.Unmarshal([]byte(`{"role": "read_write"}`), &target)
g.Expect(err).NotTo(HaveOccurred())
g.Expect(target.Role).To(Equal(UserRoleReadWrite))

err = json.Unmarshal([]byte(`{"role": "admin"}`), &target)
g.Expect(err).NotTo(HaveOccurred())
g.Expect(target.Role).To(Equal(UserRoleAdmin))

err = json.Unmarshal([]byte(`{"role": "superuser"}`), &target)
g.Expect(err).To(HaveOccurred())
}

func TestDeploymentTypeParsing(t *testing.T) {
g := NewWithT(t)

var target struct {
Type DeploymentType `json:"type"`
}

err := json.Unmarshal([]byte(`{"type": "docker"}`), &target)
g.Expect(err).NotTo(HaveOccurred())
g.Expect(target.Type).To(Equal(DeploymentTypeDocker))

err = json.Unmarshal([]byte(`{"type": "kubernetes"}`), &target)
g.Expect(err).NotTo(HaveOccurred())
g.Expect(target.Type).To(Equal(DeploymentTypeKubernetes))

err = json.Unmarshal([]byte(`{"type": "swarm"}`), &target)
g.Expect(err).To(MatchError(ErrInvalidDeploymentType))
}
Loading