Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 42 additions & 7 deletions lotus/models/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import litellm
import numpy as np
from litellm import batch_completion, completion_cost
from litellm import batch_completion, completion_cost, model_cost
from litellm.exceptions import AuthenticationError
from litellm.types.utils import ChatCompletionTokenLogprob, Choices, ModelResponse
from litellm.utils import token_counter
Expand Down Expand Up @@ -60,8 +60,8 @@ def __init__(
self,
model: str = "gpt-4o-mini",
temperature: float = 0.0,
max_ctx_len: int = 128000,
max_tokens: int = 512,
max_ctx_len: int | None = None,
max_tokens: int | None = None,
max_batch_size: int = 64,
rate_limit: int | None = None,
tokenizer: Tokenizer | None = None,
Expand All @@ -76,8 +76,8 @@ def __init__(
Args:
model (str): Name of the model to use. Defaults to "gpt-4o-mini".
temperature (float): Sampling temperature. Defaults to 0.0.
max_ctx_len (int): Maximum context length in tokens. Defaults to 128000.
max_tokens (int): Maximum number of tokens to generate. Defaults to 512.
max_ctx_len (int | None): Maximum context length in tokens. If None, derives from litellm model_cost. Defaults to None.
max_tokens (int | None): Maximum number of tokens to generate. If None, derives from litellm model_cost. Defaults to None.
max_batch_size (int): Maximum batch size for concurrent requests. Defaults to 64.
rate_limit (int | None): Maximum requests per minute. If set, caps max_batch_size and adds delays.
tokenizer (Tokenizer | None): Custom tokenizer instance. Defaults to None.
Expand All @@ -87,8 +87,12 @@ def __init__(
**kwargs: Additional keyword arguments passed to the underlying LLM API.
"""
self.model = model
self.max_ctx_len = max_ctx_len
self.max_tokens = max_tokens

# Derive max_ctx_len and max_tokens from model_cost if not provided
derived_max_ctx_len, derived_max_tokens = self._get_model_limits(model)
self.max_ctx_len = max_ctx_len if max_ctx_len is not None else derived_max_ctx_len
self.max_tokens = max_tokens if max_tokens is not None else derived_max_tokens

self.rate_limit = rate_limit
if rate_limit is not None:
self._rate_limit_delay: float = 60 / rate_limit
Expand All @@ -107,6 +111,37 @@ def __init__(

self.cache = cache or CacheFactory.create_default_cache()

def _get_model_limits(self, model: str) -> tuple[int, int]:
"""
Get max_ctx_len and max_tokens from litellm model_cost mapping.

Args:
model (str): The model name to look up.

Returns:
tuple[int, int]: A tuple of (max_ctx_len, max_tokens) with defaults to fall back on if not provided.
"""
# Default fallback values
default_max_ctx_len = 128000
default_max_tokens = 512

# Try to get model info from model_cost
model_info = model_cost.get(model)
if model_info is None:
return default_max_ctx_len, default_max_tokens

# Get max_input_tokens for context length (max_ctx_len)
derived_max_ctx_len = model_info.get("max_input_tokens")
if derived_max_ctx_len is None:
derived_max_ctx_len = model_info.get("max_tokens", default_max_ctx_len)

# Get max_output_tokens for generation limit (max_tokens)
derived_max_tokens = model_info.get("max_output_tokens")
if derived_max_tokens is None:
derived_max_tokens = model_info.get("max_tokens", default_max_tokens)

return derived_max_ctx_len, derived_max_tokens

def __call__(
self,
messages: list[list[dict[str, str]]],
Expand Down