Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1,9 +1,26 @@
# AI Rewriter Configuration

# Gemini API Key (required)
# AI Provider: gemini (default), claude, grok, ollama
AI_PROVIDER=gemini

# Gemini API Key (required when AI_PROVIDER=gemini)
# Get your API key from: https://aistudio.google.com/apikey
GEMINI_API_KEY=your_api_key_here

# Anthropic API Key (required when AI_PROVIDER=claude)
# ANTHROPIC_API_KEY=your_key_here

# xAI API Key (required when AI_PROVIDER=grok)
# XAI_API_KEY=your_key_here

# Ollama settings (required when AI_PROVIDER=ollama)
# OLLAMA_BASE_URL=http://localhost:11434
# OLLAMA_MODEL=llama3

# Enable Guardrails AI security validation (optional, defaults to true)
# Set to false to disable prompt injection and toxic content checks
ENABLE_SECURITY=true

# Server Port (optional, defaults to 8787)
PORT=8787

Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ dependencies = [
"uvicorn>=0.40.0",
]

[project.optional-dependencies]
claude = ["anthropic>=0.52.0"]
grok = ["openai>=1.30.0"]
ollama = ["openai>=1.30.0"]
all = ["anthropic>=0.52.0", "openai>=1.30.0"]

[dependency-groups]
dev = [
"pytest>=9.0.2",
Expand Down
39 changes: 22 additions & 17 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,41 @@
import os

from fastapi import FastAPI, HTTPException
from google import genai

from .models import RewriteRequest
from .prompts import mode_prompts
from .security import RewriteRequest, validate_input
import os
from .providers import ProviderError, generate
from .security import validate_input

client = genai.Client(api_key=os.environ["GEMINI_API_KEY"])
PORT = int(os.environ.get("PORT", "8787"))
ENABLE_SECURITY = os.environ.get("ENABLE_SECURITY", "true").lower() == "true"
app = FastAPI()


@app.post("/rewrite")
def rewrite(req: RewriteRequest):
# Validate input for security threats
try:
validated_text = validate_input(req.text)
except Exception as e:
raise HTTPException(
status_code=400, detail=f"Security validation failed: {str(e)}"
)
if ENABLE_SECURITY:
try:
validated_text = validate_input(req.text)
except Exception as e:
raise HTTPException(
status_code=400, detail=f"Security validation failed: {str(e)}"
)
else:
validated_text = req.text

# Build prompt with validated input
prompt = f"""
{mode_prompts.get(req.mode, mode_prompts['default'])}

Message:
{validated_text}
"""
response = client.models.generate_content(
model="models/gemini-2.5-flash",
contents=prompt,
)
return {"result": response.text.strip()}
try:
result = generate(prompt)
except ProviderError as e:
raise HTTPException(status_code=e.status_code, detail=str(e))

return {"result": result}


@app.get("/health")
Expand Down
29 changes: 29 additions & 0 deletions src/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Request models for AI Rewriter."""

from pydantic import BaseModel, field_validator


class RewriteRequest(BaseModel):
"""Validated request model with security checks."""

text: str
mode: str = "default"

@field_validator("text")
@classmethod
def validate_text_length(cls, v: str) -> str:
"""Mitigates LLM04 (DoS) by enforcing max length."""
if len(v) > 5000:
raise ValueError(f"Text too long ({len(v)} chars). Max: 5000")
if len(v) < 1:
raise ValueError("Text cannot be empty")
return v

@field_validator("mode")
@classmethod
def validate_mode(cls, v: str) -> str:
"""Validate mode is allowed."""
allowed_modes = ["default", "formal", "short", "friendly", "claude-prompt"]
if v not in allowed_modes:
raise ValueError(f"Invalid mode: {v}")
return v
130 changes: 130 additions & 0 deletions src/providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""
Multi-model provider support for AI Rewriter.

Supports: Gemini (default), Claude, Grok, Ollama.
Provider is selected via AI_PROVIDER environment variable.
"""

import logging
import os

_log = logging.getLogger(__name__)

VALID_PROVIDERS = {"gemini", "claude", "grok", "ollama"}

_LITELLM_MAP = {
"gemini": "gemini/gemini-2.5-flash",
"claude": "anthropic/claude-sonnet-4-6",
"grok": "openai/grok-3",
}


class ProviderError(Exception):
"""Error from an AI provider API (auth, billing, rate limits, etc.)."""

def __init__(self, message: str, status_code: int = 502):
super().__init__(message)
self.status_code = status_code


def _get_provider() -> str:
provider = os.environ.get("AI_PROVIDER", "gemini").lower()
if provider not in VALID_PROVIDERS:
raise ValueError(
f"Unknown AI_PROVIDER: '{provider}'. "
f"Valid options: {', '.join(sorted(VALID_PROVIDERS))}"
)
return provider


def get_litellm_model() -> str:
"""Return the LiteLLM model identifier for the configured provider."""
provider = _get_provider()
if provider == "ollama":
model = os.environ.get("OLLAMA_MODEL", "llama3")
return f"ollama_chat/{model}"
return _LITELLM_MAP[provider]


def _generate_gemini(prompt: str) -> str:
from google import genai
from google.genai.errors import APIError

try:
client = genai.Client(api_key=os.environ["GEMINI_API_KEY"])
response = client.models.generate_content(
model="models/gemini-2.5-flash",
contents=prompt,
)
return response.text.strip()
except APIError as e:
raise ProviderError(str(e), status_code=e.code) from e


def _generate_claude(prompt: str) -> str:
import anthropic

try:
client = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
message = client.messages.create(
model="claude-sonnet-4-6",
max_tokens=4096,
messages=[{"role": "user", "content": prompt}],
)
return message.content[0].text.strip()
except anthropic.APIStatusError as e:
raise ProviderError(e.message, status_code=e.status_code) from e


def _generate_grok(prompt: str) -> str:
import openai

try:
client = openai.OpenAI(
api_key=os.environ["XAI_API_KEY"],
base_url="https://api.x.ai/v1",
)
response = client.chat.completions.create(
model="grok-3",
messages=[{"role": "user", "content": prompt}],
)
return response.choices[0].message.content.strip()
except openai.APIStatusError as e:
raise ProviderError(e.message, status_code=e.status_code) from e


def _generate_ollama(prompt: str) -> str:
import openai

base_url = os.environ.get("OLLAMA_BASE_URL", "http://localhost:11434")
if not base_url.startswith(("http://", "https://")):
raise ProviderError(f"Invalid OLLAMA_BASE_URL: '{base_url}'", status_code=500)

model = os.environ.get("OLLAMA_MODEL", "llama3")
try:
client = openai.OpenAI(
api_key="ollama",
base_url=f"{base_url}/v1",
)
response = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}],
)
return response.choices[0].message.content.strip()
except openai.APIStatusError as e:
raise ProviderError(e.message, status_code=e.status_code) from e


