-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_model_resolution.py
More file actions
88 lines (74 loc) · 2.84 KB
/
test_model_resolution.py
File metadata and controls
88 lines (74 loc) · 2.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import unittest
from unittest.mock import MagicMock
import google.generativeai as genai
# The logic to be tested
def resolve_model_name(available_models):
"""
Selects the best model from the available_models list.
available_models: list of objects with a .name attribute.
"""
# Priority list of models to look for
# We strip 'models/' prefix for comparison if needed, or check both.
# Usually list_models returns 'models/gemini-1.5-flash'
candidates = [
"gemini-1.5-flash",
"gemini-1.5-flash-latest",
"gemini-1.5-flash-001",
"gemini-flash-latest",
"gemini-1.5-pro",
"gemini-pro"
]
# Create a set of available model names for O(1) lookup
# Normalize by removing 'models/' prefix if present
available_names = set()
raw_names = {} # Map stripped -> full name
for m in available_models:
name = m.name
if name.startswith("models/"):
stripped = name[7:]
else:
stripped = name
available_names.add(stripped)
raw_names[stripped] = name
# Check candidates
for candidate in candidates:
if candidate in available_names:
return raw_names[candidate]
# Fallback
return "gemini-1.5-flash"
class TestModelResolution(unittest.TestCase):
def make_models(self, names):
models = []
for n in names:
m = MagicMock()
m.name = n
m.supported_generation_methods = ["generateContent"]
models.append(m)
return models
def test_exact_match(self):
models = self.make_models(["models/gemini-1.5-flash", "models/gemini-pro"])
result = resolve_model_name(models)
self.assertEqual(result, "models/gemini-1.5-flash")
def test_variant_match(self):
models = self.make_models(["models/gemini-1.5-flash-latest", "models/gemini-pro"])
result = resolve_model_name(models)
self.assertEqual(result, "models/gemini-1.5-flash-latest")
def test_fallback_to_pro(self):
models = self.make_models(["models/gemini-pro", "models/text-bison-001"])
result = resolve_model_name(models)
self.assertEqual(result, "models/gemini-pro")
def test_no_match_fallback(self):
models = self.make_models(["models/unknown-model"])
result = resolve_model_name(models)
self.assertEqual(result, "gemini-1.5-flash")
def test_empty_list(self):
models = []
result = resolve_model_name(models)
self.assertEqual(result, "gemini-1.5-flash")
def test_preference_order(self):
# Should prefer flash over pro
models = self.make_models(["models/gemini-pro", "models/gemini-1.5-flash"])
result = resolve_model_name(models)
self.assertEqual(result, "models/gemini-1.5-flash")
if __name__ == "__main__":
unittest.main()