From d75fe1e7513adc10adbbc5781a33c02d0fb69dfa Mon Sep 17 00:00:00 2001 From: Shreya Shankar Date: Thu, 21 Aug 2025 13:48:15 -0700 Subject: [PATCH] feat: read from litellm defaults for context window limits --- lotus/models/lm.py | 49 +++++++++++++++++++++++++++++++++++++++------- 1 file changed, 42 insertions(+), 7 deletions(-) diff --git a/lotus/models/lm.py b/lotus/models/lm.py index ad0b9e4d..ad8abefa 100644 --- a/lotus/models/lm.py +++ b/lotus/models/lm.py @@ -7,7 +7,7 @@ import litellm import numpy as np -from litellm import batch_completion, completion_cost +from litellm import batch_completion, completion_cost, model_cost from litellm.exceptions import AuthenticationError from litellm.types.utils import ChatCompletionTokenLogprob, Choices, ModelResponse from litellm.utils import token_counter @@ -60,8 +60,8 @@ def __init__( self, model: str = "gpt-4o-mini", temperature: float = 0.0, - max_ctx_len: int = 128000, - max_tokens: int = 512, + max_ctx_len: int | None = None, + max_tokens: int | None = None, max_batch_size: int = 64, rate_limit: int | None = None, tokenizer: Tokenizer | None = None, @@ -76,8 +76,8 @@ def __init__( Args: model (str): Name of the model to use. Defaults to "gpt-4o-mini". temperature (float): Sampling temperature. Defaults to 0.0. - max_ctx_len (int): Maximum context length in tokens. Defaults to 128000. - max_tokens (int): Maximum number of tokens to generate. Defaults to 512. + max_ctx_len (int | None): Maximum context length in tokens. If None, derives from litellm model_cost. Defaults to None. + max_tokens (int | None): Maximum number of tokens to generate. If None, derives from litellm model_cost. Defaults to None. max_batch_size (int): Maximum batch size for concurrent requests. Defaults to 64. rate_limit (int | None): Maximum requests per minute. If set, caps max_batch_size and adds delays. tokenizer (Tokenizer | None): Custom tokenizer instance. Defaults to None. @@ -87,8 +87,12 @@ def __init__( **kwargs: Additional keyword arguments passed to the underlying LLM API. """ self.model = model - self.max_ctx_len = max_ctx_len - self.max_tokens = max_tokens + + # Derive max_ctx_len and max_tokens from model_cost if not provided + derived_max_ctx_len, derived_max_tokens = self._get_model_limits(model) + self.max_ctx_len = max_ctx_len if max_ctx_len is not None else derived_max_ctx_len + self.max_tokens = max_tokens if max_tokens is not None else derived_max_tokens + self.rate_limit = rate_limit if rate_limit is not None: self._rate_limit_delay: float = 60 / rate_limit @@ -107,6 +111,37 @@ def __init__( self.cache = cache or CacheFactory.create_default_cache() + def _get_model_limits(self, model: str) -> tuple[int, int]: + """ + Get max_ctx_len and max_tokens from litellm model_cost mapping. + + Args: + model (str): The model name to look up. + + Returns: + tuple[int, int]: A tuple of (max_ctx_len, max_tokens) with defaults to fall back on if not provided. + """ + # Default fallback values + default_max_ctx_len = 128000 + default_max_tokens = 512 + + # Try to get model info from model_cost + model_info = model_cost.get(model) + if model_info is None: + return default_max_ctx_len, default_max_tokens + + # Get max_input_tokens for context length (max_ctx_len) + derived_max_ctx_len = model_info.get("max_input_tokens") + if derived_max_ctx_len is None: + derived_max_ctx_len = model_info.get("max_tokens", default_max_ctx_len) + + # Get max_output_tokens for generation limit (max_tokens) + derived_max_tokens = model_info.get("max_output_tokens") + if derived_max_tokens is None: + derived_max_tokens = model_info.get("max_tokens", default_max_tokens) + + return derived_max_ctx_len, derived_max_tokens + def __call__( self, messages: list[list[dict[str, str]]],