Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 5 additions & 78 deletions exodus_gw/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,16 @@
from contextlib import asynccontextmanager
from uuid import uuid4

import backoff
import botocore.exceptions
import dramatiq
from asgi_correlation_id import CorrelationIdMiddleware, correlation_id
from fastapi import Depends, FastAPI, Request
from fastapi import Depends, FastAPI
from fastapi.exception_handlers import http_exception_handler
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from fastapi.routing import APIRoute
from sqlalchemy.exc import DBAPIError
from sqlalchemy.orm import Session
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware

from . import docs
from . import docs, retry
from .auth import log_login
from .aws.util import xml_response
from .database import db_engine
Expand Down Expand Up @@ -237,76 +232,8 @@ async def s3_queues_shutdown() -> None:
await client.__aexit__(None, None, None)


def new_db_session(engine):
# Make a new DB session for use in the current request.
#
# This is in its own function so that it can be wrapped by tests.
return Session(bind=engine, autoflush=False, autocommit=False)


@app.middleware("http")
async def db_session(request: Request, call_next):
"""Maintain a DB session around each request, which is also shared
with the dramatiq broker.

An implicit commit occurs if and only if the request succeeds.
"""

request.state.db = new_db_session(app.state.db_engine)

# Any dramatiq operations should also make use of this session.
broker = dramatiq.get_broker()
broker.set_session(request.state.db) # type: ignore
try:
response = await call_next(request)
if response.status_code >= 200 and response.status_code < 300:
await run_in_threadpool(request.state.db.commit)
finally:
# Check if RetryRoute has already cleaned up the session
if request.state.db is not None:
# If not, the session should be cleaned up.
broker.set_session(None) # type: ignore
await run_in_threadpool(request.state.db.close)
request.state.db = None
return response


class RetryRoute(APIRoute):
def get_route_handler(self):
original_route_handler = super().get_route_handler()

async def retry_route_handler(request: Request):
max_tries = request.app.state.settings.db_session_max_tries

@backoff.on_exception(
backoff.expo, DBAPIError, max_tries=max_tries
)
async def retry_wrapper():
broker = dramatiq.get_broker()
if request.state.db is None:
# Create new DB session if last one had an error
request.state.db = new_db_session(
request.app.state.db_engine
)
broker.set_session(request.state.db)

try:
return await original_route_handler(request)

except DBAPIError:
# Rollback and clear DB session
await run_in_threadpool(request.state.db.rollback)
broker.set_session(None)
await run_in_threadpool(request.state.db.close)
request.state.db = None
raise

return await retry_wrapper()

return retry_route_handler


app.router.route_class = RetryRoute
app.router.route_class = retry.RetryRoute
app.add_middleware(BaseHTTPMiddleware, dispatch=retry.db_session)

app.include_router(service.router)
app.include_router(upload.router)
Expand Down
79 changes: 79 additions & 0 deletions exodus_gw/retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import logging

import backoff
import dramatiq
from fastapi import Request
from fastapi.routing import APIRoute
from sqlalchemy.exc import DBAPIError
from sqlalchemy.orm import Session
from starlette.concurrency import run_in_threadpool

LOG = logging.getLogger("exodus-gw")


def new_db_session(engine):
# Make a new DB session for use in the current request.
#
# This is in its own function so that it can be wrapped by tests.
return Session(bind=engine, autoflush=False, autocommit=False)


async def db_session(request: Request, call_next):
"""Maintain a DB session around each request, which is also shared
with the dramatiq broker.

An implicit commit occurs if and only if the request succeeds.
"""

request.state.db = new_db_session(request.app.state.db_engine)

# Any dramatiq operations should also make use of this session.
broker = dramatiq.get_broker()
broker.set_session(request.state.db) # type: ignore
try:
response = await call_next(request)
if response.status_code >= 200 and response.status_code < 300:
await run_in_threadpool(request.state.db.commit)
finally:
# Check if RetryRoute has already cleaned up the session
if request.state.db is not None:
# If not, the session should be cleaned up.
broker.set_session(None) # type: ignore
await run_in_threadpool(request.state.db.close)
request.state.db = None
return response


class RetryRoute(APIRoute):
def get_route_handler(self):
original_route_handler = super().get_route_handler()

async def retry_route_handler(request: Request):
max_tries = request.app.state.settings.db_session_max_tries

@backoff.on_exception(
backoff.expo, DBAPIError, max_tries=max_tries
)
async def retry_wrapper():
broker = dramatiq.get_broker()
if request.state.db is None:
# Create new DB session if last one had an error
request.state.db = new_db_session(
request.app.state.db_engine
)
broker.set_session(request.state.db)

try:
return await original_route_handler(request)

except DBAPIError:
# Rollback and clear DB session
await run_in_threadpool(request.state.db.rollback)
broker.set_session(None)
await run_in_threadpool(request.state.db.close)
request.state.db = None
raise

return await retry_wrapper()

return retry_route_handler
3 changes: 2 additions & 1 deletion exodus_gw/routers/cdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
from exodus_gw import auth, models, schemas, worker

from .. import deps
from ..retry import RetryRoute
from ..settings import Environment, Settings

LOG = logging.getLogger("exodus-gw")

openapi_tag = {"name": "cdn", "description": __doc__}

router = APIRouter(tags=[openapi_tag["name"]])
router = APIRouter(tags=[openapi_tag["name"]], route_class=RetryRoute)


def build_policy(url: str, expiration: datetime):
Expand Down
3 changes: 2 additions & 1 deletion exodus_gw/routers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
from exodus_gw.aws.dynamodb import DynamoDB

from .. import auth, deps, models, schemas, worker
from ..retry import RetryRoute
from ..settings import Environment, Settings

LOG = logging.getLogger("exodus-gw")

