Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 257 additions & 0 deletions .github/tests/test_sem_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
import os

import pandas as pd
import pytest

import lotus
from lotus.models import LM
from lotus.types import EnsembleStrategy, ReasoningStrategy

# Gate the whole module behind an env var so devs/CI without Ollama skip cleanly.
ENABLE_OLLAMA_TESTS = os.getenv("ENABLE_OLLAMA_TESTS", "false").lower() == "true"

pytestmark = pytest.mark.skipif(
not ENABLE_OLLAMA_TESTS,
reason="Set ENABLE_OLLAMA_TESTS=true to run Ollama-backed tests",
)

MODEL_NAME = "ollama/llama3.1"


@pytest.fixture(scope="session")
def setup_models():
return {MODEL_NAME: LM(model=MODEL_NAME)}


@pytest.fixture(autouse=True)
def print_usage_after_each_test(setup_models):
yield
for _, m in setup_models.items():
m.print_total_usage()
m.reset_stats()
m.reset_cache()


@pytest.mark.parametrize("model", [MODEL_NAME])
def test_df_sem_filter_basic(setup_models, model):
lotus.settings.configure(lm=setup_models[model])

df = pd.DataFrame(
{
"Text": [
"I am really excited to go to class today!",
"I am very sad",
]
}
)
user_instruction = "{Text} is a positive sentiment"

filtered = df.sem_filter(user_instruction)

assert isinstance(filtered, pd.DataFrame)
assert "Text" in filtered.columns
assert len(filtered) >= 1
# Tolerant: ensure obviously negative text is unlikely to pass
assert "I am very sad" not in filtered["Text"].tolist()


@pytest.mark.parametrize("model", [MODEL_NAME])
def test_df_sem_filter_with_sampling_pick_first(setup_models, model):
lotus.settings.configure(lm=setup_models[model])

df = pd.DataFrame({"Text": ["Today is fantastic!", "This is terrible.", "Pretty good overall."]})
user_instruction = "{Text} is a positive sentiment"

# Exercise resampling path (n_sample>1) but keep ensemble as PICK_FIRST.
filtered = df.sem_filter(
user_instruction,
n_sample=2,
temperature=0.9,
# explicit for clarity (default is PICK_FIRST)
ensemble=EnsembleStrategy.PICK_FIRST,
)

assert isinstance(filtered, pd.DataFrame)
assert "Text" in filtered.columns
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should add assertions checking the output columns are as expected

assert len(filtered) >= 1


@pytest.mark.parametrize("model", [MODEL_NAME])
def test_df_sem_filter_with_majority_ensemble_runs(setup_models, model):
lotus.settings.configure(lm=setup_models[model])

df = pd.DataFrame({"Text": ["Great job!", "Awful service.", "Fine, I guess."]})
user_instruction = "{Text} is a positive sentiment"

# Exercise the MAJORITY voting path (boolean labels after postprocess).
# We keep assertions tolerant because model outputs can vary.
filtered = df.sem_filter(
user_instruction,
n_sample=3,
temperature=0.9,
ensemble=EnsembleStrategy.MAJORITY,
)

assert isinstance(filtered, pd.DataFrame)
assert "Text" in filtered.columns
# Should not error; size can vary depending on votes
assert len(filtered) >= 0


@pytest.mark.parametrize("model", [MODEL_NAME])
def test_df_sem_filter_invalid_n_sample_raises(setup_models, model):
lotus.settings.configure(lm=setup_models[model])

df = pd.DataFrame({"Text": ["Neutral thing"]})
user_instruction = "{Text} is a positive sentiment"

with pytest.raises(ValueError):
df.sem_filter(user_instruction, n_sample=0)


@pytest.mark.parametrize("model", [MODEL_NAME])
def test_df_sem_filter_return_all_and_explanations(setup_models, model):
lotus.settings.configure(lm=setup_models[model])

df = pd.DataFrame({"Text": ["I love sunshine", "I hate rain"]})
user_instruction = "{Text} is a positive sentiment"

# Ask for all rows + explanations (ZS_COT forces explanation path on)
full_df = df.sem_filter(
user_instruction,
return_all=True,
return_explanations=True,
strategy=ReasoningStrategy.ZS_COT,
)

assert isinstance(full_df, pd.DataFrame)
# When return_all=True, an extra 'filter_label' column is added (no suffix)
assert "filter_label" in full_df.columns
# Explanation column uses the default suffix "_filter"
assert "explanation_filter" in full_df.columns


@pytest.mark.parametrize("model", [MODEL_NAME])
def test_df_sem_filter_return_stats_tuple(setup_models, model):
lotus.settings.configure(lm=setup_models[model])

df = pd.DataFrame({"Text": ["I love this", "I dislike that"]})
user_instruction = "{Text} is a positive sentiment"

