-
Notifications
You must be signed in to change notification settings - Fork 140
Implement semantic map + filter #209
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
IsmaelKabore
wants to merge
3
commits into
lotus-data:main
Choose a base branch
from
IsmaelKabore:feature/semantic-map-filter
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,257 @@ | ||
| import os | ||
|
|
||
| import pandas as pd | ||
| import pytest | ||
|
|
||
| import lotus | ||
| from lotus.models import LM | ||
| from lotus.types import EnsembleStrategy, ReasoningStrategy | ||
|
|
||
| # Gate the whole module behind an env var so devs/CI without Ollama skip cleanly. | ||
| ENABLE_OLLAMA_TESTS = os.getenv("ENABLE_OLLAMA_TESTS", "false").lower() == "true" | ||
|
|
||
| pytestmark = pytest.mark.skipif( | ||
| not ENABLE_OLLAMA_TESTS, | ||
| reason="Set ENABLE_OLLAMA_TESTS=true to run Ollama-backed tests", | ||
| ) | ||
|
|
||
| MODEL_NAME = "ollama/llama3.1" | ||
|
|
||
|
|
||
| @pytest.fixture(scope="session") | ||
| def setup_models(): | ||
| return {MODEL_NAME: LM(model=MODEL_NAME)} | ||
|
|
||
|
|
||
| @pytest.fixture(autouse=True) | ||
| def print_usage_after_each_test(setup_models): | ||
| yield | ||
| for _, m in setup_models.items(): | ||
| m.print_total_usage() | ||
| m.reset_stats() | ||
| m.reset_cache() | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("model", [MODEL_NAME]) | ||
| def test_df_sem_filter_basic(setup_models, model): | ||
| lotus.settings.configure(lm=setup_models[model]) | ||
|
|
||
| df = pd.DataFrame( | ||
| { | ||
| "Text": [ | ||
| "I am really excited to go to class today!", | ||
| "I am very sad", | ||
| ] | ||
| } | ||
| ) | ||
| user_instruction = "{Text} is a positive sentiment" | ||
|
|
||
| filtered = df.sem_filter(user_instruction) | ||
|
|
||
| assert isinstance(filtered, pd.DataFrame) | ||
| assert "Text" in filtered.columns | ||
| assert len(filtered) >= 1 | ||
| # Tolerant: ensure obviously negative text is unlikely to pass | ||
| assert "I am very sad" not in filtered["Text"].tolist() | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("model", [MODEL_NAME]) | ||
| def test_df_sem_filter_with_sampling_pick_first(setup_models, model): | ||
| lotus.settings.configure(lm=setup_models[model]) | ||
|
|
||
| df = pd.DataFrame({"Text": ["Today is fantastic!", "This is terrible.", "Pretty good overall."]}) | ||
| user_instruction = "{Text} is a positive sentiment" | ||
|
|
||
| # Exercise resampling path (n_sample>1) but keep ensemble as PICK_FIRST. | ||
| filtered = df.sem_filter( | ||
| user_instruction, | ||
| n_sample=2, | ||
| temperature=0.9, | ||
| # explicit for clarity (default is PICK_FIRST) | ||
| ensemble=EnsembleStrategy.PICK_FIRST, | ||
| ) | ||
|
|
||
| assert isinstance(filtered, pd.DataFrame) | ||
| assert "Text" in filtered.columns | ||
| assert len(filtered) >= 1 | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("model", [MODEL_NAME]) | ||
| def test_df_sem_filter_with_majority_ensemble_runs(setup_models, model): | ||
| lotus.settings.configure(lm=setup_models[model]) | ||
|
|
||
| df = pd.DataFrame({"Text": ["Great job!", "Awful service.", "Fine, I guess."]}) | ||
| user_instruction = "{Text} is a positive sentiment" | ||
|
|
||
| # Exercise the MAJORITY voting path (boolean labels after postprocess). | ||
| # We keep assertions tolerant because model outputs can vary. | ||
| filtered = df.sem_filter( | ||
| user_instruction, | ||
| n_sample=3, | ||
| temperature=0.9, | ||
| ensemble=EnsembleStrategy.MAJORITY, | ||
| ) | ||
|
|
||
| assert isinstance(filtered, pd.DataFrame) | ||
| assert "Text" in filtered.columns | ||
| # Should not error; size can vary depending on votes | ||
| assert len(filtered) >= 0 | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("model", [MODEL_NAME]) | ||
| def test_df_sem_filter_invalid_n_sample_raises(setup_models, model): | ||
| lotus.settings.configure(lm=setup_models[model]) | ||
|
|
||
| df = pd.DataFrame({"Text": ["Neutral thing"]}) | ||
| user_instruction = "{Text} is a positive sentiment" | ||
|
|
||
| with pytest.raises(ValueError): | ||
| df.sem_filter(user_instruction, n_sample=0) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("model", [MODEL_NAME]) | ||
| def test_df_sem_filter_return_all_and_explanations(setup_models, model): | ||
| lotus.settings.configure(lm=setup_models[model]) | ||
|
|
||
| df = pd.DataFrame({"Text": ["I love sunshine", "I hate rain"]}) | ||
| user_instruction = "{Text} is a positive sentiment" | ||
|
|
||
| # Ask for all rows + explanations (ZS_COT forces explanation path on) | ||
| full_df = df.sem_filter( | ||
| user_instruction, | ||
| return_all=True, | ||
| return_explanations=True, | ||
| strategy=ReasoningStrategy.ZS_COT, | ||
| ) | ||
|
|
||
| assert isinstance(full_df, pd.DataFrame) | ||
| # When return_all=True, an extra 'filter_label' column is added (no suffix) | ||
| assert "filter_label" in full_df.columns | ||
| # Explanation column uses the default suffix "_filter" | ||
| assert "explanation_filter" in full_df.columns | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("model", [MODEL_NAME]) | ||
| def test_df_sem_filter_return_stats_tuple(setup_models, model): | ||
| lotus.settings.configure(lm=setup_models[model]) | ||
|
|
||
| df = pd.DataFrame({"Text": ["I love this", "I dislike that"]}) | ||
| user_instruction = "{Text} is a positive sentiment" | ||
|
|
||
| result = df.sem_filter( | ||
| user_instruction, | ||
| return_stats=True, | ||
| ) | ||
|
|
||
| # When return_stats=True, we get (DataFrame, stats_dict) | ||
| assert isinstance(result, tuple) and len(result) == 2 | ||
| out_df, stats = result | ||
| assert isinstance(out_df, pd.DataFrame) | ||
| assert isinstance(stats, dict) | ||
|
|
||
|
|
||
| # NEW TESTS FOR MULTI-RUN ROLLOUT FUNCTIONALITY | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("model", [MODEL_NAME]) | ||
| def test_df_sem_filter_with_rollout_columns(setup_models, model): | ||
| """Test that per-run rollout columns are correctly added when n_sample > 1.""" | ||
| lotus.settings.configure(lm=setup_models[model]) | ||
|
|
||
| df = pd.DataFrame({"Text": ["Great product!", "Terrible service.", "It's okay."]}) | ||
| user_instruction = "{Text} is a positive sentiment" | ||
|
|
||
| # Test with return_all=True to see all rollout data | ||
| full_df = df.sem_filter( | ||
| user_instruction, | ||
| n_sample=3, | ||
| ensemble=EnsembleStrategy.MAJORITY, | ||
| return_all=True, | ||
| return_explanations=True, | ||
| temperature=0.9, | ||
| ) | ||
|
|
||
| assert isinstance(full_df, pd.DataFrame) | ||
|
|
||
| # Check that per-run columns exist | ||
| assert "raw_output_1_filter" in full_df.columns | ||
| assert "raw_output_2_filter" in full_df.columns | ||
| assert "raw_output_3_filter" in full_df.columns | ||
|
|
||
| assert "parsed_output_1_filter" in full_df.columns | ||
| assert "parsed_output_2_filter" in full_df.columns | ||
| assert "parsed_output_3_filter" in full_df.columns | ||
|
|
||
| assert "explanation_1_filter" in full_df.columns | ||
| assert "explanation_2_filter" in full_df.columns | ||
| assert "explanation_3_filter" in full_df.columns | ||
|
|
||
| # Check that ensemble answer column exists | ||
| assert "ensemble_answer_filter" in full_df.columns | ||
|
|
||
| # Verify that all rows are present (return_all=True) | ||
| assert len(full_df) == 3 | ||
|
|
||
| # Verify data types | ||
| assert full_df["parsed_output_1_filter"].dtype == bool | ||
| assert full_df["parsed_output_2_filter"].dtype == bool | ||
| assert full_df["parsed_output_3_filter"].dtype == bool | ||
| assert full_df["ensemble_answer_filter"].dtype == bool | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("model", [MODEL_NAME]) | ||
| def test_df_sem_filter_rollout_columns_filtered_rows(setup_models, model): | ||
| """Test that per-run columns work correctly when return_all=False (filtered rows only).""" | ||
| lotus.settings.configure(lm=setup_models[model]) | ||
|
|
||
| df = pd.DataFrame({"Text": ["Amazing experience!", "Worst ever.", "Pretty good."]}) | ||
| user_instruction = "{Text} is a positive sentiment" | ||
|
|
||
| # Test with return_all=False (default) - only rows that pass the filter | ||
| filtered_df = df.sem_filter( | ||
| user_instruction, | ||
| n_sample=3, | ||
| ensemble=EnsembleStrategy.MAJORITY, | ||
| return_all=False, # explicit for clarity | ||
| temperature=0.9, | ||
| ) | ||
|
|
||
| assert isinstance(filtered_df, pd.DataFrame) | ||
|
|
||
| # Check that per-run columns exist | ||
| assert "raw_output_1_filter" in filtered_df.columns | ||
| assert "parsed_output_1_filter" in filtered_df.columns | ||
|
|
||
| # Check ensemble answer column | ||
| assert "ensemble_answer_filter" in filtered_df.columns | ||
|
|
||
| # All rows in filtered result should have ensemble_answer=True | ||
| assert all(filtered_df["ensemble_answer_filter"]) | ||
|
|
||
| # Verify we got at least some positive samples | ||
| assert len(filtered_df) >= 1 | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("model", [MODEL_NAME]) | ||
| def test_df_sem_filter_no_rollout_columns_when_n_sample_1(setup_models, model): | ||
| """Test that per-run columns are NOT added when n_sample=1 (default behavior).""" | ||
| lotus.settings.configure(lm=setup_models[model]) | ||
|
|
||
| df = pd.DataFrame({"Text": ["Great!", "Bad."]}) | ||
| user_instruction = "{Text} is a positive sentiment" | ||
|
|
||
| # Default n_sample=1, so no rollout columns should appear | ||
| full_df = df.sem_filter( | ||
| user_instruction, | ||
| return_all=True, | ||
| ) | ||
|
|
||
| assert isinstance(full_df, pd.DataFrame) | ||
|
|
||
| # These columns should NOT exist when n_sample=1 | ||
| assert "raw_output_1_filter" not in full_df.columns | ||
| assert "parsed_output_1_filter" not in full_df.columns | ||
| assert "ensemble_answer_filter" not in full_df.columns | ||
|
|
||
| # But the regular filter_label should exist | ||
| assert "filter_label" in full_df.columns | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| { | ||
| "python.testing.pytestArgs": [ | ||
| "docs" | ||
| ], | ||
| "python.testing.unittestEnabled": false, | ||
| "python.testing.pytestEnabled": true | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should add assertions checking the output columns are as expected