Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ ignore = [
"T201",
"PLW0603",
"PLW2901",
"PLR0912",
"PLR0915",
"F841",
"SIM105",
Expand Down
168 changes: 44 additions & 124 deletions src/dualentry_cli/auth.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Authentication for DualEntry CLI - OAuth flow via MCP endpoints and credential storage."""
"""Authentication for DualEntry CLI - OAuth 2.1 with PKCE via public API endpoints."""

from __future__ import annotations

Expand All @@ -8,37 +8,16 @@
import secrets
import socket
import webbrowser
from enum import StrEnum
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from urllib.parse import parse_qs, urlencode, urlparse
from urllib.parse import parse_qs, urlparse

import httpx
import keyring
import typer


class CodeChallengeMethod(StrEnum):
S256 = "S256"


class GrantType(StrEnum):
AUTHORIZATION_CODE = "authorization_code"
REFRESH_TOKEN = "refresh_token" # noqa: S105


class ResponseType(StrEnum):
CODE = "code"


class TokenEndpointAuthMethod(StrEnum):
NONE = "none"


_SERVICE_NAME = "dualentry-cli"
_KEY_NAME_ACCESS = "access_token"
_KEY_NAME_REFRESH = "refresh_token"
_KEY_NAME_API_KEY = "api_key" # legacy, still checked for migration
_KEY_NAME_API_KEY = "api_key"

_TOKEN_FILE = Path.home() / ".dualentry" / "tokens.json"

Expand All @@ -51,137 +30,85 @@ def _generate_pkce_pair() -> tuple[str, str]:
return verifier, challenge


# ── Token storage ────────────────────────────────────────────────────
# -- Credential storage ------------------------------------------------


def store_tokens(access_token: str, refresh_token: str) -> None:
"""Store OAuth tokens. Uses keyring with file fallback."""
def store_api_key(api_key: str) -> None:
"""Store API key. Uses keyring with file fallback."""
try:
keyring.set_password(_SERVICE_NAME, _KEY_NAME_ACCESS, access_token)
keyring.set_password(_SERVICE_NAME, _KEY_NAME_REFRESH, refresh_token)
keyring.set_password(_SERVICE_NAME, _KEY_NAME_API_KEY, api_key)
except Exception:
# Fallback to file storage (e.g. CI, headless)
_TOKEN_FILE.parent.mkdir(parents=True, exist_ok=True)
_TOKEN_FILE.write_text(json.dumps({"access_token": access_token, "refresh_token": refresh_token}))
_TOKEN_FILE.write_text(json.dumps({"api_key": api_key}))
_TOKEN_FILE.chmod(0o600)


def load_tokens() -> tuple[str | None, str | None]:
"""Load OAuth tokens. Returns (access_token, refresh_token)."""
def load_api_key() -> str | None:
"""Load stored API key."""
try:
access = keyring.get_password(_SERVICE_NAME, _KEY_NAME_ACCESS)
refresh = keyring.get_password(_SERVICE_NAME, _KEY_NAME_REFRESH)
if access and refresh:
return access, refresh
key = keyring.get_password(_SERVICE_NAME, _KEY_NAME_API_KEY)
if key:
return key
except Exception:
pass
# File fallback
if _TOKEN_FILE.exists():
try:
data = json.loads(_TOKEN_FILE.read_text())
return data.get("access_token"), data.get("refresh_token")
return data.get("api_key")
except (json.JSONDecodeError, OSError):
pass
return None, None


def load_api_key() -> str | None:
"""Load legacy API key (for X_API_KEY env var compat check)."""
try:
return keyring.get_password(_SERVICE_NAME, _KEY_NAME_API_KEY)
except Exception:
return None
return None


def clear_credentials() -> None:
"""Clear all stored credentials."""
for key in (_KEY_NAME_ACCESS, _KEY_NAME_REFRESH, _KEY_NAME_API_KEY):
try:
keyring.delete_password(_SERVICE_NAME, key)
except Exception:
pass
try:
keyring.delete_password(_SERVICE_NAME, _KEY_NAME_API_KEY)
except Exception:
pass
if _TOKEN_FILE.exists():
try:
_TOKEN_FILE.unlink()
except OSError:
pass


# legacy alias
clear_api_key = clear_credentials
# -- OAuth endpoints ---------------------------------------------------


# ── MCP OAuth client registration ───────────────────────────────────


def _register_client(mcp_url: str, redirect_uri: str) -> dict:
"""Register as an OAuth client with the MCP server (dynamic client registration)."""
def _authorize(api_url: str, redirect_uri: str, code_challenge: str, state: str) -> str:
"""POST /public/v2/oauth/authorize/ — returns the WorkOS authorization URL."""
response = httpx.post(
f"{mcp_url}/register",
f"{api_url.rstrip('/')}/public/v2/oauth/authorize/",
json={
"client_name": "DualEntry CLI",
"redirect_uris": [redirect_uri],
"grant_types": [GrantType.AUTHORIZATION_CODE, GrantType.REFRESH_TOKEN],
"response_types": [ResponseType.CODE],
"token_endpoint_auth_method": TokenEndpointAuthMethod.NONE,
"redirect_uri": redirect_uri,
"code_challenge": code_challenge,
"code_challenge_method": "S256",
"state": state,
},
timeout=30.0,
)
response.raise_for_status()
return response.json()

return response.json()["authorization_url"]

# ── OAuth flow ───────────────────────────────────────────────────────


def _start_authorize(mcp_url: str, client_id: str, redirect_uri: str, code_challenge: str, state: str) -> str:
"""Build the authorization URL and return it (the MCP /authorize endpoint redirects to WorkOS)."""
params = {
"response_type": ResponseType.CODE,
"client_id": client_id,
"redirect_uri": redirect_uri,
"code_challenge": code_challenge,
"code_challenge_method": CodeChallengeMethod.S256,
"state": state,
}
return f"{mcp_url}/authorize?{urlencode(params)}"


def _exchange_token(mcp_url: str, client_id: str, code: str, code_verifier: str, redirect_uri: str) -> dict:
"""Exchange authorization code for access/refresh tokens at MCP /token endpoint."""
def _exchange_code(api_url: str, code: str, code_verifier: str, redirect_uri: str) -> dict:
"""POST /public/v2/oauth/token/ — exchange auth code for API key."""
response = httpx.post(
f"{mcp_url}/token",
data={
"grant_type": GrantType.AUTHORIZATION_CODE,
"client_id": client_id,
f"{api_url.rstrip('/')}/public/v2/oauth/token/",
json={
"grant_type": "authorization_code",
"code": code,
"code_verifier": code_verifier,
"redirect_uri": redirect_uri,
},
headers={"Content-Type": "application/x-www-form-urlencoded"},
timeout=30.0,
)
response.raise_for_status()
return response.json()


def refresh_access_token(mcp_url: str, client_id: str, refresh_token: str) -> dict:
"""Use refresh token to get a new access/refresh token pair."""
response = httpx.post(
f"{mcp_url}/token",
data={
"grant_type": GrantType.REFRESH_TOKEN,
"client_id": client_id,
"refresh_token": refresh_token,
},
headers={"Content-Type": "application/x-www-form-urlencoded"},
timeout=30.0,
)
response.raise_for_status()
return response.json()


# ── Local callback server ────────────────────────────────────────────
# -- Local callback server ---------------------------------------------


def _find_free_port() -> int:
Expand Down Expand Up @@ -212,28 +139,22 @@ def log_message(self, format, *args):
pass


# ── Main login flow ──────────────────────────────────────────────────
# -- Main login flow ---------------------------------------------------


def run_login_flow(api_url: str) -> dict:
"""
Run the full OAuth login flow using MCP endpoints.
Run the full OAuth login flow via /public/v2/oauth/ endpoints.

Returns dict with access_token, refresh_token, and token metadata.
Returns dict with api_key, organization_id, user_email.
"""
mcp_url = f"{api_url.rstrip('/')}/mcp"

port = _find_free_port()
redirect_uri = f"http://localhost:{port}/callback"
verifier, challenge = _generate_pkce_pair()
state = secrets.token_urlsafe(16)

# Register as OAuth client
client_info = _register_client(mcp_url, redirect_uri)
client_id = client_info["client_id"]

# Build authorize URL
auth_url = _start_authorize(mcp_url, client_id, redirect_uri, challenge, state)
# Get authorization URL from backend
auth_url = _authorize(api_url, redirect_uri, challenge, state)

# Start local server and open browser
_CallbackHandler.code = None
Expand All @@ -254,12 +175,11 @@ def run_login_flow(api_url: str) -> dict:
typer.echo("State mismatch - possible CSRF attack.")
raise typer.Exit(code=1)

# Exchange code for tokens
token_response = _exchange_token(mcp_url, client_id, _CallbackHandler.code, verifier, redirect_uri)
# Exchange code for API key
token_response = _exchange_code(api_url, _CallbackHandler.code, verifier, redirect_uri)

return {
"access_token": token_response["access_token"],
"refresh_token": token_response.get("refresh_token", ""),
"expires_in": token_response.get("expires_in"),
"client_id": client_id,
"api_key": token_response["api_key"],
"organization_id": token_response["organization_id"],
"user_email": token_response["user_email"],
}
46 changes: 7 additions & 39 deletions src/dualentry_cli/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,14 @@ def __init__(self, status_code: int, detail: str):


class DualEntryClient:
def __init__(self, api_url: str, *, access_token: str | None = None, refresh_token: str | None = None, client_id: str | None = None, api_key: str | None = None):
def __init__(self, api_url: str, *, api_key: str):
self._api_url = api_url.rstrip("/")
self._base_url = f"{self._api_url}/public/v2"
self._access_token = access_token
self._refresh_token = refresh_token
self._client_id = client_id
self._api_key = api_key

headers = self._build_headers()
self._client = httpx.Client(base_url=self._base_url, headers=headers, timeout=30.0)

def _build_headers(self) -> dict[str, str]:
if self._api_key:
return {"X-API-KEY": self._api_key}
if self._access_token:
return {"Authorization": f"Bearer {self._access_token}"}
return {}
self._client = httpx.Client(
base_url=self._base_url,
headers={"X-API-KEY": api_key},
timeout=30.0,
)

@classmethod
def from_env(cls, api_url: str) -> DualEntryClient:
Expand All @@ -42,27 +33,6 @@ def from_env(cls, api_url: str) -> DualEntryClient:
raise ValueError(msg)
return cls(api_url=api_url, api_key=api_key)

def _try_refresh(self) -> bool:
"""Attempt to refresh the access token. Returns True if successful."""
if not self._refresh_token or not self._client_id:
return False
try:
from dualentry_cli.auth import refresh_access_token, store_tokens

mcp_url = f"{self._api_url}/mcp"
token_response = refresh_access_token(mcp_url, self._client_id, self._refresh_token)
self._access_token = token_response["access_token"]
self._refresh_token = token_response.get("refresh_token", self._refresh_token)
store_tokens(self._access_token, self._refresh_token)
self._client.headers.update({"Authorization": f"Bearer {self._access_token}"})
except Exception as exc:
import sys

print(f"Token refresh failed: {exc}. Re-login with: dualentry auth login", file=sys.stderr)
return False
else:
return True

def _handle_response(self, response: httpx.Response) -> dict:
if response.status_code >= 400:
try:
Expand All @@ -74,8 +44,6 @@ def _handle_response(self, response: httpx.Response) -> dict:

def _request(self, method: str, path: str, **kwargs) -> dict:
response = self._client.request(method, path, **kwargs)
if response.status_code in (401, 403) and self._access_token and self._try_refresh():
response = self._client.request(method, path, **kwargs)
return self._handle_response(response)

def get(self, path: str, params: dict[str, Any] | None = None) -> dict:
Expand All @@ -87,7 +55,7 @@ def paginate(self, path: str, params: dict[str, Any] | None = None, page_size: i
params["limit"] = page_size
params["offset"] = 0
all_items = []
max_pages = 1000 # safety guard against infinite loops
max_pages = 1000

for _ in range(max_pages):
data = self.get(path, params=params)
Expand Down
13 changes: 11 additions & 2 deletions src/dualentry_cli/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import json
import re
from pathlib import Path

import typer
Expand All @@ -21,6 +22,14 @@
EndDate = typer.Option(None, "--end-date", help="Filter to date (YYYY-MM-DD)")
Format = typer.Option("human", "--format", "-o", help="Output format: human or json")

_PREFIX_RE = re.compile(r"^[A-Z]{1,3}-(\d+)$")


def _strip_prefix(value: str) -> str:
"""Strip record prefix if present: 'IN-136159' -> '136159'."""
m = _PREFIX_RE.match(value)
return m.group(1) if m else value


def _build_filter_params(
search: str | None = None,
Expand Down Expand Up @@ -89,13 +98,13 @@ def list_cmd(

@app.command("get")
def get_cmd_with_number(
number: int = typer.Argument(help="Record number"),
number: str = typer.Argument(help="Record number (the Num column, not the # ID)"),
output: str = Format,
):
from dualentry_cli.main import get_client

client = get_client()
data = client.get(f"/{path}/{number}/")
data = client.get(f"/{path}/{_strip_prefix(number)}/")
format_output(data, resource=resource, fmt=output)

get_cmd_with_number.__doc__ = f"Get a {resource} by number."
Expand Down
Loading
Loading