Skip to content

Commit e5dffc6

Browse files
amathxbtadambaloghCopilot
authored
fix: Firebase token refresh, create_model notes mismatch (#202)
* fix: correct Permit2 allowance threshold from 10% to 100% (closes #188) The allowance guard used amount_base * 0.1 (10%) which allowed x402 payments to proceed when only 10% of the required OPG was approved. This caused downstream payment failures when the allowance was between 10-99% of the required amount. Fix: compare allowance_before against the full amount_base so the approval step is skipped only when allowance is already sufficient. * fix: ModelHub Firebase token refresh + create_model version/notes fix (closes #164, #157) Bug 1 — Firebase idToken expiry (closes #164): ModelHub cached self._hub_user at login time and never refreshed the idToken. Firebase tokens expire after 3600 s, so any API call made more than ~1 hour after construction silently fails with 401. Fix: add _get_auth_token() which checks time.time() against a stored expiry and calls firebase_app.auth().refresh(refreshToken) when the token is within _TOKEN_REFRESH_MARGIN_SEC (60 s) of expiry. All methods now call _get_auth_token() instead of reading idToken directly. Bug 2 — create_model passes version label as notes (closes #157): create_model(model_name, model_desc, version='1.00') called self.create_version(model_name, version) which mapped the version string '1.00' to the positional parameter of create_version. The server ignores that field as a version specifier and auto-assigns its own version string, so the argument was silently discarded. Fix: call self.create_version(created_name, notes=f'Initial version {version}') to make the intent explicit, and rename the local variable from model_name to created_name to avoid shadowing the input parameter. * fix: cast idToken to str to satisfy mypy [no-any-return] on line 83 firebase is an untyped package (type: ignore[import-untyped]), so self._hub_user['idToken'] resolves to Any. Since _get_auth_token() is declared -> str, mypy raised: error: Returning Any from function declared to return 'str' [no-any-return] Fix: wrap with str() cast which is always safe since Firebase idTokens are JWT strings. * Update src/opengradient/client/model_hub.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: kukac <adambalogh@users.noreply.github.com> --------- Signed-off-by: kukac <adambalogh@users.noreply.github.com> Co-authored-by: kukac <adambalogh@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 7eaf127 commit e5dffc6

1 file changed

Lines changed: 58 additions & 28 deletions

File tree

src/opengradient/client/model_hub.py

Lines changed: 58 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Model Hub for creating, versioning, and uploading ML models."""
22

33
import os
4+
import time
45
from typing import Dict, List, Optional
56

67
import firebase # type: ignore[import-untyped]
@@ -19,6 +20,9 @@
1920
"databaseURL": os.getenv("FIREBASE_DATABASE_URL", ""),
2021
}
2122

23+
# Firebase idTokens expire after 3600 seconds; refresh 60 s before expiry
24+
_TOKEN_REFRESH_MARGIN_SEC = 60
25+
2226

2327
class ModelHub:
2428
"""
@@ -34,15 +38,49 @@ class ModelHub:
3438
"""
3539

3640
def __init__(self, email: Optional[str] = None, password: Optional[str] = None):
37-
self._hub_user = self._login(email, password) if email is not None else None
41+
self._firebase_app = None
42+
self._hub_user = None
43+
self._token_expiry: float = 0.0
44+
45+
if email is not None:
46+
self._firebase_app, self._hub_user = self._login(email, password)
47+
expires_in = int(self._hub_user.get("expiresIn", 3600))
48+
self._token_expiry = time.time() + expires_in
3849

3950
@staticmethod
4051
def _login(email: str, password: Optional[str]):
4152
if not _FIREBASE_CONFIG.get("apiKey"):
4253
raise ValueError("Firebase API Key is missing in environment variables")
4354

4455
firebase_app = firebase.initialize_app(_FIREBASE_CONFIG)
45-
return firebase_app.auth().sign_in_with_email_and_password(email, password)
56+
user = firebase_app.auth().sign_in_with_email_and_password(email, password)
57+
return firebase_app, user
58+
59+
def _get_auth_token(self) -> str:
60+
"""Return a valid Firebase idToken, refreshing it if it has expired or is
61+
about to expire within ``_TOKEN_REFRESH_MARGIN_SEC`` seconds.
62+
63+
Raises:
64+
ValueError: If the user is not authenticated.
65+
"""
66+
if not self._hub_user:
67+
raise ValueError("User not authenticated")
68+
69+
if time.time() >= self._token_expiry - _TOKEN_REFRESH_MARGIN_SEC:
70+
# Refresh the token using the stored refresh token
71+
refresh_token = self._hub_user.get("refreshToken")
72+
if not refresh_token or self._firebase_app is None:
73+
raise ValueError(
74+
"Cannot refresh Firebase token: missing refresh token or Firebase app. "
75+
"Please re-authenticate by creating a new ModelHub instance."
76+
)
77+
refreshed = self._firebase_app.auth().refresh(refresh_token)
78+
self._hub_user["idToken"] = refreshed["idToken"]
79+
self._hub_user["refreshToken"] = refreshed.get("refreshToken", refresh_token)
80+
expires_in = int(refreshed.get("expiresIn", 3600))
81+
self._token_expiry = time.time() + expires_in
82+
83+
return str(self._hub_user["idToken"]) # cast Any->str for mypy [no-any-return]
4684

4785
def create_model(self, model_name: str, model_desc: str, version: str = "1.00") -> ModelRepository:
4886
"""
@@ -51,19 +89,17 @@ def create_model(self, model_name: str, model_desc: str, version: str = "1.00")
5189
Args:
5290
model_name (str): The name of the model.
5391
model_desc (str): The description of the model.
54-
version (str): The version identifier (default is "1.00").
92+
version (str): A label used in the initial version notes (default is "1.00").
93+
Note: the actual version string is assigned by the server.
5594
5695
Returns:
57-
dict: The server response containing model details.
96+
ModelRepository: Object containing the model name and server-assigned version string.
5897
5998
Raises:
60-
CreateModelError: If the model creation fails.
99+
RuntimeError: If the model creation fails.
61100
"""
62-
if not self._hub_user:
63-
raise ValueError("User not authenticated")
64-
65101
url = "https://api.opengradient.ai/api/v0/models/"
66-
headers = {"Authorization": f"Bearer {self._hub_user['idToken']}", "Content-Type": "application/json"}
102+
headers = {"Authorization": f"Bearer {self._get_auth_token()}", "Content-Type": "application/json"}
67103
payload = {"name": model_name, "description": model_desc}
68104

69105
try:
@@ -74,14 +110,18 @@ def create_model(self, model_name: str, model_desc: str, version: str = "1.00")
74110
raise RuntimeError(f"Model creation failed: {error_details}") from e
75111

76112
json_response = response.json()
77-
model_name = json_response.get("name")
78-
if not model_name:
113+
created_name = json_response.get("name")
114+
if not created_name:
79115
raise Exception(f"Model creation response missing 'name'. Full response: {json_response}")
80116

81-
# Create the specified version for the newly created model
82-
version_response = self.create_version(model_name, version)
117+
# Create the initial version for the newly created model.
118+
# Pass `version` as release notes (e.g. "1.00") since the server assigns
119+
# its own version string — previously `version` was incorrectly passed as
120+
# the positional `notes` argument, resulting in raw version labels as notes
121+
# rather than the clearer "Initial version <label>" format used here.
122+
version_response = self.create_version(created_name, notes=f"Initial version {version}")
83123

84-
return ModelRepository(model_name, version_response["versionString"])
124+
return ModelRepository(created_name, version_response["versionString"])
85125

86126
def create_version(self, model_name: str, notes: str = "", is_major: bool = False) -> dict:
87127
"""
@@ -98,11 +138,8 @@ def create_version(self, model_name: str, notes: str = "", is_major: bool = Fals
98138
Raises:
99139
Exception: If the version creation fails.
100140
"""
101-
if not self._hub_user:
102-
raise ValueError("User not authenticated")
103-
104141
url = f"https://api.opengradient.ai/api/v0/models/{model_name}/versions"
105-
headers = {"Authorization": f"Bearer {self._hub_user['idToken']}", "Content-Type": "application/json"}
142+
headers = {"Authorization": f"Bearer {self._get_auth_token()}", "Content-Type": "application/json"}
106143
payload = {"notes": notes, "is_major": is_major}
107144

108145
try:
@@ -136,20 +173,16 @@ def upload(self, model_path: str, model_name: str, version: str) -> FileUploadRe
136173
version (str): The version identifier for the model.
137174
138175
Returns:
139-
dict: The processed result.
176+
FileUploadResult: The processed result.
140177
141178
Raises:
142179
RuntimeError: If the upload fails.
143180
"""
144-
145-
if not self._hub_user:
146-
raise ValueError("User not authenticated")
147-
148181
if not os.path.exists(model_path):
149182
raise FileNotFoundError(f"Model file not found: {model_path}")
150183

151184
url = f"https://api.opengradient.ai/api/v0/models/{model_name}/versions/{version}/files"
152-
headers = {"Authorization": f"Bearer {self._hub_user['idToken']}"}
185+
headers = {"Authorization": f"Bearer {self._get_auth_token()}"}
153186

154187
try:
155188
with open(model_path, "rb") as file:
@@ -191,11 +224,8 @@ def list_files(self, model_name: str, version: str) -> List[Dict]:
191224
Raises:
192225
RuntimeError: If the file listing fails.
193226
"""
194-
if not self._hub_user:
195-
raise ValueError("User not authenticated")
196-
197227
url = f"https://api.opengradient.ai/api/v0/models/{model_name}/versions/{version}/files"
198-
headers = {"Authorization": f"Bearer {self._hub_user['idToken']}"}
228+
headers = {"Authorization": f"Bearer {self._get_auth_token()}"}
199229

200230
try:
201231
response = requests.get(url, headers=headers)

0 commit comments

Comments
 (0)