Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions py_gasbuddy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import asyncio
import json
import logging
import re
Expand Down Expand Up @@ -131,6 +132,10 @@ def __init__(
self._cache_manager: GasBuddyCache | None = None
self._timeout = timeout
self._session = session
# Serialise CSRF refreshes within a single process — without this,
# two concurrent callers on a cold cache both GET /home and both
# write the token file.
self._token_lock = asyncio.Lock()

async def process_request(self, query: GraphQLQuery) -> dict[str, Any]:
"""Process API requests.
Expand Down Expand Up @@ -702,8 +707,24 @@ async def _get_headers(self) -> None:
if self._cf_last is None or self._cf_last:
return

_LOGGER.debug("Token invalid, getting a new one...")
# Serialise concurrent token-refresh attempts within this
# GasBuddy instance. After acquiring the lock, re-check whether
# the previous holder already populated the token so we don't
# double-fetch.
async with self._token_lock:
if self._cf_last is True and self._tag: # type: ignore[unreachable]
return # type: ignore[unreachable]

_LOGGER.debug("Token invalid, getting a new one...")
await self._refresh_token(url, method, json_data)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

async def _refresh_token(
self,
url: str,
method: str,
json_data: Any,
) -> None:
"""GET /home (or solver) and persist the extracted CSRF token."""
csrf_timeout = aiohttp.ClientTimeout(total=self._timeout / 1000)
async with self._get_session() as session:
http_method = getattr(session, method)
Expand Down Expand Up @@ -738,7 +759,12 @@ async def _get_headers(self) -> None:
data[TOKEN] = self._tag
encoded = json.dumps(data).encode("utf-8")
_LOGGER.debug("CSRF token found: %s", self._tag)
await self._cache_manager.write_cache(encoded)
if self._cache_manager is not None:
await self._cache_manager.write_cache(encoded)
# Mark this instance as having a fresh token so
# any coroutine still queued on _token_lock skips
# its own refresh after we release the lock.
self._cf_last = True
else:
_LOGGER.error("CSRF token not found.")
raise CSRFTokenMissing
Expand Down
39 changes: 33 additions & 6 deletions py_gasbuddy/cache.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Cache functions for py-gasbuddy."""

import asyncio
import json
import logging
import os
import uuid
from pathlib import Path
from typing import Any

Expand All @@ -20,15 +23,39 @@ def __init__(self, cache_file: str = "") -> None:
self._cache_file = Path.home() / ".cache" / "py_gasbuddy" / "token"
else:
self._cache_file = Path(cache_file)
# Serialise cache mutations within a single process. The HA
# coordinator + a parallel service call could otherwise race
# both reading and writing the same token file.
self._lock = asyncio.Lock()

async def write_cache(self, data: Any) -> None:
"""Write cache file."""
# Create parent directories if they don't exist
if not await aiofiles.os.path.exists(self._cache_file.parent):
await aiofiles.os.makedirs(self._cache_file.parent)
"""Atomically write the cache file.

async with aiofiles.open(self._cache_file, mode="wb") as file:
await file.write(data)
Writes to a uniquely-named sibling tempfile and ``os.replace``s
onto the final path, so concurrent writers can't produce a torn
file. The asyncio lock further serialises in-process writers.
"""
async with self._lock:
# Create parent directories if they don't exist. Use
# exist_ok=True so a racing process that creates the
# directory between our check and call doesn't blow up.
await aiofiles.os.makedirs(self._cache_file.parent, exist_ok=True)

tmp_path = self._cache_file.with_name(
f".{self._cache_file.name}.{os.getpid()}.{uuid.uuid4().hex}.tmp"
)
try:
async with aiofiles.open(tmp_path, mode="wb") as file:
await file.write(data)
# os.replace is atomic on POSIX and Windows ≥Vista.
await aiofiles.os.replace(tmp_path, self._cache_file)
except Exception:
# Best-effort cleanup of the tempfile on failure.
try:
await aiofiles.os.remove(tmp_path)
except OSError:
pass
raise

async def read_cache(self) -> Any:
"""Read cache file."""
Expand Down
83 changes: 80 additions & 3 deletions tests/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,10 @@ async def test_retry_logic(mock_aioclient, caplog):
)
# Patch asyncio.sleep used by backoff so the test doesn't actually
# wait through the exponential delay between retries.
with caplog.at_level(logging.DEBUG), patch("backoff._async.asyncio.sleep", new=AsyncMock()):
with (
caplog.at_level(logging.DEBUG),
patch("backoff._async.asyncio.sleep", new=AsyncMock()),
):
with pytest.raises(py_gasbuddy.LibraryError):
manager = py_gasbuddy.GasBuddy(station_id=205033)
await manager.price_lookup()
Expand All @@ -358,11 +361,14 @@ async def test_retry_succeeds_on_second_attempt(mock_aioclient, caplog):
mock_aioclient.post(
TEST_URL,
status=403,
body='<!DOCTYPE html><html><title>Just a moment...</title></html>',
body="<!DOCTYPE html><html><title>Just a moment...</title></html>",
)
mock_aioclient.post(TEST_URL, status=200, body=load_fixture("station.json"))

