Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion src/blaxel/core/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,41 @@
import logging
import os

PROVIDER_DEBUG_LOG_OPT_IN_ENV = "BL_ALLOW_PROVIDER_DEBUG_LOGS"
PROVIDER_DEBUG_LOGGER_NAMES = (
"google_adk",
"google_adk.google.adk.models.lite_llm",
"LiteLLM",
"LiteLLM Proxy",
"LiteLLM Router",
"openai",
"openai._base_client",
"httpx",
"httpcore",
)


def _is_truthy_env(value: str | None) -> bool:
return value is not None and value.strip().lower() in {"1", "true", "yes", "on"}


def provider_debug_logs_allowed() -> bool:
return _is_truthy_env(os.environ.get(PROVIDER_DEBUG_LOG_OPT_IN_ENV))


def suppress_provider_debug_loggers() -> None:
if provider_debug_logs_allowed():
return

for logger_name in PROVIDER_DEBUG_LOGGER_NAMES:
provider_logger = logging.getLogger(logger_name)
if provider_logger.level == logging.NOTSET or provider_logger.level < logging.WARNING:
provider_logger.setLevel(logging.WARNING)
for handler in provider_logger.handlers:
if handler.level == logging.NOTSET or handler.level < logging.WARNING:
handler.setLevel(logging.WARNING)


try:
from opentelemetry.trace import get_current_span

