diff --git a/llmebench/models/Gemini.py b/llmebench/models/Gemini.py index 934c675e..9b8dd9d5 100644 --- a/llmebench/models/Gemini.py +++ b/llmebench/models/Gemini.py @@ -6,6 +6,7 @@ import vertexai import vertexai.preview.generative_models as generative_models +from google.oauth2 import service_account from vertexai.generative_models import FinishReason, GenerativeModel, Part from llmebench.models.model_base import ModelBase @@ -53,50 +54,75 @@ class GeminiModel(ModelBase): def __init__( self, project_id=None, - api_key=None, model_name=None, + location=None, + credentials_path=None, # path to JSON file + credentials_info=None, # dict or JSON string timeout=20, temperature=0, + tolerance=1e-7, top_p=0.95, max_tokens=2000, **kwargs, ): - # API parameters - # self.api_url = api_url or os.getenv("AZURE_DEPLOYMENT_API_URL") - self.api_key = api_key or os.getenv("GOOGLE_API_KEY") self.project_id = project_id or os.getenv("GOOGLE_PROJECT_ID") self.model_name = model_name or os.getenv("MODEL") - if self.api_key is None: + self.location = location or os.getenv("VERTEX_LOCATION") or "us-central1" + self.credentials = None + + # 1. Prefer explicit credentials_info (dict or JSON string) + if credentials_info: + if isinstance(credentials_info, str): + credentials_info = json.loads(credentials_info) + self.credentials = service_account.Credentials.from_service_account_info( + credentials_info + ) + # 2. Else, load from path (arg or env) + elif credentials_path or os.getenv("GOOGLE_APPLICATION_CREDENTIALS"): + path = credentials_path or os.getenv("GOOGLE_APPLICATION_CREDENTIALS") + with open(path, "r") as f: + info = json.load(f) + self.credentials = service_account.Credentials.from_service_account_info( + info + ) + elif os.getenv("GOOGLE_APPLICATION_CREDENTIALS") is not None: + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.getenv( + "GOOGLE_APPLICATION_CREDENTIALS" + ) + # 3. Else, None: will fall back to ADC (Application Default Credentials) + + if not self.project_id: raise Exception( - "API Key must be provided as model config or environment variable (`GOOGLE_API_KEY`)" + "PROJECT_ID must be set (argument or `GOOGLE_PROJECT_ID` in .env)" ) - if self.project_id is None: + if not self.model_name: + raise Exception("MODEL must be set (argument or `MODEL` in .env)") + if not self.location: raise Exception( - "PROJECT_ID must be provided as model config or environment variable (`GOOGLE_PROJECT_ID`)" + "LOCATION must be set (argument or `VERTEX_LOCATION` in .env)" ) - self.api_timeout = timeout + + vertexai.init( + project=self.project_id, + location=self.location, + credentials=self.credentials, + ) + + self.tolerance = tolerance + self.temperature = max(temperature, tolerance) + self.top_p = top_p + self.max_tokens = max_tokens + self.safety_settings = { generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH, generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH, generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH, generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH, } - # Parameters - tolerance = 1e-7 - self.temperature = temperature - if self.temperature < tolerance: - # Currently, the model inference fails if temperature - # is exactly 0, so we nudge it slightly to work around - # the issue - self.temperature += tolerance - self.top_p = top_p - self.max_tokens = max_tokens super(GeminiModel, self).__init__( retry_exceptions=(TimeoutError, GeminiFailure), **kwargs ) - vertexai.init(project=self.project_id, location="us-central1") - # self.client = GenerativeModel(self.model_name) def summarize_response(self, response): """Returns the "outputs" key's value, if available""" @@ -127,20 +153,6 @@ def prompt(self, processed_input): This method raises this exception if the server responded with a non-ok response """ - # headers = { - # "Content-Type": "application/json", - # "Authorization": "Bearer " + self.api_key, - # } - # body = { - # "input_data": { - # "input_string": processed_input, - # "parameters": { - # "max_tokens": self.max_tokens, - # "temperature": self.temperature, - # "top_p": self.top_p, - # }, - # } - # } generation_config = { "max_output_tokens": self.max_tokens, "temperature": self.temperature, diff --git a/tests/models/test_Gemini.py b/tests/models/test_Gemini.py index 7d189962..55c399e6 100644 --- a/tests/models/test_Gemini.py +++ b/tests/models/test_Gemini.py @@ -53,18 +53,20 @@ class TestGeminiDepModelConfig(unittest.TestCase): def test_gemini_deployed_model_config(self): "Test if model config parameters passed as arguments are used" model = GeminiModel( - project_id="test_project_id", api_key="secret-key", model_name="gemini-test" + project_id="test_project_id", + model_name="gemini-test", + location="us-central1", ) self.assertEqual(model.project_id, "test_project_id") - self.assertEqual(model.api_key, "secret-key") + self.assertEqual(model.location, "us-central1") self.assertEqual(model.model_name, "gemini-test") @patch.dict( "os.environ", { "GOOGLE_PROJECT_ID": "test_project_id", - "GOOGLE_API_KEY": "secret-key", + "LOCATION": "us-central1", "MODEL": "gemini-test", }, ) @@ -73,23 +75,25 @@ def test_gemini_deployed_model_config_env_var(self): model = GeminiModel() self.assertEqual(model.project_id, "test_project_id") - self.assertEqual(model.api_key, "secret-key") + self.assertEqual(model.location, "us-central1") self.assertEqual(model.model_name, "gemini-test") @patch.dict( "os.environ", { "GOOGLE_PROJECT_ID": "test_project_id", - "GOOGLE_API_KEY": "secret-env-key", + "LOCATION": "us-central1", "MODEL": "gemini-test", }, ) def test_gemini_deployed_model_config_priority(self): "Test if model config parameters passed directly get priority" model = GeminiModel( - project_id="test_project_id", api_key="secret-key", model_name="gemini_test" + project_id="test_project_id", + model_name="gemini_test", + location="us-central1", ) self.assertEqual(model.project_id, "test_project_id") - self.assertEqual(model.api_key, "secret-key") + self.assertEqual(model.location, "us-central1") self.assertEqual(model.model_name, "gemini_test")