diff --git a/pykalshi/__init__.py b/pykalshi/__init__.py index 85b294f..1627337 100644 --- a/pykalshi/__init__.py +++ b/pykalshi/__init__.py @@ -4,7 +4,7 @@ A clean, modular interface for the Kalshi trading API. """ -__version__ = "0.4.0" +__version__ = "0.5.0" import logging diff --git a/pykalshi/_async/__init__.py b/pykalshi/_async/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pykalshi/_async/api_keys.py b/pykalshi/_async/api_keys.py new file mode 100644 index 0000000..4766967 --- /dev/null +++ b/pykalshi/_async/api_keys.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ..models import APIKey, GeneratedAPIKey, APILimits + +if TYPE_CHECKING: + from .client import AsyncKalshiClient + + +class AsyncAPIKeys: + """API key management and account limits.""" + + def __init__(self, client: AsyncKalshiClient) -> None: + self._client = client + + async def list(self) -> list[APIKey]: + """List all API keys for this account.""" + data = await self._client.get("/api_keys") + return [APIKey.model_validate(k) for k in data.get("api_keys", [])] + + async def create(self, public_key: str, name: str | None = None) -> str: + """Create an API key with a provided RSA public key. + + Args: + public_key: PEM-encoded RSA public key. + name: Optional name for the key. + + Returns: + The API key ID string. + """ + body: dict = {"public_key": public_key} + if name: + body["name"] = name + data = await self._client.post("/api_keys", body) + return data["api_key_id"] + + async def generate(self, name: str | None = None) -> GeneratedAPIKey: + """Generate a new API key pair (Kalshi creates both keys). + + Returns a GeneratedAPIKey with the private_key field populated. + The private key is only returned ONCE - store it securely. + + Args: + name: Optional name for the key. + """ + body: dict = {} + if name: + body["name"] = name + data = await self._client.post("/api_keys/generate", body) + return GeneratedAPIKey.model_validate(data) + + async def delete(self, key_id: str) -> None: + """Delete an API key. + + Args: + key_id: The API key ID to delete. + """ + await self._client.delete(f"/api_keys/{key_id}") + + async def get_limits(self) -> APILimits: + """Get API rate limits for this account.""" + data = await self._client.get("/account/limits") + return APILimits.model_validate(data) diff --git a/pykalshi/_async/client.py b/pykalshi/_async/client.py new file mode 100644 index 0000000..7ffc0d2 --- /dev/null +++ b/pykalshi/_async/client.py @@ -0,0 +1,405 @@ +"""Kalshi API Client.""" + +from __future__ import annotations + +import asyncio +import json +import logging +from functools import cached_property +from typing import Any, TYPE_CHECKING +from urllib.parse import urlencode + +import httpx + +from .._base import _BaseKalshiClient, _RETRYABLE_STATUS_CODES +from .events import AsyncEvent +from .markets import AsyncMarket, AsyncSeries +from .mve import AsyncMveCollection +from ..models import MarketModel, EventModel, SeriesModel, TradeModel, CandlestickResponse, MveCollectionModel +from ..dataframe import DataFrameList +from .portfolio import AsyncPortfolio +from ..enums import MarketStatus, CandlestickPeriod +from .exchange import AsyncExchange +from .api_keys import AsyncAPIKeys +from .communications import AsyncCommunications +from ..exceptions import RateLimitError +from .._utils import normalize_ticker, normalize_tickers + +if TYPE_CHECKING: + from ..afeed import AsyncFeed + from ..rate_limiter import AsyncRateLimiterProtocol + +logger = logging.getLogger(__name__) + + +class AsyncKalshiClient(_BaseKalshiClient): + """Authenticated client for the Kalshi Trading API. + + Usage: + async with AsyncKalshiClient.from_env() as client: + market = await client.get_market("TICKER") + balance = await client.portfolio.get_balance() + """ + + def __init__( + self, + api_key_id: str | None = None, + private_key_path: str | None = None, + api_base: str | None = None, + demo: bool = False, + timeout: float = 10.0, + max_retries: int = 3, + rate_limiter: AsyncRateLimiterProtocol | None = None, + ) -> None: + super().__init__( + api_key_id=api_key_id, + private_key_path=private_key_path, + api_base=api_base, + demo=demo, + timeout=timeout, + max_retries=max_retries, + rate_limiter=rate_limiter, + ) + self._session = httpx.AsyncClient() + + async def aclose(self) -> None: + """Close the underlying HTTP connection pool.""" + await self._session.aclose() + + async def __aenter__(self) -> AsyncKalshiClient: + return self + + async def __aexit__(self, *args: Any) -> None: + await self.aclose() + + # --- HTTP methods --- + + async def _request(self, method: str, endpoint: str, **kwargs: Any) -> httpx.Response: + """Execute async HTTP request with retry on transient failures.""" + url = f"{self.api_base}{endpoint}" + + for attempt in range(self.max_retries + 1): + if self.rate_limiter is not None: + wait_time = await self.rate_limiter.acquire() + if wait_time > 0: + logger.debug("Rate limiter waited %.3fs", wait_time) + + headers = self._get_headers(method, endpoint) + request_kwargs: dict[str, Any] = {"headers": headers, "timeout": self.timeout} + if "data" in kwargs: + request_kwargs["content"] = kwargs["data"] + try: + response = await self._session.request(method, url, **request_kwargs) + except httpx.TimeoutException as e: + if attempt == self.max_retries: + raise + wait = self._compute_backoff(attempt, None) + logger.warning( + "%s %s failed (%s), retry %d/%d in %.1fs", + method, endpoint, type(e).__name__, + attempt + 1, self.max_retries, wait, + ) + await asyncio.sleep(wait) + continue + except httpx.ConnectError as e: + if attempt == self.max_retries: + raise + wait = self._compute_backoff(attempt, None) + logger.warning( + "%s %s failed (%s), retry %d/%d in %.1fs", + method, endpoint, type(e).__name__, + attempt + 1, self.max_retries, wait, + ) + await asyncio.sleep(wait) + continue + + self._update_rate_limiter(response) + + if response.status_code not in _RETRYABLE_STATUS_CODES: + return response + if attempt == self.max_retries: + if response.status_code == 429: + raise RateLimitError( + 429, "Rate limit exceeded after retries", + method=method, endpoint=endpoint, + ) + return response + + wait = self._compute_backoff(attempt, response.headers.get("Retry-After")) + logger.warning( + "%s %s returned %d, retry %d/%d in %.1fs", + method, endpoint, response.status_code, + attempt + 1, self.max_retries, wait, + ) + await asyncio.sleep(wait) + + return response # unreachable, satisfies type checker + + async def get(self, endpoint: str) -> dict[str, Any]: + """Make authenticated GET request.""" + logger.debug("GET %s", endpoint) + response = await self._request("GET", endpoint) + return self._handle_response(response, method="GET", endpoint=endpoint) + + async def paginated_get( + self, + path: str, + response_key: str, + params: dict[str, Any], + fetch_all: bool = False, + ) -> list[dict]: + """Fetch items with automatic cursor-based pagination.""" + params = dict(params) + all_items: list[dict] = [] + while True: + filtered = {k: v for k, v in params.items() if v is not None} + endpoint = f"{path}?{urlencode(filtered)}" if filtered else path + response = await self.get(endpoint) + all_items.extend(response.get(response_key, [])) + cursor = response.get("cursor", "") + if not fetch_all or not cursor: + break + params["cursor"] = cursor + return all_items + + async def post(self, endpoint: str, data: dict[str, Any]) -> dict[str, Any]: + """Make authenticated POST request.""" + logger.debug("POST %s", endpoint) + body = json.dumps(data, separators=(",", ":")) + response = await self._request("POST", endpoint, data=body) + return self._handle_response( + response, method="POST", endpoint=endpoint, request_body=data + ) + + async def put(self, endpoint: str, data: dict[str, Any]) -> dict[str, Any]: + """Make authenticated PUT request.""" + logger.debug("PUT %s", endpoint) + body = json.dumps(data, separators=(",", ":")) + response = await self._request("PUT", endpoint, data=body) + return self._handle_response( + response, method="PUT", endpoint=endpoint, request_body=data + ) + + async def delete(self, endpoint: str, body: dict | None = None) -> dict[str, Any]: + """Make authenticated DELETE request.""" + logger.debug("DELETE %s", endpoint) + if body: + data = json.dumps(body, separators=(",", ":")) + response = await self._request("DELETE", endpoint, data=data) + else: + response = await self._request("DELETE", endpoint) + return self._handle_response(response, method="DELETE", endpoint=endpoint) + + # --- Domain accessors --- + + @cached_property + def portfolio(self) -> AsyncPortfolio: + return AsyncPortfolio(self) + + @cached_property + def exchange(self) -> AsyncExchange: + return AsyncExchange(self) + + @cached_property + def api_keys(self) -> AsyncAPIKeys: + return AsyncAPIKeys(self) + + @cached_property + def communications(self) -> AsyncCommunications: + return AsyncCommunications(self) + + def feed(self) -> AsyncFeed: + """Create a new async real-time data feed.""" + from ..afeed import AsyncFeed + return AsyncFeed(self) + + # --- Domain query methods --- + + async def get_market(self, ticker: str) -> AsyncMarket: + response = await self.get(f"/markets/{ticker.upper()}") + model = MarketModel.model_validate(response["market"]) + return AsyncMarket(self, model) + + async def get_markets( + self, + *, + status: MarketStatus | None = None, + mve_filter: str | None = None, + tickers: list[str] | None = None, + series_ticker: str | None = None, + event_ticker: str | None = None, + limit: int = 100, + cursor: str | None = None, + fetch_all: bool = False, + **extra_params, + ) -> DataFrameList[AsyncMarket]: + params = { + "status": status.value if status is not None else None, + "mve_filter": mve_filter, + "tickers": ",".join(normalize_tickers(tickers)) if tickers else None, + "series_ticker": normalize_ticker(series_ticker), + "event_ticker": normalize_ticker(event_ticker), + "limit": limit, + "cursor": cursor, + **extra_params, + } + data = await self.paginated_get("/markets", "markets", params, fetch_all) + return DataFrameList(AsyncMarket(self, MarketModel.model_validate(m)) for m in data) + + async def get_event( + self, + event_ticker: str, + *, + with_nested_markets: bool = False, + ) -> AsyncEvent: + params = {} + if with_nested_markets: + params["with_nested_markets"] = "true" + endpoint = f"/events/{event_ticker.upper()}" + if params: + endpoint += "?" + urlencode(params) + response = await self.get(endpoint) + model = EventModel.model_validate(response["event"]) + return AsyncEvent(self, model) + + async def get_events( + self, + *, + series_ticker: str | None = None, + status: MarketStatus | None = None, + limit: int = 100, + cursor: str | None = None, + fetch_all: bool = False, + **extra_params, + ) -> DataFrameList[AsyncEvent]: + params = { + "limit": limit, + "series_ticker": normalize_ticker(series_ticker), + "status": status.value if status is not None else None, + "cursor": cursor, + **extra_params, + } + data = await self.paginated_get("/events", "events", params, fetch_all) + return DataFrameList(AsyncEvent(self, EventModel.model_validate(e)) for e in data) + + async def get_series( + self, + series_ticker: str, + *, + include_volume: bool = False, + ) -> AsyncSeries: + params = {} + if include_volume: + params["include_volume"] = "true" + endpoint = f"/series/{series_ticker.upper()}" + if params: + endpoint += "?" + urlencode(params) + response = await self.get(endpoint) + model = SeriesModel.model_validate(response["series"]) + return AsyncSeries(self, model) + + async def get_all_series( + self, + *, + category: str | None = None, + limit: int = 100, + cursor: str | None = None, + fetch_all: bool = False, + **extra_params, + ) -> DataFrameList[AsyncSeries]: + params = {"limit": limit, "category": category, "cursor": cursor, **extra_params} + data = await self.paginated_get("/series", "series", params, fetch_all) + return DataFrameList(AsyncSeries(self, SeriesModel.model_validate(s)) for s in data) + + async def get_mve_collection(self, collection_ticker: str) -> AsyncMveCollection: + response = await self.get(f"/multivariate_event_collections/{collection_ticker}") + model = MveCollectionModel.model_validate(response.get("multivariate_contract", response)) + return AsyncMveCollection(self, model) + + async def get_mve_collections( + self, + *, + status: str | None = None, + associated_event_ticker: str | None = None, + series_ticker: str | None = None, + limit: int = 100, + cursor: str | None = None, + fetch_all: bool = False, + ) -> DataFrameList[AsyncMveCollection]: + params = { + "limit": limit, + "status": status, + "associated_event_ticker": normalize_ticker(associated_event_ticker), + "series_ticker": normalize_ticker(series_ticker), + "cursor": cursor, + } + data = await self.paginated_get( + "/multivariate_event_collections", "multivariate_contracts", params, fetch_all + ) + return DataFrameList( + AsyncMveCollection(self, MveCollectionModel.model_validate(c)) for c in data + ) + + async def get_multivariate_events( + self, + *, + series_ticker: str | None = None, + collection_ticker: str | None = None, + with_nested_markets: bool = False, + limit: int = 200, + cursor: str | None = None, + fetch_all: bool = False, + ) -> DataFrameList[AsyncEvent]: + params: dict = {"limit": limit} + if series_ticker: + params["series_ticker"] = normalize_ticker(series_ticker) + if collection_ticker: + params["collection_ticker"] = collection_ticker + if with_nested_markets: + params["with_nested_markets"] = "true" + if cursor: + params["cursor"] = cursor + + data = await self.paginated_get("/events/multivariate", "events", params, fetch_all) + return DataFrameList(AsyncEvent(self, EventModel.model_validate(e)) for e in data) + + async def get_trades( + self, + *, + ticker: str | None = None, + min_ts: int | None = None, + max_ts: int | None = None, + limit: int = 100, + cursor: str | None = None, + fetch_all: bool = False, + **extra_params, + ) -> DataFrameList[TradeModel]: + params = { + "limit": limit, + "ticker": normalize_ticker(ticker), + "min_ts": min_ts, + "max_ts": max_ts, + "cursor": cursor, + **extra_params, + } + data = await self.paginated_get("/markets/trades", "trades", params, fetch_all) + return DataFrameList(TradeModel.model_validate(t) for t in data) + + async def get_candlesticks_batch( + self, + tickers: list[str], + start_ts: int, + end_ts: int, + period: CandlestickPeriod = CandlestickPeriod.ONE_HOUR, + ) -> dict[str, CandlestickResponse]: + query = urlencode({ + "market_tickers": ",".join(normalize_tickers(tickers)), + "start_ts": start_ts, + "end_ts": end_ts, + "period_interval": period.value, + }) + response = await self.get(f"/markets/candlesticks?{query}") + return { + item["market_ticker"]: CandlestickResponse.model_validate(item) + for item in response.get("markets", []) + } diff --git a/pykalshi/_async/communications.py b/pykalshi/_async/communications.py new file mode 100644 index 0000000..f1de7ef --- /dev/null +++ b/pykalshi/_async/communications.py @@ -0,0 +1,181 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +from urllib.parse import urlencode + +from ..models import RfqModel, QuoteModel +from ..dataframe import DataFrameList + +if TYPE_CHECKING: + from .client import AsyncKalshiClient + + +class AsyncCommunications: + """RFQ (Request for Quote) and quote operations for combo trading. + + Multivariate event combos trade via the RFQ system rather than + standard limit orders. The flow is: + + 1. Create an RFQ broadcasting your intent to trade a combo. + 2. Market makers respond with two-sided quotes. + 3. Accept a quote to execute the trade. + + Usage: + # Create an RFQ for a combo market + rfq = await client.communications.create_rfq( + market_ticker="KXMVE-...", + contracts_fp="10.00", + ) + + # List active RFQs + rfqs = await client.communications.get_rfqs(status="active") + + # Respond to an RFQ as a market maker + quote = await client.communications.create_quote( + rfq_id=rfq.rfq_id, + yes_bid="0.45", + no_bid="0.55", + ) + """ + + def __init__(self, client: AsyncKalshiClient) -> None: + self._client = client + + async def create_rfq( + self, + market_ticker: str, + *, + contracts_fp: str | None = None, + target_cost_dollars: str | None = None, + rest_remainder: bool = False, + ) -> RfqModel: + """Create a Request for Quote. + + Args: + market_ticker: The combo market ticker to request quotes for. + contracts_fp: Number of contracts to trade (fixed-point string, e.g. "10.00"). + target_cost_dollars: Target cost in dollars (e.g. "10.00"). + Use this OR contracts_fp, not both. + rest_remainder: If True, rest any unfilled portion on the orderbook. + """ + + body: dict = { + "market_ticker": market_ticker.upper(), + "rest_remainder": rest_remainder, + } + if contracts_fp is not None: + body["contracts_fp"] = contracts_fp + if target_cost_dollars is not None: + body["target_cost_dollars"] = target_cost_dollars + + response = await self._client.post("/communications/rfqs", body) + return RfqModel.model_validate(response.get("rfq", response)) + + async def get_rfqs( + self, + *, + market_ticker: str | None = None, + status: str | None = None, + mve_collection_ticker: str | None = None, + limit: int = 100, + cursor: str | None = None, + fetch_all: bool = False, + ) -> DataFrameList[RfqModel]: + """List RFQs. + + Args: + market_ticker: Filter by combo market ticker. + status: Filter by RFQ status (e.g., "active", "expired"). + mve_collection_ticker: Filter by collection ticker. + limit: Maximum results per page (default 100). + cursor: Pagination cursor. + fetch_all: If True, automatically fetch all pages. + """ + params: dict = {"limit": limit} + if market_ticker: + params["market_ticker"] = market_ticker.upper() + if status: + params["status"] = status + if mve_collection_ticker: + params["mve_collection_ticker"] = mve_collection_ticker + if cursor: + params["cursor"] = cursor + + data = await self._client.paginated_get("/communications/rfqs", "rfqs", params, fetch_all) + return DataFrameList(RfqModel.model_validate(r) for r in data) + + async def get_rfq(self, rfq_id: str) -> RfqModel: + """Get a single RFQ by ID.""" + response = await self._client.get(f"/communications/rfqs/{rfq_id}") + return RfqModel.model_validate(response.get("rfq", response)) + + async def create_quote( + self, + rfq_id: str, + *, + yes_bid: str, + no_bid: str, + rest_remainder: bool = False, + ) -> QuoteModel: + """Create a quote in response to an RFQ. + + Prices are in FixedPointDollars (e.g., "0.45"). + + Args: + rfq_id: ID of the RFQ to respond to. + yes_bid: Your bid price for the YES side (FixedPointDollars). + no_bid: Your bid price for the NO side (FixedPointDollars). + rest_remainder: If True, rest any unfilled portion on the orderbook. + """ + body: dict = { + "rfq_id": rfq_id, + "yes_bid": yes_bid, + "no_bid": no_bid, + "rest_remainder": rest_remainder, + } + + response = await self._client.post("/communications/quotes", body) + return QuoteModel.model_validate(response.get("quote", response)) + + async def get_quotes( + self, + *, + creator_user_id: str | None = None, + rfq_creator_user_id: str | None = None, + rfq_id: str | None = None, + market_ticker: str | None = None, + status: str | None = None, + limit: int = 100, + cursor: str | None = None, + fetch_all: bool = False, + ) -> DataFrameList[QuoteModel]: + """List quotes. + + The API requires at least one of creator_user_id or rfq_creator_user_id. + + Args: + creator_user_id: Filter by quote creator. Required if rfq_creator_user_id not set. + rfq_creator_user_id: Filter by RFQ creator. Required if creator_user_id not set. + rfq_id: Filter by RFQ ID. + market_ticker: Filter by combo market ticker. + status: Filter by quote status. + limit: Maximum results per page (default 100). + cursor: Pagination cursor. + fetch_all: If True, automatically fetch all pages. + """ + params: dict = {"limit": limit} + if creator_user_id: + params["creator_user_id"] = creator_user_id + if rfq_creator_user_id: + params["rfq_creator_user_id"] = rfq_creator_user_id + if rfq_id: + params["rfq_id"] = rfq_id + if market_ticker: + params["market_ticker"] = market_ticker.upper() + if status: + params["status"] = status + if cursor: + params["cursor"] = cursor + + data = await self._client.paginated_get("/communications/quotes", "quotes", params, fetch_all) + return DataFrameList(QuoteModel.model_validate(q) for q in data) diff --git a/pykalshi/_async/events.py b/pykalshi/_async/events.py new file mode 100644 index 0000000..988eddb --- /dev/null +++ b/pykalshi/_async/events.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ..models import EventModel, ForecastPercentileHistory +from ..dataframe import DataFrameList + +if TYPE_CHECKING: + from .client import AsyncKalshiClient + from .markets import AsyncMarket, AsyncSeries + + +class AsyncEvent: + """Represents a Kalshi Event. + + An event is a container for related markets (e.g., "Will X happen?" with + multiple outcome markets). + + Key fields are exposed as typed properties for IDE support. + All other EventModel fields are accessible via attribute delegation. + """ + + def __init__(self, client: AsyncKalshiClient, data: EventModel) -> None: + self._client = client + self.data = data + + # --- Typed properties for core fields --- + + @property + def event_ticker(self) -> str: + return self.data.event_ticker + + @property + def series_ticker(self) -> str: + return self.data.series_ticker + + @property + def title(self) -> str | None: + return self.data.title + + @property + def category(self) -> str | None: + return self.data.category + + @property + def mutually_exclusive(self) -> bool: + return self.data.mutually_exclusive + + # --- Domain logic --- + + async def get_markets(self) -> DataFrameList[AsyncMarket]: + """Get all markets for this event.""" + return await self._client.get_markets(event_ticker=self.data.event_ticker) + + async def get_series(self) -> AsyncSeries: + """Get the parent Series for this event.""" + return await self._client.get_series(self.series_ticker) + + async def get_forecast_percentile_history( + self, + percentiles: list[int] | None = None, + ) -> ForecastPercentileHistory: + """Get historical forecast data at various percentiles. + + Args: + percentiles: List of percentiles to fetch (e.g., [10, 25, 50, 75, 90]). + If None, returns all available percentiles. + + Returns: + ForecastPercentileHistory with percentile -> history mapping. + """ + endpoint = f"/events/{self.event_ticker}/forecast/percentile_history" + if percentiles: + endpoint += f"?percentiles={','.join(str(p) for p in percentiles)}" + response = await self._client.get(endpoint) + return ForecastPercentileHistory.model_validate(response) + + def __getattr__(self, name: str): + return getattr(self.data, name) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, AsyncEvent): + return NotImplemented + return self.data.event_ticker == other.data.event_ticker + + def __hash__(self) -> int: + return hash(self.data.event_ticker) + + def __repr__(self) -> str: + parts = [f"" + + def _repr_html_(self) -> str: + from .._repr import event_html + return event_html(self) diff --git a/pykalshi/_async/exchange.py b/pykalshi/_async/exchange.py new file mode 100644 index 0000000..5e9954e --- /dev/null +++ b/pykalshi/_async/exchange.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from ..models import ExchangeStatus, Announcement +from ..exceptions import KalshiAPIError + +if TYPE_CHECKING: + from .client import AsyncKalshiClient + + +class AsyncExchange: + """Exchange status, schedule, and announcements.""" + + def __init__(self, client: AsyncKalshiClient) -> None: + self._client = client + + async def get_status(self) -> ExchangeStatus: + """Get current exchange operational status.""" + try: + data = await self._client.get("/exchange/status") + except KalshiAPIError as e: + if e.status_code == 503 and isinstance(e.response_body, dict): + data = e.response_body + else: + raise + return ExchangeStatus.model_validate(data) + + async def is_trading(self) -> bool: + """Quick check if trading is currently active.""" + status = await self.get_status() + return status.trading_active + + async def get_schedule(self) -> dict[str, Any]: + """Get exchange trading schedule (raw format).""" + data = await self._client.get("/exchange/schedule") + return data.get("schedule", {}) + + async def get_announcements(self) -> list[Announcement]: + """Get exchange-wide announcements.""" + data = await self._client.get("/exchange/announcements") + return [Announcement.model_validate(a) for a in data.get("announcements", [])] + + async def get_user_data_timestamp(self) -> int: + """Get timestamp of last user data validation (Unix ms).""" + data = await self._client.get("/exchange/user_data_timestamp") + return data.get("user_data_timestamp", 0) diff --git a/pykalshi/_async/markets.py b/pykalshi/_async/markets.py new file mode 100644 index 0000000..f12902e --- /dev/null +++ b/pykalshi/_async/markets.py @@ -0,0 +1,236 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from ..models import MarketModel, CandlestickResponse, OrderbookResponse, SeriesModel, TradeModel +from ..dataframe import DataFrameList +from ..enums import CandlestickPeriod, MarketStatus + +if TYPE_CHECKING: + from .client import AsyncKalshiClient + from .events import AsyncEvent + +logger = logging.getLogger(__name__) + + +class AsyncMarket: + """Represents a Kalshi Market. + + Key fields are exposed as typed properties for IDE support. + All other MarketModel fields are accessible via attribute delegation. + """ + + def __init__(self, client: AsyncKalshiClient, data: MarketModel) -> None: + self._client = client + self.data = data + + # --- Typed properties for core fields --- + + @property + def ticker(self) -> str: + return self.data.ticker + + @property + def event_ticker(self) -> str | None: + return self.data.event_ticker + + @property + def status(self) -> MarketStatus | None: + return self.data.status + + @property + def title(self) -> str | None: + return self.data.title + + @property + def subtitle(self) -> str | None: + return self.data.subtitle + + @property + def yes_bid_dollars(self) -> str | None: + return self.data.yes_bid_dollars + + @property + def yes_ask_dollars(self) -> str | None: + return self.data.yes_ask_dollars + + @property + def no_bid_dollars(self) -> str | None: + return self.data.no_bid_dollars + + @property + def no_ask_dollars(self) -> str | None: + return self.data.no_ask_dollars + + @property + def last_price_dollars(self) -> str | None: + return self.data.last_price_dollars + + @property + def volume_fp(self) -> str | None: + return self.data.volume_fp + + @property + def volume_24h_fp(self) -> str | None: + return self.data.volume_24h_fp + + @property + def open_interest_fp(self) -> str | None: + return self.data.open_interest_fp + + @property + def liquidity_dollars(self) -> str | None: + return self.data.liquidity_dollars + + @property + def open_time(self) -> str | None: + return self.data.open_time + + @property + def close_time(self) -> str | None: + return self.data.close_time + + @property + def result(self) -> str | None: + return self.data.result + + @property + def series_ticker(self) -> str | None: + return self.data.series_ticker + + async def resolve_series_ticker(self) -> str | None: + """Fetch series_ticker from the event API if not present in market data.""" + if self.data.series_ticker is not None: + return self.data.series_ticker + if not self.data.event_ticker: + return None + try: + event_response = await self._client.get(f"/events/{self.data.event_ticker}") + return event_response["event"]["series_ticker"] + except Exception as e: + logger.warning( + "Failed to resolve series_ticker for %s: %s", self.data.ticker, e + ) + return None + + async def get_orderbook(self, *, depth: int | None = None) -> OrderbookResponse: + """Get the orderbook for this market.""" + endpoint = f"/markets/{self.data.ticker}/orderbook" + if depth: + endpoint += f"?depth={depth}" + response = await self._client.get(endpoint) + return OrderbookResponse.model_validate(response) + + async def get_candlesticks( + self, + start_ts: int, + end_ts: int, + period: CandlestickPeriod = CandlestickPeriod.ONE_HOUR, + ) -> CandlestickResponse: + """Get candlestick data for this market.""" + series = await self.resolve_series_ticker() + if not series: + raise ValueError(f"Market {self.data.ticker} does not have a series_ticker.") + + query = f"start_ts={start_ts}&end_ts={end_ts}&period_interval={period.value}" + endpoint = f"/series/{series}/markets/{self.data.ticker}/candlesticks?{query}" + response = await self._client.get(endpoint) + return CandlestickResponse.model_validate(response) + + async def get_trades( + self, + min_ts: int | None = None, + max_ts: int | None = None, + limit: int = 100, + cursor: str | None = None, + fetch_all: bool = False, + ) -> DataFrameList[TradeModel]: + """Get public trade history for this market.""" + return await self._client.get_trades( + ticker=self.ticker, + min_ts=min_ts, + max_ts=max_ts, + limit=limit, + cursor=cursor, + fetch_all=fetch_all, + ) + + async def get_event(self) -> AsyncEvent | None: + """Get the parent Event for this market.""" + if not self.event_ticker: + return None + return await self._client.get_event(self.event_ticker) + + def __getattr__(self, name: str): + return getattr(self.data, name) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, AsyncMarket): + return NotImplemented + return self.data.ticker == other.data.ticker + + def __hash__(self) -> int: + return hash(self.data.ticker) + + def __repr__(self) -> str: + status = self.status.value if self.status else "?" + parts = [f"" + + def _repr_html_(self) -> str: + from .._repr import market_html + return market_html(self) + + +class AsyncSeries: + """Represents a Kalshi Series (collection of related markets).""" + + def __init__(self, client: AsyncKalshiClient, data: SeriesModel) -> None: + self._client = client + self.data = data + + @property + def ticker(self) -> str: + return self.data.ticker + + @property + def title(self) -> str | None: + return self.data.title + + @property + def category(self) -> str | None: + return self.data.category + + async def get_markets(self, **kwargs) -> DataFrameList[AsyncMarket]: + """Get all markets in this series.""" + return await self._client.get_markets(series_ticker=self.ticker, **kwargs) + + async def get_events(self, **kwargs) -> DataFrameList[AsyncEvent]: + """Get all events in this series.""" + return await self._client.get_events(series_ticker=self.ticker, **kwargs) + + def __getattr__(self, name: str): + return getattr(self.data, name) + + def __repr__(self) -> str: + parts = [f"" + + def _repr_html_(self) -> str: + from .._repr import series_html + return series_html(self) diff --git a/pykalshi/_async/mve.py b/pykalshi/_async/mve.py new file mode 100644 index 0000000..70d051f --- /dev/null +++ b/pykalshi/_async/mve.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ..models import MveCollectionModel, MveSelectedLeg, EventModel, MarketModel +from ..dataframe import DataFrameList +from ..enums import Side + +if TYPE_CHECKING: + from .client import AsyncKalshiClient + from .events import AsyncEvent + from .markets import AsyncMarket + + +class AsyncMveCollection: + """Represents a multivariate event collection (combo container). + + Collections define which events can be combined into combo markets. + Use create_market() to create a tradeable combo, then trade it via + client.communications (RFQ system). + """ + + def __init__(self, client: AsyncKalshiClient, data: MveCollectionModel) -> None: + self._client = client + self.data = data + + @property + def collection_ticker(self) -> str: + return self.data.collection_ticker + + @property + def title(self) -> str | None: + return self.data.title + + @property + def series_ticker(self) -> str | None: + return self.data.series_ticker + + async def create_market( + self, + selected_markets: list[dict[str, str]], + ) -> AsyncMarket: + """Create a combo market in this collection. + + Must be called before trading or looking up a combo. Each entry + specifies a leg of the combo. + + Args: + selected_markets: List of leg dicts, each with keys: + - market_ticker: The market ticker for this leg. + - event_ticker: The event ticker for this leg. + - side: "yes" or "no". + + Returns: + The created combo Market. + + Example: + market = await collection.create_market([ + {"market_ticker": "KXABC-A", "event_ticker": "KXABC", "side": "yes"}, + {"market_ticker": "KXDEF-B", "event_ticker": "KXDEF", "side": "yes"}, + ]) + """ + from .markets import AsyncMarket + + body = {"selected_markets": selected_markets, "with_market_payload": True} + response = await self._client.post( + f"/multivariate_event_collections/{self.collection_ticker}", body + ) + model = MarketModel.model_validate(response.get("market", response)) + return AsyncMarket(self._client, model) + + async def lookup_ticker( + self, + selected_markets: list[dict[str, str]], + ) -> dict: + """Look up tickers for a combo market by its leg combination. + + Returns 404 if the combination hasn't been previously created + via create_market(). + + Args: + selected_markets: List of leg dicts (same format as create_market). + + Returns: + Dict with market_ticker and event_ticker for the combo. + """ + body = {"selected_markets": selected_markets} + return await self._client.put( + f"/multivariate_event_collections/{self.collection_ticker}/lookup", body + ) + + async def get_events(self, *, with_nested_markets: bool = False) -> DataFrameList[AsyncEvent]: + """Get multivariate events in this collection. + + Args: + with_nested_markets: If True, include markets nested in each event. + """ + return await self._client.get_multivariate_events( + collection_ticker=self.collection_ticker, + with_nested_markets=with_nested_markets, + ) + + def __getattr__(self, name: str): + return getattr(self.data, name) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, AsyncMveCollection): + return NotImplemented + return self.data.collection_ticker == other.data.collection_ticker + + def __hash__(self) -> int: + return hash(self.data.collection_ticker) + + def __repr__(self) -> str: + parts = [f"" + + def _repr_html_(self) -> str: + from .._repr import mve_collection_html + return mve_collection_html(self) diff --git a/pykalshi/_async/orders.py b/pykalshi/_async/orders.py new file mode 100644 index 0000000..b217bf9 --- /dev/null +++ b/pykalshi/_async/orders.py @@ -0,0 +1,189 @@ +from __future__ import annotations + +import asyncio +import time +from typing import TYPE_CHECKING + +from ..models import OrderModel +from ..enums import OrderStatus, Action, Side, OrderType + +if TYPE_CHECKING: + from .client import AsyncKalshiClient + +TERMINAL_STATUSES = frozenset({OrderStatus.CANCELED, OrderStatus.EXECUTED}) + + +class AsyncOrder: + """Represents a Kalshi order. + + Key fields are exposed as typed properties for IDE support. + All other OrderModel fields are accessible via attribute delegation. + """ + + def __init__(self, client: AsyncKalshiClient, data: OrderModel) -> None: + self._client = client + self.data = data + + # --- Typed properties for core fields --- + + @property + def order_id(self) -> str: + return self.data.order_id + + @property + def ticker(self) -> str: + return self.data.ticker + + @property + def status(self) -> OrderStatus: + return self.data.status + + @property + def action(self) -> Action | None: + return self.data.action + + @property + def side(self) -> Side | None: + return self.data.side + + @property + def type(self) -> OrderType | None: + return self.data.type + + @property + def yes_price_dollars(self) -> str | None: + return self.data.yes_price_dollars + + @property + def no_price_dollars(self) -> str | None: + return self.data.no_price_dollars + + @property + def initial_count_fp(self) -> str | None: + return self.data.initial_count_fp + + @property + def fill_count_fp(self) -> str | None: + return self.data.fill_count_fp + + @property + def remaining_count_fp(self) -> str | None: + return self.data.remaining_count_fp + + @property + def created_time(self) -> str | None: + return self.data.created_time + + # --- Domain logic --- + + async def cancel(self) -> AsyncOrder: + """Cancel this order. + + Returns: + Self with updated data (status will be CANCELED). + """ + updated = await self._client.portfolio.cancel_order(self.order_id) + self.data = updated.data + return self + + async def amend( + self, + *, + count_fp: str | None = None, + yes_price_dollars: str | None = None, + no_price_dollars: str | None = None, + ) -> AsyncOrder: + """Amend this order's price or count. + + Args: + count_fp: New total contract count (fixed-point string). + yes_price_dollars: New YES price (dollar string). + no_price_dollars: New NO price (dollar string, converted to yes internally). + + Returns: + Self with updated data. + """ + updated = await self._client.portfolio.amend_order( + self.order_id, + count_fp=count_fp, + yes_price_dollars=yes_price_dollars, + no_price_dollars=no_price_dollars, + ticker=self.ticker, + action=self.action, + side=self.side, + ) + self.data = updated.data + return self + + async def decrease(self, reduce_by_fp: str) -> AsyncOrder: + """Decrease the remaining count of this order. + + Args: + reduce_by_fp: Number of contracts to reduce by (fixed-point string). + + Returns: + Self with updated data. + """ + updated = await self._client.portfolio.decrease_order(self.order_id, reduce_by_fp) + self.data = updated.data + return self + + async def refresh(self) -> AsyncOrder: + """Re-fetch this order's current state from the API. + + Returns: + Self with updated data. + """ + updated = await self._client.portfolio.get_order(self.order_id) + self.data = updated.data + return self + + async def wait_until_terminal( + self, timeout: float = 30.0, poll_interval: float = 0.5 + ) -> AsyncOrder: + """Block until order reaches a terminal state. + + Terminal states are: CANCELED, EXECUTED. + + Args: + timeout: Maximum seconds to wait before raising TimeoutError. + poll_interval: Seconds between refresh calls. + + Returns: + Self with updated data. + + Raises: + TimeoutError: If timeout is reached before terminal state. + """ + deadline = time.monotonic() + timeout + while self.status not in TERMINAL_STATUSES: + if time.monotonic() >= deadline: + raise TimeoutError( + f"Order {self.order_id} still {self.status.value} after {timeout}s" + ) + await asyncio.sleep(poll_interval) + await self.refresh() + return self + + def __getattr__(self, name: str): + return getattr(self.data, name) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, AsyncOrder): + return NotImplemented + return self.data.order_id == other.data.order_id + + def __hash__(self) -> int: + return hash(self.data.order_id) + + def __repr__(self) -> str: + action = self.action.value.upper() if self.action else "?" + side = self.side.value.upper() if self.side else "?" + price = self.yes_price_dollars if self.yes_price_dollars is not None else self.no_price_dollars + filled = self.fill_count_fp or "0" + total = self.initial_count_fp or "0" + return f"" + + def _repr_html_(self) -> str: + from .._repr import order_html + return order_html(self) diff --git a/pykalshi/_async/portfolio.py b/pykalshi/_async/portfolio.py new file mode 100644 index 0000000..ee2b3c6 --- /dev/null +++ b/pykalshi/_async/portfolio.py @@ -0,0 +1,647 @@ +from __future__ import annotations + +from decimal import Decimal +from typing import TYPE_CHECKING +from urllib.parse import urlencode + +from .orders import AsyncOrder +from ..enums import Action, Side, OrderStatus, TimeInForce, SelfTradePrevention, PositionCountFilter +from ..dataframe import DataFrameList +from .._utils import normalize_ticker, normalize_tickers +from ..models import ( + OrderModel, BalanceModel, PositionModel, FillModel, + SettlementModel, QueuePositionModel, OrderGroupModel, + SubaccountModel, SubaccountBalanceModel, SubaccountTransferModel, +) + +if TYPE_CHECKING: + from .client import AsyncKalshiClient + from .markets import AsyncMarket + + +class AsyncPortfolio: + """Authenticated user's portfolio and trading operations.""" + + def __init__(self, client: AsyncKalshiClient) -> None: + self._client = client + + async def get_balance(self) -> BalanceModel: + """Get portfolio balance. Values are dollar strings.""" + data = await self._client.get("/portfolio/balance") + return BalanceModel.model_validate(data) + + async def place_order( + self, + ticker: str | AsyncMarket, + action: Action, + side: Side, + count_fp: str, + *, + yes_price_dollars: str | None = None, + no_price_dollars: str | None = None, + client_order_id: str | None = None, + time_in_force: TimeInForce | None = None, + post_only: bool = False, + reduce_only: bool = False, + expiration_ts: int | None = None, + buy_max_cost_dollars: str | None = None, + self_trade_prevention: SelfTradePrevention | None = None, + order_group_id: str | None = None, + subaccount: int | None = None, + cancel_order_on_pause: bool | None = None, + ) -> AsyncOrder: + """Place an order on a market. + + Args: + ticker: Market ticker string or Market object. + action: BUY or SELL. + side: YES or NO. + count_fp: Number of contracts (fixed-point string, e.g. "10.00"). + yes_price_dollars: Price as dollar string (e.g. "0.45"). + no_price_dollars: Price as dollar string. Converted to + yes_price_dollars internally (yes = 1.00 - no). + client_order_id: Idempotency key. Resubmitting returns existing order. + time_in_force: GTC (default), IOC (immediate-or-cancel), FOK (fill-or-kill). + post_only: If True, reject order if it would take liquidity. + reduce_only: If True, only reduce existing position, never increase. + expiration_ts: Unix timestamp when order auto-cancels. + buy_max_cost_dollars: Maximum total cost (dollar string). Protects against slippage. + self_trade_prevention: Behavior on self-cross (CANCEL_RESTING or CANCEL_INCOMING). + order_group_id: Link to an order group for OCO/bracket strategies. + subaccount: Subaccount number (0 for primary, 1-32 for subaccounts). + cancel_order_on_pause: If True, cancel order if market is paused. + """ + # Extract market structure for validation when a Market object is passed + pls = None + fte = None + if not isinstance(ticker, str): + pls = getattr(ticker, 'price_level_structure', None) + fte = getattr(ticker, 'fractional_trading_enabled', None) + + order_data = self._build_order_data( + ticker, action, side, count_fp, + yes_price_dollars=yes_price_dollars, no_price_dollars=no_price_dollars, + client_order_id=client_order_id, time_in_force=time_in_force, + post_only=post_only, reduce_only=reduce_only, + expiration_ts=expiration_ts, buy_max_cost_dollars=buy_max_cost_dollars, + self_trade_prevention=self_trade_prevention, + order_group_id=order_group_id, subaccount=subaccount, + cancel_order_on_pause=cancel_order_on_pause, + price_level_structure=pls, + fractional_trading_enabled=fte, + ) + response = await self._client.post("/portfolio/orders", order_data) + model = OrderModel.model_validate(response["order"]) + return AsyncOrder(self._client, model) + + async def cancel_order(self, order_id: str, *, subaccount: int | None = None) -> AsyncOrder: + """Cancel a resting order. + + Args: + order_id: ID of the order to cancel. + subaccount: Subaccount number (0 for primary, 1-32 for subaccounts). + + Returns: + The canceled Order with updated status. + """ + endpoint = f"/portfolio/orders/{order_id}" + if subaccount is not None: + endpoint += f"?subaccount={subaccount}" + response = await self._client.delete(endpoint) + model = OrderModel.model_validate(response["order"]) + return AsyncOrder(self._client, model) + + async def amend_order( + self, + order_id: str, + *, + count_fp: str | None = None, + yes_price_dollars: str | None = None, + no_price_dollars: str | None = None, + subaccount: int | None = None, + # Required by API but can be fetched from existing order + ticker: str | None = None, + action: Action | None = None, + side: Side | None = None, + ) -> AsyncOrder: + """Amend a resting order's price or count. + + Args: + order_id: ID of the order to amend. + count_fp: New total contract count (fixed-point string). + yes_price_dollars: New YES price (dollar string). + no_price_dollars: New NO price (dollar string). Converted internally. + subaccount: Subaccount number (0 for primary, 1-32 for subaccounts). + ticker: Market ticker (fetched from order if not provided). + action: Order action (fetched from order if not provided). + side: Order side (fetched from order if not provided). + """ + if yes_price_dollars is not None and no_price_dollars is not None: + raise ValueError("Specify yes_price_dollars or no_price_dollars, not both") + + if no_price_dollars is not None: + yes_price_dollars = str(Decimal("1") - Decimal(no_price_dollars)) + + ticker = normalize_ticker(ticker) + + # Fetch original order to get required fields if not provided + if ticker is None or action is None or side is None or count_fp is None: + original = await self.get_order(order_id) + ticker = ticker or original.ticker + action = action or original.action + side = side or original.side + if count_fp is None: + count_fp = original.remaining_count_fp + + body: dict = { + "ticker": ticker, + "action": action.value if isinstance(action, Action) else action, + "side": side.value if isinstance(side, Side) else side, + "count_fp": count_fp, + } + if yes_price_dollars is not None: + body["yes_price_dollars"] = yes_price_dollars + if subaccount is not None: + body["subaccount"] = subaccount + + if "count_fp" not in body and "yes_price_dollars" not in body: + raise ValueError("Must specify at least one of count_fp, yes_price_dollars, or no_price_dollars") + + response = await self._client.post(f"/portfolio/orders/{order_id}/amend", body) + model = OrderModel.model_validate(response["order"]) + return AsyncOrder(self._client, model) + + async def decrease_order(self, order_id: str, reduce_by_fp: str) -> AsyncOrder: + """Decrease the remaining count of a resting order. + + Args: + order_id: ID of the order to decrease. + reduce_by_fp: Number of contracts to reduce by (fixed-point string). + """ + response = await self._client.post( + f"/portfolio/orders/{order_id}/decrease", {"reduce_by_fp": reduce_by_fp} + ) + model = OrderModel.model_validate(response["order"]) + return AsyncOrder(self._client, model) + + async def get_orders( + self, + *, + status: OrderStatus | None = None, + ticker: str | None = None, + event_ticker: str | None = None, + min_ts: int | None = None, + max_ts: int | None = None, + limit: int = 100, + cursor: str | None = None, + fetch_all: bool = False, + **extra_params, + ) -> DataFrameList[AsyncOrder]: + """Get list of orders. + + Args: + status: Filter by order status (resting, canceled, executed). + ticker: Filter by market ticker. + event_ticker: Filter by event ticker (supports comma-separated, max 10). + min_ts: Filter orders after this Unix timestamp. + max_ts: Filter orders before this Unix timestamp. + limit: Maximum results per page (default 100, max 200). + cursor: Pagination cursor for fetching next page. + fetch_all: If True, automatically fetch all pages. + **extra_params: Additional API parameters (e.g., subaccount). + """ + params = { + "limit": limit, + "status": status.value if status is not None else None, + "ticker": normalize_ticker(ticker), + "event_ticker": normalize_ticker(event_ticker), + "min_ts": min_ts, + "max_ts": max_ts, + "cursor": cursor, + **extra_params, + } + data = await self._client.paginated_get("/portfolio/orders", "orders", params, fetch_all) + return DataFrameList(AsyncOrder(self._client, OrderModel.model_validate(d)) for d in data) + + async def get_order(self, order_id: str) -> AsyncOrder: + """Get a single order by ID.""" + response = await self._client.get(f"/portfolio/orders/{order_id}") + model = OrderModel.model_validate(response["order"]) + return AsyncOrder(self._client, model) + + async def get_positions( + self, + *, + ticker: str | None = None, + event_ticker: str | None = None, + count_filter: PositionCountFilter | None = None, + limit: int = 100, + cursor: str | None = None, + fetch_all: bool = False, + **extra_params, + ) -> DataFrameList[PositionModel]: + """Get portfolio positions. + + Args: + ticker: Filter by specific market ticker. + event_ticker: Filter by event ticker (supports comma-separated, max 10). + count_filter: Filter positions with non-zero values (POSITION or TOTAL_TRADED). + limit: Maximum positions per page (default 100, max 1000). + cursor: Pagination cursor for fetching next page. + fetch_all: If True, automatically fetch all pages. + **extra_params: Additional API parameters (e.g., subaccount). + """ + params = { + "limit": limit, + "ticker": normalize_ticker(ticker), + "event_ticker": normalize_ticker(event_ticker), + "count_filter": count_filter.value if count_filter is not None else None, + "cursor": cursor, + **extra_params, + } + data = await self._client.paginated_get("/portfolio/positions", "market_positions", params, fetch_all) + return DataFrameList(PositionModel.model_validate(p) for p in data) + + async def get_fills( + self, + *, + ticker: str | None = None, + order_id: str | None = None, + min_ts: int | None = None, + max_ts: int | None = None, + limit: int = 100, + cursor: str | None = None, + fetch_all: bool = False, + **extra_params, + ) -> DataFrameList[FillModel]: + """Get trade fills (executed trades). + + Args: + ticker: Filter by market ticker. + order_id: Filter by specific order ID. + min_ts: Minimum timestamp (Unix seconds). + max_ts: Maximum timestamp (Unix seconds). + limit: Maximum fills per page (default 100, max 200). + cursor: Pagination cursor for fetching next page. + fetch_all: If True, automatically fetch all pages. + **extra_params: Additional API parameters (e.g., subaccount). + """ + params = { + "limit": limit, + "ticker": normalize_ticker(ticker), + "order_id": order_id, + "min_ts": min_ts, + "max_ts": max_ts, + "cursor": cursor, + **extra_params, + } + data = await self._client.paginated_get("/portfolio/fills", "fills", params, fetch_all) + return DataFrameList(FillModel.model_validate(f) for f in data) + + # --- Batch Operations --- + + async def batch_place_orders(self, orders: list[dict]) -> DataFrameList[AsyncOrder]: + """Place multiple orders atomically. + + Args: + orders: List of order dicts with keys: ticker, action, side, count_fp, + yes_price_dollars/no_price_dollars, and optional advanced params. + + Example: + orders = [ + {"ticker": "KXBTC", "action": "buy", "side": "yes", "count_fp": "10.00", "yes_price_dollars": "0.45"}, + {"ticker": "KXBTC", "action": "buy", "side": "no", "count_fp": "10.00", "no_price_dollars": "0.45"}, + ] + results = await portfolio.batch_place_orders(orders) + """ + prepared = self._build_batch_orders(orders) + response = await self._client.post("/portfolio/orders/batched", {"orders": prepared}) + result = [] + for item in response.get("orders", []): + order_data = item.get("order") + if order_data is None: + continue + result.append(AsyncOrder(self._client, OrderModel.model_validate(order_data))) + return DataFrameList(result) + + async def batch_cancel_orders(self, order_ids: list[str]) -> DataFrameList[AsyncOrder]: + """Cancel multiple orders atomically. + + Args: + order_ids: List of order IDs to cancel (max 20). + + Returns: + The canceled Orders with updated status. + """ + orders = [{"order_id": oid} for oid in order_ids] + response = await self._client.delete("/portfolio/orders/batched", {"orders": orders}) + result = [] + for item in response.get("orders", []): + order_data = item.get("order") + if order_data is None: + continue + result.append(AsyncOrder(self._client, OrderModel.model_validate(order_data))) + return DataFrameList(result) + + # --- Queue Position --- + + async def get_queue_position(self, order_id: str) -> QueuePositionModel: + """Get queue position for a single resting order.""" + response = await self._client.get(f"/portfolio/orders/{order_id}/queue_position") + return QueuePositionModel( + order_id=order_id, + queue_position_fp=response.get("queue_position_fp", "0.00"), + ) + + async def get_queue_positions( + self, + *, + market_tickers: list[str] | None = None, + event_ticker: str | None = None, + ) -> DataFrameList[QueuePositionModel]: + """Get queue positions for all resting orders.""" + params: dict = {} + if market_tickers: + params["market_tickers"] = ",".join(normalize_tickers(market_tickers)) + if event_ticker: + params["event_ticker"] = normalize_ticker(event_ticker) + + endpoint = "/portfolio/orders/queue_positions" + if params: + endpoint = f"{endpoint}?{urlencode(params)}" + + response = await self._client.get(endpoint) + return DataFrameList( + QueuePositionModel.model_validate(qp) + for qp in response.get("queue_positions", []) + ) + + # --- Settlements --- + + async def get_settlements( + self, + *, + ticker: str | None = None, + event_ticker: str | None = None, + limit: int = 100, + cursor: str | None = None, + fetch_all: bool = False, + **extra_params, + ) -> DataFrameList[SettlementModel]: + """Get settlement records for resolved positions.""" + params = { + "limit": limit, + "ticker": normalize_ticker(ticker), + "event_ticker": normalize_ticker(event_ticker), + "cursor": cursor, + **extra_params, + } + data = await self._client.paginated_get("/portfolio/settlements", "settlements", params, fetch_all) + return DataFrameList(SettlementModel.model_validate(s) for s in data) + + async def get_resting_order_value(self) -> str: + """Get total value of all resting orders as dollar string. + + NOTE: This endpoint is FCM-only (institutional accounts). + """ + response = await self._client.get("/portfolio/summary/total_resting_order_value") + return response.get("total_resting_order_value_dollars", "0") + + # --- Order Groups (Contract Rate Limiting) --- + + async def create_order_group(self, contracts_limit_fp: str) -> OrderGroupModel: + """Create an order group for rate-limiting contract matches. + + Args: + contracts_limit_fp: Maximum contracts (fixed-point string) that can be + matched in a rolling 15-second window. + + Returns: + Created OrderGroupModel. + """ + body: dict = {"contracts_limit_fp": contracts_limit_fp} + response = await self._client.post("/portfolio/order_groups/create", body) + return OrderGroupModel.model_validate(response) + + async def get_order_group(self, order_group_id: str) -> OrderGroupModel: + """Get an order group by ID.""" + response = await self._client.get(f"/portfolio/order_groups/{order_group_id}") + response["id"] = order_group_id + return OrderGroupModel.model_validate(response) + + async def trigger_order_group(self, order_group_id: str) -> None: + """Manually trigger an order group, cancelling all orders in it.""" + await self._client.put(f"/portfolio/order_groups/{order_group_id}/trigger", {}) + + async def get_order_groups(self) -> DataFrameList[OrderGroupModel]: + """List all order groups.""" + response = await self._client.get("/portfolio/order_groups") + return DataFrameList( + OrderGroupModel.model_validate(og) + for og in response.get("order_groups", []) + ) + + async def reset_order_group(self, order_group_id: str) -> None: + """Reset matched contract counter for an order group.""" + await self._client.put(f"/portfolio/order_groups/{order_group_id}/reset", {}) + + async def update_order_group_limit(self, order_group_id: str, contracts_limit_fp: str) -> None: + """Update the contracts limit for an order group. + + Args: + order_group_id: ID of the order group. + contracts_limit_fp: New maximum contracts (fixed-point string). + """ + body: dict = {"contracts_limit_fp": contracts_limit_fp} + await self._client.put(f"/portfolio/order_groups/{order_group_id}/limit", body) + + # --- Subaccounts --- + + async def create_subaccount(self) -> SubaccountModel: + """Create a new numbered subaccount.""" + response = await self._client.post("/portfolio/subaccounts", {}) + return SubaccountModel.model_validate(response.get("subaccount", response)) + + async def transfer_between_subaccounts( + self, + from_subaccount_id: str, + to_subaccount_id: str, + amount_dollars: str, + ) -> SubaccountTransferModel: + """Transfer funds between subaccounts. + + Args: + from_subaccount_id: Source subaccount ID. + to_subaccount_id: Destination subaccount ID. + amount_dollars: Amount to transfer (dollar string). + """ + body = { + "from_subaccount_id": from_subaccount_id, + "to_subaccount_id": to_subaccount_id, + "amount_dollars": amount_dollars, + } + response = await self._client.post("/portfolio/subaccounts/transfer", body) + return SubaccountTransferModel.model_validate(response.get("transfer", response)) + + async def get_subaccount_balances(self) -> DataFrameList[SubaccountBalanceModel]: + """Get balances for all subaccounts.""" + response = await self._client.get("/portfolio/subaccounts/balances") + return DataFrameList( + SubaccountBalanceModel.model_validate(b) + for b in response.get("balances", []) + ) + + async def get_subaccount_transfers( + self, + *, + limit: int = 100, + cursor: str | None = None, + fetch_all: bool = False, + **extra_params, + ) -> DataFrameList[SubaccountTransferModel]: + """Get transfer history between subaccounts.""" + params = {"limit": limit, "cursor": cursor, **extra_params} + data = await self._client.paginated_get( + "/portfolio/subaccounts/transfers", "transfers", params, fetch_all + ) + return DataFrameList(SubaccountTransferModel.model_validate(t) for t in data) + + # --- Shared validation helpers --- + + @staticmethod + def _validate_tick_size(price: Decimal, price_level_structure: str) -> None: + """Validate that price aligns to the market's tick size. + + Raises ValueError if the price is not on a valid tick boundary. + """ + if price_level_structure == "linear_cent": + # $0.00-$1.00, tick $0.01 + tick = Decimal("0.01") + if price % tick != 0: + raise ValueError( + f"Price {price} is not on a valid tick for linear_cent " + f"(tick size $0.01)" + ) + elif price_level_structure == "deci_cent": + # $0.00-$1.00, tick $0.001 + tick = Decimal("0.001") + if price % tick != 0: + raise ValueError( + f"Price {price} is not on a valid tick for deci_cent " + f"(tick size $0.001)" + ) + elif price_level_structure == "tapered_deci_cent": + # $0.00-$0.10: tick $0.001, $0.10-$0.90: tick $0.01, $0.90-$1.00: tick $0.001 + if price <= Decimal("0.10") or price >= Decimal("0.90"): + tick = Decimal("0.001") + else: + tick = Decimal("0.01") + if price % tick != 0: + raise ValueError( + f"Price {price} is not on a valid tick for tapered_deci_cent " + f"(tick size ${tick} in this price range)" + ) + + @staticmethod + def _validate_fractional(count_fp: str, fractional_enabled: bool) -> None: + """Validate count_fp is whole when fractional trading is disabled.""" + if not fractional_enabled: + d = Decimal(count_fp) + if d != int(d): + raise ValueError( + f"Fractional trading is not enabled for this market. " + f"count_fp must be a whole number, got {count_fp}" + ) + + @staticmethod + def _build_order_data( + ticker, + action: Action, + side: Side, + count_fp: str, + *, + yes_price_dollars=None, + no_price_dollars=None, + client_order_id=None, + time_in_force=None, + post_only=False, + reduce_only=False, + expiration_ts=None, + buy_max_cost_dollars=None, + self_trade_prevention=None, + order_group_id=None, + subaccount=None, + cancel_order_on_pause=None, + price_level_structure=None, + fractional_trading_enabled=None, + ) -> dict: + """Build and validate order data dict. No I/O. + + If price_level_structure is provided, validates tick size alignment. + If fractional_trading_enabled is provided (False), validates count_fp is whole. + """ + if yes_price_dollars is not None and no_price_dollars is not None: + raise ValueError("Specify yes_price_dollars or no_price_dollars, not both") + + if yes_price_dollars is None and no_price_dollars is None: + raise ValueError("Limit orders require yes_price_dollars or no_price_dollars") + + if no_price_dollars is not None: + yes_price_dollars = str(Decimal("1") - Decimal(no_price_dollars)) + + # Validate tick size if market structure is known + if price_level_structure and yes_price_dollars is not None: + AsyncPortfolio._validate_tick_size(Decimal(yes_price_dollars), price_level_structure) + + # Validate fractional trading + if fractional_trading_enabled is not None: + AsyncPortfolio._validate_fractional(count_fp, fractional_trading_enabled) + + ticker_str = ticker.upper() if isinstance(ticker, str) else ticker.ticker + + order_data: dict = { + "ticker": ticker_str, + "action": action.value, + "side": side.value, + "count_fp": count_fp, + "yes_price_dollars": yes_price_dollars, + } + if client_order_id is not None: + order_data["client_order_id"] = client_order_id + if time_in_force is not None: + order_data["time_in_force"] = time_in_force.value + if post_only: + order_data["post_only"] = True + if reduce_only: + order_data["reduce_only"] = True + if expiration_ts is not None: + order_data["expiration_ts"] = expiration_ts + if buy_max_cost_dollars is not None: + order_data["buy_max_cost_dollars"] = buy_max_cost_dollars + if self_trade_prevention is not None: + order_data["self_trade_prevention_type"] = self_trade_prevention.value + if order_group_id is not None: + order_data["order_group_id"] = order_group_id + if subaccount is not None: + order_data["subaccount"] = subaccount + if cancel_order_on_pause is not None: + order_data["cancel_order_on_pause"] = cancel_order_on_pause + return order_data + + @staticmethod + def _build_batch_orders(orders: list[dict]) -> list[dict]: + """Validate and prepare batch orders. No I/O.""" + prepared = [] + for order in orders: + o = dict(order) + + if "yes_price_dollars" in o and "no_price_dollars" in o: + raise ValueError("Specify yes_price_dollars or no_price_dollars, not both") + if "yes_price_dollars" not in o and "no_price_dollars" not in o: + raise ValueError("Limit orders require yes_price_dollars or no_price_dollars") + if "no_price_dollars" in o: + o["yes_price_dollars"] = str(Decimal("1") - Decimal(o.pop("no_price_dollars"))) + # Strip "type" -- Kalshi API no longer accepts it + o.pop("type", None) + prepared.append(o) + return prepared diff --git a/pykalshi/_sync/__init__.py b/pykalshi/_sync/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pykalshi/_sync/api_keys.py b/pykalshi/_sync/api_keys.py new file mode 100644 index 0000000..9fbd506 --- /dev/null +++ b/pykalshi/_sync/api_keys.py @@ -0,0 +1,66 @@ +# AUTO-GENERATED from pykalshi/_async/api_keys.py — do not edit manually. +# Re-run: python scripts/generate_sync.py +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ..models import APIKey, GeneratedAPIKey, APILimits + +if TYPE_CHECKING: + from .client import KalshiClient + + +class APIKeys: + """API key management and account limits.""" + + def __init__(self, client: KalshiClient) -> None: + self._client = client + + def list(self) -> list[APIKey]: + """List all API keys for this account.""" + data = self._client.get("/api_keys") + return [APIKey.model_validate(k) for k in data.get("api_keys", [])] + + def create(self, public_key: str, name: str | None = None) -> str: + """Create an API key with a provided RSA public key. + + Args: + public_key: PEM-encoded RSA public key. + name: Optional name for the key. + + Returns: + The API key ID string. + """ + body: dict = {"public_key": public_key} + if name: + body["name"] = name + data = self._client.post("/api_keys", body) + return data["api_key_id"] + + def generate(self, name: str | None = None) -> GeneratedAPIKey: + """Generate a new API key pair (Kalshi creates both keys). + + Returns a GeneratedAPIKey with the private_key field populated. + The private key is only returned ONCE - store it securely. + + Args: + name: Optional name for the key. + """ + body: dict = {} + if name: + body["name"] = name + data = self._client.post("/api_keys/generate", body) + return GeneratedAPIKey.model_validate(data) + + def delete(self, key_id: str) -> None: + """Delete an API key. + + Args: + key_id: The API key ID to delete. + """ + self._client.delete(f"/api_keys/{key_id}") + + def get_limits(self) -> APILimits: + """Get API rate limits for this account.""" + data = self._client.get("/account/limits") + return APILimits.model_validate(data) diff --git a/pykalshi/_sync/client.py b/pykalshi/_sync/client.py new file mode 100644 index 0000000..85223f6 --- /dev/null +++ b/pykalshi/_sync/client.py @@ -0,0 +1,407 @@ +# AUTO-GENERATED from pykalshi/_async/client.py — do not edit manually. +# Re-run: python scripts/generate_sync.py +"""Kalshi API Client.""" + +from __future__ import annotations + +import time +import json +import logging +from functools import cached_property +from typing import Any, TYPE_CHECKING +from urllib.parse import urlencode + +import httpx + +from .._base import _BaseKalshiClient, _RETRYABLE_STATUS_CODES +from .events import Event +from .markets import Market, Series +from .mve import MveCollection +from ..models import MarketModel, EventModel, SeriesModel, TradeModel, CandlestickResponse, MveCollectionModel +from ..dataframe import DataFrameList +from .portfolio import Portfolio +from ..enums import MarketStatus, CandlestickPeriod +from .exchange import Exchange +from .api_keys import APIKeys +from .communications import Communications +from ..exceptions import RateLimitError +from .._utils import normalize_ticker, normalize_tickers + +if TYPE_CHECKING: + from ..feed import Feed + from ..rate_limiter import RateLimiterProtocol + +logger = logging.getLogger(__name__) + + +class KalshiClient(_BaseKalshiClient): + """Authenticated client for the Kalshi Trading API. + + Usage: + with KalshiClient.from_env() as client: + market = client.get_market("TICKER") + balance = client.portfolio.get_balance() + """ + + def __init__( + self, + api_key_id: str | None = None, + private_key_path: str | None = None, + api_base: str | None = None, + demo: bool = False, + timeout: float = 10.0, + max_retries: int = 3, + rate_limiter: RateLimiterProtocol | None = None, + ) -> None: + super().__init__( + api_key_id=api_key_id, + private_key_path=private_key_path, + api_base=api_base, + demo=demo, + timeout=timeout, + max_retries=max_retries, + rate_limiter=rate_limiter, + ) + self._session = httpx.Client() + + def close(self) -> None: + """Close the underlying HTTP connection pool.""" + self._session.close() + + def __enter__(self) -> KalshiClient: + return self + + def __exit__(self, *args: Any) -> None: + self.close() + + # --- HTTP methods --- + + def _request(self, method: str, endpoint: str, **kwargs: Any) -> httpx.Response: + """Execute async HTTP request with retry on transient failures.""" + url = f"{self.api_base}{endpoint}" + + for attempt in range(self.max_retries + 1): + if self.rate_limiter is not None: + wait_time = self.rate_limiter.acquire() + if wait_time > 0: + logger.debug("Rate limiter waited %.3fs", wait_time) + + headers = self._get_headers(method, endpoint) + request_kwargs: dict[str, Any] = {"headers": headers, "timeout": self.timeout} + if "data" in kwargs: + request_kwargs["content"] = kwargs["data"] + try: + response = self._session.request(method, url, **request_kwargs) + except httpx.TimeoutException as e: + if attempt == self.max_retries: + raise + wait = self._compute_backoff(attempt, None) + logger.warning( + "%s %s failed (%s), retry %d/%d in %.1fs", + method, endpoint, type(e).__name__, + attempt + 1, self.max_retries, wait, + ) + time.sleep(wait) + continue + except httpx.ConnectError as e: + if attempt == self.max_retries: + raise + wait = self._compute_backoff(attempt, None) + logger.warning( + "%s %s failed (%s), retry %d/%d in %.1fs", + method, endpoint, type(e).__name__, + attempt + 1, self.max_retries, wait, + ) + time.sleep(wait) + continue + + self._update_rate_limiter(response) + + if response.status_code not in _RETRYABLE_STATUS_CODES: + return response + if attempt == self.max_retries: + if response.status_code == 429: + raise RateLimitError( + 429, "Rate limit exceeded after retries", + method=method, endpoint=endpoint, + ) + return response + + wait = self._compute_backoff(attempt, response.headers.get("Retry-After")) + logger.warning( + "%s %s returned %d, retry %d/%d in %.1fs", + method, endpoint, response.status_code, + attempt + 1, self.max_retries, wait, + ) + time.sleep(wait) + + return response # unreachable, satisfies type checker + + def get(self, endpoint: str) -> dict[str, Any]: + """Make authenticated GET request.""" + logger.debug("GET %s", endpoint) + response = self._request("GET", endpoint) + return self._handle_response(response, method="GET", endpoint=endpoint) + + def paginated_get( + self, + path: str, + response_key: str, + params: dict[str, Any], + fetch_all: bool = False, + ) -> list[dict]: + """Fetch items with automatic cursor-based pagination.""" + params = dict(params) + all_items: list[dict] = [] + while True: + filtered = {k: v for k, v in params.items() if v is not None} + endpoint = f"{path}?{urlencode(filtered)}" if filtered else path + response = self.get(endpoint) + all_items.extend(response.get(response_key, [])) + cursor = response.get("cursor", "") + if not fetch_all or not cursor: + break + params["cursor"] = cursor + return all_items + + def post(self, endpoint: str, data: dict[str, Any]) -> dict[str, Any]: + """Make authenticated POST request.""" + logger.debug("POST %s", endpoint) + body = json.dumps(data, separators=(",", ":")) + response = self._request("POST", endpoint, data=body) + return self._handle_response( + response, method="POST", endpoint=endpoint, request_body=data + ) + + def put(self, endpoint: str, data: dict[str, Any]) -> dict[str, Any]: + """Make authenticated PUT request.""" + logger.debug("PUT %s", endpoint) + body = json.dumps(data, separators=(",", ":")) + response = self._request("PUT", endpoint, data=body) + return self._handle_response( + response, method="PUT", endpoint=endpoint, request_body=data + ) + + def delete(self, endpoint: str, body: dict | None = None) -> dict[str, Any]: + """Make authenticated DELETE request.""" + logger.debug("DELETE %s", endpoint) + if body: + data = json.dumps(body, separators=(",", ":")) + response = self._request("DELETE", endpoint, data=data) + else: + response = self._request("DELETE", endpoint) + return self._handle_response(response, method="DELETE", endpoint=endpoint) + + # --- Domain accessors --- + + @cached_property + def portfolio(self) -> Portfolio: + return Portfolio(self) + + @cached_property + def exchange(self) -> Exchange: + return Exchange(self) + + @cached_property + def api_keys(self) -> APIKeys: + return APIKeys(self) + + @cached_property + def communications(self) -> Communications: + return Communications(self) + + def feed(self) -> Feed: + """Create a new async real-time data feed.""" + from ..feed import Feed + return Feed(self) + + # --- Domain query methods --- + + def get_market(self, ticker: str) -> Market: + response = self.get(f"/markets/{ticker.upper()}") + model = MarketModel.model_validate(response["market"]) + return Market(self, model) + + def get_markets( + self, + *, + status: MarketStatus | None = None, + mve_filter: str | None = None, + tickers: list[str] | None = None, + series_ticker: str | None = None, + event_ticker: str | None = None, + limit: int = 100, + cursor: str | None = None, + fetch_all: bool = False, + **extra_params, + ) -> DataFrameList[Market]: + params = { + "status": status.value if status is not None else None, + "mve_filter": mve_filter, + "tickers": ",".join(normalize_tickers(tickers)) if tickers else None, + "series_ticker": normalize_ticker(series_ticker), + "event_ticker": normalize_ticker(event_ticker), + "limit": limit, + "cursor": cursor, + **extra_params, + } + data = self.paginated_get("/markets", "markets", params, fetch_all) + return DataFrameList(Market(self, MarketModel.model_validate(m)) for m in data) + + def get_event( + self, + event_ticker: str, + *, + with_nested_markets: bool = False, + ) -> Event: + params = {} + if with_nested_markets: + params["with_nested_markets"] = "true" + endpoint = f"/events/{event_ticker.upper()}" + if params: + endpoint += "?" + urlencode(params) + response = self.get(endpoint) + model = EventModel.model_validate(response["event"]) + return Event(self, model) + + def get_events( + self, + *, + series_ticker: str | None = None, + status: MarketStatus | None = None, + limit: int = 100, + cursor: str | None = None, + fetch_all: bool = False, + **extra_params, + ) -> DataFrameList[Event]: + params = { + "limit": limit, + "series_ticker": normalize_ticker(series_ticker), + "status": status.value if status is not None else None, + "cursor": cursor, + **extra_params, + } + data = self.paginated_get("/events", "events", params, fetch_all) + return DataFrameList(Event(self, EventModel.model_validate(e)) for e in data) + + def get_series( + self, + series_ticker: str, + *, + include_volume: bool = False, + ) -> Series: + params = {} + if include_volume: + params["include_volume"] = "true" + endpoint = f"/series/{series_ticker.upper()}" + if params: + endpoint += "?" + urlencode(params) + response = self.get(endpoint) + model = SeriesModel.model_validate(response["series"]) + return Series(self, model) + + def get_all_series( + self, + *, + category: str | None = None, + limit: int = 100, + cursor: str | None = None, + fetch_all: bool = False, + **extra_params, + ) -> DataFrameList[Series]: + params = {"limit": limit, "category": category, "cursor": cursor, **extra_params} + data = self.paginated_get("/series", "series", params, fetch_all) + return DataFrameList(Series(self, SeriesModel.model_validate(s)) for s in data) + + def get_mve_collection(self, collection_ticker: str) -> MveCollection: + response = self.get(f"/multivariate_event_collections/{collection_ticker}") + model = MveCollectionModel.model_validate(response.get("multivariate_contract", response)) + return MveCollection(self, model) + + def get_mve_collections( + self, + *, + status: str | None = None, + associated_event_ticker: str | None = None, + series_ticker: str | None = None, + limit: int = 100, + cursor: str | None = None, + fetch_all: bool = False, + ) -> DataFrameList[MveCollection]: + params = { + "limit": limit, + "status": status, + "associated_event_ticker": normalize_ticker(associated_event_ticker), + "series_ticker": normalize_ticker(series_ticker), + "cursor": cursor, + } + data = self.paginated_get( + "/multivariate_event_collections", "multivariate_contracts", params, fetch_all + ) + return DataFrameList( + MveCollection(self, MveCollectionModel.model_validate(c)) for c in data + ) + + def get_multivariate_events( + self, + *, + series_ticker: str | None = None, + collection_ticker: str | None = None, + with_nested_markets: bool = False, + limit: int = 200, + cursor: str | None = None, + fetch_all: bool = False, + ) -> DataFrameList[Event]: + params: dict = {"limit": limit} + if series_ticker: + params["series_ticker"] = normalize_ticker(series_ticker) + if collection_ticker: + params["collection_ticker"] = collection_ticker + if with_nested_markets: + params["with_nested_markets"] = "true" + if cursor: + params["cursor"] = cursor + + data = self.paginated_get("/events/multivariate", "events", params, fetch_all) + return DataFrameList(Event(self, EventModel.model_validate(e)) for e in data) + + def get_trades( + self, + *, + ticker: str | None = None, + min_ts: int | None = None, + max_ts: int | None = None, + limit: int = 100, + cursor: str | None = None, + fetch_all: bool = False, + **extra_params, + ) -> DataFrameList[TradeModel]: + params = { + "limit": limit, + "ticker": normalize_ticker(ticker), + "min_ts": min_ts, + "max_ts": max_ts, + "cursor": cursor, + **extra_params, + } + data = self.paginated_get("/markets/trades", "trades", params, fetch_all) + return DataFrameList(TradeModel.model_validate(t) for t in data) + + def get_candlesticks_batch( + self, + tickers: list[str], + start_ts: int, + end_ts: int, + period: CandlestickPeriod = CandlestickPeriod.ONE_HOUR, + ) -> dict[str, CandlestickResponse]: + query = urlencode({ + "market_tickers": ",".join(normalize_tickers(tickers)), + "start_ts": start_ts, + "end_ts": end_ts, + "period_interval": period.value, + }) + response = self.get(f"/markets/candlesticks?{query}") + return { + item["market_ticker"]: CandlestickResponse.model_validate(item) + for item in response.get("markets", []) + } diff --git a/pykalshi/_sync/communications.py b/pykalshi/_sync/communications.py new file mode 100644 index 0000000..b2116e3 --- /dev/null +++ b/pykalshi/_sync/communications.py @@ -0,0 +1,183 @@ +# AUTO-GENERATED from pykalshi/_async/communications.py — do not edit manually. +# Re-run: python scripts/generate_sync.py +from __future__ import annotations + +from typing import TYPE_CHECKING +from urllib.parse import urlencode + +from ..models import RfqModel, QuoteModel +from ..dataframe import DataFrameList + +if TYPE_CHECKING: + from .client import KalshiClient + + +class Communications: + """RFQ (Request for Quote) and quote operations for combo trading. + + Multivariate event combos trade via the RFQ system rather than + standard limit orders. The flow is: + + 1. Create an RFQ broadcasting your intent to trade a combo. + 2. Market makers respond with two-sided quotes. + 3. Accept a quote to execute the trade. + + Usage: + # Create an RFQ for a combo market + rfq = client.communications.create_rfq( + market_ticker="KXMVE-...", + contracts_fp="10.00", + ) + + # List active RFQs + rfqs = client.communications.get_rfqs(status="active") + + # Respond to an RFQ as a market maker + quote = client.communications.create_quote( + rfq_id=rfq.rfq_id, + yes_bid="0.45", + no_bid="0.55", + ) + """ + + def __init__(self, client: KalshiClient) -> None: + self._client = client + + def create_rfq( + self, + market_ticker: str, + *, + contracts_fp: str | None = None, + target_cost_dollars: str | None = None, + rest_remainder: bool = False, + ) -> RfqModel: + """Create a Request for Quote. + + Args: + market_ticker: The combo market ticker to request quotes for. + contracts_fp: Number of contracts to trade (fixed-point string, e.g. "10.00"). + target_cost_dollars: Target cost in dollars (e.g. "10.00"). + Use this OR contracts_fp, not both. + rest_remainder: If True, rest any unfilled portion on the orderbook. + """ + + body: dict = { + "market_ticker": market_ticker.upper(), + "rest_remainder": rest_remainder, + } + if contracts_fp is not None: + body["contracts_fp"] = contracts_fp + if target_cost_dollars is not None: + body["target_cost_dollars"] = target_cost_dollars + + response = self._client.post("/communications/rfqs", body) + return RfqModel.model_validate(response.get("rfq", response)) + + def get_rfqs( + self, + *, + market_ticker: str | None = None, + status: str | None = None, + mve_collection_ticker: str | None = None, + limit: int = 100, + cursor: str | None = None, + fetch_all: bool = False, + ) -> DataFrameList[RfqModel]: + """List RFQs. + + Args: + market_ticker: Filter by combo market ticker. + status: Filter by RFQ status (e.g., "active", "expired"). + mve_collection_ticker: Filter by collection ticker. + limit: Maximum results per page (default 100). + cursor: Pagination cursor. + fetch_all: If True, automatically fetch all pages. + """ + params: dict = {"limit": limit} + if market_ticker: + params["market_ticker"] = market_ticker.upper() + if status: + params["status"] = status + if mve_collection_ticker: + params["mve_collection_ticker"] = mve_collection_ticker + if cursor: + params["cursor"] = cursor + + data = self._client.paginated_get("/communications/rfqs", "rfqs", params, fetch_all) + return DataFrameList(RfqModel.model_validate(r) for r in data) + + def get_rfq(self, rfq_id: str) -> RfqModel: + """Get a single RFQ by ID.""" + response = self._client.get(f"/communications/rfqs/{rfq_id}") + return RfqModel.model_validate(response.get("rfq", response)) + + def create_quote( + self, + rfq_id: str, + *, + yes_bid: str, + no_bid: str, + rest_remainder: bool = False, + ) -> QuoteModel: + """Create a quote in response to an RFQ. + + Prices are in FixedPointDollars (e.g., "0.45"). + + Args: + rfq_id: ID of the RFQ to respond to. + yes_bid: Your bid price for the YES side (FixedPointDollars). + no_bid: Your bid price for the NO side (FixedPointDollars). + rest_remainder: If True, rest any unfilled portion on the orderbook. + """ + body: dict = { + "rfq_id": rfq_id, + "yes_bid": yes_bid, + "no_bid": no_bid, + "rest_remainder": rest_remainder, + } + + response = self._client.post("/communications/quotes", body) + return QuoteModel.model_validate(response.get("quote", response)) + + def get_quotes( + self, + *, + creator_user_id: str | None = None, + rfq_creator_user_id: str | None = None, + rfq_id: str | None = None, + market_ticker: str | None = None, + status: str | None = None, + limit: int = 100, + cursor: str | None = None, + fetch_all: bool = False, + ) -> DataFrameList[QuoteModel]: + """List quotes. + + The API requires at least one of creator_user_id or rfq_creator_user_id. + + Args: + creator_user_id: Filter by quote creator. Required if rfq_creator_user_id not set. + rfq_creator_user_id: Filter by RFQ creator. Required if creator_user_id not set. + rfq_id: Filter by RFQ ID. + market_ticker: Filter by combo market ticker. + status: Filter by quote status. + limit: Maximum results per page (default 100). + cursor: Pagination cursor. + fetch_all: If True, automatically fetch all pages. + """ + params: dict = {"limit": limit} + if creator_user_id: + params["creator_user_id"] = creator_user_id + if rfq_creator_user_id: + params["rfq_creator_user_id"] = rfq_creator_user_id + if rfq_id: + params["rfq_id"] = rfq_id + if market_ticker: + params["market_ticker"] = market_ticker.upper() + if status: + params["status"] = status + if cursor: + params["cursor"] = cursor + + data = self._client.paginated_get("/communications/quotes", "quotes", params, fetch_all) + return DataFrameList(QuoteModel.model_validate(q) for q in data) diff --git a/pykalshi/_sync/events.py b/pykalshi/_sync/events.py new file mode 100644 index 0000000..c827158 --- /dev/null +++ b/pykalshi/_sync/events.py @@ -0,0 +1,101 @@ +# AUTO-GENERATED from pykalshi/_async/events.py — do not edit manually. +# Re-run: python scripts/generate_sync.py +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ..models import EventModel, ForecastPercentileHistory +from ..dataframe import DataFrameList + +if TYPE_CHECKING: + from .client import KalshiClient + from .markets import Market, Series + + +class Event: + """Represents a Kalshi Event. + + An event is a container for related markets (e.g., "Will X happen?" with + multiple outcome markets). + + Key fields are exposed as typed properties for IDE support. + All other EventModel fields are accessible via attribute delegation. + """ + + def __init__(self, client: KalshiClient, data: EventModel) -> None: + self._client = client + self.data = data + + # --- Typed properties for core fields --- + + @property + def event_ticker(self) -> str: + return self.data.event_ticker + + @property + def series_ticker(self) -> str: + return self.data.series_ticker + + @property + def title(self) -> str | None: + return self.data.title + + @property + def category(self) -> str | None: + return self.data.category + + @property + def mutually_exclusive(self) -> bool: + return self.data.mutually_exclusive + + # --- Domain logic --- + + def get_markets(self) -> DataFrameList[Market]: + """Get all markets for this event.""" + return self._client.get_markets(event_ticker=self.data.event_ticker) + + def get_series(self) -> Series: + """Get the parent Series for this event.""" + return self._client.get_series(self.series_ticker) + + def get_forecast_percentile_history( + self, + percentiles: list[int] | None = None, + ) -> ForecastPercentileHistory: + """Get historical forecast data at various percentiles. + + Args: + percentiles: List of percentiles to fetch (e.g., [10, 25, 50, 75, 90]). + If None, returns all available percentiles. + + Returns: + ForecastPercentileHistory with percentile -> history mapping. + """ + endpoint = f"/events/{self.event_ticker}/forecast/percentile_history" + if percentiles: + endpoint += f"?percentiles={','.join(str(p) for p in percentiles)}" + response = self._client.get(endpoint) + return ForecastPercentileHistory.model_validate(response) + + def __getattr__(self, name: str): + return getattr(self.data, name) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Event): + return NotImplemented + return self.data.event_ticker == other.data.event_ticker + + def __hash__(self) -> int: + return hash(self.data.event_ticker) + + def __repr__(self) -> str: + parts = [f"" + + def _repr_html_(self) -> str: + from .._repr import event_html + return event_html(self) diff --git a/pykalshi/_sync/exchange.py b/pykalshi/_sync/exchange.py new file mode 100644 index 0000000..a7bb789 --- /dev/null +++ b/pykalshi/_sync/exchange.py @@ -0,0 +1,49 @@ +# AUTO-GENERATED from pykalshi/_async/exchange.py — do not edit manually. +# Re-run: python scripts/generate_sync.py +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from ..models import ExchangeStatus, Announcement +from ..exceptions import KalshiAPIError + +if TYPE_CHECKING: + from .client import KalshiClient + + +class Exchange: + """Exchange status, schedule, and announcements.""" + + def __init__(self, client: KalshiClient) -> None: + self._client = client + + def get_status(self) -> ExchangeStatus: + """Get current exchange operational status.""" + try: + data = self._client.get("/exchange/status") + except KalshiAPIError as e: + if e.status_code == 503 and isinstance(e.response_body, dict): + data = e.response_body + else: + raise + return ExchangeStatus.model_validate(data) + + def is_trading(self) -> bool: + """Quick check if trading is currently active.""" + status = self.get_status() + return status.trading_active + + def get_schedule(self) -> dict[str, Any]: + """Get exchange trading schedule (raw format).""" + data = self._client.get("/exchange/schedule") + return data.get("schedule", {}) + + def get_announcements(self) -> list[Announcement]: + """Get exchange-wide announcements.""" + data = self._client.get("/exchange/announcements") + return [Announcement.model_validate(a) for a in data.get("announcements", [])] + + def get_user_data_timestamp(self) -> int: + """Get timestamp of last user data validation (Unix ms).""" + data = self._client.get("/exchange/user_data_timestamp") + return data.get("user_data_timestamp", 0) diff --git a/pykalshi/_sync/markets.py b/pykalshi/_sync/markets.py new file mode 100644 index 0000000..9f8b82a --- /dev/null +++ b/pykalshi/_sync/markets.py @@ -0,0 +1,238 @@ +# AUTO-GENERATED from pykalshi/_async/markets.py — do not edit manually. +# Re-run: python scripts/generate_sync.py +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from ..models import MarketModel, CandlestickResponse, OrderbookResponse, SeriesModel, TradeModel +from ..dataframe import DataFrameList +from ..enums import CandlestickPeriod, MarketStatus + +if TYPE_CHECKING: + from .client import KalshiClient + from .events import Event + +logger = logging.getLogger(__name__) + + +class Market: + """Represents a Kalshi Market. + + Key fields are exposed as typed properties for IDE support. + All other MarketModel fields are accessible via attribute delegation. + """ + + def __init__(self, client: KalshiClient, data: MarketModel) -> None: + self._client = client + self.data = data + + # --- Typed properties for core fields --- + + @property + def ticker(self) -> str: + return self.data.ticker + + @property + def event_ticker(self) -> str | None: + return self.data.event_ticker + + @property + def status(self) -> MarketStatus | None: + return self.data.status + + @property + def title(self) -> str | None: + return self.data.title + + @property + def subtitle(self) -> str | None: + return self.data.subtitle + + @property + def yes_bid_dollars(self) -> str | None: + return self.data.yes_bid_dollars + + @property + def yes_ask_dollars(self) -> str | None: + return self.data.yes_ask_dollars + + @property + def no_bid_dollars(self) -> str | None: + return self.data.no_bid_dollars + + @property + def no_ask_dollars(self) -> str | None: + return self.data.no_ask_dollars + + @property + def last_price_dollars(self) -> str | None: + return self.data.last_price_dollars + + @property + def volume_fp(self) -> str | None: + return self.data.volume_fp + + @property + def volume_24h_fp(self) -> str | None: + return self.data.volume_24h_fp + + @property + def open_interest_fp(self) -> str | None: + return self.data.open_interest_fp + + @property + def liquidity_dollars(self) -> str | None: + return self.data.liquidity_dollars + + @property + def open_time(self) -> str | None: + return self.data.open_time + + @property + def close_time(self) -> str | None: + return self.data.close_time + + @property + def result(self) -> str | None: + return self.data.result + + @property + def series_ticker(self) -> str | None: + return self.data.series_ticker + + def resolve_series_ticker(self) -> str | None: + """Fetch series_ticker from the event API if not present in market data.""" + if self.data.series_ticker is not None: + return self.data.series_ticker + if not self.data.event_ticker: + return None + try: + event_response = self._client.get(f"/events/{self.data.event_ticker}") + return event_response["event"]["series_ticker"] + except Exception as e: + logger.warning( + "Failed to resolve series_ticker for %s: %s", self.data.ticker, e + ) + return None + + def get_orderbook(self, *, depth: int | None = None) -> OrderbookResponse: + """Get the orderbook for this market.""" + endpoint = f"/markets/{self.data.ticker}/orderbook" + if depth: + endpoint += f"?depth={depth}" + response = self._client.get(endpoint) + return OrderbookResponse.model_validate(response) + + def get_candlesticks( + self, + start_ts: int, + end_ts: int, + period: CandlestickPeriod = CandlestickPeriod.ONE_HOUR, + ) -> CandlestickResponse: + """Get candlestick data for this market.""" + series = self.resolve_series_ticker() + if not series: + raise ValueError(f"Market {self.data.ticker} does not have a series_ticker.") + + query = f"start_ts={start_ts}&end_ts={end_ts}&period_interval={period.value}" + endpoint = f"/series/{series}/markets/{self.data.ticker}/candlesticks?{query}" + response = self._client.get(endpoint) + return CandlestickResponse.model_validate(response) + + def get_trades( + self, + min_ts: int | None = None, + max_ts: int | None = None, + limit: int = 100, + cursor: str | None = None, + fetch_all: bool = False, + ) -> DataFrameList[TradeModel]: + """Get public trade history for this market.""" + return self._client.get_trades( + ticker=self.ticker, + min_ts=min_ts, + max_ts=max_ts, + limit=limit, + cursor=cursor, + fetch_all=fetch_all, + ) + + def get_event(self) -> Event | None: + """Get the parent Event for this market.""" + if not self.event_ticker: + return None + return self._client.get_event(self.event_ticker) + + def __getattr__(self, name: str): + return getattr(self.data, name) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Market): + return NotImplemented + return self.data.ticker == other.data.ticker + + def __hash__(self) -> int: + return hash(self.data.ticker) + + def __repr__(self) -> str: + status = self.status.value if self.status else "?" + parts = [f"" + + def _repr_html_(self) -> str: + from .._repr import market_html + return market_html(self) + + +class Series: + """Represents a Kalshi Series (collection of related markets).""" + + def __init__(self, client: KalshiClient, data: SeriesModel) -> None: + self._client = client + self.data = data + + @property + def ticker(self) -> str: + return self.data.ticker + + @property + def title(self) -> str | None: + return self.data.title + + @property + def category(self) -> str | None: + return self.data.category + + def get_markets(self, **kwargs) -> DataFrameList[Market]: + """Get all markets in this series.""" + return self._client.get_markets(series_ticker=self.ticker, **kwargs) + + def get_events(self, **kwargs) -> DataFrameList[Event]: + """Get all events in this series.""" + return self._client.get_events(series_ticker=self.ticker, **kwargs) + + def __getattr__(self, name: str): + return getattr(self.data, name) + + def __repr__(self) -> str: + parts = [f"" + + def _repr_html_(self) -> str: + from .._repr import series_html + return series_html(self) diff --git a/pykalshi/_sync/mve.py b/pykalshi/_sync/mve.py new file mode 100644 index 0000000..f1dcf4c --- /dev/null +++ b/pykalshi/_sync/mve.py @@ -0,0 +1,127 @@ +# AUTO-GENERATED from pykalshi/_async/mve.py — do not edit manually. +# Re-run: python scripts/generate_sync.py +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ..models import MveCollectionModel, MveSelectedLeg, EventModel, MarketModel +from ..dataframe import DataFrameList +from ..enums import Side + +if TYPE_CHECKING: + from .client import KalshiClient + from .events import Event + from .markets import Market + + +class MveCollection: + """Represents a multivariate event collection (combo container). + + Collections define which events can be combined into combo markets. + Use create_market() to create a tradeable combo, then trade it via + client.communications (RFQ system). + """ + + def __init__(self, client: KalshiClient, data: MveCollectionModel) -> None: + self._client = client + self.data = data + + @property + def collection_ticker(self) -> str: + return self.data.collection_ticker + + @property + def title(self) -> str | None: + return self.data.title + + @property + def series_ticker(self) -> str | None: + return self.data.series_ticker + + def create_market( + self, + selected_markets: list[dict[str, str]], + ) -> Market: + """Create a combo market in this collection. + + Must be called before trading or looking up a combo. Each entry + specifies a leg of the combo. + + Args: + selected_markets: List of leg dicts, each with keys: + - market_ticker: The market ticker for this leg. + - event_ticker: The event ticker for this leg. + - side: "yes" or "no". + + Returns: + The created combo Market. + + Example: + market = collection.create_market([ + {"market_ticker": "KXABC-A", "event_ticker": "KXABC", "side": "yes"}, + {"market_ticker": "KXDEF-B", "event_ticker": "KXDEF", "side": "yes"}, + ]) + """ + from .markets import Market + + body = {"selected_markets": selected_markets, "with_market_payload": True} + response = self._client.post( + f"/multivariate_event_collections/{self.collection_ticker}", body + ) + model = MarketModel.model_validate(response.get("market", response)) + return Market(self._client, model) + + def lookup_ticker( + self, + selected_markets: list[dict[str, str]], + ) -> dict: + """Look up tickers for a combo market by its leg combination. + + Returns 404 if the combination hasn't been previously created + via create_market(). + + Args: + selected_markets: List of leg dicts (same format as create_market). + + Returns: + Dict with market_ticker and event_ticker for the combo. + """ + body = {"selected_markets": selected_markets} + return self._client.put( + f"/multivariate_event_collections/{self.collection_ticker}/lookup", body + ) + + def get_events(self, *, with_nested_markets: bool = False) -> DataFrameList[Event]: + """Get multivariate events in this collection. + + Args: + with_nested_markets: If True, include markets nested in each event. + """ + return self._client.get_multivariate_events( + collection_ticker=self.collection_ticker, + with_nested_markets=with_nested_markets, + ) + + def __getattr__(self, name: str): + return getattr(self.data, name) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, MveCollection): + return NotImplemented + return self.data.collection_ticker == other.data.collection_ticker + + def __hash__(self) -> int: + return hash(self.data.collection_ticker) + + def __repr__(self) -> str: + parts = [f"" + + def _repr_html_(self) -> str: + from .._repr import mve_collection_html + return mve_collection_html(self) diff --git a/pykalshi/_sync/orders.py b/pykalshi/_sync/orders.py new file mode 100644 index 0000000..8937548 --- /dev/null +++ b/pykalshi/_sync/orders.py @@ -0,0 +1,190 @@ +# AUTO-GENERATED from pykalshi/_async/orders.py — do not edit manually. +# Re-run: python scripts/generate_sync.py +from __future__ import annotations + +import time +from typing import TYPE_CHECKING + +from ..models import OrderModel +from ..enums import OrderStatus, Action, Side, OrderType + +if TYPE_CHECKING: + from .client import KalshiClient + +TERMINAL_STATUSES = frozenset({OrderStatus.CANCELED, OrderStatus.EXECUTED}) + + +class Order: + """Represents a Kalshi order. + + Key fields are exposed as typed properties for IDE support. + All other OrderModel fields are accessible via attribute delegation. + """ + + def __init__(self, client: KalshiClient, data: OrderModel) -> None: + self._client = client + self.data = data + + # --- Typed properties for core fields --- + + @property + def order_id(self) -> str: + return self.data.order_id + + @property + def ticker(self) -> str: + return self.data.ticker + + @property + def status(self) -> OrderStatus: + return self.data.status + + @property + def action(self) -> Action | None: + return self.data.action + + @property + def side(self) -> Side | None: + return self.data.side + + @property + def type(self) -> OrderType | None: + return self.data.type + + @property + def yes_price_dollars(self) -> str | None: + return self.data.yes_price_dollars + + @property + def no_price_dollars(self) -> str | None: + return self.data.no_price_dollars + + @property + def initial_count_fp(self) -> str | None: + return self.data.initial_count_fp + + @property + def fill_count_fp(self) -> str | None: + return self.data.fill_count_fp + + @property + def remaining_count_fp(self) -> str | None: + return self.data.remaining_count_fp + + @property + def created_time(self) -> str | None: + return self.data.created_time + + # --- Domain logic --- + + def cancel(self) -> Order: + """Cancel this order. + + Returns: + Self with updated data (status will be CANCELED). + """ + updated = self._client.portfolio.cancel_order(self.order_id) + self.data = updated.data + return self + + def amend( + self, + *, + count_fp: str | None = None, + yes_price_dollars: str | None = None, + no_price_dollars: str | None = None, + ) -> Order: + """Amend this order's price or count. + + Args: + count_fp: New total contract count (fixed-point string). + yes_price_dollars: New YES price (dollar string). + no_price_dollars: New NO price (dollar string, converted to yes internally). + + Returns: + Self with updated data. + """ + updated = self._client.portfolio.amend_order( + self.order_id, + count_fp=count_fp, + yes_price_dollars=yes_price_dollars, + no_price_dollars=no_price_dollars, + ticker=self.ticker, + action=self.action, + side=self.side, + ) + self.data = updated.data + return self + + def decrease(self, reduce_by_fp: str) -> Order: + """Decrease the remaining count of this order. + + Args: + reduce_by_fp: Number of contracts to reduce by (fixed-point string). + + Returns: + Self with updated data. + """ + updated = self._client.portfolio.decrease_order(self.order_id, reduce_by_fp) + self.data = updated.data + return self + + def refresh(self) -> Order: + """Re-fetch this order's current state from the API. + + Returns: + Self with updated data. + """ + updated = self._client.portfolio.get_order(self.order_id) + self.data = updated.data + return self + + def wait_until_terminal( + self, timeout: float = 30.0, poll_interval: float = 0.5 + ) -> Order: + """Block until order reaches a terminal state. + + Terminal states are: CANCELED, EXECUTED. + + Args: + timeout: Maximum seconds to wait before raising TimeoutError. + poll_interval: Seconds between refresh calls. + + Returns: + Self with updated data. + + Raises: + TimeoutError: If timeout is reached before terminal state. + """ + deadline = time.monotonic() + timeout + while self.status not in TERMINAL_STATUSES: + if time.monotonic() >= deadline: + raise TimeoutError( + f"Order {self.order_id} still {self.status.value} after {timeout}s" + ) + time.sleep(poll_interval) + self.refresh() + return self + + def __getattr__(self, name: str): + return getattr(self.data, name) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Order): + return NotImplemented + return self.data.order_id == other.data.order_id + + def __hash__(self) -> int: + return hash(self.data.order_id) + + def __repr__(self) -> str: + action = self.action.value.upper() if self.action else "?" + side = self.side.value.upper() if self.side else "?" + price = self.yes_price_dollars if self.yes_price_dollars is not None else self.no_price_dollars + filled = self.fill_count_fp or "0" + total = self.initial_count_fp or "0" + return f"" + + def _repr_html_(self) -> str: + from .._repr import order_html + return order_html(self) diff --git a/pykalshi/_sync/portfolio.py b/pykalshi/_sync/portfolio.py new file mode 100644 index 0000000..c6db05e --- /dev/null +++ b/pykalshi/_sync/portfolio.py @@ -0,0 +1,649 @@ +# AUTO-GENERATED from pykalshi/_async/portfolio.py — do not edit manually. +# Re-run: python scripts/generate_sync.py +from __future__ import annotations + +from decimal import Decimal +from typing import TYPE_CHECKING +from urllib.parse import urlencode + +from .orders import Order +from ..enums import Action, Side, OrderStatus, TimeInForce, SelfTradePrevention, PositionCountFilter +from ..dataframe import DataFrameList +from .._utils import normalize_ticker, normalize_tickers +from ..models import ( + OrderModel, BalanceModel, PositionModel, FillModel, + SettlementModel, QueuePositionModel, OrderGroupModel, + SubaccountModel, SubaccountBalanceModel, SubaccountTransferModel, +) + +if TYPE_CHECKING: + from .client import KalshiClient + from .markets import Market + + +class Portfolio: + """Authenticated user's portfolio and trading operations.""" + + def __init__(self, client: KalshiClient) -> None: + self._client = client + + def get_balance(self) -> BalanceModel: + """Get portfolio balance. Values are dollar strings.""" + data = self._client.get("/portfolio/balance") + return BalanceModel.model_validate(data) + + def place_order( + self, + ticker: str | Market, + action: Action, + side: Side, + count_fp: str, + *, + yes_price_dollars: str | None = None, + no_price_dollars: str | None = None, + client_order_id: str | None = None, + time_in_force: TimeInForce | None = None, + post_only: bool = False, + reduce_only: bool = False, + expiration_ts: int | None = None, + buy_max_cost_dollars: str | None = None, + self_trade_prevention: SelfTradePrevention | None = None, + order_group_id: str | None = None, + subaccount: int | None = None, + cancel_order_on_pause: bool | None = None, + ) -> Order: + """Place an order on a market. + + Args: + ticker: Market ticker string or Market object. + action: BUY or SELL. + side: YES or NO. + count_fp: Number of contracts (fixed-point string, e.g. "10.00"). + yes_price_dollars: Price as dollar string (e.g. "0.45"). + no_price_dollars: Price as dollar string. Converted to + yes_price_dollars internally (yes = 1.00 - no). + client_order_id: Idempotency key. Resubmitting returns existing order. + time_in_force: GTC (default), IOC (immediate-or-cancel), FOK (fill-or-kill). + post_only: If True, reject order if it would take liquidity. + reduce_only: If True, only reduce existing position, never increase. + expiration_ts: Unix timestamp when order auto-cancels. + buy_max_cost_dollars: Maximum total cost (dollar string). Protects against slippage. + self_trade_prevention: Behavior on self-cross (CANCEL_RESTING or CANCEL_INCOMING). + order_group_id: Link to an order group for OCO/bracket strategies. + subaccount: Subaccount number (0 for primary, 1-32 for subaccounts). + cancel_order_on_pause: If True, cancel order if market is paused. + """ + # Extract market structure for validation when a Market object is passed + pls = None + fte = None + if not isinstance(ticker, str): + pls = getattr(ticker, 'price_level_structure', None) + fte = getattr(ticker, 'fractional_trading_enabled', None) + + order_data = self._build_order_data( + ticker, action, side, count_fp, + yes_price_dollars=yes_price_dollars, no_price_dollars=no_price_dollars, + client_order_id=client_order_id, time_in_force=time_in_force, + post_only=post_only, reduce_only=reduce_only, + expiration_ts=expiration_ts, buy_max_cost_dollars=buy_max_cost_dollars, + self_trade_prevention=self_trade_prevention, + order_group_id=order_group_id, subaccount=subaccount, + cancel_order_on_pause=cancel_order_on_pause, + price_level_structure=pls, + fractional_trading_enabled=fte, + ) + response = self._client.post("/portfolio/orders", order_data) + model = OrderModel.model_validate(response["order"]) + return Order(self._client, model) + + def cancel_order(self, order_id: str, *, subaccount: int | None = None) -> Order: + """Cancel a resting order. + + Args: + order_id: ID of the order to cancel. + subaccount: Subaccount number (0 for primary, 1-32 for subaccounts). + + Returns: + The canceled Order with updated status. + """ + endpoint = f"/portfolio/orders/{order_id}" + if subaccount is not None: + endpoint += f"?subaccount={subaccount}" + response = self._client.delete(endpoint) + model = OrderModel.model_validate(response["order"]) + return Order(self._client, model) + + def amend_order( + self, + order_id: str, + *, + count_fp: str | None = None, + yes_price_dollars: str | None = None, + no_price_dollars: str | None = None, + subaccount: int | None = None, + # Required by API but can be fetched from existing order + ticker: str | None = None, + action: Action | None = None, + side: Side | None = None, + ) -> Order: + """Amend a resting order's price or count. + + Args: + order_id: ID of the order to amend. + count_fp: New total contract count (fixed-point string). + yes_price_dollars: New YES price (dollar string). + no_price_dollars: New NO price (dollar string). Converted internally. + subaccount: Subaccount number (0 for primary, 1-32 for subaccounts). + ticker: Market ticker (fetched from order if not provided). + action: Order action (fetched from order if not provided). + side: Order side (fetched from order if not provided). + """ + if yes_price_dollars is not None and no_price_dollars is not None: + raise ValueError("Specify yes_price_dollars or no_price_dollars, not both") + + if no_price_dollars is not None: + yes_price_dollars = str(Decimal("1") - Decimal(no_price_dollars)) + + ticker = normalize_ticker(ticker) + + # Fetch original order to get required fields if not provided + if ticker is None or action is None or side is None or count_fp is None: + original = self.get_order(order_id) + ticker = ticker or original.ticker + action = action or original.action + side = side or original.side + if count_fp is None: + count_fp = original.remaining_count_fp + + body: dict = { + "ticker": ticker, + "action": action.value if isinstance(action, Action) else action, + "side": side.value if isinstance(side, Side) else side, + "count_fp": count_fp, + } + if yes_price_dollars is not None: + body["yes_price_dollars"] = yes_price_dollars + if subaccount is not None: + body["subaccount"] = subaccount + + if "count_fp" not in body and "yes_price_dollars" not in body: + raise ValueError("Must specify at least one of count_fp, yes_price_dollars, or no_price_dollars") + + response = self._client.post(f"/portfolio/orders/{order_id}/amend", body) + model = OrderModel.model_validate(response["order"]) + return Order(self._client, model) + + def decrease_order(self, order_id: str, reduce_by_fp: str) -> Order: + """Decrease the remaining count of a resting order. + + Args: + order_id: ID of the order to decrease. + reduce_by_fp: Number of contracts to reduce by (fixed-point string). + """ + response = self._client.post( + f"/portfolio/orders/{order_id}/decrease", {"reduce_by_fp": reduce_by_fp} + ) + model = OrderModel.model_validate(response["order"]) + return Order(self._client, model) + + def get_orders( + self, + *, + status: OrderStatus | None = None, + ticker: str | None = None, + event_ticker: str | None = None, + min_ts: int | None = None, + max_ts: int | None = None, + limit: int = 100, + cursor: str | None = None, + fetch_all: bool = False, + **extra_params, + ) -> DataFrameList[Order]: + """Get list of orders. + + Args: + status: Filter by order status (resting, canceled, executed). + ticker: Filter by market ticker. + event_ticker: Filter by event ticker (supports comma-separated, max 10). + min_ts: Filter orders after this Unix timestamp. + max_ts: Filter orders before this Unix timestamp. + limit: Maximum results per page (default 100, max 200). + cursor: Pagination cursor for fetching next page. + fetch_all: If True, automatically fetch all pages. + **extra_params: Additional API parameters (e.g., subaccount). + """ + params = { + "limit": limit, + "status": status.value if status is not None else None, + "ticker": normalize_ticker(ticker), + "event_ticker": normalize_ticker(event_ticker), + "min_ts": min_ts, + "max_ts": max_ts, + "cursor": cursor, + **extra_params, + } + data = self._client.paginated_get("/portfolio/orders", "orders", params, fetch_all) + return DataFrameList(Order(self._client, OrderModel.model_validate(d)) for d in data) + + def get_order(self, order_id: str) -> Order: + """Get a single order by ID.""" + response = self._client.get(f"/portfolio/orders/{order_id}") + model = OrderModel.model_validate(response["order"]) + return Order(self._client, model) + + def get_positions( + self, + *, + ticker: str | None = None, + event_ticker: str | None = None, + count_filter: PositionCountFilter | None = None, + limit: int = 100, + cursor: str | None = None, + fetch_all: bool = False, + **extra_params, + ) -> DataFrameList[PositionModel]: + """Get portfolio positions. + + Args: + ticker: Filter by specific market ticker. + event_ticker: Filter by event ticker (supports comma-separated, max 10). + count_filter: Filter positions with non-zero values (POSITION or TOTAL_TRADED). + limit: Maximum positions per page (default 100, max 1000). + cursor: Pagination cursor for fetching next page. + fetch_all: If True, automatically fetch all pages. + **extra_params: Additional API parameters (e.g., subaccount). + """ + params = { + "limit": limit, + "ticker": normalize_ticker(ticker), + "event_ticker": normalize_ticker(event_ticker), + "count_filter": count_filter.value if count_filter is not None else None, + "cursor": cursor, + **extra_params, + } + data = self._client.paginated_get("/portfolio/positions", "market_positions", params, fetch_all) + return DataFrameList(PositionModel.model_validate(p) for p in data) + + def get_fills( + self, + *, + ticker: str | None = None, + order_id: str | None = None, + min_ts: int | None = None, + max_ts: int | None = None, + limit: int = 100, + cursor: str | None = None, + fetch_all: bool = False, + **extra_params, + ) -> DataFrameList[FillModel]: + """Get trade fills (executed trades). + + Args: + ticker: Filter by market ticker. + order_id: Filter by specific order ID. + min_ts: Minimum timestamp (Unix seconds). + max_ts: Maximum timestamp (Unix seconds). + limit: Maximum fills per page (default 100, max 200). + cursor: Pagination cursor for fetching next page. + fetch_all: If True, automatically fetch all pages. + **extra_params: Additional API parameters (e.g., subaccount). + """ + params = { + "limit": limit, + "ticker": normalize_ticker(ticker), + "order_id": order_id, + "min_ts": min_ts, + "max_ts": max_ts, + "cursor": cursor, + **extra_params, + } + data = self._client.paginated_get("/portfolio/fills", "fills", params, fetch_all) + return DataFrameList(FillModel.model_validate(f) for f in data) + + # --- Batch Operations --- + + def batch_place_orders(self, orders: list[dict]) -> DataFrameList[Order]: + """Place multiple orders atomically. + + Args: + orders: List of order dicts with keys: ticker, action, side, count_fp, + yes_price_dollars/no_price_dollars, and optional advanced params. + + Example: + orders = [ + {"ticker": "KXBTC", "action": "buy", "side": "yes", "count_fp": "10.00", "yes_price_dollars": "0.45"}, + {"ticker": "KXBTC", "action": "buy", "side": "no", "count_fp": "10.00", "no_price_dollars": "0.45"}, + ] + results = portfolio.batch_place_orders(orders) + """ + prepared = self._build_batch_orders(orders) + response = self._client.post("/portfolio/orders/batched", {"orders": prepared}) + result = [] + for item in response.get("orders", []): + order_data = item.get("order") + if order_data is None: + continue + result.append(Order(self._client, OrderModel.model_validate(order_data))) + return DataFrameList(result) + + def batch_cancel_orders(self, order_ids: list[str]) -> DataFrameList[Order]: + """Cancel multiple orders atomically. + + Args: + order_ids: List of order IDs to cancel (max 20). + + Returns: + The canceled Orders with updated status. + """ + orders = [{"order_id": oid} for oid in order_ids] + response = self._client.delete("/portfolio/orders/batched", {"orders": orders}) + result = [] + for item in response.get("orders", []): + order_data = item.get("order") + if order_data is None: + continue + result.append(Order(self._client, OrderModel.model_validate(order_data))) + return DataFrameList(result) + + # --- Queue Position --- + + def get_queue_position(self, order_id: str) -> QueuePositionModel: + """Get queue position for a single resting order.""" + response = self._client.get(f"/portfolio/orders/{order_id}/queue_position") + return QueuePositionModel( + order_id=order_id, + queue_position_fp=response.get("queue_position_fp", "0.00"), + ) + + def get_queue_positions( + self, + *, + market_tickers: list[str] | None = None, + event_ticker: str | None = None, + ) -> DataFrameList[QueuePositionModel]: + """Get queue positions for all resting orders.""" + params: dict = {} + if market_tickers: + params["market_tickers"] = ",".join(normalize_tickers(market_tickers)) + if event_ticker: + params["event_ticker"] = normalize_ticker(event_ticker) + + endpoint = "/portfolio/orders/queue_positions" + if params: + endpoint = f"{endpoint}?{urlencode(params)}" + + response = self._client.get(endpoint) + return DataFrameList( + QueuePositionModel.model_validate(qp) + for qp in response.get("queue_positions", []) + ) + + # --- Settlements --- + + def get_settlements( + self, + *, + ticker: str | None = None, + event_ticker: str | None = None, + limit: int = 100, + cursor: str | None = None, + fetch_all: bool = False, + **extra_params, + ) -> DataFrameList[SettlementModel]: + """Get settlement records for resolved positions.""" + params = { + "limit": limit, + "ticker": normalize_ticker(ticker), + "event_ticker": normalize_ticker(event_ticker), + "cursor": cursor, + **extra_params, + } + data = self._client.paginated_get("/portfolio/settlements", "settlements", params, fetch_all) + return DataFrameList(SettlementModel.model_validate(s) for s in data) + + def get_resting_order_value(self) -> str: + """Get total value of all resting orders as dollar string. + + NOTE: This endpoint is FCM-only (institutional accounts). + """ + response = self._client.get("/portfolio/summary/total_resting_order_value") + return response.get("total_resting_order_value_dollars", "0") + + # --- Order Groups (Contract Rate Limiting) --- + + def create_order_group(self, contracts_limit_fp: str) -> OrderGroupModel: + """Create an order group for rate-limiting contract matches. + + Args: + contracts_limit_fp: Maximum contracts (fixed-point string) that can be + matched in a rolling 15-second window. + + Returns: + Created OrderGroupModel. + """ + body: dict = {"contracts_limit_fp": contracts_limit_fp} + response = self._client.post("/portfolio/order_groups/create", body) + return OrderGroupModel.model_validate(response) + + def get_order_group(self, order_group_id: str) -> OrderGroupModel: + """Get an order group by ID.""" + response = self._client.get(f"/portfolio/order_groups/{order_group_id}") + response["id"] = order_group_id + return OrderGroupModel.model_validate(response) + + def trigger_order_group(self, order_group_id: str) -> None: + """Manually trigger an order group, cancelling all orders in it.""" + self._client.put(f"/portfolio/order_groups/{order_group_id}/trigger", {}) + + def get_order_groups(self) -> DataFrameList[OrderGroupModel]: + """List all order groups.""" + response = self._client.get("/portfolio/order_groups") + return DataFrameList( + OrderGroupModel.model_validate(og) + for og in response.get("order_groups", []) + ) + + def reset_order_group(self, order_group_id: str) -> None: + """Reset matched contract counter for an order group.""" + self._client.put(f"/portfolio/order_groups/{order_group_id}/reset", {}) + + def update_order_group_limit(self, order_group_id: str, contracts_limit_fp: str) -> None: + """Update the contracts limit for an order group. + + Args: + order_group_id: ID of the order group. + contracts_limit_fp: New maximum contracts (fixed-point string). + """ + body: dict = {"contracts_limit_fp": contracts_limit_fp} + self._client.put(f"/portfolio/order_groups/{order_group_id}/limit", body) + + # --- Subaccounts --- + + def create_subaccount(self) -> SubaccountModel: + """Create a new numbered subaccount.""" + response = self._client.post("/portfolio/subaccounts", {}) + return SubaccountModel.model_validate(response.get("subaccount", response)) + + def transfer_between_subaccounts( + self, + from_subaccount_id: str, + to_subaccount_id: str, + amount_dollars: str, + ) -> SubaccountTransferModel: + """Transfer funds between subaccounts. + + Args: + from_subaccount_id: Source subaccount ID. + to_subaccount_id: Destination subaccount ID. + amount_dollars: Amount to transfer (dollar string). + """ + body = { + "from_subaccount_id": from_subaccount_id, + "to_subaccount_id": to_subaccount_id, + "amount_dollars": amount_dollars, + } + response = self._client.post("/portfolio/subaccounts/transfer", body) + return SubaccountTransferModel.model_validate(response.get("transfer", response)) + + def get_subaccount_balances(self) -> DataFrameList[SubaccountBalanceModel]: + """Get balances for all subaccounts.""" + response = self._client.get("/portfolio/subaccounts/balances") + return DataFrameList( + SubaccountBalanceModel.model_validate(b) + for b in response.get("balances", []) + ) + + def get_subaccount_transfers( + self, + *, + limit: int = 100, + cursor: str | None = None, + fetch_all: bool = False, + **extra_params, + ) -> DataFrameList[SubaccountTransferModel]: + """Get transfer history between subaccounts.""" + params = {"limit": limit, "cursor": cursor, **extra_params} + data = self._client.paginated_get( + "/portfolio/subaccounts/transfers", "transfers", params, fetch_all + ) + return DataFrameList(SubaccountTransferModel.model_validate(t) for t in data) + + # --- Shared validation helpers --- + + @staticmethod + def _validate_tick_size(price: Decimal, price_level_structure: str) -> None: + """Validate that price aligns to the market's tick size. + + Raises ValueError if the price is not on a valid tick boundary. + """ + if price_level_structure == "linear_cent": + # $0.00-$1.00, tick $0.01 + tick = Decimal("0.01") + if price % tick != 0: + raise ValueError( + f"Price {price} is not on a valid tick for linear_cent " + f"(tick size $0.01)" + ) + elif price_level_structure == "deci_cent": + # $0.00-$1.00, tick $0.001 + tick = Decimal("0.001") + if price % tick != 0: + raise ValueError( + f"Price {price} is not on a valid tick for deci_cent " + f"(tick size $0.001)" + ) + elif price_level_structure == "tapered_deci_cent": + # $0.00-$0.10: tick $0.001, $0.10-$0.90: tick $0.01, $0.90-$1.00: tick $0.001 + if price <= Decimal("0.10") or price >= Decimal("0.90"): + tick = Decimal("0.001") + else: + tick = Decimal("0.01") + if price % tick != 0: + raise ValueError( + f"Price {price} is not on a valid tick for tapered_deci_cent " + f"(tick size ${tick} in this price range)" + ) + + @staticmethod + def _validate_fractional(count_fp: str, fractional_enabled: bool) -> None: + """Validate count_fp is whole when fractional trading is disabled.""" + if not fractional_enabled: + d = Decimal(count_fp) + if d != int(d): + raise ValueError( + f"Fractional trading is not enabled for this market. " + f"count_fp must be a whole number, got {count_fp}" + ) + + @staticmethod + def _build_order_data( + ticker, + action: Action, + side: Side, + count_fp: str, + *, + yes_price_dollars=None, + no_price_dollars=None, + client_order_id=None, + time_in_force=None, + post_only=False, + reduce_only=False, + expiration_ts=None, + buy_max_cost_dollars=None, + self_trade_prevention=None, + order_group_id=None, + subaccount=None, + cancel_order_on_pause=None, + price_level_structure=None, + fractional_trading_enabled=None, + ) -> dict: + """Build and validate order data dict. No I/O. + + If price_level_structure is provided, validates tick size alignment. + If fractional_trading_enabled is provided (False), validates count_fp is whole. + """ + if yes_price_dollars is not None and no_price_dollars is not None: + raise ValueError("Specify yes_price_dollars or no_price_dollars, not both") + + if yes_price_dollars is None and no_price_dollars is None: + raise ValueError("Limit orders require yes_price_dollars or no_price_dollars") + + if no_price_dollars is not None: + yes_price_dollars = str(Decimal("1") - Decimal(no_price_dollars)) + + # Validate tick size if market structure is known + if price_level_structure and yes_price_dollars is not None: + Portfolio._validate_tick_size(Decimal(yes_price_dollars), price_level_structure) + + # Validate fractional trading + if fractional_trading_enabled is not None: + Portfolio._validate_fractional(count_fp, fractional_trading_enabled) + + ticker_str = ticker.upper() if isinstance(ticker, str) else ticker.ticker + + order_data: dict = { + "ticker": ticker_str, + "action": action.value, + "side": side.value, + "count_fp": count_fp, + "yes_price_dollars": yes_price_dollars, + } + if client_order_id is not None: + order_data["client_order_id"] = client_order_id + if time_in_force is not None: + order_data["time_in_force"] = time_in_force.value + if post_only: + order_data["post_only"] = True + if reduce_only: + order_data["reduce_only"] = True + if expiration_ts is not None: + order_data["expiration_ts"] = expiration_ts + if buy_max_cost_dollars is not None: + order_data["buy_max_cost_dollars"] = buy_max_cost_dollars + if self_trade_prevention is not None: + order_data["self_trade_prevention_type"] = self_trade_prevention.value + if order_group_id is not None: + order_data["order_group_id"] = order_group_id + if subaccount is not None: + order_data["subaccount"] = subaccount + if cancel_order_on_pause is not None: + order_data["cancel_order_on_pause"] = cancel_order_on_pause + return order_data + + @staticmethod + def _build_batch_orders(orders: list[dict]) -> list[dict]: + """Validate and prepare batch orders. No I/O.""" + prepared = [] + for order in orders: + o = dict(order) + + if "yes_price_dollars" in o and "no_price_dollars" in o: + raise ValueError("Specify yes_price_dollars or no_price_dollars, not both") + if "yes_price_dollars" not in o and "no_price_dollars" not in o: + raise ValueError("Limit orders require yes_price_dollars or no_price_dollars") + if "no_price_dollars" in o: + o["yes_price_dollars"] = str(Decimal("1") - Decimal(o.pop("no_price_dollars"))) + # Strip "type" -- Kalshi API no longer accepts it + o.pop("type", None) + prepared.append(o) + return prepared diff --git a/pykalshi/aclient.py b/pykalshi/aclient.py index 74b52be..2d56ecd 100644 --- a/pykalshi/aclient.py +++ b/pykalshi/aclient.py @@ -1,405 +1,3 @@ -"""Kalshi API Client — async.""" +from ._async.client import AsyncKalshiClient as AsyncKalshiClient -from __future__ import annotations - -import asyncio -import json -import logging -from functools import cached_property -from typing import Any, TYPE_CHECKING -from urllib.parse import urlencode - -import httpx - -from ._base import _BaseKalshiClient, _RETRYABLE_STATUS_CODES -from .events import AsyncEvent -from .markets import AsyncMarket, AsyncSeries -from .mve import AsyncMveCollection -from .models import MarketModel, EventModel, SeriesModel, TradeModel, CandlestickResponse, MveCollectionModel -from .dataframe import DataFrameList -from .portfolio import AsyncPortfolio -from .enums import MarketStatus, CandlestickPeriod -from .exchange import AsyncExchange -from .api_keys import AsyncAPIKeys -from .communications import AsyncCommunications -from .exceptions import RateLimitError -from ._utils import normalize_ticker, normalize_tickers - -if TYPE_CHECKING: - from .afeed import AsyncFeed - from .rate_limiter import AsyncRateLimiterProtocol - -logger = logging.getLogger(__name__) - - -class AsyncKalshiClient(_BaseKalshiClient): - """Async authenticated client for the Kalshi Trading API. - - Usage: - async with AsyncKalshiClient.from_env() as client: - market = await client.get_market("TICKER") - balance = await client.portfolio.get_balance() - """ - - def __init__( - self, - api_key_id: str | None = None, - private_key_path: str | None = None, - api_base: str | None = None, - demo: bool = False, - timeout: float = 10.0, - max_retries: int = 3, - rate_limiter: AsyncRateLimiterProtocol | None = None, - ) -> None: - super().__init__( - api_key_id=api_key_id, - private_key_path=private_key_path, - api_base=api_base, - demo=demo, - timeout=timeout, - max_retries=max_retries, - rate_limiter=rate_limiter, - ) - self._session = httpx.AsyncClient() - - async def aclose(self) -> None: - """Close the underlying HTTP connection pool.""" - await self._session.aclose() - - async def __aenter__(self) -> AsyncKalshiClient: - return self - - async def __aexit__(self, *args: Any) -> None: - await self.aclose() - - # --- HTTP methods --- - - async def _request(self, method: str, endpoint: str, **kwargs: Any) -> httpx.Response: - """Execute async HTTP request with retry on transient failures.""" - url = f"{self.api_base}{endpoint}" - - for attempt in range(self.max_retries + 1): - if self.rate_limiter is not None: - wait_time = await self.rate_limiter.acquire() - if wait_time > 0: - logger.debug("Rate limiter waited %.3fs", wait_time) - - headers = self._get_headers(method, endpoint) - request_kwargs: dict[str, Any] = {"headers": headers, "timeout": self.timeout} - if "data" in kwargs: - request_kwargs["content"] = kwargs["data"] - try: - response = await self._session.request(method, url, **request_kwargs) - except httpx.TimeoutException as e: - if attempt == self.max_retries: - raise - wait = self._compute_backoff(attempt, None) - logger.warning( - "%s %s failed (%s), retry %d/%d in %.1fs", - method, endpoint, type(e).__name__, - attempt + 1, self.max_retries, wait, - ) - await asyncio.sleep(wait) - continue - except httpx.ConnectError as e: - if attempt == self.max_retries: - raise - wait = self._compute_backoff(attempt, None) - logger.warning( - "%s %s failed (%s), retry %d/%d in %.1fs", - method, endpoint, type(e).__name__, - attempt + 1, self.max_retries, wait, - ) - await asyncio.sleep(wait) - continue - - self._update_rate_limiter(response) - - if response.status_code not in _RETRYABLE_STATUS_CODES: - return response - if attempt == self.max_retries: - if response.status_code == 429: - raise RateLimitError( - 429, "Rate limit exceeded after retries", - method=method, endpoint=endpoint, - ) - return response - - wait = self._compute_backoff(attempt, response.headers.get("Retry-After")) - logger.warning( - "%s %s returned %d, retry %d/%d in %.1fs", - method, endpoint, response.status_code, - attempt + 1, self.max_retries, wait, - ) - await asyncio.sleep(wait) - - return response # unreachable, satisfies type checker - - async def get(self, endpoint: str) -> dict[str, Any]: - """Make authenticated GET request.""" - logger.debug("GET %s", endpoint) - response = await self._request("GET", endpoint) - return self._handle_response(response, method="GET", endpoint=endpoint) - - async def paginated_get( - self, - path: str, - response_key: str, - params: dict[str, Any], - fetch_all: bool = False, - ) -> list[dict]: - """Fetch items with automatic cursor-based pagination.""" - params = dict(params) - all_items: list[dict] = [] - while True: - filtered = {k: v for k, v in params.items() if v is not None} - endpoint = f"{path}?{urlencode(filtered)}" if filtered else path - response = await self.get(endpoint) - all_items.extend(response.get(response_key, [])) - cursor = response.get("cursor", "") - if not fetch_all or not cursor: - break - params["cursor"] = cursor - return all_items - - async def post(self, endpoint: str, data: dict[str, Any]) -> dict[str, Any]: - """Make authenticated POST request.""" - logger.debug("POST %s", endpoint) - body = json.dumps(data, separators=(",", ":")) - response = await self._request("POST", endpoint, data=body) - return self._handle_response( - response, method="POST", endpoint=endpoint, request_body=data - ) - - async def put(self, endpoint: str, data: dict[str, Any]) -> dict[str, Any]: - """Make authenticated PUT request.""" - logger.debug("PUT %s", endpoint) - body = json.dumps(data, separators=(",", ":")) - response = await self._request("PUT", endpoint, data=body) - return self._handle_response( - response, method="PUT", endpoint=endpoint, request_body=data - ) - - async def delete(self, endpoint: str, body: dict | None = None) -> dict[str, Any]: - """Make authenticated DELETE request.""" - logger.debug("DELETE %s", endpoint) - if body: - data = json.dumps(body, separators=(",", ":")) - response = await self._request("DELETE", endpoint, data=data) - else: - response = await self._request("DELETE", endpoint) - return self._handle_response(response, method="DELETE", endpoint=endpoint) - - # --- Domain accessors --- - - @cached_property - def portfolio(self) -> AsyncPortfolio: - return AsyncPortfolio(self) - - @cached_property - def exchange(self) -> AsyncExchange: - return AsyncExchange(self) - - @cached_property - def api_keys(self) -> AsyncAPIKeys: - return AsyncAPIKeys(self) - - @cached_property - def communications(self) -> AsyncCommunications: - return AsyncCommunications(self) - - def feed(self) -> AsyncFeed: - """Create a new async real-time data feed.""" - from .afeed import AsyncFeed - return AsyncFeed(self) - - # --- Domain query methods --- - - async def get_market(self, ticker: str) -> AsyncMarket: - response = await self.get(f"/markets/{ticker.upper()}") - model = MarketModel.model_validate(response["market"]) - return AsyncMarket(self, model) - - async def get_markets( - self, - *, - status: MarketStatus | None = None, - mve_filter: str | None = None, - tickers: list[str] | None = None, - series_ticker: str | None = None, - event_ticker: str | None = None, - limit: int = 100, - cursor: str | None = None, - fetch_all: bool = False, - **extra_params, - ) -> DataFrameList[AsyncMarket]: - params = { - "status": status.value if status is not None else None, - "mve_filter": mve_filter, - "tickers": ",".join(normalize_tickers(tickers)) if tickers else None, - "series_ticker": normalize_ticker(series_ticker), - "event_ticker": normalize_ticker(event_ticker), - "limit": limit, - "cursor": cursor, - **extra_params, - } - data = await self.paginated_get("/markets", "markets", params, fetch_all) - return DataFrameList(AsyncMarket(self, MarketModel.model_validate(m)) for m in data) - - async def get_event( - self, - event_ticker: str, - *, - with_nested_markets: bool = False, - ) -> AsyncEvent: - params = {} - if with_nested_markets: - params["with_nested_markets"] = "true" - endpoint = f"/events/{event_ticker.upper()}" - if params: - endpoint += "?" + urlencode(params) - response = await self.get(endpoint) - model = EventModel.model_validate(response["event"]) - return AsyncEvent(self, model) - - async def get_events( - self, - *, - series_ticker: str | None = None, - status: MarketStatus | None = None, - limit: int = 100, - cursor: str | None = None, - fetch_all: bool = False, - **extra_params, - ) -> DataFrameList[AsyncEvent]: - params = { - "limit": limit, - "series_ticker": normalize_ticker(series_ticker), - "status": status.value if status is not None else None, - "cursor": cursor, - **extra_params, - } - data = await self.paginated_get("/events", "events", params, fetch_all) - return DataFrameList(AsyncEvent(self, EventModel.model_validate(e)) for e in data) - - async def get_series( - self, - series_ticker: str, - *, - include_volume: bool = False, - ) -> AsyncSeries: - params = {} - if include_volume: - params["include_volume"] = "true" - endpoint = f"/series/{series_ticker.upper()}" - if params: - endpoint += "?" + urlencode(params) - response = await self.get(endpoint) - model = SeriesModel.model_validate(response["series"]) - return AsyncSeries(self, model) - - async def get_all_series( - self, - *, - category: str | None = None, - limit: int = 100, - cursor: str | None = None, - fetch_all: bool = False, - **extra_params, - ) -> DataFrameList[AsyncSeries]: - params = {"limit": limit, "category": category, "cursor": cursor, **extra_params} - data = await self.paginated_get("/series", "series", params, fetch_all) - return DataFrameList(AsyncSeries(self, SeriesModel.model_validate(s)) for s in data) - - async def get_mve_collection(self, collection_ticker: str) -> AsyncMveCollection: - response = await self.get(f"/multivariate_event_collections/{collection_ticker}") - model = MveCollectionModel.model_validate(response.get("multivariate_contract", response)) - return AsyncMveCollection(self, model) - - async def get_mve_collections( - self, - *, - status: str | None = None, - associated_event_ticker: str | None = None, - series_ticker: str | None = None, - limit: int = 100, - cursor: str | None = None, - fetch_all: bool = False, - ) -> DataFrameList[AsyncMveCollection]: - params = { - "limit": limit, - "status": status, - "associated_event_ticker": normalize_ticker(associated_event_ticker), - "series_ticker": normalize_ticker(series_ticker), - "cursor": cursor, - } - data = await self.paginated_get( - "/multivariate_event_collections", "multivariate_contracts", params, fetch_all - ) - return DataFrameList( - AsyncMveCollection(self, MveCollectionModel.model_validate(c)) for c in data - ) - - async def get_multivariate_events( - self, - *, - series_ticker: str | None = None, - collection_ticker: str | None = None, - with_nested_markets: bool = False, - limit: int = 200, - cursor: str | None = None, - fetch_all: bool = False, - ) -> DataFrameList[AsyncEvent]: - params: dict = {"limit": limit} - if series_ticker: - params["series_ticker"] = normalize_ticker(series_ticker) - if collection_ticker: - params["collection_ticker"] = collection_ticker - if with_nested_markets: - params["with_nested_markets"] = "true" - if cursor: - params["cursor"] = cursor - - data = await self.paginated_get("/events/multivariate", "events", params, fetch_all) - return DataFrameList(AsyncEvent(self, EventModel.model_validate(e)) for e in data) - - async def get_trades( - self, - *, - ticker: str | None = None, - min_ts: int | None = None, - max_ts: int | None = None, - limit: int = 100, - cursor: str | None = None, - fetch_all: bool = False, - **extra_params, - ) -> DataFrameList[TradeModel]: - params = { - "limit": limit, - "ticker": normalize_ticker(ticker), - "min_ts": min_ts, - "max_ts": max_ts, - "cursor": cursor, - **extra_params, - } - data = await self.paginated_get("/markets/trades", "trades", params, fetch_all) - return DataFrameList(TradeModel.model_validate(t) for t in data) - - async def get_candlesticks_batch( - self, - tickers: list[str], - start_ts: int, - end_ts: int, - period: CandlestickPeriod = CandlestickPeriod.ONE_HOUR, - ) -> dict[str, CandlestickResponse]: - query = urlencode({ - "market_tickers": ",".join(normalize_tickers(tickers)), - "start_ts": start_ts, - "end_ts": end_ts, - "period_interval": period.value, - }) - response = await self.get(f"/markets/candlesticks?{query}") - return { - item["market_ticker"]: CandlestickResponse.model_validate(item) - for item in response.get("markets", []) - } +__all__ = ["AsyncKalshiClient"] diff --git a/pykalshi/api_keys.py b/pykalshi/api_keys.py index 3f148ee..f125dd9 100644 --- a/pykalshi/api_keys.py +++ b/pykalshi/api_keys.py @@ -1,91 +1,4 @@ -from __future__ import annotations -from typing import TYPE_CHECKING -from .models import APIKey, GeneratedAPIKey, APILimits +from ._sync.api_keys import APIKeys as APIKeys +from ._async.api_keys import AsyncAPIKeys as AsyncAPIKeys -if TYPE_CHECKING: - from .client import KalshiClient - - -class APIKeys: - """API key management and account limits.""" - - def __init__(self, client: KalshiClient) -> None: - self._client = client - - def list(self) -> list[APIKey]: - """List all API keys for this account.""" - data = self._client.get("/api_keys") - return [APIKey.model_validate(k) for k in data.get("api_keys", [])] - - def create(self, public_key: str, name: str | None = None) -> str: - """Create an API key with a provided RSA public key. - - Args: - public_key: PEM-encoded RSA public key. - name: Optional name for the key. - - Returns: - The API key ID string. - """ - body: dict = {"public_key": public_key} - if name: - body["name"] = name - data = self._client.post("/api_keys", body) - return data["api_key_id"] - - def generate(self, name: str | None = None) -> GeneratedAPIKey: - """Generate a new API key pair (Kalshi creates both keys). - - Returns a GeneratedAPIKey with the private_key field populated. - The private key is only returned ONCE - store it securely. - - Args: - name: Optional name for the key. - """ - body: dict = {} - if name: - body["name"] = name - data = self._client.post("/api_keys/generate", body) - return GeneratedAPIKey.model_validate(data) - - def delete(self, key_id: str) -> None: - """Delete an API key. - - Args: - key_id: The API key ID to delete. - """ - self._client.delete(f"/api_keys/{key_id}") - - def get_limits(self) -> APILimits: - """Get API rate limits for this account.""" - data = self._client.get("/account/limits") - return APILimits.model_validate(data) - - -class AsyncAPIKeys(APIKeys): - """Async variant of APIKeys.""" - - async def list(self) -> list[APIKey]: # type: ignore[override] - data = await self._client.get("/api_keys") - return [APIKey.model_validate(k) for k in data.get("api_keys", [])] - - async def create(self, public_key: str, name: str | None = None) -> str: # type: ignore[override] - body: dict = {"public_key": public_key} - if name: - body["name"] = name - data = await self._client.post("/api_keys", body) - return data["api_key_id"] - - async def generate(self, name: str | None = None) -> GeneratedAPIKey: # type: ignore[override] - body: dict = {} - if name: - body["name"] = name - data = await self._client.post("/api_keys/generate", body) - return GeneratedAPIKey.model_validate(data) - - async def delete(self, key_id: str) -> None: # type: ignore[override] - await self._client.delete(f"/api_keys/{key_id}") - - async def get_limits(self) -> APILimits: # type: ignore[override] - data = await self._client.get("/account/limits") - return APILimits.model_validate(data) +__all__ = ["APIKeys", "AsyncAPIKeys"] diff --git a/pykalshi/client.py b/pykalshi/client.py index 9a7c26d..7e6268a 100644 --- a/pykalshi/client.py +++ b/pykalshi/client.py @@ -1,420 +1,3 @@ -"""Kalshi API Client — synchronous.""" +from ._sync.client import KalshiClient as KalshiClient -from __future__ import annotations - -import json -import logging -import time -from functools import cached_property -from typing import Any -from urllib.parse import urlencode - -import httpx - -from ._base import _BaseKalshiClient, _RETRYABLE_STATUS_CODES -from .events import Event -from .markets import Market, Series -from .mve import MveCollection -from .models import MarketModel, EventModel, SeriesModel, TradeModel, CandlestickResponse, MveCollectionModel -from .dataframe import DataFrameList -from .portfolio import Portfolio -from .enums import MarketStatus, CandlestickPeriod -from .feed import Feed -from .exchange import Exchange -from .api_keys import APIKeys -from .communications import Communications -from .rate_limiter import RateLimiterProtocol -from .exceptions import RateLimitError -from ._utils import normalize_ticker, normalize_tickers - -logger = logging.getLogger(__name__) - - -class KalshiClient(_BaseKalshiClient): - """Authenticated client for the Kalshi Trading API. - - Usage: - client = KalshiClient.from_env() # Loads .env file - client = KalshiClient(api_key_id="...", private_key_path="...") - """ - - def __init__( - self, - api_key_id: str | None = None, - private_key_path: str | None = None, - api_base: str | None = None, - demo: bool = False, - timeout: float = 10.0, - max_retries: int = 3, - rate_limiter: RateLimiterProtocol | None = None, - ) -> None: - super().__init__( - api_key_id=api_key_id, - private_key_path=private_key_path, - api_base=api_base, - demo=demo, - timeout=timeout, - max_retries=max_retries, - rate_limiter=rate_limiter, - ) - self._session = httpx.Client() - - def close(self) -> None: - """Close the underlying HTTP connection pool.""" - self._session.close() - - def __enter__(self) -> KalshiClient: - return self - - def __exit__(self, *args: Any) -> None: - self.close() - - # --- HTTP methods --- - - def _request(self, method: str, endpoint: str, **kwargs: Any) -> httpx.Response: - """Execute HTTP request with retry on transient failures.""" - url = f"{self.api_base}{endpoint}" - - for attempt in range(self.max_retries + 1): - if self.rate_limiter is not None: - wait_time = self.rate_limiter.acquire() - if wait_time > 0: - logger.debug("Rate limiter waited %.3fs", wait_time) - - headers = self._get_headers(method, endpoint) - request_kwargs: dict[str, Any] = {"headers": headers, "timeout": self.timeout} - if "data" in kwargs: - request_kwargs["content"] = kwargs["data"] - try: - response = self._session.request(method, url, **request_kwargs) - except httpx.TimeoutException as e: - if attempt == self.max_retries: - raise - wait = self._compute_backoff(attempt, None) - logger.warning( - "%s %s failed (%s), retry %d/%d in %.1fs", - method, endpoint, type(e).__name__, - attempt + 1, self.max_retries, wait, - ) - time.sleep(wait) - continue - except httpx.ConnectError as e: - if attempt == self.max_retries: - raise - wait = self._compute_backoff(attempt, None) - logger.warning( - "%s %s failed (%s), retry %d/%d in %.1fs", - method, endpoint, type(e).__name__, - attempt + 1, self.max_retries, wait, - ) - time.sleep(wait) - continue - - self._update_rate_limiter(response) - - if response.status_code not in _RETRYABLE_STATUS_CODES: - return response - if attempt == self.max_retries: - if response.status_code == 429: - raise RateLimitError( - 429, "Rate limit exceeded after retries", - method=method, endpoint=endpoint, - ) - return response - - wait = self._compute_backoff(attempt, response.headers.get("Retry-After")) - logger.warning( - "%s %s returned %d, retry %d/%d in %.1fs", - method, endpoint, response.status_code, - attempt + 1, self.max_retries, wait, - ) - time.sleep(wait) - - return response # unreachable, satisfies type checker - - def get(self, endpoint: str) -> dict[str, Any]: - """Make authenticated GET request.""" - logger.debug("GET %s", endpoint) - response = self._request("GET", endpoint) - return self._handle_response(response, method="GET", endpoint=endpoint) - - def paginated_get( - self, - path: str, - response_key: str, - params: dict[str, Any], - fetch_all: bool = False, - ) -> list[dict]: - """Fetch items with automatic cursor-based pagination.""" - params = dict(params) - all_items: list[dict] = [] - while True: - filtered = {k: v for k, v in params.items() if v is not None} - endpoint = f"{path}?{urlencode(filtered)}" if filtered else path - response = self.get(endpoint) - all_items.extend(response.get(response_key, [])) - cursor = response.get("cursor", "") - if not fetch_all or not cursor: - break - params["cursor"] = cursor - return all_items - - def post(self, endpoint: str, data: dict[str, Any]) -> dict[str, Any]: - """Make authenticated POST request.""" - logger.debug("POST %s", endpoint) - body = json.dumps(data, separators=(",", ":")) - response = self._request("POST", endpoint, data=body) - return self._handle_response( - response, method="POST", endpoint=endpoint, request_body=data - ) - - def put(self, endpoint: str, data: dict[str, Any]) -> dict[str, Any]: - """Make authenticated PUT request.""" - logger.debug("PUT %s", endpoint) - body = json.dumps(data, separators=(",", ":")) - response = self._request("PUT", endpoint, data=body) - return self._handle_response( - response, method="PUT", endpoint=endpoint, request_body=data - ) - - def delete(self, endpoint: str, body: dict | None = None) -> dict[str, Any]: - """Make authenticated DELETE request.""" - logger.debug("DELETE %s", endpoint) - if body: - data = json.dumps(body, separators=(",", ":")) - response = self._request("DELETE", endpoint, data=data) - else: - response = self._request("DELETE", endpoint) - return self._handle_response(response, method="DELETE", endpoint=endpoint) - - # --- Domain accessors --- - - @cached_property - def portfolio(self) -> Portfolio: - """The authenticated user's portfolio.""" - return Portfolio(self) - - @cached_property - def exchange(self) -> Exchange: - """Exchange status, schedule, and announcements.""" - return Exchange(self) - - @cached_property - def api_keys(self) -> APIKeys: - """API key management and rate limits.""" - return APIKeys(self) - - @cached_property - def communications(self) -> Communications: - """RFQ and quote operations for combo (multivariate event) trading.""" - return Communications(self) - - def feed(self) -> Feed: - """Create a new real-time data feed. - - Returns a Feed instance for streaming market data via WebSocket. - Each call creates a new Feed — use a single Feed for all subscriptions. - """ - return Feed(self) - - # --- Domain query methods --- - - def get_market(self, ticker: str) -> Market: - """Get a Market by ticker.""" - response = self.get(f"/markets/{ticker.upper()}") - model = MarketModel.model_validate(response["market"]) - return Market(self, model) - - def get_markets( - self, - *, - status: MarketStatus | None = None, - mve_filter: str | None = None, - tickers: list[str] | None = None, - series_ticker: str | None = None, - event_ticker: str | None = None, - limit: int = 100, - cursor: str | None = None, - fetch_all: bool = False, - **extra_params, - ) -> DataFrameList[Market]: - """Search for markets.""" - params = { - "status": status.value if status is not None else None, - "mve_filter": mve_filter, - "tickers": ",".join(normalize_tickers(tickers)) if tickers else None, - "series_ticker": normalize_ticker(series_ticker), - "event_ticker": normalize_ticker(event_ticker), - "limit": limit, - "cursor": cursor, - **extra_params, - } - data = self.paginated_get("/markets", "markets", params, fetch_all) - return DataFrameList(Market(self, MarketModel.model_validate(m)) for m in data) - - def get_event( - self, - event_ticker: str, - *, - with_nested_markets: bool = False, - ) -> Event: - """Get an Event by ticker.""" - params = {} - if with_nested_markets: - params["with_nested_markets"] = "true" - endpoint = f"/events/{event_ticker.upper()}" - if params: - endpoint += "?" + urlencode(params) - response = self.get(endpoint) - model = EventModel.model_validate(response["event"]) - return Event(self, model) - - def get_events( - self, - *, - series_ticker: str | None = None, - status: MarketStatus | None = None, - limit: int = 100, - cursor: str | None = None, - fetch_all: bool = False, - **extra_params, - ) -> DataFrameList[Event]: - """Search for events.""" - params = { - "limit": limit, - "series_ticker": normalize_ticker(series_ticker), - "status": status.value if status is not None else None, - "cursor": cursor, - **extra_params, - } - data = self.paginated_get("/events", "events", params, fetch_all) - return DataFrameList(Event(self, EventModel.model_validate(e)) for e in data) - - def get_series( - self, - series_ticker: str, - *, - include_volume: bool = False, - ) -> Series: - """Get a Series by ticker.""" - params = {} - if include_volume: - params["include_volume"] = "true" - endpoint = f"/series/{series_ticker.upper()}" - if params: - endpoint += "?" + urlencode(params) - response = self.get(endpoint) - model = SeriesModel.model_validate(response["series"]) - return Series(self, model) - - def get_all_series( - self, - *, - category: str | None = None, - limit: int = 100, - cursor: str | None = None, - fetch_all: bool = False, - **extra_params, - ) -> DataFrameList[Series]: - """List all series.""" - params = {"limit": limit, "category": category, "cursor": cursor, **extra_params} - data = self.paginated_get("/series", "series", params, fetch_all) - return DataFrameList(Series(self, SeriesModel.model_validate(s)) for s in data) - - def get_mve_collection(self, collection_ticker: str) -> MveCollection: - """Get a multivariate event collection by ticker.""" - response = self.get(f"/multivariate_event_collections/{collection_ticker}") - model = MveCollectionModel.model_validate(response.get("multivariate_contract", response)) - return MveCollection(self, model) - - def get_mve_collections( - self, - *, - status: str | None = None, - associated_event_ticker: str | None = None, - series_ticker: str | None = None, - limit: int = 100, - cursor: str | None = None, - fetch_all: bool = False, - ) -> DataFrameList[MveCollection]: - """List multivariate event collections.""" - params = { - "limit": limit, - "status": status, - "associated_event_ticker": normalize_ticker(associated_event_ticker), - "series_ticker": normalize_ticker(series_ticker), - "cursor": cursor, - } - data = self.paginated_get( - "/multivariate_event_collections", "multivariate_contracts", params, fetch_all - ) - return DataFrameList( - MveCollection(self, MveCollectionModel.model_validate(c)) for c in data - ) - - def get_multivariate_events( - self, - *, - series_ticker: str | None = None, - collection_ticker: str | None = None, - with_nested_markets: bool = False, - limit: int = 200, - cursor: str | None = None, - fetch_all: bool = False, - ) -> DataFrameList[Event]: - """Get multivariate (combo) events.""" - params: dict = {"limit": limit} - if series_ticker: - params["series_ticker"] = normalize_ticker(series_ticker) - if collection_ticker: - params["collection_ticker"] = collection_ticker - if with_nested_markets: - params["with_nested_markets"] = "true" - if cursor: - params["cursor"] = cursor - - data = self.paginated_get("/events/multivariate", "events", params, fetch_all) - return DataFrameList(Event(self, EventModel.model_validate(e)) for e in data) - - def get_trades( - self, - *, - ticker: str | None = None, - min_ts: int | None = None, - max_ts: int | None = None, - limit: int = 100, - cursor: str | None = None, - fetch_all: bool = False, - **extra_params, - ) -> DataFrameList[TradeModel]: - """Get public trade history.""" - params = { - "limit": limit, - "ticker": normalize_ticker(ticker), - "min_ts": min_ts, - "max_ts": max_ts, - "cursor": cursor, - **extra_params, - } - data = self.paginated_get("/markets/trades", "trades", params, fetch_all) - return DataFrameList(TradeModel.model_validate(t) for t in data) - - def get_candlesticks_batch( - self, - tickers: list[str], - start_ts: int, - end_ts: int, - period: CandlestickPeriod = CandlestickPeriod.ONE_HOUR, - ) -> dict[str, CandlestickResponse]: - """Batch fetch candlesticks for multiple markets (up to 100 tickers).""" - query = urlencode({ - "market_tickers": ",".join(normalize_tickers(tickers)), - "start_ts": start_ts, - "end_ts": end_ts, - "period_interval": period.value, - }) - response = self.get(f"/markets/candlesticks?{query}") - return { - item["market_ticker"]: CandlestickResponse.model_validate(item) - for item in response.get("markets", []) - } +__all__ = ["KalshiClient"] diff --git a/pykalshi/communications.py b/pykalshi/communications.py index faedfe3..290812f 100644 --- a/pykalshi/communications.py +++ b/pykalshi/communications.py @@ -1,246 +1,4 @@ -from __future__ import annotations -from typing import TYPE_CHECKING -from urllib.parse import urlencode -from .models import RfqModel, QuoteModel -from .dataframe import DataFrameList +from ._sync.communications import Communications as Communications +from ._async.communications import AsyncCommunications as AsyncCommunications -if TYPE_CHECKING: - from .client import KalshiClient - - -class Communications: - """RFQ (Request for Quote) and quote operations for combo trading. - - Multivariate event combos trade via the RFQ system rather than - standard limit orders. The flow is: - - 1. Create an RFQ broadcasting your intent to trade a combo. - 2. Market makers respond with two-sided quotes. - 3. Accept a quote to execute the trade. - - Usage: - # Create an RFQ for a combo market - rfq = client.communications.create_rfq( - market_ticker="KXMVE-...", - contracts_fp="10.00", - ) - - # List active RFQs - rfqs = client.communications.get_rfqs(status="active") - - # Respond to an RFQ as a market maker - quote = client.communications.create_quote( - rfq_id=rfq.rfq_id, - yes_bid="0.45", - no_bid="0.55", - ) - """ - - def __init__(self, client: KalshiClient) -> None: - self._client = client - - def create_rfq( - self, - market_ticker: str, - *, - contracts_fp: str | None = None, - target_cost_dollars: str | None = None, - rest_remainder: bool = False, - ) -> RfqModel: - """Create a Request for Quote. - - Args: - market_ticker: The combo market ticker to request quotes for. - contracts_fp: Number of contracts to trade (fixed-point string, e.g. "10.00"). - target_cost_dollars: Target cost in dollars (e.g. "10.00"). - Use this OR contracts_fp, not both. - rest_remainder: If True, rest any unfilled portion on the orderbook. - """ - - body: dict = { - "market_ticker": market_ticker.upper(), - "rest_remainder": rest_remainder, - } - if contracts_fp is not None: - body["contracts_fp"] = contracts_fp - if target_cost_dollars is not None: - body["target_cost_dollars"] = target_cost_dollars - - response = self._client.post("/communications/rfqs", body) - return RfqModel.model_validate(response.get("rfq", response)) - - def get_rfqs( - self, - *, - market_ticker: str | None = None, - status: str | None = None, - mve_collection_ticker: str | None = None, - limit: int = 100, - cursor: str | None = None, - fetch_all: bool = False, - ) -> DataFrameList[RfqModel]: - """List RFQs. - - Args: - market_ticker: Filter by combo market ticker. - status: Filter by RFQ status (e.g., "active", "expired"). - mve_collection_ticker: Filter by collection ticker. - limit: Maximum results per page (default 100). - cursor: Pagination cursor. - fetch_all: If True, automatically fetch all pages. - """ - params: dict = {"limit": limit} - if market_ticker: - params["market_ticker"] = market_ticker.upper() - if status: - params["status"] = status - if mve_collection_ticker: - params["mve_collection_ticker"] = mve_collection_ticker - if cursor: - params["cursor"] = cursor - - data = self._client.paginated_get("/communications/rfqs", "rfqs", params, fetch_all) - return DataFrameList(RfqModel.model_validate(r) for r in data) - - def get_rfq(self, rfq_id: str) -> RfqModel: - """Get a single RFQ by ID.""" - response = self._client.get(f"/communications/rfqs/{rfq_id}") - return RfqModel.model_validate(response.get("rfq", response)) - - def create_quote( - self, - rfq_id: str, - *, - yes_bid: str, - no_bid: str, - rest_remainder: bool = False, - ) -> QuoteModel: - """Create a quote in response to an RFQ. - - Prices are in FixedPointDollars (e.g., "0.45"). - - Args: - rfq_id: ID of the RFQ to respond to. - yes_bid: Your bid price for the YES side (FixedPointDollars). - no_bid: Your bid price for the NO side (FixedPointDollars). - rest_remainder: If True, rest any unfilled portion on the orderbook. - """ - body: dict = { - "rfq_id": rfq_id, - "yes_bid": yes_bid, - "no_bid": no_bid, - "rest_remainder": rest_remainder, - } - - response = self._client.post("/communications/quotes", body) - return QuoteModel.model_validate(response.get("quote", response)) - - def get_quotes( - self, - *, - creator_user_id: str | None = None, - rfq_creator_user_id: str | None = None, - rfq_id: str | None = None, - market_ticker: str | None = None, - status: str | None = None, - limit: int = 100, - cursor: str | None = None, - fetch_all: bool = False, - ) -> DataFrameList[QuoteModel]: - """List quotes. - - The API requires at least one of creator_user_id or rfq_creator_user_id. - - Args: - creator_user_id: Filter by quote creator. Required if rfq_creator_user_id not set. - rfq_creator_user_id: Filter by RFQ creator. Required if creator_user_id not set. - rfq_id: Filter by RFQ ID. - market_ticker: Filter by combo market ticker. - status: Filter by quote status. - limit: Maximum results per page (default 100). - cursor: Pagination cursor. - fetch_all: If True, automatically fetch all pages. - """ - params: dict = {"limit": limit} - if creator_user_id: - params["creator_user_id"] = creator_user_id - if rfq_creator_user_id: - params["rfq_creator_user_id"] = rfq_creator_user_id - if rfq_id: - params["rfq_id"] = rfq_id - if market_ticker: - params["market_ticker"] = market_ticker.upper() - if status: - params["status"] = status - if cursor: - params["cursor"] = cursor - - data = self._client.paginated_get("/communications/quotes", "quotes", params, fetch_all) - return DataFrameList(QuoteModel.model_validate(q) for q in data) - - -class AsyncCommunications(Communications): - """Async variant of Communications.""" - - async def create_rfq(self, market_ticker, *, contracts_fp=None, # type: ignore[override] - target_cost_dollars=None, rest_remainder=False) -> RfqModel: - - body: dict = { - "market_ticker": market_ticker.upper(), - "rest_remainder": rest_remainder, - } - if contracts_fp is not None: - body["contracts_fp"] = contracts_fp - if target_cost_dollars is not None: - body["target_cost_dollars"] = target_cost_dollars - response = await self._client.post("/communications/rfqs", body) - return RfqModel.model_validate(response.get("rfq", response)) - - async def get_rfqs(self, *, market_ticker=None, status=None, # type: ignore[override] - mve_collection_ticker=None, limit=100, cursor=None, - fetch_all=False) -> DataFrameList[RfqModel]: - params: dict = {"limit": limit} - if market_ticker: - params["market_ticker"] = market_ticker.upper() - if status: - params["status"] = status - if mve_collection_ticker: - params["mve_collection_ticker"] = mve_collection_ticker - if cursor: - params["cursor"] = cursor - data = await self._client.paginated_get("/communications/rfqs", "rfqs", params, fetch_all) - return DataFrameList(RfqModel.model_validate(r) for r in data) - - async def get_rfq(self, rfq_id: str) -> RfqModel: # type: ignore[override] - response = await self._client.get(f"/communications/rfqs/{rfq_id}") - return RfqModel.model_validate(response.get("rfq", response)) - - async def create_quote(self, rfq_id, *, yes_bid, no_bid, # type: ignore[override] - rest_remainder=False) -> QuoteModel: - body: dict = { - "rfq_id": rfq_id, - "yes_bid": yes_bid, - "no_bid": no_bid, - "rest_remainder": rest_remainder, - } - response = await self._client.post("/communications/quotes", body) - return QuoteModel.model_validate(response.get("quote", response)) - - async def get_quotes(self, *, creator_user_id=None, rfq_creator_user_id=None, # type: ignore[override] - rfq_id=None, market_ticker=None, status=None, - limit=100, cursor=None, fetch_all=False) -> DataFrameList[QuoteModel]: - params: dict = {"limit": limit} - if creator_user_id: - params["creator_user_id"] = creator_user_id - if rfq_creator_user_id: - params["rfq_creator_user_id"] = rfq_creator_user_id - if rfq_id: - params["rfq_id"] = rfq_id - if market_ticker: - params["market_ticker"] = market_ticker.upper() - if status: - params["status"] = status - if cursor: - params["cursor"] = cursor - data = await self._client.paginated_get("/communications/quotes", "quotes", params, fetch_all) - return DataFrameList(QuoteModel.model_validate(q) for q in data) +__all__ = ["Communications", "AsyncCommunications"] diff --git a/pykalshi/events.py b/pykalshi/events.py index 4de45c5..eefd1d8 100644 --- a/pykalshi/events.py +++ b/pykalshi/events.py @@ -1,116 +1,4 @@ -from __future__ import annotations -from typing import TYPE_CHECKING -from .models import EventModel, ForecastPercentileHistory -from .dataframe import DataFrameList +from ._sync.events import Event as Event +from ._async.events import AsyncEvent as AsyncEvent -if TYPE_CHECKING: - from .client import KalshiClient - from .markets import Market, Series, AsyncMarket, AsyncSeries - - -class Event: - """Represents a Kalshi Event. - - An event is a container for related markets (e.g., "Will X happen?" with - multiple outcome markets). - - Key fields are exposed as typed properties for IDE support. - All other EventModel fields are accessible via attribute delegation. - """ - - def __init__(self, client: KalshiClient, data: EventModel) -> None: - self._client = client - self.data = data - - # --- Typed properties for core fields --- - - @property - def event_ticker(self) -> str: - return self.data.event_ticker - - @property - def series_ticker(self) -> str: - return self.data.series_ticker - - @property - def title(self) -> str | None: - return self.data.title - - @property - def category(self) -> str | None: - return self.data.category - - @property - def mutually_exclusive(self) -> bool: - return self.data.mutually_exclusive - - # --- Domain logic --- - - def get_markets(self) -> DataFrameList[Market]: - """Get all markets for this event.""" - return self._client.get_markets(event_ticker=self.data.event_ticker) - - def get_series(self) -> Series: - """Get the parent Series for this event.""" - return self._client.get_series(self.series_ticker) - - def get_forecast_percentile_history( - self, - percentiles: list[int] | None = None, - ) -> ForecastPercentileHistory: - """Get historical forecast data at various percentiles. - - Args: - percentiles: List of percentiles to fetch (e.g., [10, 25, 50, 75, 90]). - If None, returns all available percentiles. - - Returns: - ForecastPercentileHistory with percentile -> history mapping. - """ - endpoint = f"/events/{self.event_ticker}/forecast/percentile_history" - if percentiles: - endpoint += f"?percentiles={','.join(str(p) for p in percentiles)}" - response = self._client.get(endpoint) - return ForecastPercentileHistory.model_validate(response) - - def __getattr__(self, name: str): - return getattr(self.data, name) - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Event): - return NotImplemented - return self.data.event_ticker == other.data.event_ticker - - def __hash__(self) -> int: - return hash(self.data.event_ticker) - - def __repr__(self) -> str: - parts = [f"" - - def _repr_html_(self) -> str: - from ._repr import event_html - return event_html(self) - - -class AsyncEvent(Event): - """Async variant of Event. Inherits all properties and non-I/O methods.""" - - async def get_markets(self) -> DataFrameList[AsyncMarket]: # type: ignore[override] - return await self._client.get_markets(event_ticker=self.data.event_ticker) - - async def get_series(self) -> AsyncSeries: # type: ignore[override] - return await self._client.get_series(self.series_ticker) - - async def get_forecast_percentile_history( # type: ignore[override] - self, percentiles: list[int] | None = None, - ) -> ForecastPercentileHistory: - endpoint = f"/events/{self.event_ticker}/forecast/percentile_history" - if percentiles: - endpoint += f"?percentiles={','.join(str(p) for p in percentiles)}" - response = await self._client.get(endpoint) - return ForecastPercentileHistory.model_validate(response) +__all__ = ["Event", "AsyncEvent"] diff --git a/pykalshi/exchange.py b/pykalshi/exchange.py index ef9956c..3bb00e3 100644 --- a/pykalshi/exchange.py +++ b/pykalshi/exchange.py @@ -1,74 +1,4 @@ -from __future__ import annotations -from typing import TYPE_CHECKING, Any -from .models import ExchangeStatus, Announcement -from .exceptions import KalshiAPIError +from ._sync.exchange import Exchange as Exchange +from ._async.exchange import AsyncExchange as AsyncExchange -if TYPE_CHECKING: - from .client import KalshiClient - - -class Exchange: - """Exchange status, schedule, and announcements.""" - - def __init__(self, client: KalshiClient) -> None: - self._client = client - - def get_status(self) -> ExchangeStatus: - """Get current exchange operational status.""" - try: - data = self._client.get("/exchange/status") - except KalshiAPIError as e: - if e.status_code == 503 and isinstance(e.response_body, dict): - data = e.response_body - else: - raise - return ExchangeStatus.model_validate(data) - - def is_trading(self) -> bool: - """Quick check if trading is currently active.""" - return self.get_status().trading_active - - def get_schedule(self) -> dict[str, Any]: - """Get exchange trading schedule (raw format).""" - data = self._client.get("/exchange/schedule") - return data.get("schedule", {}) - - def get_announcements(self) -> list[Announcement]: - """Get exchange-wide announcements.""" - data = self._client.get("/exchange/announcements") - return [Announcement.model_validate(a) for a in data.get("announcements", [])] - - def get_user_data_timestamp(self) -> int: - """Get timestamp of last user data validation (Unix ms).""" - data = self._client.get("/exchange/user_data_timestamp") - return data.get("user_data_timestamp", 0) - - -class AsyncExchange(Exchange): - """Async variant of Exchange.""" - - async def get_status(self) -> ExchangeStatus: # type: ignore[override] - try: - data = await self._client.get("/exchange/status") - except KalshiAPIError as e: - if e.status_code == 503 and isinstance(e.response_body, dict): - data = e.response_body - else: - raise - return ExchangeStatus.model_validate(data) - - async def is_trading(self) -> bool: # type: ignore[override] - status = await self.get_status() - return status.trading_active - - async def get_schedule(self) -> dict[str, Any]: # type: ignore[override] - data = await self._client.get("/exchange/schedule") - return data.get("schedule", {}) - - async def get_announcements(self) -> list[Announcement]: # type: ignore[override] - data = await self._client.get("/exchange/announcements") - return [Announcement.model_validate(a) for a in data.get("announcements", [])] - - async def get_user_data_timestamp(self) -> int: # type: ignore[override] - data = await self._client.get("/exchange/user_data_timestamp") - return data.get("user_data_timestamp", 0) +__all__ = ["Exchange", "AsyncExchange"] diff --git a/pykalshi/markets.py b/pykalshi/markets.py index 442e0aa..a060a59 100644 --- a/pykalshi/markets.py +++ b/pykalshi/markets.py @@ -1,294 +1,4 @@ -from __future__ import annotations +from ._sync.markets import Market as Market, Series as Series +from ._async.markets import AsyncMarket as AsyncMarket, AsyncSeries as AsyncSeries -import logging -from typing import TYPE_CHECKING - -from .models import MarketModel, CandlestickResponse, OrderbookResponse, SeriesModel, TradeModel -from .dataframe import DataFrameList -from .enums import CandlestickPeriod, MarketStatus - -if TYPE_CHECKING: - from .client import KalshiClient - from .events import Event, AsyncEvent - -logger = logging.getLogger(__name__) - - -class Market: - """Represents a Kalshi Market. - - Key fields are exposed as typed properties for IDE support. - All other MarketModel fields are accessible via attribute delegation. - """ - - def __init__(self, client: KalshiClient, data: MarketModel) -> None: - self._client = client - self.data = data - - # --- Typed properties for core fields --- - - @property - def ticker(self) -> str: - return self.data.ticker - - @property - def event_ticker(self) -> str | None: - return self.data.event_ticker - - @property - def status(self) -> MarketStatus | None: - return self.data.status - - @property - def title(self) -> str | None: - return self.data.title - - @property - def subtitle(self) -> str | None: - return self.data.subtitle - - @property - def yes_bid_dollars(self) -> str | None: - return self.data.yes_bid_dollars - - @property - def yes_ask_dollars(self) -> str | None: - return self.data.yes_ask_dollars - - @property - def no_bid_dollars(self) -> str | None: - return self.data.no_bid_dollars - - @property - def no_ask_dollars(self) -> str | None: - return self.data.no_ask_dollars - - @property - def last_price_dollars(self) -> str | None: - return self.data.last_price_dollars - - @property - def volume_fp(self) -> str | None: - return self.data.volume_fp - - @property - def volume_24h_fp(self) -> str | None: - return self.data.volume_24h_fp - - @property - def open_interest_fp(self) -> str | None: - return self.data.open_interest_fp - - @property - def liquidity_dollars(self) -> str | None: - return self.data.liquidity_dollars - - @property - def open_time(self) -> str | None: - return self.data.open_time - - @property - def close_time(self) -> str | None: - return self.data.close_time - - @property - def result(self) -> str | None: - return self.data.result - - @property - def series_ticker(self) -> str | None: - return self.data.series_ticker - - def resolve_series_ticker(self) -> str | None: - """Fetch series_ticker from the event API if not present in market data.""" - if self.data.series_ticker is not None: - return self.data.series_ticker - if not self.data.event_ticker: - return None - try: - event_response = self._client.get(f"/events/{self.data.event_ticker}") - return event_response["event"]["series_ticker"] - except Exception as e: - logger.warning( - "Failed to resolve series_ticker for %s: %s", self.data.ticker, e - ) - return None - - def get_orderbook(self, *, depth: int | None = None) -> OrderbookResponse: - """Get the orderbook for this market.""" - endpoint = f"/markets/{self.data.ticker}/orderbook" - if depth: - endpoint += f"?depth={depth}" - response = self._client.get(endpoint) - return OrderbookResponse.model_validate(response) - - def get_candlesticks( - self, - start_ts: int, - end_ts: int, - period: CandlestickPeriod = CandlestickPeriod.ONE_HOUR, - ) -> CandlestickResponse: - """Get candlestick data for this market.""" - series = self.resolve_series_ticker() - if not series: - raise ValueError(f"Market {self.data.ticker} does not have a series_ticker.") - - query = f"start_ts={start_ts}&end_ts={end_ts}&period_interval={period.value}" - endpoint = f"/series/{series}/markets/{self.data.ticker}/candlesticks?{query}" - response = self._client.get(endpoint) - return CandlestickResponse.model_validate(response) - - def get_trades( - self, - min_ts: int | None = None, - max_ts: int | None = None, - limit: int = 100, - cursor: str | None = None, - fetch_all: bool = False, - ) -> DataFrameList[TradeModel]: - """Get public trade history for this market.""" - return self._client.get_trades( - ticker=self.ticker, - min_ts=min_ts, - max_ts=max_ts, - limit=limit, - cursor=cursor, - fetch_all=fetch_all, - ) - - def get_event(self) -> Event | None: - """Get the parent Event for this market.""" - if not self.event_ticker: - return None - return self._client.get_event(self.event_ticker) - - def __getattr__(self, name: str): - return getattr(self.data, name) - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Market): - return NotImplemented - return self.data.ticker == other.data.ticker - - def __hash__(self) -> int: - return hash(self.data.ticker) - - def __repr__(self) -> str: - status = self.status.value if self.status else "?" - parts = [f"" - - def _repr_html_(self) -> str: - from ._repr import market_html - return market_html(self) - - -class Series: - """Represents a Kalshi Series (collection of related markets).""" - - def __init__(self, client: KalshiClient, data: SeriesModel) -> None: - self._client = client - self.data = data - - @property - def ticker(self) -> str: - return self.data.ticker - - @property - def title(self) -> str | None: - return self.data.title - - @property - def category(self) -> str | None: - return self.data.category - - def get_markets(self, **kwargs) -> DataFrameList[Market]: - """Get all markets in this series.""" - return self._client.get_markets(series_ticker=self.ticker, **kwargs) - - def get_events(self, **kwargs) -> DataFrameList[Event]: - """Get all events in this series.""" - return self._client.get_events(series_ticker=self.ticker, **kwargs) - - def __getattr__(self, name: str): - return getattr(self.data, name) - - def __repr__(self) -> str: - parts = [f"" - - def _repr_html_(self) -> str: - from ._repr import series_html - return series_html(self) - - -class AsyncMarket(Market): - """Async variant of Market. Inherits all properties and non-I/O methods.""" - - async def resolve_series_ticker(self) -> str | None: # type: ignore[override] - if self.data.series_ticker is not None: - return self.data.series_ticker - if not self.data.event_ticker: - return None - try: - event_response = await self._client.get(f"/events/{self.data.event_ticker}") - return event_response["event"]["series_ticker"] - except Exception as e: - logger.warning("Failed to resolve series_ticker for %s: %s", self.data.ticker, e) - return None - - async def get_orderbook(self, *, depth: int | None = None) -> OrderbookResponse: # type: ignore[override] - endpoint = f"/markets/{self.data.ticker}/orderbook" - if depth: - endpoint += f"?depth={depth}" - response = await self._client.get(endpoint) - return OrderbookResponse.model_validate(response) - - async def get_candlesticks( # type: ignore[override] - self, start_ts: int, end_ts: int, - period: CandlestickPeriod = CandlestickPeriod.ONE_HOUR, - ) -> CandlestickResponse: - series = await self.resolve_series_ticker() - if not series: - raise ValueError(f"Market {self.data.ticker} does not have a series_ticker.") - query = f"start_ts={start_ts}&end_ts={end_ts}&period_interval={period.value}" - endpoint = f"/series/{series}/markets/{self.data.ticker}/candlesticks?{query}" - response = await self._client.get(endpoint) - return CandlestickResponse.model_validate(response) - - async def get_trades( # type: ignore[override] - self, min_ts=None, max_ts=None, limit=100, cursor=None, fetch_all=False, - ) -> DataFrameList[TradeModel]: - return await self._client.get_trades( - ticker=self.ticker, min_ts=min_ts, max_ts=max_ts, - limit=limit, cursor=cursor, fetch_all=fetch_all, - ) - - async def get_event(self) -> AsyncEvent | None: # type: ignore[override] - if not self.event_ticker: - return None - return await self._client.get_event(self.event_ticker) - - -class AsyncSeries(Series): - """Async variant of Series. Inherits all properties and non-I/O methods.""" - - async def get_markets(self, **kwargs) -> DataFrameList[AsyncMarket]: # type: ignore[override] - return await self._client.get_markets(series_ticker=self.ticker, **kwargs) - - async def get_events(self, **kwargs) -> DataFrameList[AsyncEvent]: # type: ignore[override] - return await self._client.get_events(series_ticker=self.ticker, **kwargs) +__all__ = ["Market", "Series", "AsyncMarket", "AsyncSeries"] diff --git a/pykalshi/mve.py b/pykalshi/mve.py index acbaf7b..4733c3c 100644 --- a/pykalshi/mve.py +++ b/pykalshi/mve.py @@ -1,148 +1,4 @@ -from __future__ import annotations -from typing import TYPE_CHECKING -from .models import MveCollectionModel, MveSelectedLeg, EventModel, MarketModel -from .dataframe import DataFrameList -from .enums import Side +from ._sync.mve import MveCollection as MveCollection +from ._async.mve import AsyncMveCollection as AsyncMveCollection -if TYPE_CHECKING: - from .client import KalshiClient - from .events import Event, AsyncEvent - from .markets import Market, AsyncMarket - - -class MveCollection: - """Represents a multivariate event collection (combo container). - - Collections define which events can be combined into combo markets. - Use create_market() to create a tradeable combo, then trade it via - client.communications (RFQ system). - """ - - def __init__(self, client: KalshiClient, data: MveCollectionModel) -> None: - self._client = client - self.data = data - - @property - def collection_ticker(self) -> str: - return self.data.collection_ticker - - @property - def title(self) -> str | None: - return self.data.title - - @property - def series_ticker(self) -> str | None: - return self.data.series_ticker - - def create_market( - self, - selected_markets: list[dict[str, str]], - ) -> Market: - """Create a combo market in this collection. - - Must be called before trading or looking up a combo. Each entry - specifies a leg of the combo. - - Args: - selected_markets: List of leg dicts, each with keys: - - market_ticker: The market ticker for this leg. - - event_ticker: The event ticker for this leg. - - side: "yes" or "no". - - Returns: - The created combo Market. - - Example: - market = collection.create_market([ - {"market_ticker": "KXABC-A", "event_ticker": "KXABC", "side": "yes"}, - {"market_ticker": "KXDEF-B", "event_ticker": "KXDEF", "side": "yes"}, - ]) - """ - from .markets import Market - - body = {"selected_markets": selected_markets, "with_market_payload": True} - response = self._client.post( - f"/multivariate_event_collections/{self.collection_ticker}", body - ) - model = MarketModel.model_validate(response.get("market", response)) - return Market(self._client, model) - - def lookup_ticker( - self, - selected_markets: list[dict[str, str]], - ) -> dict: - """Look up tickers for a combo market by its leg combination. - - Returns 404 if the combination hasn't been previously created - via create_market(). - - Args: - selected_markets: List of leg dicts (same format as create_market). - - Returns: - Dict with market_ticker and event_ticker for the combo. - """ - body = {"selected_markets": selected_markets} - return self._client.put( - f"/multivariate_event_collections/{self.collection_ticker}/lookup", body - ) - - def get_events(self, *, with_nested_markets: bool = False) -> DataFrameList[Event]: - """Get multivariate events in this collection. - - Args: - with_nested_markets: If True, include markets nested in each event. - """ - return self._client.get_multivariate_events( - collection_ticker=self.collection_ticker, - with_nested_markets=with_nested_markets, - ) - - def __getattr__(self, name: str): - return getattr(self.data, name) - - def __eq__(self, other: object) -> bool: - if not isinstance(other, MveCollection): - return NotImplemented - return self.data.collection_ticker == other.data.collection_ticker - - def __hash__(self) -> int: - return hash(self.data.collection_ticker) - - def __repr__(self) -> str: - parts = [f"" - - def _repr_html_(self) -> str: - from ._repr import mve_collection_html - return mve_collection_html(self) - - -class AsyncMveCollection(MveCollection): - """Async variant of MveCollection.""" - - async def create_market(self, selected_markets: list[dict[str, str]]) -> AsyncMarket: # type: ignore[override] - from .markets import AsyncMarket - body = {"selected_markets": selected_markets, "with_market_payload": True} - response = await self._client.post( - f"/multivariate_event_collections/{self.collection_ticker}", body - ) - model = MarketModel.model_validate(response.get("market", response)) - return AsyncMarket(self._client, model) - - async def lookup_ticker(self, selected_markets: list[dict[str, str]]) -> dict: # type: ignore[override] - body = {"selected_markets": selected_markets} - return await self._client.put( - f"/multivariate_event_collections/{self.collection_ticker}/lookup", body - ) - - async def get_events(self, *, with_nested_markets: bool = False) -> DataFrameList[AsyncEvent]: # type: ignore[override] - return await self._client.get_multivariate_events( - collection_ticker=self.collection_ticker, - with_nested_markets=with_nested_markets, - ) +__all__ = ["MveCollection", "AsyncMveCollection"] diff --git a/pykalshi/orders.py b/pykalshi/orders.py index da20b09..330d61d 100644 --- a/pykalshi/orders.py +++ b/pykalshi/orders.py @@ -1,238 +1,4 @@ -from __future__ import annotations -import time -from typing import TYPE_CHECKING -from .models import OrderModel -from .enums import OrderStatus, Action, Side, OrderType +from ._sync.orders import Order as Order +from ._async.orders import AsyncOrder as AsyncOrder -if TYPE_CHECKING: - from .client import KalshiClient - -TERMINAL_STATUSES = frozenset({OrderStatus.CANCELED, OrderStatus.EXECUTED}) - - -class Order: - """Represents a Kalshi order. - - Key fields are exposed as typed properties for IDE support. - All other OrderModel fields are accessible via attribute delegation. - """ - - def __init__(self, client: KalshiClient, data: OrderModel) -> None: - self._client = client - self.data = data - - # --- Typed properties for core fields --- - - @property - def order_id(self) -> str: - return self.data.order_id - - @property - def ticker(self) -> str: - return self.data.ticker - - @property - def status(self) -> OrderStatus: - return self.data.status - - @property - def action(self) -> Action | None: - return self.data.action - - @property - def side(self) -> Side | None: - return self.data.side - - @property - def type(self) -> OrderType | None: - return self.data.type - - @property - def yes_price_dollars(self) -> str | None: - return self.data.yes_price_dollars - - @property - def no_price_dollars(self) -> str | None: - return self.data.no_price_dollars - - @property - def initial_count_fp(self) -> str | None: - return self.data.initial_count_fp - - @property - def fill_count_fp(self) -> str | None: - return self.data.fill_count_fp - - @property - def remaining_count_fp(self) -> str | None: - return self.data.remaining_count_fp - - @property - def created_time(self) -> str | None: - return self.data.created_time - - # --- Domain logic --- - - def cancel(self) -> Order: - """Cancel this order. - - Returns: - Self with updated data (status will be CANCELED). - """ - updated = self._client.portfolio.cancel_order(self.order_id) - self.data = updated.data - return self - - def amend( - self, - *, - count_fp: str | None = None, - yes_price_dollars: str | None = None, - no_price_dollars: str | None = None, - ) -> Order: - """Amend this order's price or count. - - Args: - count_fp: New total contract count (fixed-point string). - yes_price_dollars: New YES price (dollar string). - no_price_dollars: New NO price (dollar string, converted to yes internally). - - Returns: - Self with updated data. - """ - updated = self._client.portfolio.amend_order( - self.order_id, - count_fp=count_fp, - yes_price_dollars=yes_price_dollars, - no_price_dollars=no_price_dollars, - ticker=self.ticker, - action=self.action, - side=self.side, - ) - self.data = updated.data - return self - - def decrease(self, reduce_by_fp: str) -> Order: - """Decrease the remaining count of this order. - - Args: - reduce_by_fp: Number of contracts to reduce by (fixed-point string). - - Returns: - Self with updated data. - """ - updated = self._client.portfolio.decrease_order(self.order_id, reduce_by_fp) - self.data = updated.data - return self - - def refresh(self) -> Order: - """Re-fetch this order's current state from the API. - - Returns: - Self with updated data. - """ - updated = self._client.portfolio.get_order(self.order_id) - self.data = updated.data - return self - - def wait_until_terminal( - self, timeout: float = 30.0, poll_interval: float = 0.5 - ) -> Order: - """Block until order reaches a terminal state. - - Terminal states are: CANCELED, EXECUTED. - - Args: - timeout: Maximum seconds to wait before raising TimeoutError. - poll_interval: Seconds between refresh calls. - - Returns: - Self with updated data. - - Raises: - TimeoutError: If timeout is reached before terminal state. - """ - deadline = time.monotonic() + timeout - while self.status not in TERMINAL_STATUSES: - if time.monotonic() >= deadline: - raise TimeoutError( - f"Order {self.order_id} still {self.status.value} after {timeout}s" - ) - time.sleep(poll_interval) - self.refresh() - return self - - def __getattr__(self, name: str): - return getattr(self.data, name) - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Order): - return NotImplemented - return self.data.order_id == other.data.order_id - - def __hash__(self) -> int: - return hash(self.data.order_id) - - def __repr__(self) -> str: - action = self.action.value.upper() if self.action else "?" - side = self.side.value.upper() if self.side else "?" - price = self.yes_price_dollars if self.yes_price_dollars is not None else self.no_price_dollars - filled = self.fill_count_fp or "0" - total = self.initial_count_fp or "0" - return f"" - - def _repr_html_(self) -> str: - from ._repr import order_html - return order_html(self) - - -class AsyncOrder(Order): - """Async variant of Order. Inherits all properties and non-I/O methods.""" - - async def cancel(self) -> AsyncOrder: # type: ignore[override] - updated = await self._client.portfolio.cancel_order(self.order_id) - self.data = updated.data - return self - - async def amend( # type: ignore[override] - self, - *, - count_fp: str | None = None, - yes_price_dollars: str | None = None, - no_price_dollars: str | None = None, - ) -> AsyncOrder: - updated = await self._client.portfolio.amend_order( - self.order_id, - count_fp=count_fp, - yes_price_dollars=yes_price_dollars, - no_price_dollars=no_price_dollars, - ticker=self.ticker, - action=self.action, - side=self.side, - ) - self.data = updated.data - return self - - async def decrease(self, reduce_by_fp: str) -> AsyncOrder: # type: ignore[override] - updated = await self._client.portfolio.decrease_order(self.order_id, reduce_by_fp) - self.data = updated.data - return self - - async def refresh(self) -> AsyncOrder: # type: ignore[override] - updated = await self._client.portfolio.get_order(self.order_id) - self.data = updated.data - return self - - async def wait_until_terminal( # type: ignore[override] - self, timeout: float = 30.0, poll_interval: float = 0.5 - ) -> AsyncOrder: - import asyncio - deadline = time.monotonic() + timeout - while self.status not in TERMINAL_STATUSES: - if time.monotonic() >= deadline: - raise TimeoutError( - f"Order {self.order_id} still {self.status.value} after {timeout}s" - ) - await asyncio.sleep(poll_interval) - await self.refresh() - return self +__all__ = ["Order", "AsyncOrder"] diff --git a/pykalshi/portfolio.py b/pykalshi/portfolio.py index 13bc78d..317e208 100644 --- a/pykalshi/portfolio.py +++ b/pykalshi/portfolio.py @@ -1,863 +1,4 @@ -from __future__ import annotations -from decimal import Decimal -from typing import TYPE_CHECKING -from urllib.parse import urlencode -from .orders import Order, AsyncOrder -from .enums import Action, Side, OrderStatus, TimeInForce, SelfTradePrevention, PositionCountFilter -from .dataframe import DataFrameList -from ._utils import normalize_ticker, normalize_tickers -from .models import ( - OrderModel, BalanceModel, PositionModel, FillModel, - SettlementModel, QueuePositionModel, OrderGroupModel, - SubaccountModel, SubaccountBalanceModel, SubaccountTransferModel, -) +from ._sync.portfolio import Portfolio as Portfolio +from ._async.portfolio import AsyncPortfolio as AsyncPortfolio -if TYPE_CHECKING: - from .client import KalshiClient - from .markets import Market - - -class Portfolio: - """Authenticated user's portfolio and trading operations.""" - - def __init__(self, client: KalshiClient) -> None: - self._client = client - - def get_balance(self) -> BalanceModel: - """Get portfolio balance. Values are dollar strings.""" - data = self._client.get("/portfolio/balance") - return BalanceModel.model_validate(data) - - def place_order( - self, - ticker: str | Market, - action: Action, - side: Side, - count_fp: str, - *, - yes_price_dollars: str | None = None, - no_price_dollars: str | None = None, - client_order_id: str | None = None, - time_in_force: TimeInForce | None = None, - post_only: bool = False, - reduce_only: bool = False, - expiration_ts: int | None = None, - buy_max_cost_dollars: str | None = None, - self_trade_prevention: SelfTradePrevention | None = None, - order_group_id: str | None = None, - subaccount: int | None = None, - cancel_order_on_pause: bool | None = None, - ) -> Order: - """Place an order on a market. - - Args: - ticker: Market ticker string or Market object. - action: BUY or SELL. - side: YES or NO. - count_fp: Number of contracts (fixed-point string, e.g. "10.00"). - yes_price_dollars: Price as dollar string (e.g. "0.45"). - no_price_dollars: Price as dollar string. Converted to - yes_price_dollars internally (yes = 1.00 - no). - client_order_id: Idempotency key. Resubmitting returns existing order. - time_in_force: GTC (default), IOC (immediate-or-cancel), FOK (fill-or-kill). - post_only: If True, reject order if it would take liquidity. - reduce_only: If True, only reduce existing position, never increase. - expiration_ts: Unix timestamp when order auto-cancels. - buy_max_cost_dollars: Maximum total cost (dollar string). Protects against slippage. - self_trade_prevention: Behavior on self-cross (CANCEL_RESTING or CANCEL_INCOMING). - order_group_id: Link to an order group for OCO/bracket strategies. - subaccount: Subaccount number (0 for primary, 1-32 for subaccounts). - cancel_order_on_pause: If True, cancel order if market is paused. - """ - # Extract market structure for validation when a Market object is passed - pls = None - fte = None - if not isinstance(ticker, str): - pls = getattr(ticker, 'price_level_structure', None) - fte = getattr(ticker, 'fractional_trading_enabled', None) - - order_data = self._build_order_data( - ticker, action, side, count_fp, - yes_price_dollars=yes_price_dollars, no_price_dollars=no_price_dollars, - client_order_id=client_order_id, time_in_force=time_in_force, - post_only=post_only, reduce_only=reduce_only, - expiration_ts=expiration_ts, buy_max_cost_dollars=buy_max_cost_dollars, - self_trade_prevention=self_trade_prevention, - order_group_id=order_group_id, subaccount=subaccount, - cancel_order_on_pause=cancel_order_on_pause, - price_level_structure=pls, - fractional_trading_enabled=fte, - ) - response = self._client.post("/portfolio/orders", order_data) - model = OrderModel.model_validate(response["order"]) - return Order(self._client, model) - - def cancel_order(self, order_id: str, *, subaccount: int | None = None) -> Order: - """Cancel a resting order. - - Args: - order_id: ID of the order to cancel. - subaccount: Subaccount number (0 for primary, 1-32 for subaccounts). - - Returns: - The canceled Order with updated status. - """ - endpoint = f"/portfolio/orders/{order_id}" - if subaccount is not None: - endpoint += f"?subaccount={subaccount}" - response = self._client.delete(endpoint) - model = OrderModel.model_validate(response["order"]) - return Order(self._client, model) - - def amend_order( - self, - order_id: str, - *, - count_fp: str | None = None, - yes_price_dollars: str | None = None, - no_price_dollars: str | None = None, - subaccount: int | None = None, - # Required by API but can be fetched from existing order - ticker: str | None = None, - action: Action | None = None, - side: Side | None = None, - ) -> Order: - """Amend a resting order's price or count. - - Args: - order_id: ID of the order to amend. - count_fp: New total contract count (fixed-point string). - yes_price_dollars: New YES price (dollar string). - no_price_dollars: New NO price (dollar string). Converted internally. - subaccount: Subaccount number (0 for primary, 1-32 for subaccounts). - ticker: Market ticker (fetched from order if not provided). - action: Order action (fetched from order if not provided). - side: Order side (fetched from order if not provided). - """ - if yes_price_dollars is not None and no_price_dollars is not None: - raise ValueError("Specify yes_price_dollars or no_price_dollars, not both") - - if no_price_dollars is not None: - yes_price_dollars = str(Decimal("1") - Decimal(no_price_dollars)) - - ticker = normalize_ticker(ticker) - - # Fetch original order to get required fields if not provided - if ticker is None or action is None or side is None or count_fp is None: - original = self.get_order(order_id) - ticker = ticker or original.ticker - action = action or original.action - side = side or original.side - if count_fp is None: - count_fp = original.remaining_count_fp - - body: dict = { - "ticker": ticker, - "action": action.value if isinstance(action, Action) else action, - "side": side.value if isinstance(side, Side) else side, - "count_fp": count_fp, - } - if yes_price_dollars is not None: - body["yes_price_dollars"] = yes_price_dollars - if subaccount is not None: - body["subaccount"] = subaccount - - if "count_fp" not in body and "yes_price_dollars" not in body: - raise ValueError("Must specify at least one of count_fp, yes_price_dollars, or no_price_dollars") - - response = self._client.post(f"/portfolio/orders/{order_id}/amend", body) - model = OrderModel.model_validate(response["order"]) - return Order(self._client, model) - - def decrease_order(self, order_id: str, reduce_by_fp: str) -> Order: - """Decrease the remaining count of a resting order. - - Args: - order_id: ID of the order to decrease. - reduce_by_fp: Number of contracts to reduce by (fixed-point string). - """ - response = self._client.post( - f"/portfolio/orders/{order_id}/decrease", {"reduce_by_fp": reduce_by_fp} - ) - model = OrderModel.model_validate(response["order"]) - return Order(self._client, model) - - def get_orders( - self, - *, - status: OrderStatus | None = None, - ticker: str | None = None, - event_ticker: str | None = None, - min_ts: int | None = None, - max_ts: int | None = None, - limit: int = 100, - cursor: str | None = None, - fetch_all: bool = False, - **extra_params, - ) -> DataFrameList[Order]: - """Get list of orders. - - Args: - status: Filter by order status (resting, canceled, executed). - ticker: Filter by market ticker. - event_ticker: Filter by event ticker (supports comma-separated, max 10). - min_ts: Filter orders after this Unix timestamp. - max_ts: Filter orders before this Unix timestamp. - limit: Maximum results per page (default 100, max 200). - cursor: Pagination cursor for fetching next page. - fetch_all: If True, automatically fetch all pages. - **extra_params: Additional API parameters (e.g., subaccount). - """ - params = { - "limit": limit, - "status": status.value if status is not None else None, - "ticker": normalize_ticker(ticker), - "event_ticker": normalize_ticker(event_ticker), - "min_ts": min_ts, - "max_ts": max_ts, - "cursor": cursor, - **extra_params, - } - data = self._client.paginated_get("/portfolio/orders", "orders", params, fetch_all) - return DataFrameList(Order(self._client, OrderModel.model_validate(d)) for d in data) - - def get_order(self, order_id: str) -> Order: - """Get a single order by ID.""" - response = self._client.get(f"/portfolio/orders/{order_id}") - model = OrderModel.model_validate(response["order"]) - return Order(self._client, model) - - def get_positions( - self, - *, - ticker: str | None = None, - event_ticker: str | None = None, - count_filter: PositionCountFilter | None = None, - limit: int = 100, - cursor: str | None = None, - fetch_all: bool = False, - **extra_params, - ) -> DataFrameList[PositionModel]: - """Get portfolio positions. - - Args: - ticker: Filter by specific market ticker. - event_ticker: Filter by event ticker (supports comma-separated, max 10). - count_filter: Filter positions with non-zero values (POSITION or TOTAL_TRADED). - limit: Maximum positions per page (default 100, max 1000). - cursor: Pagination cursor for fetching next page. - fetch_all: If True, automatically fetch all pages. - **extra_params: Additional API parameters (e.g., subaccount). - """ - params = { - "limit": limit, - "ticker": normalize_ticker(ticker), - "event_ticker": normalize_ticker(event_ticker), - "count_filter": count_filter.value if count_filter is not None else None, - "cursor": cursor, - **extra_params, - } - data = self._client.paginated_get("/portfolio/positions", "market_positions", params, fetch_all) - return DataFrameList(PositionModel.model_validate(p) for p in data) - - def get_fills( - self, - *, - ticker: str | None = None, - order_id: str | None = None, - min_ts: int | None = None, - max_ts: int | None = None, - limit: int = 100, - cursor: str | None = None, - fetch_all: bool = False, - **extra_params, - ) -> DataFrameList[FillModel]: - """Get trade fills (executed trades). - - Args: - ticker: Filter by market ticker. - order_id: Filter by specific order ID. - min_ts: Minimum timestamp (Unix seconds). - max_ts: Maximum timestamp (Unix seconds). - limit: Maximum fills per page (default 100, max 200). - cursor: Pagination cursor for fetching next page. - fetch_all: If True, automatically fetch all pages. - **extra_params: Additional API parameters (e.g., subaccount). - """ - params = { - "limit": limit, - "ticker": normalize_ticker(ticker), - "order_id": order_id, - "min_ts": min_ts, - "max_ts": max_ts, - "cursor": cursor, - **extra_params, - } - data = self._client.paginated_get("/portfolio/fills", "fills", params, fetch_all) - return DataFrameList(FillModel.model_validate(f) for f in data) - - # --- Batch Operations --- - - def batch_place_orders(self, orders: list[dict]) -> DataFrameList[Order]: - """Place multiple orders atomically. - - Args: - orders: List of order dicts with keys: ticker, action, side, count_fp, - yes_price_dollars/no_price_dollars, and optional advanced params. - - Example: - orders = [ - {"ticker": "KXBTC", "action": "buy", "side": "yes", "count_fp": "10.00", "yes_price_dollars": "0.45"}, - {"ticker": "KXBTC", "action": "buy", "side": "no", "count_fp": "10.00", "no_price_dollars": "0.45"}, - ] - results = portfolio.batch_place_orders(orders) - """ - prepared = self._build_batch_orders(orders) - response = self._client.post("/portfolio/orders/batched", {"orders": prepared}) - result = [] - for item in response.get("orders", []): - order_data = item.get("order") - if order_data is None: - continue - result.append(Order(self._client, OrderModel.model_validate(order_data))) - return DataFrameList(result) - - def batch_cancel_orders(self, order_ids: list[str]) -> DataFrameList[Order]: - """Cancel multiple orders atomically. - - Args: - order_ids: List of order IDs to cancel (max 20). - - Returns: - The canceled Orders with updated status. - """ - orders = [{"order_id": oid} for oid in order_ids] - response = self._client.delete("/portfolio/orders/batched", {"orders": orders}) - result = [] - for item in response.get("orders", []): - order_data = item.get("order") - if order_data is None: - continue - result.append(Order(self._client, OrderModel.model_validate(order_data))) - return DataFrameList(result) - - # --- Queue Position --- - - def get_queue_position(self, order_id: str) -> QueuePositionModel: - """Get queue position for a single resting order.""" - response = self._client.get(f"/portfolio/orders/{order_id}/queue_position") - return QueuePositionModel( - order_id=order_id, - queue_position_fp=response.get("queue_position_fp", "0.00"), - ) - - def get_queue_positions( - self, - *, - market_tickers: list[str] | None = None, - event_ticker: str | None = None, - ) -> DataFrameList[QueuePositionModel]: - """Get queue positions for all resting orders.""" - params: dict = {} - if market_tickers: - params["market_tickers"] = ",".join(normalize_tickers(market_tickers)) - if event_ticker: - params["event_ticker"] = normalize_ticker(event_ticker) - - endpoint = "/portfolio/orders/queue_positions" - if params: - endpoint = f"{endpoint}?{urlencode(params)}" - - response = self._client.get(endpoint) - return DataFrameList( - QueuePositionModel.model_validate(qp) - for qp in response.get("queue_positions", []) - ) - - # --- Settlements --- - - def get_settlements( - self, - *, - ticker: str | None = None, - event_ticker: str | None = None, - limit: int = 100, - cursor: str | None = None, - fetch_all: bool = False, - **extra_params, - ) -> DataFrameList[SettlementModel]: - """Get settlement records for resolved positions.""" - params = { - "limit": limit, - "ticker": normalize_ticker(ticker), - "event_ticker": normalize_ticker(event_ticker), - "cursor": cursor, - **extra_params, - } - data = self._client.paginated_get("/portfolio/settlements", "settlements", params, fetch_all) - return DataFrameList(SettlementModel.model_validate(s) for s in data) - - def get_resting_order_value(self) -> str: - """Get total value of all resting orders as dollar string. - - NOTE: This endpoint is FCM-only (institutional accounts). - """ - response = self._client.get("/portfolio/summary/total_resting_order_value") - return response.get("total_resting_order_value_dollars", "0") - - # --- Order Groups (Contract Rate Limiting) --- - - def create_order_group(self, contracts_limit_fp: str) -> OrderGroupModel: - """Create an order group for rate-limiting contract matches. - - Args: - contracts_limit_fp: Maximum contracts (fixed-point string) that can be - matched in a rolling 15-second window. - - Returns: - Created OrderGroupModel. - """ - body: dict = {"contracts_limit_fp": contracts_limit_fp} - response = self._client.post("/portfolio/order_groups/create", body) - return OrderGroupModel.model_validate(response) - - def get_order_group(self, order_group_id: str) -> OrderGroupModel: - """Get an order group by ID.""" - response = self._client.get(f"/portfolio/order_groups/{order_group_id}") - response["id"] = order_group_id - return OrderGroupModel.model_validate(response) - - def trigger_order_group(self, order_group_id: str) -> None: - """Manually trigger an order group, cancelling all orders in it.""" - self._client.put(f"/portfolio/order_groups/{order_group_id}/trigger", {}) - - def get_order_groups(self) -> DataFrameList[OrderGroupModel]: - """List all order groups.""" - response = self._client.get("/portfolio/order_groups") - return DataFrameList( - OrderGroupModel.model_validate(og) - for og in response.get("order_groups", []) - ) - - def reset_order_group(self, order_group_id: str) -> None: - """Reset matched contract counter for an order group.""" - self._client.put(f"/portfolio/order_groups/{order_group_id}/reset", {}) - - def update_order_group_limit(self, order_group_id: str, contracts_limit_fp: str) -> None: - """Update the contracts limit for an order group. - - Args: - order_group_id: ID of the order group. - contracts_limit_fp: New maximum contracts (fixed-point string). - """ - body: dict = {"contracts_limit_fp": contracts_limit_fp} - self._client.put(f"/portfolio/order_groups/{order_group_id}/limit", body) - - # --- Subaccounts --- - - def create_subaccount(self) -> SubaccountModel: - """Create a new numbered subaccount.""" - response = self._client.post("/portfolio/subaccounts", {}) - return SubaccountModel.model_validate(response.get("subaccount", response)) - - def transfer_between_subaccounts( - self, - from_subaccount_id: str, - to_subaccount_id: str, - amount_dollars: str, - ) -> SubaccountTransferModel: - """Transfer funds between subaccounts. - - Args: - from_subaccount_id: Source subaccount ID. - to_subaccount_id: Destination subaccount ID. - amount_dollars: Amount to transfer (dollar string). - """ - body = { - "from_subaccount_id": from_subaccount_id, - "to_subaccount_id": to_subaccount_id, - "amount_dollars": amount_dollars, - } - response = self._client.post("/portfolio/subaccounts/transfer", body) - return SubaccountTransferModel.model_validate(response.get("transfer", response)) - - def get_subaccount_balances(self) -> DataFrameList[SubaccountBalanceModel]: - """Get balances for all subaccounts.""" - response = self._client.get("/portfolio/subaccounts/balances") - return DataFrameList( - SubaccountBalanceModel.model_validate(b) - for b in response.get("balances", []) - ) - - def get_subaccount_transfers( - self, - *, - limit: int = 100, - cursor: str | None = None, - fetch_all: bool = False, - **extra_params, - ) -> DataFrameList[SubaccountTransferModel]: - """Get transfer history between subaccounts.""" - params = {"limit": limit, "cursor": cursor, **extra_params} - data = self._client.paginated_get( - "/portfolio/subaccounts/transfers", "transfers", params, fetch_all - ) - return DataFrameList(SubaccountTransferModel.model_validate(t) for t in data) - - # --- Shared validation helpers --- - - @staticmethod - def _validate_tick_size(price: Decimal, price_level_structure: str) -> None: - """Validate that price aligns to the market's tick size. - - Raises ValueError if the price is not on a valid tick boundary. - """ - if price_level_structure == "linear_cent": - # $0.00–$1.00, tick $0.01 - tick = Decimal("0.01") - if price % tick != 0: - raise ValueError( - f"Price {price} is not on a valid tick for linear_cent " - f"(tick size $0.01)" - ) - elif price_level_structure == "deci_cent": - # $0.00–$1.00, tick $0.001 - tick = Decimal("0.001") - if price % tick != 0: - raise ValueError( - f"Price {price} is not on a valid tick for deci_cent " - f"(tick size $0.001)" - ) - elif price_level_structure == "tapered_deci_cent": - # $0.00–$0.10: tick $0.001, $0.10–$0.90: tick $0.01, $0.90–$1.00: tick $0.001 - if price <= Decimal("0.10") or price >= Decimal("0.90"): - tick = Decimal("0.001") - else: - tick = Decimal("0.01") - if price % tick != 0: - raise ValueError( - f"Price {price} is not on a valid tick for tapered_deci_cent " - f"(tick size ${tick} in this price range)" - ) - - @staticmethod - def _validate_fractional(count_fp: str, fractional_enabled: bool) -> None: - """Validate count_fp is whole when fractional trading is disabled.""" - if not fractional_enabled: - d = Decimal(count_fp) - if d != int(d): - raise ValueError( - f"Fractional trading is not enabled for this market. " - f"count_fp must be a whole number, got {count_fp}" - ) - - @staticmethod - def _build_order_data( - ticker, - action: Action, - side: Side, - count_fp: str, - *, - yes_price_dollars=None, - no_price_dollars=None, - client_order_id=None, - time_in_force=None, - post_only=False, - reduce_only=False, - expiration_ts=None, - buy_max_cost_dollars=None, - self_trade_prevention=None, - order_group_id=None, - subaccount=None, - cancel_order_on_pause=None, - price_level_structure=None, - fractional_trading_enabled=None, - ) -> dict: - """Build and validate order data dict. No I/O. - - If price_level_structure is provided, validates tick size alignment. - If fractional_trading_enabled is provided (False), validates count_fp is whole. - """ - if yes_price_dollars is not None and no_price_dollars is not None: - raise ValueError("Specify yes_price_dollars or no_price_dollars, not both") - - if yes_price_dollars is None and no_price_dollars is None: - raise ValueError("Limit orders require yes_price_dollars or no_price_dollars") - - if no_price_dollars is not None: - yes_price_dollars = str(Decimal("1") - Decimal(no_price_dollars)) - - # Validate tick size if market structure is known - if price_level_structure and yes_price_dollars is not None: - Portfolio._validate_tick_size(Decimal(yes_price_dollars), price_level_structure) - - # Validate fractional trading - if fractional_trading_enabled is not None: - Portfolio._validate_fractional(count_fp, fractional_trading_enabled) - - ticker_str = ticker.upper() if isinstance(ticker, str) else ticker.ticker - - order_data: dict = { - "ticker": ticker_str, - "action": action.value, - "side": side.value, - "count_fp": count_fp, - "yes_price_dollars": yes_price_dollars, - } - if client_order_id is not None: - order_data["client_order_id"] = client_order_id - if time_in_force is not None: - order_data["time_in_force"] = time_in_force.value - if post_only: - order_data["post_only"] = True - if reduce_only: - order_data["reduce_only"] = True - if expiration_ts is not None: - order_data["expiration_ts"] = expiration_ts - if buy_max_cost_dollars is not None: - order_data["buy_max_cost_dollars"] = buy_max_cost_dollars - if self_trade_prevention is not None: - order_data["self_trade_prevention_type"] = self_trade_prevention.value - if order_group_id is not None: - order_data["order_group_id"] = order_group_id - if subaccount is not None: - order_data["subaccount"] = subaccount - if cancel_order_on_pause is not None: - order_data["cancel_order_on_pause"] = cancel_order_on_pause - return order_data - - @staticmethod - def _build_batch_orders(orders: list[dict]) -> list[dict]: - """Validate and prepare batch orders. No I/O.""" - prepared = [] - for order in orders: - o = dict(order) - - if "yes_price_dollars" in o and "no_price_dollars" in o: - raise ValueError("Specify yes_price_dollars or no_price_dollars, not both") - if "yes_price_dollars" not in o and "no_price_dollars" not in o: - raise ValueError("Limit orders require yes_price_dollars or no_price_dollars") - if "no_price_dollars" in o: - o["yes_price_dollars"] = str(Decimal("1") - Decimal(o.pop("no_price_dollars"))) - # Strip "type" — Kalshi API no longer accepts it - o.pop("type", None) - prepared.append(o) - return prepared - - -class AsyncPortfolio(Portfolio): - """Async variant of Portfolio. Inherits validation helpers.""" - - async def get_balance(self) -> BalanceModel: # type: ignore[override] - data = await self._client.get("/portfolio/balance") - return BalanceModel.model_validate(data) - - async def place_order( # type: ignore[override] - self, - ticker, - action: Action, - side: Side, - count_fp: str | None = None, - **kwargs, - ) -> AsyncOrder: - order_data = self._build_order_data( - ticker, action, side, count_fp, **kwargs - ) - response = await self._client.post("/portfolio/orders", order_data) - model = OrderModel.model_validate(response["order"]) - return AsyncOrder(self._client, model) - - async def cancel_order(self, order_id: str, *, subaccount: int | None = None) -> AsyncOrder: # type: ignore[override] - endpoint = f"/portfolio/orders/{order_id}" - if subaccount is not None: - endpoint += f"?subaccount={subaccount}" - response = await self._client.delete(endpoint) - model = OrderModel.model_validate(response["order"]) - return AsyncOrder(self._client, model) - - async def amend_order( # type: ignore[override] - self, - order_id: str, - *, - count_fp: str | None = None, - yes_price_dollars: str | None = None, - no_price_dollars: str | None = None, - subaccount: int | None = None, - ticker: str | None = None, - action: Action | None = None, - side: Side | None = None, - ) -> AsyncOrder: - if yes_price_dollars is not None and no_price_dollars is not None: - raise ValueError("Specify yes_price_dollars or no_price_dollars, not both") - if no_price_dollars is not None: - yes_price_dollars = str(Decimal("1") - Decimal(no_price_dollars)) - ticker = normalize_ticker(ticker) - - if ticker is None or action is None or side is None or count_fp is None: - original = await self.get_order(order_id) - ticker = ticker or original.ticker - action = action or original.action - side = side or original.side - if count_fp is None: - count_fp = original.remaining_count_fp - - body: dict = { - "ticker": ticker, - "action": action.value if isinstance(action, Action) else action, - "side": side.value if isinstance(side, Side) else side, - "count_fp": count_fp, - } - if yes_price_dollars is not None: - body["yes_price_dollars"] = yes_price_dollars - if subaccount is not None: - body["subaccount"] = subaccount - if "count_fp" not in body and "yes_price_dollars" not in body: - raise ValueError("Must specify at least one of count_fp, yes_price_dollars, or no_price_dollars") - - response = await self._client.post(f"/portfolio/orders/{order_id}/amend", body) - model = OrderModel.model_validate(response["order"]) - return AsyncOrder(self._client, model) - - async def decrease_order(self, order_id: str, reduce_by_fp: str) -> AsyncOrder: # type: ignore[override] - response = await self._client.post( - f"/portfolio/orders/{order_id}/decrease", {"reduce_by_fp": reduce_by_fp} - ) - model = OrderModel.model_validate(response["order"]) - return AsyncOrder(self._client, model) - - async def get_orders(self, *, status=None, ticker=None, event_ticker=None, # type: ignore[override] - min_ts=None, max_ts=None, limit=100, cursor=None, - fetch_all=False, **extra_params) -> DataFrameList[AsyncOrder]: - params = { - "limit": limit, - "status": status.value if status is not None else None, - "ticker": normalize_ticker(ticker), - "event_ticker": normalize_ticker(event_ticker), - "min_ts": min_ts, "max_ts": max_ts, "cursor": cursor, - **extra_params, - } - data = await self._client.paginated_get("/portfolio/orders", "orders", params, fetch_all) - return DataFrameList(AsyncOrder(self._client, OrderModel.model_validate(d)) for d in data) - - async def get_order(self, order_id: str) -> AsyncOrder: # type: ignore[override] - response = await self._client.get(f"/portfolio/orders/{order_id}") - model = OrderModel.model_validate(response["order"]) - return AsyncOrder(self._client, model) - - async def get_positions(self, *, ticker=None, event_ticker=None, # type: ignore[override] - count_filter=None, limit=100, cursor=None, - fetch_all=False, **extra_params) -> DataFrameList[PositionModel]: - params = { - "limit": limit, - "ticker": normalize_ticker(ticker), - "event_ticker": normalize_ticker(event_ticker), - "count_filter": count_filter.value if count_filter is not None else None, - "cursor": cursor, **extra_params, - } - data = await self._client.paginated_get("/portfolio/positions", "market_positions", params, fetch_all) - return DataFrameList(PositionModel.model_validate(p) for p in data) - - async def get_fills(self, *, ticker=None, order_id=None, # type: ignore[override] - min_ts=None, max_ts=None, limit=100, cursor=None, - fetch_all=False, **extra_params) -> DataFrameList[FillModel]: - params = { - "limit": limit, "ticker": normalize_ticker(ticker), - "order_id": order_id, "min_ts": min_ts, "max_ts": max_ts, - "cursor": cursor, **extra_params, - } - data = await self._client.paginated_get("/portfolio/fills", "fills", params, fetch_all) - return DataFrameList(FillModel.model_validate(f) for f in data) - - async def batch_place_orders(self, orders: list[dict]) -> DataFrameList[AsyncOrder]: # type: ignore[override] - prepared = self._build_batch_orders(orders) - response = await self._client.post("/portfolio/orders/batched", {"orders": prepared}) - result = [] - for item in response.get("orders", []): - order_data = item.get("order") - if order_data is None: - continue - result.append(AsyncOrder(self._client, OrderModel.model_validate(order_data))) - return DataFrameList(result) - - async def batch_cancel_orders(self, order_ids: list[str]) -> DataFrameList[AsyncOrder]: # type: ignore[override] - orders = [{"order_id": oid} for oid in order_ids] - response = await self._client.delete("/portfolio/orders/batched", {"orders": orders}) - result = [] - for item in response.get("orders", []): - order_data = item.get("order") - if order_data is None: - continue - result.append(AsyncOrder(self._client, OrderModel.model_validate(order_data))) - return DataFrameList(result) - - async def get_queue_position(self, order_id: str) -> QueuePositionModel: # type: ignore[override] - response = await self._client.get(f"/portfolio/orders/{order_id}/queue_position") - return QueuePositionModel(order_id=order_id, queue_position_fp=response.get("queue_position_fp", "0.00")) - - async def get_queue_positions(self, *, market_tickers=None, event_ticker=None) -> DataFrameList[QueuePositionModel]: # type: ignore[override] - params: dict = {} - if market_tickers: - params["market_tickers"] = ",".join(normalize_tickers(market_tickers)) - if event_ticker: - params["event_ticker"] = normalize_ticker(event_ticker) - endpoint = "/portfolio/orders/queue_positions" - if params: - endpoint = f"{endpoint}?{urlencode(params)}" - response = await self._client.get(endpoint) - return DataFrameList(QueuePositionModel.model_validate(qp) for qp in response.get("queue_positions", [])) - - async def get_settlements(self, *, ticker=None, event_ticker=None, # type: ignore[override] - limit=100, cursor=None, fetch_all=False, - **extra_params) -> DataFrameList[SettlementModel]: - params = { - "limit": limit, "ticker": normalize_ticker(ticker), - "event_ticker": normalize_ticker(event_ticker), - "cursor": cursor, **extra_params, - } - data = await self._client.paginated_get("/portfolio/settlements", "settlements", params, fetch_all) - return DataFrameList(SettlementModel.model_validate(s) for s in data) - - async def get_resting_order_value(self) -> str: # type: ignore[override] - response = await self._client.get("/portfolio/summary/total_resting_order_value") - return response.get("total_resting_order_value_dollars", "0") - - async def create_order_group(self, contracts_limit_fp: str) -> OrderGroupModel: # type: ignore[override] - response = await self._client.post("/portfolio/order_groups/create", {"contracts_limit_fp": contracts_limit_fp}) - return OrderGroupModel.model_validate(response) - - async def get_order_group(self, order_group_id: str) -> OrderGroupModel: # type: ignore[override] - response = await self._client.get(f"/portfolio/order_groups/{order_group_id}") - response["id"] = order_group_id - return OrderGroupModel.model_validate(response) - - async def trigger_order_group(self, order_group_id: str) -> None: # type: ignore[override] - await self._client.put(f"/portfolio/order_groups/{order_group_id}/trigger", {}) - - async def get_order_groups(self) -> DataFrameList[OrderGroupModel]: # type: ignore[override] - response = await self._client.get("/portfolio/order_groups") - return DataFrameList(OrderGroupModel.model_validate(og) for og in response.get("order_groups", [])) - - async def reset_order_group(self, order_group_id: str) -> None: # type: ignore[override] - await self._client.put(f"/portfolio/order_groups/{order_group_id}/reset", {}) - - async def update_order_group_limit(self, order_group_id: str, contracts_limit_fp: str) -> None: # type: ignore[override] - await self._client.put(f"/portfolio/order_groups/{order_group_id}/limit", {"contracts_limit_fp": contracts_limit_fp}) - - async def create_subaccount(self) -> SubaccountModel: # type: ignore[override] - response = await self._client.post("/portfolio/subaccounts", {}) - return SubaccountModel.model_validate(response.get("subaccount", response)) - - async def transfer_between_subaccounts(self, from_subaccount_id: str, to_subaccount_id: str, amount_dollars: str) -> SubaccountTransferModel: # type: ignore[override] - body = {"from_subaccount_id": from_subaccount_id, "to_subaccount_id": to_subaccount_id, "amount_dollars": amount_dollars} - response = await self._client.post("/portfolio/subaccounts/transfer", body) - return SubaccountTransferModel.model_validate(response.get("transfer", response)) - - async def get_subaccount_balances(self) -> DataFrameList[SubaccountBalanceModel]: # type: ignore[override] - response = await self._client.get("/portfolio/subaccounts/balances") - return DataFrameList(SubaccountBalanceModel.model_validate(b) for b in response.get("balances", [])) - - async def get_subaccount_transfers(self, *, limit=100, cursor=None, # type: ignore[override] - fetch_all=False, **extra_params) -> DataFrameList[SubaccountTransferModel]: - params = {"limit": limit, "cursor": cursor, **extra_params} - data = await self._client.paginated_get("/portfolio/subaccounts/transfers", "transfers", params, fetch_all) - return DataFrameList(SubaccountTransferModel.model_validate(t) for t in data) +__all__ = ["Portfolio", "AsyncPortfolio"] diff --git a/pyproject.toml b/pyproject.toml index bb69672..3fd6d8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "pykalshi" -version = "0.4.0" +version = "0.5.0" description = "A typed Python client for the Kalshi prediction markets API with WebSocket streaming, automatic retries, and ergonomic interfaces" readme = "README.md" license = "MIT" @@ -75,8 +75,8 @@ Issues = "https://github.com/ArshKA/pykalshi/issues" requires = ["setuptools>=61.0"] build-backend = "setuptools.build_meta" -[tool.setuptools] -packages = ["pykalshi"] +[tool.setuptools.packages.find] +include = ["pykalshi*"] [tool.pytest.ini_options] asyncio_mode = "auto" diff --git a/scripts/generate_sync.py b/scripts/generate_sync.py new file mode 100644 index 0000000..50b40bb --- /dev/null +++ b/scripts/generate_sync.py @@ -0,0 +1,191 @@ +#!/usr/bin/env python +"""Generate synchronous API from async source files. + +Usage: + python scripts/generate_sync.py [--write|--check] + +Default mode is --write. +""" + +import re +import sys +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parent.parent +ASYNC_DIR = REPO_ROOT / "pykalshi" / "_async" +SYNC_DIR = REPO_ROOT / "pykalshi" / "_sync" + +GENERATED_MODULES = [ + "client.py", + "portfolio.py", + "orders.py", + "markets.py", + "events.py", + "exchange.py", + "api_keys.py", + "communications.py", + "mve.py", +] + +# Class renames: AsyncX -> X (applied with word boundaries) +CLASS_RENAMES = { + "AsyncKalshiClient": "KalshiClient", + "AsyncPortfolio": "Portfolio", + "AsyncOrder": "Order", + "AsyncMarket": "Market", + "AsyncSeries": "Series", + "AsyncEvent": "Event", + "AsyncExchange": "Exchange", + "AsyncAPIKeys": "APIKeys", + "AsyncCommunications": "Communications", + "AsyncMveCollection": "MveCollection", + "AsyncRateLimiterProtocol": "RateLimiterProtocol", + "AsyncFeed": "Feed", +} + +HEADER = ( + "# AUTO-GENERATED from pykalshi/_async/{filename} — do not edit manually.\n" + "# Re-run: python scripts/generate_sync.py\n" +) + + +def transform(source: str, filename: str) -> str: + """Apply all sync transforms to async source code.""" + lines = source.split("\n") + result_lines = [] + + for line in lines: + # Remove # type: ignore[override] + line = re.sub(r"\s*#\s*type:\s*ignore\[override\]", "", line) + + # Keyword transforms + line = re.sub(r"\basync\s+def\b", "def", line) + line = re.sub(r"\basync\s+with\b", "with", line) + line = re.sub(r"\basync\s+for\b", "for", line) + line = re.sub(r"\bawait\s+", "", line) + + # stdlib swaps + line = re.sub(r"\basyncio\.sleep\b", "time.sleep", line) + + # httpx swaps + line = re.sub(r"\bhttpx\.AsyncClient\b", "httpx.Client", line) + + # Protocol method renames + line = re.sub(r"\b__aenter__\b", "__enter__", line) + line = re.sub(r"\b__aexit__\b", "__exit__", line) + line = re.sub(r"\baclose\b", "close", line) + + # Special-case import (must come BEFORE class renames) + line = re.sub( + r"from \.\.afeed import AsyncFeed", "from ..feed import Feed", line + ) + + # Class renames + for async_name, sync_name in CLASS_RENAMES.items(): + line = re.sub(rf"\b{async_name}\b", sync_name, line) + + # import asyncio -> import time + if re.match(r"^import asyncio\s*$", line): + line = "import time" + + result_lines.append(line) + + result = "\n".join(result_lines) + + # Deduplicate consecutive 'import time' lines + result = re.sub( + r"^(import time\n)(?=import time\n)", "", result, flags=re.MULTILINE + ) + + # Add header + header = HEADER.format(filename=filename) + result = header + result + + return result + + +def validate_allowlist() -> bool: + """Check that no unexpected .py files exist in _async/.""" + if not ASYNC_DIR.exists(): + return True + + async_files = {f.name for f in ASYNC_DIR.glob("*.py") if f.name != "__init__.py"} + allowed = set(GENERATED_MODULES) + unexpected = async_files - allowed + + if unexpected: + print(f"ERROR: Unexpected files in _async/: {unexpected}", file=sys.stderr) + print( + "Add them to GENERATED_MODULES in scripts/generate_sync.py", file=sys.stderr + ) + return False + return True + + +def main() -> int: + mode = "--write" + if len(sys.argv) > 1: + mode = sys.argv[1] + + if mode not in ("--write", "--check"): + print(f"Usage: {sys.argv[0]} [--write|--check]", file=sys.stderr) + return 1 + + if not validate_allowlist(): + return 1 + + # Determine which modules exist in _async/ + present = [] + for module in GENERATED_MODULES: + src = ASYNC_DIR / module + if src.exists(): + present.append(module) + else: + print(f"skipped: {module} (not yet migrated)") + + if not present: + print("No async modules found to generate.") + return 0 + + errors = [] + + for module in present: + src = ASYNC_DIR / module + source = src.read_text() + generated = transform(source, module) + dest = SYNC_DIR / module + + if mode == "--write": + SYNC_DIR.mkdir(parents=True, exist_ok=True) + dest.write_text(generated) + print(f"wrote: _sync/{module}") + else: # --check + if not dest.exists(): + errors.append( + f"_sync/{module} does not exist (expected from _async/{module})" + ) + continue + existing = dest.read_text() + if existing != generated: + errors.append(f"_sync/{module} is out of date") + + # Check for stale files in _sync/ + if mode == "--check" and SYNC_DIR.exists(): + sync_files = {f.name for f in SYNC_DIR.glob("*.py") if f.name != "__init__.py"} + expected = set(present) + stale = sync_files - expected + for s in stale: + errors.append(f"_sync/{s} is stale (no corresponding _async/ source)") + + if errors: + for e in errors: + print(f"ERROR: {e}", file=sys.stderr) + return 1 + + if mode == "--check": + print("OK: _sync/ is up to date") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/test_codegen.py b/tests/test_codegen.py new file mode 100644 index 0000000..3ff212d --- /dev/null +++ b/tests/test_codegen.py @@ -0,0 +1,19 @@ +import subprocess +import sys +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parent.parent + + +def test_sync_generation_is_up_to_date(): + """Assert _sync/ matches what the generator would produce.""" + result = subprocess.run( + [sys.executable, str(REPO_ROOT / "scripts" / "generate_sync.py"), "--check"], + capture_output=True, + text=True, + cwd=str(REPO_ROOT), + ) + assert result.returncode == 0, ( + f"_sync/ is out of date. Run: python scripts/generate_sync.py --write\n" + f"{result.stdout}\n{result.stderr}" + ) diff --git a/tests/test_portfolio.py b/tests/test_portfolio.py index 14033dd..f712d1b 100644 --- a/tests/test_portfolio.py +++ b/tests/test_portfolio.py @@ -380,8 +380,8 @@ def test_order_wait_until_terminal_executed(client, mock_response, mocker): order = Order(client, initial_model) # Mock time to avoid actual sleeping - mocker.patch("pykalshi.orders.time.sleep") - mock_monotonic = mocker.patch("pykalshi.orders.time.monotonic") + mocker.patch("pykalshi._sync.orders.time.sleep") + mock_monotonic = mocker.patch("pykalshi._sync.orders.time.monotonic") mock_monotonic.side_effect = [0.0, 0.5, 1.0] # start, check, check # First refresh: still resting, second refresh: executed @@ -409,7 +409,7 @@ def test_order_wait_until_terminal_already_terminal(client, mock_response, mocke ) order = Order(client, initial_model) - mock_sleep = mocker.patch("pykalshi.orders.time.sleep") + mock_sleep = mocker.patch("pykalshi._sync.orders.time.sleep") result = order.wait_until_terminal(timeout=5.0) @@ -431,8 +431,8 @@ def test_order_wait_until_terminal_timeout(client, mock_response, mocker): ) order = Order(client, initial_model) - mocker.patch("pykalshi.orders.time.sleep") - mock_monotonic = mocker.patch("pykalshi.orders.time.monotonic") + mocker.patch("pykalshi._sync.orders.time.sleep") + mock_monotonic = mocker.patch("pykalshi._sync.orders.time.monotonic") # Simulate time passing: start at 0, then jump past deadline mock_monotonic.side_effect = [0.0, 0.5, 2.1] # start, first check (ok), second check (past deadline)