diff --git a/src/FeeTracker.sol b/src/FeeTracker.sol index c0a90f3..a52a5e7 100644 --- a/src/FeeTracker.sol +++ b/src/FeeTracker.sol @@ -18,6 +18,7 @@ pragma solidity 0.8.28; import {AWKErrors} from "./agentwalletkit/AWKErrors.sol"; import {AccessControl} from "@openzeppelin/contracts/access/AccessControl.sol"; +import {SafeCast} from "@openzeppelin/contracts/utils/math/SafeCast.sol"; error InvalidFeeRate(); @@ -28,7 +29,15 @@ error InvalidFeeRate(); * All recording functions use msg.sender, so callers can only affect their own accounting. */ contract YieldSeekerFeeTracker is AccessControl { + using SafeCast for uint256; + uint256 public constant MAX_FEE_RATE_BPS = 5000; + uint256 public constant ASSET_EXCHANGE_RATE_PRECISION = 1e18; + + struct VaultPosition { + uint128 costBasis; + uint128 shares; + } uint256 public feeRateBps; address public feeCollector; @@ -36,9 +45,8 @@ contract YieldSeekerFeeTracker is AccessControl { mapping(address wallet => uint256) public agentFeesCharged; mapping(address wallet => uint256) public agentFeesPaid; - // Position tracking - mapping(address wallet => mapping(address vault => uint256)) public agentVaultCostBasis; - mapping(address wallet => mapping(address vault => uint256)) public agentVaultShares; + // Position tracking (packed into single slot per wallet+vault) + mapping(address wallet => mapping(address vault => VaultPosition)) internal _agentVaultPositions; mapping(address wallet => mapping(address token => uint256)) public agentYieldTokenFeesOwed; event YieldRecorded(address indexed wallet, uint256 yield, uint256 fee); @@ -99,8 +107,17 @@ contract YieldSeekerFeeTracker is AccessControl { * @return shares The vault shares held */ function getAgentVaultPosition(address wallet, address vault) external view returns (uint256 costBasis, uint256 shares) { - costBasis = agentVaultCostBasis[wallet][vault]; - shares = agentVaultShares[wallet][vault]; + VaultPosition storage pos = _agentVaultPositions[wallet][vault]; + costBasis = pos.costBasis; + shares = pos.shares; + } + + function agentVaultCostBasis(address wallet, address vault) external view returns (uint256) { + return _agentVaultPositions[wallet][vault].costBasis; + } + + function agentVaultShares(address wallet, address vault) external view returns (uint256) { + return _agentVaultPositions[wallet][vault].shares; } /** @@ -122,8 +139,9 @@ contract YieldSeekerFeeTracker is AccessControl { * @param sharesReceived The amount of shares received */ function recordAgentVaultShareDeposit(address vault, uint256 assetsDeposited, uint256 sharesReceived) external { - agentVaultCostBasis[msg.sender][vault] += assetsDeposited; - agentVaultShares[msg.sender][vault] += sharesReceived; + VaultPosition storage pos = _agentVaultPositions[msg.sender][vault]; + pos.costBasis = (uint256(pos.costBasis) + assetsDeposited).toUint128(); + pos.shares = (uint256(pos.shares) + sharesReceived).toUint128(); } function _chargeFeesOnProfit(address wallet, uint256 profit) internal { @@ -141,8 +159,9 @@ contract YieldSeekerFeeTracker is AccessControl { function recordAgentVaultShareWithdraw(address vault, uint256 sharesSpent, uint256 assetsReceived) external { if (sharesSpent == 0) return; address wallet = msg.sender; - uint256 totalShares = agentVaultShares[wallet][vault]; - uint256 totalCostBasis = agentVaultCostBasis[wallet][vault]; + VaultPosition storage pos = _agentVaultPositions[wallet][vault]; + uint256 totalShares = pos.shares; + uint256 totalCostBasis = pos.costBasis; uint256 vaultTokenFeesOwed = agentYieldTokenFeesOwed[wallet][vault]; uint256 feeInBaseAsset = 0; if (vaultTokenFeesOwed > 0) { @@ -153,28 +172,27 @@ contract YieldSeekerFeeTracker is AccessControl { agentFeesCharged[wallet] += feeInBaseAsset; emit YieldRecorded(wallet, feeInBaseAsset, feeInBaseAsset); } + uint256 netAssets = assetsReceived - feeInBaseAsset; if (sharesSpent > totalShares) { // Withdrawing more shares than deposits tracked - treat as full withdrawal if (totalCostBasis > 0 && totalShares > 0) { - uint256 depositSharesValue = (assetsReceived * totalShares) / sharesSpent; + uint256 depositSharesValue = (netAssets * totalShares) / sharesSpent; if (depositSharesValue > totalCostBasis) { uint256 profit = depositSharesValue - totalCostBasis; _chargeFeesOnProfit(wallet, profit); } } - agentVaultCostBasis[wallet][vault] = 0; - agentVaultShares[wallet][vault] = 0; - return; + pos.costBasis = 0; + pos.shares = 0; } else if (totalShares > 0) { // Normal withdrawal within tracked deposits uint256 proportionalCost = (totalCostBasis * sharesSpent) / totalShares; - uint256 netAssets = assetsReceived - feeInBaseAsset; if (netAssets > proportionalCost) { uint256 profit = netAssets - proportionalCost; _chargeFeesOnProfit(wallet, profit); } - agentVaultCostBasis[wallet][vault] = totalCostBasis - proportionalCost; - agentVaultShares[wallet][vault] = totalShares - sharesSpent; + pos.costBasis = (totalCostBasis - proportionalCost).toUint128(); + pos.shares = (totalShares - sharesSpent).toUint128(); } } @@ -183,15 +201,16 @@ contract YieldSeekerFeeTracker is AccessControl { * @param vault The vault address * @param assetsReceived The amount of base assets received from the withdrawal * @param totalVaultBalanceBefore The total vault balance (in base asset terms) before withdrawal - * @dev Uses actual vault balance to compute proportional cost basis, avoiding virtual share conversion. - * For rebasing tokens (Aave, CompoundV3), totalVaultBalanceBefore is the token balance. - * For exchange-rate tokens (CompoundV2), totalVaultBalanceBefore is the underlying value. + * @param vaultTokenToBaseAssetRate The exchange rate from vault tokens to base asset (18-decimal fixed point). + * For rebasing tokens (Aave, CompoundV3), this should be 1e18 (1:1 with underlying). + * For exchange-rate tokens (CompoundV2), this should be the cToken exchange rate. */ - function recordAgentVaultAssetWithdraw(address vault, uint256 assetsReceived, uint256 totalVaultBalanceBefore) external { + function recordAgentVaultAssetWithdraw(address vault, uint256 assetsReceived, uint256 totalVaultBalanceBefore, uint256 vaultTokenToBaseAssetRate) external { if (assetsReceived == 0 || totalVaultBalanceBefore == 0) return; address wallet = msg.sender; - uint256 totalCostBasis = agentVaultCostBasis[wallet][vault]; - uint256 totalShares = agentVaultShares[wallet][vault]; + VaultPosition storage pos = _agentVaultPositions[wallet][vault]; + uint256 totalCostBasis = pos.costBasis; + uint256 totalShares = pos.shares; uint256 vaultTokenFeesOwed = agentYieldTokenFeesOwed[wallet][vault]; uint256 feeInBaseAsset = 0; if (vaultTokenFeesOwed > 0) { @@ -202,10 +221,9 @@ contract YieldSeekerFeeTracker is AccessControl { feeTokenSettled = (vaultTokenFeesOwed * assetsReceived) / totalVaultBalanceBefore; } agentYieldTokenFeesOwed[wallet][vault] = vaultTokenFeesOwed - feeTokenSettled; - if (totalShares > 0) { - feeInBaseAsset = (feeTokenSettled * totalVaultBalanceBefore) / totalShares; - } else { - feeInBaseAsset = feeTokenSettled; + feeInBaseAsset = (feeTokenSettled * vaultTokenToBaseAssetRate) / ASSET_EXCHANGE_RATE_PRECISION; + if (feeInBaseAsset > assetsReceived) { + feeInBaseAsset = assetsReceived; } agentFeesCharged[wallet] += feeInBaseAsset; emit YieldRecorded(wallet, feeInBaseAsset, feeInBaseAsset); @@ -224,8 +242,8 @@ contract YieldSeekerFeeTracker is AccessControl { uint256 profit = netAssets - proportionalCost; _chargeFeesOnProfit(wallet, profit); } - agentVaultCostBasis[wallet][vault] = totalCostBasis - proportionalCost; - agentVaultShares[wallet][vault] = totalShares - proportionalShares; + pos.costBasis = (totalCostBasis - proportionalCost).toUint128(); + pos.shares = (totalShares - proportionalShares).toUint128(); } /** diff --git a/src/adapters/AaveV3Adapter.sol b/src/adapters/AaveV3Adapter.sol index a8f97ee..1882288 100644 --- a/src/adapters/AaveV3Adapter.sol +++ b/src/adapters/AaveV3Adapter.sol @@ -30,11 +30,11 @@ contract YieldSeekerAaveV3Adapter is AWKAaveV3Adapter, YieldSeekerAdapter { /** * @notice Internal deposit implementation with validation and fee tracking */ - function _depositInternal(address vault, uint256 amount) internal override returns (uint256 shares) { + function _depositInternal(address vault, uint256 amount) internal override returns (uint256 shares, uint256 assetsDeposited) { address asset = _getVaultAsset(vault); _requireBaseAsset(asset); - shares = super._depositInternal(vault, amount); - _feeTracker().recordAgentVaultShareDeposit({vault: vault, assetsDeposited: amount, sharesReceived: shares}); + (shares, assetsDeposited) = super._depositInternal(vault, amount); + _feeTracker().recordAgentVaultShareDeposit({vault: vault, assetsDeposited: assetsDeposited, sharesReceived: shares}); } /** @@ -43,8 +43,9 @@ contract YieldSeekerAaveV3Adapter is AWKAaveV3Adapter, YieldSeekerAdapter { function _withdrawInternal(address vault, uint256 shares) internal override returns (uint256 assets) { address asset = _getVaultAsset(vault); _requireBaseAsset(asset); + // aTokens rebase 1:1 with underlying, so balanceOf is already in base asset terms uint256 totalVaultBalanceBefore = IAaveAToken(vault).balanceOf(address(this)); assets = super._withdrawInternal(vault, shares); - _feeTracker().recordAgentVaultAssetWithdraw({vault: vault, assetsReceived: assets, totalVaultBalanceBefore: totalVaultBalanceBefore}); + _feeTracker().recordAgentVaultAssetWithdraw({vault: vault, assetsReceived: assets, totalVaultBalanceBefore: totalVaultBalanceBefore, vaultTokenToBaseAssetRate: _feeTracker().ASSET_EXCHANGE_RATE_PRECISION()}); } } diff --git a/src/adapters/CompoundV2Adapter.sol b/src/adapters/CompoundV2Adapter.sol index 49815e8..1c90b6c 100644 --- a/src/adapters/CompoundV2Adapter.sol +++ b/src/adapters/CompoundV2Adapter.sol @@ -18,6 +18,7 @@ pragma solidity 0.8.28; import {AWKCompoundV2Adapter, ICToken} from "../agentwalletkit/adapters/AWKCompoundV2Adapter.sol"; import {YieldSeekerAdapter} from "./Adapter.sol"; +import {IERC20Metadata} from "@openzeppelin/contracts/token/ERC20/extensions/IERC20Metadata.sol"; /** * @title YieldSeekerCompoundV2Adapter @@ -30,11 +31,11 @@ contract YieldSeekerCompoundV2Adapter is AWKCompoundV2Adapter, YieldSeekerAdapte /** * @notice Internal deposit implementation with validation and fee tracking */ - function _depositInternal(address vault, uint256 amount) internal override returns (uint256 shares) { + function _depositInternal(address vault, uint256 amount) internal override returns (uint256 shares, uint256 assetsDeposited) { address asset = _getVaultAsset(vault); _requireBaseAsset(asset); - shares = super._depositInternal(vault, amount); - _feeTracker().recordAgentVaultShareDeposit({vault: vault, assetsDeposited: amount, sharesReceived: shares}); + (shares, assetsDeposited) = super._depositInternal(vault, amount); + _feeTracker().recordAgentVaultShareDeposit({vault: vault, assetsDeposited: assetsDeposited, sharesReceived: shares}); } /** @@ -44,9 +45,13 @@ contract YieldSeekerCompoundV2Adapter is AWKCompoundV2Adapter, YieldSeekerAdapte address asset = _getVaultAsset(vault); _requireBaseAsset(asset); uint256 cTokenBalance = ICToken(vault).balanceOf(address(this)); - uint256 exchangeRate = ICToken(vault).exchangeRateStored(); - uint256 totalVaultBalanceBefore = (cTokenBalance * exchangeRate) / 1e18; + uint256 exchangeRate = ICToken(vault).exchangeRateCurrent(); + uint256 compoundExchangeRateScale = 10 ** (18 + uint256(IERC20Metadata(asset).decimals()) - uint256(ICToken(vault).decimals())); + // cTokens don't rebase, so convert cToken balance to base asset terms via exchange rate + uint256 totalVaultBalanceBefore = (cTokenBalance * exchangeRate) / compoundExchangeRateScale; + // Normalize exchange rate from Compound's scale to FeeTracker's 1e18 precision + uint256 normalizedRate = (_feeTracker().ASSET_EXCHANGE_RATE_PRECISION() * exchangeRate) / compoundExchangeRateScale; assets = super._withdrawInternal(vault, shares); - _feeTracker().recordAgentVaultAssetWithdraw({vault: vault, assetsReceived: assets, totalVaultBalanceBefore: totalVaultBalanceBefore}); + _feeTracker().recordAgentVaultAssetWithdraw({vault: vault, assetsReceived: assets, totalVaultBalanceBefore: totalVaultBalanceBefore, vaultTokenToBaseAssetRate: normalizedRate}); } } diff --git a/src/adapters/CompoundV3Adapter.sol b/src/adapters/CompoundV3Adapter.sol index f89e1e0..017cc69 100644 --- a/src/adapters/CompoundV3Adapter.sol +++ b/src/adapters/CompoundV3Adapter.sol @@ -30,11 +30,11 @@ contract YieldSeekerCompoundV3Adapter is AWKCompoundV3Adapter, YieldSeekerAdapte /** * @notice Internal deposit implementation with validation and fee tracking */ - function _depositInternal(address vault, uint256 amount) internal override returns (uint256 shares) { + function _depositInternal(address vault, uint256 amount) internal override returns (uint256 shares, uint256 assetsDeposited) { address asset = _getVaultAsset(vault); _requireBaseAsset(asset); - shares = super._depositInternal(vault, amount); - _feeTracker().recordAgentVaultShareDeposit({vault: vault, assetsDeposited: amount, sharesReceived: shares}); + (shares, assetsDeposited) = super._depositInternal(vault, amount); + _feeTracker().recordAgentVaultShareDeposit({vault: vault, assetsDeposited: assetsDeposited, sharesReceived: shares}); } /** @@ -43,8 +43,9 @@ contract YieldSeekerCompoundV3Adapter is AWKCompoundV3Adapter, YieldSeekerAdapte function _withdrawInternal(address vault, uint256 shares) internal override returns (uint256 assets) { address asset = _getVaultAsset(vault); _requireBaseAsset(asset); + // Compound V3 rebases 1:1 with underlying, so balanceOf is already in base asset terms uint256 totalVaultBalanceBefore = ICompoundV3Comet(vault).balanceOf(address(this)); assets = super._withdrawInternal(vault, shares); - _feeTracker().recordAgentVaultAssetWithdraw({vault: vault, assetsReceived: assets, totalVaultBalanceBefore: totalVaultBalanceBefore}); + _feeTracker().recordAgentVaultAssetWithdraw({vault: vault, assetsReceived: assets, totalVaultBalanceBefore: totalVaultBalanceBefore, vaultTokenToBaseAssetRate: _feeTracker().ASSET_EXCHANGE_RATE_PRECISION()}); } } diff --git a/src/adapters/ERC4626Adapter.sol b/src/adapters/ERC4626Adapter.sol index 4e815a2..33f7f52 100644 --- a/src/adapters/ERC4626Adapter.sol +++ b/src/adapters/ERC4626Adapter.sol @@ -31,11 +31,11 @@ contract YieldSeekerERC4626Adapter is AWKERC4626Adapter, YieldSeekerAdapter { * @notice Internal deposit implementation with validation and fee tracking * @dev Overrides AWK logic to add pre-check and post-fee-tracking */ - function _depositInternal(address vault, uint256 amount) internal override returns (uint256 shares) { + function _depositInternal(address vault, uint256 amount) internal override returns (uint256 shares, uint256 assetsDeposited) { address asset = IERC4626(vault).asset(); _requireBaseAsset(asset); - shares = super._depositInternal(vault, amount); - _feeTracker().recordAgentVaultShareDeposit({vault: vault, assetsDeposited: amount, sharesReceived: shares}); + (shares, assetsDeposited) = super._depositInternal(vault, amount); + _feeTracker().recordAgentVaultShareDeposit({vault: vault, assetsDeposited: assetsDeposited, sharesReceived: shares}); } /** diff --git a/src/agentwalletkit/adapters/AWKAaveV3Adapter.sol b/src/agentwalletkit/adapters/AWKAaveV3Adapter.sol index 9dcadec..4ca22ab 100644 --- a/src/agentwalletkit/adapters/AWKAaveV3Adapter.sol +++ b/src/agentwalletkit/adapters/AWKAaveV3Adapter.sol @@ -60,16 +60,18 @@ abstract contract AWKAaveV3Adapter is AWKBaseVaultAdapter { * @dev Runs in wallet context via delegatecall. The amount parameter is the underlying asset amount. * For Aave, shares received equals amount deposited (1:1 rebasing). */ - function _depositInternal(address vault, uint256 amount) internal virtual override returns (uint256 shares) { + function _depositInternal(address vault, uint256 amount) internal virtual override returns (uint256 shares, uint256 assetsDeposited) { if (amount == 0) revert AWKErrors.ZeroAmount(); address asset = IAaveAToken(vault).UNDERLYING_ASSET_ADDRESS(); address pool = IAaveAToken(vault).POOL(); + uint256 baseAssetBalanceBefore = IERC20(asset).balanceOf(address(this)); uint256 balanceBefore = IAaveAToken(vault).balanceOf(address(this)); IERC20(asset).forceApprove(pool, amount); IAaveV3Pool(pool).supply({asset: asset, amount: amount, onBehalfOf: address(this), referralCode: 0}); uint256 balanceAfter = IAaveAToken(vault).balanceOf(address(this)); shares = balanceAfter - balanceBefore; - emit Deposited(address(this), vault, amount, shares); + assetsDeposited = baseAssetBalanceBefore - IERC20(asset).balanceOf(address(this)); + emit Deposited(address(this), vault, assetsDeposited, shares); } /** diff --git a/src/agentwalletkit/adapters/AWKBaseVaultAdapter.sol b/src/agentwalletkit/adapters/AWKBaseVaultAdapter.sol index 334263d..6d8c527 100644 --- a/src/agentwalletkit/adapters/AWKBaseVaultAdapter.sol +++ b/src/agentwalletkit/adapters/AWKBaseVaultAdapter.sol @@ -34,9 +34,10 @@ abstract contract AWKBaseVaultAdapter is AWKAdapter { * @notice Deposit assets into a vault (public interface, should not be called directly) * @param amount The amount of assets to deposit * @return shares The amount of vault shares received + * @return assetsDeposited The actual amount of base asset deposited * @dev This is a placeholder function signature. Actual execution happens via execute() -> _depositInternal() */ - function deposit(uint256 amount) external pure returns (uint256 shares) { + function deposit(uint256 amount) external pure returns (uint256 shares, uint256 assetsDeposited) { revert AWKErrors.DirectCallForbidden(); } @@ -44,9 +45,10 @@ abstract contract AWKBaseVaultAdapter is AWKAdapter { * @notice Deposit a percentage of base asset balance into a vault (public interface, should not be called directly) * @param percentageBps The percentage in basis points (10000 = 100%) * @return shares The amount of vault shares received + * @return assetsDeposited The actual amount of base asset deposited * @dev This is a placeholder function signature. Actual execution happens via execute() -> _depositPercentageInternal() */ - function depositPercentage(uint256 percentageBps) external pure returns (uint256 shares) { + function depositPercentage(uint256 percentageBps) external pure returns (uint256 shares, uint256 assetsDeposited) { revert AWKErrors.DirectCallForbidden(); } @@ -78,14 +80,14 @@ abstract contract AWKBaseVaultAdapter is AWKAdapter { bytes4 selector = bytes4(data[:4]); if (selector == this.deposit.selector) { uint256 amount = abi.decode(data[4:], (uint256)); - uint256 shares = _depositInternal(target, amount); - return abi.encode(shares); + (uint256 shares, uint256 assetsDeposited) = _depositInternal(target, amount); + return abi.encode(shares, assetsDeposited); } if (selector == this.depositPercentage.selector) { uint256 percentageBps = abi.decode(data[4:], (uint256)); address asset = _getVaultAsset(target); - uint256 shares = _depositPercentageInternal(target, percentageBps, IERC20(asset)); - return abi.encode(shares); + (uint256 shares, uint256 assetsDeposited) = _depositPercentageInternal(target, percentageBps, IERC20(asset)); + return abi.encode(shares, assetsDeposited); } if (selector == this.withdraw.selector) { uint256 shares = abi.decode(data[4:], (uint256)); @@ -104,7 +106,7 @@ abstract contract AWKBaseVaultAdapter is AWKAdapter { * @return shares The amount of vault shares received * @dev Must be implemented by concrete vault adapters. Hooks are called automatically. */ - function _depositInternal(address vault, uint256 amount) internal virtual returns (uint256 shares); + function _depositInternal(address vault, uint256 amount) internal virtual returns (uint256 shares, uint256 assetsDeposited); /** * @notice Internal deposit percentage implementation @@ -112,13 +114,14 @@ abstract contract AWKBaseVaultAdapter is AWKAdapter { * @param percentageBps The percentage in basis points (10000 = 100%) * @param baseAsset The base asset token * @return shares The amount of vault shares received + * @return assetsDeposited The actual amount of base asset deposited * @dev Calculates amount based on balance and calls _depositInternal */ - function _depositPercentageInternal(address vault, uint256 percentageBps, IERC20 baseAsset) internal returns (uint256 shares) { + function _depositPercentageInternal(address vault, uint256 percentageBps, IERC20 baseAsset) internal returns (uint256 shares, uint256 assetsDeposited) { if (percentageBps == 0 || percentageBps > 1e4) revert InvalidPercentage(percentageBps); uint256 balance = baseAsset.balanceOf(address(this)); uint256 amount = (balance * percentageBps) / 1e4; - return _depositInternal(vault, amount); + (shares, assetsDeposited) = _depositInternal(vault, amount); } /** diff --git a/src/agentwalletkit/adapters/AWKCompoundV2Adapter.sol b/src/agentwalletkit/adapters/AWKCompoundV2Adapter.sol index b5373d6..9586985 100644 --- a/src/agentwalletkit/adapters/AWKCompoundV2Adapter.sol +++ b/src/agentwalletkit/adapters/AWKCompoundV2Adapter.sol @@ -25,10 +25,12 @@ import {SafeERC20} from "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol */ interface ICToken { function underlying() external view returns (address); + function decimals() external view returns (uint8); function mint(uint256 mintAmount) external returns (uint256); function redeemUnderlying(uint256 redeemAmount) external returns (uint256); function balanceOf(address account) external view returns (uint256); function exchangeRateStored() external view returns (uint256); + function exchangeRateCurrent() external returns (uint256); } /** @@ -53,16 +55,18 @@ abstract contract AWKCompoundV2Adapter is AWKBaseVaultAdapter { * @dev Runs in wallet context via delegatecall. The amount parameter is the underlying asset amount. * Returns the cTokens received as shares. */ - function _depositInternal(address vault, uint256 amount) internal virtual override returns (uint256 shares) { + function _depositInternal(address vault, uint256 amount) internal virtual override returns (uint256 shares, uint256 assetsDeposited) { if (amount == 0) revert AWKErrors.ZeroAmount(); address asset = ICToken(vault).underlying(); + uint256 baseAssetBalanceBefore = IERC20(asset).balanceOf(address(this)); uint256 balanceBefore = ICToken(vault).balanceOf(address(this)); IERC20(asset).forceApprove(vault, amount); uint256 mintResult = ICToken(vault).mint(amount); require(mintResult == 0, "AWKCompoundV2Adapter: mint failed"); uint256 balanceAfter = ICToken(vault).balanceOf(address(this)); shares = balanceAfter - balanceBefore; - emit Deposited(address(this), vault, amount, shares); + assetsDeposited = baseAssetBalanceBefore - IERC20(asset).balanceOf(address(this)); + emit Deposited(address(this), vault, assetsDeposited, shares); } /** diff --git a/src/agentwalletkit/adapters/AWKCompoundV3Adapter.sol b/src/agentwalletkit/adapters/AWKCompoundV3Adapter.sol index 1eadd9c..6ecbcc3 100644 --- a/src/agentwalletkit/adapters/AWKCompoundV3Adapter.sol +++ b/src/agentwalletkit/adapters/AWKCompoundV3Adapter.sol @@ -54,15 +54,17 @@ abstract contract AWKCompoundV3Adapter is AWKBaseVaultAdapter { * @dev Runs in wallet context via delegatecall. The amount parameter is the base token amount. * Returns the change in balance as shares (though Compound V3 uses rebasing, not actual shares). */ - function _depositInternal(address vault, uint256 amount) internal virtual override returns (uint256 shares) { + function _depositInternal(address vault, uint256 amount) internal virtual override returns (uint256 shares, uint256 assetsDeposited) { if (amount == 0) revert AWKErrors.ZeroAmount(); address asset = ICompoundV3Comet(vault).baseToken(); + uint256 baseAssetBalanceBefore = IERC20(asset).balanceOf(address(this)); uint256 balanceBefore = ICompoundV3Comet(vault).balanceOf(address(this)); IERC20(asset).forceApprove(vault, amount); ICompoundV3Comet(vault).supply({asset: asset, amount: amount}); uint256 balanceAfter = ICompoundV3Comet(vault).balanceOf(address(this)); shares = balanceAfter - balanceBefore; - emit Deposited(address(this), vault, amount, shares); + assetsDeposited = baseAssetBalanceBefore - IERC20(asset).balanceOf(address(this)); + emit Deposited(address(this), vault, assetsDeposited, shares); } /** diff --git a/src/agentwalletkit/adapters/AWKERC4626Adapter.sol b/src/agentwalletkit/adapters/AWKERC4626Adapter.sol index b2e287c..60a39d6 100644 --- a/src/agentwalletkit/adapters/AWKERC4626Adapter.sol +++ b/src/agentwalletkit/adapters/AWKERC4626Adapter.sol @@ -49,12 +49,14 @@ abstract contract AWKERC4626Adapter is AWKBaseVaultAdapter { * @notice Internal deposit implementation * @dev Runs in wallet context via delegatecall */ - function _depositInternal(address vault, uint256 amount) internal virtual override returns (uint256 shares) { + function _depositInternal(address vault, uint256 amount) internal virtual override returns (uint256 shares, uint256 assetsDeposited) { if (amount == 0) revert AWKErrors.ZeroAmount(); address asset = IERC4626(vault).asset(); + uint256 baseAssetBalanceBefore = IERC20(asset).balanceOf(address(this)); IERC20(asset).forceApprove(vault, amount); shares = IERC4626(vault).deposit({assets: amount, receiver: address(this)}); - emit Deposited(address(this), vault, amount, shares); + assetsDeposited = baseAssetBalanceBefore - IERC20(asset).balanceOf(address(this)); + emit Deposited(address(this), vault, assetsDeposited, shares); } /** diff --git a/test/mocks/MockCompoundV2.sol b/test/mocks/MockCompoundV2.sol index ac33fc0..f21d0e1 100644 --- a/test/mocks/MockCompoundV2.sol +++ b/test/mocks/MockCompoundV2.sol @@ -10,10 +10,12 @@ import {IERC20} from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; contract MockCToken is ERC20 { IERC20 private immutable _UNDERLYING; uint256 private _exchangeRateStored; - uint256 private constant EXCHANGE_RATE_SCALE = 1e18; + uint256 private immutable EXCHANGE_RATE_SCALE; constructor(address underlying_, string memory name_, string memory symbol_) ERC20(name_, symbol_) { _UNDERLYING = IERC20(underlying_); + uint8 underlyingDecimals = ERC20(underlying_).decimals(); + EXCHANGE_RATE_SCALE = 10 ** (18 + uint256(underlyingDecimals) - 8); _exchangeRateStored = EXCHANGE_RATE_SCALE; // 1:1 initially } diff --git a/test/mocks/MockFeeTracker.sol b/test/mocks/MockFeeTracker.sol index 3ff1ebc..dbc5bc4 100644 --- a/test/mocks/MockFeeTracker.sol +++ b/test/mocks/MockFeeTracker.sol @@ -4,10 +4,13 @@ pragma solidity 0.8.28; import {InvalidFeeRate} from "../../src/FeeTracker.sol"; import {AWKErrors} from "../../src/agentwalletkit/AWKErrors.sol"; import {AccessControl} from "@openzeppelin/contracts/access/AccessControl.sol"; +import {SafeCast} from "@openzeppelin/contracts/utils/math/SafeCast.sol"; /// @title MockFeeTracker /// @notice Mock implementation of FeeTracker for isolated unit testing contract MockFeeTracker is AccessControl { + using SafeCast for uint256; + bytes32 public constant ADMIN_ROLE = keccak256("ADMIN_ROLE"); uint256 public constant BASIS_POINTS = 10000; @@ -18,9 +21,13 @@ contract MockFeeTracker is AccessControl { mapping(address => uint256) private _feesOwed; + struct VaultPosition { + uint128 costBasis; + uint128 shares; + } + // Position tracking - mapping(address wallet => mapping(address vault => uint256)) public agentVaultCostBasis; - mapping(address wallet => mapping(address vault => uint256)) public agentVaultShares; + mapping(address wallet => mapping(address vault => VaultPosition)) internal _agentVaultPositions; mapping(address wallet => mapping(address token => uint256)) public agentYieldTokenFeesOwed; event FeeConfigUpdated(uint256 indexed feeRate, address indexed collector); @@ -79,32 +86,39 @@ contract MockFeeTracker is AccessControl { // ============ Position Tracking ============ function recordAgentVaultShareDeposit(address wallet, address vault, uint256 assetsDeposited, uint256 sharesReceived) external onlyRole(ADMIN_ROLE) { - agentVaultCostBasis[wallet][vault] += assetsDeposited; - agentVaultShares[wallet][vault] += sharesReceived; + VaultPosition storage pos = _agentVaultPositions[wallet][vault]; + pos.costBasis = (uint256(pos.costBasis) + assetsDeposited).toUint128(); + pos.shares = (uint256(pos.shares) + sharesReceived).toUint128(); } function recordAgentVaultShareWithdraw(address wallet, address vault, uint256 sharesSpent, uint256 assetsReceived) external onlyRole(ADMIN_ROLE) { - uint256 totalShares = agentVaultShares[wallet][vault]; - uint256 totalCostBasis = agentVaultCostBasis[wallet][vault]; - + VaultPosition storage pos = _agentVaultPositions[wallet][vault]; + uint256 totalShares = pos.shares; + uint256 totalCostBasis = pos.costBasis; if (totalShares == 0) return; - uint256 proportionalCost = (totalCostBasis * sharesSpent) / totalShares; - if (assetsReceived > proportionalCost) { uint256 profit = assetsReceived - proportionalCost; uint256 fee = (profit * _feeRate) / BASIS_POINTS; _feesOwed[wallet] += fee; emit YieldRecorded(wallet, profit, fee); } - - agentVaultCostBasis[wallet][vault] = totalCostBasis - proportionalCost; - agentVaultShares[wallet][vault] = totalShares - sharesSpent; + pos.costBasis = (totalCostBasis - proportionalCost).toUint128(); + pos.shares = (totalShares - sharesSpent).toUint128(); } function getAgentVaultPosition(address wallet, address vault) external view returns (uint256 costBasis, uint256 shares) { - costBasis = agentVaultCostBasis[wallet][vault]; - shares = agentVaultShares[wallet][vault]; + VaultPosition storage pos = _agentVaultPositions[wallet][vault]; + costBasis = pos.costBasis; + shares = pos.shares; + } + + function agentVaultCostBasis(address wallet, address vault) external view returns (uint256) { + return _agentVaultPositions[wallet][vault].costBasis; + } + + function agentVaultShares(address wallet, address vault) external view returns (uint256) { + return _agentVaultPositions[wallet][vault].shares; } function recordAgentYieldTokenEarned(address wallet, address token, uint256 amount) external onlyRole(ADMIN_ROLE) { diff --git a/test/unit/FeeTracker.t.sol b/test/unit/FeeTracker.t.sol index 05e1e26..a1ef8f6 100644 --- a/test/unit/FeeTracker.t.sol +++ b/test/unit/FeeTracker.t.sol @@ -757,14 +757,140 @@ contract YieldSeekerFeeTrackerTest is Test { uint256 feesAfter = feeTracker.getFeesOwed(agent1); uint256 totalFees = feesAfter - feesBefore; - // New calculation: only charge profit on deposit shares portion // feeInBaseAsset = (13.2 * 0.2) / 12 = 0.22 USDC (fee on yield token) - // depositSharesValue = (13.2 * 10) / 12 = 11.0 USDC - // profit on deposit = 11.0 - 10.0 = 1.0 USDC - // profit fee = 1.0 * 10% = 0.1 USDC - // Total = 0.22 + 0.1 = 0.32 USDC - uint256 expectedTotal = 0.32e6; + // depositSharesValue = ((13.2 - 0.22) * 10) / 12 = 10.8166... USDC + // profit on deposit = 10.8166 - 10.0 = 0.8166 USDC + // profit fee = 0.8166 * 10% = 0.08166 USDC + // Total = 0.22 + 0.08166 = 0.30166 USDC + uint256 expectedTotal = 0.30166e6; assertApproxEqAbs(totalFees, expectedTotal, 1e3, "Total fees should match calculation"); } + + // ============ Audit: Over-withdrawal branch double-counts vault token fees ============ + + function test_VaultShareWithdraw_OverWithdrawalBranch_DoubleCounts_VaultTokenFee() public { + // This test demonstrates the double-counting bug in the over-withdrawal branch + // of recordAgentVaultShareWithdraw (sharesSpent > totalShares). + // + // Setup: + // - Deposit: 1000 USDC -> 1000 shares (cost basis = 1000 USDC) + // - Yield token earned: 200 shares -> 20 shares fee owed (10%) + // - Withdraw: 1200 shares (> 1000 tracked) for 1320 USDC (10% appreciation) + // + // Over-withdrawal branch triggers because 1200 > 1000 tracked shares. + // + // Step 1 (vault token fee conversion): + // feeTokenSwapped = min(1200, 20) = 20 + // feeInBaseAsset = (1320 * 20) / 1200 = 22 USDC + // agentFeesCharged += 22 + // + // Step 2 (profit fee - THE BUG): + // BUGGY: depositSharesValue = (1320 * 1000) / 1200 = 1100 + // CORRECT: depositSharesValue = ((1320 - 22) * 1000) / 1200 = 1081.666... + // + // BUGGY profit = 1100 - 1000 = 100 -> fee = 10 + // CORRECT profit = 1081 - 1000 = 81 -> fee = 8 (approximately) + // + // The buggy path charges ~32 USDC total, correct path charges ~30 USDC. + // The difference (~2 USDC) is the double-counted portion. + + address vault = makeAddr("vault"); + + // Deposit 1000 USDC -> 1000 shares + vm.prank(agent1); + feeTracker.recordAgentVaultShareDeposit(vault, 1000e6, 1000e18); + + // Earn 200 shares of yield -> 20 shares fee owed at 10% + vm.prank(agent1); + feeTracker.recordAgentYieldTokenEarned(vault, 200e18); + assertEq(feeTracker.getAgentYieldTokenFeesOwed(agent1, vault), 20e18); + + uint256 feesBefore = feeTracker.getFeesOwed(agent1); + + // Withdraw all 1200 shares for 1320 USDC (triggers over-withdrawal branch) + vm.prank(agent1); + feeTracker.recordAgentVaultShareWithdraw(vault, 1200e18, 1320e6); + + uint256 totalFeesCharged = feeTracker.getFeesOwed(agent1) - feesBefore; + + // Calculate what the CORRECT fees should be: + // Block 1: vault token fee = (1320e6 * 20e18) / 1200e18 = 22e6 + uint256 feeInBaseAsset = (1320e6 * 20e18) / 1200e18; + assertEq(feeInBaseAsset, 22e6, "vault token fee should be 22 USDC"); + + // Block 2 (correct): depositSharesValue using net assets + uint256 correctDepositSharesValue = ((1320e6 - feeInBaseAsset) * 1000e18) / 1200e18; + uint256 correctProfit = correctDepositSharesValue - 1000e6; // ~81.666e6 + uint256 correctProfitFee = (correctProfit * 1000) / 10000; + uint256 correctTotalFees = feeInBaseAsset + correctProfitFee; + + // The contract should charge the CORRECT fees, not the inflated buggy amount. + // This assertion will FAIL until the bug is fixed (line 178 in FeeTracker.sol + // should use `assetsReceived - feeInBaseAsset` instead of `assetsReceived`). + assertEq(totalFeesCharged, correctTotalFees, "Fees should not double-count vault token fee in over-withdrawal branch"); + } + + // ============ Audit Fix: Safety cap on feeInBaseAsset (Issue 2) ============ + + function test_VaultAssetWithdraw_FeeInBaseAsset_CappedAtAssetsReceived() public { + address vault = makeAddr("vault"); + // Deposit 50 USDC → 50 shares + vm.prank(agent1); + feeTracker.recordAgentVaultShareDeposit(vault, 50e6, 50e6); + // Record massive reward: 1000 tokens → 100 token fee owed + vm.prank(agent1); + feeTracker.recordAgentYieldTokenEarned(vault, 1000e6); + uint256 feeOwed = feeTracker.getAgentYieldTokenFeesOwed(agent1, vault); + assertEq(feeOwed, 100e6, "Should owe 100 tokens in fees"); + // Withdraw 200 USDC from a vault with totalBalance = 1050 + // Without the cap, the old buggy formula would compute a fee > 200 and underflow + // With the fix, the fee is capped and this should not revert + uint256 rebasingRate = feeTracker.ASSET_EXCHANGE_RATE_PRECISION(); + vm.prank(agent1); + feeTracker.recordAgentVaultAssetWithdraw(vault, 200e6, 1050e6, rebasingRate); + // Verify fee was capped at assetsReceived (200e6) + uint256 feesCharged = feeTracker.agentFeesCharged(agent1); + assertTrue(feesCharged <= 200e6, "Fee should be capped at assets received"); + assertTrue(feesCharged > 0, "Fee should be non-zero"); + } + + function test_VaultAssetWithdraw_RebasingRate_CorrectFee() public { + address vault = makeAddr("vault"); + vm.prank(agent1); + feeTracker.recordAgentVaultShareDeposit(vault, 100e6, 100e6); + vm.prank(agent1); + feeTracker.recordAgentYieldTokenEarned(vault, 10e6); + // Withdraw 50 from totalVaultBalance=110, rate=1e18 (rebasing) + uint256 rebasingRate = feeTracker.ASSET_EXCHANGE_RATE_PRECISION(); + vm.prank(agent1); + feeTracker.recordAgentVaultAssetWithdraw(vault, 50e6, 110e6, rebasingRate); + uint256 expectedFeeTokenSettled = uint256(1e6) * uint256(50e6) / uint256(110e6); + // With 1e18 rate, feeInBaseAsset = feeTokenSettled (1:1) + uint256 feesCharged = feeTracker.agentFeesCharged(agent1); + // The vaultToken fee portion should be exactly feeTokenSettled + // Plus potential profit fee on the remaining netAssets + assertTrue(feesCharged >= expectedFeeTokenSettled, "Fee should include at least the token fee portion"); + } + + function test_VaultAssetWithdraw_ExchangeRate_CorrectConversion() public { + address vault = makeAddr("vault"); + // Simulate CompoundV2: deposit 1000 USDC → 1000 cTokens at 1e18 rate + vm.prank(agent1); + feeTracker.recordAgentVaultShareDeposit(vault, 1000e6, 1000e6); + // Record token fee: 10 cTokens owed + vm.prank(agent1); + feeTracker.recordAgentYieldTokenEarned(vault, 100e6); + uint256 feeOwed = feeTracker.getAgentYieldTokenFeesOwed(agent1, vault); + assertEq(feeOwed, 10e6); + // Exchange rate = 1.1e18 (10% appreciation) + // Withdraw 550 USDC from total 1100 USDC balance + uint256 exchangeRate = 1.1e18; + vm.prank(agent1); + feeTracker.recordAgentVaultAssetWithdraw(vault, 550e6, 1100e6, exchangeRate); + // feeTokenSettled = (10e6 * 550e6) / 1100e6 = 5e6 + // feeInBaseAsset = (5e6 * 1.1e18) / 1e18 = 5.5e6 + uint256 feesCharged = feeTracker.agentFeesCharged(agent1); + assertTrue(feesCharged >= 5.5e6, "Should apply exchange rate for non-rebasing tokens"); + } } diff --git a/test/unit/adapters/AaveV3Adapter.t.sol b/test/unit/adapters/AaveV3Adapter.t.sol index 2c743ca..4a643bf 100644 --- a/test/unit/adapters/AaveV3Adapter.t.sol +++ b/test/unit/adapters/AaveV3Adapter.t.sol @@ -155,4 +155,82 @@ contract AaveV3AdapterTest is Test { assertEq(costBasis, depositAmount - proportionalCost, "Cost basis should be reduced proportionally"); assertEq(shares, depositAmount - proportionalCost, "Shares should be reduced proportionally"); } + + // ============ Audit Fix: Rebasing fee conversion uses 1:1 rate (Issue 1) ============ + + function test_RebasingFeeConversion_NotInflated() public { + uint256 depositAmount = 100e6; + wallet.executeAdapter(address(adapter), address(aToken), abi.encodeWithSelector(adapter.deposit.selector, depositAmount)); + vm.prank(address(wallet)); + feeTracker.recordAgentYieldTokenEarned(address(aToken), 10e6); + uint256 feeOwed = feeTracker.getAgentYieldTokenFeesOwed(address(wallet), address(aToken)); + assertEq(feeOwed, 1e6, "Should owe 1 aToken in fees (10% of 10)"); + aToken.addYield(address(wallet), 10e6); + baseAsset.mint(address(aToken), 10e6); + assertEq(aToken.balanceOf(address(wallet)), 110e6); + uint256 feesBefore = feeTracker.agentFeesCharged(address(wallet)); + wallet.executeAdapter(address(adapter), address(aToken), abi.encodeWithSelector(adapter.withdraw.selector, uint256(50e6))); + uint256 feesAfter = feeTracker.agentFeesCharged(address(wallet)); + uint256 feesCharged = feesAfter - feesBefore; + // feeTokenSettled = (1e6 * 50e6) / 110e6 = 454545 + // With 1:1 rate: feeInBaseAsset = 454545 (correct) + // Old buggy formula: 454545 * 110e6 / 100e6 = 500000 (inflated!) + uint256 expectedFeeTokenSettled = uint256(1e6) * uint256(50e6) / uint256(110e6); + uint256 proportionalCost = (depositAmount * uint256(50e6)) / uint256(110e6); + uint256 netAssets = uint256(50e6) - expectedFeeTokenSettled; + uint256 expectedProfitFee = netAssets > proportionalCost ? ((netAssets - proportionalCost) * 1000) / 10_000 : 0; + uint256 expectedTotalFees = expectedFeeTokenSettled + expectedProfitFee; + assertEq(feesCharged, expectedTotalFees, "Fees should not be inflated for rebasing tokens"); + uint256 oldInflatedFee = (expectedFeeTokenSettled * uint256(110e6)) / uint256(100e6); + assertTrue(expectedFeeTokenSettled < oldInflatedFee, "Fee should be less than the old inflated calculation"); + } + + // ============ Audit Fix: No underflow DoS on large rewards (Issue 2) ============ + + function test_LargeReward_NoUnderflowDoS() public { + uint256 depositAmount = 50e6; + wallet.executeAdapter(address(adapter), address(aToken), abi.encodeWithSelector(adapter.deposit.selector, depositAmount)); + vm.prank(address(wallet)); + feeTracker.recordAgentYieldTokenEarned(address(aToken), 1000e6); + aToken.addYield(address(wallet), 1000e6); + baseAsset.mint(address(aToken), 1000e6); + wallet.executeAdapter(address(adapter), address(aToken), abi.encodeWithSelector(adapter.withdraw.selector, uint256(200e6))); + uint256 feesCharged = feeTracker.agentFeesCharged(address(wallet)); + assertTrue(feesCharged > 0, "Fees should be charged"); + assertTrue(feesCharged <= 200e6, "Fee should not exceed withdrawal amount"); + } + + function test_FeeCapAtAssetsReceived() public { + uint256 depositAmount = 10e6; + wallet.executeAdapter(address(adapter), address(aToken), abi.encodeWithSelector(adapter.deposit.selector, depositAmount)); + vm.prank(address(wallet)); + feeTracker.recordAgentYieldTokenEarned(address(aToken), 10_000e6); + aToken.addYield(address(wallet), 10_000e6); + baseAsset.mint(address(aToken), 10_000e6); + wallet.executeAdapter(address(adapter), address(aToken), abi.encodeWithSelector(adapter.withdraw.selector, uint256(5e6))); + uint256 feesCharged = feeTracker.agentFeesCharged(address(wallet)); + assertTrue(feesCharged <= 5e6, "Fee should be capped at assets received"); + } + + // ============ Audit Fix: Full lifecycle with rewards ============ + + function test_FullLifecycle_WithRewards() public { + wallet.executeAdapter(address(adapter), address(aToken), abi.encodeWithSelector(adapter.deposit.selector, uint256(1_000e6))); + aToken.addYield(address(wallet), 50e6); + baseAsset.mint(address(aToken), 50e6); + vm.prank(address(wallet)); + feeTracker.recordAgentYieldTokenEarned(address(aToken), 20e6); + aToken.addYield(address(wallet), 20e6); + baseAsset.mint(address(aToken), 20e6); + assertEq(aToken.balanceOf(address(wallet)), 1070e6); + wallet.executeAdapter(address(adapter), address(aToken), abi.encodeWithSelector(adapter.withdraw.selector, uint256(500e6))); + uint256 feesAfterPartial = feeTracker.agentFeesCharged(address(wallet)); + assertTrue(feesAfterPartial > 0, "Should have charged fees on partial withdraw"); + uint256 remaining = aToken.balanceOf(address(wallet)); + wallet.executeAdapter(address(adapter), address(aToken), abi.encodeWithSelector(adapter.withdraw.selector, remaining)); + (uint256 costBasis, uint256 shares) = feeTracker.getAgentVaultPosition(address(wallet), address(aToken)); + assertEq(costBasis, 0, "Cost basis should be zero after full withdraw"); + assertEq(shares, 0, "Shares should be zero after full withdraw"); + assertEq(feeTracker.getAgentYieldTokenFeesOwed(address(wallet), address(aToken)), 0, "All token fees should be settled"); + } } diff --git a/test/unit/adapters/AdapterFlows.t.sol b/test/unit/adapters/AdapterFlows.t.sol index f9a9771..4f971f1 100644 --- a/test/unit/adapters/AdapterFlows.t.sol +++ b/test/unit/adapters/AdapterFlows.t.sol @@ -164,10 +164,9 @@ contract AdapterFlowsTest is Test { _withdraw(7e6); // vaultTokenFee settlement: 7/110 of 1e6 fee owed (in vault token units) uint256 totalBalanceBefore = 110e6; - uint256 totalShares = 100e6; uint256 feeTokenSettled = (1e6 * 7e6) / totalBalanceBefore; - // Convert vault token fee to base asset using exchange rate (totalVaultBalance / totalShares) - uint256 feeInBaseAsset = (feeTokenSettled * totalBalanceBefore) / totalShares; + // For rebasing tokens (Aave), 1 vault token = 1 underlying, so rate = 1e18 + uint256 feeInBaseAsset = feeTokenSettled; // proportionalCost = (100e6 * 7e6) / 110e6 uint256 proportionalCost = (100e6 * 7e6) / totalBalanceBefore; // netAssets = 7e6 - feeInBaseAsset, profit = netAssets - proportionalCost diff --git a/test/unit/adapters/CompoundV2Adapter.t.sol b/test/unit/adapters/CompoundV2Adapter.t.sol index c7db9ab..df06275 100644 --- a/test/unit/adapters/CompoundV2Adapter.t.sol +++ b/test/unit/adapters/CompoundV2Adapter.t.sol @@ -160,4 +160,31 @@ contract CompoundV2AdapterTest is Test { assertEq(costBasis, depositAmount - proportionalCost, "Cost basis should be reduced proportionally"); assertEq(shares, depositAmount - proportionalCost, "Shares should be reduced proportionally"); } + + // ============ Audit Fix: Uses exchangeRateCurrent (Issue 3) ============ + + function test_ExchangeRateCurrent_CorrectFees() public { + uint256 depositAmount = 1_000e6; + wallet.executeAdapter(address(adapter), address(cToken), abi.encodeWithSelector(adapter.deposit.selector, depositAmount)); + cToken.addYield(1000); + baseAsset.mint(address(cToken), 100e6); + uint256 fullBalance = (depositAmount * 11000) / 10000; + wallet.executeAdapter(address(adapter), address(cToken), abi.encodeWithSelector(adapter.withdraw.selector, fullBalance)); + uint256 profit = fullBalance - depositAmount; + uint256 expectedFee = (profit * 1000) / 10_000; + assertEq(feeTracker.agentFeesCharged(address(wallet)), expectedFee, "Should charge correct fee with current exchange rate"); + } + + function test_CompoundV2_WithVaultTokenFees_UsesExchangeRate() public { + uint256 depositAmount = 1_000e6; + wallet.executeAdapter(address(adapter), address(cToken), abi.encodeWithSelector(adapter.deposit.selector, depositAmount)); + cToken.addYield(1000); + baseAsset.mint(address(cToken), 100e6); + vm.prank(address(wallet)); + feeTracker.recordAgentYieldTokenEarned(address(cToken), 50e6); + uint256 feesBefore = feeTracker.agentFeesCharged(address(wallet)); + wallet.executeAdapter(address(adapter), address(cToken), abi.encodeWithSelector(adapter.withdraw.selector, uint256(550e6))); + uint256 feesAfter = feeTracker.agentFeesCharged(address(wallet)); + assertTrue(feesAfter > feesBefore, "CompoundV2 should charge fees using the exchange rate"); + } } diff --git a/test/unit/adapters/CompoundV3Adapter.t.sol b/test/unit/adapters/CompoundV3Adapter.t.sol index 1dfc1f0..b1aed93 100644 --- a/test/unit/adapters/CompoundV3Adapter.t.sol +++ b/test/unit/adapters/CompoundV3Adapter.t.sol @@ -152,4 +152,74 @@ contract CompoundV3AdapterTest is Test { assertEq(costBasis, depositAmount - proportionalCost, "Cost basis should be reduced proportionally"); assertEq(shares, depositAmount - proportionalCost, "Shares should be reduced proportionally"); } + + // ============ Audit Fix: Rebasing fee conversion uses 1:1 rate (Issue 1) ============ + + function test_RebasingFeeConversion_NotInflated() public { + uint256 depositAmount = 100e6; + wallet.executeAdapter(address(adapter), address(comet), abi.encodeWithSelector(adapter.deposit.selector, depositAmount)); + vm.prank(address(wallet)); + feeTracker.recordAgentYieldTokenEarned(address(comet), 10e6); + comet.addYield(address(wallet), 10e6); + baseAsset.mint(address(comet), 10e6); + uint256 feesBefore = feeTracker.agentFeesCharged(address(wallet)); + wallet.executeAdapter(address(adapter), address(comet), abi.encodeWithSelector(adapter.withdraw.selector, uint256(50e6))); + uint256 feesAfter = feeTracker.agentFeesCharged(address(wallet)); + uint256 feesCharged = feesAfter - feesBefore; + uint256 expectedFeeTokenSettled = uint256(1e6) * uint256(50e6) / uint256(110e6); + uint256 proportionalCost = (depositAmount * uint256(50e6)) / uint256(110e6); + uint256 netAssets = uint256(50e6) - expectedFeeTokenSettled; + uint256 expectedProfitFee = netAssets > proportionalCost ? ((netAssets - proportionalCost) * 1000) / 10_000 : 0; + uint256 expectedTotalFees = expectedFeeTokenSettled + expectedProfitFee; + assertEq(feesCharged, expectedTotalFees, "CompoundV3 fees should not be inflated for rebasing tokens"); + } + + // ============ Audit Fix: Deposit records actual amount, not type(uint256).max (Issue 4) ============ + + function test_DepositRecordsAssetsDeposited() public { + uint256 depositAmount = 500e6; + wallet.executeAdapter(address(adapter), address(comet), abi.encodeWithSelector(adapter.deposit.selector, depositAmount)); + (uint256 costBasis, uint256 shares) = feeTracker.getAgentVaultPosition(address(wallet), address(comet)); + assertEq(costBasis, depositAmount, "Cost basis should be actual deposited amount"); + assertEq(shares, depositAmount, "Shares should be actual deposited amount"); + comet.addYield(address(wallet), 50e6); + baseAsset.mint(address(comet), 50e6); + wallet.executeAdapter(address(adapter), address(comet), abi.encodeWithSelector(adapter.withdraw.selector, uint256(250e6))); + (uint256 costBasisAfter, uint256 sharesAfter) = feeTracker.getAgentVaultPosition(address(wallet), address(comet)); + assertTrue(costBasisAfter < costBasis, "Cost basis should decrease after partial withdrawal"); + assertTrue(sharesAfter < shares, "Shares should decrease after partial withdrawal"); + } + + function test_MultiplePartialWithdraws_NoOverflow() public { + uint256 depositAmount = 1_000e6; + wallet.executeAdapter(address(adapter), address(comet), abi.encodeWithSelector(adapter.deposit.selector, depositAmount)); + comet.addYield(address(wallet), 100e6); + baseAsset.mint(address(comet), 100e6); + wallet.executeAdapter(address(adapter), address(comet), abi.encodeWithSelector(adapter.withdraw.selector, uint256(300e6))); + wallet.executeAdapter(address(adapter), address(comet), abi.encodeWithSelector(adapter.withdraw.selector, uint256(300e6))); + wallet.executeAdapter(address(adapter), address(comet), abi.encodeWithSelector(adapter.withdraw.selector, uint256(300e6))); + (uint256 costBasis, uint256 shares) = feeTracker.getAgentVaultPosition(address(wallet), address(comet)); + assertTrue(costBasis < depositAmount, "Cost basis should be reduced"); + assertTrue(shares < depositAmount, "Shares should be reduced"); + } + + // ============ Audit Fix: Full lifecycle ============ + + function test_FullLifecycle_CorrectFees() public { + wallet.executeAdapter(address(adapter), address(comet), abi.encodeWithSelector(adapter.deposit.selector, uint256(1_000e6))); + (uint256 costBasis, uint256 shares) = feeTracker.getAgentVaultPosition(address(wallet), address(comet)); + assertEq(costBasis, 1_000e6, "Cost basis should be actual amount"); + assertEq(shares, 1_000e6, "Shares should be actual amount"); + comet.addYield(address(wallet), 100e6); + baseAsset.mint(address(comet), 100e6); + wallet.executeAdapter(address(adapter), address(comet), abi.encodeWithSelector(adapter.withdraw.selector, uint256(550e6))); + uint256 remaining = comet.balanceOf(address(wallet)); + wallet.executeAdapter(address(adapter), address(comet), abi.encodeWithSelector(adapter.withdraw.selector, remaining)); + (uint256 costBasisAfter, uint256 sharesAfter) = feeTracker.getAgentVaultPosition(address(wallet), address(comet)); + assertEq(costBasisAfter, 0); + assertEq(sharesAfter, 0); + uint256 totalFees = feeTracker.agentFeesCharged(address(wallet)); + uint256 expectedFee = (100e6 * 1000) / 10_000; + assertEq(totalFees, expectedFee, "Total fees should equal 10% of 100 USDC yield"); + } } diff --git a/test/unit/adapters/ERC4626Adapter.t.sol b/test/unit/adapters/ERC4626Adapter.t.sol index c3d459f..56d4a1a 100644 --- a/test/unit/adapters/ERC4626Adapter.t.sol +++ b/test/unit/adapters/ERC4626Adapter.t.sol @@ -8,7 +8,7 @@ import {AWKErrors} from "../../../src/agentwalletkit/AWKErrors.sol"; import {MockERC20} from "../../mocks/MockERC20.sol"; import {MockERC4626} from "../../mocks/MockERC4626.sol"; import {AdapterWalletHarness} from "./AdapterHarness.t.sol"; -import {Test} from "forge-std/Test.sol"; +import {Test, console} from "forge-std/Test.sol"; contract ERC4626AdapterTest is Test { YieldSeekerERC4626Adapter adapter; @@ -102,4 +102,60 @@ contract ERC4626AdapterTest is Test { uint256 assets = _decodeUint(result); assertGt(assets, 1_000e6); } + + function test_DepositGasOverhead() public { + MockERC4626 directVault = new MockERC4626(address(baseAsset), "Direct", "dUSDC"); + address directUser = address(0xDEAD); + baseAsset.mint(directUser, 1_000_000e6); + vm.prank(directUser); + baseAsset.approve(address(directVault), type(uint256).max); + uint256 depositAmount = 10_000e6; + // 1) Direct vault deposit (baseline) + vm.prank(directUser); + uint256 g0 = gasleft(); + directVault.deposit(depositAmount, directUser); + uint256 directGas = g0 - gasleft(); + // 2) Full adapter path + uint256 g1 = gasleft(); + wallet.executeAdapter(address(adapter), address(vault), abi.encodeWithSelector(adapter.deposit.selector, depositAmount)); + uint256 adapterGas = g1 - gasleft(); + // 3) Delegatecall overhead (empty call) + uint256 g2 = gasleft(); + (bool ok,) = address(adapter).delegatecall(abi.encodeWithSignature("nonexistent()")); + uint256 delegatecallGas = g2 - gasleft(); + ok; // silence warning + // 4) _getVaultAsset (vault.asset()) + uint256 g3 = gasleft(); + vault.asset(); + uint256 getAssetGas = g3 - gasleft(); + // 5) _requireBaseAsset (reads baseAsset from wallet then compares) + uint256 g4 = gasleft(); + wallet.baseAsset(); + uint256 readBaseAssetGas = g4 - gasleft(); + // 6) Balance-delta: 2x balanceOf on base asset + uint256 g5 = gasleft(); + baseAsset.balanceOf(address(wallet)); + baseAsset.balanceOf(address(wallet)); + uint256 twoBalanceOfGas = g5 - gasleft(); + // 7) forceApprove (IERC20.approve) + uint256 g6 = gasleft(); + baseAsset.approve(address(vault), depositAmount); + uint256 approveGas = g6 - gasleft(); + // 8) FeeTracker.recordAgentVaultShareDeposit (2 SSTORE) + uint256 g7 = gasleft(); + feeTracker.recordAgentVaultShareDeposit(address(vault), depositAmount, depositAmount); + uint256 feeTrackerGas = g7 - gasleft(); + // Summary + console.log("=== ERC4626 Deposit Gas Breakdown ==="); + console.log("Direct vault deposit (baseline): ", directGas); + console.log("Full adapter path: ", adapterGas); + console.log("Overhead: ", adapterGas - directGas); + console.log("--- Component costs ---"); + console.log("Delegatecall dispatch: ", delegatecallGas); + console.log("vault.asset(): ", getAssetGas); + console.log("Read wallet.baseAsset(): ", readBaseAssetGas); + console.log("2x balanceOf (balance-delta): ", twoBalanceOfGas); + console.log("ERC20 approve: ", approveGas); + console.log("FeeTracker record deposit: ", feeTrackerGas); + } }