Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 140 additions & 11 deletions backend/open_webui/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import inspect
import json
import asyncio
import hashlib
import hmac

from pydantic import BaseModel
from typing import AsyncGenerator, Generator, Iterator
Expand Down Expand Up @@ -57,23 +59,138 @@
log.setLevel(SRC_LOG_LEVELS["MAIN"])


def validate_valves_schema(Valves, valves_data: dict) -> bool:
"""
Validate that valves_data only contains fields defined in the Valves class
and that the values match expected types.
"""
if not inspect.isclass(Valves):
return False

# Get Valves class annotations/fields
try:
# For Pydantic models
if issubclass(Valves, BaseModel):
valid_fields = set(Valves.model_fields.keys())
# Use Pydantic's validation
try:
Valves(**{k: v for k, v in valves_data.items() if v is not None})
return True
except Exception as e:
log.error(f"Pydantic validation failed for valves: {e}")
return False
else:
# For regular classes, check __annotations__
valid_fields = set(getattr(Valves, '__annotations__', {}).keys())

except Exception as e:
log.error(f"Error getting Valves fields: {e}")
return False

# Check that all keys in valves_data are valid fields
valves_keys = set(valves_data.keys())
if not valves_keys.issubset(valid_fields):
invalid_keys = valves_keys - valid_fields
log.error(f"Invalid valve keys detected: {invalid_keys}")
return False

return True


def validate_pipes_output(pipes_output) -> bool:
"""
Validate that pipes() output is a safe list of dictionaries with expected structure.
This prevents malicious code injection through the pipes attribute/method.
"""
if not isinstance(pipes_output, list):
log.error("pipes output is not a list")
return False

for pipe in pipes_output:
if not isinstance(pipe, dict):
log.error("pipe item is not a dictionary")
return False

# Check for required fields
if 'id' not in pipe or 'name' not in pipe:
log.error("pipe item missing required 'id' or 'name' field")
return False

# Validate types
if not isinstance(pipe['id'], str) or not isinstance(pipe['name'], str):
log.error("pipe 'id' or 'name' is not a string")
return False

# Check for safe characters to prevent injection
if not all(c.isalnum() or c in '-_.' for c in pipe['id']):
log.error(f"pipe id contains unsafe characters: {pipe['id']}")
return False

# Limit size to prevent DoS
if len(pipe['id']) > 256 or len(pipe['name']) > 512:
log.error("pipe id or name exceeds maximum length")
return False

# Limit number of pipes to prevent DoS
if len(pipes_output) > 1000:
log.error("pipes output exceeds maximum number of items")
return False

return True


def get_function_module_by_id(request: Request, pipe_id: str):
# Verify user has access to this function
user = request.state.user if hasattr(request.state, 'user') else None
if user:
function = Functions.get_function_by_id(pipe_id)
if not function:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Function not found"
)

# Check if user has permission to access this function
if not has_access(user.id, type="read", access_control=function.access_control):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Insufficient permissions to access this function"
)

function_module, _, _ = get_function_module_from_cache(request, pipe_id)

if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
Valves = function_module.Valves
valves = Functions.get_function_valves_by_id(pipe_id)

if valves:
try:
function_module.valves = Valves(
**{k: v for k, v in valves.items() if v is not None}
# Validate valves data before instantiation
filtered_valves = {k: v for k, v in valves.items() if v is not None}

if not validate_valves_schema(Valves, filtered_valves):
log.error(f"Invalid valves schema for function {pipe_id}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid valves configuration"
)

try:
function_module.valves = Valves(**filtered_valves)
except Exception as e:
log.exception(f"Error loading valves for function {pipe_id}: {e}")
raise e
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Error instantiating valves: {str(e)}"
)
else:
function_module.valves = Valves()
try:
function_module.valves = Valves()
except Exception as e:
log.exception(f"Error creating default valves for function {pipe_id}: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Error creating default valves: {str(e)}"
)

return function_module

Expand Down Expand Up @@ -103,8 +220,14 @@ async def get_function_models(request):
sub_pipes = function_module.pipes()
else:
sub_pipes = function_module.pipes

# Validate pipes output to prevent code injection
if not validate_pipes_output(sub_pipes):
log.error(f"Invalid pipes output for function {pipe.id}, skipping")
sub_pipes = []

except Exception as e:
log.exception(e)
log.exception(f"Error executing pipes for function {pipe.id}: {e}")
sub_pipes = []

log.debug(
Expand Down Expand Up @@ -211,11 +334,17 @@ def get_function_params(function_module, form_data, user, extra_params=None):

if "__user__" in params and hasattr(function_module, "UserValves"):
user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
try:
params["__user__"]["valves"] = function_module.UserValves(**user_valves)
except Exception as e:
log.exception(e)

# Validate user valves before instantiation
if user_valves and not validate_valves_schema(function_module.UserValves, user_valves):
log.error(f"Invalid user valves schema for function {pipe_id} and user {user.id}")
params["__user__"]["valves"] = function_module.UserValves()
else:
try:
params["__user__"]["valves"] = function_module.UserValves(**user_valves)
except Exception as e:
log.exception(e)
params["__user__"]["valves"] = function_module.UserValves()

return params

Expand Down Expand Up @@ -350,4 +479,4 @@ async def stream_content():
return res.model_dump()

message = await get_message_content(res)
return openai_chat_completion_message_template(form_data["model"], message)
return openai_chat_completion_message_template(form_data["model"], message)