Expand Down Expand Up @@ -118,9 +153,11 @@ def init_logger(log_level: str):
Parameters:
log_level (str): The logging level to set (e.g., "DEBUG", "INFO").
"""
suppress_provider_debug_loggers()
# Disable urllib3 logging
logging.getLogger("urllib3").setLevel(logging.CRITICAL)
logging.getLogger("httpx").setLevel(logging.CRITICAL)
if not provider_debug_logs_allowed():
logging.getLogger("httpx").setLevel(logging.CRITICAL)
handler = logging.StreamHandler()

logger_type = os.environ.get("BL_LOGGER", "http")
Expand All @@ -129,3 +166,4 @@ def init_logger(log_level: str):
else:
handler.setFormatter(ColoredFormatter("%(levelname)s %(name)s - %(message)s"))
logging.basicConfig(level=log_level, handlers=[handler])
suppress_provider_debug_loggers()
12 changes: 10 additions & 2 deletions src/blaxel/googleadk/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from blaxel.core import bl_model as bl_model_core
from blaxel.core import settings
from blaxel.core.common.logger import suppress_provider_debug_loggers

logger = getLogger(__name__)

Expand All @@ -26,6 +27,7 @@ async def acompletion(self, model, messages, tools, **kwargs):
Returns:
The model response as a message.
"""
suppress_provider_debug_loggers()
auth_headers = settings.auth.get_headers()
extra = dict(auth_headers)
# When auth uses X-Blaxel-Authorization (API keys), override the
Expand All @@ -35,12 +37,14 @@ async def acompletion(self, model, messages, tools, **kwargs):
if "Authorization" not in auth_headers:
extra["Authorization"] = ""
kwargs["extra_headers"] = extra
return await super().acompletion(
response = await super().acompletion(
model=model,
messages=messages,
tools=tools,
**kwargs,
)
suppress_provider_debug_loggers()
return response

def completion(self, model, messages, tools, stream=False, **kwargs):
"""Synchronously calls completion. This is used for streaming only.
Expand All @@ -55,6 +59,7 @@ def completion(self, model, messages, tools, stream=False, **kwargs):
Returns:
The response from the model.
"""
suppress_provider_debug_loggers()
auth_headers = settings.auth.get_headers()
extra = dict(auth_headers)
# When auth uses X-Blaxel-Authorization (API keys), override the
Expand All @@ -64,16 +69,19 @@ def completion(self, model, messages, tools, stream=False, **kwargs):
if "Authorization" not in auth_headers:
extra["Authorization"] = ""
kwargs["extra_headers"] = extra
return super().completion(
response = super().completion(
model=model,
messages=messages,
tools=tools,
stream=stream,
**kwargs,
)
suppress_provider_debug_loggers()
return response


async def get_google_adk_model(url: str, type: str, model: str, **kwargs):
suppress_provider_debug_loggers()
llm_client = AuthenticatedLiteLLMClient()
if type == "mistral":
return LiteLlm(
Expand Down
98 changes: 98 additions & 0 deletions tests/core/test_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import io
import logging

import pytest

from blaxel.core.common.logger import (
PROVIDER_DEBUG_LOGGER_NAMES,
suppress_provider_debug_loggers,
)


@pytest.fixture
def restore_provider_loggers():
tracked = {}
for name in PROVIDER_DEBUG_LOGGER_NAMES:
logger = logging.getLogger(name)
tracked[name] = {
"level": logger.level,
"propagate": logger.propagate,
"handlers": list(logger.handlers),
"handler_levels": {handler: handler.level for handler in logger.handlers},
}

yield

for name, state in tracked.items():
logger = logging.getLogger(name)
for handler in list(logger.handlers):
if handler not in state["handlers"]:
logger.removeHandler(handler)
handler.close()
logger.handlers[:] = state["handlers"]
for handler, level in state["handler_levels"].items():
handler.setLevel(level)
logger.setLevel(state["level"])
logger.propagate = state["propagate"]


def test_suppresses_high_risk_provider_loggers_by_default(
monkeypatch, restore_provider_loggers
):
monkeypatch.delenv("BL_ALLOW_PROVIDER_DEBUG_LOGS", raising=False)

for logger_name in PROVIDER_DEBUG_LOGGER_NAMES:
logger = logging.getLogger(logger_name)
handler = logging.StreamHandler(io.StringIO())
logger.addHandler(handler)
logger.setLevel(logging.DEBUG)
handler.setLevel(logging.DEBUG)

suppress_provider_debug_loggers()

for logger_name in PROVIDER_DEBUG_LOGGER_NAMES:
logger = logging.getLogger(logger_name)
assert logger.getEffectiveLevel() >= logging.WARNING
assert all(handler.level >= logging.WARNING for handler in logger.handlers)


def test_provider_debug_payload_patterns_do_not_emit_by_default(
monkeypatch, restore_provider_loggers
):
monkeypatch.delenv("BL_ALLOW_PROVIDER_DEBUG_LOGS", raising=False)
stream = io.StringIO()
handler = logging.StreamHandler(stream)
handler.setLevel(logging.DEBUG)
logger = logging.getLogger("LiteLLM")
logger.addHandler(handler)
logger.propagate = False
logger.setLevel(logging.DEBUG)

suppress_provider_debug_loggers()
logger.debug(
"X-Blaxel-Authorization: Bearer "
"eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjMifQ.signature "
"Request options LLM Request RAW RESPONSE"
)

assert stream.getvalue() == ""


def test_provider_debug_opt_in_preserves_debug_behavior(
monkeypatch, restore_provider_loggers
):
monkeypatch.setenv("BL_ALLOW_PROVIDER_DEBUG_LOGS", "true")
stream = io.StringIO()
handler = logging.StreamHandler(stream)
handler.setLevel(logging.DEBUG)
logger = logging.getLogger("LiteLLM")
logger.addHandler(handler)
logger.propagate = False
logger.setLevel(logging.DEBUG)

suppress_provider_debug_loggers()
logger.debug("visible provider debug")

assert logger.getEffectiveLevel() == logging.DEBUG
assert handler.level == logging.DEBUG
assert "visible provider debug" in stream.getvalue()
69 changes: 69 additions & 0 deletions tests/integration/googleadk/test_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""Google ADK Model Integration Tests."""

import logging
import re

pytest_plugins = []
import pytest # noqa: E402

Expand All @@ -16,6 +19,21 @@
"sandbox-openai",
]

BLOCKED_PROVIDER_LOG_PATTERNS = {
"x_blaxel_authorization": re.compile(r"x-blaxel-authorization", re.IGNORECASE),
"bearer": re.compile(r"\bbearer\b", re.IGNORECASE),
"jwt_like": re.compile(r"eyJ[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+\."),
"request_options": re.compile(r"Request options"),
"llm_request": re.compile(r"LLM Request"),
"raw_response": re.compile(r"RAW RESPONSE"),
}


def blocked_provider_log_counts(text: str) -> dict[str, int]:
return {
name: len(pattern.findall(text)) for name, pattern in BLOCKED_PROVIDER_LOG_PATTERNS.items()
}


@pytest.mark.asyncio(loop_scope="class")
class TestBlModel:
Expand Down Expand Up @@ -52,3 +70,54 @@ async def test_can_call_model(self, model_name: str):
collected_text += part.text

assert len(collected_text) > 0

@pytest.mark.parametrize("model_name", TEST_MODELS)
async def test_provider_debug_logs_are_suppressed(
self,
model_name: str,
monkeypatch: pytest.MonkeyPatch,
tmp_path,
capsys: pytest.CaptureFixture[str],
):
"""Force root DEBUG and verify provider internals stay out of captured logs."""
monkeypatch.delenv("BL_ALLOW_PROVIDER_DEBUG_LOGS", raising=False)
root_logger = logging.getLogger()
original_root_level = root_logger.level
log_path = tmp_path / "provider-debug.log"
capture_handler = logging.FileHandler(log_path)
capture_handler.setLevel(logging.DEBUG)
root_logger.addHandler(capture_handler)
root_logger.setLevel(logging.DEBUG)

try:
model = await bl_model(model_name)

request = LlmRequest(
contents=[
types.Content(
parts=[types.Part(text="Say hello in one word")],
role="user",
)
],
config=types.GenerateContentConfig(),
)

collected_text = ""
async for response in model.generate_content_async(request):
assert response is not None
if response.content and response.content.parts:
for part in response.content.parts:
if part.text:
collected_text += part.text
finally:
root_logger.removeHandler(capture_handler)
capture_handler.close()
root_logger.setLevel(original_root_level)

captured = capsys.readouterr()
log_text = log_path.read_text()
log_path.unlink(missing_ok=True)
counts = blocked_provider_log_counts(f"{log_text}\n{captured.out}\n{captured.err}")

assert len(collected_text) > 0
assert not any(counts.values()), f"blocked provider debug logs emitted: {counts}"
Loading