Skip to content
Open
39 changes: 39 additions & 0 deletions .github/tests/lm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import lotus
from lotus.models import LM, SentenceTransformersRM
from lotus.types import CascadeArgs
from lotus.sem_ops.ensembling import Ensemble, EnsembleConfig, EnsembleStrategy
from lotus.vector_store import FaissVS

################################################################################
Expand Down Expand Up @@ -499,6 +500,44 @@ def test_filter_cascade(setup_models):
assert stats["filters_resolved_by_helper_model"] > 0, stats


@pytest.mark.skipif(not ENABLE_OPENAI_TESTS, reason="Skipping test because OpenAI tests are not enabled")
def test_filter_ensembling(setup_models):
models = setup_models
lotus.settings.configure(lm=models["gpt-4o-mini"])

data = {
"Text": [
"I am really excited to go to class today!",
"I am very sad",
]
}
df = pd.DataFrame(data)
user_instruction = "{Text} is a positive sentiment"

# Test with n_sample=2 and majority vote
filtered_df = df.sem_filter(
user_instruction,
n_sample=2,
ensemble=Ensemble(EnsembleConfig(strategy=EnsembleStrategy.MAJORITY_VOTE)),
return_raw_outputs=True,
return_explanations=True,
return_all=True
)

# Check for new columns
expected_cols = [
"raw_output_1", "parsed_output_1", "explanation_1",
"raw_output_2", "parsed_output_2", "explanation_2",
"filter_label" # Ensemble result
]
for col in expected_cols:
assert col in filtered_df.columns, f"Column {col} missing from dataframe"

# Check ensemble logic (both samples should be True for first row)
assert filtered_df.iloc[0]["filter_label"] == True
assert filtered_df.iloc[1]["filter_label"] == False


@pytest.mark.skipif(not ENABLE_OPENAI_TESTS, reason="Skipping test because OpenAI tests are not enabled")
def test_join_cascade(setup_models):
models = setup_models
Expand Down
26 changes: 25 additions & 1 deletion .github/tests/multimodality_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

import lotus
from lotus.dtype_extensions import ImageArray
from lotus.dtype_extensions import ImageArray, AudioArray
from lotus.models import LM, SentenceTransformersRM
from lotus.vector_store import FaissVS

Expand All @@ -20,13 +20,15 @@

MODEL_NAME_TO_ENABLED = {
"gpt-4o-mini": ENABLE_OPENAI_TESTS,
"gpt-4o-audio-preview": ENABLE_OPENAI_TESTS,
"clip-ViT-B-32": ENABLE_LOCAL_TESTS,
}
ENABLED_MODEL_NAMES = set([model_name for model_name, is_enabled in MODEL_NAME_TO_ENABLED.items() if is_enabled])

MODEL_NAME_TO_CLS = {
"clip-ViT-B-32": SentenceTransformersRM,
"gpt-4o-mini": LM,
"gpt-4o-audio-preview": LM,
}


Expand Down Expand Up @@ -228,3 +230,25 @@ def test_sim_join_operation_text_index(setup_models, model):
("https://i.pinimg.com/236x/a4/3a/65/a43a65683a0314f29b66402cebdcf46d.jpg", "bird"),
]
assert expected_result == list(zip(joined_df["image"], joined_df["element"]))


@pytest.mark.parametrize("model", get_enabled("gpt-4o-audio-preview"))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to enable gpt-4o-audio-preview

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I enabled gpt-4o-audio-preview in multimodality_tests.py and updated the test to use a valid WAV file input. The test test_filter_operation_audio now passes locally!

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@harshitgupta412 , Could you please review the changes ?

def test_filter_operation_audio(setup_models, model):
lm = setup_models[model]
lotus.settings.configure(lm=lm)

# Use a real wav file content to ensure valid format
import base64
with open("test_audio.wav", "rb") as f:
wav_bytes = f.read()
wav_b64 = "data:audio/wav;base64," + base64.b64encode(wav_bytes).decode("utf-8")

audio_data = [wav_b64, wav_b64]
df = pd.DataFrame({"audio": AudioArray(audio_data)})
user_instruction = "{audio} contains audio"

# Just verify it runs without error and returns a dataframe
filtered_df = df.sem_filter(user_instruction)
assert isinstance(filtered_df, pd.DataFrame)


98 changes: 98 additions & 0 deletions examples/ensembling_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""
Example: Using Test-Time Scaling (Ensembling) with sem_filter

This example demonstrates how to use the new ensembling feature in sem_filter
to improve prediction accuracy by aggregating multiple LLM samples.
"""

import pandas as pd

import lotus
from lotus.models import LM
from lotus.sem_ops.ensembling import Ensemble, EnsembleConfig, EnsembleStrategy

# Configure the language model
lm = LM(model="gpt-4o-mini")
lotus.settings.configure(lm=lm)

# Create a sample DataFrame with movie reviews
df = pd.DataFrame({
"review": [
"This movie was absolutely fantastic! Best film I've seen all year.",
"Terrible waste of time. The plot made no sense whatsoever.",
"It was okay, had some good moments but also some boring parts.",
"A masterpiece of modern cinema. Highly recommend!",
"I fell asleep halfway through. Very disappointing.",
]
})

# Example 1: Basic ensembling with default MAJORITY_VOTE strategy
print("Example 1: Basic Ensembling (Majority Vote)")
print("-" * 50)

result = df.sem_filter(
"The {review} expresses a positive sentiment",
n_sample=3, # Run 3 samples and aggregate
)

print(f"Filtered to {len(result)} positive reviews")
print(result)

# Example 2: Using a custom ensemble configuration
print("\nExample 2: Custom Ensemble Configuration (Weighted Average)")
print("-" * 50)

# Create a custom ensemble with weighted average strategy
config = EnsembleConfig(
strategy=EnsembleStrategy.WEIGHTED_AVERAGE,
weights=[0.5, 0.3, 0.2], # Weight earlier samples more heavily
)
ensemble = Ensemble(config)

result = df.sem_filter(
"The {review} mentions specific plot details",
n_sample=3,
ensemble=ensemble,
)

print(f"Filtered to {len(result)} reviews with plot details")
print(result)

# Example 3: Accessing per-run data
print("\nExample 3: Accessing Per-Run Data")
print("-" * 50)

# Use return_all=True to get full output object with per-run details
result_with_details, stats = df.sem_filter(
"The {review} is written in a sarcastic tone",
n_sample=5,
return_stats=True,
return_all=True, # Return all rows, not just filtered ones
)

# The output contains predictions from all runs
# Access via the _raw_outputs attribute
print("Total samples run: 5")
print(f"Stats: {stats}")
print(result_with_details)

# Example 4: Consensus strategy (only returns True if all samples agree)
print("\nExample 4: Consensus Strategy")
print("-" * 50)

config = EnsembleConfig(
strategy=EnsembleStrategy.CONSENSUS,
default=False, # Default to False if no consensus
)
ensemble = Ensemble(config)

result = df.sem_filter(
"The {review} contains extremely strong language",
n_sample=3,
ensemble=ensemble,
)

print(f"Filtered to {len(result)} reviews (required unanimous agreement)")
print(result)

print("\nDone!")
8 changes: 7 additions & 1 deletion lotus/dtype_extensions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
from lotus.dtype_extensions.image import ImageDtype, ImageArray
from lotus.dtype_extensions.audio import AudioDtype, AudioArray
import pandas as pd

pd.api.extensions.register_extension_dtype(ImageDtype)
pd.api.extensions.register_extension_dtype(AudioDtype)


def convert_to_base_data(data: pd.Series | list) -> list:
"""
Converts data to proper base data type.
- For original pandas data types, this is returns tolist().
- For ImageDtype, this returns list of PIL.Image.Image.
- For AudioDtype, this returns list of audio data.
"""
if isinstance(data, pd.Series):
if isinstance(data.dtype, ImageDtype):
return [data.array.get_image(i) for i in range(len(data))]
if isinstance(data.dtype, AudioDtype):
return [data.array.get_audio(i) for i in range(len(data))]
return data.tolist()

return data


__all__ = ["ImageDtype", "ImageArray", "convert_to_base_data"]
__all__ = ["ImageDtype", "ImageArray", "AudioDtype", "AudioArray", "convert_to_base_data"]

Loading