_PROVIDERS = {
"gemini": _generate_gemini,
"claude": _generate_claude,
"grok": _generate_grok,
"ollama": _generate_ollama,
}


def generate(prompt: str) -> str:
"""Generate text using the configured AI provider."""
provider = _get_provider()
_log.info("Using AI provider: %s", provider)
return _PROVIDERS[provider](prompt)
94 changes: 38 additions & 56 deletions src/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,67 +9,49 @@
Reference: https://owasp.org/www-project-top-10-for-large-language-model-applications/
"""

from guardrails import Guard
from guardrails.hub import UnusualPrompt, ToxicLanguage
from pydantic import BaseModel, field_validator


class RewriteRequest(BaseModel):
"""Validated request model with security checks."""

text: str
mode: str = "default"

@field_validator("text")
@classmethod
def validate_text_length(cls, v: str) -> str:
"""Mitigates LLM04 (DoS) by enforcing max length."""
if len(v) > 5000:
raise ValueError(f"Text too long ({len(v)} chars). Max: 5000")
if len(v) < 1:
raise ValueError("Text cannot be empty")
return v

@field_validator("mode")
@classmethod
def validate_mode(cls, v: str) -> str:
"""Validate mode is allowed."""
allowed_modes = ["default", "formal", "short", "friendly", "claude-prompt"]
if v not in allowed_modes:
raise ValueError(f"Invalid mode: {v}")
return v


def create_input_guard() -> Guard:
"""
Create input validation guard.

Validators:
- UnusualPrompt: Detects jailbreaking and prompt injection (OWASP LLM01)
- ToxicLanguage: Filters harmful content
"""
return Guard().use_many(
UnusualPrompt(llm_callable="gemini/gemini-2.5-flash", on_fail="exception"),
ToxicLanguage(threshold=0.5, validation_method="sentence", on_fail="exception"),
)


# Global guard instance
input_guard = create_input_guard()
import logging

_log = logging.getLogger(__name__)
_input_guard = None
_guard_init_attempted = False


def _get_guard():
global _input_guard, _guard_init_attempted
if not _guard_init_attempted:
_guard_init_attempted = True
try:
from guardrails import Guard
from guardrails.hub import UnusualPrompt, ToxicLanguage
from .providers import get_litellm_model

_input_guard = Guard().use(
UnusualPrompt(
llm_callable=get_litellm_model(), on_fail="exception"
),
ToxicLanguage(
threshold=0.5, validation_method="sentence", on_fail="exception"
),
)
except ImportError:
_log.warning(
"Guardrails hub validators not installed. "
"Security validation is disabled. "
"Run 'guardrails hub install hub://guardrails/unusual_prompt "
"hub://guardrails/toxic_language' to enable."
)
return _input_guard


def validate_input(text: str) -> str:
"""
Validate input against security threats.

Args:
text: Input text to validate

Returns:
Validated text

Raises:
Exception: If validation fails
Returns validated text. Falls back to passthrough if guardrails
hub validators are not installed.
"""
validated_output = input_guard.validate(text)
guard = _get_guard()
if guard is None:
return text
validated_output = guard.validate(text)
return validated_output.validated_output
Loading