result = df.sem_filter(
user_instruction,
return_stats=True,
)

# When return_stats=True, we get (DataFrame, stats_dict)
assert isinstance(result, tuple) and len(result) == 2
out_df, stats = result
assert isinstance(out_df, pd.DataFrame)
assert isinstance(stats, dict)


# NEW TESTS FOR MULTI-RUN ROLLOUT FUNCTIONALITY


@pytest.mark.parametrize("model", [MODEL_NAME])
def test_df_sem_filter_with_rollout_columns(setup_models, model):
"""Test that per-run rollout columns are correctly added when n_sample > 1."""
lotus.settings.configure(lm=setup_models[model])

df = pd.DataFrame({"Text": ["Great product!", "Terrible service.", "It's okay."]})
user_instruction = "{Text} is a positive sentiment"

# Test with return_all=True to see all rollout data
full_df = df.sem_filter(
user_instruction,
n_sample=3,
ensemble=EnsembleStrategy.MAJORITY,
return_all=True,
return_explanations=True,
temperature=0.9,
)

assert isinstance(full_df, pd.DataFrame)

# Check that per-run columns exist
assert "raw_output_1_filter" in full_df.columns
assert "raw_output_2_filter" in full_df.columns
assert "raw_output_3_filter" in full_df.columns

assert "parsed_output_1_filter" in full_df.columns
assert "parsed_output_2_filter" in full_df.columns
assert "parsed_output_3_filter" in full_df.columns

assert "explanation_1_filter" in full_df.columns
assert "explanation_2_filter" in full_df.columns
assert "explanation_3_filter" in full_df.columns

# Check that ensemble answer column exists
assert "ensemble_answer_filter" in full_df.columns

# Verify that all rows are present (return_all=True)
assert len(full_df) == 3

# Verify data types
assert full_df["parsed_output_1_filter"].dtype == bool
assert full_df["parsed_output_2_filter"].dtype == bool
assert full_df["parsed_output_3_filter"].dtype == bool
assert full_df["ensemble_answer_filter"].dtype == bool


@pytest.mark.parametrize("model", [MODEL_NAME])
def test_df_sem_filter_rollout_columns_filtered_rows(setup_models, model):
"""Test that per-run columns work correctly when return_all=False (filtered rows only)."""
lotus.settings.configure(lm=setup_models[model])

df = pd.DataFrame({"Text": ["Amazing experience!", "Worst ever.", "Pretty good."]})
user_instruction = "{Text} is a positive sentiment"

# Test with return_all=False (default) - only rows that pass the filter
filtered_df = df.sem_filter(
user_instruction,
n_sample=3,
ensemble=EnsembleStrategy.MAJORITY,
return_all=False, # explicit for clarity
temperature=0.9,
)

assert isinstance(filtered_df, pd.DataFrame)

# Check that per-run columns exist
assert "raw_output_1_filter" in filtered_df.columns
assert "parsed_output_1_filter" in filtered_df.columns

# Check ensemble answer column
assert "ensemble_answer_filter" in filtered_df.columns

# All rows in filtered result should have ensemble_answer=True
assert all(filtered_df["ensemble_answer_filter"])

# Verify we got at least some positive samples
assert len(filtered_df) >= 1


@pytest.mark.parametrize("model", [MODEL_NAME])
def test_df_sem_filter_no_rollout_columns_when_n_sample_1(setup_models, model):
"""Test that per-run columns are NOT added when n_sample=1 (default behavior)."""
lotus.settings.configure(lm=setup_models[model])

df = pd.DataFrame({"Text": ["Great!", "Bad."]})
user_instruction = "{Text} is a positive sentiment"

# Default n_sample=1, so no rollout columns should appear
full_df = df.sem_filter(
user_instruction,
return_all=True,
)

assert isinstance(full_df, pd.DataFrame)

# These columns should NOT exist when n_sample=1
assert "raw_output_1_filter" not in full_df.columns
assert "parsed_output_1_filter" not in full_df.columns
assert "ensemble_answer_filter" not in full_df.columns

# But the regular filter_label should exist
assert "filter_label" in full_df.columns
7 changes: 7 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"python.testing.pytestArgs": [
"docs"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
4 changes: 4 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,10 @@ lm = LM(
- **Discussions**: https://github.com/lotus-data/lotus/discussions
- **Issues**: https://github.com/lotus-data/lotus/issues

## Code of Conduct

We are committed to providing a welcoming and inclusive environment for all contributors. Please be respectful and constructive in all interactions.


---

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,4 +168,4 @@ If you find LOTUS or semantic operators useful, we'd appreciate if you can pleas
eprint={2407.11418},
url={https://arxiv.org/abs/2407.11418},
}
```
```
Loading