Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions alphaswarm/services/exchanges/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .factory import DEXFactory
from .base import DEXClient, SwapResult
from .base import DEXClient, SwapResult, QuoteResult
from .uniswap import UniswapClientBase

__all__ = ["DEXFactory", "DEXClient", "SwapResult", "UniswapClientBase"]
__all__ = ["DEXFactory", "DEXClient", "SwapResult", "QuoteResult", "UniswapClientBase"]
41 changes: 28 additions & 13 deletions alphaswarm/services/exchanges/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,24 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from decimal import Decimal
from typing import List, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Generic, List, Optional, Tuple, Type, TypeGuard, TypeVar, Union

from alphaswarm.config import ChainConfig, Config, TokenInfo
from hexbytes import HexBytes

T = TypeVar("T", bound="DEXClient")
TQuote = TypeVar("TQuote")


@dataclass
class QuoteResult(Generic[TQuote]):
quote: TQuote

token_in: TokenInfo
token_out: TokenInfo
amount_in: Decimal
amount_out: Decimal


@dataclass
class SwapResult:
Expand Down Expand Up @@ -66,16 +79,14 @@ def __repr__(self) -> str:
return f"Slippage(bps={self.bps})"


T = TypeVar("T", bound="DEXClient")


class DEXClient(ABC):
class DEXClient(Generic[TQuote], ABC):
"""Base class for DEX clients"""

@abstractmethod
def __init__(self, chain_config: ChainConfig) -> None:
def __init__(self, chain_config: ChainConfig, quote_type: Type[TQuote]) -> None:
"""Initialize the DEX client with configuration"""
self._chain_config = chain_config
self._quote_type = quote_type

@property
def chain(self) -> str:
Expand All @@ -86,14 +97,15 @@ def chain_config(self) -> ChainConfig:
return self._chain_config

@abstractmethod
def get_token_price(self, token_out: TokenInfo, token_in: TokenInfo) -> Decimal:
def get_token_price(self, token_out: TokenInfo, token_in: TokenInfo, amount_in: Decimal) -> QuoteResult[TQuote]:
"""Get price/conversion rate for the pair of tokens.

The price is returned in terms of token_out/token_in (how much token out per token in).

Args:
token_out (TokenInfo): The token to be bought (going out from the pool)
token_in (TokenInfo): The token to be sold (going into the pool)
amount_in (Decimal): The amount of the token to be sold

Example:
eth_token = TokenInfo(address="0x...", decimals=18, symbol="ETH", chain="ethereum")
Expand All @@ -106,17 +118,13 @@ def get_token_price(self, token_out: TokenInfo, token_in: TokenInfo) -> Decimal:
@abstractmethod
def swap(
self,
token_out: TokenInfo,
token_in: TokenInfo,
amount_in: Decimal,
quote: QuoteResult[TQuote],
slippage_bps: int = 100,
) -> SwapResult:
"""Execute a token swap on the DEX

Args:
token_out (TokenInfo): The token to be bought (going out from the pool)
token_in (TokenInfo): The token to be sold (going into the pool)
amount_in: Amount of token_in to spend
quote (TokenPrice): The quote to execute
slippage_bps: Maximum allowed slippage in basis points (1 bp = 0.01%)

Returns:
Expand Down Expand Up @@ -155,3 +163,10 @@ def from_config(cls: Type[T], config: Config, chain: str) -> T:
An instance of the DEX client
"""
pass

def raise_if_not_quote(self, value: Any) -> None:
if self.is_quote(value):
raise TypeError(f"Expected {self._quote_type} but got {type(value)}")

def is_quote(self, value: Any) -> TypeGuard[QuoteResult[TQuote]]:
return isinstance(value, QuoteResult) and isinstance(value.quote, self._quote_type)
73 changes: 51 additions & 22 deletions alphaswarm/services/exchanges/jupiter/jupiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,55 @@

import logging
from decimal import Decimal
from typing import Any, Dict, List, Tuple
from typing import Annotated, Any, Dict, List, Optional, Tuple
from urllib.parse import urlencode

import requests
from alphaswarm.config import ChainConfig, Config, JupiterSettings, JupiterVenue, TokenInfo
from alphaswarm.services import ApiException
from alphaswarm.services.exchanges.base import DEXClient, SwapResult
from pydantic import Field
from alphaswarm.services.exchanges.base import DEXClient, QuoteResult, SwapResult
from pydantic import BaseModel, Field
from pydantic.dataclasses import dataclass

logger = logging.getLogger(__name__)


class SwapInfo(BaseModel):
amm_key: Annotated[str, Field(alias="ammKey")]
label: Annotated[Optional[str], Field(alias="label", default=None)]
input_mint: Annotated[str, Field(alias="inputMint")]
output_mint: Annotated[str, Field(alias="outputMint")]
in_amount: Annotated[str, Field(alias="inAmount")]
out_amount: Annotated[str, Field(alias="outAmount")]
fee_amount: Annotated[str, Field(alias="feeAmount")]
fee_mint: Annotated[str, Field(alias="feeMint")]

def to_dict(self) -> Dict[str, Any]:
return self.model_dump(by_alias=True)


@dataclass
class QuoteResponse:
class RoutePlan:
swap_info: Annotated[SwapInfo, Field(alias="swapInfo")]
percent: int


@dataclass
class JupiterQuote:
# TODO capture more fields if needed
out_amount: Decimal = Field(alias="outAmount")
route_plan: List[Dict[str, Any]] = Field(alias="routePlan")
out_amount: Annotated[Decimal, Field(alias="outAmount")]
route_plan: Annotated[List[RoutePlan], Field(alias="routePlan")]

def route_plan_to_string(self) -> str:
return "/".join([route.swap_info.amm_key for route in self.route_plan])


class JupiterClient(DEXClient):
class JupiterClient(DEXClient[JupiterQuote]):
"""Client for Jupiter DEX on Solana"""

def __init__(self, chain_config: ChainConfig, venue_config: JupiterVenue, settings: JupiterSettings) -> None:
self._validate_chain(chain_config.chain)
super().__init__(chain_config)
super().__init__(chain_config, JupiterQuote)
self._settings = settings
self._venue_config = venue_config
logger.info(f"Initialized JupiterClient on chain '{self.chain}'")
Expand All @@ -38,26 +61,26 @@ def _validate_chain(self, chain: str) -> None:

def swap(
self,
token_out: TokenInfo,
token_in: TokenInfo,
amount_in: Decimal,
quote: QuoteResult[JupiterQuote],
slippage_bps: int = 100,
) -> SwapResult:
raise NotImplementedError("Jupiter swap functionality is not yet implemented")

def get_token_price(self, token_out: TokenInfo, token_in: TokenInfo) -> Decimal:
def get_token_price(
self, token_out: TokenInfo, token_in: TokenInfo, amount_in: Decimal
) -> QuoteResult[JupiterQuote]:
# Verify tokens are on Solana
if not token_out.chain == self.chain or not token_in.chain == self.chain:
raise ValueError(f"Jupiter only supports Solana tokens. Got {token_out.chain} and {token_in.chain}")

logger.debug(f"Getting price for {token_out.symbol}/{token_in.symbol} on {token_out.chain} using Jupiter")
logger.debug(f"Getting amount_out for {token_out.symbol}/{token_in.symbol} on {token_out.chain} using Jupiter")

# Prepare query parameters
params = {
"inputMint": token_in.address,
"outputMint": token_out.address,
"swapMode": "ExactIn",
"amount": str(token_in.convert_to_wei(Decimal(1))), # Get price spending exactly 1 token_in
"amount": str(token_in.convert_to_wei(amount_in)),
"slippageBps": self._settings.slippage_bps,
}

Expand All @@ -68,19 +91,25 @@ def get_token_price(self, token_out: TokenInfo, token_in: TokenInfo) -> Decimal:
raise ApiException(response)

result = response.json()
quote = QuoteResponse(**result)
quote = JupiterQuote(**result)

# Calculate price (token_out per token_in)
amount_out = quote.out_amount
price = token_out.convert_from_wei(amount_out)
# Calculate amount_out (token_out per token_in)
raw_out = quote.out_amount
amount_out = token_out.convert_from_wei(raw_out)
# Log quote details
logger.debug("Quote successful:")
logger.debug(f"- Input: 1 {token_in.symbol}")
logger.debug(f"- Output: {amount_out} {token_out.symbol} lamports")
logger.debug(f"- Price: {price} {token_out.symbol}/{token_in.symbol}")
logger.debug(f"- Input: {amount_in} {token_in.symbol}")
logger.debug(f"- Output: {amount_out} {token_out.symbol}")
logger.debug(f"- Ratio: {amount_out/amount_in} {token_out.symbol}/{token_in.symbol}")
logger.debug(f"- Route: {quote.route_plan}")

return price
return QuoteResult(
quote=quote,
token_in=token_in,
token_out=token_out,
amount_in=amount_in,
amount_out=amount_out,
)

def get_markets_for_tokens(self, tokens: List[TokenInfo]) -> List[Tuple[TokenInfo, TokenInfo]]:
"""Get list of valid trading pairs between the provided tokens.
Expand Down
4 changes: 4 additions & 0 deletions alphaswarm/services/exchanges/uniswap/constants_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@
"factory": "0x5C69bEe701ef814a2B6a3EDD4B1652CB9cc5aA6f",
"router": "0x7a250d5630B4cF539739dF2C5dAcb4c659F2488D",
},
"ethereum_sepolia": {
"factory": "0xF62c03E08ada871A0bEb309762E260a7a6a880E6",
"router": "0xeE567Fe1712Faf6149d80dA1E6934E354124CfE3",
},
"base": {
"factory": "0x8909Dc15e40173Ff4699343b6eB8132c65e18eC6",
"router": "0x4752ba5dbc23f44d87826276bf6fd6b1c372ad24",
Expand Down
45 changes: 27 additions & 18 deletions alphaswarm/services/exchanges/uniswap/uniswap_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,23 @@

from alphaswarm.config import ChainConfig, TokenInfo
from alphaswarm.services.chains.evm import ERC20Contract, EVMClient, EVMSigner
from alphaswarm.services.exchanges.base import DEXClient, SwapResult
from alphaswarm.services.exchanges.base import DEXClient, QuoteResult, SwapResult
from eth_typing import ChecksumAddress, HexAddress
from pydantic.dataclasses import dataclass
from web3.types import TxReceipt

# Set up logger
logger = logging.getLogger(__name__)


class UniswapClientBase(DEXClient):
@dataclass
class UniswapQuote:
pool_address: ChecksumAddress


class UniswapClientBase(DEXClient[UniswapQuote]):
def __init__(self, chain_config: ChainConfig, version: str) -> None:
super().__init__(chain_config)
super().__init__(chain_config, UniswapQuote)
self.version = version
self._evm_client = EVMClient(chain_config)
self._router = self._get_router()
Expand All @@ -41,12 +47,16 @@ def _get_factory(self) -> ChecksumAddress:

@abstractmethod
def _swap(
self, *, token_out: TokenInfo, token_in: TokenInfo, address: str, wei_in: int, slippage_bps: int
self,
quote: QuoteResult[UniswapQuote],
slippage_bps: int,
) -> List[TxReceipt]:
pass

@abstractmethod
def _get_token_price(self, token_out: TokenInfo, token_in: TokenInfo) -> Decimal:
def _get_token_price(
self, token_out: TokenInfo, token_in: TokenInfo, amount_in: Decimal
) -> QuoteResult[UniswapQuote]:
pass

@abstractmethod
Expand Down Expand Up @@ -95,17 +105,18 @@ def _get_final_swap_amount_received(

def swap(
self,
token_out: TokenInfo,
token_in: TokenInfo,
amount_in: Decimal,
quote: QuoteResult[UniswapQuote],
slippage_bps: int = 100,
) -> SwapResult:
logger.info(f"Initiating token swap for {token_in.symbol} to {token_out.symbol}")
logger.info(f"Wallet address: {self.wallet_address}")

# Create contract instances
token_out = quote.token_out
token_out_contract = ERC20Contract(self._evm_client, token_out.checksum_address)
token_in = quote.token_in
token_in_contract = ERC20Contract(self._evm_client, token_in.checksum_address)
amount_in = quote.amount_in

logger.info(f"Initiating token swap for {token_in.symbol} to {token_out.symbol}")
logger.info(f"Wallet address: {self.wallet_address}")

# Gas balance
gas_balance = self._evm_client.get_native_balance(self.wallet_address)
Expand All @@ -118,7 +129,6 @@ def swap(
logger.info(f"Balance of {token_out.symbol}: {out_balance:,.8f}")
logger.info(f"Balance of {token_in.symbol}: {in_balance:,.8f}")
logger.info(f"ETH balance for gas: {eth_balance:,.6f}")
wei_in = token_in.convert_to_wei(amount_in)

if in_balance < amount_in:
raise ValueError(
Expand All @@ -130,10 +140,7 @@ def swap(
# 2) swap (various functions)

receipts = self._swap(
token_out=token_out,
token_in=token_in,
address=self.wallet_address,
wei_in=wei_in,
quote=quote,
slippage_bps=slippage_bps,
)

Expand Down Expand Up @@ -166,12 +173,14 @@ def _approve_token_spending(self, token: TokenInfo, raw_amount: int) -> TxReceip
tx_receipt = token_contract.approve(self.get_signer(), self._router, raw_amount)
return tx_receipt

def get_token_price(self, token_out: TokenInfo, token_in: TokenInfo) -> Decimal:
def get_token_price(
self, token_out: TokenInfo, token_in: TokenInfo, amount_in: Decimal
) -> QuoteResult[UniswapQuote]:
logger.debug(
f"Getting price for {token_out.symbol}/{token_in.symbol} on {self.chain} using Uniswap {self.version}"
)

return self._get_token_price(token_out=token_out, token_in=token_in)
return self._get_token_price(token_out=token_out, token_in=token_in, amount_in=amount_in)

def get_markets_for_tokens(self, tokens: List[TokenInfo]) -> List[Tuple[TokenInfo, TokenInfo]]:
"""Get list of valid trading pairs between the provided tokens.
Expand Down
Loading