diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 72f0b292..ba3cbe67 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -42,8 +42,11 @@ jobs: strategy: fail-fast: false matrix: - test-suite: [lm-openai, lm-ollama, rm, multimodality, utility_operators, cache] + test-suite: [settings, lm-openai, lm-ollama, rm, multimodality, utility_operators, cache] include: + - test-suite: settings + file: tests/test_settings.py + timeout: 5 - test-suite: lm-openai file: .github/tests/lm_tests.py timeout: 10 diff --git a/docs/configurations.rst b/docs/configurations.rst index b1aedda0..b543c280 100644 --- a/docs/configurations.rst +++ b/docs/configurations.rst @@ -59,3 +59,107 @@ Configurable Parameters lotus.settings.configure(lm=gpt_4o, helper_lm=gpt_4o_mini) +Scoped Settings with ``context()`` +------------------------------------ + +``lotus.settings.context(**kwargs)`` is a context manager that temporarily overrides +settings for the duration of a ``with`` block. The previous values are always restored +on exit — even if an exception is raised. + +This is useful for: + +* Switching to a cheaper model for one step in a pipeline without affecting the rest +* Running an evaluation judge with a fresh model and ``enable_cache=False`` +* Isolating settings in tests so one test cannot pollute another +* Running concurrent threads or asyncio tasks with independent settings + +Basic usage +~~~~~~~~~~~ + +.. code-block:: python + + import lotus + from lotus.models import LM + + lm = LM(model="gpt-4o") + cheap_lm = LM(model="gpt-4o-mini") + lotus.settings.configure(lm=lm) + + # Use the cheap model only for this step; gpt-4o is restored afterward + with lotus.settings.context(lm=cheap_lm, enable_cache=False): + df = df.sem_filter("Is {Review} positive?") + + # Back to gpt-4o here + df = df.sem_map("Summarise {Review} in one sentence.") + +Nested contexts +~~~~~~~~~~~~~~~ + +Contexts can be nested. Each level saves and restores independently. + +.. code-block:: python + + with lotus.settings.context(lm=cheap_lm): + # inner context adds another override on top + with lotus.settings.context(enable_cache=True): + df = df.sem_map(...) # cheap_lm + enable_cache=True + df = df.sem_filter(...) # cheap_lm only, enable_cache restored + +Concurrent threads +~~~~~~~~~~~~~~~~~~ + +Because ``context()`` uses ``contextvars.ContextVar`` internally, each thread sees +only its own overrides. Threads cannot overwrite each other's settings even though +they share the same ``lotus.settings`` object. + +.. code-block:: python + + import threading + import lotus + from lotus.models import LM + + lotus.settings.configure(lm=LM(model="gpt-4o-mini")) + + def analyse(df, model, results, key): + with lotus.settings.context(lm=model): + results[key] = df.sem_map("Summarise {Text}.") + + results = {} + t1 = threading.Thread(target=analyse, args=(df1, LM("gpt-4o-mini"), results, "fast")) + t2 = threading.Thread(target=analyse, args=(df2, LM("gpt-4o"), results, "quality")) + t1.start(); t2.start() + t1.join(); t2.join() + +Concurrent asyncio tasks +~~~~~~~~~~~~~~~~~~~~~~~~~ + +``asyncio`` tasks created with ``asyncio.create_task()`` or ``asyncio.gather()`` each +receive a copy of the caller's context, so ``ContextVar`` mutations inside one task +are invisible to others. + +.. code-block:: python + + import asyncio + import lotus + from lotus.models import LM + + lotus.settings.configure(lm=LM(model="gpt-4o-mini")) + + async def run(df, model): + with lotus.settings.context(lm=model): + await asyncio.sleep(0) # yield; other tasks see their own model + return df.sem_map("Classify {Text}.") + + async def main(): + results = await asyncio.gather( + run(df1, LM("gpt-4o-mini")), + run(df2, LM("gpt-4o")), + ) + + asyncio.run(main()) + +.. note:: + + ``configure()`` mutates the global settings object and is **not** thread-safe. + Use ``context()`` whenever settings need to differ across concurrent execution paths. + diff --git a/examples/settings_examples/concurrent_asyncio.py b/examples/settings_examples/concurrent_asyncio.py new file mode 100644 index 00000000..6c762849 --- /dev/null +++ b/examples/settings_examples/concurrent_asyncio.py @@ -0,0 +1,60 @@ +"""Run concurrent asyncio tasks, each with its own model and cache settings. + +asyncio tasks created with asyncio.create_task() or asyncio.gather() each +receive a copy of the current contextvars context, so ContextVar mutations +inside one task are invisible to others. This makes context() safe to use +in async pipelines without any extra locking. +""" + +import asyncio + +import pandas as pd + +import lotus +from lotus.models import LM + +baseline_lm = LM(model="gpt-4o-mini") +lotus.settings.configure(lm=baseline_lm, enable_cache=True) + +fast_lm = LM(model="gpt-4o-mini") +quality_lm = LM(model="gpt-4o") + +tech_articles = [ + "Researchers demonstrate a new battery chemistry that doubles energy density.", + "New compiler optimisation cuts inference latency by 40% on edge devices.", + "Open-source robotics framework gains traction in warehouse automation.", +] + +science_articles = [ + "Study links gut microbiome diversity to improved cognitive function.", + "James Webb telescope captures earliest known galaxy formation.", + "CRISPR variant achieves record efficiency in correcting sickle-cell mutations.", +] + + +async def summarise(articles: list[str], lm: LM, label: str) -> pd.DataFrame: + """Summarise a batch of articles using the provided model.""" + df = pd.DataFrame({"Article": articles}) + with lotus.settings.context(lm=lm, enable_cache=False): + # Simulate async I/O between LM calls + await asyncio.sleep(0) + df = df.sem_map("Summarise {Article} in one sentence.") + print(f"\n{label} summaries (model: {lm.model}):") + print(df["Article"].to_string(index=False)) + return df + + +async def main() -> None: + # Both tasks run concurrently; each sees only its own lm override + tech_task = asyncio.create_task(summarise(tech_articles, fast_lm, "Tech")) + science_task = asyncio.create_task(summarise(science_articles, quality_lm, "Science")) + + tech_df, science_df = await asyncio.gather(tech_task, science_task) + + # Global settings are untouched by either task + assert lotus.settings.lm is baseline_lm + assert lotus.settings.enable_cache is True + print("\nGlobal settings unchanged after tasks completed.") + + +asyncio.run(main()) diff --git a/examples/settings_examples/concurrent_threads.py b/examples/settings_examples/concurrent_threads.py new file mode 100644 index 00000000..8158809a --- /dev/null +++ b/examples/settings_examples/concurrent_threads.py @@ -0,0 +1,66 @@ +"""Run parallel analyses on different data segments, each with its own model. + +Because context() uses contextvars.ContextVar, each thread sees only its own +settings overlay. The threads cannot overwrite each other's lm or enable_cache +even though they all share the same global lotus.settings object. +""" + +import threading + +import pandas as pd + +import lotus +from lotus.models import LM + +# Global baseline — used by any code that runs outside a context +baseline_lm = LM(model="gpt-4o-mini") +lotus.settings.configure(lm=baseline_lm) + +# Two specialised models for different analysis tasks +sentiment_lm = LM(model="gpt-4o-mini") +topic_lm = LM(model="gpt-4o-mini") + +reviews = [ + "The battery life on this laptop is incredible — lasts all day easily.", + "Screen is gorgeous but the keyboard feels cheap.", + "Runs hot under load and the fan is very loud.", + "Best value for money I've found in this price range.", + "Customer support was unhelpful when I had a setup issue.", + "Surprisingly lightweight for a 15-inch machine.", +] + +results: dict[str, pd.DataFrame] = {} + + +def run_sentiment(data: list[str]) -> None: + df = pd.DataFrame({"Review": data}) + with lotus.settings.context(lm=sentiment_lm): + results["sentiment"] = df.sem_map( + "Classify the sentiment of {Review} as Positive, Negative, or Neutral." + ) + + +def run_topic(data: list[str]) -> None: + df = pd.DataFrame({"Review": data}) + with lotus.settings.context(lm=topic_lm): + results["topic"] = df.sem_map( + "Identify the main topic of {Review} in two words or fewer." + ) + + +t1 = threading.Thread(target=run_sentiment, args=(reviews,)) +t2 = threading.Thread(target=run_topic, args=(reviews,)) + +t1.start() +t2.start() +t1.join() +t2.join() + +print("Sentiment analysis:") +print(results["sentiment"].to_string(index=False)) +print("\nTopic analysis:") +print(results["topic"].to_string(index=False)) + +# Global settings are untouched by either thread +assert lotus.settings.lm is baseline_lm +print("\nGlobal lm unchanged after threads exited.") diff --git a/examples/settings_examples/eval_cache_isolation.py b/examples/settings_examples/eval_cache_isolation.py new file mode 100644 index 00000000..3d0098ed --- /dev/null +++ b/examples/settings_examples/eval_cache_isolation.py @@ -0,0 +1,53 @@ +"""Run an evaluation with a different model and cache setting than the main pipeline. + +Without context(), toggling enable_cache or swapping lm for an eval step +would leak into subsequent pipeline operations. The context manager ensures +the eval runs in isolation and the original settings are always restored. +""" + +import pandas as pd + +import lotus +from lotus.models import LM +from lotus.evals import llm_as_judge + +# Production model with caching enabled +prod_lm = LM(model="gpt-4o-mini") +# Dedicated eval judge — should not be cached so results are always fresh +judge_lm = LM(model="gpt-4o") + +lotus.settings.configure(lm=prod_lm, enable_cache=True) + +data = { + "question": [ + "What is the capital of France?", + "Who wrote Romeo and Juliet?", + "What is the speed of light?", + ], + "answer": [ + "Paris is the capital of France.", + "Romeo and Juliet was written by William Shakespeare.", + "The speed of light is approximately 300,000 km/s.", + ], +} +df = pd.DataFrame(data) + +# Step 1: Generate additional context with the cached production model +df = df.sem_map("Expand {answer} with one additional relevant fact.") +print("Expanded answers:") +print(df[["question", "answer"]].to_string(index=False)) + +# Step 2: Evaluate quality using the judge model, with caching disabled +# so every eval call goes to the model rather than returning a stale result. +with lotus.settings.context(lm=judge_lm, enable_cache=False): + scores = df.llm_as_judge( + judge_instruction="Rate the accuracy of this {answer} to the {question} from 1-10. Output only the number.", + n_trials=1, + ) + print("\nEval scores (judge: gpt-4o, cache disabled):") + print(scores) + +# Verify settings are restored +assert lotus.settings.lm is prod_lm +assert lotus.settings.enable_cache is True +print("\nSettings restored: lm=gpt-4o-mini, enable_cache=True") diff --git a/examples/settings_examples/scoped_model_switching.py b/examples/settings_examples/scoped_model_switching.py new file mode 100644 index 00000000..13d98ca3 --- /dev/null +++ b/examples/settings_examples/scoped_model_switching.py @@ -0,0 +1,39 @@ +"""Temporarily switch to a different model for one step in a pipeline. + +The global lm is restored automatically after the context exits, so later +steps continue using the original model without any manual save/restore. +""" + +import pandas as pd + +import lotus +from lotus.models import LM + +# Global model used for most pipeline steps +lm = LM(model="gpt-4o") +# Cheaper/faster model for a high-volume intermediate step +cheap_lm = LM(model="gpt-4o-mini") + +lotus.settings.configure(lm=lm) + +data = { + "Paper Title": [ + "Attention Is All You Need", + "BERT: Pre-training of Deep Bidirectional Transformers", + "Deep Residual Learning for Image Recognition", + "Generative Adversarial Networks", + "Neural Machine Translation by Jointly Learning to Align and Translate", + ] +} +df = pd.DataFrame(data) + +# Step 1: Use the cheap model for a coarse filter (high volume, low stakes) +with lotus.settings.context(lm=cheap_lm): + df = df.sem_filter("Is {Paper Title} related to natural language processing?") + print(f"After NLP filter ({len(df)} papers remaining):") + print(df["Paper Title"].tolist()) + +# Step 2: Back to the global (high-quality) model for the final summarization +df = df.sem_map("Write a one-sentence summary of the contributions of {Paper Title}.") +print("\nSummaries (generated with gpt-4o):") +print(df) diff --git a/lotus/settings.py b/lotus/settings.py index 49fad071..4c4b0aa7 100644 --- a/lotus/settings.py +++ b/lotus/settings.py @@ -1,8 +1,15 @@ +from contextlib import contextmanager +from contextvars import ContextVar +from typing import Any, Generator + import lotus.models import lotus.vector_store from lotus.types import SerializationFormat -# NOTE: Settings class is not thread-safe +# context() is safe for concurrent use across threads and asyncio tasks. +# Direct mutation via configure() or attribute assignment is not thread-safe. + +_settings_context: ContextVar[dict[str, Any] | None] = ContextVar("_settings_context", default=None) class Settings: @@ -22,13 +29,47 @@ class Settings: # Parallel groupby settings parallel_groupby_max_threads: int = 8 - def configure(self, **kwargs): + def __getattribute__(self, name: str) -> Any: + # For known settings fields, check the per-context overlay first. + annotations = object.__getattribute__(self, "__class__").__annotations__ + if name in annotations: + ctx = _settings_context.get() + if ctx is not None and name in ctx: + return ctx[name] + return object.__getattribute__(self, name) + + def configure(self, **kwargs: Any) -> None: for key, value in kwargs.items(): if not hasattr(self, key): raise ValueError(f"Invalid setting: {key}") setattr(self, key, value) - def __str__(self): + @contextmanager + def context(self, **kwargs: Any) -> Generator["Settings", None, None]: + """Temporarily override settings in the current thread or asyncio task. + + Each thread and asyncio task sees only its own overrides — concurrent + callers cannot interfere with each other. Supports nesting and + guarantees restoration even if an exception is raised. + + Example:: + + with lotus.settings.context(enable_cache=False, lm=eval_lm): + result = df.sem_filter("...") + """ + # Validate all keys before making any changes. + for key in kwargs: + if not hasattr(self, key): + raise ValueError(f"Invalid setting: {key}") + + current = _settings_context.get() or {} + token = _settings_context.set({**current, **kwargs}) + try: + yield self + finally: + _settings_context.reset(token) + + def __str__(self) -> str: return str(vars(self)) diff --git a/tests/test_settings.py b/tests/test_settings.py index 4f251bbd..54e780ae 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -1,3 +1,6 @@ +import asyncio +import threading + import pytest from lotus.settings import SerializationFormat, Settings @@ -13,13 +16,136 @@ def test_initial_values(self, settings): assert settings.rm is None assert settings.helper_lm is None assert settings.reranker is None - assert settings.enable_message_cache is False + assert settings.enable_cache is False assert settings.serialization_format == SerializationFormat.DEFAULT def test_configure_method(self, settings): - settings.configure(enable_message_cache=True) - assert settings.enable_message_cache is True + settings.configure(enable_cache=True) + assert settings.enable_cache is True def test_invalid_setting(self, settings): with pytest.raises(ValueError, match="Invalid setting: invalid_setting"): settings.configure(invalid_setting=True) + + +class TestSettingsContext: + @pytest.fixture + def settings(self): + return Settings() + + def test_context_restores_on_exit(self, settings): + settings.configure(enable_cache=True) + with settings.context(enable_cache=False): + assert settings.enable_cache is False + assert settings.enable_cache is True + + def test_context_restores_class_default_on_exit(self, settings): + # enable_cache starts at class default (False, not set as instance attr) + assert "enable_cache" not in vars(settings) + with settings.context(enable_cache=True): + assert settings.enable_cache is True + # ContextVar is reset; class default takes over without touching instance dict + assert "enable_cache" not in vars(settings) + assert settings.enable_cache is False + + def test_context_restores_on_exception(self, settings): + settings.configure(enable_cache=True) + with pytest.raises(RuntimeError): + with settings.context(enable_cache=False): + assert settings.enable_cache is False + raise RuntimeError("boom") + assert settings.enable_cache is True + + def test_context_yields_settings(self, settings): + with settings.context(enable_cache=True) as s: + assert s is settings + assert s.enable_cache is True + + def test_context_multiple_overrides(self, settings): + settings.configure(enable_cache=True, parallel_groupby_max_threads=4) + with settings.context(enable_cache=False, parallel_groupby_max_threads=16): + assert settings.enable_cache is False + assert settings.parallel_groupby_max_threads == 16 + assert settings.enable_cache is True + assert settings.parallel_groupby_max_threads == 4 + + def test_nested_contexts(self, settings): + settings.configure(enable_cache=False) + with settings.context(enable_cache=True): + assert settings.enable_cache is True + with settings.context(enable_cache=False): + assert settings.enable_cache is False + assert settings.enable_cache is True + assert settings.enable_cache is False + + def test_context_invalid_setting_raises(self, settings): + settings.configure(enable_cache=True) + with pytest.raises(ValueError, match="Invalid setting: bad_key"): + with settings.context(bad_key=True): + pass # pragma: no cover + # Settings must be unchanged after the failed context entry + assert settings.enable_cache is True + + def test_context_serialization_format(self, settings): + settings.configure(serialization_format=SerializationFormat.JSON) + with settings.context(serialization_format=SerializationFormat.XML): + assert settings.serialization_format == SerializationFormat.XML + assert settings.serialization_format == SerializationFormat.JSON + + +class TestSettingsContextConcurrency: + @pytest.fixture + def settings(self): + return Settings() + + def test_thread_isolation(self, settings): + """Two threads entering context() simultaneously see only their own overrides.""" + results: dict[int, bool] = {} + barrier = threading.Barrier(2) + + def run(thread_id: int, value: bool) -> None: + with settings.context(enable_cache=value): + barrier.wait() # both threads inside context at the same time + results[thread_id] = settings.enable_cache + + t1 = threading.Thread(target=run, args=(1, True)) + t2 = threading.Thread(target=run, args=(2, False)) + t1.start() + t2.start() + t1.join() + t2.join() + + assert results[1] is True + assert results[2] is False + + def test_thread_baseline_unaffected(self, settings): + """Global baseline is unchanged after threads exit their contexts.""" + settings.configure(enable_cache=False) + barrier = threading.Barrier(2) + + def run(value: bool) -> None: + with settings.context(enable_cache=value): + barrier.wait() + + threads = [threading.Thread(target=run, args=(v,)) for v in (True, False)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert settings.enable_cache is False + + def test_asyncio_task_isolation(self, settings): + """Two asyncio tasks entering context() see only their own overrides.""" + + async def run(value: bool) -> bool: + with settings.context(enable_cache=value): + await asyncio.sleep(0) # yield so both tasks overlap + return settings.enable_cache + + async def main() -> tuple[bool, bool]: + return await asyncio.gather(run(True), run(False)) + + r_true, r_false = asyncio.run(main()) + assert r_true is True + assert r_false is False