openapi_tag = {"name": "config", "description": __doc__}

router = APIRouter(tags=[openapi_tag["name"]])
router = APIRouter(tags=[openapi_tag["name"]], route_class=RetryRoute)

# Paths segments (e.g., "/dist" in "/content/dist/rhel") may contain
# any number of alphanumeric characters, dollars ($), hyphens (-), or
Expand Down
3 changes: 2 additions & 1 deletion exodus_gw/routers/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
from sqlalchemy.orm import Session

from .. import auth, deps, models, schemas, settings
from ..retry import RetryRoute
from .config import CONFIG_SCHEMA, config_post

LOG = logging.getLogger("exodus-gw")

openapi_tag = {"name": "deploy", "description": __doc__}

router = APIRouter(tags=[openapi_tag["name"]])
router = APIRouter(tags=[openapi_tag["name"]], route_class=RetryRoute)


@router.post(
Expand Down
3 changes: 2 additions & 1 deletion exodus_gw/routers/publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,14 @@
from sqlalchemy.orm import Session, noload

from .. import auth, deps, models, schemas, worker
from ..retry import RetryRoute
from ..settings import Environment, Settings

LOG = logging.getLogger("exodus-gw")

openapi_tag = {"name": "publish", "description": __doc__}

router = APIRouter(tags=[openapi_tag["name"]])
router = APIRouter(tags=[openapi_tag["name"]], route_class=RetryRoute)


@router.post(
Expand Down
3 changes: 2 additions & 1 deletion exodus_gw/routers/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
from .. import deps, models, schemas
from ..auth import CallContext
from ..models import DramatiqConsumer
from ..retry import RetryRoute
from ..settings import Settings

LOG = logging.getLogger("exodus-gw")

openapi_tag = {"name": "service", "description": __doc__}

router = APIRouter(tags=[openapi_tag["name"]])
router = APIRouter(tags=[openapi_tag["name"]], route_class=RetryRoute)


@router.get(
Expand Down
3 changes: 2 additions & 1 deletion exodus_gw/routers/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,14 @@
validate_object_key,
xml_response,
)
from ..retry import RetryRoute
from ..settings import Environment, Settings

LOG = logging.getLogger("s3")

openapi_tag = {"name": "upload", "description": __doc__}

router = APIRouter(tags=[openapi_tag["name"]])
router = APIRouter(tags=[openapi_tag["name"]], route_class=RetryRoute)


async def _already_uploaded(
Expand Down
5 changes: 3 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from exodus_gw import database, main, models, settings # noqa
from exodus_gw.dramatiq import Broker
from exodus_gw.retry import new_db_session

from .async_utils import BlockDetector

Expand Down Expand Up @@ -208,13 +209,13 @@ def db_session_block_detector():
"""Wrap DB sessions created by the app with an object to detect
incorrect async/non-async mixing blocking the main thread.
"""
old_ctor = main.new_db_session
old_ctor = new_db_session

def new_ctor(engine):
real_session = old_ctor(engine)
return BlockDetector(real_session)

with mock.patch("exodus_gw.main.new_db_session", new=new_ctor):
with mock.patch("exodus_gw.retry.new_db_session", new=new_ctor):
yield


Expand Down
68 changes: 68 additions & 0 deletions tests/routers/test_publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fastapi import HTTPException
from fastapi.testclient import TestClient
from freezegun import freeze_time
from sqlalchemy.exc import OperationalError

from exodus_gw import routers, schemas
from exodus_gw.main import app
Expand Down Expand Up @@ -1721,3 +1722,70 @@ def test_update_publish_items_origin_paths_invalid_link_to(db, auth_header):
"/content/origin/files/sha256/03/0344062dca731c0d5c24148722537e181d752ca8cda0097005f9268a51658b0a/test-3.rpm",
)
}


def test_update_publish_items_deadlock_retry(db, auth_header):
"""Ensure that deadlock errors are retried and eventually succeed.

When a database deadlock occurs during a publish item update, the retry
mechanism should catch the DBAPIError and retry the operation. If the
deadlock is resolved on retry, the operation should succeed.
"""

publish_id = "11224567-e89b-12d3-a456-426614174000"

publish = Publish(id=publish_id, env="test", state="PENDING")

db.add(publish)
db.commit()

# Track the number of calls to simulate a deadlock on first attempt only
call_count = {"count": 0}

# We need to mock at the SQLAlchemy Session level where execute is called
# during the insert operation
original_session_execute = sqlalchemy.orm.Session.execute

def mock_execute_with_deadlock(self, *args, **kwargs):
# Only intercept INSERT statements on Item table
if args and hasattr(args[0], "compile"):
statement_str = str(args[0].compile())
if "INSERT INTO item" in statement_str:
call_count["count"] += 1
# Raise a deadlock error on the first call only
if call_count["count"] == 1:
# Simulate a PostgreSQL deadlock error
raise OperationalError(
"deadlock detected",
"INSERT INTO item ...",
orig=Exception("deadlock detected"),
)
# On subsequent calls or other statements, use the original execute
return original_session_execute(self, *args, **kwargs)

with mock.patch.object(
sqlalchemy.orm.Session, "execute", mock_execute_with_deadlock
):
with TestClient(app) as client:
r = client.put(
"/test/publish/%s" % publish_id,
json=[
{
"web_uri": "/uri1",
"object_key": "1" * 64,
"content_type": "application/octet-stream",
},
],
headers=auth_header(roles=["test-publisher"]),
)

# Should succeed after retry
assert r.status_code == 200

# Should have been called twice (first attempt + one retry)
assert call_count["count"] == 2

# Item should be in the database
db.refresh(publish)
assert len(publish.items) == 1
assert publish.items[0].web_uri == "/uri1"
Loading