with caplog.at_level(logging.DEBUG), patch("backoff._async.asyncio.sleep", new=AsyncMock()):
with (
caplog.at_level(logging.DEBUG),
patch("backoff._async.asyncio.sleep", new=AsyncMock()),
):
manager = py_gasbuddy.GasBuddy(station_id=205033)
data = await manager.price_lookup()

Expand Down Expand Up @@ -1192,3 +1198,74 @@ async def test_location_search_pagination(mock_aioclient, caplog, tmp_path):
result_err = await manager.location_search(zipcode=12345)
assert result_err["results"] == []
assert result_err["next_cursor"] is None


async def test_cache_write_atomic_failure(tmp_path):
"""Test that a write failure cleans up the tempfile and propagates the error."""
from unittest.mock import patch

cache_file = tmp_path / "test_cache"
cache = py_gasbuddy.cache.GasBuddyCache(str(cache_file))

# Mock aiofiles.os.replace to fail
with patch("aiofiles.os.replace", side_effect=OSError("Atomic replace failed")):
with pytest.raises(OSError, match="Atomic replace failed"):
await cache.write_cache(b"data")

# The final cache file should not exist
assert not cache_file.exists()

# The temp files in that directory should also be cleaned up (no tmp files)
tmp_files = list(tmp_path.glob(".*.tmp"))
assert not tmp_files


async def test_token_refresh_concurrency(mock_aioclient):
"""Test that concurrent token-refresh attempts are serialized and only one HTTP fetch is executed."""
import asyncio
from unittest.mock import patch

# Mock GB_HOME_URL response
mock_aioclient.get(
GB_URL,
status=200,
body=load_fixture("index.html"),
)

manager = py_gasbuddy.GasBuddy()

# Track how many times _refresh_token is invoked
original_refresh = manager._refresh_token
refresh_calls = 0

async def spy_refresh(*args, **kwargs):
nonlocal refresh_calls
refresh_calls += 1
# Introduce a small sleep to simulate network delay, forcing the concurrency
await asyncio.sleep(0.05)
await original_refresh(*args, **kwargs)

with patch.object(manager, "_refresh_token", side_effect=spy_refresh):
# Trigger concurrent _get_headers calls
await asyncio.gather(
manager._get_headers(),
manager._get_headers(),
)

# Only one token refresh should have occurred
assert refresh_calls == 1
await manager.clear_cache()


async def test_cache_write_atomic_failure_cleanup_fails(tmp_path):
"""Test that a write failure propagates correctly even when the tempfile cleanup itself fails with OSError."""
from unittest.mock import patch

cache_file = tmp_path / "test_cache"
cache = py_gasbuddy.cache.GasBuddyCache(str(cache_file))

# Mock replace to fail, and mock remove to fail as well
with patch("aiofiles.os.replace", side_effect=OSError("Atomic replace failed")):
with patch("aiofiles.os.remove", side_effect=OSError("Remove failed")):
with pytest.raises(OSError, match="Atomic replace failed"):
await cache.write_cache(b"data")
Loading