diff --git a/src/blaxel/core/common/logger.py b/src/blaxel/core/common/logger.py index a21a3064..380f565c 100644 --- a/src/blaxel/core/common/logger.py +++ b/src/blaxel/core/common/logger.py @@ -7,6 +7,41 @@ import logging import os +PROVIDER_DEBUG_LOG_OPT_IN_ENV = "BL_ALLOW_PROVIDER_DEBUG_LOGS" +PROVIDER_DEBUG_LOGGER_NAMES = ( + "google_adk", + "google_adk.google.adk.models.lite_llm", + "LiteLLM", + "LiteLLM Proxy", + "LiteLLM Router", + "openai", + "openai._base_client", + "httpx", + "httpcore", +) + + +def _is_truthy_env(value: str | None) -> bool: + return value is not None and value.strip().lower() in {"1", "true", "yes", "on"} + + +def provider_debug_logs_allowed() -> bool: + return _is_truthy_env(os.environ.get(PROVIDER_DEBUG_LOG_OPT_IN_ENV)) + + +def suppress_provider_debug_loggers() -> None: + if provider_debug_logs_allowed(): + return + + for logger_name in PROVIDER_DEBUG_LOGGER_NAMES: + provider_logger = logging.getLogger(logger_name) + if provider_logger.level == logging.NOTSET or provider_logger.level < logging.WARNING: + provider_logger.setLevel(logging.WARNING) + for handler in provider_logger.handlers: + if handler.level == logging.NOTSET or handler.level < logging.WARNING: + handler.setLevel(logging.WARNING) + + try: from opentelemetry.trace import get_current_span @@ -118,9 +153,11 @@ def init_logger(log_level: str): Parameters: log_level (str): The logging level to set (e.g., "DEBUG", "INFO"). """ + suppress_provider_debug_loggers() # Disable urllib3 logging logging.getLogger("urllib3").setLevel(logging.CRITICAL) - logging.getLogger("httpx").setLevel(logging.CRITICAL) + if not provider_debug_logs_allowed(): + logging.getLogger("httpx").setLevel(logging.CRITICAL) handler = logging.StreamHandler() logger_type = os.environ.get("BL_LOGGER", "http") @@ -129,3 +166,4 @@ def init_logger(log_level: str): else: handler.setFormatter(ColoredFormatter("%(levelname)s %(name)s - %(message)s")) logging.basicConfig(level=log_level, handlers=[handler]) + suppress_provider_debug_loggers() diff --git a/src/blaxel/googleadk/model.py b/src/blaxel/googleadk/model.py index 7f742abc..0f3f803d 100644 --- a/src/blaxel/googleadk/model.py +++ b/src/blaxel/googleadk/model.py @@ -7,6 +7,7 @@ from blaxel.core import bl_model as bl_model_core from blaxel.core import settings +from blaxel.core.common.logger import suppress_provider_debug_loggers logger = getLogger(__name__) @@ -26,6 +27,7 @@ async def acompletion(self, model, messages, tools, **kwargs): Returns: The model response as a message. """ + suppress_provider_debug_loggers() auth_headers = settings.auth.get_headers() extra = dict(auth_headers) # When auth uses X-Blaxel-Authorization (API keys), override the @@ -35,12 +37,14 @@ async def acompletion(self, model, messages, tools, **kwargs): if "Authorization" not in auth_headers: extra["Authorization"] = "" kwargs["extra_headers"] = extra - return await super().acompletion( + response = await super().acompletion( model=model, messages=messages, tools=tools, **kwargs, ) + suppress_provider_debug_loggers() + return response def completion(self, model, messages, tools, stream=False, **kwargs): """Synchronously calls completion. This is used for streaming only. @@ -55,6 +59,7 @@ def completion(self, model, messages, tools, stream=False, **kwargs): Returns: The response from the model. """ + suppress_provider_debug_loggers() auth_headers = settings.auth.get_headers() extra = dict(auth_headers) # When auth uses X-Blaxel-Authorization (API keys), override the @@ -64,16 +69,19 @@ def completion(self, model, messages, tools, stream=False, **kwargs): if "Authorization" not in auth_headers: extra["Authorization"] = "" kwargs["extra_headers"] = extra - return super().completion( + response = super().completion( model=model, messages=messages, tools=tools, stream=stream, **kwargs, ) + suppress_provider_debug_loggers() + return response async def get_google_adk_model(url: str, type: str, model: str, **kwargs): + suppress_provider_debug_loggers() llm_client = AuthenticatedLiteLLMClient() if type == "mistral": return LiteLlm( diff --git a/tests/core/test_logger.py b/tests/core/test_logger.py new file mode 100644 index 00000000..f9b74246 --- /dev/null +++ b/tests/core/test_logger.py @@ -0,0 +1,98 @@ +import io +import logging + +import pytest + +from blaxel.core.common.logger import ( + PROVIDER_DEBUG_LOGGER_NAMES, + suppress_provider_debug_loggers, +) + + +@pytest.fixture +def restore_provider_loggers(): + tracked = {} + for name in PROVIDER_DEBUG_LOGGER_NAMES: + logger = logging.getLogger(name) + tracked[name] = { + "level": logger.level, + "propagate": logger.propagate, + "handlers": list(logger.handlers), + "handler_levels": {handler: handler.level for handler in logger.handlers}, + } + + yield + + for name, state in tracked.items(): + logger = logging.getLogger(name) + for handler in list(logger.handlers): + if handler not in state["handlers"]: + logger.removeHandler(handler) + handler.close() + logger.handlers[:] = state["handlers"] + for handler, level in state["handler_levels"].items(): + handler.setLevel(level) + logger.setLevel(state["level"]) + logger.propagate = state["propagate"] + + +def test_suppresses_high_risk_provider_loggers_by_default( + monkeypatch, restore_provider_loggers +): + monkeypatch.delenv("BL_ALLOW_PROVIDER_DEBUG_LOGS", raising=False) + + for logger_name in PROVIDER_DEBUG_LOGGER_NAMES: + logger = logging.getLogger(logger_name) + handler = logging.StreamHandler(io.StringIO()) + logger.addHandler(handler) + logger.setLevel(logging.DEBUG) + handler.setLevel(logging.DEBUG) + + suppress_provider_debug_loggers() + + for logger_name in PROVIDER_DEBUG_LOGGER_NAMES: + logger = logging.getLogger(logger_name) + assert logger.getEffectiveLevel() >= logging.WARNING + assert all(handler.level >= logging.WARNING for handler in logger.handlers) + + +def test_provider_debug_payload_patterns_do_not_emit_by_default( + monkeypatch, restore_provider_loggers +): + monkeypatch.delenv("BL_ALLOW_PROVIDER_DEBUG_LOGS", raising=False) + stream = io.StringIO() + handler = logging.StreamHandler(stream) + handler.setLevel(logging.DEBUG) + logger = logging.getLogger("LiteLLM") + logger.addHandler(handler) + logger.propagate = False + logger.setLevel(logging.DEBUG) + + suppress_provider_debug_loggers() + logger.debug( + "X-Blaxel-Authorization: Bearer " + "eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjMifQ.signature " + "Request options LLM Request RAW RESPONSE" + ) + + assert stream.getvalue() == "" + + +def test_provider_debug_opt_in_preserves_debug_behavior( + monkeypatch, restore_provider_loggers +): + monkeypatch.setenv("BL_ALLOW_PROVIDER_DEBUG_LOGS", "true") + stream = io.StringIO() + handler = logging.StreamHandler(stream) + handler.setLevel(logging.DEBUG) + logger = logging.getLogger("LiteLLM") + logger.addHandler(handler) + logger.propagate = False + logger.setLevel(logging.DEBUG) + + suppress_provider_debug_loggers() + logger.debug("visible provider debug") + + assert logger.getEffectiveLevel() == logging.DEBUG + assert handler.level == logging.DEBUG + assert "visible provider debug" in stream.getvalue() diff --git a/tests/integration/googleadk/test_model.py b/tests/integration/googleadk/test_model.py index 80c4b3c8..dfe13400 100644 --- a/tests/integration/googleadk/test_model.py +++ b/tests/integration/googleadk/test_model.py @@ -1,5 +1,8 @@ """Google ADK Model Integration Tests.""" +import logging +import re + pytest_plugins = [] import pytest # noqa: E402 @@ -16,6 +19,21 @@ "sandbox-openai", ] +BLOCKED_PROVIDER_LOG_PATTERNS = { + "x_blaxel_authorization": re.compile(r"x-blaxel-authorization", re.IGNORECASE), + "bearer": re.compile(r"\bbearer\b", re.IGNORECASE), + "jwt_like": re.compile(r"eyJ[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+\."), + "request_options": re.compile(r"Request options"), + "llm_request": re.compile(r"LLM Request"), + "raw_response": re.compile(r"RAW RESPONSE"), +} + + +def blocked_provider_log_counts(text: str) -> dict[str, int]: + return { + name: len(pattern.findall(text)) for name, pattern in BLOCKED_PROVIDER_LOG_PATTERNS.items() + } + @pytest.mark.asyncio(loop_scope="class") class TestBlModel: @@ -52,3 +70,54 @@ async def test_can_call_model(self, model_name: str): collected_text += part.text assert len(collected_text) > 0 + + @pytest.mark.parametrize("model_name", TEST_MODELS) + async def test_provider_debug_logs_are_suppressed( + self, + model_name: str, + monkeypatch: pytest.MonkeyPatch, + tmp_path, + capsys: pytest.CaptureFixture[str], + ): + """Force root DEBUG and verify provider internals stay out of captured logs.""" + monkeypatch.delenv("BL_ALLOW_PROVIDER_DEBUG_LOGS", raising=False) + root_logger = logging.getLogger() + original_root_level = root_logger.level + log_path = tmp_path / "provider-debug.log" + capture_handler = logging.FileHandler(log_path) + capture_handler.setLevel(logging.DEBUG) + root_logger.addHandler(capture_handler) + root_logger.setLevel(logging.DEBUG) + + try: + model = await bl_model(model_name) + + request = LlmRequest( + contents=[ + types.Content( + parts=[types.Part(text="Say hello in one word")], + role="user", + ) + ], + config=types.GenerateContentConfig(), + ) + + collected_text = "" + async for response in model.generate_content_async(request): + assert response is not None + if response.content and response.content.parts: + for part in response.content.parts: + if part.text: + collected_text += part.text + finally: + root_logger.removeHandler(capture_handler) + capture_handler.close() + root_logger.setLevel(original_root_level) + + captured = capsys.readouterr() + log_text = log_path.read_text() + log_path.unlink(missing_ok=True) + counts = blocked_provider_log_counts(f"{log_text}\n{captured.out}\n{captured.err}") + + assert len(collected_text) > 0 + assert not any(counts.values()), f"blocked provider debug logs emitted: {counts}"