Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 46 additions & 28 deletions src/FeeTracker.sol
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pragma solidity 0.8.28;

import {AWKErrors} from "./agentwalletkit/AWKErrors.sol";
import {AccessControl} from "@openzeppelin/contracts/access/AccessControl.sol";
import {SafeCast} from "@openzeppelin/contracts/utils/math/SafeCast.sol";

error InvalidFeeRate();

Expand All @@ -28,17 +29,24 @@ error InvalidFeeRate();
* All recording functions use msg.sender, so callers can only affect their own accounting.
*/
contract YieldSeekerFeeTracker is AccessControl {
using SafeCast for uint256;

uint256 public constant MAX_FEE_RATE_BPS = 5000;
uint256 public constant ASSET_EXCHANGE_RATE_PRECISION = 1e18;

struct VaultPosition {
uint128 costBasis;
uint128 shares;
}

uint256 public feeRateBps;
address public feeCollector;

mapping(address wallet => uint256) public agentFeesCharged;
mapping(address wallet => uint256) public agentFeesPaid;

// Position tracking
mapping(address wallet => mapping(address vault => uint256)) public agentVaultCostBasis;
mapping(address wallet => mapping(address vault => uint256)) public agentVaultShares;
// Position tracking (packed into single slot per wallet+vault)
mapping(address wallet => mapping(address vault => VaultPosition)) internal _agentVaultPositions;
mapping(address wallet => mapping(address token => uint256)) public agentYieldTokenFeesOwed;

event YieldRecorded(address indexed wallet, uint256 yield, uint256 fee);
Expand Down Expand Up @@ -99,8 +107,17 @@ contract YieldSeekerFeeTracker is AccessControl {
* @return shares The vault shares held
*/
function getAgentVaultPosition(address wallet, address vault) external view returns (uint256 costBasis, uint256 shares) {
costBasis = agentVaultCostBasis[wallet][vault];
shares = agentVaultShares[wallet][vault];
VaultPosition storage pos = _agentVaultPositions[wallet][vault];
costBasis = pos.costBasis;
shares = pos.shares;
}

function agentVaultCostBasis(address wallet, address vault) external view returns (uint256) {
return _agentVaultPositions[wallet][vault].costBasis;
}

function agentVaultShares(address wallet, address vault) external view returns (uint256) {
return _agentVaultPositions[wallet][vault].shares;
}

/**
Expand All @@ -122,8 +139,9 @@ contract YieldSeekerFeeTracker is AccessControl {
* @param sharesReceived The amount of shares received
*/
function recordAgentVaultShareDeposit(address vault, uint256 assetsDeposited, uint256 sharesReceived) external {
agentVaultCostBasis[msg.sender][vault] += assetsDeposited;
agentVaultShares[msg.sender][vault] += sharesReceived;
VaultPosition storage pos = _agentVaultPositions[msg.sender][vault];
pos.costBasis = (uint256(pos.costBasis) + assetsDeposited).toUint128();
pos.shares = (uint256(pos.shares) + sharesReceived).toUint128();
}

function _chargeFeesOnProfit(address wallet, uint256 profit) internal {
Expand All @@ -141,8 +159,9 @@ contract YieldSeekerFeeTracker is AccessControl {
function recordAgentVaultShareWithdraw(address vault, uint256 sharesSpent, uint256 assetsReceived) external {
if (sharesSpent == 0) return;
address wallet = msg.sender;
uint256 totalShares = agentVaultShares[wallet][vault];
uint256 totalCostBasis = agentVaultCostBasis[wallet][vault];
VaultPosition storage pos = _agentVaultPositions[wallet][vault];
uint256 totalShares = pos.shares;
uint256 totalCostBasis = pos.costBasis;
uint256 vaultTokenFeesOwed = agentYieldTokenFeesOwed[wallet][vault];
uint256 feeInBaseAsset = 0;
if (vaultTokenFeesOwed > 0) {
Expand All @@ -153,28 +172,27 @@ contract YieldSeekerFeeTracker is AccessControl {
agentFeesCharged[wallet] += feeInBaseAsset;
emit YieldRecorded(wallet, feeInBaseAsset, feeInBaseAsset);
}
uint256 netAssets = assetsReceived - feeInBaseAsset;
if (sharesSpent > totalShares) {
// Withdrawing more shares than deposits tracked - treat as full withdrawal
if (totalCostBasis > 0 && totalShares > 0) {
uint256 depositSharesValue = (assetsReceived * totalShares) / sharesSpent;
uint256 depositSharesValue = (netAssets * totalShares) / sharesSpent;
if (depositSharesValue > totalCostBasis) {
uint256 profit = depositSharesValue - totalCostBasis;
_chargeFeesOnProfit(wallet, profit);
}
}
agentVaultCostBasis[wallet][vault] = 0;
agentVaultShares[wallet][vault] = 0;
return;
pos.costBasis = 0;
pos.shares = 0;
} else if (totalShares > 0) {
// Normal withdrawal within tracked deposits
uint256 proportionalCost = (totalCostBasis * sharesSpent) / totalShares;
uint256 netAssets = assetsReceived - feeInBaseAsset;
if (netAssets > proportionalCost) {
uint256 profit = netAssets - proportionalCost;
_chargeFeesOnProfit(wallet, profit);
}
agentVaultCostBasis[wallet][vault] = totalCostBasis - proportionalCost;
agentVaultShares[wallet][vault] = totalShares - sharesSpent;
pos.costBasis = (totalCostBasis - proportionalCost).toUint128();
pos.shares = (totalShares - sharesSpent).toUint128();
}
}

Expand All @@ -183,15 +201,16 @@ contract YieldSeekerFeeTracker is AccessControl {
* @param vault The vault address
* @param assetsReceived The amount of base assets received from the withdrawal
* @param totalVaultBalanceBefore The total vault balance (in base asset terms) before withdrawal
* @dev Uses actual vault balance to compute proportional cost basis, avoiding virtual share conversion.
* For rebasing tokens (Aave, CompoundV3), totalVaultBalanceBefore is the token balance.
* For exchange-rate tokens (CompoundV2), totalVaultBalanceBefore is the underlying value.
* @param vaultTokenToBaseAssetRate The exchange rate from vault tokens to base asset (18-decimal fixed point).
* For rebasing tokens (Aave, CompoundV3), this should be 1e18 (1:1 with underlying).
* For exchange-rate tokens (CompoundV2), this should be the cToken exchange rate.
*/
function recordAgentVaultAssetWithdraw(address vault, uint256 assetsReceived, uint256 totalVaultBalanceBefore) external {
function recordAgentVaultAssetWithdraw(address vault, uint256 assetsReceived, uint256 totalVaultBalanceBefore, uint256 vaultTokenToBaseAssetRate) external {
if (assetsReceived == 0 || totalVaultBalanceBefore == 0) return;
address wallet = msg.sender;
uint256 totalCostBasis = agentVaultCostBasis[wallet][vault];
uint256 totalShares = agentVaultShares[wallet][vault];
VaultPosition storage pos = _agentVaultPositions[wallet][vault];
uint256 totalCostBasis = pos.costBasis;
uint256 totalShares = pos.shares;
uint256 vaultTokenFeesOwed = agentYieldTokenFeesOwed[wallet][vault];
uint256 feeInBaseAsset = 0;
if (vaultTokenFeesOwed > 0) {
Expand All @@ -202,10 +221,9 @@ contract YieldSeekerFeeTracker is AccessControl {
feeTokenSettled = (vaultTokenFeesOwed * assetsReceived) / totalVaultBalanceBefore;
}
agentYieldTokenFeesOwed[wallet][vault] = vaultTokenFeesOwed - feeTokenSettled;
if (totalShares > 0) {
feeInBaseAsset = (feeTokenSettled * totalVaultBalanceBefore) / totalShares;
} else {
feeInBaseAsset = feeTokenSettled;
feeInBaseAsset = (feeTokenSettled * vaultTokenToBaseAssetRate) / ASSET_EXCHANGE_RATE_PRECISION;
if (feeInBaseAsset > assetsReceived) {
feeInBaseAsset = assetsReceived;
}
agentFeesCharged[wallet] += feeInBaseAsset;
emit YieldRecorded(wallet, feeInBaseAsset, feeInBaseAsset);
Expand All @@ -224,8 +242,8 @@ contract YieldSeekerFeeTracker is AccessControl {
uint256 profit = netAssets - proportionalCost;
_chargeFeesOnProfit(wallet, profit);
}
agentVaultCostBasis[wallet][vault] = totalCostBasis - proportionalCost;
agentVaultShares[wallet][vault] = totalShares - proportionalShares;
pos.costBasis = (totalCostBasis - proportionalCost).toUint128();
pos.shares = (totalShares - proportionalShares).toUint128();
}

/**
Expand Down
9 changes: 5 additions & 4 deletions src/adapters/AaveV3Adapter.sol
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ contract YieldSeekerAaveV3Adapter is AWKAaveV3Adapter, YieldSeekerAdapter {
/**
* @notice Internal deposit implementation with validation and fee tracking
*/
function _depositInternal(address vault, uint256 amount) internal override returns (uint256 shares) {
function _depositInternal(address vault, uint256 amount) internal override returns (uint256 shares, uint256 assetsDeposited) {
address asset = _getVaultAsset(vault);
_requireBaseAsset(asset);
shares = super._depositInternal(vault, amount);
_feeTracker().recordAgentVaultShareDeposit({vault: vault, assetsDeposited: amount, sharesReceived: shares});
(shares, assetsDeposited) = super._depositInternal(vault, amount);
_feeTracker().recordAgentVaultShareDeposit({vault: vault, assetsDeposited: assetsDeposited, sharesReceived: shares});
}

/**
Expand All @@ -43,8 +43,9 @@ contract YieldSeekerAaveV3Adapter is AWKAaveV3Adapter, YieldSeekerAdapter {
function _withdrawInternal(address vault, uint256 shares) internal override returns (uint256 assets) {
address asset = _getVaultAsset(vault);
_requireBaseAsset(asset);
// aTokens rebase 1:1 with underlying, so balanceOf is already in base asset terms
uint256 totalVaultBalanceBefore = IAaveAToken(vault).balanceOf(address(this));
assets = super._withdrawInternal(vault, shares);
_feeTracker().recordAgentVaultAssetWithdraw({vault: vault, assetsReceived: assets, totalVaultBalanceBefore: totalVaultBalanceBefore});
_feeTracker().recordAgentVaultAssetWithdraw({vault: vault, assetsReceived: assets, totalVaultBalanceBefore: totalVaultBalanceBefore, vaultTokenToBaseAssetRate: _feeTracker().ASSET_EXCHANGE_RATE_PRECISION()});
}
}
17 changes: 11 additions & 6 deletions src/adapters/CompoundV2Adapter.sol
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pragma solidity 0.8.28;

import {AWKCompoundV2Adapter, ICToken} from "../agentwalletkit/adapters/AWKCompoundV2Adapter.sol";
import {YieldSeekerAdapter} from "./Adapter.sol";
import {IERC20Metadata} from "@openzeppelin/contracts/token/ERC20/extensions/IERC20Metadata.sol";

/**
* @title YieldSeekerCompoundV2Adapter
Expand All @@ -30,11 +31,11 @@ contract YieldSeekerCompoundV2Adapter is AWKCompoundV2Adapter, YieldSeekerAdapte
/**
* @notice Internal deposit implementation with validation and fee tracking
*/
function _depositInternal(address vault, uint256 amount) internal override returns (uint256 shares) {
function _depositInternal(address vault, uint256 amount) internal override returns (uint256 shares, uint256 assetsDeposited) {
address asset = _getVaultAsset(vault);
_requireBaseAsset(asset);
shares = super._depositInternal(vault, amount);
_feeTracker().recordAgentVaultShareDeposit({vault: vault, assetsDeposited: amount, sharesReceived: shares});
(shares, assetsDeposited) = super._depositInternal(vault, amount);
_feeTracker().recordAgentVaultShareDeposit({vault: vault, assetsDeposited: assetsDeposited, sharesReceived: shares});
}

/**
Expand All @@ -44,9 +45,13 @@ contract YieldSeekerCompoundV2Adapter is AWKCompoundV2Adapter, YieldSeekerAdapte
address asset = _getVaultAsset(vault);
_requireBaseAsset(asset);
uint256 cTokenBalance = ICToken(vault).balanceOf(address(this));
uint256 exchangeRate = ICToken(vault).exchangeRateStored();
uint256 totalVaultBalanceBefore = (cTokenBalance * exchangeRate) / 1e18;
uint256 exchangeRate = ICToken(vault).exchangeRateCurrent();
uint256 compoundExchangeRateScale = 10 ** (18 + uint256(IERC20Metadata(asset).decimals()) - uint256(ICToken(vault).decimals()));
// cTokens don't rebase, so convert cToken balance to base asset terms via exchange rate
uint256 totalVaultBalanceBefore = (cTokenBalance * exchangeRate) / compoundExchangeRateScale;
// Normalize exchange rate from Compound's scale to FeeTracker's 1e18 precision
uint256 normalizedRate = (_feeTracker().ASSET_EXCHANGE_RATE_PRECISION() * exchangeRate) / compoundExchangeRateScale;
assets = super._withdrawInternal(vault, shares);
_feeTracker().recordAgentVaultAssetWithdraw({vault: vault, assetsReceived: assets, totalVaultBalanceBefore: totalVaultBalanceBefore});
_feeTracker().recordAgentVaultAssetWithdraw({vault: vault, assetsReceived: assets, totalVaultBalanceBefore: totalVaultBalanceBefore, vaultTokenToBaseAssetRate: normalizedRate});
}
}
9 changes: 5 additions & 4 deletions src/adapters/CompoundV3Adapter.sol
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ contract YieldSeekerCompoundV3Adapter is AWKCompoundV3Adapter, YieldSeekerAdapte
/**
* @notice Internal deposit implementation with validation and fee tracking
*/
function _depositInternal(address vault, uint256 amount) internal override returns (uint256 shares) {
function _depositInternal(address vault, uint256 amount) internal override returns (uint256 shares, uint256 assetsDeposited) {
address asset = _getVaultAsset(vault);
_requireBaseAsset(asset);
shares = super._depositInternal(vault, amount);
_feeTracker().recordAgentVaultShareDeposit({vault: vault, assetsDeposited: amount, sharesReceived: shares});
(shares, assetsDeposited) = super._depositInternal(vault, amount);
_feeTracker().recordAgentVaultShareDeposit({vault: vault, assetsDeposited: assetsDeposited, sharesReceived: shares});
}

/**
Expand All @@ -43,8 +43,9 @@ contract YieldSeekerCompoundV3Adapter is AWKCompoundV3Adapter, YieldSeekerAdapte
function _withdrawInternal(address vault, uint256 shares) internal override returns (uint256 assets) {
address asset = _getVaultAsset(vault);
_requireBaseAsset(asset);
// Compound V3 rebases 1:1 with underlying, so balanceOf is already in base asset terms
uint256 totalVaultBalanceBefore = ICompoundV3Comet(vault).balanceOf(address(this));
assets = super._withdrawInternal(vault, shares);
_feeTracker().recordAgentVaultAssetWithdraw({vault: vault, assetsReceived: assets, totalVaultBalanceBefore: totalVaultBalanceBefore});
_feeTracker().recordAgentVaultAssetWithdraw({vault: vault, assetsReceived: assets, totalVaultBalanceBefore: totalVaultBalanceBefore, vaultTokenToBaseAssetRate: _feeTracker().ASSET_EXCHANGE_RATE_PRECISION()});
}
}
6 changes: 3 additions & 3 deletions src/adapters/ERC4626Adapter.sol
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ contract YieldSeekerERC4626Adapter is AWKERC4626Adapter, YieldSeekerAdapter {
* @notice Internal deposit implementation with validation and fee tracking
* @dev Overrides AWK logic to add pre-check and post-fee-tracking
*/
function _depositInternal(address vault, uint256 amount) internal override returns (uint256 shares) {
function _depositInternal(address vault, uint256 amount) internal override returns (uint256 shares, uint256 assetsDeposited) {
address asset = IERC4626(vault).asset();
_requireBaseAsset(asset);
shares = super._depositInternal(vault, amount);
_feeTracker().recordAgentVaultShareDeposit({vault: vault, assetsDeposited: amount, sharesReceived: shares});
(shares, assetsDeposited) = super._depositInternal(vault, amount);
_feeTracker().recordAgentVaultShareDeposit({vault: vault, assetsDeposited: assetsDeposited, sharesReceived: shares});
}

/**
Expand Down
6 changes: 4 additions & 2 deletions src/agentwalletkit/adapters/AWKAaveV3Adapter.sol
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,18 @@ abstract contract AWKAaveV3Adapter is AWKBaseVaultAdapter {
* @dev Runs in wallet context via delegatecall. The amount parameter is the underlying asset amount.
* For Aave, shares received equals amount deposited (1:1 rebasing).
*/
function _depositInternal(address vault, uint256 amount) internal virtual override returns (uint256 shares) {
function _depositInternal(address vault, uint256 amount) internal virtual override returns (uint256 shares, uint256 assetsDeposited) {
if (amount == 0) revert AWKErrors.ZeroAmount();
address asset = IAaveAToken(vault).UNDERLYING_ASSET_ADDRESS();
address pool = IAaveAToken(vault).POOL();
uint256 baseAssetBalanceBefore = IERC20(asset).balanceOf(address(this));
uint256 balanceBefore = IAaveAToken(vault).balanceOf(address(this));
IERC20(asset).forceApprove(pool, amount);
IAaveV3Pool(pool).supply({asset: asset, amount: amount, onBehalfOf: address(this), referralCode: 0});
uint256 balanceAfter = IAaveAToken(vault).balanceOf(address(this));
shares = balanceAfter - balanceBefore;
emit Deposited(address(this), vault, amount, shares);
assetsDeposited = baseAssetBalanceBefore - IERC20(asset).balanceOf(address(this));
emit Deposited(address(this), vault, assetsDeposited, shares);
}

/**
Expand Down
Loading