diff --git a/alphaswarm/services/exchanges/__init__.py b/alphaswarm/services/exchanges/__init__.py index 80be77b5..b3f4d100 100644 --- a/alphaswarm/services/exchanges/__init__.py +++ b/alphaswarm/services/exchanges/__init__.py @@ -1,5 +1,5 @@ from .factory import DEXFactory -from .base import DEXClient, SwapResult +from .base import DEXClient, SwapResult, QuoteResult from .uniswap import UniswapClientBase -__all__ = ["DEXFactory", "DEXClient", "SwapResult", "UniswapClientBase"] +__all__ = ["DEXFactory", "DEXClient", "SwapResult", "QuoteResult", "UniswapClientBase"] diff --git a/alphaswarm/services/exchanges/base.py b/alphaswarm/services/exchanges/base.py index c327661e..21ed592d 100644 --- a/alphaswarm/services/exchanges/base.py +++ b/alphaswarm/services/exchanges/base.py @@ -3,11 +3,24 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from decimal import Decimal -from typing import List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Generic, List, Optional, Tuple, Type, TypeGuard, TypeVar, Union from alphaswarm.config import ChainConfig, Config, TokenInfo from hexbytes import HexBytes +T = TypeVar("T", bound="DEXClient") +TQuote = TypeVar("TQuote") + + +@dataclass +class QuoteResult(Generic[TQuote]): + quote: TQuote + + token_in: TokenInfo + token_out: TokenInfo + amount_in: Decimal + amount_out: Decimal + @dataclass class SwapResult: @@ -66,16 +79,14 @@ def __repr__(self) -> str: return f"Slippage(bps={self.bps})" -T = TypeVar("T", bound="DEXClient") - - -class DEXClient(ABC): +class DEXClient(Generic[TQuote], ABC): """Base class for DEX clients""" @abstractmethod - def __init__(self, chain_config: ChainConfig) -> None: + def __init__(self, chain_config: ChainConfig, quote_type: Type[TQuote]) -> None: """Initialize the DEX client with configuration""" self._chain_config = chain_config + self._quote_type = quote_type @property def chain(self) -> str: @@ -86,7 +97,7 @@ def chain_config(self) -> ChainConfig: return self._chain_config @abstractmethod - def get_token_price(self, token_out: TokenInfo, token_in: TokenInfo) -> Decimal: + def get_token_price(self, token_out: TokenInfo, token_in: TokenInfo, amount_in: Decimal) -> QuoteResult[TQuote]: """Get price/conversion rate for the pair of tokens. The price is returned in terms of token_out/token_in (how much token out per token in). @@ -94,6 +105,7 @@ def get_token_price(self, token_out: TokenInfo, token_in: TokenInfo) -> Decimal: Args: token_out (TokenInfo): The token to be bought (going out from the pool) token_in (TokenInfo): The token to be sold (going into the pool) + amount_in (Decimal): The amount of the token to be sold Example: eth_token = TokenInfo(address="0x...", decimals=18, symbol="ETH", chain="ethereum") @@ -106,17 +118,13 @@ def get_token_price(self, token_out: TokenInfo, token_in: TokenInfo) -> Decimal: @abstractmethod def swap( self, - token_out: TokenInfo, - token_in: TokenInfo, - amount_in: Decimal, + quote: QuoteResult[TQuote], slippage_bps: int = 100, ) -> SwapResult: """Execute a token swap on the DEX Args: - token_out (TokenInfo): The token to be bought (going out from the pool) - token_in (TokenInfo): The token to be sold (going into the pool) - amount_in: Amount of token_in to spend + quote (TokenPrice): The quote to execute slippage_bps: Maximum allowed slippage in basis points (1 bp = 0.01%) Returns: @@ -155,3 +163,10 @@ def from_config(cls: Type[T], config: Config, chain: str) -> T: An instance of the DEX client """ pass + + def raise_if_not_quote(self, value: Any) -> None: + if self.is_quote(value): + raise TypeError(f"Expected {self._quote_type} but got {type(value)}") + + def is_quote(self, value: Any) -> TypeGuard[QuoteResult[TQuote]]: + return isinstance(value, QuoteResult) and isinstance(value.quote, self._quote_type) diff --git a/alphaswarm/services/exchanges/jupiter/jupiter.py b/alphaswarm/services/exchanges/jupiter/jupiter.py index b812b70f..f8eef053 100644 --- a/alphaswarm/services/exchanges/jupiter/jupiter.py +++ b/alphaswarm/services/exchanges/jupiter/jupiter.py @@ -2,32 +2,55 @@ import logging from decimal import Decimal -from typing import Any, Dict, List, Tuple +from typing import Annotated, Any, Dict, List, Optional, Tuple from urllib.parse import urlencode import requests from alphaswarm.config import ChainConfig, Config, JupiterSettings, JupiterVenue, TokenInfo from alphaswarm.services import ApiException -from alphaswarm.services.exchanges.base import DEXClient, SwapResult -from pydantic import Field +from alphaswarm.services.exchanges.base import DEXClient, QuoteResult, SwapResult +from pydantic import BaseModel, Field from pydantic.dataclasses import dataclass logger = logging.getLogger(__name__) +class SwapInfo(BaseModel): + amm_key: Annotated[str, Field(alias="ammKey")] + label: Annotated[Optional[str], Field(alias="label", default=None)] + input_mint: Annotated[str, Field(alias="inputMint")] + output_mint: Annotated[str, Field(alias="outputMint")] + in_amount: Annotated[str, Field(alias="inAmount")] + out_amount: Annotated[str, Field(alias="outAmount")] + fee_amount: Annotated[str, Field(alias="feeAmount")] + fee_mint: Annotated[str, Field(alias="feeMint")] + + def to_dict(self) -> Dict[str, Any]: + return self.model_dump(by_alias=True) + + @dataclass -class QuoteResponse: +class RoutePlan: + swap_info: Annotated[SwapInfo, Field(alias="swapInfo")] + percent: int + + +@dataclass +class JupiterQuote: # TODO capture more fields if needed - out_amount: Decimal = Field(alias="outAmount") - route_plan: List[Dict[str, Any]] = Field(alias="routePlan") + out_amount: Annotated[Decimal, Field(alias="outAmount")] + route_plan: Annotated[List[RoutePlan], Field(alias="routePlan")] + + def route_plan_to_string(self) -> str: + return "/".join([route.swap_info.amm_key for route in self.route_plan]) -class JupiterClient(DEXClient): +class JupiterClient(DEXClient[JupiterQuote]): """Client for Jupiter DEX on Solana""" def __init__(self, chain_config: ChainConfig, venue_config: JupiterVenue, settings: JupiterSettings) -> None: self._validate_chain(chain_config.chain) - super().__init__(chain_config) + super().__init__(chain_config, JupiterQuote) self._settings = settings self._venue_config = venue_config logger.info(f"Initialized JupiterClient on chain '{self.chain}'") @@ -38,26 +61,26 @@ def _validate_chain(self, chain: str) -> None: def swap( self, - token_out: TokenInfo, - token_in: TokenInfo, - amount_in: Decimal, + quote: QuoteResult[JupiterQuote], slippage_bps: int = 100, ) -> SwapResult: raise NotImplementedError("Jupiter swap functionality is not yet implemented") - def get_token_price(self, token_out: TokenInfo, token_in: TokenInfo) -> Decimal: + def get_token_price( + self, token_out: TokenInfo, token_in: TokenInfo, amount_in: Decimal + ) -> QuoteResult[JupiterQuote]: # Verify tokens are on Solana if not token_out.chain == self.chain or not token_in.chain == self.chain: raise ValueError(f"Jupiter only supports Solana tokens. Got {token_out.chain} and {token_in.chain}") - logger.debug(f"Getting price for {token_out.symbol}/{token_in.symbol} on {token_out.chain} using Jupiter") + logger.debug(f"Getting amount_out for {token_out.symbol}/{token_in.symbol} on {token_out.chain} using Jupiter") # Prepare query parameters params = { "inputMint": token_in.address, "outputMint": token_out.address, "swapMode": "ExactIn", - "amount": str(token_in.convert_to_wei(Decimal(1))), # Get price spending exactly 1 token_in + "amount": str(token_in.convert_to_wei(amount_in)), "slippageBps": self._settings.slippage_bps, } @@ -68,19 +91,25 @@ def get_token_price(self, token_out: TokenInfo, token_in: TokenInfo) -> Decimal: raise ApiException(response) result = response.json() - quote = QuoteResponse(**result) + quote = JupiterQuote(**result) - # Calculate price (token_out per token_in) - amount_out = quote.out_amount - price = token_out.convert_from_wei(amount_out) + # Calculate amount_out (token_out per token_in) + raw_out = quote.out_amount + amount_out = token_out.convert_from_wei(raw_out) # Log quote details logger.debug("Quote successful:") - logger.debug(f"- Input: 1 {token_in.symbol}") - logger.debug(f"- Output: {amount_out} {token_out.symbol} lamports") - logger.debug(f"- Price: {price} {token_out.symbol}/{token_in.symbol}") + logger.debug(f"- Input: {amount_in} {token_in.symbol}") + logger.debug(f"- Output: {amount_out} {token_out.symbol}") + logger.debug(f"- Ratio: {amount_out/amount_in} {token_out.symbol}/{token_in.symbol}") logger.debug(f"- Route: {quote.route_plan}") - return price + return QuoteResult( + quote=quote, + token_in=token_in, + token_out=token_out, + amount_in=amount_in, + amount_out=amount_out, + ) def get_markets_for_tokens(self, tokens: List[TokenInfo]) -> List[Tuple[TokenInfo, TokenInfo]]: """Get list of valid trading pairs between the provided tokens. diff --git a/alphaswarm/services/exchanges/uniswap/constants_v2.py b/alphaswarm/services/exchanges/uniswap/constants_v2.py index 49a65e6b..8f3ccb3f 100644 --- a/alphaswarm/services/exchanges/uniswap/constants_v2.py +++ b/alphaswarm/services/exchanges/uniswap/constants_v2.py @@ -90,6 +90,10 @@ "factory": "0x5C69bEe701ef814a2B6a3EDD4B1652CB9cc5aA6f", "router": "0x7a250d5630B4cF539739dF2C5dAcb4c659F2488D", }, + "ethereum_sepolia": { + "factory": "0xF62c03E08ada871A0bEb309762E260a7a6a880E6", + "router": "0xeE567Fe1712Faf6149d80dA1E6934E354124CfE3", + }, "base": { "factory": "0x8909Dc15e40173Ff4699343b6eB8132c65e18eC6", "router": "0x4752ba5dbc23f44d87826276bf6fd6b1c372ad24", diff --git a/alphaswarm/services/exchanges/uniswap/uniswap_client_base.py b/alphaswarm/services/exchanges/uniswap/uniswap_client_base.py index 4b3d97a1..1e4941bf 100644 --- a/alphaswarm/services/exchanges/uniswap/uniswap_client_base.py +++ b/alphaswarm/services/exchanges/uniswap/uniswap_client_base.py @@ -5,17 +5,23 @@ from alphaswarm.config import ChainConfig, TokenInfo from alphaswarm.services.chains.evm import ERC20Contract, EVMClient, EVMSigner -from alphaswarm.services.exchanges.base import DEXClient, SwapResult +from alphaswarm.services.exchanges.base import DEXClient, QuoteResult, SwapResult from eth_typing import ChecksumAddress, HexAddress +from pydantic.dataclasses import dataclass from web3.types import TxReceipt # Set up logger logger = logging.getLogger(__name__) -class UniswapClientBase(DEXClient): +@dataclass +class UniswapQuote: + pool_address: ChecksumAddress + + +class UniswapClientBase(DEXClient[UniswapQuote]): def __init__(self, chain_config: ChainConfig, version: str) -> None: - super().__init__(chain_config) + super().__init__(chain_config, UniswapQuote) self.version = version self._evm_client = EVMClient(chain_config) self._router = self._get_router() @@ -41,12 +47,16 @@ def _get_factory(self) -> ChecksumAddress: @abstractmethod def _swap( - self, *, token_out: TokenInfo, token_in: TokenInfo, address: str, wei_in: int, slippage_bps: int + self, + quote: QuoteResult[UniswapQuote], + slippage_bps: int, ) -> List[TxReceipt]: pass @abstractmethod - def _get_token_price(self, token_out: TokenInfo, token_in: TokenInfo) -> Decimal: + def _get_token_price( + self, token_out: TokenInfo, token_in: TokenInfo, amount_in: Decimal + ) -> QuoteResult[UniswapQuote]: pass @abstractmethod @@ -95,17 +105,18 @@ def _get_final_swap_amount_received( def swap( self, - token_out: TokenInfo, - token_in: TokenInfo, - amount_in: Decimal, + quote: QuoteResult[UniswapQuote], slippage_bps: int = 100, ) -> SwapResult: - logger.info(f"Initiating token swap for {token_in.symbol} to {token_out.symbol}") - logger.info(f"Wallet address: {self.wallet_address}") - # Create contract instances + token_out = quote.token_out token_out_contract = ERC20Contract(self._evm_client, token_out.checksum_address) + token_in = quote.token_in token_in_contract = ERC20Contract(self._evm_client, token_in.checksum_address) + amount_in = quote.amount_in + + logger.info(f"Initiating token swap for {token_in.symbol} to {token_out.symbol}") + logger.info(f"Wallet address: {self.wallet_address}") # Gas balance gas_balance = self._evm_client.get_native_balance(self.wallet_address) @@ -118,7 +129,6 @@ def swap( logger.info(f"Balance of {token_out.symbol}: {out_balance:,.8f}") logger.info(f"Balance of {token_in.symbol}: {in_balance:,.8f}") logger.info(f"ETH balance for gas: {eth_balance:,.6f}") - wei_in = token_in.convert_to_wei(amount_in) if in_balance < amount_in: raise ValueError( @@ -130,10 +140,7 @@ def swap( # 2) swap (various functions) receipts = self._swap( - token_out=token_out, - token_in=token_in, - address=self.wallet_address, - wei_in=wei_in, + quote=quote, slippage_bps=slippage_bps, ) @@ -166,12 +173,14 @@ def _approve_token_spending(self, token: TokenInfo, raw_amount: int) -> TxReceip tx_receipt = token_contract.approve(self.get_signer(), self._router, raw_amount) return tx_receipt - def get_token_price(self, token_out: TokenInfo, token_in: TokenInfo) -> Decimal: + def get_token_price( + self, token_out: TokenInfo, token_in: TokenInfo, amount_in: Decimal + ) -> QuoteResult[UniswapQuote]: logger.debug( f"Getting price for {token_out.symbol}/{token_in.symbol} on {self.chain} using Uniswap {self.version}" ) - return self._get_token_price(token_out=token_out, token_in=token_in) + return self._get_token_price(token_out=token_out, token_in=token_in, amount_in=amount_in) def get_markets_for_tokens(self, tokens: List[TokenInfo]) -> List[Tuple[TokenInfo, TokenInfo]]: """Get list of valid trading pairs between the provided tokens. diff --git a/alphaswarm/services/exchanges/uniswap/uniswap_client_v2.py b/alphaswarm/services/exchanges/uniswap/uniswap_client_v2.py index adabfad1..b5fb5149 100644 --- a/alphaswarm/services/exchanges/uniswap/uniswap_client_v2.py +++ b/alphaswarm/services/exchanges/uniswap/uniswap_client_v2.py @@ -6,14 +6,14 @@ from alphaswarm.config import ChainConfig, Config, TokenInfo from alphaswarm.services.chains.evm import ZERO_ADDRESS -from alphaswarm.services.exchanges.base import Slippage +from alphaswarm.services.exchanges.base import QuoteResult, Slippage from alphaswarm.services.exchanges.uniswap.constants_v2 import ( UNISWAP_V2_DEPLOYMENTS, UNISWAP_V2_FACTORY_ABI, UNISWAP_V2_ROUTER_ABI, UNISWAP_V2_VERSION, ) -from alphaswarm.services.exchanges.uniswap.uniswap_client_base import UniswapClientBase +from alphaswarm.services.exchanges.uniswap.uniswap_client_base import UniswapClientBase, UniswapQuote from eth_defi.uniswap_v2.pair import fetch_pair_details from eth_typing import ChecksumAddress from web3.types import TxReceipt @@ -33,25 +33,21 @@ def _get_factory(self) -> ChecksumAddress: return self._evm_client.to_checksum_address(UNISWAP_V2_DEPLOYMENTS[self.chain]["factory"]) def _swap( - self, token_out: TokenInfo, token_in: TokenInfo, address: str, wei_in: int, slippage_bps: int + self, + quote: QuoteResult[UniswapQuote], + slippage_bps: int, ) -> List[TxReceipt]: """Execute a swap on Uniswap V2.""" # Handle token approval and get fresh nonce - approval_receipt = self._approve_token_spending(token_in, wei_in) - - # Get price from V2 pair to calculate minimum output - price = self._get_token_price(token_out=token_out, token_in=token_in) - if not price: - raise ValueError(f"No V2 price found for {token_out.symbol}/{token_in.symbol}") + token_in = quote.token_in + token_out = quote.token_out + wei_in = token_in.convert_to_wei(quote.amount_in) - # Calculate expected output - input_amount_decimal = token_in.convert_from_wei(wei_in) - expected_output_decimal = input_amount_decimal * price - logger.info(f"Expected output: {expected_output_decimal} {token_out.symbol}") + approval_receipt = self._approve_token_spending(token_in, wei_in) # Convert expected output to raw integer and apply slippage slippage = Slippage(slippage_bps) - min_output_raw = slippage.calculate_minimum_amount(token_out.convert_to_wei(expected_output_decimal)) + min_output_raw = slippage.calculate_minimum_amount(token_out.convert_to_wei(quote.amount_out)) logger.info(f"Minimum output with {slippage} slippage (raw): {min_output_raw}") # Build swap path @@ -65,7 +61,7 @@ def _swap( wei_in, # amount in min_output_raw, # minimum amount out path, # swap path - address, # recipient + self.wallet_address, # recipient deadline, # deadline ) @@ -73,7 +69,9 @@ def _swap( swap_receipt = self._evm_client.process(swap, self.get_signer()) return [approval_receipt, swap_receipt] - def _get_token_price(self, token_out: TokenInfo, token_in: TokenInfo) -> Decimal: + def _get_token_price( + self, token_out: TokenInfo, token_in: TokenInfo, amount_in: Decimal + ) -> QuoteResult[UniswapQuote]: # Create factory contract instance factory_contract = self._web3.eth.contract(address=self._factory, abi=UNISWAP_V2_FACTORY_ABI) @@ -86,10 +84,21 @@ def _get_token_price(self, token_out: TokenInfo, token_in: TokenInfo) -> Decimal # Get V2 pair details - if reverse false, mid_price = token1_amount / token0_amount # token0 of the pair has the lowest address. Reverse if needed + price = self._get_price_from_pool(pair_address=pair_address, token_out=token_out, token_in=token_in) + quote = UniswapQuote( + pool_address=pair_address, + ) + + return QuoteResult( + quote=quote, token_in=token_in, token_out=token_out, amount_in=amount_in, amount_out=price * amount_in + ) + + def _get_price_from_pool( + self, *, pair_address: ChecksumAddress, token_out: TokenInfo, token_in: TokenInfo + ) -> Decimal: reverse = token_out.checksum_address.lower() < token_in.checksum_address.lower() pair = fetch_pair_details(self._web3, pair_address, reverse_token_order=reverse) price = pair.get_current_mid_price() - return price def _get_markets_for_tokens(self, tokens: List[TokenInfo]) -> List[Tuple[TokenInfo, TokenInfo]]: diff --git a/alphaswarm/services/exchanges/uniswap/uniswap_client_v3.py b/alphaswarm/services/exchanges/uniswap/uniswap_client_v3.py index 207cd03b..9efa2427 100644 --- a/alphaswarm/services/exchanges/uniswap/uniswap_client_v3.py +++ b/alphaswarm/services/exchanges/uniswap/uniswap_client_v3.py @@ -6,7 +6,7 @@ from alphaswarm.config import ChainConfig, Config, TokenInfo, UniswapV3Settings from alphaswarm.services.chains.evm import ZERO_ADDRESS, EVMClient, EVMContract, EVMSigner -from alphaswarm.services.exchanges.base import Slippage +from alphaswarm.services.exchanges.base import QuoteResult, Slippage from alphaswarm.services.exchanges.uniswap.constants_v3 import ( UNISWAP_V3_DEPLOYMENTS, UNISWAP_V3_FACTORY_ABI, @@ -14,7 +14,7 @@ UNISWAP_V3_ROUTER_ABI, UNISWAP_V3_VERSION, ) -from alphaswarm.services.exchanges.uniswap.uniswap_client_base import UniswapClientBase +from alphaswarm.services.exchanges.uniswap.uniswap_client_base import UniswapClientBase, UniswapQuote from eth_defi.uniswap_v3.pool import PoolDetails, fetch_pool_details from eth_defi.uniswap_v3.price import get_onchain_price from eth_typing import ChecksumAddress, HexAddress @@ -138,30 +138,24 @@ def _get_factory(self) -> ChecksumAddress: return self._evm_client.to_checksum_address(UNISWAP_V3_DEPLOYMENTS[self.chain]["factory"]) def _swap( - self, token_out: TokenInfo, token_in: TokenInfo, address: str, wei_in: int, slippage_bps: int + self, + quote: QuoteResult[UniswapQuote], + slippage_bps: int, ) -> List[TxReceipt]: """Execute a swap on Uniswap V3.""" # Handle token approval and get fresh nonce + + token_in = quote.token_in + token_out = quote.token_out + wei_in = token_in.convert_to_wei(quote.amount_in) approval_receipt = self._approve_token_spending(token_in, wei_in) # Build a swap transaction - pool = self._get_pool(token_out, token_in) + pool = self._get_pool_by_address(quote.quote.pool_address) logger.info(f"Using Uniswap V3 pool at address: {pool.address} (raw fee tier: {pool.raw_fee})") - # Get the on-chain price from the pool and reverse if necessary - price = self._get_token_price_from_pool(token_out, pool) - logger.info(f"Pool raw price: {price} ({token_out.symbol} per {token_in.symbol})") - - # Convert to decimal for calculations - amount_in_decimal = token_in.convert_from_wei(wei_in) - logger.info(f"Actual input amount: {amount_in_decimal} {token_in.symbol}") - - # Calculate expected output - expected_output_decimal = amount_in_decimal * price - logger.info(f"Expected output: {expected_output_decimal} {token_out.symbol}") - # Convert expected output to raw integer - raw_output = token_out.convert_to_wei(expected_output_decimal) + raw_output = token_out.convert_to_wei(quote.amount_out) logger.info(f"Expected output amount (raw): {raw_output}") # Calculate price impact @@ -192,7 +186,7 @@ def _swap( token_in=token_in.checksum_address, token_out=token_out.checksum_address, fee=pool.raw_fee, - recipient=self._evm_client.to_checksum_address(address), + recipient=self.wallet_address, deadline=int(self._evm_client.get_block_latest()["timestamp"] + 300), amount_in=wei_in, amount_out_minimum=min_output_raw, @@ -205,9 +199,18 @@ def _swap( return [approval_receipt, swap_receipt] - def _get_token_price(self, token_out: TokenInfo, token_in: TokenInfo) -> Decimal: + def _get_token_price( + self, token_out: TokenInfo, token_in: TokenInfo, amount_in: Decimal + ) -> QuoteResult[UniswapQuote]: pool = self._get_pool(token_out, token_in) - return self._get_token_price_from_pool(token_out, pool) + price = self._get_token_price_from_pool(token_out, pool) + return QuoteResult( + token_in=token_in, + token_out=token_out, + amount_in=amount_in, + amount_out=price * amount_in, # TODO: substract fees? + quote=UniswapQuote(pool_address=pool.address), + ) @staticmethod def _get_token_price_from_pool(token_out: TokenInfo, pool: PoolContract) -> Decimal: diff --git a/alphaswarm/tools/exchanges/execute_token_swap_tool.py b/alphaswarm/tools/exchanges/execute_token_swap_tool.py index 181147ed..97b36304 100644 --- a/alphaswarm/tools/exchanges/execute_token_swap_tool.py +++ b/alphaswarm/tools/exchanges/execute_token_swap_tool.py @@ -1,9 +1,9 @@ import logging -from decimal import Decimal from typing import Any from alphaswarm.config import Config from alphaswarm.services.exchanges import DEXFactory, SwapResult +from alphaswarm.tools.exchanges.get_token_price_tool import TokenQuote from smolagents import Tool logger = logging.getLogger(__name__) @@ -13,30 +13,14 @@ class ExecuteTokenSwapTool(Tool): """Tool for executing token swaps on supported DEXes.""" name = "execute_token_swap" - description = "Execute a token swap on a supported DEX (Uniswap V2/V3 on Ethereum and Base chains)." + description = ( + "Execute a token swap on a supported DEX (Uniswap V2/V3 on Ethereum and Base chains). " + f"Returns a {SwapResult.__name__} details of the transaction." + ) inputs = { - "token_out": { - "type": "string", - "description": "The address of the token being bought (out from the pool)", - }, - "token_in": { - "type": "string", - "description": "The address of the token being sold (in the pool)", - }, - "amount_in": {"type": "number", "description": "The amount token_in to be sold", "required": True}, - "chain": { - "type": "string", - "description": "The chain to execute the swap on", - "enum": ["solana", "base", "ethereum", "ethereum_sepolia"], - "nullable": True, - "default": "ethereum", - }, - "dex_type": { - "type": "string", - "description": "The DEX type to use", - "enum": ["uniswap_v2", "uniswap_v3", "jupiter"], - "nullable": True, - "default": "uniswap_v3", + "quote": { + "type": "object", + "description": f"A {TokenQuote.__name__} previously generated", }, "slippage_bps": { "type": "integer", @@ -53,31 +37,20 @@ def __init__(self, config: Config, *args: Any, **kwargs: Any) -> None: def forward( self, *, - token_out: str, - token_in: str, - amount_in: Decimal, - chain: str = "ethereum", - dex_type: str = "uniswap_v3", + quote: TokenQuote, slippage_bps: int = 100, ) -> SwapResult: """Execute a token swap.""" # Create DEX client - dex_client = DEXFactory.create(dex_name=dex_type, config=self.config, chain=chain) - - # Get wallet address and private key from chain config - chain_config = self.config.get_chain_config(chain) - token_in_info = chain_config.get_token_info_by_address(token_in) - token_out_info = chain_config.get_token_info_by_address(token_out) + dex_client = DEXFactory.create(dex_name=quote.dex, config=self.config, chain=quote.chain) - # Log token details + inner = quote.quote logger.info( - f"Swapping {amount_in} {token_in_info.symbol} ({token_in_info.address}) for {token_out_info.symbol} ({token_out_info.address}) on {chain}" + f"Swapping {inner.amount_in} {inner.token_in.symbol} ({inner.token_in.address}) for {inner.token_out.symbol} ({inner.token_out.address}) on {quote.chain}" ) # Execute swap return dex_client.swap( - token_out=token_out_info, - token_in=token_in_info, - amount_in=amount_in, + quote=quote.quote, slippage_bps=slippage_bps, ) diff --git a/alphaswarm/tools/exchanges/get_token_price_tool.py b/alphaswarm/tools/exchanges/get_token_price_tool.py index 529461fa..d94be677 100644 --- a/alphaswarm/tools/exchanges/get_token_price_tool.py +++ b/alphaswarm/tools/exchanges/get_token_price_tool.py @@ -1,10 +1,12 @@ import logging from datetime import UTC, datetime from decimal import Decimal -from typing import List, Optional +from typing import List, Optional, Union from alphaswarm.config import Config -from alphaswarm.services.exchanges import DEXFactory +from alphaswarm.services.exchanges import DEXFactory, QuoteResult +from alphaswarm.services.exchanges.jupiter.jupiter import JupiterQuote +from alphaswarm.services.exchanges.uniswap.uniswap_client_base import UniswapQuote from pydantic.dataclasses import dataclass from smolagents import Tool @@ -12,25 +14,24 @@ @dataclass -class TokenPrice: - price: Decimal - source: str +class TokenQuote: + datetime: str + dex: str + chain: str + quote: QuoteResult[Union[UniswapQuote, JupiterQuote]] @dataclass class TokenPriceResult: - token_out: str - token_in: str - timestamp: str - prices: List[TokenPrice] + quotes: List[TokenQuote] class GetTokenPriceTool(Tool): name = "get_token_price" description = ( "Get the current price of a token pair from available DEXes. " - "Returns a TokenPriceResult object, which contains list of prices. " - "Result prices are expressed in amount of token_out per token_in. " + f"Returns a {TokenPriceResult.__name__} object containing a list of {TokenQuote.__name__} objects." + "Examples: 'Get the price of 1 ETH in USDC on ethereum', 'Get the price of 1 GIGA in SOL on solana'" ) inputs = { "token_out": { @@ -41,6 +42,7 @@ class GetTokenPriceTool(Tool): "type": "string", "description": "The address of the token we want to sell", }, + "amount_in": {"type": "string", "description": "The amount token_in to be sold, in Token", "required": True}, "chain": { "type": "string", "description": "Blockchain to use. Must be 'solana' for Solana tokens, 'base' for Base tokens, " @@ -64,6 +66,7 @@ def forward( self, token_out: str, token_in: str, + amount_in: str, chain: str, dex_type: Optional[str] = None, ) -> TokenPriceResult: @@ -79,13 +82,15 @@ def forward( # Get prices from all available venues venues = self.config.get_trading_venues_for_chain(chain) if dex_type is None else [dex_type] - prices = [] + prices: List[TokenQuote] = [] for venue in venues: try: dex = DEXFactory.create(dex_name=venue, config=self.config, chain=chain) - price = dex.get_token_price(token_out_info, token_in_info) - prices.append(TokenPrice(price=price, source=venue)) + price = dex.get_token_price(token_out_info, token_in_info, amount_in=Decimal(amount_in)) + timestamp = datetime.now(UTC).strftime("%Y-%m-%d %H:%M UTC") + + prices.append(TokenQuote(dex=venue, chain=chain, quote=price, datetime=timestamp)) except Exception: logger.exception(f"Error getting price from {venue}") @@ -94,9 +99,8 @@ def forward( raise RuntimeError(f"No valid prices found for {token_out}/{token_in}") # Get current timestamp - timestamp = datetime.now(UTC).strftime("%Y-%m-%d %H:%M UTC") # If we have multiple prices, return them all - result = TokenPriceResult(token_out=token_out, token_in=token_in, timestamp=timestamp, prices=prices) + result = TokenPriceResult(quotes=prices) logger.debug(f"Returning result: {result}") return result diff --git a/tests/integration/services/exchanges/jupiter/test_jupiter.py b/tests/integration/services/exchanges/jupiter/test_jupiter.py index e60eeaeb..d6342b30 100644 --- a/tests/integration/services/exchanges/jupiter/test_jupiter.py +++ b/tests/integration/services/exchanges/jupiter/test_jupiter.py @@ -1,3 +1,5 @@ +from decimal import Decimal + from alphaswarm.config import Config from alphaswarm.services.exchanges.jupiter.jupiter import JupiterClient @@ -12,5 +14,5 @@ def test_get_token_price(default_config: Config) -> None: giga = tokens_config["GIGA"] sol = tokens_config["SOL"] - price = client.get_token_price(giga, sol) - assert price > 1000, "A Sol is worth many thousands of GIGA." + quote = client.get_token_price(token_out=giga, token_in=sol, amount_in=Decimal(1)) + assert 10000 > quote.amount_out > 1000, "A Sol is worth many thousands of GIGA." diff --git a/tests/integration/services/exchanges/uniswap/test_uniswap_client_v2.py b/tests/integration/services/exchanges/uniswap/test_uniswap_client_v2.py index 9b08a04f..bce721da 100644 --- a/tests/integration/services/exchanges/uniswap/test_uniswap_client_v2.py +++ b/tests/integration/services/exchanges/uniswap/test_uniswap_client_v2.py @@ -1,3 +1,7 @@ +from decimal import Decimal + +import pytest + from alphaswarm.config import Config from alphaswarm.services.exchanges import DEXFactory from alphaswarm.services.exchanges.uniswap import UniswapClientV2 @@ -27,3 +31,30 @@ def test_get_markets_for_tokens_v2(default_config: Config) -> None: assert {base_token.symbol, quote_token.symbol} == {"USDC", "WETH"} assert base_token.chain == chain assert quote_token.chain == chain + + +@pytest.fixture +def client(default_config: Config, chain: str) -> UniswapClientV2: + return UniswapClientV2.from_config(default_config, chain) + + +chains = [ + "ethereum", + "ethereum_sepolia", + "base", + # "base_sepolia", +] + + +@pytest.mark.skip("Need a funded wallet.") +@pytest.mark.parametrize("chain", chains) +def test_swap_eth_sepolia(client: UniswapClientV2, chain: str) -> None: + usdc = client.chain_config.get_token_info("USDC") + weth = client.chain_config.get_token_info("WETH") + + quote = client.get_token_price(token_out=usdc, token_in=weth, amount_in=Decimal("0.0001")) + assert quote.amount_out > quote.amount_in + + result = client.swap(quote) + print(result) + assert result.amount_out == pytest.approx(quote.amount_out, rel=Decimal("0.05")) diff --git a/tests/integration/services/exchanges/uniswap/test_uniswap_client_v3.py b/tests/integration/services/exchanges/uniswap/test_uniswap_client_v3.py index f54a8e8b..7adb0d54 100644 --- a/tests/integration/services/exchanges/uniswap/test_uniswap_client_v3.py +++ b/tests/integration/services/exchanges/uniswap/test_uniswap_client_v3.py @@ -20,21 +20,6 @@ def eth_client(default_config: Config) -> UniswapClientV3: return UniswapClientV3(chain_config=chain_config, settings=default_config.get_venue_settings_uniswap_v3()) -@pytest.fixture -def eth_sepolia_client(default_config: Config) -> UniswapClientV3: - chain_config = default_config.get_chain_config(chain="ethereum_sepolia") - return UniswapClientV3(chain_config=chain_config, settings=default_config.get_venue_settings_uniswap_v3()) - - -def test_get_price(base_client: UniswapClientV3) -> None: - usdc = base_client.chain_config.get_token_info("USDC") - weth = base_client.chain_config.get_token_info("WETH") - usdc_per_weth = base_client.get_token_price(token_out=usdc, token_in=weth) - - print(f"1 {weth.symbol} is {usdc_per_weth} {usdc.symbol}") - assert usdc_per_weth > 1000, "A WETH is worth many thousands of USDC" - - def test_quote_from_pool(base_client: UniswapClientV3) -> None: pool = base_client._get_pool_by_address(BASE_WETH_USDC_005) usdc: TokenInfo = base_client.chain_config.get_token_info("USDC") @@ -90,17 +75,40 @@ def test_get_markets_for_tokens(eth_client: UniswapClientV3) -> None: assert quote_token.chain == eth_client.chain -@pytest.mark.skip("Needs a wallet with USDC to perform the swap to WETH. Run manually") -def test_swap_eth_sepolia(eth_sepolia_client: UniswapClientV3) -> None: - usdc = eth_sepolia_client.chain_config.get_token_info("USDC") - weth = eth_sepolia_client.chain_config.get_token_info("WETH") +@pytest.fixture +def client(default_config: Config, chain: str) -> UniswapClientV3: + return UniswapClientV3.from_config(default_config, chain) + + +chains = [ + "ethereum", + "ethereum_sepolia", + "base", + # "base_sepolia", +] + + +@pytest.mark.parametrize("chain", chains) +def test_quote_weth_to_usdc(client: UniswapClientV3, chain: str) -> None: + usdc = client.chain_config.get_token_info("USDC") + weth = client.chain_config.get_token_info("WETH") + quote = client.get_token_price(token_out=usdc, token_in=weth, amount_in=Decimal("0.01")) + print(quote) + assert 10_000 > quote.amount_out > 10 + - pool = eth_sepolia_client._get_pool(usdc, weth) - print(f"find pool {pool.address}") +@pytest.mark.skip("Need a funded wallet.") +@pytest.mark.parametrize("chain", chains) +def test_swap_weth_to_usdc(client: UniswapClientV3, chain: str) -> None: + usdc = client.chain_config.get_token_info("USDC") + weth = client.chain_config.get_token_info("WETH") + amount_in = Decimal("0.0001") - quote = eth_sepolia_client._get_token_price_from_pool(token_out=weth, pool=pool) - print(f"1 {usdc.symbol} is {quote} {weth.symbol}") + quote = client.get_token_price(token_out=usdc, token_in=weth, amount_in=amount_in) + assert quote.amount_out > amount_in, "1 USDC is worth a fraction of WETH" - # Buy X Weth for 1 USDC - result = eth_sepolia_client.swap(token_out=weth, token_in=usdc, amount_in=Decimal(100)) + result = client.swap(quote) print(result) + assert result.success + assert result.amount_in == amount_in + assert result.amount_out == pytest.approx(quote.amount_out, rel=Decimal("0.05")) diff --git a/tests/integration/tools/exchanges/test_execute_token_swap_tool.py b/tests/integration/tools/exchanges/test_execute_token_swap_tool.py index 558f8fd2..fe90431a 100644 --- a/tests/integration/tools/exchanges/test_execute_token_swap_tool.py +++ b/tests/integration/tools/exchanges/test_execute_token_swap_tool.py @@ -4,7 +4,12 @@ from alphaswarm.config import Config from alphaswarm.services.chains import EVMClient -from alphaswarm.tools.exchanges import ExecuteTokenSwapTool +from alphaswarm.tools.exchanges import ExecuteTokenSwapTool, GetTokenPriceTool + + +@pytest.fixture +def token_quote_tool(default_config: Config) -> GetTokenPriceTool: + return GetTokenPriceTool(default_config) @pytest.fixture @@ -18,10 +23,22 @@ def sepolia_client(default_config: Config) -> EVMClient: @pytest.mark.skip("Requires a founded wallet. Run manually") -def test_token_swap_tool(token_swap_tool: ExecuteTokenSwapTool, sepolia_client: EVMClient) -> None: +def test_token_swap_tool( + token_quote_tool: GetTokenPriceTool, token_swap_tool: ExecuteTokenSwapTool, sepolia_client: EVMClient +) -> None: weth = sepolia_client.get_token_info_by_name("WETH") usdc = sepolia_client.get_token_info_by_name("USDC") - result = token_swap_tool.forward( - token_out=weth.address, token_in=usdc.address, amount_in=Decimal(1), chain="ethereum_sepolia" + amount_in = Decimal(10) + + quotes = token_quote_tool.forward( + token_out=weth.address, + token_in=usdc.address, + amount_in=str(amount_in), + chain=sepolia_client.chain, + dex_type="uniswap_v3", ) + assert len(quotes.quotes) == 1 + result = token_swap_tool.forward(quote=quotes.quotes[0]) print(result) + assert result.success + assert result.amount_out < amount_in diff --git a/tests/integration/tools/exchanges/test_get_token_price_tool.py b/tests/integration/tools/exchanges/test_get_token_price_tool.py index 577b5c60..730aac3f 100644 --- a/tests/integration/tools/exchanges/test_get_token_price_tool.py +++ b/tests/integration/tools/exchanges/test_get_token_price_tool.py @@ -7,27 +7,29 @@ @pytest.mark.parametrize( - "dex,chain,token_out,token_in,ratio", + "dex,chain,token_out,token_in,min_out,max_out", [ - ("jupiter", "solana", "GIGA", "SOL", 1000), - ("uniswap_v3", "base", "VIRTUAL", "WETH", 1000), - ("uniswap_v3", "ethereum_sepolia", "USDC", "WETH", 100), - ("uniswap_v3", "ethereum", "USDC", "WETH", 100), - ("uniswap_v2", "ethereum", "USDC", "WETH", 100), - (None, "ethereum", "USDC", "WETH", 100), + ("jupiter", "solana", "GIGA", "SOL", 1_000, 10_000), + ("uniswap_v3", "base", "VIRTUAL", "WETH", 1_000, 10_000), + ("uniswap_v3", "ethereum_sepolia", "USDC", "WETH", 10_000, 1_000_000), + ("uniswap_v3", "ethereum", "USDC", "WETH", 100, 10_000), + ("uniswap_v2", "ethereum", "USDC", "WETH", 100, 10_000), + (None, "ethereum", "USDC", "WETH", 100, 10_000), ], ) def test_get_token_price_tool( - dex: Optional[str], chain: str, token_out: str, token_in: str, ratio: int, default_config: Config + dex: Optional[str], chain: str, token_out: str, token_in: str, min_out: int, max_out: int, default_config: Config ) -> None: config = default_config tool = GetTokenPriceTool(config) - chaing_config = config.get_chain_config(chain) - token_info_out = chaing_config.get_token_info(token_out) - token_info_in = chaing_config.get_token_info(token_in) - result = tool.forward(token_out=token_info_out.address, token_in=token_info_in.address, dex_type=dex, chain=chain) + chain_config = config.get_chain_config(chain) + token_info_out = chain_config.get_token_info(token_out) + token_info_in = chain_config.get_token_info(token_in) + result = tool.forward( + token_out=token_info_out.address, token_in=token_info_in.address, amount_in="1", dex_type=dex, chain=chain + ) - assert len(result.prices) > 0, "at least one price is expected" - item = result.prices[0] - assert item.price > ratio, f"1 {token_in} is > {ratio} ({token_out}), got {item.price}" + assert len(result.quotes) > 0, "at least one price is expected" + item = result.quotes[0] + assert min_out < item.quote.amount_out < max_out diff --git a/tests/unit/services/exchanges/factory.py b/tests/unit/services/exchanges/factory.py index 5e362a64..307c4fbe 100644 --- a/tests/unit/services/exchanges/factory.py +++ b/tests/unit/services/exchanges/factory.py @@ -5,27 +5,29 @@ import pytest from alphaswarm.config import Config, TokenInfo, ChainConfig -from alphaswarm.services.exchanges import DEXClient, DEXFactory, SwapResult +from alphaswarm.services.exchanges import DEXClient, DEXFactory, QuoteResult, SwapResult -class MockDex(DEXClient): +class MockDex(DEXClient[str]): @classmethod def from_config(cls, config: Config, chain: str) -> MockDex: return MockDex(chain_config=config.get_chain_config(chain)) def swap( - self, token_out: TokenInfo, token_in: TokenInfo, amount_in: Decimal, slippage_bps: int = 100 + self, + quote: QuoteResult[str], + slippage_bps: int = 100, ) -> SwapResult: raise NotImplementedError("For test only") def get_markets_for_tokens(self, tokens: List[TokenInfo]) -> List[Tuple[TokenInfo, TokenInfo]]: raise NotImplementedError("For test only") - def get_token_price(self, token_out: TokenInfo, token_in: TokenInfo) -> Decimal: + def get_token_price(self, token_out: TokenInfo, token_in: TokenInfo, amount_in: Decimal) -> QuoteResult[str]: raise NotImplementedError("For test only") def __init__(self, chain_config: ChainConfig) -> None: - super().__init__(chain_config=chain_config) + super().__init__(chain_config=chain_config, quote_type=str) def test_register(default_config: Config) -> None: